mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-12 01:10:00 +08:00
Compare commits
184 Commits
v1.5.0
...
tools/goct
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
50e29e2075 | ||
|
|
452c9dbcaf | ||
|
|
3564e36a35 | ||
|
|
e479e47634 | ||
|
|
ad921a6419 | ||
|
|
44c8d6f269 | ||
|
|
8a4cc4f98d | ||
|
|
e751736516 | ||
|
|
032f2419a2 | ||
|
|
84adc054bc | ||
|
|
b92e706ce1 | ||
|
|
1b5946346e | ||
|
|
28d3905731 | ||
|
|
3726851c7f | ||
|
|
2f2ddd373b | ||
|
|
8d48e34eed | ||
|
|
32f78668db | ||
|
|
cd0f3726ed | ||
|
|
0217044900 | ||
|
|
8b4382dcec | ||
|
|
fa33329a44 | ||
|
|
d76a39ac26 | ||
|
|
76a7a17e57 | ||
|
|
4a2a8d9e45 | ||
|
|
ef26b39b4c | ||
|
|
3ca40001b4 | ||
|
|
278ae3d26a | ||
|
|
fa1d6d50a8 | ||
|
|
0f4973be06 | ||
|
|
a9aac7e420 | ||
|
|
925cf8d3d1 | ||
|
|
99ce24e2ab | ||
|
|
701bb31ed2 | ||
|
|
55e2c7ee83 | ||
|
|
90839965fa | ||
|
|
f7228e9af1 | ||
|
|
f95adae3c1 | ||
|
|
bff5b81ad9 | ||
|
|
f0bdfb928f | ||
|
|
e4a1b7bb39 | ||
|
|
b6906b5d21 | ||
|
|
116da96178 | ||
|
|
9fa98c2bd3 | ||
|
|
b1c4c4736f | ||
|
|
ef410e8083 | ||
|
|
c22bc1c8ea | ||
|
|
1853428011 | ||
|
|
3637e10815 | ||
|
|
93124329ac | ||
|
|
851a72f1cc | ||
|
|
a93c24ce84 | ||
|
|
9f42eda9ff | ||
|
|
8762a3b7ba | ||
|
|
2684a157ff | ||
|
|
63368d8b0c | ||
|
|
4f13fe8188 | ||
|
|
9fc7874336 | ||
|
|
e6518521eb | ||
|
|
8f5a0a2de7 | ||
|
|
774e8d1d08 | ||
|
|
8ad0668612 | ||
|
|
8a043d2443 | ||
|
|
0e2ee97a02 | ||
|
|
42300a7d83 | ||
|
|
fe97fab274 | ||
|
|
f93e752f98 | ||
|
|
3a66fc038f | ||
|
|
b028ed058d | ||
|
|
1fd0c3992b | ||
|
|
1aebb3e5e4 | ||
|
|
8ffe4c01d1 | ||
|
|
a31256b327 | ||
|
|
14caf5c799 | ||
|
|
c0f8a58ed7 | ||
|
|
3189ec7be6 | ||
|
|
f51e9f0ea7 | ||
|
|
ba9d510cdb | ||
|
|
8c9b619199 | ||
|
|
49f73265b9 | ||
|
|
7568674b2b | ||
|
|
3da740b7fc | ||
|
|
ce4eb6ed61 | ||
|
|
9970ff55cd | ||
|
|
d10740f871 | ||
|
|
027193dc99 | ||
|
|
de1e0f2410 | ||
|
|
062073ce58 | ||
|
|
e20b02f311 | ||
|
|
02357d2616 | ||
|
|
489d69f779 | ||
|
|
117611a170 | ||
|
|
0a46ad7ac1 | ||
|
|
bf905eaff3 | ||
|
|
88cb35e3d5 | ||
|
|
078825b4eb | ||
|
|
bbfce6abe9 | ||
|
|
0d11ce03a8 | ||
|
|
757ed19dc5 | ||
|
|
c5fd074aac | ||
|
|
8fa0bd1f1c | ||
|
|
ede19a89ec | ||
|
|
73664b92f0 | ||
|
|
8d9c2fa22a | ||
|
|
22fad4bb9c | ||
|
|
189e9bd9da | ||
|
|
98c9b5928a | ||
|
|
e13fd62d38 | ||
|
|
ffacae89eb | ||
|
|
49135fe25e | ||
|
|
2e6402f4b5 | ||
|
|
07f03ebd0c | ||
|
|
92f2676afc | ||
|
|
1807305e6d | ||
|
|
38a97d4531 | ||
|
|
b9f98ecc4a | ||
|
|
1dc222f4b2 | ||
|
|
a79b8de24d | ||
|
|
5da8a93c75 | ||
|
|
b49fc81618 | ||
|
|
6a692453dc | ||
|
|
8d0cceb80c | ||
|
|
e06abf4f6f | ||
|
|
ee555a85da | ||
|
|
1904af2323 | ||
|
|
95b85336d6 | ||
|
|
ca4ce7bce8 | ||
|
|
9065eb90d9 | ||
|
|
50bc361430 | ||
|
|
455a6c8f97 | ||
|
|
04434646eb | ||
|
|
992a56e90b | ||
|
|
ed4d5e5813 | ||
|
|
fe85e7cb42 | ||
|
|
9c6b516bb8 | ||
|
|
2e9063a9a1 | ||
|
|
c3648be533 | ||
|
|
0ab06f62ca | ||
|
|
6170d7b790 | ||
|
|
18d163c4f7 | ||
|
|
a561048d59 | ||
|
|
7a647ca40c | ||
|
|
3f6f14f976 | ||
|
|
a78d57bebd | ||
|
|
74452eb7b5 | ||
|
|
a9e364a01a | ||
|
|
29c2e20b41 | ||
|
|
42c146bcbd | ||
|
|
b61e364458 | ||
|
|
18a4dcb79f | ||
|
|
60a13f1e53 | ||
|
|
3e093bf34e | ||
|
|
211b9498ef | ||
|
|
cca45be3c5 | ||
|
|
e735915d89 | ||
|
|
f77e2c9cfa | ||
|
|
544aa7c432 | ||
|
|
4cef2b412c | ||
|
|
123c61ad12 | ||
|
|
fbf129d535 | ||
|
|
c8a17a97be | ||
|
|
3a493cd6a6 | ||
|
|
7a0c04bc21 | ||
|
|
3c9fe0b381 | ||
|
|
f8b2dc8c9f | ||
|
|
37cb00d789 | ||
|
|
e3e7bc736b | ||
|
|
fafbee24b8 | ||
|
|
8ec29d29ce | ||
|
|
cb7f3e8a17 | ||
|
|
03391b48ca | ||
|
|
d0dedb0624 | ||
|
|
e136deb3a7 | ||
|
|
a2592a17e9 | ||
|
|
05abf4a2ff | ||
|
|
d40000d4b9 | ||
|
|
4620924105 | ||
|
|
a05fe7bf0a | ||
|
|
dd347e96b0 | ||
|
|
a972f400c6 | ||
|
|
fb7664a764 | ||
|
|
7d5d7d9085 | ||
|
|
9911c11e9c | ||
|
|
0d5a68869d | ||
|
|
d9d79e930d |
@@ -3,4 +3,7 @@ comment:
|
||||
behavior: once
|
||||
require_changes: true
|
||||
ignore:
|
||||
- "tools"
|
||||
- "tools"
|
||||
- "**/mock"
|
||||
- "**/*_mock.go"
|
||||
- "**/*test"
|
||||
|
||||
2
.github/FUNDING.yml
vendored
2
.github/FUNDING.yml
vendored
@@ -10,4 +10,4 @@ liberapay: # Replace with a single Liberapay username
|
||||
issuehunt: # Replace with a single IssueHunt username
|
||||
otechie: # Replace with a single Otechie username
|
||||
custom: # https://gitee.com/kevwan/static/raw/master/images/sponsor.jpg
|
||||
ethereum: 0x5052b7f6B937B02563996D23feb69b38D06Ca150 | kevwan
|
||||
ethereum: # 0x5052b7f6B937B02563996D23feb69b38D06Ca150 | kevwan
|
||||
|
||||
12
.github/workflows/go.yml
vendored
12
.github/workflows/go.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: ^1.18
|
||||
go-version: 1.18
|
||||
check-latest: true
|
||||
cache: true
|
||||
id: go
|
||||
@@ -29,8 +29,12 @@ jobs:
|
||||
- name: Lint
|
||||
run: |
|
||||
go vet -stdmethods=false $(go list ./...)
|
||||
go install mvdan.cc/gofumpt@latest
|
||||
test -z "$(gofumpt -l -extra .)" || echo "Please run 'gofumpt -l -w -extra .'"
|
||||
|
||||
go mod tidy
|
||||
if ! test -z "$(git status --porcelain)"; then
|
||||
echo "Please run 'go mod tidy'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Test
|
||||
run: go test -race -coverprofile=coverage.txt -covermode=atomic ./...
|
||||
@@ -57,5 +61,5 @@ jobs:
|
||||
run: |
|
||||
go mod verify
|
||||
go mod download
|
||||
go test -v -race ./...
|
||||
go test ./...
|
||||
cd tools/goctl && go build -v goctl.go
|
||||
|
||||
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
goos: ${{ matrix.goos }}
|
||||
goarch: ${{ matrix.goarch }}
|
||||
goversion: "https://dl.google.com/go/go1.17.5.linux-amd64.tar.gz"
|
||||
goversion: "https://dl.google.com/go/go1.18.10.linux-amd64.tar.gz"
|
||||
project_path: "tools/goctl"
|
||||
binary_name: "goctl"
|
||||
extra_files: tools/goctl/readme.md tools/goctl/readme-cn.md
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -14,9 +14,10 @@
|
||||
**/.idea
|
||||
**/.DS_Store
|
||||
**/logs
|
||||
**/adhoc
|
||||
**/coverage.txt
|
||||
|
||||
# for test purpose
|
||||
**/adhoc
|
||||
go.work
|
||||
go.work.sum
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package bloom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
|
||||
@@ -8,28 +9,29 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
)
|
||||
|
||||
const (
|
||||
// for detailed error rate table, see http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html
|
||||
// maps as k in the error rate table
|
||||
maps = 14
|
||||
setScript = `
|
||||
// for detailed error rate table, see http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html
|
||||
// maps as k in the error rate table
|
||||
const maps = 14
|
||||
|
||||
var (
|
||||
// ErrTooLargeOffset indicates the offset is too large in bitset.
|
||||
ErrTooLargeOffset = errors.New("too large offset")
|
||||
|
||||
setScript = redis.NewScript(`
|
||||
for _, offset in ipairs(ARGV) do
|
||||
redis.call("setbit", KEYS[1], offset, 1)
|
||||
end
|
||||
`
|
||||
testScript = `
|
||||
`)
|
||||
testScript = redis.NewScript(`
|
||||
for _, offset in ipairs(ARGV) do
|
||||
if tonumber(redis.call("getbit", KEYS[1], offset)) == 0 then
|
||||
return false
|
||||
end
|
||||
end
|
||||
return true
|
||||
`
|
||||
`)
|
||||
)
|
||||
|
||||
// ErrTooLargeOffset indicates the offset is too large in bitset.
|
||||
var ErrTooLargeOffset = errors.New("too large offset")
|
||||
|
||||
type (
|
||||
// A Filter is a bloom filter.
|
||||
Filter struct {
|
||||
@@ -38,8 +40,8 @@ type (
|
||||
}
|
||||
|
||||
bitSetProvider interface {
|
||||
check([]uint) (bool, error)
|
||||
set([]uint) error
|
||||
check(ctx context.Context, offsets []uint) (bool, error)
|
||||
set(ctx context.Context, offsets []uint) error
|
||||
}
|
||||
)
|
||||
|
||||
@@ -58,14 +60,24 @@ func New(store *redis.Redis, key string, bits uint) *Filter {
|
||||
|
||||
// Add adds data into f.
|
||||
func (f *Filter) Add(data []byte) error {
|
||||
return f.AddCtx(context.Background(), data)
|
||||
}
|
||||
|
||||
// AddCtx adds data into f with context.
|
||||
func (f *Filter) AddCtx(ctx context.Context, data []byte) error {
|
||||
locations := f.getLocations(data)
|
||||
return f.bitSet.set(locations)
|
||||
return f.bitSet.set(ctx, locations)
|
||||
}
|
||||
|
||||
// Exists checks if data is in f.
|
||||
func (f *Filter) Exists(data []byte) (bool, error) {
|
||||
return f.ExistsCtx(context.Background(), data)
|
||||
}
|
||||
|
||||
// ExistsCtx checks if data is in f with context.
|
||||
func (f *Filter) ExistsCtx(ctx context.Context, data []byte) (bool, error) {
|
||||
locations := f.getLocations(data)
|
||||
isSet, err := f.bitSet.check(locations)
|
||||
isSet, err := f.bitSet.check(ctx, locations)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -111,13 +123,13 @@ func (r *redisBitSet) buildOffsetArgs(offsets []uint) ([]string, error) {
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func (r *redisBitSet) check(offsets []uint) (bool, error) {
|
||||
func (r *redisBitSet) check(ctx context.Context, offsets []uint) (bool, error) {
|
||||
args, err := r.buildOffsetArgs(offsets)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
resp, err := r.store.Eval(testScript, []string{r.key}, args)
|
||||
resp, err := r.store.ScriptRunCtx(ctx, testScript, []string{r.key}, args)
|
||||
if err == redis.Nil {
|
||||
return false, nil
|
||||
} else if err != nil {
|
||||
@@ -132,22 +144,24 @@ func (r *redisBitSet) check(offsets []uint) (bool, error) {
|
||||
return exists == 1, nil
|
||||
}
|
||||
|
||||
// del only use for testing.
|
||||
func (r *redisBitSet) del() error {
|
||||
_, err := r.store.Del(r.key)
|
||||
return err
|
||||
}
|
||||
|
||||
// expire only use for testing.
|
||||
func (r *redisBitSet) expire(seconds int) error {
|
||||
return r.store.Expire(r.key, seconds)
|
||||
}
|
||||
|
||||
func (r *redisBitSet) set(offsets []uint) error {
|
||||
func (r *redisBitSet) set(ctx context.Context, offsets []uint) error {
|
||||
args, err := r.buildOffsetArgs(offsets)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = r.store.Eval(setScript, []string{r.key}, args)
|
||||
_, err = r.store.ScriptRunCtx(ctx, setScript, []string{r.key}, args)
|
||||
if err == redis.Nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,30 +1,31 @@
|
||||
package bloom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis/redistest"
|
||||
)
|
||||
|
||||
func TestRedisBitSet_New_Set_Test(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean()
|
||||
store := redistest.CreateRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
bitSet := newRedisBitSet(store, "test_key", 1024)
|
||||
isSetBefore, err := bitSet.check([]uint{0})
|
||||
isSetBefore, err := bitSet.check(ctx, []uint{0})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if isSetBefore {
|
||||
t.Fatal("Bit should not be set")
|
||||
}
|
||||
err = bitSet.set([]uint{512})
|
||||
err = bitSet.set(ctx, []uint{512})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
isSetAfter, err := bitSet.check([]uint{512})
|
||||
isSetAfter, err := bitSet.check(ctx, []uint{512})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -42,9 +43,7 @@ func TestRedisBitSet_New_Set_Test(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRedisBitSet_Add(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean()
|
||||
store := redistest.CreateRedis(t)
|
||||
|
||||
filter := New(store, "test_key", 64)
|
||||
assert.Nil(t, filter.Add([]byte("hello")))
|
||||
@@ -53,3 +52,51 @@ func TestRedisBitSet_Add(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func TestFilter_Exists(t *testing.T) {
|
||||
store, clean := redistest.CreateRedisWithClean(t)
|
||||
|
||||
rbs := New(store, "test", 64)
|
||||
_, err := rbs.Exists([]byte{0, 1, 2})
|
||||
assert.NoError(t, err)
|
||||
|
||||
clean()
|
||||
rbs = New(store, "test", 64)
|
||||
_, err = rbs.Exists([]byte{0, 1, 2})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRedisBitSet_check(t *testing.T) {
|
||||
store, clean := redistest.CreateRedisWithClean(t)
|
||||
ctx := context.Background()
|
||||
|
||||
rbs := newRedisBitSet(store, "test", 0)
|
||||
assert.Error(t, rbs.set(ctx, []uint{0, 1, 2}))
|
||||
_, err := rbs.check(ctx, []uint{0, 1, 2})
|
||||
assert.Error(t, err)
|
||||
|
||||
rbs = newRedisBitSet(store, "test", 64)
|
||||
_, err = rbs.check(ctx, []uint{0, 1, 2})
|
||||
assert.NoError(t, err)
|
||||
|
||||
clean()
|
||||
rbs = newRedisBitSet(store, "test", 64)
|
||||
_, err = rbs.check(ctx, []uint{0, 1, 2})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRedisBitSet_set(t *testing.T) {
|
||||
logx.Disable()
|
||||
store, clean := redistest.CreateRedisWithClean(t)
|
||||
ctx := context.Background()
|
||||
|
||||
rbs := newRedisBitSet(store, "test", 0)
|
||||
assert.Error(t, rbs.set(ctx, []uint{0, 1, 2}))
|
||||
|
||||
rbs = newRedisBitSet(store, "test", 64)
|
||||
assert.NoError(t, rbs.set(ctx, []uint{0, 1, 2}))
|
||||
|
||||
clean()
|
||||
rbs = newRedisBitSet(store, "test", 64)
|
||||
assert.Error(t, rbs.set(ctx, []uint{0, 1, 2}))
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package codec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
@@ -21,3 +23,45 @@ func TestGzip(t *testing.T) {
|
||||
assert.True(t, len(bs) < buf.Len())
|
||||
assert.Equal(t, buf.Bytes(), actual)
|
||||
}
|
||||
|
||||
func TestGunzip(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected []byte
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "valid input",
|
||||
input: func() []byte {
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
gz.Write([]byte("hello"))
|
||||
gz.Close()
|
||||
return buf.Bytes()
|
||||
}(),
|
||||
expected: []byte("hello"),
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid input",
|
||||
input: []byte("invalid input"),
|
||||
expected: nil,
|
||||
expectedErr: gzip.ErrHeader,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result, err := Gunzip(test.input)
|
||||
|
||||
if !bytes.Equal(result, test.expected) {
|
||||
t.Errorf("unexpected result: %v", result)
|
||||
}
|
||||
|
||||
if !errors.Is(err, test.expectedErr) {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package codec
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -41,6 +42,7 @@ func TestCryption(t *testing.T) {
|
||||
|
||||
file, err := fs.TempFilenameWithText(priKey)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(file)
|
||||
dec, err := NewRsaDecrypter(file)
|
||||
assert.Nil(t, err)
|
||||
actual, err := dec.Decrypt(ret)
|
||||
|
||||
@@ -13,7 +13,10 @@ import (
|
||||
"github.com/zeromicro/go-zero/internal/encoding"
|
||||
)
|
||||
|
||||
const jsonTagKey = "json"
|
||||
const (
|
||||
jsonTagKey = "json"
|
||||
jsonTagSep = ','
|
||||
)
|
||||
|
||||
var (
|
||||
fillDefaultUnmarshaler = mapping.NewUnmarshaler(jsonTagKey, mapping.WithDefault())
|
||||
@@ -70,13 +73,13 @@ func LoadConfig(file string, v any, opts ...Option) error {
|
||||
|
||||
// LoadFromJsonBytes loads config into v from content json bytes.
|
||||
func LoadFromJsonBytes(content []byte, v any) error {
|
||||
info, err := buildFieldsInfo(reflect.TypeOf(v))
|
||||
info, err := buildFieldsInfo(reflect.TypeOf(v), "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var m map[string]any
|
||||
if err := jsonx.Unmarshal(content, &m); err != nil {
|
||||
if err = jsonx.Unmarshal(content, &m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -124,13 +127,13 @@ func MustLoad(path string, v any, opts ...Option) {
|
||||
}
|
||||
}
|
||||
|
||||
func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo) error {
|
||||
func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo, fullName string) error {
|
||||
if prev, ok := info.children[key]; ok {
|
||||
if child.mapField != nil {
|
||||
return newDupKeyError(key)
|
||||
return newConflictKeyError(fullName)
|
||||
}
|
||||
|
||||
if err := mergeFields(prev, key, child.children); err != nil {
|
||||
if err := mergeFields(prev, key, child.children, fullName); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
@@ -140,27 +143,27 @@ func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
|
||||
func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type, fullName string) error {
|
||||
switch ft.Kind() {
|
||||
case reflect.Struct:
|
||||
fields, err := buildFieldsInfo(ft)
|
||||
fields, err := buildFieldsInfo(ft, fullName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for k, v := range fields.children {
|
||||
if err = addOrMergeFields(info, k, v); err != nil {
|
||||
if err = addOrMergeFields(info, k, v, fullName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case reflect.Map:
|
||||
elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
|
||||
elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()), fullName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, ok := info.children[lowerCaseName]; ok {
|
||||
return newDupKeyError(lowerCaseName)
|
||||
return newConflictKeyError(fullName)
|
||||
}
|
||||
|
||||
info.children[lowerCaseName] = &fieldInfo{
|
||||
@@ -169,7 +172,7 @@ func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.T
|
||||
}
|
||||
default:
|
||||
if _, ok := info.children[lowerCaseName]; ok {
|
||||
return newDupKeyError(lowerCaseName)
|
||||
return newConflictKeyError(fullName)
|
||||
}
|
||||
|
||||
info.children[lowerCaseName] = &fieldInfo{
|
||||
@@ -180,14 +183,14 @@ func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.T
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
|
||||
func buildFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) {
|
||||
tp = mapping.Deref(tp)
|
||||
|
||||
switch tp.Kind() {
|
||||
case reflect.Struct:
|
||||
return buildStructFieldsInfo(tp)
|
||||
return buildStructFieldsInfo(tp, fullName)
|
||||
case reflect.Array, reflect.Slice:
|
||||
return buildFieldsInfo(mapping.Deref(tp.Elem()))
|
||||
return buildFieldsInfo(mapping.Deref(tp.Elem()), fullName)
|
||||
case reflect.Chan, reflect.Func:
|
||||
return nil, fmt.Errorf("unsupported type: %s", tp.Kind())
|
||||
default:
|
||||
@@ -197,23 +200,23 @@ func buildFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
|
||||
func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type, fullName string) error {
|
||||
var finfo *fieldInfo
|
||||
var err error
|
||||
|
||||
switch ft.Kind() {
|
||||
case reflect.Struct:
|
||||
finfo, err = buildFieldsInfo(ft)
|
||||
finfo, err = buildFieldsInfo(ft, fullName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case reflect.Array, reflect.Slice:
|
||||
finfo, err = buildFieldsInfo(ft.Elem())
|
||||
finfo, err = buildFieldsInfo(ft.Elem(), fullName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case reflect.Map:
|
||||
elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
|
||||
elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()), fullName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -223,31 +226,37 @@ func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type)
|
||||
mapField: elemInfo,
|
||||
}
|
||||
default:
|
||||
finfo, err = buildFieldsInfo(ft)
|
||||
finfo, err = buildFieldsInfo(ft, fullName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return addOrMergeFields(info, lowerCaseName, finfo)
|
||||
return addOrMergeFields(info, lowerCaseName, finfo, fullName)
|
||||
}
|
||||
|
||||
func buildStructFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
|
||||
func buildStructFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) {
|
||||
info := &fieldInfo{
|
||||
children: make(map[string]*fieldInfo),
|
||||
}
|
||||
|
||||
for i := 0; i < tp.NumField(); i++ {
|
||||
field := tp.Field(i)
|
||||
name := field.Name
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
name := getTagName(field)
|
||||
lowerCaseName := toLowerCase(name)
|
||||
ft := mapping.Deref(field.Type)
|
||||
// flatten anonymous fields
|
||||
if field.Anonymous {
|
||||
if err := buildAnonymousFieldInfo(info, lowerCaseName, ft); err != nil {
|
||||
if err := buildAnonymousFieldInfo(info, lowerCaseName, ft,
|
||||
getFullName(fullName, lowerCaseName)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if err := buildNamedFieldInfo(info, lowerCaseName, ft); err != nil {
|
||||
} else if err := buildNamedFieldInfo(info, lowerCaseName, ft,
|
||||
getFullName(fullName, lowerCaseName)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -255,15 +264,32 @@ func buildStructFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func mergeFields(prev *fieldInfo, key string, children map[string]*fieldInfo) error {
|
||||
// getTagName get the tag name of the given field, if no tag name, use file.Name.
|
||||
// field.Name is returned on tags like `json:""` and `json:",optional"`.
|
||||
func getTagName(field reflect.StructField) string {
|
||||
if tag, ok := field.Tag.Lookup(jsonTagKey); ok {
|
||||
if pos := strings.IndexByte(tag, jsonTagSep); pos >= 0 {
|
||||
tag = tag[:pos]
|
||||
}
|
||||
|
||||
tag = strings.TrimSpace(tag)
|
||||
if len(tag) > 0 {
|
||||
return tag
|
||||
}
|
||||
}
|
||||
|
||||
return field.Name
|
||||
}
|
||||
|
||||
func mergeFields(prev *fieldInfo, key string, children map[string]*fieldInfo, fullName string) error {
|
||||
if len(prev.children) == 0 || len(children) == 0 {
|
||||
return newDupKeyError(key)
|
||||
return newConflictKeyError(fullName)
|
||||
}
|
||||
|
||||
// merge fields
|
||||
for k, v := range children {
|
||||
if _, ok := prev.children[k]; ok {
|
||||
return newDupKeyError(k)
|
||||
return newConflictKeyError(fullName)
|
||||
}
|
||||
|
||||
prev.children[k] = v
|
||||
@@ -314,14 +340,22 @@ func toLowerCaseKeyMap(m map[string]any, info *fieldInfo) map[string]any {
|
||||
return res
|
||||
}
|
||||
|
||||
type dupKeyError struct {
|
||||
type conflictKeyError struct {
|
||||
key string
|
||||
}
|
||||
|
||||
func newDupKeyError(key string) dupKeyError {
|
||||
return dupKeyError{key: key}
|
||||
func newConflictKeyError(key string) conflictKeyError {
|
||||
return conflictKeyError{key: key}
|
||||
}
|
||||
|
||||
func (e dupKeyError) Error() string {
|
||||
return fmt.Sprintf("duplicated key %s", e.key)
|
||||
func (e conflictKeyError) Error() string {
|
||||
return fmt.Sprintf("conflict key %s, pay attention to anonymous fields", e.key)
|
||||
}
|
||||
|
||||
func getFullName(parent, child string) string {
|
||||
if len(parent) == 0 {
|
||||
return child
|
||||
}
|
||||
|
||||
return strings.Join([]string{parent, child}, ".")
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package conf
|
||||
|
||||
import (
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -9,7 +10,7 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/hash"
|
||||
)
|
||||
|
||||
var dupErr dupKeyError
|
||||
var dupErr conflictKeyError
|
||||
|
||||
func TestLoadConfig_notExists(t *testing.T) {
|
||||
assert.NotNil(t, Load("not_a_file", nil))
|
||||
@@ -34,11 +35,11 @@ func TestConfigJson(t *testing.T) {
|
||||
"c": "${FOO}",
|
||||
"d": "abcd!@#$112"
|
||||
}`
|
||||
t.Setenv("FOO", "2")
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test, func(t *testing.T) {
|
||||
os.Setenv("FOO", "2")
|
||||
defer os.Unsetenv("FOO")
|
||||
tmpfile, err := createTempFile(test, text)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(tmpfile)
|
||||
@@ -80,8 +81,7 @@ b = 1
|
||||
c = "${FOO}"
|
||||
d = "abcd!@#$112"
|
||||
`
|
||||
os.Setenv("FOO", "2")
|
||||
defer os.Unsetenv("FOO")
|
||||
t.Setenv("FOO", "2")
|
||||
tmpfile, err := createTempFile(".toml", text)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(tmpfile)
|
||||
@@ -123,6 +123,24 @@ d = "abcd"
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigWithLower(t *testing.T) {
|
||||
text := `a = "foo"
|
||||
b = 1
|
||||
`
|
||||
tmpfile, err := createTempFile(".toml", text)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(tmpfile)
|
||||
|
||||
var val struct {
|
||||
A string `json:"a"`
|
||||
b int
|
||||
}
|
||||
if assert.NoError(t, Load(tmpfile, &val)) {
|
||||
assert.Equal(t, "foo", val.A)
|
||||
assert.Equal(t, 0, val.b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigJsonCanonical(t *testing.T) {
|
||||
text := []byte(`{"a": "foo", "B": "bar"}`)
|
||||
|
||||
@@ -188,8 +206,7 @@ b = 1
|
||||
c = "${FOO}"
|
||||
d = "abcd!@#112"
|
||||
`
|
||||
os.Setenv("FOO", "2")
|
||||
defer os.Unsetenv("FOO")
|
||||
t.Setenv("FOO", "2")
|
||||
tmpfile, err := createTempFile(".toml", text)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(tmpfile)
|
||||
@@ -220,11 +237,10 @@ func TestConfigJsonEnv(t *testing.T) {
|
||||
"c": "${FOO}",
|
||||
"d": "abcd!@#$a12 3"
|
||||
}`
|
||||
t.Setenv("FOO", "2")
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test, func(t *testing.T) {
|
||||
os.Setenv("FOO", "2")
|
||||
defer os.Unsetenv("FOO")
|
||||
tmpfile, err := createTempFile(test, text)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(tmpfile)
|
||||
@@ -672,7 +688,7 @@ func Test_FieldOverwrite(t *testing.T) {
|
||||
input := []byte(`{"Name": "hello"}`)
|
||||
err := LoadFromJsonBytes(input, val)
|
||||
assert.ErrorAs(t, err, &dupErr)
|
||||
assert.Equal(t, newDupKeyError("name").Error(), err.Error())
|
||||
assert.Equal(t, newConflictKeyError("name").Error(), err.Error())
|
||||
}
|
||||
|
||||
validate(&St1{})
|
||||
@@ -715,7 +731,7 @@ func Test_FieldOverwrite(t *testing.T) {
|
||||
input := []byte(`{"Name": "hello"}`)
|
||||
err := LoadFromJsonBytes(input, val)
|
||||
assert.ErrorAs(t, err, &dupErr)
|
||||
assert.Equal(t, newDupKeyError("name").Error(), err.Error())
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
validate(&St0{})
|
||||
@@ -1022,22 +1038,22 @@ func TestLoadNamedFieldOverwritten(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func createTempFile(ext, text string) (string, error) {
|
||||
tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
func TestLoadLowerMemberShouldNotConflict(t *testing.T) {
|
||||
type (
|
||||
Redis struct {
|
||||
db uint
|
||||
}
|
||||
|
||||
if err := os.WriteFile(tmpfile.Name(), []byte(text), os.ModeTemporary); err != nil {
|
||||
return "", err
|
||||
}
|
||||
Config struct {
|
||||
db uint
|
||||
Redis
|
||||
}
|
||||
)
|
||||
|
||||
filename := tmpfile.Name()
|
||||
if err = tmpfile.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filename, nil
|
||||
var c Config
|
||||
assert.NoError(t, LoadFromJsonBytes([]byte(`{}`), &c))
|
||||
assert.Zero(t, c.db)
|
||||
assert.Zero(t, c.Redis.db)
|
||||
}
|
||||
|
||||
func TestFillDefaultUnmarshal(t *testing.T) {
|
||||
@@ -1079,7 +1095,7 @@ func TestFillDefaultUnmarshal(t *testing.T) {
|
||||
assert.Equal(t, st.C, "c")
|
||||
})
|
||||
|
||||
t.Run("has vaue", func(t *testing.T) {
|
||||
t.Run("has value", func(t *testing.T) {
|
||||
type St struct {
|
||||
A string `json:",default=a"`
|
||||
B string
|
||||
@@ -1091,3 +1107,201 @@ func TestFillDefaultUnmarshal(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfigWithJsonTag(t *testing.T) {
|
||||
t.Run("map with value", func(t *testing.T) {
|
||||
var input = []byte(`[Value]
|
||||
[Value.first]
|
||||
Email = "foo"
|
||||
[Value.second]
|
||||
Email = "bar"`)
|
||||
|
||||
type Value struct {
|
||||
Email string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
ValueMap map[string]Value `json:"Value"`
|
||||
}
|
||||
|
||||
var c Config
|
||||
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
|
||||
assert.Len(t, c.ValueMap, 2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("map with ptr value", func(t *testing.T) {
|
||||
var input = []byte(`[Value]
|
||||
[Value.first]
|
||||
Email = "foo"
|
||||
[Value.second]
|
||||
Email = "bar"`)
|
||||
|
||||
type Value struct {
|
||||
Email string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
ValueMap map[string]*Value `json:"Value"`
|
||||
}
|
||||
|
||||
var c Config
|
||||
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
|
||||
assert.Len(t, c.ValueMap, 2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("map with optional", func(t *testing.T) {
|
||||
var input = []byte(`[Value]
|
||||
[Value.first]
|
||||
Email = "foo"
|
||||
[Value.second]
|
||||
Email = "bar"`)
|
||||
|
||||
type Value struct {
|
||||
Email string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Value map[string]Value `json:",optional"`
|
||||
}
|
||||
|
||||
var c Config
|
||||
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
|
||||
assert.Len(t, c.Value, 2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("map with empty tag", func(t *testing.T) {
|
||||
var input = []byte(`[Value]
|
||||
[Value.first]
|
||||
Email = "foo"
|
||||
[Value.second]
|
||||
Email = "bar"`)
|
||||
|
||||
type Value struct {
|
||||
Email string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Value map[string]Value `json:" "`
|
||||
}
|
||||
|
||||
var c Config
|
||||
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
|
||||
assert.Len(t, c.Value, 2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_getFullName(t *testing.T) {
|
||||
assert.Equal(t, "a.b", getFullName("a", "b"))
|
||||
assert.Equal(t, "a", getFullName("", "a"))
|
||||
}
|
||||
|
||||
func Test_buildFieldsInfo(t *testing.T) {
|
||||
type ParentSt struct {
|
||||
Name string
|
||||
M map[string]int
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
t reflect.Type
|
||||
ok bool
|
||||
containsKey string
|
||||
}{
|
||||
{
|
||||
name: "normal",
|
||||
t: reflect.TypeOf(struct{ A string }{}),
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "struct anonymous",
|
||||
t: reflect.TypeOf(struct {
|
||||
ParentSt
|
||||
Name string
|
||||
}{}),
|
||||
ok: false,
|
||||
containsKey: newConflictKeyError("name").Error(),
|
||||
},
|
||||
{
|
||||
name: "struct ptr anonymous",
|
||||
t: reflect.TypeOf(struct {
|
||||
*ParentSt
|
||||
Name string
|
||||
}{}),
|
||||
ok: false,
|
||||
containsKey: newConflictKeyError("name").Error(),
|
||||
},
|
||||
{
|
||||
name: "more struct anonymous",
|
||||
t: reflect.TypeOf(struct {
|
||||
Value struct {
|
||||
ParentSt
|
||||
Name string
|
||||
}
|
||||
}{}),
|
||||
ok: false,
|
||||
containsKey: newConflictKeyError("value.name").Error(),
|
||||
},
|
||||
{
|
||||
name: "map anonymous",
|
||||
t: reflect.TypeOf(struct {
|
||||
ParentSt
|
||||
M string
|
||||
}{}),
|
||||
ok: false,
|
||||
containsKey: newConflictKeyError("m").Error(),
|
||||
},
|
||||
{
|
||||
name: "map more anonymous",
|
||||
t: reflect.TypeOf(struct {
|
||||
Value struct {
|
||||
ParentSt
|
||||
M string
|
||||
}
|
||||
}{}),
|
||||
ok: false,
|
||||
containsKey: newConflictKeyError("value.m").Error(),
|
||||
},
|
||||
{
|
||||
name: "struct slice anonymous",
|
||||
t: reflect.TypeOf([]struct {
|
||||
ParentSt
|
||||
Name string
|
||||
}{}),
|
||||
ok: false,
|
||||
containsKey: newConflictKeyError("name").Error(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := buildFieldsInfo(tt.t, "")
|
||||
if tt.ok {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, err.Error(), tt.containsKey)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createTempFile(ext, text string) (string, error) {
|
||||
tmpFile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := os.WriteFile(tmpFile.Name(), []byte(text), os.ModeTemporary); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
filename := tmpFile.Name()
|
||||
if err = tmpFile.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filename, nil
|
||||
}
|
||||
|
||||
@@ -45,8 +45,7 @@ func TestPropertiesEnv(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(tmpfile)
|
||||
|
||||
os.Setenv("FOO", "2")
|
||||
defer os.Unsetenv("FOO")
|
||||
t.Setenv("FOO", "2")
|
||||
|
||||
props, err := LoadProperties(tmpfile, UseEnv())
|
||||
assert.Nil(t, err)
|
||||
|
||||
@@ -13,6 +13,7 @@ var (
|
||||
type EtcdConf struct {
|
||||
Hosts []string
|
||||
Key string
|
||||
ID int64 `json:",optional"`
|
||||
User string `json:",optional"`
|
||||
Pass string `json:",optional"`
|
||||
CertFile string `json:",optional"`
|
||||
@@ -26,6 +27,11 @@ func (c EtcdConf) HasAccount() bool {
|
||||
return len(c.User) > 0 && len(c.Pass) > 0
|
||||
}
|
||||
|
||||
// HasID returns if ID provided.
|
||||
func (c EtcdConf) HasID() bool {
|
||||
return c.ID > 0
|
||||
}
|
||||
|
||||
// HasTLS returns if TLS CertFile/CertKeyFile/CACertFile are provided.
|
||||
func (c EtcdConf) HasTLS() bool {
|
||||
return len(c.CertFile) > 0 && len(c.CertKeyFile) > 0 && len(c.CACertFile) > 0
|
||||
|
||||
@@ -80,3 +80,90 @@ func TestEtcdConf_HasAccount(t *testing.T) {
|
||||
assert.Equal(t, test.hasAccount, test.EtcdConf.HasAccount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestEtcdConf_HasID(t *testing.T) {
|
||||
tests := []struct {
|
||||
EtcdConf
|
||||
hasServerID bool
|
||||
}{
|
||||
{
|
||||
EtcdConf: EtcdConf{
|
||||
Hosts: []string{"any"},
|
||||
ID: -1,
|
||||
},
|
||||
hasServerID: false,
|
||||
},
|
||||
{
|
||||
EtcdConf: EtcdConf{
|
||||
Hosts: []string{"any"},
|
||||
ID: 0,
|
||||
},
|
||||
hasServerID: false,
|
||||
},
|
||||
{
|
||||
EtcdConf: EtcdConf{
|
||||
Hosts: []string{"any"},
|
||||
ID: 10000,
|
||||
},
|
||||
hasServerID: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
assert.Equal(t, test.hasServerID, test.EtcdConf.HasID())
|
||||
}
|
||||
}
|
||||
|
||||
func TestEtcdConf_HasTLS(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
conf EtcdConf
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "empty config",
|
||||
conf: EtcdConf{},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "missing CertFile",
|
||||
conf: EtcdConf{
|
||||
CertKeyFile: "key",
|
||||
CACertFile: "ca",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "missing CertKeyFile",
|
||||
conf: EtcdConf{
|
||||
CertFile: "cert",
|
||||
CACertFile: "ca",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "missing CACertFile",
|
||||
conf: EtcdConf{
|
||||
CertFile: "cert",
|
||||
CertKeyFile: "key",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "valid config",
|
||||
conf: EtcdConf{
|
||||
CertFile: "cert",
|
||||
CertKeyFile: "key",
|
||||
CACertFile: "ca",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.conf.HasTLS()
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,85 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
const (
|
||||
certContent = `-----BEGIN CERTIFICATE-----
|
||||
MIIDazCCAlOgAwIBAgIUEg9GVO2oaPn+YSmiqmFIuAo10WIwDQYJKoZIhvcNAQEM
|
||||
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
|
||||
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAzMTExMzIxMjNaGA8yMTIz
|
||||
MDIxNTEzMjEyM1owRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
|
||||
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
|
||||
AQEBBQADggEPADCCAQoCggEBALplXlWsIf0O/IgnIplmiZHKGnxyfyufyE2FBRNk
|
||||
OofRqbKuPH8GNqbkvZm7N29fwTDAQ+mViAggCkDht4hOzoWJMA7KYJt8JnTSWL48
|
||||
M1lcrpc9DL2gszC/JF/FGvyANbBtLklkZPFBGdHUX14pjrT937wqPtm+SqUHSvRT
|
||||
B7bmwmm2drRcmhpVm98LSlV7uQ2EgnJgsLjBPITKUejLmVLHfgX0RwQ2xIpX9pS4
|
||||
FCe1BTacwl2gGp7Mje7y4Mfv3o0ArJW6Tuwbjx59ZXwb1KIP71b7bT04AVS8ZeYO
|
||||
UMLKKuB5UR9x9Rn6cLXOTWBpcMVyzDgrAFLZjnE9LPUolZMCAwEAAaNRME8wHwYD
|
||||
VR0jBBgwFoAUeW8w8pmhncbRgTsl48k4/7wnfx8wCQYDVR0TBAIwADALBgNVHQ8E
|
||||
BAMCBPAwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBDAUAA4IBAQAI
|
||||
y9xaoS88CLPBsX6mxfcTAFVfGNTRW9VN9Ng1cCnUR+YGoXGM/l+qP4f7p8ocdGwK
|
||||
iYZErVTzXYIn+D27//wpY3klJk3gAnEUBT3QRkStBw7XnpbeZ2oPBK+cmDnCnZPS
|
||||
BIF1wxPX7vIgaxs5Zsdqwk3qvZ4Djr2wP7LabNWTLSBKgQoUY45Liw6pffLwcGF9
|
||||
UKlu54bvGze2SufISCR3ib+I+FLvqpvJhXToZWYb/pfI/HccuCL1oot1x8vx6DQy
|
||||
U+TYxlZsKS5mdNxAX3dqEkEMsgEi+g/tzDPXJImfeCGGBhIOXLm8SRypiuGdEbc9
|
||||
xkWYxRPegajuEZGvCqVs
|
||||
-----END CERTIFICATE-----`
|
||||
keyContent = `-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEowIBAAKCAQEAumVeVawh/Q78iCcimWaJkcoafHJ/K5/ITYUFE2Q6h9Gpsq48
|
||||
fwY2puS9mbs3b1/BMMBD6ZWICCAKQOG3iE7OhYkwDspgm3wmdNJYvjwzWVyulz0M
|
||||
vaCzML8kX8Ua/IA1sG0uSWRk8UEZ0dRfXimOtP3fvCo+2b5KpQdK9FMHtubCabZ2
|
||||
tFyaGlWb3wtKVXu5DYSCcmCwuME8hMpR6MuZUsd+BfRHBDbEilf2lLgUJ7UFNpzC
|
||||
XaAansyN7vLgx+/ejQCslbpO7BuPHn1lfBvUog/vVvttPTgBVLxl5g5Qwsoq4HlR
|
||||
H3H1Gfpwtc5NYGlwxXLMOCsAUtmOcT0s9SiVkwIDAQABAoIBAD5meTJNMgO55Kjg
|
||||
ESExxpRcCIno+tHr5+6rvYtEXqPheOIsmmwb9Gfi4+Z3WpOaht5/Pz0Ppj6yGzyl
|
||||
U//6AgGKb+BDuBvVcDpjwPnOxZIBCSHwejdxeQu0scSuA97MPS0XIAvJ5FEv7ijk
|
||||
5Bht6SyGYURpECltHygoTNuGgGqmO+McCJRLE9L09lTBI6UQ/JQwWJqSr7wx6iPU
|
||||
M1Ze/srIV+7cyEPu6i0DGjS1gSQKkX68Lqn1w6oE290O+OZvleO0gZ02fLDWCZke
|
||||
aeD9+EU/Pw+rqm3H6o0szOFIpzhRp41FUdW9sybB3Yp3u7c/574E+04Z/e30LMKs
|
||||
TCtE1QECgYEA3K7KIpw0NH2HXL5C3RHcLmr204xeBfS70riBQQuVUgYdmxak2ima
|
||||
80RInskY8hRhSGTg0l+VYIH8cmjcUyqMSOELS5XfRH99r4QPiK8AguXg80T4VumY
|
||||
W3Pf+zEC2ssgP/gYthV0g0Xj5m2QxktOF9tRw5nkg739ZR4dI9lm/iECgYEA2Dnf
|
||||
uwEDGqHiQRF6/fh5BG/nGVMvrefkqx6WvTJQ3k/M/9WhxB+lr/8yH46TuS8N2b29
|
||||
FoTf3Mr9T7pr/PWkOPzoY3P56nYbKU8xSwCim9xMzhBMzj8/N9ukJvXy27/VOz56
|
||||
eQaKqnvdXNGtPJrIMDGHps2KKWlKLyAlapzjVTMCgYAA/W++tACv85g13EykfT4F
|
||||
n0k4LbsGP9DP4zABQLIMyiY72eAncmRVjwrcW36XJ2xATOONTgx3gF3HjZzfaqNy
|
||||
eD/6uNNllUTVEryXGmHgNHPL45VRnn6memCY2eFvZdXhM5W4y2PYaunY0MkDercA
|
||||
+GTngbs6tBF88KOk04bYwQKBgFl68cRgsdkmnwwQYNaTKfmVGYzYaQXNzkqmWPko
|
||||
xmCJo6tHzC7ubdG8iRCYHzfmahPuuj6EdGPZuSRyYFgJi5Ftz/nAN+84OxtIQ3zn
|
||||
YWOgskQgaLh9YfsKsQ7Sf1NDOsnOnD5TX7UXl07fEpLe9vNCvAFiU8e5Y9LGudU5
|
||||
4bYTAoGBAMdX3a3bXp4cZvXNBJ/QLVyxC6fP1Q4haCR1Od3m+T00Jth2IX2dk/fl
|
||||
p6xiJT1av5JtYabv1dFKaXOS5s1kLGGuCCSKpkvFZm826aQ2AFm0XGqEQDLeei5b
|
||||
A52Kpy/YJ+RkG4BTFtAooFq6DmA0cnoP6oPvG2h6XtDJwDTPInJb
|
||||
-----END RSA PRIVATE KEY-----`
|
||||
caContent = `-----BEGIN CERTIFICATE-----
|
||||
MIIDbTCCAlWgAwIBAgIUBJvFoCowKich7MMfseJ+DYzzirowDQYJKoZIhvcNAQEM
|
||||
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
|
||||
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAzMTExMzIxMDNaGA8yMTIz
|
||||
MDIxNTEzMjEwM1owRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
|
||||
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
|
||||
AQEBBQADggEPADCCAQoCggEBAO4to2YMYj0bxgr2FCiweSTSFuPx33zSw2x/s9Wf
|
||||
OR41bm2DFsyYT5f3sOIKlXZEdLmOKty2e3ho3yC0EyNpVHdykkkHT3aDI17quZax
|
||||
kYi/URqqtl1Z08A22txolc04hAZisg2BypGi3vql81UW1t3zyloGnJoIAeXR9uca
|
||||
ljP6Bk3bwsxoVBLi1JtHrO0hHLQaeHmKhAyrys06X0LRdn7Px48yRZlt6FaLSa8X
|
||||
YiRM0G44bVy/h6BkoQjMYGwVmCVk6zjJ9U7ZPFqdnDMNxAfR+hjDnYodqdLDMTTR
|
||||
1NPVrnEnNwFx0AMLvgt/ba/45vZCEAmSZnFXFAJJcM7ai9ECAwEAAaNTMFEwHQYD
|
||||
VR0OBBYEFHlvMPKZoZ3G0YE7JePJOP+8J38fMB8GA1UdIwQYMBaAFHlvMPKZoZ3G
|
||||
0YE7JePJOP+8J38fMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQEMBQADggEB
|
||||
AMX8dNulADOo9uQgBMyFb9TVra7iY0zZjzv4GY5XY7scd52n6CnfAPvYBBDnTr/O
|
||||
BgNp5jaujb4+9u/2qhV3f9n+/3WOb2CmPehBgVSzlXqHeQ9lshmgwZPeem2T+8Tm
|
||||
Nnc/xQnsUfCFszUDxpkr55+aLVM22j02RWqcZ4q7TAaVYL+kdFVMc8FoqG/0ro6A
|
||||
BjE/Qn0Nn7ciX1VUjDt8l+k7ummPJTmzdi6i6E4AwO9dzrGNgGJ4aWL8cC6xYcIX
|
||||
goVIRTFeONXSDno/oPjWHpIPt7L15heMpKBHNuzPkKx2YVqPHE5QZxWfS+Lzgx+Q
|
||||
E2oTTM0rYKOZ8p6000mhvKI=
|
||||
-----END CERTIFICATE-----`
|
||||
)
|
||||
|
||||
func TestAccount(t *testing.T) {
|
||||
endpoints := []string{
|
||||
"192.168.0.2:2379",
|
||||
@@ -32,3 +105,34 @@ func TestAccount(t *testing.T) {
|
||||
assert.Equal(t, username, account.User)
|
||||
assert.Equal(t, anotherPassword, account.Pass)
|
||||
}
|
||||
|
||||
func TestTLSMethods(t *testing.T) {
|
||||
certFile := createTempFile(t, []byte(certContent))
|
||||
defer os.Remove(certFile)
|
||||
keyFile := createTempFile(t, []byte(keyContent))
|
||||
defer os.Remove(keyFile)
|
||||
caFile := createTempFile(t, []byte(caContent))
|
||||
defer os.Remove(caFile)
|
||||
|
||||
assert.NoError(t, AddTLS([]string{"foo"}, certFile, keyFile, caFile, false))
|
||||
cfg, ok := GetTLS([]string{"foo"})
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, cfg)
|
||||
|
||||
assert.Error(t, AddTLS([]string{"bar"}, "bad-file", keyFile, caFile, false))
|
||||
assert.Error(t, AddTLS([]string{"bar"}, certFile, keyFile, "bad-file", false))
|
||||
}
|
||||
|
||||
func createTempFile(t *testing.T, body []byte) string {
|
||||
tmpFile, err := os.CreateTemp(os.TempDir(), "go-unit-*.tmp")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tmpFile.Close()
|
||||
if err = os.WriteFile(tmpFile.Name(), body, os.ModePerm); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return tmpFile.Name()
|
||||
}
|
||||
|
||||
@@ -337,13 +337,11 @@ func (c *cluster) watchConnState(cli EtcdClient) {
|
||||
// DialClient dials an etcd cluster with given endpoints.
|
||||
func DialClient(endpoints []string) (EtcdClient, error) {
|
||||
cfg := clientv3.Config{
|
||||
Endpoints: endpoints,
|
||||
AutoSyncInterval: autoSyncInterval,
|
||||
DialTimeout: DialTimeout,
|
||||
DialKeepAliveTime: dialKeepAliveTime,
|
||||
DialKeepAliveTimeout: DialTimeout,
|
||||
RejectOldCluster: true,
|
||||
PermitWithoutStream: true,
|
||||
Endpoints: endpoints,
|
||||
AutoSyncInterval: autoSyncInterval,
|
||||
DialTimeout: DialTimeout,
|
||||
RejectOldCluster: true,
|
||||
PermitWithoutStream: true,
|
||||
}
|
||||
if account, ok := GetAccount(endpoints); ok {
|
||||
cfg.Username = account.User
|
||||
|
||||
@@ -2,8 +2,10 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -14,6 +16,7 @@ import (
|
||||
"go.etcd.io/etcd/api/v3/etcdserverpb"
|
||||
"go.etcd.io/etcd/api/v3/mvccpb"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"go.etcd.io/etcd/client/v3/mock/mockserver"
|
||||
)
|
||||
|
||||
var mockLock sync.Mutex
|
||||
@@ -242,3 +245,58 @@ func TestValueOnlyContext(t *testing.T) {
|
||||
ctx.Done()
|
||||
assert.Nil(t, ctx.Err())
|
||||
}
|
||||
|
||||
func TestDialClient(t *testing.T) {
|
||||
svr, err := mockserver.StartMockServers(1)
|
||||
assert.NoError(t, err)
|
||||
svr.StartAt(0)
|
||||
|
||||
certFile := createTempFile(t, []byte(certContent))
|
||||
defer os.Remove(certFile)
|
||||
keyFile := createTempFile(t, []byte(keyContent))
|
||||
defer os.Remove(keyFile)
|
||||
caFile := createTempFile(t, []byte(caContent))
|
||||
defer os.Remove(caFile)
|
||||
|
||||
endpoints := []string{svr.Servers[0].Address}
|
||||
AddAccount(endpoints, "foo", "bar")
|
||||
assert.NoError(t, AddTLS(endpoints, certFile, keyFile, caFile, false))
|
||||
|
||||
old := DialTimeout
|
||||
DialTimeout = time.Millisecond
|
||||
defer func() {
|
||||
DialTimeout = old
|
||||
}()
|
||||
_, err = DialClient(endpoints)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRegistry_Monitor(t *testing.T) {
|
||||
svr, err := mockserver.StartMockServers(1)
|
||||
assert.NoError(t, err)
|
||||
svr.StartAt(0)
|
||||
|
||||
endpoints := []string{svr.Servers[0].Address}
|
||||
GetRegistry().lock.Lock()
|
||||
GetRegistry().clusters = map[string]*cluster{
|
||||
getClusterKey(endpoints): {
|
||||
listeners: map[string][]UpdateListener{},
|
||||
values: map[string]map[string]string{
|
||||
"foo": {
|
||||
"bar": "baz",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
GetRegistry().lock.Unlock()
|
||||
assert.Error(t, GetRegistry().Monitor(endpoints, "foo", new(mockListener)))
|
||||
}
|
||||
|
||||
type mockListener struct {
|
||||
}
|
||||
|
||||
func (m *mockListener) OnAdd(_ KV) {
|
||||
}
|
||||
|
||||
func (m *mockListener) OnDelete(_ KV) {
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ const (
|
||||
autoSyncInterval = time.Minute
|
||||
coolDownInterval = time.Second
|
||||
dialTimeout = 5 * time.Second
|
||||
dialKeepAliveTime = 5 * time.Second
|
||||
requestTimeout = 3 * time.Second
|
||||
endpointsSeparator = ","
|
||||
)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package discov
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -13,6 +16,83 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"golang.org/x/net/http2"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/resolver"
|
||||
"google.golang.org/grpc/resolver/manual"
|
||||
)
|
||||
|
||||
const (
|
||||
certContent = `-----BEGIN CERTIFICATE-----
|
||||
MIIDazCCAlOgAwIBAgIUEg9GVO2oaPn+YSmiqmFIuAo10WIwDQYJKoZIhvcNAQEM
|
||||
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
|
||||
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAzMTExMzIxMjNaGA8yMTIz
|
||||
MDIxNTEzMjEyM1owRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
|
||||
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
|
||||
AQEBBQADggEPADCCAQoCggEBALplXlWsIf0O/IgnIplmiZHKGnxyfyufyE2FBRNk
|
||||
OofRqbKuPH8GNqbkvZm7N29fwTDAQ+mViAggCkDht4hOzoWJMA7KYJt8JnTSWL48
|
||||
M1lcrpc9DL2gszC/JF/FGvyANbBtLklkZPFBGdHUX14pjrT937wqPtm+SqUHSvRT
|
||||
B7bmwmm2drRcmhpVm98LSlV7uQ2EgnJgsLjBPITKUejLmVLHfgX0RwQ2xIpX9pS4
|
||||
FCe1BTacwl2gGp7Mje7y4Mfv3o0ArJW6Tuwbjx59ZXwb1KIP71b7bT04AVS8ZeYO
|
||||
UMLKKuB5UR9x9Rn6cLXOTWBpcMVyzDgrAFLZjnE9LPUolZMCAwEAAaNRME8wHwYD
|
||||
VR0jBBgwFoAUeW8w8pmhncbRgTsl48k4/7wnfx8wCQYDVR0TBAIwADALBgNVHQ8E
|
||||
BAMCBPAwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBDAUAA4IBAQAI
|
||||
y9xaoS88CLPBsX6mxfcTAFVfGNTRW9VN9Ng1cCnUR+YGoXGM/l+qP4f7p8ocdGwK
|
||||
iYZErVTzXYIn+D27//wpY3klJk3gAnEUBT3QRkStBw7XnpbeZ2oPBK+cmDnCnZPS
|
||||
BIF1wxPX7vIgaxs5Zsdqwk3qvZ4Djr2wP7LabNWTLSBKgQoUY45Liw6pffLwcGF9
|
||||
UKlu54bvGze2SufISCR3ib+I+FLvqpvJhXToZWYb/pfI/HccuCL1oot1x8vx6DQy
|
||||
U+TYxlZsKS5mdNxAX3dqEkEMsgEi+g/tzDPXJImfeCGGBhIOXLm8SRypiuGdEbc9
|
||||
xkWYxRPegajuEZGvCqVs
|
||||
-----END CERTIFICATE-----`
|
||||
keyContent = `-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEowIBAAKCAQEAumVeVawh/Q78iCcimWaJkcoafHJ/K5/ITYUFE2Q6h9Gpsq48
|
||||
fwY2puS9mbs3b1/BMMBD6ZWICCAKQOG3iE7OhYkwDspgm3wmdNJYvjwzWVyulz0M
|
||||
vaCzML8kX8Ua/IA1sG0uSWRk8UEZ0dRfXimOtP3fvCo+2b5KpQdK9FMHtubCabZ2
|
||||
tFyaGlWb3wtKVXu5DYSCcmCwuME8hMpR6MuZUsd+BfRHBDbEilf2lLgUJ7UFNpzC
|
||||
XaAansyN7vLgx+/ejQCslbpO7BuPHn1lfBvUog/vVvttPTgBVLxl5g5Qwsoq4HlR
|
||||
H3H1Gfpwtc5NYGlwxXLMOCsAUtmOcT0s9SiVkwIDAQABAoIBAD5meTJNMgO55Kjg
|
||||
ESExxpRcCIno+tHr5+6rvYtEXqPheOIsmmwb9Gfi4+Z3WpOaht5/Pz0Ppj6yGzyl
|
||||
U//6AgGKb+BDuBvVcDpjwPnOxZIBCSHwejdxeQu0scSuA97MPS0XIAvJ5FEv7ijk
|
||||
5Bht6SyGYURpECltHygoTNuGgGqmO+McCJRLE9L09lTBI6UQ/JQwWJqSr7wx6iPU
|
||||
M1Ze/srIV+7cyEPu6i0DGjS1gSQKkX68Lqn1w6oE290O+OZvleO0gZ02fLDWCZke
|
||||
aeD9+EU/Pw+rqm3H6o0szOFIpzhRp41FUdW9sybB3Yp3u7c/574E+04Z/e30LMKs
|
||||
TCtE1QECgYEA3K7KIpw0NH2HXL5C3RHcLmr204xeBfS70riBQQuVUgYdmxak2ima
|
||||
80RInskY8hRhSGTg0l+VYIH8cmjcUyqMSOELS5XfRH99r4QPiK8AguXg80T4VumY
|
||||
W3Pf+zEC2ssgP/gYthV0g0Xj5m2QxktOF9tRw5nkg739ZR4dI9lm/iECgYEA2Dnf
|
||||
uwEDGqHiQRF6/fh5BG/nGVMvrefkqx6WvTJQ3k/M/9WhxB+lr/8yH46TuS8N2b29
|
||||
FoTf3Mr9T7pr/PWkOPzoY3P56nYbKU8xSwCim9xMzhBMzj8/N9ukJvXy27/VOz56
|
||||
eQaKqnvdXNGtPJrIMDGHps2KKWlKLyAlapzjVTMCgYAA/W++tACv85g13EykfT4F
|
||||
n0k4LbsGP9DP4zABQLIMyiY72eAncmRVjwrcW36XJ2xATOONTgx3gF3HjZzfaqNy
|
||||
eD/6uNNllUTVEryXGmHgNHPL45VRnn6memCY2eFvZdXhM5W4y2PYaunY0MkDercA
|
||||
+GTngbs6tBF88KOk04bYwQKBgFl68cRgsdkmnwwQYNaTKfmVGYzYaQXNzkqmWPko
|
||||
xmCJo6tHzC7ubdG8iRCYHzfmahPuuj6EdGPZuSRyYFgJi5Ftz/nAN+84OxtIQ3zn
|
||||
YWOgskQgaLh9YfsKsQ7Sf1NDOsnOnD5TX7UXl07fEpLe9vNCvAFiU8e5Y9LGudU5
|
||||
4bYTAoGBAMdX3a3bXp4cZvXNBJ/QLVyxC6fP1Q4haCR1Od3m+T00Jth2IX2dk/fl
|
||||
p6xiJT1av5JtYabv1dFKaXOS5s1kLGGuCCSKpkvFZm826aQ2AFm0XGqEQDLeei5b
|
||||
A52Kpy/YJ+RkG4BTFtAooFq6DmA0cnoP6oPvG2h6XtDJwDTPInJb
|
||||
-----END RSA PRIVATE KEY-----`
|
||||
caContent = `-----BEGIN CERTIFICATE-----
|
||||
MIIDbTCCAlWgAwIBAgIUBJvFoCowKich7MMfseJ+DYzzirowDQYJKoZIhvcNAQEM
|
||||
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
|
||||
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAzMTExMzIxMDNaGA8yMTIz
|
||||
MDIxNTEzMjEwM1owRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
|
||||
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
|
||||
AQEBBQADggEPADCCAQoCggEBAO4to2YMYj0bxgr2FCiweSTSFuPx33zSw2x/s9Wf
|
||||
OR41bm2DFsyYT5f3sOIKlXZEdLmOKty2e3ho3yC0EyNpVHdykkkHT3aDI17quZax
|
||||
kYi/URqqtl1Z08A22txolc04hAZisg2BypGi3vql81UW1t3zyloGnJoIAeXR9uca
|
||||
ljP6Bk3bwsxoVBLi1JtHrO0hHLQaeHmKhAyrys06X0LRdn7Px48yRZlt6FaLSa8X
|
||||
YiRM0G44bVy/h6BkoQjMYGwVmCVk6zjJ9U7ZPFqdnDMNxAfR+hjDnYodqdLDMTTR
|
||||
1NPVrnEnNwFx0AMLvgt/ba/45vZCEAmSZnFXFAJJcM7ai9ECAwEAAaNTMFEwHQYD
|
||||
VR0OBBYEFHlvMPKZoZ3G0YE7JePJOP+8J38fMB8GA1UdIwQYMBaAFHlvMPKZoZ3G
|
||||
0YE7JePJOP+8J38fMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQEMBQADggEB
|
||||
AMX8dNulADOo9uQgBMyFb9TVra7iY0zZjzv4GY5XY7scd52n6CnfAPvYBBDnTr/O
|
||||
BgNp5jaujb4+9u/2qhV3f9n+/3WOb2CmPehBgVSzlXqHeQ9lshmgwZPeem2T+8Tm
|
||||
Nnc/xQnsUfCFszUDxpkr55+aLVM22j02RWqcZ4q7TAaVYL+kdFVMc8FoqG/0ro6A
|
||||
BjE/Qn0Nn7ciX1VUjDt8l+k7ummPJTmzdi6i6E4AwO9dzrGNgGJ4aWL8cC6xYcIX
|
||||
goVIRTFeONXSDno/oPjWHpIPt7L15heMpKBHNuzPkKx2YVqPHE5QZxWfS+Lzgx+Q
|
||||
E2oTTM0rYKOZ8p6000mhvKI=
|
||||
-----END CERTIFICATE-----`
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -37,7 +117,7 @@ func TestPublisher_register(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestPublisher_registerWithId(t *testing.T) {
|
||||
func TestPublisher_registerWithOptions(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
const id = 2
|
||||
@@ -49,7 +129,15 @@ func TestPublisher_registerWithId(t *testing.T) {
|
||||
ID: 1,
|
||||
}, nil)
|
||||
cli.EXPECT().Put(gomock.Any(), makeEtcdKey("thekey", id), "thevalue", gomock.Any())
|
||||
pub := NewPublisher(nil, "thekey", "thevalue", WithId(id))
|
||||
|
||||
certFile := createTempFile(t, []byte(certContent))
|
||||
defer os.Remove(certFile)
|
||||
keyFile := createTempFile(t, []byte(keyContent))
|
||||
defer os.Remove(keyFile)
|
||||
caFile := createTempFile(t, []byte(caContent))
|
||||
defer os.Remove(caFile)
|
||||
pub := NewPublisher(nil, "thekey", "thevalue", WithId(id),
|
||||
WithPubEtcdTLS(certFile, keyFile, caFile, true))
|
||||
_, err := pub.register(cli)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
@@ -169,3 +257,92 @@ func TestPublisher_Resume(t *testing.T) {
|
||||
}()
|
||||
<-publisher.resumeChan
|
||||
}
|
||||
|
||||
func TestPublisher_keepAliveAsync(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
const id clientv3.LeaseID = 1
|
||||
conn := createMockConn(t)
|
||||
defer conn.Close()
|
||||
cli := internal.NewMockEtcdClient(ctrl)
|
||||
cli.EXPECT().ActiveConnection().Return(conn).AnyTimes()
|
||||
cli.EXPECT().Close()
|
||||
defer cli.Close()
|
||||
cli.ActiveConnection()
|
||||
restore := setMockClient(cli)
|
||||
defer restore()
|
||||
cli.EXPECT().Ctx().AnyTimes()
|
||||
cli.EXPECT().KeepAlive(gomock.Any(), id)
|
||||
cli.EXPECT().Grant(gomock.Any(), timeToLive).Return(&clientv3.LeaseGrantResponse{
|
||||
ID: 1,
|
||||
}, nil)
|
||||
cli.EXPECT().Put(gomock.Any(), makeEtcdKey("thekey", int64(id)), "thevalue", gomock.Any())
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
|
||||
wg.Done()
|
||||
})
|
||||
pub := NewPublisher([]string{"the-endpoint"}, "thekey", "thevalue")
|
||||
pub.lease = id
|
||||
assert.Nil(t, pub.KeepAlive())
|
||||
pub.Stop()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func createMockConn(t *testing.T) *grpc.ClientConn {
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Error while listening. Err: %v", err)
|
||||
}
|
||||
defer lis.Close()
|
||||
lisAddr := resolver.Address{Addr: lis.Addr().String()}
|
||||
lisDone := make(chan struct{})
|
||||
dialDone := make(chan struct{})
|
||||
// 1st listener accepts the connection and then does nothing
|
||||
go func() {
|
||||
defer close(lisDone)
|
||||
conn, err := lis.Accept()
|
||||
if err != nil {
|
||||
t.Errorf("Error while accepting. Err: %v", err)
|
||||
return
|
||||
}
|
||||
framer := http2.NewFramer(conn, conn)
|
||||
if err := framer.WriteSettings(http2.Setting{}); err != nil {
|
||||
t.Errorf("Error while writing settings. Err: %v", err)
|
||||
return
|
||||
}
|
||||
<-dialDone // Close conn only after dial returns.
|
||||
}()
|
||||
|
||||
r := manual.NewBuilderWithScheme("whatever")
|
||||
r.InitialState(resolver.State{Addresses: []resolver.Address{lisAddr}})
|
||||
client, err := grpc.DialContext(context.Background(), r.Scheme()+":///test.server",
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r))
|
||||
close(dialDone)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial failed. Err: %v", err)
|
||||
}
|
||||
|
||||
timeout := time.After(1 * time.Second)
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatal("timed out waiting for server to finish")
|
||||
case <-lisDone:
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func createTempFile(t *testing.T, body []byte) string {
|
||||
tmpFile, err := os.CreateTemp(os.TempDir(), "go-unit-*.tmp")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tmpFile.Close()
|
||||
if err = os.WriteFile(tmpFile.Name(), body, os.ModePerm); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return tmpFile.Name()
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ func TestBulkExecutorFlush(t *testing.T) {
|
||||
wait.Wait()
|
||||
}
|
||||
|
||||
func TestBuldExecutorFlushSlowTasks(t *testing.T) {
|
||||
func TestBulkExecutorFlushSlowTasks(t *testing.T) {
|
||||
const total = 1500
|
||||
lock := new(sync.Mutex)
|
||||
result := make([]any, 0, 10000)
|
||||
|
||||
@@ -81,7 +81,7 @@ func (pe *PeriodicalExecutor) Flush() bool {
|
||||
}())
|
||||
}
|
||||
|
||||
// Sync lets caller to run fn thread-safe with pe, especially for the underlying container.
|
||||
// Sync lets caller run fn thread-safe with pe, especially for the underlying container.
|
||||
func (pe *PeriodicalExecutor) Sync(fn func()) {
|
||||
pe.lock.Lock()
|
||||
defer pe.lock.Unlock()
|
||||
@@ -116,7 +116,7 @@ func (pe *PeriodicalExecutor) addAndCheck(task any) (any, bool) {
|
||||
}
|
||||
|
||||
func (pe *PeriodicalExecutor) backgroundFlush() {
|
||||
threading.GoSafe(func() {
|
||||
go func() {
|
||||
// flush before quit goroutine to avoid missing tasks
|
||||
defer pe.Flush()
|
||||
|
||||
@@ -144,7 +144,7 @@ func (pe *PeriodicalExecutor) backgroundFlush() {
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
func (pe *PeriodicalExecutor) doneExecution() {
|
||||
@@ -162,7 +162,9 @@ func (pe *PeriodicalExecutor) executeTasks(tasks any) bool {
|
||||
|
||||
ok := pe.hasTasks(tasks)
|
||||
if ok {
|
||||
pe.container.Execute(tasks)
|
||||
threading.RunSafe(func() {
|
||||
pe.container.Execute(tasks)
|
||||
})
|
||||
}
|
||||
|
||||
return ok
|
||||
|
||||
@@ -108,25 +108,83 @@ func TestPeriodicalExecutor_Bulk(t *testing.T) {
|
||||
lock.Unlock()
|
||||
}
|
||||
|
||||
func TestPeriodicalExecutor_Panic(t *testing.T) {
|
||||
// avoid data race
|
||||
var lock sync.Mutex
|
||||
ticker := timex.NewFakeTicker()
|
||||
|
||||
var (
|
||||
executedTasks []int
|
||||
expected []int
|
||||
)
|
||||
executor := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks any) {
|
||||
tt := tasks.([]int)
|
||||
lock.Lock()
|
||||
executedTasks = append(executedTasks, tt...)
|
||||
lock.Unlock()
|
||||
if tt[0] == 0 {
|
||||
panic("test")
|
||||
}
|
||||
}))
|
||||
executor.newTicker = func(duration time.Duration) timex.Ticker {
|
||||
return ticker
|
||||
}
|
||||
for i := 0; i < 30; i++ {
|
||||
executor.Add(i)
|
||||
expected = append(expected, i)
|
||||
}
|
||||
ticker.Tick()
|
||||
ticker.Tick()
|
||||
time.Sleep(time.Millisecond)
|
||||
lock.Lock()
|
||||
assert.Equal(t, expected, executedTasks)
|
||||
lock.Unlock()
|
||||
}
|
||||
|
||||
func TestPeriodicalExecutor_FlushPanic(t *testing.T) {
|
||||
var (
|
||||
executedTasks []int
|
||||
expected []int
|
||||
lock sync.Mutex
|
||||
)
|
||||
executor := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks any) {
|
||||
tt := tasks.([]int)
|
||||
lock.Lock()
|
||||
executedTasks = append(executedTasks, tt...)
|
||||
lock.Unlock()
|
||||
if tt[0] == 0 {
|
||||
panic("flush panic")
|
||||
}
|
||||
}))
|
||||
for i := 0; i < 8; i++ {
|
||||
executor.Add(i)
|
||||
expected = append(expected, i)
|
||||
}
|
||||
executor.Flush()
|
||||
lock.Lock()
|
||||
assert.Equal(t, expected, executedTasks)
|
||||
lock.Unlock()
|
||||
}
|
||||
|
||||
func TestPeriodicalExecutor_Wait(t *testing.T) {
|
||||
var lock sync.Mutex
|
||||
executer := NewBulkExecutor(func(tasks []any) {
|
||||
executor := NewBulkExecutor(func(tasks []any) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}, WithBulkTasks(1), WithBulkInterval(time.Second))
|
||||
for i := 0; i < 10; i++ {
|
||||
executer.Add(1)
|
||||
executor.Add(1)
|
||||
}
|
||||
executer.Flush()
|
||||
executer.Wait()
|
||||
executor.Flush()
|
||||
executor.Wait()
|
||||
}
|
||||
|
||||
func TestPeriodicalExecutor_WaitFast(t *testing.T) {
|
||||
const total = 3
|
||||
var cnt int
|
||||
var lock sync.Mutex
|
||||
executer := NewBulkExecutor(func(tasks []any) {
|
||||
executor := NewBulkExecutor(func(tasks []any) {
|
||||
defer func() {
|
||||
cnt++
|
||||
}()
|
||||
@@ -135,10 +193,10 @@ func TestPeriodicalExecutor_WaitFast(t *testing.T) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}, WithBulkTasks(1), WithBulkInterval(10*time.Millisecond))
|
||||
for i := 0; i < total; i++ {
|
||||
executer.Add(2)
|
||||
executor.Add(2)
|
||||
}
|
||||
executer.Flush()
|
||||
executer.Wait()
|
||||
executor.Flush()
|
||||
executor.Wait()
|
||||
assert.Equal(t, total, cnt)
|
||||
}
|
||||
|
||||
@@ -151,13 +209,7 @@ func TestPeriodicalExecutor_Deadlock(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPeriodicalExecutor_hasTasks(t *testing.T) {
|
||||
ticker := timex.NewFakeTicker()
|
||||
defer ticker.Stop()
|
||||
|
||||
exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, nil))
|
||||
exec.newTicker = func(d time.Duration) timex.Ticker {
|
||||
return ticker
|
||||
}
|
||||
assert.False(t, exec.hasTasks(nil))
|
||||
assert.True(t, exec.hasTasks(1))
|
||||
}
|
||||
|
||||
@@ -74,6 +74,11 @@ func TestFirstLineShort(t *testing.T) {
|
||||
assert.Equal(t, "first line", val)
|
||||
}
|
||||
|
||||
func TestFirstLineError(t *testing.T) {
|
||||
_, err := FirstLine("/tmp/does-not-exist")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestLastLine(t *testing.T) {
|
||||
filename, err := fs.TempFilenameWithText(text)
|
||||
assert.Nil(t, err)
|
||||
@@ -113,3 +118,8 @@ func TestLastLineWithLastNewlineShort(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "last line", val)
|
||||
}
|
||||
|
||||
func TestLastLineError(t *testing.T) {
|
||||
_, err := LastLine("/tmp/does-not-exist")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package fs
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux || darwin
|
||||
// +build linux darwin
|
||||
|
||||
package fs
|
||||
|
||||
|
||||
@@ -11,29 +11,29 @@ import (
|
||||
// The file is kept as open, the caller should close the file handle,
|
||||
// and remove the file by name.
|
||||
func TempFileWithText(text string) (*os.File, error) {
|
||||
tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text)))
|
||||
tmpFile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := os.WriteFile(tmpfile.Name(), []byte(text), os.ModeTemporary); err != nil {
|
||||
if err := os.WriteFile(tmpFile.Name(), []byte(text), os.ModeTemporary); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tmpfile, nil
|
||||
return tmpFile, nil
|
||||
}
|
||||
|
||||
// TempFilenameWithText creates the file with the given content,
|
||||
// and returns the filename (full path).
|
||||
// The caller should remove the file after use.
|
||||
func TempFilenameWithText(text string) (string, error) {
|
||||
tmpfile, err := TempFileWithText(text)
|
||||
tmpFile, err := TempFileWithText(text)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
filename := tmpfile.Name()
|
||||
if err = tmpfile.Close(); err != nil {
|
||||
filename := tmpFile.Name()
|
||||
if err = tmpFile.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
|
||||
@@ -1,31 +1,87 @@
|
||||
package fx
|
||||
|
||||
import "github.com/zeromicro/go-zero/core/errorx"
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/errorx"
|
||||
)
|
||||
|
||||
const defaultRetryTimes = 3
|
||||
|
||||
var errTimeout = errors.New("retry timeout")
|
||||
|
||||
type (
|
||||
// RetryOption defines the method to customize DoWithRetry.
|
||||
RetryOption func(*retryOptions)
|
||||
|
||||
retryOptions struct {
|
||||
times int
|
||||
times int
|
||||
interval time.Duration
|
||||
timeout time.Duration
|
||||
}
|
||||
)
|
||||
|
||||
// DoWithRetry runs fn, and retries if failed. Default to retry 3 times.
|
||||
// Note that if the fn function accesses global variables outside the function
|
||||
// and performs modification operations, it is best to lock them,
|
||||
// otherwise there may be data race issues
|
||||
func DoWithRetry(fn func() error, opts ...RetryOption) error {
|
||||
return retry(func(errChan chan error, retryCount int) {
|
||||
errChan <- fn()
|
||||
}, opts...)
|
||||
}
|
||||
|
||||
// DoWithRetryCtx runs fn, and retries if failed. Default to retry 3 times.
|
||||
// fn retryCount indicates the current number of retries, starting from 0
|
||||
// Note that if the fn function accesses global variables outside the function
|
||||
// and performs modification operations, it is best to lock them,
|
||||
// otherwise there may be data race issues
|
||||
func DoWithRetryCtx(ctx context.Context, fn func(ctx context.Context, retryCount int) error,
|
||||
opts ...RetryOption) error {
|
||||
return retry(func(errChan chan error, retryCount int) {
|
||||
errChan <- fn(ctx, retryCount)
|
||||
}, opts...)
|
||||
}
|
||||
|
||||
func retry(fn func(errChan chan error, retryCount int), opts ...RetryOption) error {
|
||||
options := newRetryOptions()
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
var berr errorx.BatchError
|
||||
var cancelFunc context.CancelFunc
|
||||
ctx := context.Background()
|
||||
if options.timeout > 0 {
|
||||
ctx, cancelFunc = context.WithTimeout(ctx, options.timeout)
|
||||
defer cancelFunc()
|
||||
}
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
for i := 0; i < options.times; i++ {
|
||||
if err := fn(); err != nil {
|
||||
berr.Add(err)
|
||||
} else {
|
||||
return nil
|
||||
go fn(errChan, i)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
berr.Add(err)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
case <-ctx.Done():
|
||||
berr.Add(errTimeout)
|
||||
return berr.Err()
|
||||
}
|
||||
|
||||
if options.interval > 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
berr.Add(errTimeout)
|
||||
return berr.Err()
|
||||
case <-time.After(options.interval):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,6 +95,18 @@ func WithRetry(times int) RetryOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithInterval(interval time.Duration) RetryOption {
|
||||
return func(options *retryOptions) {
|
||||
options.interval = interval
|
||||
}
|
||||
}
|
||||
|
||||
func WithTimeout(timeout time.Duration) RetryOption {
|
||||
return func(options *retryOptions) {
|
||||
options.timeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
func newRetryOptions() *retryOptions {
|
||||
return &retryOptions{
|
||||
times: defaultRetryTimes,
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package fx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@@ -12,31 +14,103 @@ func TestRetry(t *testing.T) {
|
||||
return errors.New("any")
|
||||
}))
|
||||
|
||||
var times int
|
||||
times1 := 0
|
||||
assert.Nil(t, DoWithRetry(func() error {
|
||||
times++
|
||||
if times == defaultRetryTimes {
|
||||
times1++
|
||||
if times1 == defaultRetryTimes {
|
||||
return nil
|
||||
}
|
||||
return errors.New("any")
|
||||
}))
|
||||
|
||||
times = 0
|
||||
times2 := 0
|
||||
assert.NotNil(t, DoWithRetry(func() error {
|
||||
times++
|
||||
if times == defaultRetryTimes+1 {
|
||||
times2++
|
||||
if times2 == defaultRetryTimes+1 {
|
||||
return nil
|
||||
}
|
||||
return errors.New("any")
|
||||
}))
|
||||
|
||||
total := 2 * defaultRetryTimes
|
||||
times = 0
|
||||
times3 := 0
|
||||
assert.Nil(t, DoWithRetry(func() error {
|
||||
times++
|
||||
if times == total {
|
||||
times3++
|
||||
if times3 == total {
|
||||
return nil
|
||||
}
|
||||
return errors.New("any")
|
||||
}, WithRetry(total)))
|
||||
}
|
||||
|
||||
func TestRetryWithTimeout(t *testing.T) {
|
||||
assert.Nil(t, DoWithRetry(func() error {
|
||||
return nil
|
||||
}, WithTimeout(time.Millisecond*500)))
|
||||
|
||||
times1 := 0
|
||||
assert.Nil(t, DoWithRetry(func() error {
|
||||
times1++
|
||||
if times1 == 1 {
|
||||
return errors.New("any ")
|
||||
}
|
||||
time.Sleep(time.Millisecond * 150)
|
||||
return nil
|
||||
}, WithTimeout(time.Millisecond*250)))
|
||||
|
||||
total := defaultRetryTimes
|
||||
times2 := 0
|
||||
assert.Nil(t, DoWithRetry(func() error {
|
||||
times2++
|
||||
if times2 == total {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
return errors.New("any")
|
||||
}, WithTimeout(time.Millisecond*50*(time.Duration(total)+2))))
|
||||
|
||||
assert.NotNil(t, DoWithRetry(func() error {
|
||||
return errors.New("any")
|
||||
}, WithTimeout(time.Millisecond*250)))
|
||||
}
|
||||
|
||||
func TestRetryWithInterval(t *testing.T) {
|
||||
times1 := 0
|
||||
assert.NotNil(t, DoWithRetry(func() error {
|
||||
times1++
|
||||
if times1 == 1 {
|
||||
return errors.New("any")
|
||||
}
|
||||
time.Sleep(time.Millisecond * 150)
|
||||
return nil
|
||||
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
|
||||
|
||||
times2 := 0
|
||||
assert.NotNil(t, DoWithRetry(func() error {
|
||||
times2++
|
||||
if times2 == 2 {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(time.Millisecond * 150)
|
||||
return errors.New("any ")
|
||||
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
|
||||
|
||||
}
|
||||
|
||||
func TestRetryCtx(t *testing.T) {
|
||||
assert.NotNil(t, DoWithRetryCtx(context.Background(), func(ctx context.Context, retryCount int) error {
|
||||
if retryCount == 0 {
|
||||
return errors.New("any")
|
||||
}
|
||||
time.Sleep(time.Millisecond * 150)
|
||||
return nil
|
||||
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
|
||||
|
||||
assert.NotNil(t, DoWithRetryCtx(context.Background(), func(ctx context.Context, retryCount int) error {
|
||||
if retryCount == 1 {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(time.Millisecond * 150)
|
||||
return errors.New("any ")
|
||||
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
|
||||
}
|
||||
|
||||
@@ -292,6 +292,18 @@ func (s Stream) Map(fn MapFunc, opts ...Option) Stream {
|
||||
}, opts...)
|
||||
}
|
||||
|
||||
// Max returns the maximum item from the underlying source.
|
||||
func (s Stream) Max(less LessFunc) any {
|
||||
var max any
|
||||
for item := range s.source {
|
||||
if max == nil || less(max, item) {
|
||||
max = item
|
||||
}
|
||||
}
|
||||
|
||||
return max
|
||||
}
|
||||
|
||||
// Merge merges all the items into a slice and generates a new stream.
|
||||
func (s Stream) Merge() Stream {
|
||||
var items []any
|
||||
@@ -306,6 +318,18 @@ func (s Stream) Merge() Stream {
|
||||
return Range(source)
|
||||
}
|
||||
|
||||
// Min returns the minimum item from the underlying source.
|
||||
func (s Stream) Min(less LessFunc) any {
|
||||
var min any
|
||||
for item := range s.source {
|
||||
if min == nil || less(item, min) {
|
||||
min = item
|
||||
}
|
||||
}
|
||||
|
||||
return min
|
||||
}
|
||||
|
||||
// NoneMatch returns whether all elements of this stream don't match the provided predicate.
|
||||
// May not evaluate the predicate on all elements if not necessary for determining the result.
|
||||
// If the stream is empty then true is returned and the predicate is not evaluated.
|
||||
|
||||
@@ -503,6 +503,83 @@ func TestStream_Concat(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestStream_Max(t *testing.T) {
|
||||
runCheckedTest(t, func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
elements []any
|
||||
max any
|
||||
}{
|
||||
{
|
||||
name: "no elements with nil",
|
||||
},
|
||||
{
|
||||
name: "no elements",
|
||||
elements: []any{},
|
||||
max: nil,
|
||||
},
|
||||
{
|
||||
name: "1 element",
|
||||
elements: []any{1},
|
||||
max: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple elements",
|
||||
elements: []any{1, 2, 9, 5, 8},
|
||||
max: 9,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
val := Just(test.elements...).Max(func(a, b any) bool {
|
||||
return a.(int) < b.(int)
|
||||
})
|
||||
assetEqual(t, test.max, val)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStream_Min(t *testing.T) {
|
||||
runCheckedTest(t, func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
elements []any
|
||||
min any
|
||||
}{
|
||||
{
|
||||
name: "no elements with nil",
|
||||
min: nil,
|
||||
},
|
||||
{
|
||||
name: "no elements",
|
||||
elements: []any{},
|
||||
min: nil,
|
||||
},
|
||||
{
|
||||
name: "1 element",
|
||||
elements: []any{1},
|
||||
min: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple elements",
|
||||
elements: []any{-1, 1, 2, 9, 5, 8},
|
||||
min: -1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
val := Just(test.elements...).Min(func(a, b any) bool {
|
||||
return a.(int) < b.(int)
|
||||
})
|
||||
assetEqual(t, test.min, val)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkParallelMapReduce(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
|
||||
@@ -40,11 +40,11 @@ b`,
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.input, func(t *testing.T) {
|
||||
tmpfile, err := fs.TempFilenameWithText(test.input)
|
||||
tmpFile, err := fs.TempFilenameWithText(test.input)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(tmpfile)
|
||||
defer os.Remove(tmpFile)
|
||||
|
||||
content, err := ReadText(tmpfile)
|
||||
content, err := ReadText(tmpFile)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, test.expect, content)
|
||||
})
|
||||
@@ -59,9 +59,9 @@ func TestReadTextLines(t *testing.T) {
|
||||
#a
|
||||
3`
|
||||
|
||||
tmpfile, err := fs.TempFilenameWithText(text)
|
||||
tmpFile, err := fs.TempFilenameWithText(text)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(tmpfile)
|
||||
defer os.Remove(tmpFile)
|
||||
|
||||
tests := []struct {
|
||||
options []TextReadOption
|
||||
@@ -87,7 +87,7 @@ func TestReadTextLines(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.Rand(), func(t *testing.T) {
|
||||
lines, err := ReadTextLines(tmpfile, test.options...)
|
||||
lines, err := ReadTextLines(tmpFile, test.options...)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, test.expectLines, len(lines))
|
||||
})
|
||||
|
||||
@@ -9,21 +9,6 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
)
|
||||
|
||||
// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
|
||||
const periodScript = `local limit = tonumber(ARGV[1])
|
||||
local window = tonumber(ARGV[2])
|
||||
local current = redis.call("INCRBY", KEYS[1], 1)
|
||||
if current == 1 then
|
||||
redis.call("expire", KEYS[1], window)
|
||||
end
|
||||
if current < limit then
|
||||
return 1
|
||||
elseif current == limit then
|
||||
return 2
|
||||
else
|
||||
return 0
|
||||
end`
|
||||
|
||||
const (
|
||||
// Unknown means not initialized state.
|
||||
Unknown = iota
|
||||
@@ -39,8 +24,25 @@ const (
|
||||
internalHitQuota = 2
|
||||
)
|
||||
|
||||
// ErrUnknownCode is an error that represents unknown status code.
|
||||
var ErrUnknownCode = errors.New("unknown status code")
|
||||
var (
|
||||
// ErrUnknownCode is an error that represents unknown status code.
|
||||
ErrUnknownCode = errors.New("unknown status code")
|
||||
|
||||
// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
|
||||
periodScript = redis.NewScript(`local limit = tonumber(ARGV[1])
|
||||
local window = tonumber(ARGV[2])
|
||||
local current = redis.call("INCRBY", KEYS[1], 1)
|
||||
if current == 1 then
|
||||
redis.call("expire", KEYS[1], window)
|
||||
end
|
||||
if current < limit then
|
||||
return 1
|
||||
elseif current == limit then
|
||||
return 2
|
||||
else
|
||||
return 0
|
||||
end`)
|
||||
)
|
||||
|
||||
type (
|
||||
// PeriodOption defines the method to customize a PeriodLimit.
|
||||
@@ -80,7 +82,7 @@ func (h *PeriodLimit) Take(key string) (int, error) {
|
||||
|
||||
// TakeCtx requests a permit with context, it returns the permit state.
|
||||
func (h *PeriodLimit) TakeCtx(ctx context.Context, key string) (int, error) {
|
||||
resp, err := h.limitStore.EvalCtx(ctx, periodScript, []string{h.keyPrefix + key}, []string{
|
||||
resp, err := h.limitStore.ScriptRunCtx(ctx, periodScript, []string{h.keyPrefix + key}, []string{
|
||||
strconv.Itoa(h.quota),
|
||||
strconv.Itoa(h.calcExpireSeconds()),
|
||||
})
|
||||
|
||||
@@ -33,9 +33,7 @@ func TestPeriodLimit_RedisUnavailable(t *testing.T) {
|
||||
}
|
||||
|
||||
func testPeriodLimit(t *testing.T, opts ...PeriodOption) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean()
|
||||
store := redistest.CreateRedis(t)
|
||||
|
||||
const (
|
||||
seconds = 1
|
||||
|
||||
@@ -15,10 +15,15 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
|
||||
// KEYS[1] as tokens_key
|
||||
// KEYS[2] as timestamp_key
|
||||
script = `local rate = tonumber(ARGV[1])
|
||||
tokenFormat = "{%s}.tokens"
|
||||
timestampFormat = "{%s}.ts"
|
||||
pingInterval = time.Millisecond * 100
|
||||
)
|
||||
|
||||
// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
|
||||
// KEYS[1] as tokens_key
|
||||
// KEYS[2] as timestamp_key
|
||||
var script = redis.NewScript(`local rate = tonumber(ARGV[1])
|
||||
local capacity = tonumber(ARGV[2])
|
||||
local now = tonumber(ARGV[3])
|
||||
local requested = tonumber(ARGV[4])
|
||||
@@ -45,11 +50,7 @@ end
|
||||
redis.call("setex", KEYS[1], ttl, new_tokens)
|
||||
redis.call("setex", KEYS[2], ttl, now)
|
||||
|
||||
return allowed`
|
||||
tokenFormat = "{%s}.tokens"
|
||||
timestampFormat = "{%s}.ts"
|
||||
pingInterval = time.Millisecond * 100
|
||||
)
|
||||
return allowed`)
|
||||
|
||||
// A TokenLimiter controls how frequently events are allowed to happen with in one second.
|
||||
type TokenLimiter struct {
|
||||
@@ -110,7 +111,7 @@ func (lim *TokenLimiter) reserveN(ctx context.Context, now time.Time, n int) boo
|
||||
return lim.rescueLimiter.AllowN(now, n)
|
||||
}
|
||||
|
||||
resp, err := lim.store.EvalCtx(ctx,
|
||||
resp, err := lim.store.ScriptRunCtx(ctx,
|
||||
script,
|
||||
[]string{
|
||||
lim.tokenKey,
|
||||
|
||||
@@ -70,9 +70,7 @@ func TestTokenLimit_Rescue(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTokenLimit_Take(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean()
|
||||
store := redistest.CreateRedis(t)
|
||||
|
||||
const (
|
||||
total = 100
|
||||
@@ -92,9 +90,7 @@ func TestTokenLimit_Take(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTokenLimit_TakeBurst(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean()
|
||||
store := redistest.CreateRedis(t)
|
||||
|
||||
const (
|
||||
total = 100
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package logc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -11,14 +10,11 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||
)
|
||||
|
||||
func TestAddGlobalFields(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
buf := logtest.NewCollector(t)
|
||||
|
||||
Info(context.Background(), "hello")
|
||||
buf.Reset()
|
||||
@@ -34,155 +30,90 @@ func TestAddGlobalFields(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAlert(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
Alert(context.Background(), "foo")
|
||||
assert.True(t, strings.Contains(buf.String(), "foo"), buf.String())
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
Error(context.Background(), "foo")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestErrorf(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
Errorf(context.Background(), "foo %s", "bar")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestErrorv(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
Errorv(context.Background(), "foo")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestErrorw(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
Errorw(context.Background(), "foo", Field("a", "b"))
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestInfo(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
Info(context.Background(), "foo")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestInfof(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
Infof(context.Background(), "foo %s", "bar")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestInfov(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
Infov(context.Background(), "foo")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestInfow(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
Infow(context.Background(), "foo", Field("a", "b"))
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestDebug(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
Debug(context.Background(), "foo")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestDebugf(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
Debugf(context.Background(), "foo %s", "bar")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestDebugv(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
Debugv(context.Background(), "foo")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestDebugw(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
Debugw(context.Background(), "foo", Field("a", "b"))
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
@@ -204,48 +135,28 @@ func TestMisc(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSlow(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
Slow(context.Background(), "foo")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
|
||||
}
|
||||
|
||||
func TestSlowf(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
Slowf(context.Background(), "foo %s", "bar")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
|
||||
}
|
||||
|
||||
func TestSlowv(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
Slowv(context.Background(), "foo")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
|
||||
}
|
||||
|
||||
func TestSloww(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
Sloww(context.Background(), "foo", Field("a", "b"))
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
|
||||
|
||||
@@ -23,7 +23,7 @@ type LogConf struct {
|
||||
MaxContentLength uint32 `json:",optional"`
|
||||
// Compress represents whether to compress the log file, default is `false`.
|
||||
Compress bool `json:",optional"`
|
||||
// Stdout represents whether to log statistics, default is `true`.
|
||||
// Stat represents whether to log statistics, default is `true`.
|
||||
Stat bool `json:",default=true"`
|
||||
// KeepDays represents how many days the log files will be kept. Default to keep all files.
|
||||
// Only take effect when Mode is `file` or `volume`, both work when Rotation is `daily` or `size`.
|
||||
@@ -32,13 +32,13 @@ type LogConf struct {
|
||||
StackCooldownMillis int `json:",default=100"`
|
||||
// MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever.
|
||||
// Only take effect when RotationRuleType is `size`.
|
||||
// Even thougth `MaxBackups` sets 0, log files will still be removed
|
||||
// Even though `MaxBackups` sets 0, log files will still be removed
|
||||
// if the `KeepDays` limitation is reached.
|
||||
MaxBackups int `json:",default=0"`
|
||||
// MaxSize represents how much space the writing log file takes up. 0 means no limit. The unit is `MB`.
|
||||
// Only take effect when RotationRuleType is `size`
|
||||
MaxSize int `json:",default=0"`
|
||||
// RotationRuleType represents the type of log rotation rule. Default is `daily`.
|
||||
// Rotation represents the type of log rotation rule. Default is `daily`.
|
||||
// daily: daily rotation.
|
||||
// size: size limited rotation.
|
||||
Rotation string `json:",default=daily,options=[daily,size]"`
|
||||
|
||||
@@ -197,7 +197,12 @@ func Must(err error) {
|
||||
msg := err.Error()
|
||||
log.Print(msg)
|
||||
getWriter().Severe(msg)
|
||||
os.Exit(1)
|
||||
|
||||
if ExitOnFatal.True() {
|
||||
os.Exit(1)
|
||||
} else {
|
||||
panic(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// MustSetup sets up logging with given config c. It exits on error.
|
||||
@@ -353,14 +358,16 @@ func createOutput(path string) (io.WriteCloser, error) {
|
||||
return nil, ErrLogPathNotSet
|
||||
}
|
||||
|
||||
var rule RotateRule
|
||||
switch options.rotationRule {
|
||||
case sizeRotationRule:
|
||||
return NewLogger(path, NewSizeLimitRotateRule(path, backupFileDelimiter, options.keepDays,
|
||||
options.maxSize, options.maxBackups, options.gzipEnabled), options.gzipEnabled)
|
||||
rule = NewSizeLimitRotateRule(path, backupFileDelimiter, options.keepDays, options.maxSize,
|
||||
options.maxBackups, options.gzipEnabled)
|
||||
default:
|
||||
return NewLogger(path, DefaultRotateRule(path, backupFileDelimiter, options.keepDays,
|
||||
options.gzipEnabled), options.gzipEnabled)
|
||||
rule = DefaultRotateRule(path, backupFileDelimiter, options.keepDays, options.gzipEnabled)
|
||||
}
|
||||
|
||||
return NewLogger(path, rule, options.gzipEnabled)
|
||||
}
|
||||
|
||||
func getWriter() Writer {
|
||||
|
||||
@@ -24,6 +24,10 @@ var (
|
||||
_ Writer = (*mockWriter)(nil)
|
||||
)
|
||||
|
||||
func init() {
|
||||
ExitOnFatal.Set(false)
|
||||
}
|
||||
|
||||
type mockWriter struct {
|
||||
lock sync.Mutex
|
||||
builder strings.Builder
|
||||
@@ -208,6 +212,12 @@ func TestFileLineConsoleMode(t *testing.T) {
|
||||
assert.True(t, w.Contains(fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestMust(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
Must(errors.New("foo"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestStructedLogAlert(t *testing.T) {
|
||||
w := new(mockWriter)
|
||||
old := writer.Swap(w)
|
||||
@@ -574,26 +584,38 @@ func TestSetup(t *testing.T) {
|
||||
atomic.StoreUint32(&encoding, jsonEncodingType)
|
||||
}()
|
||||
|
||||
setupOnce = sync.Once{}
|
||||
MustSetup(LogConf{
|
||||
ServiceName: "any",
|
||||
Mode: "console",
|
||||
Encoding: "json",
|
||||
TimeFormat: timeFormat,
|
||||
})
|
||||
setupOnce = sync.Once{}
|
||||
MustSetup(LogConf{
|
||||
ServiceName: "any",
|
||||
Mode: "console",
|
||||
TimeFormat: timeFormat,
|
||||
})
|
||||
setupOnce = sync.Once{}
|
||||
MustSetup(LogConf{
|
||||
ServiceName: "any",
|
||||
Mode: "file",
|
||||
Path: os.TempDir(),
|
||||
})
|
||||
setupOnce = sync.Once{}
|
||||
MustSetup(LogConf{
|
||||
ServiceName: "any",
|
||||
Mode: "volume",
|
||||
Path: os.TempDir(),
|
||||
})
|
||||
setupOnce = sync.Once{}
|
||||
MustSetup(LogConf{
|
||||
ServiceName: "any",
|
||||
Mode: "console",
|
||||
TimeFormat: timeFormat,
|
||||
})
|
||||
setupOnce = sync.Once{}
|
||||
MustSetup(LogConf{
|
||||
ServiceName: "any",
|
||||
Mode: "console",
|
||||
|
||||
84
core/logx/logtest/logtest.go
Normal file
84
core/logx/logtest/logtest.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package logtest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
type Buffer struct {
|
||||
buf *bytes.Buffer
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func Discard(t *testing.T) {
|
||||
prev := logx.Reset()
|
||||
logx.SetWriter(logx.NewWriter(io.Discard))
|
||||
|
||||
t.Cleanup(func() {
|
||||
logx.SetWriter(prev)
|
||||
})
|
||||
}
|
||||
|
||||
func NewCollector(t *testing.T) *Buffer {
|
||||
var buf bytes.Buffer
|
||||
writer := logx.NewWriter(&buf)
|
||||
prev := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
|
||||
t.Cleanup(func() {
|
||||
logx.SetWriter(prev)
|
||||
})
|
||||
|
||||
return &Buffer{
|
||||
buf: &buf,
|
||||
t: t,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) Bytes() []byte {
|
||||
return b.buf.Bytes()
|
||||
}
|
||||
|
||||
func (b *Buffer) Content() string {
|
||||
var m map[string]interface{}
|
||||
if err := json.Unmarshal(b.buf.Bytes(), &m); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
content, ok := m["content"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch val := content.(type) {
|
||||
case string:
|
||||
return val
|
||||
default:
|
||||
// err is impossible to be not nil, unmarshaled from b.buf.Bytes()
|
||||
bs, _ := json.Marshal(content)
|
||||
return string(bs)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Buffer) Reset() {
|
||||
b.buf.Reset()
|
||||
}
|
||||
|
||||
func (b *Buffer) String() string {
|
||||
return b.buf.String()
|
||||
}
|
||||
|
||||
func PanicOnFatal(t *testing.T) {
|
||||
ok := logx.ExitOnFatal.CompareAndSwap(true, false)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
logx.ExitOnFatal.CompareAndSwap(false, true)
|
||||
})
|
||||
}
|
||||
44
core/logx/logtest/logtest_test.go
Normal file
44
core/logx/logtest/logtest_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package logtest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
func TestCollector(t *testing.T) {
|
||||
const input = "hello"
|
||||
c := NewCollector(t)
|
||||
logx.Info(input)
|
||||
assert.Equal(t, input, c.Content())
|
||||
assert.Contains(t, c.String(), input)
|
||||
c.Reset()
|
||||
assert.Empty(t, c.Bytes())
|
||||
}
|
||||
|
||||
func TestPanicOnFatal(t *testing.T) {
|
||||
const input = "hello"
|
||||
Discard(t)
|
||||
logx.Info(input)
|
||||
|
||||
PanicOnFatal(t)
|
||||
PanicOnFatal(t)
|
||||
assert.Panics(t, func() {
|
||||
logx.Must(errors.New("foo"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestCollectorContent(t *testing.T) {
|
||||
const input = "hello"
|
||||
c := NewCollector(t)
|
||||
c.buf.WriteString(input)
|
||||
assert.Empty(t, c.Content())
|
||||
c.Reset()
|
||||
c.buf.WriteString(`{}`)
|
||||
assert.Empty(t, c.Content())
|
||||
c.Reset()
|
||||
c.buf.WriteString(`{"content":1}`)
|
||||
assert.Equal(t, "1", c.Content())
|
||||
}
|
||||
@@ -65,7 +65,7 @@ func (l *richLogger) Errorf(format string, v ...any) {
|
||||
}
|
||||
|
||||
func (l *richLogger) Errorv(v any) {
|
||||
l.err(fmt.Sprint(v))
|
||||
l.err(v)
|
||||
}
|
||||
|
||||
func (l *richLogger) Errorw(msg string, fields ...LogField) {
|
||||
|
||||
@@ -66,6 +66,9 @@ func TestTraceDebug(t *testing.T) {
|
||||
l.WithDuration(time.Second).Debugv(testlog)
|
||||
validate(t, w.String(), true, true)
|
||||
w.Reset()
|
||||
l.WithDuration(time.Second).Debugv(testobj)
|
||||
validateContentType(t, w.String(), map[string]any{}, true, true)
|
||||
w.Reset()
|
||||
l.WithDuration(time.Second).Debugw(testlog, Field("foo", "bar"))
|
||||
validate(t, w.String(), true, true)
|
||||
assert.True(t, strings.Contains(w.String(), "foo"), w.String())
|
||||
@@ -103,6 +106,9 @@ func TestTraceError(t *testing.T) {
|
||||
l.WithDuration(time.Second).Errorv(testlog)
|
||||
validate(t, w.String(), true, true)
|
||||
w.Reset()
|
||||
l.WithDuration(time.Second).Errorv(testobj)
|
||||
validateContentType(t, w.String(), map[string]any{}, true, true)
|
||||
w.Reset()
|
||||
l.WithDuration(time.Second).Errorw(testlog, Field("basket", "ball"))
|
||||
validate(t, w.String(), true, true)
|
||||
assert.True(t, strings.Contains(w.String(), "basket"), w.String())
|
||||
@@ -137,6 +143,9 @@ func TestTraceInfo(t *testing.T) {
|
||||
l.WithDuration(time.Second).Infov(testlog)
|
||||
validate(t, w.String(), true, true)
|
||||
w.Reset()
|
||||
l.WithDuration(time.Second).Infov(testobj)
|
||||
validateContentType(t, w.String(), map[string]any{}, true, true)
|
||||
w.Reset()
|
||||
l.WithDuration(time.Second).Infow(testlog, Field("basket", "ball"))
|
||||
validate(t, w.String(), true, true)
|
||||
assert.True(t, strings.Contains(w.String(), "basket"), w.String())
|
||||
@@ -173,6 +182,9 @@ func TestTraceInfoConsole(t *testing.T) {
|
||||
w.Reset()
|
||||
l.WithDuration(time.Second).Infov(testlog)
|
||||
validate(t, w.String(), true, true)
|
||||
w.Reset()
|
||||
l.WithDuration(time.Second).Infov(testobj)
|
||||
validateContentType(t, w.String(), map[string]any{}, true, true)
|
||||
}
|
||||
|
||||
func TestTraceSlow(t *testing.T) {
|
||||
@@ -204,6 +216,9 @@ func TestTraceSlow(t *testing.T) {
|
||||
l.WithDuration(time.Second).Slowv(testlog)
|
||||
validate(t, w.String(), true, true)
|
||||
w.Reset()
|
||||
l.WithDuration(time.Second).Slowv(testobj)
|
||||
validateContentType(t, w.String(), map[string]any{}, true, true)
|
||||
w.Reset()
|
||||
l.WithDuration(time.Second).Sloww(testlog, Field("basket", "ball"))
|
||||
validate(t, w.String(), true, true)
|
||||
assert.True(t, strings.Contains(w.String(), "basket"), w.String())
|
||||
@@ -311,8 +326,32 @@ func validate(t *testing.T, body string, expectedTrace, expectedSpan bool) {
|
||||
assert.Equal(t, expectedSpan, len(val.Span) > 0, body)
|
||||
}
|
||||
|
||||
type mockValue struct {
|
||||
Trace string `json:"trace"`
|
||||
Span string `json:"span"`
|
||||
Foo string `json:"foo"`
|
||||
func validateContentType(t *testing.T, body string, expectedType any, expectedTrace, expectedSpan bool) {
|
||||
var val mockValue
|
||||
dec := json.NewDecoder(strings.NewReader(body))
|
||||
|
||||
for {
|
||||
var doc mockValue
|
||||
err := dec.Decode(&doc)
|
||||
if err == io.EOF {
|
||||
// all done
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
val = doc
|
||||
}
|
||||
|
||||
assert.IsType(t, expectedType, val.Content, body)
|
||||
assert.Equal(t, expectedTrace, len(val.Trace) > 0, body)
|
||||
assert.Equal(t, expectedSpan, len(val.Span) > 0, body)
|
||||
}
|
||||
|
||||
type mockValue struct {
|
||||
Trace string `json:"trace"`
|
||||
Span string `json:"span"`
|
||||
Foo string `json:"foo"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
@@ -237,7 +237,7 @@ func NewLogger(filename string, rule RotateRule, compress bool) (*RotateLogger,
|
||||
rule: rule,
|
||||
compress: compress,
|
||||
}
|
||||
if err := l.init(); err != nil {
|
||||
if err := l.initialize(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -281,7 +281,7 @@ func (l *RotateLogger) getBackupFilename() string {
|
||||
return l.backup
|
||||
}
|
||||
|
||||
func (l *RotateLogger) init() error {
|
||||
func (l *RotateLogger) initialize() error {
|
||||
l.backup = l.rule.BackupFileName()
|
||||
|
||||
if fileInfo, err := os.Stat(l.filename); err != nil {
|
||||
|
||||
@@ -2,6 +2,7 @@ package logx
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"testing"
|
||||
@@ -13,18 +14,58 @@ import (
|
||||
)
|
||||
|
||||
func TestDailyRotateRuleMarkRotated(t *testing.T) {
|
||||
var rule DailyRotateRule
|
||||
rule.MarkRotated()
|
||||
assert.Equal(t, getNowDate(), rule.rotatedTime)
|
||||
t.Run("daily rule", func(t *testing.T) {
|
||||
var rule DailyRotateRule
|
||||
rule.MarkRotated()
|
||||
assert.Equal(t, getNowDate(), rule.rotatedTime)
|
||||
})
|
||||
|
||||
t.Run("daily rule", func(t *testing.T) {
|
||||
rule := DefaultRotateRule("test", "-", 1, false)
|
||||
_, ok := rule.(*DailyRotateRule)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
|
||||
var rule DailyRotateRule
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
rule.days = 1
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
rule.gzip = true
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
t.Run("no files", func(t *testing.T) {
|
||||
var rule DailyRotateRule
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
rule.days = 1
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
rule.gzip = true
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
})
|
||||
|
||||
t.Run("bad files", func(t *testing.T) {
|
||||
rule := DailyRotateRule{
|
||||
filename: "[a-z",
|
||||
}
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
rule.days = 1
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
rule.gzip = true
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
})
|
||||
|
||||
t.Run("temp files", func(t *testing.T) {
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
_ = f1.Close()
|
||||
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
_ = f2.Close()
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove(f1.Name())
|
||||
_ = os.Remove(f2.Name())
|
||||
})
|
||||
rule := DailyRotateRule{
|
||||
filename: path.Join(os.TempDir(), "go-zero-test-"),
|
||||
days: 1,
|
||||
}
|
||||
assert.NotEmpty(t, rule.OutdatedFiles())
|
||||
})
|
||||
}
|
||||
|
||||
func TestDailyRotateRuleShallRotate(t *testing.T) {
|
||||
@@ -34,20 +75,101 @@ func TestDailyRotateRuleShallRotate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSizeLimitRotateRuleMarkRotated(t *testing.T) {
|
||||
var rule SizeLimitRotateRule
|
||||
rule.MarkRotated()
|
||||
assert.Equal(t, getNowDateInRFC3339Format(), rule.rotatedTime)
|
||||
t.Run("size limit rule", func(t *testing.T) {
|
||||
var rule SizeLimitRotateRule
|
||||
rule.MarkRotated()
|
||||
assert.Equal(t, getNowDateInRFC3339Format(), rule.rotatedTime)
|
||||
})
|
||||
|
||||
t.Run("size limit rule", func(t *testing.T) {
|
||||
rule := NewSizeLimitRotateRule("foo", "-", 1, 1, 1, false)
|
||||
rule.MarkRotated()
|
||||
assert.Equal(t, getNowDateInRFC3339Format(), rule.(*SizeLimitRotateRule).rotatedTime)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
|
||||
var rule SizeLimitRotateRule
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
rule.days = 1
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
rule.gzip = true
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
rule.maxBackups = 0
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
t.Run("no files", func(t *testing.T) {
|
||||
var rule SizeLimitRotateRule
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
rule.days = 1
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
rule.gzip = true
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
rule.maxBackups = 0
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
})
|
||||
|
||||
t.Run("bad files", func(t *testing.T) {
|
||||
rule := SizeLimitRotateRule{
|
||||
DailyRotateRule: DailyRotateRule{
|
||||
filename: "[a-z",
|
||||
},
|
||||
}
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
rule.days = 1
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
rule.gzip = true
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
})
|
||||
|
||||
t.Run("temp files", func(t *testing.T) {
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = f1.Close()
|
||||
_ = os.Remove(f1.Name())
|
||||
_ = f2.Close()
|
||||
_ = os.Remove(f2.Name())
|
||||
_ = f3.Close()
|
||||
_ = os.Remove(f3.Name())
|
||||
})
|
||||
rule := SizeLimitRotateRule{
|
||||
DailyRotateRule: DailyRotateRule{
|
||||
filename: path.Join(os.TempDir(), "go-zero-test-"),
|
||||
days: 1,
|
||||
},
|
||||
maxBackups: 3,
|
||||
}
|
||||
assert.NotEmpty(t, rule.OutdatedFiles())
|
||||
})
|
||||
|
||||
t.Run("no backups", func(t *testing.T) {
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = f1.Close()
|
||||
_ = os.Remove(f1.Name())
|
||||
_ = f2.Close()
|
||||
_ = os.Remove(f2.Name())
|
||||
_ = f3.Close()
|
||||
_ = os.Remove(f3.Name())
|
||||
})
|
||||
rule := SizeLimitRotateRule{
|
||||
DailyRotateRule: DailyRotateRule{
|
||||
filename: path.Join(os.TempDir(), "go-zero-test-"),
|
||||
days: 1,
|
||||
},
|
||||
}
|
||||
assert.NotEmpty(t, rule.OutdatedFiles())
|
||||
|
||||
logger := new(RotateLogger)
|
||||
logger.rule = &rule
|
||||
logger.maybeDeleteOutdatedFiles()
|
||||
assert.Empty(t, rule.OutdatedFiles())
|
||||
})
|
||||
}
|
||||
|
||||
func TestSizeLimitRotateRuleShallRotate(t *testing.T) {
|
||||
@@ -61,14 +183,26 @@ func TestSizeLimitRotateRuleShallRotate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRotateLoggerClose(t *testing.T) {
|
||||
filename, err := fs.TempFilenameWithText("foo")
|
||||
assert.Nil(t, err)
|
||||
if len(filename) > 0 {
|
||||
defer os.Remove(filename)
|
||||
}
|
||||
logger, err := NewLogger(filename, new(DailyRotateRule), false)
|
||||
assert.Nil(t, err)
|
||||
assert.Nil(t, logger.Close())
|
||||
t.Run("close", func(t *testing.T) {
|
||||
filename, err := fs.TempFilenameWithText("foo")
|
||||
assert.Nil(t, err)
|
||||
if len(filename) > 0 {
|
||||
defer os.Remove(filename)
|
||||
}
|
||||
logger, err := NewLogger(filename, new(DailyRotateRule), false)
|
||||
assert.Nil(t, err)
|
||||
_, err = logger.Write([]byte("foo"))
|
||||
assert.Nil(t, err)
|
||||
assert.Nil(t, logger.Close())
|
||||
})
|
||||
|
||||
t.Run("close and write", func(t *testing.T) {
|
||||
logger := new(RotateLogger)
|
||||
logger.done = make(chan struct{})
|
||||
close(logger.done)
|
||||
_, err := logger.Write([]byte("foo"))
|
||||
assert.ErrorIs(t, err, ErrLogFileClosed)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRotateLoggerGetBackupFilename(t *testing.T) {
|
||||
@@ -179,7 +313,7 @@ func TestRotateLoggerWithSizeLimitRotateRuleClose(t *testing.T) {
|
||||
}
|
||||
logger, err := NewLogger(filename, new(SizeLimitRotateRule), false)
|
||||
assert.Nil(t, err)
|
||||
assert.Nil(t, logger.Close())
|
||||
_ = logger.Close()
|
||||
}
|
||||
|
||||
func TestRotateLoggerGetBackupWithSizeLimitRotateRuleFilename(t *testing.T) {
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
|
||||
const testlog = "Stay hungry, stay foolish."
|
||||
|
||||
var testobj = map[string]any{"foo": "bar"}
|
||||
|
||||
func TestCollectSysLog(t *testing.T) {
|
||||
CollectSysLog()
|
||||
content := getContent(captureOutput(func() {
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package logx
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
)
|
||||
|
||||
const (
|
||||
// DebugLevel logs everything
|
||||
@@ -61,6 +65,8 @@ var (
|
||||
ErrLogPathNotSet = errors.New("log path must be set")
|
||||
// ErrLogServiceNameNotSet is an error that indicates that the service name is not set.
|
||||
ErrLogServiceNameNotSet = errors.New("log service name must be set")
|
||||
// ExitOnFatal defines whether to exit on fatal errors, defined here to make it easier to test.
|
||||
ExitOnFatal = syncx.ForAtomicBool(true)
|
||||
|
||||
truncatedField = Field(truncatedKey, true)
|
||||
)
|
||||
|
||||
@@ -97,6 +97,15 @@ func TestConsoleWriter(t *testing.T) {
|
||||
w.(*concreteWriter).statLog = easyToCloseWriter{}
|
||||
}
|
||||
|
||||
func TestNewFileWriter(t *testing.T) {
|
||||
t.Run("access", func(t *testing.T) {
|
||||
_, err := newFileWriter(LogConf{
|
||||
Path: "/not-exists",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNopWriter(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
var w nopWriter
|
||||
|
||||
@@ -148,14 +148,17 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
|
||||
return errValueNotSettable
|
||||
}
|
||||
|
||||
baseType := fieldType.Elem()
|
||||
dereffedBaseType := Deref(baseType)
|
||||
dereffedBaseKind := dereffedBaseType.Kind()
|
||||
refValue := reflect.ValueOf(mapValue)
|
||||
if refValue.Kind() != reflect.Slice {
|
||||
return errTypeMismatch
|
||||
}
|
||||
if refValue.IsNil() {
|
||||
return nil
|
||||
}
|
||||
|
||||
baseType := fieldType.Elem()
|
||||
dereffedBaseType := Deref(baseType)
|
||||
dereffedBaseKind := dereffedBaseType.Kind()
|
||||
conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap())
|
||||
if refValue.Len() == 0 {
|
||||
value.Set(conv)
|
||||
@@ -289,6 +292,10 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any)
|
||||
return reflect.ValueOf(mapValue), nil
|
||||
}
|
||||
|
||||
if keyType != valueType.Key() {
|
||||
return emptyValue, errTypeMismatch
|
||||
}
|
||||
|
||||
refValue := reflect.ValueOf(mapValue)
|
||||
targetValue := reflect.MakeMapWithSize(mapType, refValue.Len())
|
||||
dereffedElemType := Deref(elemType)
|
||||
@@ -343,7 +350,12 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any)
|
||||
return emptyValue, errTypeMismatch
|
||||
}
|
||||
|
||||
targetValue.SetMapIndex(key, reflect.ValueOf(v))
|
||||
val := reflect.ValueOf(v)
|
||||
if !val.Type().AssignableTo(dereffedElemType) {
|
||||
return emptyValue, errTypeMismatch
|
||||
}
|
||||
|
||||
targetValue.SetMapIndex(key, val)
|
||||
case json.Number:
|
||||
target := reflect.New(dereffedElemType)
|
||||
if err := setValueFromString(dereffedElemKind, target.Elem(), v.String()); err != nil {
|
||||
@@ -486,7 +498,7 @@ func (u *Unmarshaler) processAnonymousStructFieldOptional(fieldType reflect.Type
|
||||
}
|
||||
|
||||
if filled && required != requiredFilled {
|
||||
return fmt.Errorf("%s is not fully set", key)
|
||||
return fmt.Errorf("%q is not fully set", key)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -509,8 +521,8 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
|
||||
vp valueWithParent, opts *fieldOptionsWithContext, fullName string) error {
|
||||
derefedFieldType := Deref(fieldType)
|
||||
typeKind := derefedFieldType.Kind()
|
||||
valueKind := reflect.TypeOf(vp.value).Kind()
|
||||
mapValue := vp.value
|
||||
valueKind := reflect.TypeOf(mapValue).Kind()
|
||||
|
||||
switch {
|
||||
case valueKind == reflect.Map && typeKind == reflect.Struct:
|
||||
@@ -523,6 +535,8 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
|
||||
current: mapValuer(mv),
|
||||
parent: vp.parent,
|
||||
}, fullName)
|
||||
case typeKind == reflect.Slice && valueKind == reflect.Slice:
|
||||
return u.fillSlice(fieldType, value, mapValue)
|
||||
case valueKind == reflect.Map && typeKind == reflect.Map:
|
||||
return u.fillMap(fieldType, value, mapValue)
|
||||
case valueKind == reflect.String && typeKind == reflect.Map:
|
||||
@@ -541,23 +555,16 @@ func (u *Unmarshaler) processFieldPrimitive(fieldType reflect.Type, value reflec
|
||||
typeKind := Deref(fieldType).Kind()
|
||||
valueKind := reflect.TypeOf(mapValue).Kind()
|
||||
|
||||
switch {
|
||||
case typeKind == reflect.Slice && valueKind == reflect.Slice:
|
||||
return u.fillSlice(fieldType, value, mapValue)
|
||||
case typeKind == reflect.Map && valueKind == reflect.Map:
|
||||
return u.fillMap(fieldType, value, mapValue)
|
||||
switch v := mapValue.(type) {
|
||||
case json.Number:
|
||||
return u.processFieldPrimitiveWithJSONNumber(fieldType, value, v, opts, fullName)
|
||||
default:
|
||||
switch v := mapValue.(type) {
|
||||
case json.Number:
|
||||
return u.processFieldPrimitiveWithJSONNumber(fieldType, value, v, opts, fullName)
|
||||
default:
|
||||
if typeKind == valueKind {
|
||||
if err := validateValueInOptions(mapValue, opts.options()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fillWithSameType(fieldType, value, mapValue, opts)
|
||||
if typeKind == valueKind {
|
||||
if err := validateValueInOptions(mapValue, opts.options()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fillWithSameType(fieldType, value, mapValue, opts)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -691,6 +698,10 @@ func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value ref
|
||||
|
||||
func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect.Value,
|
||||
m valuerWithParent, fullName string) error {
|
||||
if !field.IsExported() {
|
||||
return nil
|
||||
}
|
||||
|
||||
key, opts, err := u.parseOptionsWithContext(field, m, fullName)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -715,7 +726,7 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect
|
||||
// When fillDefault is used, m is a null value, hasValue must be false, all priority judgments fillDefault.
|
||||
if u.opts.fillDefault {
|
||||
if !value.IsZero() {
|
||||
return fmt.Errorf("set the default value, %s must be zero", fullName)
|
||||
return fmt.Errorf("set the default value, %q must be zero", fullName)
|
||||
}
|
||||
return u.processNamedFieldWithoutValue(field.Type, value, opts, fullName)
|
||||
} else if !hasValue {
|
||||
@@ -736,11 +747,11 @@ func (u *Unmarshaler) processNamedFieldWithValue(fieldType reflect.Type, value r
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("field %s mustn't be nil", key)
|
||||
return fmt.Errorf("field %q mustn't be nil", key)
|
||||
}
|
||||
|
||||
if !value.CanSet() {
|
||||
return fmt.Errorf("field %s is not settable", key)
|
||||
return fmt.Errorf("field %q is not settable", key)
|
||||
}
|
||||
|
||||
maybeNewValue(fieldType, value)
|
||||
@@ -784,7 +795,7 @@ func (u *Unmarshaler) processNamedFieldWithValueFromString(fieldType reflect.Typ
|
||||
}
|
||||
|
||||
if !stringx.Contains(options, checkValue) {
|
||||
return fmt.Errorf(`value "%s" for field "%s" is not defined in options "%v"`,
|
||||
return fmt.Errorf(`value "%s" for field %q is not defined in options "%v"`,
|
||||
mapValue, key, options)
|
||||
}
|
||||
}
|
||||
@@ -810,6 +821,11 @@ func (u *Unmarshaler) processNamedFieldWithoutValue(fieldType reflect.Type, valu
|
||||
}
|
||||
|
||||
if u.opts.fillDefault {
|
||||
if fieldType.Kind() != reflect.Ptr && fieldKind == reflect.Struct {
|
||||
return u.processFieldNotFromString(fieldType, value, valueWithParent{
|
||||
value: emptyMap,
|
||||
}, opts, fullName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -865,12 +881,14 @@ func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v any, fullName
|
||||
|
||||
numFields := baseType.NumField()
|
||||
for i := 0; i < numFields; i++ {
|
||||
field := baseType.Field(i)
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
typeField := baseType.Field(i)
|
||||
valueField := valElem.Field(i)
|
||||
if err := u.processField(typeField, valueField, m, fullName); err != nil {
|
||||
if len(fullName) > 0 {
|
||||
err = fmt.Errorf("%w, fullName: %s, field: %s, type: %s",
|
||||
err, fullName, typeField.Name, valueField.Type().Name())
|
||||
}
|
||||
|
||||
if err := u.processField(field, valElem.Field(i), m, fullName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -1024,11 +1042,11 @@ func join(elem ...string) string {
|
||||
}
|
||||
|
||||
func newInitError(name string) error {
|
||||
return fmt.Errorf("field %s is not set", name)
|
||||
return fmt.Errorf("field %q is not set", name)
|
||||
}
|
||||
|
||||
func newTypeMismatchError(name string) error {
|
||||
return fmt.Errorf("type mismatch for field %s", name)
|
||||
return fmt.Errorf("type mismatch for field %q", name)
|
||||
}
|
||||
|
||||
func readKeys(key string) []string {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -255,7 +255,7 @@ func parseGroupedSegments(val string) []string {
|
||||
|
||||
// don't modify returned fieldOptions, it's cached and shared among different calls.
|
||||
func parseKeyAndOptions(tagName string, field reflect.StructField) (string, *fieldOptions, error) {
|
||||
value := field.Tag.Get(tagName)
|
||||
value := strings.TrimSpace(field.Tag.Get(tagName))
|
||||
if len(value) == 0 {
|
||||
return field.Name, nil, nil
|
||||
}
|
||||
@@ -370,7 +370,7 @@ func parseOption(fieldOpts *fieldOptions, fieldName, option string) error {
|
||||
fieldOpts.Optional = true
|
||||
fieldOpts.OptionalDep = segs[1]
|
||||
default:
|
||||
return fmt.Errorf("field %s has wrong optional", fieldName)
|
||||
return fmt.Errorf("field %q has wrong optional", fieldName)
|
||||
}
|
||||
case option == optionalOption:
|
||||
fieldOpts.Optional = true
|
||||
@@ -429,7 +429,7 @@ func parseOptions(val string) []string {
|
||||
func parseProperty(field, tag, val string) (string, error) {
|
||||
segs := strings.Split(val, equalToken)
|
||||
if len(segs) != 2 {
|
||||
return "", fmt.Errorf("field %s has wrong %s", field, tag)
|
||||
return "", fmt.Errorf("field %q has wrong tag value %q", field, tag)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(segs[1]), nil
|
||||
@@ -628,7 +628,7 @@ func validateValueInOptions(val any, options []string) error {
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
if !stringx.Contains(options, v) {
|
||||
return fmt.Errorf(`error: value "%s" is not defined in options "%v"`, v, options)
|
||||
return fmt.Errorf(`error: value %q is not defined in options "%v"`, v, options)
|
||||
}
|
||||
default:
|
||||
if !stringx.Contains(options, Repr(v)) {
|
||||
|
||||
@@ -144,6 +144,10 @@ func TestParseSegments(t *testing.T) {
|
||||
input: "",
|
||||
expect: []string{},
|
||||
},
|
||||
{
|
||||
input: " ",
|
||||
expect: []string{},
|
||||
},
|
||||
{
|
||||
input: ",",
|
||||
expect: []string{""},
|
||||
|
||||
@@ -34,7 +34,7 @@ type (
|
||||
recursiveValuer node
|
||||
)
|
||||
|
||||
// Value gets the value assciated with the given key from mv.
|
||||
// Value gets the value associated with the given key from mv.
|
||||
func (mv mapValuer) Value(key string) (any, bool) {
|
||||
v, ok := mv[key]
|
||||
return v, ok
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build fuzz
|
||||
// +build fuzz
|
||||
|
||||
package mr
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ package mr
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"log"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
var errDummy = errors.New("dummy")
|
||||
|
||||
func init() {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
log.SetOutput(io.Discard)
|
||||
}
|
||||
|
||||
func TestFinish(t *testing.T) {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package proc
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -21,13 +20,11 @@ func TestEnvInt(t *testing.T) {
|
||||
val, ok := EnvInt("any")
|
||||
assert.Equal(t, 0, val)
|
||||
assert.False(t, ok)
|
||||
err := os.Setenv("anyInt", "10")
|
||||
assert.Nil(t, err)
|
||||
t.Setenv("anyInt", "10")
|
||||
val, ok = EnvInt("anyInt")
|
||||
assert.Equal(t, 10, val)
|
||||
assert.True(t, ok)
|
||||
err = os.Setenv("anyString", "a")
|
||||
assert.Nil(t, err)
|
||||
t.Setenv("anyString", "a")
|
||||
val, ok = EnvInt("anyString")
|
||||
assert.Equal(t, 0, val)
|
||||
assert.False(t, ok)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package proc
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux || darwin
|
||||
// +build linux darwin
|
||||
|
||||
package proc
|
||||
|
||||
|
||||
@@ -5,19 +5,11 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||
)
|
||||
|
||||
func TestDumpGoroutines(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
w := logx.NewWriter(&buf)
|
||||
o := logx.Reset()
|
||||
logx.SetWriter(w)
|
||||
defer func() {
|
||||
logx.Reset()
|
||||
logx.SetWriter(o)
|
||||
}()
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
dumpGoroutines()
|
||||
assert.True(t, strings.Contains(buf.String(), ".dump"))
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package proc
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux || darwin
|
||||
// +build linux darwin
|
||||
|
||||
package proc
|
||||
|
||||
|
||||
@@ -5,25 +5,16 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||
)
|
||||
|
||||
func TestProfile(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
w := logx.NewWriter(&buf)
|
||||
o := logx.Reset()
|
||||
logx.SetWriter(w)
|
||||
|
||||
defer func() {
|
||||
logx.Reset()
|
||||
logx.SetWriter(o)
|
||||
}()
|
||||
|
||||
c := logtest.NewCollector(t)
|
||||
profiler := StartProfile()
|
||||
// start again should not work
|
||||
assert.NotNil(t, StartProfile())
|
||||
profiler.Stop()
|
||||
// stop twice
|
||||
profiler.Stop()
|
||||
assert.True(t, strings.Contains(buf.String(), ".pprof"))
|
||||
assert.True(t, strings.Contains(c.String(), ".pprof"))
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package proc
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux || darwin
|
||||
// +build linux darwin
|
||||
|
||||
package proc
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux || darwin
|
||||
// +build linux darwin
|
||||
|
||||
package proc
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package proc
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux || darwin
|
||||
// +build linux darwin
|
||||
|
||||
package proc
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@ package prof
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
@@ -13,6 +15,10 @@ const (
|
||||
|
||||
// DisplayStats prints the goroutine, memory, GC stats with given interval, default to 5 seconds.
|
||||
func DisplayStats(interval ...time.Duration) {
|
||||
displayStatsWithWriter(os.Stdout, interval...)
|
||||
}
|
||||
|
||||
func displayStatsWithWriter(writer io.Writer, interval ...time.Duration) {
|
||||
duration := defaultInterval
|
||||
for _, val := range interval {
|
||||
duration = val
|
||||
@@ -24,7 +30,7 @@ func DisplayStats(interval ...time.Duration) {
|
||||
for range ticker.C {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
fmt.Printf("Goroutines: %d, Alloc: %vm, TotalAlloc: %vm, Sys: %vm, NumGC: %v\n",
|
||||
fmt.Fprintf(writer, "Goroutines: %d, Alloc: %vm, TotalAlloc: %vm, Sys: %vm, NumGC: %v\n",
|
||||
runtime.NumGoroutine(), m.Alloc/mega, m.TotalAlloc/mega, m.Sys/mega, m.NumGC)
|
||||
}
|
||||
}()
|
||||
|
||||
36
core/prof/runtime_test.go
Normal file
36
core/prof/runtime_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package prof
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDisplayStats(t *testing.T) {
|
||||
writer := &threadSafeBuffer{
|
||||
buf: strings.Builder{},
|
||||
}
|
||||
displayStatsWithWriter(writer, time.Millisecond*10)
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
assert.Contains(t, writer.String(), "Goroutines: ")
|
||||
}
|
||||
|
||||
type threadSafeBuffer struct {
|
||||
buf strings.Builder
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (b *threadSafeBuffer) String() string {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
return b.buf.String()
|
||||
}
|
||||
|
||||
func (b *threadSafeBuffer) Write(p []byte) (n int, err error) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
return b.buf.Write(p)
|
||||
}
|
||||
@@ -21,6 +21,11 @@ func Enabled() bool {
|
||||
return enabled.True()
|
||||
}
|
||||
|
||||
// Enable enables prometheus.
|
||||
func Enable() {
|
||||
enabled.Set(true)
|
||||
}
|
||||
|
||||
// StartAgent starts a prometheus agent.
|
||||
func StartAgent(c Config) {
|
||||
if len(c.Host) == 0 {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package queue
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -37,10 +38,82 @@ func TestQueue(t *testing.T) {
|
||||
assert.Equal(t, int32(rounds), atomic.LoadInt32(&consumer.count))
|
||||
}
|
||||
|
||||
func TestQueue_Broadcast(t *testing.T) {
|
||||
producer := newMockedProducer(rounds)
|
||||
consumer := newMockedConsumer()
|
||||
consumer.wait.Add(consumers)
|
||||
q := NewQueue(func() (Producer, error) {
|
||||
return producer, nil
|
||||
}, func() (Consumer, error) {
|
||||
return consumer, nil
|
||||
})
|
||||
q.AddListener(new(mockedListener))
|
||||
q.SetName("mockqueue")
|
||||
q.SetNumConsumer(consumers)
|
||||
q.SetNumProducer(1)
|
||||
q.Broadcast("message")
|
||||
go func() {
|
||||
producer.wait.Wait()
|
||||
q.Stop()
|
||||
}()
|
||||
q.Start()
|
||||
consumer.wait.Wait()
|
||||
assert.Equal(t, int32(rounds), atomic.LoadInt32(&consumer.count))
|
||||
assert.Equal(t, int32(consumers), atomic.LoadInt32(&consumer.events))
|
||||
}
|
||||
|
||||
func TestQueue_PauseResume(t *testing.T) {
|
||||
producer := newMockedProducer(rounds)
|
||||
consumer := newMockedConsumer()
|
||||
consumer.wait.Add(consumers)
|
||||
q := NewQueue(func() (Producer, error) {
|
||||
return producer, nil
|
||||
}, func() (Consumer, error) {
|
||||
return consumer, nil
|
||||
})
|
||||
q.AddListener(new(mockedListener))
|
||||
q.SetName("mockqueue")
|
||||
q.SetNumConsumer(consumers)
|
||||
q.SetNumProducer(1)
|
||||
go func() {
|
||||
producer.wait.Wait()
|
||||
q.Stop()
|
||||
}()
|
||||
q.Start()
|
||||
producer.listener.OnProducerPause()
|
||||
assert.Equal(t, int32(0), atomic.LoadInt32(&q.active))
|
||||
producer.listener.OnProducerResume()
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&q.active))
|
||||
assert.Equal(t, int32(rounds), atomic.LoadInt32(&consumer.count))
|
||||
}
|
||||
|
||||
func TestQueue_ConsumeError(t *testing.T) {
|
||||
producer := newMockedProducer(rounds)
|
||||
consumer := newMockedConsumer()
|
||||
consumer.consumeErr = errors.New("consume error")
|
||||
consumer.wait.Add(consumers)
|
||||
q := NewQueue(func() (Producer, error) {
|
||||
return producer, nil
|
||||
}, func() (Consumer, error) {
|
||||
return consumer, nil
|
||||
})
|
||||
q.AddListener(new(mockedListener))
|
||||
q.SetName("mockqueue")
|
||||
q.SetNumConsumer(consumers)
|
||||
q.SetNumProducer(1)
|
||||
go func() {
|
||||
producer.wait.Wait()
|
||||
q.Stop()
|
||||
}()
|
||||
q.Start()
|
||||
assert.Equal(t, int32(rounds), atomic.LoadInt32(&consumer.count))
|
||||
}
|
||||
|
||||
type mockedConsumer struct {
|
||||
count int32
|
||||
events int32
|
||||
wait sync.WaitGroup
|
||||
count int32
|
||||
events int32
|
||||
consumeErr error
|
||||
wait sync.WaitGroup
|
||||
}
|
||||
|
||||
func newMockedConsumer() *mockedConsumer {
|
||||
@@ -49,7 +122,7 @@ func newMockedConsumer() *mockedConsumer {
|
||||
|
||||
func (c *mockedConsumer) Consume(string) error {
|
||||
atomic.AddInt32(&c.count, 1)
|
||||
return nil
|
||||
return c.consumeErr
|
||||
}
|
||||
|
||||
func (c *mockedConsumer) OnEvent(any) {
|
||||
@@ -59,9 +132,10 @@ func (c *mockedConsumer) OnEvent(any) {
|
||||
}
|
||||
|
||||
type mockedProducer struct {
|
||||
total int32
|
||||
count int32
|
||||
wait sync.WaitGroup
|
||||
total int32
|
||||
count int32
|
||||
listener ProduceListener
|
||||
wait sync.WaitGroup
|
||||
}
|
||||
|
||||
func newMockedProducer(total int32) *mockedProducer {
|
||||
@@ -72,6 +146,7 @@ func newMockedProducer(total int32) *mockedProducer {
|
||||
}
|
||||
|
||||
func (p *mockedProducer) AddListener(listener ProduceListener) {
|
||||
p.listener = listener
|
||||
}
|
||||
|
||||
func (p *mockedProducer) Produce() (string, bool) {
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package rescue
|
||||
|
||||
import "github.com/zeromicro/go-zero/core/logx"
|
||||
import (
|
||||
"context"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
// Recover is used with defer to do cleanup on panics.
|
||||
// Use it like:
|
||||
@@ -15,3 +20,14 @@ func Recover(cleanups ...func()) {
|
||||
logx.ErrorStack(p)
|
||||
}
|
||||
}
|
||||
|
||||
// RecoverCtx is used with defer to do cleanup on panics.
|
||||
func RecoverCtx(ctx context.Context, cleanups ...func()) {
|
||||
for _, cleanup := range cleanups {
|
||||
cleanup()
|
||||
}
|
||||
|
||||
if p := recover(); p != nil {
|
||||
logx.WithContext(ctx).Errorf("%+v\n%s", p, debug.Stack())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package rescue
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
@@ -25,3 +26,17 @@ func TestRescue(t *testing.T) {
|
||||
})
|
||||
assert.Equal(t, int32(5), atomic.LoadInt32(&count))
|
||||
}
|
||||
|
||||
func TestRescueCtx(t *testing.T) {
|
||||
var count int32
|
||||
assert.NotPanics(t, func() {
|
||||
defer RecoverCtx(context.Background(), func() {
|
||||
atomic.AddInt32(&count, 2)
|
||||
}, func() {
|
||||
atomic.AddInt32(&count, 3)
|
||||
})
|
||||
|
||||
panic("hello")
|
||||
})
|
||||
assert.Equal(t, int32(5), atomic.LoadInt32(&count))
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build debug
|
||||
// +build debug
|
||||
|
||||
package search
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/load"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
@@ -39,9 +37,7 @@ type ServiceConf struct {
|
||||
|
||||
// MustSetUp sets up the service, exits on error.
|
||||
func (sc ServiceConf) MustSetUp() {
|
||||
if err := sc.SetUp(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
logx.Must(sc.SetUp())
|
||||
}
|
||||
|
||||
// SetUp sets up the service.
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
package stat
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package stat
|
||||
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package stat
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -13,8 +11,7 @@ import (
|
||||
)
|
||||
|
||||
func TestReport(t *testing.T) {
|
||||
os.Setenv(clusterNameKey, "test-cluster")
|
||||
defer os.Unsetenv(clusterNameKey)
|
||||
t.Setenv(clusterNameKey, "test-cluster")
|
||||
|
||||
var count int32
|
||||
SetReporter(func(s string) {
|
||||
|
||||
@@ -3,6 +3,7 @@ package internal
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
@@ -278,13 +279,12 @@ func runningInUserNS() bool {
|
||||
var a, b, c int64
|
||||
fmt.Sscanf(line, "%d %d %d", &a, &b, &c)
|
||||
|
||||
/*
|
||||
* We assume we are in the initial user namespace if we have a full
|
||||
* range - 4294967295 uids starting at uid 0.
|
||||
*/
|
||||
if a == 0 && b == 0 && c == 4294967295 {
|
||||
// We assume we are in the initial user namespace if we have a full
|
||||
// range - 4294967295 uids starting at uid 0.
|
||||
if a == 0 && b == 0 && c == math.MaxUint32 {
|
||||
return
|
||||
}
|
||||
|
||||
inUserNS = true
|
||||
})
|
||||
|
||||
|
||||
27
core/stat/internal/cgroup_linux_test.go
Normal file
27
core/stat/internal/cgroup_linux_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRunningInUserNS(t *testing.T) {
|
||||
// should be false in docker
|
||||
assert.False(t, runningInUserNS())
|
||||
}
|
||||
|
||||
func TestCgroupV1(t *testing.T) {
|
||||
if isCgroup2UnifiedMode() {
|
||||
cg, err := currentCgroupV1()
|
||||
assert.NoError(t, err)
|
||||
_, err = cg.cpus()
|
||||
assert.Error(t, err)
|
||||
_, err = cg.cpuPeriodUs()
|
||||
assert.Error(t, err)
|
||||
_, err = cg.cpuQuotaUs()
|
||||
assert.Error(t, err)
|
||||
_, err = cg.usageAllCpus()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
package internal
|
||||
|
||||
|
||||
@@ -141,7 +141,7 @@ func (c *metricsContainer) Execute(v any) {
|
||||
report.Median = float32(medianTask.Duration) / float32(time.Millisecond)
|
||||
tenPercent := fiftyPercent / 5
|
||||
if tenPercent > 0 {
|
||||
top10pTasks := topK(tasks, tenPercent)
|
||||
top10pTasks := topK(top50pTasks, tenPercent)
|
||||
task90th := top10pTasks[0]
|
||||
report.Top90th = float32(task90th.Duration) / float32(time.Millisecond)
|
||||
onePercent := tenPercent / 10
|
||||
@@ -163,7 +163,7 @@ func (c *metricsContainer) Execute(v any) {
|
||||
report.Top99p9th = mostDuration
|
||||
}
|
||||
} else {
|
||||
mostDuration := getTopDuration(tasks)
|
||||
mostDuration := getTopDuration(top50pTasks)
|
||||
report.Top90th = mostDuration
|
||||
report.Top99th = mostDuration
|
||||
report.Top99p9th = mostDuration
|
||||
|
||||
@@ -6,11 +6,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
func TestMetrics(t *testing.T) {
|
||||
logx.Disable()
|
||||
DisableLog()
|
||||
defer logEnabled.Set(true)
|
||||
|
||||
|
||||
59
core/stat/usage_test.go
Normal file
59
core/stat/usage_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package stat
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||
)
|
||||
|
||||
func TestBToMb(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
bytes uint64
|
||||
expected float32
|
||||
}{
|
||||
{
|
||||
name: "Test 1: Convert 0 bytes to MB",
|
||||
bytes: 0,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "Test 2: Convert 1048576 bytes to MB",
|
||||
bytes: 1048576,
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Test 3: Convert 2097152 bytes to MB",
|
||||
bytes: 2097152,
|
||||
expected: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := bToMb(test.bytes)
|
||||
assert.Equal(t, test.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrintUsage(t *testing.T) {
|
||||
c := logtest.NewCollector(t)
|
||||
|
||||
printUsage()
|
||||
|
||||
output := c.String()
|
||||
assert.Contains(t, output, "CPU:")
|
||||
assert.Contains(t, output, "MEMORY:")
|
||||
assert.Contains(t, output, "Alloc=")
|
||||
assert.Contains(t, output, "TotalAlloc=")
|
||||
assert.Contains(t, output, "Sys=")
|
||||
assert.Contains(t, output, "NumGC=")
|
||||
|
||||
lines := strings.Split(output, "\n")
|
||||
assert.Len(t, lines, 2)
|
||||
fields := strings.Split(lines[0], ", ")
|
||||
assert.Len(t, fields, 5)
|
||||
}
|
||||
@@ -69,3 +69,62 @@ func TestFieldNamesWithDashTagAndOptions(t *testing.T) {
|
||||
assert.Equal(t, expected, out)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostgreSqlJoin(t *testing.T) {
|
||||
// Test with empty input array
|
||||
var input []string
|
||||
var expectedOutput string
|
||||
assert.Equal(t, expectedOutput, PostgreSqlJoin(input))
|
||||
|
||||
// Test with single element input array
|
||||
input = []string{"foo"}
|
||||
expectedOutput = "foo = $2"
|
||||
assert.Equal(t, expectedOutput, PostgreSqlJoin(input))
|
||||
|
||||
// Test with multiple elements input array
|
||||
input = []string{"foo", "bar", "baz"}
|
||||
expectedOutput = "foo = $2, bar = $3, baz = $4"
|
||||
assert.Equal(t, expectedOutput, PostgreSqlJoin(input))
|
||||
}
|
||||
|
||||
type testStruct struct {
|
||||
Foo string `db:"foo"`
|
||||
Bar int `db:"bar"`
|
||||
Baz bool `db:"-"`
|
||||
}
|
||||
|
||||
func TestRawFieldNames(t *testing.T) {
|
||||
// Test with a struct without tags
|
||||
in := struct {
|
||||
Foo string
|
||||
Bar int
|
||||
}{}
|
||||
expectedOutput := []string{"`Foo`", "`Bar`"}
|
||||
assert.ElementsMatch(t, expectedOutput, RawFieldNames(in))
|
||||
|
||||
// Test pg without db tag
|
||||
expectedOutput = []string{"Foo", "Bar"}
|
||||
assert.ElementsMatch(t, expectedOutput, RawFieldNames(in, true))
|
||||
|
||||
// Test with a struct with tags
|
||||
input := testStruct{}
|
||||
expectedOutput = []string{"`foo`", "`bar`"}
|
||||
assert.ElementsMatch(t, expectedOutput, RawFieldNames(input))
|
||||
|
||||
// Test with nil input (pointer)
|
||||
var nilInput *testStruct
|
||||
assert.Panics(t, func() {
|
||||
RawFieldNames(nilInput)
|
||||
}, "RawFieldNames should panic with nil input")
|
||||
|
||||
// Test with non-struct input
|
||||
inputInt := 42
|
||||
assert.Panics(t, func() {
|
||||
RawFieldNames(inputInt)
|
||||
}, "RawFieldNames should panic with non-struct input")
|
||||
|
||||
// Test with PostgreSQL flag
|
||||
input = testStruct{}
|
||||
expectedOutput = []string{"foo", "bar"}
|
||||
assert.ElementsMatch(t, expectedOutput, RawFieldNames(input, true))
|
||||
}
|
||||
|
||||
12
core/stores/cache/cache_test.go
vendored
12
core/stores/cache/cache_test.go
vendored
@@ -112,12 +112,8 @@ func (mc *mockedNode) TakeWithExpireCtx(ctx context.Context, val any, key string
|
||||
func TestCache_SetDel(t *testing.T) {
|
||||
t.Run("test set del", func(t *testing.T) {
|
||||
const total = 1000
|
||||
r1, clean1, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean1()
|
||||
r2, clean2, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean2()
|
||||
r1 := redistest.CreateRedis(t)
|
||||
r2 := redistest.CreateRedis(t)
|
||||
conf := ClusterConf{
|
||||
{
|
||||
RedisConf: redis.RedisConf{
|
||||
@@ -193,9 +189,7 @@ func TestCache_SetDel(t *testing.T) {
|
||||
|
||||
func TestCache_OneNode(t *testing.T) {
|
||||
const total = 1000
|
||||
r, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean()
|
||||
r := redistest.CreateRedis(t)
|
||||
conf := ClusterConf{
|
||||
{
|
||||
RedisConf: redis.RedisConf{
|
||||
|
||||
174
core/stores/cache/cachenode_test.go
vendored
174
core/stores/cache/cachenode_test.go
vendored
@@ -1,6 +1,3 @@
|
||||
//go:build !race
|
||||
|
||||
// Disable data race detection is because of the timingWheel in cacheNode.
|
||||
package cache
|
||||
|
||||
import (
|
||||
@@ -34,10 +31,10 @@ func init() {
|
||||
|
||||
func TestCacheNode_DelCache(t *testing.T) {
|
||||
t.Run("del cache", func(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
store.Type = redis.ClusterType
|
||||
defer clean()
|
||||
r, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer r.Close()
|
||||
store := redis.New(r.Addr(), redis.Cluster())
|
||||
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
@@ -58,16 +55,16 @@ func TestCacheNode_DelCache(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("del cache with errors", func(t *testing.T) {
|
||||
old := timingWheel
|
||||
old := timingWheel.Load()
|
||||
ticker := timex.NewFakeTicker()
|
||||
var err error
|
||||
timingWheel, err = collection.NewTimingWheelWithTicker(
|
||||
tw, err := collection.NewTimingWheelWithTicker(
|
||||
time.Millisecond, timingWheelSlots, func(key, value any) {
|
||||
clean(key, value)
|
||||
}, ticker)
|
||||
timingWheel.Store(tw)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
timingWheel = old
|
||||
timingWheel.Store(old)
|
||||
})
|
||||
|
||||
r, err := miniredis.Run()
|
||||
@@ -84,9 +81,7 @@ func TestCacheNode_DelCache(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCacheNode_DelCacheWithErrors(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean()
|
||||
store := redistest.CreateRedis(t)
|
||||
store.Type = redis.ClusterType
|
||||
|
||||
cn := cacheNode{
|
||||
@@ -122,9 +117,7 @@ func TestCacheNode_InvalidCache(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCacheNode_SetWithExpire(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean()
|
||||
store := redistest.CreateRedis(t)
|
||||
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
@@ -139,14 +132,12 @@ func TestCacheNode_SetWithExpire(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCacheNode_Take(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean()
|
||||
store := redistest.CreateRedis(t)
|
||||
|
||||
cn := NewNode(store, syncx.NewSingleFlight(), NewStat("any"), errTestNotFound,
|
||||
WithExpiry(time.Second), WithNotFoundExpiry(time.Second))
|
||||
var str string
|
||||
err = cn.Take(&str, "any", func(v any) error {
|
||||
err := cn.Take(&str, "any", func(v any) error {
|
||||
*v.(*string) = "value"
|
||||
return nil
|
||||
})
|
||||
@@ -174,48 +165,103 @@ func TestCacheNode_TakeBadRedis(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCacheNode_TakeNotFound(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean()
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
store := redistest.CreateRedis(t)
|
||||
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
barrier: syncx.NewSingleFlight(),
|
||||
lock: new(sync.Mutex),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
stat: NewStat("any"),
|
||||
errNotFound: errTestNotFound,
|
||||
}
|
||||
var str string
|
||||
err = cn.Take(&str, "any", func(v any) error {
|
||||
return errTestNotFound
|
||||
})
|
||||
assert.True(t, cn.IsNotFound(err))
|
||||
assert.True(t, cn.IsNotFound(cn.Get("any", &str)))
|
||||
val, err := store.Get("any")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, `*`, val)
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
barrier: syncx.NewSingleFlight(),
|
||||
lock: new(sync.Mutex),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
stat: NewStat("any"),
|
||||
errNotFound: errTestNotFound,
|
||||
}
|
||||
var str string
|
||||
err := cn.Take(&str, "any", func(v any) error {
|
||||
return errTestNotFound
|
||||
})
|
||||
assert.True(t, cn.IsNotFound(err))
|
||||
assert.True(t, cn.IsNotFound(cn.Get("any", &str)))
|
||||
val, err := store.Get("any")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, `*`, val)
|
||||
|
||||
store.Set("any", "*")
|
||||
err = cn.Take(&str, "any", func(v any) error {
|
||||
return nil
|
||||
})
|
||||
assert.True(t, cn.IsNotFound(err))
|
||||
assert.True(t, cn.IsNotFound(cn.Get("any", &str)))
|
||||
store.Set("any", "*")
|
||||
err = cn.Take(&str, "any", func(v any) error {
|
||||
return nil
|
||||
})
|
||||
assert.True(t, cn.IsNotFound(err))
|
||||
assert.True(t, cn.IsNotFound(cn.Get("any", &str)))
|
||||
|
||||
store.Del("any")
|
||||
errDummy := errors.New("dummy")
|
||||
err = cn.Take(&str, "any", func(v any) error {
|
||||
return errDummy
|
||||
store.Del("any")
|
||||
errDummy := errors.New("dummy")
|
||||
err = cn.Take(&str, "any", func(v any) error {
|
||||
return errDummy
|
||||
})
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
|
||||
t.Run("not found with redis error", func(t *testing.T) {
|
||||
r, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer r.Close()
|
||||
store, err := redis.NewRedis(redis.RedisConf{
|
||||
Host: r.Addr(),
|
||||
Type: redis.NodeType,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
barrier: syncx.NewSingleFlight(),
|
||||
lock: new(sync.Mutex),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
stat: NewStat("any"),
|
||||
errNotFound: errTestNotFound,
|
||||
}
|
||||
var str string
|
||||
err = cn.Take(&str, "any", func(v any) error {
|
||||
r.SetError("mock error")
|
||||
return errTestNotFound
|
||||
})
|
||||
assert.True(t, cn.IsNotFound(err))
|
||||
})
|
||||
}
|
||||
|
||||
func TestCacheNode_TakeCtxWithRedisError(t *testing.T) {
|
||||
t.Run("not found with redis error", func(t *testing.T) {
|
||||
r, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer r.Close()
|
||||
store, err := redis.NewRedis(redis.RedisConf{
|
||||
Host: r.Addr(),
|
||||
Type: redis.NodeType,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
barrier: syncx.NewSingleFlight(),
|
||||
lock: new(sync.Mutex),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
stat: NewStat("any"),
|
||||
errNotFound: errTestNotFound,
|
||||
}
|
||||
var str string
|
||||
err = cn.Take(&str, "any", func(v any) error {
|
||||
str = "foo"
|
||||
r.SetError("mock error")
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCacheNode_TakeNotFoundButChangedByOthers(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.NoError(t, err)
|
||||
defer clean()
|
||||
store := redistest.CreateRedis(t)
|
||||
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
@@ -228,7 +274,7 @@ func TestCacheNode_TakeNotFoundButChangedByOthers(t *testing.T) {
|
||||
}
|
||||
|
||||
var str string
|
||||
err = cn.Take(&str, "any", func(v any) error {
|
||||
err := cn.Take(&str, "any", func(v any) error {
|
||||
store.Set("any", "foo")
|
||||
return errTestNotFound
|
||||
})
|
||||
@@ -242,9 +288,7 @@ func TestCacheNode_TakeNotFoundButChangedByOthers(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCacheNode_TakeWithExpire(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean()
|
||||
store := redistest.CreateRedis(t)
|
||||
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
@@ -256,7 +300,7 @@ func TestCacheNode_TakeWithExpire(t *testing.T) {
|
||||
errNotFound: errors.New("any"),
|
||||
}
|
||||
var str string
|
||||
err = cn.TakeWithExpire(&str, "any", func(v any, expire time.Duration) error {
|
||||
err := cn.TakeWithExpire(&str, "any", func(v any, expire time.Duration) error {
|
||||
*v.(*string) = "value"
|
||||
return nil
|
||||
})
|
||||
@@ -269,9 +313,7 @@ func TestCacheNode_TakeWithExpire(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCacheNode_String(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean()
|
||||
store := redistest.CreateRedis(t)
|
||||
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
@@ -286,9 +328,7 @@ func TestCacheNode_String(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCacheValueWithBigInt(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean()
|
||||
store := redistest.CreateRedis(t)
|
||||
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
|
||||
28
core/stores/cache/cacheopt_test.go
vendored
Normal file
28
core/stores/cache/cacheopt_test.go
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCacheOptions(t *testing.T) {
|
||||
t.Run("default options", func(t *testing.T) {
|
||||
o := newOptions()
|
||||
assert.Equal(t, defaultExpiry, o.Expiry)
|
||||
assert.Equal(t, defaultNotFoundExpiry, o.NotFoundExpiry)
|
||||
})
|
||||
|
||||
t.Run("with expiry", func(t *testing.T) {
|
||||
o := newOptions(WithExpiry(time.Second))
|
||||
assert.Equal(t, time.Second, o.Expiry)
|
||||
assert.Equal(t, defaultNotFoundExpiry, o.NotFoundExpiry)
|
||||
})
|
||||
|
||||
t.Run("with not found expiry", func(t *testing.T) {
|
||||
o := newOptions(WithNotFoundExpiry(time.Second))
|
||||
assert.Equal(t, defaultExpiry, o.Expiry)
|
||||
assert.Equal(t, time.Second, o.NotFoundExpiry)
|
||||
})
|
||||
}
|
||||
24
core/stores/cache/cleaner.go
vendored
24
core/stores/cache/cleaner.go
vendored
@@ -2,6 +2,7 @@ package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/collection"
|
||||
@@ -19,7 +20,8 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
timingWheel *collection.TimingWheel
|
||||
// use atomic to avoid data race in unit tests
|
||||
timingWheel atomic.Value
|
||||
taskRunner = threading.NewTaskRunner(cleanWorkers)
|
||||
)
|
||||
|
||||
@@ -30,22 +32,27 @@ type delayTask struct {
|
||||
}
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
timingWheel, err = collection.NewTimingWheel(time.Second, timingWheelSlots, clean)
|
||||
tw, err := collection.NewTimingWheel(time.Second, timingWheelSlots, clean)
|
||||
logx.Must(err)
|
||||
timingWheel.Store(tw)
|
||||
|
||||
proc.AddShutdownListener(func() {
|
||||
timingWheel.Drain(clean)
|
||||
if err := tw.Drain(clean); err != nil {
|
||||
logx.Errorf("failed to drain timing wheel: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// AddCleanTask adds a clean task on given keys.
|
||||
func AddCleanTask(task func() error, keys ...string) {
|
||||
timingWheel.SetTimer(stringx.Randn(taskKeyLen), delayTask{
|
||||
tw := timingWheel.Load().(*collection.TimingWheel)
|
||||
if err := tw.SetTimer(stringx.Randn(taskKeyLen), delayTask{
|
||||
delay: time.Second,
|
||||
task: task,
|
||||
keys: keys,
|
||||
}, time.Second)
|
||||
}, time.Second); err != nil {
|
||||
logx.Errorf("failed to set timer for keys: %q, error: %v", formatKeys(keys), err)
|
||||
}
|
||||
}
|
||||
|
||||
func clean(key, value any) {
|
||||
@@ -59,7 +66,10 @@ func clean(key, value any) {
|
||||
next, ok := nextDelay(dt.delay)
|
||||
if ok {
|
||||
dt.delay = next
|
||||
timingWheel.SetTimer(key, dt, next)
|
||||
tw := timingWheel.Load().(*collection.TimingWheel)
|
||||
if err = tw.SetTimer(key, dt, next); err != nil {
|
||||
logx.Errorf("failed to set timer for key: %s, error: %v", key, err)
|
||||
}
|
||||
} else {
|
||||
msg := fmt.Sprintf("retried but failed to clear cache with keys: %q, error: %v",
|
||||
formatKeys(dt.keys), err)
|
||||
|
||||
14
core/stores/cache/cleaner_test.go
vendored
14
core/stores/cache/cleaner_test.go
vendored
@@ -5,7 +5,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/collection"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
)
|
||||
|
||||
func TestNextDelay(t *testing.T) {
|
||||
@@ -49,6 +51,18 @@ func TestNextDelay(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
old := timingWheel.Load()
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, err := collection.NewTimingWheelWithTicker(
|
||||
time.Millisecond, timingWheelSlots, func(key, value any) {
|
||||
clean(key, value)
|
||||
}, ticker)
|
||||
timingWheel.Store(tw)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
timingWheel.Store(old)
|
||||
})
|
||||
|
||||
next, ok := nextDelay(test.input)
|
||||
assert.Equal(t, test.ok, ok)
|
||||
assert.Equal(t, test.output, next)
|
||||
|
||||
@@ -3,12 +3,11 @@ package mon
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
@@ -573,15 +572,7 @@ func TestDecoratedCollection_LogDuration(t *testing.T) {
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
w := logx.NewWriter(&buf)
|
||||
o := logx.Reset()
|
||||
logx.SetWriter(w)
|
||||
|
||||
defer func() {
|
||||
logx.Reset()
|
||||
logx.SetWriter(o)
|
||||
}()
|
||||
buf := logtest.NewCollector(t)
|
||||
|
||||
buf.Reset()
|
||||
c.logDuration(context.Background(), "foo", timex.Now(), nil, "bar")
|
||||
|
||||
@@ -2,10 +2,10 @@ package mon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
mopt "go.mongodb.org/mongo-driver/mongo/options"
|
||||
@@ -39,10 +39,7 @@ type (
|
||||
// MustNewModel returns a Model, exits on errors.
|
||||
func MustNewModel(uri, db, collection string, opts ...Option) *Model {
|
||||
model, err := NewModel(uri, db, collection, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
logx.Must(err)
|
||||
return model
|
||||
}
|
||||
|
||||
|
||||
@@ -3,12 +3,11 @@ package mon
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||
)
|
||||
|
||||
func TestFormatAddrs(t *testing.T) {
|
||||
@@ -40,15 +39,7 @@ func TestFormatAddrs(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_logDuration(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
w := logx.NewWriter(&buf)
|
||||
o := logx.Reset()
|
||||
logx.SetWriter(w)
|
||||
|
||||
defer func() {
|
||||
logx.Reset()
|
||||
logx.SetWriter(o)
|
||||
}()
|
||||
buf := logtest.NewCollector(t)
|
||||
|
||||
buf.Reset()
|
||||
logDuration(context.Background(), "foo", "bar", time.Millisecond, nil)
|
||||
|
||||
@@ -2,8 +2,8 @@ package monc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stores/cache"
|
||||
"github.com/zeromicro/go-zero/core/stores/mon"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
@@ -30,20 +30,14 @@ type Model struct {
|
||||
// MustNewModel returns a Model with a cache cluster, exists on errors.
|
||||
func MustNewModel(uri, db, collection string, c cache.CacheConf, opts ...cache.Option) *Model {
|
||||
model, err := NewModel(uri, db, collection, c, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
logx.Must(err)
|
||||
return model
|
||||
}
|
||||
|
||||
// MustNewNodeModel returns a Model with a cache node, exists on errors.
|
||||
func MustNewNodeModel(uri, db, collection string, rds *redis.Redis, opts ...cache.Option) *Model {
|
||||
model, err := NewNodeModel(uri, db, collection, rds, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
logx.Must(err)
|
||||
return model
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,8 @@ package postgres
|
||||
|
||||
import (
|
||||
// imports the driver, don't remove this comment, golint requires.
|
||||
_ "github.com/jackc/pgx/v5"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package redis
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrEmptyHost is an error that indicates no redis host is set.
|
||||
@@ -9,23 +12,24 @@ var (
|
||||
ErrEmptyType = errors.New("empty redis type")
|
||||
// ErrEmptyKey is an error that indicates no redis key is set.
|
||||
ErrEmptyKey = errors.New("empty redis key")
|
||||
// ErrPing is an error that indicates ping failed.
|
||||
ErrPing = errors.New("ping redis failed")
|
||||
)
|
||||
|
||||
type (
|
||||
// A RedisConf is a redis config.
|
||||
RedisConf struct {
|
||||
Host string
|
||||
Type string `json:",default=node,options=node|cluster"`
|
||||
Pass string `json:",optional"`
|
||||
Tls bool `json:",optional"`
|
||||
Host string
|
||||
Type string `json:",default=node,options=node|cluster"`
|
||||
Pass string `json:",optional"`
|
||||
Tls bool `json:",optional"`
|
||||
NonBlock bool `json:",default=true"`
|
||||
// PingTimeout is the timeout for ping redis.
|
||||
PingTimeout time.Duration `json:",default=1s"`
|
||||
}
|
||||
|
||||
// A RedisKeyConf is a redis config with key.
|
||||
RedisKeyConf struct {
|
||||
RedisConf
|
||||
Key string `json:",optional"`
|
||||
Key string
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ func (h hook) AfterProcess(ctx context.Context, cmd red.Cmder) error {
|
||||
logDuration(ctx, []red.Cmder{cmd}, duration)
|
||||
}
|
||||
|
||||
metricReqDur.Observe(int64(duration/time.Millisecond), cmd.Name())
|
||||
metricReqDur.Observe(duration.Milliseconds(), cmd.Name())
|
||||
if msg := formatError(err); len(msg) > 0 {
|
||||
metricReqErr.Inc(cmd.Name(), msg)
|
||||
}
|
||||
@@ -103,7 +103,7 @@ func (h hook) AfterProcessPipeline(ctx context.Context, cmds []red.Cmder) error
|
||||
logDuration(ctx, cmds, duration)
|
||||
}
|
||||
|
||||
metricReqDur.Observe(int64(duration/time.Millisecond), "Pipeline")
|
||||
metricReqDur.Observe(duration.Milliseconds(), "Pipeline")
|
||||
if msg := formatError(batchError.Err()); len(msg) > 0 {
|
||||
metricReqErr.Inc("Pipeline", msg)
|
||||
}
|
||||
|
||||
@@ -2,14 +2,18 @@ package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
red "github.com/go-redis/redis/v8"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||
ztrace "github.com/zeromicro/go-zero/core/trace"
|
||||
tracesdk "go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
@@ -47,8 +51,7 @@ func TestHookProcessCase2(t *testing.T) {
|
||||
})
|
||||
defer ztrace.StopAgent()
|
||||
|
||||
w, restore := injectLog()
|
||||
defer restore()
|
||||
w := logtest.NewCollector(t)
|
||||
|
||||
ctx, err := durationHook.BeforeProcess(context.Background(), red.NewCmd(context.Background()))
|
||||
if err != nil {
|
||||
@@ -115,8 +118,7 @@ func TestHookProcessPipelineCase2(t *testing.T) {
|
||||
})
|
||||
defer ztrace.StopAgent()
|
||||
|
||||
w, restore := injectLog()
|
||||
defer restore()
|
||||
w := logtest.NewCollector(t)
|
||||
|
||||
ctx, err := durationHook.BeforeProcessPipeline(context.Background(), []red.Cmder{
|
||||
red.NewCmd(context.Background()),
|
||||
@@ -135,8 +137,7 @@ func TestHookProcessPipelineCase2(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHookProcessPipelineCase3(t *testing.T) {
|
||||
w, restore := injectLog()
|
||||
defer restore()
|
||||
w := logtest.NewCollector(t)
|
||||
|
||||
assert.Nil(t, durationHook.AfterProcessPipeline(context.Background(), []red.Cmder{
|
||||
red.NewCmd(context.Background()),
|
||||
@@ -145,8 +146,7 @@ func TestHookProcessPipelineCase3(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHookProcessPipelineCase4(t *testing.T) {
|
||||
w, restore := injectLog()
|
||||
defer restore()
|
||||
w := logtest.NewCollector(t)
|
||||
|
||||
ctx := context.WithValue(context.Background(), startTimeKey, "foo")
|
||||
assert.Nil(t, durationHook.AfterProcessPipeline(ctx, []red.Cmder{
|
||||
@@ -169,8 +169,7 @@ func TestHookProcessPipelineCase5(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLogDuration(t *testing.T) {
|
||||
w, restore := injectLog()
|
||||
defer restore()
|
||||
w := logtest.NewCollector(t)
|
||||
|
||||
logDuration(context.Background(), []red.Cmder{
|
||||
red.NewCmd(context.Background(), "get", "foo"),
|
||||
@@ -184,14 +183,39 @@ func TestLogDuration(t *testing.T) {
|
||||
assert.True(t, strings.Contains(w.String(), `get foo\nset bar 0`))
|
||||
}
|
||||
|
||||
func injectLog() (r *strings.Builder, restore func()) {
|
||||
var buf strings.Builder
|
||||
w := logx.NewWriter(&buf)
|
||||
o := logx.Reset()
|
||||
logx.SetWriter(w)
|
||||
|
||||
return &buf, func() {
|
||||
logx.Reset()
|
||||
logx.SetWriter(o)
|
||||
func TestFormatError(t *testing.T) {
|
||||
// Test case: err is OpError
|
||||
err := &net.OpError{
|
||||
Err: mockOpError{},
|
||||
}
|
||||
assert.Equal(t, "timeout", formatError(err))
|
||||
|
||||
// Test case: err is nil
|
||||
assert.Equal(t, "", formatError(nil))
|
||||
|
||||
// Test case: err is red.Nil
|
||||
assert.Equal(t, "", formatError(red.Nil))
|
||||
|
||||
// Test case: err is io.EOF
|
||||
assert.Equal(t, "eof", formatError(io.EOF))
|
||||
|
||||
// Test case: err is context.DeadlineExceeded
|
||||
assert.Equal(t, "context deadline", formatError(context.DeadlineExceeded))
|
||||
|
||||
// Test case: err is breaker.ErrServiceUnavailable
|
||||
assert.Equal(t, "breaker", formatError(breaker.ErrServiceUnavailable))
|
||||
|
||||
// Test case: err is unknown
|
||||
assert.Equal(t, "unexpected error", formatError(errors.New("some error")))
|
||||
}
|
||||
|
||||
type mockOpError struct {
|
||||
}
|
||||
|
||||
func (mockOpError) Error() string {
|
||||
return "mock error"
|
||||
}
|
||||
|
||||
func (mockOpError) Timeout() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -4,12 +4,13 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
red "github.com/go-redis/redis/v8"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"github.com/zeromicro/go-zero/core/errorx"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/mapping"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
)
|
||||
@@ -25,6 +26,7 @@ const (
|
||||
blockingQueryTimeout = 5 * time.Second
|
||||
readWriteTimeout = 2 * time.Second
|
||||
defaultSlowThreshold = time.Millisecond * 100
|
||||
defaultPingTimeout = time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -51,11 +53,12 @@ type (
|
||||
|
||||
// Redis defines a redis node/cluster. It is thread-safe.
|
||||
Redis struct {
|
||||
Addr string
|
||||
Type string
|
||||
Pass string
|
||||
tls bool
|
||||
brk breaker.Breaker
|
||||
Addr string
|
||||
Type string
|
||||
Pass string
|
||||
tls bool
|
||||
brk breaker.Breaker
|
||||
hooks []red.Hook
|
||||
}
|
||||
|
||||
// RedisNode interface represents a redis node.
|
||||
@@ -84,24 +87,23 @@ type (
|
||||
FloatCmd = red.FloatCmd
|
||||
// StringCmd is an alias of redis.StringCmd.
|
||||
StringCmd = red.StringCmd
|
||||
// Script is an alias of redis.Script.
|
||||
Script = red.Script
|
||||
)
|
||||
|
||||
// MustNewRedis returns a Redis with given options.
|
||||
func MustNewRedis(conf RedisConf, opts ...Option) *Redis {
|
||||
rds, err := NewRedis(conf, opts...)
|
||||
logx.Must(err)
|
||||
return rds
|
||||
}
|
||||
|
||||
// New returns a Redis with given options.
|
||||
// Deprecated: use MustNewRedis or NewRedis instead.
|
||||
func New(addr string, opts ...Option) *Redis {
|
||||
return newRedis(addr, opts...)
|
||||
}
|
||||
|
||||
// MustNewRedis returns a Redis with given options.
|
||||
func MustNewRedis(conf RedisConf, opts ...Option) *Redis {
|
||||
rds, err := NewRedis(conf, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return rds
|
||||
}
|
||||
|
||||
// NewRedis returns a Redis with given options.
|
||||
func NewRedis(conf RedisConf, opts ...Option) (*Redis, error) {
|
||||
if err := conf.Validate(); err != nil {
|
||||
@@ -119,8 +121,10 @@ func NewRedis(conf RedisConf, opts ...Option) (*Redis, error) {
|
||||
}
|
||||
|
||||
rds := newRedis(conf.Host, opts...)
|
||||
if !rds.Ping() {
|
||||
return nil, ErrPing
|
||||
if !conf.NonBlock {
|
||||
if err := rds.checkConnection(conf.PingTimeout); err != nil {
|
||||
return nil, errorx.Wrap(err, fmt.Sprintf("redis connect error, addr: %s", conf.Host))
|
||||
}
|
||||
}
|
||||
|
||||
return rds, nil
|
||||
@@ -140,6 +144,11 @@ func newRedis(addr string, opts ...Option) *Redis {
|
||||
return r
|
||||
}
|
||||
|
||||
// NewScript returns a new Script instance.
|
||||
func NewScript(script string) *Script {
|
||||
return red.NewScript(script)
|
||||
}
|
||||
|
||||
// BitCount is redis bitcount command implementation.
|
||||
func (s *Redis) BitCount(key string, start, end int64) (int64, error) {
|
||||
return s.BitCountCtx(context.Background(), key, start, end)
|
||||
@@ -832,12 +841,12 @@ func (s *Redis) HincrbyCtx(ctx context.Context, key, field string, increment int
|
||||
return
|
||||
}
|
||||
|
||||
// HincrbyFloat is the implementation of redis hincrby command.
|
||||
// HincrbyFloat is the implementation of redis hincrbyfloat command.
|
||||
func (s *Redis) HincrbyFloat(key, field string, increment float64) (float64, error) {
|
||||
return s.HincrbyFloatCtx(context.Background(), key, field, increment)
|
||||
}
|
||||
|
||||
// HincrbyFloatCtx is the implementation of redis hincrby command.
|
||||
// HincrbyFloatCtx is the implementation of redis hincrbyfloat command.
|
||||
func (s *Redis) HincrbyFloatCtx(ctx context.Context, key, field string, increment float64) (val float64, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
@@ -1065,12 +1074,12 @@ func (s *Redis) IncrbyCtx(ctx context.Context, key string, increment int64) (val
|
||||
return
|
||||
}
|
||||
|
||||
// IncrbyFloat is the implementation of redis incrby command.
|
||||
// IncrbyFloat is the implementation of redis hincrbyfloat command.
|
||||
func (s *Redis) IncrbyFloat(key string, increment float64) (float64, error) {
|
||||
return s.IncrbyFloatCtx(context.Background(), key, increment)
|
||||
}
|
||||
|
||||
// IncrbyFloatCtx is the implementation of redis incrby command.
|
||||
// IncrbyFloatCtx is the implementation of redis hincrbyfloat command.
|
||||
func (s *Redis) IncrbyFloatCtx(ctx context.Context, key string, increment float64) (val float64, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
@@ -1170,6 +1179,26 @@ func (s *Redis) LpopCtx(ctx context.Context, key string) (val string, err error)
|
||||
return
|
||||
}
|
||||
|
||||
// LpopCount is the implementation of redis lpopCount command.
|
||||
func (s *Redis) LpopCount(key string, count int) ([]string, error) {
|
||||
return s.LpopCountCtx(context.Background(), key, count)
|
||||
}
|
||||
|
||||
// LpopCountCtx is the implementation of redis lpopCount command.
|
||||
func (s *Redis) LpopCountCtx(ctx context.Context, key string, count int) (val []string, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val, err = conn.LPopCount(ctx, key, count).Result()
|
||||
return err
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Lpush is the implementation of redis lpush command.
|
||||
func (s *Redis) Lpush(key string, values ...any) (int, error) {
|
||||
return s.LpushCtx(context.Background(), key, values...)
|
||||
@@ -1432,6 +1461,26 @@ func (s *Redis) RpopCtx(ctx context.Context, key string) (val string, err error)
|
||||
return
|
||||
}
|
||||
|
||||
// RpopCount is the implementation of redis rpopCount command.
|
||||
func (s *Redis) RpopCount(key string, count int) ([]string, error) {
|
||||
return s.RpopCountCtx(context.Background(), key, count)
|
||||
}
|
||||
|
||||
// RpopCountCtx is the implementation of redis rpopCount command.
|
||||
func (s *Redis) RpopCountCtx(ctx context.Context, key string, count int) (val []string, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val, err = conn.RPopCount(ctx, key, count).Result()
|
||||
return err
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Rpush is the implementation of redis rpush command.
|
||||
func (s *Redis) Rpush(key string, values ...any) (int, error) {
|
||||
return s.RpushCtx(context.Background(), key, values...)
|
||||
@@ -1585,6 +1634,25 @@ func (s *Redis) ScriptLoadCtx(ctx context.Context, script string) (string, error
|
||||
return conn.ScriptLoad(ctx, script).Result()
|
||||
}
|
||||
|
||||
// ScriptRun is the implementation of *redis.Script run command.
|
||||
func (s *Redis) ScriptRun(script *Script, keys []string, args ...any) (any, error) {
|
||||
return s.ScriptRunCtx(context.Background(), script, keys, args...)
|
||||
}
|
||||
|
||||
// ScriptRunCtx is the implementation of *redis.Script run command.
|
||||
func (s *Redis) ScriptRunCtx(ctx context.Context, script *Script, keys []string, args ...any) (val any, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val, err = script.Run(ctx, conn, keys, args...).Result()
|
||||
return err
|
||||
}, acceptable)
|
||||
return
|
||||
}
|
||||
|
||||
// Set is the implementation of redis set command.
|
||||
func (s *Redis) Set(key, value string) error {
|
||||
return s.SetCtx(context.Background(), key, value)
|
||||
@@ -2729,6 +2797,23 @@ func (s *Redis) ZunionstoreCtx(ctx context.Context, dest string, store *ZStore)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Redis) checkConnection(pingTimeout time.Duration) error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
timeout := defaultPingTimeout
|
||||
if pingTimeout > 0 {
|
||||
timeout = pingTimeout
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
return conn.Ping(ctx).Err()
|
||||
}
|
||||
|
||||
// Cluster customizes the given Redis as a cluster.
|
||||
func Cluster() Option {
|
||||
return func(r *Redis) {
|
||||
@@ -2755,6 +2840,14 @@ func WithTLS() Option {
|
||||
}
|
||||
}
|
||||
|
||||
// withHook customizes the given Redis with given hook, only for private use now,
|
||||
// maybe expose later.
|
||||
func withHook(hook red.Hook) Option {
|
||||
return func(r *Redis) {
|
||||
r.hooks = append(r.hooks, hook)
|
||||
}
|
||||
}
|
||||
|
||||
func acceptable(err error) bool {
|
||||
return err == nil || err == red.Nil || err == context.Canceled
|
||||
}
|
||||
|
||||
@@ -16,6 +16,25 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
type myHook struct {
|
||||
red.Hook
|
||||
includePing bool
|
||||
}
|
||||
|
||||
var _ red.Hook = myHook{}
|
||||
|
||||
func (m myHook) BeforeProcess(ctx context.Context, cmd red.Cmder) (context.Context, error) {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (m myHook) AfterProcess(ctx context.Context, cmd red.Cmder) error {
|
||||
// skip ping cmd
|
||||
if cmd.Name() == "ping" && !m.includePing {
|
||||
return nil
|
||||
}
|
||||
return errors.New("hook error")
|
||||
}
|
||||
|
||||
func TestNewRedis(t *testing.T) {
|
||||
r1, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
@@ -126,6 +145,31 @@ func TestNewRedis(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedis_NonBlock(t *testing.T) {
|
||||
logx.Disable()
|
||||
|
||||
t.Run("nonBlock true", func(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
// use hook to simulate redis ping error
|
||||
_, err := NewRedis(RedisConf{
|
||||
Host: s.Addr(),
|
||||
NonBlock: true,
|
||||
Type: NodeType,
|
||||
}, withHook(myHook{includePing: true}))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("nonBlock false", func(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
_, err := NewRedis(RedisConf{
|
||||
Host: s.Addr(),
|
||||
NonBlock: false,
|
||||
Type: NodeType,
|
||||
}, withHook(myHook{includePing: true}))
|
||||
assert.ErrorContains(t, err, "redis connect error")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Decr(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
_, err := New(client.Addr, badType()).Decr("a")
|
||||
@@ -196,6 +240,24 @@ func TestRedis_Eval(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_ScriptRun(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
sc := NewScript(`redis.call("EXISTS", KEYS[1])`)
|
||||
sc2 := NewScript(`return redis.call("EXISTS", KEYS[1])`)
|
||||
_, err := New(client.Addr, badType()).ScriptRun(sc, []string{"notexist"})
|
||||
assert.NotNil(t, err)
|
||||
_, err = client.ScriptRun(sc, []string{"notexist"})
|
||||
assert.Equal(t, Nil, err)
|
||||
err = client.Set("key1", "value1")
|
||||
assert.Nil(t, err)
|
||||
_, err = client.ScriptRun(sc, []string{"key1"})
|
||||
assert.Equal(t, Nil, err)
|
||||
val, err := client.ScriptRun(sc2, []string{"key1"})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), val)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_GeoHash(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
_, err := client.GeoHash("parent", "child1", "child2")
|
||||
@@ -507,6 +569,14 @@ func TestRedis_List(t *testing.T) {
|
||||
vals, err = client.Lrange("key", 0, 10)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"value2", "value3"}, vals)
|
||||
vals, err = client.LpopCount("key", 2)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"value2", "value3"}, vals)
|
||||
_, err = client.Lpush("key", "value1", "value2")
|
||||
assert.Nil(t, err)
|
||||
vals, err = client.RpopCount("key", 4)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"value1", "value2"}, vals)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -523,6 +593,34 @@ func TestRedis_List(t *testing.T) {
|
||||
|
||||
_, err = client.Rpush("key", "value3", "value4")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = client.LpopCount("key", 2)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = client.RpopCount("key", 2)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
})
|
||||
t.Run("list redis type error", func(t *testing.T) {
|
||||
runOnRedisWithError(t, func(client *Redis) {
|
||||
client.Type = "nil"
|
||||
_, err := client.Llen("key")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = client.Lpush("key", "value1", "value2")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = client.Lrem("key", 2, "value1")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = client.Rpush("key", "value3", "value4")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = client.LpopCount("key", 2)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = client.RpopCount("key", 2)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user