Compare commits

...

64 Commits

Author SHA1 Message Date
Kevin Wan
f3369f8e81 chore: update goctl version to 1.4.4 (#2811) 2023-01-21 21:45:25 +08:00
Kevin Wan
c9b05ae07e fix: mapping optional dep not canonicaled (#2807) 2023-01-20 23:57:49 +08:00
Kevin Wan
32a59dbc27 chore: refactor func name (#2804)
* chore: refactor func name

* chore: make plain log clearer
2023-01-18 17:20:45 +08:00
Kevin Wan
ba0dff2d61 chore: add more tests (#2803)
* chore: add more tests

* chore: add more tests
2023-01-18 13:15:41 +08:00
Kevin Wan
10da5e0424 chore: add more tests (#2801) 2023-01-17 21:55:36 +08:00
Kevin Wan
4bed34090f chore: add more tests (#2800) 2023-01-17 09:59:42 +08:00
Kevin Wan
2bfecf9354 chore: remove mgo related packages (#2799) 2023-01-16 23:13:59 +08:00
Kevin Wan
6d129e0264 chore: add more tests (#2797)
* chore: add more tests

* chore: add more tests

* chore: add more tests
2023-01-16 22:33:39 +08:00
foliet
a2df1bb164 fix: modify the generated update function and add return values for update and delete functions (#2793) 2023-01-15 22:11:08 +08:00
Kevin Wan
5f02e623f5 chore: add more tests (#2795)
* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests
2023-01-15 21:32:41 +08:00
Kevin Wan
963b52fb1b chore: add more tests (#2794) 2023-01-15 15:28:27 +08:00
Kevin Wan
02265d0bfe chore: add more tests (#2792)
* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests
2023-01-15 00:16:12 +08:00
Kevin Wan
2e57e91826 Update readme-cn.md 2023-01-13 23:03:18 +08:00
Ofey Chan
82c642d3f4 feat: expose NewTimingWheelWithClock (#2787) 2023-01-13 17:46:40 +08:00
Kevin Wan
b2571883ca chore: refactor (#2785)
* chore: refactor

* chore: refactor
2023-01-13 14:04:37 +08:00
Alonexy
00ff50c2cc add zset withsocre float (#2689)
* add zset withsocre float

* update

* add IncrbyFloat,HincrbyFloat

Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
2023-01-12 22:37:14 +08:00
Kevin Wan
4d7fa08b0b feat: support **struct in mapping (#2784)
* feat: support **struct in mapping

* chore: fix test failure
2023-01-12 20:45:32 +08:00
Kevin Wan
367afb544c feat: support ptr of ptr of ... in mapping (#2779)
* feat: support ptr of ptr of ... in mapping

* feat: support ptr of ptr of time.Duration in mapping

* feat: support ptr of ptr of json.Number in mapping

* chore: improve setting in mapping

* feat: support ptr of ptr encoding.TextUnmarshaler in mapping

* chore: add more tests

* fix: string ptr

* chore: update tests
2023-01-12 15:56:51 +08:00
cong
43b8c7f641 chore(trace): improve rest tracinghandler (#2783) 2023-01-12 12:50:57 +08:00
dependabot[bot]
a2dcb0079a chore(deps): bump github.com/jhump/protoreflect from 1.14.0 to 1.14.1 (#2782) 2023-01-12 09:41:03 +08:00
cong
f9619328f2 refactor(rest): use static config for trace ignore paths. (#2773) 2023-01-12 09:40:18 +08:00
Kevin Wan
bae061a67e chore: add tests (#2778) 2023-01-11 15:21:39 +08:00
Kevin Wan
0b176e17ac fix: #2576 (#2776) 2023-01-11 00:45:11 +08:00
Kevin Wan
6340e24c17 chore: add tests (#2774) 2023-01-09 23:48:31 +08:00
Kevin Wan
74e0676617 feat: add config to truncate long log content (#2767) 2023-01-09 09:39:30 +08:00
MarkJoyMa
0defb7522f feat: replace NewBetchInserter function name (#2769) 2023-01-09 09:38:57 +08:00
Kevin Wan
0c786ca849 chore: remove simple methods, inlined (#2768) 2023-01-09 00:55:13 +08:00
Kevin Wan
26c541b9cb feat: add middlewares config for zrpc (#2766)
* feat: add middlewares config for zrpc

* chore: add tests

* chore: improve codecov

* chore: improve codecov
2023-01-08 19:34:05 +08:00
Kevin Wan
ade6f9ee46 feat: add middlewares config for rest (#2765)
* feat: add middlewares config for rest

* chore: disable logs in tests

* chore: enable verbose in tests
2023-01-08 16:41:53 +08:00
Kevin Wan
f4502171ea Update readme-cn.md 2023-01-08 12:42:27 +08:00
chensy
8157e2118d fix: replace goctl ExactValidArgs to MatchAll (#2759)
Co-authored-by: chenjieping <chenjieping@kezaihui.com>
2023-01-07 17:07:40 +08:00
Kevin Wan
e52dace416 chore: refactor (#2764) 2023-01-07 14:13:44 +08:00
chen quan
dc260f196a refactor: simplify the code (#2763)
* refactor: simplify the code

* fix: fix data race

* refactor: simplify the code

* refactor: simplify the code
2023-01-07 13:32:56 +08:00
dependabot[bot]
559726112c chore(deps): bump github.com/alicebob/miniredis/v2 from 2.23.1 to 2.30.0 (#2762)
Bumps [github.com/alicebob/miniredis/v2](https://github.com/alicebob/miniredis) from 2.23.1 to 2.30.0.
- [Release notes](https://github.com/alicebob/miniredis/releases)
- [Changelog](https://github.com/alicebob/miniredis/blob/master/CHANGELOG.md)
- [Commits](https://github.com/alicebob/miniredis/compare/v2.23.1...v2.30.0)

---
updated-dependencies:
- dependency-name: github.com/alicebob/miniredis/v2
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-01-07 12:15:22 +08:00
MarkJoyMa
a5fcf24c04 feat: add batch inserter (#2755) 2023-01-06 23:30:50 +08:00
chen quan
fc9b3ffdc1 refactor: use opentelemetry's standard api to track http status code (#2760) 2023-01-06 23:27:54 +08:00
MarkJoyMa
e71c505e94 feat: add mongo options (#2753)
* feat: add mongo options

* feat: add mongo options

* feat: add mongo options

* feat: add mongo options

* feat: add mongo options

* feat: add mongo options
2023-01-05 22:14:50 +08:00
chen quan
21c49009c0 chore: remove unnecessary code (#2754) 2023-01-05 22:12:07 +08:00
#Suyghur
69d355eb4b feat(redis): add zscan command implementation (#2729) (#2751) 2023-01-04 13:44:17 +08:00
Kevin Wan
83f88d177f chore: improve codecov (#2752) 2023-01-04 13:42:20 +08:00
xiang
641ebf1667 feat: trace http.status_code (#2708)
* feat: trace http.status_code

* feat: implements http.Flusher & http.Hijacker for traceResponseWriter

* test: delete notTracingSpans after test

* feat: trace http.status_code

* feat: implements http.Flusher & http.Hijacker for traceResponseWriter

* test: delete notTracingSpans after test

* refactor: update trace handler span message

* fix: code conflict
2023-01-04 10:21:57 +08:00
Kevin Wan
cf435bfcc1 chore: remove roadmap file, not updating (#2749) 2023-01-03 23:38:14 +08:00
Kevin Wan
28f1b15b8e Update readme.md (#2748) 2023-01-03 23:14:39 +08:00
cong
42413dc294 feat(trace): support otlp http exporter (#2746)
* feat(trace): support otlp http exporter

* chore: use otlptracehttp v1.10.0 not upgrade grpc version prevent other modules break

* refactor(trace): rename exporter kind grpc to otlpgrpc.

BREAKING CHANGE: trace Config.Batcher should use otlpgrpc instead of grpc now.
2023-01-03 22:49:30 +08:00
Kevin Wan
ec7ac43948 chore: reorg imports (#2745)
* chore: reorg imports

* chore: format code
2023-01-03 22:26:45 +08:00
cong
deefc1a8eb fix(trace): grpc exporter should use nonblock option (#2744)
* fix(trace): grpc exporter should use nonblock option

* chore: sort imports
2023-01-03 18:15:09 +08:00
Kevin Wan
036328f1ea chore: update tests (#2741)
* chore: update tests

* chore: codecov on comments

* chore: codecov on comments
2023-01-03 18:02:35 +08:00
wojiukankan
85057a623d 🐛 debug grpc export (#2379) (#2719)
* 🐛 debug grpc export (#2379) 

#2379 Fixed the issue that the GRPC exporter did not establish an RPC link
原文使用的 otlptracegrpc.NewUnstarted创建的是一个未建立rpc连接的导出器,无法正常使用;改为otlptracegrpc.New才妥

* Update agent_test.go

修复单元测试失败
2023-01-03 17:04:35 +08:00
Xargin
1c544a26be use stat instead of disableStat (#2740) 2023-01-03 11:29:24 +08:00
chainlife
20a61ce43e logx conf add DisableStat (#2434)
Co-authored-by: sunsoft <sunsoft@qq.com>
2023-01-02 23:22:13 +08:00
Kevin Wan
dd294e8cd6 fix: #2700, timeout not enough for writing responses (#2738)
* fix: #2700, timeout not enough for writing responses

* fix: test fail

* chore: add comments
2023-01-02 13:51:15 +08:00
JackSon_tm.m
3e9d0161bc add ServeHTTP to Server/Engin for doing Httptest (#2704) 2023-01-02 00:24:58 +08:00
Kevin Wan
cf6c349118 fix: #2735 (#2736)
* fix: #2735

* chore: make error consistent
2023-01-01 12:21:53 +08:00
Kevin Wan
c7a0ec428c fix: key like TLSConfig not working (#2730)
* fix: key like TLSConfig not working

* fix: remove unnecessary code

* chore: rename variable
2022-12-29 14:50:53 +08:00
chowyu12
ce1c02f4f9 Feat: ignorecolums add sort (#2648)
* add go-grpc_opt and go_opt for grpc new command

* feat: remove log when disable log

* feat: add sort

Co-authored-by: zhouyy <zhouyy@ickey.cn>
2022-12-28 14:53:22 +08:00
Kevin Wan
c3756a8f1c fix: etcd publisher reconnecting problem (#2710)
* fix: etcd publisher reconnecting problem

* chore: fix wrong call
2022-12-27 20:03:03 +08:00
Archer
f4fd735aee Use read-write lock instead of mutex (#2727) 2022-12-26 15:00:47 +08:00
Archer
683d793719 RawFieldNames should ignore the field whose name is start with a dash (#2725) 2022-12-24 21:27:32 +08:00
Kevin Wan
affbcb5698 fix: camel cased key of map item in config (#2715)
* fix: camel cased key of map item in config

* fix: mapping anonymous problem

* fix: mapping anonymous problem

* chore: refactor

* chore: add more tests

* chore: refactor
2022-12-24 21:26:33 +08:00
Kevin Wan
f0d1722bbd chore: pass by value for config in dev server (#2712) 2022-12-24 11:41:23 +08:00
chowyu12
c4f8eca459 Feat update rootpkg (#2718)
* add go-grpc_opt and go_opt for grpc new command

* feat: remove log when disable log

* feat: remove repeat code

Co-authored-by: zhouyy <zhouyy@ickey.cn>
2022-12-23 23:57:56 +08:00
Kevin Wan
251c071418 Update readme.md 2022-12-22 23:21:41 +08:00
Kevin Wan
6652c4e445 Update readme-cn.md 2022-12-16 00:08:12 +08:00
Kevin Wan
f73613dff0 Update readme.md 2022-12-16 00:07:50 +08:00
131 changed files with 4296 additions and 5043 deletions

View File

@@ -1,3 +1,6 @@
comment: false
comment:
layout: "flags, files"
behavior: once
require_changes: true
ignore:
- "tools"

View File

@@ -1,28 +0,0 @@
# go-zero Roadmap
This document defines a high level roadmap for go-zero development and upcoming releases.
Community and contributor involvement is vital for successfully implementing all desired items for each release.
We hope that the items listed below will inspire further engagement from the community to keep go-zero progressing and shipping exciting and valuable features.
## 2021 Q2
- [x] Support service discovery through K8S client api
- [x] Log full sql statements for easier sql problem solving
## 2021 Q3
- [x] Support `goctl model pg` to support PostgreSQL code generation
- [x] Adapt builtin tracing mechanism to opentracing solutions
## 2021 Q4
- [x] Support `username/password` authentication in ETCD
- [x] Support `SSL/TLS` in ETCD
- [x] Support `SSL/TLS` in `zRPC`
- [x] Support `TLS` in redis connections
- [x] Support `goctl bug` to report bugs conveniently
## 2022
- [x] Support `context` in redis related methods for timeout and tracing
- [x] Support `context` in sql related methods for timeout and tracing
- [x] Support `context` in mongodb related methods for timeout and tracing
- [x] Add `httpc.Do` with HTTP call governance, like circuit breaker etc.
- [ ] Support `goctl doctor` command to report potential issues for given service
- [ ] Support `goctl mock` command to start a mocking server with given `.api` file

View File

@@ -32,9 +32,11 @@ func NewECBEncrypter(b cipher.Block) cipher.BlockMode {
return (*ecbEncrypter)(newECB(b))
}
// BlockSize returns the mode's block size.
func (x *ecbEncrypter) BlockSize() int { return x.blockSize }
// why we don't return error is because cipher.BlockMode doesn't allow this
// CryptBlocks encrypts a number of blocks. The length of src must be a multiple of
// the block size. Dst and src must overlap entirely or not at all.
func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
if len(src)%x.blockSize != 0 {
logx.Error("crypto/cipher: input not full blocks")
@@ -59,11 +61,13 @@ func NewECBDecrypter(b cipher.Block) cipher.BlockMode {
return (*ecbDecrypter)(newECB(b))
}
// BlockSize returns the mode's block size.
func (x *ecbDecrypter) BlockSize() int {
return x.blockSize
}
// why we don't return error is because cipher.BlockMode doesn't allow this
// CryptBlocks decrypts a number of blocks. The length of src must be a multiple of
// the block size. Dst and src must overlap entirely or not at all.
func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
if len(src)%x.blockSize != 0 {
logx.Error("crypto/cipher: input not full blocks")

View File

@@ -1,6 +1,7 @@
package codec
import (
"crypto/aes"
"encoding/base64"
"testing"
@@ -10,7 +11,8 @@ import (
func TestAesEcb(t *testing.T) {
var (
key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
val = []byte("hello")
val = []byte("helloworld")
valLong = []byte("helloworldlong..")
badKey1 = []byte("aaaaaaaaa")
// more than 32 chars
badKey2 = []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
@@ -31,6 +33,39 @@ func TestAesEcb(t *testing.T) {
src, err := EcbDecrypt(key, dst)
assert.Nil(t, err)
assert.Equal(t, val, src)
block, err := aes.NewCipher(key)
assert.NoError(t, err)
encrypter := NewECBEncrypter(block)
assert.Equal(t, 16, encrypter.BlockSize())
decrypter := NewECBDecrypter(block)
assert.Equal(t, 16, decrypter.BlockSize())
dst = make([]byte, 8)
encrypter.CryptBlocks(dst, val)
for _, b := range dst {
assert.Equal(t, byte(0), b)
}
dst = make([]byte, 8)
encrypter.CryptBlocks(dst, valLong)
for _, b := range dst {
assert.Equal(t, byte(0), b)
}
dst = make([]byte, 8)
decrypter.CryptBlocks(dst, val)
for _, b := range dst {
assert.Equal(t, byte(0), b)
}
dst = make([]byte, 8)
decrypter.CryptBlocks(dst, valLong)
for _, b := range dst {
assert.Equal(t, byte(0), b)
}
_, err = EcbEncryptBase64("cTR0N3dDKkYtSmFOZFJnVWpYbjJyNXU4eC9BP0QK", "aGVsbG93b3JsZGxvbmcuLgo=")
assert.Error(t, err)
}
func TestAesEcbBase64(t *testing.T) {

View File

@@ -80,3 +80,17 @@ func TestKeyBytes(t *testing.T) {
assert.Nil(t, err)
assert.True(t, len(key.Bytes()) > 0)
}
func TestDHOnErrors(t *testing.T) {
key, err := GenerateKey()
assert.Nil(t, err)
assert.NotEmpty(t, key.Bytes())
_, err = ComputeKey(key.PubKey, key.PriKey)
assert.NoError(t, err)
_, err = ComputeKey(nil, key.PriKey)
assert.Error(t, err)
_, err = ComputeKey(key.PubKey, nil)
assert.Error(t, err)
assert.NotNil(t, NewPublicKey([]byte("")))
}

View File

@@ -6,7 +6,7 @@ import "sync"
type Ring struct {
elements []interface{}
index int
lock sync.Mutex
lock sync.RWMutex
}
// NewRing returns a Ring object with the given size n.
@@ -31,8 +31,8 @@ func (r *Ring) Add(v interface{}) {
// Take takes all items from r.
func (r *Ring) Take() []interface{} {
r.lock.Lock()
defer r.lock.Unlock()
r.lock.RLock()
defer r.lock.RUnlock()
var size int
var start int

View File

@@ -69,10 +69,11 @@ func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*Tim
interval, numSlots, execute)
}
return newTimingWheelWithClock(interval, numSlots, execute, timex.NewTicker(interval))
return NewTimingWheelWithTicker(interval, numSlots, execute, timex.NewTicker(interval))
}
func newTimingWheelWithClock(interval time.Duration, numSlots int, execute Execute,
// NewTimingWheelWithTicker returns a TimingWheel with the given ticker.
func NewTimingWheelWithTicker(interval time.Duration, numSlots int, execute Execute,
ticker timex.Ticker) (*TimingWheel, error) {
tw := &TimingWheel{
interval: interval,

View File

@@ -26,7 +26,7 @@ func TestNewTimingWheel(t *testing.T) {
func TestTimingWheel_Drain(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
}, ticker)
tw.SetTimer("first", 3, testStep*4)
tw.SetTimer("second", 5, testStep*7)
@@ -62,7 +62,7 @@ func TestTimingWheel_Drain(t *testing.T) {
func TestTimingWheel_SetTimerSoon(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
@@ -78,7 +78,7 @@ func TestTimingWheel_SetTimerSoon(t *testing.T) {
func TestTimingWheel_SetTimerTwice(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 5, v.(int))
@@ -96,7 +96,7 @@ func TestTimingWheel_SetTimerTwice(t *testing.T) {
func TestTimingWheel_SetTimerWrongDelay(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {}, ticker)
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
defer tw.Stop()
assert.NotPanics(t, func() {
tw.SetTimer("any", 3, -testStep)
@@ -105,7 +105,7 @@ func TestTimingWheel_SetTimerWrongDelay(t *testing.T) {
func TestTimingWheel_SetTimerAfterClose(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {}, ticker)
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
tw.Stop()
assert.Equal(t, ErrClosed, tw.SetTimer("any", 3, testStep))
}
@@ -113,7 +113,7 @@ func TestTimingWheel_SetTimerAfterClose(t *testing.T) {
func TestTimingWheel_MoveTimer(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := newTimingWheelWithClock(testStep, 3, func(k, v interface{}) {
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
@@ -139,7 +139,7 @@ func TestTimingWheel_MoveTimer(t *testing.T) {
func TestTimingWheel_MoveTimerSoon(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := newTimingWheelWithClock(testStep, 3, func(k, v interface{}) {
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
@@ -155,7 +155,7 @@ func TestTimingWheel_MoveTimerSoon(t *testing.T) {
func TestTimingWheel_MoveTimerEarlier(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
@@ -173,7 +173,7 @@ func TestTimingWheel_MoveTimerEarlier(t *testing.T) {
func TestTimingWheel_RemoveTimer(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {}, ticker)
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
tw.SetTimer("any", 3, testStep)
assert.NotPanics(t, func() {
tw.RemoveTimer("any")
@@ -236,7 +236,7 @@ func TestTimingWheel_SetTimer(t *testing.T) {
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
assert.Equal(t, 1, key.(int))
assert.Equal(t, 2, value.(int))
actual = atomic.LoadInt32(&count)
@@ -317,7 +317,7 @@ func TestTimingWheel_SetAndMoveThenStart(t *testing.T) {
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
@@ -405,7 +405,7 @@ func TestTimingWheel_SetAndMoveTwice(t *testing.T) {
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
@@ -486,7 +486,7 @@ func TestTimingWheel_ElapsedAndSet(t *testing.T) {
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
@@ -577,7 +577,7 @@ func TestTimingWheel_ElapsedAndSetThenMove(t *testing.T) {
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
@@ -612,7 +612,7 @@ func TestMoveAndRemoveTask(t *testing.T) {
}
}
var keys []int
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
keys = append(keys, v.(int))

View File

@@ -5,6 +5,7 @@ import (
"log"
"os"
"path"
"reflect"
"strings"
"github.com/zeromicro/go-zero/core/jsonx"
@@ -12,8 +13,6 @@ import (
"github.com/zeromicro/go-zero/internal/encoding"
)
const distanceBetweenUpperAndLower = 32
var loaders = map[string]func([]byte, interface{}) error{
".json": LoadFromJsonBytes,
".toml": LoadFromTomlBytes,
@@ -21,6 +20,12 @@ var loaders = map[string]func([]byte, interface{}) error{
".yml": LoadFromYamlBytes,
}
type fieldInfo struct {
name string
kind reflect.Kind
children map[string]fieldInfo
}
// Load loads config into v from file, .json, .yaml and .yml are acceptable.
func Load(file string, v interface{}, opts ...Option) error {
content, err := os.ReadFile(file)
@@ -58,7 +63,10 @@ func LoadFromJsonBytes(content []byte, v interface{}) error {
return err
}
return mapping.UnmarshalJsonMap(toCamelCaseKeyMap(m), v, mapping.WithCanonicalKeyFunc(toCamelCase))
finfo := buildFieldsInfo(reflect.TypeOf(v))
lowerCaseKeyMap := toLowerCaseKeyMap(m, finfo)
return mapping.UnmarshalJsonMap(lowerCaseKeyMap, v, mapping.WithCanonicalKeyFunc(toLowerCase))
}
// LoadConfigFromJsonBytes loads config into v from content json bytes.
@@ -100,53 +108,76 @@ func MustLoad(path string, v interface{}, opts ...Option) {
}
}
func toCamelCase(s string) string {
var buf strings.Builder
buf.Grow(len(s))
var capNext bool
boundary := true
for _, v := range s {
isCap := v >= 'A' && v <= 'Z'
isLow := v >= 'a' && v <= 'z'
if boundary && (isCap || isLow) {
if capNext {
if isLow {
v -= distanceBetweenUpperAndLower
func buildFieldsInfo(tp reflect.Type) map[string]fieldInfo {
tp = mapping.Deref(tp)
switch tp.Kind() {
case reflect.Struct:
return buildStructFieldsInfo(tp)
case reflect.Array, reflect.Slice:
return buildFieldsInfo(mapping.Deref(tp.Elem()))
default:
return nil
}
}
func buildStructFieldsInfo(tp reflect.Type) map[string]fieldInfo {
info := make(map[string]fieldInfo)
for i := 0; i < tp.NumField(); i++ {
field := tp.Field(i)
name := field.Name
lowerCaseName := toLowerCase(name)
ft := mapping.Deref(field.Type)
// flatten anonymous fields
if field.Anonymous {
if ft.Kind() == reflect.Struct {
fields := buildFieldsInfo(ft)
for k, v := range fields {
info[k] = v
}
} else {
if isCap {
v += distanceBetweenUpperAndLower
info[lowerCaseName] = fieldInfo{
name: name,
kind: ft.Kind(),
}
}
boundary = false
continue
}
if isCap || isLow {
buf.WriteRune(v)
capNext = false
} else if v == ' ' || v == '\t' {
buf.WriteRune(v)
capNext = false
boundary = true
} else if v == '_' {
capNext = true
boundary = true
} else {
buf.WriteRune(v)
capNext = true
var fields map[string]fieldInfo
switch ft.Kind() {
case reflect.Struct:
fields = buildFieldsInfo(ft)
case reflect.Array, reflect.Slice:
fields = buildFieldsInfo(ft.Elem())
case reflect.Map:
fields = buildFieldsInfo(ft.Elem())
}
info[lowerCaseName] = fieldInfo{
name: name,
kind: ft.Kind(),
children: fields,
}
}
return buf.String()
return info
}
func toCamelCaseInterface(v interface{}) interface{} {
func toLowerCase(s string) string {
return strings.ToLower(s)
}
func toLowerCaseInterface(v interface{}, info map[string]fieldInfo) interface{} {
switch vv := v.(type) {
case map[string]interface{}:
return toCamelCaseKeyMap(vv)
return toLowerCaseKeyMap(vv, info)
case []interface{}:
var arr []interface{}
for _, vvv := range vv {
arr = append(arr, toCamelCaseInterface(vvv))
arr = append(arr, toLowerCaseInterface(vvv, info))
}
return arr
default:
@@ -154,10 +185,22 @@ func toCamelCaseInterface(v interface{}) interface{} {
}
}
func toCamelCaseKeyMap(m map[string]interface{}) map[string]interface{} {
func toLowerCaseKeyMap(m map[string]interface{}, info map[string]fieldInfo) map[string]interface{} {
res := make(map[string]interface{})
for k, v := range m {
res[toCamelCase(k)] = toCamelCaseInterface(v)
ti, ok := info[k]
if ok {
res[k] = toLowerCaseInterface(v, ti.children)
continue
}
lk := toLowerCase(k)
if ti, ok = info[lk]; ok {
res[lk] = toLowerCaseInterface(v, ti.children)
} else {
res[k] = v
}
}
return res

View File

@@ -97,6 +97,30 @@ d = "abcd!@#$112"
assert.Equal(t, "abcd!@#$112", val.D)
}
func TestConfigOptional(t *testing.T) {
text := `a = "foo"
b = 1
c = "FOO"
d = "abcd"
`
tmpfile, err := createTempFile(".toml", text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
var val struct {
A string `json:"a"`
B int `json:"b,optional"`
C string `json:"c,optional=B"`
D string `json:"d,optional=b"`
}
if assert.NoError(t, Load(tmpfile, &val)) {
assert.Equal(t, "foo", val.A)
assert.Equal(t, 1, val.B)
assert.Equal(t, "FOO", val.C)
assert.Equal(t, "abcd", val.D)
}
}
func TestConfigJsonCanonical(t *testing.T) {
text := []byte(`{"a": "foo", "B": "bar"}`)
@@ -237,23 +261,23 @@ func TestToCamelCase(t *testing.T) {
},
{
input: "hello_world",
expect: "helloWorld",
expect: "hello_world",
},
{
input: "Hello_world",
expect: "helloWorld",
expect: "hello_world",
},
{
input: "hello_World",
expect: "helloWorld",
expect: "hello_world",
},
{
input: "helloWorld",
expect: "helloWorld",
expect: "helloworld",
},
{
input: "HelloWorld",
expect: "helloWorld",
expect: "helloworld",
},
{
input: "hello World",
@@ -269,30 +293,34 @@ func TestToCamelCase(t *testing.T) {
},
{
input: "Hello World foo_bar",
expect: "hello world fooBar",
expect: "hello world foo_bar",
},
{
input: "Hello World foo_Bar",
expect: "hello world fooBar",
expect: "hello world foo_bar",
},
{
input: "Hello World Foo_bar",
expect: "hello world fooBar",
expect: "hello world foo_bar",
},
{
input: "Hello World Foo_Bar",
expect: "hello world fooBar",
expect: "hello world foo_bar",
},
{
input: "Hello.World Foo_Bar",
expect: "hello.world foo_bar",
},
{
input: "你好 World Foo_Bar",
expect: "你好 world fooBar",
expect: "你好 world foo_bar",
},
}
for _, test := range tests {
test := test
t.Run(test.input, func(t *testing.T) {
assert.Equal(t, test.expect, toCamelCase(test.input))
assert.Equal(t, test.expect, toLowerCase(test.input))
})
}
}
@@ -328,6 +356,100 @@ func TestLoadFromYamlBytes(t *testing.T) {
assert.Equal(t, "foo", val.Layer1.Layer2.Layer3)
}
func TestLoadFromYamlBytesTerm(t *testing.T) {
input := []byte(`layer1:
layer2:
tls_conf: foo`)
var val struct {
Layer1 struct {
Layer2 struct {
Layer3 string `json:"tls_conf"`
}
}
}
assert.NoError(t, LoadFromYamlBytes(input, &val))
assert.Equal(t, "foo", val.Layer1.Layer2.Layer3)
}
func TestLoadFromYamlBytesLayers(t *testing.T) {
input := []byte(`layer1:
layer2:
layer3: foo`)
var val struct {
Value string `json:"Layer1.Layer2.Layer3"`
}
assert.NoError(t, LoadFromYamlBytes(input, &val))
assert.Equal(t, "foo", val.Value)
}
func TestUnmarshalJsonBytesMap(t *testing.T) {
input := []byte(`{"foo":{"/mtproto.RPCTos": "bff.bff","bar":"baz"}}`)
var val struct {
Foo map[string]string
}
assert.NoError(t, LoadFromJsonBytes(input, &val))
assert.Equal(t, "bff.bff", val.Foo["/mtproto.RPCTos"])
assert.Equal(t, "baz", val.Foo["bar"])
}
func TestUnmarshalJsonBytesMapWithSliceElements(t *testing.T) {
input := []byte(`{"foo":{"/mtproto.RPCTos": ["bff.bff", "any"],"bar":["baz", "qux"]}}`)
var val struct {
Foo map[string][]string
}
assert.NoError(t, LoadFromJsonBytes(input, &val))
assert.EqualValues(t, []string{"bff.bff", "any"}, val.Foo["/mtproto.RPCTos"])
assert.EqualValues(t, []string{"baz", "qux"}, val.Foo["bar"])
}
func TestUnmarshalJsonBytesMapWithSliceOfStructs(t *testing.T) {
input := []byte(`{"foo":{
"/mtproto.RPCTos": [{"bar": "any"}],
"bar":[{"bar": "qux"}, {"bar": "ever"}]}}`)
var val struct {
Foo map[string][]struct {
Bar string
}
}
assert.NoError(t, LoadFromJsonBytes(input, &val))
assert.Equal(t, 1, len(val.Foo["/mtproto.RPCTos"]))
assert.Equal(t, "any", val.Foo["/mtproto.RPCTos"][0].Bar)
assert.Equal(t, 2, len(val.Foo["bar"]))
assert.Equal(t, "qux", val.Foo["bar"][0].Bar)
assert.Equal(t, "ever", val.Foo["bar"][1].Bar)
}
func TestUnmarshalJsonBytesWithAnonymousField(t *testing.T) {
type (
Int int
InnerConf struct {
Name string
}
Conf struct {
Int
InnerConf
}
)
var (
input = []byte(`{"Name": "hello", "int": 3}`)
c Conf
)
assert.NoError(t, LoadFromJsonBytes(input, &c))
assert.Equal(t, "hello", c.Name)
assert.Equal(t, Int(3), c.Int)
}
func createTempFile(ext, text string) (string, error) {
tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
if err != nil {

View File

@@ -1,6 +1,8 @@
package discov
import (
"time"
"github.com/zeromicro/go-zero/core/discov/internal"
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logx"
@@ -51,12 +53,7 @@ func NewPublisher(endpoints []string, key, value string, opts ...PubOption) *Pub
// KeepAlive keeps key:value alive.
func (p *Publisher) KeepAlive() error {
cli, err := internal.GetRegistry().GetConn(p.endpoints)
if err != nil {
return err
}
p.lease, err = p.register(cli)
cli, err := p.doRegister()
if err != nil {
return err
}
@@ -83,6 +80,43 @@ func (p *Publisher) Stop() {
p.quit.Close()
}
func (p *Publisher) doKeepAlive() error {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for range ticker.C {
select {
case <-p.quit.Done():
return nil
default:
cli, err := p.doRegister()
if err != nil {
logx.Errorf("etcd publisher doRegister: %s", err.Error())
break
}
if err := p.keepAliveAsync(cli); err != nil {
logx.Errorf("etcd publisher keepAliveAsync: %s", err.Error())
break
}
return nil
}
}
return nil
}
func (p *Publisher) doRegister() (internal.EtcdClient, error) {
cli, err := internal.GetRegistry().GetConn(p.endpoints)
if err != nil {
return nil, err
}
p.lease, err = p.register(cli)
return cli, err
}
func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error {
ch, err := cli.KeepAlive(cli.Ctx(), p.lease)
if err != nil {
@@ -95,8 +129,8 @@ func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error {
case _, ok := <-ch:
if !ok {
p.revoke(cli)
if err := p.KeepAlive(); err != nil {
logx.Errorf("KeepAlive: %s", err.Error())
if err := p.doKeepAlive(); err != nil {
logx.Errorf("etcd publisher KeepAlive: %s", err.Error())
}
return
}
@@ -105,8 +139,8 @@ func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error {
p.revoke(cli)
select {
case <-p.resumeChan:
if err := p.KeepAlive(); err != nil {
logx.Errorf("KeepAlive: %s", err.Error())
if err := p.doKeepAlive(); err != nil {
logx.Errorf("etcd publisher KeepAlive: %s", err.Error())
}
return
case <-p.quit.Done():
@@ -141,7 +175,7 @@ func (p *Publisher) register(client internal.EtcdClient) (clientv3.LeaseID, erro
func (p *Publisher) revoke(cli internal.EtcdClient) {
if _, err := cli.Revoke(cli.Ctx(), p.lease); err != nil {
logx.Error(err)
logx.Errorf("etcd publisher revoke: %s", err.Error())
}
}

View File

@@ -1,44 +0,0 @@
package jsontype
import (
"encoding/json"
"time"
"github.com/globalsign/mgo/bson"
)
// MilliTime represents time.Time that works better with mongodb.
type MilliTime struct {
time.Time
}
// MarshalJSON marshals mt to json bytes.
func (mt MilliTime) MarshalJSON() ([]byte, error) {
return json.Marshal(mt.Milli())
}
// UnmarshalJSON unmarshals data into mt.
func (mt *MilliTime) UnmarshalJSON(data []byte) error {
var milli int64
if err := json.Unmarshal(data, &milli); err != nil {
return err
}
mt.Time = time.Unix(0, milli*int64(time.Millisecond))
return nil
}
// GetBSON returns BSON base on mt.
func (mt MilliTime) GetBSON() (interface{}, error) {
return mt.Time, nil
}
// SetBSON sets raw into mt.
func (mt *MilliTime) SetBSON(raw bson.Raw) error {
return raw.Unmarshal(&mt.Time)
}
// Milli returns milliseconds for mt.
func (mt MilliTime) Milli() int64 {
return mt.UnixNano() / int64(time.Millisecond)
}

View File

@@ -1,126 +0,0 @@
package jsontype
import (
"strconv"
"testing"
"time"
"github.com/globalsign/mgo/bson"
"github.com/stretchr/testify/assert"
)
func TestMilliTime_GetBSON(t *testing.T) {
tests := []struct {
name string
tm time.Time
}{
{
name: "now",
tm: time.Now(),
},
{
name: "future",
tm: time.Now().Add(time.Hour),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got, err := MilliTime{test.tm}.GetBSON()
assert.Nil(t, err)
assert.Equal(t, test.tm, got)
})
}
}
func TestMilliTime_MarshalJSON(t *testing.T) {
tests := []struct {
name string
tm time.Time
}{
{
name: "now",
tm: time.Now(),
},
{
name: "future",
tm: time.Now().Add(time.Hour),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
b, err := MilliTime{test.tm}.MarshalJSON()
assert.Nil(t, err)
assert.Equal(t, strconv.FormatInt(test.tm.UnixNano()/1e6, 10), string(b))
})
}
}
func TestMilliTime_Milli(t *testing.T) {
tests := []struct {
name string
tm time.Time
}{
{
name: "now",
tm: time.Now(),
},
{
name: "future",
tm: time.Now().Add(time.Hour),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
n := MilliTime{test.tm}.Milli()
assert.Equal(t, test.tm.UnixNano()/1e6, n)
})
}
}
func TestMilliTime_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
tm time.Time
}{
{
name: "now",
tm: time.Now(),
},
{
name: "future",
tm: time.Now().Add(time.Hour),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var mt MilliTime
s := strconv.FormatInt(test.tm.UnixNano()/1e6, 10)
err := mt.UnmarshalJSON([]byte(s))
assert.Nil(t, err)
s1, err := mt.MarshalJSON()
assert.Nil(t, err)
assert.Equal(t, s, string(s1))
})
}
}
func TestUnmarshalWithError(t *testing.T) {
var mt MilliTime
assert.NotNil(t, mt.UnmarshalJSON([]byte("hello")))
}
func TestSetBSON(t *testing.T) {
data, err := bson.Marshal(time.Now())
assert.Nil(t, err)
var raw bson.Raw
assert.Nil(t, bson.Unmarshal(data, &raw))
var mt MilliTime
assert.Nil(t, mt.SetBSON(raw))
assert.NotNil(t, mt.SetBSON(bson.Raw{}))
}

View File

@@ -29,7 +29,7 @@ func Repr(v interface{}) string {
}
val := reflect.ValueOf(v)
if val.Kind() == reflect.Ptr && !val.IsNil() {
for val.Kind() == reflect.Ptr && !val.IsNil() {
val = val.Elem()
}

View File

@@ -1,6 +1,9 @@
package lang
import (
"encoding/json"
"errors"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
@@ -110,6 +113,28 @@ func TestRepr(t *testing.T) {
}
}
func TestReprOfValue(t *testing.T) {
t.Run("error", func(t *testing.T) {
assert.Equal(t, "error", reprOfValue(reflect.ValueOf(errors.New("error"))))
})
t.Run("stringer", func(t *testing.T) {
assert.Equal(t, "1.23", reprOfValue(reflect.ValueOf(json.Number("1.23"))))
})
t.Run("int", func(t *testing.T) {
assert.Equal(t, "1", reprOfValue(reflect.ValueOf(1)))
})
t.Run("int", func(t *testing.T) {
assert.Equal(t, "1", reprOfValue(reflect.ValueOf("1")))
})
t.Run("int", func(t *testing.T) {
assert.Equal(t, "1", reprOfValue(reflect.ValueOf(uint(1))))
})
}
type mockStringable struct{}
func (m mockStringable) String() string {

View File

@@ -8,7 +8,9 @@ type LogConf struct {
TimeFormat string `json:",optional"`
Path string `json:",default=logs"`
Level string `json:",default=info,options=[debug,info,error,severe]"`
MaxContentLength uint32 `json:",optional"`
Compress bool `json:",optional"`
Stat bool `json:",default=true"`
KeepDays int `json:",optional"`
StackCooldownMillis int `json:",default=100"`
// MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever.

View File

@@ -20,6 +20,8 @@ var (
timeFormat = "2006-01-02T15:04:05.000Z07:00"
logLevel uint32
encoding uint32 = jsonEncodingType
// maxContentLength is used to truncate the log content, 0 for not truncating.
maxContentLength uint32
// use uint32 for atomic operations
disableLog uint32
disableStat uint32
@@ -230,10 +232,16 @@ func SetUp(c LogConf) (err error) {
setupOnce.Do(func() {
setupLogLevel(c)
if !c.Stat {
DisableStat()
}
if len(c.TimeFormat) > 0 {
timeFormat = c.TimeFormat
}
atomic.StoreUint32(&maxContentLength, c.MaxContentLength)
switch c.Encoding {
case plainEncoding:
atomic.StoreUint32(&encoding, plainEncodingType)

View File

@@ -529,9 +529,9 @@ func TestSetLevel(t *testing.T) {
func TestSetLevelTwiceWithMode(t *testing.T) {
testModes := []string{
"mode",
"console",
"volumn",
"mode",
}
w := new(mockWriter)
old := writer.Swap(w)
@@ -791,9 +791,12 @@ func doTestStructedLogConsole(t *testing.T, w *mockWriter, write func(...interfa
func testSetLevelTwiceWithMode(t *testing.T, mode string, w *mockWriter) {
writer.Store(nil)
SetUp(LogConf{
Mode: mode,
Level: "error",
Path: "/dev/null",
Mode: mode,
Level: "debug",
Path: "/dev/null",
Encoding: plainEncoding,
Stat: false,
TimeFormat: time.RFC3339,
})
SetUp(LogConf{
Mode: mode,

View File

@@ -16,13 +16,13 @@ const (
const (
jsonEncodingType = iota
plainEncodingType
plainEncoding = "plain"
plainEncodingSep = '\t'
sizeRotationRule = "size"
)
const (
plainEncoding = "plain"
plainEncodingSep = '\t'
sizeRotationRule = "size"
accessFilename = "access.log"
errorFilename = "error.log"
severeFilename = "severe.log"
@@ -53,6 +53,7 @@ const (
spanKey = "span"
timestampKey = "@timestamp"
traceKey = "trace"
truncatedKey = "truncated"
)
var (
@@ -60,4 +61,6 @@ var (
ErrLogPathNotSet = errors.New("log path must be set")
// ErrLogServiceNameNotSet is an error that indicates that the service name is not set.
ErrLogServiceNameNotSet = errors.New("log service name must be set")
truncatedField = Field(truncatedKey, true)
)

View File

@@ -253,11 +253,11 @@ func (n nopWriter) Stack(_ interface{}) {
func (n nopWriter) Stat(_ interface{}, _ ...LogField) {
}
func buildFields(fields ...LogField) []string {
func buildPlainFields(fields ...LogField) []string {
var items []string
for _, field := range fields {
items = append(items, fmt.Sprintf("%s=%v", field.Key, field.Value))
items = append(items, fmt.Sprintf("%s=%+v", field.Key, field.Value))
}
return items
@@ -269,15 +269,29 @@ func combineGlobalFields(fields []LogField) []LogField {
return fields
}
return append(globals.([]LogField), fields...)
gf := globals.([]LogField)
ret := make([]LogField, 0, len(gf)+len(fields))
ret = append(ret, gf...)
ret = append(ret, fields...)
return ret
}
func output(writer io.Writer, level string, val interface{}, fields ...LogField) {
// only truncate string content, don't know how to truncate the values of other types.
if v, ok := val.(string); ok {
maxLen := atomic.LoadUint32(&maxContentLength)
if maxLen > 0 && len(v) > int(maxLen) {
val = v[:maxLen]
fields = append(fields, truncatedField)
}
}
fields = combineGlobalFields(fields)
switch atomic.LoadUint32(&encoding) {
case plainEncodingType:
writePlainAny(writer, level, val, buildFields(fields...)...)
writePlainAny(writer, level, val, buildPlainFields(fields...)...)
default:
entry := make(logEntry)
for _, field := range fields {

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"log"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
@@ -157,9 +158,40 @@ func TestWritePlainAny(t *testing.T) {
}
func TestLogWithLimitContentLength(t *testing.T) {
maxLen := atomic.LoadUint32(&maxContentLength)
atomic.StoreUint32(&maxContentLength, 10)
t.Cleanup(func() {
atomic.StoreUint32(&maxContentLength, maxLen)
})
t.Run("alert", func(t *testing.T) {
var buf bytes.Buffer
w := NewWriter(&buf)
w.Info("1234567890")
var v1 mockedEntry
if err := json.Unmarshal(buf.Bytes(), &v1); err != nil {
t.Fatal(err)
}
assert.Equal(t, "1234567890", v1.Content)
assert.False(t, v1.Truncated)
buf.Reset()
var v2 mockedEntry
w.Info("12345678901")
if err := json.Unmarshal(buf.Bytes(), &v2); err != nil {
t.Fatal(err)
}
assert.Equal(t, "1234567890", v2.Content)
assert.True(t, v2.Truncated)
})
}
type mockedEntry struct {
Level string `json:"level"`
Content string `json:"content"`
Level string `json:"level"`
Content string `json:"content"`
Truncated bool `json:"truncated"`
}
type easyToCloseWriter struct{}

View File

@@ -34,7 +34,7 @@ func getJsonUnmarshaler(opts ...UnmarshalOption) *Unmarshaler {
}
func unmarshalJsonBytes(content []byte, v interface{}, unmarshaler *Unmarshaler) error {
var m map[string]interface{}
var m interface{}
if err := jsonx.Unmarshal(content, &m); err != nil {
return err
}
@@ -43,7 +43,7 @@ func unmarshalJsonBytes(content []byte, v interface{}, unmarshaler *Unmarshaler)
}
func unmarshalJsonReader(reader io.Reader, v interface{}, unmarshaler *Unmarshaler) error {
var m map[string]interface{}
var m interface{}
if err := jsonx.UnmarshalFromReader(reader, &m); err != nil {
return err
}

View File

@@ -856,8 +856,7 @@ func TestUnmarshalBytesError(t *testing.T) {
}
err := UnmarshalJsonBytes([]byte(payload), &v)
assert.NotNil(t, err)
assert.True(t, strings.Contains(err.Error(), payload))
assert.Equal(t, errTypeMismatch, err)
}
func TestUnmarshalReaderError(t *testing.T) {
@@ -867,9 +866,7 @@ func TestUnmarshalReaderError(t *testing.T) {
Any string
}
err := UnmarshalJsonReader(reader, &v)
assert.NotNil(t, err)
assert.True(t, strings.Contains(err.Error(), payload))
assert.Equal(t, errTypeMismatch, UnmarshalJsonReader(reader, &v))
}
func TestUnmarshalMap(t *testing.T) {
@@ -920,3 +917,26 @@ func TestUnmarshalMap(t *testing.T) {
assert.Equal(t, "foo", v.Any)
})
}
func TestUnmarshalJsonArray(t *testing.T) {
var v []struct {
Name string `json:"name"`
Age int `json:"age"`
}
body := `[{"name":"kevin", "age": 18}]`
assert.NoError(t, UnmarshalJsonBytes([]byte(body), &v))
assert.Equal(t, 1, len(v))
assert.Equal(t, "kevin", v[0].Name)
assert.Equal(t, 18, v[0].Age)
}
func TestUnmarshalJsonBytesError(t *testing.T) {
var v []struct {
Name string `json:"name"`
Age int `json:"age"`
}
assert.Error(t, UnmarshalJsonBytes([]byte((``)), &v))
assert.Error(t, UnmarshalJsonReader(strings.NewReader(``), &v))
}

View File

@@ -71,8 +71,29 @@ func UnmarshalKey(m map[string]interface{}, v interface{}) error {
}
// Unmarshal unmarshals m into v.
func (u *Unmarshaler) Unmarshal(m map[string]interface{}, v interface{}) error {
return u.UnmarshalValuer(mapValuer(m), v)
func (u *Unmarshaler) Unmarshal(i interface{}, v interface{}) error {
valueType := reflect.TypeOf(v)
if valueType.Kind() != reflect.Ptr {
return errValueNotSettable
}
elemType := Deref(valueType)
switch iv := i.(type) {
case map[string]interface{}:
if elemType.Kind() != reflect.Struct {
return errTypeMismatch
}
return u.UnmarshalValuer(mapValuer(iv), v)
case []interface{}:
if elemType.Kind() != reflect.Slice {
return errTypeMismatch
}
return u.fillSlice(elemType, reflect.ValueOf(v).Elem(), iv)
default:
return errUnsupportedType
}
}
// UnmarshalValuer unmarshals m into v.
@@ -127,7 +148,6 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
}
baseType := fieldType.Elem()
baseKind := baseType.Kind()
dereffedBaseType := Deref(baseType)
dereffedBaseKind := dereffedBaseType.Kind()
refValue := reflect.ValueOf(mapValue)
@@ -156,11 +176,7 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
return err
}
if baseKind == reflect.Ptr {
conv.Index(i).Set(target)
} else {
conv.Index(i).Set(target.Elem())
}
SetValue(fieldType.Elem(), conv.Index(i), target.Elem())
case reflect.Slice:
if err := u.fillSlice(dereffedBaseType, conv.Index(i), ithValue); err != nil {
return err
@@ -214,9 +230,9 @@ func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
ithVal := slice.Index(index)
switch v := value.(type) {
case fmt.Stringer:
return setValue(baseKind, ithVal, v.String())
return setValueFromString(baseKind, ithVal, v.String())
case string:
return setValue(baseKind, ithVal, v)
return setValueFromString(baseKind, ithVal, v)
case map[string]interface{}:
return u.fillMap(ithVal.Type(), ithVal, value)
default:
@@ -230,7 +246,7 @@ func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
target := reflect.New(baseType).Elem()
target.Set(reflect.ValueOf(value))
ithVal.Set(target.Addr())
SetValue(ithVal.Type(), ithVal, target)
return nil
}
@@ -274,7 +290,6 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue inter
refValue := reflect.ValueOf(mapValue)
targetValue := reflect.MakeMapWithSize(mapType, refValue.Len())
fieldElemKind := elemType.Kind()
dereffedElemType := Deref(elemType)
dereffedElemKind := dereffedElemType.Kind()
@@ -301,11 +316,7 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue inter
return emptyValue, err
}
if fieldElemKind == reflect.Ptr {
targetValue.SetMapIndex(key, target)
} else {
targetValue.SetMapIndex(key, target.Elem())
}
SetMapIndexValue(elemType, targetValue, key, target.Elem())
case reflect.Map:
keythMap, ok := keythData.(map[string]interface{})
if !ok {
@@ -334,7 +345,7 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue inter
targetValue.SetMapIndex(key, reflect.ValueOf(v))
case json.Number:
target := reflect.New(dereffedElemType)
if err := setValue(dereffedElemKind, target.Elem(), v.String()); err != nil {
if err := setValueFromString(dereffedElemKind, target.Elem(), v.String()); err != nil {
return emptyValue, err
}
@@ -361,6 +372,26 @@ func (u *Unmarshaler) parseOptionsWithContext(field reflect.StructField, m Value
return key, nil, nil
}
if u.opts.canonicalKey != nil {
key = u.opts.canonicalKey(key)
if len(options.OptionalDep) > 0 {
// need to create a new fieldOption, because the original one is shared through cache.
options = &fieldOptions{
fieldOptionsWithContext: fieldOptionsWithContext{
Inherit: options.Inherit,
FromString: options.FromString,
Optional: options.Optional,
Options: options.Options,
Default: options.Default,
EnvVar: options.EnvVar,
Range: options.Range,
},
OptionalDep: u.opts.canonicalKey(options.OptionalDep),
}
}
}
optsWithContext, err := options.toOptionsWithContext(key, m, fullName)
if err != nil {
return "", nil, err
@@ -376,19 +407,51 @@ func (u *Unmarshaler) processAnonymousField(field reflect.StructField, value ref
return err
}
if _, hasValue := getValue(m, key); hasValue {
return fmt.Errorf("fields of %s can't be wrapped inside, because it's anonymous", key)
}
if options.optional() {
return u.processAnonymousFieldOptional(field.Type, value, key, m, fullName)
return u.processAnonymousFieldOptional(field, value, key, m, fullName)
}
return u.processAnonymousFieldRequired(field.Type, value, m, fullName)
return u.processAnonymousFieldRequired(field, value, m, fullName)
}
func (u *Unmarshaler) processAnonymousFieldOptional(fieldType reflect.Type, value reflect.Value,
func (u *Unmarshaler) processAnonymousFieldOptional(field reflect.StructField, value reflect.Value,
key string, m valuerWithParent, fullName string) error {
derefedFieldType := Deref(field.Type)
switch derefedFieldType.Kind() {
case reflect.Struct:
return u.processAnonymousStructFieldOptional(field.Type, value, key, m, fullName)
default:
return u.processNamedField(field, value, m, fullName)
}
}
func (u *Unmarshaler) processAnonymousFieldRequired(field reflect.StructField, value reflect.Value,
m valuerWithParent, fullName string) error {
fieldType := field.Type
maybeNewValue(fieldType, value)
derefedFieldType := Deref(fieldType)
indirectValue := reflect.Indirect(value)
switch derefedFieldType.Kind() {
case reflect.Struct:
for i := 0; i < derefedFieldType.NumField(); i++ {
if err := u.processField(derefedFieldType.Field(i), indirectValue.Field(i),
m, fullName); err != nil {
return err
}
}
default:
if err := u.processNamedField(field, indirectValue, m, fullName); err != nil {
return err
}
}
return nil
}
func (u *Unmarshaler) processAnonymousStructFieldOptional(fieldType reflect.Type,
value reflect.Value, key string, m valuerWithParent, fullName string) error {
var filled bool
var required int
var requiredFilled int
@@ -428,21 +491,6 @@ func (u *Unmarshaler) processAnonymousFieldOptional(fieldType reflect.Type, valu
return nil
}
func (u *Unmarshaler) processAnonymousFieldRequired(fieldType reflect.Type, value reflect.Value,
m valuerWithParent, fullName string) error {
maybeNewValue(fieldType, value)
derefedFieldType := Deref(fieldType)
indirectValue := reflect.Indirect(value)
for i := 0; i < derefedFieldType.NumField(); i++ {
if err := u.processField(derefedFieldType.Field(i), indirectValue.Field(i), m, fullName); err != nil {
return err
}
}
return nil
}
func (u *Unmarshaler) processField(field reflect.StructField, value reflect.Value,
m valuerWithParent, fullName string) error {
if usingDifferentKeys(u.key, field) {
@@ -481,7 +529,7 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
case valueKind == reflect.String && typeKind == reflect.Slice:
return u.fillSliceFromString(fieldType, value, mapValue)
case valueKind == reflect.String && derefedFieldType == durationType:
return fillDurationValue(fieldType.Kind(), value, mapValue.(string))
return fillDurationValue(fieldType, value, mapValue.(string))
default:
return u.processFieldPrimitive(fieldType, value, mapValue, opts, fullName)
}
@@ -517,8 +565,8 @@ func (u *Unmarshaler) processFieldPrimitive(fieldType reflect.Type, value reflec
func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type, value reflect.Value,
v json.Number, opts *fieldOptionsWithContext, fullName string) error {
fieldKind := fieldType.Kind()
typeKind := Deref(fieldType).Kind()
baseType := Deref(fieldType)
typeKind := baseType.Kind()
if err := validateJsonNumberRange(v, opts); err != nil {
return err
@@ -528,9 +576,7 @@ func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type
return err
}
if fieldKind == reflect.Ptr {
value = value.Elem()
}
target := reflect.New(Deref(fieldType)).Elem()
switch typeKind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
@@ -539,7 +585,7 @@ func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type
return err
}
value.SetInt(iValue)
target.SetInt(iValue)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
iValue, err := v.Int64()
if err != nil {
@@ -550,18 +596,20 @@ func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type
return fmt.Errorf("unmarshal %q with bad value %q", fullName, v.String())
}
value.SetUint(uint64(iValue))
target.SetUint(uint64(iValue))
case reflect.Float32, reflect.Float64:
fValue, err := v.Float64()
if err != nil {
return err
}
value.SetFloat(fValue)
target.SetFloat(fValue)
default:
return newTypeMismatchError(fullName)
}
SetValue(fieldType, value, target)
return nil
}
@@ -574,7 +622,7 @@ func (u *Unmarshaler) processFieldStruct(fieldType reflect.Type, value reflect.V
return err
}
value.Set(target.Addr())
SetValue(fieldType, value, target)
} else if err := u.unmarshalWithFullName(m, value.Addr().Interface(), fullName); err != nil {
return err
}
@@ -588,7 +636,13 @@ func (u *Unmarshaler) processFieldTextUnmarshaler(fieldType reflect.Type, value
var ok bool
if fieldType.Kind() == reflect.Ptr {
tval, ok = value.Interface().(encoding.TextUnmarshaler)
if value.Elem().Kind() == reflect.Ptr {
target := reflect.New(Deref(fieldType))
SetValue(fieldType.Elem(), value, target)
tval, ok = target.Interface().(encoding.TextUnmarshaler)
} else {
tval, ok = value.Interface().(encoding.TextUnmarshaler)
}
} else {
tval, ok = value.Addr().Interface().(encoding.TextUnmarshaler)
}
@@ -621,7 +675,7 @@ func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value ref
value.SetBool(val)
return nil
case durationType.Kind():
if err := fillDurationValue(fieldKind, value, envVal); err != nil {
if err := fillDurationValue(fieldType, value, envVal); err != nil {
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
}
@@ -693,44 +747,57 @@ func (u *Unmarshaler) processNamedFieldWithValue(fieldType reflect.Type, value r
return u.processFieldNotFromString(fieldType, value, vp, opts, fullName)
default:
if u.opts.fromString || opts.fromString() {
valueKind := reflect.TypeOf(mapValue).Kind()
if valueKind != reflect.String {
return fmt.Errorf("error: the value in map is not string, but %s", valueKind)
}
options := opts.options()
if len(options) > 0 {
if !stringx.Contains(options, mapValue.(string)) {
return fmt.Errorf(`error: value "%s" for field "%s" is not defined in options "%v"`,
mapValue, key, options)
}
}
return fillPrimitive(fieldType, value, mapValue, opts, fullName)
return u.processNamedFieldWithValueFromString(fieldType, value, mapValue,
key, opts, fullName)
}
return u.processFieldNotFromString(fieldType, value, vp, opts, fullName)
}
}
func (u *Unmarshaler) processNamedFieldWithValueFromString(fieldType reflect.Type, value reflect.Value,
mapValue interface{}, key string, opts *fieldOptionsWithContext, fullName string) error {
valueKind := reflect.TypeOf(mapValue).Kind()
if valueKind != reflect.String {
return fmt.Errorf("the value in map is not string, but %s", valueKind)
}
options := opts.options()
if len(options) > 0 {
var checkValue string
switch mt := mapValue.(type) {
case string:
checkValue = mt
case fmt.Stringer:
checkValue = mt.String()
default:
return fmt.Errorf("the value in map is not string or json.Number, but %s",
valueKind.String())
}
if !stringx.Contains(options, checkValue) {
return fmt.Errorf(`value "%s" for field "%s" is not defined in options "%v"`,
mapValue, key, options)
}
}
return fillPrimitive(fieldType, value, mapValue, opts, fullName)
}
func (u *Unmarshaler) processNamedFieldWithoutValue(fieldType reflect.Type, value reflect.Value,
opts *fieldOptionsWithContext, fullName string) error {
derefedType := Deref(fieldType)
fieldKind := derefedType.Kind()
if defaultValue, ok := opts.getDefault(); ok {
if fieldType.Kind() == reflect.Ptr {
maybeNewValue(fieldType, value)
value = value.Elem()
}
if derefedType == durationType {
return fillDurationValue(fieldKind, value, defaultValue)
return fillDurationValue(fieldType, value, defaultValue)
}
switch fieldKind {
case reflect.Array, reflect.Slice:
return u.fillSliceWithDefault(derefedType, value, defaultValue)
default:
return setValue(fieldKind, value, defaultValue)
return setValueFromString(fieldKind, value, defaultValue)
}
}
@@ -771,15 +838,22 @@ func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v interface{}, f
return err
}
rte := reflect.TypeOf(v).Elem()
if rte.Kind() != reflect.Struct {
valueType := reflect.TypeOf(v)
baseType := Deref(valueType)
if baseType.Kind() != reflect.Struct {
return errValueNotStruct
}
rve := rv.Elem()
numFields := rte.NumField()
valElem := rv.Elem()
if valElem.Kind() == reflect.Ptr {
target := reflect.New(baseType).Elem()
SetValue(valueType.Elem(), valElem, target)
valElem = target
}
numFields := baseType.NumField()
for i := 0; i < numFields; i++ {
if err := u.processField(rte.Field(i), rve.Field(i), m, fullName); err != nil {
if err := u.processField(baseType.Field(i), valElem.Field(i), m, fullName); err != nil {
return err
}
}
@@ -815,17 +889,13 @@ func createValuer(v valuerWithParent, opts *fieldOptionsWithContext) valuerWithP
}
}
func fillDurationValue(fieldKind reflect.Kind, value reflect.Value, dur string) error {
func fillDurationValue(fieldType reflect.Type, value reflect.Value, dur string) error {
d, err := time.ParseDuration(dur)
if err != nil {
return err
}
if fieldKind == reflect.Ptr {
value.Elem().Set(reflect.ValueOf(d))
} else {
value.Set(reflect.ValueOf(d))
}
SetValue(fieldType, value, reflect.ValueOf(d))
return nil
}
@@ -841,7 +911,7 @@ func fillPrimitive(fieldType reflect.Type, value reflect.Value, mapValue interfa
target := reflect.New(baseType).Elem()
switch mapValue.(type) {
case string, json.Number:
value.Set(target.Addr())
SetValue(fieldType, value, target)
value = target
}
}
@@ -853,7 +923,7 @@ func fillPrimitive(fieldType reflect.Type, value reflect.Value, mapValue interfa
if err := validateJsonNumberRange(v, opts); err != nil {
return err
}
return setValue(baseType.Kind(), value, v.String())
return setValueFromString(baseType.Kind(), value, v.String())
default:
return newTypeMismatchError(fullName)
}
@@ -873,7 +943,7 @@ func fillWithSameType(fieldType reflect.Type, value reflect.Value, mapValue inte
baseType := Deref(fieldType)
target := reflect.New(baseType).Elem()
setSameKindValue(baseType, target, mapValue)
value.Set(target.Addr())
SetValue(fieldType, value, target)
} else {
setSameKindValue(fieldType, value, mapValue)
}

File diff suppressed because it is too large Load Diff

View File

@@ -56,7 +56,7 @@ type (
// Deref dereferences a type, if pointer type, returns its element type.
func Deref(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
@@ -68,6 +68,16 @@ func Repr(v interface{}) string {
return lang.Repr(v)
}
// SetValue sets target to value, pointers are processed automatically.
func SetValue(tp reflect.Type, value, target reflect.Value) {
value.Set(convertTypeOfPtr(tp, target))
}
// SetMapIndexValue sets target to value at key position, pointers are processed automatically.
func SetMapIndexValue(tp reflect.Type, value, key, target reflect.Value) {
value.SetMapIndex(key, convertTypeOfPtr(tp, target))
}
// ValidatePtr validates v if it's a valid pointer.
func ValidatePtr(v *reflect.Value) error {
// sequence is very important, IsNil must be called after checking Kind() with reflect.Ptr,
@@ -79,7 +89,7 @@ func ValidatePtr(v *reflect.Value) error {
return nil
}
func convertType(kind reflect.Kind, str string) (interface{}, error) {
func convertTypeFromString(kind reflect.Kind, str string) (interface{}, error) {
switch kind {
case reflect.Bool:
switch strings.ToLower(str) {
@@ -118,6 +128,23 @@ func convertType(kind reflect.Kind, str string) (interface{}, error) {
}
}
func convertTypeOfPtr(tp reflect.Type, target reflect.Value) reflect.Value {
// keep the original value is a pointer
if tp.Kind() == reflect.Ptr && target.CanAddr() {
tp = tp.Elem()
target = target.Addr()
}
for tp.Kind() == reflect.Ptr {
p := reflect.New(target.Type())
p.Elem().Set(target)
target = p
tp = tp.Elem()
}
return target
}
func doParseKeyAndOptions(field reflect.StructField, value string) (string, *fieldOptions, error) {
segments := parseSegments(value)
key := strings.TrimSpace(segments[0])
@@ -476,13 +503,13 @@ func setMatchedPrimitiveValue(kind reflect.Kind, value reflect.Value, v interfac
return nil
}
func setValue(kind reflect.Kind, value reflect.Value, str string) error {
func setValueFromString(kind reflect.Kind, value reflect.Value, str string) error {
if !value.CanSet() {
return errValueNotSettable
}
value = ensureValue(value)
v, err := convertType(kind, str)
v, err := convertTypeFromString(kind, str)
if err != nil {
return err
}
@@ -555,7 +582,7 @@ func validateAndSetValue(kind reflect.Kind, value reflect.Value, str string, opt
return errValueNotSettable
}
v, err := convertType(kind, str)
v, err := convertTypeFromString(kind, str)
if err != nil {
return err
}

View File

@@ -237,7 +237,7 @@ func TestValidatePtrWithZeroValue(t *testing.T) {
func TestSetValueNotSettable(t *testing.T) {
var i int
assert.NotNil(t, setValue(reflect.Int, reflect.ValueOf(i), "1"))
assert.NotNil(t, setValueFromString(reflect.Int, reflect.ValueOf(i), "1"))
}
func TestParseKeyAndOptionsErrors(t *testing.T) {
@@ -290,7 +290,7 @@ func TestSetValueFormatErrors(t *testing.T) {
for _, test := range tests {
t.Run(test.kind.String(), func(t *testing.T) {
err := setValue(test.kind, test.target, test.value)
err := setValueFromString(test.kind, test.target, test.value)
assert.NotEqual(t, errValueNotSettable, err)
assert.NotNil(t, err)
})

View File

@@ -31,3 +31,27 @@ func TestMapValuerWithInherit_Value(t *testing.T) {
assert.Equal(t, "localhost", m["host"])
assert.Equal(t, 8080, m["port"])
}
func TestRecursiveValuer_Value(t *testing.T) {
input := map[string]interface{}{
"component": map[string]interface{}{
"name": "test",
"foo": map[string]interface{}{
"bar": "baz",
},
},
"foo": "value",
}
valuer := recursiveValuer{
current: mapValuer(input["component"].(map[string]interface{})),
parent: simpleValuer{
current: mapValuer(input),
},
}
val, ok := valuer.Value("foo")
assert.True(t, ok)
assert.EqualValues(t, map[string]interface{}{
"bar": "baz",
}, val)
}

View File

@@ -45,9 +45,13 @@ func RawFieldNames(in interface{}, postgresSql ...bool) []string {
// `db:"id"`
// `db:"id,type=char,length=16"`
// `db:",type=char,length=16"`
// `db:"-,type=char,length=16"`
if strings.Contains(tagv, ",") {
tagv = strings.TrimSpace(strings.Split(tagv, ",")[0])
}
if tagv == "-" {
continue
}
if len(tagv) == 0 {
tagv = fi.Name
}

View File

@@ -39,3 +39,33 @@ func TestFieldNamesWithTagOptions(t *testing.T) {
assert.Equal(t, expected, out)
})
}
type mockedUserWithDashTag struct {
ID string `db:"id" json:"id,omitempty"`
UserName string `db:"user_name" json:"userName,omitempty"`
Mobile string `db:"-" json:"mobile,omitempty"`
}
func TestFieldNamesWithDashTag(t *testing.T) {
t.Run("new", func(t *testing.T) {
var u mockedUserWithDashTag
out := RawFieldNames(&u)
expected := []string{"`id`", "`user_name`"}
assert.Equal(t, expected, out)
})
}
type mockedUserWithDashTagAndOptions struct {
ID string `db:"id" json:"id,omitempty"`
UserName string `db:"user_name,type=varchar,length=255" json:"userName,omitempty"`
Mobile string `db:"-,type=varchar,length=255" json:"mobile,omitempty"`
}
func TestFieldNamesWithDashTagAndOptions(t *testing.T) {
t.Run("new", func(t *testing.T) {
var u mockedUserWithDashTagAndOptions
out := RawFieldNames(&u)
expected := []string{"`id`", "`user_name`"}
assert.Equal(t, expected, out)
})
}

View File

@@ -10,6 +10,7 @@ import (
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/errorx"
"github.com/zeromicro/go-zero/core/hash"
@@ -109,51 +110,85 @@ func (mc *mockedNode) TakeWithExpireCtx(ctx context.Context, val interface{}, ke
}
func TestCache_SetDel(t *testing.T) {
const total = 1000
r1, clean1, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean1()
r2, clean2, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean2()
conf := ClusterConf{
{
RedisConf: redis.RedisConf{
Host: r1.Addr,
Type: redis.NodeType,
t.Run("test set del", func(t *testing.T) {
const total = 1000
r1, clean1, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean1()
r2, clean2, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean2()
conf := ClusterConf{
{
RedisConf: redis.RedisConf{
Host: r1.Addr,
Type: redis.NodeType,
},
Weight: 100,
},
Weight: 100,
},
{
RedisConf: redis.RedisConf{
Host: r2.Addr,
Type: redis.NodeType,
{
RedisConf: redis.RedisConf{
Host: r2.Addr,
Type: redis.NodeType,
},
Weight: 100,
},
Weight: 100,
},
}
c := New(conf, syncx.NewSingleFlight(), NewStat("mock"), errPlaceholder)
for i := 0; i < total; i++ {
if i%2 == 0 {
assert.Nil(t, c.Set(fmt.Sprintf("key/%d", i), i))
} else {
assert.Nil(t, c.SetWithExpire(fmt.Sprintf("key/%d", i), i, 0))
}
}
for i := 0; i < total; i++ {
var val int
assert.Nil(t, c.Get(fmt.Sprintf("key/%d", i), &val))
assert.Equal(t, i, val)
}
assert.Nil(t, c.Del())
for i := 0; i < total; i++ {
assert.Nil(t, c.Del(fmt.Sprintf("key/%d", i)))
}
for i := 0; i < total; i++ {
var val int
assert.True(t, c.IsNotFound(c.Get(fmt.Sprintf("key/%d", i), &val)))
assert.Equal(t, 0, val)
}
c := New(conf, syncx.NewSingleFlight(), NewStat("mock"), errPlaceholder)
for i := 0; i < total; i++ {
if i%2 == 0 {
assert.Nil(t, c.Set(fmt.Sprintf("key/%d", i), i))
} else {
assert.Nil(t, c.SetWithExpire(fmt.Sprintf("key/%d", i), i, 0))
}
}
for i := 0; i < total; i++ {
var val int
assert.Nil(t, c.Get(fmt.Sprintf("key/%d", i), &val))
assert.Equal(t, i, val)
}
assert.Nil(t, c.Del())
for i := 0; i < total; i++ {
assert.Nil(t, c.Del(fmt.Sprintf("key/%d", i)))
}
assert.Nil(t, c.Del("a", "b", "c"))
for i := 0; i < total; i++ {
var val int
assert.True(t, c.IsNotFound(c.Get(fmt.Sprintf("key/%d", i), &val)))
assert.Equal(t, 0, val)
}
})
t.Run("test set del error", func(t *testing.T) {
r1, err := miniredis.Run()
assert.NoError(t, err)
defer r1.Close()
r1.SetError("mock error")
r2, err := miniredis.Run()
assert.NoError(t, err)
defer r2.Close()
r2.SetError("mock error")
conf := ClusterConf{
{
RedisConf: redis.RedisConf{
Host: r1.Addr(),
Type: redis.NodeType,
},
Weight: 100,
},
{
RedisConf: redis.RedisConf{
Host: r2.Addr(),
Type: redis.NodeType,
},
Weight: 100,
},
}
c := New(conf, syncx.NewSingleFlight(), NewStat("mock"), errPlaceholder)
assert.NoError(t, c.Del("a", "b", "c"))
})
}
func TestCache_OneNode(t *testing.T) {

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"math/rand"
"runtime"
"strconv"
"sync"
"testing"
@@ -11,12 +12,14 @@ import (
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/collection"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/core/stores/redis/redistest"
"github.com/zeromicro/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/timex"
)
var errTestNotFound = errors.New("not found")
@@ -27,27 +30,54 @@ func init() {
}
func TestCacheNode_DelCache(t *testing.T) {
store, clean, err := redistest.CreateRedis()
assert.Nil(t, err)
store.Type = redis.ClusterType
defer clean()
t.Run("del cache", func(t *testing.T) {
store, clean, err := redistest.CreateRedis()
assert.Nil(t, err)
store.Type = redis.ClusterType
defer clean()
cn := cacheNode{
rds: store,
r: rand.New(rand.NewSource(time.Now().UnixNano())),
lock: new(sync.Mutex),
unstableExpiry: mathx.NewUnstable(expiryDeviation),
stat: NewStat("any"),
errNotFound: errTestNotFound,
}
assert.Nil(t, cn.Del())
assert.Nil(t, cn.Del([]string{}...))
assert.Nil(t, cn.Del(make([]string, 0)...))
cn.Set("first", "one")
assert.Nil(t, cn.Del("first"))
cn.Set("first", "one")
cn.Set("second", "two")
assert.Nil(t, cn.Del("first", "second"))
cn := cacheNode{
rds: store,
r: rand.New(rand.NewSource(time.Now().UnixNano())),
lock: new(sync.Mutex),
unstableExpiry: mathx.NewUnstable(expiryDeviation),
stat: NewStat("any"),
errNotFound: errTestNotFound,
}
assert.Nil(t, cn.Del())
assert.Nil(t, cn.Del([]string{}...))
assert.Nil(t, cn.Del(make([]string, 0)...))
cn.Set("first", "one")
assert.Nil(t, cn.Del("first"))
cn.Set("first", "one")
cn.Set("second", "two")
assert.Nil(t, cn.Del("first", "second"))
})
t.Run("del cache with errors", func(t *testing.T) {
old := timingWheel
ticker := timex.NewFakeTicker()
var err error
timingWheel, err = collection.NewTimingWheelWithTicker(
time.Millisecond, timingWheelSlots, func(key, value interface{}) {
clean(key, value)
}, ticker)
assert.NoError(t, err)
t.Cleanup(func() {
timingWheel = old
})
r, err := miniredis.Run()
assert.NoError(t, err)
defer r.Close()
r.SetError("mock error")
node := NewNode(redis.New(r.Addr(), redis.Cluster()), syncx.NewSingleFlight(),
NewStat("any"), errTestNotFound)
assert.NoError(t, node.Del("foo", "bar"))
ticker.Tick()
runtime.Gosched()
})
}
func TestCacheNode_DelCacheWithErrors(t *testing.T) {
@@ -125,6 +155,21 @@ func TestCacheNode_Take(t *testing.T) {
assert.Equal(t, `"value"`, val)
}
func TestCacheNode_TakeBadRedis(t *testing.T) {
r, err := miniredis.Run()
assert.NoError(t, err)
defer r.Close()
r.SetError("mock error")
cn := NewNode(redis.New(r.Addr()), syncx.NewSingleFlight(), NewStat("any"),
errTestNotFound, WithExpiry(time.Second), WithNotFoundExpiry(time.Second))
var str string
assert.Error(t, cn.Take(&str, "any", func(v interface{}) error {
*v.(*string) = "value"
return nil
}))
}
func TestCacheNode_TakeNotFound(t *testing.T) {
store, clean, err := redistest.CreateRedis()
assert.Nil(t, err)

View File

@@ -5,6 +5,7 @@ import (
"time"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/timex"
)
const statInterval = time.Minute
@@ -25,7 +26,13 @@ func NewStat(name string) *Stat {
ret := &Stat{
name: name,
}
go ret.statLoop()
go func() {
ticker := timex.NewTicker(statInterval)
defer ticker.Stop()
ret.statLoop(ticker)
}()
return ret
}
@@ -50,11 +57,8 @@ func (s *Stat) IncrementDbFails() {
atomic.AddUint64(&s.DbFails, 1)
}
func (s *Stat) statLoop() {
ticker := time.NewTicker(statInterval)
defer ticker.Stop()
for range ticker.C {
func (s *Stat) statLoop(ticker timex.Ticker) {
for range ticker.Chan() {
total := atomic.SwapUint64(&s.Total, 0)
if total == 0 {
continue

28
core/stores/cache/cachestat_test.go vendored Normal file
View File

@@ -0,0 +1,28 @@
package cache
import (
"testing"
"github.com/zeromicro/go-zero/core/timex"
)
func TestCacheStat_statLoop(t *testing.T) {
t.Run("stat loop total 0", func(t *testing.T) {
var stat Stat
ticker := timex.NewFakeTicker()
go stat.statLoop(ticker)
ticker.Tick()
ticker.Tick()
ticker.Stop()
})
t.Run("stat loop total not 0", func(t *testing.T) {
var stat Stat
stat.IncrementTotal()
ticker := timex.NewFakeTicker()
go stat.statLoop(ticker)
ticker.Tick()
ticker.Tick()
ticker.Stop()
})
}

View File

@@ -26,9 +26,14 @@ type (
)
// NewBulkInserter returns a BulkInserter.
func NewBulkInserter(coll *mongo.Collection, interval ...time.Duration) *BulkInserter {
func NewBulkInserter(coll Collection, interval ...time.Duration) (*BulkInserter, error) {
cloneColl, err := coll.Clone()
if err != nil {
return nil, err
}
inserter := &dbInserter{
collection: coll,
collection: cloneColl,
}
duration := flushInterval
@@ -39,7 +44,7 @@ func NewBulkInserter(coll *mongo.Collection, interval ...time.Duration) *BulkIns
return &BulkInserter{
executor: executors.NewPeriodicalExecutor(duration, inserter),
inserter: inserter,
}
}, nil
}
// Flush flushes the inserter, writes all pending records.

View File

@@ -15,7 +15,8 @@ func TestBulkInserter(t *testing.T) {
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "ok", Value: 1}}...))
bulk := NewBulkInserter(mt.Coll)
bulk, err := NewBulkInserter(createModel(mt).Collection)
assert.Equal(t, err, nil)
bulk.SetResultHandler(func(result *mongo.InsertManyResult, err error) {
assert.Nil(t, err)
assert.Equal(t, 2, len(result.InsertedIDs))

View File

@@ -3,15 +3,12 @@ package mon
import (
"context"
"io"
"time"
"github.com/zeromicro/go-zero/core/syncx"
"go.mongodb.org/mongo-driver/mongo"
mopt "go.mongodb.org/mongo-driver/mongo/options"
)
const defaultTimeout = time.Second
var clientManager = syncx.NewResourceManager()
// ClosableClient wraps *mongo.Client and provides a Close method.
@@ -30,9 +27,20 @@ func Inject(key string, client *mongo.Client) {
clientManager.Inject(key, &ClosableClient{client})
}
func getClient(url string) (*mongo.Client, error) {
func getClient(url string, opts ...Option) (*mongo.Client, error) {
val, err := clientManager.GetResource(url, func() (io.Closer, error) {
cli, err := mongo.Connect(context.Background(), mopt.Client().ApplyURI(url))
o := mopt.Client().ApplyURI(url)
opts = append([]Option{defaultTimeoutOption()}, opts...)
for _, opt := range opts {
opt(o)
}
cli, err := mongo.Connect(context.Background(), o)
if err != nil {
return nil, err
}
err = cli.Ping(context.Background(), nil)
if err != nil {
return nil, err
}

View File

@@ -48,7 +48,7 @@ func MustNewModel(uri, db, collection string, opts ...Option) *Model {
// NewModel returns a Model.
func NewModel(uri, db, collection string, opts ...Option) (*Model, error) {
cli, err := getClient(uri)
cli, err := getClient(uri, opts...)
if err != nil {
return nil, err
}

View File

@@ -4,14 +4,15 @@ import (
"time"
"github.com/zeromicro/go-zero/core/syncx"
mopt "go.mongodb.org/mongo-driver/mongo/options"
)
const defaultTimeout = time.Second * 3
var slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold)
type (
options struct {
timeout time.Duration
}
options = mopt.ClientOptions
// Option defines the method to customize a mongo model.
Option func(opts *options)
@@ -22,8 +23,15 @@ func SetSlowThreshold(threshold time.Duration) {
slowThreshold.Set(threshold)
}
func defaultOptions() *options {
return &options{
timeout: defaultTimeout,
func defaultTimeoutOption() Option {
return func(opts *options) {
opts.SetTimeout(defaultTimeout)
}
}
// WithTimeout set the mon client operation timeout.
func WithTimeout(timeout time.Duration) Option {
return func(opts *options) {
opts.SetTimeout(timeout)
}
}

View File

@@ -5,6 +5,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
mopt "go.mongodb.org/mongo-driver/mongo/options"
)
func TestSetSlowThreshold(t *testing.T) {
@@ -13,6 +14,14 @@ func TestSetSlowThreshold(t *testing.T) {
assert.Equal(t, time.Second, slowThreshold.Load())
}
func TestDefaultOptions(t *testing.T) {
assert.Equal(t, defaultTimeout, defaultOptions().timeout)
func Test_defaultTimeoutOption(t *testing.T) {
opts := mopt.Client()
defaultTimeoutOption()(opts)
assert.Equal(t, defaultTimeout, *opts.Timeout)
}
func TestWithTimeout(t *testing.T) {
opts := mopt.Client()
WithTimeout(time.Second)(opts)
assert.Equal(t, time.Second, *opts.Timeout)
}

View File

@@ -14,12 +14,13 @@ import (
var mongoCmdAttributeKey = attribute.Key("mongo.cmd")
func startSpan(ctx context.Context, cmd string) (context.Context, oteltrace.Span) {
tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
tracer := otel.Tracer(trace.TraceName)
ctx, span := tracer.Start(ctx,
spanName,
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
)
span.SetAttributes(mongoCmdAttributeKey.String(cmd))
return ctx, span
}

View File

@@ -1,92 +0,0 @@
package mongo
import (
"time"
"github.com/globalsign/mgo"
"github.com/zeromicro/go-zero/core/executors"
"github.com/zeromicro/go-zero/core/logx"
)
const (
flushInterval = time.Second
maxBulkRows = 1000
)
type (
// ResultHandler is a handler that used to handle results.
ResultHandler func(*mgo.BulkResult, error)
// A BulkInserter is used to insert bulk of mongo records.
BulkInserter struct {
executor *executors.PeriodicalExecutor
inserter *dbInserter
}
)
// NewBulkInserter returns a BulkInserter.
func NewBulkInserter(session *mgo.Session, dbName string, collectionNamer func() string) *BulkInserter {
inserter := &dbInserter{
session: session,
dbName: dbName,
collectionNamer: collectionNamer,
}
return &BulkInserter{
executor: executors.NewPeriodicalExecutor(flushInterval, inserter),
inserter: inserter,
}
}
// Flush flushes the inserter, writes all pending records.
func (bi *BulkInserter) Flush() {
bi.executor.Flush()
}
// Insert inserts doc.
func (bi *BulkInserter) Insert(doc interface{}) {
bi.executor.Add(doc)
}
// SetResultHandler sets the result handler.
func (bi *BulkInserter) SetResultHandler(handler ResultHandler) {
bi.executor.Sync(func() {
bi.inserter.resultHandler = handler
})
}
type dbInserter struct {
session *mgo.Session
dbName string
collectionNamer func() string
documents []interface{}
resultHandler ResultHandler
}
func (in *dbInserter) AddTask(doc interface{}) bool {
in.documents = append(in.documents, doc)
return len(in.documents) >= maxBulkRows
}
func (in *dbInserter) Execute(objs interface{}) {
docs := objs.([]interface{})
if len(docs) == 0 {
return
}
bulk := in.session.DB(in.dbName).C(in.collectionNamer()).Bulk()
bulk.Insert(docs...)
bulk.Unordered()
result, err := bulk.Run()
if in.resultHandler != nil {
in.resultHandler(result, err)
} else if err != nil {
logx.Error(err)
}
}
func (in *dbInserter) RemoveAll() interface{} {
documents := in.documents
in.documents = nil
return documents
}

View File

@@ -1,242 +0,0 @@
package mongo
import (
"encoding/json"
"time"
"github.com/globalsign/mgo"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/mongo/internal"
"github.com/zeromicro/go-zero/core/timex"
)
const defaultSlowThreshold = time.Millisecond * 500
// ErrNotFound is an alias of mgo.ErrNotFound.
var ErrNotFound = mgo.ErrNotFound
type (
// Collection interface represents a mongo connection.
Collection interface {
Find(query interface{}) Query
FindId(id interface{}) Query
Insert(docs ...interface{}) error
Pipe(pipeline interface{}) Pipe
Remove(selector interface{}) error
RemoveAll(selector interface{}) (*mgo.ChangeInfo, error)
RemoveId(id interface{}) error
Update(selector, update interface{}) error
UpdateId(id, update interface{}) error
Upsert(selector, update interface{}) (*mgo.ChangeInfo, error)
}
decoratedCollection struct {
name string
collection internal.MgoCollection
brk breaker.Breaker
}
keepablePromise struct {
promise breaker.Promise
log func(error)
}
)
func newCollection(collection *mgo.Collection, brk breaker.Breaker) Collection {
return &decoratedCollection{
name: collection.FullName,
collection: collection,
brk: brk,
}
}
func (c *decoratedCollection) Find(query interface{}) Query {
promise, err := c.brk.Allow()
if err != nil {
return rejectedQuery{}
}
startTime := timex.Now()
return promisedQuery{
Query: c.collection.Find(query),
promise: keepablePromise{
promise: promise,
log: func(err error) {
duration := timex.Since(startTime)
c.logDuration("find", duration, err, query)
},
},
}
}
func (c *decoratedCollection) FindId(id interface{}) Query {
promise, err := c.brk.Allow()
if err != nil {
return rejectedQuery{}
}
startTime := timex.Now()
return promisedQuery{
Query: c.collection.FindId(id),
promise: keepablePromise{
promise: promise,
log: func(err error) {
duration := timex.Since(startTime)
c.logDuration("findId", duration, err, id)
},
},
}
}
func (c *decoratedCollection) Insert(docs ...interface{}) (err error) {
return c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
duration := timex.Since(startTime)
c.logDuration("insert", duration, err, docs...)
}()
return c.collection.Insert(docs...)
}, acceptable)
}
func (c *decoratedCollection) Pipe(pipeline interface{}) Pipe {
promise, err := c.brk.Allow()
if err != nil {
return rejectedPipe{}
}
startTime := timex.Now()
return promisedPipe{
Pipe: c.collection.Pipe(pipeline),
promise: keepablePromise{
promise: promise,
log: func(err error) {
duration := timex.Since(startTime)
c.logDuration("pipe", duration, err, pipeline)
},
},
}
}
func (c *decoratedCollection) Remove(selector interface{}) (err error) {
return c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
duration := timex.Since(startTime)
c.logDuration("remove", duration, err, selector)
}()
return c.collection.Remove(selector)
}, acceptable)
}
func (c *decoratedCollection) RemoveAll(selector interface{}) (info *mgo.ChangeInfo, err error) {
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
duration := timex.Since(startTime)
c.logDuration("removeAll", duration, err, selector)
}()
info, err = c.collection.RemoveAll(selector)
return err
}, acceptable)
return
}
func (c *decoratedCollection) RemoveId(id interface{}) (err error) {
return c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
duration := timex.Since(startTime)
c.logDuration("removeId", duration, err, id)
}()
return c.collection.RemoveId(id)
}, acceptable)
}
func (c *decoratedCollection) Update(selector, update interface{}) (err error) {
return c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
duration := timex.Since(startTime)
c.logDuration("update", duration, err, selector, update)
}()
return c.collection.Update(selector, update)
}, acceptable)
}
func (c *decoratedCollection) UpdateId(id, update interface{}) (err error) {
return c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
duration := timex.Since(startTime)
c.logDuration("updateId", duration, err, id, update)
}()
return c.collection.UpdateId(id, update)
}, acceptable)
}
func (c *decoratedCollection) Upsert(selector, update interface{}) (info *mgo.ChangeInfo, err error) {
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
duration := timex.Since(startTime)
c.logDuration("upsert", duration, err, selector, update)
}()
info, err = c.collection.Upsert(selector, update)
return err
}, acceptable)
return
}
func (c *decoratedCollection) logDuration(method string, duration time.Duration, err error, docs ...interface{}) {
content, e := json.Marshal(docs)
if e != nil {
logx.Error(err)
} else if err != nil {
if duration > slowThreshold.Load() {
logx.WithDuration(duration).Slowf("[MONGO] mongo(%s) - slowcall - %s - fail(%s) - %s",
c.name, method, err.Error(), string(content))
} else {
logx.WithDuration(duration).Infof("mongo(%s) - %s - fail(%s) - %s",
c.name, method, err.Error(), string(content))
}
} else {
if duration > slowThreshold.Load() {
logx.WithDuration(duration).Slowf("[MONGO] mongo(%s) - slowcall - %s - ok - %s",
c.name, method, string(content))
} else {
logx.WithDuration(duration).Infof("mongo(%s) - %s - ok - %s", c.name, method, string(content))
}
}
}
func (p keepablePromise) accept(err error) error {
p.promise.Accept()
p.log(err)
return err
}
func (p keepablePromise) keep(err error) error {
if acceptable(err) {
p.promise.Accept()
} else {
p.promise.Reject(err.Error())
}
p.log(err)
return err
}
func acceptable(err error) bool {
return err == nil || err == mgo.ErrNotFound
}

View File

@@ -1,345 +0,0 @@
package mongo
import (
"errors"
"strings"
"testing"
"time"
"github.com/globalsign/mgo"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/mongo/internal"
"github.com/zeromicro/go-zero/core/stringx"
)
var errDummy = errors.New("dummy")
func TestKeepPromise_accept(t *testing.T) {
p := new(mockPromise)
kp := keepablePromise{
promise: p,
log: func(error) {},
}
assert.Nil(t, kp.accept(nil))
assert.Equal(t, mgo.ErrNotFound, kp.accept(mgo.ErrNotFound))
}
func TestKeepPromise_keep(t *testing.T) {
tests := []struct {
err error
accepted bool
reason string
}{
{
err: nil,
accepted: true,
reason: "",
},
{
err: mgo.ErrNotFound,
accepted: true,
reason: "",
},
{
err: errors.New("any"),
accepted: false,
reason: "any",
},
}
for _, test := range tests {
t.Run(stringx.RandId(), func(t *testing.T) {
p := new(mockPromise)
kp := keepablePromise{
promise: p,
log: func(error) {},
}
assert.Equal(t, test.err, kp.keep(test.err))
assert.Equal(t, test.accepted, p.accepted)
assert.Equal(t, test.reason, p.reason)
})
}
}
func TestNewCollection(t *testing.T) {
col := newCollection(&mgo.Collection{
Database: nil,
Name: "foo",
FullName: "bar",
}, breaker.GetBreaker("localhost"))
assert.Equal(t, "bar", col.(*decoratedCollection).name)
}
func TestCollectionFind(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
var query mgo.Query
col := internal.NewMockMgoCollection(ctrl)
col.EXPECT().Find(gomock.Any()).Return(&query)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
actual := c.Find(nil)
switch v := actual.(type) {
case promisedQuery:
assert.Equal(t, &query, v.Query)
assert.Equal(t, errDummy, v.promise.keep(errDummy))
default:
t.Fail()
}
c.brk = new(dropBreaker)
actual = c.Find(nil)
assert.Equal(t, rejectedQuery{}, actual)
}
func TestCollectionFindId(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
var query mgo.Query
col := internal.NewMockMgoCollection(ctrl)
col.EXPECT().FindId(gomock.Any()).Return(&query)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
actual := c.FindId(nil)
switch v := actual.(type) {
case promisedQuery:
assert.Equal(t, &query, v.Query)
assert.Equal(t, errDummy, v.promise.keep(errDummy))
default:
t.Fail()
}
c.brk = new(dropBreaker)
actual = c.FindId(nil)
assert.Equal(t, rejectedQuery{}, actual)
}
func TestCollectionInsert(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
col := internal.NewMockMgoCollection(ctrl)
col.EXPECT().Insert(nil, nil).Return(errDummy)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
err := c.Insert(nil, nil)
assert.Equal(t, errDummy, err)
c.brk = new(dropBreaker)
err = c.Insert(nil, nil)
assert.Equal(t, errDummy, err)
}
func TestCollectionPipe(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
var pipe mgo.Pipe
col := internal.NewMockMgoCollection(ctrl)
col.EXPECT().Pipe(gomock.Any()).Return(&pipe)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
actual := c.Pipe(nil)
switch v := actual.(type) {
case promisedPipe:
assert.Equal(t, &pipe, v.Pipe)
assert.Equal(t, errDummy, v.promise.keep(errDummy))
default:
t.Fail()
}
c.brk = new(dropBreaker)
actual = c.Pipe(nil)
assert.Equal(t, rejectedPipe{}, actual)
}
func TestCollectionRemove(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
col := internal.NewMockMgoCollection(ctrl)
col.EXPECT().Remove(gomock.Any()).Return(errDummy)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
err := c.Remove(nil)
assert.Equal(t, errDummy, err)
c.brk = new(dropBreaker)
err = c.Remove(nil)
assert.Equal(t, errDummy, err)
}
func TestCollectionRemoveAll(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
col := internal.NewMockMgoCollection(ctrl)
col.EXPECT().RemoveAll(gomock.Any()).Return(nil, errDummy)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
_, err := c.RemoveAll(nil)
assert.Equal(t, errDummy, err)
c.brk = new(dropBreaker)
_, err = c.RemoveAll(nil)
assert.Equal(t, errDummy, err)
}
func TestCollectionRemoveId(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
col := internal.NewMockMgoCollection(ctrl)
col.EXPECT().RemoveId(gomock.Any()).Return(errDummy)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
err := c.RemoveId(nil)
assert.Equal(t, errDummy, err)
c.brk = new(dropBreaker)
err = c.RemoveId(nil)
assert.Equal(t, errDummy, err)
}
func TestCollectionUpdate(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
col := internal.NewMockMgoCollection(ctrl)
col.EXPECT().Update(gomock.Any(), gomock.Any()).Return(errDummy)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
err := c.Update(nil, nil)
assert.Equal(t, errDummy, err)
c.brk = new(dropBreaker)
err = c.Update(nil, nil)
assert.Equal(t, errDummy, err)
}
func TestCollectionUpdateId(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
col := internal.NewMockMgoCollection(ctrl)
col.EXPECT().UpdateId(gomock.Any(), gomock.Any()).Return(errDummy)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
err := c.UpdateId(nil, nil)
assert.Equal(t, errDummy, err)
c.brk = new(dropBreaker)
err = c.UpdateId(nil, nil)
assert.Equal(t, errDummy, err)
}
func TestCollectionUpsert(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
col := internal.NewMockMgoCollection(ctrl)
col.EXPECT().Upsert(gomock.Any(), gomock.Any()).Return(nil, errDummy)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
_, err := c.Upsert(nil, nil)
assert.Equal(t, errDummy, err)
c.brk = new(dropBreaker)
_, err = c.Upsert(nil, nil)
assert.Equal(t, errDummy, err)
}
func Test_logDuration(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
col := internal.NewMockMgoCollection(ctrl)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
var buf strings.Builder
w := logx.NewWriter(&buf)
o := logx.Reset()
logx.SetWriter(w)
defer func() {
logx.Reset()
logx.SetWriter(o)
}()
buf.Reset()
c.logDuration("foo", time.Millisecond, nil, "bar")
assert.Contains(t, buf.String(), "foo")
assert.Contains(t, buf.String(), "bar")
buf.Reset()
c.logDuration("foo", time.Millisecond, errors.New("bar"), make(chan int))
assert.Contains(t, buf.String(), "bar")
buf.Reset()
c.logDuration("foo", slowThreshold.Load()+time.Millisecond, errors.New("bar"))
assert.Contains(t, buf.String(), "bar")
assert.Contains(t, buf.String(), "slowcall")
buf.Reset()
c.logDuration("foo", slowThreshold.Load()+time.Millisecond, nil)
assert.Contains(t, buf.String(), "foo")
assert.Contains(t, buf.String(), "slowcall")
}
type mockPromise struct {
accepted bool
reason string
}
func (p *mockPromise) Accept() {
p.accepted = true
}
func (p *mockPromise) Reject(reason string) {
p.reason = reason
}
type dropBreaker struct{}
func (d *dropBreaker) Name() string {
return "dummy"
}
func (d *dropBreaker) Allow() (breaker.Promise, error) {
return nil, errDummy
}
func (d *dropBreaker) Do(req func() error) error {
return nil
}
func (d *dropBreaker) DoWithAcceptable(req func() error, acceptable breaker.Acceptable) error {
return errDummy
}
func (d *dropBreaker) DoWithFallback(req func() error, fallback func(err error) error) error {
return nil
}
func (d *dropBreaker) DoWithFallbackAcceptable(req func() error, fallback func(err error) error,
acceptable breaker.Acceptable) error {
return nil
}

View File

@@ -1,19 +0,0 @@
//go:generate mockgen -package internal -destination collection_mock.go -source collection.go
package internal
import "github.com/globalsign/mgo"
// MgoCollection interface represents a mgo collection.
type MgoCollection interface {
Find(query interface{}) *mgo.Query
FindId(id interface{}) *mgo.Query
Insert(docs ...interface{}) error
Pipe(pipeline interface{}) *mgo.Pipe
Remove(selector interface{}) error
RemoveAll(selector interface{}) (*mgo.ChangeInfo, error)
RemoveId(id interface{}) error
Update(selector, update interface{}) error
UpdateId(id, update interface{}) error
Upsert(selector, update interface{}) (*mgo.ChangeInfo, error)
}

View File

@@ -1,181 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: collection.go
// Package internal is a generated GoMock package.
package internal
import (
reflect "reflect"
mgo "github.com/globalsign/mgo"
gomock "github.com/golang/mock/gomock"
)
// MockMgoCollection is a mock of MgoCollection interface
type MockMgoCollection struct {
ctrl *gomock.Controller
recorder *MockMgoCollectionMockRecorder
}
// MockMgoCollectionMockRecorder is the mock recorder for MockMgoCollection
type MockMgoCollectionMockRecorder struct {
mock *MockMgoCollection
}
// NewMockMgoCollection creates a new mock instance
func NewMockMgoCollection(ctrl *gomock.Controller) *MockMgoCollection {
mock := &MockMgoCollection{ctrl: ctrl}
mock.recorder = &MockMgoCollectionMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockMgoCollection) EXPECT() *MockMgoCollectionMockRecorder {
return m.recorder
}
// Find mocks base method
func (m *MockMgoCollection) Find(query interface{}) *mgo.Query {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Find", query)
ret0, _ := ret[0].(*mgo.Query)
return ret0
}
// Find indicates an expected call of Find
func (mr *MockMgoCollectionMockRecorder) Find(query interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Find", reflect.TypeOf((*MockMgoCollection)(nil).Find), query)
}
// FindId mocks base method
func (m *MockMgoCollection) FindId(id interface{}) *mgo.Query {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FindId", id)
ret0, _ := ret[0].(*mgo.Query)
return ret0
}
// FindId indicates an expected call of FindId
func (mr *MockMgoCollectionMockRecorder) FindId(id interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindId", reflect.TypeOf((*MockMgoCollection)(nil).FindId), id)
}
// Insert mocks base method
func (m *MockMgoCollection) Insert(docs ...interface{}) error {
m.ctrl.T.Helper()
varargs := []interface{}{}
for _, a := range docs {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Insert", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// Insert indicates an expected call of Insert
func (mr *MockMgoCollectionMockRecorder) Insert(docs ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockMgoCollection)(nil).Insert), docs...)
}
// Pipe mocks base method
func (m *MockMgoCollection) Pipe(pipeline interface{}) *mgo.Pipe {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Pipe", pipeline)
ret0, _ := ret[0].(*mgo.Pipe)
return ret0
}
// Pipe indicates an expected call of Pipe
func (mr *MockMgoCollectionMockRecorder) Pipe(pipeline interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Pipe", reflect.TypeOf((*MockMgoCollection)(nil).Pipe), pipeline)
}
// Remove mocks base method
func (m *MockMgoCollection) Remove(selector interface{}) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Remove", selector)
ret0, _ := ret[0].(error)
return ret0
}
// Remove indicates an expected call of Remove
func (mr *MockMgoCollectionMockRecorder) Remove(selector interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockMgoCollection)(nil).Remove), selector)
}
// RemoveAll mocks base method
func (m *MockMgoCollection) RemoveAll(selector interface{}) (*mgo.ChangeInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveAll", selector)
ret0, _ := ret[0].(*mgo.ChangeInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RemoveAll indicates an expected call of RemoveAll
func (mr *MockMgoCollectionMockRecorder) RemoveAll(selector interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveAll", reflect.TypeOf((*MockMgoCollection)(nil).RemoveAll), selector)
}
// RemoveId mocks base method
func (m *MockMgoCollection) RemoveId(id interface{}) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveId", id)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveId indicates an expected call of RemoveId
func (mr *MockMgoCollectionMockRecorder) RemoveId(id interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveId", reflect.TypeOf((*MockMgoCollection)(nil).RemoveId), id)
}
// Update mocks base method
func (m *MockMgoCollection) Update(selector, update interface{}) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Update", selector, update)
ret0, _ := ret[0].(error)
return ret0
}
// Update indicates an expected call of Update
func (mr *MockMgoCollectionMockRecorder) Update(selector, update interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockMgoCollection)(nil).Update), selector, update)
}
// UpdateId mocks base method
func (m *MockMgoCollection) UpdateId(id, update interface{}) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateId", id, update)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateId indicates an expected call of UpdateId
func (mr *MockMgoCollectionMockRecorder) UpdateId(id, update interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateId", reflect.TypeOf((*MockMgoCollection)(nil).UpdateId), id, update)
}
// Upsert mocks base method
func (m *MockMgoCollection) Upsert(selector, update interface{}) (*mgo.ChangeInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Upsert", selector, update)
ret0, _ := ret[0].(*mgo.ChangeInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Upsert indicates an expected call of Upsert
func (mr *MockMgoCollectionMockRecorder) Upsert(selector, update interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockMgoCollection)(nil).Upsert), selector, update)
}

View File

@@ -1,99 +0,0 @@
//go:generate mockgen -package mongo -destination iter_mock.go -source iter.go Iter
package mongo
import (
"github.com/globalsign/mgo/bson"
"github.com/zeromicro/go-zero/core/breaker"
)
type (
// Iter interface represents a mongo iter.
Iter interface {
All(result interface{}) error
Close() error
Done() bool
Err() error
For(result interface{}, f func() error) error
Next(result interface{}) bool
State() (int64, []bson.Raw)
Timeout() bool
}
// A ClosableIter is a closable mongo iter.
ClosableIter struct {
Iter
Cleanup func()
}
promisedIter struct {
Iter
promise keepablePromise
}
rejectedIter struct{}
)
func (i promisedIter) All(result interface{}) error {
return i.promise.keep(i.Iter.All(result))
}
func (i promisedIter) Close() error {
return i.promise.keep(i.Iter.Close())
}
func (i promisedIter) Err() error {
return i.Iter.Err()
}
func (i promisedIter) For(result interface{}, f func() error) error {
var ferr error
err := i.Iter.For(result, func() error {
ferr = f()
return ferr
})
if ferr == err {
return i.promise.accept(err)
}
return i.promise.keep(err)
}
// Close closes a mongo iter.
func (it *ClosableIter) Close() error {
err := it.Iter.Close()
it.Cleanup()
return err
}
func (i rejectedIter) All(result interface{}) error {
return breaker.ErrServiceUnavailable
}
func (i rejectedIter) Close() error {
return breaker.ErrServiceUnavailable
}
func (i rejectedIter) Done() bool {
return false
}
func (i rejectedIter) Err() error {
return breaker.ErrServiceUnavailable
}
func (i rejectedIter) For(result interface{}, f func() error) error {
return breaker.ErrServiceUnavailable
}
func (i rejectedIter) Next(result interface{}) bool {
return false
}
func (i rejectedIter) State() (int64, []bson.Raw) {
return 0, nil
}
func (i rejectedIter) Timeout() bool {
return false
}

View File

@@ -1,148 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: iter.go
// Package mongo is a generated GoMock package.
package mongo
import (
reflect "reflect"
bson "github.com/globalsign/mgo/bson"
gomock "github.com/golang/mock/gomock"
)
// MockIter is a mock of Iter interface
type MockIter struct {
ctrl *gomock.Controller
recorder *MockIterMockRecorder
}
// MockIterMockRecorder is the mock recorder for MockIter
type MockIterMockRecorder struct {
mock *MockIter
}
// NewMockIter creates a new mock instance
func NewMockIter(ctrl *gomock.Controller) *MockIter {
mock := &MockIter{ctrl: ctrl}
mock.recorder = &MockIterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockIter) EXPECT() *MockIterMockRecorder {
return m.recorder
}
// All mocks base method
func (m *MockIter) All(result interface{}) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "All", result)
ret0, _ := ret[0].(error)
return ret0
}
// All indicates an expected call of All
func (mr *MockIterMockRecorder) All(result interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "All", reflect.TypeOf((*MockIter)(nil).All), result)
}
// Close mocks base method
func (m *MockIter) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close
func (mr *MockIterMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockIter)(nil).Close))
}
// Done mocks base method
func (m *MockIter) Done() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Done")
ret0, _ := ret[0].(bool)
return ret0
}
// Done indicates an expected call of Done
func (mr *MockIterMockRecorder) Done() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Done", reflect.TypeOf((*MockIter)(nil).Done))
}
// Err mocks base method
func (m *MockIter) Err() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Err")
ret0, _ := ret[0].(error)
return ret0
}
// Err indicates an expected call of Err
func (mr *MockIterMockRecorder) Err() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockIter)(nil).Err))
}
// For mocks base method
func (m *MockIter) For(result interface{}, f func() error) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "For", result, f)
ret0, _ := ret[0].(error)
return ret0
}
// For indicates an expected call of For
func (mr *MockIterMockRecorder) For(result, f interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "For", reflect.TypeOf((*MockIter)(nil).For), result, f)
}
// Next mocks base method
func (m *MockIter) Next(result interface{}) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Next", result)
ret0, _ := ret[0].(bool)
return ret0
}
// Next indicates an expected call of Next
func (mr *MockIterMockRecorder) Next(result interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockIter)(nil).Next), result)
}
// State mocks base method
func (m *MockIter) State() (int64, []bson.Raw) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "State")
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].([]bson.Raw)
return ret0, ret1
}
// State indicates an expected call of State
func (mr *MockIterMockRecorder) State() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "State", reflect.TypeOf((*MockIter)(nil).State))
}
// Timeout mocks base method
func (m *MockIter) Timeout() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Timeout")
ret0, _ := ret[0].(bool)
return ret0
}
// Timeout indicates an expected call of Timeout
func (mr *MockIterMockRecorder) Timeout() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Timeout", reflect.TypeOf((*MockIter)(nil).Timeout))
}

View File

@@ -1,264 +0,0 @@
package mongo
import (
"errors"
"testing"
"github.com/globalsign/mgo"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/stringx"
"github.com/zeromicro/go-zero/core/syncx"
)
func TestClosableIter_Close(t *testing.T) {
errs := []error{
nil,
mgo.ErrNotFound,
}
for _, err := range errs {
t.Run(stringx.RandId(), func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
cleaned := syncx.NewAtomicBool()
iter := NewMockIter(ctrl)
iter.EXPECT().Close().Return(err)
ci := ClosableIter{
Iter: iter,
Cleanup: func() {
cleaned.Set(true)
},
}
assert.Equal(t, err, ci.Close())
assert.True(t, cleaned.True())
})
}
}
func TestPromisedIter_AllAndClose(t *testing.T) {
tests := []struct {
err error
accepted bool
reason string
}{
{
err: nil,
accepted: true,
reason: "",
},
{
err: mgo.ErrNotFound,
accepted: true,
reason: "",
},
{
err: errors.New("any"),
accepted: false,
reason: "any",
},
}
for _, test := range tests {
t.Run(stringx.RandId(), func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
iter := NewMockIter(ctrl)
iter.EXPECT().All(gomock.Any()).Return(test.err)
promise := new(mockPromise)
pi := promisedIter{
Iter: iter,
promise: keepablePromise{
promise: promise,
log: func(error) {},
},
}
assert.Equal(t, test.err, pi.All(nil))
assert.Equal(t, test.accepted, promise.accepted)
assert.Equal(t, test.reason, promise.reason)
})
}
for _, test := range tests {
t.Run(stringx.RandId(), func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
iter := NewMockIter(ctrl)
iter.EXPECT().Close().Return(test.err)
promise := new(mockPromise)
pi := promisedIter{
Iter: iter,
promise: keepablePromise{
promise: promise,
log: func(error) {},
},
}
assert.Equal(t, test.err, pi.Close())
assert.Equal(t, test.accepted, promise.accepted)
assert.Equal(t, test.reason, promise.reason)
})
}
}
func TestPromisedIter_Err(t *testing.T) {
errs := []error{
nil,
mgo.ErrNotFound,
}
for _, err := range errs {
t.Run(stringx.RandId(), func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
iter := NewMockIter(ctrl)
iter.EXPECT().Err().Return(err)
promise := new(mockPromise)
pi := promisedIter{
Iter: iter,
promise: keepablePromise{
promise: promise,
log: func(error) {},
},
}
assert.Equal(t, err, pi.Err())
})
}
}
func TestPromisedIter_For(t *testing.T) {
tests := []struct {
err error
accepted bool
reason string
}{
{
err: nil,
accepted: true,
reason: "",
},
{
err: mgo.ErrNotFound,
accepted: true,
reason: "",
},
{
err: errors.New("any"),
accepted: false,
reason: "any",
},
}
for _, test := range tests {
t.Run(stringx.RandId(), func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
iter := NewMockIter(ctrl)
iter.EXPECT().For(gomock.Any(), gomock.Any()).Return(test.err)
promise := new(mockPromise)
pi := promisedIter{
Iter: iter,
promise: keepablePromise{
promise: promise,
log: func(error) {},
},
}
assert.Equal(t, test.err, pi.For(nil, nil))
assert.Equal(t, test.accepted, promise.accepted)
assert.Equal(t, test.reason, promise.reason)
})
}
}
func TestRejectedIter_All(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).All(nil))
}
func TestRejectedIter_Close(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).Close())
}
func TestRejectedIter_Done(t *testing.T) {
assert.False(t, new(rejectedIter).Done())
}
func TestRejectedIter_Err(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).Err())
}
func TestRejectedIter_For(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).For(nil, nil))
}
func TestRejectedIter_Next(t *testing.T) {
assert.False(t, new(rejectedIter).Next(nil))
}
func TestRejectedIter_State(t *testing.T) {
n, raw := new(rejectedIter).State()
assert.Equal(t, int64(0), n)
assert.Nil(t, raw)
}
func TestRejectedIter_Timeout(t *testing.T) {
assert.False(t, new(rejectedIter).Timeout())
}
func TestIter_Done(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
iter := NewMockIter(ctrl)
iter.EXPECT().Done().Return(true)
ci := ClosableIter{
Iter: iter,
Cleanup: nil,
}
assert.True(t, ci.Done())
}
func TestIter_Next(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
iter := NewMockIter(ctrl)
iter.EXPECT().Next(gomock.Any()).Return(true)
ci := ClosableIter{
Iter: iter,
Cleanup: nil,
}
assert.True(t, ci.Next(nil))
}
func TestIter_State(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
iter := NewMockIter(ctrl)
iter.EXPECT().State().Return(int64(1), nil)
ci := ClosableIter{
Iter: iter,
Cleanup: nil,
}
n, raw := ci.State()
assert.Equal(t, int64(1), n)
assert.Nil(t, raw)
}
func TestIter_Timeout(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
iter := NewMockIter(ctrl)
iter.EXPECT().Timeout().Return(true)
ci := ClosableIter{
Iter: iter,
Cleanup: nil,
}
assert.True(t, ci.Timeout())
}

View File

@@ -1,177 +0,0 @@
package mongo
import (
"log"
"time"
"github.com/globalsign/mgo"
"github.com/zeromicro/go-zero/core/breaker"
)
// A Model is a mongo model.
type Model struct {
session *concurrentSession
db *mgo.Database
collection string
brk breaker.Breaker
opts []Option
}
// MustNewModel returns a Model, exits on errors.
func MustNewModel(url, collection string, opts ...Option) *Model {
model, err := NewModel(url, collection, opts...)
if err != nil {
log.Fatal(err)
}
return model
}
// NewModel returns a Model.
func NewModel(url, collection string, opts ...Option) (*Model, error) {
session, err := getConcurrentSession(url)
if err != nil {
return nil, err
}
return &Model{
session: session,
// If name is empty, the database name provided in the dialed URL is used instead
db: session.DB(""),
collection: collection,
brk: breaker.GetBreaker(url),
opts: opts,
}, nil
}
// Find finds a record with given query.
func (mm *Model) Find(query interface{}) (Query, error) {
return mm.query(func(c Collection) Query {
return c.Find(query)
})
}
// FindId finds a record with given id.
func (mm *Model) FindId(id interface{}) (Query, error) {
return mm.query(func(c Collection) Query {
return c.FindId(id)
})
}
// GetCollection returns a Collection with given session.
func (mm *Model) GetCollection(session *mgo.Session) Collection {
return newCollection(mm.db.C(mm.collection).With(session), mm.brk)
}
// Insert inserts docs into mm.
func (mm *Model) Insert(docs ...interface{}) error {
return mm.execute(func(c Collection) error {
return c.Insert(docs...)
})
}
// Pipe returns a Pipe with given pipeline.
func (mm *Model) Pipe(pipeline interface{}) (Pipe, error) {
return mm.pipe(func(c Collection) Pipe {
return c.Pipe(pipeline)
})
}
// PutSession returns the given session.
func (mm *Model) PutSession(session *mgo.Session) {
mm.session.putSession(session)
}
// Remove removes the records with given selector.
func (mm *Model) Remove(selector interface{}) error {
return mm.execute(func(c Collection) error {
return c.Remove(selector)
})
}
// RemoveAll removes all with given selector and returns a mgo.ChangeInfo.
func (mm *Model) RemoveAll(selector interface{}) (*mgo.ChangeInfo, error) {
return mm.change(func(c Collection) (*mgo.ChangeInfo, error) {
return c.RemoveAll(selector)
})
}
// RemoveId removes a record with given id.
func (mm *Model) RemoveId(id interface{}) error {
return mm.execute(func(c Collection) error {
return c.RemoveId(id)
})
}
// TakeSession gets a session.
func (mm *Model) TakeSession() (*mgo.Session, error) {
return mm.session.takeSession(mm.opts...)
}
// Update updates a record with given selector.
func (mm *Model) Update(selector, update interface{}) error {
return mm.execute(func(c Collection) error {
return c.Update(selector, update)
})
}
// UpdateId updates a record with given id.
func (mm *Model) UpdateId(id, update interface{}) error {
return mm.execute(func(c Collection) error {
return c.UpdateId(id, update)
})
}
// Upsert upserts a record with given selector, and returns a mgo.ChangeInfo.
func (mm *Model) Upsert(selector, update interface{}) (*mgo.ChangeInfo, error) {
return mm.change(func(c Collection) (*mgo.ChangeInfo, error) {
return c.Upsert(selector, update)
})
}
func (mm *Model) change(fn func(c Collection) (*mgo.ChangeInfo, error)) (*mgo.ChangeInfo, error) {
session, err := mm.TakeSession()
if err != nil {
return nil, err
}
defer mm.PutSession(session)
return fn(mm.GetCollection(session))
}
func (mm *Model) execute(fn func(c Collection) error) error {
session, err := mm.TakeSession()
if err != nil {
return err
}
defer mm.PutSession(session)
return fn(mm.GetCollection(session))
}
func (mm *Model) pipe(fn func(c Collection) Pipe) (Pipe, error) {
session, err := mm.TakeSession()
if err != nil {
return nil, err
}
defer mm.PutSession(session)
return fn(mm.GetCollection(session)), nil
}
func (mm *Model) query(fn func(c Collection) Query) (Query, error) {
session, err := mm.TakeSession()
if err != nil {
return nil, err
}
defer mm.PutSession(session)
return fn(mm.GetCollection(session)), nil
}
// WithTimeout customizes an operation with given timeout.
func WithTimeout(timeout time.Duration) Option {
return func(opts *options) {
opts.timeout = timeout
}
}

View File

@@ -1,14 +0,0 @@
package mongo
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestWithTimeout(t *testing.T) {
o := defaultOptions()
WithTimeout(time.Second)(o)
assert.Equal(t, time.Second, o.timeout)
}

View File

@@ -1,29 +0,0 @@
package mongo
import (
"time"
"github.com/zeromicro/go-zero/core/syncx"
)
var slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold)
type (
options struct {
timeout time.Duration
}
// Option defines the method to customize a mongo model.
Option func(opts *options)
)
// SetSlowThreshold sets the slow threshold.
func SetSlowThreshold(threshold time.Duration) {
slowThreshold.Set(threshold)
}
func defaultOptions() *options {
return &options{
timeout: defaultTimeout,
}
}

View File

@@ -1,14 +0,0 @@
package mongo
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestSetSlowThreshold(t *testing.T) {
assert.Equal(t, defaultSlowThreshold, slowThreshold.Load())
SetSlowThreshold(time.Second)
assert.Equal(t, time.Second, slowThreshold.Load())
}

View File

@@ -1,100 +0,0 @@
package mongo
import (
"time"
"github.com/globalsign/mgo"
"github.com/zeromicro/go-zero/core/breaker"
)
type (
// Pipe interface represents a mongo pipe.
Pipe interface {
All(result interface{}) error
AllowDiskUse() Pipe
Batch(n int) Pipe
Collation(collation *mgo.Collation) Pipe
Explain(result interface{}) error
Iter() Iter
One(result interface{}) error
SetMaxTime(d time.Duration) Pipe
}
promisedPipe struct {
*mgo.Pipe
promise keepablePromise
}
rejectedPipe struct{}
)
func (p promisedPipe) All(result interface{}) error {
return p.promise.keep(p.Pipe.All(result))
}
func (p promisedPipe) AllowDiskUse() Pipe {
p.Pipe.AllowDiskUse()
return p
}
func (p promisedPipe) Batch(n int) Pipe {
p.Pipe.Batch(n)
return p
}
func (p promisedPipe) Collation(collation *mgo.Collation) Pipe {
p.Pipe.Collation(collation)
return p
}
func (p promisedPipe) Explain(result interface{}) error {
return p.promise.keep(p.Pipe.Explain(result))
}
func (p promisedPipe) Iter() Iter {
return promisedIter{
Iter: p.Pipe.Iter(),
promise: p.promise,
}
}
func (p promisedPipe) One(result interface{}) error {
return p.promise.keep(p.Pipe.One(result))
}
func (p promisedPipe) SetMaxTime(d time.Duration) Pipe {
p.Pipe.SetMaxTime(d)
return p
}
func (p rejectedPipe) All(result interface{}) error {
return breaker.ErrServiceUnavailable
}
func (p rejectedPipe) AllowDiskUse() Pipe {
return p
}
func (p rejectedPipe) Batch(n int) Pipe {
return p
}
func (p rejectedPipe) Collation(collation *mgo.Collation) Pipe {
return p
}
func (p rejectedPipe) Explain(result interface{}) error {
return breaker.ErrServiceUnavailable
}
func (p rejectedPipe) Iter() Iter {
return rejectedIter{}
}
func (p rejectedPipe) One(result interface{}) error {
return breaker.ErrServiceUnavailable
}
func (p rejectedPipe) SetMaxTime(d time.Duration) Pipe {
return p
}

View File

@@ -1,44 +0,0 @@
package mongo
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/breaker"
)
func TestRejectedPipe_All(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedPipe).All(nil))
}
func TestRejectedPipe_AllowDiskUse(t *testing.T) {
var p rejectedPipe
assert.Equal(t, p, p.AllowDiskUse())
}
func TestRejectedPipe_Batch(t *testing.T) {
var p rejectedPipe
assert.Equal(t, p, p.Batch(1))
}
func TestRejectedPipe_Collation(t *testing.T) {
var p rejectedPipe
assert.Equal(t, p, p.Collation(nil))
}
func TestRejectedPipe_Explain(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedPipe).Explain(nil))
}
func TestRejectedPipe_Iter(t *testing.T) {
assert.EqualValues(t, rejectedIter{}, new(rejectedPipe).Iter())
}
func TestRejectedPipe_One(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedPipe).One(nil))
}
func TestRejectedPipe_SetMaxTime(t *testing.T) {
var p rejectedPipe
assert.Equal(t, p, p.SetMaxTime(0))
}

View File

@@ -1,285 +0,0 @@
package mongo
import (
"time"
"github.com/globalsign/mgo"
"github.com/zeromicro/go-zero/core/breaker"
)
type (
// Query interface represents a mongo query.
Query interface {
All(result interface{}) error
Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error)
Batch(n int) Query
Collation(collation *mgo.Collation) Query
Comment(comment string) Query
Count() (int, error)
Distinct(key string, result interface{}) error
Explain(result interface{}) error
For(result interface{}, f func() error) error
Hint(indexKey ...string) Query
Iter() Iter
Limit(n int) Query
LogReplay() Query
MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error)
One(result interface{}) error
Prefetch(p float64) Query
Select(selector interface{}) Query
SetMaxScan(n int) Query
SetMaxTime(d time.Duration) Query
Skip(n int) Query
Snapshot() Query
Sort(fields ...string) Query
Tail(timeout time.Duration) Iter
}
promisedQuery struct {
*mgo.Query
promise keepablePromise
}
rejectedQuery struct{}
)
func (q promisedQuery) All(result interface{}) error {
return q.promise.keep(q.Query.All(result))
}
func (q promisedQuery) Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error) {
info, err := q.Query.Apply(change, result)
return info, q.promise.keep(err)
}
func (q promisedQuery) Batch(n int) Query {
return promisedQuery{
Query: q.Query.Batch(n),
promise: q.promise,
}
}
func (q promisedQuery) Collation(collation *mgo.Collation) Query {
return promisedQuery{
Query: q.Query.Collation(collation),
promise: q.promise,
}
}
func (q promisedQuery) Comment(comment string) Query {
return promisedQuery{
Query: q.Query.Comment(comment),
promise: q.promise,
}
}
func (q promisedQuery) Count() (int, error) {
v, err := q.Query.Count()
return v, q.promise.keep(err)
}
func (q promisedQuery) Distinct(key string, result interface{}) error {
return q.promise.keep(q.Query.Distinct(key, result))
}
func (q promisedQuery) Explain(result interface{}) error {
return q.promise.keep(q.Query.Explain(result))
}
func (q promisedQuery) For(result interface{}, f func() error) error {
var ferr error
err := q.Query.For(result, func() error {
ferr = f()
return ferr
})
if ferr == err {
return q.promise.accept(err)
}
return q.promise.keep(err)
}
func (q promisedQuery) Hint(indexKey ...string) Query {
return promisedQuery{
Query: q.Query.Hint(indexKey...),
promise: q.promise,
}
}
func (q promisedQuery) Iter() Iter {
return promisedIter{
Iter: q.Query.Iter(),
promise: q.promise,
}
}
func (q promisedQuery) Limit(n int) Query {
return promisedQuery{
Query: q.Query.Limit(n),
promise: q.promise,
}
}
func (q promisedQuery) LogReplay() Query {
return promisedQuery{
Query: q.Query.LogReplay(),
promise: q.promise,
}
}
func (q promisedQuery) MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error) {
info, err := q.Query.MapReduce(job, result)
return info, q.promise.keep(err)
}
func (q promisedQuery) One(result interface{}) error {
return q.promise.keep(q.Query.One(result))
}
func (q promisedQuery) Prefetch(p float64) Query {
return promisedQuery{
Query: q.Query.Prefetch(p),
promise: q.promise,
}
}
func (q promisedQuery) Select(selector interface{}) Query {
return promisedQuery{
Query: q.Query.Select(selector),
promise: q.promise,
}
}
func (q promisedQuery) SetMaxScan(n int) Query {
return promisedQuery{
Query: q.Query.SetMaxScan(n),
promise: q.promise,
}
}
func (q promisedQuery) SetMaxTime(d time.Duration) Query {
return promisedQuery{
Query: q.Query.SetMaxTime(d),
promise: q.promise,
}
}
func (q promisedQuery) Skip(n int) Query {
return promisedQuery{
Query: q.Query.Skip(n),
promise: q.promise,
}
}
func (q promisedQuery) Snapshot() Query {
return promisedQuery{
Query: q.Query.Snapshot(),
promise: q.promise,
}
}
func (q promisedQuery) Sort(fields ...string) Query {
return promisedQuery{
Query: q.Query.Sort(fields...),
promise: q.promise,
}
}
func (q promisedQuery) Tail(timeout time.Duration) Iter {
return promisedIter{
Iter: q.Query.Tail(timeout),
promise: q.promise,
}
}
func (q rejectedQuery) All(result interface{}) error {
return breaker.ErrServiceUnavailable
}
func (q rejectedQuery) Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error) {
return nil, breaker.ErrServiceUnavailable
}
func (q rejectedQuery) Batch(n int) Query {
return q
}
func (q rejectedQuery) Collation(collation *mgo.Collation) Query {
return q
}
func (q rejectedQuery) Comment(comment string) Query {
return q
}
func (q rejectedQuery) Count() (int, error) {
return 0, breaker.ErrServiceUnavailable
}
func (q rejectedQuery) Distinct(key string, result interface{}) error {
return breaker.ErrServiceUnavailable
}
func (q rejectedQuery) Explain(result interface{}) error {
return breaker.ErrServiceUnavailable
}
func (q rejectedQuery) For(result interface{}, f func() error) error {
return breaker.ErrServiceUnavailable
}
func (q rejectedQuery) Hint(indexKey ...string) Query {
return q
}
func (q rejectedQuery) Iter() Iter {
return rejectedIter{}
}
func (q rejectedQuery) Limit(n int) Query {
return q
}
func (q rejectedQuery) LogReplay() Query {
return q
}
func (q rejectedQuery) MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error) {
return nil, breaker.ErrServiceUnavailable
}
func (q rejectedQuery) One(result interface{}) error {
return breaker.ErrServiceUnavailable
}
func (q rejectedQuery) Prefetch(p float64) Query {
return q
}
func (q rejectedQuery) Select(selector interface{}) Query {
return q
}
func (q rejectedQuery) SetMaxScan(n int) Query {
return q
}
func (q rejectedQuery) SetMaxTime(d time.Duration) Query {
return q
}
func (q rejectedQuery) Skip(n int) Query {
return q
}
func (q rejectedQuery) Snapshot() Query {
return q
}
func (q rejectedQuery) Sort(fields ...string) Query {
return q
}
func (q rejectedQuery) Tail(timeout time.Duration) Iter {
return rejectedIter{}
}

View File

@@ -1,120 +0,0 @@
package mongo
import (
"testing"
"github.com/globalsign/mgo"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/breaker"
)
func Test_rejectedQuery_All(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).All(nil))
}
func Test_rejectedQuery_Apply(t *testing.T) {
info, err := new(rejectedQuery).Apply(mgo.Change{}, nil)
assert.Equal(t, breaker.ErrServiceUnavailable, err)
assert.Nil(t, info)
}
func Test_rejectedQuery_Batch(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.Batch(1))
}
func Test_rejectedQuery_Collation(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.Collation(nil))
}
func Test_rejectedQuery_Comment(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.Comment(""))
}
func Test_rejectedQuery_Count(t *testing.T) {
n, err := new(rejectedQuery).Count()
assert.Equal(t, breaker.ErrServiceUnavailable, err)
assert.Equal(t, 0, n)
}
func Test_rejectedQuery_Distinct(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).Distinct("", nil))
}
func Test_rejectedQuery_Explain(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).Explain(nil))
}
func Test_rejectedQuery_For(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).For(nil, nil))
}
func Test_rejectedQuery_Hint(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.Hint())
}
func Test_rejectedQuery_Iter(t *testing.T) {
assert.EqualValues(t, rejectedIter{}, new(rejectedQuery).Iter())
}
func Test_rejectedQuery_Limit(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.Limit(1))
}
func Test_rejectedQuery_LogReplay(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.LogReplay())
}
func Test_rejectedQuery_MapReduce(t *testing.T) {
info, err := new(rejectedQuery).MapReduce(nil, nil)
assert.Equal(t, breaker.ErrServiceUnavailable, err)
assert.Nil(t, info)
}
func Test_rejectedQuery_One(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).One(nil))
}
func Test_rejectedQuery_Prefetch(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.Prefetch(1))
}
func Test_rejectedQuery_Select(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.Select(nil))
}
func Test_rejectedQuery_SetMaxScan(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.SetMaxScan(0))
}
func Test_rejectedQuery_SetMaxTime(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.SetMaxTime(0))
}
func Test_rejectedQuery_Skip(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.Skip(0))
}
func Test_rejectedQuery_Snapshot(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.Snapshot())
}
func Test_rejectedQuery_Sort(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.Sort())
}
func Test_rejectedQuery_Tail(t *testing.T) {
assert.EqualValues(t, rejectedIter{}, new(rejectedQuery).Tail(0))
}

View File

@@ -1,70 +0,0 @@
package mongo
import (
"io"
"time"
"github.com/globalsign/mgo"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/syncx"
)
const (
defaultConcurrency = 50
defaultTimeout = time.Second
)
var sessionManager = syncx.NewResourceManager()
type concurrentSession struct {
*mgo.Session
limit syncx.TimeoutLimit
}
func (cs *concurrentSession) Close() error {
cs.Session.Close()
return nil
}
func getConcurrentSession(url string) (*concurrentSession, error) {
val, err := sessionManager.GetResource(url, func() (io.Closer, error) {
mgoSession, err := mgo.Dial(url)
if err != nil {
return nil, err
}
concurrentSess := &concurrentSession{
Session: mgoSession,
limit: syncx.NewTimeoutLimit(defaultConcurrency),
}
return concurrentSess, nil
})
if err != nil {
return nil, err
}
return val.(*concurrentSession), nil
}
func (cs *concurrentSession) putSession(session *mgo.Session) {
if err := cs.limit.Return(); err != nil {
logx.Error(err)
}
// anyway, we need to close the session
session.Close()
}
func (cs *concurrentSession) takeSession(opts ...Option) (*mgo.Session, error) {
o := defaultOptions()
for _, opt := range opts {
opt(o)
}
if err := cs.limit.Borrow(o.timeout); err != nil {
return nil, err
}
return cs.Copy(), nil
}

View File

@@ -1,10 +0,0 @@
package mongo
import "strings"
const mongoAddrSep = ","
// FormatAddr formats mongo hosts to a string.
func FormatAddr(hosts []string) string {
return strings.Join(hosts, mongoAddrSep)
}

View File

@@ -1,35 +0,0 @@
package mongo
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestFormatAddrs(t *testing.T) {
tests := []struct {
addrs []string
expect string
}{
{
addrs: []string{"a", "b"},
expect: "a,b",
},
{
addrs: []string{"a", "b", "c"},
expect: "a,b,c",
},
{
addrs: []string{},
expect: "",
},
{
addrs: nil,
expect: "",
},
}
for _, test := range tests {
assert.Equal(t, test.expect, FormatAddr(test.addrs))
}
}

View File

@@ -1,199 +0,0 @@
package mongoc
import (
"github.com/globalsign/mgo"
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/mongo"
"github.com/zeromicro/go-zero/core/syncx"
)
var (
// ErrNotFound is an alias of mgo.ErrNotFound.
ErrNotFound = mgo.ErrNotFound
// can't use one SingleFlight per conn, because multiple conns may share the same cache key.
singleFlight = syncx.NewSingleFlight()
stats = cache.NewStat("mongoc")
)
type (
// QueryOption defines the method to customize a mongo query.
QueryOption func(query mongo.Query) mongo.Query
// CachedCollection interface represents a mongo collection with cache.
CachedCollection interface {
Count(query interface{}) (int, error)
DelCache(keys ...string) error
FindAllNoCache(v, query interface{}, opts ...QueryOption) error
FindOne(v interface{}, key string, query interface{}) error
FindOneNoCache(v, query interface{}) error
FindOneId(v interface{}, key string, id interface{}) error
FindOneIdNoCache(v, id interface{}) error
GetCache(key string, v interface{}) error
Insert(docs ...interface{}) error
Pipe(pipeline interface{}) mongo.Pipe
Remove(selector interface{}, keys ...string) error
RemoveNoCache(selector interface{}) error
RemoveAll(selector interface{}, keys ...string) (*mgo.ChangeInfo, error)
RemoveAllNoCache(selector interface{}) (*mgo.ChangeInfo, error)
RemoveId(id interface{}, keys ...string) error
RemoveIdNoCache(id interface{}) error
SetCache(key string, v interface{}) error
Update(selector, update interface{}, keys ...string) error
UpdateNoCache(selector, update interface{}) error
UpdateId(id, update interface{}, keys ...string) error
UpdateIdNoCache(id, update interface{}) error
Upsert(selector, update interface{}, keys ...string) (*mgo.ChangeInfo, error)
UpsertNoCache(selector, update interface{}) (*mgo.ChangeInfo, error)
}
cachedCollection struct {
collection mongo.Collection
cache cache.Cache
}
)
func newCollection(collection mongo.Collection, c cache.Cache) CachedCollection {
return &cachedCollection{
collection: collection,
cache: c,
}
}
func (c *cachedCollection) Count(query interface{}) (int, error) {
return c.collection.Find(query).Count()
}
func (c *cachedCollection) DelCache(keys ...string) error {
return c.cache.Del(keys...)
}
func (c *cachedCollection) FindAllNoCache(v, query interface{}, opts ...QueryOption) error {
q := c.collection.Find(query)
for _, opt := range opts {
q = opt(q)
}
return q.All(v)
}
func (c *cachedCollection) FindOne(v interface{}, key string, query interface{}) error {
return c.cache.Take(v, key, func(v interface{}) error {
q := c.collection.Find(query)
return q.One(v)
})
}
func (c *cachedCollection) FindOneNoCache(v, query interface{}) error {
q := c.collection.Find(query)
return q.One(v)
}
func (c *cachedCollection) FindOneId(v interface{}, key string, id interface{}) error {
return c.cache.Take(v, key, func(v interface{}) error {
q := c.collection.FindId(id)
return q.One(v)
})
}
func (c *cachedCollection) FindOneIdNoCache(v, id interface{}) error {
q := c.collection.FindId(id)
return q.One(v)
}
func (c *cachedCollection) GetCache(key string, v interface{}) error {
return c.cache.Get(key, v)
}
func (c *cachedCollection) Insert(docs ...interface{}) error {
return c.collection.Insert(docs...)
}
func (c *cachedCollection) Pipe(pipeline interface{}) mongo.Pipe {
return c.collection.Pipe(pipeline)
}
func (c *cachedCollection) Remove(selector interface{}, keys ...string) error {
if err := c.RemoveNoCache(selector); err != nil {
return err
}
return c.DelCache(keys...)
}
func (c *cachedCollection) RemoveNoCache(selector interface{}) error {
return c.collection.Remove(selector)
}
func (c *cachedCollection) RemoveAll(selector interface{}, keys ...string) (*mgo.ChangeInfo, error) {
info, err := c.RemoveAllNoCache(selector)
if err != nil {
return nil, err
}
if err := c.DelCache(keys...); err != nil {
return nil, err
}
return info, nil
}
func (c *cachedCollection) RemoveAllNoCache(selector interface{}) (*mgo.ChangeInfo, error) {
return c.collection.RemoveAll(selector)
}
func (c *cachedCollection) RemoveId(id interface{}, keys ...string) error {
if err := c.RemoveIdNoCache(id); err != nil {
return err
}
return c.DelCache(keys...)
}
func (c *cachedCollection) RemoveIdNoCache(id interface{}) error {
return c.collection.RemoveId(id)
}
func (c *cachedCollection) SetCache(key string, v interface{}) error {
return c.cache.Set(key, v)
}
func (c *cachedCollection) Update(selector, update interface{}, keys ...string) error {
if err := c.UpdateNoCache(selector, update); err != nil {
return err
}
return c.DelCache(keys...)
}
func (c *cachedCollection) UpdateNoCache(selector, update interface{}) error {
return c.collection.Update(selector, update)
}
func (c *cachedCollection) UpdateId(id, update interface{}, keys ...string) error {
if err := c.UpdateIdNoCache(id, update); err != nil {
return err
}
return c.DelCache(keys...)
}
func (c *cachedCollection) UpdateIdNoCache(id, update interface{}) error {
return c.collection.UpdateId(id, update)
}
func (c *cachedCollection) Upsert(selector, update interface{}, keys ...string) (*mgo.ChangeInfo, error) {
info, err := c.UpsertNoCache(selector, update)
if err != nil {
return nil, err
}
if err := c.DelCache(keys...); err != nil {
return nil, err
}
return info, nil
}
func (c *cachedCollection) UpsertNoCache(selector, update interface{}) (*mgo.ChangeInfo, error) {
return c.collection.Upsert(selector, update)
}

View File

@@ -1,365 +0,0 @@
package mongoc
import (
"encoding/json"
"errors"
"io"
"log"
"os"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/globalsign/mgo"
"github.com/globalsign/mgo/bson"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/mongo"
"github.com/zeromicro/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/core/stores/redis/redistest"
)
const dummyCount = 10
func init() {
stat.SetReporter(nil)
}
func TestCollection_Count(t *testing.T) {
resetStats()
r, clean, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean()
cach := cache.NewNode(r, singleFlight, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach)
val, err := c.Count("any")
assert.Nil(t, err)
assert.Equal(t, dummyCount, val)
var value string
assert.Nil(t, r.Set("any", `"foo"`))
assert.Nil(t, c.GetCache("any", &value))
assert.Equal(t, "foo", value)
assert.Nil(t, c.DelCache("any"))
assert.Nil(t, c.SetCache("any", "bar"))
assert.Nil(t, c.FindAllNoCache(&value, "any", func(query mongo.Query) mongo.Query {
return query
}))
assert.Nil(t, c.FindOne(&value, "any", "foo"))
assert.Equal(t, "bar", value)
assert.Nil(t, c.DelCache("any"))
c = newCollection(dummyConn{val: `"bar"`}, cach)
assert.Nil(t, c.FindOne(&value, "any", "foo"))
assert.Equal(t, "bar", value)
assert.Nil(t, c.FindOneNoCache(&value, "foo"))
assert.Equal(t, "bar", value)
assert.Nil(t, c.FindOneId(&value, "anyone", "foo"))
assert.Equal(t, "bar", value)
assert.Nil(t, c.FindOneIdNoCache(&value, "foo"))
assert.Equal(t, "bar", value)
assert.Nil(t, c.Insert("foo"))
assert.Nil(t, c.Pipe("foo"))
assert.Nil(t, c.Remove("any"))
assert.Nil(t, c.RemoveId("any"))
_, err = c.RemoveAll("any")
assert.Nil(t, err)
assert.Nil(t, c.Update("foo", "bar"))
assert.Nil(t, c.UpdateId("foo", "bar"))
_, err = c.Upsert("foo", "bar")
assert.Nil(t, err)
c = newCollection(dummyConn{
val: `"bar"`,
removeErr: errors.New("any"),
}, cach)
assert.NotNil(t, c.Remove("any"))
_, err = c.RemoveAll("any", "bar")
assert.NotNil(t, err)
assert.NotNil(t, c.RemoveId("any"))
c = newCollection(dummyConn{
val: `"bar"`,
updateErr: errors.New("any"),
}, cach)
assert.NotNil(t, c.Update("foo", "bar"))
assert.NotNil(t, c.UpdateId("foo", "bar"))
_, err = c.Upsert("foo", "bar")
assert.NotNil(t, err)
}
func TestStat(t *testing.T) {
resetStats()
r, clean, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean()
cach := cache.NewNode(r, singleFlight, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach).(*cachedCollection)
for i := 0; i < 10; i++ {
var str string
if err = c.cache.Take(&str, "name", func(v interface{}) error {
*v.(*string) = "zero"
return nil
}); err != nil {
t.Error(err)
}
}
assert.Equal(t, uint64(10), atomic.LoadUint64(&stats.Total))
assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
}
func TestStatCacheFails(t *testing.T) {
resetStats()
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stdout)
r := redis.New("localhost:59999")
cach := cache.NewNode(r, singleFlight, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach)
for i := 0; i < 20; i++ {
var str string
err := c.FindOne(&str, "name", bson.M{})
assert.NotNil(t, err)
}
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Total))
assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.Hit))
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Miss))
assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.DbFails))
}
func TestStatDbFails(t *testing.T) {
resetStats()
r, clean, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean()
cach := cache.NewNode(r, singleFlight, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach).(*cachedCollection)
for i := 0; i < 20; i++ {
var str string
err := c.cache.Take(&str, "name", func(v interface{}) error {
return errors.New("db failed")
})
assert.NotNil(t, err)
}
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Total))
assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.Hit))
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.DbFails))
}
func TestStatFromMemory(t *testing.T) {
resetStats()
r, clean, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean()
cach := cache.NewNode(r, singleFlight, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach).(*cachedCollection)
var all sync.WaitGroup
var wait sync.WaitGroup
all.Add(10)
wait.Add(4)
go func() {
var str string
if err := c.cache.Take(&str, "name", func(v interface{}) error {
*v.(*string) = "zero"
return nil
}); err != nil {
t.Error(err)
}
wait.Wait()
runtime.Gosched()
all.Done()
}()
for i := 0; i < 4; i++ {
go func() {
var str string
wait.Done()
if err := c.cache.Take(&str, "name", func(v interface{}) error {
*v.(*string) = "zero"
return nil
}); err != nil {
t.Error(err)
}
all.Done()
}()
}
for i := 0; i < 5; i++ {
go func() {
var str string
if err := c.cache.Take(&str, "name", func(v interface{}) error {
*v.(*string) = "zero"
return nil
}); err != nil {
t.Error(err)
}
all.Done()
}()
}
all.Wait()
assert.Equal(t, uint64(10), atomic.LoadUint64(&stats.Total))
assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
}
func resetStats() {
atomic.StoreUint64(&stats.Total, 0)
atomic.StoreUint64(&stats.Hit, 0)
atomic.StoreUint64(&stats.Miss, 0)
atomic.StoreUint64(&stats.DbFails, 0)
}
type dummyConn struct {
val string
removeErr error
updateErr error
}
func (c dummyConn) Find(query interface{}) mongo.Query {
return dummyQuery{val: c.val}
}
func (c dummyConn) FindId(id interface{}) mongo.Query {
return dummyQuery{val: c.val}
}
func (c dummyConn) Insert(docs ...interface{}) error {
return nil
}
func (c dummyConn) Remove(selector interface{}) error {
return c.removeErr
}
func (dummyConn) Pipe(pipeline interface{}) mongo.Pipe {
return nil
}
func (c dummyConn) RemoveAll(selector interface{}) (*mgo.ChangeInfo, error) {
return nil, c.removeErr
}
func (c dummyConn) RemoveId(id interface{}) error {
return c.removeErr
}
func (c dummyConn) Update(selector, update interface{}) error {
return c.updateErr
}
func (c dummyConn) UpdateId(id, update interface{}) error {
return c.updateErr
}
func (c dummyConn) Upsert(selector, update interface{}) (*mgo.ChangeInfo, error) {
return nil, c.updateErr
}
type dummyQuery struct {
val string
}
func (d dummyQuery) All(result interface{}) error {
return nil
}
func (d dummyQuery) Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error) {
return nil, nil
}
func (d dummyQuery) Count() (int, error) {
return dummyCount, nil
}
func (d dummyQuery) Distinct(key string, result interface{}) error {
return nil
}
func (d dummyQuery) Explain(result interface{}) error {
return nil
}
func (d dummyQuery) For(result interface{}, f func() error) error {
return nil
}
func (d dummyQuery) MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error) {
return nil, nil
}
func (d dummyQuery) One(result interface{}) error {
return json.Unmarshal([]byte(d.val), result)
}
func (d dummyQuery) Batch(n int) mongo.Query {
return d
}
func (d dummyQuery) Collation(collation *mgo.Collation) mongo.Query {
return d
}
func (d dummyQuery) Comment(comment string) mongo.Query {
return d
}
func (d dummyQuery) Hint(indexKey ...string) mongo.Query {
return d
}
func (d dummyQuery) Iter() mongo.Iter {
return &mgo.Iter{}
}
func (d dummyQuery) Limit(n int) mongo.Query {
return d
}
func (d dummyQuery) LogReplay() mongo.Query {
return d
}
func (d dummyQuery) Prefetch(p float64) mongo.Query {
return d
}
func (d dummyQuery) Select(selector interface{}) mongo.Query {
return d
}
func (d dummyQuery) SetMaxScan(n int) mongo.Query {
return d
}
func (d dummyQuery) SetMaxTime(duration time.Duration) mongo.Query {
return d
}
func (d dummyQuery) Skip(n int) mongo.Query {
return d
}
func (d dummyQuery) Snapshot() mongo.Query {
return d
}
func (d dummyQuery) Sort(fields ...string) mongo.Query {
return d
}
func (d dummyQuery) Tail(timeout time.Duration) mongo.Iter {
return &mgo.Iter{}
}

View File

@@ -1,273 +0,0 @@
package mongoc
import (
"log"
"github.com/globalsign/mgo"
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/mongo"
"github.com/zeromicro/go-zero/core/stores/redis"
)
// A Model is a mongo model that built with cache capability.
type Model struct {
*mongo.Model
cache cache.Cache
generateCollection func(*mgo.Session) CachedCollection
}
// MustNewNodeModel returns a Model with a cache node, exists on errors.
func MustNewNodeModel(url, collection string, rds *redis.Redis, opts ...cache.Option) *Model {
model, err := NewNodeModel(url, collection, rds, opts...)
if err != nil {
log.Fatal(err)
}
return model
}
// MustNewModel returns a Model with a cache cluster, exists on errors.
func MustNewModel(url, collection string, c cache.CacheConf, opts ...cache.Option) *Model {
model, err := NewModel(url, collection, c, opts...)
if err != nil {
log.Fatal(err)
}
return model
}
// NewModel returns a Model with a cache cluster.
func NewModel(url, collection string, conf cache.CacheConf, opts ...cache.Option) (*Model, error) {
c := cache.New(conf, singleFlight, stats, mgo.ErrNotFound, opts...)
return NewModelWithCache(url, collection, c)
}
// NewModelWithCache returns a Model with a custom cache.
func NewModelWithCache(url, collection string, c cache.Cache) (*Model, error) {
return createModel(url, collection, c, func(collection mongo.Collection) CachedCollection {
return newCollection(collection, c)
})
}
// NewNodeModel returns a Model with a cache node.
func NewNodeModel(url, collection string, rds *redis.Redis, opts ...cache.Option) (*Model, error) {
c := cache.NewNode(rds, singleFlight, stats, mgo.ErrNotFound, opts...)
return NewModelWithCache(url, collection, c)
}
// Count returns the count of given query.
func (mm *Model) Count(query interface{}) (int, error) {
return mm.executeInt(func(c CachedCollection) (int, error) {
return c.Count(query)
})
}
// DelCache deletes the cache with given keys.
func (mm *Model) DelCache(keys ...string) error {
return mm.cache.Del(keys...)
}
// GetCache unmarshal the cache into v with given key.
func (mm *Model) GetCache(key string, v interface{}) error {
return mm.cache.Get(key, v)
}
// GetCollection returns a cache collection.
func (mm *Model) GetCollection(session *mgo.Session) CachedCollection {
return mm.generateCollection(session)
}
// FindAllNoCache finds all records without cache.
func (mm *Model) FindAllNoCache(v, query interface{}, opts ...QueryOption) error {
return mm.execute(func(c CachedCollection) error {
return c.FindAllNoCache(v, query, opts...)
})
}
// FindOne unmarshals a record into v with given key and query.
func (mm *Model) FindOne(v interface{}, key string, query interface{}) error {
return mm.execute(func(c CachedCollection) error {
return c.FindOne(v, key, query)
})
}
// FindOneNoCache unmarshals a record into v with query, without cache.
func (mm *Model) FindOneNoCache(v, query interface{}) error {
return mm.execute(func(c CachedCollection) error {
return c.FindOneNoCache(v, query)
})
}
// FindOneId unmarshals a record into v with query.
func (mm *Model) FindOneId(v interface{}, key string, id interface{}) error {
return mm.execute(func(c CachedCollection) error {
return c.FindOneId(v, key, id)
})
}
// FindOneIdNoCache unmarshals a record into v with query, without cache.
func (mm *Model) FindOneIdNoCache(v, id interface{}) error {
return mm.execute(func(c CachedCollection) error {
return c.FindOneIdNoCache(v, id)
})
}
// Insert inserts docs.
func (mm *Model) Insert(docs ...interface{}) error {
return mm.execute(func(c CachedCollection) error {
return c.Insert(docs...)
})
}
// Pipe returns a mongo pipe with given pipeline.
func (mm *Model) Pipe(pipeline interface{}) (mongo.Pipe, error) {
return mm.pipe(func(c CachedCollection) mongo.Pipe {
return c.Pipe(pipeline)
})
}
// Remove removes a record with given selector, and remove it from cache with given keys.
func (mm *Model) Remove(selector interface{}, keys ...string) error {
return mm.execute(func(c CachedCollection) error {
return c.Remove(selector, keys...)
})
}
// RemoveNoCache removes a record with given selector.
func (mm *Model) RemoveNoCache(selector interface{}) error {
return mm.execute(func(c CachedCollection) error {
return c.RemoveNoCache(selector)
})
}
// RemoveAll removes all records with given selector, and removes cache with given keys.
func (mm *Model) RemoveAll(selector interface{}, keys ...string) (*mgo.ChangeInfo, error) {
return mm.change(func(c CachedCollection) (*mgo.ChangeInfo, error) {
return c.RemoveAll(selector, keys...)
})
}
// RemoveAllNoCache removes all records with given selector, and returns a mgo.ChangeInfo.
func (mm *Model) RemoveAllNoCache(selector interface{}) (*mgo.ChangeInfo, error) {
return mm.change(func(c CachedCollection) (*mgo.ChangeInfo, error) {
return c.RemoveAllNoCache(selector)
})
}
// RemoveId removes a record with given id, and removes cache with given keys.
func (mm *Model) RemoveId(id interface{}, keys ...string) error {
return mm.execute(func(c CachedCollection) error {
return c.RemoveId(id, keys...)
})
}
// RemoveIdNoCache removes a record with given id.
func (mm *Model) RemoveIdNoCache(id interface{}) error {
return mm.execute(func(c CachedCollection) error {
return c.RemoveIdNoCache(id)
})
}
// SetCache sets the cache with given key and value.
func (mm *Model) SetCache(key string, v interface{}) error {
return mm.cache.Set(key, v)
}
// Update updates the record with given selector, and delete cache with given keys.
func (mm *Model) Update(selector, update interface{}, keys ...string) error {
return mm.execute(func(c CachedCollection) error {
return c.Update(selector, update, keys...)
})
}
// UpdateNoCache updates the record with given selector.
func (mm *Model) UpdateNoCache(selector, update interface{}) error {
return mm.execute(func(c CachedCollection) error {
return c.UpdateNoCache(selector, update)
})
}
// UpdateId updates the record with given id, and delete cache with given keys.
func (mm *Model) UpdateId(id, update interface{}, keys ...string) error {
return mm.execute(func(c CachedCollection) error {
return c.UpdateId(id, update, keys...)
})
}
// UpdateIdNoCache updates the record with given id.
func (mm *Model) UpdateIdNoCache(id, update interface{}) error {
return mm.execute(func(c CachedCollection) error {
return c.UpdateIdNoCache(id, update)
})
}
// Upsert upserts a record with given selector, and delete cache with given keys.
func (mm *Model) Upsert(selector, update interface{}, keys ...string) (*mgo.ChangeInfo, error) {
return mm.change(func(c CachedCollection) (*mgo.ChangeInfo, error) {
return c.Upsert(selector, update, keys...)
})
}
// UpsertNoCache upserts a record with given selector.
func (mm *Model) UpsertNoCache(selector, update interface{}) (*mgo.ChangeInfo, error) {
return mm.change(func(c CachedCollection) (*mgo.ChangeInfo, error) {
return c.UpsertNoCache(selector, update)
})
}
func (mm *Model) change(fn func(c CachedCollection) (*mgo.ChangeInfo, error)) (*mgo.ChangeInfo, error) {
session, err := mm.TakeSession()
if err != nil {
return nil, err
}
defer mm.PutSession(session)
return fn(mm.GetCollection(session))
}
func (mm *Model) execute(fn func(c CachedCollection) error) error {
session, err := mm.TakeSession()
if err != nil {
return err
}
defer mm.PutSession(session)
return fn(mm.GetCollection(session))
}
func (mm *Model) executeInt(fn func(c CachedCollection) (int, error)) (int, error) {
session, err := mm.TakeSession()
if err != nil {
return 0, err
}
defer mm.PutSession(session)
return fn(mm.GetCollection(session))
}
func (mm *Model) pipe(fn func(c CachedCollection) mongo.Pipe) (mongo.Pipe, error) {
session, err := mm.TakeSession()
if err != nil {
return nil, err
}
defer mm.PutSession(session)
return fn(mm.GetCollection(session)), nil
}
func createModel(url, collection string, c cache.Cache,
create func(mongo.Collection) CachedCollection) (*Model, error) {
model, err := mongo.NewModel(url, collection)
if err != nil {
return nil, err
}
return &Model{
Model: model,
cache: c,
generateCollection: func(session *mgo.Session) CachedCollection {
collection := model.GetCollection(session)
return create(collection)
},
}, nil
}

View File

@@ -40,6 +40,16 @@ func TestRedisConf(t *testing.T) {
},
ok: true,
},
{
name: "ok",
RedisConf: RedisConf{
Host: "localhost:6379",
Type: ClusterType,
Pass: "pwd",
Tls: true,
},
ok: true,
},
}
for _, test := range tests {

View File

@@ -25,7 +25,7 @@ const spanName = "redis"
var (
startTimeKey = contextKey("startTime")
durationHook = hook{tracer: otel.GetTracerProvider().Tracer(trace.TraceName)}
durationHook = hook{tracer: otel.Tracer(trace.TraceName)}
redisCmdsAttributeKey = attribute.Key("redis.cmds")
)

View File

@@ -91,13 +91,16 @@ func TestHookProcessPipelineCase1(t *testing.T) {
log.SetOutput(&buf)
defer log.SetOutput(writer)
_, err := durationHook.BeforeProcessPipeline(context.Background(), []red.Cmder{})
assert.NoError(t, err)
ctx, err := durationHook.BeforeProcessPipeline(context.Background(), []red.Cmder{
red.NewCmd(context.Background()),
})
assert.NoError(t, err)
assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name())
assert.Nil(t, durationHook.AfterProcessPipeline(ctx, []red.Cmder{
assert.NoError(t, durationHook.AfterProcessPipeline(ctx, []red.Cmder{}))
assert.NoError(t, durationHook.AfterProcessPipeline(ctx, []red.Cmder{
red.NewCmd(context.Background()),
}))
assert.False(t, strings.Contains(buf.String(), "slow"))

View File

@@ -42,6 +42,12 @@ type (
Score int64
}
// A FloatPair is a key/pair for float set used in redis zet.
FloatPair struct {
Key string
Score float64
}
// Redis defines a redis node/cluster. It is thread-safe.
Redis struct {
Addr string
@@ -786,6 +792,28 @@ func (s *Redis) HincrbyCtx(ctx context.Context, key, field string, increment int
return
}
// HincrbyFloat is the implementation of redis hincrby command.
func (s *Redis) HincrbyFloat(key, field string, increment float64) (float64, error) {
return s.HincrbyFloatCtx(context.Background(), key, field, increment)
}
// HincrbyFloatCtx is the implementation of redis hincrby command.
func (s *Redis) HincrbyFloatCtx(ctx context.Context, key, field string, increment float64) (val float64, err error) {
err = s.brk.DoWithAcceptable(func() error {
conn, err := getRedis(s)
if err != nil {
return err
}
val, err = conn.HIncrByFloat(ctx, key, field, increment).Result()
if err != nil {
return err
}
return nil
}, acceptable)
return
}
// Hkeys is the implementation of redis hkeys command.
func (s *Redis) Hkeys(key string) ([]string, error) {
return s.HkeysCtx(context.Background(), key)
@@ -997,6 +1025,26 @@ func (s *Redis) IncrbyCtx(ctx context.Context, key string, increment int64) (val
return
}
// IncrbyFloat is the implementation of redis incrby command.
func (s *Redis) IncrbyFloat(key string, increment float64) (float64, error) {
return s.IncrbyFloatCtx(context.Background(), key, increment)
}
// IncrbyFloatCtx is the implementation of redis incrby command.
func (s *Redis) IncrbyFloatCtx(ctx context.Context, key string, increment float64) (val float64, err error) {
err = s.brk.DoWithAcceptable(func() error {
conn, err := getRedis(s)
if err != nil {
return err
}
val, err = conn.IncrByFloat(ctx, key, increment).Result()
return err
}, acceptable)
return
}
// Keys is the implementation of redis keys command.
func (s *Redis) Keys(pattern string) ([]string, error) {
return s.KeysCtx(context.Background(), pattern)
@@ -1849,17 +1897,17 @@ func (s *Redis) Zadd(key string, score int64, value string) (bool, error) {
return s.ZaddCtx(context.Background(), key, score, value)
}
// ZaddFloat is the implementation of redis zadd command.
func (s *Redis) ZaddFloat(key string, score float64, value string) (bool, error) {
return s.ZaddFloatCtx(context.Background(), key, score, value)
}
// ZaddCtx is the implementation of redis zadd command.
func (s *Redis) ZaddCtx(ctx context.Context, key string, score int64, value string) (
val bool, err error) {
return s.ZaddFloatCtx(ctx, key, float64(score), value)
}
// ZaddFloat is the implementation of redis zadd command.
func (s *Redis) ZaddFloat(key string, score float64, value string) (bool, error) {
return s.ZaddFloatCtx(context.Background(), key, score, value)
}
// ZaddFloatCtx is the implementation of redis zadd command.
func (s *Redis) ZaddFloatCtx(ctx context.Context, key string, score float64, value string) (
val bool, err error) {
@@ -2017,6 +2065,47 @@ func (s *Redis) ZscoreCtx(ctx context.Context, key, value string) (val int64, er
return
}
// ZscoreByFloat is the implementation of redis zscore command score by float.
func (s *Redis) ZscoreByFloat(key, value string) (float64, error) {
return s.ZscoreByFloatCtx(context.Background(), key, value)
}
// ZscoreByFloatCtx is the implementation of redis zscore command score by float.
func (s *Redis) ZscoreByFloatCtx(ctx context.Context, key, value string) (val float64, err error) {
err = s.brk.DoWithAcceptable(func() error {
conn, err := getRedis(s)
if err != nil {
return err
}
val, err = conn.ZScore(ctx, key, value).Result()
return err
}, acceptable)
return
}
// Zscan is the implementation of redis zscan command.
func (s *Redis) Zscan(key string, cursor uint64, match string, count int64) (
keys []string, cur uint64, err error) {
return s.ZscanCtx(context.Background(), key, cursor, match, count)
}
// ZscanCtx is the implementation of redis zscan command.
func (s *Redis) ZscanCtx(ctx context.Context, key string, cursor uint64, match string, count int64) (
keys []string, cur uint64, err error) {
err = s.brk.DoWithAcceptable(func() error {
conn, err := getRedis(s)
if err != nil {
return err
}
keys, cur, err = conn.ZScan(ctx, key, cursor, match, count).Result()
return err
}, acceptable)
return
}
// Zrank is the implementation of redis zrank command.
func (s *Redis) Zrank(key, field string) (int64, error) {
return s.ZrankCtx(context.Background(), key, field)
@@ -2162,13 +2251,52 @@ func (s *Redis) ZrangeWithScoresCtx(ctx context.Context, key string, start, stop
return
}
// ZrangeWithScoresByFloat is the implementation of redis zrange command with scores by float64.
func (s *Redis) ZrangeWithScoresByFloat(key string, start, stop int64) ([]FloatPair, error) {
return s.ZrangeWithScoresByFloatCtx(context.Background(), key, start, stop)
}
// ZrangeWithScoresByFloatCtx is the implementation of redis zrange command with scores by float64.
func (s *Redis) ZrangeWithScoresByFloatCtx(ctx context.Context, key string, start, stop int64) (
val []FloatPair, err error) {
err = s.brk.DoWithAcceptable(func() error {
conn, err := getRedis(s)
if err != nil {
return err
}
v, err := conn.ZRangeWithScores(ctx, key, start, stop).Result()
if err != nil {
return err
}
val = toFloatPairs(v)
return nil
}, acceptable)
return
}
// ZRevRangeWithScores is the implementation of redis zrevrange command with scores.
// Deprecated: use ZrevrangeWithScores instead.
func (s *Redis) ZRevRangeWithScores(key string, start, stop int64) ([]Pair, error) {
return s.ZRevRangeWithScoresCtx(context.Background(), key, start, stop)
return s.ZrevrangeWithScoresCtx(context.Background(), key, start, stop)
}
// ZrevrangeWithScores is the implementation of redis zrevrange command with scores.
func (s *Redis) ZrevrangeWithScores(key string, start, stop int64) ([]Pair, error) {
return s.ZrevrangeWithScoresCtx(context.Background(), key, start, stop)
}
// ZRevRangeWithScoresCtx is the implementation of redis zrevrange command with scores.
// Deprecated: use ZrevrangeWithScoresCtx instead.
func (s *Redis) ZRevRangeWithScoresCtx(ctx context.Context, key string, start, stop int64) (
val []Pair, err error) {
return s.ZrevrangeWithScoresCtx(ctx, key, start, stop)
}
// ZrevrangeWithScoresCtx is the implementation of redis zrevrange command with scores.
func (s *Redis) ZrevrangeWithScoresCtx(ctx context.Context, key string, start, stop int64) (
val []Pair, err error) {
err = s.brk.DoWithAcceptable(func() error {
conn, err := getRedis(s)
@@ -2188,6 +2316,45 @@ func (s *Redis) ZRevRangeWithScoresCtx(ctx context.Context, key string, start, s
return
}
// ZRevRangeWithScoresByFloat is the implementation of redis zrevrange command with scores by float.
// Deprecated: use ZrevrangeWithScoresByFloat instead.
func (s *Redis) ZRevRangeWithScoresByFloat(key string, start, stop int64) ([]FloatPair, error) {
return s.ZrevrangeWithScoresByFloatCtx(context.Background(), key, start, stop)
}
// ZrevrangeWithScoresByFloat is the implementation of redis zrevrange command with scores by float.
func (s *Redis) ZrevrangeWithScoresByFloat(key string, start, stop int64) ([]FloatPair, error) {
return s.ZrevrangeWithScoresByFloatCtx(context.Background(), key, start, stop)
}
// ZRevRangeWithScoresByFloatCtx is the implementation of redis zrevrange command with scores by float.
// Deprecated: use ZrevrangeWithScoresByFloatCtx instead.
func (s *Redis) ZRevRangeWithScoresByFloatCtx(ctx context.Context, key string, start, stop int64) (
val []FloatPair, err error) {
return s.ZrevrangeWithScoresByFloatCtx(ctx, key, start, stop)
}
// ZrevrangeWithScoresByFloatCtx is the implementation of redis zrevrange command with scores by float.
func (s *Redis) ZrevrangeWithScoresByFloatCtx(ctx context.Context, key string, start, stop int64) (
val []FloatPair, err error) {
err = s.brk.DoWithAcceptable(func() error {
conn, err := getRedis(s)
if err != nil {
return err
}
v, err := conn.ZRevRangeWithScores(ctx, key, start, stop).Result()
if err != nil {
return err
}
val = toFloatPairs(v)
return nil
}, acceptable)
return
}
// ZrangebyscoreWithScores is the implementation of redis zrangebyscore command with scores.
func (s *Redis) ZrangebyscoreWithScores(key string, start, stop int64) ([]Pair, error) {
return s.ZrangebyscoreWithScoresCtx(context.Background(), key, start, stop)
@@ -2217,6 +2384,35 @@ func (s *Redis) ZrangebyscoreWithScoresCtx(ctx context.Context, key string, star
return
}
// ZrangebyscoreWithScoresByFloat is the implementation of redis zrangebyscore command with scores by float.
func (s *Redis) ZrangebyscoreWithScoresByFloat(key string, start, stop float64) ([]FloatPair, error) {
return s.ZrangebyscoreWithScoresByFloatCtx(context.Background(), key, start, stop)
}
// ZrangebyscoreWithScoresByFloatCtx is the implementation of redis zrangebyscore command with scores by float.
func (s *Redis) ZrangebyscoreWithScoresByFloatCtx(ctx context.Context, key string, start, stop float64) (
val []FloatPair, err error) {
err = s.brk.DoWithAcceptable(func() error {
conn, err := getRedis(s)
if err != nil {
return err
}
v, err := conn.ZRangeByScoreWithScores(ctx, key, &red.ZRangeBy{
Min: strconv.FormatFloat(start, 'f', -1, 64),
Max: strconv.FormatFloat(stop, 'f', -1, 64),
}).Result()
if err != nil {
return err
}
val = toFloatPairs(v)
return nil
}, acceptable)
return
}
// ZrangebyscoreWithScoresAndLimit is the implementation of redis zrangebyscore command
// with scores and limit.
func (s *Redis) ZrangebyscoreWithScoresAndLimit(key string, start, stop int64,
@@ -2255,6 +2451,45 @@ func (s *Redis) ZrangebyscoreWithScoresAndLimitCtx(ctx context.Context, key stri
return
}
// ZrangebyscoreWithScoresByFloatAndLimit is the implementation of redis zrangebyscore command
// with scores by float and limit.
func (s *Redis) ZrangebyscoreWithScoresByFloatAndLimit(key string, start, stop float64,
page, size int) ([]FloatPair, error) {
return s.ZrangebyscoreWithScoresByFloatAndLimitCtx(context.Background(),
key, start, stop, page, size)
}
// ZrangebyscoreWithScoresByFloatAndLimitCtx is the implementation of redis zrangebyscore command
// with scores by float and limit.
func (s *Redis) ZrangebyscoreWithScoresByFloatAndLimitCtx(ctx context.Context, key string, start,
stop float64, page, size int) (val []FloatPair, err error) {
err = s.brk.DoWithAcceptable(func() error {
if size <= 0 {
return nil
}
conn, err := getRedis(s)
if err != nil {
return err
}
v, err := conn.ZRangeByScoreWithScores(ctx, key, &red.ZRangeBy{
Min: strconv.FormatFloat(start, 'f', -1, 64),
Max: strconv.FormatFloat(stop, 'f', -1, 64),
Offset: int64(page * size),
Count: int64(size),
}).Result()
if err != nil {
return err
}
val = toFloatPairs(v)
return nil
}, acceptable)
return
}
// Zrevrange is the implementation of redis zrevrange command.
func (s *Redis) Zrevrange(key string, start, stop int64) ([]string, error) {
return s.ZrevrangeCtx(context.Background(), key, start, stop)
@@ -2305,11 +2540,42 @@ func (s *Redis) ZrevrangebyscoreWithScoresCtx(ctx context.Context, key string, s
return
}
// ZrevrangebyscoreWithScoresByFloat is the implementation of redis zrevrangebyscore command with scores by float.
func (s *Redis) ZrevrangebyscoreWithScoresByFloat(key string, start, stop float64) (
[]FloatPair, error) {
return s.ZrevrangebyscoreWithScoresByFloatCtx(context.Background(), key, start, stop)
}
// ZrevrangebyscoreWithScoresByFloatCtx is the implementation of redis zrevrangebyscore command with scores by float.
func (s *Redis) ZrevrangebyscoreWithScoresByFloatCtx(ctx context.Context, key string,
start, stop float64) (val []FloatPair, err error) {
err = s.brk.DoWithAcceptable(func() error {
conn, err := getRedis(s)
if err != nil {
return err
}
v, err := conn.ZRevRangeByScoreWithScores(ctx, key, &red.ZRangeBy{
Min: strconv.FormatFloat(start, 'f', -1, 64),
Max: strconv.FormatFloat(stop, 'f', -1, 64),
}).Result()
if err != nil {
return err
}
val = toFloatPairs(v)
return nil
}, acceptable)
return
}
// ZrevrangebyscoreWithScoresAndLimit is the implementation of redis zrevrangebyscore command
// with scores and limit.
func (s *Redis) ZrevrangebyscoreWithScoresAndLimit(key string, start, stop int64,
page, size int) ([]Pair, error) {
return s.ZrevrangebyscoreWithScoresAndLimitCtx(context.Background(), key, start, stop, page, size)
return s.ZrevrangebyscoreWithScoresAndLimitCtx(context.Background(),
key, start, stop, page, size)
}
// ZrevrangebyscoreWithScoresAndLimitCtx is the implementation of redis zrevrangebyscore command
@@ -2343,6 +2609,45 @@ func (s *Redis) ZrevrangebyscoreWithScoresAndLimitCtx(ctx context.Context, key s
return
}
// ZrevrangebyscoreWithScoresByFloatAndLimit is the implementation of redis zrevrangebyscore command
// with scores by float and limit.
func (s *Redis) ZrevrangebyscoreWithScoresByFloatAndLimit(key string, start, stop float64,
page, size int) ([]FloatPair, error) {
return s.ZrevrangebyscoreWithScoresByFloatAndLimitCtx(context.Background(),
key, start, stop, page, size)
}
// ZrevrangebyscoreWithScoresByFloatAndLimitCtx is the implementation of redis zrevrangebyscore command
// with scores by float and limit.
func (s *Redis) ZrevrangebyscoreWithScoresByFloatAndLimitCtx(ctx context.Context, key string,
start, stop float64, page, size int) (val []FloatPair, err error) {
err = s.brk.DoWithAcceptable(func() error {
if size <= 0 {
return nil
}
conn, err := getRedis(s)
if err != nil {
return err
}
v, err := conn.ZRevRangeByScoreWithScores(ctx, key, &red.ZRangeBy{
Min: strconv.FormatFloat(start, 'f', -1, 64),
Max: strconv.FormatFloat(stop, 'f', -1, 64),
Offset: int64(page * size),
Count: int64(size),
}).Result()
if err != nil {
return err
}
val = toFloatPairs(v)
return nil
}, acceptable)
return
}
// Zrevrank is the implementation of redis zrevrank command.
func (s *Redis) Zrevrank(key, field string) (int64, error) {
return s.ZrevrankCtx(context.Background(), key, field)
@@ -2444,19 +2749,43 @@ func toPairs(vals []red.Z) []Pair {
return pairs
}
func toStrings(vals []interface{}) []string {
ret := make([]string, len(vals))
func toFloatPairs(vals []red.Z) []FloatPair {
pairs := make([]FloatPair, len(vals))
for i, val := range vals {
if val == nil {
ret[i] = ""
} else {
switch val := val.(type) {
case string:
ret[i] = val
default:
ret[i] = mapping.Repr(val)
switch member := val.Member.(type) {
case string:
pairs[i] = FloatPair{
Key: member,
Score: val.Score,
}
default:
pairs[i] = FloatPair{
Key: mapping.Repr(val.Member),
Score: val.Score,
}
}
}
return pairs
}
func toStrings(vals []interface{}) []string {
ret := make([]string, len(vals))
for i, val := range vals {
if val == nil {
ret[i] = ""
continue
}
switch val := val.(type) {
case string:
ret[i] = val
default:
ret[i] = mapping.Repr(val)
}
}
return ret
}

File diff suppressed because it is too large Load Diff

View File

@@ -8,12 +8,39 @@ import (
)
func TestBlockingNode(t *testing.T) {
r, err := miniredis.Run()
assert.Nil(t, err)
node, err := CreateBlockingNode(New(r.Addr()))
assert.Nil(t, err)
node.Close()
node, err = CreateBlockingNode(New(r.Addr(), Cluster()))
assert.Nil(t, err)
node.Close()
t.Run("test blocking node", func(t *testing.T) {
r, err := miniredis.Run()
assert.NoError(t, err)
defer r.Close()
node, err := CreateBlockingNode(New(r.Addr()))
assert.NoError(t, err)
node.Close()
// close again to make sure it's safe
assert.NotPanics(t, func() {
node.Close()
})
})
t.Run("test blocking node with cluster", func(t *testing.T) {
r, err := miniredis.Run()
assert.NoError(t, err)
defer r.Close()
node, err := CreateBlockingNode(New(r.Addr(), Cluster(), WithTLS()))
assert.NoError(t, err)
node.Close()
assert.NotPanics(t, func() {
node.Close()
})
})
t.Run("test blocking node with bad type", func(t *testing.T) {
r, err := miniredis.Run()
assert.NoError(t, err)
defer r.Close()
_, err = CreateBlockingNode(New(r.Addr(), badType()))
assert.Error(t, err)
})
}

View File

@@ -14,7 +14,7 @@ import (
var sqlAttributeKey = attribute.Key("sql.method")
func startSpan(ctx context.Context, method string) (context.Context, oteltrace.Span) {
tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
tracer := otel.Tracer(trace.TraceName)
start, span := tracer.Start(ctx,
spanName,
oteltrace.WithSpanKind(oteltrace.SpanKindClient),

View File

@@ -10,17 +10,18 @@ import (
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/exporters/jaeger"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
"go.opentelemetry.io/otel/exporters/zipkin"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
"google.golang.org/grpc"
)
const (
kindJaeger = "jaeger"
kindZipkin = "zipkin"
kindGrpc = "grpc"
kindJaeger = "jaeger"
kindZipkin = "zipkin"
kindOtlpGrpc = "otlpgrpc"
kindOtlpHttp = "otlphttp"
)
var (
@@ -59,12 +60,24 @@ func createExporter(c Config) (sdktrace.SpanExporter, error) {
return jaeger.New(jaeger.WithCollectorEndpoint(jaeger.WithEndpoint(c.Endpoint)))
case kindZipkin:
return zipkin.New(c.Endpoint)
case kindGrpc:
return otlptracegrpc.NewUnstarted(
case kindOtlpGrpc:
// Always treat trace exporter as optional component, so we use nonblock here,
// otherwise this would slow down app start up even set a dial timeout here when
// endpoint can not reach.
// If the connection not dial success, the global otel ErrorHandler will catch error
// when reporting data like other exporters.
return otlptracegrpc.New(
context.Background(),
otlptracegrpc.WithInsecure(),
otlptracegrpc.WithEndpoint(c.Endpoint),
otlptracegrpc.WithDialOption(grpc.WithBlock()),
), nil
)
case kindOtlpHttp:
// Not support flexible configuration now.
return otlptracehttp.New(
context.Background(),
otlptracehttp.WithInsecure(),
otlptracehttp.WithEndpoint(c.Endpoint),
)
default:
return nil, fmt.Errorf("unknown exporter: %s", c.Batcher)
}

View File

@@ -14,6 +14,7 @@ func TestStartAgent(t *testing.T) {
endpoint1 = "localhost:1234"
endpoint2 = "remotehost:1234"
endpoint3 = "localhost:1235"
endpoint4 = "localhost:1236"
)
c1 := Config{
Name: "foo",
@@ -36,7 +37,12 @@ func TestStartAgent(t *testing.T) {
c5 := Config{
Name: "grpc",
Endpoint: endpoint3,
Batcher: "grpc",
Batcher: kindOtlpGrpc,
}
c6 := Config{
Name: "otlphttp",
Endpoint: endpoint4,
Batcher: kindOtlpHttp,
}
StartAgent(c1)
@@ -45,12 +51,13 @@ func TestStartAgent(t *testing.T) {
StartAgent(c3)
StartAgent(c4)
StartAgent(c5)
StartAgent(c6)
lock.Lock()
defer lock.Unlock()
// because remotehost cannot be resolved
assert.Equal(t, 3, len(agents))
assert.Equal(t, 4, len(agents))
_, ok := agents[""]
assert.True(t, ok)
_, ok = agents[endpoint1]

View File

@@ -8,5 +8,5 @@ type Config struct {
Name string `json:",optional"`
Endpoint string `json:",optional"`
Sampler float64 `json:",default=1.0"`
Batcher string `json:",default=jaeger,options=jaeger|zipkin|grpc"`
Batcher string `json:",default=jaeger,options=jaeger|zipkin|otlpgrpc|otlphttp"`
}

8
go.mod
View File

@@ -5,17 +5,17 @@ go 1.16
require (
github.com/ClickHouse/clickhouse-go/v2 v2.0.14
github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/alicebob/miniredis/v2 v2.23.1
github.com/alicebob/miniredis/v2 v2.30.0
github.com/fatih/color v1.13.0
github.com/felixge/fgprof v0.9.3
github.com/fullstorydev/grpcurl v1.8.7
github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8
github.com/go-redis/redis/v8 v8.11.5
github.com/go-sql-driver/mysql v1.7.0
github.com/golang-jwt/jwt/v4 v4.4.3
github.com/golang/mock v1.6.0
github.com/golang/protobuf v1.5.2
github.com/google/uuid v1.3.0
github.com/jhump/protoreflect v1.14.0
github.com/jhump/protoreflect v1.14.1
github.com/lib/pq v1.10.7
github.com/olekukonko/tablewriter v0.0.5
github.com/pelletier/go-toml/v2 v2.0.6
@@ -28,6 +28,7 @@ require (
go.opentelemetry.io/otel v1.10.0
go.opentelemetry.io/otel/exporters/jaeger v1.10.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.10.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.10.0
go.opentelemetry.io/otel/exporters/zipkin v1.10.0
go.opentelemetry.io/otel/sdk v1.10.0
go.opentelemetry.io/otel/trace v1.10.0
@@ -48,7 +49,6 @@ require (
)
require (
github.com/felixge/fgprof v0.9.3
github.com/google/gofuzz v1.2.0 // indirect
github.com/mattn/go-runewidth v0.0.13 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect

12
go.sum
View File

@@ -389,8 +389,8 @@ github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRF
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
github.com/alicebob/miniredis/v2 v2.23.1 h1:jR6wZggBxwWygeXcdNyguCOCIjPsZyNUNlAkTx2fu0U=
github.com/alicebob/miniredis/v2 v2.23.1/go.mod h1:84TWKZlxYkfgMucPBf5SOQBYJceZeQRFIaQgNMiCX6Q=
github.com/alicebob/miniredis/v2 v2.30.0 h1:uA3uhDbCxfO9+DI/DuGeAMr9qI+noVWwGPNTFuKID5M=
github.com/alicebob/miniredis/v2 v2.30.0/go.mod h1:84TWKZlxYkfgMucPBf5SOQBYJceZeQRFIaQgNMiCX6Q=
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
@@ -465,8 +465,6 @@ github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4
github.com/fullstorydev/grpcurl v1.8.7 h1:xJWosq3BQovQ4QrdPO72OrPiWuGgEsxY8ldYsJbPrqI=
github.com/fullstorydev/grpcurl v1.8.7/go.mod h1:pVtM4qe3CMoLaIzYS8uvTuDj2jVYmXqMUkZeijnXp/E=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8 h1:DujepqpGd1hyOd7aW59XpK7Qymp8iy83xq74fLr21is=
github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q=
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
@@ -640,8 +638,8 @@ github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+
github.com/jhump/goprotoc v0.5.0/go.mod h1:VrbvcYrQOrTi3i0Vf+m+oqQWk9l72mjkJCYo7UvLHRQ=
github.com/jhump/protoreflect v1.11.0/go.mod h1:U7aMIjN0NWq9swDP7xDdoMfRHb35uiuTd3Z9nFXJf5E=
github.com/jhump/protoreflect v1.12.0/go.mod h1:JytZfP5d0r8pVNLZvai7U/MCuTWITgrI4tTg7puQFKI=
github.com/jhump/protoreflect v1.14.0 h1:MBbQK392K3u8NTLbKOCIi3XdI+y+c6yt5oMq0X3xviw=
github.com/jhump/protoreflect v1.14.0/go.mod h1:JytZfP5d0r8pVNLZvai7U/MCuTWITgrI4tTg7puQFKI=
github.com/jhump/protoreflect v1.14.1 h1:N88q7JkxTHWFEqReuTsYH1dPIwXxA0ITNQp7avLY10s=
github.com/jhump/protoreflect v1.14.1/go.mod h1:JytZfP5d0r8pVNLZvai7U/MCuTWITgrI4tTg7puQFKI=
github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks=
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
@@ -860,6 +858,8 @@ go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.10.0 h1:pDDYmo0QadUPal5fwXo
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.10.0/go.mod h1:Krqnjl22jUJ0HgMzw5eveuCvFDXY4nSYb4F8t5gdrag=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.10.0 h1:KtiUEhQmj/Pa874bVYKGNVdq8NPKiacPbaRRtgXi+t4=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.10.0/go.mod h1:OfUCyyIiDvNXHWpcWgbF+MWvqPZiNa3YDEnivcnYsV0=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.10.0 h1:S8DedULB3gp93Rh+9Z+7NTEv+6Id/KYS7LDyipZ9iCE=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.10.0/go.mod h1:5WV40MLWwvWlGP7Xm8g3pMcg0pKOUY609qxJn8y7LmM=
go.opentelemetry.io/otel/exporters/zipkin v1.10.0 h1:HcPAFsFpEBKF+G5NIOA+gBsxifd3Ej+wb+KsdBLa15E=
go.opentelemetry.io/otel/exporters/zipkin v1.10.0/go.mod h1:HdfvgwcOoCB0+zzrTHycW6btjK0zNpkz2oTGO815SCI=
go.opentelemetry.io/otel/sdk v1.10.0 h1:jZ6K7sVn04kk/3DNUdJ4mqRlGDiXAVuIG+MMENpTNdY=

View File

@@ -19,13 +19,13 @@ var once sync.Once
// Server is inner http server, expose some useful observability information of app.
// For example health check, metrics and pprof.
type Server struct {
config *Config
config Config
server *http.ServeMux
routes []string
}
// NewServer returns a new inner http Server.
func NewServer(config *Config) *Server {
func NewServer(config Config) *Server {
return &Server{
config: config,
server: http.NewServeMux(),
@@ -76,7 +76,7 @@ func (s *Server) StartAsync() {
func StartAgent(c Config) {
once.Do(func() {
if c.Enabled {
s := NewServer(&c)
s := NewServer(c)
s.StartAsync()
}
})

View File

@@ -23,7 +23,7 @@
>
> `GOPROXY=https://goproxy.cn/,direct go install github.com/zeromicro/go-zero/tools/goctl@latest`
>
> `goctl migrate —verbose —version v1.4.0`
> `goctl migrate —verbose —version v1.4.3`
## 0. go-zero 介绍
@@ -293,6 +293,9 @@ go-zero 已被许多公司用于生产部署,接入场景如在线教育、电
>78. ZeroCMF
>79. 安徽寻梦投资发展集团
>80. 广州腾思信息科技有限公司
>81. 广州机智云物联网科技有限公司
>82. 厦门亿联网络技术股份有限公司
>83. 北京麦芽田网络科技有限公司
如果贵公司也已使用 go-zero欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。

View File

@@ -111,7 +111,7 @@ go install github.com/zeromicro/go-zero/tools/goctl@latest
```
```shell
goctl migrate —verbose —version v1.4.0
goctl migrate —verbose —version v1.4.3
```
## Quick Start
@@ -154,7 +154,7 @@ goctl migrate —verbose —version v1.4.0
```go
type (
Request {
Name string `path:"name,options=you|me"` // parameters are auto validated
Name string `path:"name,options=[you,me]"` // parameters are auto validated
}
Response {
@@ -270,8 +270,6 @@ go-zero enlisted in the [CNCF Cloud Native Landscape](https://landscape.cncf.io/
If you like or are using this project to learn or start your solution, please give it a star. Thanks!
[![Star History Chart](https://api.star-history.com/svg?repos=zeromicro/go-zero&type=Date)](#go-zero)
## Buy me a coffee
<a href="https://www.buymeacoffee.com/kevwan" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 60px !important;width: 217px !important;" ></a>

View File

@@ -7,6 +7,21 @@ import (
)
type (
// MiddlewaresConf is the config of middlewares.
MiddlewaresConf struct {
Trace bool `json:",default=true"`
Log bool `json:",default=true"`
Prometheus bool `json:",default=true"`
MaxConns bool `json:",default=true"`
Breaker bool `json:",default=true"`
Shedding bool `json:",default=true"`
Timeout bool `json:",default=true"`
Recover bool `json:",default=true"`
Metrics bool `json:",default=true"`
MaxBytes bool `json:",default=true"`
Gunzip bool `json:",default=true"`
}
// A PrivateKeyConf is a private key config.
PrivateKeyConf struct {
Fingerprint string
@@ -40,5 +55,9 @@ type (
Timeout int64 `json:",default=3000"`
CpuThreshold int64 `json:",default=900,range=[0:1000]"`
Signature SignatureConf `json:",optional"`
// There are default values for all the items in Middlewares.
Middlewares MiddlewaresConf
// TraceIgnorePaths is paths blacklist for trace middleware.
TraceIgnorePaths []string `json:",optional"`
}
)

View File

@@ -88,19 +88,7 @@ func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *sta
route Route, verifier func(chain.Chain) chain.Chain) error {
chn := ng.chain
if chn == nil {
chn = chain.New(
handler.TracingHandler(ng.conf.Name, route.Path),
ng.getLogHandler(),
handler.PrometheusHandler(route.Path),
handler.MaxConns(ng.conf.MaxConns),
handler.BreakerHandler(route.Method, route.Path, metrics),
handler.SheddingHandler(ng.getShedder(fr.priority), metrics),
handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)),
handler.RecoverHandler,
handler.MetricHandler(metrics),
handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)),
handler.GunzipHandler,
)
chn = ng.buildChainWithNativeMiddlewares(fr, route, metrics)
}
chn = ng.appendAuthHandler(fr, chn, verifier)
@@ -125,6 +113,49 @@ func (ng *engine) bindRoutes(router httpx.Router) error {
return nil
}
func (ng *engine) buildChainWithNativeMiddlewares(fr featuredRoutes, route Route,
metrics *stat.Metrics) chain.Chain {
chn := chain.New()
if ng.conf.Middlewares.Trace {
chn = chn.Append(handler.TraceHandler(ng.conf.Name,
route.Path,
handler.WithTraceIgnorePaths(ng.conf.TraceIgnorePaths)))
}
if ng.conf.Middlewares.Log {
chn = chn.Append(ng.getLogHandler())
}
if ng.conf.Middlewares.Prometheus {
chn = chn.Append(handler.PrometheusHandler(route.Path))
}
if ng.conf.Middlewares.MaxConns {
chn = chn.Append(handler.MaxConnsHandler(ng.conf.MaxConns))
}
if ng.conf.Middlewares.Breaker {
chn = chn.Append(handler.BreakerHandler(route.Method, route.Path, metrics))
}
if ng.conf.Middlewares.Shedding {
chn = chn.Append(handler.SheddingHandler(ng.getShedder(fr.priority), metrics))
}
if ng.conf.Middlewares.Timeout {
chn = chn.Append(handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)))
}
if ng.conf.Middlewares.Recover {
chn = chn.Append(handler.RecoverHandler)
}
if ng.conf.Middlewares.Metrics {
chn = chn.Append(handler.MetricHandler(metrics))
}
if ng.conf.Middlewares.MaxBytes {
chn = chn.Append(handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)))
}
if ng.conf.Middlewares.Gunzip {
chn = chn.Append(handler.GunzipHandler)
}
return chn
}
func (ng *engine) checkedMaxBytes(bytes int64) int64 {
if bytes > 0 {
return bytes
@@ -173,7 +204,9 @@ func (ng *engine) getShedder(priority bool) load.Shedder {
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
chn := chain.New(
handler.TracingHandler(ng.conf.Name, ""),
handler.TraceHandler(ng.conf.Name,
"",
handler.WithTraceIgnorePaths(ng.conf.TraceIgnorePaths)),
ng.getLogHandler(),
)
@@ -287,10 +320,9 @@ func (ng *engine) withTimeout() internal.StartOption {
// without this timeout setting, the server will time out and respond 503 Service Unavailable,
// which triggers the circuit breaker.
svr.ReadTimeout = 4 * time.Duration(timeout) * time.Millisecond / 5
// factor 0.9, to avoid clients not reading the response
// without this timeout setting, the server will time out and respond 503 Service Unavailable,
// which triggers the circuit breaker.
svr.WriteTimeout = 9 * time.Duration(timeout) * time.Millisecond / 10
// factor 1.1, to avoid servers don't have enough time to write responses.
// setting the factor less than 1.0 may lead clients not receiving the responses.
svr.WriteTimeout = 11 * time.Duration(timeout) * time.Millisecond / 10
}
}
}

View File

@@ -18,10 +18,14 @@ func TestNewEngine(t *testing.T) {
yamls := []string{
`Name: foo
Port: 54321
Middlewares:
Log: false
`,
`Name: foo
Port: 54321
CpuThreshold: 500
Middlewares:
Log: false
`,
`Name: foo
Port: 54321
@@ -323,7 +327,7 @@ func TestEngine_withTimeout(t *testing.T) {
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*9/10, svr.WriteTimeout)
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*11/10, svr.WriteTimeout)
assert.Equal(t, time.Duration(0), svr.IdleTimeout)
})
}

View File

@@ -22,6 +22,7 @@ import (
"github.com/zeromicro/go-zero/core/utils"
"github.com/zeromicro/go-zero/rest/httpx"
"github.com/zeromicro/go-zero/rest/internal"
"github.com/zeromicro/go-zero/rest/internal/response"
)
const (
@@ -31,66 +32,30 @@ const (
var slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold)
type loggedResponseWriter struct {
w http.ResponseWriter
r *http.Request
code int
}
func (w *loggedResponseWriter) Flush() {
if flusher, ok := w.w.(http.Flusher); ok {
flusher.Flush()
}
}
func (w *loggedResponseWriter) Header() http.Header {
return w.w.Header()
}
// Hijack implements the http.Hijacker interface.
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
func (w *loggedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacked, ok := w.w.(http.Hijacker); ok {
return hijacked.Hijack()
}
return nil, nil, errors.New("server doesn't support hijacking")
}
func (w *loggedResponseWriter) Write(bytes []byte) (int, error) {
return w.w.Write(bytes)
}
func (w *loggedResponseWriter) WriteHeader(code int) {
w.w.WriteHeader(code)
w.code = code
}
// LogHandler returns a middleware that logs http request and response.
func LogHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
timer := utils.NewElapsedTimer()
logs := new(internal.LogCollector)
lrw := loggedResponseWriter{
w: w,
r: r,
code: http.StatusOK,
lrw := response.WithCodeResponseWriter{
Writer: w,
Code: http.StatusOK,
}
var dup io.ReadCloser
r.Body, dup = iox.DupReadCloser(r.Body)
next.ServeHTTP(&lrw, r.WithContext(context.WithValue(r.Context(), internal.LogContext, logs)))
r.Body = dup
logBrief(r, lrw.code, timer, logs)
logBrief(r, lrw.Code, timer, logs)
})
}
type detailLoggedResponseWriter struct {
writer *loggedResponseWriter
writer *response.WithCodeResponseWriter
buf *bytes.Buffer
}
func newDetailLoggedResponseWriter(writer *loggedResponseWriter, buf *bytes.Buffer) *detailLoggedResponseWriter {
func newDetailLoggedResponseWriter(writer *response.WithCodeResponseWriter, buf *bytes.Buffer) *detailLoggedResponseWriter {
return &detailLoggedResponseWriter{
writer: writer,
buf: buf,
@@ -108,7 +73,7 @@ func (w *detailLoggedResponseWriter) Header() http.Header {
// Hijack implements the http.Hijacker interface.
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
func (w *detailLoggedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacked, ok := w.writer.w.(http.Hijacker); ok {
if hijacked, ok := w.writer.Writer.(http.Hijacker); ok {
return hijacked.Hijack()
}
@@ -129,10 +94,9 @@ func DetailedLogHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
timer := utils.NewElapsedTimer()
var buf bytes.Buffer
lrw := newDetailLoggedResponseWriter(&loggedResponseWriter{
w: w,
r: r,
code: http.StatusOK,
lrw := newDetailLoggedResponseWriter(&response.WithCodeResponseWriter{
Writer: w,
Code: http.StatusOK,
}, &buf)
var dup io.ReadCloser
@@ -203,7 +167,7 @@ func logDetails(r *http.Request, response *detailLoggedResponseWriter, timer *ut
logs *internal.LogCollector) {
var buf bytes.Buffer
duration := timer.Duration()
code := response.writer.code
code := response.writer.Code
logger := logx.WithContext(r.Context())
buf.WriteString(fmt.Sprintf("[HTTP] %s - %d - %s - %s\n=> %s\n",
r.Method, code, r.RemoteAddr, timex.ReprOfDuration(duration), dumpRequest(r)))

View File

@@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/rest/internal"
"github.com/zeromicro/go-zero/rest/internal/response"
)
func init() {
@@ -54,7 +55,7 @@ func TestLogHandlerVeryLong(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "http://localhost", &buf)
handler := LogHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Context().Value(internal.LogContext).(*internal.LogCollector).Append("anything")
io.Copy(io.Discard, r.Body)
_, _ = io.Copy(io.Discard, r.Body)
w.Header().Set("X-Test", "test")
w.WriteHeader(http.StatusServiceUnavailable)
_, err := w.Write([]byte("content"))
@@ -89,45 +90,26 @@ func TestLogHandlerSlow(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.Code)
}
}
func TestLogHandler_Hijack(t *testing.T) {
resp := httptest.NewRecorder()
writer := &loggedResponseWriter{
w: resp,
}
assert.NotPanics(t, func() {
writer.Hijack()
})
writer = &loggedResponseWriter{
w: mockedHijackable{resp},
}
assert.NotPanics(t, func() {
writer.Hijack()
})
}
func TestDetailedLogHandler_Hijack(t *testing.T) {
resp := httptest.NewRecorder()
writer := &detailLoggedResponseWriter{
writer: &loggedResponseWriter{
w: resp,
writer: &response.WithCodeResponseWriter{
Writer: resp,
},
}
assert.NotPanics(t, func() {
writer.Hijack()
_, _, _ = writer.Hijack()
})
writer = &detailLoggedResponseWriter{
writer: &loggedResponseWriter{
w: mockedHijackable{resp},
writer: &response.WithCodeResponseWriter{
Writer: mockedHijackable{resp},
},
}
assert.NotPanics(t, func() {
writer.Hijack()
_, _, _ = writer.Hijack()
})
}
func TestSetSlowThreshold(t *testing.T) {
assert.Equal(t, defaultSlowThreshold, slowThreshold.Load())
SetSlowThreshold(time.Second)

View File

@@ -8,8 +8,8 @@ import (
"github.com/zeromicro/go-zero/rest/internal"
)
// MaxConns returns a middleware that limit the concurrent connections.
func MaxConns(n int) func(http.Handler) http.Handler {
// MaxConnsHandler returns a middleware that limit the concurrent connections.
func MaxConnsHandler(n int) func(http.Handler) http.Handler {
if n <= 0 {
return func(next http.Handler) http.Handler {
return next

View File

@@ -24,7 +24,7 @@ func TestMaxConnsHandler(t *testing.T) {
done := make(chan lang.PlaceholderType)
defer close(done)
maxConns := MaxConns(conns)
maxConns := MaxConnsHandler(conns)
handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
waitGroup.Done()
<-done
@@ -54,7 +54,7 @@ func TestWithoutMaxConnsHandler(t *testing.T) {
done := make(chan lang.PlaceholderType)
defer close(done)
maxConns := MaxConns(0)
maxConns := MaxConnsHandler(0)
handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
val := r.Header.Get(key)
if val == value {

View File

@@ -0,0 +1,78 @@
package handler
import (
"net/http"
"github.com/zeromicro/go-zero/core/collection"
"github.com/zeromicro/go-zero/core/trace"
"github.com/zeromicro/go-zero/rest/internal/response"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
oteltrace "go.opentelemetry.io/otel/trace"
)
type (
// TraceOption defines the method to customize an traceOptions.
TraceOption func(options *traceOptions)
// traceOptions is TraceHandler options.
traceOptions struct {
traceIgnorePaths []string
}
)
// TraceHandler return a middleware that process the opentelemetry.
func TraceHandler(serviceName, path string, opts ...TraceOption) func(http.Handler) http.Handler {
var options traceOptions
for _, opt := range opts {
opt(&options)
}
ignorePaths := collection.NewSet()
ignorePaths.AddStr(options.traceIgnorePaths...)
return func(next http.Handler) http.Handler {
tracer := otel.Tracer(trace.TraceName)
propagator := otel.GetTextMapPropagator()
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
spanName := path
if len(spanName) == 0 {
spanName = r.URL.Path
}
if ignorePaths.Contains(spanName) {
next.ServeHTTP(w, r)
return
}
ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
spanCtx, span := tracer.Start(
ctx,
spanName,
oteltrace.WithSpanKind(oteltrace.SpanKindServer),
oteltrace.WithAttributes(semconv.HTTPServerAttributesFromHTTPRequest(
serviceName, spanName, r)...),
)
defer span.End()
// convenient for tracking error messages
propagator.Inject(spanCtx, propagation.HeaderCarrier(w.Header()))
trw := &response.WithCodeResponseWriter{Writer: w, Code: http.StatusOK}
next.ServeHTTP(trw, r.WithContext(spanCtx))
span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(trw.Code)...)
span.SetStatus(semconv.SpanStatusFromHTTPStatusCodeAndSpanKind(
trw.Code, oteltrace.SpanKindServer))
})
}
}
// WithTraceIgnorePaths specifies the traceIgnorePaths option for TraceHandler.
func WithTraceIgnorePaths(traceIgnorePaths []string) TraceOption {
return func(options *traceOptions) {
options.traceIgnorePaths = append(options.traceIgnorePaths, traceIgnorePaths...)
}
}

View File

@@ -2,8 +2,10 @@ package handler
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
@@ -25,7 +27,7 @@ func TestOtelHandler(t *testing.T) {
for _, test := range []string{"", "bar"} {
t.Run(test, func(t *testing.T) {
h := chain.New(TracingHandler("foo", test)).Then(
h := chain.New(TraceHandler("foo", test)).Then(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
span := trace.SpanFromContext(r.Context())
assert.True(t, span.SpanContext().IsValid())
@@ -59,12 +61,11 @@ func TestDontTracingSpan(t *testing.T) {
Batcher: "jaeger",
Sampler: 1.0,
})
DontTraceSpan("bar")
defer ztrace.StopAgent()
for _, test := range []string{"", "bar", "foo"} {
t.Run(test, func(t *testing.T) {
h := chain.New(TracingHandler("foo", test)).Then(
h := chain.New(TraceHandler("foo", test, WithTraceIgnorePaths([]string{"bar"}))).Then(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
span := trace.SpanFromContext(r.Context())
spanCtx := span.SpanContext()
@@ -97,3 +98,50 @@ func TestDontTracingSpan(t *testing.T) {
})
}
}
func TestTraceResponseWriter(t *testing.T) {
ztrace.StartAgent(ztrace.Config{
Name: "go-zero-test",
Endpoint: "http://localhost:14268/api/traces",
Batcher: "jaeger",
Sampler: 1.0,
})
defer ztrace.StopAgent()
for _, test := range []int{0, 200, 300, 400, 401, 500, 503} {
t.Run(strconv.Itoa(test), func(t *testing.T) {
h := chain.New(TraceHandler("foo", "bar")).Then(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
span := trace.SpanFromContext(r.Context())
spanCtx := span.SpanContext()
assert.True(t, span.IsRecording())
assert.True(t, spanCtx.IsValid())
if test != 0 {
w.WriteHeader(test)
}
w.Write([]byte("hello"))
}))
ts := httptest.NewServer(h)
defer ts.Close()
client := ts.Client()
err := func(ctx context.Context) error {
ctx, span := otel.Tracer("httptrace/client").Start(ctx, "test")
defer span.End()
req, _ := http.NewRequest("GET", ts.URL, nil)
otel.GetTextMapPropagator().Inject(ctx, propagation.HeaderCarrier(req.Header))
res, err := client.Do(req)
assert.Nil(t, err)
resBody := make([]byte, 5)
_, err = res.Body.Read(resBody)
assert.Equal(t, io.EOF, err)
assert.Equal(t, []byte("hello"), resBody, "response body fail")
return res.Body.Close()
}(context.Background())
assert.Nil(t, err)
})
}
}

View File

@@ -1,54 +0,0 @@
package handler
import (
"net/http"
"sync"
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/trace"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
oteltrace "go.opentelemetry.io/otel/trace"
)
var notTracingSpans sync.Map
// DontTraceSpan disable tracing for the specified span name.
func DontTraceSpan(spanName string) {
notTracingSpans.Store(spanName, lang.Placeholder)
}
// TracingHandler return a middleware that process the opentelemetry.
func TracingHandler(serviceName, path string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
propagator := otel.GetTextMapPropagator()
tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
spanName := path
if len(spanName) == 0 {
spanName = r.URL.Path
}
if _, ok := notTracingSpans.Load(spanName); ok {
next.ServeHTTP(w, r)
return
}
ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
spanCtx, span := tracer.Start(
ctx,
spanName,
oteltrace.WithSpanKind(oteltrace.SpanKindServer),
oteltrace.WithAttributes(semconv.HTTPServerAttributesFromHTTPRequest(
serviceName, spanName, r)...),
)
defer span.End()
// convenient for tracking error messages
propagator.Inject(spanCtx, propagation.HeaderCarrier(w.Header()))
next.ServeHTTP(w, r.WithContext(spanCtx))
})
}
}

View File

@@ -7,7 +7,6 @@ import (
"fmt"
"io"
"net/http"
"net/http/httptrace"
nurl "net/url"
"strings"
@@ -157,7 +156,7 @@ func fillPath(u *nurl.URL, val map[string]interface{}) error {
}
func request(r *http.Request, cli client) (*http.Response, error) {
tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
tracer := otel.Tracer(trace.TraceName)
propagator := otel.GetTextMapPropagator()
spanName := r.URL.Path
@@ -176,11 +175,6 @@ func request(r *http.Request, cli client) (*http.Response, error) {
respHandlers[i] = h
}
clientTrace := httptrace.ContextClientTrace(ctx)
if clientTrace != nil {
ctx = httptrace.WithClientTrace(ctx, clientTrace)
}
r = r.WithContext(ctx)
propagator.Inject(ctx, propagation.HeaderCarrier(r.Header))
@@ -196,7 +190,7 @@ func request(r *http.Request, cli client) (*http.Response, error) {
}
span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(resp.StatusCode)...)
span.SetStatus(semconv.SpanStatusFromHTTPStatusCode(resp.StatusCode))
span.SetStatus(semconv.SpanStatusFromHTTPStatusCodeAndSpanKind(resp.StatusCode, oteltrace.SpanKindClient))
return resp, err
}

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"net/http/httptest"
"net/http/httptrace"
"strings"
"testing"
"github.com/stretchr/testify/assert"
@@ -205,12 +206,14 @@ func TestDo_WithClientHttpTrace(t *testing.T) {
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer svr.Close()
enter := false
_, err := Do(httptrace.WithClientTrace(context.Background(),
&httptrace.ClientTrace{
DNSStart: func(info httptrace.DNSStartInfo) {
assert.Equal(t, "localhost", info.Host)
GetConn: func(hostPort string) {
assert.Equal(t, "127.0.0.1", strings.Split(hostPort, ":")[0])
enter = true
},
}), http.MethodGet, svr.URL, nil)
assert.Nil(t, err)
assert.True(t, enter)
}

View File

@@ -223,6 +223,22 @@ func TestParseJsonBody(t *testing.T) {
assert.Equal(t, "", v.Name)
assert.Equal(t, 0, v.Age)
})
t.Run("array body", func(t *testing.T) {
var v []struct {
Name string `json:"name"`
Age int `json:"age"`
}
body := `[{"name":"kevin", "age": 18}]`
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
r.Header.Set(ContentType, header.JsonContentType)
assert.NoError(t, ParseJsonBody(r, &v))
assert.Equal(t, 1, len(v))
assert.Equal(t, "kevin", v[0].Name)
assert.Equal(t, 18, v[0].Age)
})
}
func TestParseRequired(t *testing.T) {

View File

@@ -90,6 +90,25 @@ func (s *Server) Routes() []Route {
return routes
}
// ServeHTTP is for test purpose, allow developer to do a unit test with
// all defined router without starting an HTTP Server.
//
// For example:
//
// server := MustNewServer(...)
// server.addRoute(...) // router a
// server.addRoute(...) // router b
// server.addRoute(...) // router c
//
// r, _ := http.NewRequest(...)
// w := httptest.NewRecorder(...)
// server.ServeHTTP(w, r)
// // verify the response
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.ngin.bindRoutes(s.router)
s.router.ServeHTTP(w, r)
}
// Start starts the Server.
// Graceful shutdown is enabled by default.
// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.

View File

@@ -535,3 +535,91 @@ func TestServer_WithCors(t *testing.T) {
cr.ServeHTTP(httptest.NewRecorder(), req)
assert.Equal(t, int32(0), atomic.LoadInt32(&called))
}
func TestServer_ServeHTTP(t *testing.T) {
const configYaml = `
Name: foo
Port: 54321
`
var cnf RestConf
assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
svr, err := NewServer(cnf)
assert.Nil(t, err)
svr.AddRoutes([]Route{
{
Method: http.MethodGet,
Path: "/foo",
Handler: func(writer http.ResponseWriter, request *http.Request) {
_, _ = writer.Write([]byte("succeed"))
writer.WriteHeader(http.StatusOK)
},
},
{
Method: http.MethodGet,
Path: "/bar",
Handler: func(writer http.ResponseWriter, request *http.Request) {
_, _ = writer.Write([]byte("succeed"))
writer.WriteHeader(http.StatusOK)
},
},
{
Method: http.MethodGet,
Path: "/user/:name",
Handler: func(writer http.ResponseWriter, request *http.Request) {
var userInfo struct {
Name string `path:"name"`
}
err := httpx.Parse(request, &userInfo)
if err != nil {
_, _ = writer.Write([]byte("failed"))
writer.WriteHeader(http.StatusBadRequest)
return
}
_, _ = writer.Write([]byte("succeed"))
writer.WriteHeader(http.StatusOK)
},
},
})
testCase := []struct {
name string
path string
code int
}{
{
name: "URI : /foo",
path: "/foo",
code: http.StatusOK,
},
{
name: "URI : /bar",
path: "/bar",
code: http.StatusOK,
},
{
name: "URI : undefined path",
path: "/test",
code: http.StatusNotFound,
},
{
name: "URI : /user/:name",
path: "/user/abc",
code: http.StatusOK,
},
}
for _, test := range testCase {
t.Run(test.name, func(t *testing.T) {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", test.path, nil)
svr.ServeHTTP(w, req)
assert.Equal(t, test.code, w.Code)
})
}
}

View File

@@ -51,7 +51,7 @@ var (
Use: "new",
Short: "Fast create api service",
Example: "goctl api new [options] service-name",
Args: cobra.ExactValidArgs(1),
Args: cobra.MatchAll(cobra.ExactArgs(1), cobra.OnlyValidArgs),
RunE: func(cmd *cobra.Command, args []string) error {
return new.CreateServiceCommand(args)
},

View File

@@ -8,7 +8,6 @@ import (
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
"github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/pkg/golang"
"github.com/zeromicro/go-zero/tools/goctl/util"
"github.com/zeromicro/go-zero/tools/goctl/util/format"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
@@ -41,14 +40,10 @@ func genHandler(dir, rootPkg string, cfg *config.Config, group spec.Group, route
handler = strings.Title(handler)
logicName = pkgName
}
parentPkg, err := golang.GetParentPackage(dir)
if err != nil {
return err
}
return doGenToFile(dir, handler, cfg, group, route, handlerInfo{
PkgName: pkgName,
ImportPackages: genHandlerImports(group, route, parentPkg),
ImportPackages: genHandlerImports(group, route, rootPkg),
HandlerName: handler,
RequestType: util.Title(route.RequestTypeName()),
LogicName: logicName,

View File

@@ -2,6 +2,7 @@ package spec_test
import (
"fmt"
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
)

View File

@@ -1,9 +1,10 @@
package tsgen
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
"testing"
)
func TestGenTsType(t *testing.T) {

View File

@@ -9,7 +9,7 @@ import (
var rootCmd = &cobra.Command{
Use: "compare",
Short: "Compare the goctl commands generated results between urfave and cobra",
Args: cobra.ExactValidArgs(1),
Args: cobra.MatchAll(cobra.ExactArgs(1), cobra.OnlyValidArgs),
Run: func(cmd *cobra.Command, args []string) {
dir := args[0]
testdata.MustRun(dir)

View File

@@ -6,7 +6,7 @@ package server
import (
"context"
"github.com/zeromicro/go-zero/tools/goctl/example/rpc/hello/internal/logic/greet"
greetlogic "github.com/zeromicro/go-zero/tools/goctl/example/rpc/hello/internal/logic/greet"
"github.com/zeromicro/go-zero/tools/goctl/example/rpc/hello/internal/svc"
"github.com/zeromicro/go-zero/tools/goctl/example/rpc/hello/pb/hello"
)

View File

@@ -7,10 +7,11 @@
package hello
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
)
const (

Some files were not shown because too many files have changed in this diff Show More