mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-12 01:10:00 +08:00
Compare commits
138 Commits
tools/goct
...
go1.16
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ea0a843f8 | ||
|
|
9e0e01b2bc | ||
|
|
af50a80d01 | ||
|
|
703fb8d970 | ||
|
|
e964e530e1 | ||
|
|
52265087d1 | ||
|
|
b4c2677eb9 | ||
|
|
30296fb1ca | ||
|
|
356c80defd | ||
|
|
8c31525378 | ||
|
|
2cf09f3c36 | ||
|
|
d41e542c92 | ||
|
|
265a24ac6d | ||
|
|
7d88fc39dc | ||
|
|
6957b6a344 | ||
|
|
bca6a230c8 | ||
|
|
cc8413d683 | ||
|
|
3842283fa8 | ||
|
|
fe13a533f5 | ||
|
|
7a327ccda4 | ||
|
|
06e4507406 | ||
|
|
8794d5b753 | ||
|
|
9bfa63d995 | ||
|
|
a432b121fb | ||
|
|
b61c94bb66 | ||
|
|
93fcf899dc | ||
|
|
9f4b3bae92 | ||
|
|
805cb87d98 | ||
|
|
366131640e | ||
|
|
956884a3ff | ||
|
|
f571cb8af2 | ||
|
|
cc5acf3b90 | ||
|
|
e1aa665443 | ||
|
|
cd357d9484 | ||
|
|
6d4d7cbd6b | ||
|
|
c593b5b531 | ||
|
|
fd5b38b07c | ||
|
|
41efb48f55 | ||
|
|
0ef3626839 | ||
|
|
77a72b16e9 | ||
|
|
21566f1b7a | ||
|
|
b2646e228b | ||
|
|
588b883710 | ||
|
|
033910bbd8 | ||
|
|
530dd79e3f | ||
|
|
cd5263ac75 | ||
|
|
ea3302a468 | ||
|
|
abf15b373c | ||
|
|
a865e9ee29 | ||
|
|
f8292198cf | ||
|
|
016d965f56 | ||
|
|
95d7c73409 | ||
|
|
939ef2a181 | ||
|
|
f0b8dd45fe | ||
|
|
0ba9335b04 | ||
|
|
04f181f0b4 | ||
|
|
89f841c126 | ||
|
|
d785c8c377 | ||
|
|
687a1d15da | ||
|
|
aaa974e1ad | ||
|
|
2779568ccf | ||
|
|
f7d50ae626 | ||
|
|
33594ea350 | ||
|
|
ee2ec974c4 | ||
|
|
fd2f2f0f54 | ||
|
|
86a2429d7d | ||
|
|
e5fe5dcc50 | ||
|
|
b510e7c242 | ||
|
|
dfe92e709f | ||
|
|
cb649cf627 | ||
|
|
ce19a5ade6 | ||
|
|
6dc56de714 | ||
|
|
f3369f8e81 | ||
|
|
c9b05ae07e | ||
|
|
32a59dbc27 | ||
|
|
ba0dff2d61 | ||
|
|
10da5e0424 | ||
|
|
4bed34090f | ||
|
|
2bfecf9354 | ||
|
|
6d129e0264 | ||
|
|
a2df1bb164 | ||
|
|
5f02e623f5 | ||
|
|
963b52fb1b | ||
|
|
02265d0bfe | ||
|
|
2e57e91826 | ||
|
|
82c642d3f4 | ||
|
|
b2571883ca | ||
|
|
00ff50c2cc | ||
|
|
4d7fa08b0b | ||
|
|
367afb544c | ||
|
|
43b8c7f641 | ||
|
|
a2dcb0079a | ||
|
|
f9619328f2 | ||
|
|
bae061a67e | ||
|
|
0b176e17ac | ||
|
|
6340e24c17 | ||
|
|
74e0676617 | ||
|
|
0defb7522f | ||
|
|
0c786ca849 | ||
|
|
26c541b9cb | ||
|
|
ade6f9ee46 | ||
|
|
f4502171ea | ||
|
|
8157e2118d | ||
|
|
e52dace416 | ||
|
|
dc260f196a | ||
|
|
559726112c | ||
|
|
a5fcf24c04 | ||
|
|
fc9b3ffdc1 | ||
|
|
e71c505e94 | ||
|
|
21c49009c0 | ||
|
|
69d355eb4b | ||
|
|
83f88d177f | ||
|
|
641ebf1667 | ||
|
|
cf435bfcc1 | ||
|
|
28f1b15b8e | ||
|
|
42413dc294 | ||
|
|
ec7ac43948 | ||
|
|
deefc1a8eb | ||
|
|
036328f1ea | ||
|
|
85057a623d | ||
|
|
1c544a26be | ||
|
|
20a61ce43e | ||
|
|
dd294e8cd6 | ||
|
|
3e9d0161bc | ||
|
|
cf6c349118 | ||
|
|
c7a0ec428c | ||
|
|
ce1c02f4f9 | ||
|
|
c3756a8f1c | ||
|
|
f4fd735aee | ||
|
|
683d793719 | ||
|
|
affbcb5698 | ||
|
|
f0d1722bbd | ||
|
|
c4f8eca459 | ||
|
|
251c071418 | ||
|
|
6652c4e445 | ||
|
|
f73613dff0 | ||
|
|
7a75dce465 | ||
|
|
801f1adf71 |
@@ -1,3 +1,6 @@
|
||||
comment: false
|
||||
comment:
|
||||
layout: "flags, files"
|
||||
behavior: once
|
||||
require_changes: true
|
||||
ignore:
|
||||
- "tools"
|
||||
28
ROADMAP.md
28
ROADMAP.md
@@ -1,28 +0,0 @@
|
||||
# go-zero Roadmap
|
||||
|
||||
This document defines a high level roadmap for go-zero development and upcoming releases.
|
||||
Community and contributor involvement is vital for successfully implementing all desired items for each release.
|
||||
We hope that the items listed below will inspire further engagement from the community to keep go-zero progressing and shipping exciting and valuable features.
|
||||
|
||||
## 2021 Q2
|
||||
- [x] Support service discovery through K8S client api
|
||||
- [x] Log full sql statements for easier sql problem solving
|
||||
|
||||
## 2021 Q3
|
||||
- [x] Support `goctl model pg` to support PostgreSQL code generation
|
||||
- [x] Adapt builtin tracing mechanism to opentracing solutions
|
||||
|
||||
## 2021 Q4
|
||||
- [x] Support `username/password` authentication in ETCD
|
||||
- [x] Support `SSL/TLS` in ETCD
|
||||
- [x] Support `SSL/TLS` in `zRPC`
|
||||
- [x] Support `TLS` in redis connections
|
||||
- [x] Support `goctl bug` to report bugs conveniently
|
||||
|
||||
## 2022
|
||||
- [x] Support `context` in redis related methods for timeout and tracing
|
||||
- [x] Support `context` in sql related methods for timeout and tracing
|
||||
- [x] Support `context` in mongodb related methods for timeout and tracing
|
||||
- [x] Add `httpc.Do` with HTTP call governance, like circuit breaker etc.
|
||||
- [ ] Support `goctl doctor` command to report potential issues for given service
|
||||
- [ ] Support `goctl mock` command to start a mocking server with given `.api` file
|
||||
@@ -32,9 +32,11 @@ func NewECBEncrypter(b cipher.Block) cipher.BlockMode {
|
||||
return (*ecbEncrypter)(newECB(b))
|
||||
}
|
||||
|
||||
// BlockSize returns the mode's block size.
|
||||
func (x *ecbEncrypter) BlockSize() int { return x.blockSize }
|
||||
|
||||
// why we don't return error is because cipher.BlockMode doesn't allow this
|
||||
// CryptBlocks encrypts a number of blocks. The length of src must be a multiple of
|
||||
// the block size. Dst and src must overlap entirely or not at all.
|
||||
func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
|
||||
if len(src)%x.blockSize != 0 {
|
||||
logx.Error("crypto/cipher: input not full blocks")
|
||||
@@ -59,11 +61,13 @@ func NewECBDecrypter(b cipher.Block) cipher.BlockMode {
|
||||
return (*ecbDecrypter)(newECB(b))
|
||||
}
|
||||
|
||||
// BlockSize returns the mode's block size.
|
||||
func (x *ecbDecrypter) BlockSize() int {
|
||||
return x.blockSize
|
||||
}
|
||||
|
||||
// why we don't return error is because cipher.BlockMode doesn't allow this
|
||||
// CryptBlocks decrypts a number of blocks. The length of src must be a multiple of
|
||||
// the block size. Dst and src must overlap entirely or not at all.
|
||||
func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
|
||||
if len(src)%x.blockSize != 0 {
|
||||
logx.Error("crypto/cipher: input not full blocks")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
|
||||
@@ -10,7 +11,8 @@ import (
|
||||
func TestAesEcb(t *testing.T) {
|
||||
var (
|
||||
key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
|
||||
val = []byte("hello")
|
||||
val = []byte("helloworld")
|
||||
valLong = []byte("helloworldlong..")
|
||||
badKey1 = []byte("aaaaaaaaa")
|
||||
// more than 32 chars
|
||||
badKey2 = []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
|
||||
@@ -31,6 +33,39 @@ func TestAesEcb(t *testing.T) {
|
||||
src, err := EcbDecrypt(key, dst)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, val, src)
|
||||
block, err := aes.NewCipher(key)
|
||||
assert.NoError(t, err)
|
||||
encrypter := NewECBEncrypter(block)
|
||||
assert.Equal(t, 16, encrypter.BlockSize())
|
||||
decrypter := NewECBDecrypter(block)
|
||||
assert.Equal(t, 16, decrypter.BlockSize())
|
||||
|
||||
dst = make([]byte, 8)
|
||||
encrypter.CryptBlocks(dst, val)
|
||||
for _, b := range dst {
|
||||
assert.Equal(t, byte(0), b)
|
||||
}
|
||||
|
||||
dst = make([]byte, 8)
|
||||
encrypter.CryptBlocks(dst, valLong)
|
||||
for _, b := range dst {
|
||||
assert.Equal(t, byte(0), b)
|
||||
}
|
||||
|
||||
dst = make([]byte, 8)
|
||||
decrypter.CryptBlocks(dst, val)
|
||||
for _, b := range dst {
|
||||
assert.Equal(t, byte(0), b)
|
||||
}
|
||||
|
||||
dst = make([]byte, 8)
|
||||
decrypter.CryptBlocks(dst, valLong)
|
||||
for _, b := range dst {
|
||||
assert.Equal(t, byte(0), b)
|
||||
}
|
||||
|
||||
_, err = EcbEncryptBase64("cTR0N3dDKkYtSmFOZFJnVWpYbjJyNXU4eC9BP0QK", "aGVsbG93b3JsZGxvbmcuLgo=")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAesEcbBase64(t *testing.T) {
|
||||
|
||||
@@ -80,3 +80,17 @@ func TestKeyBytes(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, len(key.Bytes()) > 0)
|
||||
}
|
||||
|
||||
func TestDHOnErrors(t *testing.T) {
|
||||
key, err := GenerateKey()
|
||||
assert.Nil(t, err)
|
||||
assert.NotEmpty(t, key.Bytes())
|
||||
_, err = ComputeKey(key.PubKey, key.PriKey)
|
||||
assert.NoError(t, err)
|
||||
_, err = ComputeKey(nil, key.PriKey)
|
||||
assert.Error(t, err)
|
||||
_, err = ComputeKey(key.PubKey, nil)
|
||||
assert.Error(t, err)
|
||||
|
||||
assert.NotNil(t, NewPublicKey([]byte("")))
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import "sync"
|
||||
type Ring struct {
|
||||
elements []interface{}
|
||||
index int
|
||||
lock sync.Mutex
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRing returns a Ring object with the given size n.
|
||||
@@ -31,8 +31,8 @@ func (r *Ring) Add(v interface{}) {
|
||||
|
||||
// Take takes all items from r.
|
||||
func (r *Ring) Take() []interface{} {
|
||||
r.lock.Lock()
|
||||
defer r.lock.Unlock()
|
||||
r.lock.RLock()
|
||||
defer r.lock.RUnlock()
|
||||
|
||||
var size int
|
||||
var start int
|
||||
|
||||
@@ -213,23 +213,23 @@ func (s *Set) validate(i interface{}) {
|
||||
switch i.(type) {
|
||||
case int:
|
||||
if s.tp != intType {
|
||||
logx.Errorf("Error: element is int, but set contains elements with type %d", s.tp)
|
||||
logx.Errorf("element is int, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case int64:
|
||||
if s.tp != int64Type {
|
||||
logx.Errorf("Error: element is int64, but set contains elements with type %d", s.tp)
|
||||
logx.Errorf("element is int64, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case uint:
|
||||
if s.tp != uintType {
|
||||
logx.Errorf("Error: element is uint, but set contains elements with type %d", s.tp)
|
||||
logx.Errorf("element is uint, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case uint64:
|
||||
if s.tp != uint64Type {
|
||||
logx.Errorf("Error: element is uint64, but set contains elements with type %d", s.tp)
|
||||
logx.Errorf("element is uint64, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case string:
|
||||
if s.tp != stringType {
|
||||
logx.Errorf("Error: element is string, but set contains elements with type %d", s.tp)
|
||||
logx.Errorf("element is string, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,10 +69,11 @@ func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*Tim
|
||||
interval, numSlots, execute)
|
||||
}
|
||||
|
||||
return newTimingWheelWithClock(interval, numSlots, execute, timex.NewTicker(interval))
|
||||
return NewTimingWheelWithTicker(interval, numSlots, execute, timex.NewTicker(interval))
|
||||
}
|
||||
|
||||
func newTimingWheelWithClock(interval time.Duration, numSlots int, execute Execute,
|
||||
// NewTimingWheelWithTicker returns a TimingWheel with the given ticker.
|
||||
func NewTimingWheelWithTicker(interval time.Duration, numSlots int, execute Execute,
|
||||
ticker timex.Ticker) (*TimingWheel, error) {
|
||||
tw := &TimingWheel{
|
||||
interval: interval,
|
||||
|
||||
@@ -26,7 +26,7 @@ func TestNewTimingWheel(t *testing.T) {
|
||||
|
||||
func TestTimingWheel_Drain(t *testing.T) {
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
|
||||
}, ticker)
|
||||
tw.SetTimer("first", 3, testStep*4)
|
||||
tw.SetTimer("second", 5, testStep*7)
|
||||
@@ -62,7 +62,7 @@ func TestTimingWheel_Drain(t *testing.T) {
|
||||
func TestTimingWheel_SetTimerSoon(t *testing.T) {
|
||||
run := syncx.NewAtomicBool()
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
|
||||
assert.True(t, run.CompareAndSwap(false, true))
|
||||
assert.Equal(t, "any", k)
|
||||
assert.Equal(t, 3, v.(int))
|
||||
@@ -78,7 +78,7 @@ func TestTimingWheel_SetTimerSoon(t *testing.T) {
|
||||
func TestTimingWheel_SetTimerTwice(t *testing.T) {
|
||||
run := syncx.NewAtomicBool()
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
|
||||
assert.True(t, run.CompareAndSwap(false, true))
|
||||
assert.Equal(t, "any", k)
|
||||
assert.Equal(t, 5, v.(int))
|
||||
@@ -96,7 +96,7 @@ func TestTimingWheel_SetTimerTwice(t *testing.T) {
|
||||
|
||||
func TestTimingWheel_SetTimerWrongDelay(t *testing.T) {
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {}, ticker)
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
|
||||
defer tw.Stop()
|
||||
assert.NotPanics(t, func() {
|
||||
tw.SetTimer("any", 3, -testStep)
|
||||
@@ -105,7 +105,7 @@ func TestTimingWheel_SetTimerWrongDelay(t *testing.T) {
|
||||
|
||||
func TestTimingWheel_SetTimerAfterClose(t *testing.T) {
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {}, ticker)
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
|
||||
tw.Stop()
|
||||
assert.Equal(t, ErrClosed, tw.SetTimer("any", 3, testStep))
|
||||
}
|
||||
@@ -113,7 +113,7 @@ func TestTimingWheel_SetTimerAfterClose(t *testing.T) {
|
||||
func TestTimingWheel_MoveTimer(t *testing.T) {
|
||||
run := syncx.NewAtomicBool()
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 3, func(k, v interface{}) {
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v interface{}) {
|
||||
assert.True(t, run.CompareAndSwap(false, true))
|
||||
assert.Equal(t, "any", k)
|
||||
assert.Equal(t, 3, v.(int))
|
||||
@@ -139,7 +139,7 @@ func TestTimingWheel_MoveTimer(t *testing.T) {
|
||||
func TestTimingWheel_MoveTimerSoon(t *testing.T) {
|
||||
run := syncx.NewAtomicBool()
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 3, func(k, v interface{}) {
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v interface{}) {
|
||||
assert.True(t, run.CompareAndSwap(false, true))
|
||||
assert.Equal(t, "any", k)
|
||||
assert.Equal(t, 3, v.(int))
|
||||
@@ -155,7 +155,7 @@ func TestTimingWheel_MoveTimerSoon(t *testing.T) {
|
||||
func TestTimingWheel_MoveTimerEarlier(t *testing.T) {
|
||||
run := syncx.NewAtomicBool()
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
|
||||
assert.True(t, run.CompareAndSwap(false, true))
|
||||
assert.Equal(t, "any", k)
|
||||
assert.Equal(t, 3, v.(int))
|
||||
@@ -173,7 +173,7 @@ func TestTimingWheel_MoveTimerEarlier(t *testing.T) {
|
||||
|
||||
func TestTimingWheel_RemoveTimer(t *testing.T) {
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {}, ticker)
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
|
||||
tw.SetTimer("any", 3, testStep)
|
||||
assert.NotPanics(t, func() {
|
||||
tw.RemoveTimer("any")
|
||||
@@ -236,7 +236,7 @@ func TestTimingWheel_SetTimer(t *testing.T) {
|
||||
}
|
||||
var actual int32
|
||||
done := make(chan lang.PlaceholderType)
|
||||
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
|
||||
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
|
||||
assert.Equal(t, 1, key.(int))
|
||||
assert.Equal(t, 2, value.(int))
|
||||
actual = atomic.LoadInt32(&count)
|
||||
@@ -317,7 +317,7 @@ func TestTimingWheel_SetAndMoveThenStart(t *testing.T) {
|
||||
}
|
||||
var actual int32
|
||||
done := make(chan lang.PlaceholderType)
|
||||
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
|
||||
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
|
||||
actual = atomic.LoadInt32(&count)
|
||||
close(done)
|
||||
}, ticker)
|
||||
@@ -405,7 +405,7 @@ func TestTimingWheel_SetAndMoveTwice(t *testing.T) {
|
||||
}
|
||||
var actual int32
|
||||
done := make(chan lang.PlaceholderType)
|
||||
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
|
||||
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
|
||||
actual = atomic.LoadInt32(&count)
|
||||
close(done)
|
||||
}, ticker)
|
||||
@@ -486,7 +486,7 @@ func TestTimingWheel_ElapsedAndSet(t *testing.T) {
|
||||
}
|
||||
var actual int32
|
||||
done := make(chan lang.PlaceholderType)
|
||||
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
|
||||
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
|
||||
actual = atomic.LoadInt32(&count)
|
||||
close(done)
|
||||
}, ticker)
|
||||
@@ -577,7 +577,7 @@ func TestTimingWheel_ElapsedAndSetThenMove(t *testing.T) {
|
||||
}
|
||||
var actual int32
|
||||
done := make(chan lang.PlaceholderType)
|
||||
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
|
||||
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
|
||||
actual = atomic.LoadInt32(&count)
|
||||
close(done)
|
||||
}, ticker)
|
||||
@@ -612,7 +612,7 @@ func TestMoveAndRemoveTask(t *testing.T) {
|
||||
}
|
||||
}
|
||||
var keys []int
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
|
||||
assert.Equal(t, "any", k)
|
||||
assert.Equal(t, 3, v.(int))
|
||||
keys = append(keys, v.(int))
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/jsonx"
|
||||
@@ -12,13 +13,29 @@ import (
|
||||
"github.com/zeromicro/go-zero/internal/encoding"
|
||||
)
|
||||
|
||||
const distanceBetweenUpperAndLower = 32
|
||||
const jsonTagKey = "json"
|
||||
|
||||
var loaders = map[string]func([]byte, interface{}) error{
|
||||
".json": LoadFromJsonBytes,
|
||||
".toml": LoadFromTomlBytes,
|
||||
".yaml": LoadFromYamlBytes,
|
||||
".yml": LoadFromYamlBytes,
|
||||
var (
|
||||
fillDefaultUnmarshaler = mapping.NewUnmarshaler(jsonTagKey, mapping.WithDefault())
|
||||
loaders = map[string]func([]byte, interface{}) error{
|
||||
".json": LoadFromJsonBytes,
|
||||
".toml": LoadFromTomlBytes,
|
||||
".yaml": LoadFromYamlBytes,
|
||||
".yml": LoadFromYamlBytes,
|
||||
}
|
||||
)
|
||||
|
||||
// children and mapField should not be both filled.
|
||||
// named fields and map cannot be bound to the same field name.
|
||||
type fieldInfo struct {
|
||||
children map[string]*fieldInfo
|
||||
mapField *fieldInfo
|
||||
}
|
||||
|
||||
// FillDefault fills the default values for the given v,
|
||||
// and the premise is that the value of v must be guaranteed to be empty.
|
||||
func FillDefault(v interface{}) error {
|
||||
return fillDefaultUnmarshaler.Unmarshal(map[string]interface{}{}, v)
|
||||
}
|
||||
|
||||
// Load loads config into v from file, .json, .yaml and .yml are acceptable.
|
||||
@@ -53,12 +70,19 @@ func LoadConfig(file string, v interface{}, opts ...Option) error {
|
||||
|
||||
// LoadFromJsonBytes loads config into v from content json bytes.
|
||||
func LoadFromJsonBytes(content []byte, v interface{}) error {
|
||||
info, err := buildFieldsInfo(reflect.TypeOf(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var m map[string]interface{}
|
||||
if err := jsonx.Unmarshal(content, &m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return mapping.UnmarshalJsonMap(toCamelCaseKeyMap(m), v, mapping.WithCanonicalKeyFunc(toCamelCase))
|
||||
lowerCaseKeyMap := toLowerCaseKeyMap(m, info)
|
||||
|
||||
return mapping.UnmarshalJsonMap(lowerCaseKeyMap, v, mapping.WithCanonicalKeyFunc(toLowerCase))
|
||||
}
|
||||
|
||||
// LoadConfigFromJsonBytes loads config into v from content json bytes.
|
||||
@@ -100,53 +124,166 @@ func MustLoad(path string, v interface{}, opts ...Option) {
|
||||
}
|
||||
}
|
||||
|
||||
func toCamelCase(s string) string {
|
||||
var buf strings.Builder
|
||||
buf.Grow(len(s))
|
||||
var capNext bool
|
||||
boundary := true
|
||||
for _, v := range s {
|
||||
isCap := v >= 'A' && v <= 'Z'
|
||||
isLow := v >= 'a' && v <= 'z'
|
||||
if boundary && (isCap || isLow) {
|
||||
if capNext {
|
||||
if isLow {
|
||||
v -= distanceBetweenUpperAndLower
|
||||
}
|
||||
} else {
|
||||
if isCap {
|
||||
v += distanceBetweenUpperAndLower
|
||||
}
|
||||
}
|
||||
boundary = false
|
||||
func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo) error {
|
||||
if prev, ok := info.children[key]; ok {
|
||||
if child.mapField != nil {
|
||||
return newDupKeyError(key)
|
||||
}
|
||||
if isCap || isLow {
|
||||
buf.WriteRune(v)
|
||||
capNext = false
|
||||
} else if v == ' ' || v == '\t' {
|
||||
buf.WriteRune(v)
|
||||
capNext = false
|
||||
boundary = true
|
||||
} else if v == '_' {
|
||||
capNext = true
|
||||
boundary = true
|
||||
} else {
|
||||
buf.WriteRune(v)
|
||||
capNext = true
|
||||
|
||||
if err := mergeFields(prev, key, child.children); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
info.children[key] = child
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
|
||||
switch ft.Kind() {
|
||||
case reflect.Struct:
|
||||
fields, err := buildFieldsInfo(ft)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for k, v := range fields.children {
|
||||
if err = addOrMergeFields(info, k, v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case reflect.Map:
|
||||
elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, ok := info.children[lowerCaseName]; ok {
|
||||
return newDupKeyError(lowerCaseName)
|
||||
}
|
||||
|
||||
info.children[lowerCaseName] = &fieldInfo{
|
||||
children: make(map[string]*fieldInfo),
|
||||
mapField: elemField,
|
||||
}
|
||||
default:
|
||||
if _, ok := info.children[lowerCaseName]; ok {
|
||||
return newDupKeyError(lowerCaseName)
|
||||
}
|
||||
|
||||
info.children[lowerCaseName] = &fieldInfo{
|
||||
children: make(map[string]*fieldInfo),
|
||||
}
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
return nil
|
||||
}
|
||||
|
||||
func toCamelCaseInterface(v interface{}) interface{} {
|
||||
func buildFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
|
||||
tp = mapping.Deref(tp)
|
||||
|
||||
switch tp.Kind() {
|
||||
case reflect.Struct:
|
||||
return buildStructFieldsInfo(tp)
|
||||
case reflect.Array, reflect.Slice:
|
||||
return buildFieldsInfo(mapping.Deref(tp.Elem()))
|
||||
case reflect.Chan, reflect.Func:
|
||||
return nil, fmt.Errorf("unsupported type: %s", tp.Kind())
|
||||
default:
|
||||
return &fieldInfo{
|
||||
children: make(map[string]*fieldInfo),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
|
||||
var finfo *fieldInfo
|
||||
var err error
|
||||
|
||||
switch ft.Kind() {
|
||||
case reflect.Struct:
|
||||
finfo, err = buildFieldsInfo(ft)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case reflect.Array, reflect.Slice:
|
||||
finfo, err = buildFieldsInfo(ft.Elem())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case reflect.Map:
|
||||
elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
finfo = &fieldInfo{
|
||||
children: make(map[string]*fieldInfo),
|
||||
mapField: elemInfo,
|
||||
}
|
||||
default:
|
||||
finfo, err = buildFieldsInfo(ft)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return addOrMergeFields(info, lowerCaseName, finfo)
|
||||
}
|
||||
|
||||
func buildStructFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
|
||||
info := &fieldInfo{
|
||||
children: make(map[string]*fieldInfo),
|
||||
}
|
||||
|
||||
for i := 0; i < tp.NumField(); i++ {
|
||||
field := tp.Field(i)
|
||||
name := field.Name
|
||||
lowerCaseName := toLowerCase(name)
|
||||
ft := mapping.Deref(field.Type)
|
||||
// flatten anonymous fields
|
||||
if field.Anonymous {
|
||||
if err := buildAnonymousFieldInfo(info, lowerCaseName, ft); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if err := buildNamedFieldInfo(info, lowerCaseName, ft); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func mergeFields(prev *fieldInfo, key string, children map[string]*fieldInfo) error {
|
||||
if len(prev.children) == 0 || len(children) == 0 {
|
||||
return newDupKeyError(key)
|
||||
}
|
||||
|
||||
// merge fields
|
||||
for k, v := range children {
|
||||
if _, ok := prev.children[k]; ok {
|
||||
return newDupKeyError(k)
|
||||
}
|
||||
|
||||
prev.children[k] = v
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func toLowerCase(s string) string {
|
||||
return strings.ToLower(s)
|
||||
}
|
||||
|
||||
func toLowerCaseInterface(v interface{}, info *fieldInfo) interface{} {
|
||||
switch vv := v.(type) {
|
||||
case map[string]interface{}:
|
||||
return toCamelCaseKeyMap(vv)
|
||||
return toLowerCaseKeyMap(vv, info)
|
||||
case []interface{}:
|
||||
var arr []interface{}
|
||||
for _, vvv := range vv {
|
||||
arr = append(arr, toCamelCaseInterface(vvv))
|
||||
arr = append(arr, toLowerCaseInterface(vvv, info))
|
||||
}
|
||||
return arr
|
||||
default:
|
||||
@@ -154,11 +291,37 @@ func toCamelCaseInterface(v interface{}) interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func toCamelCaseKeyMap(m map[string]interface{}) map[string]interface{} {
|
||||
func toLowerCaseKeyMap(m map[string]interface{}, info *fieldInfo) map[string]interface{} {
|
||||
res := make(map[string]interface{})
|
||||
|
||||
for k, v := range m {
|
||||
res[toCamelCase(k)] = toCamelCaseInterface(v)
|
||||
ti, ok := info.children[k]
|
||||
if ok {
|
||||
res[k] = toLowerCaseInterface(v, ti)
|
||||
continue
|
||||
}
|
||||
|
||||
lk := toLowerCase(k)
|
||||
if ti, ok = info.children[lk]; ok {
|
||||
res[lk] = toLowerCaseInterface(v, ti)
|
||||
} else if info.mapField != nil {
|
||||
res[k] = toLowerCaseInterface(v, info.mapField)
|
||||
} else {
|
||||
res[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
type dupKeyError struct {
|
||||
key string
|
||||
}
|
||||
|
||||
func newDupKeyError(key string) dupKeyError {
|
||||
return dupKeyError{key: key}
|
||||
}
|
||||
|
||||
func (e dupKeyError) Error() string {
|
||||
return fmt.Sprintf("duplicated key %s", e.key)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/hash"
|
||||
)
|
||||
|
||||
var dupErr dupKeyError
|
||||
|
||||
func TestLoadConfig_notExists(t *testing.T) {
|
||||
assert.NotNil(t, Load("not_a_file", nil))
|
||||
}
|
||||
@@ -17,7 +19,7 @@ func TestLoadConfig_notRecogFile(t *testing.T) {
|
||||
filename, err := fs.TempFilenameWithText("hello")
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(filename)
|
||||
assert.NotNil(t, Load(filename, nil))
|
||||
assert.NotNil(t, LoadConfig(filename, nil))
|
||||
}
|
||||
|
||||
func TestConfigJson(t *testing.T) {
|
||||
@@ -64,7 +66,7 @@ func TestLoadFromJsonBytesArray(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
assert.NoError(t, LoadFromJsonBytes(input, &val))
|
||||
assert.NoError(t, LoadConfigFromJsonBytes(input, &val))
|
||||
var expect []string
|
||||
for _, user := range val.Users {
|
||||
expect = append(expect, user.Name)
|
||||
@@ -97,6 +99,30 @@ d = "abcd!@#$112"
|
||||
assert.Equal(t, "abcd!@#$112", val.D)
|
||||
}
|
||||
|
||||
func TestConfigOptional(t *testing.T) {
|
||||
text := `a = "foo"
|
||||
b = 1
|
||||
c = "FOO"
|
||||
d = "abcd"
|
||||
`
|
||||
tmpfile, err := createTempFile(".toml", text)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(tmpfile)
|
||||
|
||||
var val struct {
|
||||
A string `json:"a"`
|
||||
B int `json:"b,optional"`
|
||||
C string `json:"c,optional=B"`
|
||||
D string `json:"d,optional=b"`
|
||||
}
|
||||
if assert.NoError(t, Load(tmpfile, &val)) {
|
||||
assert.Equal(t, "foo", val.A)
|
||||
assert.Equal(t, 1, val.B)
|
||||
assert.Equal(t, "FOO", val.C)
|
||||
assert.Equal(t, "abcd", val.D)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigJsonCanonical(t *testing.T) {
|
||||
text := []byte(`{"a": "foo", "B": "bar"}`)
|
||||
|
||||
@@ -148,7 +174,7 @@ B: bar`)
|
||||
A string
|
||||
B string
|
||||
}
|
||||
assert.NoError(t, LoadFromYamlBytes(text, &val1))
|
||||
assert.NoError(t, LoadConfigFromYamlBytes(text, &val1))
|
||||
assert.Equal(t, "foo", val1.A)
|
||||
assert.Equal(t, "bar", val1.B)
|
||||
assert.NoError(t, LoadFromYamlBytes(text, &val2))
|
||||
@@ -237,23 +263,23 @@ func TestToCamelCase(t *testing.T) {
|
||||
},
|
||||
{
|
||||
input: "hello_world",
|
||||
expect: "helloWorld",
|
||||
expect: "hello_world",
|
||||
},
|
||||
{
|
||||
input: "Hello_world",
|
||||
expect: "helloWorld",
|
||||
expect: "hello_world",
|
||||
},
|
||||
{
|
||||
input: "hello_World",
|
||||
expect: "helloWorld",
|
||||
expect: "hello_world",
|
||||
},
|
||||
{
|
||||
input: "helloWorld",
|
||||
expect: "helloWorld",
|
||||
expect: "helloworld",
|
||||
},
|
||||
{
|
||||
input: "HelloWorld",
|
||||
expect: "helloWorld",
|
||||
expect: "helloworld",
|
||||
},
|
||||
{
|
||||
input: "hello World",
|
||||
@@ -269,30 +295,34 @@ func TestToCamelCase(t *testing.T) {
|
||||
},
|
||||
{
|
||||
input: "Hello World foo_bar",
|
||||
expect: "hello world fooBar",
|
||||
expect: "hello world foo_bar",
|
||||
},
|
||||
{
|
||||
input: "Hello World foo_Bar",
|
||||
expect: "hello world fooBar",
|
||||
expect: "hello world foo_bar",
|
||||
},
|
||||
{
|
||||
input: "Hello World Foo_bar",
|
||||
expect: "hello world fooBar",
|
||||
expect: "hello world foo_bar",
|
||||
},
|
||||
{
|
||||
input: "Hello World Foo_Bar",
|
||||
expect: "hello world fooBar",
|
||||
expect: "hello world foo_bar",
|
||||
},
|
||||
{
|
||||
input: "Hello.World Foo_Bar",
|
||||
expect: "hello.world foo_bar",
|
||||
},
|
||||
{
|
||||
input: "你好 World Foo_Bar",
|
||||
expect: "你好 world fooBar",
|
||||
expect: "你好 world foo_bar",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.input, func(t *testing.T) {
|
||||
assert.Equal(t, test.expect, toCamelCase(test.input))
|
||||
assert.Equal(t, test.expect, toLowerCase(test.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -328,6 +358,670 @@ func TestLoadFromYamlBytes(t *testing.T) {
|
||||
assert.Equal(t, "foo", val.Layer1.Layer2.Layer3)
|
||||
}
|
||||
|
||||
func TestLoadFromYamlBytesTerm(t *testing.T) {
|
||||
input := []byte(`layer1:
|
||||
layer2:
|
||||
tls_conf: foo`)
|
||||
var val struct {
|
||||
Layer1 struct {
|
||||
Layer2 struct {
|
||||
Layer3 string `json:"tls_conf"`
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert.NoError(t, LoadFromYamlBytes(input, &val))
|
||||
assert.Equal(t, "foo", val.Layer1.Layer2.Layer3)
|
||||
}
|
||||
|
||||
func TestLoadFromYamlBytesLayers(t *testing.T) {
|
||||
input := []byte(`layer1:
|
||||
layer2:
|
||||
layer3: foo`)
|
||||
var val struct {
|
||||
Value string `json:"Layer1.Layer2.Layer3"`
|
||||
}
|
||||
|
||||
assert.NoError(t, LoadFromYamlBytes(input, &val))
|
||||
assert.Equal(t, "foo", val.Value)
|
||||
}
|
||||
|
||||
func TestLoadFromYamlItemOverlay(t *testing.T) {
|
||||
type (
|
||||
Redis struct {
|
||||
Host string
|
||||
Port int
|
||||
}
|
||||
|
||||
RedisKey struct {
|
||||
Redis
|
||||
Key string
|
||||
}
|
||||
|
||||
Server struct {
|
||||
Redis RedisKey
|
||||
}
|
||||
|
||||
TestConfig struct {
|
||||
Server
|
||||
Redis Redis
|
||||
}
|
||||
)
|
||||
|
||||
input := []byte(`Redis:
|
||||
Host: localhost
|
||||
Port: 6379
|
||||
Key: test
|
||||
`)
|
||||
|
||||
var c TestConfig
|
||||
assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr)
|
||||
}
|
||||
|
||||
func TestLoadFromYamlItemOverlayReverse(t *testing.T) {
|
||||
type (
|
||||
Redis struct {
|
||||
Host string
|
||||
Port int
|
||||
}
|
||||
|
||||
RedisKey struct {
|
||||
Redis
|
||||
Key string
|
||||
}
|
||||
|
||||
Server struct {
|
||||
Redis Redis
|
||||
}
|
||||
|
||||
TestConfig struct {
|
||||
Redis RedisKey
|
||||
Server
|
||||
}
|
||||
)
|
||||
|
||||
input := []byte(`Redis:
|
||||
Host: localhost
|
||||
Port: 6379
|
||||
Key: test
|
||||
`)
|
||||
|
||||
var c TestConfig
|
||||
assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr)
|
||||
}
|
||||
|
||||
func TestLoadFromYamlItemOverlayWithMap(t *testing.T) {
|
||||
type (
|
||||
Redis struct {
|
||||
Host string
|
||||
Port int
|
||||
}
|
||||
|
||||
RedisKey struct {
|
||||
Redis
|
||||
Key string
|
||||
}
|
||||
|
||||
Server struct {
|
||||
Redis RedisKey
|
||||
}
|
||||
|
||||
TestConfig struct {
|
||||
Server
|
||||
Redis map[string]interface{}
|
||||
}
|
||||
)
|
||||
|
||||
input := []byte(`Redis:
|
||||
Host: localhost
|
||||
Port: 6379
|
||||
Key: test
|
||||
`)
|
||||
|
||||
var c TestConfig
|
||||
assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr)
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesMap(t *testing.T) {
|
||||
input := []byte(`{"foo":{"/mtproto.RPCTos": "bff.bff","bar":"baz"}}`)
|
||||
|
||||
var val struct {
|
||||
Foo map[string]string
|
||||
}
|
||||
|
||||
assert.NoError(t, LoadFromJsonBytes(input, &val))
|
||||
assert.Equal(t, "bff.bff", val.Foo["/mtproto.RPCTos"])
|
||||
assert.Equal(t, "baz", val.Foo["bar"])
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesMapWithSliceElements(t *testing.T) {
|
||||
input := []byte(`{"foo":{"/mtproto.RPCTos": ["bff.bff", "any"],"bar":["baz", "qux"]}}`)
|
||||
|
||||
var val struct {
|
||||
Foo map[string][]string
|
||||
}
|
||||
|
||||
assert.NoError(t, LoadFromJsonBytes(input, &val))
|
||||
assert.EqualValues(t, []string{"bff.bff", "any"}, val.Foo["/mtproto.RPCTos"])
|
||||
assert.EqualValues(t, []string{"baz", "qux"}, val.Foo["bar"])
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesMapWithSliceOfStructs(t *testing.T) {
|
||||
input := []byte(`{"foo":{
|
||||
"/mtproto.RPCTos": [{"bar": "any"}],
|
||||
"bar":[{"bar": "qux"}, {"bar": "ever"}]}}`)
|
||||
|
||||
var val struct {
|
||||
Foo map[string][]struct {
|
||||
Bar string
|
||||
}
|
||||
}
|
||||
|
||||
assert.NoError(t, LoadFromJsonBytes(input, &val))
|
||||
assert.Equal(t, 1, len(val.Foo["/mtproto.RPCTos"]))
|
||||
assert.Equal(t, "any", val.Foo["/mtproto.RPCTos"][0].Bar)
|
||||
assert.Equal(t, 2, len(val.Foo["bar"]))
|
||||
assert.Equal(t, "qux", val.Foo["bar"][0].Bar)
|
||||
assert.Equal(t, "ever", val.Foo["bar"][1].Bar)
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesWithAnonymousField(t *testing.T) {
|
||||
type (
|
||||
Int int
|
||||
|
||||
InnerConf struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
Conf struct {
|
||||
Int
|
||||
InnerConf
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
input = []byte(`{"Name": "hello", "int": 3}`)
|
||||
c Conf
|
||||
)
|
||||
assert.NoError(t, LoadFromJsonBytes(input, &c))
|
||||
assert.Equal(t, "hello", c.Name)
|
||||
assert.Equal(t, Int(3), c.Int)
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesWithMapValueOfStruct(t *testing.T) {
|
||||
type (
|
||||
Value struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
Config struct {
|
||||
Items map[string]Value
|
||||
}
|
||||
)
|
||||
|
||||
var inputs = [][]byte{
|
||||
[]byte(`{"Items": {"Key":{"Name": "foo"}}}`),
|
||||
[]byte(`{"Items": {"Key":{"Name": "foo"}}}`),
|
||||
[]byte(`{"items": {"key":{"name": "foo"}}}`),
|
||||
[]byte(`{"items": {"key":{"name": "foo"}}}`),
|
||||
}
|
||||
for _, input := range inputs {
|
||||
var c Config
|
||||
if assert.NoError(t, LoadFromJsonBytes(input, &c)) {
|
||||
assert.Equal(t, 1, len(c.Items))
|
||||
for _, v := range c.Items {
|
||||
assert.Equal(t, "foo", v.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesWithMapTypeValueOfStruct(t *testing.T) {
|
||||
type (
|
||||
Value struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
Map map[string]Value
|
||||
|
||||
Config struct {
|
||||
Map
|
||||
}
|
||||
)
|
||||
|
||||
var inputs = [][]byte{
|
||||
[]byte(`{"Map": {"Key":{"Name": "foo"}}}`),
|
||||
[]byte(`{"Map": {"Key":{"Name": "foo"}}}`),
|
||||
[]byte(`{"map": {"key":{"name": "foo"}}}`),
|
||||
[]byte(`{"map": {"key":{"name": "foo"}}}`),
|
||||
}
|
||||
for _, input := range inputs {
|
||||
var c Config
|
||||
if assert.NoError(t, LoadFromJsonBytes(input, &c)) {
|
||||
assert.Equal(t, 1, len(c.Map))
|
||||
for _, v := range c.Map {
|
||||
assert.Equal(t, "foo", v.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_FieldOverwrite(t *testing.T) {
|
||||
t.Run("normal", func(t *testing.T) {
|
||||
type Base struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type St1 struct {
|
||||
Base
|
||||
Name2 string
|
||||
}
|
||||
|
||||
type St2 struct {
|
||||
Base
|
||||
Name2 string
|
||||
}
|
||||
|
||||
type St3 struct {
|
||||
*Base
|
||||
Name2 string
|
||||
}
|
||||
|
||||
type St4 struct {
|
||||
*Base
|
||||
Name2 *string
|
||||
}
|
||||
|
||||
validate := func(val interface{}) {
|
||||
input := []byte(`{"Name": "hello", "Name2": "world"}`)
|
||||
assert.NoError(t, LoadFromJsonBytes(input, val))
|
||||
}
|
||||
|
||||
validate(&St1{})
|
||||
validate(&St2{})
|
||||
validate(&St3{})
|
||||
validate(&St4{})
|
||||
})
|
||||
|
||||
t.Run("Inherit Override", func(t *testing.T) {
|
||||
type Base struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type St1 struct {
|
||||
Base
|
||||
Name string
|
||||
}
|
||||
|
||||
type St2 struct {
|
||||
Base
|
||||
Name int
|
||||
}
|
||||
|
||||
type St3 struct {
|
||||
*Base
|
||||
Name int
|
||||
}
|
||||
|
||||
type St4 struct {
|
||||
*Base
|
||||
Name *string
|
||||
}
|
||||
|
||||
validate := func(val interface{}) {
|
||||
input := []byte(`{"Name": "hello"}`)
|
||||
err := LoadFromJsonBytes(input, val)
|
||||
assert.ErrorAs(t, err, &dupErr)
|
||||
assert.Equal(t, newDupKeyError("name").Error(), err.Error())
|
||||
}
|
||||
|
||||
validate(&St1{})
|
||||
validate(&St2{})
|
||||
validate(&St3{})
|
||||
validate(&St4{})
|
||||
})
|
||||
|
||||
t.Run("Inherit more", func(t *testing.T) {
|
||||
type Base1 struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type St0 struct {
|
||||
Base1
|
||||
Name string
|
||||
}
|
||||
|
||||
type St1 struct {
|
||||
St0
|
||||
Name string
|
||||
}
|
||||
|
||||
type St2 struct {
|
||||
St0
|
||||
Name int
|
||||
}
|
||||
|
||||
type St3 struct {
|
||||
*St0
|
||||
Name int
|
||||
}
|
||||
|
||||
type St4 struct {
|
||||
*St0
|
||||
Name *int
|
||||
}
|
||||
|
||||
validate := func(val interface{}) {
|
||||
input := []byte(`{"Name": "hello"}`)
|
||||
err := LoadFromJsonBytes(input, val)
|
||||
assert.ErrorAs(t, err, &dupErr)
|
||||
assert.Equal(t, newDupKeyError("name").Error(), err.Error())
|
||||
}
|
||||
|
||||
validate(&St0{})
|
||||
validate(&St1{})
|
||||
validate(&St2{})
|
||||
validate(&St3{})
|
||||
validate(&St4{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldOverwriteComplicated(t *testing.T) {
|
||||
t.Run("double maps", func(t *testing.T) {
|
||||
type (
|
||||
Base1 struct {
|
||||
Values map[string]string
|
||||
}
|
||||
Base2 struct {
|
||||
Values map[string]string
|
||||
}
|
||||
Config struct {
|
||||
Base1
|
||||
Base2
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Values": {"Key": "Value"}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("merge children", func(t *testing.T) {
|
||||
type (
|
||||
Inner1 struct {
|
||||
Name string
|
||||
}
|
||||
Inner2 struct {
|
||||
Age int
|
||||
}
|
||||
Base1 struct {
|
||||
Inner Inner1
|
||||
}
|
||||
Base2 struct {
|
||||
Inner Inner2
|
||||
}
|
||||
Config struct {
|
||||
Base1
|
||||
Base2
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Inner": {"Name": "foo", "Age": 10}}`)
|
||||
if assert.NoError(t, LoadFromJsonBytes(input, &c)) {
|
||||
assert.Equal(t, "foo", c.Base1.Inner.Name)
|
||||
assert.Equal(t, 10, c.Base2.Inner.Age)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("overwritten maps", func(t *testing.T) {
|
||||
type (
|
||||
Inner struct {
|
||||
Map map[string]string
|
||||
}
|
||||
Config struct {
|
||||
Map map[string]string
|
||||
Inner
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Inner": {"Map": {"Key": "Value"}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten nested maps", func(t *testing.T) {
|
||||
type (
|
||||
Inner struct {
|
||||
Map map[string]string
|
||||
}
|
||||
Middle1 struct {
|
||||
Map map[string]string
|
||||
Inner
|
||||
}
|
||||
Middle2 struct {
|
||||
Map map[string]string
|
||||
Inner
|
||||
}
|
||||
Config struct {
|
||||
Middle1
|
||||
Middle2
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Middle1": {"Inner": {"Map": {"Key": "Value"}}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten outer/inner maps", func(t *testing.T) {
|
||||
type (
|
||||
Inner struct {
|
||||
Map map[string]string
|
||||
}
|
||||
Middle struct {
|
||||
Inner
|
||||
Map map[string]string
|
||||
}
|
||||
Config struct {
|
||||
Middle
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Middle": {"Inner": {"Map": {"Key": "Value"}}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten anonymous maps", func(t *testing.T) {
|
||||
type (
|
||||
Inner struct {
|
||||
Map map[string]string
|
||||
}
|
||||
Middle struct {
|
||||
Inner
|
||||
Map map[string]string
|
||||
}
|
||||
Elem map[string]Middle
|
||||
Config struct {
|
||||
Elem
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Elem": {"Key": {"Inner": {"Map": {"Key": "Value"}}}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten primitive and map", func(t *testing.T) {
|
||||
type (
|
||||
Inner struct {
|
||||
Value string
|
||||
}
|
||||
Elem map[string]Inner
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Config struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten map and slice", func(t *testing.T) {
|
||||
type (
|
||||
Inner struct {
|
||||
Value string
|
||||
}
|
||||
Elem []Inner
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Config struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten map and string", func(t *testing.T) {
|
||||
type (
|
||||
Elem string
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Config struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadNamedFieldOverwritten(t *testing.T) {
|
||||
t.Run("overwritten named struct", func(t *testing.T) {
|
||||
type (
|
||||
Elem string
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Base struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
Config struct {
|
||||
Val Base
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Val": {"Elem": {"Key": {"Value": "Value"}}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten named []struct", func(t *testing.T) {
|
||||
type (
|
||||
Elem string
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Base struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
Config struct {
|
||||
Vals []Base
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Vals": [{"Elem": {"Key": {"Value": "Value"}}}]}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten named map[string]struct", func(t *testing.T) {
|
||||
type (
|
||||
Elem string
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Base struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
Config struct {
|
||||
Vals map[string]Base
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Vals": {"Key": {"Elem": {"Key": {"Value": "Value"}}}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten named *struct", func(t *testing.T) {
|
||||
type (
|
||||
Elem string
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Base struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
Config struct {
|
||||
Vals *Base
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Vals": [{"Elem": {"Key": {"Value": "Value"}}}]}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten named struct", func(t *testing.T) {
|
||||
type (
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Base struct {
|
||||
Named
|
||||
Elem Named
|
||||
}
|
||||
Config struct {
|
||||
Val Base
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Val": {"Elem": "Value"}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten named struct", func(t *testing.T) {
|
||||
type Config struct {
|
||||
Val chan int
|
||||
}
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Val": 1}`)
|
||||
assert.Error(t, LoadFromJsonBytes(input, &c))
|
||||
})
|
||||
}
|
||||
|
||||
func createTempFile(ext, text string) (string, error) {
|
||||
tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
|
||||
if err != nil {
|
||||
@@ -345,3 +1039,55 @@ func createTempFile(ext, text string) (string, error) {
|
||||
|
||||
return filename, nil
|
||||
}
|
||||
|
||||
func TestFillDefaultUnmarshal(t *testing.T) {
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
type St struct{}
|
||||
err := FillDefault(St{})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("not nil", func(t *testing.T) {
|
||||
type St struct{}
|
||||
err := FillDefault(&St{})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("default", func(t *testing.T) {
|
||||
type St struct {
|
||||
A string `json:",default=a"`
|
||||
B string
|
||||
}
|
||||
var st St
|
||||
err := FillDefault(&st)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, st.A, "a")
|
||||
})
|
||||
|
||||
t.Run("env", func(t *testing.T) {
|
||||
type St struct {
|
||||
A string `json:",default=a"`
|
||||
B string
|
||||
C string `json:",env=TEST_C"`
|
||||
}
|
||||
t.Setenv("TEST_C", "c")
|
||||
|
||||
var st St
|
||||
err := FillDefault(&st)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, st.A, "a")
|
||||
assert.Equal(t, st.C, "c")
|
||||
})
|
||||
|
||||
t.Run("has vaue", func(t *testing.T) {
|
||||
type St struct {
|
||||
A string `json:",default=a"`
|
||||
B string
|
||||
}
|
||||
var st = St{
|
||||
A: "b",
|
||||
}
|
||||
err := FillDefault(&st)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
// PropertyError represents a configuration error message.
|
||||
type PropertyError struct {
|
||||
error
|
||||
message string
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
```go
|
||||
type RestfulConf struct {
|
||||
ServiceName string `json:",env=SERVICE_NAME"` // read from env automatically
|
||||
Host string `json:",default=0.0.0.0"`
|
||||
Port int
|
||||
LogMode string `json:",options=[file,console]"`
|
||||
@@ -21,20 +22,20 @@ type RestfulConf struct {
|
||||
|
||||
```yaml
|
||||
# most fields are optional or have default values
|
||||
Port: 8080
|
||||
LogMode: console
|
||||
port: 8080
|
||||
logMode: console
|
||||
# you can use env settings
|
||||
MaxBytes: ${MAX_BYTES}
|
||||
maxBytes: ${MAX_BYTES}
|
||||
```
|
||||
|
||||
- toml example
|
||||
|
||||
```toml
|
||||
# most fields are optional or have default values
|
||||
Port = 8_080
|
||||
LogMode = "console"
|
||||
port = 8_080
|
||||
logMode = "console"
|
||||
# you can use env settings
|
||||
MaxBytes = "${MAX_BYTES}"
|
||||
maxBytes = "${MAX_BYTES}"
|
||||
```
|
||||
|
||||
3. Load the config from a file:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package discov
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/discov/internal"
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
@@ -51,12 +53,7 @@ func NewPublisher(endpoints []string, key, value string, opts ...PubOption) *Pub
|
||||
|
||||
// KeepAlive keeps key:value alive.
|
||||
func (p *Publisher) KeepAlive() error {
|
||||
cli, err := internal.GetRegistry().GetConn(p.endpoints)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.lease, err = p.register(cli)
|
||||
cli, err := p.doRegister()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -83,6 +80,43 @@ func (p *Publisher) Stop() {
|
||||
p.quit.Close()
|
||||
}
|
||||
|
||||
func (p *Publisher) doKeepAlive() error {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
select {
|
||||
case <-p.quit.Done():
|
||||
return nil
|
||||
default:
|
||||
cli, err := p.doRegister()
|
||||
if err != nil {
|
||||
logx.Errorf("etcd publisher doRegister: %s", err.Error())
|
||||
break
|
||||
}
|
||||
|
||||
if err := p.keepAliveAsync(cli); err != nil {
|
||||
logx.Errorf("etcd publisher keepAliveAsync: %s", err.Error())
|
||||
break
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Publisher) doRegister() (internal.EtcdClient, error) {
|
||||
cli, err := internal.GetRegistry().GetConn(p.endpoints)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.lease, err = p.register(cli)
|
||||
return cli, err
|
||||
}
|
||||
|
||||
func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error {
|
||||
ch, err := cli.KeepAlive(cli.Ctx(), p.lease)
|
||||
if err != nil {
|
||||
@@ -95,8 +129,8 @@ func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error {
|
||||
case _, ok := <-ch:
|
||||
if !ok {
|
||||
p.revoke(cli)
|
||||
if err := p.KeepAlive(); err != nil {
|
||||
logx.Errorf("KeepAlive: %s", err.Error())
|
||||
if err := p.doKeepAlive(); err != nil {
|
||||
logx.Errorf("etcd publisher KeepAlive: %s", err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -105,8 +139,8 @@ func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error {
|
||||
p.revoke(cli)
|
||||
select {
|
||||
case <-p.resumeChan:
|
||||
if err := p.KeepAlive(); err != nil {
|
||||
logx.Errorf("KeepAlive: %s", err.Error())
|
||||
if err := p.doKeepAlive(); err != nil {
|
||||
logx.Errorf("etcd publisher KeepAlive: %s", err.Error())
|
||||
}
|
||||
return
|
||||
case <-p.quit.Done():
|
||||
@@ -141,7 +175,7 @@ func (p *Publisher) register(client internal.EtcdClient) (clientv3.LeaseID, erro
|
||||
|
||||
func (p *Publisher) revoke(cli internal.EtcdClient) {
|
||||
if _, err := cli.Revoke(cli.Ctx(), p.lease); err != nil {
|
||||
logx.Error(err)
|
||||
logx.Errorf("etcd publisher revoke: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -53,10 +53,11 @@ func TestChunkExecutorFlushInterval(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestChunkExecutorEmpty(t *testing.T) {
|
||||
NewChunkExecutor(func(items []interface{}) {
|
||||
executor := NewChunkExecutor(func(items []interface{}) {
|
||||
assert.Fail(t, "should not called")
|
||||
}, WithChunkBytes(10), WithFlushInterval(time.Millisecond))
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
executor.Wait()
|
||||
}
|
||||
|
||||
func TestChunkExecutorFlush(t *testing.T) {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
)
|
||||
|
||||
@@ -67,6 +68,7 @@ func TestPeriodicalExecutor_QuitGoroutine(t *testing.T) {
|
||||
ticker.Tick()
|
||||
ticker.Wait(time.Millisecond * idleRound)
|
||||
assert.Equal(t, routines, runtime.NumGoroutine())
|
||||
proc.Shutdown()
|
||||
}
|
||||
|
||||
func TestPeriodicalExecutor_Bulk(t *testing.T) {
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
package jsontype
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo/bson"
|
||||
)
|
||||
|
||||
// MilliTime represents time.Time that works better with mongodb.
|
||||
type MilliTime struct {
|
||||
time.Time
|
||||
}
|
||||
|
||||
// MarshalJSON marshals mt to json bytes.
|
||||
func (mt MilliTime) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(mt.Milli())
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals data into mt.
|
||||
func (mt *MilliTime) UnmarshalJSON(data []byte) error {
|
||||
var milli int64
|
||||
if err := json.Unmarshal(data, &milli); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mt.Time = time.Unix(0, milli*int64(time.Millisecond))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetBSON returns BSON base on mt.
|
||||
func (mt MilliTime) GetBSON() (interface{}, error) {
|
||||
return mt.Time, nil
|
||||
}
|
||||
|
||||
// SetBSON sets raw into mt.
|
||||
func (mt *MilliTime) SetBSON(raw bson.Raw) error {
|
||||
return raw.Unmarshal(&mt.Time)
|
||||
}
|
||||
|
||||
// Milli returns milliseconds for mt.
|
||||
func (mt MilliTime) Milli() int64 {
|
||||
return mt.UnixNano() / int64(time.Millisecond)
|
||||
}
|
||||
@@ -1,126 +0,0 @@
|
||||
package jsontype
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo/bson"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMilliTime_GetBSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tm time.Time
|
||||
}{
|
||||
{
|
||||
name: "now",
|
||||
tm: time.Now(),
|
||||
},
|
||||
{
|
||||
name: "future",
|
||||
tm: time.Now().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
got, err := MilliTime{test.tm}.GetBSON()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, test.tm, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMilliTime_MarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tm time.Time
|
||||
}{
|
||||
{
|
||||
name: "now",
|
||||
tm: time.Now(),
|
||||
},
|
||||
{
|
||||
name: "future",
|
||||
tm: time.Now().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
b, err := MilliTime{test.tm}.MarshalJSON()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, strconv.FormatInt(test.tm.UnixNano()/1e6, 10), string(b))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMilliTime_Milli(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tm time.Time
|
||||
}{
|
||||
{
|
||||
name: "now",
|
||||
tm: time.Now(),
|
||||
},
|
||||
{
|
||||
name: "future",
|
||||
tm: time.Now().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
n := MilliTime{test.tm}.Milli()
|
||||
assert.Equal(t, test.tm.UnixNano()/1e6, n)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMilliTime_UnmarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tm time.Time
|
||||
}{
|
||||
{
|
||||
name: "now",
|
||||
tm: time.Now(),
|
||||
},
|
||||
{
|
||||
name: "future",
|
||||
tm: time.Now().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var mt MilliTime
|
||||
s := strconv.FormatInt(test.tm.UnixNano()/1e6, 10)
|
||||
err := mt.UnmarshalJSON([]byte(s))
|
||||
assert.Nil(t, err)
|
||||
s1, err := mt.MarshalJSON()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, s, string(s1))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalWithError(t *testing.T) {
|
||||
var mt MilliTime
|
||||
assert.NotNil(t, mt.UnmarshalJSON([]byte("hello")))
|
||||
}
|
||||
|
||||
func TestSetBSON(t *testing.T) {
|
||||
data, err := bson.Marshal(time.Now())
|
||||
assert.Nil(t, err)
|
||||
|
||||
var raw bson.Raw
|
||||
assert.Nil(t, bson.Unmarshal(data, &raw))
|
||||
|
||||
var mt MilliTime
|
||||
assert.Nil(t, mt.SetBSON(raw))
|
||||
assert.NotNil(t, mt.SetBSON(bson.Raw{}))
|
||||
}
|
||||
@@ -29,7 +29,7 @@ func Repr(v interface{}) string {
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(v)
|
||||
if val.Kind() == reflect.Ptr && !val.IsNil() {
|
||||
for val.Kind() == reflect.Ptr && !val.IsNil() {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package lang
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -110,6 +113,28 @@ func TestRepr(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReprOfValue(t *testing.T) {
|
||||
t.Run("error", func(t *testing.T) {
|
||||
assert.Equal(t, "error", reprOfValue(reflect.ValueOf(errors.New("error"))))
|
||||
})
|
||||
|
||||
t.Run("stringer", func(t *testing.T) {
|
||||
assert.Equal(t, "1.23", reprOfValue(reflect.ValueOf(json.Number("1.23"))))
|
||||
})
|
||||
|
||||
t.Run("int", func(t *testing.T) {
|
||||
assert.Equal(t, "1", reprOfValue(reflect.ValueOf(1)))
|
||||
})
|
||||
|
||||
t.Run("int", func(t *testing.T) {
|
||||
assert.Equal(t, "1", reprOfValue(reflect.ValueOf("1")))
|
||||
})
|
||||
|
||||
t.Run("int", func(t *testing.T) {
|
||||
assert.Equal(t, "1", reprOfValue(reflect.ValueOf(uint(1))))
|
||||
})
|
||||
}
|
||||
|
||||
type mockStringable struct{}
|
||||
|
||||
func (m mockStringable) String() string {
|
||||
|
||||
@@ -27,6 +27,26 @@ func Close() error {
|
||||
return logx.Close()
|
||||
}
|
||||
|
||||
// Debug writes v into access log.
|
||||
func Debug(ctx context.Context, v ...interface{}) {
|
||||
getLogger(ctx).Debug(v...)
|
||||
}
|
||||
|
||||
// Debugf writes v with format into access log.
|
||||
func Debugf(ctx context.Context, format string, v ...interface{}) {
|
||||
getLogger(ctx).Debugf(format, v...)
|
||||
}
|
||||
|
||||
// Debugv writes v into access log with json content.
|
||||
func Debugv(ctx context.Context, v interface{}) {
|
||||
getLogger(ctx).Debugv(v)
|
||||
}
|
||||
|
||||
// Debugw writes msg along with fields into access log.
|
||||
func Debugw(ctx context.Context, msg string, fields ...LogField) {
|
||||
getLogger(ctx).Debugw(msg, fields...)
|
||||
}
|
||||
|
||||
// Error writes v into error log.
|
||||
func Error(ctx context.Context, v ...interface{}) {
|
||||
getLogger(ctx).Error(v...)
|
||||
|
||||
@@ -140,6 +140,54 @@ func TestInfow(t *testing.T) {
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestDebug(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
file, line := getFileLine()
|
||||
Debug(context.Background(), "foo")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestDebugf(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
file, line := getFileLine()
|
||||
Debugf(context.Background(), "foo %s", "bar")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestDebugv(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
file, line := getFileLine()
|
||||
Debugv(context.Background(), "foo")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestDebugw(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
file, line := getFileLine()
|
||||
Debugw(context.Background(), "foo", Field("a", "b"))
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestMust(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
Must(nil)
|
||||
|
||||
@@ -2,15 +2,34 @@ package logx
|
||||
|
||||
// A LogConf is a logging config.
|
||||
type LogConf struct {
|
||||
ServiceName string `json:",optional"`
|
||||
Mode string `json:",default=console,options=[console,file,volume]"`
|
||||
Encoding string `json:",default=json,options=[json,plain]"`
|
||||
TimeFormat string `json:",optional"`
|
||||
Path string `json:",default=logs"`
|
||||
Level string `json:",default=info,options=[debug,info,error,severe]"`
|
||||
Compress bool `json:",optional"`
|
||||
KeepDays int `json:",optional"`
|
||||
StackCooldownMillis int `json:",default=100"`
|
||||
// ServiceName represents the service name.
|
||||
ServiceName string `json:",optional"`
|
||||
// Mode represents the logging mode, default is `console`.
|
||||
// console: log to console.
|
||||
// file: log to file.
|
||||
// volume: used in k8s, prepend the hostname to the log file name.
|
||||
Mode string `json:",default=console,options=[console,file,volume]"`
|
||||
// Encoding represents the encoding type, default is `json`.
|
||||
// json: json encoding.
|
||||
// plain: plain text encoding, typically used in development.
|
||||
Encoding string `json:",default=json,options=[json,plain]"`
|
||||
// TimeFormat represents the time format, default is `2006-01-02T15:04:05.000Z07:00`.
|
||||
TimeFormat string `json:",optional"`
|
||||
// Path represents the log file path, default is `logs`.
|
||||
Path string `json:",default=logs"`
|
||||
// Level represents the log level, default is `info`.
|
||||
Level string `json:",default=info,options=[debug,info,error,severe]"`
|
||||
// MaxContentLength represents the max content bytes, default is no limit.
|
||||
MaxContentLength uint32 `json:",optional"`
|
||||
// Compress represents whether to compress the log file, default is `false`.
|
||||
Compress bool `json:",optional"`
|
||||
// Stdout represents whether to log statistics, default is `true`.
|
||||
Stat bool `json:",default=true"`
|
||||
// KeepDays represents how many days the log files will be kept. Default to keep all files.
|
||||
// Only take effect when Mode is `file` or `volume`, both work when Rotation is `daily` or `size`.
|
||||
KeepDays int `json:",optional"`
|
||||
// StackCooldownMillis represents the cooldown time for stack logging, default is 100ms.
|
||||
StackCooldownMillis int `json:",default=100"`
|
||||
// MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever.
|
||||
// Only take effect when RotationRuleType is `size`.
|
||||
// Even thougth `MaxBackups` sets 0, log files will still be removed
|
||||
|
||||
@@ -20,6 +20,8 @@ var (
|
||||
timeFormat = "2006-01-02T15:04:05.000Z07:00"
|
||||
logLevel uint32
|
||||
encoding uint32 = jsonEncodingType
|
||||
// maxContentLength is used to truncate the log content, 0 for not truncating.
|
||||
maxContentLength uint32
|
||||
// use uint32 for atomic operations
|
||||
disableLog uint32
|
||||
disableStat uint32
|
||||
@@ -230,10 +232,16 @@ func SetUp(c LogConf) (err error) {
|
||||
setupOnce.Do(func() {
|
||||
setupLogLevel(c)
|
||||
|
||||
if !c.Stat {
|
||||
DisableStat()
|
||||
}
|
||||
|
||||
if len(c.TimeFormat) > 0 {
|
||||
timeFormat = c.TimeFormat
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&maxContentLength, c.MaxContentLength)
|
||||
|
||||
switch c.Encoding {
|
||||
case plainEncoding:
|
||||
atomic.StoreUint32(&encoding, plainEncodingType)
|
||||
|
||||
@@ -529,9 +529,9 @@ func TestSetLevel(t *testing.T) {
|
||||
|
||||
func TestSetLevelTwiceWithMode(t *testing.T) {
|
||||
testModes := []string{
|
||||
"mode",
|
||||
"console",
|
||||
"volumn",
|
||||
"mode",
|
||||
}
|
||||
w := new(mockWriter)
|
||||
old := writer.Swap(w)
|
||||
@@ -791,9 +791,12 @@ func doTestStructedLogConsole(t *testing.T, w *mockWriter, write func(...interfa
|
||||
func testSetLevelTwiceWithMode(t *testing.T, mode string, w *mockWriter) {
|
||||
writer.Store(nil)
|
||||
SetUp(LogConf{
|
||||
Mode: mode,
|
||||
Level: "error",
|
||||
Path: "/dev/null",
|
||||
Mode: mode,
|
||||
Level: "debug",
|
||||
Path: "/dev/null",
|
||||
Encoding: plainEncoding,
|
||||
Stat: false,
|
||||
TimeFormat: time.RFC3339,
|
||||
})
|
||||
SetUp(LogConf{
|
||||
Mode: mode,
|
||||
|
||||
@@ -16,13 +16,13 @@ const (
|
||||
const (
|
||||
jsonEncodingType = iota
|
||||
plainEncodingType
|
||||
|
||||
plainEncoding = "plain"
|
||||
plainEncodingSep = '\t'
|
||||
sizeRotationRule = "size"
|
||||
)
|
||||
|
||||
const (
|
||||
plainEncoding = "plain"
|
||||
plainEncodingSep = '\t'
|
||||
sizeRotationRule = "size"
|
||||
|
||||
accessFilename = "access.log"
|
||||
errorFilename = "error.log"
|
||||
severeFilename = "severe.log"
|
||||
@@ -53,6 +53,7 @@ const (
|
||||
spanKey = "span"
|
||||
timestampKey = "@timestamp"
|
||||
traceKey = "trace"
|
||||
truncatedKey = "truncated"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -60,4 +61,6 @@ var (
|
||||
ErrLogPathNotSet = errors.New("log path must be set")
|
||||
// ErrLogServiceNameNotSet is an error that indicates that the service name is not set.
|
||||
ErrLogServiceNameNotSet = errors.New("log service name must be set")
|
||||
|
||||
truncatedField = Field(truncatedKey, true)
|
||||
)
|
||||
|
||||
@@ -253,11 +253,11 @@ func (n nopWriter) Stack(_ interface{}) {
|
||||
func (n nopWriter) Stat(_ interface{}, _ ...LogField) {
|
||||
}
|
||||
|
||||
func buildFields(fields ...LogField) []string {
|
||||
func buildPlainFields(fields ...LogField) []string {
|
||||
var items []string
|
||||
|
||||
for _, field := range fields {
|
||||
items = append(items, fmt.Sprintf("%s=%v", field.Key, field.Value))
|
||||
items = append(items, fmt.Sprintf("%s=%+v", field.Key, field.Value))
|
||||
}
|
||||
|
||||
return items
|
||||
@@ -269,15 +269,29 @@ func combineGlobalFields(fields []LogField) []LogField {
|
||||
return fields
|
||||
}
|
||||
|
||||
return append(globals.([]LogField), fields...)
|
||||
gf := globals.([]LogField)
|
||||
ret := make([]LogField, 0, len(gf)+len(fields))
|
||||
ret = append(ret, gf...)
|
||||
ret = append(ret, fields...)
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func output(writer io.Writer, level string, val interface{}, fields ...LogField) {
|
||||
// only truncate string content, don't know how to truncate the values of other types.
|
||||
if v, ok := val.(string); ok {
|
||||
maxLen := atomic.LoadUint32(&maxContentLength)
|
||||
if maxLen > 0 && len(v) > int(maxLen) {
|
||||
val = v[:maxLen]
|
||||
fields = append(fields, truncatedField)
|
||||
}
|
||||
}
|
||||
|
||||
fields = combineGlobalFields(fields)
|
||||
|
||||
switch atomic.LoadUint32(&encoding) {
|
||||
case plainEncodingType:
|
||||
writePlainAny(writer, level, val, buildFields(fields...)...)
|
||||
writePlainAny(writer, level, val, buildPlainFields(fields...)...)
|
||||
default:
|
||||
entry := make(logEntry)
|
||||
for _, field := range fields {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -107,7 +108,7 @@ func TestNopWriter(t *testing.T) {
|
||||
w.Stack("foo")
|
||||
w.Stat("foo")
|
||||
w.Slow("foo")
|
||||
w.Close()
|
||||
_ = w.Close()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -157,9 +158,40 @@ func TestWritePlainAny(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestLogWithLimitContentLength(t *testing.T) {
|
||||
maxLen := atomic.LoadUint32(&maxContentLength)
|
||||
atomic.StoreUint32(&maxContentLength, 10)
|
||||
|
||||
t.Cleanup(func() {
|
||||
atomic.StoreUint32(&maxContentLength, maxLen)
|
||||
})
|
||||
|
||||
t.Run("alert", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
w := NewWriter(&buf)
|
||||
w.Info("1234567890")
|
||||
var v1 mockedEntry
|
||||
if err := json.Unmarshal(buf.Bytes(), &v1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, "1234567890", v1.Content)
|
||||
assert.False(t, v1.Truncated)
|
||||
|
||||
buf.Reset()
|
||||
var v2 mockedEntry
|
||||
w.Info("12345678901")
|
||||
if err := json.Unmarshal(buf.Bytes(), &v2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, "1234567890", v2.Content)
|
||||
assert.True(t, v2.Truncated)
|
||||
})
|
||||
}
|
||||
|
||||
type mockedEntry struct {
|
||||
Level string `json:"level"`
|
||||
Content string `json:"content"`
|
||||
Level string `json:"level"`
|
||||
Content string `json:"content"`
|
||||
Truncated bool `json:"truncated"`
|
||||
}
|
||||
|
||||
type easyToCloseWriter struct{}
|
||||
|
||||
@@ -34,7 +34,7 @@ func getJsonUnmarshaler(opts ...UnmarshalOption) *Unmarshaler {
|
||||
}
|
||||
|
||||
func unmarshalJsonBytes(content []byte, v interface{}, unmarshaler *Unmarshaler) error {
|
||||
var m map[string]interface{}
|
||||
var m interface{}
|
||||
if err := jsonx.Unmarshal(content, &m); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func unmarshalJsonBytes(content []byte, v interface{}, unmarshaler *Unmarshaler)
|
||||
}
|
||||
|
||||
func unmarshalJsonReader(reader io.Reader, v interface{}, unmarshaler *Unmarshaler) error {
|
||||
var m map[string]interface{}
|
||||
var m interface{}
|
||||
if err := jsonx.UnmarshalFromReader(reader, &m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -856,8 +856,7 @@ func TestUnmarshalBytesError(t *testing.T) {
|
||||
}
|
||||
|
||||
err := UnmarshalJsonBytes([]byte(payload), &v)
|
||||
assert.NotNil(t, err)
|
||||
assert.True(t, strings.Contains(err.Error(), payload))
|
||||
assert.Equal(t, errTypeMismatch, err)
|
||||
}
|
||||
|
||||
func TestUnmarshalReaderError(t *testing.T) {
|
||||
@@ -867,9 +866,7 @@ func TestUnmarshalReaderError(t *testing.T) {
|
||||
Any string
|
||||
}
|
||||
|
||||
err := UnmarshalJsonReader(reader, &v)
|
||||
assert.NotNil(t, err)
|
||||
assert.True(t, strings.Contains(err.Error(), payload))
|
||||
assert.Equal(t, errTypeMismatch, UnmarshalJsonReader(reader, &v))
|
||||
}
|
||||
|
||||
func TestUnmarshalMap(t *testing.T) {
|
||||
@@ -920,3 +917,26 @@ func TestUnmarshalMap(t *testing.T) {
|
||||
assert.Equal(t, "foo", v.Any)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonArray(t *testing.T) {
|
||||
var v []struct {
|
||||
Name string `json:"name"`
|
||||
Age int `json:"age"`
|
||||
}
|
||||
|
||||
body := `[{"name":"kevin", "age": 18}]`
|
||||
assert.NoError(t, UnmarshalJsonBytes([]byte(body), &v))
|
||||
assert.Equal(t, 1, len(v))
|
||||
assert.Equal(t, "kevin", v[0].Name)
|
||||
assert.Equal(t, 18, v[0].Age)
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesError(t *testing.T) {
|
||||
var v []struct {
|
||||
Name string `json:"name"`
|
||||
Age int `json:"age"`
|
||||
}
|
||||
|
||||
assert.Error(t, UnmarshalJsonBytes([]byte((``)), &v))
|
||||
assert.Error(t, UnmarshalJsonReader(strings.NewReader(``), &v))
|
||||
}
|
||||
|
||||
@@ -47,6 +47,7 @@ type (
|
||||
UnmarshalOption func(*unmarshalOptions)
|
||||
|
||||
unmarshalOptions struct {
|
||||
fillDefault bool
|
||||
fromString bool
|
||||
canonicalKey func(key string) string
|
||||
}
|
||||
@@ -71,8 +72,29 @@ func UnmarshalKey(m map[string]interface{}, v interface{}) error {
|
||||
}
|
||||
|
||||
// Unmarshal unmarshals m into v.
|
||||
func (u *Unmarshaler) Unmarshal(m map[string]interface{}, v interface{}) error {
|
||||
return u.UnmarshalValuer(mapValuer(m), v)
|
||||
func (u *Unmarshaler) Unmarshal(i interface{}, v interface{}) error {
|
||||
valueType := reflect.TypeOf(v)
|
||||
if valueType.Kind() != reflect.Ptr {
|
||||
return errValueNotSettable
|
||||
}
|
||||
|
||||
elemType := Deref(valueType)
|
||||
switch iv := i.(type) {
|
||||
case map[string]interface{}:
|
||||
if elemType.Kind() != reflect.Struct {
|
||||
return errTypeMismatch
|
||||
}
|
||||
|
||||
return u.UnmarshalValuer(mapValuer(iv), v)
|
||||
case []interface{}:
|
||||
if elemType.Kind() != reflect.Slice {
|
||||
return errTypeMismatch
|
||||
}
|
||||
|
||||
return u.fillSlice(elemType, reflect.ValueOf(v).Elem(), iv)
|
||||
default:
|
||||
return errUnsupportedType
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalValuer unmarshals m into v.
|
||||
@@ -127,7 +149,6 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
|
||||
}
|
||||
|
||||
baseType := fieldType.Elem()
|
||||
baseKind := baseType.Kind()
|
||||
dereffedBaseType := Deref(baseType)
|
||||
dereffedBaseKind := dereffedBaseType.Kind()
|
||||
refValue := reflect.ValueOf(mapValue)
|
||||
@@ -156,11 +177,7 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
|
||||
return err
|
||||
}
|
||||
|
||||
if baseKind == reflect.Ptr {
|
||||
conv.Index(i).Set(target)
|
||||
} else {
|
||||
conv.Index(i).Set(target.Elem())
|
||||
}
|
||||
SetValue(fieldType.Elem(), conv.Index(i), target.Elem())
|
||||
case reflect.Slice:
|
||||
if err := u.fillSlice(dereffedBaseType, conv.Index(i), ithValue); err != nil {
|
||||
return err
|
||||
@@ -214,9 +231,9 @@ func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
|
||||
ithVal := slice.Index(index)
|
||||
switch v := value.(type) {
|
||||
case fmt.Stringer:
|
||||
return setValue(baseKind, ithVal, v.String())
|
||||
return setValueFromString(baseKind, ithVal, v.String())
|
||||
case string:
|
||||
return setValue(baseKind, ithVal, v)
|
||||
return setValueFromString(baseKind, ithVal, v)
|
||||
case map[string]interface{}:
|
||||
return u.fillMap(ithVal.Type(), ithVal, value)
|
||||
default:
|
||||
@@ -230,7 +247,7 @@ func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
|
||||
|
||||
target := reflect.New(baseType).Elem()
|
||||
target.Set(reflect.ValueOf(value))
|
||||
ithVal.Set(target.Addr())
|
||||
SetValue(ithVal.Type(), ithVal, target)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -274,7 +291,6 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue inter
|
||||
|
||||
refValue := reflect.ValueOf(mapValue)
|
||||
targetValue := reflect.MakeMapWithSize(mapType, refValue.Len())
|
||||
fieldElemKind := elemType.Kind()
|
||||
dereffedElemType := Deref(elemType)
|
||||
dereffedElemKind := dereffedElemType.Kind()
|
||||
|
||||
@@ -301,11 +317,7 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue inter
|
||||
return emptyValue, err
|
||||
}
|
||||
|
||||
if fieldElemKind == reflect.Ptr {
|
||||
targetValue.SetMapIndex(key, target)
|
||||
} else {
|
||||
targetValue.SetMapIndex(key, target.Elem())
|
||||
}
|
||||
SetMapIndexValue(elemType, targetValue, key, target.Elem())
|
||||
case reflect.Map:
|
||||
keythMap, ok := keythData.(map[string]interface{})
|
||||
if !ok {
|
||||
@@ -334,7 +346,7 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue inter
|
||||
targetValue.SetMapIndex(key, reflect.ValueOf(v))
|
||||
case json.Number:
|
||||
target := reflect.New(dereffedElemType)
|
||||
if err := setValue(dereffedElemKind, target.Elem(), v.String()); err != nil {
|
||||
if err := setValueFromString(dereffedElemKind, target.Elem(), v.String()); err != nil {
|
||||
return emptyValue, err
|
||||
}
|
||||
|
||||
@@ -361,6 +373,26 @@ func (u *Unmarshaler) parseOptionsWithContext(field reflect.StructField, m Value
|
||||
return key, nil, nil
|
||||
}
|
||||
|
||||
if u.opts.canonicalKey != nil {
|
||||
key = u.opts.canonicalKey(key)
|
||||
|
||||
if len(options.OptionalDep) > 0 {
|
||||
// need to create a new fieldOption, because the original one is shared through cache.
|
||||
options = &fieldOptions{
|
||||
fieldOptionsWithContext: fieldOptionsWithContext{
|
||||
Inherit: options.Inherit,
|
||||
FromString: options.FromString,
|
||||
Optional: options.Optional,
|
||||
Options: options.Options,
|
||||
Default: options.Default,
|
||||
EnvVar: options.EnvVar,
|
||||
Range: options.Range,
|
||||
},
|
||||
OptionalDep: u.opts.canonicalKey(options.OptionalDep),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
optsWithContext, err := options.toOptionsWithContext(key, m, fullName)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
@@ -376,19 +408,51 @@ func (u *Unmarshaler) processAnonymousField(field reflect.StructField, value ref
|
||||
return err
|
||||
}
|
||||
|
||||
if _, hasValue := getValue(m, key); hasValue {
|
||||
return fmt.Errorf("fields of %s can't be wrapped inside, because it's anonymous", key)
|
||||
}
|
||||
|
||||
if options.optional() {
|
||||
return u.processAnonymousFieldOptional(field.Type, value, key, m, fullName)
|
||||
return u.processAnonymousFieldOptional(field, value, key, m, fullName)
|
||||
}
|
||||
|
||||
return u.processAnonymousFieldRequired(field.Type, value, m, fullName)
|
||||
return u.processAnonymousFieldRequired(field, value, m, fullName)
|
||||
}
|
||||
|
||||
func (u *Unmarshaler) processAnonymousFieldOptional(fieldType reflect.Type, value reflect.Value,
|
||||
func (u *Unmarshaler) processAnonymousFieldOptional(field reflect.StructField, value reflect.Value,
|
||||
key string, m valuerWithParent, fullName string) error {
|
||||
derefedFieldType := Deref(field.Type)
|
||||
|
||||
switch derefedFieldType.Kind() {
|
||||
case reflect.Struct:
|
||||
return u.processAnonymousStructFieldOptional(field.Type, value, key, m, fullName)
|
||||
default:
|
||||
return u.processNamedField(field, value, m, fullName)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Unmarshaler) processAnonymousFieldRequired(field reflect.StructField, value reflect.Value,
|
||||
m valuerWithParent, fullName string) error {
|
||||
fieldType := field.Type
|
||||
maybeNewValue(fieldType, value)
|
||||
derefedFieldType := Deref(fieldType)
|
||||
indirectValue := reflect.Indirect(value)
|
||||
|
||||
switch derefedFieldType.Kind() {
|
||||
case reflect.Struct:
|
||||
for i := 0; i < derefedFieldType.NumField(); i++ {
|
||||
if err := u.processField(derefedFieldType.Field(i), indirectValue.Field(i),
|
||||
m, fullName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
default:
|
||||
if err := u.processNamedField(field, indirectValue, m, fullName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *Unmarshaler) processAnonymousStructFieldOptional(fieldType reflect.Type,
|
||||
value reflect.Value, key string, m valuerWithParent, fullName string) error {
|
||||
var filled bool
|
||||
var required int
|
||||
var requiredFilled int
|
||||
@@ -428,21 +492,6 @@ func (u *Unmarshaler) processAnonymousFieldOptional(fieldType reflect.Type, valu
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *Unmarshaler) processAnonymousFieldRequired(fieldType reflect.Type, value reflect.Value,
|
||||
m valuerWithParent, fullName string) error {
|
||||
maybeNewValue(fieldType, value)
|
||||
derefedFieldType := Deref(fieldType)
|
||||
indirectValue := reflect.Indirect(value)
|
||||
|
||||
for i := 0; i < derefedFieldType.NumField(); i++ {
|
||||
if err := u.processField(derefedFieldType.Field(i), indirectValue.Field(i), m, fullName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *Unmarshaler) processField(field reflect.StructField, value reflect.Value,
|
||||
m valuerWithParent, fullName string) error {
|
||||
if usingDifferentKeys(u.key, field) {
|
||||
@@ -481,7 +530,7 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
|
||||
case valueKind == reflect.String && typeKind == reflect.Slice:
|
||||
return u.fillSliceFromString(fieldType, value, mapValue)
|
||||
case valueKind == reflect.String && derefedFieldType == durationType:
|
||||
return fillDurationValue(fieldType.Kind(), value, mapValue.(string))
|
||||
return fillDurationValue(fieldType, value, mapValue.(string))
|
||||
default:
|
||||
return u.processFieldPrimitive(fieldType, value, mapValue, opts, fullName)
|
||||
}
|
||||
@@ -517,8 +566,8 @@ func (u *Unmarshaler) processFieldPrimitive(fieldType reflect.Type, value reflec
|
||||
|
||||
func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type, value reflect.Value,
|
||||
v json.Number, opts *fieldOptionsWithContext, fullName string) error {
|
||||
fieldKind := fieldType.Kind()
|
||||
typeKind := Deref(fieldType).Kind()
|
||||
baseType := Deref(fieldType)
|
||||
typeKind := baseType.Kind()
|
||||
|
||||
if err := validateJsonNumberRange(v, opts); err != nil {
|
||||
return err
|
||||
@@ -528,9 +577,7 @@ func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type
|
||||
return err
|
||||
}
|
||||
|
||||
if fieldKind == reflect.Ptr {
|
||||
value = value.Elem()
|
||||
}
|
||||
target := reflect.New(Deref(fieldType)).Elem()
|
||||
|
||||
switch typeKind {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
@@ -539,7 +586,7 @@ func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type
|
||||
return err
|
||||
}
|
||||
|
||||
value.SetInt(iValue)
|
||||
target.SetInt(iValue)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
iValue, err := v.Int64()
|
||||
if err != nil {
|
||||
@@ -550,18 +597,20 @@ func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type
|
||||
return fmt.Errorf("unmarshal %q with bad value %q", fullName, v.String())
|
||||
}
|
||||
|
||||
value.SetUint(uint64(iValue))
|
||||
target.SetUint(uint64(iValue))
|
||||
case reflect.Float32, reflect.Float64:
|
||||
fValue, err := v.Float64()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
value.SetFloat(fValue)
|
||||
target.SetFloat(fValue)
|
||||
default:
|
||||
return newTypeMismatchError(fullName)
|
||||
}
|
||||
|
||||
SetValue(fieldType, value, target)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -574,7 +623,7 @@ func (u *Unmarshaler) processFieldStruct(fieldType reflect.Type, value reflect.V
|
||||
return err
|
||||
}
|
||||
|
||||
value.Set(target.Addr())
|
||||
SetValue(fieldType, value, target)
|
||||
} else if err := u.unmarshalWithFullName(m, value.Addr().Interface(), fullName); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -588,7 +637,13 @@ func (u *Unmarshaler) processFieldTextUnmarshaler(fieldType reflect.Type, value
|
||||
var ok bool
|
||||
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
tval, ok = value.Interface().(encoding.TextUnmarshaler)
|
||||
if value.Elem().Kind() == reflect.Ptr {
|
||||
target := reflect.New(Deref(fieldType))
|
||||
SetValue(fieldType.Elem(), value, target)
|
||||
tval, ok = target.Interface().(encoding.TextUnmarshaler)
|
||||
} else {
|
||||
tval, ok = value.Interface().(encoding.TextUnmarshaler)
|
||||
}
|
||||
} else {
|
||||
tval, ok = value.Addr().Interface().(encoding.TextUnmarshaler)
|
||||
}
|
||||
@@ -621,7 +676,7 @@ func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value ref
|
||||
value.SetBool(val)
|
||||
return nil
|
||||
case durationType.Kind():
|
||||
if err := fillDurationValue(fieldKind, value, envVal); err != nil {
|
||||
if err := fillDurationValue(fieldType, value, envVal); err != nil {
|
||||
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
|
||||
}
|
||||
|
||||
@@ -656,7 +711,14 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect
|
||||
|
||||
valuer := createValuer(m, opts)
|
||||
mapValue, hasValue := getValue(valuer, canonicalKey)
|
||||
if !hasValue {
|
||||
|
||||
// When fillDefault is used, m is a null value, hasValue must be false, all priority judgments fillDefault,
|
||||
if u.opts.fillDefault {
|
||||
if !value.IsZero() {
|
||||
return fmt.Errorf("set the default value, %s must be zero", fullName)
|
||||
}
|
||||
return u.processNamedFieldWithoutValue(field.Type, value, opts, fullName)
|
||||
} else if !hasValue {
|
||||
return u.processNamedFieldWithoutValue(field.Type, value, opts, fullName)
|
||||
}
|
||||
|
||||
@@ -693,47 +755,64 @@ func (u *Unmarshaler) processNamedFieldWithValue(fieldType reflect.Type, value r
|
||||
return u.processFieldNotFromString(fieldType, value, vp, opts, fullName)
|
||||
default:
|
||||
if u.opts.fromString || opts.fromString() {
|
||||
valueKind := reflect.TypeOf(mapValue).Kind()
|
||||
if valueKind != reflect.String {
|
||||
return fmt.Errorf("error: the value in map is not string, but %s", valueKind)
|
||||
}
|
||||
|
||||
options := opts.options()
|
||||
if len(options) > 0 {
|
||||
if !stringx.Contains(options, mapValue.(string)) {
|
||||
return fmt.Errorf(`error: value "%s" for field "%s" is not defined in options "%v"`,
|
||||
mapValue, key, options)
|
||||
}
|
||||
}
|
||||
|
||||
return fillPrimitive(fieldType, value, mapValue, opts, fullName)
|
||||
return u.processNamedFieldWithValueFromString(fieldType, value, mapValue,
|
||||
key, opts, fullName)
|
||||
}
|
||||
|
||||
return u.processFieldNotFromString(fieldType, value, vp, opts, fullName)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Unmarshaler) processNamedFieldWithValueFromString(fieldType reflect.Type, value reflect.Value,
|
||||
mapValue interface{}, key string, opts *fieldOptionsWithContext, fullName string) error {
|
||||
valueKind := reflect.TypeOf(mapValue).Kind()
|
||||
if valueKind != reflect.String {
|
||||
return fmt.Errorf("the value in map is not string, but %s", valueKind)
|
||||
}
|
||||
|
||||
options := opts.options()
|
||||
if len(options) > 0 {
|
||||
var checkValue string
|
||||
switch mt := mapValue.(type) {
|
||||
case string:
|
||||
checkValue = mt
|
||||
case fmt.Stringer:
|
||||
checkValue = mt.String()
|
||||
default:
|
||||
return fmt.Errorf("the value in map is not string or json.Number, but %s",
|
||||
valueKind.String())
|
||||
}
|
||||
|
||||
if !stringx.Contains(options, checkValue) {
|
||||
return fmt.Errorf(`value "%s" for field "%s" is not defined in options "%v"`,
|
||||
mapValue, key, options)
|
||||
}
|
||||
}
|
||||
|
||||
return fillPrimitive(fieldType, value, mapValue, opts, fullName)
|
||||
}
|
||||
|
||||
func (u *Unmarshaler) processNamedFieldWithoutValue(fieldType reflect.Type, value reflect.Value,
|
||||
opts *fieldOptionsWithContext, fullName string) error {
|
||||
derefedType := Deref(fieldType)
|
||||
fieldKind := derefedType.Kind()
|
||||
if defaultValue, ok := opts.getDefault(); ok {
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
maybeNewValue(fieldType, value)
|
||||
value = value.Elem()
|
||||
}
|
||||
if derefedType == durationType {
|
||||
return fillDurationValue(fieldKind, value, defaultValue)
|
||||
return fillDurationValue(fieldType, value, defaultValue)
|
||||
}
|
||||
|
||||
switch fieldKind {
|
||||
case reflect.Array, reflect.Slice:
|
||||
return u.fillSliceWithDefault(derefedType, value, defaultValue)
|
||||
default:
|
||||
return setValue(fieldKind, value, defaultValue)
|
||||
return setValueFromString(fieldKind, value, defaultValue)
|
||||
}
|
||||
}
|
||||
|
||||
if u.opts.fillDefault {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch fieldKind {
|
||||
case reflect.Array, reflect.Map, reflect.Slice:
|
||||
if !opts.optional() {
|
||||
@@ -771,15 +850,27 @@ func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v interface{}, f
|
||||
return err
|
||||
}
|
||||
|
||||
rte := reflect.TypeOf(v).Elem()
|
||||
if rte.Kind() != reflect.Struct {
|
||||
valueType := reflect.TypeOf(v)
|
||||
baseType := Deref(valueType)
|
||||
if baseType.Kind() != reflect.Struct {
|
||||
return errValueNotStruct
|
||||
}
|
||||
|
||||
rve := rv.Elem()
|
||||
numFields := rte.NumField()
|
||||
valElem := rv.Elem()
|
||||
if valElem.Kind() == reflect.Ptr {
|
||||
target := reflect.New(baseType).Elem()
|
||||
SetValue(valueType.Elem(), valElem, target)
|
||||
valElem = target
|
||||
}
|
||||
|
||||
numFields := baseType.NumField()
|
||||
for i := 0; i < numFields; i++ {
|
||||
if err := u.processField(rte.Field(i), rve.Field(i), m, fullName); err != nil {
|
||||
field := baseType.Field(i)
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := u.processField(field, valElem.Field(i), m, fullName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -794,13 +885,20 @@ func WithStringValues() UnmarshalOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithCanonicalKeyFunc customizes an Unmarshaler with Canonical Key func
|
||||
// WithCanonicalKeyFunc customizes an Unmarshaler with Canonical Key func.
|
||||
func WithCanonicalKeyFunc(f func(string) string) UnmarshalOption {
|
||||
return func(opt *unmarshalOptions) {
|
||||
opt.canonicalKey = f
|
||||
}
|
||||
}
|
||||
|
||||
// WithDefault customizes an Unmarshaler with fill default values.
|
||||
func WithDefault() UnmarshalOption {
|
||||
return func(opt *unmarshalOptions) {
|
||||
opt.fillDefault = true
|
||||
}
|
||||
}
|
||||
|
||||
func createValuer(v valuerWithParent, opts *fieldOptionsWithContext) valuerWithParent {
|
||||
if opts.inherit() {
|
||||
return recursiveValuer{
|
||||
@@ -815,17 +913,13 @@ func createValuer(v valuerWithParent, opts *fieldOptionsWithContext) valuerWithP
|
||||
}
|
||||
}
|
||||
|
||||
func fillDurationValue(fieldKind reflect.Kind, value reflect.Value, dur string) error {
|
||||
func fillDurationValue(fieldType reflect.Type, value reflect.Value, dur string) error {
|
||||
d, err := time.ParseDuration(dur)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if fieldKind == reflect.Ptr {
|
||||
value.Elem().Set(reflect.ValueOf(d))
|
||||
} else {
|
||||
value.Set(reflect.ValueOf(d))
|
||||
}
|
||||
SetValue(fieldType, value, reflect.ValueOf(d))
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -841,7 +935,7 @@ func fillPrimitive(fieldType reflect.Type, value reflect.Value, mapValue interfa
|
||||
target := reflect.New(baseType).Elem()
|
||||
switch mapValue.(type) {
|
||||
case string, json.Number:
|
||||
value.Set(target.Addr())
|
||||
SetValue(fieldType, value, target)
|
||||
value = target
|
||||
}
|
||||
}
|
||||
@@ -853,7 +947,7 @@ func fillPrimitive(fieldType reflect.Type, value reflect.Value, mapValue interfa
|
||||
if err := validateJsonNumberRange(v, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
return setValue(baseType.Kind(), value, v.String())
|
||||
return setValueFromString(baseType.Kind(), value, v.String())
|
||||
default:
|
||||
return newTypeMismatchError(fullName)
|
||||
}
|
||||
@@ -873,7 +967,7 @@ func fillWithSameType(fieldType reflect.Type, value reflect.Value, mapValue inte
|
||||
baseType := Deref(fieldType)
|
||||
target := reflect.New(baseType).Elem()
|
||||
setSameKindValue(baseType, target, mapValue)
|
||||
value.Set(target.Addr())
|
||||
SetValue(fieldType, value, target)
|
||||
} else {
|
||||
setSameKindValue(fieldType, value, mapValue)
|
||||
}
|
||||
@@ -934,7 +1028,7 @@ func newInitError(name string) error {
|
||||
}
|
||||
|
||||
func newTypeMismatchError(name string) error {
|
||||
return fmt.Errorf("error: type mismatch for field %s", name)
|
||||
return fmt.Errorf("type mismatch for field %s", name)
|
||||
}
|
||||
|
||||
func readKeys(key string) []string {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -56,7 +56,7 @@ type (
|
||||
|
||||
// Deref dereferences a type, if pointer type, returns its element type.
|
||||
func Deref(t reflect.Type) reflect.Type {
|
||||
if t.Kind() == reflect.Ptr {
|
||||
for t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
@@ -68,6 +68,16 @@ func Repr(v interface{}) string {
|
||||
return lang.Repr(v)
|
||||
}
|
||||
|
||||
// SetValue sets target to value, pointers are processed automatically.
|
||||
func SetValue(tp reflect.Type, value, target reflect.Value) {
|
||||
value.Set(convertTypeOfPtr(tp, target))
|
||||
}
|
||||
|
||||
// SetMapIndexValue sets target to value at key position, pointers are processed automatically.
|
||||
func SetMapIndexValue(tp reflect.Type, value, key, target reflect.Value) {
|
||||
value.SetMapIndex(key, convertTypeOfPtr(tp, target))
|
||||
}
|
||||
|
||||
// ValidatePtr validates v if it's a valid pointer.
|
||||
func ValidatePtr(v *reflect.Value) error {
|
||||
// sequence is very important, IsNil must be called after checking Kind() with reflect.Ptr,
|
||||
@@ -79,7 +89,7 @@ func ValidatePtr(v *reflect.Value) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertType(kind reflect.Kind, str string) (interface{}, error) {
|
||||
func convertTypeFromString(kind reflect.Kind, str string) (interface{}, error) {
|
||||
switch kind {
|
||||
case reflect.Bool:
|
||||
switch strings.ToLower(str) {
|
||||
@@ -118,6 +128,23 @@ func convertType(kind reflect.Kind, str string) (interface{}, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func convertTypeOfPtr(tp reflect.Type, target reflect.Value) reflect.Value {
|
||||
// keep the original value is a pointer
|
||||
if tp.Kind() == reflect.Ptr && target.CanAddr() {
|
||||
tp = tp.Elem()
|
||||
target = target.Addr()
|
||||
}
|
||||
|
||||
for tp.Kind() == reflect.Ptr {
|
||||
p := reflect.New(target.Type())
|
||||
p.Elem().Set(target)
|
||||
target = p
|
||||
tp = tp.Elem()
|
||||
}
|
||||
|
||||
return target
|
||||
}
|
||||
|
||||
func doParseKeyAndOptions(field reflect.StructField, value string) (string, *fieldOptions, error) {
|
||||
segments := parseSegments(value)
|
||||
key := strings.TrimSpace(segments[0])
|
||||
@@ -476,13 +503,13 @@ func setMatchedPrimitiveValue(kind reflect.Kind, value reflect.Value, v interfac
|
||||
return nil
|
||||
}
|
||||
|
||||
func setValue(kind reflect.Kind, value reflect.Value, str string) error {
|
||||
func setValueFromString(kind reflect.Kind, value reflect.Value, str string) error {
|
||||
if !value.CanSet() {
|
||||
return errValueNotSettable
|
||||
}
|
||||
|
||||
value = ensureValue(value)
|
||||
v, err := convertType(kind, str)
|
||||
v, err := convertTypeFromString(kind, str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -555,7 +582,7 @@ func validateAndSetValue(kind reflect.Kind, value reflect.Value, str string, opt
|
||||
return errValueNotSettable
|
||||
}
|
||||
|
||||
v, err := convertType(kind, str)
|
||||
v, err := convertTypeFromString(kind, str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -237,7 +237,7 @@ func TestValidatePtrWithZeroValue(t *testing.T) {
|
||||
|
||||
func TestSetValueNotSettable(t *testing.T) {
|
||||
var i int
|
||||
assert.NotNil(t, setValue(reflect.Int, reflect.ValueOf(i), "1"))
|
||||
assert.NotNil(t, setValueFromString(reflect.Int, reflect.ValueOf(i), "1"))
|
||||
}
|
||||
|
||||
func TestParseKeyAndOptionsErrors(t *testing.T) {
|
||||
@@ -290,7 +290,7 @@ func TestSetValueFormatErrors(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.kind.String(), func(t *testing.T) {
|
||||
err := setValue(test.kind, test.target, test.value)
|
||||
err := setValueFromString(test.kind, test.target, test.value)
|
||||
assert.NotEqual(t, errValueNotSettable, err)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
@@ -31,3 +31,27 @@ func TestMapValuerWithInherit_Value(t *testing.T) {
|
||||
assert.Equal(t, "localhost", m["host"])
|
||||
assert.Equal(t, 8080, m["port"])
|
||||
}
|
||||
|
||||
func TestRecursiveValuer_Value(t *testing.T) {
|
||||
input := map[string]interface{}{
|
||||
"component": map[string]interface{}{
|
||||
"name": "test",
|
||||
"foo": map[string]interface{}{
|
||||
"bar": "baz",
|
||||
},
|
||||
},
|
||||
"foo": "value",
|
||||
}
|
||||
valuer := recursiveValuer{
|
||||
current: mapValuer(input["component"].(map[string]interface{})),
|
||||
parent: simpleValuer{
|
||||
current: mapValuer(input),
|
||||
},
|
||||
}
|
||||
|
||||
val, ok := valuer.Value("foo")
|
||||
assert.True(t, ok)
|
||||
assert.EqualValues(t, map[string]interface{}{
|
||||
"bar": "baz",
|
||||
}, val)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/prometheus"
|
||||
)
|
||||
|
||||
@@ -17,6 +18,9 @@ func TestNewCounterVec(t *testing.T) {
|
||||
})
|
||||
defer counterVec.close()
|
||||
counterVecNil := NewCounterVec(nil)
|
||||
counterVec.Inc("path", "code")
|
||||
counterVec.Add(1, "path", "code")
|
||||
proc.Shutdown()
|
||||
assert.NotNil(t, counterVec)
|
||||
assert.Nil(t, counterVecNil)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
)
|
||||
|
||||
func TestNewGaugeVec(t *testing.T) {
|
||||
@@ -18,6 +19,8 @@ func TestNewGaugeVec(t *testing.T) {
|
||||
gaugeVecNil := NewGaugeVec(nil)
|
||||
assert.NotNil(t, gaugeVec)
|
||||
assert.Nil(t, gaugeVecNil)
|
||||
|
||||
proc.Shutdown()
|
||||
}
|
||||
|
||||
func TestGaugeInc(t *testing.T) {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
)
|
||||
|
||||
func TestNewHistogramVec(t *testing.T) {
|
||||
@@ -47,4 +48,6 @@ func TestHistogramObserve(t *testing.T) {
|
||||
|
||||
err := testutil.CollectAndCompare(hv.histogram, strings.NewReader(metadata+val))
|
||||
assert.Nil(t, err)
|
||||
|
||||
proc.Shutdown()
|
||||
}
|
||||
|
||||
@@ -15,5 +15,14 @@ func AddWrapUpListener(fn func()) func() {
|
||||
return fn
|
||||
}
|
||||
|
||||
// SetTimeToForceQuit does nothing on windows.
|
||||
func SetTimeToForceQuit(duration time.Duration) {
|
||||
}
|
||||
|
||||
// Shutdown does nothing on windows.
|
||||
func Shutdown() {
|
||||
}
|
||||
|
||||
// WrapUp does nothing on windows.
|
||||
func WrapUp() {
|
||||
}
|
||||
|
||||
@@ -43,6 +43,16 @@ func SetTimeToForceQuit(duration time.Duration) {
|
||||
delayTimeBeforeForceQuit = duration
|
||||
}
|
||||
|
||||
// Shutdown calls the registered shutdown listeners, only for test purpose.
|
||||
func Shutdown() {
|
||||
shutdownListeners.notifyListeners()
|
||||
}
|
||||
|
||||
// WrapUp wraps up the process, only for test purpose.
|
||||
func WrapUp() {
|
||||
wrapUpListeners.notifyListeners()
|
||||
}
|
||||
|
||||
func gracefulStop(signals chan os.Signal) {
|
||||
signal.Stop(signals)
|
||||
|
||||
|
||||
@@ -18,14 +18,14 @@ func TestShutdown(t *testing.T) {
|
||||
called := AddWrapUpListener(func() {
|
||||
val++
|
||||
})
|
||||
wrapUpListeners.notifyListeners()
|
||||
WrapUp()
|
||||
called()
|
||||
assert.Equal(t, 1, val)
|
||||
|
||||
called = AddShutdownListener(func() {
|
||||
val += 2
|
||||
})
|
||||
shutdownListeners.notifyListeners()
|
||||
Shutdown()
|
||||
called()
|
||||
assert.Equal(t, 3, val)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
@@ -16,3 +17,15 @@ func TestServiceConf(t *testing.T) {
|
||||
}
|
||||
c.MustSetUp()
|
||||
}
|
||||
|
||||
func TestServiceConfWithMetricsUrl(t *testing.T) {
|
||||
c := ServiceConf{
|
||||
Name: "foo",
|
||||
Log: logx.LogConf{
|
||||
Mode: "volume",
|
||||
},
|
||||
Mode: "dev",
|
||||
MetricsUrl: "http://localhost:8080",
|
||||
}
|
||||
assert.NoError(t, c.SetUp())
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -55,6 +56,7 @@ func TestServiceGroup(t *testing.T) {
|
||||
}
|
||||
|
||||
group.Stop()
|
||||
proc.Shutdown()
|
||||
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
@@ -45,9 +45,13 @@ func RawFieldNames(in interface{}, postgresSql ...bool) []string {
|
||||
// `db:"id"`
|
||||
// `db:"id,type=char,length=16"`
|
||||
// `db:",type=char,length=16"`
|
||||
// `db:"-,type=char,length=16"`
|
||||
if strings.Contains(tagv, ",") {
|
||||
tagv = strings.TrimSpace(strings.Split(tagv, ",")[0])
|
||||
}
|
||||
if tagv == "-" {
|
||||
continue
|
||||
}
|
||||
if len(tagv) == 0 {
|
||||
tagv = fi.Name
|
||||
}
|
||||
|
||||
@@ -39,3 +39,33 @@ func TestFieldNamesWithTagOptions(t *testing.T) {
|
||||
assert.Equal(t, expected, out)
|
||||
})
|
||||
}
|
||||
|
||||
type mockedUserWithDashTag struct {
|
||||
ID string `db:"id" json:"id,omitempty"`
|
||||
UserName string `db:"user_name" json:"userName,omitempty"`
|
||||
Mobile string `db:"-" json:"mobile,omitempty"`
|
||||
}
|
||||
|
||||
func TestFieldNamesWithDashTag(t *testing.T) {
|
||||
t.Run("new", func(t *testing.T) {
|
||||
var u mockedUserWithDashTag
|
||||
out := RawFieldNames(&u)
|
||||
expected := []string{"`id`", "`user_name`"}
|
||||
assert.Equal(t, expected, out)
|
||||
})
|
||||
}
|
||||
|
||||
type mockedUserWithDashTagAndOptions struct {
|
||||
ID string `db:"id" json:"id,omitempty"`
|
||||
UserName string `db:"user_name,type=varchar,length=255" json:"userName,omitempty"`
|
||||
Mobile string `db:"-,type=varchar,length=255" json:"mobile,omitempty"`
|
||||
}
|
||||
|
||||
func TestFieldNamesWithDashTagAndOptions(t *testing.T) {
|
||||
t.Run("new", func(t *testing.T) {
|
||||
var u mockedUserWithDashTagAndOptions
|
||||
out := RawFieldNames(&u)
|
||||
expected := []string{"`id`", "`user_name`"}
|
||||
assert.Equal(t, expected, out)
|
||||
})
|
||||
}
|
||||
|
||||
5
core/stores/cache/cache.go
vendored
5
core/stores/cache/cache.go
vendored
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/zeromicro/go-zero/core/errorx"
|
||||
"github.com/zeromicro/go-zero/core/hash"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
)
|
||||
|
||||
@@ -62,12 +63,12 @@ func New(c ClusterConf, barrier syncx.SingleFlight, st *Stat, errNotFound error,
|
||||
}
|
||||
|
||||
if len(c) == 1 {
|
||||
return NewNode(c[0].NewRedis(), barrier, st, errNotFound, opts...)
|
||||
return NewNode(redis.MustNewRedis(c[0].RedisConf), barrier, st, errNotFound, opts...)
|
||||
}
|
||||
|
||||
dispatcher := hash.NewConsistentHash()
|
||||
for _, node := range c {
|
||||
cn := NewNode(node.NewRedis(), barrier, st, errNotFound, opts...)
|
||||
cn := NewNode(redis.MustNewRedis(node.RedisConf), barrier, st, errNotFound, opts...)
|
||||
dispatcher.AddWithWeight(cn, node.Weight)
|
||||
}
|
||||
|
||||
|
||||
119
core/stores/cache/cache_test.go
vendored
119
core/stores/cache/cache_test.go
vendored
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/errorx"
|
||||
"github.com/zeromicro/go-zero/core/hash"
|
||||
@@ -109,51 +110,85 @@ func (mc *mockedNode) TakeWithExpireCtx(ctx context.Context, val interface{}, ke
|
||||
}
|
||||
|
||||
func TestCache_SetDel(t *testing.T) {
|
||||
const total = 1000
|
||||
r1, clean1, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean1()
|
||||
r2, clean2, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean2()
|
||||
conf := ClusterConf{
|
||||
{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: r1.Addr,
|
||||
Type: redis.NodeType,
|
||||
t.Run("test set del", func(t *testing.T) {
|
||||
const total = 1000
|
||||
r1, clean1, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean1()
|
||||
r2, clean2, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean2()
|
||||
conf := ClusterConf{
|
||||
{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: r1.Addr,
|
||||
Type: redis.NodeType,
|
||||
},
|
||||
Weight: 100,
|
||||
},
|
||||
Weight: 100,
|
||||
},
|
||||
{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: r2.Addr,
|
||||
Type: redis.NodeType,
|
||||
{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: r2.Addr,
|
||||
Type: redis.NodeType,
|
||||
},
|
||||
Weight: 100,
|
||||
},
|
||||
Weight: 100,
|
||||
},
|
||||
}
|
||||
c := New(conf, syncx.NewSingleFlight(), NewStat("mock"), errPlaceholder)
|
||||
for i := 0; i < total; i++ {
|
||||
if i%2 == 0 {
|
||||
assert.Nil(t, c.Set(fmt.Sprintf("key/%d", i), i))
|
||||
} else {
|
||||
assert.Nil(t, c.SetWithExpire(fmt.Sprintf("key/%d", i), i, 0))
|
||||
}
|
||||
}
|
||||
for i := 0; i < total; i++ {
|
||||
var val int
|
||||
assert.Nil(t, c.Get(fmt.Sprintf("key/%d", i), &val))
|
||||
assert.Equal(t, i, val)
|
||||
}
|
||||
assert.Nil(t, c.Del())
|
||||
for i := 0; i < total; i++ {
|
||||
assert.Nil(t, c.Del(fmt.Sprintf("key/%d", i)))
|
||||
}
|
||||
for i := 0; i < total; i++ {
|
||||
var val int
|
||||
assert.True(t, c.IsNotFound(c.Get(fmt.Sprintf("key/%d", i), &val)))
|
||||
assert.Equal(t, 0, val)
|
||||
}
|
||||
c := New(conf, syncx.NewSingleFlight(), NewStat("mock"), errPlaceholder)
|
||||
for i := 0; i < total; i++ {
|
||||
if i%2 == 0 {
|
||||
assert.Nil(t, c.Set(fmt.Sprintf("key/%d", i), i))
|
||||
} else {
|
||||
assert.Nil(t, c.SetWithExpire(fmt.Sprintf("key/%d", i), i, 0))
|
||||
}
|
||||
}
|
||||
for i := 0; i < total; i++ {
|
||||
var val int
|
||||
assert.Nil(t, c.Get(fmt.Sprintf("key/%d", i), &val))
|
||||
assert.Equal(t, i, val)
|
||||
}
|
||||
assert.Nil(t, c.Del())
|
||||
for i := 0; i < total; i++ {
|
||||
assert.Nil(t, c.Del(fmt.Sprintf("key/%d", i)))
|
||||
}
|
||||
assert.Nil(t, c.Del("a", "b", "c"))
|
||||
for i := 0; i < total; i++ {
|
||||
var val int
|
||||
assert.True(t, c.IsNotFound(c.Get(fmt.Sprintf("key/%d", i), &val)))
|
||||
assert.Equal(t, 0, val)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test set del error", func(t *testing.T) {
|
||||
r1, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer r1.Close()
|
||||
|
||||
r2, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer r2.Close()
|
||||
|
||||
conf := ClusterConf{
|
||||
{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: r1.Addr(),
|
||||
Type: redis.NodeType,
|
||||
},
|
||||
Weight: 100,
|
||||
},
|
||||
{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: r2.Addr(),
|
||||
Type: redis.NodeType,
|
||||
},
|
||||
Weight: 100,
|
||||
},
|
||||
}
|
||||
c := New(conf, syncx.NewSingleFlight(), NewStat("mock"), errPlaceholder)
|
||||
r1.SetError("mock error")
|
||||
r2.SetError("mock error")
|
||||
assert.NoError(t, c.Del("a", "b", "c"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestCache_OneNode(t *testing.T) {
|
||||
|
||||
3
core/stores/cache/cachenode.go
vendored
3
core/stores/cache/cachenode.go
vendored
@@ -277,5 +277,6 @@ func (c cacheNode) processCache(ctx context.Context, key, data string, v interfa
|
||||
|
||||
func (c cacheNode) setCacheWithNotFound(ctx context.Context, key string) error {
|
||||
seconds := int(math.Ceil(c.aroundDuration(c.notFoundExpiry).Seconds()))
|
||||
return c.rds.SetexCtx(ctx, key, notFoundPlaceholder, seconds)
|
||||
_, err := c.rds.SetnxExCtx(ctx, key, notFoundPlaceholder, seconds)
|
||||
return err
|
||||
}
|
||||
|
||||
114
core/stores/cache/cachenode_test.go
vendored
114
core/stores/cache/cachenode_test.go
vendored
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -11,12 +12,14 @@ import (
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/collection"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/mathx"
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis/redistest"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
)
|
||||
|
||||
var errTestNotFound = errors.New("not found")
|
||||
@@ -27,27 +30,54 @@ func init() {
|
||||
}
|
||||
|
||||
func TestCacheNode_DelCache(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
store.Type = redis.ClusterType
|
||||
defer clean()
|
||||
t.Run("del cache", func(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
store.Type = redis.ClusterType
|
||||
defer clean()
|
||||
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
lock: new(sync.Mutex),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
stat: NewStat("any"),
|
||||
errNotFound: errTestNotFound,
|
||||
}
|
||||
assert.Nil(t, cn.Del())
|
||||
assert.Nil(t, cn.Del([]string{}...))
|
||||
assert.Nil(t, cn.Del(make([]string, 0)...))
|
||||
cn.Set("first", "one")
|
||||
assert.Nil(t, cn.Del("first"))
|
||||
cn.Set("first", "one")
|
||||
cn.Set("second", "two")
|
||||
assert.Nil(t, cn.Del("first", "second"))
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
lock: new(sync.Mutex),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
stat: NewStat("any"),
|
||||
errNotFound: errTestNotFound,
|
||||
}
|
||||
assert.Nil(t, cn.Del())
|
||||
assert.Nil(t, cn.Del([]string{}...))
|
||||
assert.Nil(t, cn.Del(make([]string, 0)...))
|
||||
cn.Set("first", "one")
|
||||
assert.Nil(t, cn.Del("first"))
|
||||
cn.Set("first", "one")
|
||||
cn.Set("second", "two")
|
||||
assert.Nil(t, cn.Del("first", "second"))
|
||||
})
|
||||
|
||||
t.Run("del cache with errors", func(t *testing.T) {
|
||||
old := timingWheel
|
||||
ticker := timex.NewFakeTicker()
|
||||
var err error
|
||||
timingWheel, err = collection.NewTimingWheelWithTicker(
|
||||
time.Millisecond, timingWheelSlots, func(key, value interface{}) {
|
||||
clean(key, value)
|
||||
}, ticker)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
timingWheel = old
|
||||
})
|
||||
|
||||
r, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer r.Close()
|
||||
r.SetError("mock error")
|
||||
|
||||
node := NewNode(redis.New(r.Addr(), redis.Cluster()), syncx.NewSingleFlight(),
|
||||
NewStat("any"), errTestNotFound)
|
||||
assert.NoError(t, node.Del("foo", "bar"))
|
||||
ticker.Tick()
|
||||
runtime.Gosched()
|
||||
})
|
||||
}
|
||||
|
||||
func TestCacheNode_DelCacheWithErrors(t *testing.T) {
|
||||
@@ -125,6 +155,21 @@ func TestCacheNode_Take(t *testing.T) {
|
||||
assert.Equal(t, `"value"`, val)
|
||||
}
|
||||
|
||||
func TestCacheNode_TakeBadRedis(t *testing.T) {
|
||||
r, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer r.Close()
|
||||
r.SetError("mock error")
|
||||
|
||||
cn := NewNode(redis.New(r.Addr()), syncx.NewSingleFlight(), NewStat("any"),
|
||||
errTestNotFound, WithExpiry(time.Second), WithNotFoundExpiry(time.Second))
|
||||
var str string
|
||||
assert.Error(t, cn.Take(&str, "any", func(v interface{}) error {
|
||||
*v.(*string) = "value"
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func TestCacheNode_TakeNotFound(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
@@ -164,6 +209,35 @@ func TestCacheNode_TakeNotFound(t *testing.T) {
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCacheNode_TakeNotFoundButChangedByOthers(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.NoError(t, err)
|
||||
defer clean()
|
||||
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
barrier: syncx.NewSingleFlight(),
|
||||
lock: new(sync.Mutex),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
stat: NewStat("any"),
|
||||
errNotFound: errTestNotFound,
|
||||
}
|
||||
|
||||
var str string
|
||||
err = cn.Take(&str, "any", func(v interface{}) error {
|
||||
store.Set("any", "foo")
|
||||
return errTestNotFound
|
||||
})
|
||||
assert.True(t, cn.IsNotFound(err))
|
||||
|
||||
val, err := store.Get("any")
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, "foo", val)
|
||||
}
|
||||
assert.True(t, cn.IsNotFound(cn.Get("any", &str)))
|
||||
}
|
||||
|
||||
func TestCacheNode_TakeWithExpire(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
|
||||
16
core/stores/cache/cachestat.go
vendored
16
core/stores/cache/cachestat.go
vendored
@@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
)
|
||||
|
||||
const statInterval = time.Minute
|
||||
@@ -25,7 +26,13 @@ func NewStat(name string) *Stat {
|
||||
ret := &Stat{
|
||||
name: name,
|
||||
}
|
||||
go ret.statLoop()
|
||||
|
||||
go func() {
|
||||
ticker := timex.NewTicker(statInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
ret.statLoop(ticker)
|
||||
}()
|
||||
|
||||
return ret
|
||||
}
|
||||
@@ -50,11 +57,8 @@ func (s *Stat) IncrementDbFails() {
|
||||
atomic.AddUint64(&s.DbFails, 1)
|
||||
}
|
||||
|
||||
func (s *Stat) statLoop() {
|
||||
ticker := time.NewTicker(statInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
func (s *Stat) statLoop(ticker timex.Ticker) {
|
||||
for range ticker.Chan() {
|
||||
total := atomic.SwapUint64(&s.Total, 0)
|
||||
if total == 0 {
|
||||
continue
|
||||
|
||||
28
core/stores/cache/cachestat_test.go
vendored
Normal file
28
core/stores/cache/cachestat_test.go
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
)
|
||||
|
||||
func TestCacheStat_statLoop(t *testing.T) {
|
||||
t.Run("stat loop total 0", func(t *testing.T) {
|
||||
var stat Stat
|
||||
ticker := timex.NewFakeTicker()
|
||||
go stat.statLoop(ticker)
|
||||
ticker.Tick()
|
||||
ticker.Tick()
|
||||
ticker.Stop()
|
||||
})
|
||||
|
||||
t.Run("stat loop total not 0", func(t *testing.T) {
|
||||
var stat Stat
|
||||
stat.IncrementTotal()
|
||||
ticker := timex.NewFakeTicker()
|
||||
go stat.statLoop(ticker)
|
||||
ticker.Tick()
|
||||
ticker.Tick()
|
||||
ticker.Stop()
|
||||
})
|
||||
}
|
||||
2
core/stores/cache/cleaner_test.go
vendored
2
core/stores/cache/cleaner_test.go
vendored
@@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
)
|
||||
|
||||
func TestNextDelay(t *testing.T) {
|
||||
@@ -51,6 +52,7 @@ func TestNextDelay(t *testing.T) {
|
||||
next, ok := nextDelay(test.input)
|
||||
assert.Equal(t, test.ok, ok)
|
||||
assert.Equal(t, test.output, next)
|
||||
proc.Shutdown()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -164,7 +164,7 @@ func NewStore(c KvConf) Store {
|
||||
// because Store and redis.Redis has different methods.
|
||||
dispatcher := hash.NewConsistentHash()
|
||||
for _, node := range c {
|
||||
cn := node.NewRedis()
|
||||
cn := redis.MustNewRedis(node.RedisConf)
|
||||
dispatcher.AddWithWeight(cn, node.Weight)
|
||||
}
|
||||
|
||||
|
||||
@@ -26,9 +26,14 @@ type (
|
||||
)
|
||||
|
||||
// NewBulkInserter returns a BulkInserter.
|
||||
func NewBulkInserter(coll *mongo.Collection, interval ...time.Duration) *BulkInserter {
|
||||
func NewBulkInserter(coll Collection, interval ...time.Duration) (*BulkInserter, error) {
|
||||
cloneColl, err := coll.Clone()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inserter := &dbInserter{
|
||||
collection: coll,
|
||||
collection: cloneColl,
|
||||
}
|
||||
|
||||
duration := flushInterval
|
||||
@@ -39,7 +44,7 @@ func NewBulkInserter(coll *mongo.Collection, interval ...time.Duration) *BulkIns
|
||||
return &BulkInserter{
|
||||
executor: executors.NewPeriodicalExecutor(duration, inserter),
|
||||
inserter: inserter,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Flush flushes the inserter, writes all pending records.
|
||||
|
||||
@@ -15,7 +15,8 @@ func TestBulkInserter(t *testing.T) {
|
||||
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "ok", Value: 1}}...))
|
||||
bulk := NewBulkInserter(mt.Coll)
|
||||
bulk, err := NewBulkInserter(createModel(mt).Collection)
|
||||
assert.Equal(t, err, nil)
|
||||
bulk.SetResultHandler(func(result *mongo.InsertManyResult, err error) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, len(result.InsertedIDs))
|
||||
|
||||
@@ -3,15 +3,12 @@ package mon
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
mopt "go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
const defaultTimeout = time.Second
|
||||
|
||||
var clientManager = syncx.NewResourceManager()
|
||||
|
||||
// ClosableClient wraps *mongo.Client and provides a Close method.
|
||||
@@ -30,9 +27,20 @@ func Inject(key string, client *mongo.Client) {
|
||||
clientManager.Inject(key, &ClosableClient{client})
|
||||
}
|
||||
|
||||
func getClient(url string) (*mongo.Client, error) {
|
||||
func getClient(url string, opts ...Option) (*mongo.Client, error) {
|
||||
val, err := clientManager.GetResource(url, func() (io.Closer, error) {
|
||||
cli, err := mongo.Connect(context.Background(), mopt.Client().ApplyURI(url))
|
||||
o := mopt.Client().ApplyURI(url)
|
||||
opts = append([]Option{defaultTimeoutOption()}, opts...)
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
|
||||
cli, err := mongo.Connect(context.Background(), o)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = cli.Ping(context.Background(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ func MustNewModel(uri, db, collection string, opts ...Option) *Model {
|
||||
|
||||
// NewModel returns a Model.
|
||||
func NewModel(uri, db, collection string, opts ...Option) (*Model, error) {
|
||||
cli, err := getClient(uri)
|
||||
cli, err := getClient(uri, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -4,14 +4,15 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
mopt "go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
const defaultTimeout = time.Second * 3
|
||||
|
||||
var slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold)
|
||||
|
||||
type (
|
||||
options struct {
|
||||
timeout time.Duration
|
||||
}
|
||||
options = mopt.ClientOptions
|
||||
|
||||
// Option defines the method to customize a mongo model.
|
||||
Option func(opts *options)
|
||||
@@ -22,8 +23,15 @@ func SetSlowThreshold(threshold time.Duration) {
|
||||
slowThreshold.Set(threshold)
|
||||
}
|
||||
|
||||
func defaultOptions() *options {
|
||||
return &options{
|
||||
timeout: defaultTimeout,
|
||||
func defaultTimeoutOption() Option {
|
||||
return func(opts *options) {
|
||||
opts.SetTimeout(defaultTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout set the mon client operation timeout.
|
||||
func WithTimeout(timeout time.Duration) Option {
|
||||
return func(opts *options) {
|
||||
opts.SetTimeout(timeout)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
mopt "go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
func TestSetSlowThreshold(t *testing.T) {
|
||||
@@ -13,6 +14,14 @@ func TestSetSlowThreshold(t *testing.T) {
|
||||
assert.Equal(t, time.Second, slowThreshold.Load())
|
||||
}
|
||||
|
||||
func TestDefaultOptions(t *testing.T) {
|
||||
assert.Equal(t, defaultTimeout, defaultOptions().timeout)
|
||||
func Test_defaultTimeoutOption(t *testing.T) {
|
||||
opts := mopt.Client()
|
||||
defaultTimeoutOption()(opts)
|
||||
assert.Equal(t, defaultTimeout, *opts.Timeout)
|
||||
}
|
||||
|
||||
func TestWithTimeout(t *testing.T) {
|
||||
opts := mopt.Client()
|
||||
WithTimeout(time.Second)(opts)
|
||||
assert.Equal(t, time.Second, *opts.Timeout)
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
|
||||
"github.com/zeromicro/go-zero/core/trace"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
@@ -14,12 +13,10 @@ import (
|
||||
var mongoCmdAttributeKey = attribute.Key("mongo.cmd")
|
||||
|
||||
func startSpan(ctx context.Context, cmd string) (context.Context, oteltrace.Span) {
|
||||
tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
|
||||
ctx, span := tracer.Start(ctx,
|
||||
spanName,
|
||||
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
|
||||
)
|
||||
tracer := trace.TracerFromContext(ctx)
|
||||
ctx, span := tracer.Start(ctx, spanName, oteltrace.WithSpanKind(oteltrace.SpanKindClient))
|
||||
span.SetAttributes(mongoCmdAttributeKey.String(cmd))
|
||||
|
||||
return ctx, span
|
||||
}
|
||||
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/zeromicro/go-zero/core/executors"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
const (
|
||||
flushInterval = time.Second
|
||||
maxBulkRows = 1000
|
||||
)
|
||||
|
||||
type (
|
||||
// ResultHandler is a handler that used to handle results.
|
||||
ResultHandler func(*mgo.BulkResult, error)
|
||||
|
||||
// A BulkInserter is used to insert bulk of mongo records.
|
||||
BulkInserter struct {
|
||||
executor *executors.PeriodicalExecutor
|
||||
inserter *dbInserter
|
||||
}
|
||||
)
|
||||
|
||||
// NewBulkInserter returns a BulkInserter.
|
||||
func NewBulkInserter(session *mgo.Session, dbName string, collectionNamer func() string) *BulkInserter {
|
||||
inserter := &dbInserter{
|
||||
session: session,
|
||||
dbName: dbName,
|
||||
collectionNamer: collectionNamer,
|
||||
}
|
||||
|
||||
return &BulkInserter{
|
||||
executor: executors.NewPeriodicalExecutor(flushInterval, inserter),
|
||||
inserter: inserter,
|
||||
}
|
||||
}
|
||||
|
||||
// Flush flushes the inserter, writes all pending records.
|
||||
func (bi *BulkInserter) Flush() {
|
||||
bi.executor.Flush()
|
||||
}
|
||||
|
||||
// Insert inserts doc.
|
||||
func (bi *BulkInserter) Insert(doc interface{}) {
|
||||
bi.executor.Add(doc)
|
||||
}
|
||||
|
||||
// SetResultHandler sets the result handler.
|
||||
func (bi *BulkInserter) SetResultHandler(handler ResultHandler) {
|
||||
bi.executor.Sync(func() {
|
||||
bi.inserter.resultHandler = handler
|
||||
})
|
||||
}
|
||||
|
||||
type dbInserter struct {
|
||||
session *mgo.Session
|
||||
dbName string
|
||||
collectionNamer func() string
|
||||
documents []interface{}
|
||||
resultHandler ResultHandler
|
||||
}
|
||||
|
||||
func (in *dbInserter) AddTask(doc interface{}) bool {
|
||||
in.documents = append(in.documents, doc)
|
||||
return len(in.documents) >= maxBulkRows
|
||||
}
|
||||
|
||||
func (in *dbInserter) Execute(objs interface{}) {
|
||||
docs := objs.([]interface{})
|
||||
if len(docs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
bulk := in.session.DB(in.dbName).C(in.collectionNamer()).Bulk()
|
||||
bulk.Insert(docs...)
|
||||
bulk.Unordered()
|
||||
result, err := bulk.Run()
|
||||
if in.resultHandler != nil {
|
||||
in.resultHandler(result, err)
|
||||
} else if err != nil {
|
||||
logx.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (in *dbInserter) RemoveAll() interface{} {
|
||||
documents := in.documents
|
||||
in.documents = nil
|
||||
return documents
|
||||
}
|
||||
@@ -1,242 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stores/mongo/internal"
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
)
|
||||
|
||||
const defaultSlowThreshold = time.Millisecond * 500
|
||||
|
||||
// ErrNotFound is an alias of mgo.ErrNotFound.
|
||||
var ErrNotFound = mgo.ErrNotFound
|
||||
|
||||
type (
|
||||
// Collection interface represents a mongo connection.
|
||||
Collection interface {
|
||||
Find(query interface{}) Query
|
||||
FindId(id interface{}) Query
|
||||
Insert(docs ...interface{}) error
|
||||
Pipe(pipeline interface{}) Pipe
|
||||
Remove(selector interface{}) error
|
||||
RemoveAll(selector interface{}) (*mgo.ChangeInfo, error)
|
||||
RemoveId(id interface{}) error
|
||||
Update(selector, update interface{}) error
|
||||
UpdateId(id, update interface{}) error
|
||||
Upsert(selector, update interface{}) (*mgo.ChangeInfo, error)
|
||||
}
|
||||
|
||||
decoratedCollection struct {
|
||||
name string
|
||||
collection internal.MgoCollection
|
||||
brk breaker.Breaker
|
||||
}
|
||||
|
||||
keepablePromise struct {
|
||||
promise breaker.Promise
|
||||
log func(error)
|
||||
}
|
||||
)
|
||||
|
||||
func newCollection(collection *mgo.Collection, brk breaker.Breaker) Collection {
|
||||
return &decoratedCollection{
|
||||
name: collection.FullName,
|
||||
collection: collection,
|
||||
brk: brk,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Find(query interface{}) Query {
|
||||
promise, err := c.brk.Allow()
|
||||
if err != nil {
|
||||
return rejectedQuery{}
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
return promisedQuery{
|
||||
Query: c.collection.Find(query),
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(err error) {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("find", duration, err, query)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) FindId(id interface{}) Query {
|
||||
promise, err := c.brk.Allow()
|
||||
if err != nil {
|
||||
return rejectedQuery{}
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
return promisedQuery{
|
||||
Query: c.collection.FindId(id),
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(err error) {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("findId", duration, err, id)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Insert(docs ...interface{}) (err error) {
|
||||
return c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("insert", duration, err, docs...)
|
||||
}()
|
||||
|
||||
return c.collection.Insert(docs...)
|
||||
}, acceptable)
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Pipe(pipeline interface{}) Pipe {
|
||||
promise, err := c.brk.Allow()
|
||||
if err != nil {
|
||||
return rejectedPipe{}
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
return promisedPipe{
|
||||
Pipe: c.collection.Pipe(pipeline),
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(err error) {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("pipe", duration, err, pipeline)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Remove(selector interface{}) (err error) {
|
||||
return c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("remove", duration, err, selector)
|
||||
}()
|
||||
|
||||
return c.collection.Remove(selector)
|
||||
}, acceptable)
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) RemoveAll(selector interface{}) (info *mgo.ChangeInfo, err error) {
|
||||
err = c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("removeAll", duration, err, selector)
|
||||
}()
|
||||
|
||||
info, err = c.collection.RemoveAll(selector)
|
||||
return err
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) RemoveId(id interface{}) (err error) {
|
||||
return c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("removeId", duration, err, id)
|
||||
}()
|
||||
|
||||
return c.collection.RemoveId(id)
|
||||
}, acceptable)
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Update(selector, update interface{}) (err error) {
|
||||
return c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("update", duration, err, selector, update)
|
||||
}()
|
||||
|
||||
return c.collection.Update(selector, update)
|
||||
}, acceptable)
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) UpdateId(id, update interface{}) (err error) {
|
||||
return c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("updateId", duration, err, id, update)
|
||||
}()
|
||||
|
||||
return c.collection.UpdateId(id, update)
|
||||
}, acceptable)
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Upsert(selector, update interface{}) (info *mgo.ChangeInfo, err error) {
|
||||
err = c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("upsert", duration, err, selector, update)
|
||||
}()
|
||||
|
||||
info, err = c.collection.Upsert(selector, update)
|
||||
return err
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) logDuration(method string, duration time.Duration, err error, docs ...interface{}) {
|
||||
content, e := json.Marshal(docs)
|
||||
if e != nil {
|
||||
logx.Error(err)
|
||||
} else if err != nil {
|
||||
if duration > slowThreshold.Load() {
|
||||
logx.WithDuration(duration).Slowf("[MONGO] mongo(%s) - slowcall - %s - fail(%s) - %s",
|
||||
c.name, method, err.Error(), string(content))
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("mongo(%s) - %s - fail(%s) - %s",
|
||||
c.name, method, err.Error(), string(content))
|
||||
}
|
||||
} else {
|
||||
if duration > slowThreshold.Load() {
|
||||
logx.WithDuration(duration).Slowf("[MONGO] mongo(%s) - slowcall - %s - ok - %s",
|
||||
c.name, method, string(content))
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("mongo(%s) - %s - ok - %s", c.name, method, string(content))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p keepablePromise) accept(err error) error {
|
||||
p.promise.Accept()
|
||||
p.log(err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p keepablePromise) keep(err error) error {
|
||||
if acceptable(err) {
|
||||
p.promise.Accept()
|
||||
} else {
|
||||
p.promise.Reject(err.Error())
|
||||
}
|
||||
|
||||
p.log(err)
|
||||
return err
|
||||
}
|
||||
|
||||
func acceptable(err error) bool {
|
||||
return err == nil || err == mgo.ErrNotFound
|
||||
}
|
||||
@@ -1,345 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stores/mongo/internal"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
var errDummy = errors.New("dummy")
|
||||
|
||||
func TestKeepPromise_accept(t *testing.T) {
|
||||
p := new(mockPromise)
|
||||
kp := keepablePromise{
|
||||
promise: p,
|
||||
log: func(error) {},
|
||||
}
|
||||
assert.Nil(t, kp.accept(nil))
|
||||
assert.Equal(t, mgo.ErrNotFound, kp.accept(mgo.ErrNotFound))
|
||||
}
|
||||
|
||||
func TestKeepPromise_keep(t *testing.T) {
|
||||
tests := []struct {
|
||||
err error
|
||||
accepted bool
|
||||
reason string
|
||||
}{
|
||||
{
|
||||
err: nil,
|
||||
accepted: true,
|
||||
reason: "",
|
||||
},
|
||||
{
|
||||
err: mgo.ErrNotFound,
|
||||
accepted: true,
|
||||
reason: "",
|
||||
},
|
||||
{
|
||||
err: errors.New("any"),
|
||||
accepted: false,
|
||||
reason: "any",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
p := new(mockPromise)
|
||||
kp := keepablePromise{
|
||||
promise: p,
|
||||
log: func(error) {},
|
||||
}
|
||||
assert.Equal(t, test.err, kp.keep(test.err))
|
||||
assert.Equal(t, test.accepted, p.accepted)
|
||||
assert.Equal(t, test.reason, p.reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCollection(t *testing.T) {
|
||||
col := newCollection(&mgo.Collection{
|
||||
Database: nil,
|
||||
Name: "foo",
|
||||
FullName: "bar",
|
||||
}, breaker.GetBreaker("localhost"))
|
||||
assert.Equal(t, "bar", col.(*decoratedCollection).name)
|
||||
}
|
||||
|
||||
func TestCollectionFind(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
var query mgo.Query
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
col.EXPECT().Find(gomock.Any()).Return(&query)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
actual := c.Find(nil)
|
||||
switch v := actual.(type) {
|
||||
case promisedQuery:
|
||||
assert.Equal(t, &query, v.Query)
|
||||
assert.Equal(t, errDummy, v.promise.keep(errDummy))
|
||||
default:
|
||||
t.Fail()
|
||||
}
|
||||
c.brk = new(dropBreaker)
|
||||
actual = c.Find(nil)
|
||||
assert.Equal(t, rejectedQuery{}, actual)
|
||||
}
|
||||
|
||||
func TestCollectionFindId(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
var query mgo.Query
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
col.EXPECT().FindId(gomock.Any()).Return(&query)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
actual := c.FindId(nil)
|
||||
switch v := actual.(type) {
|
||||
case promisedQuery:
|
||||
assert.Equal(t, &query, v.Query)
|
||||
assert.Equal(t, errDummy, v.promise.keep(errDummy))
|
||||
default:
|
||||
t.Fail()
|
||||
}
|
||||
c.brk = new(dropBreaker)
|
||||
actual = c.FindId(nil)
|
||||
assert.Equal(t, rejectedQuery{}, actual)
|
||||
}
|
||||
|
||||
func TestCollectionInsert(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
col.EXPECT().Insert(nil, nil).Return(errDummy)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
err := c.Insert(nil, nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
c.brk = new(dropBreaker)
|
||||
err = c.Insert(nil, nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollectionPipe(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
var pipe mgo.Pipe
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
col.EXPECT().Pipe(gomock.Any()).Return(&pipe)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
actual := c.Pipe(nil)
|
||||
switch v := actual.(type) {
|
||||
case promisedPipe:
|
||||
assert.Equal(t, &pipe, v.Pipe)
|
||||
assert.Equal(t, errDummy, v.promise.keep(errDummy))
|
||||
default:
|
||||
t.Fail()
|
||||
}
|
||||
c.brk = new(dropBreaker)
|
||||
actual = c.Pipe(nil)
|
||||
assert.Equal(t, rejectedPipe{}, actual)
|
||||
}
|
||||
|
||||
func TestCollectionRemove(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
col.EXPECT().Remove(gomock.Any()).Return(errDummy)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
err := c.Remove(nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
c.brk = new(dropBreaker)
|
||||
err = c.Remove(nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollectionRemoveAll(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
col.EXPECT().RemoveAll(gomock.Any()).Return(nil, errDummy)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
_, err := c.RemoveAll(nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.RemoveAll(nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollectionRemoveId(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
col.EXPECT().RemoveId(gomock.Any()).Return(errDummy)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
err := c.RemoveId(nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
c.brk = new(dropBreaker)
|
||||
err = c.RemoveId(nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollectionUpdate(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
col.EXPECT().Update(gomock.Any(), gomock.Any()).Return(errDummy)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
err := c.Update(nil, nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
c.brk = new(dropBreaker)
|
||||
err = c.Update(nil, nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollectionUpdateId(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
col.EXPECT().UpdateId(gomock.Any(), gomock.Any()).Return(errDummy)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
err := c.UpdateId(nil, nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
c.brk = new(dropBreaker)
|
||||
err = c.UpdateId(nil, nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollectionUpsert(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
col.EXPECT().Upsert(gomock.Any(), gomock.Any()).Return(nil, errDummy)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
_, err := c.Upsert(nil, nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.Upsert(nil, nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func Test_logDuration(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
w := logx.NewWriter(&buf)
|
||||
o := logx.Reset()
|
||||
logx.SetWriter(w)
|
||||
|
||||
defer func() {
|
||||
logx.Reset()
|
||||
logx.SetWriter(o)
|
||||
}()
|
||||
|
||||
buf.Reset()
|
||||
c.logDuration("foo", time.Millisecond, nil, "bar")
|
||||
assert.Contains(t, buf.String(), "foo")
|
||||
assert.Contains(t, buf.String(), "bar")
|
||||
|
||||
buf.Reset()
|
||||
c.logDuration("foo", time.Millisecond, errors.New("bar"), make(chan int))
|
||||
assert.Contains(t, buf.String(), "bar")
|
||||
|
||||
buf.Reset()
|
||||
c.logDuration("foo", slowThreshold.Load()+time.Millisecond, errors.New("bar"))
|
||||
assert.Contains(t, buf.String(), "bar")
|
||||
assert.Contains(t, buf.String(), "slowcall")
|
||||
|
||||
buf.Reset()
|
||||
c.logDuration("foo", slowThreshold.Load()+time.Millisecond, nil)
|
||||
assert.Contains(t, buf.String(), "foo")
|
||||
assert.Contains(t, buf.String(), "slowcall")
|
||||
}
|
||||
|
||||
type mockPromise struct {
|
||||
accepted bool
|
||||
reason string
|
||||
}
|
||||
|
||||
func (p *mockPromise) Accept() {
|
||||
p.accepted = true
|
||||
}
|
||||
|
||||
func (p *mockPromise) Reject(reason string) {
|
||||
p.reason = reason
|
||||
}
|
||||
|
||||
type dropBreaker struct{}
|
||||
|
||||
func (d *dropBreaker) Name() string {
|
||||
return "dummy"
|
||||
}
|
||||
|
||||
func (d *dropBreaker) Allow() (breaker.Promise, error) {
|
||||
return nil, errDummy
|
||||
}
|
||||
|
||||
func (d *dropBreaker) Do(req func() error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *dropBreaker) DoWithAcceptable(req func() error, acceptable breaker.Acceptable) error {
|
||||
return errDummy
|
||||
}
|
||||
|
||||
func (d *dropBreaker) DoWithFallback(req func() error, fallback func(err error) error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *dropBreaker) DoWithFallbackAcceptable(req func() error, fallback func(err error) error,
|
||||
acceptable breaker.Acceptable) error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
//go:generate mockgen -package internal -destination collection_mock.go -source collection.go
|
||||
|
||||
package internal
|
||||
|
||||
import "github.com/globalsign/mgo"
|
||||
|
||||
// MgoCollection interface represents a mgo collection.
|
||||
type MgoCollection interface {
|
||||
Find(query interface{}) *mgo.Query
|
||||
FindId(id interface{}) *mgo.Query
|
||||
Insert(docs ...interface{}) error
|
||||
Pipe(pipeline interface{}) *mgo.Pipe
|
||||
Remove(selector interface{}) error
|
||||
RemoveAll(selector interface{}) (*mgo.ChangeInfo, error)
|
||||
RemoveId(id interface{}) error
|
||||
Update(selector, update interface{}) error
|
||||
UpdateId(id, update interface{}) error
|
||||
Upsert(selector, update interface{}) (*mgo.ChangeInfo, error)
|
||||
}
|
||||
@@ -1,181 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: collection.go
|
||||
|
||||
// Package internal is a generated GoMock package.
|
||||
package internal
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
mgo "github.com/globalsign/mgo"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockMgoCollection is a mock of MgoCollection interface
|
||||
type MockMgoCollection struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockMgoCollectionMockRecorder
|
||||
}
|
||||
|
||||
// MockMgoCollectionMockRecorder is the mock recorder for MockMgoCollection
|
||||
type MockMgoCollectionMockRecorder struct {
|
||||
mock *MockMgoCollection
|
||||
}
|
||||
|
||||
// NewMockMgoCollection creates a new mock instance
|
||||
func NewMockMgoCollection(ctrl *gomock.Controller) *MockMgoCollection {
|
||||
mock := &MockMgoCollection{ctrl: ctrl}
|
||||
mock.recorder = &MockMgoCollectionMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockMgoCollection) EXPECT() *MockMgoCollectionMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Find mocks base method
|
||||
func (m *MockMgoCollection) Find(query interface{}) *mgo.Query {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Find", query)
|
||||
ret0, _ := ret[0].(*mgo.Query)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Find indicates an expected call of Find
|
||||
func (mr *MockMgoCollectionMockRecorder) Find(query interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Find", reflect.TypeOf((*MockMgoCollection)(nil).Find), query)
|
||||
}
|
||||
|
||||
// FindId mocks base method
|
||||
func (m *MockMgoCollection) FindId(id interface{}) *mgo.Query {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FindId", id)
|
||||
ret0, _ := ret[0].(*mgo.Query)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FindId indicates an expected call of FindId
|
||||
func (mr *MockMgoCollectionMockRecorder) FindId(id interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindId", reflect.TypeOf((*MockMgoCollection)(nil).FindId), id)
|
||||
}
|
||||
|
||||
// Insert mocks base method
|
||||
func (m *MockMgoCollection) Insert(docs ...interface{}) error {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{}
|
||||
for _, a := range docs {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Insert", varargs...)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Insert indicates an expected call of Insert
|
||||
func (mr *MockMgoCollectionMockRecorder) Insert(docs ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockMgoCollection)(nil).Insert), docs...)
|
||||
}
|
||||
|
||||
// Pipe mocks base method
|
||||
func (m *MockMgoCollection) Pipe(pipeline interface{}) *mgo.Pipe {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Pipe", pipeline)
|
||||
ret0, _ := ret[0].(*mgo.Pipe)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Pipe indicates an expected call of Pipe
|
||||
func (mr *MockMgoCollectionMockRecorder) Pipe(pipeline interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Pipe", reflect.TypeOf((*MockMgoCollection)(nil).Pipe), pipeline)
|
||||
}
|
||||
|
||||
// Remove mocks base method
|
||||
func (m *MockMgoCollection) Remove(selector interface{}) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Remove", selector)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Remove indicates an expected call of Remove
|
||||
func (mr *MockMgoCollectionMockRecorder) Remove(selector interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockMgoCollection)(nil).Remove), selector)
|
||||
}
|
||||
|
||||
// RemoveAll mocks base method
|
||||
func (m *MockMgoCollection) RemoveAll(selector interface{}) (*mgo.ChangeInfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RemoveAll", selector)
|
||||
ret0, _ := ret[0].(*mgo.ChangeInfo)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// RemoveAll indicates an expected call of RemoveAll
|
||||
func (mr *MockMgoCollectionMockRecorder) RemoveAll(selector interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveAll", reflect.TypeOf((*MockMgoCollection)(nil).RemoveAll), selector)
|
||||
}
|
||||
|
||||
// RemoveId mocks base method
|
||||
func (m *MockMgoCollection) RemoveId(id interface{}) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RemoveId", id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RemoveId indicates an expected call of RemoveId
|
||||
func (mr *MockMgoCollectionMockRecorder) RemoveId(id interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveId", reflect.TypeOf((*MockMgoCollection)(nil).RemoveId), id)
|
||||
}
|
||||
|
||||
// Update mocks base method
|
||||
func (m *MockMgoCollection) Update(selector, update interface{}) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Update", selector, update)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Update indicates an expected call of Update
|
||||
func (mr *MockMgoCollectionMockRecorder) Update(selector, update interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockMgoCollection)(nil).Update), selector, update)
|
||||
}
|
||||
|
||||
// UpdateId mocks base method
|
||||
func (m *MockMgoCollection) UpdateId(id, update interface{}) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateId", id, update)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateId indicates an expected call of UpdateId
|
||||
func (mr *MockMgoCollectionMockRecorder) UpdateId(id, update interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateId", reflect.TypeOf((*MockMgoCollection)(nil).UpdateId), id, update)
|
||||
}
|
||||
|
||||
// Upsert mocks base method
|
||||
func (m *MockMgoCollection) Upsert(selector, update interface{}) (*mgo.ChangeInfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Upsert", selector, update)
|
||||
ret0, _ := ret[0].(*mgo.ChangeInfo)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Upsert indicates an expected call of Upsert
|
||||
func (mr *MockMgoCollectionMockRecorder) Upsert(selector, update interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockMgoCollection)(nil).Upsert), selector, update)
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
//go:generate mockgen -package mongo -destination iter_mock.go -source iter.go Iter
|
||||
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"github.com/globalsign/mgo/bson"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
)
|
||||
|
||||
type (
|
||||
// Iter interface represents a mongo iter.
|
||||
Iter interface {
|
||||
All(result interface{}) error
|
||||
Close() error
|
||||
Done() bool
|
||||
Err() error
|
||||
For(result interface{}, f func() error) error
|
||||
Next(result interface{}) bool
|
||||
State() (int64, []bson.Raw)
|
||||
Timeout() bool
|
||||
}
|
||||
|
||||
// A ClosableIter is a closable mongo iter.
|
||||
ClosableIter struct {
|
||||
Iter
|
||||
Cleanup func()
|
||||
}
|
||||
|
||||
promisedIter struct {
|
||||
Iter
|
||||
promise keepablePromise
|
||||
}
|
||||
|
||||
rejectedIter struct{}
|
||||
)
|
||||
|
||||
func (i promisedIter) All(result interface{}) error {
|
||||
return i.promise.keep(i.Iter.All(result))
|
||||
}
|
||||
|
||||
func (i promisedIter) Close() error {
|
||||
return i.promise.keep(i.Iter.Close())
|
||||
}
|
||||
|
||||
func (i promisedIter) Err() error {
|
||||
return i.Iter.Err()
|
||||
}
|
||||
|
||||
func (i promisedIter) For(result interface{}, f func() error) error {
|
||||
var ferr error
|
||||
err := i.Iter.For(result, func() error {
|
||||
ferr = f()
|
||||
return ferr
|
||||
})
|
||||
if ferr == err {
|
||||
return i.promise.accept(err)
|
||||
}
|
||||
|
||||
return i.promise.keep(err)
|
||||
}
|
||||
|
||||
// Close closes a mongo iter.
|
||||
func (it *ClosableIter) Close() error {
|
||||
err := it.Iter.Close()
|
||||
it.Cleanup()
|
||||
return err
|
||||
}
|
||||
|
||||
func (i rejectedIter) All(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (i rejectedIter) Close() error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (i rejectedIter) Done() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (i rejectedIter) Err() error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (i rejectedIter) For(result interface{}, f func() error) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (i rejectedIter) Next(result interface{}) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (i rejectedIter) State() (int64, []bson.Raw) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (i rejectedIter) Timeout() bool {
|
||||
return false
|
||||
}
|
||||
@@ -1,148 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: iter.go
|
||||
|
||||
// Package mongo is a generated GoMock package.
|
||||
package mongo
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
bson "github.com/globalsign/mgo/bson"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockIter is a mock of Iter interface
|
||||
type MockIter struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockIterMockRecorder
|
||||
}
|
||||
|
||||
// MockIterMockRecorder is the mock recorder for MockIter
|
||||
type MockIterMockRecorder struct {
|
||||
mock *MockIter
|
||||
}
|
||||
|
||||
// NewMockIter creates a new mock instance
|
||||
func NewMockIter(ctrl *gomock.Controller) *MockIter {
|
||||
mock := &MockIter{ctrl: ctrl}
|
||||
mock.recorder = &MockIterMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockIter) EXPECT() *MockIterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// All mocks base method
|
||||
func (m *MockIter) All(result interface{}) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "All", result)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// All indicates an expected call of All
|
||||
func (mr *MockIterMockRecorder) All(result interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "All", reflect.TypeOf((*MockIter)(nil).All), result)
|
||||
}
|
||||
|
||||
// Close mocks base method
|
||||
func (m *MockIter) Close() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close
|
||||
func (mr *MockIterMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockIter)(nil).Close))
|
||||
}
|
||||
|
||||
// Done mocks base method
|
||||
func (m *MockIter) Done() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Done")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Done indicates an expected call of Done
|
||||
func (mr *MockIterMockRecorder) Done() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Done", reflect.TypeOf((*MockIter)(nil).Done))
|
||||
}
|
||||
|
||||
// Err mocks base method
|
||||
func (m *MockIter) Err() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Err")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Err indicates an expected call of Err
|
||||
func (mr *MockIterMockRecorder) Err() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockIter)(nil).Err))
|
||||
}
|
||||
|
||||
// For mocks base method
|
||||
func (m *MockIter) For(result interface{}, f func() error) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "For", result, f)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// For indicates an expected call of For
|
||||
func (mr *MockIterMockRecorder) For(result, f interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "For", reflect.TypeOf((*MockIter)(nil).For), result, f)
|
||||
}
|
||||
|
||||
// Next mocks base method
|
||||
func (m *MockIter) Next(result interface{}) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Next", result)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Next indicates an expected call of Next
|
||||
func (mr *MockIterMockRecorder) Next(result interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockIter)(nil).Next), result)
|
||||
}
|
||||
|
||||
// State mocks base method
|
||||
func (m *MockIter) State() (int64, []bson.Raw) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "State")
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].([]bson.Raw)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// State indicates an expected call of State
|
||||
func (mr *MockIterMockRecorder) State() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "State", reflect.TypeOf((*MockIter)(nil).State))
|
||||
}
|
||||
|
||||
// Timeout mocks base method
|
||||
func (m *MockIter) Timeout() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Timeout")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Timeout indicates an expected call of Timeout
|
||||
func (mr *MockIterMockRecorder) Timeout() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Timeout", reflect.TypeOf((*MockIter)(nil).Timeout))
|
||||
}
|
||||
@@ -1,264 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
)
|
||||
|
||||
func TestClosableIter_Close(t *testing.T) {
|
||||
errs := []error{
|
||||
nil,
|
||||
mgo.ErrNotFound,
|
||||
}
|
||||
|
||||
for _, err := range errs {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
cleaned := syncx.NewAtomicBool()
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().Close().Return(err)
|
||||
ci := ClosableIter{
|
||||
Iter: iter,
|
||||
Cleanup: func() {
|
||||
cleaned.Set(true)
|
||||
},
|
||||
}
|
||||
assert.Equal(t, err, ci.Close())
|
||||
assert.True(t, cleaned.True())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromisedIter_AllAndClose(t *testing.T) {
|
||||
tests := []struct {
|
||||
err error
|
||||
accepted bool
|
||||
reason string
|
||||
}{
|
||||
{
|
||||
err: nil,
|
||||
accepted: true,
|
||||
reason: "",
|
||||
},
|
||||
{
|
||||
err: mgo.ErrNotFound,
|
||||
accepted: true,
|
||||
reason: "",
|
||||
},
|
||||
{
|
||||
err: errors.New("any"),
|
||||
accepted: false,
|
||||
reason: "any",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().All(gomock.Any()).Return(test.err)
|
||||
promise := new(mockPromise)
|
||||
pi := promisedIter{
|
||||
Iter: iter,
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(error) {},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, test.err, pi.All(nil))
|
||||
assert.Equal(t, test.accepted, promise.accepted)
|
||||
assert.Equal(t, test.reason, promise.reason)
|
||||
})
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().Close().Return(test.err)
|
||||
promise := new(mockPromise)
|
||||
pi := promisedIter{
|
||||
Iter: iter,
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(error) {},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, test.err, pi.Close())
|
||||
assert.Equal(t, test.accepted, promise.accepted)
|
||||
assert.Equal(t, test.reason, promise.reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromisedIter_Err(t *testing.T) {
|
||||
errs := []error{
|
||||
nil,
|
||||
mgo.ErrNotFound,
|
||||
}
|
||||
|
||||
for _, err := range errs {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().Err().Return(err)
|
||||
promise := new(mockPromise)
|
||||
pi := promisedIter{
|
||||
Iter: iter,
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(error) {},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, err, pi.Err())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromisedIter_For(t *testing.T) {
|
||||
tests := []struct {
|
||||
err error
|
||||
accepted bool
|
||||
reason string
|
||||
}{
|
||||
{
|
||||
err: nil,
|
||||
accepted: true,
|
||||
reason: "",
|
||||
},
|
||||
{
|
||||
err: mgo.ErrNotFound,
|
||||
accepted: true,
|
||||
reason: "",
|
||||
},
|
||||
{
|
||||
err: errors.New("any"),
|
||||
accepted: false,
|
||||
reason: "any",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().For(gomock.Any(), gomock.Any()).Return(test.err)
|
||||
promise := new(mockPromise)
|
||||
pi := promisedIter{
|
||||
Iter: iter,
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(error) {},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, test.err, pi.For(nil, nil))
|
||||
assert.Equal(t, test.accepted, promise.accepted)
|
||||
assert.Equal(t, test.reason, promise.reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRejectedIter_All(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).All(nil))
|
||||
}
|
||||
|
||||
func TestRejectedIter_Close(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).Close())
|
||||
}
|
||||
|
||||
func TestRejectedIter_Done(t *testing.T) {
|
||||
assert.False(t, new(rejectedIter).Done())
|
||||
}
|
||||
|
||||
func TestRejectedIter_Err(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).Err())
|
||||
}
|
||||
|
||||
func TestRejectedIter_For(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).For(nil, nil))
|
||||
}
|
||||
|
||||
func TestRejectedIter_Next(t *testing.T) {
|
||||
assert.False(t, new(rejectedIter).Next(nil))
|
||||
}
|
||||
|
||||
func TestRejectedIter_State(t *testing.T) {
|
||||
n, raw := new(rejectedIter).State()
|
||||
assert.Equal(t, int64(0), n)
|
||||
assert.Nil(t, raw)
|
||||
}
|
||||
|
||||
func TestRejectedIter_Timeout(t *testing.T) {
|
||||
assert.False(t, new(rejectedIter).Timeout())
|
||||
}
|
||||
|
||||
func TestIter_Done(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().Done().Return(true)
|
||||
ci := ClosableIter{
|
||||
Iter: iter,
|
||||
Cleanup: nil,
|
||||
}
|
||||
assert.True(t, ci.Done())
|
||||
}
|
||||
|
||||
func TestIter_Next(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().Next(gomock.Any()).Return(true)
|
||||
ci := ClosableIter{
|
||||
Iter: iter,
|
||||
Cleanup: nil,
|
||||
}
|
||||
assert.True(t, ci.Next(nil))
|
||||
}
|
||||
|
||||
func TestIter_State(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().State().Return(int64(1), nil)
|
||||
ci := ClosableIter{
|
||||
Iter: iter,
|
||||
Cleanup: nil,
|
||||
}
|
||||
n, raw := ci.State()
|
||||
assert.Equal(t, int64(1), n)
|
||||
assert.Nil(t, raw)
|
||||
}
|
||||
|
||||
func TestIter_Timeout(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().Timeout().Return(true)
|
||||
ci := ClosableIter{
|
||||
Iter: iter,
|
||||
Cleanup: nil,
|
||||
}
|
||||
assert.True(t, ci.Timeout())
|
||||
}
|
||||
@@ -1,177 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
)
|
||||
|
||||
// A Model is a mongo model.
|
||||
type Model struct {
|
||||
session *concurrentSession
|
||||
db *mgo.Database
|
||||
collection string
|
||||
brk breaker.Breaker
|
||||
opts []Option
|
||||
}
|
||||
|
||||
// MustNewModel returns a Model, exits on errors.
|
||||
func MustNewModel(url, collection string, opts ...Option) *Model {
|
||||
model, err := NewModel(url, collection, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return model
|
||||
}
|
||||
|
||||
// NewModel returns a Model.
|
||||
func NewModel(url, collection string, opts ...Option) (*Model, error) {
|
||||
session, err := getConcurrentSession(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Model{
|
||||
session: session,
|
||||
// If name is empty, the database name provided in the dialed URL is used instead
|
||||
db: session.DB(""),
|
||||
collection: collection,
|
||||
brk: breaker.GetBreaker(url),
|
||||
opts: opts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Find finds a record with given query.
|
||||
func (mm *Model) Find(query interface{}) (Query, error) {
|
||||
return mm.query(func(c Collection) Query {
|
||||
return c.Find(query)
|
||||
})
|
||||
}
|
||||
|
||||
// FindId finds a record with given id.
|
||||
func (mm *Model) FindId(id interface{}) (Query, error) {
|
||||
return mm.query(func(c Collection) Query {
|
||||
return c.FindId(id)
|
||||
})
|
||||
}
|
||||
|
||||
// GetCollection returns a Collection with given session.
|
||||
func (mm *Model) GetCollection(session *mgo.Session) Collection {
|
||||
return newCollection(mm.db.C(mm.collection).With(session), mm.brk)
|
||||
}
|
||||
|
||||
// Insert inserts docs into mm.
|
||||
func (mm *Model) Insert(docs ...interface{}) error {
|
||||
return mm.execute(func(c Collection) error {
|
||||
return c.Insert(docs...)
|
||||
})
|
||||
}
|
||||
|
||||
// Pipe returns a Pipe with given pipeline.
|
||||
func (mm *Model) Pipe(pipeline interface{}) (Pipe, error) {
|
||||
return mm.pipe(func(c Collection) Pipe {
|
||||
return c.Pipe(pipeline)
|
||||
})
|
||||
}
|
||||
|
||||
// PutSession returns the given session.
|
||||
func (mm *Model) PutSession(session *mgo.Session) {
|
||||
mm.session.putSession(session)
|
||||
}
|
||||
|
||||
// Remove removes the records with given selector.
|
||||
func (mm *Model) Remove(selector interface{}) error {
|
||||
return mm.execute(func(c Collection) error {
|
||||
return c.Remove(selector)
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveAll removes all with given selector and returns a mgo.ChangeInfo.
|
||||
func (mm *Model) RemoveAll(selector interface{}) (*mgo.ChangeInfo, error) {
|
||||
return mm.change(func(c Collection) (*mgo.ChangeInfo, error) {
|
||||
return c.RemoveAll(selector)
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveId removes a record with given id.
|
||||
func (mm *Model) RemoveId(id interface{}) error {
|
||||
return mm.execute(func(c Collection) error {
|
||||
return c.RemoveId(id)
|
||||
})
|
||||
}
|
||||
|
||||
// TakeSession gets a session.
|
||||
func (mm *Model) TakeSession() (*mgo.Session, error) {
|
||||
return mm.session.takeSession(mm.opts...)
|
||||
}
|
||||
|
||||
// Update updates a record with given selector.
|
||||
func (mm *Model) Update(selector, update interface{}) error {
|
||||
return mm.execute(func(c Collection) error {
|
||||
return c.Update(selector, update)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateId updates a record with given id.
|
||||
func (mm *Model) UpdateId(id, update interface{}) error {
|
||||
return mm.execute(func(c Collection) error {
|
||||
return c.UpdateId(id, update)
|
||||
})
|
||||
}
|
||||
|
||||
// Upsert upserts a record with given selector, and returns a mgo.ChangeInfo.
|
||||
func (mm *Model) Upsert(selector, update interface{}) (*mgo.ChangeInfo, error) {
|
||||
return mm.change(func(c Collection) (*mgo.ChangeInfo, error) {
|
||||
return c.Upsert(selector, update)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) change(fn func(c Collection) (*mgo.ChangeInfo, error)) (*mgo.ChangeInfo, error) {
|
||||
session, err := mm.TakeSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer mm.PutSession(session)
|
||||
|
||||
return fn(mm.GetCollection(session))
|
||||
}
|
||||
|
||||
func (mm *Model) execute(fn func(c Collection) error) error {
|
||||
session, err := mm.TakeSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer mm.PutSession(session)
|
||||
|
||||
return fn(mm.GetCollection(session))
|
||||
}
|
||||
|
||||
func (mm *Model) pipe(fn func(c Collection) Pipe) (Pipe, error) {
|
||||
session, err := mm.TakeSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer mm.PutSession(session)
|
||||
|
||||
return fn(mm.GetCollection(session)), nil
|
||||
}
|
||||
|
||||
func (mm *Model) query(fn func(c Collection) Query) (Query, error) {
|
||||
session, err := mm.TakeSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer mm.PutSession(session)
|
||||
|
||||
return fn(mm.GetCollection(session)), nil
|
||||
}
|
||||
|
||||
// WithTimeout customizes an operation with given timeout.
|
||||
func WithTimeout(timeout time.Duration) Option {
|
||||
return func(opts *options) {
|
||||
opts.timeout = timeout
|
||||
}
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestWithTimeout(t *testing.T) {
|
||||
o := defaultOptions()
|
||||
WithTimeout(time.Second)(o)
|
||||
assert.Equal(t, time.Second, o.timeout)
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
)
|
||||
|
||||
var slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold)
|
||||
|
||||
type (
|
||||
options struct {
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// Option defines the method to customize a mongo model.
|
||||
Option func(opts *options)
|
||||
)
|
||||
|
||||
// SetSlowThreshold sets the slow threshold.
|
||||
func SetSlowThreshold(threshold time.Duration) {
|
||||
slowThreshold.Set(threshold)
|
||||
}
|
||||
|
||||
func defaultOptions() *options {
|
||||
return &options{
|
||||
timeout: defaultTimeout,
|
||||
}
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSetSlowThreshold(t *testing.T) {
|
||||
assert.Equal(t, defaultSlowThreshold, slowThreshold.Load())
|
||||
SetSlowThreshold(time.Second)
|
||||
assert.Equal(t, time.Second, slowThreshold.Load())
|
||||
}
|
||||
@@ -1,100 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
)
|
||||
|
||||
type (
|
||||
// Pipe interface represents a mongo pipe.
|
||||
Pipe interface {
|
||||
All(result interface{}) error
|
||||
AllowDiskUse() Pipe
|
||||
Batch(n int) Pipe
|
||||
Collation(collation *mgo.Collation) Pipe
|
||||
Explain(result interface{}) error
|
||||
Iter() Iter
|
||||
One(result interface{}) error
|
||||
SetMaxTime(d time.Duration) Pipe
|
||||
}
|
||||
|
||||
promisedPipe struct {
|
||||
*mgo.Pipe
|
||||
promise keepablePromise
|
||||
}
|
||||
|
||||
rejectedPipe struct{}
|
||||
)
|
||||
|
||||
func (p promisedPipe) All(result interface{}) error {
|
||||
return p.promise.keep(p.Pipe.All(result))
|
||||
}
|
||||
|
||||
func (p promisedPipe) AllowDiskUse() Pipe {
|
||||
p.Pipe.AllowDiskUse()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p promisedPipe) Batch(n int) Pipe {
|
||||
p.Pipe.Batch(n)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p promisedPipe) Collation(collation *mgo.Collation) Pipe {
|
||||
p.Pipe.Collation(collation)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p promisedPipe) Explain(result interface{}) error {
|
||||
return p.promise.keep(p.Pipe.Explain(result))
|
||||
}
|
||||
|
||||
func (p promisedPipe) Iter() Iter {
|
||||
return promisedIter{
|
||||
Iter: p.Pipe.Iter(),
|
||||
promise: p.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (p promisedPipe) One(result interface{}) error {
|
||||
return p.promise.keep(p.Pipe.One(result))
|
||||
}
|
||||
|
||||
func (p promisedPipe) SetMaxTime(d time.Duration) Pipe {
|
||||
p.Pipe.SetMaxTime(d)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p rejectedPipe) All(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (p rejectedPipe) AllowDiskUse() Pipe {
|
||||
return p
|
||||
}
|
||||
|
||||
func (p rejectedPipe) Batch(n int) Pipe {
|
||||
return p
|
||||
}
|
||||
|
||||
func (p rejectedPipe) Collation(collation *mgo.Collation) Pipe {
|
||||
return p
|
||||
}
|
||||
|
||||
func (p rejectedPipe) Explain(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (p rejectedPipe) Iter() Iter {
|
||||
return rejectedIter{}
|
||||
}
|
||||
|
||||
func (p rejectedPipe) One(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (p rejectedPipe) SetMaxTime(d time.Duration) Pipe {
|
||||
return p
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
)
|
||||
|
||||
func TestRejectedPipe_All(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedPipe).All(nil))
|
||||
}
|
||||
|
||||
func TestRejectedPipe_AllowDiskUse(t *testing.T) {
|
||||
var p rejectedPipe
|
||||
assert.Equal(t, p, p.AllowDiskUse())
|
||||
}
|
||||
|
||||
func TestRejectedPipe_Batch(t *testing.T) {
|
||||
var p rejectedPipe
|
||||
assert.Equal(t, p, p.Batch(1))
|
||||
}
|
||||
|
||||
func TestRejectedPipe_Collation(t *testing.T) {
|
||||
var p rejectedPipe
|
||||
assert.Equal(t, p, p.Collation(nil))
|
||||
}
|
||||
|
||||
func TestRejectedPipe_Explain(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedPipe).Explain(nil))
|
||||
}
|
||||
|
||||
func TestRejectedPipe_Iter(t *testing.T) {
|
||||
assert.EqualValues(t, rejectedIter{}, new(rejectedPipe).Iter())
|
||||
}
|
||||
|
||||
func TestRejectedPipe_One(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedPipe).One(nil))
|
||||
}
|
||||
|
||||
func TestRejectedPipe_SetMaxTime(t *testing.T) {
|
||||
var p rejectedPipe
|
||||
assert.Equal(t, p, p.SetMaxTime(0))
|
||||
}
|
||||
@@ -1,285 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
)
|
||||
|
||||
type (
|
||||
// Query interface represents a mongo query.
|
||||
Query interface {
|
||||
All(result interface{}) error
|
||||
Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error)
|
||||
Batch(n int) Query
|
||||
Collation(collation *mgo.Collation) Query
|
||||
Comment(comment string) Query
|
||||
Count() (int, error)
|
||||
Distinct(key string, result interface{}) error
|
||||
Explain(result interface{}) error
|
||||
For(result interface{}, f func() error) error
|
||||
Hint(indexKey ...string) Query
|
||||
Iter() Iter
|
||||
Limit(n int) Query
|
||||
LogReplay() Query
|
||||
MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error)
|
||||
One(result interface{}) error
|
||||
Prefetch(p float64) Query
|
||||
Select(selector interface{}) Query
|
||||
SetMaxScan(n int) Query
|
||||
SetMaxTime(d time.Duration) Query
|
||||
Skip(n int) Query
|
||||
Snapshot() Query
|
||||
Sort(fields ...string) Query
|
||||
Tail(timeout time.Duration) Iter
|
||||
}
|
||||
|
||||
promisedQuery struct {
|
||||
*mgo.Query
|
||||
promise keepablePromise
|
||||
}
|
||||
|
||||
rejectedQuery struct{}
|
||||
)
|
||||
|
||||
func (q promisedQuery) All(result interface{}) error {
|
||||
return q.promise.keep(q.Query.All(result))
|
||||
}
|
||||
|
||||
func (q promisedQuery) Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error) {
|
||||
info, err := q.Query.Apply(change, result)
|
||||
return info, q.promise.keep(err)
|
||||
}
|
||||
|
||||
func (q promisedQuery) Batch(n int) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Batch(n),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Collation(collation *mgo.Collation) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Collation(collation),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Comment(comment string) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Comment(comment),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Count() (int, error) {
|
||||
v, err := q.Query.Count()
|
||||
return v, q.promise.keep(err)
|
||||
}
|
||||
|
||||
func (q promisedQuery) Distinct(key string, result interface{}) error {
|
||||
return q.promise.keep(q.Query.Distinct(key, result))
|
||||
}
|
||||
|
||||
func (q promisedQuery) Explain(result interface{}) error {
|
||||
return q.promise.keep(q.Query.Explain(result))
|
||||
}
|
||||
|
||||
func (q promisedQuery) For(result interface{}, f func() error) error {
|
||||
var ferr error
|
||||
err := q.Query.For(result, func() error {
|
||||
ferr = f()
|
||||
return ferr
|
||||
})
|
||||
if ferr == err {
|
||||
return q.promise.accept(err)
|
||||
}
|
||||
|
||||
return q.promise.keep(err)
|
||||
}
|
||||
|
||||
func (q promisedQuery) Hint(indexKey ...string) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Hint(indexKey...),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Iter() Iter {
|
||||
return promisedIter{
|
||||
Iter: q.Query.Iter(),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Limit(n int) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Limit(n),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) LogReplay() Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.LogReplay(),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error) {
|
||||
info, err := q.Query.MapReduce(job, result)
|
||||
return info, q.promise.keep(err)
|
||||
}
|
||||
|
||||
func (q promisedQuery) One(result interface{}) error {
|
||||
return q.promise.keep(q.Query.One(result))
|
||||
}
|
||||
|
||||
func (q promisedQuery) Prefetch(p float64) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Prefetch(p),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Select(selector interface{}) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Select(selector),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) SetMaxScan(n int) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.SetMaxScan(n),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) SetMaxTime(d time.Duration) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.SetMaxTime(d),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Skip(n int) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Skip(n),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Snapshot() Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Snapshot(),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Sort(fields ...string) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Sort(fields...),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Tail(timeout time.Duration) Iter {
|
||||
return promisedIter{
|
||||
Iter: q.Query.Tail(timeout),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q rejectedQuery) All(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error) {
|
||||
return nil, breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Batch(n int) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Collation(collation *mgo.Collation) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Comment(comment string) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Count() (int, error) {
|
||||
return 0, breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Distinct(key string, result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Explain(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) For(result interface{}, f func() error) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Hint(indexKey ...string) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Iter() Iter {
|
||||
return rejectedIter{}
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Limit(n int) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) LogReplay() Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error) {
|
||||
return nil, breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) One(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Prefetch(p float64) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Select(selector interface{}) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) SetMaxScan(n int) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) SetMaxTime(d time.Duration) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Skip(n int) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Snapshot() Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Sort(fields ...string) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Tail(timeout time.Duration) Iter {
|
||||
return rejectedIter{}
|
||||
}
|
||||
@@ -1,120 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
)
|
||||
|
||||
func Test_rejectedQuery_All(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).All(nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Apply(t *testing.T) {
|
||||
info, err := new(rejectedQuery).Apply(mgo.Change{}, nil)
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, err)
|
||||
assert.Nil(t, info)
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Batch(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Batch(1))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Collation(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Collation(nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Comment(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Comment(""))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Count(t *testing.T) {
|
||||
n, err := new(rejectedQuery).Count()
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, err)
|
||||
assert.Equal(t, 0, n)
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Distinct(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).Distinct("", nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Explain(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).Explain(nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_For(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).For(nil, nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Hint(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Hint())
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Iter(t *testing.T) {
|
||||
assert.EqualValues(t, rejectedIter{}, new(rejectedQuery).Iter())
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Limit(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Limit(1))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_LogReplay(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.LogReplay())
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_MapReduce(t *testing.T) {
|
||||
info, err := new(rejectedQuery).MapReduce(nil, nil)
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, err)
|
||||
assert.Nil(t, info)
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_One(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).One(nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Prefetch(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Prefetch(1))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Select(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Select(nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_SetMaxScan(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.SetMaxScan(0))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_SetMaxTime(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.SetMaxTime(0))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Skip(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Skip(0))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Snapshot(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Snapshot())
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Sort(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Sort())
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Tail(t *testing.T) {
|
||||
assert.EqualValues(t, rejectedIter{}, new(rejectedQuery).Tail(0))
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultConcurrency = 50
|
||||
defaultTimeout = time.Second
|
||||
)
|
||||
|
||||
var sessionManager = syncx.NewResourceManager()
|
||||
|
||||
type concurrentSession struct {
|
||||
*mgo.Session
|
||||
limit syncx.TimeoutLimit
|
||||
}
|
||||
|
||||
func (cs *concurrentSession) Close() error {
|
||||
cs.Session.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func getConcurrentSession(url string) (*concurrentSession, error) {
|
||||
val, err := sessionManager.GetResource(url, func() (io.Closer, error) {
|
||||
mgoSession, err := mgo.Dial(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
concurrentSess := &concurrentSession{
|
||||
Session: mgoSession,
|
||||
limit: syncx.NewTimeoutLimit(defaultConcurrency),
|
||||
}
|
||||
|
||||
return concurrentSess, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return val.(*concurrentSession), nil
|
||||
}
|
||||
|
||||
func (cs *concurrentSession) putSession(session *mgo.Session) {
|
||||
if err := cs.limit.Return(); err != nil {
|
||||
logx.Error(err)
|
||||
}
|
||||
|
||||
// anyway, we need to close the session
|
||||
session.Close()
|
||||
}
|
||||
|
||||
func (cs *concurrentSession) takeSession(opts ...Option) (*mgo.Session, error) {
|
||||
o := defaultOptions()
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
|
||||
if err := cs.limit.Borrow(o.timeout); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cs.Copy(), nil
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import "strings"
|
||||
|
||||
const mongoAddrSep = ","
|
||||
|
||||
// FormatAddr formats mongo hosts to a string.
|
||||
func FormatAddr(hosts []string) string {
|
||||
return strings.Join(hosts, mongoAddrSep)
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFormatAddrs(t *testing.T) {
|
||||
tests := []struct {
|
||||
addrs []string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
addrs: []string{"a", "b"},
|
||||
expect: "a,b",
|
||||
},
|
||||
{
|
||||
addrs: []string{"a", "b", "c"},
|
||||
expect: "a,b,c",
|
||||
},
|
||||
{
|
||||
addrs: []string{},
|
||||
expect: "",
|
||||
},
|
||||
{
|
||||
addrs: nil,
|
||||
expect: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
assert.Equal(t, test.expect, FormatAddr(test.addrs))
|
||||
}
|
||||
}
|
||||
@@ -1,199 +0,0 @@
|
||||
package mongoc
|
||||
|
||||
import (
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/zeromicro/go-zero/core/stores/cache"
|
||||
"github.com/zeromicro/go-zero/core/stores/mongo"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNotFound is an alias of mgo.ErrNotFound.
|
||||
ErrNotFound = mgo.ErrNotFound
|
||||
|
||||
// can't use one SingleFlight per conn, because multiple conns may share the same cache key.
|
||||
singleFlight = syncx.NewSingleFlight()
|
||||
stats = cache.NewStat("mongoc")
|
||||
)
|
||||
|
||||
type (
|
||||
// QueryOption defines the method to customize a mongo query.
|
||||
QueryOption func(query mongo.Query) mongo.Query
|
||||
|
||||
// CachedCollection interface represents a mongo collection with cache.
|
||||
CachedCollection interface {
|
||||
Count(query interface{}) (int, error)
|
||||
DelCache(keys ...string) error
|
||||
FindAllNoCache(v, query interface{}, opts ...QueryOption) error
|
||||
FindOne(v interface{}, key string, query interface{}) error
|
||||
FindOneNoCache(v, query interface{}) error
|
||||
FindOneId(v interface{}, key string, id interface{}) error
|
||||
FindOneIdNoCache(v, id interface{}) error
|
||||
GetCache(key string, v interface{}) error
|
||||
Insert(docs ...interface{}) error
|
||||
Pipe(pipeline interface{}) mongo.Pipe
|
||||
Remove(selector interface{}, keys ...string) error
|
||||
RemoveNoCache(selector interface{}) error
|
||||
RemoveAll(selector interface{}, keys ...string) (*mgo.ChangeInfo, error)
|
||||
RemoveAllNoCache(selector interface{}) (*mgo.ChangeInfo, error)
|
||||
RemoveId(id interface{}, keys ...string) error
|
||||
RemoveIdNoCache(id interface{}) error
|
||||
SetCache(key string, v interface{}) error
|
||||
Update(selector, update interface{}, keys ...string) error
|
||||
UpdateNoCache(selector, update interface{}) error
|
||||
UpdateId(id, update interface{}, keys ...string) error
|
||||
UpdateIdNoCache(id, update interface{}) error
|
||||
Upsert(selector, update interface{}, keys ...string) (*mgo.ChangeInfo, error)
|
||||
UpsertNoCache(selector, update interface{}) (*mgo.ChangeInfo, error)
|
||||
}
|
||||
|
||||
cachedCollection struct {
|
||||
collection mongo.Collection
|
||||
cache cache.Cache
|
||||
}
|
||||
)
|
||||
|
||||
func newCollection(collection mongo.Collection, c cache.Cache) CachedCollection {
|
||||
return &cachedCollection{
|
||||
collection: collection,
|
||||
cache: c,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cachedCollection) Count(query interface{}) (int, error) {
|
||||
return c.collection.Find(query).Count()
|
||||
}
|
||||
|
||||
func (c *cachedCollection) DelCache(keys ...string) error {
|
||||
return c.cache.Del(keys...)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) FindAllNoCache(v, query interface{}, opts ...QueryOption) error {
|
||||
q := c.collection.Find(query)
|
||||
for _, opt := range opts {
|
||||
q = opt(q)
|
||||
}
|
||||
return q.All(v)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) FindOne(v interface{}, key string, query interface{}) error {
|
||||
return c.cache.Take(v, key, func(v interface{}) error {
|
||||
q := c.collection.Find(query)
|
||||
return q.One(v)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *cachedCollection) FindOneNoCache(v, query interface{}) error {
|
||||
q := c.collection.Find(query)
|
||||
return q.One(v)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) FindOneId(v interface{}, key string, id interface{}) error {
|
||||
return c.cache.Take(v, key, func(v interface{}) error {
|
||||
q := c.collection.FindId(id)
|
||||
return q.One(v)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *cachedCollection) FindOneIdNoCache(v, id interface{}) error {
|
||||
q := c.collection.FindId(id)
|
||||
return q.One(v)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) GetCache(key string, v interface{}) error {
|
||||
return c.cache.Get(key, v)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) Insert(docs ...interface{}) error {
|
||||
return c.collection.Insert(docs...)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) Pipe(pipeline interface{}) mongo.Pipe {
|
||||
return c.collection.Pipe(pipeline)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) Remove(selector interface{}, keys ...string) error {
|
||||
if err := c.RemoveNoCache(selector); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.DelCache(keys...)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) RemoveNoCache(selector interface{}) error {
|
||||
return c.collection.Remove(selector)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) RemoveAll(selector interface{}, keys ...string) (*mgo.ChangeInfo, error) {
|
||||
info, err := c.RemoveAllNoCache(selector)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.DelCache(keys...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func (c *cachedCollection) RemoveAllNoCache(selector interface{}) (*mgo.ChangeInfo, error) {
|
||||
return c.collection.RemoveAll(selector)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) RemoveId(id interface{}, keys ...string) error {
|
||||
if err := c.RemoveIdNoCache(id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.DelCache(keys...)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) RemoveIdNoCache(id interface{}) error {
|
||||
return c.collection.RemoveId(id)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) SetCache(key string, v interface{}) error {
|
||||
return c.cache.Set(key, v)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) Update(selector, update interface{}, keys ...string) error {
|
||||
if err := c.UpdateNoCache(selector, update); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.DelCache(keys...)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) UpdateNoCache(selector, update interface{}) error {
|
||||
return c.collection.Update(selector, update)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) UpdateId(id, update interface{}, keys ...string) error {
|
||||
if err := c.UpdateIdNoCache(id, update); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.DelCache(keys...)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) UpdateIdNoCache(id, update interface{}) error {
|
||||
return c.collection.UpdateId(id, update)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) Upsert(selector, update interface{}, keys ...string) (*mgo.ChangeInfo, error) {
|
||||
info, err := c.UpsertNoCache(selector, update)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.DelCache(keys...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func (c *cachedCollection) UpsertNoCache(selector, update interface{}) (*mgo.ChangeInfo, error) {
|
||||
return c.collection.Upsert(selector, update)
|
||||
}
|
||||
@@ -1,365 +0,0 @@
|
||||
package mongoc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/globalsign/mgo/bson"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/core/stores/cache"
|
||||
"github.com/zeromicro/go-zero/core/stores/mongo"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis/redistest"
|
||||
)
|
||||
|
||||
const dummyCount = 10
|
||||
|
||||
func init() {
|
||||
stat.SetReporter(nil)
|
||||
}
|
||||
|
||||
func TestCollection_Count(t *testing.T) {
|
||||
resetStats()
|
||||
r, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean()
|
||||
|
||||
cach := cache.NewNode(r, singleFlight, stats, mgo.ErrNotFound)
|
||||
c := newCollection(dummyConn{}, cach)
|
||||
val, err := c.Count("any")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, dummyCount, val)
|
||||
|
||||
var value string
|
||||
assert.Nil(t, r.Set("any", `"foo"`))
|
||||
assert.Nil(t, c.GetCache("any", &value))
|
||||
assert.Equal(t, "foo", value)
|
||||
assert.Nil(t, c.DelCache("any"))
|
||||
|
||||
assert.Nil(t, c.SetCache("any", "bar"))
|
||||
assert.Nil(t, c.FindAllNoCache(&value, "any", func(query mongo.Query) mongo.Query {
|
||||
return query
|
||||
}))
|
||||
assert.Nil(t, c.FindOne(&value, "any", "foo"))
|
||||
assert.Equal(t, "bar", value)
|
||||
assert.Nil(t, c.DelCache("any"))
|
||||
c = newCollection(dummyConn{val: `"bar"`}, cach)
|
||||
assert.Nil(t, c.FindOne(&value, "any", "foo"))
|
||||
assert.Equal(t, "bar", value)
|
||||
assert.Nil(t, c.FindOneNoCache(&value, "foo"))
|
||||
assert.Equal(t, "bar", value)
|
||||
assert.Nil(t, c.FindOneId(&value, "anyone", "foo"))
|
||||
assert.Equal(t, "bar", value)
|
||||
assert.Nil(t, c.FindOneIdNoCache(&value, "foo"))
|
||||
assert.Equal(t, "bar", value)
|
||||
assert.Nil(t, c.Insert("foo"))
|
||||
assert.Nil(t, c.Pipe("foo"))
|
||||
assert.Nil(t, c.Remove("any"))
|
||||
assert.Nil(t, c.RemoveId("any"))
|
||||
_, err = c.RemoveAll("any")
|
||||
assert.Nil(t, err)
|
||||
assert.Nil(t, c.Update("foo", "bar"))
|
||||
assert.Nil(t, c.UpdateId("foo", "bar"))
|
||||
_, err = c.Upsert("foo", "bar")
|
||||
assert.Nil(t, err)
|
||||
|
||||
c = newCollection(dummyConn{
|
||||
val: `"bar"`,
|
||||
removeErr: errors.New("any"),
|
||||
}, cach)
|
||||
assert.NotNil(t, c.Remove("any"))
|
||||
_, err = c.RemoveAll("any", "bar")
|
||||
assert.NotNil(t, err)
|
||||
assert.NotNil(t, c.RemoveId("any"))
|
||||
|
||||
c = newCollection(dummyConn{
|
||||
val: `"bar"`,
|
||||
updateErr: errors.New("any"),
|
||||
}, cach)
|
||||
assert.NotNil(t, c.Update("foo", "bar"))
|
||||
assert.NotNil(t, c.UpdateId("foo", "bar"))
|
||||
_, err = c.Upsert("foo", "bar")
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestStat(t *testing.T) {
|
||||
resetStats()
|
||||
r, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean()
|
||||
|
||||
cach := cache.NewNode(r, singleFlight, stats, mgo.ErrNotFound)
|
||||
c := newCollection(dummyConn{}, cach).(*cachedCollection)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
var str string
|
||||
if err = c.cache.Take(&str, "name", func(v interface{}) error {
|
||||
*v.(*string) = "zero"
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, uint64(10), atomic.LoadUint64(&stats.Total))
|
||||
assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
|
||||
}
|
||||
|
||||
func TestStatCacheFails(t *testing.T) {
|
||||
resetStats()
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stdout)
|
||||
|
||||
r := redis.New("localhost:59999")
|
||||
cach := cache.NewNode(r, singleFlight, stats, mgo.ErrNotFound)
|
||||
c := newCollection(dummyConn{}, cach)
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
var str string
|
||||
err := c.FindOne(&str, "name", bson.M{})
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Total))
|
||||
assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.Hit))
|
||||
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Miss))
|
||||
assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.DbFails))
|
||||
}
|
||||
|
||||
func TestStatDbFails(t *testing.T) {
|
||||
resetStats()
|
||||
r, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean()
|
||||
|
||||
cach := cache.NewNode(r, singleFlight, stats, mgo.ErrNotFound)
|
||||
c := newCollection(dummyConn{}, cach).(*cachedCollection)
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
var str string
|
||||
err := c.cache.Take(&str, "name", func(v interface{}) error {
|
||||
return errors.New("db failed")
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Total))
|
||||
assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.Hit))
|
||||
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.DbFails))
|
||||
}
|
||||
|
||||
func TestStatFromMemory(t *testing.T) {
|
||||
resetStats()
|
||||
r, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean()
|
||||
|
||||
cach := cache.NewNode(r, singleFlight, stats, mgo.ErrNotFound)
|
||||
c := newCollection(dummyConn{}, cach).(*cachedCollection)
|
||||
|
||||
var all sync.WaitGroup
|
||||
var wait sync.WaitGroup
|
||||
all.Add(10)
|
||||
wait.Add(4)
|
||||
go func() {
|
||||
var str string
|
||||
if err := c.cache.Take(&str, "name", func(v interface{}) error {
|
||||
*v.(*string) = "zero"
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
wait.Wait()
|
||||
runtime.Gosched()
|
||||
all.Done()
|
||||
}()
|
||||
|
||||
for i := 0; i < 4; i++ {
|
||||
go func() {
|
||||
var str string
|
||||
wait.Done()
|
||||
if err := c.cache.Take(&str, "name", func(v interface{}) error {
|
||||
*v.(*string) = "zero"
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
all.Done()
|
||||
}()
|
||||
}
|
||||
for i := 0; i < 5; i++ {
|
||||
go func() {
|
||||
var str string
|
||||
if err := c.cache.Take(&str, "name", func(v interface{}) error {
|
||||
*v.(*string) = "zero"
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
all.Done()
|
||||
}()
|
||||
}
|
||||
all.Wait()
|
||||
|
||||
assert.Equal(t, uint64(10), atomic.LoadUint64(&stats.Total))
|
||||
assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
|
||||
}
|
||||
|
||||
func resetStats() {
|
||||
atomic.StoreUint64(&stats.Total, 0)
|
||||
atomic.StoreUint64(&stats.Hit, 0)
|
||||
atomic.StoreUint64(&stats.Miss, 0)
|
||||
atomic.StoreUint64(&stats.DbFails, 0)
|
||||
}
|
||||
|
||||
type dummyConn struct {
|
||||
val string
|
||||
removeErr error
|
||||
updateErr error
|
||||
}
|
||||
|
||||
func (c dummyConn) Find(query interface{}) mongo.Query {
|
||||
return dummyQuery{val: c.val}
|
||||
}
|
||||
|
||||
func (c dummyConn) FindId(id interface{}) mongo.Query {
|
||||
return dummyQuery{val: c.val}
|
||||
}
|
||||
|
||||
func (c dummyConn) Insert(docs ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c dummyConn) Remove(selector interface{}) error {
|
||||
return c.removeErr
|
||||
}
|
||||
|
||||
func (dummyConn) Pipe(pipeline interface{}) mongo.Pipe {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c dummyConn) RemoveAll(selector interface{}) (*mgo.ChangeInfo, error) {
|
||||
return nil, c.removeErr
|
||||
}
|
||||
|
||||
func (c dummyConn) RemoveId(id interface{}) error {
|
||||
return c.removeErr
|
||||
}
|
||||
|
||||
func (c dummyConn) Update(selector, update interface{}) error {
|
||||
return c.updateErr
|
||||
}
|
||||
|
||||
func (c dummyConn) UpdateId(id, update interface{}) error {
|
||||
return c.updateErr
|
||||
}
|
||||
|
||||
func (c dummyConn) Upsert(selector, update interface{}) (*mgo.ChangeInfo, error) {
|
||||
return nil, c.updateErr
|
||||
}
|
||||
|
||||
type dummyQuery struct {
|
||||
val string
|
||||
}
|
||||
|
||||
func (d dummyQuery) All(result interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummyQuery) Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d dummyQuery) Count() (int, error) {
|
||||
return dummyCount, nil
|
||||
}
|
||||
|
||||
func (d dummyQuery) Distinct(key string, result interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummyQuery) Explain(result interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummyQuery) For(result interface{}, f func() error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummyQuery) MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d dummyQuery) One(result interface{}) error {
|
||||
return json.Unmarshal([]byte(d.val), result)
|
||||
}
|
||||
|
||||
func (d dummyQuery) Batch(n int) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) Collation(collation *mgo.Collation) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) Comment(comment string) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) Hint(indexKey ...string) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) Iter() mongo.Iter {
|
||||
return &mgo.Iter{}
|
||||
}
|
||||
|
||||
func (d dummyQuery) Limit(n int) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) LogReplay() mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) Prefetch(p float64) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) Select(selector interface{}) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) SetMaxScan(n int) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) SetMaxTime(duration time.Duration) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) Skip(n int) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) Snapshot() mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) Sort(fields ...string) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) Tail(timeout time.Duration) mongo.Iter {
|
||||
return &mgo.Iter{}
|
||||
}
|
||||
@@ -1,273 +0,0 @@
|
||||
package mongoc
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/zeromicro/go-zero/core/stores/cache"
|
||||
"github.com/zeromicro/go-zero/core/stores/mongo"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
)
|
||||
|
||||
// A Model is a mongo model that built with cache capability.
|
||||
type Model struct {
|
||||
*mongo.Model
|
||||
cache cache.Cache
|
||||
generateCollection func(*mgo.Session) CachedCollection
|
||||
}
|
||||
|
||||
// MustNewNodeModel returns a Model with a cache node, exists on errors.
|
||||
func MustNewNodeModel(url, collection string, rds *redis.Redis, opts ...cache.Option) *Model {
|
||||
model, err := NewNodeModel(url, collection, rds, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return model
|
||||
}
|
||||
|
||||
// MustNewModel returns a Model with a cache cluster, exists on errors.
|
||||
func MustNewModel(url, collection string, c cache.CacheConf, opts ...cache.Option) *Model {
|
||||
model, err := NewModel(url, collection, c, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return model
|
||||
}
|
||||
|
||||
// NewModel returns a Model with a cache cluster.
|
||||
func NewModel(url, collection string, conf cache.CacheConf, opts ...cache.Option) (*Model, error) {
|
||||
c := cache.New(conf, singleFlight, stats, mgo.ErrNotFound, opts...)
|
||||
return NewModelWithCache(url, collection, c)
|
||||
}
|
||||
|
||||
// NewModelWithCache returns a Model with a custom cache.
|
||||
func NewModelWithCache(url, collection string, c cache.Cache) (*Model, error) {
|
||||
return createModel(url, collection, c, func(collection mongo.Collection) CachedCollection {
|
||||
return newCollection(collection, c)
|
||||
})
|
||||
}
|
||||
|
||||
// NewNodeModel returns a Model with a cache node.
|
||||
func NewNodeModel(url, collection string, rds *redis.Redis, opts ...cache.Option) (*Model, error) {
|
||||
c := cache.NewNode(rds, singleFlight, stats, mgo.ErrNotFound, opts...)
|
||||
return NewModelWithCache(url, collection, c)
|
||||
}
|
||||
|
||||
// Count returns the count of given query.
|
||||
func (mm *Model) Count(query interface{}) (int, error) {
|
||||
return mm.executeInt(func(c CachedCollection) (int, error) {
|
||||
return c.Count(query)
|
||||
})
|
||||
}
|
||||
|
||||
// DelCache deletes the cache with given keys.
|
||||
func (mm *Model) DelCache(keys ...string) error {
|
||||
return mm.cache.Del(keys...)
|
||||
}
|
||||
|
||||
// GetCache unmarshal the cache into v with given key.
|
||||
func (mm *Model) GetCache(key string, v interface{}) error {
|
||||
return mm.cache.Get(key, v)
|
||||
}
|
||||
|
||||
// GetCollection returns a cache collection.
|
||||
func (mm *Model) GetCollection(session *mgo.Session) CachedCollection {
|
||||
return mm.generateCollection(session)
|
||||
}
|
||||
|
||||
// FindAllNoCache finds all records without cache.
|
||||
func (mm *Model) FindAllNoCache(v, query interface{}, opts ...QueryOption) error {
|
||||
return mm.execute(func(c CachedCollection) error {
|
||||
return c.FindAllNoCache(v, query, opts...)
|
||||
})
|
||||
}
|
||||
|
||||
// FindOne unmarshals a record into v with given key and query.
|
||||
func (mm *Model) FindOne(v interface{}, key string, query interface{}) error {
|
||||
return mm.execute(func(c CachedCollection) error {
|
||||
return c.FindOne(v, key, query)
|
||||
})
|
||||
}
|
||||
|
||||
// FindOneNoCache unmarshals a record into v with query, without cache.
|
||||
func (mm *Model) FindOneNoCache(v, query interface{}) error {
|
||||
return mm.execute(func(c CachedCollection) error {
|
||||
return c.FindOneNoCache(v, query)
|
||||
})
|
||||
}
|
||||
|
||||
// FindOneId unmarshals a record into v with query.
|
||||
func (mm *Model) FindOneId(v interface{}, key string, id interface{}) error {
|
||||
return mm.execute(func(c CachedCollection) error {
|
||||
return c.FindOneId(v, key, id)
|
||||
})
|
||||
}
|
||||
|
||||
// FindOneIdNoCache unmarshals a record into v with query, without cache.
|
||||
func (mm *Model) FindOneIdNoCache(v, id interface{}) error {
|
||||
return mm.execute(func(c CachedCollection) error {
|
||||
return c.FindOneIdNoCache(v, id)
|
||||
})
|
||||
}
|
||||
|
||||
// Insert inserts docs.
|
||||
func (mm *Model) Insert(docs ...interface{}) error {
|
||||
return mm.execute(func(c CachedCollection) error {
|
||||
return c.Insert(docs...)
|
||||
})
|
||||
}
|
||||
|
||||
// Pipe returns a mongo pipe with given pipeline.
|
||||
func (mm *Model) Pipe(pipeline interface{}) (mongo.Pipe, error) {
|
||||
return mm.pipe(func(c CachedCollection) mongo.Pipe {
|
||||
return c.Pipe(pipeline)
|
||||
})
|
||||
}
|
||||
|
||||
// Remove removes a record with given selector, and remove it from cache with given keys.
|
||||
func (mm *Model) Remove(selector interface{}, keys ...string) error {
|
||||
return mm.execute(func(c CachedCollection) error {
|
||||
return c.Remove(selector, keys...)
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveNoCache removes a record with given selector.
|
||||
func (mm *Model) RemoveNoCache(selector interface{}) error {
|
||||
return mm.execute(func(c CachedCollection) error {
|
||||
return c.RemoveNoCache(selector)
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveAll removes all records with given selector, and removes cache with given keys.
|
||||
func (mm *Model) RemoveAll(selector interface{}, keys ...string) (*mgo.ChangeInfo, error) {
|
||||
return mm.change(func(c CachedCollection) (*mgo.ChangeInfo, error) {
|
||||
return c.RemoveAll(selector, keys...)
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveAllNoCache removes all records with given selector, and returns a mgo.ChangeInfo.
|
||||
func (mm *Model) RemoveAllNoCache(selector interface{}) (*mgo.ChangeInfo, error) {
|
||||
return mm.change(func(c CachedCollection) (*mgo.ChangeInfo, error) {
|
||||
return c.RemoveAllNoCache(selector)
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveId removes a record with given id, and removes cache with given keys.
|
||||
func (mm *Model) RemoveId(id interface{}, keys ...string) error {
|
||||
return mm.execute(func(c CachedCollection) error {
|
||||
return c.RemoveId(id, keys...)
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveIdNoCache removes a record with given id.
|
||||
func (mm *Model) RemoveIdNoCache(id interface{}) error {
|
||||
return mm.execute(func(c CachedCollection) error {
|
||||
return c.RemoveIdNoCache(id)
|
||||
})
|
||||
}
|
||||
|
||||
// SetCache sets the cache with given key and value.
|
||||
func (mm *Model) SetCache(key string, v interface{}) error {
|
||||
return mm.cache.Set(key, v)
|
||||
}
|
||||
|
||||
// Update updates the record with given selector, and delete cache with given keys.
|
||||
func (mm *Model) Update(selector, update interface{}, keys ...string) error {
|
||||
return mm.execute(func(c CachedCollection) error {
|
||||
return c.Update(selector, update, keys...)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateNoCache updates the record with given selector.
|
||||
func (mm *Model) UpdateNoCache(selector, update interface{}) error {
|
||||
return mm.execute(func(c CachedCollection) error {
|
||||
return c.UpdateNoCache(selector, update)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateId updates the record with given id, and delete cache with given keys.
|
||||
func (mm *Model) UpdateId(id, update interface{}, keys ...string) error {
|
||||
return mm.execute(func(c CachedCollection) error {
|
||||
return c.UpdateId(id, update, keys...)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateIdNoCache updates the record with given id.
|
||||
func (mm *Model) UpdateIdNoCache(id, update interface{}) error {
|
||||
return mm.execute(func(c CachedCollection) error {
|
||||
return c.UpdateIdNoCache(id, update)
|
||||
})
|
||||
}
|
||||
|
||||
// Upsert upserts a record with given selector, and delete cache with given keys.
|
||||
func (mm *Model) Upsert(selector, update interface{}, keys ...string) (*mgo.ChangeInfo, error) {
|
||||
return mm.change(func(c CachedCollection) (*mgo.ChangeInfo, error) {
|
||||
return c.Upsert(selector, update, keys...)
|
||||
})
|
||||
}
|
||||
|
||||
// UpsertNoCache upserts a record with given selector.
|
||||
func (mm *Model) UpsertNoCache(selector, update interface{}) (*mgo.ChangeInfo, error) {
|
||||
return mm.change(func(c CachedCollection) (*mgo.ChangeInfo, error) {
|
||||
return c.UpsertNoCache(selector, update)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) change(fn func(c CachedCollection) (*mgo.ChangeInfo, error)) (*mgo.ChangeInfo, error) {
|
||||
session, err := mm.TakeSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer mm.PutSession(session)
|
||||
|
||||
return fn(mm.GetCollection(session))
|
||||
}
|
||||
|
||||
func (mm *Model) execute(fn func(c CachedCollection) error) error {
|
||||
session, err := mm.TakeSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer mm.PutSession(session)
|
||||
|
||||
return fn(mm.GetCollection(session))
|
||||
}
|
||||
|
||||
func (mm *Model) executeInt(fn func(c CachedCollection) (int, error)) (int, error) {
|
||||
session, err := mm.TakeSession()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer mm.PutSession(session)
|
||||
|
||||
return fn(mm.GetCollection(session))
|
||||
}
|
||||
|
||||
func (mm *Model) pipe(fn func(c CachedCollection) mongo.Pipe) (mongo.Pipe, error) {
|
||||
session, err := mm.TakeSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer mm.PutSession(session)
|
||||
|
||||
return fn(mm.GetCollection(session)), nil
|
||||
}
|
||||
|
||||
func createModel(url, collection string, c cache.Cache,
|
||||
create func(mongo.Collection) CachedCollection) (*Model, error) {
|
||||
model, err := mongo.NewModel(url, collection)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Model{
|
||||
Model: model,
|
||||
cache: c,
|
||||
generateCollection: func(session *mgo.Session) CachedCollection {
|
||||
collection := model.GetCollection(session)
|
||||
return create(collection)
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
@@ -9,6 +9,8 @@ var (
|
||||
ErrEmptyType = errors.New("empty redis type")
|
||||
// ErrEmptyKey is an error that indicates no redis key is set.
|
||||
ErrEmptyKey = errors.New("empty redis key")
|
||||
// ErrPing is an error that indicates ping failed.
|
||||
ErrPing = errors.New("ping redis failed")
|
||||
)
|
||||
|
||||
type (
|
||||
@@ -28,6 +30,7 @@ type (
|
||||
)
|
||||
|
||||
// NewRedis returns a Redis.
|
||||
// Deprecated: use MustNewRedis or NewRedis instead.
|
||||
func (rc RedisConf) NewRedis() *Redis {
|
||||
var opts []Option
|
||||
if rc.Type == ClusterType {
|
||||
|
||||
@@ -40,6 +40,16 @@ func TestRedisConf(t *testing.T) {
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
RedisConf: RedisConf{
|
||||
Host: "localhost:6379",
|
||||
Type: ClusterType,
|
||||
Pass: "pwd",
|
||||
Tls: true,
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/mapping"
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
"github.com/zeromicro/go-zero/core/trace"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
@@ -25,15 +24,13 @@ const spanName = "redis"
|
||||
|
||||
var (
|
||||
startTimeKey = contextKey("startTime")
|
||||
durationHook = hook{tracer: otel.GetTracerProvider().Tracer(trace.TraceName)}
|
||||
durationHook = hook{}
|
||||
redisCmdsAttributeKey = attribute.Key("redis.cmds")
|
||||
)
|
||||
|
||||
type (
|
||||
contextKey string
|
||||
hook struct {
|
||||
tracer oteltrace.Tracer
|
||||
}
|
||||
hook struct{}
|
||||
)
|
||||
|
||||
func (h hook) BeforeProcess(ctx context.Context, cmd red.Cmder) (context.Context, error) {
|
||||
@@ -155,7 +152,9 @@ func logDuration(ctx context.Context, cmds []red.Cmder, duration time.Duration)
|
||||
}
|
||||
|
||||
func (h hook) startSpan(ctx context.Context, cmds ...red.Cmder) context.Context {
|
||||
ctx, span := h.tracer.Start(ctx,
|
||||
tracer := trace.TracerFromContext(ctx)
|
||||
|
||||
ctx, span := tracer.Start(ctx,
|
||||
spanName,
|
||||
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
|
||||
)
|
||||
|
||||
@@ -91,13 +91,16 @@ func TestHookProcessPipelineCase1(t *testing.T) {
|
||||
log.SetOutput(&buf)
|
||||
defer log.SetOutput(writer)
|
||||
|
||||
_, err := durationHook.BeforeProcessPipeline(context.Background(), []red.Cmder{})
|
||||
assert.NoError(t, err)
|
||||
ctx, err := durationHook.BeforeProcessPipeline(context.Background(), []red.Cmder{
|
||||
red.NewCmd(context.Background()),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name())
|
||||
|
||||
assert.Nil(t, durationHook.AfterProcessPipeline(ctx, []red.Cmder{
|
||||
assert.NoError(t, durationHook.AfterProcessPipeline(ctx, []red.Cmder{}))
|
||||
assert.NoError(t, durationHook.AfterProcessPipeline(ctx, []red.Cmder{
|
||||
red.NewCmd(context.Background()),
|
||||
}))
|
||||
assert.False(t, strings.Contains(buf.String(), "slow"))
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -42,6 +43,12 @@ type (
|
||||
Score int64
|
||||
}
|
||||
|
||||
// A FloatPair is a key/pair for float set used in redis zet.
|
||||
FloatPair struct {
|
||||
Key string
|
||||
Score float64
|
||||
}
|
||||
|
||||
// Redis defines a redis node/cluster. It is thread-safe.
|
||||
Redis struct {
|
||||
Addr string
|
||||
@@ -80,7 +87,46 @@ type (
|
||||
)
|
||||
|
||||
// New returns a Redis with given options.
|
||||
// Deprecated: use MustNewRedis or NewRedis instead.
|
||||
func New(addr string, opts ...Option) *Redis {
|
||||
return newRedis(addr, opts...)
|
||||
}
|
||||
|
||||
// MustNewRedis returns a Redis with given options.
|
||||
func MustNewRedis(conf RedisConf, opts ...Option) *Redis {
|
||||
rds, err := NewRedis(conf, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return rds
|
||||
}
|
||||
|
||||
// NewRedis returns a Redis with given options.
|
||||
func NewRedis(conf RedisConf, opts ...Option) (*Redis, error) {
|
||||
if err := conf.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if conf.Type == ClusterType {
|
||||
opts = append([]Option{Cluster()}, opts...)
|
||||
}
|
||||
if len(conf.Pass) > 0 {
|
||||
opts = append([]Option{WithPass(conf.Pass)}, opts...)
|
||||
}
|
||||
if conf.Tls {
|
||||
opts = append([]Option{WithTLS()}, opts...)
|
||||
}
|
||||
|
||||
rds := newRedis(conf.Host, opts...)
|
||||
if !rds.Ping() {
|
||||
return nil, ErrPing
|
||||
}
|
||||
|
||||
return rds, nil
|
||||
}
|
||||
|
||||
func newRedis(addr string, opts ...Option) *Redis {
|
||||
r := &Redis{
|
||||
Addr: addr,
|
||||
Type: NodeType,
|
||||
@@ -786,6 +832,28 @@ func (s *Redis) HincrbyCtx(ctx context.Context, key, field string, increment int
|
||||
return
|
||||
}
|
||||
|
||||
// HincrbyFloat is the implementation of redis hincrby command.
|
||||
func (s *Redis) HincrbyFloat(key, field string, increment float64) (float64, error) {
|
||||
return s.HincrbyFloatCtx(context.Background(), key, field, increment)
|
||||
}
|
||||
|
||||
// HincrbyFloatCtx is the implementation of redis hincrby command.
|
||||
func (s *Redis) HincrbyFloatCtx(ctx context.Context, key, field string, increment float64) (val float64, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val, err = conn.HIncrByFloat(ctx, key, field, increment).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Hkeys is the implementation of redis hkeys command.
|
||||
func (s *Redis) Hkeys(key string) ([]string, error) {
|
||||
return s.HkeysCtx(context.Background(), key)
|
||||
@@ -997,6 +1065,26 @@ func (s *Redis) IncrbyCtx(ctx context.Context, key string, increment int64) (val
|
||||
return
|
||||
}
|
||||
|
||||
// IncrbyFloat is the implementation of redis incrby command.
|
||||
func (s *Redis) IncrbyFloat(key string, increment float64) (float64, error) {
|
||||
return s.IncrbyFloatCtx(context.Background(), key, increment)
|
||||
}
|
||||
|
||||
// IncrbyFloatCtx is the implementation of redis incrby command.
|
||||
func (s *Redis) IncrbyFloatCtx(ctx context.Context, key string, increment float64) (val float64, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val, err = conn.IncrByFloat(ctx, key, increment).Result()
|
||||
return err
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Keys is the implementation of redis keys command.
|
||||
func (s *Redis) Keys(pattern string) ([]string, error) {
|
||||
return s.KeysCtx(context.Background(), pattern)
|
||||
@@ -1849,17 +1937,17 @@ func (s *Redis) Zadd(key string, score int64, value string) (bool, error) {
|
||||
return s.ZaddCtx(context.Background(), key, score, value)
|
||||
}
|
||||
|
||||
// ZaddFloat is the implementation of redis zadd command.
|
||||
func (s *Redis) ZaddFloat(key string, score float64, value string) (bool, error) {
|
||||
return s.ZaddFloatCtx(context.Background(), key, score, value)
|
||||
}
|
||||
|
||||
// ZaddCtx is the implementation of redis zadd command.
|
||||
func (s *Redis) ZaddCtx(ctx context.Context, key string, score int64, value string) (
|
||||
val bool, err error) {
|
||||
return s.ZaddFloatCtx(ctx, key, float64(score), value)
|
||||
}
|
||||
|
||||
// ZaddFloat is the implementation of redis zadd command.
|
||||
func (s *Redis) ZaddFloat(key string, score float64, value string) (bool, error) {
|
||||
return s.ZaddFloatCtx(context.Background(), key, score, value)
|
||||
}
|
||||
|
||||
// ZaddFloatCtx is the implementation of redis zadd command.
|
||||
func (s *Redis) ZaddFloatCtx(ctx context.Context, key string, score float64, value string) (
|
||||
val bool, err error) {
|
||||
@@ -2017,6 +2105,47 @@ func (s *Redis) ZscoreCtx(ctx context.Context, key, value string) (val int64, er
|
||||
return
|
||||
}
|
||||
|
||||
// ZscoreByFloat is the implementation of redis zscore command score by float.
|
||||
func (s *Redis) ZscoreByFloat(key, value string) (float64, error) {
|
||||
return s.ZscoreByFloatCtx(context.Background(), key, value)
|
||||
}
|
||||
|
||||
// ZscoreByFloatCtx is the implementation of redis zscore command score by float.
|
||||
func (s *Redis) ZscoreByFloatCtx(ctx context.Context, key, value string) (val float64, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val, err = conn.ZScore(ctx, key, value).Result()
|
||||
return err
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Zscan is the implementation of redis zscan command.
|
||||
func (s *Redis) Zscan(key string, cursor uint64, match string, count int64) (
|
||||
keys []string, cur uint64, err error) {
|
||||
return s.ZscanCtx(context.Background(), key, cursor, match, count)
|
||||
}
|
||||
|
||||
// ZscanCtx is the implementation of redis zscan command.
|
||||
func (s *Redis) ZscanCtx(ctx context.Context, key string, cursor uint64, match string, count int64) (
|
||||
keys []string, cur uint64, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keys, cur, err = conn.ZScan(ctx, key, cursor, match, count).Result()
|
||||
return err
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Zrank is the implementation of redis zrank command.
|
||||
func (s *Redis) Zrank(key, field string) (int64, error) {
|
||||
return s.ZrankCtx(context.Background(), key, field)
|
||||
@@ -2162,13 +2291,52 @@ func (s *Redis) ZrangeWithScoresCtx(ctx context.Context, key string, start, stop
|
||||
return
|
||||
}
|
||||
|
||||
// ZrangeWithScoresByFloat is the implementation of redis zrange command with scores by float64.
|
||||
func (s *Redis) ZrangeWithScoresByFloat(key string, start, stop int64) ([]FloatPair, error) {
|
||||
return s.ZrangeWithScoresByFloatCtx(context.Background(), key, start, stop)
|
||||
}
|
||||
|
||||
// ZrangeWithScoresByFloatCtx is the implementation of redis zrange command with scores by float64.
|
||||
func (s *Redis) ZrangeWithScoresByFloatCtx(ctx context.Context, key string, start, stop int64) (
|
||||
val []FloatPair, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v, err := conn.ZRangeWithScores(ctx, key, start, stop).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val = toFloatPairs(v)
|
||||
return nil
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// ZRevRangeWithScores is the implementation of redis zrevrange command with scores.
|
||||
// Deprecated: use ZrevrangeWithScores instead.
|
||||
func (s *Redis) ZRevRangeWithScores(key string, start, stop int64) ([]Pair, error) {
|
||||
return s.ZRevRangeWithScoresCtx(context.Background(), key, start, stop)
|
||||
return s.ZrevrangeWithScoresCtx(context.Background(), key, start, stop)
|
||||
}
|
||||
|
||||
// ZrevrangeWithScores is the implementation of redis zrevrange command with scores.
|
||||
func (s *Redis) ZrevrangeWithScores(key string, start, stop int64) ([]Pair, error) {
|
||||
return s.ZrevrangeWithScoresCtx(context.Background(), key, start, stop)
|
||||
}
|
||||
|
||||
// ZRevRangeWithScoresCtx is the implementation of redis zrevrange command with scores.
|
||||
// Deprecated: use ZrevrangeWithScoresCtx instead.
|
||||
func (s *Redis) ZRevRangeWithScoresCtx(ctx context.Context, key string, start, stop int64) (
|
||||
val []Pair, err error) {
|
||||
return s.ZrevrangeWithScoresCtx(ctx, key, start, stop)
|
||||
}
|
||||
|
||||
// ZrevrangeWithScoresCtx is the implementation of redis zrevrange command with scores.
|
||||
func (s *Redis) ZrevrangeWithScoresCtx(ctx context.Context, key string, start, stop int64) (
|
||||
val []Pair, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
@@ -2188,6 +2356,45 @@ func (s *Redis) ZRevRangeWithScoresCtx(ctx context.Context, key string, start, s
|
||||
return
|
||||
}
|
||||
|
||||
// ZRevRangeWithScoresByFloat is the implementation of redis zrevrange command with scores by float.
|
||||
// Deprecated: use ZrevrangeWithScoresByFloat instead.
|
||||
func (s *Redis) ZRevRangeWithScoresByFloat(key string, start, stop int64) ([]FloatPair, error) {
|
||||
return s.ZrevrangeWithScoresByFloatCtx(context.Background(), key, start, stop)
|
||||
}
|
||||
|
||||
// ZrevrangeWithScoresByFloat is the implementation of redis zrevrange command with scores by float.
|
||||
func (s *Redis) ZrevrangeWithScoresByFloat(key string, start, stop int64) ([]FloatPair, error) {
|
||||
return s.ZrevrangeWithScoresByFloatCtx(context.Background(), key, start, stop)
|
||||
}
|
||||
|
||||
// ZRevRangeWithScoresByFloatCtx is the implementation of redis zrevrange command with scores by float.
|
||||
// Deprecated: use ZrevrangeWithScoresByFloatCtx instead.
|
||||
func (s *Redis) ZRevRangeWithScoresByFloatCtx(ctx context.Context, key string, start, stop int64) (
|
||||
val []FloatPair, err error) {
|
||||
return s.ZrevrangeWithScoresByFloatCtx(ctx, key, start, stop)
|
||||
}
|
||||
|
||||
// ZrevrangeWithScoresByFloatCtx is the implementation of redis zrevrange command with scores by float.
|
||||
func (s *Redis) ZrevrangeWithScoresByFloatCtx(ctx context.Context, key string, start, stop int64) (
|
||||
val []FloatPair, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v, err := conn.ZRevRangeWithScores(ctx, key, start, stop).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val = toFloatPairs(v)
|
||||
return nil
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// ZrangebyscoreWithScores is the implementation of redis zrangebyscore command with scores.
|
||||
func (s *Redis) ZrangebyscoreWithScores(key string, start, stop int64) ([]Pair, error) {
|
||||
return s.ZrangebyscoreWithScoresCtx(context.Background(), key, start, stop)
|
||||
@@ -2217,6 +2424,35 @@ func (s *Redis) ZrangebyscoreWithScoresCtx(ctx context.Context, key string, star
|
||||
return
|
||||
}
|
||||
|
||||
// ZrangebyscoreWithScoresByFloat is the implementation of redis zrangebyscore command with scores by float.
|
||||
func (s *Redis) ZrangebyscoreWithScoresByFloat(key string, start, stop float64) ([]FloatPair, error) {
|
||||
return s.ZrangebyscoreWithScoresByFloatCtx(context.Background(), key, start, stop)
|
||||
}
|
||||
|
||||
// ZrangebyscoreWithScoresByFloatCtx is the implementation of redis zrangebyscore command with scores by float.
|
||||
func (s *Redis) ZrangebyscoreWithScoresByFloatCtx(ctx context.Context, key string, start, stop float64) (
|
||||
val []FloatPair, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v, err := conn.ZRangeByScoreWithScores(ctx, key, &red.ZRangeBy{
|
||||
Min: strconv.FormatFloat(start, 'f', -1, 64),
|
||||
Max: strconv.FormatFloat(stop, 'f', -1, 64),
|
||||
}).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val = toFloatPairs(v)
|
||||
return nil
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// ZrangebyscoreWithScoresAndLimit is the implementation of redis zrangebyscore command
|
||||
// with scores and limit.
|
||||
func (s *Redis) ZrangebyscoreWithScoresAndLimit(key string, start, stop int64,
|
||||
@@ -2255,6 +2491,45 @@ func (s *Redis) ZrangebyscoreWithScoresAndLimitCtx(ctx context.Context, key stri
|
||||
return
|
||||
}
|
||||
|
||||
// ZrangebyscoreWithScoresByFloatAndLimit is the implementation of redis zrangebyscore command
|
||||
// with scores by float and limit.
|
||||
func (s *Redis) ZrangebyscoreWithScoresByFloatAndLimit(key string, start, stop float64,
|
||||
page, size int) ([]FloatPair, error) {
|
||||
return s.ZrangebyscoreWithScoresByFloatAndLimitCtx(context.Background(),
|
||||
key, start, stop, page, size)
|
||||
}
|
||||
|
||||
// ZrangebyscoreWithScoresByFloatAndLimitCtx is the implementation of redis zrangebyscore command
|
||||
// with scores by float and limit.
|
||||
func (s *Redis) ZrangebyscoreWithScoresByFloatAndLimitCtx(ctx context.Context, key string, start,
|
||||
stop float64, page, size int) (val []FloatPair, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
if size <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v, err := conn.ZRangeByScoreWithScores(ctx, key, &red.ZRangeBy{
|
||||
Min: strconv.FormatFloat(start, 'f', -1, 64),
|
||||
Max: strconv.FormatFloat(stop, 'f', -1, 64),
|
||||
Offset: int64(page * size),
|
||||
Count: int64(size),
|
||||
}).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val = toFloatPairs(v)
|
||||
return nil
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Zrevrange is the implementation of redis zrevrange command.
|
||||
func (s *Redis) Zrevrange(key string, start, stop int64) ([]string, error) {
|
||||
return s.ZrevrangeCtx(context.Background(), key, start, stop)
|
||||
@@ -2305,11 +2580,42 @@ func (s *Redis) ZrevrangebyscoreWithScoresCtx(ctx context.Context, key string, s
|
||||
return
|
||||
}
|
||||
|
||||
// ZrevrangebyscoreWithScoresByFloat is the implementation of redis zrevrangebyscore command with scores by float.
|
||||
func (s *Redis) ZrevrangebyscoreWithScoresByFloat(key string, start, stop float64) (
|
||||
[]FloatPair, error) {
|
||||
return s.ZrevrangebyscoreWithScoresByFloatCtx(context.Background(), key, start, stop)
|
||||
}
|
||||
|
||||
// ZrevrangebyscoreWithScoresByFloatCtx is the implementation of redis zrevrangebyscore command with scores by float.
|
||||
func (s *Redis) ZrevrangebyscoreWithScoresByFloatCtx(ctx context.Context, key string,
|
||||
start, stop float64) (val []FloatPair, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v, err := conn.ZRevRangeByScoreWithScores(ctx, key, &red.ZRangeBy{
|
||||
Min: strconv.FormatFloat(start, 'f', -1, 64),
|
||||
Max: strconv.FormatFloat(stop, 'f', -1, 64),
|
||||
}).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val = toFloatPairs(v)
|
||||
return nil
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// ZrevrangebyscoreWithScoresAndLimit is the implementation of redis zrevrangebyscore command
|
||||
// with scores and limit.
|
||||
func (s *Redis) ZrevrangebyscoreWithScoresAndLimit(key string, start, stop int64,
|
||||
page, size int) ([]Pair, error) {
|
||||
return s.ZrevrangebyscoreWithScoresAndLimitCtx(context.Background(), key, start, stop, page, size)
|
||||
return s.ZrevrangebyscoreWithScoresAndLimitCtx(context.Background(),
|
||||
key, start, stop, page, size)
|
||||
}
|
||||
|
||||
// ZrevrangebyscoreWithScoresAndLimitCtx is the implementation of redis zrevrangebyscore command
|
||||
@@ -2343,6 +2649,45 @@ func (s *Redis) ZrevrangebyscoreWithScoresAndLimitCtx(ctx context.Context, key s
|
||||
return
|
||||
}
|
||||
|
||||
// ZrevrangebyscoreWithScoresByFloatAndLimit is the implementation of redis zrevrangebyscore command
|
||||
// with scores by float and limit.
|
||||
func (s *Redis) ZrevrangebyscoreWithScoresByFloatAndLimit(key string, start, stop float64,
|
||||
page, size int) ([]FloatPair, error) {
|
||||
return s.ZrevrangebyscoreWithScoresByFloatAndLimitCtx(context.Background(),
|
||||
key, start, stop, page, size)
|
||||
}
|
||||
|
||||
// ZrevrangebyscoreWithScoresByFloatAndLimitCtx is the implementation of redis zrevrangebyscore command
|
||||
// with scores by float and limit.
|
||||
func (s *Redis) ZrevrangebyscoreWithScoresByFloatAndLimitCtx(ctx context.Context, key string,
|
||||
start, stop float64, page, size int) (val []FloatPair, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
if size <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v, err := conn.ZRevRangeByScoreWithScores(ctx, key, &red.ZRangeBy{
|
||||
Min: strconv.FormatFloat(start, 'f', -1, 64),
|
||||
Max: strconv.FormatFloat(stop, 'f', -1, 64),
|
||||
Offset: int64(page * size),
|
||||
Count: int64(size),
|
||||
}).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val = toFloatPairs(v)
|
||||
return nil
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Zrevrank is the implementation of redis zrevrank command.
|
||||
func (s *Redis) Zrevrank(key, field string) (int64, error) {
|
||||
return s.ZrevrankCtx(context.Background(), key, field)
|
||||
@@ -2444,19 +2789,43 @@ func toPairs(vals []red.Z) []Pair {
|
||||
return pairs
|
||||
}
|
||||
|
||||
func toStrings(vals []interface{}) []string {
|
||||
ret := make([]string, len(vals))
|
||||
func toFloatPairs(vals []red.Z) []FloatPair {
|
||||
pairs := make([]FloatPair, len(vals))
|
||||
|
||||
for i, val := range vals {
|
||||
if val == nil {
|
||||
ret[i] = ""
|
||||
} else {
|
||||
switch val := val.(type) {
|
||||
case string:
|
||||
ret[i] = val
|
||||
default:
|
||||
ret[i] = mapping.Repr(val)
|
||||
switch member := val.Member.(type) {
|
||||
case string:
|
||||
pairs[i] = FloatPair{
|
||||
Key: member,
|
||||
Score: val.Score,
|
||||
}
|
||||
default:
|
||||
pairs[i] = FloatPair{
|
||||
Key: mapping.Repr(val.Member),
|
||||
Score: val.Score,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return pairs
|
||||
}
|
||||
|
||||
func toStrings(vals []interface{}) []string {
|
||||
ret := make([]string, len(vals))
|
||||
|
||||
for i, val := range vals {
|
||||
if val == nil {
|
||||
ret[i] = ""
|
||||
continue
|
||||
}
|
||||
|
||||
switch val := val.(type) {
|
||||
case string:
|
||||
ret[i] = val
|
||||
default:
|
||||
ret[i] = mapping.Repr(val)
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,12 +8,39 @@ import (
|
||||
)
|
||||
|
||||
func TestBlockingNode(t *testing.T) {
|
||||
r, err := miniredis.Run()
|
||||
assert.Nil(t, err)
|
||||
node, err := CreateBlockingNode(New(r.Addr()))
|
||||
assert.Nil(t, err)
|
||||
node.Close()
|
||||
node, err = CreateBlockingNode(New(r.Addr(), Cluster()))
|
||||
assert.Nil(t, err)
|
||||
node.Close()
|
||||
t.Run("test blocking node", func(t *testing.T) {
|
||||
r, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer r.Close()
|
||||
|
||||
node, err := CreateBlockingNode(New(r.Addr()))
|
||||
assert.NoError(t, err)
|
||||
node.Close()
|
||||
// close again to make sure it's safe
|
||||
assert.NotPanics(t, func() {
|
||||
node.Close()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("test blocking node with cluster", func(t *testing.T) {
|
||||
r, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer r.Close()
|
||||
|
||||
node, err := CreateBlockingNode(New(r.Addr(), Cluster(), WithTLS()))
|
||||
assert.NoError(t, err)
|
||||
node.Close()
|
||||
assert.NotPanics(t, func() {
|
||||
node.Close()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("test blocking node with bad type", func(t *testing.T) {
|
||||
r, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer r.Close()
|
||||
|
||||
_, err = CreateBlockingNode(New(r.Addr(), badType()))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -52,14 +52,11 @@ func TestSqlConn(t *testing.T) {
|
||||
}
|
||||
|
||||
func buildConn() (mock sqlmock.Sqlmock, err error) {
|
||||
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
|
||||
connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
|
||||
var db *sql.DB
|
||||
var err error
|
||||
db, mock, err = sqlmock.New()
|
||||
return &pingedDB{
|
||||
DB: db,
|
||||
}, err
|
||||
return db, err
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package sqlx
|
||||
import (
|
||||
"database/sql"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
@@ -17,43 +16,29 @@ const (
|
||||
|
||||
var connManager = syncx.NewResourceManager()
|
||||
|
||||
type pingedDB struct {
|
||||
*sql.DB
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func getCachedSqlConn(driverName, server string) (*pingedDB, error) {
|
||||
func getCachedSqlConn(driverName, server string) (*sql.DB, error) {
|
||||
val, err := connManager.GetResource(server, func() (io.Closer, error) {
|
||||
conn, err := newDBConnection(driverName, server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pingedDB{
|
||||
DB: conn,
|
||||
}, nil
|
||||
return conn, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return val.(*pingedDB), nil
|
||||
return val.(*sql.DB), nil
|
||||
}
|
||||
|
||||
func getSqlConn(driverName, server string) (*sql.DB, error) {
|
||||
pdb, err := getCachedSqlConn(driverName, server)
|
||||
conn, err := getCachedSqlConn(driverName, server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pdb.once.Do(func() {
|
||||
err = pdb.Ping()
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pdb.DB, nil
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func newDBConnection(driverName, datasource string) (*sql.DB, error) {
|
||||
@@ -70,5 +55,10 @@ func newDBConnection(driverName, datasource string) (*sql.DB, error) {
|
||||
conn.SetMaxOpenConns(maxOpenConns)
|
||||
conn.SetConnMaxLifetime(maxLifetime)
|
||||
|
||||
if err := conn.Ping(); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/trace"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
@@ -14,11 +13,8 @@ import (
|
||||
var sqlAttributeKey = attribute.Key("sql.method")
|
||||
|
||||
func startSpan(ctx context.Context, method string) (context.Context, oteltrace.Span) {
|
||||
tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
|
||||
start, span := tracer.Start(ctx,
|
||||
spanName,
|
||||
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
|
||||
)
|
||||
tracer := trace.TracerFromContext(ctx)
|
||||
start, span := tracer.Start(ctx, spanName, oteltrace.WithSpanKind(oteltrace.SpanKindClient))
|
||||
span.SetAttributes(sqlAttributeKey.String(method))
|
||||
|
||||
return start, span
|
||||
|
||||
@@ -66,7 +66,7 @@ func format(query string, args ...interface{}) (string, error) {
|
||||
switch ch {
|
||||
case '?':
|
||||
if argIndex >= numArgs {
|
||||
return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
|
||||
return "", fmt.Errorf("%d ? in sql, but less arguments provided", argIndex)
|
||||
}
|
||||
|
||||
writeValue(&b, args[argIndex])
|
||||
@@ -93,7 +93,7 @@ func format(query string, args ...interface{}) (string, error) {
|
||||
|
||||
index--
|
||||
if index < 0 || numArgs <= index {
|
||||
return "", fmt.Errorf("error: wrong index %d in sql", index)
|
||||
return "", fmt.Errorf("wrong index %d in sql", index)
|
||||
}
|
||||
|
||||
writeValue(&b, args[index])
|
||||
@@ -124,7 +124,7 @@ func format(query string, args ...interface{}) (string, error) {
|
||||
}
|
||||
|
||||
if argIndex < numArgs {
|
||||
return "", fmt.Errorf("error: %d arguments provided, not matching sql", argIndex)
|
||||
return "", fmt.Errorf("%d arguments provided, not matching sql", argIndex)
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
|
||||
@@ -14,7 +14,6 @@ func (n *node) add(word string) {
|
||||
}
|
||||
|
||||
nd := n
|
||||
var depth int
|
||||
for i, char := range chars {
|
||||
if nd.children == nil {
|
||||
child := new(node)
|
||||
@@ -23,7 +22,6 @@ func (n *node) add(word string) {
|
||||
nd = child
|
||||
} else if child, ok := nd.children[char]; ok {
|
||||
nd = child
|
||||
depth++
|
||||
} else {
|
||||
child := new(node)
|
||||
child.depth = i + 1
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
package stringx
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// replace more than once to avoid overlapped keywords after replace.
|
||||
// only try 2 times to avoid too many or infinite loops.
|
||||
const replaceTimes = 2
|
||||
|
||||
type (
|
||||
// Replacer interface wraps the Replace method.
|
||||
@@ -30,68 +37,48 @@ func NewReplacer(mapping map[string]string) Replacer {
|
||||
|
||||
// Replace replaces text with given substitutes.
|
||||
func (r *replacer) Replace(text string) string {
|
||||
var builder strings.Builder
|
||||
var start int
|
||||
chars := []rune(text)
|
||||
size := len(chars)
|
||||
|
||||
for start < size {
|
||||
cur := r.node
|
||||
|
||||
if start > 0 {
|
||||
builder.WriteString(string(chars[:start]))
|
||||
}
|
||||
|
||||
for i := start; i < size; i++ {
|
||||
child, ok := cur.children[chars[i]]
|
||||
if ok {
|
||||
cur = child
|
||||
} else if cur == r.node {
|
||||
builder.WriteRune(chars[i])
|
||||
// cur already points to root, set start only
|
||||
start = i + 1
|
||||
continue
|
||||
} else {
|
||||
curDepth := cur.depth
|
||||
cur = cur.fail
|
||||
child, ok = cur.children[chars[i]]
|
||||
if !ok {
|
||||
// write this path
|
||||
builder.WriteString(string(chars[i-curDepth : i+1]))
|
||||
// go to root
|
||||
cur = r.node
|
||||
start = i + 1
|
||||
continue
|
||||
}
|
||||
|
||||
failDepth := cur.depth
|
||||
// write path before jump
|
||||
builder.WriteString(string(chars[start : start+curDepth-failDepth]))
|
||||
start += curDepth - failDepth
|
||||
cur = child
|
||||
}
|
||||
|
||||
if cur.end {
|
||||
val := string(chars[i+1-cur.depth : i+1])
|
||||
builder.WriteString(r.mapping[val])
|
||||
builder.WriteString(string(chars[i+1:]))
|
||||
// only matching this path, all previous paths are done
|
||||
if start >= i+1-cur.depth && i+1 >= size {
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
chars = []rune(builder.String())
|
||||
size = len(chars)
|
||||
builder.Reset()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !cur.end {
|
||||
builder.WriteString(string(chars[start:]))
|
||||
return builder.String()
|
||||
for i := 0; i < replaceTimes; i++ {
|
||||
var replaced bool
|
||||
if text, replaced = r.doReplace(text); !replaced {
|
||||
return text
|
||||
}
|
||||
}
|
||||
|
||||
return string(chars)
|
||||
return text
|
||||
}
|
||||
|
||||
func (r *replacer) doReplace(text string) (string, bool) {
|
||||
chars := []rune(text)
|
||||
scopes := r.find(chars)
|
||||
if len(scopes) == 0 {
|
||||
return text, false
|
||||
}
|
||||
|
||||
sort.Slice(scopes, func(i, j int) bool {
|
||||
if scopes[i].start < scopes[j].start {
|
||||
return true
|
||||
}
|
||||
if scopes[i].start == scopes[j].start {
|
||||
return scopes[i].stop > scopes[j].stop
|
||||
}
|
||||
return false
|
||||
})
|
||||
|
||||
var buf strings.Builder
|
||||
var index int
|
||||
for i := 0; i < len(scopes); i++ {
|
||||
scp := &scopes[i]
|
||||
if scp.start < index {
|
||||
continue
|
||||
}
|
||||
|
||||
buf.WriteString(string(chars[index:scp.start]))
|
||||
buf.WriteString(r.mapping[string(chars[scp.start:scp.stop])])
|
||||
index = scp.stop
|
||||
}
|
||||
if index < len(chars) {
|
||||
buf.WriteString(string(chars[index:]))
|
||||
}
|
||||
|
||||
return buf.String(), true
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build go1.18
|
||||
// +build go1.18
|
||||
|
||||
package stringx
|
||||
|
||||
|
||||
@@ -15,6 +15,15 @@ func TestReplacer_Replace(t *testing.T) {
|
||||
assert.Equal(t, "零1234五", NewReplacer(mapping).Replace("零一二三四五"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceJumpMatch(t *testing.T) {
|
||||
mapping := map[string]string{
|
||||
"abcdeg": "ABCDEG",
|
||||
"cdef": "CDEF",
|
||||
"cde": "CDE",
|
||||
}
|
||||
assert.Equal(t, "abCDEF", NewReplacer(mapping).Replace("abcdef"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceOverlap(t *testing.T) {
|
||||
mapping := map[string]string{
|
||||
"3d": "34",
|
||||
@@ -44,6 +53,14 @@ func TestReplacer_ReplacePartialMatch(t *testing.T) {
|
||||
assert.Equal(t, "零一二三四五", NewReplacer(mapping).Replace("零一二三四五"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplacePartialMatchEnds(t *testing.T) {
|
||||
mapping := map[string]string{
|
||||
"二三四七": "2347",
|
||||
"三四": "34",
|
||||
}
|
||||
assert.Equal(t, "零一二34", NewReplacer(mapping).Replace("零一二三四"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceMultiMatches(t *testing.T) {
|
||||
mapping := map[string]string{
|
||||
"二三": "23",
|
||||
@@ -51,6 +68,54 @@ func TestReplacer_ReplaceMultiMatches(t *testing.T) {
|
||||
assert.Equal(t, "零一23四五一23四五", NewReplacer(mapping).Replace("零一二三四五一二三四五"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceLongestMatching(t *testing.T) {
|
||||
keywords := map[string]string{
|
||||
"日本": "japan",
|
||||
"日本的首都": "东京",
|
||||
}
|
||||
replacer := NewReplacer(keywords)
|
||||
assert.Equal(t, "东京在japan", replacer.Replace("日本的首都在日本"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceSuffixMatch(t *testing.T) {
|
||||
// case1
|
||||
{
|
||||
keywords := map[string]string{
|
||||
"abcde": "ABCDE",
|
||||
"bcde": "BCDE",
|
||||
"bcd": "BCD",
|
||||
}
|
||||
assert.Equal(t, "aBCDf", NewReplacer(keywords).Replace("abcdf"))
|
||||
}
|
||||
// case2
|
||||
{
|
||||
keywords := map[string]string{
|
||||
"abcde": "ABCDE",
|
||||
"bcde": "BCDE",
|
||||
"cde": "CDE",
|
||||
"c": "C",
|
||||
"cd": "CD",
|
||||
}
|
||||
assert.Equal(t, "abCDf", NewReplacer(keywords).Replace("abcdf"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceLongestOverlap(t *testing.T) {
|
||||
keywords := map[string]string{
|
||||
"456": "def",
|
||||
"abcd": "1234",
|
||||
}
|
||||
replacer := NewReplacer(keywords)
|
||||
assert.Equal(t, "123def7", replacer.Replace("abcd567"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceLongestLonger(t *testing.T) {
|
||||
mapping := map[string]string{
|
||||
"c": "3",
|
||||
}
|
||||
assert.Equal(t, "3d", NewReplacer(mapping).Replace("cd"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceJumpToFail(t *testing.T) {
|
||||
mapping := map[string]string{
|
||||
"bcdf": "1235",
|
||||
@@ -146,3 +211,21 @@ func TestFuzzReplacerCase2(t *testing.T) {
|
||||
t.Errorf("result: %s, match: %v", val, keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceLongestMatch(t *testing.T) {
|
||||
replacer := NewReplacer(map[string]string{
|
||||
"日本的首都": "东京",
|
||||
"日本": "本日",
|
||||
})
|
||||
assert.Equal(t, "东京是东京", replacer.Replace("日本的首都是东京"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceIndefinitely(t *testing.T) {
|
||||
mapping := map[string]string{
|
||||
"日本的首都": "东京",
|
||||
"东京": "日本的首都",
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
NewReplacer(mapping).Replace("日本的首都是东京")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package trace
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
@@ -10,17 +11,18 @@ import (
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/exporters/jaeger"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
|
||||
"go.opentelemetry.io/otel/exporters/zipkin"
|
||||
"go.opentelemetry.io/otel/sdk/resource"
|
||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
const (
|
||||
kindJaeger = "jaeger"
|
||||
kindZipkin = "zipkin"
|
||||
kindGrpc = "grpc"
|
||||
kindJaeger = "jaeger"
|
||||
kindZipkin = "zipkin"
|
||||
kindOtlpGrpc = "otlpgrpc"
|
||||
kindOtlpHttp = "otlphttp"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -56,15 +58,31 @@ func createExporter(c Config) (sdktrace.SpanExporter, error) {
|
||||
// Just support jaeger and zipkin now, more for later
|
||||
switch c.Batcher {
|
||||
case kindJaeger:
|
||||
u, _ := url.Parse(c.Endpoint)
|
||||
if u.Scheme == "udp" {
|
||||
return jaeger.New(jaeger.WithAgentEndpoint(jaeger.WithAgentHost(u.Hostname()), jaeger.WithAgentPort(u.Port())))
|
||||
}
|
||||
return jaeger.New(jaeger.WithCollectorEndpoint(jaeger.WithEndpoint(c.Endpoint)))
|
||||
case kindZipkin:
|
||||
return zipkin.New(c.Endpoint)
|
||||
case kindGrpc:
|
||||
return otlptracegrpc.NewUnstarted(
|
||||
case kindOtlpGrpc:
|
||||
// Always treat trace exporter as optional component, so we use nonblock here,
|
||||
// otherwise this would slow down app start up even set a dial timeout here when
|
||||
// endpoint can not reach.
|
||||
// If the connection not dial success, the global otel ErrorHandler will catch error
|
||||
// when reporting data like other exporters.
|
||||
return otlptracegrpc.New(
|
||||
context.Background(),
|
||||
otlptracegrpc.WithInsecure(),
|
||||
otlptracegrpc.WithEndpoint(c.Endpoint),
|
||||
otlptracegrpc.WithDialOption(grpc.WithBlock()),
|
||||
), nil
|
||||
)
|
||||
case kindOtlpHttp:
|
||||
// Not support flexible configuration now.
|
||||
return otlptracehttp.New(
|
||||
context.Background(),
|
||||
otlptracehttp.WithInsecure(),
|
||||
otlptracehttp.WithEndpoint(c.Endpoint),
|
||||
)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown exporter: %s", c.Batcher)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,8 @@ func TestStartAgent(t *testing.T) {
|
||||
endpoint1 = "localhost:1234"
|
||||
endpoint2 = "remotehost:1234"
|
||||
endpoint3 = "localhost:1235"
|
||||
endpoint4 = "localhost:1236"
|
||||
endpoint5 = "udp://localhost:6831"
|
||||
)
|
||||
c1 := Config{
|
||||
Name: "foo",
|
||||
@@ -36,7 +38,17 @@ func TestStartAgent(t *testing.T) {
|
||||
c5 := Config{
|
||||
Name: "grpc",
|
||||
Endpoint: endpoint3,
|
||||
Batcher: "grpc",
|
||||
Batcher: kindOtlpGrpc,
|
||||
}
|
||||
c6 := Config{
|
||||
Name: "otlphttp",
|
||||
Endpoint: endpoint4,
|
||||
Batcher: kindOtlpHttp,
|
||||
}
|
||||
c7 := Config{
|
||||
Name: "UDP",
|
||||
Endpoint: endpoint5,
|
||||
Batcher: kindJaeger,
|
||||
}
|
||||
|
||||
StartAgent(c1)
|
||||
@@ -45,16 +57,20 @@ func TestStartAgent(t *testing.T) {
|
||||
StartAgent(c3)
|
||||
StartAgent(c4)
|
||||
StartAgent(c5)
|
||||
StartAgent(c6)
|
||||
StartAgent(c7)
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
// because remotehost cannot be resolved
|
||||
assert.Equal(t, 3, len(agents))
|
||||
assert.Equal(t, 5, len(agents))
|
||||
_, ok := agents[""]
|
||||
assert.True(t, ok)
|
||||
_, ok = agents[endpoint1]
|
||||
assert.True(t, ok)
|
||||
_, ok = agents[endpoint2]
|
||||
assert.False(t, ok)
|
||||
_, ok = agents[endpoint5]
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
@@ -8,5 +8,5 @@ type Config struct {
|
||||
Name string `json:",optional"`
|
||||
Endpoint string `json:",optional"`
|
||||
Sampler float64 `json:",default=1.0"`
|
||||
Batcher string `json:",default=jaeger,options=jaeger|zipkin|grpc"`
|
||||
Batcher string `json:",default=jaeger,options=jaeger|zipkin|otlpgrpc|otlphttp"`
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user