Compare commits

..

72 Commits

Author SHA1 Message Date
kevin
2ea0a843f8 chore: remove any keywords 2023-03-04 20:54:26 +08:00
Kevin Wan
9e0e01b2bc chore: add tests (#2960) 2023-03-04 20:38:50 +08:00
yangjinheng
af50a80d01 timeout writer add hijack 2023-03-04 20:38:45 +08:00
yangjinheng
703fb8d970 Update timeouthandler.go 2023-03-04 20:38:40 +08:00
MarkJoyMa
e964e530e1 x 2023-03-04 20:32:21 +08:00
MarkJoyMa
52265087d1 x 2023-03-04 20:32:16 +08:00
MarkJoyMa
b4c2677eb9 add ut 2023-03-04 20:32:10 +08:00
MarkJoyMa
30296fb1ca feat: conf add FillDefault func 2023-03-04 20:31:44 +08:00
zhoumingji
356c80defd Fix bug in dartgen: The property 'isEmpty' can't be unconditionally accessed because the receiver can be 'null' 2023-03-04 20:31:38 +08:00
zhoumingji
8c31525378 Fix bug in dartgen: Increase the processing logic when route.RequestType is empty 2023-03-04 20:31:30 +08:00
cui fliter
2cf09f3c36 fix functiom name
Signed-off-by: cui fliter <imcusg@gmail.com>
2023-03-04 20:31:20 +08:00
Kevin Wan
d41e542c92 feat: support grpc client keepalive config (#2950) 2023-03-04 20:30:31 +08:00
tanglihao
265a24ac6d fix code format style use const config.DefaultFormat 2023-03-04 20:30:21 +08:00
tanglihao
7d88fc39dc fix log name conflict 2023-03-04 20:30:16 +08:00
anqiansong
6957b6a344 format code 2023-03-04 20:30:10 +08:00
anqiansong
bca6a230c8 remove unused code 2023-03-04 20:30:04 +08:00
anqiansong
cc8413d683 remove unused code 2023-03-04 20:29:56 +08:00
anqiansong
3842283fa8 Fix #2879 2023-03-04 20:29:41 +08:00
qiying.wang
fe13a533f5 chore: remove redundant prefix of "error: " in error creation 2023-03-04 20:26:40 +08:00
qiying.wang
7a327ccda4 chore: add tests for logc debug 2023-03-04 20:25:52 +08:00
qiying.wang
06e4507406 feat: add debug log for logc 2023-03-04 20:25:27 +08:00
kevin
8794d5b753 chore: add comments 2023-03-04 20:25:21 +08:00
kevin
9bfa63d995 chore: add more tests 2023-03-04 20:25:15 +08:00
kevin
a432b121fb chore: add more tests 2023-03-04 20:25:07 +08:00
kevin
b61c94bb66 feat: check key overwritten 2023-03-04 20:24:33 +08:00
Kevin Wan
93fcf899dc fix: config map cannot handle case-insensitive keys. (#2932)
* fix: #2922

* chore: rename const

* feat: support anonymous map field

* feat: support anonymous map field
2023-03-04 20:23:53 +08:00
Kevin Wan
9f4b3bae92 fix: #2899, using autoscaling/v2beta2 instead of v2beta1 (#2900)
* fix: #2899, using autoscaling/v2 instead of v2beta1

* chore: change hpa definition

---------

Co-authored-by: kevin.wan <kevin.wan@yijinin.com>
2023-03-04 20:22:27 +08:00
Kevin Wan
805cb87d98 chore: refine rest validator (#2928)
* chore: refine rest validator

* chore: add more tests

* chore: reformat code

* chore: add comments
2023-03-04 20:22:10 +08:00
Qiying Wang
366131640e feat: add configurable validator for httpx.Parse (#2923)
Co-authored-by: qiying.wang <qiying.wang@highlight.mobi>
2023-03-04 20:22:05 +08:00
Kevin Wan
956884a3ff fix: timeout not working if greater than global rest timeout (#2926) 2023-03-04 20:21:59 +08:00
raymonder jin
f571cb8af2 del unnecessary blank 2023-03-04 20:21:54 +08:00
Kevin Wan
cc5acf3b90 chore: reformat code (#2925) 2023-03-04 20:21:49 +08:00
chenquan
e1aa665443 fix: fixed the bug that old trace instances may be fetched 2023-03-04 20:21:43 +08:00
xiandong
cd357d9484 rm parseErr when kindJaeger 2023-03-04 20:21:28 +08:00
xiandong
6d4d7cbd6b rm kindJaegerUdp 2023-03-04 20:21:18 +08:00
xiandong
c593b5b531 add parseEndpoint 2023-03-04 20:20:29 +08:00
xiandong
fd5b38b07c add parseEndpoint 2023-03-04 20:20:17 +08:00
xiandong
41efb48f55 add test for Endpoint of kindJaegerUdp 2023-03-04 20:19:40 +08:00
xiandong
0ef3626839 add test for Endpoint of kindJaegerUdp 2023-03-04 20:19:34 +08:00
xiandong
77a72b16e9 add kindJaegerUdp 2023-03-04 20:19:25 +08:00
Kevin Wan
21566f1b7a chore: reformat code (#2903) 2023-03-04 20:17:35 +08:00
anqiansong
b2646e228b feat: Add request.ts (#2901)
* Add request.ts

* Update comments

* Refactor request filename
2023-03-04 20:17:21 +08:00
cong
588b883710 refactor: simplify sqlx fail fast ping and simplify miniredis setup in test (#2897)
* chore(redistest): simplify miniredis setup in test

* refactor(sqlx): simplify sqlx fail fast ping

* chore: close connection if not available
2023-03-04 20:17:16 +08:00
Kevin Wan
033910bbd8 Update readme-cn.md 2023-03-04 20:17:11 +08:00
fondoger
530dd79e3f Fix bug in dart api gen: path parameter is not replaced 2023-03-04 20:17:05 +08:00
Kevin Wan
cd5263ac75 Update readme-cn.md 2023-03-04 20:16:58 +08:00
Kevin Wan
ea3302a468 fix: test failures (#2892)
Co-authored-by: kevin.wan <kevin.wan@yijinin.com>
2023-03-04 20:16:50 +08:00
fondoger
abf15b373c Fix Dart API generation bugs; Add ability to generate API for path parameters (#2887)
* Fix bug in dartgen: Import path should match the generated api filename

* Use Route.HandlerName as generated dart API function name

Reasons:
- There is bug when using url path name as function name, because it may have invalid characters such as ":"
- Switching to HandlerName aligns with other languages such as typescript generation

* [DartGen] Add ability to generate api for url path parameters such as /path/:param
2023-03-04 20:16:44 +08:00
Kevin Wan
a865e9ee29 refactor: simplify stringx.Replacer, and avoid potential infinite loops (#2877)
* simplify replace

* backup

* refactor: simplify stringx.Replacer

* chore: add comments and const

* chore: add more tests

* chore: rename variable
2023-03-04 20:16:37 +08:00
Kevin Wan
f8292198cf Update readme-cn.md 2023-03-04 20:15:38 +08:00
Kevin Wan
016d965f56 chore: refactor (#2875) 2023-03-04 20:15:30 +08:00
dahaihu
95d7c73409 fix Replacer suffix match, and add test case (#2867)
* fix: replace shoud replace the longest match

* feat: revert bytes.Buffer to strings.Builder

* fix: loop reset nextStart

* feat: add node longest match test

* feat: add replacer suffix match test case

* feat: multiple match

* fix: partial match ends

* fix: replace look back upon error

* feat: rm unnecessary branch

---------

Co-authored-by: hudahai <hscxrzs@gmail.com>
Co-authored-by: hushichang <hushichang@sensetime.com>
2023-03-04 20:15:25 +08:00
Kevin Wan
939ef2a181 chore: add more tests (#2873) 2023-03-04 20:15:18 +08:00
Kevin Wan
f0b8dd45fe fix: test failure (#2874) 2023-03-04 20:15:08 +08:00
Mikael
0ba9335b04 only unmashal public variables (#2872)
* only unmashal public variables

* only unmashal public variables

* only unmashal public variables

* only unmashal public variables
2023-03-04 20:15:01 +08:00
Kevin Wan
04f181f0b4 chore: add more tests (#2866)
* chore: add more tests

* chore: add more tests

* chore: fix test failure
2023-03-04 20:14:54 +08:00
hudahai
89f841c126 fix: loop reset nextStart 2023-03-04 20:14:48 +08:00
hudahai
d785c8c377 feat: revert bytes.Buffer to strings.Builder 2023-03-04 20:14:41 +08:00
hudahai
687a1d15da fix: replace shoud replace the longest match 2023-03-04 20:14:35 +08:00
Kevin Wan
aaa974e1ad Update readme-cn.md 2023-03-04 20:14:22 +08:00
Kevin Wan
2779568ccf fix: conf anonymous overlay problem (#2847) 2023-03-04 20:14:10 +08:00
Kevin Wan
f7d50ae626 Update readme-cn.md 2023-03-04 20:14:01 +08:00
Kevin Wan
33594ea350 Chore/rewire (#2836)
* fix: problem on name overlaping in config (#2820)

* chore: fix missing funcs on windows (#2825)

* chore: add more tests (#2812)

* chore: add more tests

* chore: add more tests

* chore: add more tests (#2814)

* chore: add more tests (#2815)

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* feat: upgrade go to v1.18 (#2817)

* feat: upgrade go to v1.18

* feat: upgrade go to v1.18

* chore: change interface{} to any (#2818)

* chore: change interface{} to any

* chore: update goctl version to 1.5.0

* chore: update goctl deps

* chore: update goctl interface{} to any (#2819)

* chore: update goctl interface{} to any

* chore: update goctl interface{} to any

* chore(deps): bump google.golang.org/grpc from 1.52.0 to 1.52.3 (#2823)

* support custom maxBytes in API file (#2822)

* feat: mapreduce generic version (#2827)

* feat: mapreduce generic version

* fix: gateway mr type issue

---------

Co-authored-by: kevin.wan <kevin.wan@yijinin.com>

* feat: add MustNewRedis (#2824)

* feat: add MustNewRedis

* feat: add MustNewRedis

* feat: add MustNewRedis

* x

* x

* fix ut

* x

* x

* x

* x

* x

* chore: improve codecov (#2828)

* feat: converge grpc interceptor processing (#2830)

* feat: converge grpc interceptor processing

* x

* x

* chore(deps): bump go.opentelemetry.io/otel/exporters/zipkin (#2831)

* chore(deps): bump go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp (#2833)

Bumps [go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp](https://github.com/open-telemetry/opentelemetry-go) from 1.11.2 to 1.12.0.
- [Release notes](https://github.com/open-telemetry/opentelemetry-go/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-go/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-go/compare/v1.11.2...v1.12.0)

---
updated-dependencies:
- dependency-name: go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* chore(deps): bump go.opentelemetry.io/otel/exporters/jaeger (#2832)

Bumps [go.opentelemetry.io/otel/exporters/jaeger](https://github.com/open-telemetry/opentelemetry-go) from 1.11.2 to 1.12.0.
- [Release notes](https://github.com/open-telemetry/opentelemetry-go/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-go/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-go/compare/v1.11.2...v1.12.0)

---
updated-dependencies:
- dependency-name: go.opentelemetry.io/otel/exporters/jaeger
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Xiaoju Jiang <44432198+jiang4869@users.noreply.github.com>
Co-authored-by: kevin.wan <kevin.wan@yijinin.com>
Co-authored-by: MarkJoyMa <64180138+MarkJoyMa@users.noreply.github.com>
2023-03-04 20:13:37 +08:00
MarkJoyMa
ee2ec974c4 feat: converge grpc interceptor processing (#2830)
* feat: converge grpc interceptor processing

* x

* x
2023-03-04 20:12:30 +08:00
Kevin Wan
fd2f2f0f54 chore: improve codecov (#2828) 2023-03-04 20:12:16 +08:00
MarkJoyMa
86a2429d7d feat: add MustNewRedis (#2824)
* feat: add MustNewRedis

* feat: add MustNewRedis

* feat: add MustNewRedis

* x

* x

* fix ut

* x

* x

* x

* x

* x
2023-03-04 20:12:05 +08:00
Xiaoju Jiang
e5fe5dcc50 support custom maxBytes in API file (#2822) 2023-03-04 20:11:55 +08:00
Kevin Wan
b510e7c242 chore: fix missing funcs on windows (#2825) 2023-03-04 20:11:46 +08:00
Kevin Wan
dfe92e709f fix: problem on name overlaping in config (#2820) 2023-03-04 20:11:18 +08:00
Kevin Wan
cb649cf627 chore: add more tests (#2815)
* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests
2023-03-04 20:11:03 +08:00
Kevin Wan
ce19a5ade6 chore: add more tests (#2814) 2023-03-04 20:10:57 +08:00
Kevin Wan
6dc56de714 chore: add more tests (#2812)
* chore: add more tests

* chore: add more tests
2023-03-04 20:09:03 +08:00
667 changed files with 10493 additions and 31306 deletions

View File

@@ -1,13 +1,6 @@
coverage:
status:
patch: true
project: false # disabled because project coverage is not stable
comment: comment:
layout: "flags, files" layout: "flags, files"
behavior: once behavior: once
require_changes: true require_changes: true
ignore: ignore:
- "tools" - "tools"
- "**/mock"
- "**/*_mock.go"
- "**/*test"

View File

@@ -1,7 +1 @@
**/.git **/.git
.dockerignore
Dockerfile
goctl
Makefile
readme.md
readme-cn.md

2
.github/FUNDING.yml vendored
View File

@@ -10,4 +10,4 @@ liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username otechie: # Replace with a single Otechie username
custom: # https://gitee.com/kevwan/static/raw/master/images/sponsor.jpg custom: # https://gitee.com/kevwan/static/raw/master/images/sponsor.jpg
ethereum: # 0x5052b7f6B937B02563996D23feb69b38D06Ca150 | kevwan ethereum: 0x5052b7f6B937B02563996D23feb69b38D06Ca150 | kevwan

View File

@@ -5,19 +5,7 @@
version: 2 version: 2
updates: updates:
- package-ecosystem: "docker" # Update image tags in Dockerfile
directory: "/"
schedule:
interval: "weekly"
- package-ecosystem: "github-actions" # Update GitHub Actions
directory: "/"
schedule:
interval: "weekly"
- package-ecosystem: "gomod" # See documentation for possible values - package-ecosystem: "gomod" # See documentation for possible values
directory: "/" # Location of package manifests directory: "/" # Location of package manifests
schedule: schedule:
interval: "daily" interval: "daily"
- package-ecosystem: "gomod" # See documentation for possible values
directory: "/tools/goctl" # Location of package manifests
schedule:
interval: "daily"

View File

@@ -35,11 +35,11 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v3
# Initializes the CodeQL tools for scanning. # Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL - name: Initialize CodeQL
uses: github/codeql-action/init@v3 uses: github/codeql-action/init@v2
with: with:
languages: ${{ matrix.language }} languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file. # If you wish to specify custom queries, you can do so here or in a config file.
@@ -50,7 +50,7 @@ jobs:
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
# If this step fails, then you should remove it and run the build manually (see below) # If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild - name: Autobuild
uses: github/codeql-action/autobuild@v3 uses: github/codeql-action/autobuild@v2
# Command-line programs to run using the OS shell. # Command-line programs to run using the OS shell.
# 📚 https://git.io/JvXDl # 📚 https://git.io/JvXDl
@@ -64,4 +64,4 @@ jobs:
# make release # make release
- name: Perform CodeQL Analysis - name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3 uses: github/codeql-action/analyze@v2

View File

@@ -12,12 +12,12 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Set up Go 1.x - name: Set up Go 1.x
uses: actions/setup-go@v5 uses: actions/setup-go@v3
with: with:
go-version: '1.19' go-version: ^1.16
check-latest: true check-latest: true
cache: true cache: true
id: go id: go
@@ -29,31 +29,27 @@ jobs:
- name: Lint - name: Lint
run: | run: |
go vet -stdmethods=false $(go list ./...) go vet -stdmethods=false $(go list ./...)
go install mvdan.cc/gofumpt@latest
go mod tidy test -z "$(gofumpt -l -extra .)" || echo "Please run 'gofumpt -l -w -extra .'"
if ! test -z "$(git status --porcelain)"; then
echo "Please run 'go mod tidy'"
exit 1
fi
- name: Test - name: Test
run: go test -race -coverprofile=coverage.txt -covermode=atomic ./... run: go test -race -coverprofile=coverage.txt -covermode=atomic ./...
- name: Codecov - name: Codecov
uses: codecov/codecov-action@v4 uses: codecov/codecov-action@v3
test-win: test-win:
name: Windows name: Windows
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- name: Checkout codebase - name: Checkout codebase
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Set up Go 1.x - name: Set up Go 1.x
uses: actions/setup-go@v5 uses: actions/setup-go@v3
with: with:
# use 1.19 to guarantee Go 1.19 compatibility # use 1.16 to guarantee Go 1.16 compatibility
go-version: '1.19' go-version: 1.16
check-latest: true check-latest: true
cache: true cache: true
@@ -61,5 +57,5 @@ jobs:
run: | run: |
go mod verify go mod verify
go mod download go mod download
go test ./... go test -v -race ./...
cd tools/goctl && go build -v goctl.go cd tools/goctl && go build -v goctl.go

View File

@@ -7,7 +7,7 @@ jobs:
close-issues: close-issues:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/stale@v9 - uses: actions/stale@v6
with: with:
days-before-issue-stale: 365 days-before-issue-stale: 365
days-before-issue-close: 90 days-before-issue-close: 90

View File

@@ -16,13 +16,13 @@ jobs:
- goarch: "386" - goarch: "386"
goos: darwin goos: darwin
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- uses: zeromicro/go-zero-release-action@master - uses: zeromicro/go-zero-release-action@master
with: with:
github_token: ${{ secrets.GITHUB_TOKEN }} github_token: ${{ secrets.GITHUB_TOKEN }}
goos: ${{ matrix.goos }} goos: ${{ matrix.goos }}
goarch: ${{ matrix.goarch }} goarch: ${{ matrix.goarch }}
goversion: "https://dl.google.com/go/go1.19.13.linux-amd64.tar.gz" goversion: "https://dl.google.com/go/go1.17.5.linux-amd64.tar.gz"
project_path: "tools/goctl" project_path: "tools/goctl"
binary_name: "goctl" binary_name: "goctl"
extra_files: tools/goctl/readme.md tools/goctl/readme-cn.md extra_files: tools/goctl/readme.md tools/goctl/readme-cn.md

View File

@@ -5,7 +5,7 @@ jobs:
name: runner / staticcheck name: runner / staticcheck
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- uses: reviewdog/action-staticcheck@v1 - uses: reviewdog/action-staticcheck@v1
with: with:
github_token: ${{ secrets.github_token }} github_token: ${{ secrets.github_token }}

6
.gitignore vendored
View File

@@ -11,14 +11,12 @@
!api !api
# ignore # ignore
**/.idea .idea
**/.vscode
**/.DS_Store **/.DS_Store
**/logs **/logs
**/adhoc
**/coverage.txt
# for test purpose # for test purpose
**/adhoc
go.work go.work
go.work.sum go.work.sum

View File

@@ -1,76 +1,102 @@
# 🚀 Contributing to go-zero # Contributing
Welcome to the go-zero community! We're thrilled to have you here. Contributing to our project is a fantastic way to be a part of the go-zero journey. Let's make this guide exciting and fun! Welcome to go-zero!
## 📜 Before You Dive In - [Before you get started](#before-you-get-started)
- [Code of Conduct](#code-of-conduct)
- [Community Expectations](#community-expectations)
- [Getting started](#getting-started)
- [Your First Contribution](#your-first-contribution)
- [Find something to work on](#find-something-to-work-on)
- [Find a good first topic](#find-a-good-first-topic)
- [Work on an Issue](#work-on-an-issue)
- [File an Issue](#file-an-issue)
- [Contributor Workflow](#contributor-workflow)
- [Creating Pull Requests](#creating-pull-requests)
- [Code Review](#code-review)
- [Testing](#testing)
### 🤝 Code of Conduct # Before you get started
Let's start on the right foot. Please take a moment to read and embrace our [Code of Conduct](/code-of-conduct.md). We're all about creating a welcoming and respectful environment. ## Code of Conduct
### 🌟 Community Expectations Please make sure to read and observe our [Code of Conduct](/code-of-conduct.md).
At go-zero, we're like a close-knit family, and we believe in creating a healthy, friendly, and productive atmosphere. It's all about sharing knowledge and building amazing things together. ## Community Expectations
## 🚀 Getting Started go-zero is a community project driven by its community which strives to promote a healthy, friendly and productive environment.
go-zero is a web and rpc framework written in Go. It's born to ensure the stability of the busy sites with resilient design. Builtin goctl greatly improves the development productivity.
Get your adventure rolling! Here's how to begin: # Getting started
1. 🍴 **Fork the Repository**: Head over to the GitHub repository and fork it to your own space. - Fork the repository on GitHub.
- Make your changes on your fork repository.
- Submit a PR.
2. 🛠️ **Make Your Magic**: Work your magic in your forked repository. Create new features, squash bugs, or improve documentation - it's your world to conquer!
3. 🚀 **Submit a PR (Pull Request)**: When you're ready to unveil your creation, submit a Pull Request. We can't wait to see your awesome work! # Your First Contribution
## 🌟 Your First Contribution We will help you to contribute in different areas like filing issues, developing features, fixing critical bugs and
getting your work reviewed and merged.
We're here to guide you on your quest to become a go-zero contributor. Whether you want to file issues, develop features, or tame some critical bugs, we've got you covered. If you have questions about the development process,
feel free to [file an issue](https://github.com/zeromicro/go-zero/issues/new/choose).
If you have questions or need guidance at any stage, don't hesitate to [open an issue](https://github.com/zeromicro/go-zero/issues/new/choose). ## Find something to work on
## 🔍 Find Something to Work On We are always in need of help, be it fixing documentation, reporting bugs or writing some code.
Look at places where you feel best coding practices aren't followed, code refactoring is needed or tests are missing.
Here is how you get started.
Ready to dive into the action? There are several ways to contribute: ### Find a good first topic
### 💼 Find a Good First Topic [go-zero](https://github.com/zeromicro/go-zero) has beginner-friendly issues that provide a good first issue.
For example, [go-zero](https://github.com/zeromicro/go-zero) has
[help wanted](https://github.com/zeromicro/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) and
[good first issue](https://github.com/zeromicro/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)
labels for issues that should not need deep knowledge of the system.
We can help new contributors who wish to work on such issues.
Discover easy-entry issues labeled as [help wanted](https://github.com/zeromicro/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) or [good first issue](https://github.com/zeromicro/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). These issues are perfect for newcomers and don't require deep knowledge of the system. We're here to assist you with these tasks. Another good way to contribute is to find a documentation improvement, such as a missing/broken link.
Please see [Contributing](#contributing) below for the workflow.
### 🪄 Work on an Issue #### Work on an issue
Once you've picked an issue that excites you, let us know by commenting on it. Our maintainers will assign it to you, and you can embark on your mission! When you are willing to take on an issue, just reply on the issue. The maintainer will assign it to you.
### 📢 File an Issue ### File an Issue
Reporting an issue is just as valuable as code contributions. If you discover a problem, don't hesitate to [open an issue](https://github.com/zeromicro/go-zero/issues/new/choose). Be sure to follow our guidelines when submitting an issue. While we encourage everyone to contribute code, it is also appreciated when someone reports an issue.
## 🎯 Contributor Workflow Please follow the prompted submission guidelines while opening an issue.
Here's a rough guide to your contributor journey: # Contributor Workflow
1. 🌱 Create a New Branch: Start by creating a topic branch, usually based on the 'master' branch. This is where your contribution will grow. Please do not ever hesitate to ask a question or send a pull request.
2. 💡 Make Commits: Commit your work in logical units. Each commit should tell a story. This is a rough outline of what a contributor's workflow looks like:
3. 🚀 Push Changes: Push the changes in your topic branch to your personal fork of the repository. - Create a topic branch from where to base the contribution. This is usually master.
- Make commits of logical units.
- Push changes in a topic branch to a personal fork of the repository.
- Submit a pull request to [go-zero](https://github.com/zeromicro/go-zero).
4. 📦 Submit a Pull Request: When your creation is complete, submit a Pull Request to the [go-zero repository](https://github.com/zeromicro/go-zero). ## Creating Pull Requests
## 🌠 Creating Pull Requests Pull requests are often called simply "PR".
go-zero generally follows the standard [github pull request](https://help.github.com/articles/about-pull-requests/) process.
To submit a proposed change, please develop the code/fix and add new test cases.
After that, run these local verifications before submitting pull request to predict the pass or
fail of continuous integration.
Pull Requests (PRs) are your way of making a grand entrance with your contribution. Here's how to do it: * Format the code with `gofmt`
* Run the test with data race enabled `go test -race ./...`
- 💼 Format Your Code: Ensure your code is beautifully formatted with `gofmt`. ## Code Review
- 🏃 Run Tests: Verify that your changes pass all the tests, including data race tests. Run `go test -race ./...` for the ultimate validation.
## 👁️‍🗨️ Code Review To make it easier for your PR to receive reviews, consider the reviewers will need you to:
Getting your PR reviewed is the final step before your contribution becomes part of go-zero's magical world. To make the process smooth, keep these things in mind: * follow [good coding guidelines](https://github.com/golang/go/wiki/CodeReviewComments).
* write [good commit messages](https://chris.beams.io/posts/git-commit/).
* break large changes into a logical series of smaller patches which individually make easily understandable changes, and in aggregate solve a broader issue.
- 🧙‍♀️ Follow Good Coding Practices: Stick to [good coding guidelines](https://github.com/golang/go/wiki/CodeReviewComments).
- 📝 Write Awesome Commit Messages: Craft [impressive commit messages](https://chris.beams.io/posts/git-commit/) - they're like spells in the wizard's book!
- 🔍 Break It Down: For larger changes, consider breaking them into a series of smaller, logical patches. Each patch should make an understandable and meaningful improvement.
Congratulations on your contribution journey! We're thrilled to have you as part of our go-zero community. Let's make amazing things together! 🌟
Now, go out there and start your adventure! If you have any more magical ideas to enhance this guide, please share them. 🔥

View File

@@ -1,16 +0,0 @@
# Security Policy
## Supported Versions
We publish releases monthly.
| Version | Supported |
| ------- | ------------------ |
| >= 1.4.4 | :white_check_mark: |
| < 1.4.4 | :x: |
## Reporting a Vulnerability
https://github.com/zeromicro/go-zero/security/advisories
Accepted vulnerabilities are expected to be fixed within a month.

View File

@@ -1,127 +1,76 @@
# Contributor Covenant Code of Conduct # Contributor Covenant Code of Conduct
## Our Pledge ## Our Pledge
We as members, contributors, and leaders pledge to make participation in our In the interest of fostering an open and welcoming environment, we as
community a harassment-free experience for everyone, regardless of age, body contributors and maintainers pledge to make participation in our project and
size, visible or invisible disability, ethnicity, sex characteristics, gender our community a harassment-free experience for everyone, regardless of age, body
identity and expression, level of experience, education, socio-economic status, size, disability, ethnicity, sex characteristics, gender identity and expression,
nationality, personal appearance, race, caste, color, religion, or sexual level of experience, education, socio-economic status, nationality, personal
identity and orientation. appearance, race, religion, or sexual identity and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards ## Our Standards
Examples of behavior that contributes to a positive environment for our Examples of behavior that contributes to creating a positive environment
community include: include:
* Demonstrating empathy and kindness toward other people * Using welcoming and inclusive language
* Being respectful of differing opinions, viewpoints, and experiences * Being respectful of differing viewpoints and experiences
* Giving and gracefully accepting constructive feedback * Gracefully accepting constructive criticism
* Accepting responsibility and apologizing to those affected by our mistakes, * Focusing on what is best for the community
and learning from the experience * Showing empathy towards other community members
* Focusing on what is best not just for us as individuals, but for the overall
community
Examples of unacceptable behavior include: Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery, and sexual attention or advances of * The use of sexualized language or imagery and unwelcome sexual attention or
any kind advances
* Trolling, insulting or derogatory comments, and personal or political attacks * Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment * Public or private harassment
* Publishing others' private information, such as a physical or email address, * Publishing others' private information, such as a physical or electronic
without their explicit permission address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a * Other conduct which could reasonably be considered inappropriate in a
professional setting professional setting
## Enforcement Responsibilities ## Our Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of Project maintainers are responsible for clarifying the standards of acceptable
acceptable behavior and will take appropriate and fair corrective action in behavior and are expected to take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive, response to any instances of unacceptable behavior.
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject Project maintainers have the right and responsibility to remove, edit, or
comments, commits, code, wiki edits, issues, and other contributions that are reject comments, commits, code, wiki edits, issues, and other contributions
not aligned to this Code of Conduct, and will communicate reasons for moderation that are not aligned to this Code of Conduct, or to ban temporarily or
decisions when appropriate. permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope ## Scope
This Code of Conduct applies within all community spaces, and also applies when This Code of Conduct applies within all project spaces, and it also applies when
an individual is officially representing the community in public spaces. an individual is representing the project or its community in public spaces.
Examples of representing our community include using an official e-mail address, Examples of representing a project or community include using an official
posting via an official social media account, or acting as an appointed project e-mail address, posting via an official social media account, or acting
representative at an online or offline event. as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.
## Enforcement ## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at reported by contacting the project team at [INSERT EMAIL ADDRESS]. All
[INSERT CONTACT METHOD]. complaints will be reviewed and investigated and will result in a response that
All complaints will be reviewed and investigated promptly and fairly. is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
All community leaders are obligated to respect the privacy and security of the Project maintainers who do not follow or enforce the Code of Conduct in good
reporter of any incident. faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series of
actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or permanent
ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within the
community.
## Attribution ## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
version 2.1, available at available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
Community Impact Guidelines were inspired by [homepage]: https://www.contributor-covenant.org
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
For answers to common questions about this code of conduct, see the FAQ at For answers to common questions about this code of conduct, see
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at https://www.contributor-covenant.org/faq
[https://www.contributor-covenant.org/translations][translations].

View File

@@ -1,7 +1,6 @@
package bloom package bloom
import ( import (
"context"
"errors" "errors"
"strconv" "strconv"
@@ -9,29 +8,28 @@ import (
"github.com/zeromicro/go-zero/core/stores/redis" "github.com/zeromicro/go-zero/core/stores/redis"
) )
// for detailed error rate table, see http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html const (
// maps as k in the error rate table // for detailed error rate table, see http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html
const maps = 14 // maps as k in the error rate table
maps = 14
var ( setScript = `
// ErrTooLargeOffset indicates the offset is too large in bitset.
ErrTooLargeOffset = errors.New("too large offset")
setScript = redis.NewScript(`
for _, offset in ipairs(ARGV) do for _, offset in ipairs(ARGV) do
redis.call("setbit", KEYS[1], offset, 1) redis.call("setbit", KEYS[1], offset, 1)
end end
`) `
testScript = redis.NewScript(` testScript = `
for _, offset in ipairs(ARGV) do for _, offset in ipairs(ARGV) do
if tonumber(redis.call("getbit", KEYS[1], offset)) == 0 then if tonumber(redis.call("getbit", KEYS[1], offset)) == 0 then
return false return false
end end
end end
return true return true
`) `
) )
// ErrTooLargeOffset indicates the offset is too large in bitset.
var ErrTooLargeOffset = errors.New("too large offset")
type ( type (
// A Filter is a bloom filter. // A Filter is a bloom filter.
Filter struct { Filter struct {
@@ -40,8 +38,8 @@ type (
} }
bitSetProvider interface { bitSetProvider interface {
check(ctx context.Context, offsets []uint) (bool, error) check([]uint) (bool, error)
set(ctx context.Context, offsets []uint) error set([]uint) error
} }
) )
@@ -60,24 +58,14 @@ func New(store *redis.Redis, key string, bits uint) *Filter {
// Add adds data into f. // Add adds data into f.
func (f *Filter) Add(data []byte) error { func (f *Filter) Add(data []byte) error {
return f.AddCtx(context.Background(), data)
}
// AddCtx adds data into f with context.
func (f *Filter) AddCtx(ctx context.Context, data []byte) error {
locations := f.getLocations(data) locations := f.getLocations(data)
return f.bitSet.set(ctx, locations) return f.bitSet.set(locations)
} }
// Exists checks if data is in f. // Exists checks if data is in f.
func (f *Filter) Exists(data []byte) (bool, error) { func (f *Filter) Exists(data []byte) (bool, error) {
return f.ExistsCtx(context.Background(), data)
}
// ExistsCtx checks if data is in f with context.
func (f *Filter) ExistsCtx(ctx context.Context, data []byte) (bool, error) {
locations := f.getLocations(data) locations := f.getLocations(data)
isSet, err := f.bitSet.check(ctx, locations) isSet, err := f.bitSet.check(locations)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -123,14 +111,14 @@ func (r *redisBitSet) buildOffsetArgs(offsets []uint) ([]string, error) {
return args, nil return args, nil
} }
func (r *redisBitSet) check(ctx context.Context, offsets []uint) (bool, error) { func (r *redisBitSet) check(offsets []uint) (bool, error) {
args, err := r.buildOffsetArgs(offsets) args, err := r.buildOffsetArgs(offsets)
if err != nil { if err != nil {
return false, err return false, err
} }
resp, err := r.store.ScriptRunCtx(ctx, testScript, []string{r.key}, args) resp, err := r.store.Eval(testScript, []string{r.key}, args)
if errors.Is(err, redis.Nil) { if err == redis.Nil {
return false, nil return false, nil
} else if err != nil { } else if err != nil {
return false, err return false, err
@@ -144,25 +132,23 @@ func (r *redisBitSet) check(ctx context.Context, offsets []uint) (bool, error) {
return exists == 1, nil return exists == 1, nil
} }
// del only use for testing.
func (r *redisBitSet) del() error { func (r *redisBitSet) del() error {
_, err := r.store.Del(r.key) _, err := r.store.Del(r.key)
return err return err
} }
// expire only use for testing.
func (r *redisBitSet) expire(seconds int) error { func (r *redisBitSet) expire(seconds int) error {
return r.store.Expire(r.key, seconds) return r.store.Expire(r.key, seconds)
} }
func (r *redisBitSet) set(ctx context.Context, offsets []uint) error { func (r *redisBitSet) set(offsets []uint) error {
args, err := r.buildOffsetArgs(offsets) args, err := r.buildOffsetArgs(offsets)
if err != nil { if err != nil {
return err return err
} }
_, err = r.store.ScriptRunCtx(ctx, setScript, []string{r.key}, args) _, err = r.store.Eval(setScript, []string{r.key}, args)
if errors.Is(err, redis.Nil) { if err == redis.Nil {
return nil return nil
} }

View File

@@ -1,31 +1,30 @@
package bloom package bloom
import ( import (
"context"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/redis/redistest" "github.com/zeromicro/go-zero/core/stores/redis/redistest"
) )
func TestRedisBitSet_New_Set_Test(t *testing.T) { func TestRedisBitSet_New_Set_Test(t *testing.T) {
store := redistest.CreateRedis(t) store, clean, err := redistest.CreateRedis()
ctx := context.Background() assert.Nil(t, err)
defer clean()
bitSet := newRedisBitSet(store, "test_key", 1024) bitSet := newRedisBitSet(store, "test_key", 1024)
isSetBefore, err := bitSet.check(ctx, []uint{0}) isSetBefore, err := bitSet.check([]uint{0})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if isSetBefore { if isSetBefore {
t.Fatal("Bit should not be set") t.Fatal("Bit should not be set")
} }
err = bitSet.set(ctx, []uint{512}) err = bitSet.set([]uint{512})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
isSetAfter, err := bitSet.check(ctx, []uint{512}) isSetAfter, err := bitSet.check([]uint{512})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -43,7 +42,9 @@ func TestRedisBitSet_New_Set_Test(t *testing.T) {
} }
func TestRedisBitSet_Add(t *testing.T) { func TestRedisBitSet_Add(t *testing.T) {
store := redistest.CreateRedis(t) store, clean, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean()
filter := New(store, "test_key", 64) filter := New(store, "test_key", 64)
assert.Nil(t, filter.Add([]byte("hello"))) assert.Nil(t, filter.Add([]byte("hello")))
@@ -52,51 +53,3 @@ func TestRedisBitSet_Add(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
} }
func TestFilter_Exists(t *testing.T) {
store, clean := redistest.CreateRedisWithClean(t)
rbs := New(store, "test", 64)
_, err := rbs.Exists([]byte{0, 1, 2})
assert.NoError(t, err)
clean()
rbs = New(store, "test", 64)
_, err = rbs.Exists([]byte{0, 1, 2})
assert.Error(t, err)
}
func TestRedisBitSet_check(t *testing.T) {
store, clean := redistest.CreateRedisWithClean(t)
ctx := context.Background()
rbs := newRedisBitSet(store, "test", 0)
assert.Error(t, rbs.set(ctx, []uint{0, 1, 2}))
_, err := rbs.check(ctx, []uint{0, 1, 2})
assert.Error(t, err)
rbs = newRedisBitSet(store, "test", 64)
_, err = rbs.check(ctx, []uint{0, 1, 2})
assert.NoError(t, err)
clean()
rbs = newRedisBitSet(store, "test", 64)
_, err = rbs.check(ctx, []uint{0, 1, 2})
assert.Error(t, err)
}
func TestRedisBitSet_set(t *testing.T) {
logx.Disable()
store, clean := redistest.CreateRedisWithClean(t)
ctx := context.Background()
rbs := newRedisBitSet(store, "test", 0)
assert.Error(t, rbs.set(ctx, []uint{0, 1, 2}))
rbs = newRedisBitSet(store, "test", 64)
assert.NoError(t, rbs.set(ctx, []uint{0, 1, 2}))
clean()
rbs = newRedisBitSet(store, "test", 64)
assert.Error(t, rbs.set(ctx, []uint{0, 1, 2}))
}

View File

@@ -31,10 +31,9 @@ type (
Name() string Name() string
// Allow checks if the request is allowed. // Allow checks if the request is allowed.
// If allowed, a promise will be returned, // If allowed, a promise will be returned, the caller needs to call promise.Accept()
// otherwise ErrServiceUnavailable will be returned as the error. // on success, or call promise.Reject() on failure.
// The caller needs to call promise.Accept() on success, // If not allow, ErrServiceUnavailable will be returned.
// or call promise.Reject() on failure.
Allow() (Promise, error) Allow() (Promise, error)
// Do runs the given request if the Breaker accepts it. // Do runs the given request if the Breaker accepts it.
@@ -47,26 +46,23 @@ type (
// DoWithAcceptable returns an error instantly if the Breaker rejects the request. // DoWithAcceptable returns an error instantly if the Breaker rejects the request.
// If a panic occurs in the request, the Breaker handles it as an error // If a panic occurs in the request, the Breaker handles it as an error
// and causes the same panic again. // and causes the same panic again.
// acceptable checks if it's a successful call, even if the error is not nil. // acceptable checks if it's a successful call, even if the err is not nil.
DoWithAcceptable(req func() error, acceptable Acceptable) error DoWithAcceptable(req func() error, acceptable Acceptable) error
// DoWithFallback runs the given request if the Breaker accepts it. // DoWithFallback runs the given request if the Breaker accepts it.
// DoWithFallback runs the fallback if the Breaker rejects the request. // DoWithFallback runs the fallback if the Breaker rejects the request.
// If a panic occurs in the request, the Breaker handles it as an error // If a panic occurs in the request, the Breaker handles it as an error
// and causes the same panic again. // and causes the same panic again.
DoWithFallback(req func() error, fallback Fallback) error DoWithFallback(req func() error, fallback func(err error) error) error
// DoWithFallbackAcceptable runs the given request if the Breaker accepts it. // DoWithFallbackAcceptable runs the given request if the Breaker accepts it.
// DoWithFallbackAcceptable runs the fallback if the Breaker rejects the request. // DoWithFallbackAcceptable runs the fallback if the Breaker rejects the request.
// If a panic occurs in the request, the Breaker handles it as an error // If a panic occurs in the request, the Breaker handles it as an error
// and causes the same panic again. // and causes the same panic again.
// acceptable checks if it's a successful call, even if the error is not nil. // acceptable checks if it's a successful call, even if the err is not nil.
DoWithFallbackAcceptable(req func() error, fallback Fallback, acceptable Acceptable) error DoWithFallbackAcceptable(req func() error, fallback func(err error) error, acceptable Acceptable) error
} }
// Fallback is the func to be called if the request is rejected.
Fallback func(err error) error
// Option defines the method to customize a Breaker. // Option defines the method to customize a Breaker.
Option func(breaker *circuitBreaker) Option func(breaker *circuitBreaker)
@@ -90,12 +86,12 @@ type (
internalThrottle interface { internalThrottle interface {
allow() (internalPromise, error) allow() (internalPromise, error)
doReq(req func() error, fallback Fallback, acceptable Acceptable) error doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error
} }
throttle interface { throttle interface {
allow() (Promise, error) allow() (Promise, error)
doReq(req func() error, fallback Fallback, acceptable Acceptable) error doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error
} }
) )
@@ -126,11 +122,11 @@ func (cb *circuitBreaker) DoWithAcceptable(req func() error, acceptable Acceptab
return cb.throttle.doReq(req, nil, acceptable) return cb.throttle.doReq(req, nil, acceptable)
} }
func (cb *circuitBreaker) DoWithFallback(req func() error, fallback Fallback) error { func (cb *circuitBreaker) DoWithFallback(req func() error, fallback func(err error) error) error {
return cb.throttle.doReq(req, fallback, defaultAcceptable) return cb.throttle.doReq(req, fallback, defaultAcceptable)
} }
func (cb *circuitBreaker) DoWithFallbackAcceptable(req func() error, fallback Fallback, func (cb *circuitBreaker) DoWithFallbackAcceptable(req func() error, fallback func(err error) error,
acceptable Acceptable) error { acceptable Acceptable) error {
return cb.throttle.doReq(req, fallback, acceptable) return cb.throttle.doReq(req, fallback, acceptable)
} }
@@ -172,7 +168,7 @@ func (lt loggedThrottle) allow() (Promise, error) {
}, lt.logError(err) }, lt.logError(err)
} }
func (lt loggedThrottle) doReq(req func() error, fallback Fallback, acceptable Acceptable) error { func (lt loggedThrottle) doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error {
return lt.logError(lt.internalThrottle.doReq(req, fallback, func(err error) bool { return lt.logError(lt.internalThrottle.doReq(req, fallback, func(err error) bool {
accept := acceptable(err) accept := acceptable(err)
if !accept && err != nil { if !accept && err != nil {
@@ -183,7 +179,7 @@ func (lt loggedThrottle) doReq(req func() error, fallback Fallback, acceptable A
} }
func (lt loggedThrottle) logError(err error) error { func (lt loggedThrottle) logError(err error) error {
if errors.Is(err, ErrServiceUnavailable) { if err == ErrServiceUnavailable {
// if circuit open, not possible to have empty error window // if circuit open, not possible to have empty error window
stat.Report(fmt.Sprintf( stat.Report(fmt.Sprintf(
"proc(%s/%d), callee: %s, breaker is open and requests dropped\nlast errors:\n%s", "proc(%s/%d), callee: %s, breaker is open and requests dropped\nlast errors:\n%s",

View File

@@ -22,14 +22,14 @@ func DoWithAcceptable(name string, req func() error, acceptable Acceptable) erro
} }
// DoWithFallback calls Breaker.DoWithFallback on the Breaker with given name. // DoWithFallback calls Breaker.DoWithFallback on the Breaker with given name.
func DoWithFallback(name string, req func() error, fallback Fallback) error { func DoWithFallback(name string, req func() error, fallback func(err error) error) error {
return do(name, func(b Breaker) error { return do(name, func(b Breaker) error {
return b.DoWithFallback(req, fallback) return b.DoWithFallback(req, fallback)
}) })
} }
// DoWithFallbackAcceptable calls Breaker.DoWithFallbackAcceptable on the Breaker with given name. // DoWithFallbackAcceptable calls Breaker.DoWithFallbackAcceptable on the Breaker with given name.
func DoWithFallbackAcceptable(name string, req func() error, fallback Fallback, func DoWithFallbackAcceptable(name string, req func() error, fallback func(err error) error,
acceptable Acceptable) error { acceptable Acceptable) error {
return do(name, func(b Breaker) error { return do(name, func(b Breaker) error {
return b.DoWithFallbackAcceptable(req, fallback, acceptable) return b.DoWithFallbackAcceptable(req, fallback, acceptable)
@@ -59,7 +59,7 @@ func GetBreaker(name string) Breaker {
// NoBreakerFor disables the circuit breaker for the given name. // NoBreakerFor disables the circuit breaker for the given name.
func NoBreakerFor(name string) { func NoBreakerFor(name string) {
lock.Lock() lock.Lock()
breakers[name] = NopBreaker() breakers[name] = newNoOpBreaker()
lock.Unlock() lock.Unlock()
} }

View File

@@ -30,7 +30,7 @@ func TestBreakersDoWithAcceptable(t *testing.T) {
assert.Equal(t, errDummy, GetBreaker("anyone").DoWithAcceptable(func() error { assert.Equal(t, errDummy, GetBreaker("anyone").DoWithAcceptable(func() error {
return errDummy return errDummy
}, func(err error) bool { }, func(err error) bool {
return err == nil || errors.Is(err, errDummy) return err == nil || err == errDummy
})) }))
} }
verify(t, func() bool { verify(t, func() bool {
@@ -45,12 +45,12 @@ func TestBreakersDoWithAcceptable(t *testing.T) {
}, func(err error) bool { }, func(err error) bool {
return err == nil return err == nil
}) })
assert.True(t, errors.Is(err, errDummy) || errors.Is(err, ErrServiceUnavailable)) assert.True(t, err == errDummy || err == ErrServiceUnavailable)
} }
verify(t, func() bool { verify(t, func() bool {
return errors.Is(Do("another", func() error { return ErrServiceUnavailable == Do("another", func() error {
return nil return nil
}), ErrServiceUnavailable) })
}) })
} }
@@ -75,12 +75,12 @@ func TestBreakersFallback(t *testing.T) {
}, func(err error) error { }, func(err error) error {
return nil return nil
}) })
assert.True(t, err == nil || errors.Is(err, errDummy)) assert.True(t, err == nil || err == errDummy)
} }
verify(t, func() bool { verify(t, func() bool {
return errors.Is(Do("fallback", func() error { return ErrServiceUnavailable == Do("fallback", func() error {
return nil return nil
}), ErrServiceUnavailable) })
}) })
} }
@@ -94,12 +94,12 @@ func TestBreakersAcceptableFallback(t *testing.T) {
}, func(err error) bool { }, func(err error) bool {
return err == nil return err == nil
}) })
assert.True(t, err == nil || errors.Is(err, errDummy)) assert.True(t, err == nil || err == errDummy)
} }
verify(t, func() bool { verify(t, func() bool {
return errors.Is(Do("acceptablefallback", func() error { return ErrServiceUnavailable == Do("acceptablefallback", func() error {
return nil return nil
}), ErrServiceUnavailable) })
}) })
} }

View File

@@ -1,6 +1,7 @@
package breaker package breaker
import ( import (
"math"
"time" "time"
"github.com/zeromicro/go-zero/core/collection" "github.com/zeromicro/go-zero/core/collection"
@@ -37,8 +38,7 @@ func (b *googleBreaker) accept() error {
accepts, total := b.history() accepts, total := b.history()
weightedAccepts := b.k * float64(accepts) weightedAccepts := b.k * float64(accepts)
// https://landing.google.com/sre/sre-book/chapters/handling-overload/#eq2101 // https://landing.google.com/sre/sre-book/chapters/handling-overload/#eq2101
// for better performance, no need to care about negative ratio dropRatio := math.Max(0, (float64(total-protection)-weightedAccepts)/float64(total+1))
dropRatio := (float64(total-protection) - weightedAccepts) / float64(total+1)
if dropRatio <= 0 { if dropRatio <= 0 {
return nil return nil
} }
@@ -52,7 +52,6 @@ func (b *googleBreaker) accept() error {
func (b *googleBreaker) allow() (internalPromise, error) { func (b *googleBreaker) allow() (internalPromise, error) {
if err := b.accept(); err != nil { if err := b.accept(); err != nil {
b.markFailure()
return nil, err return nil, err
} }
@@ -61,9 +60,8 @@ func (b *googleBreaker) allow() (internalPromise, error) {
}, nil }, nil
} }
func (b *googleBreaker) doReq(req func() error, fallback Fallback, acceptable Acceptable) error { func (b *googleBreaker) doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error {
if err := b.accept(); err != nil { if err := b.accept(); err != nil {
b.markFailure()
if fallback != nil { if fallback != nil {
return fallback(err) return fallback(err)
} }
@@ -71,19 +69,18 @@ func (b *googleBreaker) doReq(req func() error, fallback Fallback, acceptable Ac
return err return err
} }
var success bool
defer func() { defer func() {
// if req() panic, success is false, mark as failure if e := recover(); e != nil {
if success {
b.markSuccess()
} else {
b.markFailure() b.markFailure()
panic(e)
} }
}() }()
err := req() err := req()
if acceptable(err) { if acceptable(err) {
success = true b.markSuccess()
} else {
b.markFailure()
} }
return err return err

View File

@@ -95,7 +95,7 @@ func TestGoogleBreakerAcceptable(t *testing.T) {
assert.Equal(t, errAcceptable, b.doReq(func() error { assert.Equal(t, errAcceptable, b.doReq(func() error {
return errAcceptable return errAcceptable
}, nil, func(err error) bool { }, nil, func(err error) bool {
return errors.Is(err, errAcceptable) return err == errAcceptable
})) }))
} }
@@ -105,7 +105,7 @@ func TestGoogleBreakerNotAcceptable(t *testing.T) {
assert.Equal(t, errAcceptable, b.doReq(func() error { assert.Equal(t, errAcceptable, b.doReq(func() error {
return errAcceptable return errAcceptable
}, nil, func(err error) bool { }, nil, func(err error) bool {
return !errors.Is(err, errAcceptable) return err != errAcceptable
})) }))
} }
@@ -206,7 +206,7 @@ func BenchmarkGoogleBreakerAllow(b *testing.B) {
breaker := getGoogleBreaker() breaker := getGoogleBreaker()
b.ResetTimer() b.ResetTimer()
for i := 0; i <= b.N; i++ { for i := 0; i <= b.N; i++ {
_ = breaker.accept() breaker.accept()
if i%2 == 0 { if i%2 == 0 {
breaker.markSuccess() breaker.markSuccess()
} else { } else {
@@ -215,16 +215,6 @@ func BenchmarkGoogleBreakerAllow(b *testing.B) {
} }
} }
func BenchmarkGoogleBreakerDoReq(b *testing.B) {
breaker := getGoogleBreaker()
b.ResetTimer()
for i := 0; i <= b.N; i++ {
_ = breaker.doReq(func() error {
return nil
}, nil, defaultAcceptable)
}
}
func markSuccess(b *googleBreaker, count int) { func markSuccess(b *googleBreaker, count int) {
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
p, err := b.allow() p, err := b.allow()

View File

@@ -1,35 +1,35 @@
package breaker package breaker
const nopBreakerName = "nopBreaker" const noOpBreakerName = "nopBreaker"
type nopBreaker struct{} type noOpBreaker struct{}
// NopBreaker returns a breaker that never trigger breaker circuit. func newNoOpBreaker() Breaker {
func NopBreaker() Breaker { return noOpBreaker{}
return nopBreaker{}
} }
func (b nopBreaker) Name() string { func (b noOpBreaker) Name() string {
return nopBreakerName return noOpBreakerName
} }
func (b nopBreaker) Allow() (Promise, error) { func (b noOpBreaker) Allow() (Promise, error) {
return nopPromise{}, nil return nopPromise{}, nil
} }
func (b nopBreaker) Do(req func() error) error { func (b noOpBreaker) Do(req func() error) error {
return req() return req()
} }
func (b nopBreaker) DoWithAcceptable(req func() error, _ Acceptable) error { func (b noOpBreaker) DoWithAcceptable(req func() error, _ Acceptable) error {
return req() return req()
} }
func (b nopBreaker) DoWithFallback(req func() error, _ Fallback) error { func (b noOpBreaker) DoWithFallback(req func() error, _ func(err error) error) error {
return req() return req()
} }
func (b nopBreaker) DoWithFallbackAcceptable(req func() error, _ Fallback, _ Acceptable) error { func (b noOpBreaker) DoWithFallbackAcceptable(req func() error, _ func(err error) error,
_ Acceptable) error {
return req() return req()
} }

View File

@@ -8,8 +8,8 @@ import (
) )
func TestNopBreaker(t *testing.T) { func TestNopBreaker(t *testing.T) {
b := NopBreaker() b := newNoOpBreaker()
assert.Equal(t, nopBreakerName, b.Name()) assert.Equal(t, noOpBreakerName, b.Name())
p, err := b.Allow() p, err := b.Allow()
assert.Nil(t, err) assert.Nil(t, err)
p.Accept() p.Accept()

View File

@@ -23,7 +23,7 @@ var (
zero = big.NewInt(0) zero = big.NewInt(0)
) )
// DhKey defines the Diffie-Hellman key. // DhKey defines the Diffie Hellman key.
type DhKey struct { type DhKey struct {
PriKey *big.Int PriKey *big.Int
PubKey *big.Int PubKey *big.Int
@@ -46,7 +46,7 @@ func ComputeKey(pubKey, priKey *big.Int) (*big.Int, error) {
return new(big.Int).Exp(pubKey, priKey, p), nil return new(big.Int).Exp(pubKey, priKey, p), nil
} }
// GenerateKey returns a Diffie-Hellman key. // GenerateKey returns a Diffie Hellman key.
func GenerateKey() (*DhKey, error) { func GenerateKey() (*DhKey, error) {
var err error var err error
var x *big.Int var x *big.Int

View File

@@ -2,8 +2,6 @@ package codec
import ( import (
"bytes" "bytes"
"compress/gzip"
"errors"
"fmt" "fmt"
"testing" "testing"
@@ -23,45 +21,3 @@ func TestGzip(t *testing.T) {
assert.True(t, len(bs) < buf.Len()) assert.True(t, len(bs) < buf.Len())
assert.Equal(t, buf.Bytes(), actual) assert.Equal(t, buf.Bytes(), actual)
} }
func TestGunzip(t *testing.T) {
tests := []struct {
name string
input []byte
expected []byte
expectedErr error
}{
{
name: "valid input",
input: func() []byte {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
gz.Write([]byte("hello"))
gz.Close()
return buf.Bytes()
}(),
expected: []byte("hello"),
expectedErr: nil,
},
{
name: "invalid input",
input: []byte("invalid input"),
expected: nil,
expectedErr: gzip.ErrHeader,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result, err := Gunzip(test.input)
if !bytes.Equal(result, test.expected) {
t.Errorf("unexpected result: %v", result)
}
if !errors.Is(err, test.expectedErr) {
t.Errorf("unexpected error: %v", err)
}
})
}
}

View File

@@ -2,7 +2,6 @@ package codec
import ( import (
"encoding/base64" "encoding/base64"
"os"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -42,7 +41,6 @@ func TestCryption(t *testing.T) {
file, err := fs.TempFilenameWithText(priKey) file, err := fs.TempFilenameWithText(priKey)
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(file)
dec, err := NewRsaDecrypter(file) dec, err := NewRsaDecrypter(file)
assert.Nil(t, err) assert.Nil(t, err)
actual, err := dec.Decrypt(ret) actual, err := dec.Decrypt(ret)

View File

@@ -30,7 +30,7 @@ type (
Cache struct { Cache struct {
name string name string
lock sync.Mutex lock sync.Mutex
data map[string]any data map[string]interface{}
expire time.Duration expire time.Duration
timingWheel *TimingWheel timingWheel *TimingWheel
lruCache lru lruCache lru
@@ -43,7 +43,7 @@ type (
// NewCache returns a Cache with given expire. // NewCache returns a Cache with given expire.
func NewCache(expire time.Duration, opts ...CacheOption) (*Cache, error) { func NewCache(expire time.Duration, opts ...CacheOption) (*Cache, error) {
cache := &Cache{ cache := &Cache{
data: make(map[string]any), data: make(map[string]interface{}),
expire: expire, expire: expire,
lruCache: emptyLruCache, lruCache: emptyLruCache,
barrier: syncx.NewSingleFlight(), barrier: syncx.NewSingleFlight(),
@@ -59,7 +59,7 @@ func NewCache(expire time.Duration, opts ...CacheOption) (*Cache, error) {
} }
cache.stats = newCacheStat(cache.name, cache.size) cache.stats = newCacheStat(cache.name, cache.size)
timingWheel, err := NewTimingWheel(time.Second, slots, func(k, v any) { timingWheel, err := NewTimingWheel(time.Second, slots, func(k, v interface{}) {
key, ok := k.(string) key, ok := k.(string)
if !ok { if !ok {
return return
@@ -85,7 +85,7 @@ func (c *Cache) Del(key string) {
} }
// Get returns the item with the given key from c. // Get returns the item with the given key from c.
func (c *Cache) Get(key string) (any, bool) { func (c *Cache) Get(key string) (interface{}, bool) {
value, ok := c.doGet(key) value, ok := c.doGet(key)
if ok { if ok {
c.stats.IncrementHit() c.stats.IncrementHit()
@@ -97,12 +97,12 @@ func (c *Cache) Get(key string) (any, bool) {
} }
// Set sets value into c with key. // Set sets value into c with key.
func (c *Cache) Set(key string, value any) { func (c *Cache) Set(key string, value interface{}) {
c.SetWithExpire(key, value, c.expire) c.SetWithExpire(key, value, c.expire)
} }
// SetWithExpire sets value into c with key and expire with the given value. // SetWithExpire sets value into c with key and expire with the given value.
func (c *Cache) SetWithExpire(key string, value any, expire time.Duration) { func (c *Cache) SetWithExpire(key string, value interface{}, expire time.Duration) {
c.lock.Lock() c.lock.Lock()
_, ok := c.data[key] _, ok := c.data[key]
c.data[key] = value c.data[key] = value
@@ -120,16 +120,16 @@ func (c *Cache) SetWithExpire(key string, value any, expire time.Duration) {
// Take returns the item with the given key. // Take returns the item with the given key.
// If the item is in c, return it directly. // If the item is in c, return it directly.
// If not, use fetch method to get the item, set into c and return it. // If not, use fetch method to get the item, set into c and return it.
func (c *Cache) Take(key string, fetch func() (any, error)) (any, error) { func (c *Cache) Take(key string, fetch func() (interface{}, error)) (interface{}, error) {
if val, ok := c.doGet(key); ok { if val, ok := c.doGet(key); ok {
c.stats.IncrementHit() c.stats.IncrementHit()
return val, nil return val, nil
} }
var fresh bool var fresh bool
val, err := c.barrier.Do(key, func() (any, error) { val, err := c.barrier.Do(key, func() (interface{}, error) {
// because O(1) on map search in memory, and fetch is an IO query, // because O(1) on map search in memory, and fetch is an IO query
// so we do double-check, cache might be taken by another call // so we do double check, cache might be taken by another call
if val, ok := c.doGet(key); ok { if val, ok := c.doGet(key); ok {
return val, nil return val, nil
} }
@@ -157,7 +157,7 @@ func (c *Cache) Take(key string, fetch func() (any, error)) (any, error) {
return val, nil return val, nil
} }
func (c *Cache) doGet(key string) (any, bool) { func (c *Cache) doGet(key string) (interface{}, bool) {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()

View File

@@ -52,7 +52,7 @@ func TestCacheTake(t *testing.T) {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
wg.Add(1) wg.Add(1)
go func() { go func() {
cache.Take("first", func() (any, error) { cache.Take("first", func() (interface{}, error) {
atomic.AddInt32(&count, 1) atomic.AddInt32(&count, 1)
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
return "first element", nil return "first element", nil
@@ -76,7 +76,7 @@ func TestCacheTakeExists(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
cache.Set("first", "first element") cache.Set("first", "first element")
cache.Take("first", func() (any, error) { cache.Take("first", func() (interface{}, error) {
atomic.AddInt32(&count, 1) atomic.AddInt32(&count, 1)
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
return "first element", nil return "first element", nil
@@ -99,7 +99,7 @@ func TestCacheTakeError(t *testing.T) {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
wg.Add(1) wg.Add(1)
go func() { go func() {
_, err := cache.Take("first", func() (any, error) { _, err := cache.Take("first", func() (interface{}, error) {
atomic.AddInt32(&count, 1) atomic.AddInt32(&count, 1)
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
return "", errDummy return "", errDummy

View File

@@ -5,7 +5,7 @@ import "sync"
// A Queue is a FIFO queue. // A Queue is a FIFO queue.
type Queue struct { type Queue struct {
lock sync.Mutex lock sync.Mutex
elements []any elements []interface{}
size int size int
head int head int
tail int tail int
@@ -15,7 +15,7 @@ type Queue struct {
// NewQueue returns a Queue object. // NewQueue returns a Queue object.
func NewQueue(size int) *Queue { func NewQueue(size int) *Queue {
return &Queue{ return &Queue{
elements: make([]any, size), elements: make([]interface{}, size),
size: size, size: size,
} }
} }
@@ -30,12 +30,12 @@ func (q *Queue) Empty() bool {
} }
// Put puts element into q at the last position. // Put puts element into q at the last position.
func (q *Queue) Put(element any) { func (q *Queue) Put(element interface{}) {
q.lock.Lock() q.lock.Lock()
defer q.lock.Unlock() defer q.lock.Unlock()
if q.head == q.tail && q.count > 0 { if q.head == q.tail && q.count > 0 {
nodes := make([]any, len(q.elements)+q.size) nodes := make([]interface{}, len(q.elements)+q.size)
copy(nodes, q.elements[q.head:]) copy(nodes, q.elements[q.head:])
copy(nodes[len(q.elements)-q.head:], q.elements[:q.head]) copy(nodes[len(q.elements)-q.head:], q.elements[:q.head])
q.head = 0 q.head = 0
@@ -49,7 +49,7 @@ func (q *Queue) Put(element any) {
} }
// Take takes the first element out of q if not empty. // Take takes the first element out of q if not empty.
func (q *Queue) Take() (any, bool) { func (q *Queue) Take() (interface{}, bool) {
q.lock.Lock() q.lock.Lock()
defer q.lock.Unlock() defer q.lock.Unlock()

View File

@@ -4,7 +4,7 @@ import "sync"
// A Ring can be used as fixed size ring. // A Ring can be used as fixed size ring.
type Ring struct { type Ring struct {
elements []any elements []interface{}
index int index int
lock sync.RWMutex lock sync.RWMutex
} }
@@ -16,44 +16,36 @@ func NewRing(n int) *Ring {
} }
return &Ring{ return &Ring{
elements: make([]any, n), elements: make([]interface{}, n),
} }
} }
// Add adds v into r. // Add adds v into r.
func (r *Ring) Add(v any) { func (r *Ring) Add(v interface{}) {
r.lock.Lock() r.lock.Lock()
defer r.lock.Unlock() defer r.lock.Unlock()
rlen := len(r.elements) r.elements[r.index%len(r.elements)] = v
r.elements[r.index%rlen] = v
r.index++ r.index++
// prevent ring index overflow
if r.index >= rlen<<1 {
r.index -= rlen
}
} }
// Take takes all items from r. // Take takes all items from r.
func (r *Ring) Take() []any { func (r *Ring) Take() []interface{} {
r.lock.RLock() r.lock.RLock()
defer r.lock.RUnlock() defer r.lock.RUnlock()
var size int var size int
var start int var start int
rlen := len(r.elements) if r.index > len(r.elements) {
size = len(r.elements)
if r.index > rlen { start = r.index % len(r.elements)
size = rlen
start = r.index % rlen
} else { } else {
size = r.index size = r.index
} }
elements := make([]any, size) elements := make([]interface{}, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
elements[i] = r.elements[(start+i)%rlen] elements[i] = r.elements[(start+i)%len(r.elements)]
} }
return elements return elements

View File

@@ -19,7 +19,7 @@ func TestRingLess(t *testing.T) {
ring.Add(i) ring.Add(i)
} }
elements := ring.Take() elements := ring.Take()
assert.ElementsMatch(t, []any{0, 1, 2}, elements) assert.ElementsMatch(t, []interface{}{0, 1, 2}, elements)
} }
func TestRingMore(t *testing.T) { func TestRingMore(t *testing.T) {
@@ -28,7 +28,7 @@ func TestRingMore(t *testing.T) {
ring.Add(i) ring.Add(i)
} }
elements := ring.Take() elements := ring.Take()
assert.ElementsMatch(t, []any{6, 7, 8, 9, 10}, elements) assert.ElementsMatch(t, []interface{}{6, 7, 8, 9, 10}, elements)
} }
func TestRingAdd(t *testing.T) { func TestRingAdd(t *testing.T) {

View File

@@ -14,23 +14,21 @@ type SafeMap struct {
lock sync.RWMutex lock sync.RWMutex
deletionOld int deletionOld int
deletionNew int deletionNew int
dirtyOld map[any]any dirtyOld map[interface{}]interface{}
dirtyNew map[any]any dirtyNew map[interface{}]interface{}
} }
// NewSafeMap returns a SafeMap. // NewSafeMap returns a SafeMap.
func NewSafeMap() *SafeMap { func NewSafeMap() *SafeMap {
return &SafeMap{ return &SafeMap{
dirtyOld: make(map[any]any), dirtyOld: make(map[interface{}]interface{}),
dirtyNew: make(map[any]any), dirtyNew: make(map[interface{}]interface{}),
} }
} }
// Del deletes the value with the given key from m. // Del deletes the value with the given key from m.
func (m *SafeMap) Del(key any) { func (m *SafeMap) Del(key interface{}) {
m.lock.Lock() m.lock.Lock()
defer m.lock.Unlock()
if _, ok := m.dirtyOld[key]; ok { if _, ok := m.dirtyOld[key]; ok {
delete(m.dirtyOld, key) delete(m.dirtyOld, key)
m.deletionOld++ m.deletionOld++
@@ -44,20 +42,21 @@ func (m *SafeMap) Del(key any) {
} }
m.dirtyOld = m.dirtyNew m.dirtyOld = m.dirtyNew
m.deletionOld = m.deletionNew m.deletionOld = m.deletionNew
m.dirtyNew = make(map[any]any) m.dirtyNew = make(map[interface{}]interface{})
m.deletionNew = 0 m.deletionNew = 0
} }
if m.deletionNew >= maxDeletion && len(m.dirtyNew) < copyThreshold { if m.deletionNew >= maxDeletion && len(m.dirtyNew) < copyThreshold {
for k, v := range m.dirtyNew { for k, v := range m.dirtyNew {
m.dirtyOld[k] = v m.dirtyOld[k] = v
} }
m.dirtyNew = make(map[any]any) m.dirtyNew = make(map[interface{}]interface{})
m.deletionNew = 0 m.deletionNew = 0
} }
m.lock.Unlock()
} }
// Get gets the value with the given key from m. // Get gets the value with the given key from m.
func (m *SafeMap) Get(key any) (any, bool) { func (m *SafeMap) Get(key interface{}) (interface{}, bool) {
m.lock.RLock() m.lock.RLock()
defer m.lock.RUnlock() defer m.lock.RUnlock()
@@ -71,7 +70,7 @@ func (m *SafeMap) Get(key any) (any, bool) {
// Range calls f sequentially for each key and value present in the map. // Range calls f sequentially for each key and value present in the map.
// If f returns false, range stops the iteration. // If f returns false, range stops the iteration.
func (m *SafeMap) Range(f func(key, val any) bool) { func (m *SafeMap) Range(f func(key, val interface{}) bool) {
m.lock.RLock() m.lock.RLock()
defer m.lock.RUnlock() defer m.lock.RUnlock()
@@ -88,10 +87,8 @@ func (m *SafeMap) Range(f func(key, val any) bool) {
} }
// Set sets the value into m with the given key. // Set sets the value into m with the given key.
func (m *SafeMap) Set(key, value any) { func (m *SafeMap) Set(key, value interface{}) {
m.lock.Lock() m.lock.Lock()
defer m.lock.Unlock()
if m.deletionOld <= maxDeletion { if m.deletionOld <= maxDeletion {
if _, ok := m.dirtyNew[key]; ok { if _, ok := m.dirtyNew[key]; ok {
delete(m.dirtyNew, key) delete(m.dirtyNew, key)
@@ -105,6 +102,7 @@ func (m *SafeMap) Set(key, value any) {
} }
m.dirtyNew[key] = value m.dirtyNew[key] = value
} }
m.lock.Unlock()
} }
// Size returns the size of m. // Size returns the size of m.

View File

@@ -138,7 +138,7 @@ func TestSafeMap_Range(t *testing.T) {
} }
var count int32 var count int32
m.Range(func(k, v any) bool { m.Range(func(k, v interface{}) bool {
atomic.AddInt32(&count, 1) atomic.AddInt32(&count, 1)
newMap.Set(k, v) newMap.Set(k, v)
return true return true
@@ -147,65 +147,3 @@ func TestSafeMap_Range(t *testing.T) {
assert.Equal(t, m.dirtyNew, newMap.dirtyNew) assert.Equal(t, m.dirtyNew, newMap.dirtyNew)
assert.Equal(t, m.dirtyOld, newMap.dirtyOld) assert.Equal(t, m.dirtyOld, newMap.dirtyOld)
} }
func TestSetManyTimes(t *testing.T) {
const iteration = maxDeletion * 2
m := NewSafeMap()
for i := 0; i < iteration; i++ {
m.Set(i, i)
if i%3 == 0 {
m.Del(i / 2)
}
}
var count int
m.Range(func(k, v any) bool {
count++
return count < maxDeletion/2
})
assert.Equal(t, maxDeletion/2, count)
for i := 0; i < iteration; i++ {
m.Set(i, i)
if i%3 == 0 {
m.Del(i / 2)
}
}
for i := 0; i < iteration; i++ {
m.Set(i, i)
if i%3 == 0 {
m.Del(i / 2)
}
}
for i := 0; i < iteration; i++ {
m.Set(i, i)
if i%3 == 0 {
m.Del(i / 2)
}
}
count = 0
m.Range(func(k, v any) bool {
count++
return count < maxDeletion
})
assert.Equal(t, maxDeletion, count)
}
func TestSetManyTimesNew(t *testing.T) {
m := NewSafeMap()
for i := 0; i < maxDeletion*3; i++ {
m.Set(i, i)
}
for i := 0; i < maxDeletion*2; i++ {
m.Del(i)
}
for i := 0; i < maxDeletion*3; i++ {
m.Set(i+maxDeletion*3, i+maxDeletion*3)
}
for i := 0; i < maxDeletion*2; i++ {
m.Del(i + maxDeletion*2)
}
for i := 0; i < maxDeletion-copyThreshold+1; i++ {
m.Del(i + maxDeletion*2)
}
assert.Equal(t, 0, len(m.dirtyNew))
}

View File

@@ -17,14 +17,14 @@ const (
// Set is not thread-safe, for concurrent use, make sure to use it with synchronization. // Set is not thread-safe, for concurrent use, make sure to use it with synchronization.
type Set struct { type Set struct {
data map[any]lang.PlaceholderType data map[interface{}]lang.PlaceholderType
tp int tp int
} }
// NewSet returns a managed Set, can only put the values with the same type. // NewSet returns a managed Set, can only put the values with the same type.
func NewSet() *Set { func NewSet() *Set {
return &Set{ return &Set{
data: make(map[any]lang.PlaceholderType), data: make(map[interface{}]lang.PlaceholderType),
tp: untyped, tp: untyped,
} }
} }
@@ -32,13 +32,13 @@ func NewSet() *Set {
// NewUnmanagedSet returns an unmanaged Set, which can put values with different types. // NewUnmanagedSet returns an unmanaged Set, which can put values with different types.
func NewUnmanagedSet() *Set { func NewUnmanagedSet() *Set {
return &Set{ return &Set{
data: make(map[any]lang.PlaceholderType), data: make(map[interface{}]lang.PlaceholderType),
tp: unmanaged, tp: unmanaged,
} }
} }
// Add adds i into s. // Add adds i into s.
func (s *Set) Add(i ...any) { func (s *Set) Add(i ...interface{}) {
for _, each := range i { for _, each := range i {
s.add(each) s.add(each)
} }
@@ -80,7 +80,7 @@ func (s *Set) AddStr(ss ...string) {
} }
// Contains checks if i is in s. // Contains checks if i is in s.
func (s *Set) Contains(i any) bool { func (s *Set) Contains(i interface{}) bool {
if len(s.data) == 0 { if len(s.data) == 0 {
return false return false
} }
@@ -91,8 +91,8 @@ func (s *Set) Contains(i any) bool {
} }
// Keys returns the keys in s. // Keys returns the keys in s.
func (s *Set) Keys() []any { func (s *Set) Keys() []interface{} {
var keys []any var keys []interface{}
for key := range s.data { for key := range s.data {
keys = append(keys, key) keys = append(keys, key)
@@ -167,7 +167,7 @@ func (s *Set) KeysStr() []string {
} }
// Remove removes i from s. // Remove removes i from s.
func (s *Set) Remove(i any) { func (s *Set) Remove(i interface{}) {
s.validate(i) s.validate(i)
delete(s.data, i) delete(s.data, i)
} }
@@ -177,7 +177,7 @@ func (s *Set) Count() int {
return len(s.data) return len(s.data)
} }
func (s *Set) add(i any) { func (s *Set) add(i interface{}) {
switch s.tp { switch s.tp {
case unmanaged: case unmanaged:
// do nothing // do nothing
@@ -189,7 +189,7 @@ func (s *Set) add(i any) {
s.data[i] = lang.Placeholder s.data[i] = lang.Placeholder
} }
func (s *Set) setType(i any) { func (s *Set) setType(i interface{}) {
// s.tp can only be untyped here // s.tp can only be untyped here
switch i.(type) { switch i.(type) {
case int: case int:
@@ -205,7 +205,7 @@ func (s *Set) setType(i any) {
} }
} }
func (s *Set) validate(i any) { func (s *Set) validate(i interface{}) {
if s.tp == unmanaged { if s.tp == unmanaged {
return return
} }

View File

@@ -13,7 +13,7 @@ func init() {
} }
func BenchmarkRawSet(b *testing.B) { func BenchmarkRawSet(b *testing.B) {
m := make(map[any]struct{}) m := make(map[interface{}]struct{})
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
m[i] = struct{}{} m[i] = struct{}{}
_ = m[i] _ = m[i]
@@ -39,7 +39,7 @@ func BenchmarkSet(b *testing.B) {
func TestAdd(t *testing.T) { func TestAdd(t *testing.T) {
// given // given
set := NewUnmanagedSet() set := NewUnmanagedSet()
values := []any{1, 2, 3} values := []interface{}{1, 2, 3}
// when // when
set.Add(values...) set.Add(values...)
@@ -135,7 +135,7 @@ func TestContainsUnmanagedWithoutElements(t *testing.T) {
func TestRemove(t *testing.T) { func TestRemove(t *testing.T) {
// given // given
set := NewSet() set := NewSet()
set.Add([]any{1, 2, 3}...) set.Add([]interface{}{1, 2, 3}...)
// when // when
set.Remove(2) set.Remove(2)
@@ -147,7 +147,7 @@ func TestRemove(t *testing.T) {
func TestCount(t *testing.T) { func TestCount(t *testing.T) {
// given // given
set := NewSet() set := NewSet()
set.Add([]any{1, 2, 3}...) set.Add([]interface{}{1, 2, 3}...)
// then // then
assert.Equal(t, set.Count(), 3) assert.Equal(t, set.Count(), 3)
@@ -198,5 +198,5 @@ func TestSetType(t *testing.T) {
set.add(1) set.add(1)
set.add("2") set.add("2")
vals := set.Keys() vals := set.Keys()
assert.ElementsMatch(t, []any{1, "2"}, vals) assert.ElementsMatch(t, []interface{}{1, "2"}, vals)
} }

View File

@@ -20,7 +20,7 @@ var (
type ( type (
// Execute defines the method to execute the task. // Execute defines the method to execute the task.
Execute func(key, value any) Execute func(key, value interface{})
// A TimingWheel is a timing wheel object to schedule tasks. // A TimingWheel is a timing wheel object to schedule tasks.
TimingWheel struct { TimingWheel struct {
@@ -33,14 +33,14 @@ type (
execute Execute execute Execute
setChannel chan timingEntry setChannel chan timingEntry
moveChannel chan baseEntry moveChannel chan baseEntry
removeChannel chan any removeChannel chan interface{}
drainChannel chan func(key, value any) drainChannel chan func(key, value interface{})
stopChannel chan lang.PlaceholderType stopChannel chan lang.PlaceholderType
} }
timingEntry struct { timingEntry struct {
baseEntry baseEntry
value any value interface{}
circle int circle int
diff int diff int
removed bool removed bool
@@ -48,7 +48,7 @@ type (
baseEntry struct { baseEntry struct {
delay time.Duration delay time.Duration
key any key interface{}
} }
positionEntry struct { positionEntry struct {
@@ -57,8 +57,8 @@ type (
} }
timingTask struct { timingTask struct {
key any key interface{}
value any value interface{}
} }
) )
@@ -85,8 +85,8 @@ func NewTimingWheelWithTicker(interval time.Duration, numSlots int, execute Exec
numSlots: numSlots, numSlots: numSlots,
setChannel: make(chan timingEntry), setChannel: make(chan timingEntry),
moveChannel: make(chan baseEntry), moveChannel: make(chan baseEntry),
removeChannel: make(chan any), removeChannel: make(chan interface{}),
drainChannel: make(chan func(key, value any)), drainChannel: make(chan func(key, value interface{})),
stopChannel: make(chan lang.PlaceholderType), stopChannel: make(chan lang.PlaceholderType),
} }
@@ -97,7 +97,7 @@ func NewTimingWheelWithTicker(interval time.Duration, numSlots int, execute Exec
} }
// Drain drains all items and executes them. // Drain drains all items and executes them.
func (tw *TimingWheel) Drain(fn func(key, value any)) error { func (tw *TimingWheel) Drain(fn func(key, value interface{})) error {
select { select {
case tw.drainChannel <- fn: case tw.drainChannel <- fn:
return nil return nil
@@ -107,7 +107,7 @@ func (tw *TimingWheel) Drain(fn func(key, value any)) error {
} }
// MoveTimer moves the task with the given key to the given delay. // MoveTimer moves the task with the given key to the given delay.
func (tw *TimingWheel) MoveTimer(key any, delay time.Duration) error { func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) error {
if delay <= 0 || key == nil { if delay <= 0 || key == nil {
return ErrArgument return ErrArgument
} }
@@ -124,7 +124,7 @@ func (tw *TimingWheel) MoveTimer(key any, delay time.Duration) error {
} }
// RemoveTimer removes the task with the given key. // RemoveTimer removes the task with the given key.
func (tw *TimingWheel) RemoveTimer(key any) error { func (tw *TimingWheel) RemoveTimer(key interface{}) error {
if key == nil { if key == nil {
return ErrArgument return ErrArgument
} }
@@ -138,7 +138,7 @@ func (tw *TimingWheel) RemoveTimer(key any) error {
} }
// SetTimer sets the task value with the given key to the delay. // SetTimer sets the task value with the given key to the delay.
func (tw *TimingWheel) SetTimer(key, value any, delay time.Duration) error { func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) error {
if delay <= 0 || key == nil { if delay <= 0 || key == nil {
return ErrArgument return ErrArgument
} }
@@ -162,7 +162,7 @@ func (tw *TimingWheel) Stop() {
close(tw.stopChannel) close(tw.stopChannel)
} }
func (tw *TimingWheel) drainAll(fn func(key, value any)) { func (tw *TimingWheel) drainAll(fn func(key, value interface{})) {
runner := threading.NewTaskRunner(drainWorkers) runner := threading.NewTaskRunner(drainWorkers)
for _, slot := range tw.slots { for _, slot := range tw.slots {
for e := slot.Front(); e != nil; { for e := slot.Front(); e != nil; {
@@ -232,7 +232,7 @@ func (tw *TimingWheel) onTick() {
tw.scanAndRunTasks(l) tw.scanAndRunTasks(l)
} }
func (tw *TimingWheel) removeTask(key any) { func (tw *TimingWheel) removeTask(key interface{}) {
val, ok := tw.timers.Get(key) val, ok := tw.timers.Get(key)
if !ok { if !ok {
return return

View File

@@ -20,13 +20,13 @@ const (
) )
func TestNewTimingWheel(t *testing.T) { func TestNewTimingWheel(t *testing.T) {
_, err := NewTimingWheel(0, 10, func(key, value any) {}) _, err := NewTimingWheel(0, 10, func(key, value interface{}) {})
assert.NotNil(t, err) assert.NotNil(t, err)
} }
func TestTimingWheel_Drain(t *testing.T) { func TestTimingWheel_Drain(t *testing.T) {
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) { tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
}, ticker) }, ticker)
tw.SetTimer("first", 3, testStep*4) tw.SetTimer("first", 3, testStep*4)
tw.SetTimer("second", 5, testStep*7) tw.SetTimer("second", 5, testStep*7)
@@ -36,7 +36,7 @@ func TestTimingWheel_Drain(t *testing.T) {
var lock sync.Mutex var lock sync.Mutex
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(3) wg.Add(3)
tw.Drain(func(key, value any) { tw.Drain(func(key, value interface{}) {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
keys = append(keys, key.(string)) keys = append(keys, key.(string))
@@ -50,19 +50,19 @@ func TestTimingWheel_Drain(t *testing.T) {
assert.EqualValues(t, []string{"first", "second", "third"}, keys) assert.EqualValues(t, []string{"first", "second", "third"}, keys)
assert.EqualValues(t, []int{3, 5, 7}, vals) assert.EqualValues(t, []int{3, 5, 7}, vals)
var count int var count int
tw.Drain(func(key, value any) { tw.Drain(func(key, value interface{}) {
count++ count++
}) })
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
assert.Equal(t, 0, count) assert.Equal(t, 0, count)
tw.Stop() tw.Stop()
assert.Equal(t, ErrClosed, tw.Drain(func(key, value any) {})) assert.Equal(t, ErrClosed, tw.Drain(func(key, value interface{}) {}))
} }
func TestTimingWheel_SetTimerSoon(t *testing.T) { func TestTimingWheel_SetTimerSoon(t *testing.T) {
run := syncx.NewAtomicBool() run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) { tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true)) assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k) assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int)) assert.Equal(t, 3, v.(int))
@@ -78,7 +78,7 @@ func TestTimingWheel_SetTimerSoon(t *testing.T) {
func TestTimingWheel_SetTimerTwice(t *testing.T) { func TestTimingWheel_SetTimerTwice(t *testing.T) {
run := syncx.NewAtomicBool() run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) { tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true)) assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k) assert.Equal(t, "any", k)
assert.Equal(t, 5, v.(int)) assert.Equal(t, 5, v.(int))
@@ -96,7 +96,7 @@ func TestTimingWheel_SetTimerTwice(t *testing.T) {
func TestTimingWheel_SetTimerWrongDelay(t *testing.T) { func TestTimingWheel_SetTimerWrongDelay(t *testing.T) {
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker) tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
defer tw.Stop() defer tw.Stop()
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
tw.SetTimer("any", 3, -testStep) tw.SetTimer("any", 3, -testStep)
@@ -105,7 +105,7 @@ func TestTimingWheel_SetTimerWrongDelay(t *testing.T) {
func TestTimingWheel_SetTimerAfterClose(t *testing.T) { func TestTimingWheel_SetTimerAfterClose(t *testing.T) {
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker) tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
tw.Stop() tw.Stop()
assert.Equal(t, ErrClosed, tw.SetTimer("any", 3, testStep)) assert.Equal(t, ErrClosed, tw.SetTimer("any", 3, testStep))
} }
@@ -113,7 +113,7 @@ func TestTimingWheel_SetTimerAfterClose(t *testing.T) {
func TestTimingWheel_MoveTimer(t *testing.T) { func TestTimingWheel_MoveTimer(t *testing.T) {
run := syncx.NewAtomicBool() run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v any) { tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true)) assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k) assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int)) assert.Equal(t, 3, v.(int))
@@ -139,7 +139,7 @@ func TestTimingWheel_MoveTimer(t *testing.T) {
func TestTimingWheel_MoveTimerSoon(t *testing.T) { func TestTimingWheel_MoveTimerSoon(t *testing.T) {
run := syncx.NewAtomicBool() run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v any) { tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true)) assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k) assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int)) assert.Equal(t, 3, v.(int))
@@ -155,7 +155,7 @@ func TestTimingWheel_MoveTimerSoon(t *testing.T) {
func TestTimingWheel_MoveTimerEarlier(t *testing.T) { func TestTimingWheel_MoveTimerEarlier(t *testing.T) {
run := syncx.NewAtomicBool() run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) { tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true)) assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k) assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int)) assert.Equal(t, 3, v.(int))
@@ -173,7 +173,7 @@ func TestTimingWheel_MoveTimerEarlier(t *testing.T) {
func TestTimingWheel_RemoveTimer(t *testing.T) { func TestTimingWheel_RemoveTimer(t *testing.T) {
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker) tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
tw.SetTimer("any", 3, testStep) tw.SetTimer("any", 3, testStep)
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
tw.RemoveTimer("any") tw.RemoveTimer("any")
@@ -236,7 +236,7 @@ func TestTimingWheel_SetTimer(t *testing.T) {
} }
var actual int32 var actual int32
done := make(chan lang.PlaceholderType) done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) { tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
assert.Equal(t, 1, key.(int)) assert.Equal(t, 1, key.(int))
assert.Equal(t, 2, value.(int)) assert.Equal(t, 2, value.(int))
actual = atomic.LoadInt32(&count) actual = atomic.LoadInt32(&count)
@@ -317,7 +317,7 @@ func TestTimingWheel_SetAndMoveThenStart(t *testing.T) {
} }
var actual int32 var actual int32
done := make(chan lang.PlaceholderType) done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) { tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count) actual = atomic.LoadInt32(&count)
close(done) close(done)
}, ticker) }, ticker)
@@ -405,7 +405,7 @@ func TestTimingWheel_SetAndMoveTwice(t *testing.T) {
} }
var actual int32 var actual int32
done := make(chan lang.PlaceholderType) done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) { tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count) actual = atomic.LoadInt32(&count)
close(done) close(done)
}, ticker) }, ticker)
@@ -486,7 +486,7 @@ func TestTimingWheel_ElapsedAndSet(t *testing.T) {
} }
var actual int32 var actual int32
done := make(chan lang.PlaceholderType) done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) { tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count) actual = atomic.LoadInt32(&count)
close(done) close(done)
}, ticker) }, ticker)
@@ -577,7 +577,7 @@ func TestTimingWheel_ElapsedAndSetThenMove(t *testing.T) {
} }
var actual int32 var actual int32
done := make(chan lang.PlaceholderType) done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) { tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count) actual = atomic.LoadInt32(&count)
close(done) close(done)
}, ticker) }, ticker)
@@ -612,7 +612,7 @@ func TestMoveAndRemoveTask(t *testing.T) {
} }
} }
var keys []int var keys []int
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) { tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
assert.Equal(t, "any", k) assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int)) assert.Equal(t, 3, v.(int))
keys = append(keys, v.(int)) keys = append(keys, v.(int))
@@ -632,7 +632,7 @@ func TestMoveAndRemoveTask(t *testing.T) {
func BenchmarkTimingWheel(b *testing.B) { func BenchmarkTimingWheel(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
tw, _ := NewTimingWheel(time.Second, 100, func(k, v any) {}) tw, _ := NewTimingWheel(time.Second, 100, func(k, v interface{}) {})
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tw.SetTimer(i, i, time.Second) tw.SetTimer(i, i, time.Second)
tw.SetTimer(b.N+i, b.N+i, time.Second) tw.SetTimer(b.N+i, b.N+i, time.Second)

View File

@@ -13,14 +13,11 @@ import (
"github.com/zeromicro/go-zero/internal/encoding" "github.com/zeromicro/go-zero/internal/encoding"
) )
const ( const jsonTagKey = "json"
jsonTagKey = "json"
jsonTagSep = ','
)
var ( var (
fillDefaultUnmarshaler = mapping.NewUnmarshaler(jsonTagKey, mapping.WithDefault()) fillDefaultUnmarshaler = mapping.NewUnmarshaler(jsonTagKey, mapping.WithDefault())
loaders = map[string]func([]byte, any) error{ loaders = map[string]func([]byte, interface{}) error{
".json": LoadFromJsonBytes, ".json": LoadFromJsonBytes,
".toml": LoadFromTomlBytes, ".toml": LoadFromTomlBytes,
".yaml": LoadFromYamlBytes, ".yaml": LoadFromYamlBytes,
@@ -37,12 +34,12 @@ type fieldInfo struct {
// FillDefault fills the default values for the given v, // FillDefault fills the default values for the given v,
// and the premise is that the value of v must be guaranteed to be empty. // and the premise is that the value of v must be guaranteed to be empty.
func FillDefault(v any) error { func FillDefault(v interface{}) error {
return fillDefaultUnmarshaler.Unmarshal(map[string]any{}, v) return fillDefaultUnmarshaler.Unmarshal(map[string]interface{}{}, v)
} }
// Load loads config into v from file, .json, .yaml and .yml are acceptable. // Load loads config into v from file, .json, .yaml and .yml are acceptable.
func Load(file string, v any, opts ...Option) error { func Load(file string, v interface{}, opts ...Option) error {
content, err := os.ReadFile(file) content, err := os.ReadFile(file)
if err != nil { if err != nil {
return err return err
@@ -67,19 +64,19 @@ func Load(file string, v any, opts ...Option) error {
// LoadConfig loads config into v from file, .json, .yaml and .yml are acceptable. // LoadConfig loads config into v from file, .json, .yaml and .yml are acceptable.
// Deprecated: use Load instead. // Deprecated: use Load instead.
func LoadConfig(file string, v any, opts ...Option) error { func LoadConfig(file string, v interface{}, opts ...Option) error {
return Load(file, v, opts...) return Load(file, v, opts...)
} }
// LoadFromJsonBytes loads config into v from content json bytes. // LoadFromJsonBytes loads config into v from content json bytes.
func LoadFromJsonBytes(content []byte, v any) error { func LoadFromJsonBytes(content []byte, v interface{}) error {
info, err := buildFieldsInfo(reflect.TypeOf(v), "") info, err := buildFieldsInfo(reflect.TypeOf(v))
if err != nil { if err != nil {
return err return err
} }
var m map[string]any var m map[string]interface{}
if err = jsonx.Unmarshal(content, &m); err != nil { if err := jsonx.Unmarshal(content, &m); err != nil {
return err return err
} }
@@ -90,12 +87,12 @@ func LoadFromJsonBytes(content []byte, v any) error {
// LoadConfigFromJsonBytes loads config into v from content json bytes. // LoadConfigFromJsonBytes loads config into v from content json bytes.
// Deprecated: use LoadFromJsonBytes instead. // Deprecated: use LoadFromJsonBytes instead.
func LoadConfigFromJsonBytes(content []byte, v any) error { func LoadConfigFromJsonBytes(content []byte, v interface{}) error {
return LoadFromJsonBytes(content, v) return LoadFromJsonBytes(content, v)
} }
// LoadFromTomlBytes loads config into v from content toml bytes. // LoadFromTomlBytes loads config into v from content toml bytes.
func LoadFromTomlBytes(content []byte, v any) error { func LoadFromTomlBytes(content []byte, v interface{}) error {
b, err := encoding.TomlToJson(content) b, err := encoding.TomlToJson(content)
if err != nil { if err != nil {
return err return err
@@ -105,7 +102,7 @@ func LoadFromTomlBytes(content []byte, v any) error {
} }
// LoadFromYamlBytes loads config into v from content yaml bytes. // LoadFromYamlBytes loads config into v from content yaml bytes.
func LoadFromYamlBytes(content []byte, v any) error { func LoadFromYamlBytes(content []byte, v interface{}) error {
b, err := encoding.YamlToJson(content) b, err := encoding.YamlToJson(content)
if err != nil { if err != nil {
return err return err
@@ -116,24 +113,24 @@ func LoadFromYamlBytes(content []byte, v any) error {
// LoadConfigFromYamlBytes loads config into v from content yaml bytes. // LoadConfigFromYamlBytes loads config into v from content yaml bytes.
// Deprecated: use LoadFromYamlBytes instead. // Deprecated: use LoadFromYamlBytes instead.
func LoadConfigFromYamlBytes(content []byte, v any) error { func LoadConfigFromYamlBytes(content []byte, v interface{}) error {
return LoadFromYamlBytes(content, v) return LoadFromYamlBytes(content, v)
} }
// MustLoad loads config into v from path, exits on error. // MustLoad loads config into v from path, exits on error.
func MustLoad(path string, v any, opts ...Option) { func MustLoad(path string, v interface{}, opts ...Option) {
if err := Load(path, v, opts...); err != nil { if err := Load(path, v, opts...); err != nil {
log.Fatalf("error: config file %s, %s", path, err.Error()) log.Fatalf("error: config file %s, %s", path, err.Error())
} }
} }
func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo, fullName string) error { func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo) error {
if prev, ok := info.children[key]; ok { if prev, ok := info.children[key]; ok {
if child.mapField != nil { if child.mapField != nil {
return newConflictKeyError(fullName) return newDupKeyError(key)
} }
if err := mergeFields(prev, child.children, fullName); err != nil { if err := mergeFields(prev, key, child.children); err != nil {
return err return err
} }
} else { } else {
@@ -143,27 +140,27 @@ func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo, fullName st
return nil return nil
} }
func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type, fullName string) error { func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
switch ft.Kind() { switch ft.Kind() {
case reflect.Struct: case reflect.Struct:
fields, err := buildFieldsInfo(ft, fullName) fields, err := buildFieldsInfo(ft)
if err != nil { if err != nil {
return err return err
} }
for k, v := range fields.children { for k, v := range fields.children {
if err = addOrMergeFields(info, k, v, fullName); err != nil { if err = addOrMergeFields(info, k, v); err != nil {
return err return err
} }
} }
case reflect.Map: case reflect.Map:
elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()), fullName) elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
if err != nil { if err != nil {
return err return err
} }
if _, ok := info.children[lowerCaseName]; ok { if _, ok := info.children[lowerCaseName]; ok {
return newConflictKeyError(fullName) return newDupKeyError(lowerCaseName)
} }
info.children[lowerCaseName] = &fieldInfo{ info.children[lowerCaseName] = &fieldInfo{
@@ -172,7 +169,7 @@ func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.T
} }
default: default:
if _, ok := info.children[lowerCaseName]; ok { if _, ok := info.children[lowerCaseName]; ok {
return newConflictKeyError(fullName) return newDupKeyError(lowerCaseName)
} }
info.children[lowerCaseName] = &fieldInfo{ info.children[lowerCaseName] = &fieldInfo{
@@ -183,14 +180,14 @@ func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.T
return nil return nil
} }
func buildFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) { func buildFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
tp = mapping.Deref(tp) tp = mapping.Deref(tp)
switch tp.Kind() { switch tp.Kind() {
case reflect.Struct: case reflect.Struct:
return buildStructFieldsInfo(tp, fullName) return buildStructFieldsInfo(tp)
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
return buildFieldsInfo(mapping.Deref(tp.Elem()), fullName) return buildFieldsInfo(mapping.Deref(tp.Elem()))
case reflect.Chan, reflect.Func: case reflect.Chan, reflect.Func:
return nil, fmt.Errorf("unsupported type: %s", tp.Kind()) return nil, fmt.Errorf("unsupported type: %s", tp.Kind())
default: default:
@@ -200,23 +197,23 @@ func buildFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) {
} }
} }
func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type, fullName string) error { func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
var finfo *fieldInfo var finfo *fieldInfo
var err error var err error
switch ft.Kind() { switch ft.Kind() {
case reflect.Struct: case reflect.Struct:
finfo, err = buildFieldsInfo(ft, fullName) finfo, err = buildFieldsInfo(ft)
if err != nil { if err != nil {
return err return err
} }
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
finfo, err = buildFieldsInfo(ft.Elem(), fullName) finfo, err = buildFieldsInfo(ft.Elem())
if err != nil { if err != nil {
return err return err
} }
case reflect.Map: case reflect.Map:
elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()), fullName) elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
if err != nil { if err != nil {
return err return err
} }
@@ -226,37 +223,31 @@ func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type,
mapField: elemInfo, mapField: elemInfo,
} }
default: default:
finfo, err = buildFieldsInfo(ft, fullName) finfo, err = buildFieldsInfo(ft)
if err != nil { if err != nil {
return err return err
} }
} }
return addOrMergeFields(info, lowerCaseName, finfo, fullName) return addOrMergeFields(info, lowerCaseName, finfo)
} }
func buildStructFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) { func buildStructFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
info := &fieldInfo{ info := &fieldInfo{
children: make(map[string]*fieldInfo), children: make(map[string]*fieldInfo),
} }
for i := 0; i < tp.NumField(); i++ { for i := 0; i < tp.NumField(); i++ {
field := tp.Field(i) field := tp.Field(i)
if !field.IsExported() { name := field.Name
continue
}
name := getTagName(field)
lowerCaseName := toLowerCase(name) lowerCaseName := toLowerCase(name)
ft := mapping.Deref(field.Type) ft := mapping.Deref(field.Type)
// flatten anonymous fields // flatten anonymous fields
if field.Anonymous { if field.Anonymous {
if err := buildAnonymousFieldInfo(info, lowerCaseName, ft, if err := buildAnonymousFieldInfo(info, lowerCaseName, ft); err != nil {
getFullName(fullName, lowerCaseName)); err != nil {
return nil, err return nil, err
} }
} else if err := buildNamedFieldInfo(info, lowerCaseName, ft, } else if err := buildNamedFieldInfo(info, lowerCaseName, ft); err != nil {
getFullName(fullName, lowerCaseName)); err != nil {
return nil, err return nil, err
} }
} }
@@ -264,32 +255,15 @@ func buildStructFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error)
return info, nil return info, nil
} }
// getTagName get the tag name of the given field, if no tag name, use file.Name. func mergeFields(prev *fieldInfo, key string, children map[string]*fieldInfo) error {
// field.Name is returned on tags like `json:""` and `json:",optional"`.
func getTagName(field reflect.StructField) string {
if tag, ok := field.Tag.Lookup(jsonTagKey); ok {
if pos := strings.IndexByte(tag, jsonTagSep); pos >= 0 {
tag = tag[:pos]
}
tag = strings.TrimSpace(tag)
if len(tag) > 0 {
return tag
}
}
return field.Name
}
func mergeFields(prev *fieldInfo, children map[string]*fieldInfo, fullName string) error {
if len(prev.children) == 0 || len(children) == 0 { if len(prev.children) == 0 || len(children) == 0 {
return newConflictKeyError(fullName) return newDupKeyError(key)
} }
// merge fields // merge fields
for k, v := range children { for k, v := range children {
if _, ok := prev.children[k]; ok { if _, ok := prev.children[k]; ok {
return newConflictKeyError(fullName) return newDupKeyError(k)
} }
prev.children[k] = v prev.children[k] = v
@@ -302,12 +276,12 @@ func toLowerCase(s string) string {
return strings.ToLower(s) return strings.ToLower(s)
} }
func toLowerCaseInterface(v any, info *fieldInfo) any { func toLowerCaseInterface(v interface{}, info *fieldInfo) interface{} {
switch vv := v.(type) { switch vv := v.(type) {
case map[string]any: case map[string]interface{}:
return toLowerCaseKeyMap(vv, info) return toLowerCaseKeyMap(vv, info)
case []any: case []interface{}:
var arr []any var arr []interface{}
for _, vvv := range vv { for _, vvv := range vv {
arr = append(arr, toLowerCaseInterface(vvv, info)) arr = append(arr, toLowerCaseInterface(vvv, info))
} }
@@ -317,8 +291,8 @@ func toLowerCaseInterface(v any, info *fieldInfo) any {
} }
} }
func toLowerCaseKeyMap(m map[string]any, info *fieldInfo) map[string]any { func toLowerCaseKeyMap(m map[string]interface{}, info *fieldInfo) map[string]interface{} {
res := make(map[string]any) res := make(map[string]interface{})
for k, v := range m { for k, v := range m {
ti, ok := info.children[k] ti, ok := info.children[k]
@@ -340,22 +314,14 @@ func toLowerCaseKeyMap(m map[string]any, info *fieldInfo) map[string]any {
return res return res
} }
type conflictKeyError struct { type dupKeyError struct {
key string key string
} }
func newConflictKeyError(key string) conflictKeyError { func newDupKeyError(key string) dupKeyError {
return conflictKeyError{key: key} return dupKeyError{key: key}
} }
func (e conflictKeyError) Error() string { func (e dupKeyError) Error() string {
return fmt.Sprintf("conflict key %s, pay attention to anonymous fields", e.key) return fmt.Sprintf("duplicated key %s", e.key)
}
func getFullName(parent, child string) string {
if len(parent) == 0 {
return child
}
return strings.Join([]string{parent, child}, ".")
} }

View File

@@ -2,7 +2,6 @@ package conf
import ( import (
"os" "os"
"reflect"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -10,7 +9,7 @@ import (
"github.com/zeromicro/go-zero/core/hash" "github.com/zeromicro/go-zero/core/hash"
) )
var dupErr conflictKeyError var dupErr dupKeyError
func TestLoadConfig_notExists(t *testing.T) { func TestLoadConfig_notExists(t *testing.T) {
assert.NotNil(t, Load("not_a_file", nil)) assert.NotNil(t, Load("not_a_file", nil))
@@ -35,11 +34,11 @@ func TestConfigJson(t *testing.T) {
"c": "${FOO}", "c": "${FOO}",
"d": "abcd!@#$112" "d": "abcd!@#$112"
}` }`
t.Setenv("FOO", "2")
for _, test := range tests { for _, test := range tests {
test := test test := test
t.Run(test, func(t *testing.T) { t.Run(test, func(t *testing.T) {
os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(test, text) tmpfile, err := createTempFile(test, text)
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(tmpfile) defer os.Remove(tmpfile)
@@ -81,7 +80,8 @@ b = 1
c = "${FOO}" c = "${FOO}"
d = "abcd!@#$112" d = "abcd!@#$112"
` `
t.Setenv("FOO", "2") os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(".toml", text) tmpfile, err := createTempFile(".toml", text)
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(tmpfile) defer os.Remove(tmpfile)
@@ -123,24 +123,6 @@ d = "abcd"
} }
} }
func TestConfigWithLower(t *testing.T) {
text := `a = "foo"
b = 1
`
tmpfile, err := createTempFile(".toml", text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
var val struct {
A string `json:"a"`
b int
}
if assert.NoError(t, Load(tmpfile, &val)) {
assert.Equal(t, "foo", val.A)
assert.Equal(t, 0, val.b)
}
}
func TestConfigJsonCanonical(t *testing.T) { func TestConfigJsonCanonical(t *testing.T) {
text := []byte(`{"a": "foo", "B": "bar"}`) text := []byte(`{"a": "foo", "B": "bar"}`)
@@ -206,7 +188,8 @@ b = 1
c = "${FOO}" c = "${FOO}"
d = "abcd!@#112" d = "abcd!@#112"
` `
t.Setenv("FOO", "2") os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(".toml", text) tmpfile, err := createTempFile(".toml", text)
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(tmpfile) defer os.Remove(tmpfile)
@@ -237,10 +220,11 @@ func TestConfigJsonEnv(t *testing.T) {
"c": "${FOO}", "c": "${FOO}",
"d": "abcd!@#$a12 3" "d": "abcd!@#$a12 3"
}` }`
t.Setenv("FOO", "2")
for _, test := range tests { for _, test := range tests {
test := test test := test
t.Run(test, func(t *testing.T) { t.Run(test, func(t *testing.T) {
os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(test, text) tmpfile, err := createTempFile(test, text)
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(tmpfile) defer os.Remove(tmpfile)
@@ -648,7 +632,7 @@ func Test_FieldOverwrite(t *testing.T) {
Name2 *string Name2 *string
} }
validate := func(val any) { validate := func(val interface{}) {
input := []byte(`{"Name": "hello", "Name2": "world"}`) input := []byte(`{"Name": "hello", "Name2": "world"}`)
assert.NoError(t, LoadFromJsonBytes(input, val)) assert.NoError(t, LoadFromJsonBytes(input, val))
} }
@@ -684,11 +668,11 @@ func Test_FieldOverwrite(t *testing.T) {
Name *string Name *string
} }
validate := func(val any) { validate := func(val interface{}) {
input := []byte(`{"Name": "hello"}`) input := []byte(`{"Name": "hello"}`)
err := LoadFromJsonBytes(input, val) err := LoadFromJsonBytes(input, val)
assert.ErrorAs(t, err, &dupErr) assert.ErrorAs(t, err, &dupErr)
assert.Equal(t, newConflictKeyError("name").Error(), err.Error()) assert.Equal(t, newDupKeyError("name").Error(), err.Error())
} }
validate(&St1{}) validate(&St1{})
@@ -727,11 +711,11 @@ func Test_FieldOverwrite(t *testing.T) {
Name *int Name *int
} }
validate := func(val any) { validate := func(val interface{}) {
input := []byte(`{"Name": "hello"}`) input := []byte(`{"Name": "hello"}`)
err := LoadFromJsonBytes(input, val) err := LoadFromJsonBytes(input, val)
assert.ErrorAs(t, err, &dupErr) assert.ErrorAs(t, err, &dupErr)
assert.Error(t, err) assert.Equal(t, newDupKeyError("name").Error(), err.Error())
} }
validate(&St0{}) validate(&St0{})
@@ -1038,22 +1022,22 @@ func TestLoadNamedFieldOverwritten(t *testing.T) {
}) })
} }
func TestLoadLowerMemberShouldNotConflict(t *testing.T) { func createTempFile(ext, text string) (string, error) {
type ( tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
Redis struct { if err != nil {
db uint return "", err
} }
Config struct { if err := os.WriteFile(tmpfile.Name(), []byte(text), os.ModeTemporary); err != nil {
db uint return "", err
Redis }
}
)
var c Config filename := tmpfile.Name()
assert.NoError(t, LoadFromJsonBytes([]byte(`{}`), &c)) if err = tmpfile.Close(); err != nil {
assert.Zero(t, c.db) return "", err
assert.Zero(t, c.Redis.db) }
return filename, nil
} }
func TestFillDefaultUnmarshal(t *testing.T) { func TestFillDefaultUnmarshal(t *testing.T) {
@@ -1095,7 +1079,7 @@ func TestFillDefaultUnmarshal(t *testing.T) {
assert.Equal(t, st.C, "c") assert.Equal(t, st.C, "c")
}) })
t.Run("has value", func(t *testing.T) { t.Run("has vaue", func(t *testing.T) {
type St struct { type St struct {
A string `json:",default=a"` A string `json:",default=a"`
B string B string
@@ -1107,201 +1091,3 @@ func TestFillDefaultUnmarshal(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
}) })
} }
func TestConfigWithJsonTag(t *testing.T) {
t.Run("map with value", func(t *testing.T) {
var input = []byte(`[Value]
[Value.first]
Email = "foo"
[Value.second]
Email = "bar"`)
type Value struct {
Email string
}
type Config struct {
ValueMap map[string]Value `json:"Value"`
}
var c Config
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
assert.Len(t, c.ValueMap, 2)
}
})
t.Run("map with ptr value", func(t *testing.T) {
var input = []byte(`[Value]
[Value.first]
Email = "foo"
[Value.second]
Email = "bar"`)
type Value struct {
Email string
}
type Config struct {
ValueMap map[string]*Value `json:"Value"`
}
var c Config
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
assert.Len(t, c.ValueMap, 2)
}
})
t.Run("map with optional", func(t *testing.T) {
var input = []byte(`[Value]
[Value.first]
Email = "foo"
[Value.second]
Email = "bar"`)
type Value struct {
Email string
}
type Config struct {
Value map[string]Value `json:",optional"`
}
var c Config
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
assert.Len(t, c.Value, 2)
}
})
t.Run("map with empty tag", func(t *testing.T) {
var input = []byte(`[Value]
[Value.first]
Email = "foo"
[Value.second]
Email = "bar"`)
type Value struct {
Email string
}
type Config struct {
Value map[string]Value `json:" "`
}
var c Config
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
assert.Len(t, c.Value, 2)
}
})
}
func Test_getFullName(t *testing.T) {
assert.Equal(t, "a.b", getFullName("a", "b"))
assert.Equal(t, "a", getFullName("", "a"))
}
func Test_buildFieldsInfo(t *testing.T) {
type ParentSt struct {
Name string
M map[string]int
}
tests := []struct {
name string
t reflect.Type
ok bool
containsKey string
}{
{
name: "normal",
t: reflect.TypeOf(struct{ A string }{}),
ok: true,
},
{
name: "struct anonymous",
t: reflect.TypeOf(struct {
ParentSt
Name string
}{}),
ok: false,
containsKey: newConflictKeyError("name").Error(),
},
{
name: "struct ptr anonymous",
t: reflect.TypeOf(struct {
*ParentSt
Name string
}{}),
ok: false,
containsKey: newConflictKeyError("name").Error(),
},
{
name: "more struct anonymous",
t: reflect.TypeOf(struct {
Value struct {
ParentSt
Name string
}
}{}),
ok: false,
containsKey: newConflictKeyError("value.name").Error(),
},
{
name: "map anonymous",
t: reflect.TypeOf(struct {
ParentSt
M string
}{}),
ok: false,
containsKey: newConflictKeyError("m").Error(),
},
{
name: "map more anonymous",
t: reflect.TypeOf(struct {
Value struct {
ParentSt
M string
}
}{}),
ok: false,
containsKey: newConflictKeyError("value.m").Error(),
},
{
name: "struct slice anonymous",
t: reflect.TypeOf([]struct {
ParentSt
Name string
}{}),
ok: false,
containsKey: newConflictKeyError("name").Error(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := buildFieldsInfo(tt.t, "")
if tt.ok {
assert.NoError(t, err)
} else {
assert.Error(t, err)
assert.Equal(t, err.Error(), tt.containsKey)
}
})
}
}
func createTempFile(ext, text string) (string, error) {
tmpFile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
if err != nil {
return "", err
}
if err := os.WriteFile(tmpFile.Name(), []byte(text), os.ModeTemporary); err != nil {
return "", err
}
filename := tmpFile.Name()
if err = tmpFile.Close(); err != nil {
return "", err
}
return filename, nil
}

View File

@@ -45,7 +45,8 @@ func TestPropertiesEnv(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(tmpfile) defer os.Remove(tmpfile)
t.Setenv("FOO", "2") os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
props, err := LoadProperties(tmpfile, UseEnv()) props, err := LoadProperties(tmpfile, UseEnv())
assert.Nil(t, err) assert.Nil(t, err)

View File

@@ -12,7 +12,7 @@ type RestfulConf struct {
MaxConns int `json:",default=10000"` MaxConns int `json:",default=10000"`
MaxBytes int64 `json:",default=1048576"` MaxBytes int64 `json:",default=1048576"`
Timeout time.Duration `json:",default=3s"` Timeout time.Duration `json:",default=3s"`
CpuThreshold int64 `json:",default=900,range=[0:1000)"` CpuThreshold int64 `json:",default=900,range=[0:1000]"`
} }
``` ```

View File

@@ -14,13 +14,13 @@ type contextValuer struct {
context.Context context.Context
} }
func (cv contextValuer) Value(key string) (any, bool) { func (cv contextValuer) Value(key string) (interface{}, bool) {
v := cv.Context.Value(key) v := cv.Context.Value(key)
return v, v != nil return v, v != nil
} }
// For unmarshals ctx into v. // For unmarshals ctx into v.
func For(ctx context.Context, v any) error { func For(ctx context.Context, v interface{}) error {
return unmarshaler.UnmarshalValuer(contextValuer{ return unmarshaler.UnmarshalValuer(contextValuer{
Context: ctx, Context: ctx,
}, v) }, v)

View File

@@ -13,7 +13,6 @@ var (
type EtcdConf struct { type EtcdConf struct {
Hosts []string Hosts []string
Key string Key string
ID int64 `json:",optional"`
User string `json:",optional"` User string `json:",optional"`
Pass string `json:",optional"` Pass string `json:",optional"`
CertFile string `json:",optional"` CertFile string `json:",optional"`
@@ -27,11 +26,6 @@ func (c EtcdConf) HasAccount() bool {
return len(c.User) > 0 && len(c.Pass) > 0 return len(c.User) > 0 && len(c.Pass) > 0
} }
// HasID returns if ID provided.
func (c EtcdConf) HasID() bool {
return c.ID > 0
}
// HasTLS returns if TLS CertFile/CertKeyFile/CACertFile are provided. // HasTLS returns if TLS CertFile/CertKeyFile/CACertFile are provided.
func (c EtcdConf) HasTLS() bool { func (c EtcdConf) HasTLS() bool {
return len(c.CertFile) > 0 && len(c.CertKeyFile) > 0 && len(c.CACertFile) > 0 return len(c.CertFile) > 0 && len(c.CertKeyFile) > 0 && len(c.CACertFile) > 0

View File

@@ -80,90 +80,3 @@ func TestEtcdConf_HasAccount(t *testing.T) {
assert.Equal(t, test.hasAccount, test.EtcdConf.HasAccount()) assert.Equal(t, test.hasAccount, test.EtcdConf.HasAccount())
} }
} }
func TestEtcdConf_HasID(t *testing.T) {
tests := []struct {
EtcdConf
hasServerID bool
}{
{
EtcdConf: EtcdConf{
Hosts: []string{"any"},
ID: -1,
},
hasServerID: false,
},
{
EtcdConf: EtcdConf{
Hosts: []string{"any"},
ID: 0,
},
hasServerID: false,
},
{
EtcdConf: EtcdConf{
Hosts: []string{"any"},
ID: 10000,
},
hasServerID: true,
},
}
for _, test := range tests {
assert.Equal(t, test.hasServerID, test.EtcdConf.HasID())
}
}
func TestEtcdConf_HasTLS(t *testing.T) {
tests := []struct {
name string
conf EtcdConf
want bool
}{
{
name: "empty config",
conf: EtcdConf{},
want: false,
},
{
name: "missing CertFile",
conf: EtcdConf{
CertKeyFile: "key",
CACertFile: "ca",
},
want: false,
},
{
name: "missing CertKeyFile",
conf: EtcdConf{
CertFile: "cert",
CACertFile: "ca",
},
want: false,
},
{
name: "missing CACertFile",
conf: EtcdConf{
CertFile: "cert",
CertKeyFile: "key",
},
want: false,
},
{
name: "valid config",
conf: EtcdConf{
CertFile: "cert",
CertKeyFile: "key",
CACertFile: "ca",
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.conf.HasTLS()
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -1,85 +1,12 @@
package internal package internal
import ( import (
"os"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stringx" "github.com/zeromicro/go-zero/core/stringx"
) )
const (
certContent = `-----BEGIN CERTIFICATE-----
MIIDazCCAlOgAwIBAgIUEg9GVO2oaPn+YSmiqmFIuAo10WIwDQYJKoZIhvcNAQEM
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAzMTExMzIxMjNaGA8yMTIz
MDIxNTEzMjEyM1owRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
AQEBBQADggEPADCCAQoCggEBALplXlWsIf0O/IgnIplmiZHKGnxyfyufyE2FBRNk
OofRqbKuPH8GNqbkvZm7N29fwTDAQ+mViAggCkDht4hOzoWJMA7KYJt8JnTSWL48
M1lcrpc9DL2gszC/JF/FGvyANbBtLklkZPFBGdHUX14pjrT937wqPtm+SqUHSvRT
B7bmwmm2drRcmhpVm98LSlV7uQ2EgnJgsLjBPITKUejLmVLHfgX0RwQ2xIpX9pS4
FCe1BTacwl2gGp7Mje7y4Mfv3o0ArJW6Tuwbjx59ZXwb1KIP71b7bT04AVS8ZeYO
UMLKKuB5UR9x9Rn6cLXOTWBpcMVyzDgrAFLZjnE9LPUolZMCAwEAAaNRME8wHwYD
VR0jBBgwFoAUeW8w8pmhncbRgTsl48k4/7wnfx8wCQYDVR0TBAIwADALBgNVHQ8E
BAMCBPAwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBDAUAA4IBAQAI
y9xaoS88CLPBsX6mxfcTAFVfGNTRW9VN9Ng1cCnUR+YGoXGM/l+qP4f7p8ocdGwK
iYZErVTzXYIn+D27//wpY3klJk3gAnEUBT3QRkStBw7XnpbeZ2oPBK+cmDnCnZPS
BIF1wxPX7vIgaxs5Zsdqwk3qvZ4Djr2wP7LabNWTLSBKgQoUY45Liw6pffLwcGF9
UKlu54bvGze2SufISCR3ib+I+FLvqpvJhXToZWYb/pfI/HccuCL1oot1x8vx6DQy
U+TYxlZsKS5mdNxAX3dqEkEMsgEi+g/tzDPXJImfeCGGBhIOXLm8SRypiuGdEbc9
xkWYxRPegajuEZGvCqVs
-----END CERTIFICATE-----`
keyContent = `-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAumVeVawh/Q78iCcimWaJkcoafHJ/K5/ITYUFE2Q6h9Gpsq48
fwY2puS9mbs3b1/BMMBD6ZWICCAKQOG3iE7OhYkwDspgm3wmdNJYvjwzWVyulz0M
vaCzML8kX8Ua/IA1sG0uSWRk8UEZ0dRfXimOtP3fvCo+2b5KpQdK9FMHtubCabZ2
tFyaGlWb3wtKVXu5DYSCcmCwuME8hMpR6MuZUsd+BfRHBDbEilf2lLgUJ7UFNpzC
XaAansyN7vLgx+/ejQCslbpO7BuPHn1lfBvUog/vVvttPTgBVLxl5g5Qwsoq4HlR
H3H1Gfpwtc5NYGlwxXLMOCsAUtmOcT0s9SiVkwIDAQABAoIBAD5meTJNMgO55Kjg
ESExxpRcCIno+tHr5+6rvYtEXqPheOIsmmwb9Gfi4+Z3WpOaht5/Pz0Ppj6yGzyl
U//6AgGKb+BDuBvVcDpjwPnOxZIBCSHwejdxeQu0scSuA97MPS0XIAvJ5FEv7ijk
5Bht6SyGYURpECltHygoTNuGgGqmO+McCJRLE9L09lTBI6UQ/JQwWJqSr7wx6iPU
M1Ze/srIV+7cyEPu6i0DGjS1gSQKkX68Lqn1w6oE290O+OZvleO0gZ02fLDWCZke
aeD9+EU/Pw+rqm3H6o0szOFIpzhRp41FUdW9sybB3Yp3u7c/574E+04Z/e30LMKs
TCtE1QECgYEA3K7KIpw0NH2HXL5C3RHcLmr204xeBfS70riBQQuVUgYdmxak2ima
80RInskY8hRhSGTg0l+VYIH8cmjcUyqMSOELS5XfRH99r4QPiK8AguXg80T4VumY
W3Pf+zEC2ssgP/gYthV0g0Xj5m2QxktOF9tRw5nkg739ZR4dI9lm/iECgYEA2Dnf
uwEDGqHiQRF6/fh5BG/nGVMvrefkqx6WvTJQ3k/M/9WhxB+lr/8yH46TuS8N2b29
FoTf3Mr9T7pr/PWkOPzoY3P56nYbKU8xSwCim9xMzhBMzj8/N9ukJvXy27/VOz56
eQaKqnvdXNGtPJrIMDGHps2KKWlKLyAlapzjVTMCgYAA/W++tACv85g13EykfT4F
n0k4LbsGP9DP4zABQLIMyiY72eAncmRVjwrcW36XJ2xATOONTgx3gF3HjZzfaqNy
eD/6uNNllUTVEryXGmHgNHPL45VRnn6memCY2eFvZdXhM5W4y2PYaunY0MkDercA
+GTngbs6tBF88KOk04bYwQKBgFl68cRgsdkmnwwQYNaTKfmVGYzYaQXNzkqmWPko
xmCJo6tHzC7ubdG8iRCYHzfmahPuuj6EdGPZuSRyYFgJi5Ftz/nAN+84OxtIQ3zn
YWOgskQgaLh9YfsKsQ7Sf1NDOsnOnD5TX7UXl07fEpLe9vNCvAFiU8e5Y9LGudU5
4bYTAoGBAMdX3a3bXp4cZvXNBJ/QLVyxC6fP1Q4haCR1Od3m+T00Jth2IX2dk/fl
p6xiJT1av5JtYabv1dFKaXOS5s1kLGGuCCSKpkvFZm826aQ2AFm0XGqEQDLeei5b
A52Kpy/YJ+RkG4BTFtAooFq6DmA0cnoP6oPvG2h6XtDJwDTPInJb
-----END RSA PRIVATE KEY-----`
caContent = `-----BEGIN CERTIFICATE-----
MIIDbTCCAlWgAwIBAgIUBJvFoCowKich7MMfseJ+DYzzirowDQYJKoZIhvcNAQEM
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAzMTExMzIxMDNaGA8yMTIz
MDIxNTEzMjEwM1owRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
AQEBBQADggEPADCCAQoCggEBAO4to2YMYj0bxgr2FCiweSTSFuPx33zSw2x/s9Wf
OR41bm2DFsyYT5f3sOIKlXZEdLmOKty2e3ho3yC0EyNpVHdykkkHT3aDI17quZax
kYi/URqqtl1Z08A22txolc04hAZisg2BypGi3vql81UW1t3zyloGnJoIAeXR9uca
ljP6Bk3bwsxoVBLi1JtHrO0hHLQaeHmKhAyrys06X0LRdn7Px48yRZlt6FaLSa8X
YiRM0G44bVy/h6BkoQjMYGwVmCVk6zjJ9U7ZPFqdnDMNxAfR+hjDnYodqdLDMTTR
1NPVrnEnNwFx0AMLvgt/ba/45vZCEAmSZnFXFAJJcM7ai9ECAwEAAaNTMFEwHQYD
VR0OBBYEFHlvMPKZoZ3G0YE7JePJOP+8J38fMB8GA1UdIwQYMBaAFHlvMPKZoZ3G
0YE7JePJOP+8J38fMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQEMBQADggEB
AMX8dNulADOo9uQgBMyFb9TVra7iY0zZjzv4GY5XY7scd52n6CnfAPvYBBDnTr/O
BgNp5jaujb4+9u/2qhV3f9n+/3WOb2CmPehBgVSzlXqHeQ9lshmgwZPeem2T+8Tm
Nnc/xQnsUfCFszUDxpkr55+aLVM22j02RWqcZ4q7TAaVYL+kdFVMc8FoqG/0ro6A
BjE/Qn0Nn7ciX1VUjDt8l+k7ummPJTmzdi6i6E4AwO9dzrGNgGJ4aWL8cC6xYcIX
goVIRTFeONXSDno/oPjWHpIPt7L15heMpKBHNuzPkKx2YVqPHE5QZxWfS+Lzgx+Q
E2oTTM0rYKOZ8p6000mhvKI=
-----END CERTIFICATE-----`
)
func TestAccount(t *testing.T) { func TestAccount(t *testing.T) {
endpoints := []string{ endpoints := []string{
"192.168.0.2:2379", "192.168.0.2:2379",
@@ -105,34 +32,3 @@ func TestAccount(t *testing.T) {
assert.Equal(t, username, account.User) assert.Equal(t, username, account.User)
assert.Equal(t, anotherPassword, account.Pass) assert.Equal(t, anotherPassword, account.Pass)
} }
func TestTLSMethods(t *testing.T) {
certFile := createTempFile(t, []byte(certContent))
defer os.Remove(certFile)
keyFile := createTempFile(t, []byte(keyContent))
defer os.Remove(keyFile)
caFile := createTempFile(t, []byte(caContent))
defer os.Remove(caFile)
assert.NoError(t, AddTLS([]string{"foo"}, certFile, keyFile, caFile, false))
cfg, ok := GetTLS([]string{"foo"})
assert.True(t, ok)
assert.NotNil(t, cfg)
assert.Error(t, AddTLS([]string{"bar"}, "bad-file", keyFile, caFile, false))
assert.Error(t, AddTLS([]string{"bar"}, certFile, keyFile, "bad-file", false))
}
func createTempFile(t *testing.T, body []byte) string {
tmpFile, err := os.CreateTemp(os.TempDir(), "go-unit-*.tmp")
if err != nil {
t.Fatal(err)
}
tmpFile.Close()
if err = os.WriteFile(tmpFile.Name(), body, os.ModePerm); err != nil {
t.Fatal(err)
}
return tmpFile.Name()
}

View File

@@ -81,7 +81,7 @@ func (mr *MockEtcdClientMockRecorder) Ctx() *gomock.Call {
// Get mocks base method // Get mocks base method
func (m *MockEtcdClient) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { func (m *MockEtcdClient) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
varargs := []any{ctx, key} varargs := []interface{}{ctx, key}
for _, a := range opts { for _, a := range opts {
varargs = append(varargs, a) varargs = append(varargs, a)
} }
@@ -92,9 +92,9 @@ func (m *MockEtcdClient) Get(ctx context.Context, key string, opts ...clientv3.O
} }
// Get indicates an expected call of Get // Get indicates an expected call of Get
func (mr *MockEtcdClientMockRecorder) Get(ctx, key any, opts ...any) *gomock.Call { func (mr *MockEtcdClientMockRecorder) Get(ctx, key interface{}, opts ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, key}, opts...) varargs := append([]interface{}{ctx, key}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockEtcdClient)(nil).Get), varargs...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockEtcdClient)(nil).Get), varargs...)
} }
@@ -108,7 +108,7 @@ func (m *MockEtcdClient) Grant(ctx context.Context, ttl int64) (*clientv3.LeaseG
} }
// Grant indicates an expected call of Grant // Grant indicates an expected call of Grant
func (mr *MockEtcdClientMockRecorder) Grant(ctx, ttl any) *gomock.Call { func (mr *MockEtcdClientMockRecorder) Grant(ctx, ttl interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Grant", reflect.TypeOf((*MockEtcdClient)(nil).Grant), ctx, ttl) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Grant", reflect.TypeOf((*MockEtcdClient)(nil).Grant), ctx, ttl)
} }
@@ -123,7 +123,7 @@ func (m *MockEtcdClient) KeepAlive(ctx context.Context, id clientv3.LeaseID) (<-
} }
// KeepAlive indicates an expected call of KeepAlive // KeepAlive indicates an expected call of KeepAlive
func (mr *MockEtcdClientMockRecorder) KeepAlive(ctx, id any) *gomock.Call { func (mr *MockEtcdClientMockRecorder) KeepAlive(ctx, id interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeepAlive", reflect.TypeOf((*MockEtcdClient)(nil).KeepAlive), ctx, id) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeepAlive", reflect.TypeOf((*MockEtcdClient)(nil).KeepAlive), ctx, id)
} }
@@ -131,7 +131,7 @@ func (mr *MockEtcdClientMockRecorder) KeepAlive(ctx, id any) *gomock.Call {
// Put mocks base method // Put mocks base method
func (m *MockEtcdClient) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) { func (m *MockEtcdClient) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
varargs := []any{ctx, key, val} varargs := []interface{}{ctx, key, val}
for _, a := range opts { for _, a := range opts {
varargs = append(varargs, a) varargs = append(varargs, a)
} }
@@ -142,9 +142,9 @@ func (m *MockEtcdClient) Put(ctx context.Context, key, val string, opts ...clien
} }
// Put indicates an expected call of Put // Put indicates an expected call of Put
func (mr *MockEtcdClientMockRecorder) Put(ctx, key, val any, opts ...any) *gomock.Call { func (mr *MockEtcdClientMockRecorder) Put(ctx, key, val interface{}, opts ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, key, val}, opts...) varargs := append([]interface{}{ctx, key, val}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockEtcdClient)(nil).Put), varargs...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockEtcdClient)(nil).Put), varargs...)
} }
@@ -158,7 +158,7 @@ func (m *MockEtcdClient) Revoke(ctx context.Context, id clientv3.LeaseID) (*clie
} }
// Revoke indicates an expected call of Revoke // Revoke indicates an expected call of Revoke
func (mr *MockEtcdClientMockRecorder) Revoke(ctx, id any) *gomock.Call { func (mr *MockEtcdClientMockRecorder) Revoke(ctx, id interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Revoke", reflect.TypeOf((*MockEtcdClient)(nil).Revoke), ctx, id) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Revoke", reflect.TypeOf((*MockEtcdClient)(nil).Revoke), ctx, id)
} }
@@ -166,7 +166,7 @@ func (mr *MockEtcdClientMockRecorder) Revoke(ctx, id any) *gomock.Call {
// Watch mocks base method // Watch mocks base method
func (m *MockEtcdClient) Watch(ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan { func (m *MockEtcdClient) Watch(ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan {
m.ctrl.T.Helper() m.ctrl.T.Helper()
varargs := []any{ctx, key} varargs := []interface{}{ctx, key}
for _, a := range opts { for _, a := range opts {
varargs = append(varargs, a) varargs = append(varargs, a)
} }
@@ -176,8 +176,8 @@ func (m *MockEtcdClient) Watch(ctx context.Context, key string, opts ...clientv3
} }
// Watch indicates an expected call of Watch // Watch indicates an expected call of Watch
func (mr *MockEtcdClientMockRecorder) Watch(ctx, key any, opts ...any) *gomock.Call { func (mr *MockEtcdClientMockRecorder) Watch(ctx, key interface{}, opts ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, key}, opts...) varargs := append([]interface{}{ctx, key}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockEtcdClient)(nil).Watch), varargs...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockEtcdClient)(nil).Watch), varargs...)
} }

View File

@@ -2,7 +2,6 @@ package internal
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"sort" "sort"
@@ -15,7 +14,6 @@ import (
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/syncx" "github.com/zeromicro/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/threading" "github.com/zeromicro/go-zero/core/threading"
"go.etcd.io/etcd/api/v3/v3rpc/rpctypes"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
) )
@@ -24,7 +22,6 @@ var (
clusters: make(map[string]*cluster), clusters: make(map[string]*cluster),
} }
connManager = syncx.NewResourceManager() connManager = syncx.NewResourceManager()
errClosed = errors.New("etcd monitor chan has been closed")
) )
// A Registry is a registry that manages the etcd client connections. // A Registry is a registry that manages the etcd client connections.
@@ -222,7 +219,7 @@ func (c *cluster) load(cli EtcdClient, key string) int64 {
break break
} }
logx.Errorf("%s, key is %s", err.Error(), key) logx.Error(err)
time.Sleep(coolDownInterval) time.Sleep(coolDownInterval)
} }
@@ -291,47 +288,40 @@ func (c *cluster) reload(cli EtcdClient) {
func (c *cluster) watch(cli EtcdClient, key string, rev int64) { func (c *cluster) watch(cli EtcdClient, key string, rev int64) {
for { for {
err := c.watchStream(cli, key, rev) if c.watchStream(cli, key, rev) {
if err == nil {
return return
} }
if rev != 0 && errors.Is(err, rpctypes.ErrCompacted) {
logx.Errorf("etcd watch stream has been compacted, try to reload, rev %d", rev)
rev = c.load(cli, key)
}
// log the error and retry
logx.Error(err)
} }
} }
func (c *cluster) watchStream(cli EtcdClient, key string, rev int64) error { func (c *cluster) watchStream(cli EtcdClient, key string, rev int64) bool {
var rch clientv3.WatchChan var rch clientv3.WatchChan
if rev != 0 { if rev != 0 {
rch = cli.Watch(clientv3.WithRequireLeader(c.context(cli)), makeKeyPrefix(key), rch = cli.Watch(clientv3.WithRequireLeader(c.context(cli)), makeKeyPrefix(key), clientv3.WithPrefix(),
clientv3.WithPrefix(), clientv3.WithRev(rev+1)) clientv3.WithRev(rev+1))
} else { } else {
rch = cli.Watch(clientv3.WithRequireLeader(c.context(cli)), makeKeyPrefix(key), rch = cli.Watch(clientv3.WithRequireLeader(c.context(cli)), makeKeyPrefix(key), clientv3.WithPrefix())
clientv3.WithPrefix())
} }
for { for {
select { select {
case wresp, ok := <-rch: case wresp, ok := <-rch:
if !ok { if !ok {
return errClosed logx.Error("etcd monitor chan has been closed")
return false
} }
if wresp.Canceled { if wresp.Canceled {
return fmt.Errorf("etcd monitor chan has been canceled, error: %w", wresp.Err()) logx.Errorf("etcd monitor chan has been canceled, error: %v", wresp.Err())
return false
} }
if wresp.Err() != nil { if wresp.Err() != nil {
return fmt.Errorf("etcd monitor chan error: %w", wresp.Err()) logx.Error(fmt.Sprintf("etcd monitor chan error: %v", wresp.Err()))
return false
} }
c.handleWatchEvents(key, wresp.Events) c.handleWatchEvents(key, wresp.Events)
case <-c.done: case <-c.done:
return nil return true
} }
} }
} }
@@ -347,11 +337,13 @@ func (c *cluster) watchConnState(cli EtcdClient) {
// DialClient dials an etcd cluster with given endpoints. // DialClient dials an etcd cluster with given endpoints.
func DialClient(endpoints []string) (EtcdClient, error) { func DialClient(endpoints []string) (EtcdClient, error) {
cfg := clientv3.Config{ cfg := clientv3.Config{
Endpoints: endpoints, Endpoints: endpoints,
AutoSyncInterval: autoSyncInterval, AutoSyncInterval: autoSyncInterval,
DialTimeout: DialTimeout, DialTimeout: DialTimeout,
RejectOldCluster: true, DialKeepAliveTime: dialKeepAliveTime,
PermitWithoutStream: true, DialKeepAliveTimeout: DialTimeout,
RejectOldCluster: true,
PermitWithoutStream: true,
} }
if account, ok := GetAccount(endpoints); ok { if account, ok := GetAccount(endpoints); ok {
cfg.Username = account.User cfg.Username = account.User

View File

@@ -2,10 +2,8 @@ package internal
import ( import (
"context" "context"
"os"
"sync" "sync"
"testing" "testing"
"time"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -16,7 +14,6 @@ import (
"go.etcd.io/etcd/api/v3/etcdserverpb" "go.etcd.io/etcd/api/v3/etcdserverpb"
"go.etcd.io/etcd/api/v3/mvccpb" "go.etcd.io/etcd/api/v3/mvccpb"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
"go.etcd.io/etcd/client/v3/mock/mockserver"
) )
var mockLock sync.Mutex var mockLock sync.Mutex
@@ -170,7 +167,7 @@ func TestCluster_Watch(t *testing.T) {
assert.Equal(t, "world", kv.Val) assert.Equal(t, "world", kv.Val)
wg.Done() wg.Done()
}).MaxTimes(1) }).MaxTimes(1)
listener.EXPECT().OnDelete(gomock.Any()).Do(func(_ any) { listener.EXPECT().OnDelete(gomock.Any()).Do(func(_ interface{}) {
wg.Done() wg.Done()
}).MaxTimes(1) }).MaxTimes(1)
go c.watch(cli, "any", 0) go c.watch(cli, "any", 0)
@@ -245,58 +242,3 @@ func TestValueOnlyContext(t *testing.T) {
ctx.Done() ctx.Done()
assert.Nil(t, ctx.Err()) assert.Nil(t, ctx.Err())
} }
func TestDialClient(t *testing.T) {
svr, err := mockserver.StartMockServers(1)
assert.NoError(t, err)
svr.StartAt(0)
certFile := createTempFile(t, []byte(certContent))
defer os.Remove(certFile)
keyFile := createTempFile(t, []byte(keyContent))
defer os.Remove(keyFile)
caFile := createTempFile(t, []byte(caContent))
defer os.Remove(caFile)
endpoints := []string{svr.Servers[0].Address}
AddAccount(endpoints, "foo", "bar")
assert.NoError(t, AddTLS(endpoints, certFile, keyFile, caFile, false))
old := DialTimeout
DialTimeout = time.Millisecond
defer func() {
DialTimeout = old
}()
_, err = DialClient(endpoints)
assert.Error(t, err)
}
func TestRegistry_Monitor(t *testing.T) {
svr, err := mockserver.StartMockServers(1)
assert.NoError(t, err)
svr.StartAt(0)
endpoints := []string{svr.Servers[0].Address}
GetRegistry().lock.Lock()
GetRegistry().clusters = map[string]*cluster{
getClusterKey(endpoints): {
listeners: map[string][]UpdateListener{},
values: map[string]map[string]string{
"foo": {
"bar": "baz",
},
},
},
}
GetRegistry().lock.Unlock()
assert.Error(t, GetRegistry().Monitor(endpoints, "foo", new(mockListener)))
}
type mockListener struct {
}
func (m *mockListener) OnAdd(_ KV) {
}
func (m *mockListener) OnDelete(_ KV) {
}

View File

@@ -58,7 +58,7 @@ func (m *MocketcdConn) WaitForStateChange(ctx context.Context, sourceState conne
} }
// WaitForStateChange indicates an expected call of WaitForStateChange // WaitForStateChange indicates an expected call of WaitForStateChange
func (mr *MocketcdConnMockRecorder) WaitForStateChange(ctx, sourceState any) *gomock.Call { func (mr *MocketcdConnMockRecorder) WaitForStateChange(ctx, sourceState interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WaitForStateChange", reflect.TypeOf((*MocketcdConn)(nil).WaitForStateChange), ctx, sourceState) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WaitForStateChange", reflect.TypeOf((*MocketcdConn)(nil).WaitForStateChange), ctx, sourceState)
} }

View File

@@ -40,7 +40,7 @@ func (m *MockUpdateListener) OnAdd(kv KV) {
} }
// OnAdd indicates an expected call of OnAdd // OnAdd indicates an expected call of OnAdd
func (mr *MockUpdateListenerMockRecorder) OnAdd(kv any) *gomock.Call { func (mr *MockUpdateListenerMockRecorder) OnAdd(kv interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnAdd", reflect.TypeOf((*MockUpdateListener)(nil).OnAdd), kv) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnAdd", reflect.TypeOf((*MockUpdateListener)(nil).OnAdd), kv)
} }
@@ -52,7 +52,7 @@ func (m *MockUpdateListener) OnDelete(kv KV) {
} }
// OnDelete indicates an expected call of OnDelete // OnDelete indicates an expected call of OnDelete
func (mr *MockUpdateListenerMockRecorder) OnDelete(kv any) *gomock.Call { func (mr *MockUpdateListenerMockRecorder) OnDelete(kv interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnDelete", reflect.TypeOf((*MockUpdateListener)(nil).OnDelete), kv) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnDelete", reflect.TypeOf((*MockUpdateListener)(nil).OnDelete), kv)
} }

View File

@@ -9,6 +9,7 @@ const (
autoSyncInterval = time.Minute autoSyncInterval = time.Minute
coolDownInterval = time.Second coolDownInterval = time.Second
dialTimeout = 5 * time.Second dialTimeout = 5 * time.Second
dialKeepAliveTime = 5 * time.Second
requestTimeout = 3 * time.Second requestTimeout = 3 * time.Second
endpointsSeparator = "," endpointsSeparator = ","
) )

View File

@@ -1,10 +1,7 @@
package discov package discov
import ( import (
"context"
"errors" "errors"
"net"
"os"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -16,83 +13,6 @@ import (
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stringx" "github.com/zeromicro/go-zero/core/stringx"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
"golang.org/x/net/http2"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
)
const (
certContent = `-----BEGIN CERTIFICATE-----
MIIDazCCAlOgAwIBAgIUEg9GVO2oaPn+YSmiqmFIuAo10WIwDQYJKoZIhvcNAQEM
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAzMTExMzIxMjNaGA8yMTIz
MDIxNTEzMjEyM1owRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
AQEBBQADggEPADCCAQoCggEBALplXlWsIf0O/IgnIplmiZHKGnxyfyufyE2FBRNk
OofRqbKuPH8GNqbkvZm7N29fwTDAQ+mViAggCkDht4hOzoWJMA7KYJt8JnTSWL48
M1lcrpc9DL2gszC/JF/FGvyANbBtLklkZPFBGdHUX14pjrT937wqPtm+SqUHSvRT
B7bmwmm2drRcmhpVm98LSlV7uQ2EgnJgsLjBPITKUejLmVLHfgX0RwQ2xIpX9pS4
FCe1BTacwl2gGp7Mje7y4Mfv3o0ArJW6Tuwbjx59ZXwb1KIP71b7bT04AVS8ZeYO
UMLKKuB5UR9x9Rn6cLXOTWBpcMVyzDgrAFLZjnE9LPUolZMCAwEAAaNRME8wHwYD
VR0jBBgwFoAUeW8w8pmhncbRgTsl48k4/7wnfx8wCQYDVR0TBAIwADALBgNVHQ8E
BAMCBPAwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBDAUAA4IBAQAI
y9xaoS88CLPBsX6mxfcTAFVfGNTRW9VN9Ng1cCnUR+YGoXGM/l+qP4f7p8ocdGwK
iYZErVTzXYIn+D27//wpY3klJk3gAnEUBT3QRkStBw7XnpbeZ2oPBK+cmDnCnZPS
BIF1wxPX7vIgaxs5Zsdqwk3qvZ4Djr2wP7LabNWTLSBKgQoUY45Liw6pffLwcGF9
UKlu54bvGze2SufISCR3ib+I+FLvqpvJhXToZWYb/pfI/HccuCL1oot1x8vx6DQy
U+TYxlZsKS5mdNxAX3dqEkEMsgEi+g/tzDPXJImfeCGGBhIOXLm8SRypiuGdEbc9
xkWYxRPegajuEZGvCqVs
-----END CERTIFICATE-----`
keyContent = `-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAumVeVawh/Q78iCcimWaJkcoafHJ/K5/ITYUFE2Q6h9Gpsq48
fwY2puS9mbs3b1/BMMBD6ZWICCAKQOG3iE7OhYkwDspgm3wmdNJYvjwzWVyulz0M
vaCzML8kX8Ua/IA1sG0uSWRk8UEZ0dRfXimOtP3fvCo+2b5KpQdK9FMHtubCabZ2
tFyaGlWb3wtKVXu5DYSCcmCwuME8hMpR6MuZUsd+BfRHBDbEilf2lLgUJ7UFNpzC
XaAansyN7vLgx+/ejQCslbpO7BuPHn1lfBvUog/vVvttPTgBVLxl5g5Qwsoq4HlR
H3H1Gfpwtc5NYGlwxXLMOCsAUtmOcT0s9SiVkwIDAQABAoIBAD5meTJNMgO55Kjg
ESExxpRcCIno+tHr5+6rvYtEXqPheOIsmmwb9Gfi4+Z3WpOaht5/Pz0Ppj6yGzyl
U//6AgGKb+BDuBvVcDpjwPnOxZIBCSHwejdxeQu0scSuA97MPS0XIAvJ5FEv7ijk
5Bht6SyGYURpECltHygoTNuGgGqmO+McCJRLE9L09lTBI6UQ/JQwWJqSr7wx6iPU
M1Ze/srIV+7cyEPu6i0DGjS1gSQKkX68Lqn1w6oE290O+OZvleO0gZ02fLDWCZke
aeD9+EU/Pw+rqm3H6o0szOFIpzhRp41FUdW9sybB3Yp3u7c/574E+04Z/e30LMKs
TCtE1QECgYEA3K7KIpw0NH2HXL5C3RHcLmr204xeBfS70riBQQuVUgYdmxak2ima
80RInskY8hRhSGTg0l+VYIH8cmjcUyqMSOELS5XfRH99r4QPiK8AguXg80T4VumY
W3Pf+zEC2ssgP/gYthV0g0Xj5m2QxktOF9tRw5nkg739ZR4dI9lm/iECgYEA2Dnf
uwEDGqHiQRF6/fh5BG/nGVMvrefkqx6WvTJQ3k/M/9WhxB+lr/8yH46TuS8N2b29
FoTf3Mr9T7pr/PWkOPzoY3P56nYbKU8xSwCim9xMzhBMzj8/N9ukJvXy27/VOz56
eQaKqnvdXNGtPJrIMDGHps2KKWlKLyAlapzjVTMCgYAA/W++tACv85g13EykfT4F
n0k4LbsGP9DP4zABQLIMyiY72eAncmRVjwrcW36XJ2xATOONTgx3gF3HjZzfaqNy
eD/6uNNllUTVEryXGmHgNHPL45VRnn6memCY2eFvZdXhM5W4y2PYaunY0MkDercA
+GTngbs6tBF88KOk04bYwQKBgFl68cRgsdkmnwwQYNaTKfmVGYzYaQXNzkqmWPko
xmCJo6tHzC7ubdG8iRCYHzfmahPuuj6EdGPZuSRyYFgJi5Ftz/nAN+84OxtIQ3zn
YWOgskQgaLh9YfsKsQ7Sf1NDOsnOnD5TX7UXl07fEpLe9vNCvAFiU8e5Y9LGudU5
4bYTAoGBAMdX3a3bXp4cZvXNBJ/QLVyxC6fP1Q4haCR1Od3m+T00Jth2IX2dk/fl
p6xiJT1av5JtYabv1dFKaXOS5s1kLGGuCCSKpkvFZm826aQ2AFm0XGqEQDLeei5b
A52Kpy/YJ+RkG4BTFtAooFq6DmA0cnoP6oPvG2h6XtDJwDTPInJb
-----END RSA PRIVATE KEY-----`
caContent = `-----BEGIN CERTIFICATE-----
MIIDbTCCAlWgAwIBAgIUBJvFoCowKich7MMfseJ+DYzzirowDQYJKoZIhvcNAQEM
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAzMTExMzIxMDNaGA8yMTIz
MDIxNTEzMjEwM1owRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
AQEBBQADggEPADCCAQoCggEBAO4to2YMYj0bxgr2FCiweSTSFuPx33zSw2x/s9Wf
OR41bm2DFsyYT5f3sOIKlXZEdLmOKty2e3ho3yC0EyNpVHdykkkHT3aDI17quZax
kYi/URqqtl1Z08A22txolc04hAZisg2BypGi3vql81UW1t3zyloGnJoIAeXR9uca
ljP6Bk3bwsxoVBLi1JtHrO0hHLQaeHmKhAyrys06X0LRdn7Px48yRZlt6FaLSa8X
YiRM0G44bVy/h6BkoQjMYGwVmCVk6zjJ9U7ZPFqdnDMNxAfR+hjDnYodqdLDMTTR
1NPVrnEnNwFx0AMLvgt/ba/45vZCEAmSZnFXFAJJcM7ai9ECAwEAAaNTMFEwHQYD
VR0OBBYEFHlvMPKZoZ3G0YE7JePJOP+8J38fMB8GA1UdIwQYMBaAFHlvMPKZoZ3G
0YE7JePJOP+8J38fMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQEMBQADggEB
AMX8dNulADOo9uQgBMyFb9TVra7iY0zZjzv4GY5XY7scd52n6CnfAPvYBBDnTr/O
BgNp5jaujb4+9u/2qhV3f9n+/3WOb2CmPehBgVSzlXqHeQ9lshmgwZPeem2T+8Tm
Nnc/xQnsUfCFszUDxpkr55+aLVM22j02RWqcZ4q7TAaVYL+kdFVMc8FoqG/0ro6A
BjE/Qn0Nn7ciX1VUjDt8l+k7ummPJTmzdi6i6E4AwO9dzrGNgGJ4aWL8cC6xYcIX
goVIRTFeONXSDno/oPjWHpIPt7L15heMpKBHNuzPkKx2YVqPHE5QZxWfS+Lzgx+Q
E2oTTM0rYKOZ8p6000mhvKI=
-----END CERTIFICATE-----`
) )
func init() { func init() {
@@ -117,7 +37,7 @@ func TestPublisher_register(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
} }
func TestPublisher_registerWithOptions(t *testing.T) { func TestPublisher_registerWithId(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
const id = 2 const id = 2
@@ -129,15 +49,7 @@ func TestPublisher_registerWithOptions(t *testing.T) {
ID: 1, ID: 1,
}, nil) }, nil)
cli.EXPECT().Put(gomock.Any(), makeEtcdKey("thekey", id), "thevalue", gomock.Any()) cli.EXPECT().Put(gomock.Any(), makeEtcdKey("thekey", id), "thevalue", gomock.Any())
pub := NewPublisher(nil, "thekey", "thevalue", WithId(id))
certFile := createTempFile(t, []byte(certContent))
defer os.Remove(certFile)
keyFile := createTempFile(t, []byte(keyContent))
defer os.Remove(keyFile)
caFile := createTempFile(t, []byte(caContent))
defer os.Remove(caFile)
pub := NewPublisher(nil, "thekey", "thevalue", WithId(id),
WithPubEtcdTLS(certFile, keyFile, caFile, true))
_, err := pub.register(cli) _, err := pub.register(cli)
assert.Nil(t, err) assert.Nil(t, err)
} }
@@ -213,7 +125,7 @@ func TestPublisher_keepAliveAsyncQuit(t *testing.T) {
cli.EXPECT().KeepAlive(gomock.Any(), id) cli.EXPECT().KeepAlive(gomock.Any(), id)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) { cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ interface{}) {
wg.Done() wg.Done()
}) })
pub := NewPublisher(nil, "thekey", "thevalue") pub := NewPublisher(nil, "thekey", "thevalue")
@@ -235,7 +147,7 @@ func TestPublisher_keepAliveAsyncPause(t *testing.T) {
pub := NewPublisher(nil, "thekey", "thevalue") pub := NewPublisher(nil, "thekey", "thevalue")
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) { cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ interface{}) {
pub.Stop() pub.Stop()
wg.Done() wg.Done()
}) })
@@ -257,92 +169,3 @@ func TestPublisher_Resume(t *testing.T) {
}() }()
<-publisher.resumeChan <-publisher.resumeChan
} }
func TestPublisher_keepAliveAsync(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
const id clientv3.LeaseID = 1
conn := createMockConn(t)
defer conn.Close()
cli := internal.NewMockEtcdClient(ctrl)
cli.EXPECT().ActiveConnection().Return(conn).AnyTimes()
cli.EXPECT().Close()
defer cli.Close()
cli.ActiveConnection()
restore := setMockClient(cli)
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().KeepAlive(gomock.Any(), id)
cli.EXPECT().Grant(gomock.Any(), timeToLive).Return(&clientv3.LeaseGrantResponse{
ID: 1,
}, nil)
cli.EXPECT().Put(gomock.Any(), makeEtcdKey("thekey", int64(id)), "thevalue", gomock.Any())
var wg sync.WaitGroup
wg.Add(1)
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
wg.Done()
})
pub := NewPublisher([]string{"the-endpoint"}, "thekey", "thevalue")
pub.lease = id
assert.Nil(t, pub.KeepAlive())
pub.Stop()
wg.Wait()
}
func createMockConn(t *testing.T) *grpc.ClientConn {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening. Err: %v", err)
}
defer lis.Close()
lisAddr := resolver.Address{Addr: lis.Addr().String()}
lisDone := make(chan struct{})
dialDone := make(chan struct{})
// 1st listener accepts the connection and then does nothing
go func() {
defer close(lisDone)
conn, err := lis.Accept()
if err != nil {
t.Errorf("Error while accepting. Err: %v", err)
return
}
framer := http2.NewFramer(conn, conn)
if err := framer.WriteSettings(http2.Setting{}); err != nil {
t.Errorf("Error while writing settings. Err: %v", err)
return
}
<-dialDone // Close conn only after dial returns.
}()
r := manual.NewBuilderWithScheme("whatever")
r.InitialState(resolver.State{Addresses: []resolver.Address{lisAddr}})
client, err := grpc.DialContext(context.Background(), r.Scheme()+":///test.server",
grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r))
close(dialDone)
if err != nil {
t.Fatalf("Dial failed. Err: %v", err)
}
timeout := time.After(1 * time.Second)
select {
case <-timeout:
t.Fatal("timed out waiting for server to finish")
case <-lisDone:
}
return client
}
func createTempFile(t *testing.T, body []byte) string {
tmpFile, err := os.CreateTemp(os.TempDir(), "go-unit-*.tmp")
if err != nil {
t.Fatal(err)
}
tmpFile.Close()
if err = os.WriteFile(tmpFile.Name(), body, os.ModePerm); err != nil {
t.Fatal(err)
}
return tmpFile.Name()
}

View File

@@ -1,15 +1,11 @@
package errorx package errorx
import ( import "bytes"
"bytes"
"sync"
)
type ( type (
// A BatchError is an error that can hold multiple errors. // A BatchError is an error that can hold multiple errors.
BatchError struct { BatchError struct {
errs errorArray errs errorArray
lock sync.Mutex
} }
errorArray []error errorArray []error
@@ -17,9 +13,6 @@ type (
// Add adds errs to be, nil errors are ignored. // Add adds errs to be, nil errors are ignored.
func (be *BatchError) Add(errs ...error) { func (be *BatchError) Add(errs ...error) {
be.lock.Lock()
defer be.lock.Unlock()
for _, err := range errs { for _, err := range errs {
if err != nil { if err != nil {
be.errs = append(be.errs, err) be.errs = append(be.errs, err)
@@ -29,9 +22,6 @@ func (be *BatchError) Add(errs ...error) {
// Err returns an error that represents all errors. // Err returns an error that represents all errors.
func (be *BatchError) Err() error { func (be *BatchError) Err() error {
be.lock.Lock()
defer be.lock.Unlock()
switch len(be.errs) { switch len(be.errs) {
case 0: case 0:
return nil return nil
@@ -44,9 +34,6 @@ func (be *BatchError) Err() error {
// NotNil checks if any error inside. // NotNil checks if any error inside.
func (be *BatchError) NotNil() bool { func (be *BatchError) NotNil() bool {
be.lock.Lock()
defer be.lock.Unlock()
return len(be.errs) > 0 return len(be.errs) > 0
} }

View File

@@ -3,7 +3,6 @@ package errorx
import ( import (
"errors" "errors"
"fmt" "fmt"
"sync"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -34,7 +33,7 @@ func TestBatchErrorNilFromFunc(t *testing.T) {
func TestBatchErrorOneError(t *testing.T) { func TestBatchErrorOneError(t *testing.T) {
var batch BatchError var batch BatchError
batch.Add(errors.New(err1)) batch.Add(errors.New(err1))
assert.NotNil(t, batch.Err()) assert.NotNil(t, batch)
assert.Equal(t, err1, batch.Err().Error()) assert.Equal(t, err1, batch.Err().Error())
assert.True(t, batch.NotNil()) assert.True(t, batch.NotNil())
} }
@@ -43,26 +42,7 @@ func TestBatchErrorWithErrors(t *testing.T) {
var batch BatchError var batch BatchError
batch.Add(errors.New(err1)) batch.Add(errors.New(err1))
batch.Add(errors.New(err2)) batch.Add(errors.New(err2))
assert.NotNil(t, batch.Err()) assert.NotNil(t, batch)
assert.Equal(t, fmt.Sprintf("%s\n%s", err1, err2), batch.Err().Error()) assert.Equal(t, fmt.Sprintf("%s\n%s", err1, err2), batch.Err().Error())
assert.True(t, batch.NotNil()) assert.True(t, batch.NotNil())
} }
func TestBatchErrorConcurrentAdd(t *testing.T) {
const count = 10000
var batch BatchError
var wg sync.WaitGroup
wg.Add(count)
for i := 0; i < count; i++ {
go func() {
defer wg.Done()
batch.Add(errors.New(err1))
}()
}
wg.Wait()
assert.NotNil(t, batch.Err())
assert.Equal(t, count, len(batch.errs))
assert.True(t, batch.NotNil())
}

View File

@@ -12,7 +12,7 @@ func Wrap(err error, message string) error {
} }
// Wrapf returns an error that wraps err with given format and args. // Wrapf returns an error that wraps err with given format and args.
func Wrapf(err error, format string, args ...any) error { func Wrapf(err error, format string, args ...interface{}) error {
if err == nil { if err == nil {
return nil return nil
} }

View File

@@ -42,7 +42,7 @@ func NewBulkExecutor(execute Execute, opts ...BulkOption) *BulkExecutor {
} }
// Add adds task into be. // Add adds task into be.
func (be *BulkExecutor) Add(task any) error { func (be *BulkExecutor) Add(task interface{}) error {
be.executor.Add(task) be.executor.Add(task)
return nil return nil
} }
@@ -79,22 +79,22 @@ func newBulkOptions() bulkOptions {
} }
type bulkContainer struct { type bulkContainer struct {
tasks []any tasks []interface{}
execute Execute execute Execute
maxTasks int maxTasks int
} }
func (bc *bulkContainer) AddTask(task any) bool { func (bc *bulkContainer) AddTask(task interface{}) bool {
bc.tasks = append(bc.tasks, task) bc.tasks = append(bc.tasks, task)
return len(bc.tasks) >= bc.maxTasks return len(bc.tasks) >= bc.maxTasks
} }
func (bc *bulkContainer) Execute(tasks any) { func (bc *bulkContainer) Execute(tasks interface{}) {
vals := tasks.([]any) vals := tasks.([]interface{})
bc.execute(vals) bc.execute(vals)
} }
func (bc *bulkContainer) RemoveAll() any { func (bc *bulkContainer) RemoveAll() interface{} {
tasks := bc.tasks tasks := bc.tasks
bc.tasks = nil bc.tasks = nil
return tasks return tasks

View File

@@ -12,7 +12,7 @@ func TestBulkExecutor(t *testing.T) {
var values []int var values []int
var lock sync.Mutex var lock sync.Mutex
executor := NewBulkExecutor(func(items []any) { executor := NewBulkExecutor(func(items []interface{}) {
lock.Lock() lock.Lock()
values = append(values, len(items)) values = append(values, len(items))
lock.Unlock() lock.Unlock()
@@ -40,7 +40,7 @@ func TestBulkExecutorFlushInterval(t *testing.T) {
var wait sync.WaitGroup var wait sync.WaitGroup
wait.Add(1) wait.Add(1)
executor := NewBulkExecutor(func(items []any) { executor := NewBulkExecutor(func(items []interface{}) {
assert.Equal(t, size, len(items)) assert.Equal(t, size, len(items))
wait.Done() wait.Done()
}, WithBulkTasks(caches), WithBulkInterval(time.Millisecond*100)) }, WithBulkTasks(caches), WithBulkInterval(time.Millisecond*100))
@@ -53,7 +53,7 @@ func TestBulkExecutorFlushInterval(t *testing.T) {
} }
func TestBulkExecutorEmpty(t *testing.T) { func TestBulkExecutorEmpty(t *testing.T) {
NewBulkExecutor(func(items []any) { NewBulkExecutor(func(items []interface{}) {
assert.Fail(t, "should not called") assert.Fail(t, "should not called")
}, WithBulkTasks(10), WithBulkInterval(time.Millisecond)) }, WithBulkTasks(10), WithBulkInterval(time.Millisecond))
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
@@ -67,7 +67,7 @@ func TestBulkExecutorFlush(t *testing.T) {
var wait sync.WaitGroup var wait sync.WaitGroup
wait.Add(1) wait.Add(1)
be := NewBulkExecutor(func(items []any) { be := NewBulkExecutor(func(items []interface{}) {
assert.Equal(t, tasks, len(items)) assert.Equal(t, tasks, len(items))
wait.Done() wait.Done()
}, WithBulkTasks(caches), WithBulkInterval(time.Minute)) }, WithBulkTasks(caches), WithBulkInterval(time.Minute))
@@ -78,11 +78,11 @@ func TestBulkExecutorFlush(t *testing.T) {
wait.Wait() wait.Wait()
} }
func TestBulkExecutorFlushSlowTasks(t *testing.T) { func TestBuldExecutorFlushSlowTasks(t *testing.T) {
const total = 1500 const total = 1500
lock := new(sync.Mutex) lock := new(sync.Mutex)
result := make([]any, 0, 10000) result := make([]interface{}, 0, 10000)
exec := NewBulkExecutor(func(tasks []any) { exec := NewBulkExecutor(func(tasks []interface{}) {
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
@@ -100,7 +100,7 @@ func TestBulkExecutorFlushSlowTasks(t *testing.T) {
func BenchmarkBulkExecutor(b *testing.B) { func BenchmarkBulkExecutor(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
be := NewBulkExecutor(func(tasks []any) { be := NewBulkExecutor(func(tasks []interface{}) {
time.Sleep(time.Millisecond * time.Duration(len(tasks))) time.Sleep(time.Millisecond * time.Duration(len(tasks)))
}) })
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {

View File

@@ -42,7 +42,7 @@ func NewChunkExecutor(execute Execute, opts ...ChunkOption) *ChunkExecutor {
} }
// Add adds task with given chunk size into ce. // Add adds task with given chunk size into ce.
func (ce *ChunkExecutor) Add(task any, size int) error { func (ce *ChunkExecutor) Add(task interface{}, size int) error {
ce.executor.Add(chunk{ ce.executor.Add(chunk{
val: task, val: task,
size: size, size: size,
@@ -82,25 +82,25 @@ func newChunkOptions() chunkOptions {
} }
type chunkContainer struct { type chunkContainer struct {
tasks []any tasks []interface{}
execute Execute execute Execute
size int size int
maxChunkSize int maxChunkSize int
} }
func (bc *chunkContainer) AddTask(task any) bool { func (bc *chunkContainer) AddTask(task interface{}) bool {
ck := task.(chunk) ck := task.(chunk)
bc.tasks = append(bc.tasks, ck.val) bc.tasks = append(bc.tasks, ck.val)
bc.size += ck.size bc.size += ck.size
return bc.size >= bc.maxChunkSize return bc.size >= bc.maxChunkSize
} }
func (bc *chunkContainer) Execute(tasks any) { func (bc *chunkContainer) Execute(tasks interface{}) {
vals := tasks.([]any) vals := tasks.([]interface{})
bc.execute(vals) bc.execute(vals)
} }
func (bc *chunkContainer) RemoveAll() any { func (bc *chunkContainer) RemoveAll() interface{} {
tasks := bc.tasks tasks := bc.tasks
bc.tasks = nil bc.tasks = nil
bc.size = 0 bc.size = 0
@@ -108,6 +108,6 @@ func (bc *chunkContainer) RemoveAll() any {
} }
type chunk struct { type chunk struct {
val any val interface{}
size int size int
} }

View File

@@ -12,7 +12,7 @@ func TestChunkExecutor(t *testing.T) {
var values []int var values []int
var lock sync.Mutex var lock sync.Mutex
executor := NewChunkExecutor(func(items []any) { executor := NewChunkExecutor(func(items []interface{}) {
lock.Lock() lock.Lock()
values = append(values, len(items)) values = append(values, len(items))
lock.Unlock() lock.Unlock()
@@ -40,7 +40,7 @@ func TestChunkExecutorFlushInterval(t *testing.T) {
var wait sync.WaitGroup var wait sync.WaitGroup
wait.Add(1) wait.Add(1)
executor := NewChunkExecutor(func(items []any) { executor := NewChunkExecutor(func(items []interface{}) {
assert.Equal(t, size, len(items)) assert.Equal(t, size, len(items))
wait.Done() wait.Done()
}, WithChunkBytes(caches), WithFlushInterval(time.Millisecond*100)) }, WithChunkBytes(caches), WithFlushInterval(time.Millisecond*100))
@@ -53,7 +53,7 @@ func TestChunkExecutorFlushInterval(t *testing.T) {
} }
func TestChunkExecutorEmpty(t *testing.T) { func TestChunkExecutorEmpty(t *testing.T) {
executor := NewChunkExecutor(func(items []any) { executor := NewChunkExecutor(func(items []interface{}) {
assert.Fail(t, "should not called") assert.Fail(t, "should not called")
}, WithChunkBytes(10), WithFlushInterval(time.Millisecond)) }, WithChunkBytes(10), WithFlushInterval(time.Millisecond))
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
@@ -68,7 +68,7 @@ func TestChunkExecutorFlush(t *testing.T) {
var wait sync.WaitGroup var wait sync.WaitGroup
wait.Add(1) wait.Add(1)
be := NewChunkExecutor(func(items []any) { be := NewChunkExecutor(func(items []interface{}) {
assert.Equal(t, tasks, len(items)) assert.Equal(t, tasks, len(items))
wait.Done() wait.Done()
}, WithChunkBytes(caches), WithFlushInterval(time.Minute)) }, WithChunkBytes(caches), WithFlushInterval(time.Minute))
@@ -82,7 +82,7 @@ func TestChunkExecutorFlush(t *testing.T) {
func BenchmarkChunkExecutor(b *testing.B) { func BenchmarkChunkExecutor(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
be := NewChunkExecutor(func(tasks []any) { be := NewChunkExecutor(func(tasks []interface{}) {
time.Sleep(time.Millisecond * time.Duration(len(tasks))) time.Sleep(time.Millisecond * time.Duration(len(tasks)))
}) })
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {

View File

@@ -21,16 +21,16 @@ type (
TaskContainer interface { TaskContainer interface {
// AddTask adds the task into the container. // AddTask adds the task into the container.
// Returns true if the container needs to be flushed after the addition. // Returns true if the container needs to be flushed after the addition.
AddTask(task any) bool AddTask(task interface{}) bool
// Execute handles the collected tasks by the container when flushing. // Execute handles the collected tasks by the container when flushing.
Execute(tasks any) Execute(tasks interface{})
// RemoveAll removes the contained tasks, and return them. // RemoveAll removes the contained tasks, and return them.
RemoveAll() any RemoveAll() interface{}
} }
// A PeriodicalExecutor is an executor that periodically execute tasks. // A PeriodicalExecutor is an executor that periodically execute tasks.
PeriodicalExecutor struct { PeriodicalExecutor struct {
commander chan any commander chan interface{}
interval time.Duration interval time.Duration
container TaskContainer container TaskContainer
waitGroup sync.WaitGroup waitGroup sync.WaitGroup
@@ -48,7 +48,7 @@ type (
func NewPeriodicalExecutor(interval time.Duration, container TaskContainer) *PeriodicalExecutor { func NewPeriodicalExecutor(interval time.Duration, container TaskContainer) *PeriodicalExecutor {
executor := &PeriodicalExecutor{ executor := &PeriodicalExecutor{
// buffer 1 to let the caller go quickly // buffer 1 to let the caller go quickly
commander: make(chan any, 1), commander: make(chan interface{}, 1),
interval: interval, interval: interval,
container: container, container: container,
confirmChan: make(chan lang.PlaceholderType), confirmChan: make(chan lang.PlaceholderType),
@@ -64,7 +64,7 @@ func NewPeriodicalExecutor(interval time.Duration, container TaskContainer) *Per
} }
// Add adds tasks into pe. // Add adds tasks into pe.
func (pe *PeriodicalExecutor) Add(task any) { func (pe *PeriodicalExecutor) Add(task interface{}) {
if vals, ok := pe.addAndCheck(task); ok { if vals, ok := pe.addAndCheck(task); ok {
pe.commander <- vals pe.commander <- vals
<-pe.confirmChan <-pe.confirmChan
@@ -74,14 +74,14 @@ func (pe *PeriodicalExecutor) Add(task any) {
// Flush forces pe to execute tasks. // Flush forces pe to execute tasks.
func (pe *PeriodicalExecutor) Flush() bool { func (pe *PeriodicalExecutor) Flush() bool {
pe.enterExecution() pe.enterExecution()
return pe.executeTasks(func() any { return pe.executeTasks(func() interface{} {
pe.lock.Lock() pe.lock.Lock()
defer pe.lock.Unlock() defer pe.lock.Unlock()
return pe.container.RemoveAll() return pe.container.RemoveAll()
}()) }())
} }
// Sync lets caller run fn thread-safe with pe, especially for the underlying container. // Sync lets caller to run fn thread-safe with pe, especially for the underlying container.
func (pe *PeriodicalExecutor) Sync(fn func()) { func (pe *PeriodicalExecutor) Sync(fn func()) {
pe.lock.Lock() pe.lock.Lock()
defer pe.lock.Unlock() defer pe.lock.Unlock()
@@ -96,7 +96,7 @@ func (pe *PeriodicalExecutor) Wait() {
}) })
} }
func (pe *PeriodicalExecutor) addAndCheck(task any) (any, bool) { func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool) {
pe.lock.Lock() pe.lock.Lock()
defer func() { defer func() {
if !pe.guarded { if !pe.guarded {
@@ -116,7 +116,7 @@ func (pe *PeriodicalExecutor) addAndCheck(task any) (any, bool) {
} }
func (pe *PeriodicalExecutor) backgroundFlush() { func (pe *PeriodicalExecutor) backgroundFlush() {
go func() { threading.GoSafe(func() {
// flush before quit goroutine to avoid missing tasks // flush before quit goroutine to avoid missing tasks
defer pe.Flush() defer pe.Flush()
@@ -144,7 +144,7 @@ func (pe *PeriodicalExecutor) backgroundFlush() {
} }
} }
} }
}() })
} }
func (pe *PeriodicalExecutor) doneExecution() { func (pe *PeriodicalExecutor) doneExecution() {
@@ -157,20 +157,18 @@ func (pe *PeriodicalExecutor) enterExecution() {
}) })
} }
func (pe *PeriodicalExecutor) executeTasks(tasks any) bool { func (pe *PeriodicalExecutor) executeTasks(tasks interface{}) bool {
defer pe.doneExecution() defer pe.doneExecution()
ok := pe.hasTasks(tasks) ok := pe.hasTasks(tasks)
if ok { if ok {
threading.RunSafe(func() { pe.container.Execute(tasks)
pe.container.Execute(tasks)
})
} }
return ok return ok
} }
func (pe *PeriodicalExecutor) hasTasks(tasks any) bool { func (pe *PeriodicalExecutor) hasTasks(tasks interface{}) bool {
if tasks == nil { if tasks == nil {
return false return false
} }

View File

@@ -17,22 +17,22 @@ const threshold = 10
type container struct { type container struct {
interval time.Duration interval time.Duration
tasks []int tasks []int
execute func(tasks any) execute func(tasks interface{})
} }
func newContainer(interval time.Duration, execute func(tasks any)) *container { func newContainer(interval time.Duration, execute func(tasks interface{})) *container {
return &container{ return &container{
interval: interval, interval: interval,
execute: execute, execute: execute,
} }
} }
func (c *container) AddTask(task any) bool { func (c *container) AddTask(task interface{}) bool {
c.tasks = append(c.tasks, task.(int)) c.tasks = append(c.tasks, task.(int))
return len(c.tasks) > threshold return len(c.tasks) > threshold
} }
func (c *container) Execute(tasks any) { func (c *container) Execute(tasks interface{}) {
if c.execute != nil { if c.execute != nil {
c.execute(tasks) c.execute(tasks)
} else { } else {
@@ -40,7 +40,7 @@ func (c *container) Execute(tasks any) {
} }
} }
func (c *container) RemoveAll() any { func (c *container) RemoveAll() interface{} {
tasks := c.tasks tasks := c.tasks
c.tasks = nil c.tasks = nil
return tasks return tasks
@@ -76,7 +76,7 @@ func TestPeriodicalExecutor_Bulk(t *testing.T) {
var vals []int var vals []int
// avoid data race // avoid data race
var lock sync.Mutex var lock sync.Mutex
exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks any) { exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks interface{}) {
t := tasks.([]int) t := tasks.([]int)
for _, each := range t { for _, each := range t {
lock.Lock() lock.Lock()
@@ -108,83 +108,25 @@ func TestPeriodicalExecutor_Bulk(t *testing.T) {
lock.Unlock() lock.Unlock()
} }
func TestPeriodicalExecutor_Panic(t *testing.T) {
// avoid data race
var lock sync.Mutex
ticker := timex.NewFakeTicker()
var (
executedTasks []int
expected []int
)
executor := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks any) {
tt := tasks.([]int)
lock.Lock()
executedTasks = append(executedTasks, tt...)
lock.Unlock()
if tt[0] == 0 {
panic("test")
}
}))
executor.newTicker = func(duration time.Duration) timex.Ticker {
return ticker
}
for i := 0; i < 30; i++ {
executor.Add(i)
expected = append(expected, i)
}
ticker.Tick()
ticker.Tick()
time.Sleep(time.Millisecond)
lock.Lock()
assert.Equal(t, expected, executedTasks)
lock.Unlock()
}
func TestPeriodicalExecutor_FlushPanic(t *testing.T) {
var (
executedTasks []int
expected []int
lock sync.Mutex
)
executor := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks any) {
tt := tasks.([]int)
lock.Lock()
executedTasks = append(executedTasks, tt...)
lock.Unlock()
if tt[0] == 0 {
panic("flush panic")
}
}))
for i := 0; i < 8; i++ {
executor.Add(i)
expected = append(expected, i)
}
executor.Flush()
lock.Lock()
assert.Equal(t, expected, executedTasks)
lock.Unlock()
}
func TestPeriodicalExecutor_Wait(t *testing.T) { func TestPeriodicalExecutor_Wait(t *testing.T) {
var lock sync.Mutex var lock sync.Mutex
executor := NewBulkExecutor(func(tasks []any) { executer := NewBulkExecutor(func(tasks []interface{}) {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
}, WithBulkTasks(1), WithBulkInterval(time.Second)) }, WithBulkTasks(1), WithBulkInterval(time.Second))
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
executor.Add(1) executer.Add(1)
} }
executor.Flush() executer.Flush()
executor.Wait() executer.Wait()
} }
func TestPeriodicalExecutor_WaitFast(t *testing.T) { func TestPeriodicalExecutor_WaitFast(t *testing.T) {
const total = 3 const total = 3
var cnt int var cnt int
var lock sync.Mutex var lock sync.Mutex
executor := NewBulkExecutor(func(tasks []any) { executer := NewBulkExecutor(func(tasks []interface{}) {
defer func() { defer func() {
cnt++ cnt++
}() }()
@@ -193,15 +135,15 @@ func TestPeriodicalExecutor_WaitFast(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
}, WithBulkTasks(1), WithBulkInterval(10*time.Millisecond)) }, WithBulkTasks(1), WithBulkInterval(10*time.Millisecond))
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
executor.Add(2) executer.Add(2)
} }
executor.Flush() executer.Flush()
executor.Wait() executer.Wait()
assert.Equal(t, total, cnt) assert.Equal(t, total, cnt)
} }
func TestPeriodicalExecutor_Deadlock(t *testing.T) { func TestPeriodicalExecutor_Deadlock(t *testing.T) {
executor := NewBulkExecutor(func(tasks []any) { executor := NewBulkExecutor(func(tasks []interface{}) {
}, WithBulkTasks(1), WithBulkInterval(time.Millisecond)) }, WithBulkTasks(1), WithBulkInterval(time.Millisecond))
for i := 0; i < 1e5; i++ { for i := 0; i < 1e5; i++ {
executor.Add(1) executor.Add(1)
@@ -209,7 +151,13 @@ func TestPeriodicalExecutor_Deadlock(t *testing.T) {
} }
func TestPeriodicalExecutor_hasTasks(t *testing.T) { func TestPeriodicalExecutor_hasTasks(t *testing.T) {
ticker := timex.NewFakeTicker()
defer ticker.Stop()
exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, nil)) exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, nil))
exec.newTicker = func(d time.Duration) timex.Ticker {
return ticker
}
assert.False(t, exec.hasTasks(nil)) assert.False(t, exec.hasTasks(nil))
assert.True(t, exec.hasTasks(1)) assert.True(t, exec.hasTasks(1))
} }

View File

@@ -5,4 +5,4 @@ import "time"
const defaultFlushInterval = time.Second const defaultFlushInterval = time.Second
// Execute defines the method to execute tasks. // Execute defines the method to execute tasks.
type Execute func(tasks []any) type Execute func(tasks []interface{})

View File

@@ -74,11 +74,6 @@ func TestFirstLineShort(t *testing.T) {
assert.Equal(t, "first line", val) assert.Equal(t, "first line", val)
} }
func TestFirstLineError(t *testing.T) {
_, err := FirstLine("/tmp/does-not-exist")
assert.Error(t, err)
}
func TestLastLine(t *testing.T) { func TestLastLine(t *testing.T) {
filename, err := fs.TempFilenameWithText(text) filename, err := fs.TempFilenameWithText(text)
assert.Nil(t, err) assert.Nil(t, err)
@@ -118,8 +113,3 @@ func TestLastLineWithLastNewlineShort(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "last line", val) assert.Equal(t, "last line", val)
} }
func TestLastLineError(t *testing.T) {
_, err := LastLine("/tmp/does-not-exist")
assert.Error(t, err)
}

View File

@@ -5,7 +5,7 @@ import "gopkg.in/cheggaaa/pb.v1"
type ( type (
// A Scanner is used to read lines. // A Scanner is used to read lines.
Scanner interface { Scanner interface {
// Scan checks if it has remaining to read. // Scan checks if has remaining to read.
Scan() bool Scan() bool
// Text returns next line. // Text returns next line.
Text() string Text() string

View File

@@ -1,4 +1,5 @@
//go:build windows //go:build windows
// +build windows
package fs package fs

View File

@@ -1,4 +1,5 @@
//go:build linux || darwin //go:build linux || darwin
// +build linux darwin
package fs package fs

View File

@@ -11,29 +11,29 @@ import (
// The file is kept as open, the caller should close the file handle, // The file is kept as open, the caller should close the file handle,
// and remove the file by name. // and remove the file by name.
func TempFileWithText(text string) (*os.File, error) { func TempFileWithText(text string) (*os.File, error) {
tmpFile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))) tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text)))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := os.WriteFile(tmpFile.Name(), []byte(text), os.ModeTemporary); err != nil { if err := os.WriteFile(tmpfile.Name(), []byte(text), os.ModeTemporary); err != nil {
return nil, err return nil, err
} }
return tmpFile, nil return tmpfile, nil
} }
// TempFilenameWithText creates the file with the given content, // TempFilenameWithText creates the file with the given content,
// and returns the filename (full path). // and returns the filename (full path).
// The caller should remove the file after use. // The caller should remove the file after use.
func TempFilenameWithText(text string) (string, error) { func TempFilenameWithText(text string) (string, error) {
tmpFile, err := TempFileWithText(text) tmpfile, err := TempFileWithText(text)
if err != nil { if err != nil {
return "", err return "", err
} }
filename := tmpFile.Name() filename := tmpfile.Name()
if err = tmpFile.Close(); err != nil { if err = tmpfile.Close(); err != nil {
return "", err return "", err
} }

View File

@@ -1,12 +1,6 @@
package fx package fx
import ( import "github.com/zeromicro/go-zero/core/errorx"
"context"
"errors"
"time"
"github.com/zeromicro/go-zero/core/errorx"
)
const defaultRetryTimes = 3 const defaultRetryTimes = 3
@@ -15,110 +9,36 @@ type (
RetryOption func(*retryOptions) RetryOption func(*retryOptions)
retryOptions struct { retryOptions struct {
times int times int
interval time.Duration
timeout time.Duration
ignoreErrors []error
} }
) )
// DoWithRetry runs fn, and retries if failed. Default to retry 3 times. // DoWithRetry runs fn, and retries if failed. Default to retry 3 times.
// Note that if the fn function accesses global variables outside the function
// and performs modification operations, it is best to lock them,
// otherwise there may be data race issues
func DoWithRetry(fn func() error, opts ...RetryOption) error { func DoWithRetry(fn func() error, opts ...RetryOption) error {
return retry(context.Background(), func(errChan chan error, retryCount int) {
errChan <- fn()
}, opts...)
}
// DoWithRetryCtx runs fn, and retries if failed. Default to retry 3 times.
// fn retryCount indicates the current number of retries, starting from 0
// Note that if the fn function accesses global variables outside the function
// and performs modification operations, it is best to lock them,
// otherwise there may be data race issues
func DoWithRetryCtx(ctx context.Context, fn func(ctx context.Context, retryCount int) error,
opts ...RetryOption) error {
return retry(ctx, func(errChan chan error, retryCount int) {
errChan <- fn(ctx, retryCount)
}, opts...)
}
func retry(ctx context.Context, fn func(errChan chan error, retryCount int), opts ...RetryOption) error {
options := newRetryOptions() options := newRetryOptions()
for _, opt := range opts { for _, opt := range opts {
opt(options) opt(options)
} }
var berr errorx.BatchError var berr errorx.BatchError
var cancelFunc context.CancelFunc
if options.timeout > 0 {
ctx, cancelFunc = context.WithTimeout(ctx, options.timeout)
defer cancelFunc()
}
errChan := make(chan error, 1)
for i := 0; i < options.times; i++ { for i := 0; i < options.times; i++ {
go fn(errChan, i) if err := fn(); err != nil {
berr.Add(err)
select { } else {
case err := <-errChan: return nil
if err != nil {
for _, ignoreErr := range options.ignoreErrors {
if errors.Is(err, ignoreErr) {
return nil
}
}
berr.Add(err)
} else {
return nil
}
case <-ctx.Done():
berr.Add(ctx.Err())
return berr.Err()
}
if options.interval > 0 {
select {
case <-ctx.Done():
berr.Add(ctx.Err())
return berr.Err()
case <-time.After(options.interval):
}
} }
} }
return berr.Err() return berr.Err()
} }
// WithIgnoreErrors Ignore the specified errors // WithRetry customize a DoWithRetry call with given retry times.
func WithIgnoreErrors(ignoreErrors []error) RetryOption {
return func(options *retryOptions) {
options.ignoreErrors = ignoreErrors
}
}
// WithInterval customizes a DoWithRetry call with given interval.
func WithInterval(interval time.Duration) RetryOption {
return func(options *retryOptions) {
options.interval = interval
}
}
// WithRetry customizes a DoWithRetry call with given retry times.
func WithRetry(times int) RetryOption { func WithRetry(times int) RetryOption {
return func(options *retryOptions) { return func(options *retryOptions) {
options.times = times options.times = times
} }
} }
// WithTimeout customizes a DoWithRetry call with given timeout.
func WithTimeout(timeout time.Duration) RetryOption {
return func(options *retryOptions) {
options.timeout = timeout
}
}
func newRetryOptions() *retryOptions { func newRetryOptions() *retryOptions {
return &retryOptions{ return &retryOptions{
times: defaultRetryTimes, times: defaultRetryTimes,

View File

@@ -1,10 +1,8 @@
package fx package fx
import ( import (
"context"
"errors" "errors"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -14,153 +12,31 @@ func TestRetry(t *testing.T) {
return errors.New("any") return errors.New("any")
})) }))
times1 := 0 var times int
assert.Nil(t, DoWithRetry(func() error { assert.Nil(t, DoWithRetry(func() error {
times1++ times++
if times1 == defaultRetryTimes { if times == defaultRetryTimes {
return nil return nil
} }
return errors.New("any") return errors.New("any")
})) }))
times2 := 0 times = 0
assert.NotNil(t, DoWithRetry(func() error { assert.NotNil(t, DoWithRetry(func() error {
times2++ times++
if times2 == defaultRetryTimes+1 { if times == defaultRetryTimes+1 {
return nil return nil
} }
return errors.New("any") return errors.New("any")
})) }))
total := 2 * defaultRetryTimes total := 2 * defaultRetryTimes
times3 := 0 times = 0
assert.Nil(t, DoWithRetry(func() error { assert.Nil(t, DoWithRetry(func() error {
times3++ times++
if times3 == total { if times == total {
return nil return nil
} }
return errors.New("any") return errors.New("any")
}, WithRetry(total))) }, WithRetry(total)))
} }
func TestRetryWithTimeout(t *testing.T) {
assert.Nil(t, DoWithRetry(func() error {
return nil
}, WithTimeout(time.Millisecond*500)))
times1 := 0
assert.Nil(t, DoWithRetry(func() error {
times1++
if times1 == 1 {
return errors.New("any ")
}
time.Sleep(time.Millisecond * 150)
return nil
}, WithTimeout(time.Millisecond*250)))
total := defaultRetryTimes
times2 := 0
assert.Nil(t, DoWithRetry(func() error {
times2++
if times2 == total {
return nil
}
time.Sleep(time.Millisecond * 50)
return errors.New("any")
}, WithTimeout(time.Millisecond*50*(time.Duration(total)+2))))
assert.NotNil(t, DoWithRetry(func() error {
return errors.New("any")
}, WithTimeout(time.Millisecond*250)))
}
func TestRetryWithInterval(t *testing.T) {
times1 := 0
assert.NotNil(t, DoWithRetry(func() error {
times1++
if times1 == 1 {
return errors.New("any")
}
time.Sleep(time.Millisecond * 150)
return nil
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
times2 := 0
assert.NotNil(t, DoWithRetry(func() error {
times2++
if times2 == 2 {
return nil
}
time.Sleep(time.Millisecond * 150)
return errors.New("any ")
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
}
func TestRetryWithWithIgnoreErrors(t *testing.T) {
ignoreErr1 := errors.New("ignore error1")
ignoreErr2 := errors.New("ignore error2")
ignoreErrs := []error{ignoreErr1, ignoreErr2}
assert.Nil(t, DoWithRetry(func() error {
return ignoreErr1
}, WithIgnoreErrors(ignoreErrs)))
assert.Nil(t, DoWithRetry(func() error {
return ignoreErr2
}, WithIgnoreErrors(ignoreErrs)))
assert.NotNil(t, DoWithRetry(func() error {
return errors.New("any")
}))
}
func TestRetryCtx(t *testing.T) {
t.Run("with timeout", func(t *testing.T) {
assert.NotNil(t, DoWithRetryCtx(context.Background(), func(ctx context.Context, retryCount int) error {
if retryCount == 0 {
return errors.New("any")
}
time.Sleep(time.Millisecond * 150)
return nil
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
assert.NotNil(t, DoWithRetryCtx(context.Background(), func(ctx context.Context, retryCount int) error {
if retryCount == 1 {
return nil
}
time.Sleep(time.Millisecond * 150)
return errors.New("any ")
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
})
t.Run("with deadline exceeded", func(t *testing.T) {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond*250))
defer cancel()
var times int
assert.Error(t, DoWithRetryCtx(ctx, func(ctx context.Context, retryCount int) error {
times++
time.Sleep(time.Millisecond * 150)
return errors.New("any")
}, WithInterval(time.Millisecond*150)))
assert.Equal(t, 1, times)
})
t.Run("with deadline not exceeded", func(t *testing.T) {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond*250))
defer cancel()
var times int
assert.NoError(t, DoWithRetryCtx(ctx, func(ctx context.Context, retryCount int) error {
times++
if times == defaultRetryTimes {
return nil
}
time.Sleep(time.Millisecond * 50)
return errors.New("any")
}))
assert.Equal(t, defaultRetryTimes, times)
})
}

View File

@@ -21,31 +21,31 @@ type (
} }
// FilterFunc defines the method to filter a Stream. // FilterFunc defines the method to filter a Stream.
FilterFunc func(item any) bool FilterFunc func(item interface{}) bool
// ForAllFunc defines the method to handle all elements in a Stream. // ForAllFunc defines the method to handle all elements in a Stream.
ForAllFunc func(pipe <-chan any) ForAllFunc func(pipe <-chan interface{})
// ForEachFunc defines the method to handle each element in a Stream. // ForEachFunc defines the method to handle each element in a Stream.
ForEachFunc func(item any) ForEachFunc func(item interface{})
// GenerateFunc defines the method to send elements into a Stream. // GenerateFunc defines the method to send elements into a Stream.
GenerateFunc func(source chan<- any) GenerateFunc func(source chan<- interface{})
// KeyFunc defines the method to generate keys for the elements in a Stream. // KeyFunc defines the method to generate keys for the elements in a Stream.
KeyFunc func(item any) any KeyFunc func(item interface{}) interface{}
// LessFunc defines the method to compare the elements in a Stream. // LessFunc defines the method to compare the elements in a Stream.
LessFunc func(a, b any) bool LessFunc func(a, b interface{}) bool
// MapFunc defines the method to map each element to another object in a Stream. // MapFunc defines the method to map each element to another object in a Stream.
MapFunc func(item any) any MapFunc func(item interface{}) interface{}
// Option defines the method to customize a Stream. // Option defines the method to customize a Stream.
Option func(opts *rxOptions) Option func(opts *rxOptions)
// ParallelFunc defines the method to handle elements parallelly. // ParallelFunc defines the method to handle elements parallelly.
ParallelFunc func(item any) ParallelFunc func(item interface{})
// ReduceFunc defines the method to reduce all the elements in a Stream. // ReduceFunc defines the method to reduce all the elements in a Stream.
ReduceFunc func(pipe <-chan any) (any, error) ReduceFunc func(pipe <-chan interface{}) (interface{}, error)
// WalkFunc defines the method to walk through all the elements in a Stream. // WalkFunc defines the method to walk through all the elements in a Stream.
WalkFunc func(item any, pipe chan<- any) WalkFunc func(item interface{}, pipe chan<- interface{})
// A Stream is a stream that can be used to do stream processing. // A Stream is a stream that can be used to do stream processing.
Stream struct { Stream struct {
source <-chan any source <-chan interface{}
} }
) )
@@ -56,7 +56,7 @@ func Concat(s Stream, others ...Stream) Stream {
// From constructs a Stream from the given GenerateFunc. // From constructs a Stream from the given GenerateFunc.
func From(generate GenerateFunc) Stream { func From(generate GenerateFunc) Stream {
source := make(chan any) source := make(chan interface{})
threading.GoSafe(func() { threading.GoSafe(func() {
defer close(source) defer close(source)
@@ -67,8 +67,8 @@ func From(generate GenerateFunc) Stream {
} }
// Just converts the given arbitrary items to a Stream. // Just converts the given arbitrary items to a Stream.
func Just(items ...any) Stream { func Just(items ...interface{}) Stream {
source := make(chan any, len(items)) source := make(chan interface{}, len(items))
for _, item := range items { for _, item := range items {
source <- item source <- item
} }
@@ -78,7 +78,7 @@ func Just(items ...any) Stream {
} }
// Range converts the given channel to a Stream. // Range converts the given channel to a Stream.
func Range(source <-chan any) Stream { func Range(source <-chan interface{}) Stream {
return Stream{ return Stream{
source: source, source: source,
} }
@@ -87,7 +87,7 @@ func Range(source <-chan any) Stream {
// AllMach returns whether all elements of this stream match the provided predicate. // AllMach returns whether all elements of this stream match the provided predicate.
// May not evaluate the predicate on all elements if not necessary for determining the result. // May not evaluate the predicate on all elements if not necessary for determining the result.
// If the stream is empty then true is returned and the predicate is not evaluated. // If the stream is empty then true is returned and the predicate is not evaluated.
func (s Stream) AllMach(predicate func(item any) bool) bool { func (s Stream) AllMach(predicate func(item interface{}) bool) bool {
for item := range s.source { for item := range s.source {
if !predicate(item) { if !predicate(item) {
// make sure the former goroutine not block, and current func returns fast. // make sure the former goroutine not block, and current func returns fast.
@@ -102,7 +102,7 @@ func (s Stream) AllMach(predicate func(item any) bool) bool {
// AnyMach returns whether any elements of this stream match the provided predicate. // AnyMach returns whether any elements of this stream match the provided predicate.
// May not evaluate the predicate on all elements if not necessary for determining the result. // May not evaluate the predicate on all elements if not necessary for determining the result.
// If the stream is empty then false is returned and the predicate is not evaluated. // If the stream is empty then false is returned and the predicate is not evaluated.
func (s Stream) AnyMach(predicate func(item any) bool) bool { func (s Stream) AnyMach(predicate func(item interface{}) bool) bool {
for item := range s.source { for item := range s.source {
if predicate(item) { if predicate(item) {
// make sure the former goroutine not block, and current func returns fast. // make sure the former goroutine not block, and current func returns fast.
@@ -121,7 +121,7 @@ func (s Stream) Buffer(n int) Stream {
n = 0 n = 0
} }
source := make(chan any, n) source := make(chan interface{}, n)
go func() { go func() {
for item := range s.source { for item := range s.source {
source <- item source <- item
@@ -134,7 +134,7 @@ func (s Stream) Buffer(n int) Stream {
// Concat returns a Stream that concatenated other streams // Concat returns a Stream that concatenated other streams
func (s Stream) Concat(others ...Stream) Stream { func (s Stream) Concat(others ...Stream) Stream {
source := make(chan any) source := make(chan interface{})
go func() { go func() {
group := threading.NewRoutineGroup() group := threading.NewRoutineGroup()
@@ -170,12 +170,12 @@ func (s Stream) Count() (count int) {
// Distinct removes the duplicated items base on the given KeyFunc. // Distinct removes the duplicated items base on the given KeyFunc.
func (s Stream) Distinct(fn KeyFunc) Stream { func (s Stream) Distinct(fn KeyFunc) Stream {
source := make(chan any) source := make(chan interface{})
threading.GoSafe(func() { threading.GoSafe(func() {
defer close(source) defer close(source)
keys := make(map[any]lang.PlaceholderType) keys := make(map[interface{}]lang.PlaceholderType)
for item := range s.source { for item := range s.source {
key := fn(item) key := fn(item)
if _, ok := keys[key]; !ok { if _, ok := keys[key]; !ok {
@@ -195,7 +195,7 @@ func (s Stream) Done() {
// Filter filters the items by the given FilterFunc. // Filter filters the items by the given FilterFunc.
func (s Stream) Filter(fn FilterFunc, opts ...Option) Stream { func (s Stream) Filter(fn FilterFunc, opts ...Option) Stream {
return s.Walk(func(item any, pipe chan<- any) { return s.Walk(func(item interface{}, pipe chan<- interface{}) {
if fn(item) { if fn(item) {
pipe <- item pipe <- item
} }
@@ -203,7 +203,7 @@ func (s Stream) Filter(fn FilterFunc, opts ...Option) Stream {
} }
// First returns the first item, nil if no items. // First returns the first item, nil if no items.
func (s Stream) First() any { func (s Stream) First() interface{} {
for item := range s.source { for item := range s.source {
// make sure the former goroutine not block, and current func returns fast. // make sure the former goroutine not block, and current func returns fast.
go drain(s.source) go drain(s.source)
@@ -229,13 +229,13 @@ func (s Stream) ForEach(fn ForEachFunc) {
// Group groups the elements into different groups based on their keys. // Group groups the elements into different groups based on their keys.
func (s Stream) Group(fn KeyFunc) Stream { func (s Stream) Group(fn KeyFunc) Stream {
groups := make(map[any][]any) groups := make(map[interface{}][]interface{})
for item := range s.source { for item := range s.source {
key := fn(item) key := fn(item)
groups[key] = append(groups[key], item) groups[key] = append(groups[key], item)
} }
source := make(chan any) source := make(chan interface{})
go func() { go func() {
for _, group := range groups { for _, group := range groups {
source <- group source <- group
@@ -252,7 +252,7 @@ func (s Stream) Head(n int64) Stream {
panic("n must be greater than 0") panic("n must be greater than 0")
} }
source := make(chan any) source := make(chan interface{})
go func() { go func() {
for item := range s.source { for item := range s.source {
@@ -279,7 +279,7 @@ func (s Stream) Head(n int64) Stream {
} }
// Last returns the last item, or nil if no items. // Last returns the last item, or nil if no items.
func (s Stream) Last() (item any) { func (s Stream) Last() (item interface{}) {
for item = range s.source { for item = range s.source {
} }
return return
@@ -287,53 +287,29 @@ func (s Stream) Last() (item any) {
// Map converts each item to another corresponding item, which means it's a 1:1 model. // Map converts each item to another corresponding item, which means it's a 1:1 model.
func (s Stream) Map(fn MapFunc, opts ...Option) Stream { func (s Stream) Map(fn MapFunc, opts ...Option) Stream {
return s.Walk(func(item any, pipe chan<- any) { return s.Walk(func(item interface{}, pipe chan<- interface{}) {
pipe <- fn(item) pipe <- fn(item)
}, opts...) }, opts...)
} }
// Max returns the maximum item from the underlying source.
func (s Stream) Max(less LessFunc) any {
var max any
for item := range s.source {
if max == nil || less(max, item) {
max = item
}
}
return max
}
// Merge merges all the items into a slice and generates a new stream. // Merge merges all the items into a slice and generates a new stream.
func (s Stream) Merge() Stream { func (s Stream) Merge() Stream {
var items []any var items []interface{}
for item := range s.source { for item := range s.source {
items = append(items, item) items = append(items, item)
} }
source := make(chan any, 1) source := make(chan interface{}, 1)
source <- items source <- items
close(source) close(source)
return Range(source) return Range(source)
} }
// Min returns the minimum item from the underlying source.
func (s Stream) Min(less LessFunc) any {
var min any
for item := range s.source {
if min == nil || less(item, min) {
min = item
}
}
return min
}
// NoneMatch returns whether all elements of this stream don't match the provided predicate. // NoneMatch returns whether all elements of this stream don't match the provided predicate.
// May not evaluate the predicate on all elements if not necessary for determining the result. // May not evaluate the predicate on all elements if not necessary for determining the result.
// If the stream is empty then true is returned and the predicate is not evaluated. // If the stream is empty then true is returned and the predicate is not evaluated.
func (s Stream) NoneMatch(predicate func(item any) bool) bool { func (s Stream) NoneMatch(predicate func(item interface{}) bool) bool {
for item := range s.source { for item := range s.source {
if predicate(item) { if predicate(item) {
// make sure the former goroutine not block, and current func returns fast. // make sure the former goroutine not block, and current func returns fast.
@@ -347,19 +323,19 @@ func (s Stream) NoneMatch(predicate func(item any) bool) bool {
// Parallel applies the given ParallelFunc to each item concurrently with given number of workers. // Parallel applies the given ParallelFunc to each item concurrently with given number of workers.
func (s Stream) Parallel(fn ParallelFunc, opts ...Option) { func (s Stream) Parallel(fn ParallelFunc, opts ...Option) {
s.Walk(func(item any, pipe chan<- any) { s.Walk(func(item interface{}, pipe chan<- interface{}) {
fn(item) fn(item)
}, opts...).Done() }, opts...).Done()
} }
// Reduce is a utility method to let the caller deal with the underlying channel. // Reduce is an utility method to let the caller deal with the underlying channel.
func (s Stream) Reduce(fn ReduceFunc) (any, error) { func (s Stream) Reduce(fn ReduceFunc) (interface{}, error) {
return fn(s.source) return fn(s.source)
} }
// Reverse reverses the elements in the stream. // Reverse reverses the elements in the stream.
func (s Stream) Reverse() Stream { func (s Stream) Reverse() Stream {
var items []any var items []interface{}
for item := range s.source { for item := range s.source {
items = append(items, item) items = append(items, item)
} }
@@ -381,7 +357,7 @@ func (s Stream) Skip(n int64) Stream {
return s return s
} }
source := make(chan any) source := make(chan interface{})
go func() { go func() {
for item := range s.source { for item := range s.source {
@@ -400,7 +376,7 @@ func (s Stream) Skip(n int64) Stream {
// Sort sorts the items from the underlying source. // Sort sorts the items from the underlying source.
func (s Stream) Sort(less LessFunc) Stream { func (s Stream) Sort(less LessFunc) Stream {
var items []any var items []interface{}
for item := range s.source { for item := range s.source {
items = append(items, item) items = append(items, item)
} }
@@ -418,9 +394,9 @@ func (s Stream) Split(n int) Stream {
panic("n should be greater than 0") panic("n should be greater than 0")
} }
source := make(chan any) source := make(chan interface{})
go func() { go func() {
var chunk []any var chunk []interface{}
for item := range s.source { for item := range s.source {
chunk = append(chunk, item) chunk = append(chunk, item)
if len(chunk) == n { if len(chunk) == n {
@@ -443,7 +419,7 @@ func (s Stream) Tail(n int64) Stream {
panic("n should be greater than 0") panic("n should be greater than 0")
} }
source := make(chan any) source := make(chan interface{})
go func() { go func() {
ring := collection.NewRing(int(n)) ring := collection.NewRing(int(n))
@@ -470,7 +446,7 @@ func (s Stream) Walk(fn WalkFunc, opts ...Option) Stream {
} }
func (s Stream) walkLimited(fn WalkFunc, option *rxOptions) Stream { func (s Stream) walkLimited(fn WalkFunc, option *rxOptions) Stream {
pipe := make(chan any, option.workers) pipe := make(chan interface{}, option.workers)
go func() { go func() {
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -501,7 +477,7 @@ func (s Stream) walkLimited(fn WalkFunc, option *rxOptions) Stream {
} }
func (s Stream) walkUnlimited(fn WalkFunc, option *rxOptions) Stream { func (s Stream) walkUnlimited(fn WalkFunc, option *rxOptions) Stream {
pipe := make(chan any, option.workers) pipe := make(chan interface{}, option.workers)
go func() { go func() {
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -553,7 +529,7 @@ func buildOptions(opts ...Option) *rxOptions {
} }
// drain drains the given channel. // drain drains the given channel.
func drain(channel <-chan any) { func drain(channel <-chan interface{}) {
for range channel { for range channel {
} }
} }

View File

@@ -23,7 +23,7 @@ func TestBuffer(t *testing.T) {
var count int32 var count int32
var wait sync.WaitGroup var wait sync.WaitGroup
wait.Add(1) wait.Add(1)
From(func(source chan<- any) { From(func(source chan<- interface{}) {
ticker := time.NewTicker(10 * time.Millisecond) ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop() defer ticker.Stop()
@@ -36,7 +36,7 @@ func TestBuffer(t *testing.T) {
return return
} }
} }
}).Buffer(N).ForAll(func(pipe <-chan any) { }).Buffer(N).ForAll(func(pipe <-chan interface{}) {
wait.Wait() wait.Wait()
// why N+1, because take one more to wait for sending into the channel // why N+1, because take one more to wait for sending into the channel
assert.Equal(t, int32(N+1), atomic.LoadInt32(&count)) assert.Equal(t, int32(N+1), atomic.LoadInt32(&count))
@@ -47,7 +47,7 @@ func TestBuffer(t *testing.T) {
func TestBufferNegative(t *testing.T) { func TestBufferNegative(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var result int var result int
Just(1, 2, 3, 4).Buffer(-1).Reduce(func(pipe <-chan any) (any, error) { Just(1, 2, 3, 4).Buffer(-1).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe { for item := range pipe {
result += item.(int) result += item.(int)
} }
@@ -61,22 +61,22 @@ func TestCount(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
elements []any elements []interface{}
}{ }{
{ {
name: "no elements with nil", name: "no elements with nil",
}, },
{ {
name: "no elements", name: "no elements",
elements: []any{}, elements: []interface{}{},
}, },
{ {
name: "1 element", name: "1 element",
elements: []any{1}, elements: []interface{}{1},
}, },
{ {
name: "multiple elements", name: "multiple elements",
elements: []any{1, 2, 3}, elements: []interface{}{1, 2, 3},
}, },
} }
@@ -92,7 +92,7 @@ func TestCount(t *testing.T) {
func TestDone(t *testing.T) { func TestDone(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var count int32 var count int32
Just(1, 2, 3).Walk(func(item any, pipe chan<- any) { Just(1, 2, 3).Walk(func(item interface{}, pipe chan<- interface{}) {
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, int32(item.(int))) atomic.AddInt32(&count, int32(item.(int)))
}).Done() }).Done()
@@ -103,7 +103,7 @@ func TestDone(t *testing.T) {
func TestJust(t *testing.T) { func TestJust(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var result int var result int
Just(1, 2, 3, 4).Reduce(func(pipe <-chan any) (any, error) { Just(1, 2, 3, 4).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe { for item := range pipe {
result += item.(int) result += item.(int)
} }
@@ -116,9 +116,9 @@ func TestJust(t *testing.T) {
func TestDistinct(t *testing.T) { func TestDistinct(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var result int var result int
Just(4, 1, 3, 2, 3, 4).Distinct(func(item any) any { Just(4, 1, 3, 2, 3, 4).Distinct(func(item interface{}) interface{} {
return item return item
}).Reduce(func(pipe <-chan any) (any, error) { }).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe { for item := range pipe {
result += item.(int) result += item.(int)
} }
@@ -131,9 +131,9 @@ func TestDistinct(t *testing.T) {
func TestFilter(t *testing.T) { func TestFilter(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var result int var result int
Just(1, 2, 3, 4).Filter(func(item any) bool { Just(1, 2, 3, 4).Filter(func(item interface{}) bool {
return item.(int)%2 == 0 return item.(int)%2 == 0
}).Reduce(func(pipe <-chan any) (any, error) { }).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe { for item := range pipe {
result += item.(int) result += item.(int)
} }
@@ -154,9 +154,9 @@ func TestFirst(t *testing.T) {
func TestForAll(t *testing.T) { func TestForAll(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var result int var result int
Just(1, 2, 3, 4).Filter(func(item any) bool { Just(1, 2, 3, 4).Filter(func(item interface{}) bool {
return item.(int)%2 == 0 return item.(int)%2 == 0
}).ForAll(func(pipe <-chan any) { }).ForAll(func(pipe <-chan interface{}) {
for item := range pipe { for item := range pipe {
result += item.(int) result += item.(int)
} }
@@ -168,11 +168,11 @@ func TestForAll(t *testing.T) {
func TestGroup(t *testing.T) { func TestGroup(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var groups [][]int var groups [][]int
Just(10, 11, 20, 21).Group(func(item any) any { Just(10, 11, 20, 21).Group(func(item interface{}) interface{} {
v := item.(int) v := item.(int)
return v / 10 return v / 10
}).ForEach(func(item any) { }).ForEach(func(item interface{}) {
v := item.([]any) v := item.([]interface{})
var group []int var group []int
for _, each := range v { for _, each := range v {
group = append(group, each.(int)) group = append(group, each.(int))
@@ -191,7 +191,7 @@ func TestGroup(t *testing.T) {
func TestHead(t *testing.T) { func TestHead(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var result int var result int
Just(1, 2, 3, 4).Head(2).Reduce(func(pipe <-chan any) (any, error) { Just(1, 2, 3, 4).Head(2).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe { for item := range pipe {
result += item.(int) result += item.(int)
} }
@@ -204,7 +204,7 @@ func TestHead(t *testing.T) {
func TestHeadZero(t *testing.T) { func TestHeadZero(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
assert.Panics(t, func() { assert.Panics(t, func() {
Just(1, 2, 3, 4).Head(0).Reduce(func(pipe <-chan any) (any, error) { Just(1, 2, 3, 4).Head(0).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
return nil, nil return nil, nil
}) })
}) })
@@ -214,7 +214,7 @@ func TestHeadZero(t *testing.T) {
func TestHeadMore(t *testing.T) { func TestHeadMore(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var result int var result int
Just(1, 2, 3, 4).Head(6).Reduce(func(pipe <-chan any) (any, error) { Just(1, 2, 3, 4).Head(6).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe { for item := range pipe {
result += item.(int) result += item.(int)
} }
@@ -245,14 +245,14 @@ func TestMap(t *testing.T) {
expect int expect int
}{ }{
{ {
mapper: func(item any) any { mapper: func(item interface{}) interface{} {
v := item.(int) v := item.(int)
return v * v return v * v
}, },
expect: 30, expect: 30,
}, },
{ {
mapper: func(item any) any { mapper: func(item interface{}) interface{} {
v := item.(int) v := item.(int)
if v%2 == 0 { if v%2 == 0 {
return 0 return 0
@@ -262,7 +262,7 @@ func TestMap(t *testing.T) {
expect: 10, expect: 10,
}, },
{ {
mapper: func(item any) any { mapper: func(item interface{}) interface{} {
v := item.(int) v := item.(int)
if v%2 == 0 { if v%2 == 0 {
panic(v) panic(v)
@@ -283,12 +283,12 @@ func TestMap(t *testing.T) {
} else { } else {
workers = runtime.NumCPU() workers = runtime.NumCPU()
} }
From(func(source chan<- any) { From(func(source chan<- interface{}) {
for i := 1; i < 5; i++ { for i := 1; i < 5; i++ {
source <- i source <- i
} }
}).Map(test.mapper, WithWorkers(workers)).Reduce( }).Map(test.mapper, WithWorkers(workers)).Reduce(
func(pipe <-chan any) (any, error) { func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe { for item := range pipe {
result += item.(int) result += item.(int)
} }
@@ -303,8 +303,8 @@ func TestMap(t *testing.T) {
func TestMerge(t *testing.T) { func TestMerge(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
Just(1, 2, 3, 4).Merge().ForEach(func(item any) { Just(1, 2, 3, 4).Merge().ForEach(func(item interface{}) {
assert.ElementsMatch(t, []any{1, 2, 3, 4}, item.([]any)) assert.ElementsMatch(t, []interface{}{1, 2, 3, 4}, item.([]interface{}))
}) })
}) })
} }
@@ -312,7 +312,7 @@ func TestMerge(t *testing.T) {
func TestParallelJust(t *testing.T) { func TestParallelJust(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var count int32 var count int32
Just(1, 2, 3).Parallel(func(item any) { Just(1, 2, 3).Parallel(func(item interface{}) {
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, int32(item.(int))) atomic.AddInt32(&count, int32(item.(int)))
}, UnlimitedWorkers()) }, UnlimitedWorkers())
@@ -322,8 +322,8 @@ func TestParallelJust(t *testing.T) {
func TestReverse(t *testing.T) { func TestReverse(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
Just(1, 2, 3, 4).Reverse().Merge().ForEach(func(item any) { Just(1, 2, 3, 4).Reverse().Merge().ForEach(func(item interface{}) {
assert.ElementsMatch(t, []any{4, 3, 2, 1}, item.([]any)) assert.ElementsMatch(t, []interface{}{4, 3, 2, 1}, item.([]interface{}))
}) })
}) })
} }
@@ -331,9 +331,9 @@ func TestReverse(t *testing.T) {
func TestSort(t *testing.T) { func TestSort(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var prev int var prev int
Just(5, 3, 7, 1, 9, 6, 4, 8, 2).Sort(func(a, b any) bool { Just(5, 3, 7, 1, 9, 6, 4, 8, 2).Sort(func(a, b interface{}) bool {
return a.(int) < b.(int) return a.(int) < b.(int)
}).ForEach(func(item any) { }).ForEach(func(item interface{}) {
next := item.(int) next := item.(int)
assert.True(t, prev < next) assert.True(t, prev < next)
prev = next prev = next
@@ -346,12 +346,12 @@ func TestSplit(t *testing.T) {
assert.Panics(t, func() { assert.Panics(t, func() {
Just(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).Split(0).Done() Just(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).Split(0).Done()
}) })
var chunks [][]any var chunks [][]interface{}
Just(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).Split(4).ForEach(func(item any) { Just(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).Split(4).ForEach(func(item interface{}) {
chunk := item.([]any) chunk := item.([]interface{})
chunks = append(chunks, chunk) chunks = append(chunks, chunk)
}) })
assert.EqualValues(t, [][]any{ assert.EqualValues(t, [][]interface{}{
{1, 2, 3, 4}, {1, 2, 3, 4},
{5, 6, 7, 8}, {5, 6, 7, 8},
{9, 10}, {9, 10},
@@ -362,7 +362,7 @@ func TestSplit(t *testing.T) {
func TestTail(t *testing.T) { func TestTail(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var result int var result int
Just(1, 2, 3, 4).Tail(2).Reduce(func(pipe <-chan any) (any, error) { Just(1, 2, 3, 4).Tail(2).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe { for item := range pipe {
result += item.(int) result += item.(int)
} }
@@ -375,7 +375,7 @@ func TestTail(t *testing.T) {
func TestTailZero(t *testing.T) { func TestTailZero(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
assert.Panics(t, func() { assert.Panics(t, func() {
Just(1, 2, 3, 4).Tail(0).Reduce(func(pipe <-chan any) (any, error) { Just(1, 2, 3, 4).Tail(0).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
return nil, nil return nil, nil
}) })
}) })
@@ -385,11 +385,11 @@ func TestTailZero(t *testing.T) {
func TestWalk(t *testing.T) { func TestWalk(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var result int var result int
Just(1, 2, 3, 4, 5).Walk(func(item any, pipe chan<- any) { Just(1, 2, 3, 4, 5).Walk(func(item interface{}, pipe chan<- interface{}) {
if item.(int)%2 != 0 { if item.(int)%2 != 0 {
pipe <- item pipe <- item
} }
}, UnlimitedWorkers()).ForEach(func(item any) { }, UnlimitedWorkers()).ForEach(func(item interface{}) {
result += item.(int) result += item.(int)
}) })
assert.Equal(t, 9, result) assert.Equal(t, 9, result)
@@ -398,16 +398,16 @@ func TestWalk(t *testing.T) {
func TestStream_AnyMach(t *testing.T) { func TestStream_AnyMach(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
assetEqual(t, false, Just(1, 2, 3).AnyMach(func(item any) bool { assetEqual(t, false, Just(1, 2, 3).AnyMach(func(item interface{}) bool {
return item.(int) == 4 return item.(int) == 4
})) }))
assetEqual(t, false, Just(1, 2, 3).AnyMach(func(item any) bool { assetEqual(t, false, Just(1, 2, 3).AnyMach(func(item interface{}) bool {
return item.(int) == 0 return item.(int) == 0
})) }))
assetEqual(t, true, Just(1, 2, 3).AnyMach(func(item any) bool { assetEqual(t, true, Just(1, 2, 3).AnyMach(func(item interface{}) bool {
return item.(int) == 2 return item.(int) == 2
})) }))
assetEqual(t, true, Just(1, 2, 3).AnyMach(func(item any) bool { assetEqual(t, true, Just(1, 2, 3).AnyMach(func(item interface{}) bool {
return item.(int) == 2 return item.(int) == 2
})) }))
}) })
@@ -416,17 +416,17 @@ func TestStream_AnyMach(t *testing.T) {
func TestStream_AllMach(t *testing.T) { func TestStream_AllMach(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
assetEqual( assetEqual(
t, true, Just(1, 2, 3).AllMach(func(item any) bool { t, true, Just(1, 2, 3).AllMach(func(item interface{}) bool {
return true return true
}), }),
) )
assetEqual( assetEqual(
t, false, Just(1, 2, 3).AllMach(func(item any) bool { t, false, Just(1, 2, 3).AllMach(func(item interface{}) bool {
return false return false
}), }),
) )
assetEqual( assetEqual(
t, false, Just(1, 2, 3).AllMach(func(item any) bool { t, false, Just(1, 2, 3).AllMach(func(item interface{}) bool {
return item.(int) == 1 return item.(int) == 1
}), }),
) )
@@ -436,17 +436,17 @@ func TestStream_AllMach(t *testing.T) {
func TestStream_NoneMatch(t *testing.T) { func TestStream_NoneMatch(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
assetEqual( assetEqual(
t, true, Just(1, 2, 3).NoneMatch(func(item any) bool { t, true, Just(1, 2, 3).NoneMatch(func(item interface{}) bool {
return false return false
}), }),
) )
assetEqual( assetEqual(
t, false, Just(1, 2, 3).NoneMatch(func(item any) bool { t, false, Just(1, 2, 3).NoneMatch(func(item interface{}) bool {
return true return true
}), }),
) )
assetEqual( assetEqual(
t, true, Just(1, 2, 3).NoneMatch(func(item any) bool { t, true, Just(1, 2, 3).NoneMatch(func(item interface{}) bool {
return item.(int) == 4 return item.(int) == 4
}), }),
) )
@@ -455,19 +455,19 @@ func TestStream_NoneMatch(t *testing.T) {
func TestConcat(t *testing.T) { func TestConcat(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
a1 := []any{1, 2, 3} a1 := []interface{}{1, 2, 3}
a2 := []any{4, 5, 6} a2 := []interface{}{4, 5, 6}
s1 := Just(a1...) s1 := Just(a1...)
s2 := Just(a2...) s2 := Just(a2...)
stream := Concat(s1, s2) stream := Concat(s1, s2)
var items []any var items []interface{}
for item := range stream.source { for item := range stream.source {
items = append(items, item) items = append(items, item)
} }
sort.Slice(items, func(i, j int) bool { sort.Slice(items, func(i, j int) bool {
return items[i].(int) < items[j].(int) return items[i].(int) < items[j].(int)
}) })
ints := make([]any, 0) ints := make([]interface{}, 0)
ints = append(ints, a1...) ints = append(ints, a1...)
ints = append(ints, a2...) ints = append(ints, a2...)
assetEqual(t, ints, items) assetEqual(t, ints, items)
@@ -479,7 +479,7 @@ func TestStream_Skip(t *testing.T) {
assetEqual(t, 3, Just(1, 2, 3, 4).Skip(1).Count()) assetEqual(t, 3, Just(1, 2, 3, 4).Skip(1).Count())
assetEqual(t, 1, Just(1, 2, 3, 4).Skip(3).Count()) assetEqual(t, 1, Just(1, 2, 3, 4).Skip(3).Count())
assetEqual(t, 4, Just(1, 2, 3, 4).Skip(0).Count()) assetEqual(t, 4, Just(1, 2, 3, 4).Skip(0).Count())
equal(t, Just(1, 2, 3, 4).Skip(3), []any{4}) equal(t, Just(1, 2, 3, 4).Skip(3), []interface{}{4})
assert.Panics(t, func() { assert.Panics(t, func() {
Just(1, 2, 3, 4).Skip(-1) Just(1, 2, 3, 4).Skip(-1)
}) })
@@ -489,104 +489,27 @@ func TestStream_Skip(t *testing.T) {
func TestStream_Concat(t *testing.T) { func TestStream_Concat(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
stream := Just(1).Concat(Just(2), Just(3)) stream := Just(1).Concat(Just(2), Just(3))
var items []any var items []interface{}
for item := range stream.source { for item := range stream.source {
items = append(items, item) items = append(items, item)
} }
sort.Slice(items, func(i, j int) bool { sort.Slice(items, func(i, j int) bool {
return items[i].(int) < items[j].(int) return items[i].(int) < items[j].(int)
}) })
assetEqual(t, []any{1, 2, 3}, items) assetEqual(t, []interface{}{1, 2, 3}, items)
just := Just(1) just := Just(1)
equal(t, just.Concat(just), []any{1}) equal(t, just.Concat(just), []interface{}{1})
})
}
func TestStream_Max(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
tests := []struct {
name string
elements []any
max any
}{
{
name: "no elements with nil",
},
{
name: "no elements",
elements: []any{},
max: nil,
},
{
name: "1 element",
elements: []any{1},
max: 1,
},
{
name: "multiple elements",
elements: []any{1, 2, 9, 5, 8},
max: 9,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
val := Just(test.elements...).Max(func(a, b any) bool {
return a.(int) < b.(int)
})
assetEqual(t, test.max, val)
})
}
})
}
func TestStream_Min(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
tests := []struct {
name string
elements []any
min any
}{
{
name: "no elements with nil",
min: nil,
},
{
name: "no elements",
elements: []any{},
min: nil,
},
{
name: "1 element",
elements: []any{1},
min: 1,
},
{
name: "multiple elements",
elements: []any{-1, 1, 2, 9, 5, 8},
min: -1,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
val := Just(test.elements...).Min(func(a, b any) bool {
return a.(int) < b.(int)
})
assetEqual(t, test.min, val)
})
}
}) })
} }
func BenchmarkParallelMapReduce(b *testing.B) { func BenchmarkParallelMapReduce(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
mapper := func(v any) any { mapper := func(v interface{}) interface{} {
return v.(int64) * v.(int64) return v.(int64) * v.(int64)
} }
reducer := func(input <-chan any) (any, error) { reducer := func(input <-chan interface{}) (interface{}, error) {
var result int64 var result int64
for v := range input { for v := range input {
result += v.(int64) result += v.(int64)
@@ -594,7 +517,7 @@ func BenchmarkParallelMapReduce(b *testing.B) {
return result, nil return result, nil
} }
b.ResetTimer() b.ResetTimer()
From(func(input chan<- any) { From(func(input chan<- interface{}) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
input <- int64(rand.Int()) input <- int64(rand.Int())
@@ -606,10 +529,10 @@ func BenchmarkParallelMapReduce(b *testing.B) {
func BenchmarkMapReduce(b *testing.B) { func BenchmarkMapReduce(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
mapper := func(v any) any { mapper := func(v interface{}) interface{} {
return v.(int64) * v.(int64) return v.(int64) * v.(int64)
} }
reducer := func(input <-chan any) (any, error) { reducer := func(input <-chan interface{}) (interface{}, error) {
var result int64 var result int64
for v := range input { for v := range input {
result += v.(int64) result += v.(int64)
@@ -617,21 +540,21 @@ func BenchmarkMapReduce(b *testing.B) {
return result, nil return result, nil
} }
b.ResetTimer() b.ResetTimer()
From(func(input chan<- any) { From(func(input chan<- interface{}) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
input <- int64(rand.Int()) input <- int64(rand.Int())
} }
}).Map(mapper).Reduce(reducer) }).Map(mapper).Reduce(reducer)
} }
func assetEqual(t *testing.T, except, data any) { func assetEqual(t *testing.T, except, data interface{}) {
if !reflect.DeepEqual(except, data) { if !reflect.DeepEqual(except, data) {
t.Errorf(" %v, want %v", data, except) t.Errorf(" %v, want %v", data, except)
} }
} }
func equal(t *testing.T, stream Stream, data []any) { func equal(t *testing.T, stream Stream, data []interface{}) {
items := make([]any, 0) items := make([]interface{}, 0)
for item := range stream.source { for item := range stream.source {
items = append(items, item) items = append(items, item)
} }

View File

@@ -29,7 +29,7 @@ func DoWithTimeout(fn func() error, timeout time.Duration, opts ...DoOption) err
// create channel with buffer size 1 to avoid goroutine leak // create channel with buffer size 1 to avoid goroutine leak
done := make(chan error, 1) done := make(chan error, 1)
panicChan := make(chan any, 1) panicChan := make(chan interface{}, 1)
go func() { go func() {
defer func() { defer func() {
if p := recover(); p != nil { if p := recover(); p != nil {

View File

@@ -26,7 +26,7 @@ type (
hashFunc Func hashFunc Func
replicas int replicas int
keys []uint64 keys []uint64
ring map[uint64][]any ring map[uint64][]interface{}
nodes map[string]lang.PlaceholderType nodes map[string]lang.PlaceholderType
lock sync.RWMutex lock sync.RWMutex
} }
@@ -50,21 +50,21 @@ func NewCustomConsistentHash(replicas int, fn Func) *ConsistentHash {
return &ConsistentHash{ return &ConsistentHash{
hashFunc: fn, hashFunc: fn,
replicas: replicas, replicas: replicas,
ring: make(map[uint64][]any), ring: make(map[uint64][]interface{}),
nodes: make(map[string]lang.PlaceholderType), nodes: make(map[string]lang.PlaceholderType),
} }
} }
// Add adds the node with the number of h.replicas, // Add adds the node with the number of h.replicas,
// the later call will overwrite the replicas of the former calls. // the later call will overwrite the replicas of the former calls.
func (h *ConsistentHash) Add(node any) { func (h *ConsistentHash) Add(node interface{}) {
h.AddWithReplicas(node, h.replicas) h.AddWithReplicas(node, h.replicas)
} }
// AddWithReplicas adds the node with the number of replicas, // AddWithReplicas adds the node with the number of replicas,
// replicas will be truncated to h.replicas if it's larger than h.replicas, // replicas will be truncated to h.replicas if it's larger than h.replicas,
// the later call will overwrite the replicas of the former calls. // the later call will overwrite the replicas of the former calls.
func (h *ConsistentHash) AddWithReplicas(node any, replicas int) { func (h *ConsistentHash) AddWithReplicas(node interface{}, replicas int) {
h.Remove(node) h.Remove(node)
if replicas > h.replicas { if replicas > h.replicas {
@@ -89,7 +89,7 @@ func (h *ConsistentHash) AddWithReplicas(node any, replicas int) {
// AddWithWeight adds the node with weight, the weight can be 1 to 100, indicates the percent, // AddWithWeight adds the node with weight, the weight can be 1 to 100, indicates the percent,
// the later call will overwrite the replicas of the former calls. // the later call will overwrite the replicas of the former calls.
func (h *ConsistentHash) AddWithWeight(node any, weight int) { func (h *ConsistentHash) AddWithWeight(node interface{}, weight int) {
// don't need to make sure weight not larger than TopWeight, // don't need to make sure weight not larger than TopWeight,
// because AddWithReplicas makes sure replicas cannot be larger than h.replicas // because AddWithReplicas makes sure replicas cannot be larger than h.replicas
replicas := h.replicas * weight / TopWeight replicas := h.replicas * weight / TopWeight
@@ -97,7 +97,7 @@ func (h *ConsistentHash) AddWithWeight(node any, weight int) {
} }
// Get returns the corresponding node from h base on the given v. // Get returns the corresponding node from h base on the given v.
func (h *ConsistentHash) Get(v any) (any, bool) { func (h *ConsistentHash) Get(v interface{}) (interface{}, bool) {
h.lock.RLock() h.lock.RLock()
defer h.lock.RUnlock() defer h.lock.RUnlock()
@@ -124,7 +124,7 @@ func (h *ConsistentHash) Get(v any) (any, bool) {
} }
// Remove removes the given node from h. // Remove removes the given node from h.
func (h *ConsistentHash) Remove(node any) { func (h *ConsistentHash) Remove(node interface{}) {
nodeRepr := repr(node) nodeRepr := repr(node)
h.lock.Lock() h.lock.Lock()
@@ -177,10 +177,10 @@ func (h *ConsistentHash) removeNode(nodeRepr string) {
delete(h.nodes, nodeRepr) delete(h.nodes, nodeRepr)
} }
func innerRepr(node any) string { func innerRepr(node interface{}) string {
return fmt.Sprintf("%d:%v", prime, node) return fmt.Sprintf("%d:%v", prime, node)
} }
func repr(node any) string { func repr(node interface{}) string {
return lang.Repr(node) return lang.Repr(node)
} }

View File

@@ -42,7 +42,7 @@ func TestConsistentHash(t *testing.T) {
keys[key.(string)]++ keys[key.(string)]++
} }
mi := make(map[any]int, len(keys)) mi := make(map[interface{}]int, len(keys))
for k, v := range keys { for k, v := range keys {
mi[k] = v mi[k] = v
} }

View File

@@ -16,7 +16,7 @@ func NewBufferPool(capability int) *BufferPool {
return &BufferPool{ return &BufferPool{
capability: capability, capability: capability,
pool: &sync.Pool{ pool: &sync.Pool{
New: func() any { New: func() interface{} {
return new(bytes.Buffer) return new(bytes.Buffer)
}, },
}, },
@@ -32,10 +32,6 @@ func (bp *BufferPool) Get() *bytes.Buffer {
// Put returns buf into bp. // Put returns buf into bp.
func (bp *BufferPool) Put(buf *bytes.Buffer) { func (bp *BufferPool) Put(buf *bytes.Buffer) {
if buf == nil {
return
}
if buf.Cap() < bp.capability { if buf.Cap() < bp.capability {
bp.pool.Put(buf) bp.pool.Put(buf)
} }

View File

@@ -13,26 +13,3 @@ func TestBufferPool(t *testing.T) {
pool.Put(bytes.NewBuffer(make([]byte, 0, 2*capacity))) pool.Put(bytes.NewBuffer(make([]byte, 0, 2*capacity)))
assert.True(t, pool.Get().Cap() <= capacity) assert.True(t, pool.Get().Cap() <= capacity)
} }
func TestBufferPool_Put(t *testing.T) {
t.Run("with nil buf", func(t *testing.T) {
pool := NewBufferPool(1024)
pool.Put(nil)
val := pool.Get()
assert.IsType(t, new(bytes.Buffer), val)
})
t.Run("with less-cap buf", func(t *testing.T) {
pool := NewBufferPool(1024)
pool.Put(bytes.NewBuffer(make([]byte, 0, 512)))
val := pool.Get()
assert.IsType(t, new(bytes.Buffer), val)
})
t.Run("with more-cap buf", func(t *testing.T) {
pool := NewBufferPool(1024)
pool.Put(bytes.NewBuffer(make([]byte, 0, 1024<<1)))
val := pool.Get()
assert.IsType(t, new(bytes.Buffer), val)
})
}

View File

@@ -1,12 +0,0 @@
package iox
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNopCloser(t *testing.T) {
closer := NopCloser(nil)
assert.NoError(t, closer.Close())
}

View File

@@ -2,7 +2,7 @@ package iox
import "os" import "os"
// RedirectInOut redirects stdin to r, stdout to w, and callers need to call restore afterward. // RedirectInOut redirects stdin to r, stdout to w, and callers need to call restore afterwards.
func RedirectInOut() (restore func(), err error) { func RedirectInOut() (restore func(), err error) {
var r, w *os.File var r, w *os.File
r, w, err = os.Pipe() r, w, err = os.Pipe()

View File

@@ -35,16 +35,6 @@ func KeepSpace() TextReadOption {
} }
} }
// LimitDupReadCloser returns two io.ReadCloser that read from the first will be written to the second.
// But the second io.ReadCloser is limited to up to n bytes.
// The first returned reader needs to be read first, because the content
// read from it will be written to the underlying buffer of the second reader.
func LimitDupReadCloser(reader io.ReadCloser, n int64) (io.ReadCloser, io.ReadCloser) {
var buf bytes.Buffer
tee := LimitTeeReader(reader, &buf, n)
return io.NopCloser(tee), io.NopCloser(&buf)
}
// ReadBytes reads exactly the bytes with the length of len(buf) // ReadBytes reads exactly the bytes with the length of len(buf)
func ReadBytes(reader io.Reader, buf []byte) error { func ReadBytes(reader io.Reader, buf []byte) error {
var got int var got int

View File

@@ -40,22 +40,17 @@ b`,
for _, test := range tests { for _, test := range tests {
t.Run(test.input, func(t *testing.T) { t.Run(test.input, func(t *testing.T) {
tmpFile, err := fs.TempFilenameWithText(test.input) tmpfile, err := fs.TempFilenameWithText(test.input)
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(tmpFile) defer os.Remove(tmpfile)
content, err := ReadText(tmpFile) content, err := ReadText(tmpfile)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, test.expect, content) assert.Equal(t, test.expect, content)
}) })
} }
} }
func TestReadTextError(t *testing.T) {
_, err := ReadText("not-exist")
assert.NotNil(t, err)
}
func TestReadTextLines(t *testing.T) { func TestReadTextLines(t *testing.T) {
text := `1 text := `1
@@ -64,9 +59,9 @@ func TestReadTextLines(t *testing.T) {
#a #a
3` 3`
tmpFile, err := fs.TempFilenameWithText(text) tmpfile, err := fs.TempFilenameWithText(text)
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(tmpFile) defer os.Remove(tmpfile)
tests := []struct { tests := []struct {
options []TextReadOption options []TextReadOption
@@ -92,18 +87,13 @@ func TestReadTextLines(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(stringx.Rand(), func(t *testing.T) { t.Run(stringx.Rand(), func(t *testing.T) {
lines, err := ReadTextLines(tmpFile, test.options...) lines, err := ReadTextLines(tmpfile, test.options...)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, test.expectLines, len(lines)) assert.Equal(t, test.expectLines, len(lines))
}) })
} }
} }
func TestReadTextLinesError(t *testing.T) {
_, err := ReadTextLines("not-exist")
assert.NotNil(t, err)
}
func TestDupReadCloser(t *testing.T) { func TestDupReadCloser(t *testing.T) {
input := "hello" input := "hello"
reader := io.NopCloser(bytes.NewBufferString(input)) reader := io.NopCloser(bytes.NewBufferString(input))
@@ -118,29 +108,6 @@ func TestDupReadCloser(t *testing.T) {
verify(r2) verify(r2)
} }
func TestLimitDupReadCloser(t *testing.T) {
input := "hello world"
limitBytes := int64(4)
reader := io.NopCloser(bytes.NewBufferString(input))
r1, r2 := LimitDupReadCloser(reader, limitBytes)
verify := func(r io.Reader) {
output, err := io.ReadAll(r)
assert.Nil(t, err)
assert.Equal(t, input, string(output))
}
verifyLimit := func(r io.Reader, limit int64) {
output, err := io.ReadAll(r)
if limit < int64(len(input)) {
input = input[:limit]
}
assert.Nil(t, err)
assert.Equal(t, input, string(output))
}
verify(r1)
verifyLimit(r2, limitBytes)
}
func TestReadBytes(t *testing.T) { func TestReadBytes(t *testing.T) {
reader := io.NopCloser(bytes.NewBufferString("helloworld")) reader := io.NopCloser(bytes.NewBufferString("helloworld"))
buf := make([]byte, 5) buf := make([]byte, 5)

View File

@@ -1,35 +0,0 @@
package iox
import "io"
// LimitTeeReader returns a Reader that writes up to n bytes to w what it reads from r.
// First n bytes reads from r performed through it are matched with
// corresponding writes to w. There is no internal buffering -
// the write must complete before the first n bytes read completes.
// Any error encountered while writing is reported as a read error.
func LimitTeeReader(r io.Reader, w io.Writer, n int64) io.Reader {
return &limitTeeReader{r, w, n}
}
type limitTeeReader struct {
r io.Reader
w io.Writer
n int64 // limit bytes remaining
}
func (t *limitTeeReader) Read(p []byte) (n int, err error) {
n, err = t.r.Read(p)
if n > 0 && t.n > 0 {
limit := int64(n)
if limit > t.n {
limit = t.n
}
if n, err := t.w.Write(p[:limit]); err != nil {
return n, err
}
t.n -= limit
}
return
}

View File

@@ -1,40 +0,0 @@
package iox
import (
"bytes"
"io"
"testing"
"github.com/stretchr/testify/assert"
)
func TestLimitTeeReader(t *testing.T) {
limit := int64(4)
src := []byte("hello, world")
dst := make([]byte, len(src))
rb := bytes.NewBuffer(src)
wb := new(bytes.Buffer)
r := LimitTeeReader(rb, wb, limit)
if n, err := io.ReadFull(r, dst); err != nil || n != len(src) {
t.Fatalf("ReadFull(r, dst) = %d, %v; want %d, nil", n, err, len(src))
}
if !bytes.Equal(dst, src) {
t.Errorf("bytes read = %q want %q", dst, src)
}
if !bytes.Equal(wb.Bytes(), src[:limit]) {
t.Errorf("bytes written = %q want %q", wb.Bytes(), src)
}
n, err := r.Read(dst)
assert.Equal(t, 0, n)
assert.Equal(t, io.EOF, err)
rb = bytes.NewBuffer(src)
pr, pw := io.Pipe()
if assert.NoError(t, pr.Close()) {
r = LimitTeeReader(rb, pw, limit)
n, err := io.ReadFull(r, dst)
assert.Equal(t, 0, n)
assert.Equal(t, io.ErrClosedPipe, err)
}
}

View File

@@ -2,14 +2,13 @@ package iox
import ( import (
"bytes" "bytes"
"errors"
"io" "io"
"os" "os"
) )
const bufSize = 32 * 1024 const bufSize = 32 * 1024
// CountLines returns the number of lines in the file. // CountLines returns the number of lines in file.
func CountLines(file string) (int, error) { func CountLines(file string) (int, error) {
f, err := os.Open(file) f, err := os.Open(file)
if err != nil { if err != nil {
@@ -27,7 +26,7 @@ func CountLines(file string) (int, error) {
count += bytes.Count(buf[:c], lineSep) count += bytes.Count(buf[:c], lineSep)
switch { switch {
case errors.Is(err, io.EOF): case err == io.EOF:
if noEol { if noEol {
count++ count++
} }

View File

@@ -24,8 +24,3 @@ func TestCountLines(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 4, lines) assert.Equal(t, 4, lines)
} }
func TestCountLinesError(t *testing.T) {
_, err := CountLines("not-exist")
assert.NotNil(t, err)
}

View File

@@ -2,12 +2,11 @@ package iox
import ( import (
"bufio" "bufio"
"errors"
"io" "io"
"strings" "strings"
) )
// A TextLineScanner is a scanner that can scan lines from the given reader. // A TextLineScanner is a scanner that can scan lines from given reader.
type TextLineScanner struct { type TextLineScanner struct {
reader *bufio.Reader reader *bufio.Reader
hasNext bool hasNext bool
@@ -15,7 +14,7 @@ type TextLineScanner struct {
err error err error
} }
// NewTextLineScanner returns a TextLineScanner with the given reader. // NewTextLineScanner returns a TextLineScanner with given reader.
func NewTextLineScanner(reader io.Reader) *TextLineScanner { func NewTextLineScanner(reader io.Reader) *TextLineScanner {
return &TextLineScanner{ return &TextLineScanner{
reader: bufio.NewReader(reader), reader: bufio.NewReader(reader),
@@ -31,7 +30,7 @@ func (scanner *TextLineScanner) Scan() bool {
line, err := scanner.reader.ReadString('\n') line, err := scanner.reader.ReadString('\n')
scanner.line = strings.TrimRight(line, "\n") scanner.line = strings.TrimRight(line, "\n")
if errors.Is(err, io.EOF) { if err == io.EOF {
scanner.hasNext = false scanner.hasNext = false
return true return true
} else if err != nil { } else if err != nil {

View File

@@ -3,7 +3,6 @@ package iox
import ( import (
"strings" "strings"
"testing" "testing"
"testing/iotest"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -23,10 +22,3 @@ func TestScanner(t *testing.T) {
} }
assert.EqualValues(t, []string{"1", "2", "3", "4"}, lines) assert.EqualValues(t, []string{"1", "2", "3", "4"}, lines)
} }
func TestBadScanner(t *testing.T) {
scanner := NewTextLineScanner(iotest.ErrReader(iotest.ErrTimeout))
assert.False(t, scanner.Scan())
_, err := scanner.Line()
assert.ErrorIs(t, err, iotest.ErrTimeout)
}

View File

@@ -9,12 +9,12 @@ import (
) )
// Marshal marshals v into json bytes. // Marshal marshals v into json bytes.
func Marshal(v any) ([]byte, error) { func Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v) return json.Marshal(v)
} }
// MarshalToString marshals v into a string. // MarshalToString marshals v into a string.
func MarshalToString(v any) (string, error) { func MarshalToString(v interface{}) (string, error) {
data, err := Marshal(v) data, err := Marshal(v)
if err != nil { if err != nil {
return "", err return "", err
@@ -24,7 +24,7 @@ func MarshalToString(v any) (string, error) {
} }
// Unmarshal unmarshals data bytes into v. // Unmarshal unmarshals data bytes into v.
func Unmarshal(data []byte, v any) error { func Unmarshal(data []byte, v interface{}) error {
decoder := json.NewDecoder(bytes.NewReader(data)) decoder := json.NewDecoder(bytes.NewReader(data))
if err := unmarshalUseNumber(decoder, v); err != nil { if err := unmarshalUseNumber(decoder, v); err != nil {
return formatError(string(data), err) return formatError(string(data), err)
@@ -34,7 +34,7 @@ func Unmarshal(data []byte, v any) error {
} }
// UnmarshalFromString unmarshals v from str. // UnmarshalFromString unmarshals v from str.
func UnmarshalFromString(str string, v any) error { func UnmarshalFromString(str string, v interface{}) error {
decoder := json.NewDecoder(strings.NewReader(str)) decoder := json.NewDecoder(strings.NewReader(str))
if err := unmarshalUseNumber(decoder, v); err != nil { if err := unmarshalUseNumber(decoder, v); err != nil {
return formatError(str, err) return formatError(str, err)
@@ -44,7 +44,7 @@ func UnmarshalFromString(str string, v any) error {
} }
// UnmarshalFromReader unmarshals v from reader. // UnmarshalFromReader unmarshals v from reader.
func UnmarshalFromReader(reader io.Reader, v any) error { func UnmarshalFromReader(reader io.Reader, v interface{}) error {
var buf strings.Builder var buf strings.Builder
teeReader := io.TeeReader(reader, &buf) teeReader := io.TeeReader(reader, &buf)
decoder := json.NewDecoder(teeReader) decoder := json.NewDecoder(teeReader)
@@ -55,7 +55,7 @@ func UnmarshalFromReader(reader io.Reader, v any) error {
return nil return nil
} }
func unmarshalUseNumber(decoder *json.Decoder, v any) error { func unmarshalUseNumber(decoder *json.Decoder, v interface{}) error {
decoder.UseNumber() decoder.UseNumber()
return decoder.Decode(v) return decoder.Decode(v)
} }

View File

@@ -11,13 +11,13 @@ var Placeholder PlaceholderType
type ( type (
// AnyType can be used to hold any type. // AnyType can be used to hold any type.
AnyType = any AnyType = interface{}
// PlaceholderType represents a placeholder type. // PlaceholderType represents a placeholder type.
PlaceholderType = struct{} PlaceholderType = struct{}
) )
// Repr returns the string representation of v. // Repr returns the string representation of v.
func Repr(v any) string { func Repr(v interface{}) string {
if v == nil { if v == nil {
return "" return ""
} }

View File

@@ -23,7 +23,7 @@ func TestRepr(t *testing.T) {
u64 uint64 = 8 u64 uint64 = 8
) )
tests := []struct { tests := []struct {
v any v interface{}
expect string expect string
}{ }{
{ {

View File

@@ -9,6 +9,21 @@ import (
"github.com/zeromicro/go-zero/core/stores/redis" "github.com/zeromicro/go-zero/core/stores/redis"
) )
// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
const periodScript = `local limit = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local current = redis.call("INCRBY", KEYS[1], 1)
if current == 1 then
redis.call("expire", KEYS[1], window)
end
if current < limit then
return 1
elseif current == limit then
return 2
else
return 0
end`
const ( const (
// Unknown means not initialized state. // Unknown means not initialized state.
Unknown = iota Unknown = iota
@@ -24,25 +39,8 @@ const (
internalHitQuota = 2 internalHitQuota = 2
) )
var ( // ErrUnknownCode is an error that represents unknown status code.
// ErrUnknownCode is an error that represents unknown status code. var ErrUnknownCode = errors.New("unknown status code")
ErrUnknownCode = errors.New("unknown status code")
// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
periodScript = redis.NewScript(`local limit = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local current = redis.call("INCRBY", KEYS[1], 1)
if current == 1 then
redis.call("expire", KEYS[1], window)
end
if current < limit then
return 1
elseif current == limit then
return 2
else
return 0
end`)
)
type ( type (
// PeriodOption defines the method to customize a PeriodLimit. // PeriodOption defines the method to customize a PeriodLimit.
@@ -82,7 +80,7 @@ func (h *PeriodLimit) Take(key string) (int, error) {
// TakeCtx requests a permit with context, it returns the permit state. // TakeCtx requests a permit with context, it returns the permit state.
func (h *PeriodLimit) TakeCtx(ctx context.Context, key string) (int, error) { func (h *PeriodLimit) TakeCtx(ctx context.Context, key string) (int, error) {
resp, err := h.limitStore.ScriptRunCtx(ctx, periodScript, []string{h.keyPrefix + key}, []string{ resp, err := h.limitStore.EvalCtx(ctx, periodScript, []string{h.keyPrefix + key}, []string{
strconv.Itoa(h.quota), strconv.Itoa(h.quota),
strconv.Itoa(h.calcExpireSeconds()), strconv.Itoa(h.calcExpireSeconds()),
}) })

View File

@@ -33,7 +33,9 @@ func TestPeriodLimit_RedisUnavailable(t *testing.T) {
} }
func testPeriodLimit(t *testing.T, opts ...PeriodOption) { func testPeriodLimit(t *testing.T, opts ...PeriodOption) {
store := redistest.CreateRedis(t) store, clean, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean()
const ( const (
seconds = 1 seconds = 1

View File

@@ -15,15 +15,10 @@ import (
) )
const ( const (
tokenFormat = "{%s}.tokens" // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
timestampFormat = "{%s}.ts" // KEYS[1] as tokens_key
pingInterval = time.Millisecond * 100 // KEYS[2] as timestamp_key
) script = `local rate = tonumber(ARGV[1])
// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
// KEYS[1] as tokens_key
// KEYS[2] as timestamp_key
var script = redis.NewScript(`local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2]) local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3]) local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4]) local requested = tonumber(ARGV[4])
@@ -50,7 +45,11 @@ end
redis.call("setex", KEYS[1], ttl, new_tokens) redis.call("setex", KEYS[1], ttl, new_tokens)
redis.call("setex", KEYS[2], ttl, now) redis.call("setex", KEYS[2], ttl, now)
return allowed`) return allowed`
tokenFormat = "{%s}.tokens"
timestampFormat = "{%s}.ts"
pingInterval = time.Millisecond * 100
)
// A TokenLimiter controls how frequently events are allowed to happen with in one second. // A TokenLimiter controls how frequently events are allowed to happen with in one second.
type TokenLimiter struct { type TokenLimiter struct {
@@ -111,7 +110,7 @@ func (lim *TokenLimiter) reserveN(ctx context.Context, now time.Time, n int) boo
return lim.rescueLimiter.AllowN(now, n) return lim.rescueLimiter.AllowN(now, n)
} }
resp, err := lim.store.ScriptRunCtx(ctx, resp, err := lim.store.EvalCtx(ctx,
script, script,
[]string{ []string{
lim.tokenKey, lim.tokenKey,
@@ -125,7 +124,7 @@ func (lim *TokenLimiter) reserveN(ctx context.Context, now time.Time, n int) boo
}) })
// redis allowed == false // redis allowed == false
// Lua boolean false -> r Nil bulk reply // Lua boolean false -> r Nil bulk reply
if errors.Is(err, redis.Nil) { if err == redis.Nil {
return false return false
} }
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {

View File

@@ -70,7 +70,9 @@ func TestTokenLimit_Rescue(t *testing.T) {
} }
func TestTokenLimit_Take(t *testing.T) { func TestTokenLimit_Take(t *testing.T) {
store := redistest.CreateRedis(t) store, clean, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean()
const ( const (
total = 100 total = 100
@@ -90,7 +92,9 @@ func TestTokenLimit_Take(t *testing.T) {
} }
func TestTokenLimit_TakeBurst(t *testing.T) { func TestTokenLimit_TakeBurst(t *testing.T) {
store := redistest.CreateRedis(t) store, clean, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean()
const ( const (
total = 100 total = 100

View File

@@ -9,7 +9,6 @@ import (
"github.com/zeromicro/go-zero/core/collection" "github.com/zeromicro/go-zero/core/collection"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/stat" "github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/core/syncx" "github.com/zeromicro/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/timex" "github.com/zeromicro/go-zero/core/timex"
@@ -22,11 +21,8 @@ const (
defaultCpuThreshold = 900 defaultCpuThreshold = 900
defaultMinRt = float64(time.Second / time.Millisecond) defaultMinRt = float64(time.Second / time.Millisecond)
// moving average hyperparameter beta for calculating requests on the fly // moving average hyperparameter beta for calculating requests on the fly
flyingBeta = 0.9 flyingBeta = 0.9
coolOffDuration = time.Second coolOffDuration = time.Second
cpuMax = 1000 // millicpu
millisecondsPerSecond = 1000
overloadFactorLowerBound = 0.1
) )
var ( var (
@@ -70,7 +66,7 @@ type (
adaptiveShedder struct { adaptiveShedder struct {
cpuThreshold int64 cpuThreshold int64
windowScale float64 windows int64
flying int64 flying int64
avgFlying float64 avgFlying float64
avgFlyingLock syncx.SpinLock avgFlyingLock syncx.SpinLock
@@ -109,7 +105,7 @@ func NewAdaptiveShedder(opts ...ShedderOption) Shedder {
bucketDuration := options.window / time.Duration(options.buckets) bucketDuration := options.window / time.Duration(options.buckets)
return &adaptiveShedder{ return &adaptiveShedder{
cpuThreshold: options.cpuThreshold, cpuThreshold: options.cpuThreshold,
windowScale: float64(time.Second) / float64(bucketDuration) / millisecondsPerSecond, windows: int64(time.Second / bucketDuration),
overloadTime: syncx.NewAtomicDuration(), overloadTime: syncx.NewAtomicDuration(),
droppedRecently: syncx.NewAtomicBool(), droppedRecently: syncx.NewAtomicBool(),
passCounter: collection.NewRollingWindow(options.buckets, bucketDuration, passCounter: collection.NewRollingWindow(options.buckets, bucketDuration,
@@ -138,10 +134,10 @@ func (as *adaptiveShedder) Allow() (Promise, error) {
func (as *adaptiveShedder) addFlying(delta int64) { func (as *adaptiveShedder) addFlying(delta int64) {
flying := atomic.AddInt64(&as.flying, delta) flying := atomic.AddInt64(&as.flying, delta)
// update avgFlying when the request is finished. // update avgFlying when the request is finished.
// this strategy makes avgFlying have a little bit of lag against flying, and smoother. // this strategy makes avgFlying have a little bit lag against flying, and smoother.
// when the flying requests increase rapidly, avgFlying increase slower, accept more requests. // when the flying requests increase rapidly, avgFlying increase slower, accept more requests.
// when the flying requests drop rapidly, avgFlying drop slower, accept fewer requests. // when the flying requests drop rapidly, avgFlying drop slower, accept less requests.
// it makes the service to serve as many requests as possible. // it makes the service to serve as more requests as possible.
if delta < 0 { if delta < 0 {
as.avgFlyingLock.Lock() as.avgFlyingLock.Lock()
as.avgFlying = as.avgFlying*flyingBeta + float64(flying)*(1-flyingBeta) as.avgFlying = as.avgFlying*flyingBeta + float64(flying)*(1-flyingBeta)
@@ -153,17 +149,16 @@ func (as *adaptiveShedder) highThru() bool {
as.avgFlyingLock.Lock() as.avgFlyingLock.Lock()
avgFlying := as.avgFlying avgFlying := as.avgFlying
as.avgFlyingLock.Unlock() as.avgFlyingLock.Unlock()
maxFlight := as.maxFlight() * as.overloadFactor() maxFlight := as.maxFlight()
return avgFlying > maxFlight && float64(atomic.LoadInt64(&as.flying)) > maxFlight return int64(avgFlying) > maxFlight && atomic.LoadInt64(&as.flying) > maxFlight
} }
func (as *adaptiveShedder) maxFlight() float64 { func (as *adaptiveShedder) maxFlight() int64 {
// windows = buckets per second // windows = buckets per second
// maxQPS = maxPASS * windows // maxQPS = maxPASS * windows
// minRT = min average response time in milliseconds // minRT = min average response time in milliseconds
// allowedFlying = maxQPS * minRT / milliseconds_per_second // maxQPS * minRT / milliseconds_per_second
maxFlight := float64(as.maxPass()) * as.minRt() * as.windowScale return int64(math.Max(1, float64(as.maxPass()*as.windows)*(as.minRt()/1e3)))
return mathx.AtLeast(maxFlight, 1)
} }
func (as *adaptiveShedder) maxPass() int64 { func (as *adaptiveShedder) maxPass() int64 {
@@ -179,8 +174,6 @@ func (as *adaptiveShedder) maxPass() int64 {
} }
func (as *adaptiveShedder) minRt() float64 { func (as *adaptiveShedder) minRt() float64 {
// if no requests in previous windows, return defaultMinRt,
// its a reasonable large value to avoid dropping requests.
result := defaultMinRt result := defaultMinRt
as.rtCounter.Reduce(func(b *collection.Bucket) { as.rtCounter.Reduce(func(b *collection.Bucket) {
@@ -197,13 +190,6 @@ func (as *adaptiveShedder) minRt() float64 {
return result return result
} }
func (as *adaptiveShedder) overloadFactor() float64 {
// as.cpuThreshold must be less than cpuMax
factor := (cpuMax - float64(stat.CpuUsage())) / (cpuMax - float64(as.cpuThreshold))
// at least accept 10% of acceptable requests, even cpu is highly overloaded.
return mathx.Between(factor, overloadFactorLowerBound, 1)
}
func (as *adaptiveShedder) shouldDrop() bool { func (as *adaptiveShedder) shouldDrop() bool {
if as.systemOverloaded() || as.stillHot() { if as.systemOverloaded() || as.stillHot() {
if as.highThru() { if as.highThru() {
@@ -250,14 +236,14 @@ func (as *adaptiveShedder) systemOverloaded() bool {
return true return true
} }
// WithBuckets customizes the Shedder with the given number of buckets. // WithBuckets customizes the Shedder with given number of buckets.
func WithBuckets(buckets int) ShedderOption { func WithBuckets(buckets int) ShedderOption {
return func(opts *shedderOptions) { return func(opts *shedderOptions) {
opts.buckets = buckets opts.buckets = buckets
} }
} }
// WithCpuThreshold customizes the Shedder with the given cpu threshold. // WithCpuThreshold customizes the Shedder with given cpu threshold.
func WithCpuThreshold(threshold int64) ShedderOption { func WithCpuThreshold(threshold int64) ShedderOption {
return func(opts *shedderOptions) { return func(opts *shedderOptions) {
opts.cpuThreshold = threshold opts.cpuThreshold = threshold

View File

@@ -19,7 +19,6 @@ import (
const ( const (
buckets = 10 buckets = 10
bucketDuration = time.Millisecond * 50 bucketDuration = time.Millisecond * 50
windowFactor = 0.01
) )
func init() { func init() {
@@ -115,10 +114,10 @@ func TestAdaptiveShedderMaxFlight(t *testing.T) {
shedder := &adaptiveShedder{ shedder := &adaptiveShedder{
passCounter: passCounter, passCounter: passCounter,
rtCounter: rtCounter, rtCounter: rtCounter,
windowScale: windowFactor, windows: buckets,
droppedRecently: syncx.NewAtomicBool(), droppedRecently: syncx.NewAtomicBool(),
} }
assert.Equal(t, float64(54), shedder.maxFlight()) assert.Equal(t, int64(54), shedder.maxFlight())
} }
func TestAdaptiveShedderShouldDrop(t *testing.T) { func TestAdaptiveShedderShouldDrop(t *testing.T) {
@@ -137,7 +136,7 @@ func TestAdaptiveShedderShouldDrop(t *testing.T) {
shedder := &adaptiveShedder{ shedder := &adaptiveShedder{
passCounter: passCounter, passCounter: passCounter,
rtCounter: rtCounter, rtCounter: rtCounter,
windowScale: windowFactor, windows: buckets,
overloadTime: syncx.NewAtomicDuration(), overloadTime: syncx.NewAtomicDuration(),
droppedRecently: syncx.NewAtomicBool(), droppedRecently: syncx.NewAtomicBool(),
} }
@@ -150,8 +149,7 @@ func TestAdaptiveShedderShouldDrop(t *testing.T) {
// cpu >= 800, inflight > maxPass // cpu >= 800, inflight > maxPass
shedder.avgFlying = 80 shedder.avgFlying = 80
// because of the overloadFactor, so we need to make sure maxFlight is greater than flying shedder.flying = 50
shedder.flying = int64(shedder.maxFlight()*shedder.overloadFactor()) - 5
assert.False(t, shedder.shouldDrop()) assert.False(t, shedder.shouldDrop())
// cpu >= 800, inflight > maxPass // cpu >= 800, inflight > maxPass
@@ -192,7 +190,7 @@ func TestAdaptiveShedderStillHot(t *testing.T) {
shedder := &adaptiveShedder{ shedder := &adaptiveShedder{
passCounter: passCounter, passCounter: passCounter,
rtCounter: rtCounter, rtCounter: rtCounter,
windowScale: windowFactor, windows: buckets,
overloadTime: syncx.NewAtomicDuration(), overloadTime: syncx.NewAtomicDuration(),
droppedRecently: syncx.ForAtomicBool(true), droppedRecently: syncx.ForAtomicBool(true),
} }
@@ -241,30 +239,6 @@ func BenchmarkAdaptiveShedder_Allow(b *testing.B) {
b.Run("low load", bench) b.Run("low load", bench)
} }
func BenchmarkMaxFlight(b *testing.B) {
passCounter := newRollingWindow()
rtCounter := newRollingWindow()
for i := 0; i < 10; i++ {
if i > 0 {
time.Sleep(bucketDuration)
}
passCounter.Add(float64((i + 1) * 100))
for j := i*10 + 1; j <= i*10+10; j++ {
rtCounter.Add(float64(j))
}
}
shedder := &adaptiveShedder{
passCounter: passCounter,
rtCounter: rtCounter,
windowScale: windowFactor,
droppedRecently: syncx.NewAtomicBool(),
}
for i := 0; i < b.N; i++ {
_ = shedder.maxFlight()
}
}
func newRollingWindow() *collection.RollingWindow { func newRollingWindow() *collection.RollingWindow {
return collection.NewRollingWindow(buckets, bucketDuration, collection.IgnoreCurrentBucket()) return collection.NewRollingWindow(buckets, bucketDuration, collection.IgnoreCurrentBucket())
} }

View File

@@ -6,7 +6,7 @@ import (
"github.com/zeromicro/go-zero/core/syncx" "github.com/zeromicro/go-zero/core/syncx"
) )
// A ShedderGroup is a manager to manage key-based shedders. // A ShedderGroup is a manager to manage key based shedders.
type ShedderGroup struct { type ShedderGroup struct {
options []ShedderOption options []ShedderOption
manager *syncx.ResourceManager manager *syncx.ResourceManager

View File

@@ -42,53 +42,53 @@ func Debugv(ctx context.Context, v interface{}) {
getLogger(ctx).Debugv(v) getLogger(ctx).Debugv(v)
} }
// Debugw writes msg along with fields into the access log. // Debugw writes msg along with fields into access log.
func Debugw(ctx context.Context, msg string, fields ...LogField) { func Debugw(ctx context.Context, msg string, fields ...LogField) {
getLogger(ctx).Debugw(msg, fields...) getLogger(ctx).Debugw(msg, fields...)
} }
// Error writes v into error log. // Error writes v into error log.
func Error(ctx context.Context, v ...any) { func Error(ctx context.Context, v ...interface{}) {
getLogger(ctx).Error(v...) getLogger(ctx).Error(v...)
} }
// Errorf writes v with format into error log. // Errorf writes v with format into error log.
func Errorf(ctx context.Context, format string, v ...any) { func Errorf(ctx context.Context, format string, v ...interface{}) {
getLogger(ctx).Errorf(fmt.Errorf(format, v...).Error()) getLogger(ctx).Errorf(fmt.Errorf(format, v...).Error())
} }
// Errorv writes v into error log with json content. // Errorv writes v into error log with json content.
// No call stack attached, because not elegant to pack the messages. // No call stack attached, because not elegant to pack the messages.
func Errorv(ctx context.Context, v any) { func Errorv(ctx context.Context, v interface{}) {
getLogger(ctx).Errorv(v) getLogger(ctx).Errorv(v)
} }
// Errorw writes msg along with fields into the error log. // Errorw writes msg along with fields into error log.
func Errorw(ctx context.Context, msg string, fields ...LogField) { func Errorw(ctx context.Context, msg string, fields ...LogField) {
getLogger(ctx).Errorw(msg, fields...) getLogger(ctx).Errorw(msg, fields...)
} }
// Field returns a LogField for the given key and value. // Field returns a LogField for the given key and value.
func Field(key string, value any) LogField { func Field(key string, value interface{}) LogField {
return logx.Field(key, value) return logx.Field(key, value)
} }
// Info writes v into access log. // Info writes v into access log.
func Info(ctx context.Context, v ...any) { func Info(ctx context.Context, v ...interface{}) {
getLogger(ctx).Info(v...) getLogger(ctx).Info(v...)
} }
// Infof writes v with format into access log. // Infof writes v with format into access log.
func Infof(ctx context.Context, format string, v ...any) { func Infof(ctx context.Context, format string, v ...interface{}) {
getLogger(ctx).Infof(format, v...) getLogger(ctx).Infof(format, v...)
} }
// Infov writes v into access log with json content. // Infov writes v into access log with json content.
func Infov(ctx context.Context, v any) { func Infov(ctx context.Context, v interface{}) {
getLogger(ctx).Infov(v) getLogger(ctx).Infov(v)
} }
// Infow writes msg along with fields into the access log. // Infow writes msg along with fields into access log.
func Infow(ctx context.Context, msg string, fields ...LogField) { func Infow(ctx context.Context, msg string, fields ...LogField) {
getLogger(ctx).Infow(msg, fields...) getLogger(ctx).Infow(msg, fields...)
} }
@@ -108,27 +108,26 @@ func SetLevel(level uint32) {
logx.SetLevel(level) logx.SetLevel(level)
} }
// SetUp sets up the logx. // SetUp sets up the logx. If already set up, just return nil.
// If already set up, return nil. // we allow SetUp to be called multiple times, because for example
// We allow SetUp to be called multiple times, because, for example,
// we need to allow different service frameworks to initialize logx respectively. // we need to allow different service frameworks to initialize logx respectively.
// The same logic for SetUp // the same logic for SetUp
func SetUp(c LogConf) error { func SetUp(c LogConf) error {
return logx.SetUp(c) return logx.SetUp(c)
} }
// Slow writes v into slow log. // Slow writes v into slow log.
func Slow(ctx context.Context, v ...any) { func Slow(ctx context.Context, v ...interface{}) {
getLogger(ctx).Slow(v...) getLogger(ctx).Slow(v...)
} }
// Slowf writes v with format into slow log. // Slowf writes v with format into slow log.
func Slowf(ctx context.Context, format string, v ...any) { func Slowf(ctx context.Context, format string, v ...interface{}) {
getLogger(ctx).Slowf(format, v...) getLogger(ctx).Slowf(format, v...)
} }
// Slowv writes v into slow log with json content. // Slowv writes v into slow log with json content.
func Slowv(ctx context.Context, v any) { func Slowv(ctx context.Context, v interface{}) {
getLogger(ctx).Slowv(v) getLogger(ctx).Slowv(v)
} }

View File

@@ -1,6 +1,7 @@
package logc package logc
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@@ -10,11 +11,14 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/logx/logtest"
) )
func TestAddGlobalFields(t *testing.T) { func TestAddGlobalFields(t *testing.T) {
buf := logtest.NewCollector(t) var buf bytes.Buffer
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
Info(context.Background(), "hello") Info(context.Background(), "hello")
buf.Reset() buf.Reset()
@@ -22,7 +26,7 @@ func TestAddGlobalFields(t *testing.T) {
AddGlobalFields(Field("a", "1"), Field("b", "2")) AddGlobalFields(Field("a", "1"), Field("b", "2"))
AddGlobalFields(Field("c", "3")) AddGlobalFields(Field("c", "3"))
Info(context.Background(), "world") Info(context.Background(), "world")
var m map[string]any var m map[string]interface{}
assert.NoError(t, json.Unmarshal(buf.Bytes(), &m)) assert.NoError(t, json.Unmarshal(buf.Bytes(), &m))
assert.Equal(t, "1", m["a"]) assert.Equal(t, "1", m["a"])
assert.Equal(t, "2", m["b"]) assert.Equal(t, "2", m["b"])
@@ -30,90 +34,155 @@ func TestAddGlobalFields(t *testing.T) {
} }
func TestAlert(t *testing.T) { func TestAlert(t *testing.T) {
buf := logtest.NewCollector(t) var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
Alert(context.Background(), "foo") Alert(context.Background(), "foo")
assert.True(t, strings.Contains(buf.String(), "foo"), buf.String()) assert.True(t, strings.Contains(buf.String(), "foo"), buf.String())
} }
func TestError(t *testing.T) { func TestError(t *testing.T) {
buf := logtest.NewCollector(t) var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
file, line := getFileLine() file, line := getFileLine()
Error(context.Background(), "foo") Error(context.Background(), "foo")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
} }
func TestErrorf(t *testing.T) { func TestErrorf(t *testing.T) {
buf := logtest.NewCollector(t) var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
file, line := getFileLine() file, line := getFileLine()
Errorf(context.Background(), "foo %s", "bar") Errorf(context.Background(), "foo %s", "bar")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
} }
func TestErrorv(t *testing.T) { func TestErrorv(t *testing.T) {
buf := logtest.NewCollector(t) var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
file, line := getFileLine() file, line := getFileLine()
Errorv(context.Background(), "foo") Errorv(context.Background(), "foo")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
} }
func TestErrorw(t *testing.T) { func TestErrorw(t *testing.T) {
buf := logtest.NewCollector(t) var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
file, line := getFileLine() file, line := getFileLine()
Errorw(context.Background(), "foo", Field("a", "b")) Errorw(context.Background(), "foo", Field("a", "b"))
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
} }
func TestInfo(t *testing.T) { func TestInfo(t *testing.T) {
buf := logtest.NewCollector(t) var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
file, line := getFileLine() file, line := getFileLine()
Info(context.Background(), "foo") Info(context.Background(), "foo")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
} }
func TestInfof(t *testing.T) { func TestInfof(t *testing.T) {
buf := logtest.NewCollector(t) var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
file, line := getFileLine() file, line := getFileLine()
Infof(context.Background(), "foo %s", "bar") Infof(context.Background(), "foo %s", "bar")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
} }
func TestInfov(t *testing.T) { func TestInfov(t *testing.T) {
buf := logtest.NewCollector(t) var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
file, line := getFileLine() file, line := getFileLine()
Infov(context.Background(), "foo") Infov(context.Background(), "foo")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
} }
func TestInfow(t *testing.T) { func TestInfow(t *testing.T) {
buf := logtest.NewCollector(t) var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
file, line := getFileLine() file, line := getFileLine()
Infow(context.Background(), "foo", Field("a", "b")) Infow(context.Background(), "foo", Field("a", "b"))
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
} }
func TestDebug(t *testing.T) { func TestDebug(t *testing.T) {
buf := logtest.NewCollector(t) var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
file, line := getFileLine() file, line := getFileLine()
Debug(context.Background(), "foo") Debug(context.Background(), "foo")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
} }
func TestDebugf(t *testing.T) { func TestDebugf(t *testing.T) {
buf := logtest.NewCollector(t) var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
file, line := getFileLine() file, line := getFileLine()
Debugf(context.Background(), "foo %s", "bar") Debugf(context.Background(), "foo %s", "bar")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
} }
func TestDebugv(t *testing.T) { func TestDebugv(t *testing.T) {
buf := logtest.NewCollector(t) var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
file, line := getFileLine() file, line := getFileLine()
Debugv(context.Background(), "foo") Debugv(context.Background(), "foo")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
} }
func TestDebugw(t *testing.T) { func TestDebugw(t *testing.T) {
buf := logtest.NewCollector(t) var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
file, line := getFileLine() file, line := getFileLine()
Debugw(context.Background(), "foo", Field("a", "b")) Debugw(context.Background(), "foo", Field("a", "b"))
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1))) assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
@@ -135,28 +204,48 @@ func TestMisc(t *testing.T) {
} }
func TestSlow(t *testing.T) { func TestSlow(t *testing.T) {
buf := logtest.NewCollector(t) var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
file, line := getFileLine() file, line := getFileLine()
Slow(context.Background(), "foo") Slow(context.Background(), "foo")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String()) assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
} }
func TestSlowf(t *testing.T) { func TestSlowf(t *testing.T) {
buf := logtest.NewCollector(t) var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
file, line := getFileLine() file, line := getFileLine()
Slowf(context.Background(), "foo %s", "bar") Slowf(context.Background(), "foo %s", "bar")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String()) assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
} }
func TestSlowv(t *testing.T) { func TestSlowv(t *testing.T) {
buf := logtest.NewCollector(t) var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
file, line := getFileLine() file, line := getFileLine()
Slowv(context.Background(), "foo") Slowv(context.Background(), "foo")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String()) assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
} }
func TestSloww(t *testing.T) { func TestSloww(t *testing.T) {
buf := logtest.NewCollector(t) var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
file, line := getFileLine() file, line := getFileLine()
Sloww(context.Background(), "foo", Field("a", "b")) Sloww(context.Background(), "foo", Field("a", "b"))
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String()) assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())

View File

@@ -23,7 +23,7 @@ type LogConf struct {
MaxContentLength uint32 `json:",optional"` MaxContentLength uint32 `json:",optional"`
// Compress represents whether to compress the log file, default is `false`. // Compress represents whether to compress the log file, default is `false`.
Compress bool `json:",optional"` Compress bool `json:",optional"`
// Stat represents whether to log statistics, default is `true`. // Stdout represents whether to log statistics, default is `true`.
Stat bool `json:",default=true"` Stat bool `json:",default=true"`
// KeepDays represents how many days the log files will be kept. Default to keep all files. // KeepDays represents how many days the log files will be kept. Default to keep all files.
// Only take effect when Mode is `file` or `volume`, both work when Rotation is `daily` or `size`. // Only take effect when Mode is `file` or `volume`, both work when Rotation is `daily` or `size`.
@@ -32,13 +32,13 @@ type LogConf struct {
StackCooldownMillis int `json:",default=100"` StackCooldownMillis int `json:",default=100"`
// MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever. // MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever.
// Only take effect when RotationRuleType is `size`. // Only take effect when RotationRuleType is `size`.
// Even though `MaxBackups` sets 0, log files will still be removed // Even thougth `MaxBackups` sets 0, log files will still be removed
// if the `KeepDays` limitation is reached. // if the `KeepDays` limitation is reached.
MaxBackups int `json:",default=0"` MaxBackups int `json:",default=0"`
// MaxSize represents how much space the writing log file takes up. 0 means no limit. The unit is `MB`. // MaxSize represents how much space the writing log file takes up. 0 means no limit. The unit is `MB`.
// Only take effect when RotationRuleType is `size` // Only take effect when RotationRuleType is `size`
MaxSize int `json:",default=0"` MaxSize int `json:",default=0"`
// Rotation represents the type of log rotation rule. Default is `daily`. // RotationRuleType represents the type of log rotation rule. Default is `daily`.
// daily: daily rotation. // daily: daily rotation.
// size: size limited rotation. // size: size limited rotation.
Rotation string `json:",default=daily,options=[daily,size]"` Rotation string `json:",default=daily,options=[daily,size]"`

View File

@@ -25,7 +25,7 @@ func TestAddGlobalFields(t *testing.T) {
AddGlobalFields(Field("a", "1"), Field("b", "2")) AddGlobalFields(Field("a", "1"), Field("b", "2"))
AddGlobalFields(Field("c", "3")) AddGlobalFields(Field("c", "3"))
Info("world") Info("world")
var m map[string]any var m map[string]interface{}
assert.NoError(t, json.Unmarshal(buf.Bytes(), &m)) assert.NoError(t, json.Unmarshal(buf.Bytes(), &m))
assert.Equal(t, "1", m["a"]) assert.Equal(t, "1", m["a"])
assert.Equal(t, "2", m["b"]) assert.Equal(t, "2", m["b"])

View File

@@ -1,40 +0,0 @@
package logx
import (
"io"
"os"
)
var fileSys realFileSystem
type (
fileSystem interface {
Close(closer io.Closer) error
Copy(writer io.Writer, reader io.Reader) (int64, error)
Create(name string) (*os.File, error)
Open(name string) (*os.File, error)
Remove(name string) error
}
realFileSystem struct{}
)
func (fs realFileSystem) Close(closer io.Closer) error {
return closer.Close()
}
func (fs realFileSystem) Copy(writer io.Writer, reader io.Reader) (int64, error) {
return io.Copy(writer, reader)
}
func (fs realFileSystem) Create(name string) (*os.File, error) {
return os.Create(name)
}
func (fs realFileSystem) Open(name string) (*os.File, error) {
return os.Open(name)
}
func (fs realFileSystem) Remove(name string) error {
return os.Remove(name)
}

Some files were not shown because too many files have changed in this diff Show More