Compare commits

..

72 Commits

Author SHA1 Message Date
kevin
2ea0a843f8 chore: remove any keywords 2023-03-04 20:54:26 +08:00
Kevin Wan
9e0e01b2bc chore: add tests (#2960) 2023-03-04 20:38:50 +08:00
yangjinheng
af50a80d01 timeout writer add hijack 2023-03-04 20:38:45 +08:00
yangjinheng
703fb8d970 Update timeouthandler.go 2023-03-04 20:38:40 +08:00
MarkJoyMa
e964e530e1 x 2023-03-04 20:32:21 +08:00
MarkJoyMa
52265087d1 x 2023-03-04 20:32:16 +08:00
MarkJoyMa
b4c2677eb9 add ut 2023-03-04 20:32:10 +08:00
MarkJoyMa
30296fb1ca feat: conf add FillDefault func 2023-03-04 20:31:44 +08:00
zhoumingji
356c80defd Fix bug in dartgen: The property 'isEmpty' can't be unconditionally accessed because the receiver can be 'null' 2023-03-04 20:31:38 +08:00
zhoumingji
8c31525378 Fix bug in dartgen: Increase the processing logic when route.RequestType is empty 2023-03-04 20:31:30 +08:00
cui fliter
2cf09f3c36 fix functiom name
Signed-off-by: cui fliter <imcusg@gmail.com>
2023-03-04 20:31:20 +08:00
Kevin Wan
d41e542c92 feat: support grpc client keepalive config (#2950) 2023-03-04 20:30:31 +08:00
tanglihao
265a24ac6d fix code format style use const config.DefaultFormat 2023-03-04 20:30:21 +08:00
tanglihao
7d88fc39dc fix log name conflict 2023-03-04 20:30:16 +08:00
anqiansong
6957b6a344 format code 2023-03-04 20:30:10 +08:00
anqiansong
bca6a230c8 remove unused code 2023-03-04 20:30:04 +08:00
anqiansong
cc8413d683 remove unused code 2023-03-04 20:29:56 +08:00
anqiansong
3842283fa8 Fix #2879 2023-03-04 20:29:41 +08:00
qiying.wang
fe13a533f5 chore: remove redundant prefix of "error: " in error creation 2023-03-04 20:26:40 +08:00
qiying.wang
7a327ccda4 chore: add tests for logc debug 2023-03-04 20:25:52 +08:00
qiying.wang
06e4507406 feat: add debug log for logc 2023-03-04 20:25:27 +08:00
kevin
8794d5b753 chore: add comments 2023-03-04 20:25:21 +08:00
kevin
9bfa63d995 chore: add more tests 2023-03-04 20:25:15 +08:00
kevin
a432b121fb chore: add more tests 2023-03-04 20:25:07 +08:00
kevin
b61c94bb66 feat: check key overwritten 2023-03-04 20:24:33 +08:00
Kevin Wan
93fcf899dc fix: config map cannot handle case-insensitive keys. (#2932)
* fix: #2922

* chore: rename const

* feat: support anonymous map field

* feat: support anonymous map field
2023-03-04 20:23:53 +08:00
Kevin Wan
9f4b3bae92 fix: #2899, using autoscaling/v2beta2 instead of v2beta1 (#2900)
* fix: #2899, using autoscaling/v2 instead of v2beta1

* chore: change hpa definition

---------

Co-authored-by: kevin.wan <kevin.wan@yijinin.com>
2023-03-04 20:22:27 +08:00
Kevin Wan
805cb87d98 chore: refine rest validator (#2928)
* chore: refine rest validator

* chore: add more tests

* chore: reformat code

* chore: add comments
2023-03-04 20:22:10 +08:00
Qiying Wang
366131640e feat: add configurable validator for httpx.Parse (#2923)
Co-authored-by: qiying.wang <qiying.wang@highlight.mobi>
2023-03-04 20:22:05 +08:00
Kevin Wan
956884a3ff fix: timeout not working if greater than global rest timeout (#2926) 2023-03-04 20:21:59 +08:00
raymonder jin
f571cb8af2 del unnecessary blank 2023-03-04 20:21:54 +08:00
Kevin Wan
cc5acf3b90 chore: reformat code (#2925) 2023-03-04 20:21:49 +08:00
chenquan
e1aa665443 fix: fixed the bug that old trace instances may be fetched 2023-03-04 20:21:43 +08:00
xiandong
cd357d9484 rm parseErr when kindJaeger 2023-03-04 20:21:28 +08:00
xiandong
6d4d7cbd6b rm kindJaegerUdp 2023-03-04 20:21:18 +08:00
xiandong
c593b5b531 add parseEndpoint 2023-03-04 20:20:29 +08:00
xiandong
fd5b38b07c add parseEndpoint 2023-03-04 20:20:17 +08:00
xiandong
41efb48f55 add test for Endpoint of kindJaegerUdp 2023-03-04 20:19:40 +08:00
xiandong
0ef3626839 add test for Endpoint of kindJaegerUdp 2023-03-04 20:19:34 +08:00
xiandong
77a72b16e9 add kindJaegerUdp 2023-03-04 20:19:25 +08:00
Kevin Wan
21566f1b7a chore: reformat code (#2903) 2023-03-04 20:17:35 +08:00
anqiansong
b2646e228b feat: Add request.ts (#2901)
* Add request.ts

* Update comments

* Refactor request filename
2023-03-04 20:17:21 +08:00
cong
588b883710 refactor: simplify sqlx fail fast ping and simplify miniredis setup in test (#2897)
* chore(redistest): simplify miniredis setup in test

* refactor(sqlx): simplify sqlx fail fast ping

* chore: close connection if not available
2023-03-04 20:17:16 +08:00
Kevin Wan
033910bbd8 Update readme-cn.md 2023-03-04 20:17:11 +08:00
fondoger
530dd79e3f Fix bug in dart api gen: path parameter is not replaced 2023-03-04 20:17:05 +08:00
Kevin Wan
cd5263ac75 Update readme-cn.md 2023-03-04 20:16:58 +08:00
Kevin Wan
ea3302a468 fix: test failures (#2892)
Co-authored-by: kevin.wan <kevin.wan@yijinin.com>
2023-03-04 20:16:50 +08:00
fondoger
abf15b373c Fix Dart API generation bugs; Add ability to generate API for path parameters (#2887)
* Fix bug in dartgen: Import path should match the generated api filename

* Use Route.HandlerName as generated dart API function name

Reasons:
- There is bug when using url path name as function name, because it may have invalid characters such as ":"
- Switching to HandlerName aligns with other languages such as typescript generation

* [DartGen] Add ability to generate api for url path parameters such as /path/:param
2023-03-04 20:16:44 +08:00
Kevin Wan
a865e9ee29 refactor: simplify stringx.Replacer, and avoid potential infinite loops (#2877)
* simplify replace

* backup

* refactor: simplify stringx.Replacer

* chore: add comments and const

* chore: add more tests

* chore: rename variable
2023-03-04 20:16:37 +08:00
Kevin Wan
f8292198cf Update readme-cn.md 2023-03-04 20:15:38 +08:00
Kevin Wan
016d965f56 chore: refactor (#2875) 2023-03-04 20:15:30 +08:00
dahaihu
95d7c73409 fix Replacer suffix match, and add test case (#2867)
* fix: replace shoud replace the longest match

* feat: revert bytes.Buffer to strings.Builder

* fix: loop reset nextStart

* feat: add node longest match test

* feat: add replacer suffix match test case

* feat: multiple match

* fix: partial match ends

* fix: replace look back upon error

* feat: rm unnecessary branch

---------

Co-authored-by: hudahai <hscxrzs@gmail.com>
Co-authored-by: hushichang <hushichang@sensetime.com>
2023-03-04 20:15:25 +08:00
Kevin Wan
939ef2a181 chore: add more tests (#2873) 2023-03-04 20:15:18 +08:00
Kevin Wan
f0b8dd45fe fix: test failure (#2874) 2023-03-04 20:15:08 +08:00
Mikael
0ba9335b04 only unmashal public variables (#2872)
* only unmashal public variables

* only unmashal public variables

* only unmashal public variables

* only unmashal public variables
2023-03-04 20:15:01 +08:00
Kevin Wan
04f181f0b4 chore: add more tests (#2866)
* chore: add more tests

* chore: add more tests

* chore: fix test failure
2023-03-04 20:14:54 +08:00
hudahai
89f841c126 fix: loop reset nextStart 2023-03-04 20:14:48 +08:00
hudahai
d785c8c377 feat: revert bytes.Buffer to strings.Builder 2023-03-04 20:14:41 +08:00
hudahai
687a1d15da fix: replace shoud replace the longest match 2023-03-04 20:14:35 +08:00
Kevin Wan
aaa974e1ad Update readme-cn.md 2023-03-04 20:14:22 +08:00
Kevin Wan
2779568ccf fix: conf anonymous overlay problem (#2847) 2023-03-04 20:14:10 +08:00
Kevin Wan
f7d50ae626 Update readme-cn.md 2023-03-04 20:14:01 +08:00
Kevin Wan
33594ea350 Chore/rewire (#2836)
* fix: problem on name overlaping in config (#2820)

* chore: fix missing funcs on windows (#2825)

* chore: add more tests (#2812)

* chore: add more tests

* chore: add more tests

* chore: add more tests (#2814)

* chore: add more tests (#2815)

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* feat: upgrade go to v1.18 (#2817)

* feat: upgrade go to v1.18

* feat: upgrade go to v1.18

* chore: change interface{} to any (#2818)

* chore: change interface{} to any

* chore: update goctl version to 1.5.0

* chore: update goctl deps

* chore: update goctl interface{} to any (#2819)

* chore: update goctl interface{} to any

* chore: update goctl interface{} to any

* chore(deps): bump google.golang.org/grpc from 1.52.0 to 1.52.3 (#2823)

* support custom maxBytes in API file (#2822)

* feat: mapreduce generic version (#2827)

* feat: mapreduce generic version

* fix: gateway mr type issue

---------

Co-authored-by: kevin.wan <kevin.wan@yijinin.com>

* feat: add MustNewRedis (#2824)

* feat: add MustNewRedis

* feat: add MustNewRedis

* feat: add MustNewRedis

* x

* x

* fix ut

* x

* x

* x

* x

* x

* chore: improve codecov (#2828)

* feat: converge grpc interceptor processing (#2830)

* feat: converge grpc interceptor processing

* x

* x

* chore(deps): bump go.opentelemetry.io/otel/exporters/zipkin (#2831)

* chore(deps): bump go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp (#2833)

Bumps [go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp](https://github.com/open-telemetry/opentelemetry-go) from 1.11.2 to 1.12.0.
- [Release notes](https://github.com/open-telemetry/opentelemetry-go/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-go/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-go/compare/v1.11.2...v1.12.0)

---
updated-dependencies:
- dependency-name: go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* chore(deps): bump go.opentelemetry.io/otel/exporters/jaeger (#2832)

Bumps [go.opentelemetry.io/otel/exporters/jaeger](https://github.com/open-telemetry/opentelemetry-go) from 1.11.2 to 1.12.0.
- [Release notes](https://github.com/open-telemetry/opentelemetry-go/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-go/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-go/compare/v1.11.2...v1.12.0)

---
updated-dependencies:
- dependency-name: go.opentelemetry.io/otel/exporters/jaeger
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Xiaoju Jiang <44432198+jiang4869@users.noreply.github.com>
Co-authored-by: kevin.wan <kevin.wan@yijinin.com>
Co-authored-by: MarkJoyMa <64180138+MarkJoyMa@users.noreply.github.com>
2023-03-04 20:13:37 +08:00
MarkJoyMa
ee2ec974c4 feat: converge grpc interceptor processing (#2830)
* feat: converge grpc interceptor processing

* x

* x
2023-03-04 20:12:30 +08:00
Kevin Wan
fd2f2f0f54 chore: improve codecov (#2828) 2023-03-04 20:12:16 +08:00
MarkJoyMa
86a2429d7d feat: add MustNewRedis (#2824)
* feat: add MustNewRedis

* feat: add MustNewRedis

* feat: add MustNewRedis

* x

* x

* fix ut

* x

* x

* x

* x

* x
2023-03-04 20:12:05 +08:00
Xiaoju Jiang
e5fe5dcc50 support custom maxBytes in API file (#2822) 2023-03-04 20:11:55 +08:00
Kevin Wan
b510e7c242 chore: fix missing funcs on windows (#2825) 2023-03-04 20:11:46 +08:00
Kevin Wan
dfe92e709f fix: problem on name overlaping in config (#2820) 2023-03-04 20:11:18 +08:00
Kevin Wan
cb649cf627 chore: add more tests (#2815)
* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests
2023-03-04 20:11:03 +08:00
Kevin Wan
ce19a5ade6 chore: add more tests (#2814) 2023-03-04 20:10:57 +08:00
Kevin Wan
6dc56de714 chore: add more tests (#2812)
* chore: add more tests

* chore: add more tests
2023-03-04 20:09:03 +08:00
827 changed files with 12596 additions and 50010 deletions

View File

@@ -1,13 +1,6 @@
coverage:
status:
patch: true
project: false # disabled because project coverage is not stable
comment:
layout: "flags, files"
behavior: once
require_changes: true
ignore:
- "tools"
- "**/mock"
- "**/*_mock.go"
- "**/*test"
- "tools"

View File

@@ -1,7 +1 @@
**/.git
.dockerignore
Dockerfile
goctl
Makefile
readme.md
readme-cn.md

12
.github/FUNDING.yml vendored
View File

@@ -1,3 +1,13 @@
# These are supported funding model platforms
github: [zeromicro]
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
custom: # https://gitee.com/kevwan/static/raw/master/images/sponsor.jpg
ethereum: 0x5052b7f6B937B02563996D23feb69b38D06Ca150 | kevwan

View File

@@ -5,19 +5,7 @@
version: 2
updates:
- package-ecosystem: "docker" # Update image tags in Dockerfile
directory: "/"
schedule:
interval: "weekly"
- package-ecosystem: "github-actions" # Update GitHub Actions
directory: "/"
schedule:
interval: "weekly"
- package-ecosystem: "gomod" # See documentation for possible values
directory: "/" # Location of package manifests
schedule:
interval: "daily"
- package-ecosystem: "gomod" # See documentation for possible values
directory: "/tools/goctl" # Location of package manifests
schedule:
interval: "daily"

View File

@@ -35,11 +35,11 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v3
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
uses: github/codeql-action/init@v2
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
@@ -50,7 +50,7 @@ jobs:
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
uses: github/codeql-action/autobuild@v3
uses: github/codeql-action/autobuild@v2
# Command-line programs to run using the OS shell.
# 📚 https://git.io/JvXDl
@@ -64,4 +64,4 @@ jobs:
# make release
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3
uses: github/codeql-action/analyze@v2

View File

@@ -12,12 +12,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check out code into the Go module directory
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: Set up Go 1.x
uses: actions/setup-go@v5
uses: actions/setup-go@v3
with:
go-version-file: go.mod
go-version: ^1.16
check-latest: true
cache: true
id: go
@@ -29,31 +29,27 @@ jobs:
- name: Lint
run: |
go vet -stdmethods=false $(go list ./...)
go mod tidy
if ! test -z "$(git status --porcelain)"; then
echo "Please run 'go mod tidy'"
exit 1
fi
go install mvdan.cc/gofumpt@latest
test -z "$(gofumpt -l -extra .)" || echo "Please run 'gofumpt -l -w -extra .'"
- name: Test
run: go test -race -coverprofile=coverage.txt -covermode=atomic ./...
- name: Codecov
uses: codecov/codecov-action@v5
uses: codecov/codecov-action@v3
test-win:
name: Windows
runs-on: windows-latest
steps:
- name: Checkout codebase
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: Set up Go 1.x
uses: actions/setup-go@v5
uses: actions/setup-go@v3
with:
# make sure Go version compatible with go-zero
go-version-file: go.mod
# use 1.16 to guarantee Go 1.16 compatibility
go-version: 1.16
check-latest: true
cache: true
@@ -61,5 +57,5 @@ jobs:
run: |
go mod verify
go mod download
go test ./...
go test -v -race ./...
cd tools/goctl && go build -v goctl.go

18
.github/workflows/issue-translator.yml vendored Normal file
View File

@@ -0,0 +1,18 @@
name: 'issue-translator'
on:
issue_comment:
types: [created]
issues:
types: [opened]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: usthe/issues-translate-action@v2.7
with:
IS_MODIFY_TITLE: true
# not require, default false, . Decide whether to modify the issue title
# if true, the robot account @Issues-translate-bot must have modification permissions, invite @Issues-translate-bot to your project or use your custom bot.
CUSTOM_BOT_NOTE: Bot detected the issue body's language is not English, translate it automatically. 👯👭🏻🧑‍🤝‍🧑👫🧑🏿‍🤝‍🧑🏻👩🏾‍🤝‍👨🏿👬🏿
# not require. Customize the translation robot prefix message.

View File

@@ -7,7 +7,7 @@ jobs:
close-issues:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v9
- uses: actions/stale@v6
with:
days-before-issue-stale: 365
days-before-issue-close: 90

View File

@@ -16,13 +16,13 @@ jobs:
- goarch: "386"
goos: darwin
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v3
- uses: zeromicro/go-zero-release-action@master
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
goos: ${{ matrix.goos }}
goarch: ${{ matrix.goarch }}
goversion: "https://dl.google.com/go/go1.21.13.linux-amd64.tar.gz"
goversion: "https://dl.google.com/go/go1.17.5.linux-amd64.tar.gz"
project_path: "tools/goctl"
binary_name: "goctl"
extra_files: tools/goctl/readme.md tools/goctl/readme-cn.md
extra_files: tools/goctl/readme.md tools/goctl/readme-cn.md

View File

@@ -5,7 +5,7 @@ jobs:
name: runner / staticcheck
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v3
- uses: reviewdog/action-staticcheck@v1
with:
github_token: ${{ secrets.github_token }}
@@ -14,6 +14,6 @@ jobs:
# Report all results.
filter_mode: nofilter
# Exit with 1 when it find at least one finding.
fail_level: any
fail_on_error: true
# Set staticcheck flags
staticcheck_flags: -checks=inherit,-SA1019,-SA1029,-SA5008

View File

@@ -1,42 +0,0 @@
name: Release Version Check
on:
push:
tags:
- 'tools/goctl/v*'
workflow_dispatch:
jobs:
version-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: '1.21'
- name: Extract tag version
id: get_version
run: |
# Extract version from tools/goctl/v* format
VERSION="${GITHUB_REF#refs/tags/tools/goctl/v}"
echo "VERSION=$VERSION" >> $GITHUB_ENV
echo "Extracted version: $VERSION"
- name: Check version in goctl source code
run: |
# Change to goctl directory
cd tools/goctl
# Check version in BuildVersion constant
VERSION_IN_CODE=$(grep -r "const BuildVersion =" . | grep -o '".*"' | tr -d '"')
echo "Version in code: $VERSION_IN_CODE"
echo "Expected version: $VERSION"
if [ "$VERSION_IN_CODE" != "$VERSION" ]; then
echo "Version mismatch: Version in code ($VERSION_IN_CODE) doesn't match tag version ($VERSION)"
exit 1
fi
echo "✅ Version check passed!"

6
.gitignore vendored
View File

@@ -11,14 +11,12 @@
!api
# ignore
**/.idea
**/.vscode
.idea
**/.DS_Store
**/logs
**/adhoc
**/coverage.txt
# for test purpose
**/adhoc
go.work
go.work.sum

View File

@@ -1,76 +1,102 @@
# 🚀 Contributing to go-zero
# Contributing
Welcome to the go-zero community! We're thrilled to have you here. Contributing to our project is a fantastic way to be a part of the go-zero journey. Let's make this guide exciting and fun!
Welcome to go-zero!
## 📜 Before You Dive In
- [Before you get started](#before-you-get-started)
- [Code of Conduct](#code-of-conduct)
- [Community Expectations](#community-expectations)
- [Getting started](#getting-started)
- [Your First Contribution](#your-first-contribution)
- [Find something to work on](#find-something-to-work-on)
- [Find a good first topic](#find-a-good-first-topic)
- [Work on an Issue](#work-on-an-issue)
- [File an Issue](#file-an-issue)
- [Contributor Workflow](#contributor-workflow)
- [Creating Pull Requests](#creating-pull-requests)
- [Code Review](#code-review)
- [Testing](#testing)
### 🤝 Code of Conduct
# Before you get started
Let's start on the right foot. Please take a moment to read and embrace our [Code of Conduct](/code-of-conduct.md). We're all about creating a welcoming and respectful environment.
## Code of Conduct
### 🌟 Community Expectations
Please make sure to read and observe our [Code of Conduct](/code-of-conduct.md).
At go-zero, we're like a close-knit family, and we believe in creating a healthy, friendly, and productive atmosphere. It's all about sharing knowledge and building amazing things together.
## Community Expectations
## 🚀 Getting Started
go-zero is a community project driven by its community which strives to promote a healthy, friendly and productive environment.
go-zero is a web and rpc framework written in Go. It's born to ensure the stability of the busy sites with resilient design. Builtin goctl greatly improves the development productivity.
Get your adventure rolling! Here's how to begin:
# Getting started
1. 🍴 **Fork the Repository**: Head over to the GitHub repository and fork it to your own space.
- Fork the repository on GitHub.
- Make your changes on your fork repository.
- Submit a PR.
2. 🛠️ **Make Your Magic**: Work your magic in your forked repository. Create new features, squash bugs, or improve documentation - it's your world to conquer!
3. 🚀 **Submit a PR (Pull Request)**: When you're ready to unveil your creation, submit a Pull Request. We can't wait to see your awesome work!
# Your First Contribution
## 🌟 Your First Contribution
We will help you to contribute in different areas like filing issues, developing features, fixing critical bugs and
getting your work reviewed and merged.
We're here to guide you on your quest to become a go-zero contributor. Whether you want to file issues, develop features, or tame some critical bugs, we've got you covered.
If you have questions about the development process,
feel free to [file an issue](https://github.com/zeromicro/go-zero/issues/new/choose).
If you have questions or need guidance at any stage, don't hesitate to [open an issue](https://github.com/zeromicro/go-zero/issues/new/choose).
## Find something to work on
## 🔍 Find Something to Work On
We are always in need of help, be it fixing documentation, reporting bugs or writing some code.
Look at places where you feel best coding practices aren't followed, code refactoring is needed or tests are missing.
Here is how you get started.
Ready to dive into the action? There are several ways to contribute:
### Find a good first topic
### 💼 Find a Good First Topic
[go-zero](https://github.com/zeromicro/go-zero) has beginner-friendly issues that provide a good first issue.
For example, [go-zero](https://github.com/zeromicro/go-zero) has
[help wanted](https://github.com/zeromicro/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) and
[good first issue](https://github.com/zeromicro/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)
labels for issues that should not need deep knowledge of the system.
We can help new contributors who wish to work on such issues.
Discover easy-entry issues labeled as [help wanted](https://github.com/zeromicro/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) or [good first issue](https://github.com/zeromicro/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). These issues are perfect for newcomers and don't require deep knowledge of the system. We're here to assist you with these tasks.
Another good way to contribute is to find a documentation improvement, such as a missing/broken link.
Please see [Contributing](#contributing) below for the workflow.
### 🪄 Work on an Issue
#### Work on an issue
Once you've picked an issue that excites you, let us know by commenting on it. Our maintainers will assign it to you, and you can embark on your mission!
When you are willing to take on an issue, just reply on the issue. The maintainer will assign it to you.
### 📢 File an Issue
### File an Issue
Reporting an issue is just as valuable as code contributions. If you discover a problem, don't hesitate to [open an issue](https://github.com/zeromicro/go-zero/issues/new/choose). Be sure to follow our guidelines when submitting an issue.
While we encourage everyone to contribute code, it is also appreciated when someone reports an issue.
## 🎯 Contributor Workflow
Please follow the prompted submission guidelines while opening an issue.
Here's a rough guide to your contributor journey:
# Contributor Workflow
1. 🌱 Create a New Branch: Start by creating a topic branch, usually based on the 'master' branch. This is where your contribution will grow.
Please do not ever hesitate to ask a question or send a pull request.
2. 💡 Make Commits: Commit your work in logical units. Each commit should tell a story.
This is a rough outline of what a contributor's workflow looks like:
3. 🚀 Push Changes: Push the changes in your topic branch to your personal fork of the repository.
- Create a topic branch from where to base the contribution. This is usually master.
- Make commits of logical units.
- Push changes in a topic branch to a personal fork of the repository.
- Submit a pull request to [go-zero](https://github.com/zeromicro/go-zero).
4. 📦 Submit a Pull Request: When your creation is complete, submit a Pull Request to the [go-zero repository](https://github.com/zeromicro/go-zero).
## Creating Pull Requests
## 🌠 Creating Pull Requests
Pull requests are often called simply "PR".
go-zero generally follows the standard [github pull request](https://help.github.com/articles/about-pull-requests/) process.
To submit a proposed change, please develop the code/fix and add new test cases.
After that, run these local verifications before submitting pull request to predict the pass or
fail of continuous integration.
Pull Requests (PRs) are your way of making a grand entrance with your contribution. Here's how to do it:
* Format the code with `gofmt`
* Run the test with data race enabled `go test -race ./...`
- 💼 Format Your Code: Ensure your code is beautifully formatted with `gofmt`.
- 🏃 Run Tests: Verify that your changes pass all the tests, including data race tests. Run `go test -race ./...` for the ultimate validation.
## Code Review
## 👁️‍🗨️ Code Review
To make it easier for your PR to receive reviews, consider the reviewers will need you to:
Getting your PR reviewed is the final step before your contribution becomes part of go-zero's magical world. To make the process smooth, keep these things in mind:
* follow [good coding guidelines](https://github.com/golang/go/wiki/CodeReviewComments).
* write [good commit messages](https://chris.beams.io/posts/git-commit/).
* break large changes into a logical series of smaller patches which individually make easily understandable changes, and in aggregate solve a broader issue.
- 🧙‍♀️ Follow Good Coding Practices: Stick to [good coding guidelines](https://github.com/golang/go/wiki/CodeReviewComments).
- 📝 Write Awesome Commit Messages: Craft [impressive commit messages](https://chris.beams.io/posts/git-commit/) - they're like spells in the wizard's book!
- 🔍 Break It Down: For larger changes, consider breaking them into a series of smaller, logical patches. Each patch should make an understandable and meaningful improvement.
Congratulations on your contribution journey! We're thrilled to have you as part of our go-zero community. Let's make amazing things together! 🌟
Now, go out there and start your adventure! If you have any more magical ideas to enhance this guide, please share them. 🔥

View File

@@ -1,16 +0,0 @@
# Security Policy
## Supported Versions
We publish releases monthly.
| Version | Supported |
| ------- | ------------------ |
| >= 1.4.4 | :white_check_mark: |
| < 1.4.4 | :x: |
## Reporting a Vulnerability
https://github.com/zeromicro/go-zero/security/advisories
Accepted vulnerabilities are expected to be fixed within a month.

View File

@@ -1,127 +1,76 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, caste, color, religion, or sexual
identity and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
Examples of behavior that contributes to creating a positive environment
include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the overall
community
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior include:
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery, and sexual attention or advances of
any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email address,
without their explicit permission
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
## Our Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
[INSERT CONTACT METHOD].
All complaints will be reviewed and investigated promptly and fairly.
reported by contacting the project team at [INSERT EMAIL ADDRESS]. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series of
actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or permanent
ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within the
community.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.1, available at
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
Community Impact Guidelines were inspired by
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
[https://www.contributor-covenant.org/translations][translations].
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq

View File

@@ -1,8 +1,6 @@
package bloom
import (
"context"
_ "embed"
"errors"
"strconv"
@@ -10,23 +8,28 @@ import (
"github.com/zeromicro/go-zero/core/stores/redis"
)
// for detailed error rate table, see http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html
// maps as k in the error rate table
const maps = 14
var (
// ErrTooLargeOffset indicates the offset is too large in bitset.
ErrTooLargeOffset = errors.New("too large offset")
//go:embed setscript.lua
setLuaScript string
setScript = redis.NewScript(setLuaScript)
//go:embed testscript.lua
testLuaScript string
testScript = redis.NewScript(testLuaScript)
const (
// for detailed error rate table, see http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html
// maps as k in the error rate table
maps = 14
setScript = `
for _, offset in ipairs(ARGV) do
redis.call("setbit", KEYS[1], offset, 1)
end
`
testScript = `
for _, offset in ipairs(ARGV) do
if tonumber(redis.call("getbit", KEYS[1], offset)) == 0 then
return false
end
end
return true
`
)
// ErrTooLargeOffset indicates the offset is too large in bitset.
var ErrTooLargeOffset = errors.New("too large offset")
type (
// A Filter is a bloom filter.
Filter struct {
@@ -35,8 +38,8 @@ type (
}
bitSetProvider interface {
check(ctx context.Context, offsets []uint) (bool, error)
set(ctx context.Context, offsets []uint) error
check([]uint) (bool, error)
set([]uint) error
}
)
@@ -55,24 +58,14 @@ func New(store *redis.Redis, key string, bits uint) *Filter {
// Add adds data into f.
func (f *Filter) Add(data []byte) error {
return f.AddCtx(context.Background(), data)
}
// AddCtx adds data into f with context.
func (f *Filter) AddCtx(ctx context.Context, data []byte) error {
locations := f.getLocations(data)
return f.bitSet.set(ctx, locations)
return f.bitSet.set(locations)
}
// Exists checks if data is in f.
func (f *Filter) Exists(data []byte) (bool, error) {
return f.ExistsCtx(context.Background(), data)
}
// ExistsCtx checks if data is in f with context.
func (f *Filter) ExistsCtx(ctx context.Context, data []byte) (bool, error) {
locations := f.getLocations(data)
isSet, err := f.bitSet.check(ctx, locations)
isSet, err := f.bitSet.check(locations)
if err != nil {
return false, err
}
@@ -105,7 +98,7 @@ func newRedisBitSet(store *redis.Redis, key string, bits uint) *redisBitSet {
}
func (r *redisBitSet) buildOffsetArgs(offsets []uint) ([]string, error) {
args := make([]string, 0, len(offsets))
var args []string
for _, offset := range offsets {
if offset >= r.bits {
@@ -118,14 +111,14 @@ func (r *redisBitSet) buildOffsetArgs(offsets []uint) ([]string, error) {
return args, nil
}
func (r *redisBitSet) check(ctx context.Context, offsets []uint) (bool, error) {
func (r *redisBitSet) check(offsets []uint) (bool, error) {
args, err := r.buildOffsetArgs(offsets)
if err != nil {
return false, err
}
resp, err := r.store.ScriptRunCtx(ctx, testScript, []string{r.key}, args)
if errors.Is(err, redis.Nil) {
resp, err := r.store.Eval(testScript, []string{r.key}, args)
if err == redis.Nil {
return false, nil
} else if err != nil {
return false, err
@@ -139,25 +132,23 @@ func (r *redisBitSet) check(ctx context.Context, offsets []uint) (bool, error) {
return exists == 1, nil
}
// del only use for testing.
func (r *redisBitSet) del() error {
_, err := r.store.Del(r.key)
return err
}
// expire only use for testing.
func (r *redisBitSet) expire(seconds int) error {
return r.store.Expire(r.key, seconds)
}
func (r *redisBitSet) set(ctx context.Context, offsets []uint) error {
func (r *redisBitSet) set(offsets []uint) error {
args, err := r.buildOffsetArgs(offsets)
if err != nil {
return err
}
_, err = r.store.ScriptRunCtx(ctx, setScript, []string{r.key}, args)
if errors.Is(err, redis.Nil) {
_, err = r.store.Eval(setScript, []string{r.key}, args)
if err == redis.Nil {
return nil
}

View File

@@ -1,31 +1,30 @@
package bloom
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/redis/redistest"
)
func TestRedisBitSet_New_Set_Test(t *testing.T) {
store := redistest.CreateRedis(t)
ctx := context.Background()
store, clean, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean()
bitSet := newRedisBitSet(store, "test_key", 1024)
isSetBefore, err := bitSet.check(ctx, []uint{0})
isSetBefore, err := bitSet.check([]uint{0})
if err != nil {
t.Fatal(err)
}
if isSetBefore {
t.Fatal("Bit should not be set")
}
err = bitSet.set(ctx, []uint{512})
err = bitSet.set([]uint{512})
if err != nil {
t.Fatal(err)
}
isSetAfter, err := bitSet.check(ctx, []uint{512})
isSetAfter, err := bitSet.check([]uint{512})
if err != nil {
t.Fatal(err)
}
@@ -43,7 +42,9 @@ func TestRedisBitSet_New_Set_Test(t *testing.T) {
}
func TestRedisBitSet_Add(t *testing.T) {
store := redistest.CreateRedis(t)
store, clean, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean()
filter := New(store, "test_key", 64)
assert.Nil(t, filter.Add([]byte("hello")))
@@ -52,51 +53,3 @@ func TestRedisBitSet_Add(t *testing.T) {
assert.Nil(t, err)
assert.True(t, ok)
}
func TestFilter_Exists(t *testing.T) {
store, clean := redistest.CreateRedisWithClean(t)
rbs := New(store, "test", 64)
_, err := rbs.Exists([]byte{0, 1, 2})
assert.NoError(t, err)
clean()
rbs = New(store, "test", 64)
_, err = rbs.Exists([]byte{0, 1, 2})
assert.Error(t, err)
}
func TestRedisBitSet_check(t *testing.T) {
store, clean := redistest.CreateRedisWithClean(t)
ctx := context.Background()
rbs := newRedisBitSet(store, "test", 0)
assert.Error(t, rbs.set(ctx, []uint{0, 1, 2}))
_, err := rbs.check(ctx, []uint{0, 1, 2})
assert.Error(t, err)
rbs = newRedisBitSet(store, "test", 64)
_, err = rbs.check(ctx, []uint{0, 1, 2})
assert.NoError(t, err)
clean()
rbs = newRedisBitSet(store, "test", 64)
_, err = rbs.check(ctx, []uint{0, 1, 2})
assert.Error(t, err)
}
func TestRedisBitSet_set(t *testing.T) {
logx.Disable()
store, clean := redistest.CreateRedisWithClean(t)
ctx := context.Background()
rbs := newRedisBitSet(store, "test", 0)
assert.Error(t, rbs.set(ctx, []uint{0, 1, 2}))
rbs = newRedisBitSet(store, "test", 64)
assert.NoError(t, rbs.set(ctx, []uint{0, 1, 2}))
clean()
rbs = newRedisBitSet(store, "test", 64)
assert.Error(t, rbs.set(ctx, []uint{0, 1, 2}))
}

View File

@@ -1,3 +0,0 @@
for _, offset in ipairs(ARGV) do
redis.call("setbit", KEYS[1], offset, 1)
end

View File

@@ -1,6 +0,0 @@
for _, offset in ipairs(ARGV) do
if tonumber(redis.call("getbit", KEYS[1], offset)) == 0 then
return false
end
end
return true

View File

@@ -1,7 +1,6 @@
package breaker
import (
"context"
"errors"
"fmt"
"strings"
@@ -32,53 +31,38 @@ type (
Name() string
// Allow checks if the request is allowed.
// If allowed, a promise will be returned,
// otherwise ErrServiceUnavailable will be returned as the error.
// The caller needs to call promise.Accept() on success,
// or call promise.Reject() on failure.
// If allowed, a promise will be returned, the caller needs to call promise.Accept()
// on success, or call promise.Reject() on failure.
// If not allow, ErrServiceUnavailable will be returned.
Allow() (Promise, error)
// AllowCtx checks if the request is allowed when ctx isn't done.
AllowCtx(ctx context.Context) (Promise, error)
// Do runs the given request if the Breaker accepts it.
// Do returns an error instantly if the Breaker rejects the request.
// If a panic occurs in the request, the Breaker handles it as an error
// and causes the same panic again.
Do(req func() error) error
// DoCtx runs the given request if the Breaker accepts it when ctx isn't done.
DoCtx(ctx context.Context, req func() error) error
// DoWithAcceptable runs the given request if the Breaker accepts it.
// DoWithAcceptable returns an error instantly if the Breaker rejects the request.
// If a panic occurs in the request, the Breaker handles it as an error
// and causes the same panic again.
// acceptable checks if it's a successful call, even if the error is not nil.
// acceptable checks if it's a successful call, even if the err is not nil.
DoWithAcceptable(req func() error, acceptable Acceptable) error
// DoWithAcceptableCtx runs the given request if the Breaker accepts it when ctx isn't done.
DoWithAcceptableCtx(ctx context.Context, req func() error, acceptable Acceptable) error
// DoWithFallback runs the given request if the Breaker accepts it.
// DoWithFallback runs the fallback if the Breaker rejects the request.
// If a panic occurs in the request, the Breaker handles it as an error
// and causes the same panic again.
DoWithFallback(req func() error, fallback Fallback) error
// DoWithFallbackCtx runs the given request if the Breaker accepts it when ctx isn't done.
DoWithFallbackCtx(ctx context.Context, req func() error, fallback Fallback) error
DoWithFallback(req func() error, fallback func(err error) error) error
// DoWithFallbackAcceptable runs the given request if the Breaker accepts it.
// DoWithFallbackAcceptable runs the fallback if the Breaker rejects the request.
// If a panic occurs in the request, the Breaker handles it as an error
// and causes the same panic again.
// acceptable checks if it's a successful call, even if the error is not nil.
DoWithFallbackAcceptable(req func() error, fallback Fallback, acceptable Acceptable) error
// DoWithFallbackAcceptableCtx runs the given request if the Breaker accepts it when ctx isn't done.
DoWithFallbackAcceptableCtx(ctx context.Context, req func() error, fallback Fallback,
acceptable Acceptable) error
// acceptable checks if it's a successful call, even if the err is not nil.
DoWithFallbackAcceptable(req func() error, fallback func(err error) error, acceptable Acceptable) error
}
// Fallback is the func to be called if the request is rejected.
Fallback func(err error) error
// Option defines the method to customize a Breaker.
Option func(breaker *circuitBreaker)
@@ -102,12 +86,12 @@ type (
internalThrottle interface {
allow() (internalPromise, error)
doReq(req func() error, fallback Fallback, acceptable Acceptable) error
doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error
}
throttle interface {
allow() (Promise, error)
doReq(req func() error, fallback Fallback, acceptable Acceptable) error
doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error
}
)
@@ -130,71 +114,23 @@ func (cb *circuitBreaker) Allow() (Promise, error) {
return cb.throttle.allow()
}
func (cb *circuitBreaker) AllowCtx(ctx context.Context) (Promise, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
return cb.Allow()
}
}
func (cb *circuitBreaker) Do(req func() error) error {
return cb.throttle.doReq(req, nil, defaultAcceptable)
}
func (cb *circuitBreaker) DoCtx(ctx context.Context, req func() error) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
return cb.Do(req)
}
}
func (cb *circuitBreaker) DoWithAcceptable(req func() error, acceptable Acceptable) error {
return cb.throttle.doReq(req, nil, acceptable)
}
func (cb *circuitBreaker) DoWithAcceptableCtx(ctx context.Context, req func() error,
acceptable Acceptable) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
return cb.DoWithAcceptable(req, acceptable)
}
}
func (cb *circuitBreaker) DoWithFallback(req func() error, fallback Fallback) error {
func (cb *circuitBreaker) DoWithFallback(req func() error, fallback func(err error) error) error {
return cb.throttle.doReq(req, fallback, defaultAcceptable)
}
func (cb *circuitBreaker) DoWithFallbackCtx(ctx context.Context, req func() error,
fallback Fallback) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
return cb.DoWithFallback(req, fallback)
}
}
func (cb *circuitBreaker) DoWithFallbackAcceptable(req func() error, fallback Fallback,
func (cb *circuitBreaker) DoWithFallbackAcceptable(req func() error, fallback func(err error) error,
acceptable Acceptable) error {
return cb.throttle.doReq(req, fallback, acceptable)
}
func (cb *circuitBreaker) DoWithFallbackAcceptableCtx(ctx context.Context, req func() error,
fallback Fallback, acceptable Acceptable) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
return cb.DoWithFallbackAcceptable(req, fallback, acceptable)
}
}
func (cb *circuitBreaker) Name() string {
return cb.name
}
@@ -232,7 +168,7 @@ func (lt loggedThrottle) allow() (Promise, error) {
}, lt.logError(err)
}
func (lt loggedThrottle) doReq(req func() error, fallback Fallback, acceptable Acceptable) error {
func (lt loggedThrottle) doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error {
return lt.logError(lt.internalThrottle.doReq(req, fallback, func(err error) bool {
accept := acceptable(err)
if !accept && err != nil {
@@ -243,7 +179,7 @@ func (lt loggedThrottle) doReq(req func() error, fallback Fallback, acceptable A
}
func (lt loggedThrottle) logError(err error) error {
if errors.Is(err, ErrServiceUnavailable) {
if err == ErrServiceUnavailable {
// if circuit open, not possible to have empty error window
stat.Report(fmt.Sprintf(
"proc(%s/%d), callee: %s, breaker is open and requests dropped\nlast errors:\n%s",
@@ -269,7 +205,7 @@ func (ew *errorWindow) add(reason string) {
}
func (ew *errorWindow) String() string {
reasons := make([]string, 0, ew.count)
var reasons []string
ew.lock.Lock()
// reverse order

View File

@@ -1,13 +1,11 @@
package breaker
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stat"
@@ -18,274 +16,10 @@ func init() {
}
func TestCircuitBreaker_Allow(t *testing.T) {
t.Run("allow", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
_, err := b.Allow()
assert.Nil(t, err)
})
t.Run("allow with ctx", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
_, err := b.AllowCtx(context.Background())
assert.Nil(t, err)
})
t.Run("allow with ctx timeout", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
defer cancel()
time.Sleep(time.Millisecond)
_, err := b.AllowCtx(ctx)
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
t.Run("allow with ctx cancel", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
for i := 0; i < 100; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
cancel()
_, err := b.AllowCtx(ctx)
assert.ErrorIs(t, err, context.Canceled)
}
_, err := b.AllowCtx(context.Background())
assert.NoError(t, err)
})
}
func TestCircuitBreaker_Do(t *testing.T) {
t.Run("do", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.Do(func() error {
return nil
})
assert.Nil(t, err)
})
t.Run("do with ctx", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoCtx(context.Background(), func() error {
return nil
})
assert.Nil(t, err)
})
t.Run("do with ctx timeout", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
defer cancel()
time.Sleep(time.Millisecond)
err := b.DoCtx(ctx, func() error {
return nil
})
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
t.Run("do with ctx cancel", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
for i := 0; i < 100; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
cancel()
err := b.DoCtx(ctx, func() error {
return nil
})
assert.ErrorIs(t, err, context.Canceled)
}
assert.NoError(t, b.DoCtx(context.Background(), func() error {
return nil
}))
})
}
func TestCircuitBreaker_DoWithAcceptable(t *testing.T) {
t.Run("doWithAcceptable", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoWithAcceptable(func() error {
return nil
}, func(err error) bool {
return true
})
assert.Nil(t, err)
})
t.Run("doWithAcceptable with ctx", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoWithAcceptableCtx(context.Background(), func() error {
return nil
}, func(err error) bool {
return true
})
assert.Nil(t, err)
})
t.Run("doWithAcceptable with ctx timeout", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
defer cancel()
time.Sleep(time.Millisecond)
err := b.DoWithAcceptableCtx(ctx, func() error {
return nil
}, func(err error) bool {
return true
})
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
t.Run("doWithAcceptable with ctx cancel", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
for i := 0; i < 100; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
cancel()
err := b.DoWithAcceptableCtx(ctx, func() error {
return nil
}, func(err error) bool {
return true
})
assert.ErrorIs(t, err, context.Canceled)
}
assert.NoError(t, b.DoWithAcceptableCtx(context.Background(), func() error {
return nil
}, func(err error) bool {
return true
}))
})
}
func TestCircuitBreaker_DoWithFallback(t *testing.T) {
t.Run("doWithFallback", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoWithFallback(func() error {
return nil
}, func(err error) error {
return err
})
assert.Nil(t, err)
})
t.Run("doWithFallback with ctx", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoWithFallbackCtx(context.Background(), func() error {
return nil
}, func(err error) error {
return err
})
assert.Nil(t, err)
})
t.Run("doWithFallback with ctx timeout", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
defer cancel()
time.Sleep(time.Millisecond)
err := b.DoWithFallbackCtx(ctx, func() error {
return nil
}, func(err error) error {
return err
})
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
t.Run("doWithFallback with ctx cancel", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
for i := 0; i < 100; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
cancel()
err := b.DoWithFallbackCtx(ctx, func() error {
return nil
}, func(err error) error {
return err
})
assert.ErrorIs(t, err, context.Canceled)
}
assert.NoError(t, b.DoWithFallbackCtx(context.Background(), func() error {
return nil
}, func(err error) error {
return err
}))
})
}
func TestCircuitBreaker_DoWithFallbackAcceptable(t *testing.T) {
t.Run("doWithFallbackAcceptable", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoWithFallbackAcceptable(func() error {
return nil
}, func(err error) error {
return err
}, func(err error) bool {
return true
})
assert.Nil(t, err)
})
t.Run("doWithFallbackAcceptable with ctx", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoWithFallbackAcceptableCtx(context.Background(), func() error {
return nil
}, func(err error) error {
return err
}, func(err error) bool {
return true
})
assert.Nil(t, err)
})
t.Run("doWithFallbackAcceptable with ctx timeout", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
defer cancel()
time.Sleep(time.Millisecond)
err := b.DoWithFallbackAcceptableCtx(ctx, func() error {
return nil
}, func(err error) error {
return err
}, func(err error) bool {
return true
})
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
t.Run("doWithFallbackAcceptable with ctx cancel", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
for i := 0; i < 100; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
cancel()
err := b.DoWithFallbackAcceptableCtx(ctx, func() error {
return nil
}, func(err error) error {
return err
}, func(err error) bool {
return true
})
assert.ErrorIs(t, err, context.Canceled)
}
assert.NoError(t, b.DoWithFallbackAcceptableCtx(context.Background(), func() error {
return nil
}, func(err error) error {
return err
}, func(err error) bool {
return true
}))
})
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
_, err := b.Allow()
assert.Nil(t, err)
}
func TestLogReason(t *testing.T) {

View File

@@ -1,9 +1,6 @@
package breaker
import (
"context"
"sync"
)
import "sync"
var (
lock sync.RWMutex
@@ -17,13 +14,6 @@ func Do(name string, req func() error) error {
})
}
// DoCtx calls Breaker.DoCtx on the Breaker with given name.
func DoCtx(ctx context.Context, name string, req func() error) error {
return do(name, func(b Breaker) error {
return b.DoCtx(ctx, req)
})
}
// DoWithAcceptable calls Breaker.DoWithAcceptable on the Breaker with given name.
func DoWithAcceptable(name string, req func() error, acceptable Acceptable) error {
return do(name, func(b Breaker) error {
@@ -31,44 +21,21 @@ func DoWithAcceptable(name string, req func() error, acceptable Acceptable) erro
})
}
// DoWithAcceptableCtx calls Breaker.DoWithAcceptableCtx on the Breaker with given name.
func DoWithAcceptableCtx(ctx context.Context, name string, req func() error,
acceptable Acceptable) error {
return do(name, func(b Breaker) error {
return b.DoWithAcceptableCtx(ctx, req, acceptable)
})
}
// DoWithFallback calls Breaker.DoWithFallback on the Breaker with given name.
func DoWithFallback(name string, req func() error, fallback Fallback) error {
func DoWithFallback(name string, req func() error, fallback func(err error) error) error {
return do(name, func(b Breaker) error {
return b.DoWithFallback(req, fallback)
})
}
// DoWithFallbackCtx calls Breaker.DoWithFallbackCtx on the Breaker with given name.
func DoWithFallbackCtx(ctx context.Context, name string, req func() error, fallback Fallback) error {
return do(name, func(b Breaker) error {
return b.DoWithFallbackCtx(ctx, req, fallback)
})
}
// DoWithFallbackAcceptable calls Breaker.DoWithFallbackAcceptable on the Breaker with given name.
func DoWithFallbackAcceptable(name string, req func() error, fallback Fallback,
func DoWithFallbackAcceptable(name string, req func() error, fallback func(err error) error,
acceptable Acceptable) error {
return do(name, func(b Breaker) error {
return b.DoWithFallbackAcceptable(req, fallback, acceptable)
})
}
// DoWithFallbackAcceptableCtx calls Breaker.DoWithFallbackAcceptableCtx on the Breaker with given name.
func DoWithFallbackAcceptableCtx(ctx context.Context, name string, req func() error,
fallback Fallback, acceptable Acceptable) error {
return do(name, func(b Breaker) error {
return b.DoWithFallbackAcceptableCtx(ctx, req, fallback, acceptable)
})
}
// GetBreaker returns the Breaker with the given name.
func GetBreaker(name string) Breaker {
lock.RLock()
@@ -92,7 +59,7 @@ func GetBreaker(name string) Breaker {
// NoBreakerFor disables the circuit breaker for the given name.
func NoBreakerFor(name string) {
lock.Lock()
breakers[name] = NopBreaker()
breakers[name] = newNoOpBreaker()
lock.Unlock()
}

View File

@@ -1,7 +1,6 @@
package breaker
import (
"context"
"errors"
"fmt"
"testing"
@@ -23,9 +22,6 @@ func TestBreakersDo(t *testing.T) {
assert.Equal(t, errDummy, Do("any", func() error {
return errDummy
}))
assert.Equal(t, errDummy, DoCtx(context.Background(), "any", func() error {
return errDummy
}))
}
func TestBreakersDoWithAcceptable(t *testing.T) {
@@ -34,7 +30,7 @@ func TestBreakersDoWithAcceptable(t *testing.T) {
assert.Equal(t, errDummy, GetBreaker("anyone").DoWithAcceptable(func() error {
return errDummy
}, func(err error) bool {
return err == nil || errors.Is(err, errDummy)
return err == nil || err == errDummy
}))
}
verify(t, func() bool {
@@ -42,13 +38,6 @@ func TestBreakersDoWithAcceptable(t *testing.T) {
return nil
}) == nil
})
verify(t, func() bool {
return DoWithAcceptableCtx(context.Background(), "anyone", func() error {
return nil
}, func(err error) bool {
return true
}) == nil
})
for i := 0; i < 10000; i++ {
err := DoWithAcceptable("another", func() error {
@@ -56,12 +45,12 @@ func TestBreakersDoWithAcceptable(t *testing.T) {
}, func(err error) bool {
return err == nil
})
assert.True(t, errors.Is(err, errDummy) || errors.Is(err, ErrServiceUnavailable))
assert.True(t, err == errDummy || err == ErrServiceUnavailable)
}
verify(t, func() bool {
return errors.Is(Do("another", func() error {
return ErrServiceUnavailable == Do("another", func() error {
return nil
}), ErrServiceUnavailable)
})
})
}
@@ -86,24 +75,18 @@ func TestBreakersFallback(t *testing.T) {
}, func(err error) error {
return nil
})
assert.True(t, err == nil || errors.Is(err, errDummy))
err = DoWithFallbackCtx(context.Background(), "fallback", func() error {
return errDummy
}, func(err error) error {
return nil
})
assert.True(t, err == nil || errors.Is(err, errDummy))
assert.True(t, err == nil || err == errDummy)
}
verify(t, func() bool {
return errors.Is(Do("fallback", func() error {
return ErrServiceUnavailable == Do("fallback", func() error {
return nil
}), ErrServiceUnavailable)
})
})
}
func TestBreakersAcceptableFallback(t *testing.T) {
errDummy := errors.New("any")
for i := 0; i < 5000; i++ {
for i := 0; i < 10000; i++ {
err := DoWithFallbackAcceptable("acceptablefallback", func() error {
return errDummy
}, func(err error) error {
@@ -111,20 +94,12 @@ func TestBreakersAcceptableFallback(t *testing.T) {
}, func(err error) bool {
return err == nil
})
assert.True(t, err == nil || errors.Is(err, errDummy))
err = DoWithFallbackAcceptableCtx(context.Background(), "acceptablefallback", func() error {
return errDummy
}, func(err error) error {
return nil
}, func(err error) bool {
return err == nil
})
assert.True(t, err == nil || errors.Is(err, errDummy))
assert.True(t, err == nil || err == errDummy)
}
verify(t, func() bool {
return errors.Is(Do("acceptablefallback", func() error {
return ErrServiceUnavailable == Do("acceptablefallback", func() error {
return nil
}), ErrServiceUnavailable)
})
})
}
@@ -135,5 +110,5 @@ func verify(t *testing.T, fn func() bool) {
count++
}
}
assert.True(t, count >= 75, fmt.Sprintf("should be greater than 75, actual %d", count))
assert.True(t, count >= 80, fmt.Sprintf("should be greater than 80, actual %d", count))
}

View File

@@ -1,48 +0,0 @@
package breaker
const (
success = iota
fail
drop
)
// bucket defines the bucket that holds sum and num of additions.
type bucket struct {
Sum int64
Success int64
Failure int64
Drop int64
}
func (b *bucket) Add(v int64) {
switch v {
case fail:
b.fail()
case drop:
b.drop()
default:
b.succeed()
}
}
func (b *bucket) Reset() {
b.Sum = 0
b.Success = 0
b.Failure = 0
b.Drop = 0
}
func (b *bucket) drop() {
b.Sum++
b.Drop++
}
func (b *bucket) fail() {
b.Sum++
b.Failure++
}
func (b *bucket) succeed() {
b.Sum++
b.Success++
}

View File

@@ -1,43 +0,0 @@
package breaker
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestBucketAdd(t *testing.T) {
b := &bucket{}
// Test succeed
b.Add(0) // Using 0 for success
assert.Equal(t, int64(1), b.Sum, "Sum should be incremented")
assert.Equal(t, int64(1), b.Success, "Success should be incremented")
assert.Equal(t, int64(0), b.Failure, "Failure should not be incremented")
assert.Equal(t, int64(0), b.Drop, "Drop should not be incremented")
// Test failure
b.Add(fail)
assert.Equal(t, int64(2), b.Sum, "Sum should be incremented")
assert.Equal(t, int64(1), b.Failure, "Failure should be incremented")
assert.Equal(t, int64(0), b.Drop, "Drop should not be incremented")
// Test drop
b.Add(drop)
assert.Equal(t, int64(3), b.Sum, "Sum should be incremented")
assert.Equal(t, int64(1), b.Drop, "Drop should be incremented")
}
func TestBucketReset(t *testing.T) {
b := &bucket{
Sum: 3,
Success: 1,
Failure: 1,
Drop: 1,
}
b.Reset()
assert.Equal(t, int64(0), b.Sum, "Sum should be reset to 0")
assert.Equal(t, int64(0), b.Success, "Success should be reset to 0")
assert.Equal(t, int64(0), b.Failure, "Failure should be reset to 0")
assert.Equal(t, int64(0), b.Drop, "Drop should be reset to 0")
}

View File

@@ -1,87 +1,57 @@
package breaker
import (
"math"
"time"
"github.com/zeromicro/go-zero/core/collection"
"github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/timex"
)
const (
// 250ms for bucket duration
window = time.Second * 10
buckets = 40
forcePassDuration = time.Second
k = 1.5
minK = 1.1
protection = 5
window = time.Second * 10
buckets = 40
k = 1.5
protection = 5
)
// googleBreaker is a netflixBreaker pattern from google.
// see Client-Side Throttling section in https://landing.google.com/sre/sre-book/chapters/handling-overload/
type (
googleBreaker struct {
k float64
stat *collection.RollingWindow[int64, *bucket]
proba *mathx.Proba
lastPass *syncx.AtomicDuration
}
windowResult struct {
accepts int64
total int64
failingBuckets int64
workingBuckets int64
}
)
type googleBreaker struct {
k float64
stat *collection.RollingWindow
proba *mathx.Proba
}
func newGoogleBreaker() *googleBreaker {
bucketDuration := time.Duration(int64(window) / int64(buckets))
st := collection.NewRollingWindow[int64, *bucket](func() *bucket {
return new(bucket)
}, buckets, bucketDuration)
st := collection.NewRollingWindow(buckets, bucketDuration)
return &googleBreaker{
stat: st,
k: k,
proba: mathx.NewProba(),
lastPass: syncx.NewAtomicDuration(),
stat: st,
k: k,
proba: mathx.NewProba(),
}
}
func (b *googleBreaker) accept() error {
var w float64
history := b.history()
w = b.k - (b.k-minK)*float64(history.failingBuckets)/buckets
weightedAccepts := mathx.AtLeast(w, minK) * float64(history.accepts)
accepts, total := b.history()
weightedAccepts := b.k * float64(accepts)
// https://landing.google.com/sre/sre-book/chapters/handling-overload/#eq2101
// for better performance, no need to care about the negative ratio
dropRatio := (float64(history.total-protection) - weightedAccepts) / float64(history.total+1)
dropRatio := math.Max(0, (float64(total-protection)-weightedAccepts)/float64(total+1))
if dropRatio <= 0 {
return nil
}
lastPass := b.lastPass.Load()
if lastPass > 0 && timex.Since(lastPass) > forcePassDuration {
b.lastPass.Set(timex.Now())
return nil
}
dropRatio *= float64(buckets-history.workingBuckets) / buckets
if b.proba.TrueOnProba(dropRatio) {
return ErrServiceUnavailable
}
b.lastPass.Set(timex.Now())
return nil
}
func (b *googleBreaker) allow() (internalPromise, error) {
if err := b.accept(); err != nil {
b.markDrop()
return nil, err
}
@@ -90,9 +60,8 @@ func (b *googleBreaker) allow() (internalPromise, error) {
}, nil
}
func (b *googleBreaker) doReq(req func() error, fallback Fallback, acceptable Acceptable) error {
func (b *googleBreaker) doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error {
if err := b.accept(); err != nil {
b.markDrop()
if fallback != nil {
return fallback(err)
}
@@ -100,55 +69,38 @@ func (b *googleBreaker) doReq(req func() error, fallback Fallback, acceptable Ac
return err
}
var succ bool
defer func() {
// if req() panic, success is false, mark as failure
if succ {
b.markSuccess()
} else {
if e := recover(); e != nil {
b.markFailure()
panic(e)
}
}()
err := req()
if acceptable(err) {
succ = true
b.markSuccess()
} else {
b.markFailure()
}
return err
}
func (b *googleBreaker) markDrop() {
b.stat.Add(drop)
func (b *googleBreaker) markSuccess() {
b.stat.Add(1)
}
func (b *googleBreaker) markFailure() {
b.stat.Add(fail)
b.stat.Add(0)
}
func (b *googleBreaker) markSuccess() {
b.stat.Add(success)
}
func (b *googleBreaker) history() windowResult {
var result windowResult
b.stat.Reduce(func(b *bucket) {
result.accepts += b.Success
result.total += b.Sum
if b.Failure > 0 {
result.workingBuckets = 0
} else if b.Success > 0 {
result.workingBuckets++
}
if b.Success > 0 {
result.failingBuckets = 0
} else if b.Failure > 0 {
result.failingBuckets++
}
func (b *googleBreaker) history() (accepts, total int64) {
b.stat.Reduce(func(b *collection.Bucket) {
accepts += int64(b.Sum)
total += b.Count
})
return result
return
}
type googlePromise struct {

View File

@@ -10,7 +10,6 @@ import (
"github.com/zeromicro/go-zero/core/collection"
"github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/core/syncx"
)
const (
@@ -23,14 +22,11 @@ func init() {
}
func getGoogleBreaker() *googleBreaker {
st := collection.NewRollingWindow[int64, *bucket](func() *bucket {
return new(bucket)
}, testBuckets, testInterval)
st := collection.NewRollingWindow(testBuckets, testInterval)
return &googleBreaker{
stat: st,
k: 5,
proba: mathx.NewProba(),
lastPass: syncx.NewAtomicDuration(),
stat: st,
k: 5,
proba: mathx.NewProba(),
}
}
@@ -67,33 +63,6 @@ func TestGoogleBreakerOpen(t *testing.T) {
})
}
func TestGoogleBreakerRecover(t *testing.T) {
st := collection.NewRollingWindow[int64, *bucket](func() *bucket {
return new(bucket)
}, testBuckets*2, testInterval)
b := &googleBreaker{
stat: st,
k: k,
proba: mathx.NewProba(),
lastPass: syncx.NewAtomicDuration(),
}
for i := 0; i < testBuckets; i++ {
for j := 0; j < 100; j++ {
b.stat.Add(1)
}
time.Sleep(testInterval)
}
for i := 0; i < testBuckets; i++ {
for j := 0; j < 100; j++ {
b.stat.Add(0)
}
time.Sleep(testInterval)
}
verify(t, func() bool {
return b.accept() == nil
})
}
func TestGoogleBreakerFallback(t *testing.T) {
b := getGoogleBreaker()
markSuccess(b, 1)
@@ -120,50 +89,13 @@ func TestGoogleBreakerReject(t *testing.T) {
}, nil, defaultAcceptable))
}
func TestGoogleBreakerMoreFallingBuckets(t *testing.T) {
t.Parallel()
t.Run("more falling buckets", func(t *testing.T) {
b := getGoogleBreaker()
func() {
stopChan := time.After(testInterval * 6)
for {
time.Sleep(time.Millisecond)
select {
case <-stopChan:
return
default:
assert.Error(t, b.doReq(func() error {
return errors.New("foo")
}, func(err error) error {
return err
}, func(err error) bool {
return err == nil
}))
}
}
}()
var count int
for i := 0; i < 100; i++ {
if errors.Is(b.doReq(func() error {
return ErrServiceUnavailable
}, nil, defaultAcceptable), ErrServiceUnavailable) {
count++
}
}
assert.True(t, count > 90)
})
}
func TestGoogleBreakerAcceptable(t *testing.T) {
b := getGoogleBreaker()
errAcceptable := errors.New("any")
assert.Equal(t, errAcceptable, b.doReq(func() error {
return errAcceptable
}, nil, func(err error) bool {
return errors.Is(err, errAcceptable)
return err == errAcceptable
}))
}
@@ -173,7 +105,7 @@ func TestGoogleBreakerNotAcceptable(t *testing.T) {
assert.Equal(t, errAcceptable, b.doReq(func() error {
return errAcceptable
}, nil, func(err error) bool {
return !errors.Is(err, errAcceptable)
return err != errAcceptable
}))
}
@@ -232,38 +164,41 @@ func TestGoogleBreakerSelfProtection(t *testing.T) {
}
func TestGoogleBreakerHistory(t *testing.T) {
var b *googleBreaker
var accepts, total int64
sleep := testInterval
t.Run("accepts == total", func(t *testing.T) {
b := getGoogleBreaker()
b = getGoogleBreaker()
markSuccessWithDuration(b, 10, sleep/2)
result := b.history()
assert.Equal(t, int64(10), result.accepts)
assert.Equal(t, int64(10), result.total)
accepts, total = b.history()
assert.Equal(t, int64(10), accepts)
assert.Equal(t, int64(10), total)
})
t.Run("fail == total", func(t *testing.T) {
b := getGoogleBreaker()
b = getGoogleBreaker()
markFailedWithDuration(b, 10, sleep/2)
result := b.history()
assert.Equal(t, int64(0), result.accepts)
assert.Equal(t, int64(10), result.total)
accepts, total = b.history()
assert.Equal(t, int64(0), accepts)
assert.Equal(t, int64(10), total)
})
t.Run("accepts = 1/2 * total, fail = 1/2 * total", func(t *testing.T) {
b := getGoogleBreaker()
b = getGoogleBreaker()
markFailedWithDuration(b, 5, sleep/2)
markSuccessWithDuration(b, 5, sleep/2)
result := b.history()
assert.Equal(t, int64(5), result.accepts)
assert.Equal(t, int64(10), result.total)
accepts, total = b.history()
assert.Equal(t, int64(5), accepts)
assert.Equal(t, int64(10), total)
})
t.Run("auto reset rolling counter", func(t *testing.T) {
b := getGoogleBreaker()
b = getGoogleBreaker()
time.Sleep(testInterval * testBuckets)
result := b.history()
assert.Equal(t, int64(0), result.accepts)
assert.Equal(t, int64(0), result.total)
accepts, total = b.history()
assert.Equal(t, int64(0), accepts)
assert.Equal(t, int64(0), total)
})
}
@@ -271,7 +206,7 @@ func BenchmarkGoogleBreakerAllow(b *testing.B) {
breaker := getGoogleBreaker()
b.ResetTimer()
for i := 0; i <= b.N; i++ {
_ = breaker.accept()
breaker.accept()
if i%2 == 0 {
breaker.markSuccess()
} else {
@@ -280,16 +215,6 @@ func BenchmarkGoogleBreakerAllow(b *testing.B) {
}
}
func BenchmarkGoogleBreakerDoReq(b *testing.B) {
breaker := getGoogleBreaker()
b.ResetTimer()
for i := 0; i <= b.N; i++ {
_ = breaker.doReq(func() error {
return nil
}, nil, defaultAcceptable)
}
}
func markSuccess(b *googleBreaker, count int) {
for i := 0; i < count; i++ {
p, err := b.allow()

View File

@@ -1,58 +1,35 @@
package breaker
import "context"
const noOpBreakerName = "nopBreaker"
const nopBreakerName = "nopBreaker"
type noOpBreaker struct{}
type nopBreaker struct{}
// NopBreaker returns a breaker that never trigger breaker circuit.
func NopBreaker() Breaker {
return nopBreaker{}
func newNoOpBreaker() Breaker {
return noOpBreaker{}
}
func (b nopBreaker) Name() string {
return nopBreakerName
func (b noOpBreaker) Name() string {
return noOpBreakerName
}
func (b nopBreaker) Allow() (Promise, error) {
func (b noOpBreaker) Allow() (Promise, error) {
return nopPromise{}, nil
}
func (b nopBreaker) AllowCtx(_ context.Context) (Promise, error) {
return nopPromise{}, nil
}
func (b nopBreaker) Do(req func() error) error {
func (b noOpBreaker) Do(req func() error) error {
return req()
}
func (b nopBreaker) DoCtx(_ context.Context, req func() error) error {
func (b noOpBreaker) DoWithAcceptable(req func() error, _ Acceptable) error {
return req()
}
func (b nopBreaker) DoWithAcceptable(req func() error, _ Acceptable) error {
func (b noOpBreaker) DoWithFallback(req func() error, _ func(err error) error) error {
return req()
}
func (b nopBreaker) DoWithAcceptableCtx(_ context.Context, req func() error, _ Acceptable) error {
return req()
}
func (b nopBreaker) DoWithFallback(req func() error, _ Fallback) error {
return req()
}
func (b nopBreaker) DoWithFallbackCtx(_ context.Context, req func() error, _ Fallback) error {
return req()
}
func (b nopBreaker) DoWithFallbackAcceptable(req func() error, _ Fallback, _ Acceptable) error {
return req()
}
func (b nopBreaker) DoWithFallbackAcceptableCtx(_ context.Context, req func() error,
_ Fallback, _ Acceptable) error {
func (b noOpBreaker) DoWithFallbackAcceptable(req func() error, _ func(err error) error,
_ Acceptable) error {
return req()
}

View File

@@ -1,7 +1,6 @@
package breaker
import (
"context"
"errors"
"testing"
@@ -9,11 +8,9 @@ import (
)
func TestNopBreaker(t *testing.T) {
b := NopBreaker()
assert.Equal(t, nopBreakerName, b.Name())
_, err := b.Allow()
assert.Nil(t, err)
p, err := b.AllowCtx(context.Background())
b := newNoOpBreaker()
assert.Equal(t, noOpBreakerName, b.Name())
p, err := b.Allow()
assert.Nil(t, err)
p.Accept()
for i := 0; i < 1000; i++ {
@@ -24,34 +21,18 @@ func TestNopBreaker(t *testing.T) {
assert.Nil(t, b.Do(func() error {
return nil
}))
assert.Nil(t, b.DoCtx(context.Background(), func() error {
return nil
}))
assert.Nil(t, b.DoWithAcceptable(func() error {
return nil
}, defaultAcceptable))
assert.Nil(t, b.DoWithAcceptableCtx(context.Background(), func() error {
return nil
}, defaultAcceptable))
errDummy := errors.New("any")
assert.Equal(t, errDummy, b.DoWithFallback(func() error {
return errDummy
}, func(err error) error {
return nil
}))
assert.Equal(t, errDummy, b.DoWithFallbackCtx(context.Background(), func() error {
return errDummy
}, func(err error) error {
return nil
}))
assert.Equal(t, errDummy, b.DoWithFallbackAcceptable(func() error {
return errDummy
}, func(err error) error {
return nil
}, defaultAcceptable))
assert.Equal(t, errDummy, b.DoWithFallbackAcceptableCtx(context.Background(), func() error {
return errDummy
}, func(err error) error {
return nil
}, defaultAcceptable))
}

View File

@@ -23,7 +23,7 @@ var (
zero = big.NewInt(0)
)
// DhKey defines the Diffie-Hellman key.
// DhKey defines the Diffie Hellman key.
type DhKey struct {
PriKey *big.Int
PubKey *big.Int
@@ -46,7 +46,7 @@ func ComputeKey(pubKey, priKey *big.Int) (*big.Int, error) {
return new(big.Int).Exp(pubKey, priKey, p), nil
}
// GenerateKey returns a Diffie-Hellman key.
// GenerateKey returns a Diffie Hellman key.
func GenerateKey() (*DhKey, error) {
var err error
var x *big.Int

View File

@@ -2,8 +2,6 @@ package codec
import (
"bytes"
"compress/gzip"
"errors"
"fmt"
"testing"
@@ -23,45 +21,3 @@ func TestGzip(t *testing.T) {
assert.True(t, len(bs) < buf.Len())
assert.Equal(t, buf.Bytes(), actual)
}
func TestGunzip(t *testing.T) {
tests := []struct {
name string
input []byte
expected []byte
expectedErr error
}{
{
name: "valid input",
input: func() []byte {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
gz.Write([]byte("hello"))
gz.Close()
return buf.Bytes()
}(),
expected: []byte("hello"),
expectedErr: nil,
},
{
name: "invalid input",
input: []byte("invalid input"),
expected: nil,
expectedErr: gzip.ErrHeader,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result, err := Gunzip(test.input)
if !bytes.Equal(result, test.expected) {
t.Errorf("unexpected result: %v", result)
}
if !errors.Is(err, test.expectedErr) {
t.Errorf("unexpected error: %v", err)
}
})
}
}

View File

@@ -2,7 +2,6 @@ package codec
import (
"encoding/base64"
"os"
"testing"
"github.com/stretchr/testify/assert"
@@ -42,7 +41,6 @@ func TestCryption(t *testing.T) {
file, err := fs.TempFilenameWithText(priKey)
assert.Nil(t, err)
defer os.Remove(file)
dec, err := NewRsaDecrypter(file)
assert.Nil(t, err)
actual, err := dec.Decrypt(ret)

View File

@@ -30,7 +30,7 @@ type (
Cache struct {
name string
lock sync.Mutex
data map[string]any
data map[string]interface{}
expire time.Duration
timingWheel *TimingWheel
lruCache lru
@@ -43,7 +43,7 @@ type (
// NewCache returns a Cache with given expire.
func NewCache(expire time.Duration, opts ...CacheOption) (*Cache, error) {
cache := &Cache{
data: make(map[string]any),
data: make(map[string]interface{}),
expire: expire,
lruCache: emptyLruCache,
barrier: syncx.NewSingleFlight(),
@@ -59,7 +59,7 @@ func NewCache(expire time.Duration, opts ...CacheOption) (*Cache, error) {
}
cache.stats = newCacheStat(cache.name, cache.size)
timingWheel, err := NewTimingWheel(time.Second, slots, func(k, v any) {
timingWheel, err := NewTimingWheel(time.Second, slots, func(k, v interface{}) {
key, ok := k.(string)
if !ok {
return
@@ -85,7 +85,7 @@ func (c *Cache) Del(key string) {
}
// Get returns the item with the given key from c.
func (c *Cache) Get(key string) (any, bool) {
func (c *Cache) Get(key string) (interface{}, bool) {
value, ok := c.doGet(key)
if ok {
c.stats.IncrementHit()
@@ -97,12 +97,12 @@ func (c *Cache) Get(key string) (any, bool) {
}
// Set sets value into c with key.
func (c *Cache) Set(key string, value any) {
func (c *Cache) Set(key string, value interface{}) {
c.SetWithExpire(key, value, c.expire)
}
// SetWithExpire sets value into c with key and expire with the given value.
func (c *Cache) SetWithExpire(key string, value any, expire time.Duration) {
func (c *Cache) SetWithExpire(key string, value interface{}, expire time.Duration) {
c.lock.Lock()
_, ok := c.data[key]
c.data[key] = value
@@ -120,16 +120,16 @@ func (c *Cache) SetWithExpire(key string, value any, expire time.Duration) {
// Take returns the item with the given key.
// If the item is in c, return it directly.
// If not, use fetch method to get the item, set into c and return it.
func (c *Cache) Take(key string, fetch func() (any, error)) (any, error) {
func (c *Cache) Take(key string, fetch func() (interface{}, error)) (interface{}, error) {
if val, ok := c.doGet(key); ok {
c.stats.IncrementHit()
return val, nil
}
var fresh bool
val, err := c.barrier.Do(key, func() (any, error) {
// because O(1) on map search in memory, and fetch is an IO query,
// so we do double-check, cache might be taken by another call
val, err := c.barrier.Do(key, func() (interface{}, error) {
// because O(1) on map search in memory, and fetch is an IO query
// so we do double check, cache might be taken by another call
if val, ok := c.doGet(key); ok {
return val, nil
}
@@ -157,7 +157,7 @@ func (c *Cache) Take(key string, fetch func() (any, error)) (any, error) {
return val, nil
}
func (c *Cache) doGet(key string) (any, bool) {
func (c *Cache) doGet(key string) (interface{}, bool) {
c.lock.Lock()
defer c.lock.Unlock()

View File

@@ -52,7 +52,7 @@ func TestCacheTake(t *testing.T) {
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
cache.Take("first", func() (any, error) {
cache.Take("first", func() (interface{}, error) {
atomic.AddInt32(&count, 1)
time.Sleep(time.Millisecond * 100)
return "first element", nil
@@ -76,7 +76,7 @@ func TestCacheTakeExists(t *testing.T) {
wg.Add(1)
go func() {
cache.Set("first", "first element")
cache.Take("first", func() (any, error) {
cache.Take("first", func() (interface{}, error) {
atomic.AddInt32(&count, 1)
time.Sleep(time.Millisecond * 100)
return "first element", nil
@@ -99,7 +99,7 @@ func TestCacheTakeError(t *testing.T) {
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
_, err := cache.Take("first", func() (any, error) {
_, err := cache.Take("first", func() (interface{}, error) {
atomic.AddInt32(&count, 1)
time.Sleep(time.Millisecond * 100)
return "", errDummy

View File

@@ -5,7 +5,7 @@ import "sync"
// A Queue is a FIFO queue.
type Queue struct {
lock sync.Mutex
elements []any
elements []interface{}
size int
head int
tail int
@@ -15,7 +15,7 @@ type Queue struct {
// NewQueue returns a Queue object.
func NewQueue(size int) *Queue {
return &Queue{
elements: make([]any, size),
elements: make([]interface{}, size),
size: size,
}
}
@@ -30,12 +30,12 @@ func (q *Queue) Empty() bool {
}
// Put puts element into q at the last position.
func (q *Queue) Put(element any) {
func (q *Queue) Put(element interface{}) {
q.lock.Lock()
defer q.lock.Unlock()
if q.head == q.tail && q.count > 0 {
nodes := make([]any, len(q.elements)+q.size)
nodes := make([]interface{}, len(q.elements)+q.size)
copy(nodes, q.elements[q.head:])
copy(nodes[len(q.elements)-q.head:], q.elements[:q.head])
q.head = 0
@@ -49,7 +49,7 @@ func (q *Queue) Put(element any) {
}
// Take takes the first element out of q if not empty.
func (q *Queue) Take() (any, bool) {
func (q *Queue) Take() (interface{}, bool) {
q.lock.Lock()
defer q.lock.Unlock()

View File

@@ -4,7 +4,7 @@ import "sync"
// A Ring can be used as fixed size ring.
type Ring struct {
elements []any
elements []interface{}
index int
lock sync.RWMutex
}
@@ -16,44 +16,36 @@ func NewRing(n int) *Ring {
}
return &Ring{
elements: make([]any, n),
elements: make([]interface{}, n),
}
}
// Add adds v into r.
func (r *Ring) Add(v any) {
func (r *Ring) Add(v interface{}) {
r.lock.Lock()
defer r.lock.Unlock()
rlen := len(r.elements)
r.elements[r.index%rlen] = v
r.elements[r.index%len(r.elements)] = v
r.index++
// prevent ring index overflow
if r.index >= rlen<<1 {
r.index -= rlen
}
}
// Take takes all items from r.
func (r *Ring) Take() []any {
func (r *Ring) Take() []interface{} {
r.lock.RLock()
defer r.lock.RUnlock()
var size int
var start int
rlen := len(r.elements)
if r.index > rlen {
size = rlen
start = r.index % rlen
if r.index > len(r.elements) {
size = len(r.elements)
start = r.index % len(r.elements)
} else {
size = r.index
}
elements := make([]any, size)
elements := make([]interface{}, size)
for i := 0; i < size; i++ {
elements[i] = r.elements[(start+i)%rlen]
elements[i] = r.elements[(start+i)%len(r.elements)]
}
return elements

View File

@@ -19,7 +19,7 @@ func TestRingLess(t *testing.T) {
ring.Add(i)
}
elements := ring.Take()
assert.ElementsMatch(t, []any{0, 1, 2}, elements)
assert.ElementsMatch(t, []interface{}{0, 1, 2}, elements)
}
func TestRingMore(t *testing.T) {
@@ -28,7 +28,7 @@ func TestRingMore(t *testing.T) {
ring.Add(i)
}
elements := ring.Take()
assert.ElementsMatch(t, []any{6, 7, 8, 9, 10}, elements)
assert.ElementsMatch(t, []interface{}{6, 7, 8, 9, 10}, elements)
}
func TestRingAdd(t *testing.T) {

View File

@@ -4,28 +4,18 @@ import (
"sync"
"time"
"github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/timex"
)
type (
// BucketInterface is the interface that defines the buckets.
BucketInterface[T Numerical] interface {
Add(v T)
Reset()
}
// Numerical is the interface that restricts the numerical type.
Numerical = mathx.Numerical
// RollingWindowOption let callers customize the RollingWindow.
RollingWindowOption[T Numerical, B BucketInterface[T]] func(rollingWindow *RollingWindow[T, B])
RollingWindowOption func(rollingWindow *RollingWindow)
// RollingWindow defines a rolling window to calculate the events in buckets with the time interval.
RollingWindow[T Numerical, B BucketInterface[T]] struct {
// RollingWindow defines a rolling window to calculate the events in buckets with time interval.
RollingWindow struct {
lock sync.RWMutex
size int
win *window[T, B]
win *window
interval time.Duration
offset int
ignoreCurrent bool
@@ -35,15 +25,14 @@ type (
// NewRollingWindow returns a RollingWindow that with size buckets and time interval,
// use opts to customize the RollingWindow.
func NewRollingWindow[T Numerical, B BucketInterface[T]](newBucket func() B, size int,
interval time.Duration, opts ...RollingWindowOption[T, B]) *RollingWindow[T, B] {
func NewRollingWindow(size int, interval time.Duration, opts ...RollingWindowOption) *RollingWindow {
if size < 1 {
panic("size must be greater than 0")
}
w := &RollingWindow[T, B]{
w := &RollingWindow{
size: size,
win: newWindow[T, B](newBucket, size),
win: newWindow(size),
interval: interval,
lastTime: timex.Now(),
}
@@ -54,7 +43,7 @@ func NewRollingWindow[T Numerical, B BucketInterface[T]](newBucket func() B, siz
}
// Add adds value to current bucket.
func (rw *RollingWindow[T, B]) Add(v T) {
func (rw *RollingWindow) Add(v float64) {
rw.lock.Lock()
defer rw.lock.Unlock()
rw.updateOffset()
@@ -62,13 +51,13 @@ func (rw *RollingWindow[T, B]) Add(v T) {
}
// Reduce runs fn on all buckets, ignore current bucket if ignoreCurrent was set.
func (rw *RollingWindow[T, B]) Reduce(fn func(b B)) {
func (rw *RollingWindow) Reduce(fn func(b *Bucket)) {
rw.lock.RLock()
defer rw.lock.RUnlock()
var diff int
span := rw.span()
// ignore the current bucket, because of partial data
// ignore current bucket, because of partial data
if span == 0 && rw.ignoreCurrent {
diff = rw.size - 1
} else {
@@ -80,7 +69,7 @@ func (rw *RollingWindow[T, B]) Reduce(fn func(b B)) {
}
}
func (rw *RollingWindow[T, B]) span() int {
func (rw *RollingWindow) span() int {
offset := int(timex.Since(rw.lastTime) / rw.interval)
if 0 <= offset && offset < rw.size {
return offset
@@ -89,7 +78,7 @@ func (rw *RollingWindow[T, B]) span() int {
return rw.size
}
func (rw *RollingWindow[T, B]) updateOffset() {
func (rw *RollingWindow) updateOffset() {
span := rw.span()
if span <= 0 {
return
@@ -108,54 +97,54 @@ func (rw *RollingWindow[T, B]) updateOffset() {
}
// Bucket defines the bucket that holds sum and num of additions.
type Bucket[T Numerical] struct {
Sum T
type Bucket struct {
Sum float64
Count int64
}
func (b *Bucket[T]) Add(v T) {
func (b *Bucket) add(v float64) {
b.Sum += v
b.Count++
}
func (b *Bucket[T]) Reset() {
func (b *Bucket) reset() {
b.Sum = 0
b.Count = 0
}
type window[T Numerical, B BucketInterface[T]] struct {
buckets []B
type window struct {
buckets []*Bucket
size int
}
func newWindow[T Numerical, B BucketInterface[T]](newBucket func() B, size int) *window[T, B] {
buckets := make([]B, size)
func newWindow(size int) *window {
buckets := make([]*Bucket, size)
for i := 0; i < size; i++ {
buckets[i] = newBucket()
buckets[i] = new(Bucket)
}
return &window[T, B]{
return &window{
buckets: buckets,
size: size,
}
}
func (w *window[T, B]) add(offset int, v T) {
w.buckets[offset%w.size].Add(v)
func (w *window) add(offset int, v float64) {
w.buckets[offset%w.size].add(v)
}
func (w *window[T, B]) reduce(start, count int, fn func(b B)) {
func (w *window) reduce(start, count int, fn func(b *Bucket)) {
for i := 0; i < count; i++ {
fn(w.buckets[(start+i)%w.size])
}
}
func (w *window[T, B]) resetBucket(offset int) {
w.buckets[offset%w.size].Reset()
func (w *window) resetBucket(offset int) {
w.buckets[offset%w.size].reset()
}
// IgnoreCurrentBucket lets the Reduce call ignore current bucket.
func IgnoreCurrentBucket[T Numerical, B BucketInterface[T]]() RollingWindowOption[T, B] {
return func(w *RollingWindow[T, B]) {
func IgnoreCurrentBucket() RollingWindowOption {
return func(w *RollingWindow) {
w.ignoreCurrent = true
}
}

View File

@@ -12,24 +12,18 @@ import (
const duration = time.Millisecond * 50
func TestNewRollingWindow(t *testing.T) {
assert.NotNil(t, NewRollingWindow[int64, *Bucket[int64]](func() *Bucket[int64] {
return new(Bucket[int64])
}, 10, time.Second))
assert.NotNil(t, NewRollingWindow(10, time.Second))
assert.Panics(t, func() {
NewRollingWindow[int64, *Bucket[int64]](func() *Bucket[int64] {
return new(Bucket[int64])
}, 0, time.Second)
NewRollingWindow(0, time.Second)
})
}
func TestRollingWindowAdd(t *testing.T) {
const size = 3
r := NewRollingWindow[float64, *Bucket[float64]](func() *Bucket[float64] {
return new(Bucket[float64])
}, size, duration)
r := NewRollingWindow(size, duration)
listBuckets := func() []float64 {
var buckets []float64
r.Reduce(func(b *Bucket[float64]) {
r.Reduce(func(b *Bucket) {
buckets = append(buckets, b.Sum)
})
return buckets
@@ -53,12 +47,10 @@ func TestRollingWindowAdd(t *testing.T) {
func TestRollingWindowReset(t *testing.T) {
const size = 3
r := NewRollingWindow[float64, *Bucket[float64]](func() *Bucket[float64] {
return new(Bucket[float64])
}, size, duration, IgnoreCurrentBucket[float64, *Bucket[float64]]())
r := NewRollingWindow(size, duration, IgnoreCurrentBucket())
listBuckets := func() []float64 {
var buckets []float64
r.Reduce(func(b *Bucket[float64]) {
r.Reduce(func(b *Bucket) {
buckets = append(buckets, b.Sum)
})
return buckets
@@ -80,19 +72,15 @@ func TestRollingWindowReset(t *testing.T) {
func TestRollingWindowReduce(t *testing.T) {
const size = 4
tests := []struct {
win *RollingWindow[float64, *Bucket[float64]]
win *RollingWindow
expect float64
}{
{
win: NewRollingWindow[float64, *Bucket[float64]](func() *Bucket[float64] {
return new(Bucket[float64])
}, size, duration),
win: NewRollingWindow(size, duration),
expect: 10,
},
{
win: NewRollingWindow[float64, *Bucket[float64]](func() *Bucket[float64] {
return new(Bucket[float64])
}, size, duration, IgnoreCurrentBucket[float64, *Bucket[float64]]()),
win: NewRollingWindow(size, duration, IgnoreCurrentBucket()),
expect: 4,
},
}
@@ -109,7 +97,7 @@ func TestRollingWindowReduce(t *testing.T) {
}
}
var result float64
r.Reduce(func(b *Bucket[float64]) {
r.Reduce(func(b *Bucket) {
result += b.Sum
})
assert.Equal(t, test.expect, result)
@@ -120,12 +108,10 @@ func TestRollingWindowReduce(t *testing.T) {
func TestRollingWindowBucketTimeBoundary(t *testing.T) {
const size = 3
interval := time.Millisecond * 30
r := NewRollingWindow[float64, *Bucket[float64]](func() *Bucket[float64] {
return new(Bucket[float64])
}, size, interval)
r := NewRollingWindow(size, interval)
listBuckets := func() []float64 {
var buckets []float64
r.Reduce(func(b *Bucket[float64]) {
r.Reduce(func(b *Bucket) {
buckets = append(buckets, b.Sum)
})
return buckets
@@ -152,9 +138,7 @@ func TestRollingWindowBucketTimeBoundary(t *testing.T) {
func TestRollingWindowDataRace(t *testing.T) {
const size = 3
r := NewRollingWindow[float64, *Bucket[float64]](func() *Bucket[float64] {
return new(Bucket[float64])
}, size, duration)
r := NewRollingWindow(size, duration)
stop := make(chan bool)
go func() {
for {
@@ -173,7 +157,7 @@ func TestRollingWindowDataRace(t *testing.T) {
case <-stop:
return
default:
r.Reduce(func(b *Bucket[float64]) {})
r.Reduce(func(b *Bucket) {})
}
}
}()

View File

@@ -14,23 +14,21 @@ type SafeMap struct {
lock sync.RWMutex
deletionOld int
deletionNew int
dirtyOld map[any]any
dirtyNew map[any]any
dirtyOld map[interface{}]interface{}
dirtyNew map[interface{}]interface{}
}
// NewSafeMap returns a SafeMap.
func NewSafeMap() *SafeMap {
return &SafeMap{
dirtyOld: make(map[any]any),
dirtyNew: make(map[any]any),
dirtyOld: make(map[interface{}]interface{}),
dirtyNew: make(map[interface{}]interface{}),
}
}
// Del deletes the value with the given key from m.
func (m *SafeMap) Del(key any) {
func (m *SafeMap) Del(key interface{}) {
m.lock.Lock()
defer m.lock.Unlock()
if _, ok := m.dirtyOld[key]; ok {
delete(m.dirtyOld, key)
m.deletionOld++
@@ -44,20 +42,21 @@ func (m *SafeMap) Del(key any) {
}
m.dirtyOld = m.dirtyNew
m.deletionOld = m.deletionNew
m.dirtyNew = make(map[any]any)
m.dirtyNew = make(map[interface{}]interface{})
m.deletionNew = 0
}
if m.deletionNew >= maxDeletion && len(m.dirtyNew) < copyThreshold {
for k, v := range m.dirtyNew {
m.dirtyOld[k] = v
}
m.dirtyNew = make(map[any]any)
m.dirtyNew = make(map[interface{}]interface{})
m.deletionNew = 0
}
m.lock.Unlock()
}
// Get gets the value with the given key from m.
func (m *SafeMap) Get(key any) (any, bool) {
func (m *SafeMap) Get(key interface{}) (interface{}, bool) {
m.lock.RLock()
defer m.lock.RUnlock()
@@ -71,7 +70,7 @@ func (m *SafeMap) Get(key any) (any, bool) {
// Range calls f sequentially for each key and value present in the map.
// If f returns false, range stops the iteration.
func (m *SafeMap) Range(f func(key, val any) bool) {
func (m *SafeMap) Range(f func(key, val interface{}) bool) {
m.lock.RLock()
defer m.lock.RUnlock()
@@ -88,10 +87,8 @@ func (m *SafeMap) Range(f func(key, val any) bool) {
}
// Set sets the value into m with the given key.
func (m *SafeMap) Set(key, value any) {
func (m *SafeMap) Set(key, value interface{}) {
m.lock.Lock()
defer m.lock.Unlock()
if m.deletionOld <= maxDeletion {
if _, ok := m.dirtyNew[key]; ok {
delete(m.dirtyNew, key)
@@ -105,6 +102,7 @@ func (m *SafeMap) Set(key, value any) {
}
m.dirtyNew[key] = value
}
m.lock.Unlock()
}
// Size returns the size of m.

View File

@@ -138,7 +138,7 @@ func TestSafeMap_Range(t *testing.T) {
}
var count int32
m.Range(func(k, v any) bool {
m.Range(func(k, v interface{}) bool {
atomic.AddInt32(&count, 1)
newMap.Set(k, v)
return true
@@ -147,65 +147,3 @@ func TestSafeMap_Range(t *testing.T) {
assert.Equal(t, m.dirtyNew, newMap.dirtyNew)
assert.Equal(t, m.dirtyOld, newMap.dirtyOld)
}
func TestSetManyTimes(t *testing.T) {
const iteration = maxDeletion * 2
m := NewSafeMap()
for i := 0; i < iteration; i++ {
m.Set(i, i)
if i%3 == 0 {
m.Del(i / 2)
}
}
var count int
m.Range(func(k, v any) bool {
count++
return count < maxDeletion/2
})
assert.Equal(t, maxDeletion/2, count)
for i := 0; i < iteration; i++ {
m.Set(i, i)
if i%3 == 0 {
m.Del(i / 2)
}
}
for i := 0; i < iteration; i++ {
m.Set(i, i)
if i%3 == 0 {
m.Del(i / 2)
}
}
for i := 0; i < iteration; i++ {
m.Set(i, i)
if i%3 == 0 {
m.Del(i / 2)
}
}
count = 0
m.Range(func(k, v any) bool {
count++
return count < maxDeletion
})
assert.Equal(t, maxDeletion, count)
}
func TestSetManyTimesNew(t *testing.T) {
m := NewSafeMap()
for i := 0; i < maxDeletion*3; i++ {
m.Set(i, i)
}
for i := 0; i < maxDeletion*2; i++ {
m.Del(i)
}
for i := 0; i < maxDeletion*3; i++ {
m.Set(i+maxDeletion*3, i+maxDeletion*3)
}
for i := 0; i < maxDeletion*2; i++ {
m.Del(i + maxDeletion*2)
}
for i := 0; i < maxDeletion-copyThreshold+1; i++ {
m.Del(i + maxDeletion*2)
}
assert.Equal(t, 0, len(m.dirtyNew))
}

View File

@@ -17,14 +17,14 @@ const (
// Set is not thread-safe, for concurrent use, make sure to use it with synchronization.
type Set struct {
data map[any]lang.PlaceholderType
data map[interface{}]lang.PlaceholderType
tp int
}
// NewSet returns a managed Set, can only put the values with the same type.
func NewSet() *Set {
return &Set{
data: make(map[any]lang.PlaceholderType),
data: make(map[interface{}]lang.PlaceholderType),
tp: untyped,
}
}
@@ -32,13 +32,13 @@ func NewSet() *Set {
// NewUnmanagedSet returns an unmanaged Set, which can put values with different types.
func NewUnmanagedSet() *Set {
return &Set{
data: make(map[any]lang.PlaceholderType),
data: make(map[interface{}]lang.PlaceholderType),
tp: unmanaged,
}
}
// Add adds i into s.
func (s *Set) Add(i ...any) {
func (s *Set) Add(i ...interface{}) {
for _, each := range i {
s.add(each)
}
@@ -80,7 +80,7 @@ func (s *Set) AddStr(ss ...string) {
}
// Contains checks if i is in s.
func (s *Set) Contains(i any) bool {
func (s *Set) Contains(i interface{}) bool {
if len(s.data) == 0 {
return false
}
@@ -91,8 +91,8 @@ func (s *Set) Contains(i any) bool {
}
// Keys returns the keys in s.
func (s *Set) Keys() []any {
var keys []any
func (s *Set) Keys() []interface{} {
var keys []interface{}
for key := range s.data {
keys = append(keys, key)
@@ -167,7 +167,7 @@ func (s *Set) KeysStr() []string {
}
// Remove removes i from s.
func (s *Set) Remove(i any) {
func (s *Set) Remove(i interface{}) {
s.validate(i)
delete(s.data, i)
}
@@ -177,7 +177,7 @@ func (s *Set) Count() int {
return len(s.data)
}
func (s *Set) add(i any) {
func (s *Set) add(i interface{}) {
switch s.tp {
case unmanaged:
// do nothing
@@ -189,7 +189,7 @@ func (s *Set) add(i any) {
s.data[i] = lang.Placeholder
}
func (s *Set) setType(i any) {
func (s *Set) setType(i interface{}) {
// s.tp can only be untyped here
switch i.(type) {
case int:
@@ -205,7 +205,7 @@ func (s *Set) setType(i any) {
}
}
func (s *Set) validate(i any) {
func (s *Set) validate(i interface{}) {
if s.tp == unmanaged {
return
}

View File

@@ -13,7 +13,7 @@ func init() {
}
func BenchmarkRawSet(b *testing.B) {
m := make(map[any]struct{})
m := make(map[interface{}]struct{})
for i := 0; i < b.N; i++ {
m[i] = struct{}{}
_ = m[i]
@@ -39,7 +39,7 @@ func BenchmarkSet(b *testing.B) {
func TestAdd(t *testing.T) {
// given
set := NewUnmanagedSet()
values := []any{1, 2, 3}
values := []interface{}{1, 2, 3}
// when
set.Add(values...)
@@ -135,7 +135,7 @@ func TestContainsUnmanagedWithoutElements(t *testing.T) {
func TestRemove(t *testing.T) {
// given
set := NewSet()
set.Add([]any{1, 2, 3}...)
set.Add([]interface{}{1, 2, 3}...)
// when
set.Remove(2)
@@ -147,7 +147,7 @@ func TestRemove(t *testing.T) {
func TestCount(t *testing.T) {
// given
set := NewSet()
set.Add([]any{1, 2, 3}...)
set.Add([]interface{}{1, 2, 3}...)
// then
assert.Equal(t, set.Count(), 3)
@@ -198,5 +198,5 @@ func TestSetType(t *testing.T) {
set.add(1)
set.add("2")
vals := set.Keys()
assert.ElementsMatch(t, []any{1, "2"}, vals)
assert.ElementsMatch(t, []interface{}{1, "2"}, vals)
}

View File

@@ -20,7 +20,7 @@ var (
type (
// Execute defines the method to execute the task.
Execute func(key, value any)
Execute func(key, value interface{})
// A TimingWheel is a timing wheel object to schedule tasks.
TimingWheel struct {
@@ -33,14 +33,14 @@ type (
execute Execute
setChannel chan timingEntry
moveChannel chan baseEntry
removeChannel chan any
drainChannel chan func(key, value any)
removeChannel chan interface{}
drainChannel chan func(key, value interface{})
stopChannel chan lang.PlaceholderType
}
timingEntry struct {
baseEntry
value any
value interface{}
circle int
diff int
removed bool
@@ -48,7 +48,7 @@ type (
baseEntry struct {
delay time.Duration
key any
key interface{}
}
positionEntry struct {
@@ -57,8 +57,8 @@ type (
}
timingTask struct {
key any
value any
key interface{}
value interface{}
}
)
@@ -85,8 +85,8 @@ func NewTimingWheelWithTicker(interval time.Duration, numSlots int, execute Exec
numSlots: numSlots,
setChannel: make(chan timingEntry),
moveChannel: make(chan baseEntry),
removeChannel: make(chan any),
drainChannel: make(chan func(key, value any)),
removeChannel: make(chan interface{}),
drainChannel: make(chan func(key, value interface{})),
stopChannel: make(chan lang.PlaceholderType),
}
@@ -97,7 +97,7 @@ func NewTimingWheelWithTicker(interval time.Duration, numSlots int, execute Exec
}
// Drain drains all items and executes them.
func (tw *TimingWheel) Drain(fn func(key, value any)) error {
func (tw *TimingWheel) Drain(fn func(key, value interface{})) error {
select {
case tw.drainChannel <- fn:
return nil
@@ -107,7 +107,7 @@ func (tw *TimingWheel) Drain(fn func(key, value any)) error {
}
// MoveTimer moves the task with the given key to the given delay.
func (tw *TimingWheel) MoveTimer(key any, delay time.Duration) error {
func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) error {
if delay <= 0 || key == nil {
return ErrArgument
}
@@ -124,7 +124,7 @@ func (tw *TimingWheel) MoveTimer(key any, delay time.Duration) error {
}
// RemoveTimer removes the task with the given key.
func (tw *TimingWheel) RemoveTimer(key any) error {
func (tw *TimingWheel) RemoveTimer(key interface{}) error {
if key == nil {
return ErrArgument
}
@@ -138,7 +138,7 @@ func (tw *TimingWheel) RemoveTimer(key any) error {
}
// SetTimer sets the task value with the given key to the delay.
func (tw *TimingWheel) SetTimer(key, value any, delay time.Duration) error {
func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) error {
if delay <= 0 || key == nil {
return ErrArgument
}
@@ -162,7 +162,7 @@ func (tw *TimingWheel) Stop() {
close(tw.stopChannel)
}
func (tw *TimingWheel) drainAll(fn func(key, value any)) {
func (tw *TimingWheel) drainAll(fn func(key, value interface{})) {
runner := threading.NewTaskRunner(drainWorkers)
for _, slot := range tw.slots {
for e := slot.Front(); e != nil; {
@@ -232,7 +232,7 @@ func (tw *TimingWheel) onTick() {
tw.scanAndRunTasks(l)
}
func (tw *TimingWheel) removeTask(key any) {
func (tw *TimingWheel) removeTask(key interface{}) {
val, ok := tw.timers.Get(key)
if !ok {
return

View File

@@ -20,13 +20,13 @@ const (
)
func TestNewTimingWheel(t *testing.T) {
_, err := NewTimingWheel(0, 10, func(key, value any) {})
_, err := NewTimingWheel(0, 10, func(key, value interface{}) {})
assert.NotNil(t, err)
}
func TestTimingWheel_Drain(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
}, ticker)
tw.SetTimer("first", 3, testStep*4)
tw.SetTimer("second", 5, testStep*7)
@@ -36,7 +36,7 @@ func TestTimingWheel_Drain(t *testing.T) {
var lock sync.Mutex
var wg sync.WaitGroup
wg.Add(3)
tw.Drain(func(key, value any) {
tw.Drain(func(key, value interface{}) {
lock.Lock()
defer lock.Unlock()
keys = append(keys, key.(string))
@@ -50,19 +50,19 @@ func TestTimingWheel_Drain(t *testing.T) {
assert.EqualValues(t, []string{"first", "second", "third"}, keys)
assert.EqualValues(t, []int{3, 5, 7}, vals)
var count int
tw.Drain(func(key, value any) {
tw.Drain(func(key, value interface{}) {
count++
})
time.Sleep(time.Millisecond * 100)
assert.Equal(t, 0, count)
tw.Stop()
assert.Equal(t, ErrClosed, tw.Drain(func(key, value any) {}))
assert.Equal(t, ErrClosed, tw.Drain(func(key, value interface{}) {}))
}
func TestTimingWheel_SetTimerSoon(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
@@ -78,7 +78,7 @@ func TestTimingWheel_SetTimerSoon(t *testing.T) {
func TestTimingWheel_SetTimerTwice(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 5, v.(int))
@@ -96,7 +96,7 @@ func TestTimingWheel_SetTimerTwice(t *testing.T) {
func TestTimingWheel_SetTimerWrongDelay(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker)
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
defer tw.Stop()
assert.NotPanics(t, func() {
tw.SetTimer("any", 3, -testStep)
@@ -105,7 +105,7 @@ func TestTimingWheel_SetTimerWrongDelay(t *testing.T) {
func TestTimingWheel_SetTimerAfterClose(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker)
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
tw.Stop()
assert.Equal(t, ErrClosed, tw.SetTimer("any", 3, testStep))
}
@@ -113,7 +113,7 @@ func TestTimingWheel_SetTimerAfterClose(t *testing.T) {
func TestTimingWheel_MoveTimer(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v any) {
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
@@ -139,7 +139,7 @@ func TestTimingWheel_MoveTimer(t *testing.T) {
func TestTimingWheel_MoveTimerSoon(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v any) {
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
@@ -155,7 +155,7 @@ func TestTimingWheel_MoveTimerSoon(t *testing.T) {
func TestTimingWheel_MoveTimerEarlier(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
@@ -173,7 +173,7 @@ func TestTimingWheel_MoveTimerEarlier(t *testing.T) {
func TestTimingWheel_RemoveTimer(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker)
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
tw.SetTimer("any", 3, testStep)
assert.NotPanics(t, func() {
tw.RemoveTimer("any")
@@ -236,7 +236,7 @@ func TestTimingWheel_SetTimer(t *testing.T) {
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) {
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
assert.Equal(t, 1, key.(int))
assert.Equal(t, 2, value.(int))
actual = atomic.LoadInt32(&count)
@@ -317,7 +317,7 @@ func TestTimingWheel_SetAndMoveThenStart(t *testing.T) {
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) {
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
@@ -405,7 +405,7 @@ func TestTimingWheel_SetAndMoveTwice(t *testing.T) {
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) {
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
@@ -486,7 +486,7 @@ func TestTimingWheel_ElapsedAndSet(t *testing.T) {
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) {
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
@@ -577,7 +577,7 @@ func TestTimingWheel_ElapsedAndSetThenMove(t *testing.T) {
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) {
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
@@ -612,7 +612,7 @@ func TestMoveAndRemoveTask(t *testing.T) {
}
}
var keys []int
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
keys = append(keys, v.(int))
@@ -632,7 +632,7 @@ func TestMoveAndRemoveTask(t *testing.T) {
func BenchmarkTimingWheel(b *testing.B) {
b.ReportAllocs()
tw, _ := NewTimingWheel(time.Second, 100, func(k, v any) {})
tw, _ := NewTimingWheel(time.Second, 100, func(k, v interface{}) {})
for i := 0; i < b.N; i++ {
tw.SetTimer(i, i, time.Second)
tw.SetTimer(b.N+i, b.N+i, time.Second)

View File

@@ -13,14 +13,11 @@ import (
"github.com/zeromicro/go-zero/internal/encoding"
)
const (
jsonTagKey = "json"
jsonTagSep = ','
)
const jsonTagKey = "json"
var (
fillDefaultUnmarshaler = mapping.NewUnmarshaler(jsonTagKey, mapping.WithDefault())
loaders = map[string]func([]byte, any) error{
loaders = map[string]func([]byte, interface{}) error{
".json": LoadFromJsonBytes,
".toml": LoadFromTomlBytes,
".yaml": LoadFromYamlBytes,
@@ -37,12 +34,12 @@ type fieldInfo struct {
// FillDefault fills the default values for the given v,
// and the premise is that the value of v must be guaranteed to be empty.
func FillDefault(v any) error {
return fillDefaultUnmarshaler.Unmarshal(map[string]any{}, v)
func FillDefault(v interface{}) error {
return fillDefaultUnmarshaler.Unmarshal(map[string]interface{}{}, v)
}
// Load loads config into v from file, .json, .yaml and .yml are acceptable.
func Load(file string, v any, opts ...Option) error {
func Load(file string, v interface{}, opts ...Option) error {
content, err := os.ReadFile(file)
if err != nil {
return err
@@ -62,49 +59,40 @@ func Load(file string, v any, opts ...Option) error {
return loader([]byte(os.ExpandEnv(string(content))), v)
}
if err = loader(content, v); err != nil {
return err
}
return validate(v)
return loader(content, v)
}
// LoadConfig loads config into v from file, .json, .yaml and .yml are acceptable.
// Deprecated: use Load instead.
func LoadConfig(file string, v any, opts ...Option) error {
func LoadConfig(file string, v interface{}, opts ...Option) error {
return Load(file, v, opts...)
}
// LoadFromJsonBytes loads config into v from content json bytes.
func LoadFromJsonBytes(content []byte, v any) error {
info, err := buildFieldsInfo(reflect.TypeOf(v), "")
func LoadFromJsonBytes(content []byte, v interface{}) error {
info, err := buildFieldsInfo(reflect.TypeOf(v))
if err != nil {
return err
}
var m map[string]any
if err = jsonx.Unmarshal(content, &m); err != nil {
var m map[string]interface{}
if err := jsonx.Unmarshal(content, &m); err != nil {
return err
}
lowerCaseKeyMap := toLowerCaseKeyMap(m, info)
if err = mapping.UnmarshalJsonMap(lowerCaseKeyMap, v,
mapping.WithCanonicalKeyFunc(toLowerCase)); err != nil {
return err
}
return validate(v)
return mapping.UnmarshalJsonMap(lowerCaseKeyMap, v, mapping.WithCanonicalKeyFunc(toLowerCase))
}
// LoadConfigFromJsonBytes loads config into v from content json bytes.
// Deprecated: use LoadFromJsonBytes instead.
func LoadConfigFromJsonBytes(content []byte, v any) error {
func LoadConfigFromJsonBytes(content []byte, v interface{}) error {
return LoadFromJsonBytes(content, v)
}
// LoadFromTomlBytes loads config into v from content toml bytes.
func LoadFromTomlBytes(content []byte, v any) error {
func LoadFromTomlBytes(content []byte, v interface{}) error {
b, err := encoding.TomlToJson(content)
if err != nil {
return err
@@ -114,7 +102,7 @@ func LoadFromTomlBytes(content []byte, v any) error {
}
// LoadFromYamlBytes loads config into v from content yaml bytes.
func LoadFromYamlBytes(content []byte, v any) error {
func LoadFromYamlBytes(content []byte, v interface{}) error {
b, err := encoding.YamlToJson(content)
if err != nil {
return err
@@ -125,24 +113,24 @@ func LoadFromYamlBytes(content []byte, v any) error {
// LoadConfigFromYamlBytes loads config into v from content yaml bytes.
// Deprecated: use LoadFromYamlBytes instead.
func LoadConfigFromYamlBytes(content []byte, v any) error {
func LoadConfigFromYamlBytes(content []byte, v interface{}) error {
return LoadFromYamlBytes(content, v)
}
// MustLoad loads config into v from path, exits on error.
func MustLoad(path string, v any, opts ...Option) {
func MustLoad(path string, v interface{}, opts ...Option) {
if err := Load(path, v, opts...); err != nil {
log.Fatalf("error: config file %s, %s", path, err.Error())
}
}
func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo, fullName string) error {
func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo) error {
if prev, ok := info.children[key]; ok {
if child.mapField != nil {
return newConflictKeyError(fullName)
return newDupKeyError(key)
}
if err := mergeFields(prev, child.children, fullName); err != nil {
if err := mergeFields(prev, key, child.children); err != nil {
return err
}
} else {
@@ -152,27 +140,27 @@ func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo, fullName st
return nil
}
func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type, fullName string) error {
func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
switch ft.Kind() {
case reflect.Struct:
fields, err := buildFieldsInfo(ft, fullName)
fields, err := buildFieldsInfo(ft)
if err != nil {
return err
}
for k, v := range fields.children {
if err = addOrMergeFields(info, k, v, fullName); err != nil {
if err = addOrMergeFields(info, k, v); err != nil {
return err
}
}
case reflect.Map:
elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()), fullName)
elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
if err != nil {
return err
}
if _, ok := info.children[lowerCaseName]; ok {
return newConflictKeyError(fullName)
return newDupKeyError(lowerCaseName)
}
info.children[lowerCaseName] = &fieldInfo{
@@ -181,7 +169,7 @@ func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.T
}
default:
if _, ok := info.children[lowerCaseName]; ok {
return newConflictKeyError(fullName)
return newDupKeyError(lowerCaseName)
}
info.children[lowerCaseName] = &fieldInfo{
@@ -192,16 +180,16 @@ func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.T
return nil
}
func buildFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) {
func buildFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
tp = mapping.Deref(tp)
switch tp.Kind() {
case reflect.Struct:
return buildStructFieldsInfo(tp, fullName)
case reflect.Array, reflect.Slice, reflect.Map:
return buildFieldsInfo(mapping.Deref(tp.Elem()), fullName)
return buildStructFieldsInfo(tp)
case reflect.Array, reflect.Slice:
return buildFieldsInfo(mapping.Deref(tp.Elem()))
case reflect.Chan, reflect.Func:
return nil, fmt.Errorf("unsupported type: %s, fullName: %s", tp.Kind(), fullName)
return nil, fmt.Errorf("unsupported type: %s", tp.Kind())
default:
return &fieldInfo{
children: make(map[string]*fieldInfo),
@@ -209,23 +197,23 @@ func buildFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) {
}
}
func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type, fullName string) error {
func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
var finfo *fieldInfo
var err error
switch ft.Kind() {
case reflect.Struct:
finfo, err = buildFieldsInfo(ft, fullName)
finfo, err = buildFieldsInfo(ft)
if err != nil {
return err
}
case reflect.Array, reflect.Slice:
finfo, err = buildFieldsInfo(ft.Elem(), fullName)
finfo, err = buildFieldsInfo(ft.Elem())
if err != nil {
return err
}
case reflect.Map:
elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()), fullName)
elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
if err != nil {
return err
}
@@ -235,37 +223,31 @@ func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type,
mapField: elemInfo,
}
default:
finfo, err = buildFieldsInfo(ft, fullName)
finfo, err = buildFieldsInfo(ft)
if err != nil {
return err
}
}
return addOrMergeFields(info, lowerCaseName, finfo, fullName)
return addOrMergeFields(info, lowerCaseName, finfo)
}
func buildStructFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) {
func buildStructFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
info := &fieldInfo{
children: make(map[string]*fieldInfo),
}
for i := 0; i < tp.NumField(); i++ {
field := tp.Field(i)
if !field.IsExported() {
continue
}
name := getTagName(field)
name := field.Name
lowerCaseName := toLowerCase(name)
ft := mapping.Deref(field.Type)
// flatten anonymous fields
if field.Anonymous {
if err := buildAnonymousFieldInfo(info, lowerCaseName, ft,
getFullName(fullName, lowerCaseName)); err != nil {
if err := buildAnonymousFieldInfo(info, lowerCaseName, ft); err != nil {
return nil, err
}
} else if err := buildNamedFieldInfo(info, lowerCaseName, ft,
getFullName(fullName, lowerCaseName)); err != nil {
} else if err := buildNamedFieldInfo(info, lowerCaseName, ft); err != nil {
return nil, err
}
}
@@ -273,32 +255,15 @@ func buildStructFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error)
return info, nil
}
// getTagName get the tag name of the given field, if no tag name, use file.Name.
// field.Name is returned on tags like `json:""` and `json:",optional"`.
func getTagName(field reflect.StructField) string {
if tag, ok := field.Tag.Lookup(jsonTagKey); ok {
if pos := strings.IndexByte(tag, jsonTagSep); pos >= 0 {
tag = tag[:pos]
}
tag = strings.TrimSpace(tag)
if len(tag) > 0 {
return tag
}
}
return field.Name
}
func mergeFields(prev *fieldInfo, children map[string]*fieldInfo, fullName string) error {
func mergeFields(prev *fieldInfo, key string, children map[string]*fieldInfo) error {
if len(prev.children) == 0 || len(children) == 0 {
return newConflictKeyError(fullName)
return newDupKeyError(key)
}
// merge fields
for k, v := range children {
if _, ok := prev.children[k]; ok {
return newConflictKeyError(fullName)
return newDupKeyError(k)
}
prev.children[k] = v
@@ -311,12 +276,12 @@ func toLowerCase(s string) string {
return strings.ToLower(s)
}
func toLowerCaseInterface(v any, info *fieldInfo) any {
func toLowerCaseInterface(v interface{}, info *fieldInfo) interface{} {
switch vv := v.(type) {
case map[string]any:
case map[string]interface{}:
return toLowerCaseKeyMap(vv, info)
case []any:
var arr []any
case []interface{}:
var arr []interface{}
for _, vvv := range vv {
arr = append(arr, toLowerCaseInterface(vvv, info))
}
@@ -326,8 +291,8 @@ func toLowerCaseInterface(v any, info *fieldInfo) any {
}
}
func toLowerCaseKeyMap(m map[string]any, info *fieldInfo) map[string]any {
res := make(map[string]any)
func toLowerCaseKeyMap(m map[string]interface{}, info *fieldInfo) map[string]interface{} {
res := make(map[string]interface{})
for k, v := range m {
ti, ok := info.children[k]
@@ -341,8 +306,6 @@ func toLowerCaseKeyMap(m map[string]any, info *fieldInfo) map[string]any {
res[lk] = toLowerCaseInterface(v, ti)
} else if info.mapField != nil {
res[k] = toLowerCaseInterface(v, info.mapField)
} else if vv, ok := v.(map[string]any); ok {
res[k] = toLowerCaseKeyMap(vv, info)
} else {
res[k] = v
}
@@ -351,22 +314,14 @@ func toLowerCaseKeyMap(m map[string]any, info *fieldInfo) map[string]any {
return res
}
type conflictKeyError struct {
type dupKeyError struct {
key string
}
func newConflictKeyError(key string) conflictKeyError {
return conflictKeyError{key: key}
func newDupKeyError(key string) dupKeyError {
return dupKeyError{key: key}
}
func (e conflictKeyError) Error() string {
return fmt.Sprintf("conflict key %s, pay attention to anonymous fields", e.key)
}
func getFullName(parent, child string) string {
if len(parent) == 0 {
return child
}
return strings.Join([]string{parent, child}, ".")
func (e dupKeyError) Error() string {
return fmt.Sprintf("duplicated key %s", e.key)
}

View File

@@ -1,9 +1,7 @@
package conf
import (
"errors"
"os"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
@@ -11,7 +9,7 @@ import (
"github.com/zeromicro/go-zero/core/hash"
)
var dupErr conflictKeyError
var dupErr dupKeyError
func TestLoadConfig_notExists(t *testing.T) {
assert.NotNil(t, Load("not_a_file", nil))
@@ -36,13 +34,14 @@ func TestConfigJson(t *testing.T) {
"c": "${FOO}",
"d": "abcd!@#$112"
}`
t.Setenv("FOO", "2")
for _, test := range tests {
test := test
t.Run(test, func(t *testing.T) {
tmpfile, err := createTempFile(t, test, text)
os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(test, text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
var val struct {
A string `json:"a"`
@@ -81,9 +80,11 @@ b = 1
c = "${FOO}"
d = "abcd!@#$112"
`
t.Setenv("FOO", "2")
tmpfile, err := createTempFile(t, ".toml", text)
os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(".toml", text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
var val struct {
A string `json:"a"`
@@ -104,8 +105,9 @@ b = 1
c = "FOO"
d = "abcd"
`
tmpfile, err := createTempFile(t, ".toml", text)
tmpfile, err := createTempFile(".toml", text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
var val struct {
A string `json:"a"`
@@ -121,23 +123,6 @@ d = "abcd"
}
}
func TestConfigWithLower(t *testing.T) {
text := `a = "foo"
b = 1
`
tmpfile, err := createTempFile(t, ".toml", text)
assert.Nil(t, err)
var val struct {
A string `json:"a"`
b int
}
if assert.NoError(t, Load(tmpfile, &val)) {
assert.Equal(t, "foo", val.A)
assert.Equal(t, 0, val.b)
}
}
func TestConfigJsonCanonical(t *testing.T) {
text := []byte(`{"a": "foo", "B": "bar"}`)
@@ -203,9 +188,11 @@ b = 1
c = "${FOO}"
d = "abcd!@#112"
`
t.Setenv("FOO", "2")
tmpfile, err := createTempFile(t, ".toml", text)
os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(".toml", text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
var val struct {
A string `json:"a"`
@@ -233,12 +220,14 @@ func TestConfigJsonEnv(t *testing.T) {
"c": "${FOO}",
"d": "abcd!@#$a12 3"
}`
t.Setenv("FOO", "2")
for _, test := range tests {
test := test
t.Run(test, func(t *testing.T) {
tmpfile, err := createTempFile(t, test, text)
os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(test, text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
var val struct {
A string `json:"a"`
@@ -643,7 +632,7 @@ func Test_FieldOverwrite(t *testing.T) {
Name2 *string
}
validate := func(val any) {
validate := func(val interface{}) {
input := []byte(`{"Name": "hello", "Name2": "world"}`)
assert.NoError(t, LoadFromJsonBytes(input, val))
}
@@ -679,11 +668,11 @@ func Test_FieldOverwrite(t *testing.T) {
Name *string
}
validate := func(val any) {
validate := func(val interface{}) {
input := []byte(`{"Name": "hello"}`)
err := LoadFromJsonBytes(input, val)
assert.ErrorAs(t, err, &dupErr)
assert.Equal(t, newConflictKeyError("name").Error(), err.Error())
assert.Equal(t, newDupKeyError("name").Error(), err.Error())
}
validate(&St1{})
@@ -722,11 +711,11 @@ func Test_FieldOverwrite(t *testing.T) {
Name *int
}
validate := func(val any) {
validate := func(val interface{}) {
input := []byte(`{"Name": "hello"}`)
err := LoadFromJsonBytes(input, val)
assert.ErrorAs(t, err, &dupErr)
assert.Error(t, err)
assert.Equal(t, newDupKeyError("name").Error(), err.Error())
}
validate(&St0{})
@@ -1033,22 +1022,22 @@ func TestLoadNamedFieldOverwritten(t *testing.T) {
})
}
func TestLoadLowerMemberShouldNotConflict(t *testing.T) {
type (
Redis struct {
db uint
}
func createTempFile(ext, text string) (string, error) {
tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
if err != nil {
return "", err
}
Config struct {
db uint
Redis
}
)
if err := os.WriteFile(tmpfile.Name(), []byte(text), os.ModeTemporary); err != nil {
return "", err
}
var c Config
assert.NoError(t, LoadFromJsonBytes([]byte(`{}`), &c))
assert.Zero(t, c.db)
assert.Zero(t, c.Redis.db)
filename := tmpfile.Name()
if err = tmpfile.Close(); err != nil {
return "", err
}
return filename, nil
}
func TestFillDefaultUnmarshal(t *testing.T) {
@@ -1090,7 +1079,7 @@ func TestFillDefaultUnmarshal(t *testing.T) {
assert.Equal(t, st.C, "c")
})
t.Run("has value", func(t *testing.T) {
t.Run("has vaue", func(t *testing.T) {
type St struct {
A string `json:",default=a"`
B string
@@ -1102,278 +1091,3 @@ func TestFillDefaultUnmarshal(t *testing.T) {
assert.Error(t, err)
})
}
func TestConfigWithJsonTag(t *testing.T) {
t.Run("map with value", func(t *testing.T) {
var input = []byte(`[Value]
[Value.first]
Email = "foo"
[Value.second]
Email = "bar"`)
type Value struct {
Email string
}
type Config struct {
ValueMap map[string]Value `json:"Value"`
}
var c Config
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
assert.Len(t, c.ValueMap, 2)
}
})
t.Run("map with ptr value", func(t *testing.T) {
var input = []byte(`[Value]
[Value.first]
Email = "foo"
[Value.second]
Email = "bar"`)
type Value struct {
Email string
}
type Config struct {
ValueMap map[string]*Value `json:"Value"`
}
var c Config
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
assert.Len(t, c.ValueMap, 2)
}
})
t.Run("map with optional", func(t *testing.T) {
var input = []byte(`[Value]
[Value.first]
Email = "foo"
[Value.second]
Email = "bar"`)
type Value struct {
Email string
}
type Config struct {
Value map[string]Value `json:",optional"`
}
var c Config
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
assert.Len(t, c.Value, 2)
}
})
t.Run("map with empty tag", func(t *testing.T) {
var input = []byte(`[Value]
[Value.first]
Email = "foo"
[Value.second]
Email = "bar"`)
type Value struct {
Email string
}
type Config struct {
Value map[string]Value `json:" "`
}
var c Config
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
assert.Len(t, c.Value, 2)
}
})
t.Run("multi layer map", func(t *testing.T) {
type Value struct {
User struct {
Name string
}
}
type Config struct {
Value map[string]map[string]Value
}
var input = []byte(`
[Value.first.User1.User]
Name = "foo"
[Value.second.User2.User]
Name = "bar"
`)
var c Config
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
assert.Len(t, c.Value, 2)
}
})
}
func Test_LoadBadConfig(t *testing.T) {
type Config struct {
Name string `json:"name,options=foo|bar"`
}
file, err := createTempFile(t, ".json", `{"name": "baz"}`)
assert.NoError(t, err)
var c Config
err = Load(file, &c)
assert.Error(t, err)
}
func Test_getFullName(t *testing.T) {
assert.Equal(t, "a.b", getFullName("a", "b"))
assert.Equal(t, "a", getFullName("", "a"))
}
func TestValidate(t *testing.T) {
t.Run("normal config", func(t *testing.T) {
var c mockConfig
err := LoadFromJsonBytes([]byte(`{"val": "hello", "number": 8}`), &c)
assert.NoError(t, err)
})
t.Run("error no int", func(t *testing.T) {
var c mockConfig
err := LoadFromJsonBytes([]byte(`{"val": "hello"}`), &c)
assert.Error(t, err)
})
t.Run("error no string", func(t *testing.T) {
var c mockConfig
err := LoadFromJsonBytes([]byte(`{"number": 8}`), &c)
assert.Error(t, err)
})
}
func Test_buildFieldsInfo(t *testing.T) {
type ParentSt struct {
Name string
M map[string]int
}
tests := []struct {
name string
t reflect.Type
ok bool
containsKey string
}{
{
name: "normal",
t: reflect.TypeOf(struct{ A string }{}),
ok: true,
},
{
name: "struct anonymous",
t: reflect.TypeOf(struct {
ParentSt
Name string
}{}),
ok: false,
containsKey: newConflictKeyError("name").Error(),
},
{
name: "struct ptr anonymous",
t: reflect.TypeOf(struct {
*ParentSt
Name string
}{}),
ok: false,
containsKey: newConflictKeyError("name").Error(),
},
{
name: "more struct anonymous",
t: reflect.TypeOf(struct {
Value struct {
ParentSt
Name string
}
}{}),
ok: false,
containsKey: newConflictKeyError("value.name").Error(),
},
{
name: "map anonymous",
t: reflect.TypeOf(struct {
ParentSt
M string
}{}),
ok: false,
containsKey: newConflictKeyError("m").Error(),
},
{
name: "map more anonymous",
t: reflect.TypeOf(struct {
Value struct {
ParentSt
M string
}
}{}),
ok: false,
containsKey: newConflictKeyError("value.m").Error(),
},
{
name: "struct slice anonymous",
t: reflect.TypeOf([]struct {
ParentSt
Name string
}{}),
ok: false,
containsKey: newConflictKeyError("name").Error(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := buildFieldsInfo(tt.t, "")
if tt.ok {
assert.NoError(t, err)
} else {
assert.Error(t, err)
assert.Equal(t, err.Error(), tt.containsKey)
}
})
}
}
func createTempFile(t *testing.T, ext, text string) (string, error) {
tmpFile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
if err != nil {
return "", err
}
if err = os.WriteFile(tmpFile.Name(), []byte(text), os.ModeTemporary); err != nil {
return "", err
}
filename := tmpFile.Name()
if err = tmpFile.Close(); err != nil {
return "", err
}
t.Cleanup(func() {
_ = os.Remove(filename)
})
return filename, nil
}
type mockConfig struct {
Val string
Number int
}
func (m mockConfig) Validate() error {
if len(m.Val) == 0 {
return errors.New("val is empty")
}
if m.Number == 0 {
return errors.New("number is zero")
}
return nil
}

View File

@@ -45,7 +45,8 @@ func TestPropertiesEnv(t *testing.T) {
assert.Nil(t, err)
defer os.Remove(tmpfile)
t.Setenv("FOO", "2")
os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
props, err := LoadProperties(tmpfile, UseEnv())
assert.Nil(t, err)

View File

@@ -12,7 +12,7 @@ type RestfulConf struct {
MaxConns int `json:",default=10000"`
MaxBytes int64 `json:",default=1048576"`
Timeout time.Duration `json:",default=3s"`
CpuThreshold int64 `json:",default=900,range=[0:1000)"`
CpuThreshold int64 `json:",default=900,range=[0:1000]"`
}
```

View File

@@ -1,12 +0,0 @@
package conf
import "github.com/zeromicro/go-zero/core/validation"
// validate validates the value if it implements the Validator interface.
func validate(v any) error {
if val, ok := v.(validation.Validator); ok {
return val.Validate()
}
return nil
}

View File

@@ -1,81 +0,0 @@
package conf
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
type mockType int
func (m mockType) Validate() error {
if m < 10 {
return errors.New("invalid value")
}
return nil
}
type anotherMockType int
func Test_validate(t *testing.T) {
tests := []struct {
name string
v any
wantErr bool
}{
{
name: "invalid",
v: mockType(5),
wantErr: true,
},
{
name: "valid",
v: mockType(10),
wantErr: false,
},
{
name: "not validator",
v: anotherMockType(5),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validate(tt.v)
assert.Equal(t, tt.wantErr, err != nil)
})
}
}
type mockVal struct {
}
func (m mockVal) Validate() error {
return errors.New("invalid value")
}
func Test_validateValPtr(t *testing.T) {
tests := []struct {
name string
v any
wantErr bool
}{
{
name: "invalid",
v: mockVal{},
},
{
name: "invalid value",
v: &mockVal{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Error(t, validate(tt.v))
})
}
}

View File

@@ -1,200 +0,0 @@
package configurator
import (
"errors"
"fmt"
"reflect"
"strings"
"sync"
"sync/atomic"
"github.com/zeromicro/go-zero/core/configcenter/subscriber"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/mapping"
"github.com/zeromicro/go-zero/core/threading"
)
var (
errEmptyConfig = errors.New("empty config value")
errMissingUnmarshalerType = errors.New("missing unmarshaler type")
)
// Configurator is the interface for configuration center.
type Configurator[T any] interface {
// GetConfig returns the subscription value.
GetConfig() (T, error)
// AddListener adds a listener to the subscriber.
AddListener(listener func())
}
type (
// Config is the configuration for Configurator.
Config struct {
// Type is the value type, yaml, json or toml.
Type string `json:",default=yaml,options=[yaml,json,toml]"`
// Log is the flag to control logging.
Log bool `json:",default=true"`
}
configCenter[T any] struct {
conf Config
unmarshaler LoaderFn
subscriber subscriber.Subscriber
listeners []func()
lock sync.Mutex
snapshot atomic.Value
}
value[T any] struct {
data string
marshalData T
err error
}
)
// Configurator is the interface for configuration center.
var _ Configurator[any] = (*configCenter[any])(nil)
// MustNewConfigCenter returns a Configurator, exits on errors.
func MustNewConfigCenter[T any](c Config, subscriber subscriber.Subscriber) Configurator[T] {
cc, err := NewConfigCenter[T](c, subscriber)
logx.Must(err)
return cc
}
// NewConfigCenter returns a Configurator.
func NewConfigCenter[T any](c Config, subscriber subscriber.Subscriber) (Configurator[T], error) {
unmarshaler, ok := Unmarshaler(strings.ToLower(c.Type))
if !ok {
return nil, fmt.Errorf("unknown format: %s", c.Type)
}
cc := &configCenter[T]{
conf: c,
unmarshaler: unmarshaler,
subscriber: subscriber,
}
if err := cc.loadConfig(); err != nil {
return nil, err
}
if err := cc.subscriber.AddListener(cc.onChange); err != nil {
return nil, err
}
if _, err := cc.GetConfig(); err != nil {
return nil, err
}
return cc, nil
}
// AddListener adds listener to s.
func (c *configCenter[T]) AddListener(listener func()) {
c.lock.Lock()
defer c.lock.Unlock()
c.listeners = append(c.listeners, listener)
}
// GetConfig return structured config.
func (c *configCenter[T]) GetConfig() (T, error) {
v := c.value()
if v == nil || len(v.data) == 0 {
var empty T
return empty, errEmptyConfig
}
return v.marshalData, v.err
}
// Value returns the subscription value.
func (c *configCenter[T]) Value() string {
v := c.value()
if v == nil {
return ""
}
return v.data
}
func (c *configCenter[T]) loadConfig() error {
v, err := c.subscriber.Value()
if err != nil {
if c.conf.Log {
logx.Errorf("ConfigCenter loads changed configuration, error: %v", err)
}
return err
}
if c.conf.Log {
logx.Infof("ConfigCenter loads changed configuration, content [%s]", v)
}
c.snapshot.Store(c.genValue(v))
return nil
}
func (c *configCenter[T]) onChange() {
if err := c.loadConfig(); err != nil {
return
}
c.lock.Lock()
listeners := make([]func(), len(c.listeners))
copy(listeners, c.listeners)
c.lock.Unlock()
for _, l := range listeners {
threading.GoSafe(l)
}
}
func (c *configCenter[T]) value() *value[T] {
content := c.snapshot.Load()
if content == nil {
return nil
}
return content.(*value[T])
}
func (c *configCenter[T]) genValue(data string) *value[T] {
v := &value[T]{
data: data,
}
if len(data) == 0 {
return v
}
t := reflect.TypeOf(v.marshalData)
// if the type is nil, it means that the user has not set the type of the configuration.
if t == nil {
v.err = errMissingUnmarshalerType
return v
}
t = mapping.Deref(t)
switch t.Kind() {
case reflect.Struct, reflect.Array, reflect.Slice:
if err := c.unmarshaler([]byte(data), &v.marshalData); err != nil {
v.err = err
if c.conf.Log {
logx.Errorf("ConfigCenter unmarshal configuration failed, err: %+v, content [%s]",
err.Error(), data)
}
}
case reflect.String:
if str, ok := any(data).(T); ok {
v.marshalData = str
} else {
v.err = errMissingUnmarshalerType
}
default:
if c.conf.Log {
logx.Errorf("ConfigCenter unmarshal configuration missing unmarshaler for type: %s, content [%s]",
t.Kind(), data)
}
v.err = errMissingUnmarshalerType
}
return v
}

View File

@@ -1,233 +0,0 @@
package configurator
import (
"errors"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestNewConfigCenter(t *testing.T) {
_, err := NewConfigCenter[any](Config{
Log: true,
}, &mockSubscriber{})
assert.Error(t, err)
_, err = NewConfigCenter[any](Config{
Type: "json",
Log: true,
}, &mockSubscriber{})
assert.Error(t, err)
}
func TestConfigCenter_GetConfig(t *testing.T) {
mock := &mockSubscriber{}
type Data struct {
Name string `json:"name"`
}
mock.v = `{"name": "go-zero"}`
c1, err := NewConfigCenter[Data](Config{
Type: "json",
Log: true,
}, mock)
assert.NoError(t, err)
data, err := c1.GetConfig()
assert.NoError(t, err)
assert.Equal(t, "go-zero", data.Name)
mock.v = `{"name": "111"}`
c2, err := NewConfigCenter[Data](Config{Type: "json"}, mock)
assert.NoError(t, err)
mock.v = `{}`
c3, err := NewConfigCenter[string](Config{
Type: "json",
Log: true,
}, mock)
assert.NoError(t, err)
_, err = c3.GetConfig()
assert.NoError(t, err)
data, err = c2.GetConfig()
assert.NoError(t, err)
mock.lisErr = errors.New("mock error")
_, err = NewConfigCenter[Data](Config{
Type: "json",
Log: true,
}, mock)
assert.Error(t, err)
}
func TestConfigCenter_onChange(t *testing.T) {
mock := &mockSubscriber{}
type Data struct {
Name string `json:"name"`
}
mock.v = `{"name": "go-zero"}`
c1, err := NewConfigCenter[Data](Config{Type: "json", Log: true}, mock)
assert.NoError(t, err)
data, err := c1.GetConfig()
assert.NoError(t, err)
assert.Equal(t, "go-zero", data.Name)
mock.v = `{"name": "go-zero2"}`
mock.change()
data, err = c1.GetConfig()
assert.NoError(t, err)
assert.Equal(t, "go-zero2", data.Name)
mock.valErr = errors.New("mock error")
_, err = NewConfigCenter[Data](Config{Type: "json", Log: false}, mock)
assert.Error(t, err)
}
func TestConfigCenter_Value(t *testing.T) {
mock := &mockSubscriber{}
mock.v = "1234"
c, err := NewConfigCenter[string](Config{
Type: "json",
Log: true,
}, mock)
assert.NoError(t, err)
cc := c.(*configCenter[string])
assert.Equal(t, cc.Value(), "1234")
mock.valErr = errors.New("mock error")
_, err = NewConfigCenter[any](Config{
Type: "json",
Log: true,
}, mock)
assert.Error(t, err)
}
func TestConfigCenter_AddListener(t *testing.T) {
mock := &mockSubscriber{}
mock.v = "1234"
c, err := NewConfigCenter[string](Config{
Type: "json",
Log: true,
}, mock)
assert.NoError(t, err)
cc := c.(*configCenter[string])
var a, b int
var mutex sync.Mutex
cc.AddListener(func() {
mutex.Lock()
a = 1
mutex.Unlock()
})
cc.AddListener(func() {
mutex.Lock()
b = 2
mutex.Unlock()
})
assert.Equal(t, 2, len(cc.listeners))
mock.change()
time.Sleep(time.Millisecond * 100)
mutex.Lock()
assert.Equal(t, 1, a)
assert.Equal(t, 2, b)
mutex.Unlock()
}
func TestConfigCenter_genValue(t *testing.T) {
t.Run("data is empty", func(t *testing.T) {
c := &configCenter[string]{
unmarshaler: registry.unmarshalers["json"],
conf: Config{Log: true},
}
v := c.genValue("")
assert.Equal(t, "", v.data)
})
t.Run("invalid template type", func(t *testing.T) {
c := &configCenter[any]{
unmarshaler: registry.unmarshalers["json"],
conf: Config{Log: true},
}
v := c.genValue("xxxx")
assert.Equal(t, errMissingUnmarshalerType, v.err)
})
t.Run("unsupported template type", func(t *testing.T) {
c := &configCenter[int]{
unmarshaler: registry.unmarshalers["json"],
conf: Config{Log: true},
}
v := c.genValue("1")
assert.Equal(t, errMissingUnmarshalerType, v.err)
})
t.Run("supported template string type", func(t *testing.T) {
c := &configCenter[string]{
unmarshaler: registry.unmarshalers["json"],
conf: Config{Log: true},
}
v := c.genValue("12345")
assert.NoError(t, v.err)
assert.Equal(t, "12345", v.data)
})
t.Run("unmarshal fail", func(t *testing.T) {
c := &configCenter[struct {
Name string `json:"name"`
}]{
unmarshaler: registry.unmarshalers["json"],
conf: Config{Log: true},
}
v := c.genValue(`{"name":"new name}`)
assert.Equal(t, `{"name":"new name}`, v.data)
assert.Error(t, v.err)
})
t.Run("success", func(t *testing.T) {
c := &configCenter[struct {
Name string `json:"name"`
}]{
unmarshaler: registry.unmarshalers["json"],
conf: Config{Log: true},
}
v := c.genValue(`{"name":"new name"}`)
assert.Equal(t, `{"name":"new name"}`, v.data)
assert.Equal(t, "new name", v.marshalData.Name)
assert.NoError(t, v.err)
})
}
type mockSubscriber struct {
v string
lisErr, valErr error
listener func()
}
func (m *mockSubscriber) AddListener(listener func()) error {
m.listener = listener
return m.lisErr
}
func (m *mockSubscriber) Value() (string, error) {
return m.v, m.valErr
}
func (m *mockSubscriber) change() {
if m.listener != nil {
m.listener()
}
}

View File

@@ -1,67 +0,0 @@
package subscriber
import (
"github.com/zeromicro/go-zero/core/discov"
"github.com/zeromicro/go-zero/core/logx"
)
type (
// etcdSubscriber is a subscriber that subscribes to etcd.
etcdSubscriber struct {
*discov.Subscriber
}
// EtcdConf is the configuration for etcd.
EtcdConf = discov.EtcdConf
)
// MustNewEtcdSubscriber returns an etcd Subscriber, exits on errors.
func MustNewEtcdSubscriber(conf EtcdConf) Subscriber {
s, err := NewEtcdSubscriber(conf)
logx.Must(err)
return s
}
// NewEtcdSubscriber returns an etcd Subscriber.
func NewEtcdSubscriber(conf EtcdConf) (Subscriber, error) {
opts := buildSubOptions(conf)
s, err := discov.NewSubscriber(conf.Hosts, conf.Key, opts...)
if err != nil {
return nil, err
}
return &etcdSubscriber{Subscriber: s}, nil
}
// buildSubOptions constructs the options for creating a new etcd subscriber.
func buildSubOptions(conf EtcdConf) []discov.SubOption {
opts := []discov.SubOption{
discov.WithExactMatch(),
}
if len(conf.User) > 0 {
opts = append(opts, discov.WithSubEtcdAccount(conf.User, conf.Pass))
}
if len(conf.CertFile) > 0 || len(conf.CertKeyFile) > 0 || len(conf.CACertFile) > 0 {
opts = append(opts, discov.WithSubEtcdTLS(conf.CertFile, conf.CertKeyFile,
conf.CACertFile, conf.InsecureSkipVerify))
}
return opts
}
// AddListener adds a listener to the subscriber.
func (s *etcdSubscriber) AddListener(listener func()) error {
s.Subscriber.AddListener(listener)
return nil
}
// Value returns the value of the subscriber.
func (s *etcdSubscriber) Value() (string, error) {
vs := s.Subscriber.Values()
if len(vs) > 0 {
return vs[len(vs)-1], nil
}
return "", nil
}

View File

@@ -1,9 +0,0 @@
package subscriber
// Subscriber is the interface for configcenter subscribers.
type Subscriber interface {
// AddListener adds a listener to the subscriber.
AddListener(listener func()) error
// Value returns the value of the subscriber.
Value() (string, error)
}

View File

@@ -1,41 +0,0 @@
package configurator
import (
"sync"
"github.com/zeromicro/go-zero/core/conf"
)
var registry = &unmarshalerRegistry{
unmarshalers: map[string]LoaderFn{
"json": conf.LoadFromJsonBytes,
"toml": conf.LoadFromTomlBytes,
"yaml": conf.LoadFromYamlBytes,
},
}
type (
// LoaderFn is the function type for loading configuration.
LoaderFn func([]byte, any) error
// unmarshalerRegistry is the registry for unmarshalers.
unmarshalerRegistry struct {
unmarshalers map[string]LoaderFn
mu sync.RWMutex
}
)
// RegisterUnmarshaler registers an unmarshaler.
func RegisterUnmarshaler(name string, fn LoaderFn) {
registry.mu.Lock()
defer registry.mu.Unlock()
registry.unmarshalers[name] = fn
}
// Unmarshaler returns the unmarshaler by name.
func Unmarshaler(name string) (LoaderFn, bool) {
registry.mu.RLock()
defer registry.mu.RUnlock()
fn, ok := registry.unmarshalers[name]
return fn, ok
}

View File

@@ -1,28 +0,0 @@
package configurator
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestRegisterUnmarshaler(t *testing.T) {
RegisterUnmarshaler("test", func(data []byte, v interface{}) error {
return nil
})
_, ok := Unmarshaler("test")
assert.True(t, ok)
_, ok = Unmarshaler("test2")
assert.False(t, ok)
_, ok = Unmarshaler("json")
assert.True(t, ok)
_, ok = Unmarshaler("toml")
assert.True(t, ok)
_, ok = Unmarshaler("yaml")
assert.True(t, ok)
}

View File

@@ -14,13 +14,13 @@ type contextValuer struct {
context.Context
}
func (cv contextValuer) Value(key string) (any, bool) {
func (cv contextValuer) Value(key string) (interface{}, bool) {
v := cv.Context.Value(key)
return v, v != nil
}
// For unmarshals ctx into v.
func For(ctx context.Context, v any) error {
func For(ctx context.Context, v interface{}) error {
return unmarshaler.UnmarshalValuer(contextValuer{
Context: ctx,
}, v)

View File

@@ -13,7 +13,6 @@ var (
type EtcdConf struct {
Hosts []string
Key string
ID int64 `json:",optional"`
User string `json:",optional"`
Pass string `json:",optional"`
CertFile string `json:",optional"`
@@ -27,11 +26,6 @@ func (c EtcdConf) HasAccount() bool {
return len(c.User) > 0 && len(c.Pass) > 0
}
// HasID returns if ID provided.
func (c EtcdConf) HasID() bool {
return c.ID > 0
}
// HasTLS returns if TLS CertFile/CertKeyFile/CACertFile are provided.
func (c EtcdConf) HasTLS() bool {
return len(c.CertFile) > 0 && len(c.CertKeyFile) > 0 && len(c.CACertFile) > 0

View File

@@ -80,90 +80,3 @@ func TestEtcdConf_HasAccount(t *testing.T) {
assert.Equal(t, test.hasAccount, test.EtcdConf.HasAccount())
}
}
func TestEtcdConf_HasID(t *testing.T) {
tests := []struct {
EtcdConf
hasServerID bool
}{
{
EtcdConf: EtcdConf{
Hosts: []string{"any"},
ID: -1,
},
hasServerID: false,
},
{
EtcdConf: EtcdConf{
Hosts: []string{"any"},
ID: 0,
},
hasServerID: false,
},
{
EtcdConf: EtcdConf{
Hosts: []string{"any"},
ID: 10000,
},
hasServerID: true,
},
}
for _, test := range tests {
assert.Equal(t, test.hasServerID, test.EtcdConf.HasID())
}
}
func TestEtcdConf_HasTLS(t *testing.T) {
tests := []struct {
name string
conf EtcdConf
want bool
}{
{
name: "empty config",
conf: EtcdConf{},
want: false,
},
{
name: "missing CertFile",
conf: EtcdConf{
CertKeyFile: "key",
CACertFile: "ca",
},
want: false,
},
{
name: "missing CertKeyFile",
conf: EtcdConf{
CertFile: "cert",
CACertFile: "ca",
},
want: false,
},
{
name: "missing CACertFile",
conf: EtcdConf{
CertFile: "cert",
CertKeyFile: "key",
},
want: false,
},
{
name: "valid config",
conf: EtcdConf{
CertFile: "cert",
CertKeyFile: "key",
CACertFile: "ca",
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.conf.HasTLS()
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -1,85 +1,12 @@
package internal
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stringx"
)
const (
certContent = `-----BEGIN CERTIFICATE-----
MIIDazCCAlOgAwIBAgIUEg9GVO2oaPn+YSmiqmFIuAo10WIwDQYJKoZIhvcNAQEM
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAzMTExMzIxMjNaGA8yMTIz
MDIxNTEzMjEyM1owRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
AQEBBQADggEPADCCAQoCggEBALplXlWsIf0O/IgnIplmiZHKGnxyfyufyE2FBRNk
OofRqbKuPH8GNqbkvZm7N29fwTDAQ+mViAggCkDht4hOzoWJMA7KYJt8JnTSWL48
M1lcrpc9DL2gszC/JF/FGvyANbBtLklkZPFBGdHUX14pjrT937wqPtm+SqUHSvRT
B7bmwmm2drRcmhpVm98LSlV7uQ2EgnJgsLjBPITKUejLmVLHfgX0RwQ2xIpX9pS4
FCe1BTacwl2gGp7Mje7y4Mfv3o0ArJW6Tuwbjx59ZXwb1KIP71b7bT04AVS8ZeYO
UMLKKuB5UR9x9Rn6cLXOTWBpcMVyzDgrAFLZjnE9LPUolZMCAwEAAaNRME8wHwYD
VR0jBBgwFoAUeW8w8pmhncbRgTsl48k4/7wnfx8wCQYDVR0TBAIwADALBgNVHQ8E
BAMCBPAwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBDAUAA4IBAQAI
y9xaoS88CLPBsX6mxfcTAFVfGNTRW9VN9Ng1cCnUR+YGoXGM/l+qP4f7p8ocdGwK
iYZErVTzXYIn+D27//wpY3klJk3gAnEUBT3QRkStBw7XnpbeZ2oPBK+cmDnCnZPS
BIF1wxPX7vIgaxs5Zsdqwk3qvZ4Djr2wP7LabNWTLSBKgQoUY45Liw6pffLwcGF9
UKlu54bvGze2SufISCR3ib+I+FLvqpvJhXToZWYb/pfI/HccuCL1oot1x8vx6DQy
U+TYxlZsKS5mdNxAX3dqEkEMsgEi+g/tzDPXJImfeCGGBhIOXLm8SRypiuGdEbc9
xkWYxRPegajuEZGvCqVs
-----END CERTIFICATE-----`
keyContent = `-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAumVeVawh/Q78iCcimWaJkcoafHJ/K5/ITYUFE2Q6h9Gpsq48
fwY2puS9mbs3b1/BMMBD6ZWICCAKQOG3iE7OhYkwDspgm3wmdNJYvjwzWVyulz0M
vaCzML8kX8Ua/IA1sG0uSWRk8UEZ0dRfXimOtP3fvCo+2b5KpQdK9FMHtubCabZ2
tFyaGlWb3wtKVXu5DYSCcmCwuME8hMpR6MuZUsd+BfRHBDbEilf2lLgUJ7UFNpzC
XaAansyN7vLgx+/ejQCslbpO7BuPHn1lfBvUog/vVvttPTgBVLxl5g5Qwsoq4HlR
H3H1Gfpwtc5NYGlwxXLMOCsAUtmOcT0s9SiVkwIDAQABAoIBAD5meTJNMgO55Kjg
ESExxpRcCIno+tHr5+6rvYtEXqPheOIsmmwb9Gfi4+Z3WpOaht5/Pz0Ppj6yGzyl
U//6AgGKb+BDuBvVcDpjwPnOxZIBCSHwejdxeQu0scSuA97MPS0XIAvJ5FEv7ijk
5Bht6SyGYURpECltHygoTNuGgGqmO+McCJRLE9L09lTBI6UQ/JQwWJqSr7wx6iPU
M1Ze/srIV+7cyEPu6i0DGjS1gSQKkX68Lqn1w6oE290O+OZvleO0gZ02fLDWCZke
aeD9+EU/Pw+rqm3H6o0szOFIpzhRp41FUdW9sybB3Yp3u7c/574E+04Z/e30LMKs
TCtE1QECgYEA3K7KIpw0NH2HXL5C3RHcLmr204xeBfS70riBQQuVUgYdmxak2ima
80RInskY8hRhSGTg0l+VYIH8cmjcUyqMSOELS5XfRH99r4QPiK8AguXg80T4VumY
W3Pf+zEC2ssgP/gYthV0g0Xj5m2QxktOF9tRw5nkg739ZR4dI9lm/iECgYEA2Dnf
uwEDGqHiQRF6/fh5BG/nGVMvrefkqx6WvTJQ3k/M/9WhxB+lr/8yH46TuS8N2b29
FoTf3Mr9T7pr/PWkOPzoY3P56nYbKU8xSwCim9xMzhBMzj8/N9ukJvXy27/VOz56
eQaKqnvdXNGtPJrIMDGHps2KKWlKLyAlapzjVTMCgYAA/W++tACv85g13EykfT4F
n0k4LbsGP9DP4zABQLIMyiY72eAncmRVjwrcW36XJ2xATOONTgx3gF3HjZzfaqNy
eD/6uNNllUTVEryXGmHgNHPL45VRnn6memCY2eFvZdXhM5W4y2PYaunY0MkDercA
+GTngbs6tBF88KOk04bYwQKBgFl68cRgsdkmnwwQYNaTKfmVGYzYaQXNzkqmWPko
xmCJo6tHzC7ubdG8iRCYHzfmahPuuj6EdGPZuSRyYFgJi5Ftz/nAN+84OxtIQ3zn
YWOgskQgaLh9YfsKsQ7Sf1NDOsnOnD5TX7UXl07fEpLe9vNCvAFiU8e5Y9LGudU5
4bYTAoGBAMdX3a3bXp4cZvXNBJ/QLVyxC6fP1Q4haCR1Od3m+T00Jth2IX2dk/fl
p6xiJT1av5JtYabv1dFKaXOS5s1kLGGuCCSKpkvFZm826aQ2AFm0XGqEQDLeei5b
A52Kpy/YJ+RkG4BTFtAooFq6DmA0cnoP6oPvG2h6XtDJwDTPInJb
-----END RSA PRIVATE KEY-----`
caContent = `-----BEGIN CERTIFICATE-----
MIIDbTCCAlWgAwIBAgIUBJvFoCowKich7MMfseJ+DYzzirowDQYJKoZIhvcNAQEM
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAzMTExMzIxMDNaGA8yMTIz
MDIxNTEzMjEwM1owRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
AQEBBQADggEPADCCAQoCggEBAO4to2YMYj0bxgr2FCiweSTSFuPx33zSw2x/s9Wf
OR41bm2DFsyYT5f3sOIKlXZEdLmOKty2e3ho3yC0EyNpVHdykkkHT3aDI17quZax
kYi/URqqtl1Z08A22txolc04hAZisg2BypGi3vql81UW1t3zyloGnJoIAeXR9uca
ljP6Bk3bwsxoVBLi1JtHrO0hHLQaeHmKhAyrys06X0LRdn7Px48yRZlt6FaLSa8X
YiRM0G44bVy/h6BkoQjMYGwVmCVk6zjJ9U7ZPFqdnDMNxAfR+hjDnYodqdLDMTTR
1NPVrnEnNwFx0AMLvgt/ba/45vZCEAmSZnFXFAJJcM7ai9ECAwEAAaNTMFEwHQYD
VR0OBBYEFHlvMPKZoZ3G0YE7JePJOP+8J38fMB8GA1UdIwQYMBaAFHlvMPKZoZ3G
0YE7JePJOP+8J38fMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQEMBQADggEB
AMX8dNulADOo9uQgBMyFb9TVra7iY0zZjzv4GY5XY7scd52n6CnfAPvYBBDnTr/O
BgNp5jaujb4+9u/2qhV3f9n+/3WOb2CmPehBgVSzlXqHeQ9lshmgwZPeem2T+8Tm
Nnc/xQnsUfCFszUDxpkr55+aLVM22j02RWqcZ4q7TAaVYL+kdFVMc8FoqG/0ro6A
BjE/Qn0Nn7ciX1VUjDt8l+k7ummPJTmzdi6i6E4AwO9dzrGNgGJ4aWL8cC6xYcIX
goVIRTFeONXSDno/oPjWHpIPt7L15heMpKBHNuzPkKx2YVqPHE5QZxWfS+Lzgx+Q
E2oTTM0rYKOZ8p6000mhvKI=
-----END CERTIFICATE-----`
)
func TestAccount(t *testing.T) {
endpoints := []string{
"192.168.0.2:2379",
@@ -105,34 +32,3 @@ func TestAccount(t *testing.T) {
assert.Equal(t, username, account.User)
assert.Equal(t, anotherPassword, account.Pass)
}
func TestTLSMethods(t *testing.T) {
certFile := createTempFile(t, []byte(certContent))
defer os.Remove(certFile)
keyFile := createTempFile(t, []byte(keyContent))
defer os.Remove(keyFile)
caFile := createTempFile(t, []byte(caContent))
defer os.Remove(caFile)
assert.NoError(t, AddTLS([]string{"foo"}, certFile, keyFile, caFile, false))
cfg, ok := GetTLS([]string{"foo"})
assert.True(t, ok)
assert.NotNil(t, cfg)
assert.Error(t, AddTLS([]string{"bar"}, "bad-file", keyFile, caFile, false))
assert.Error(t, AddTLS([]string{"bar"}, certFile, keyFile, "bad-file", false))
}
func createTempFile(t *testing.T, body []byte) string {
tmpFile, err := os.CreateTemp(os.TempDir(), "go-unit-*.tmp")
if err != nil {
t.Fatal(err)
}
tmpFile.Close()
if err = os.WriteFile(tmpFile.Name(), body, os.ModePerm); err != nil {
t.Fatal(err)
}
return tmpFile.Name()
}

View File

@@ -81,7 +81,7 @@ func (mr *MockEtcdClientMockRecorder) Ctx() *gomock.Call {
// Get mocks base method
func (m *MockEtcdClient) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, key}
varargs := []interface{}{ctx, key}
for _, a := range opts {
varargs = append(varargs, a)
}
@@ -92,9 +92,9 @@ func (m *MockEtcdClient) Get(ctx context.Context, key string, opts ...clientv3.O
}
// Get indicates an expected call of Get
func (mr *MockEtcdClientMockRecorder) Get(ctx, key any, opts ...any) *gomock.Call {
func (mr *MockEtcdClientMockRecorder) Get(ctx, key interface{}, opts ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, key}, opts...)
varargs := append([]interface{}{ctx, key}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockEtcdClient)(nil).Get), varargs...)
}
@@ -108,7 +108,7 @@ func (m *MockEtcdClient) Grant(ctx context.Context, ttl int64) (*clientv3.LeaseG
}
// Grant indicates an expected call of Grant
func (mr *MockEtcdClientMockRecorder) Grant(ctx, ttl any) *gomock.Call {
func (mr *MockEtcdClientMockRecorder) Grant(ctx, ttl interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Grant", reflect.TypeOf((*MockEtcdClient)(nil).Grant), ctx, ttl)
}
@@ -123,7 +123,7 @@ func (m *MockEtcdClient) KeepAlive(ctx context.Context, id clientv3.LeaseID) (<-
}
// KeepAlive indicates an expected call of KeepAlive
func (mr *MockEtcdClientMockRecorder) KeepAlive(ctx, id any) *gomock.Call {
func (mr *MockEtcdClientMockRecorder) KeepAlive(ctx, id interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeepAlive", reflect.TypeOf((*MockEtcdClient)(nil).KeepAlive), ctx, id)
}
@@ -131,7 +131,7 @@ func (mr *MockEtcdClientMockRecorder) KeepAlive(ctx, id any) *gomock.Call {
// Put mocks base method
func (m *MockEtcdClient) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, key, val}
varargs := []interface{}{ctx, key, val}
for _, a := range opts {
varargs = append(varargs, a)
}
@@ -142,9 +142,9 @@ func (m *MockEtcdClient) Put(ctx context.Context, key, val string, opts ...clien
}
// Put indicates an expected call of Put
func (mr *MockEtcdClientMockRecorder) Put(ctx, key, val any, opts ...any) *gomock.Call {
func (mr *MockEtcdClientMockRecorder) Put(ctx, key, val interface{}, opts ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, key, val}, opts...)
varargs := append([]interface{}{ctx, key, val}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockEtcdClient)(nil).Put), varargs...)
}
@@ -158,7 +158,7 @@ func (m *MockEtcdClient) Revoke(ctx context.Context, id clientv3.LeaseID) (*clie
}
// Revoke indicates an expected call of Revoke
func (mr *MockEtcdClientMockRecorder) Revoke(ctx, id any) *gomock.Call {
func (mr *MockEtcdClientMockRecorder) Revoke(ctx, id interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Revoke", reflect.TypeOf((*MockEtcdClient)(nil).Revoke), ctx, id)
}
@@ -166,7 +166,7 @@ func (mr *MockEtcdClientMockRecorder) Revoke(ctx, id any) *gomock.Call {
// Watch mocks base method
func (m *MockEtcdClient) Watch(ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan {
m.ctrl.T.Helper()
varargs := []any{ctx, key}
varargs := []interface{}{ctx, key}
for _, a := range opts {
varargs = append(varargs, a)
}
@@ -176,8 +176,8 @@ func (m *MockEtcdClient) Watch(ctx context.Context, key string, opts ...clientv3
}
// Watch indicates an expected call of Watch
func (mr *MockEtcdClientMockRecorder) Watch(ctx, key any, opts ...any) *gomock.Call {
func (mr *MockEtcdClientMockRecorder) Watch(ctx, key interface{}, opts ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, key}, opts...)
varargs := append([]interface{}{ctx, key}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockEtcdClient)(nil).Watch), varargs...)
}

View File

@@ -2,7 +2,6 @@ package internal
import (
"context"
"errors"
"fmt"
"io"
"sort"
@@ -10,30 +9,25 @@ import (
"sync"
"time"
"github.com/zeromicro/go-zero/core/contextx"
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logc"
"github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/threading"
"go.etcd.io/etcd/api/v3/v3rpc/rpctypes"
clientv3 "go.etcd.io/etcd/client/v3"
)
const coolDownDeviation = 0.05
var (
registry = Registry{
clusters: make(map[string]*cluster),
}
connManager = syncx.NewResourceManager()
coolDownUnstable = mathx.NewUnstable(coolDownDeviation)
errClosed = errors.New("etcd monitor chan has been closed")
connManager = syncx.NewResourceManager()
)
// A Registry is a registry that manages the etcd client connections.
type Registry struct {
clusters map[string]*cluster
lock sync.RWMutex
lock sync.Mutex
}
// GetRegistry returns a global Registry.
@@ -43,148 +37,60 @@ func GetRegistry() *Registry {
// GetConn returns an etcd client connection associated with given endpoints.
func (r *Registry) GetConn(endpoints []string) (EtcdClient, error) {
c, _ := r.getOrCreateCluster(endpoints)
c, _ := r.getCluster(endpoints)
return c.getClient()
}
// Monitor monitors the key on given etcd endpoints, notify with the given UpdateListener.
func (r *Registry) Monitor(endpoints []string, key string, exactMatch bool, l UpdateListener) error {
wkey := watchKey{
key: key,
exactMatch: exactMatch,
}
c, exists := r.getOrCreateCluster(endpoints)
func (r *Registry) Monitor(endpoints []string, key string, l UpdateListener) error {
c, exists := r.getCluster(endpoints)
// if exists, the existing values should be updated to the listener.
if exists {
c.lock.Lock()
watcher, ok := c.watchers[wkey]
if ok {
watcher.listeners = append(watcher.listeners, l)
}
c.lock.Unlock()
if ok {
kvs := c.getCurrent(wkey)
for _, kv := range kvs {
l.OnAdd(kv)
}
return nil
kvs := c.getCurrent(key)
for _, kv := range kvs {
l.OnAdd(kv)
}
}
return c.monitor(wkey, l)
return c.monitor(key, l)
}
func (r *Registry) Unmonitor(endpoints []string, key string, exactMatch bool, l UpdateListener) {
c, exists := r.getCluster(endpoints)
if !exists {
return
}
wkey := watchKey{
key: key,
exactMatch: exactMatch,
}
c.lock.Lock()
defer c.lock.Unlock()
watcher, ok := c.watchers[wkey]
if !ok {
return
}
for i, listener := range watcher.listeners {
if listener == l {
watcher.listeners = append(watcher.listeners[:i], watcher.listeners[i+1:]...)
break
}
}
if len(watcher.listeners) == 0 {
if watcher.cancel != nil {
watcher.cancel()
}
delete(c.watchers, wkey)
}
}
func (r *Registry) getCluster(endpoints []string) (*cluster, bool) {
func (r *Registry) getCluster(endpoints []string) (c *cluster, exists bool) {
clusterKey := getClusterKey(endpoints)
r.lock.RLock()
c, ok := r.clusters[clusterKey]
r.lock.RUnlock()
return c, ok
}
func (r *Registry) getOrCreateCluster(endpoints []string) (c *cluster, exists bool) {
c, exists = r.getCluster(endpoints)
r.lock.Lock()
defer r.lock.Unlock()
c, exists = r.clusters[clusterKey]
if !exists {
clusterKey := getClusterKey(endpoints)
r.lock.Lock()
defer r.lock.Unlock()
// double-check locking
c, exists = r.clusters[clusterKey]
if !exists {
c = newCluster(endpoints)
r.clusters[clusterKey] = c
}
c = newCluster(endpoints)
r.clusters[clusterKey] = c
}
return
}
type (
watchKey struct {
key string
exactMatch bool
}
watchValue struct {
listeners []UpdateListener
values map[string]string
cancel context.CancelFunc
}
cluster struct {
endpoints []string
key string
watchers map[watchKey]*watchValue
watchGroup *threading.RoutineGroup
done chan lang.PlaceholderType
lock sync.RWMutex
}
)
type cluster struct {
endpoints []string
key string
values map[string]map[string]string
listeners map[string][]UpdateListener
watchGroup *threading.RoutineGroup
done chan lang.PlaceholderType
lock sync.Mutex
}
func newCluster(endpoints []string) *cluster {
return &cluster{
endpoints: endpoints,
key: getClusterKey(endpoints),
watchers: make(map[watchKey]*watchValue),
values: make(map[string]map[string]string),
listeners: make(map[string][]UpdateListener),
watchGroup: threading.NewRoutineGroup(),
done: make(chan lang.PlaceholderType),
}
}
func (c *cluster) addListener(key watchKey, l UpdateListener) {
c.lock.Lock()
defer c.lock.Unlock()
watcher, ok := c.watchers[key]
if ok {
watcher.listeners = append(watcher.listeners, l)
return
}
val := newWatchValue()
val.listeners = []UpdateListener{l}
c.watchers[key] = val
func (c *cluster) context(cli EtcdClient) context.Context {
return contextx.ValueOnlyFrom(cli.Ctx())
}
func (c *cluster) getClient() (EtcdClient, error) {
@@ -198,17 +104,12 @@ func (c *cluster) getClient() (EtcdClient, error) {
return val.(EtcdClient), nil
}
func (c *cluster) getCurrent(key watchKey) []KV {
c.lock.RLock()
defer c.lock.RUnlock()
watcher, ok := c.watchers[key]
if !ok {
return nil
}
func (c *cluster) getCurrent(key string) []KV {
c.lock.Lock()
defer c.lock.Unlock()
var kvs []KV
for k, v := range watcher.values {
for k, v := range c.values[key] {
kvs = append(kvs, KV{
Key: k,
Val: v,
@@ -218,23 +119,42 @@ func (c *cluster) getCurrent(key watchKey) []KV {
return kvs
}
func (c *cluster) handleChanges(key watchKey, kvs []KV) {
func (c *cluster) handleChanges(key string, kvs []KV) {
var add []KV
var remove []KV
c.lock.Lock()
watcher, ok := c.watchers[key]
listeners := append([]UpdateListener(nil), c.listeners[key]...)
vals, ok := c.values[key]
if !ok {
c.lock.Unlock()
return
add = kvs
vals = make(map[string]string)
for _, kv := range kvs {
vals[kv.Key] = kv.Val
}
c.values[key] = vals
} else {
m := make(map[string]string)
for _, kv := range kvs {
m[kv.Key] = kv.Val
}
for k, v := range vals {
if val, ok := m[k]; !ok || v != val {
remove = append(remove, KV{
Key: k,
Val: v,
})
}
}
for k, v := range m {
if val, ok := vals[k]; !ok || v != val {
add = append(add, KV{
Key: k,
Val: v,
})
}
}
c.values[key] = m
}
listeners := append([]UpdateListener(nil), watcher.listeners...)
// watcher.values cannot be nil
vals := watcher.values
newVals := make(map[string]string, len(kvs)+len(vals))
for _, kv := range kvs {
newVals[kv.Key] = kv.Val
}
add, remove := calculateChanges(vals, newVals)
watcher.values = newVals
c.lock.Unlock()
for _, kv := range add {
@@ -249,22 +169,20 @@ func (c *cluster) handleChanges(key watchKey, kvs []KV) {
}
}
func (c *cluster) handleWatchEvents(ctx context.Context, key watchKey, events []*clientv3.Event) {
c.lock.RLock()
watcher, ok := c.watchers[key]
if !ok {
c.lock.RUnlock()
return
}
listeners := append([]UpdateListener(nil), watcher.listeners...)
c.lock.RUnlock()
func (c *cluster) handleWatchEvents(key string, events []*clientv3.Event) {
c.lock.Lock()
listeners := append([]UpdateListener(nil), c.listeners[key]...)
c.lock.Unlock()
for _, ev := range events {
switch ev.Type {
case clientv3.EventTypePut:
c.lock.Lock()
watcher.values[string(ev.Kv.Key)] = string(ev.Kv.Value)
if vals, ok := c.values[key]; ok {
vals[string(ev.Kv.Key)] = string(ev.Kv.Value)
} else {
c.values[key] = map[string]string{string(ev.Kv.Key): string(ev.Kv.Value)}
}
c.lock.Unlock()
for _, l := range listeners {
l.OnAdd(KV{
@@ -274,7 +192,9 @@ func (c *cluster) handleWatchEvents(ctx context.Context, key watchKey, events []
}
case clientv3.EventTypeDelete:
c.lock.Lock()
delete(watcher.values, string(ev.Kv.Key))
if vals, ok := c.values[key]; ok {
delete(vals, string(ev.Kv.Key))
}
c.lock.Unlock()
for _, l := range listeners {
l.OnDelete(KV{
@@ -283,29 +203,24 @@ func (c *cluster) handleWatchEvents(ctx context.Context, key watchKey, events []
})
}
default:
logc.Errorf(ctx, "Unknown event type: %v", ev.Type)
logx.Errorf("Unknown event type: %v", ev.Type)
}
}
}
func (c *cluster) load(cli EtcdClient, key watchKey) int64 {
func (c *cluster) load(cli EtcdClient, key string) int64 {
var resp *clientv3.GetResponse
for {
var err error
ctx, cancel := context.WithTimeout(cli.Ctx(), RequestTimeout)
if key.exactMatch {
resp, err = cli.Get(ctx, key.key)
} else {
resp, err = cli.Get(ctx, makeKeyPrefix(key.key), clientv3.WithPrefix())
}
ctx, cancel := context.WithTimeout(c.context(cli), RequestTimeout)
resp, err = cli.Get(ctx, makeKeyPrefix(key), clientv3.WithPrefix())
cancel()
if err == nil {
break
}
logc.Errorf(cli.Ctx(), "%s, key: %s, exactMatch: %t", err.Error(), key.key, key.exactMatch)
time.Sleep(coolDownUnstable.AroundDuration(coolDownInterval))
logx.Error(err)
time.Sleep(coolDownInterval)
}
var kvs []KV
@@ -321,13 +236,16 @@ func (c *cluster) load(cli EtcdClient, key watchKey) int64 {
return resp.Header.Revision
}
func (c *cluster) monitor(key watchKey, l UpdateListener) error {
func (c *cluster) monitor(key string, l UpdateListener) error {
c.lock.Lock()
c.listeners[key] = append(c.listeners[key], l)
c.lock.Unlock()
cli, err := c.getClient()
if err != nil {
return err
}
c.addListener(key, l)
rev := c.load(cli, key)
c.watchGroup.Run(func() {
c.watch(cli, key, rev)
@@ -349,22 +267,16 @@ func (c *cluster) newClient() (EtcdClient, error) {
func (c *cluster) reload(cli EtcdClient) {
c.lock.Lock()
// cancel the previous watches
close(c.done)
c.watchGroup.Wait()
var keys []watchKey
for wk, wval := range c.watchers {
keys = append(keys, wk)
if wval.cancel != nil {
wval.cancel()
}
}
c.done = make(chan lang.PlaceholderType)
c.watchGroup = threading.NewRoutineGroup()
var keys []string
for k := range c.listeners {
keys = append(keys, k)
}
c.lock.Unlock()
// start new watches
for _, key := range keys {
k := key
c.watchGroup.Run(func() {
@@ -374,80 +286,46 @@ func (c *cluster) reload(cli EtcdClient) {
}
}
func (c *cluster) watch(cli EtcdClient, key watchKey, rev int64) {
func (c *cluster) watch(cli EtcdClient, key string, rev int64) {
for {
err := c.watchStream(cli, key, rev)
if err == nil {
if c.watchStream(cli, key, rev) {
return
}
if rev != 0 && errors.Is(err, rpctypes.ErrCompacted) {
logc.Errorf(cli.Ctx(), "etcd watch stream has been compacted, try to reload, rev %d", rev)
rev = c.load(cli, key)
}
// log the error and retry
logc.Error(cli.Ctx(), err)
}
}
func (c *cluster) watchStream(cli EtcdClient, key watchKey, rev int64) error {
ctx, rch := c.setupWatch(cli, key, rev)
func (c *cluster) watchStream(cli EtcdClient, key string, rev int64) bool {
var rch clientv3.WatchChan
if rev != 0 {
rch = cli.Watch(clientv3.WithRequireLeader(c.context(cli)), makeKeyPrefix(key), clientv3.WithPrefix(),
clientv3.WithRev(rev+1))
} else {
rch = cli.Watch(clientv3.WithRequireLeader(c.context(cli)), makeKeyPrefix(key), clientv3.WithPrefix())
}
for {
select {
case wresp, ok := <-rch:
if !ok {
return errClosed
logx.Error("etcd monitor chan has been closed")
return false
}
if wresp.Canceled {
return fmt.Errorf("etcd monitor chan has been canceled, error: %w", wresp.Err())
logx.Errorf("etcd monitor chan has been canceled, error: %v", wresp.Err())
return false
}
if wresp.Err() != nil {
return fmt.Errorf("etcd monitor chan error: %w", wresp.Err())
logx.Error(fmt.Sprintf("etcd monitor chan error: %v", wresp.Err()))
return false
}
c.handleWatchEvents(ctx, key, wresp.Events)
case <-ctx.Done():
return nil
c.handleWatchEvents(key, wresp.Events)
case <-c.done:
return nil
return true
}
}
}
func (c *cluster) setupWatch(cli EtcdClient, key watchKey, rev int64) (context.Context, clientv3.WatchChan) {
var (
rch clientv3.WatchChan
ops []clientv3.OpOption
wkey = key.key
)
if !key.exactMatch {
wkey = makeKeyPrefix(key.key)
ops = append(ops, clientv3.WithPrefix())
}
if rev != 0 {
ops = append(ops, clientv3.WithRev(rev+1))
}
ctx, cancel := context.WithCancel(cli.Ctx())
if watcher, ok := c.watchers[key]; ok {
watcher.cancel = cancel
} else {
val := newWatchValue()
val.cancel = cancel
c.lock.Lock()
c.watchers[key] = val
c.lock.Unlock()
}
rch = cli.Watch(clientv3.WithRequireLeader(ctx), wkey, ops...)
return ctx, rch
}
func (c *cluster) watchConnState(cli EtcdClient) {
watcher := newStateWatcher()
watcher.addListener(func() {
@@ -459,11 +337,13 @@ func (c *cluster) watchConnState(cli EtcdClient) {
// DialClient dials an etcd cluster with given endpoints.
func DialClient(endpoints []string) (EtcdClient, error) {
cfg := clientv3.Config{
Endpoints: endpoints,
AutoSyncInterval: autoSyncInterval,
DialTimeout: DialTimeout,
RejectOldCluster: true,
PermitWithoutStream: true,
Endpoints: endpoints,
AutoSyncInterval: autoSyncInterval,
DialTimeout: DialTimeout,
DialKeepAliveTime: dialKeepAliveTime,
DialKeepAliveTimeout: DialTimeout,
RejectOldCluster: true,
PermitWithoutStream: true,
}
if account, ok := GetAccount(endpoints); ok {
cfg.Username = account.User
@@ -476,28 +356,6 @@ func DialClient(endpoints []string) (EtcdClient, error) {
return clientv3.New(cfg)
}
func calculateChanges(oldVals, newVals map[string]string) (add, remove []KV) {
for k, v := range newVals {
if val, ok := oldVals[k]; !ok || v != val {
add = append(add, KV{
Key: k,
Val: v,
})
}
}
for k, v := range oldVals {
if val, ok := newVals[k]; !ok || v != val {
remove = append(remove, KV{
Key: k,
Val: v,
})
}
}
return add, remove
}
func getClusterKey(endpoints []string) string {
sort.Strings(endpoints)
return strings.Join(endpoints, endpointsSeparator)
@@ -506,10 +364,3 @@ func getClusterKey(endpoints []string) string {
func makeKeyPrefix(key string) string {
return fmt.Sprintf("%s%c", key, Delimiter)
}
// NewClient returns a watchValue that make sure values are not nil.
func newWatchValue() *watchValue {
return &watchValue{
values: make(map[string]string),
}
}

View File

@@ -2,10 +2,8 @@ package internal
import (
"context"
"os"
"sync"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
@@ -13,11 +11,9 @@ import (
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stringx"
"github.com/zeromicro/go-zero/core/threading"
"go.etcd.io/etcd/api/v3/etcdserverpb"
"go.etcd.io/etcd/api/v3/mvccpb"
clientv3 "go.etcd.io/etcd/client/v3"
"go.etcd.io/etcd/client/v3/mock/mockserver"
)
var mockLock sync.Mutex
@@ -39,9 +35,9 @@ func setMockClient(cli EtcdClient) func() {
func TestGetCluster(t *testing.T) {
AddAccount([]string{"first"}, "foo", "bar")
c1, _ := GetRegistry().getOrCreateCluster([]string{"first"})
c2, _ := GetRegistry().getOrCreateCluster([]string{"second"})
c3, _ := GetRegistry().getOrCreateCluster([]string{"first"})
c1, _ := GetRegistry().getCluster([]string{"first"})
c2, _ := GetRegistry().getCluster([]string{"second"})
c3, _ := GetRegistry().getCluster([]string{"first"})
assert.Equal(t, c1, c3)
assert.NotEqual(t, c1, c2)
}
@@ -51,36 +47,6 @@ func TestGetClusterKey(t *testing.T) {
getClusterKey([]string{"remotehost:5678", "localhost:1234"}))
}
func TestUnmonitor(t *testing.T) {
t.Run("no listener", func(t *testing.T) {
reg := &Registry{
clusters: map[string]*cluster{},
}
assert.NotPanics(t, func() {
reg.Unmonitor([]string{"any"}, "any", false, nil)
})
})
t.Run("no value", func(t *testing.T) {
reg := &Registry{
clusters: map[string]*cluster{
"any": {
watchers: map[watchKey]*watchValue{
{
key: "any",
}: {
values: map[string]string{},
},
},
},
},
}
assert.NotPanics(t, func() {
reg.Unmonitor([]string{"any"}, "another", false, nil)
})
})
}
func TestCluster_HandleChanges(t *testing.T) {
ctrl := gomock.NewController(t)
l := NewMockUpdateListener(ctrl)
@@ -109,14 +75,8 @@ func TestCluster_HandleChanges(t *testing.T) {
Val: "4",
})
c := newCluster([]string{"any"})
key := watchKey{
key: "any",
exactMatch: false,
}
c.watchers[key] = &watchValue{
listeners: []UpdateListener{l},
}
c.handleChanges(key, []KV{
c.listeners["any"] = []UpdateListener{l}
c.handleChanges("any", []KV{
{
Key: "first",
Val: "1",
@@ -129,8 +89,8 @@ func TestCluster_HandleChanges(t *testing.T) {
assert.EqualValues(t, map[string]string{
"first": "1",
"second": "2",
}, c.watchers[key].values)
c.handleChanges(key, []KV{
}, c.values["any"])
c.handleChanges("any", []KV{
{
Key: "third",
Val: "3",
@@ -143,7 +103,7 @@ func TestCluster_HandleChanges(t *testing.T) {
assert.EqualValues(t, map[string]string{
"third": "3",
"fourth": "4",
}, c.watchers[key].values)
}, c.values["any"])
}
func TestCluster_Load(t *testing.T) {
@@ -163,11 +123,9 @@ func TestCluster_Load(t *testing.T) {
}, nil)
cli.EXPECT().Ctx().Return(context.Background())
c := &cluster{
watchers: make(map[watchKey]*watchValue),
values: make(map[string]map[string]string),
}
c.load(cli, watchKey{
key: "any",
})
c.load(cli, "any")
}
func TestCluster_Watch(t *testing.T) {
@@ -199,25 +157,20 @@ func TestCluster_Watch(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
c := &cluster{
watchers: make(map[watchKey]*watchValue),
}
key := watchKey{
key: "any",
listeners: make(map[string][]UpdateListener),
values: make(map[string]map[string]string),
}
listener := NewMockUpdateListener(ctrl)
c.watchers[key] = &watchValue{
listeners: []UpdateListener{listener},
values: make(map[string]string),
}
c.listeners["any"] = []UpdateListener{listener}
listener.EXPECT().OnAdd(gomock.Any()).Do(func(kv KV) {
assert.Equal(t, "hello", kv.Key)
assert.Equal(t, "world", kv.Val)
wg.Done()
}).MaxTimes(1)
listener.EXPECT().OnDelete(gomock.Any()).Do(func(_ any) {
listener.EXPECT().OnDelete(gomock.Any()).Do(func(_ interface{}) {
wg.Done()
}).MaxTimes(1)
go c.watch(cli, key, 0)
go c.watch(cli, "any", 0)
ch <- clientv3.WatchResponse{
Events: []*clientv3.Event{
{
@@ -255,111 +208,17 @@ func TestClusterWatch_RespFailures(t *testing.T) {
ch := make(chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes()
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
c := &cluster{
watchers: make(map[watchKey]*watchValue),
}
c := new(cluster)
c.done = make(chan lang.PlaceholderType)
go func() {
ch <- resp
close(c.done)
}()
key := watchKey{
key: "any",
}
c.watch(cli, key, 0)
c.watch(cli, "any", 0)
})
}
}
func TestCluster_getCurrent(t *testing.T) {
t.Run("no value", func(t *testing.T) {
c := &cluster{
watchers: map[watchKey]*watchValue{
{
key: "any",
}: {
values: map[string]string{},
},
},
}
assert.Nil(t, c.getCurrent(watchKey{
key: "another",
}))
})
}
func TestCluster_handleWatchEvents(t *testing.T) {
t.Run("no value", func(t *testing.T) {
c := &cluster{
watchers: map[watchKey]*watchValue{
{
key: "any",
}: {
values: map[string]string{},
},
},
}
assert.NotPanics(t, func() {
c.handleWatchEvents(context.Background(), watchKey{
key: "another",
}, nil)
})
})
}
func TestCluster_addListener(t *testing.T) {
t.Run("has listener", func(t *testing.T) {
c := &cluster{
watchers: map[watchKey]*watchValue{
{
key: "any",
}: {
listeners: make([]UpdateListener, 0),
},
},
}
assert.NotPanics(t, func() {
c.addListener(watchKey{
key: "any",
}, nil)
})
})
t.Run("no listener", func(t *testing.T) {
c := &cluster{
watchers: map[watchKey]*watchValue{
{
key: "any",
}: {
listeners: make([]UpdateListener, 0),
},
},
}
assert.NotPanics(t, func() {
c.addListener(watchKey{
key: "another",
}, nil)
})
})
}
func TestCluster_reload(t *testing.T) {
c := &cluster{
watchers: map[watchKey]*watchValue{},
watchGroup: threading.NewRoutineGroup(),
done: make(chan lang.PlaceholderType),
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
cli := NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
assert.NotPanics(t, func() {
c.reload(cli)
})
}
func TestClusterWatch_CloseChan(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
@@ -369,17 +228,13 @@ func TestClusterWatch_CloseChan(t *testing.T) {
ch := make(chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes()
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
c := &cluster{
watchers: make(map[watchKey]*watchValue),
}
c := new(cluster)
c.done = make(chan lang.PlaceholderType)
go func() {
close(ch)
close(c.done)
}()
c.watch(cli, watchKey{
key: "any",
}, 0)
c.watch(cli, "any", 0)
}
func TestValueOnlyContext(t *testing.T) {
@@ -387,101 +242,3 @@ func TestValueOnlyContext(t *testing.T) {
ctx.Done()
assert.Nil(t, ctx.Err())
}
func TestDialClient(t *testing.T) {
svr, err := mockserver.StartMockServers(1)
assert.NoError(t, err)
svr.StartAt(0)
certFile := createTempFile(t, []byte(certContent))
defer os.Remove(certFile)
keyFile := createTempFile(t, []byte(keyContent))
defer os.Remove(keyFile)
caFile := createTempFile(t, []byte(caContent))
defer os.Remove(caFile)
endpoints := []string{svr.Servers[0].Address}
AddAccount(endpoints, "foo", "bar")
assert.NoError(t, AddTLS(endpoints, certFile, keyFile, caFile, false))
old := DialTimeout
DialTimeout = time.Millisecond
defer func() {
DialTimeout = old
}()
_, err = DialClient(endpoints)
assert.Error(t, err)
}
func TestRegistry_Monitor(t *testing.T) {
svr, err := mockserver.StartMockServers(1)
assert.NoError(t, err)
svr.StartAt(0)
endpoints := []string{svr.Servers[0].Address}
GetRegistry().lock.Lock()
GetRegistry().clusters = map[string]*cluster{
getClusterKey(endpoints): {
watchers: map[watchKey]*watchValue{
watchKey{
key: "foo",
exactMatch: true,
}: {
values: map[string]string{
"bar": "baz",
},
},
},
},
}
GetRegistry().lock.Unlock()
assert.Error(t, GetRegistry().Monitor(endpoints, "foo", false, new(mockListener)))
}
func TestRegistry_Unmonitor(t *testing.T) {
svr, err := mockserver.StartMockServers(1)
assert.NoError(t, err)
svr.StartAt(0)
_, cancel := context.WithCancel(context.Background())
endpoints := []string{svr.Servers[0].Address}
GetRegistry().lock.Lock()
GetRegistry().clusters = map[string]*cluster{
getClusterKey(endpoints): {
watchers: map[watchKey]*watchValue{
watchKey{
key: "foo",
exactMatch: true,
}: {
values: map[string]string{
"bar": "baz",
},
cancel: cancel,
},
},
},
}
GetRegistry().lock.Unlock()
l := new(mockListener)
assert.NoError(t, GetRegistry().Monitor(endpoints, "foo", true, l))
watchVals := GetRegistry().clusters[getClusterKey(endpoints)].watchers[watchKey{
key: "foo",
exactMatch: true,
}]
assert.Equal(t, 1, len(watchVals.listeners))
GetRegistry().Unmonitor(endpoints, "foo", true, l)
watchVals = GetRegistry().clusters[getClusterKey(endpoints)].watchers[watchKey{
key: "foo",
exactMatch: true,
}]
assert.Nil(t, watchVals)
}
type mockListener struct {
}
func (m *mockListener) OnAdd(_ KV) {
}
func (m *mockListener) OnDelete(_ KV) {
}

View File

@@ -58,7 +58,7 @@ func (m *MocketcdConn) WaitForStateChange(ctx context.Context, sourceState conne
}
// WaitForStateChange indicates an expected call of WaitForStateChange
func (mr *MocketcdConnMockRecorder) WaitForStateChange(ctx, sourceState any) *gomock.Call {
func (mr *MocketcdConnMockRecorder) WaitForStateChange(ctx, sourceState interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WaitForStateChange", reflect.TypeOf((*MocketcdConn)(nil).WaitForStateChange), ctx, sourceState)
}

View File

@@ -10,7 +10,6 @@ type (
}
// UpdateListener wraps the OnAdd and OnDelete methods.
// The implementation should be thread-safe and idempotent.
UpdateListener interface {
OnAdd(kv KV)
OnDelete(kv KV)

View File

@@ -40,7 +40,7 @@ func (m *MockUpdateListener) OnAdd(kv KV) {
}
// OnAdd indicates an expected call of OnAdd
func (mr *MockUpdateListenerMockRecorder) OnAdd(kv any) *gomock.Call {
func (mr *MockUpdateListenerMockRecorder) OnAdd(kv interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnAdd", reflect.TypeOf((*MockUpdateListener)(nil).OnAdd), kv)
}
@@ -52,7 +52,7 @@ func (m *MockUpdateListener) OnDelete(kv KV) {
}
// OnDelete indicates an expected call of OnDelete
func (mr *MockUpdateListenerMockRecorder) OnDelete(kv any) *gomock.Call {
func (mr *MockUpdateListenerMockRecorder) OnDelete(kv interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnDelete", reflect.TypeOf((*MockUpdateListener)(nil).OnDelete), kv)
}

View File

@@ -9,6 +9,7 @@ const (
autoSyncInterval = time.Minute
coolDownInterval = time.Second
dialTimeout = 5 * time.Second
dialKeepAliveTime = 5 * time.Second
requestTimeout = 3 * time.Second
endpointsSeparator = ","
)

View File

@@ -5,7 +5,6 @@ import (
"github.com/zeromicro/go-zero/core/discov/internal"
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logc"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/syncx"
@@ -92,12 +91,12 @@ func (p *Publisher) doKeepAlive() error {
default:
cli, err := p.doRegister()
if err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher doRegister: %s", err.Error())
logx.Errorf("etcd publisher doRegister: %s", err.Error())
break
}
if err := p.keepAliveAsync(cli); err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher keepAliveAsync: %s", err.Error())
logx.Errorf("etcd publisher keepAliveAsync: %s", err.Error())
break
}
@@ -131,17 +130,17 @@ func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error {
if !ok {
p.revoke(cli)
if err := p.doKeepAlive(); err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %s", err.Error())
logx.Errorf("etcd publisher KeepAlive: %s", err.Error())
}
return
}
case <-p.pauseChan:
logc.Infof(cli.Ctx(), "paused etcd renew, key: %s, value: %s", p.key, p.value)
logx.Infof("paused etcd renew, key: %s, value: %s", p.key, p.value)
p.revoke(cli)
select {
case <-p.resumeChan:
if err := p.doKeepAlive(); err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %s", err.Error())
logx.Errorf("etcd publisher KeepAlive: %s", err.Error())
}
return
case <-p.quit.Done():
@@ -176,7 +175,7 @@ func (p *Publisher) register(client internal.EtcdClient) (clientv3.LeaseID, erro
func (p *Publisher) revoke(cli internal.EtcdClient) {
if _, err := cli.Revoke(cli.Ctx(), p.lease); err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher revoke: %s", err.Error())
logx.Errorf("etcd publisher revoke: %s", err.Error())
}
}

View File

@@ -1,10 +1,7 @@
package discov
import (
"context"
"errors"
"net"
"os"
"sync"
"testing"
"time"
@@ -16,83 +13,6 @@ import (
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stringx"
clientv3 "go.etcd.io/etcd/client/v3"
"golang.org/x/net/http2"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
)
const (
certContent = `-----BEGIN CERTIFICATE-----
MIIDazCCAlOgAwIBAgIUEg9GVO2oaPn+YSmiqmFIuAo10WIwDQYJKoZIhvcNAQEM
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAzMTExMzIxMjNaGA8yMTIz
MDIxNTEzMjEyM1owRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
AQEBBQADggEPADCCAQoCggEBALplXlWsIf0O/IgnIplmiZHKGnxyfyufyE2FBRNk
OofRqbKuPH8GNqbkvZm7N29fwTDAQ+mViAggCkDht4hOzoWJMA7KYJt8JnTSWL48
M1lcrpc9DL2gszC/JF/FGvyANbBtLklkZPFBGdHUX14pjrT937wqPtm+SqUHSvRT
B7bmwmm2drRcmhpVm98LSlV7uQ2EgnJgsLjBPITKUejLmVLHfgX0RwQ2xIpX9pS4
FCe1BTacwl2gGp7Mje7y4Mfv3o0ArJW6Tuwbjx59ZXwb1KIP71b7bT04AVS8ZeYO
UMLKKuB5UR9x9Rn6cLXOTWBpcMVyzDgrAFLZjnE9LPUolZMCAwEAAaNRME8wHwYD
VR0jBBgwFoAUeW8w8pmhncbRgTsl48k4/7wnfx8wCQYDVR0TBAIwADALBgNVHQ8E
BAMCBPAwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBDAUAA4IBAQAI
y9xaoS88CLPBsX6mxfcTAFVfGNTRW9VN9Ng1cCnUR+YGoXGM/l+qP4f7p8ocdGwK
iYZErVTzXYIn+D27//wpY3klJk3gAnEUBT3QRkStBw7XnpbeZ2oPBK+cmDnCnZPS
BIF1wxPX7vIgaxs5Zsdqwk3qvZ4Djr2wP7LabNWTLSBKgQoUY45Liw6pffLwcGF9
UKlu54bvGze2SufISCR3ib+I+FLvqpvJhXToZWYb/pfI/HccuCL1oot1x8vx6DQy
U+TYxlZsKS5mdNxAX3dqEkEMsgEi+g/tzDPXJImfeCGGBhIOXLm8SRypiuGdEbc9
xkWYxRPegajuEZGvCqVs
-----END CERTIFICATE-----`
keyContent = `-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAumVeVawh/Q78iCcimWaJkcoafHJ/K5/ITYUFE2Q6h9Gpsq48
fwY2puS9mbs3b1/BMMBD6ZWICCAKQOG3iE7OhYkwDspgm3wmdNJYvjwzWVyulz0M
vaCzML8kX8Ua/IA1sG0uSWRk8UEZ0dRfXimOtP3fvCo+2b5KpQdK9FMHtubCabZ2
tFyaGlWb3wtKVXu5DYSCcmCwuME8hMpR6MuZUsd+BfRHBDbEilf2lLgUJ7UFNpzC
XaAansyN7vLgx+/ejQCslbpO7BuPHn1lfBvUog/vVvttPTgBVLxl5g5Qwsoq4HlR
H3H1Gfpwtc5NYGlwxXLMOCsAUtmOcT0s9SiVkwIDAQABAoIBAD5meTJNMgO55Kjg
ESExxpRcCIno+tHr5+6rvYtEXqPheOIsmmwb9Gfi4+Z3WpOaht5/Pz0Ppj6yGzyl
U//6AgGKb+BDuBvVcDpjwPnOxZIBCSHwejdxeQu0scSuA97MPS0XIAvJ5FEv7ijk
5Bht6SyGYURpECltHygoTNuGgGqmO+McCJRLE9L09lTBI6UQ/JQwWJqSr7wx6iPU
M1Ze/srIV+7cyEPu6i0DGjS1gSQKkX68Lqn1w6oE290O+OZvleO0gZ02fLDWCZke
aeD9+EU/Pw+rqm3H6o0szOFIpzhRp41FUdW9sybB3Yp3u7c/574E+04Z/e30LMKs
TCtE1QECgYEA3K7KIpw0NH2HXL5C3RHcLmr204xeBfS70riBQQuVUgYdmxak2ima
80RInskY8hRhSGTg0l+VYIH8cmjcUyqMSOELS5XfRH99r4QPiK8AguXg80T4VumY
W3Pf+zEC2ssgP/gYthV0g0Xj5m2QxktOF9tRw5nkg739ZR4dI9lm/iECgYEA2Dnf
uwEDGqHiQRF6/fh5BG/nGVMvrefkqx6WvTJQ3k/M/9WhxB+lr/8yH46TuS8N2b29
FoTf3Mr9T7pr/PWkOPzoY3P56nYbKU8xSwCim9xMzhBMzj8/N9ukJvXy27/VOz56
eQaKqnvdXNGtPJrIMDGHps2KKWlKLyAlapzjVTMCgYAA/W++tACv85g13EykfT4F
n0k4LbsGP9DP4zABQLIMyiY72eAncmRVjwrcW36XJ2xATOONTgx3gF3HjZzfaqNy
eD/6uNNllUTVEryXGmHgNHPL45VRnn6memCY2eFvZdXhM5W4y2PYaunY0MkDercA
+GTngbs6tBF88KOk04bYwQKBgFl68cRgsdkmnwwQYNaTKfmVGYzYaQXNzkqmWPko
xmCJo6tHzC7ubdG8iRCYHzfmahPuuj6EdGPZuSRyYFgJi5Ftz/nAN+84OxtIQ3zn
YWOgskQgaLh9YfsKsQ7Sf1NDOsnOnD5TX7UXl07fEpLe9vNCvAFiU8e5Y9LGudU5
4bYTAoGBAMdX3a3bXp4cZvXNBJ/QLVyxC6fP1Q4haCR1Od3m+T00Jth2IX2dk/fl
p6xiJT1av5JtYabv1dFKaXOS5s1kLGGuCCSKpkvFZm826aQ2AFm0XGqEQDLeei5b
A52Kpy/YJ+RkG4BTFtAooFq6DmA0cnoP6oPvG2h6XtDJwDTPInJb
-----END RSA PRIVATE KEY-----`
caContent = `-----BEGIN CERTIFICATE-----
MIIDbTCCAlWgAwIBAgIUBJvFoCowKich7MMfseJ+DYzzirowDQYJKoZIhvcNAQEM
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAzMTExMzIxMDNaGA8yMTIz
MDIxNTEzMjEwM1owRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
AQEBBQADggEPADCCAQoCggEBAO4to2YMYj0bxgr2FCiweSTSFuPx33zSw2x/s9Wf
OR41bm2DFsyYT5f3sOIKlXZEdLmOKty2e3ho3yC0EyNpVHdykkkHT3aDI17quZax
kYi/URqqtl1Z08A22txolc04hAZisg2BypGi3vql81UW1t3zyloGnJoIAeXR9uca
ljP6Bk3bwsxoVBLi1JtHrO0hHLQaeHmKhAyrys06X0LRdn7Px48yRZlt6FaLSa8X
YiRM0G44bVy/h6BkoQjMYGwVmCVk6zjJ9U7ZPFqdnDMNxAfR+hjDnYodqdLDMTTR
1NPVrnEnNwFx0AMLvgt/ba/45vZCEAmSZnFXFAJJcM7ai9ECAwEAAaNTMFEwHQYD
VR0OBBYEFHlvMPKZoZ3G0YE7JePJOP+8J38fMB8GA1UdIwQYMBaAFHlvMPKZoZ3G
0YE7JePJOP+8J38fMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQEMBQADggEB
AMX8dNulADOo9uQgBMyFb9TVra7iY0zZjzv4GY5XY7scd52n6CnfAPvYBBDnTr/O
BgNp5jaujb4+9u/2qhV3f9n+/3WOb2CmPehBgVSzlXqHeQ9lshmgwZPeem2T+8Tm
Nnc/xQnsUfCFszUDxpkr55+aLVM22j02RWqcZ4q7TAaVYL+kdFVMc8FoqG/0ro6A
BjE/Qn0Nn7ciX1VUjDt8l+k7ummPJTmzdi6i6E4AwO9dzrGNgGJ4aWL8cC6xYcIX
goVIRTFeONXSDno/oPjWHpIPt7L15heMpKBHNuzPkKx2YVqPHE5QZxWfS+Lzgx+Q
E2oTTM0rYKOZ8p6000mhvKI=
-----END CERTIFICATE-----`
)
func init() {
@@ -117,7 +37,7 @@ func TestPublisher_register(t *testing.T) {
assert.Nil(t, err)
}
func TestPublisher_registerWithOptions(t *testing.T) {
func TestPublisher_registerWithId(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
const id = 2
@@ -129,15 +49,7 @@ func TestPublisher_registerWithOptions(t *testing.T) {
ID: 1,
}, nil)
cli.EXPECT().Put(gomock.Any(), makeEtcdKey("thekey", id), "thevalue", gomock.Any())
certFile := createTempFile(t, []byte(certContent))
defer os.Remove(certFile)
keyFile := createTempFile(t, []byte(keyContent))
defer os.Remove(keyFile)
caFile := createTempFile(t, []byte(caContent))
defer os.Remove(caFile)
pub := NewPublisher(nil, "thekey", "thevalue", WithId(id),
WithPubEtcdTLS(certFile, keyFile, caFile, true))
pub := NewPublisher(nil, "thekey", "thevalue", WithId(id))
_, err := pub.register(cli)
assert.Nil(t, err)
}
@@ -213,7 +125,7 @@ func TestPublisher_keepAliveAsyncQuit(t *testing.T) {
cli.EXPECT().KeepAlive(gomock.Any(), id)
var wg sync.WaitGroup
wg.Add(1)
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ interface{}) {
wg.Done()
})
pub := NewPublisher(nil, "thekey", "thevalue")
@@ -235,7 +147,7 @@ func TestPublisher_keepAliveAsyncPause(t *testing.T) {
pub := NewPublisher(nil, "thekey", "thevalue")
var wg sync.WaitGroup
wg.Add(1)
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ interface{}) {
pub.Stop()
wg.Done()
})
@@ -257,92 +169,3 @@ func TestPublisher_Resume(t *testing.T) {
}()
<-publisher.resumeChan
}
func TestPublisher_keepAliveAsync(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
const id clientv3.LeaseID = 1
conn := createMockConn(t)
defer conn.Close()
cli := internal.NewMockEtcdClient(ctrl)
cli.EXPECT().ActiveConnection().Return(conn).AnyTimes()
cli.EXPECT().Close()
defer cli.Close()
cli.ActiveConnection()
restore := setMockClient(cli)
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().KeepAlive(gomock.Any(), id)
cli.EXPECT().Grant(gomock.Any(), timeToLive).Return(&clientv3.LeaseGrantResponse{
ID: 1,
}, nil)
cli.EXPECT().Put(gomock.Any(), makeEtcdKey("thekey", int64(id)), "thevalue", gomock.Any())
var wg sync.WaitGroup
wg.Add(1)
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
wg.Done()
})
pub := NewPublisher([]string{"the-endpoint"}, "thekey", "thevalue")
pub.lease = id
assert.Nil(t, pub.KeepAlive())
pub.Stop()
wg.Wait()
}
func createMockConn(t *testing.T) *grpc.ClientConn {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening. Err: %v", err)
}
defer lis.Close()
lisAddr := resolver.Address{Addr: lis.Addr().String()}
lisDone := make(chan struct{})
dialDone := make(chan struct{})
// 1st listener accepts the connection and then does nothing
go func() {
defer close(lisDone)
conn, err := lis.Accept()
if err != nil {
t.Errorf("Error while accepting. Err: %v", err)
return
}
framer := http2.NewFramer(conn, conn)
if err := framer.WriteSettings(http2.Setting{}); err != nil {
t.Errorf("Error while writing settings. Err: %v", err)
return
}
<-dialDone // Close conn only after dial returns.
}()
r := manual.NewBuilderWithScheme("whatever")
r.InitialState(resolver.State{Addresses: []resolver.Address{lisAddr}})
client, err := grpc.DialContext(context.Background(), r.Scheme()+":///test.server",
grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r))
close(dialDone)
if err != nil {
t.Fatalf("Dial failed. Err: %v", err)
}
timeout := time.After(1 * time.Second)
select {
case <-timeout:
t.Fatal("timed out waiting for server to finish")
case <-lisDone:
}
return client
}
func createTempFile(t *testing.T, body []byte) string {
tmpFile, err := os.CreateTemp(os.TempDir(), "go-unit-*.tmp")
if err != nil {
t.Fatal(err)
}
tmpFile.Close()
if err = os.WriteFile(tmpFile.Name(), body, os.ModePerm); err != nil {
t.Fatal(err)
}
return tmpFile.Name()
}

View File

@@ -15,11 +15,9 @@ type (
// A Subscriber is used to subscribe the given key on an etcd cluster.
Subscriber struct {
endpoints []string
exclusive bool
key string
exactMatch bool
items *container
endpoints []string
exclusive bool
items *container
}
)
@@ -30,14 +28,13 @@ type (
func NewSubscriber(endpoints []string, key string, opts ...SubOption) (*Subscriber, error) {
sub := &Subscriber{
endpoints: endpoints,
key: key,
}
for _, opt := range opts {
opt(sub)
}
sub.items = newContainer(sub.exclusive)
if err := internal.GetRegistry().Monitor(endpoints, key, sub.exactMatch, sub.items); err != nil {
if err := internal.GetRegistry().Monitor(endpoints, key, sub.items); err != nil {
return nil, err
}
@@ -49,11 +46,6 @@ func (s *Subscriber) AddListener(listener func()) {
s.items.addListener(listener)
}
// Close closes the subscriber.
func (s *Subscriber) Close() {
internal.GetRegistry().Unmonitor(s.endpoints, s.key, s.exactMatch, s.items)
}
// Values returns all the subscription values.
func (s *Subscriber) Values() []string {
return s.items.getValues()
@@ -67,13 +59,6 @@ func Exclusive() SubOption {
}
}
// WithExactMatch turn off querying using key prefixes.
func WithExactMatch() SubOption {
return func(sub *Subscriber) {
sub.exactMatch = true
}
}
// WithSubEtcdAccount provides the etcd username/password.
func WithSubEtcdAccount(user, pass string) SubOption {
return func(sub *Subscriber) {

View File

@@ -225,28 +225,3 @@ func TestWithSubEtcdAccount(t *testing.T) {
assert.Equal(t, user, account.User)
assert.Equal(t, "bar", account.Pass)
}
func TestWithExactMatch(t *testing.T) {
sub := new(Subscriber)
WithExactMatch()(sub)
sub.items = newContainer(sub.exclusive)
var count int32
sub.AddListener(func() {
atomic.AddInt32(&count, 1)
})
sub.items.notifyChange()
assert.Empty(t, sub.Values())
assert.Equal(t, int32(1), atomic.LoadInt32(&count))
}
func TestSubscriberClose(t *testing.T) {
l := newContainer(false)
sub := &Subscriber{
endpoints: []string{"localhost:12379"},
key: "foo",
items: l,
}
assert.NotPanics(t, func() {
sub.Close()
})
}

View File

@@ -1,21 +1,18 @@
package errorx
import (
"errors"
"sync"
import "bytes"
type (
// A BatchError is an error that can hold multiple errors.
BatchError struct {
errs errorArray
}
errorArray []error
)
// BatchError is an error that can hold multiple errors.
type BatchError struct {
errs []error
lock sync.RWMutex
}
// Add adds one or more non-nil errors to the BatchError instance.
// Add adds errs to be, nil errors are ignored.
func (be *BatchError) Add(errs ...error) {
be.lock.Lock()
defer be.lock.Unlock()
for _, err := range errs {
if err != nil {
be.errs = append(be.errs, err)
@@ -23,20 +20,33 @@ func (be *BatchError) Add(errs ...error) {
}
}
// Err returns an error that represents all accumulated errors.
// It returns nil if there are no errors.
// Err returns an error that represents all errors.
func (be *BatchError) Err() error {
be.lock.RLock()
defer be.lock.RUnlock()
// If there are no non-nil errors, errors.Join(...) returns nil.
return errors.Join(be.errs...)
switch len(be.errs) {
case 0:
return nil
case 1:
return be.errs[0]
default:
return be.errs
}
}
// NotNil checks if there is at least one error inside the BatchError.
// NotNil checks if any error inside.
func (be *BatchError) NotNil() bool {
be.lock.RLock()
defer be.lock.RUnlock()
return len(be.errs) > 0
}
// Error returns a string that represents inside errors.
func (ea errorArray) Error() string {
var buf bytes.Buffer
for i := range ea {
if i > 0 {
buf.WriteByte('\n')
}
buf.WriteString(ea[i].Error())
}
return buf.String()
}

View File

@@ -3,7 +3,6 @@ package errorx
import (
"errors"
"fmt"
"sync"
"testing"
"github.com/stretchr/testify/assert"
@@ -34,7 +33,7 @@ func TestBatchErrorNilFromFunc(t *testing.T) {
func TestBatchErrorOneError(t *testing.T) {
var batch BatchError
batch.Add(errors.New(err1))
assert.NotNil(t, batch.Err())
assert.NotNil(t, batch)
assert.Equal(t, err1, batch.Err().Error())
assert.True(t, batch.NotNil())
}
@@ -43,105 +42,7 @@ func TestBatchErrorWithErrors(t *testing.T) {
var batch BatchError
batch.Add(errors.New(err1))
batch.Add(errors.New(err2))
assert.NotNil(t, batch.Err())
assert.NotNil(t, batch)
assert.Equal(t, fmt.Sprintf("%s\n%s", err1, err2), batch.Err().Error())
assert.True(t, batch.NotNil())
}
func TestBatchErrorConcurrentAdd(t *testing.T) {
const count = 10000
var batch BatchError
var wg sync.WaitGroup
wg.Add(count)
for i := 0; i < count; i++ {
go func() {
defer wg.Done()
batch.Add(errors.New(err1))
}()
}
wg.Wait()
assert.NotNil(t, batch.Err())
assert.Equal(t, count, len(batch.errs))
assert.True(t, batch.NotNil())
}
func TestBatchError_Unwrap(t *testing.T) {
t.Run("nil", func(t *testing.T) {
var be BatchError
assert.Nil(t, be.Err())
assert.True(t, errors.Is(be.Err(), nil))
})
t.Run("one error", func(t *testing.T) {
var errFoo = errors.New("foo")
var errBar = errors.New("bar")
var be BatchError
be.Add(errFoo)
assert.True(t, errors.Is(be.Err(), errFoo))
assert.False(t, errors.Is(be.Err(), errBar))
})
t.Run("two errors", func(t *testing.T) {
var errFoo = errors.New("foo")
var errBar = errors.New("bar")
var errBaz = errors.New("baz")
var be BatchError
be.Add(errFoo)
be.Add(errBar)
assert.True(t, errors.Is(be.Err(), errFoo))
assert.True(t, errors.Is(be.Err(), errBar))
assert.False(t, errors.Is(be.Err(), errBaz))
})
}
func TestBatchError_Add(t *testing.T) {
var be BatchError
// Test adding nil errors
be.Add(nil, nil)
assert.False(t, be.NotNil(), "Expected BatchError to be empty after adding nil errors")
// Test adding non-nil errors
err1 := errors.New("error 1")
err2 := errors.New("error 2")
be.Add(err1, err2)
assert.True(t, be.NotNil(), "Expected BatchError to be non-empty after adding errors")
// Test adding a mix of nil and non-nil errors
err3 := errors.New("error 3")
be.Add(nil, err3, nil)
assert.True(t, be.NotNil(), "Expected BatchError to be non-empty after adding a mix of nil and non-nil errors")
}
func TestBatchError_Err(t *testing.T) {
var be BatchError
// Test Err() on empty BatchError
assert.Nil(t, be.Err(), "Expected nil error for empty BatchError")
// Test Err() with multiple errors
err1 := errors.New("error 1")
err2 := errors.New("error 2")
be.Add(err1, err2)
combinedErr := be.Err()
assert.NotNil(t, combinedErr, "Expected nil error for BatchError with multiple errors")
// Check if the combined error contains both error messages
errString := combinedErr.Error()
assert.Truef(t, errors.Is(combinedErr, err1), "Combined error doesn't contain first error: %s", errString)
assert.Truef(t, errors.Is(combinedErr, err2), "Combined error doesn't contain second error: %s", errString)
}
func TestBatchError_NotNil(t *testing.T) {
var be BatchError
// Test NotNil() on empty BatchError
assert.Nil(t, be.Err(), "Expected nil error for empty BatchError")
// Test NotNil() after adding an error
be.Add(errors.New("test error"))
assert.NotNil(t, be.Err(), "Expected non-nil error after adding an error")
}

View File

@@ -1,14 +0,0 @@
package errorx
import "errors"
// In checks if the given err is one of errs.
func In(err error, errs ...error) bool {
for _, each := range errs {
if errors.Is(err, each) {
return true
}
}
return false
}

View File

@@ -1,70 +0,0 @@
package errorx
import (
"errors"
"testing"
)
func TestIn(t *testing.T) {
err1 := errors.New("error 1")
err2 := errors.New("error 2")
err3 := errors.New("error 3")
tests := []struct {
name string
err error
errs []error
want bool
}{
{
name: "Error matches one of the errors in the list",
err: err1,
errs: []error{err1, err2},
want: true,
},
{
name: "Error does not match any errors in the list",
err: err3,
errs: []error{err1, err2},
want: false,
},
{
name: "Empty error list",
err: err1,
errs: []error{},
want: false,
},
{
name: "Nil error with non-nil list",
err: nil,
errs: []error{err1, err2},
want: false,
},
{
name: "Non-nil error with nil in list",
err: err1,
errs: []error{nil, err2},
want: false,
},
{
name: "Error matches nil error in the list",
err: nil,
errs: []error{nil, err2},
want: true,
},
{
name: "Nil error with empty list",
err: nil,
errs: []error{},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := In(tt.err, tt.errs...); got != tt.want {
t.Errorf("In() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -12,7 +12,7 @@ func Wrap(err error, message string) error {
}
// Wrapf returns an error that wraps err with given format and args.
func Wrapf(err error, format string, args ...any) error {
func Wrapf(err error, format string, args ...interface{}) error {
if err == nil {
return nil
}

View File

@@ -42,7 +42,7 @@ func NewBulkExecutor(execute Execute, opts ...BulkOption) *BulkExecutor {
}
// Add adds task into be.
func (be *BulkExecutor) Add(task any) error {
func (be *BulkExecutor) Add(task interface{}) error {
be.executor.Add(task)
return nil
}
@@ -79,22 +79,22 @@ func newBulkOptions() bulkOptions {
}
type bulkContainer struct {
tasks []any
tasks []interface{}
execute Execute
maxTasks int
}
func (bc *bulkContainer) AddTask(task any) bool {
func (bc *bulkContainer) AddTask(task interface{}) bool {
bc.tasks = append(bc.tasks, task)
return len(bc.tasks) >= bc.maxTasks
}
func (bc *bulkContainer) Execute(tasks any) {
vals := tasks.([]any)
func (bc *bulkContainer) Execute(tasks interface{}) {
vals := tasks.([]interface{})
bc.execute(vals)
}
func (bc *bulkContainer) RemoveAll() any {
func (bc *bulkContainer) RemoveAll() interface{} {
tasks := bc.tasks
bc.tasks = nil
return tasks

View File

@@ -12,7 +12,7 @@ func TestBulkExecutor(t *testing.T) {
var values []int
var lock sync.Mutex
executor := NewBulkExecutor(func(items []any) {
executor := NewBulkExecutor(func(items []interface{}) {
lock.Lock()
values = append(values, len(items))
lock.Unlock()
@@ -40,7 +40,7 @@ func TestBulkExecutorFlushInterval(t *testing.T) {
var wait sync.WaitGroup
wait.Add(1)
executor := NewBulkExecutor(func(items []any) {
executor := NewBulkExecutor(func(items []interface{}) {
assert.Equal(t, size, len(items))
wait.Done()
}, WithBulkTasks(caches), WithBulkInterval(time.Millisecond*100))
@@ -53,7 +53,7 @@ func TestBulkExecutorFlushInterval(t *testing.T) {
}
func TestBulkExecutorEmpty(t *testing.T) {
NewBulkExecutor(func(items []any) {
NewBulkExecutor(func(items []interface{}) {
assert.Fail(t, "should not called")
}, WithBulkTasks(10), WithBulkInterval(time.Millisecond))
time.Sleep(time.Millisecond * 100)
@@ -67,7 +67,7 @@ func TestBulkExecutorFlush(t *testing.T) {
var wait sync.WaitGroup
wait.Add(1)
be := NewBulkExecutor(func(items []any) {
be := NewBulkExecutor(func(items []interface{}) {
assert.Equal(t, tasks, len(items))
wait.Done()
}, WithBulkTasks(caches), WithBulkInterval(time.Minute))
@@ -78,11 +78,11 @@ func TestBulkExecutorFlush(t *testing.T) {
wait.Wait()
}
func TestBulkExecutorFlushSlowTasks(t *testing.T) {
func TestBuldExecutorFlushSlowTasks(t *testing.T) {
const total = 1500
lock := new(sync.Mutex)
result := make([]any, 0, 10000)
exec := NewBulkExecutor(func(tasks []any) {
result := make([]interface{}, 0, 10000)
exec := NewBulkExecutor(func(tasks []interface{}) {
time.Sleep(time.Millisecond * 100)
lock.Lock()
defer lock.Unlock()
@@ -100,7 +100,7 @@ func TestBulkExecutorFlushSlowTasks(t *testing.T) {
func BenchmarkBulkExecutor(b *testing.B) {
b.ReportAllocs()
be := NewBulkExecutor(func(tasks []any) {
be := NewBulkExecutor(func(tasks []interface{}) {
time.Sleep(time.Millisecond * time.Duration(len(tasks)))
})
for i := 0; i < b.N; i++ {

View File

@@ -42,7 +42,7 @@ func NewChunkExecutor(execute Execute, opts ...ChunkOption) *ChunkExecutor {
}
// Add adds task with given chunk size into ce.
func (ce *ChunkExecutor) Add(task any, size int) error {
func (ce *ChunkExecutor) Add(task interface{}, size int) error {
ce.executor.Add(chunk{
val: task,
size: size,
@@ -82,25 +82,25 @@ func newChunkOptions() chunkOptions {
}
type chunkContainer struct {
tasks []any
tasks []interface{}
execute Execute
size int
maxChunkSize int
}
func (bc *chunkContainer) AddTask(task any) bool {
func (bc *chunkContainer) AddTask(task interface{}) bool {
ck := task.(chunk)
bc.tasks = append(bc.tasks, ck.val)
bc.size += ck.size
return bc.size >= bc.maxChunkSize
}
func (bc *chunkContainer) Execute(tasks any) {
vals := tasks.([]any)
func (bc *chunkContainer) Execute(tasks interface{}) {
vals := tasks.([]interface{})
bc.execute(vals)
}
func (bc *chunkContainer) RemoveAll() any {
func (bc *chunkContainer) RemoveAll() interface{} {
tasks := bc.tasks
bc.tasks = nil
bc.size = 0
@@ -108,6 +108,6 @@ func (bc *chunkContainer) RemoveAll() any {
}
type chunk struct {
val any
val interface{}
size int
}

View File

@@ -12,7 +12,7 @@ func TestChunkExecutor(t *testing.T) {
var values []int
var lock sync.Mutex
executor := NewChunkExecutor(func(items []any) {
executor := NewChunkExecutor(func(items []interface{}) {
lock.Lock()
values = append(values, len(items))
lock.Unlock()
@@ -40,7 +40,7 @@ func TestChunkExecutorFlushInterval(t *testing.T) {
var wait sync.WaitGroup
wait.Add(1)
executor := NewChunkExecutor(func(items []any) {
executor := NewChunkExecutor(func(items []interface{}) {
assert.Equal(t, size, len(items))
wait.Done()
}, WithChunkBytes(caches), WithFlushInterval(time.Millisecond*100))
@@ -53,7 +53,7 @@ func TestChunkExecutorFlushInterval(t *testing.T) {
}
func TestChunkExecutorEmpty(t *testing.T) {
executor := NewChunkExecutor(func(items []any) {
executor := NewChunkExecutor(func(items []interface{}) {
assert.Fail(t, "should not called")
}, WithChunkBytes(10), WithFlushInterval(time.Millisecond))
time.Sleep(time.Millisecond * 100)
@@ -68,7 +68,7 @@ func TestChunkExecutorFlush(t *testing.T) {
var wait sync.WaitGroup
wait.Add(1)
be := NewChunkExecutor(func(items []any) {
be := NewChunkExecutor(func(items []interface{}) {
assert.Equal(t, tasks, len(items))
wait.Done()
}, WithChunkBytes(caches), WithFlushInterval(time.Minute))
@@ -82,7 +82,7 @@ func TestChunkExecutorFlush(t *testing.T) {
func BenchmarkChunkExecutor(b *testing.B) {
b.ReportAllocs()
be := NewChunkExecutor(func(tasks []any) {
be := NewChunkExecutor(func(tasks []interface{}) {
time.Sleep(time.Millisecond * time.Duration(len(tasks)))
})
for i := 0; i < b.N; i++ {

View File

@@ -21,16 +21,16 @@ type (
TaskContainer interface {
// AddTask adds the task into the container.
// Returns true if the container needs to be flushed after the addition.
AddTask(task any) bool
AddTask(task interface{}) bool
// Execute handles the collected tasks by the container when flushing.
Execute(tasks any)
Execute(tasks interface{})
// RemoveAll removes the contained tasks, and return them.
RemoveAll() any
RemoveAll() interface{}
}
// A PeriodicalExecutor is an executor that periodically execute tasks.
PeriodicalExecutor struct {
commander chan any
commander chan interface{}
interval time.Duration
container TaskContainer
waitGroup sync.WaitGroup
@@ -48,7 +48,7 @@ type (
func NewPeriodicalExecutor(interval time.Duration, container TaskContainer) *PeriodicalExecutor {
executor := &PeriodicalExecutor{
// buffer 1 to let the caller go quickly
commander: make(chan any, 1),
commander: make(chan interface{}, 1),
interval: interval,
container: container,
confirmChan: make(chan lang.PlaceholderType),
@@ -64,7 +64,7 @@ func NewPeriodicalExecutor(interval time.Duration, container TaskContainer) *Per
}
// Add adds tasks into pe.
func (pe *PeriodicalExecutor) Add(task any) {
func (pe *PeriodicalExecutor) Add(task interface{}) {
if vals, ok := pe.addAndCheck(task); ok {
pe.commander <- vals
<-pe.confirmChan
@@ -74,14 +74,14 @@ func (pe *PeriodicalExecutor) Add(task any) {
// Flush forces pe to execute tasks.
func (pe *PeriodicalExecutor) Flush() bool {
pe.enterExecution()
return pe.executeTasks(func() any {
return pe.executeTasks(func() interface{} {
pe.lock.Lock()
defer pe.lock.Unlock()
return pe.container.RemoveAll()
}())
}
// Sync lets caller run fn thread-safe with pe, especially for the underlying container.
// Sync lets caller to run fn thread-safe with pe, especially for the underlying container.
func (pe *PeriodicalExecutor) Sync(fn func()) {
pe.lock.Lock()
defer pe.lock.Unlock()
@@ -96,7 +96,7 @@ func (pe *PeriodicalExecutor) Wait() {
})
}
func (pe *PeriodicalExecutor) addAndCheck(task any) (any, bool) {
func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool) {
pe.lock.Lock()
defer func() {
if !pe.guarded {
@@ -116,7 +116,7 @@ func (pe *PeriodicalExecutor) addAndCheck(task any) (any, bool) {
}
func (pe *PeriodicalExecutor) backgroundFlush() {
go func() {
threading.GoSafe(func() {
// flush before quit goroutine to avoid missing tasks
defer pe.Flush()
@@ -144,7 +144,7 @@ func (pe *PeriodicalExecutor) backgroundFlush() {
}
}
}
}()
})
}
func (pe *PeriodicalExecutor) doneExecution() {
@@ -157,20 +157,18 @@ func (pe *PeriodicalExecutor) enterExecution() {
})
}
func (pe *PeriodicalExecutor) executeTasks(tasks any) bool {
func (pe *PeriodicalExecutor) executeTasks(tasks interface{}) bool {
defer pe.doneExecution()
ok := pe.hasTasks(tasks)
if ok {
threading.RunSafe(func() {
pe.container.Execute(tasks)
})
pe.container.Execute(tasks)
}
return ok
}
func (pe *PeriodicalExecutor) hasTasks(tasks any) bool {
func (pe *PeriodicalExecutor) hasTasks(tasks interface{}) bool {
if tasks == nil {
return false
}

View File

@@ -17,22 +17,22 @@ const threshold = 10
type container struct {
interval time.Duration
tasks []int
execute func(tasks any)
execute func(tasks interface{})
}
func newContainer(interval time.Duration, execute func(tasks any)) *container {
func newContainer(interval time.Duration, execute func(tasks interface{})) *container {
return &container{
interval: interval,
execute: execute,
}
}
func (c *container) AddTask(task any) bool {
func (c *container) AddTask(task interface{}) bool {
c.tasks = append(c.tasks, task.(int))
return len(c.tasks) > threshold
}
func (c *container) Execute(tasks any) {
func (c *container) Execute(tasks interface{}) {
if c.execute != nil {
c.execute(tasks)
} else {
@@ -40,7 +40,7 @@ func (c *container) Execute(tasks any) {
}
}
func (c *container) RemoveAll() any {
func (c *container) RemoveAll() interface{} {
tasks := c.tasks
c.tasks = nil
return tasks
@@ -76,7 +76,7 @@ func TestPeriodicalExecutor_Bulk(t *testing.T) {
var vals []int
// avoid data race
var lock sync.Mutex
exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks any) {
exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks interface{}) {
t := tasks.([]int)
for _, each := range t {
lock.Lock()
@@ -108,83 +108,25 @@ func TestPeriodicalExecutor_Bulk(t *testing.T) {
lock.Unlock()
}
func TestPeriodicalExecutor_Panic(t *testing.T) {
// avoid data race
var lock sync.Mutex
ticker := timex.NewFakeTicker()
var (
executedTasks []int
expected []int
)
executor := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks any) {
tt := tasks.([]int)
lock.Lock()
executedTasks = append(executedTasks, tt...)
lock.Unlock()
if tt[0] == 0 {
panic("test")
}
}))
executor.newTicker = func(duration time.Duration) timex.Ticker {
return ticker
}
for i := 0; i < 30; i++ {
executor.Add(i)
expected = append(expected, i)
}
ticker.Tick()
ticker.Tick()
time.Sleep(time.Millisecond)
lock.Lock()
assert.Equal(t, expected, executedTasks)
lock.Unlock()
}
func TestPeriodicalExecutor_FlushPanic(t *testing.T) {
var (
executedTasks []int
expected []int
lock sync.Mutex
)
executor := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks any) {
tt := tasks.([]int)
lock.Lock()
executedTasks = append(executedTasks, tt...)
lock.Unlock()
if tt[0] == 0 {
panic("flush panic")
}
}))
for i := 0; i < 8; i++ {
executor.Add(i)
expected = append(expected, i)
}
executor.Flush()
lock.Lock()
assert.Equal(t, expected, executedTasks)
lock.Unlock()
}
func TestPeriodicalExecutor_Wait(t *testing.T) {
var lock sync.Mutex
executor := NewBulkExecutor(func(tasks []any) {
executer := NewBulkExecutor(func(tasks []interface{}) {
lock.Lock()
defer lock.Unlock()
time.Sleep(10 * time.Millisecond)
}, WithBulkTasks(1), WithBulkInterval(time.Second))
for i := 0; i < 10; i++ {
executor.Add(1)
executer.Add(1)
}
executor.Flush()
executor.Wait()
executer.Flush()
executer.Wait()
}
func TestPeriodicalExecutor_WaitFast(t *testing.T) {
const total = 3
var cnt int
var lock sync.Mutex
executor := NewBulkExecutor(func(tasks []any) {
executer := NewBulkExecutor(func(tasks []interface{}) {
defer func() {
cnt++
}()
@@ -193,15 +135,15 @@ func TestPeriodicalExecutor_WaitFast(t *testing.T) {
time.Sleep(10 * time.Millisecond)
}, WithBulkTasks(1), WithBulkInterval(10*time.Millisecond))
for i := 0; i < total; i++ {
executor.Add(2)
executer.Add(2)
}
executor.Flush()
executor.Wait()
executer.Flush()
executer.Wait()
assert.Equal(t, total, cnt)
}
func TestPeriodicalExecutor_Deadlock(t *testing.T) {
executor := NewBulkExecutor(func(tasks []any) {
executor := NewBulkExecutor(func(tasks []interface{}) {
}, WithBulkTasks(1), WithBulkInterval(time.Millisecond))
for i := 0; i < 1e5; i++ {
executor.Add(1)
@@ -209,7 +151,13 @@ func TestPeriodicalExecutor_Deadlock(t *testing.T) {
}
func TestPeriodicalExecutor_hasTasks(t *testing.T) {
ticker := timex.NewFakeTicker()
defer ticker.Stop()
exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, nil))
exec.newTicker = func(d time.Duration) timex.Ticker {
return ticker
}
assert.False(t, exec.hasTasks(nil))
assert.True(t, exec.hasTasks(1))
}

View File

@@ -5,4 +5,4 @@ import "time"
const defaultFlushInterval = time.Second
// Execute defines the method to execute tasks.
type Execute func(tasks []any)
type Execute func(tasks []interface{})

View File

@@ -35,7 +35,6 @@ func firstLine(file *os.File) (string, error) {
for {
buf := make([]byte, bufSize)
n, err := file.ReadAt(buf, offset)
if err != nil && err != io.EOF {
return "", err
}
@@ -46,10 +45,6 @@ func firstLine(file *os.File) (string, error) {
}
}
if err == io.EOF {
return string(append(first, buf[:n]...)), nil
}
first = append(first, buf[:n]...)
offset += bufSize
}
@@ -62,42 +57,30 @@ func lastLine(filename string, file *os.File) (string, error) {
}
var last []byte
bufLen := int64(bufSize)
offset := info.Size()
for offset > 0 {
if offset < bufLen {
bufLen = offset
for {
offset -= bufSize
if offset < 0 {
offset = 0
} else {
offset -= bufLen
}
buf := make([]byte, bufLen)
buf := make([]byte, bufSize)
n, err := file.ReadAt(buf, offset)
if err != nil && err != io.EOF {
return "", err
}
if n == 0 {
break
}
if buf[n-1] == '\n' {
buf = buf[:n-1]
n--
} else {
buf = buf[:n]
}
for i := n - 1; i >= 0; i-- {
if buf[i] == '\n' {
return string(append(buf[i+1:], last...)), nil
for n--; n >= 0; n-- {
if buf[n] == '\n' {
return string(append(buf[n+1:], last...)), nil
}
}
last = append(buf, last...)
}
return string(last), nil
}

View File

@@ -52,7 +52,6 @@ last line`
second line
last line
`
emptyContent = ``
)
func TestFirstLine(t *testing.T) {
@@ -75,31 +74,6 @@ func TestFirstLineShort(t *testing.T) {
assert.Equal(t, "first line", val)
}
func TestFirstLineError(t *testing.T) {
_, err := FirstLine("/tmp/does-not-exist")
assert.Error(t, err)
}
func TestFirstLineEmptyFile(t *testing.T) {
filename, err := fs.TempFilenameWithText(emptyContent)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := FirstLine(filename)
assert.Nil(t, err)
assert.Equal(t, "", val)
}
func TestFirstLineWithoutNewline(t *testing.T) {
filename, err := fs.TempFilenameWithText(longLine)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := FirstLine(filename)
assert.Nil(t, err)
assert.Equal(t, longLine, val)
}
func TestLastLine(t *testing.T) {
filename, err := fs.TempFilenameWithText(text)
assert.Nil(t, err)
@@ -120,16 +94,6 @@ func TestLastLineWithLastNewline(t *testing.T) {
assert.Equal(t, longLine, val)
}
func TestLastLineWithoutLastNewline(t *testing.T) {
filename, err := fs.TempFilenameWithText(longLine)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := LastLine(filename)
assert.Nil(t, err)
assert.Equal(t, longLine, val)
}
func TestLastLineShort(t *testing.T) {
filename, err := fs.TempFilenameWithText(shortText)
assert.Nil(t, err)
@@ -149,72 +113,3 @@ func TestLastLineWithLastNewlineShort(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, "last line", val)
}
func TestLastLineError(t *testing.T) {
_, err := LastLine("/tmp/does-not-exist")
assert.Error(t, err)
}
func TestLastLineEmptyFile(t *testing.T) {
filename, err := fs.TempFilenameWithText(emptyContent)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := LastLine(filename)
assert.Nil(t, err)
assert.Equal(t, "", val)
}
func TestFirstLineExactlyBufSize(t *testing.T) {
content := make([]byte, bufSize)
for i := range content {
content[i] = 'a'
}
content[bufSize-1] = '\n' // Ensure there is a newline at the edge
filename, err := fs.TempFilenameWithText(string(content))
assert.Nil(t, err)
defer os.Remove(filename)
val, err := FirstLine(filename)
assert.Nil(t, err)
assert.Equal(t, string(content[:bufSize-1]), val)
}
func TestLastLineExactlyBufSize(t *testing.T) {
content := make([]byte, bufSize)
for i := range content {
content[i] = 'a'
}
content[bufSize-1] = '\n' // Ensure there is a newline at the edge
filename, err := fs.TempFilenameWithText(string(content))
assert.Nil(t, err)
defer os.Remove(filename)
val, err := LastLine(filename)
assert.Nil(t, err)
assert.Equal(t, string(content[:bufSize-1]), val)
}
func TestFirstLineLargeFile(t *testing.T) {
content := text + text + text + "\n" + "extra"
filename, err := fs.TempFilenameWithText(content)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := FirstLine(filename)
assert.Nil(t, err)
assert.Equal(t, "first line", val)
}
func TestLastLineLargeFile(t *testing.T) {
content := text + text + text + "\n" + "extra"
filename, err := fs.TempFilenameWithText(content)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := LastLine(filename)
assert.Nil(t, err)
assert.Equal(t, "extra", val)
}

View File

@@ -5,7 +5,7 @@ import "gopkg.in/cheggaaa/pb.v1"
type (
// A Scanner is used to read lines.
Scanner interface {
// Scan checks if it has remaining to read.
// Scan checks if has remaining to read.
Scan() bool
// Text returns next line.
Text() string

View File

@@ -1,4 +1,5 @@
//go:build windows
// +build windows
package fs

View File

@@ -1,4 +1,5 @@
//go:build linux || darwin || freebsd
//go:build linux || darwin
// +build linux darwin
package fs

View File

@@ -11,29 +11,29 @@ import (
// The file is kept as open, the caller should close the file handle,
// and remove the file by name.
func TempFileWithText(text string) (*os.File, error) {
tmpFile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text)))
tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text)))
if err != nil {
return nil, err
}
if err := os.WriteFile(tmpFile.Name(), []byte(text), os.ModeTemporary); err != nil {
if err := os.WriteFile(tmpfile.Name(), []byte(text), os.ModeTemporary); err != nil {
return nil, err
}
return tmpFile, nil
return tmpfile, nil
}
// TempFilenameWithText creates the file with the given content,
// and returns the filename (full path).
// The caller should remove the file after use.
func TempFilenameWithText(text string) (string, error) {
tmpFile, err := TempFileWithText(text)
tmpfile, err := TempFileWithText(text)
if err != nil {
return "", err
}
filename := tmpFile.Name()
if err = tmpFile.Close(); err != nil {
filename := tmpfile.Name()
if err = tmpfile.Close(); err != nil {
return "", err
}

View File

@@ -1,9 +1,6 @@
package fx
import (
"github.com/zeromicro/go-zero/core/errorx"
"github.com/zeromicro/go-zero/core/threading"
)
import "github.com/zeromicro/go-zero/core/threading"
// Parallel runs fns parallelly and waits for done.
func Parallel(fns ...func()) {
@@ -13,20 +10,3 @@ func Parallel(fns ...func()) {
}
group.Wait()
}
func ParallelErr(fns ...func() error) error {
var be errorx.BatchError
group := threading.NewRoutineGroup()
for _, fn := range fns {
f := fn
group.RunSafe(func() {
if err := f(); err != nil {
be.Add(err)
}
})
}
group.Wait()
return be.Err()
}

View File

@@ -1,7 +1,6 @@
package fx
import (
"errors"
"sync/atomic"
"testing"
"time"
@@ -23,54 +22,3 @@ func TestParallel(t *testing.T) {
})
assert.Equal(t, int32(6), count)
}
func TestParallelErr(t *testing.T) {
var count int32
err := ParallelErr(
func() error {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, 1)
return errors.New("failed to exec #1")
},
func() error {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, 2)
return errors.New("failed to exec #2")
},
func() error {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, 3)
return nil
},
)
assert.Equal(t, int32(6), count)
assert.Error(t, err)
assert.ErrorContains(t, err, "failed to exec #1", "failed to exec #2")
}
func TestParallelErrErrorNil(t *testing.T) {
var count int32
err := ParallelErr(
func() error {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, 1)
return nil
},
func() error {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, 2)
return nil
},
func() error {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, 3)
return nil
},
)
assert.Equal(t, int32(6), count)
assert.NoError(t, err)
}

View File

@@ -1,12 +1,6 @@
package fx
import (
"context"
"errors"
"time"
"github.com/zeromicro/go-zero/core/errorx"
)
import "github.com/zeromicro/go-zero/core/errorx"
const defaultRetryTimes = 3
@@ -15,110 +9,36 @@ type (
RetryOption func(*retryOptions)
retryOptions struct {
times int
interval time.Duration
timeout time.Duration
ignoreErrors []error
times int
}
)
// DoWithRetry runs fn, and retries if failed. Default to retry 3 times.
// Note that if the fn function accesses global variables outside the function
// and performs modification operations, it is best to lock them,
// otherwise there may be data race issues
func DoWithRetry(fn func() error, opts ...RetryOption) error {
return retry(context.Background(), func(errChan chan error, retryCount int) {
errChan <- fn()
}, opts...)
}
// DoWithRetryCtx runs fn, and retries if failed. Default to retry 3 times.
// fn retryCount indicates the current number of retries, starting from 0
// Note that if the fn function accesses global variables outside the function
// and performs modification operations, it is best to lock them,
// otherwise there may be data race issues
func DoWithRetryCtx(ctx context.Context, fn func(ctx context.Context, retryCount int) error,
opts ...RetryOption) error {
return retry(ctx, func(errChan chan error, retryCount int) {
errChan <- fn(ctx, retryCount)
}, opts...)
}
func retry(ctx context.Context, fn func(errChan chan error, retryCount int), opts ...RetryOption) error {
options := newRetryOptions()
for _, opt := range opts {
opt(options)
}
var berr errorx.BatchError
var cancelFunc context.CancelFunc
if options.timeout > 0 {
ctx, cancelFunc = context.WithTimeout(ctx, options.timeout)
defer cancelFunc()
}
errChan := make(chan error, 1)
for i := 0; i < options.times; i++ {
go fn(errChan, i)
select {
case err := <-errChan:
if err != nil {
for _, ignoreErr := range options.ignoreErrors {
if errors.Is(err, ignoreErr) {
return nil
}
}
berr.Add(err)
} else {
return nil
}
case <-ctx.Done():
berr.Add(ctx.Err())
return berr.Err()
}
if options.interval > 0 {
select {
case <-ctx.Done():
berr.Add(ctx.Err())
return berr.Err()
case <-time.After(options.interval):
}
if err := fn(); err != nil {
berr.Add(err)
} else {
return nil
}
}
return berr.Err()
}
// WithIgnoreErrors Ignore the specified errors
func WithIgnoreErrors(ignoreErrors []error) RetryOption {
return func(options *retryOptions) {
options.ignoreErrors = ignoreErrors
}
}
// WithInterval customizes a DoWithRetry call with given interval.
func WithInterval(interval time.Duration) RetryOption {
return func(options *retryOptions) {
options.interval = interval
}
}
// WithRetry customizes a DoWithRetry call with given retry times.
// WithRetry customize a DoWithRetry call with given retry times.
func WithRetry(times int) RetryOption {
return func(options *retryOptions) {
options.times = times
}
}
// WithTimeout customizes a DoWithRetry call with given timeout.
func WithTimeout(timeout time.Duration) RetryOption {
return func(options *retryOptions) {
options.timeout = timeout
}
}
func newRetryOptions() *retryOptions {
return &retryOptions{
times: defaultRetryTimes,

View File

@@ -1,10 +1,8 @@
package fx
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
@@ -14,153 +12,31 @@ func TestRetry(t *testing.T) {
return errors.New("any")
}))
times1 := 0
var times int
assert.Nil(t, DoWithRetry(func() error {
times1++
if times1 == defaultRetryTimes {
times++
if times == defaultRetryTimes {
return nil
}
return errors.New("any")
}))
times2 := 0
times = 0
assert.NotNil(t, DoWithRetry(func() error {
times2++
if times2 == defaultRetryTimes+1 {
times++
if times == defaultRetryTimes+1 {
return nil
}
return errors.New("any")
}))
total := 2 * defaultRetryTimes
times3 := 0
times = 0
assert.Nil(t, DoWithRetry(func() error {
times3++
if times3 == total {
times++
if times == total {
return nil
}
return errors.New("any")
}, WithRetry(total)))
}
func TestRetryWithTimeout(t *testing.T) {
assert.Nil(t, DoWithRetry(func() error {
return nil
}, WithTimeout(time.Millisecond*500)))
times1 := 0
assert.Nil(t, DoWithRetry(func() error {
times1++
if times1 == 1 {
return errors.New("any ")
}
time.Sleep(time.Millisecond * 150)
return nil
}, WithTimeout(time.Millisecond*250)))
total := defaultRetryTimes
times2 := 0
assert.Nil(t, DoWithRetry(func() error {
times2++
if times2 == total {
return nil
}
time.Sleep(time.Millisecond * 50)
return errors.New("any")
}, WithTimeout(time.Millisecond*50*(time.Duration(total)+2))))
assert.NotNil(t, DoWithRetry(func() error {
return errors.New("any")
}, WithTimeout(time.Millisecond*250)))
}
func TestRetryWithInterval(t *testing.T) {
times1 := 0
assert.NotNil(t, DoWithRetry(func() error {
times1++
if times1 == 1 {
return errors.New("any")
}
time.Sleep(time.Millisecond * 150)
return nil
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
times2 := 0
assert.NotNil(t, DoWithRetry(func() error {
times2++
if times2 == 2 {
return nil
}
time.Sleep(time.Millisecond * 150)
return errors.New("any ")
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
}
func TestRetryWithWithIgnoreErrors(t *testing.T) {
ignoreErr1 := errors.New("ignore error1")
ignoreErr2 := errors.New("ignore error2")
ignoreErrs := []error{ignoreErr1, ignoreErr2}
assert.Nil(t, DoWithRetry(func() error {
return ignoreErr1
}, WithIgnoreErrors(ignoreErrs)))
assert.Nil(t, DoWithRetry(func() error {
return ignoreErr2
}, WithIgnoreErrors(ignoreErrs)))
assert.NotNil(t, DoWithRetry(func() error {
return errors.New("any")
}))
}
func TestRetryCtx(t *testing.T) {
t.Run("with timeout", func(t *testing.T) {
assert.NotNil(t, DoWithRetryCtx(context.Background(), func(ctx context.Context, retryCount int) error {
if retryCount == 0 {
return errors.New("any")
}
time.Sleep(time.Millisecond * 150)
return nil
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
assert.NotNil(t, DoWithRetryCtx(context.Background(), func(ctx context.Context, retryCount int) error {
if retryCount == 1 {
return nil
}
time.Sleep(time.Millisecond * 150)
return errors.New("any ")
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
})
t.Run("with deadline exceeded", func(t *testing.T) {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond*250))
defer cancel()
var times int
assert.Error(t, DoWithRetryCtx(ctx, func(ctx context.Context, retryCount int) error {
times++
time.Sleep(time.Millisecond * 150)
return errors.New("any")
}, WithInterval(time.Millisecond*150)))
assert.Equal(t, 1, times)
})
t.Run("with deadline not exceeded", func(t *testing.T) {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond*250))
defer cancel()
var times int
assert.NoError(t, DoWithRetryCtx(ctx, func(ctx context.Context, retryCount int) error {
times++
if times == defaultRetryTimes {
return nil
}
time.Sleep(time.Millisecond * 50)
return errors.New("any")
}))
assert.Equal(t, defaultRetryTimes, times)
})
}

View File

@@ -21,31 +21,31 @@ type (
}
// FilterFunc defines the method to filter a Stream.
FilterFunc func(item any) bool
FilterFunc func(item interface{}) bool
// ForAllFunc defines the method to handle all elements in a Stream.
ForAllFunc func(pipe <-chan any)
ForAllFunc func(pipe <-chan interface{})
// ForEachFunc defines the method to handle each element in a Stream.
ForEachFunc func(item any)
ForEachFunc func(item interface{})
// GenerateFunc defines the method to send elements into a Stream.
GenerateFunc func(source chan<- any)
GenerateFunc func(source chan<- interface{})
// KeyFunc defines the method to generate keys for the elements in a Stream.
KeyFunc func(item any) any
KeyFunc func(item interface{}) interface{}
// LessFunc defines the method to compare the elements in a Stream.
LessFunc func(a, b any) bool
LessFunc func(a, b interface{}) bool
// MapFunc defines the method to map each element to another object in a Stream.
MapFunc func(item any) any
MapFunc func(item interface{}) interface{}
// Option defines the method to customize a Stream.
Option func(opts *rxOptions)
// ParallelFunc defines the method to handle elements parallelly.
ParallelFunc func(item any)
ParallelFunc func(item interface{})
// ReduceFunc defines the method to reduce all the elements in a Stream.
ReduceFunc func(pipe <-chan any) (any, error)
ReduceFunc func(pipe <-chan interface{}) (interface{}, error)
// WalkFunc defines the method to walk through all the elements in a Stream.
WalkFunc func(item any, pipe chan<- any)
WalkFunc func(item interface{}, pipe chan<- interface{})
// A Stream is a stream that can be used to do stream processing.
Stream struct {
source <-chan any
source <-chan interface{}
}
)
@@ -56,7 +56,7 @@ func Concat(s Stream, others ...Stream) Stream {
// From constructs a Stream from the given GenerateFunc.
func From(generate GenerateFunc) Stream {
source := make(chan any)
source := make(chan interface{})
threading.GoSafe(func() {
defer close(source)
@@ -67,8 +67,8 @@ func From(generate GenerateFunc) Stream {
}
// Just converts the given arbitrary items to a Stream.
func Just(items ...any) Stream {
source := make(chan any, len(items))
func Just(items ...interface{}) Stream {
source := make(chan interface{}, len(items))
for _, item := range items {
source <- item
}
@@ -78,16 +78,16 @@ func Just(items ...any) Stream {
}
// Range converts the given channel to a Stream.
func Range(source <-chan any) Stream {
func Range(source <-chan interface{}) Stream {
return Stream{
source: source,
}
}
// AllMatch returns whether all elements of this stream match the provided predicate.
// AllMach returns whether all elements of this stream match the provided predicate.
// May not evaluate the predicate on all elements if not necessary for determining the result.
// If the stream is empty then true is returned and the predicate is not evaluated.
func (s Stream) AllMatch(predicate func(item any) bool) bool {
func (s Stream) AllMach(predicate func(item interface{}) bool) bool {
for item := range s.source {
if !predicate(item) {
// make sure the former goroutine not block, and current func returns fast.
@@ -99,10 +99,10 @@ func (s Stream) AllMatch(predicate func(item any) bool) bool {
return true
}
// AnyMatch returns whether any elements of this stream match the provided predicate.
// AnyMach returns whether any elements of this stream match the provided predicate.
// May not evaluate the predicate on all elements if not necessary for determining the result.
// If the stream is empty then false is returned and the predicate is not evaluated.
func (s Stream) AnyMatch(predicate func(item any) bool) bool {
func (s Stream) AnyMach(predicate func(item interface{}) bool) bool {
for item := range s.source {
if predicate(item) {
// make sure the former goroutine not block, and current func returns fast.
@@ -121,7 +121,7 @@ func (s Stream) Buffer(n int) Stream {
n = 0
}
source := make(chan any, n)
source := make(chan interface{}, n)
go func() {
for item := range s.source {
source <- item
@@ -134,7 +134,7 @@ func (s Stream) Buffer(n int) Stream {
// Concat returns a Stream that concatenated other streams
func (s Stream) Concat(others ...Stream) Stream {
source := make(chan any)
source := make(chan interface{})
go func() {
group := threading.NewRoutineGroup()
@@ -170,12 +170,12 @@ func (s Stream) Count() (count int) {
// Distinct removes the duplicated items base on the given KeyFunc.
func (s Stream) Distinct(fn KeyFunc) Stream {
source := make(chan any)
source := make(chan interface{})
threading.GoSafe(func() {
defer close(source)
keys := make(map[any]lang.PlaceholderType)
keys := make(map[interface{}]lang.PlaceholderType)
for item := range s.source {
key := fn(item)
if _, ok := keys[key]; !ok {
@@ -195,7 +195,7 @@ func (s Stream) Done() {
// Filter filters the items by the given FilterFunc.
func (s Stream) Filter(fn FilterFunc, opts ...Option) Stream {
return s.Walk(func(item any, pipe chan<- any) {
return s.Walk(func(item interface{}, pipe chan<- interface{}) {
if fn(item) {
pipe <- item
}
@@ -203,7 +203,7 @@ func (s Stream) Filter(fn FilterFunc, opts ...Option) Stream {
}
// First returns the first item, nil if no items.
func (s Stream) First() any {
func (s Stream) First() interface{} {
for item := range s.source {
// make sure the former goroutine not block, and current func returns fast.
go drain(s.source)
@@ -229,13 +229,13 @@ func (s Stream) ForEach(fn ForEachFunc) {
// Group groups the elements into different groups based on their keys.
func (s Stream) Group(fn KeyFunc) Stream {
groups := make(map[any][]any)
groups := make(map[interface{}][]interface{})
for item := range s.source {
key := fn(item)
groups[key] = append(groups[key], item)
}
source := make(chan any)
source := make(chan interface{})
go func() {
for _, group := range groups {
source <- group
@@ -252,7 +252,7 @@ func (s Stream) Head(n int64) Stream {
panic("n must be greater than 0")
}
source := make(chan any)
source := make(chan interface{})
go func() {
for item := range s.source {
@@ -279,7 +279,7 @@ func (s Stream) Head(n int64) Stream {
}
// Last returns the last item, or nil if no items.
func (s Stream) Last() (item any) {
func (s Stream) Last() (item interface{}) {
for item = range s.source {
}
return
@@ -287,53 +287,29 @@ func (s Stream) Last() (item any) {
// Map converts each item to another corresponding item, which means it's a 1:1 model.
func (s Stream) Map(fn MapFunc, opts ...Option) Stream {
return s.Walk(func(item any, pipe chan<- any) {
return s.Walk(func(item interface{}, pipe chan<- interface{}) {
pipe <- fn(item)
}, opts...)
}
// Max returns the maximum item from the underlying source.
func (s Stream) Max(less LessFunc) any {
var max any
for item := range s.source {
if max == nil || less(max, item) {
max = item
}
}
return max
}
// Merge merges all the items into a slice and generates a new stream.
func (s Stream) Merge() Stream {
var items []any
var items []interface{}
for item := range s.source {
items = append(items, item)
}
source := make(chan any, 1)
source := make(chan interface{}, 1)
source <- items
close(source)
return Range(source)
}
// Min returns the minimum item from the underlying source.
func (s Stream) Min(less LessFunc) any {
var min any
for item := range s.source {
if min == nil || less(item, min) {
min = item
}
}
return min
}
// NoneMatch returns whether all elements of this stream don't match the provided predicate.
// May not evaluate the predicate on all elements if not necessary for determining the result.
// If the stream is empty then true is returned and the predicate is not evaluated.
func (s Stream) NoneMatch(predicate func(item any) bool) bool {
func (s Stream) NoneMatch(predicate func(item interface{}) bool) bool {
for item := range s.source {
if predicate(item) {
// make sure the former goroutine not block, and current func returns fast.
@@ -347,19 +323,19 @@ func (s Stream) NoneMatch(predicate func(item any) bool) bool {
// Parallel applies the given ParallelFunc to each item concurrently with given number of workers.
func (s Stream) Parallel(fn ParallelFunc, opts ...Option) {
s.Walk(func(item any, pipe chan<- any) {
s.Walk(func(item interface{}, pipe chan<- interface{}) {
fn(item)
}, opts...).Done()
}
// Reduce is a utility method to let the caller deal with the underlying channel.
func (s Stream) Reduce(fn ReduceFunc) (any, error) {
// Reduce is an utility method to let the caller deal with the underlying channel.
func (s Stream) Reduce(fn ReduceFunc) (interface{}, error) {
return fn(s.source)
}
// Reverse reverses the elements in the stream.
func (s Stream) Reverse() Stream {
var items []any
var items []interface{}
for item := range s.source {
items = append(items, item)
}
@@ -381,7 +357,7 @@ func (s Stream) Skip(n int64) Stream {
return s
}
source := make(chan any)
source := make(chan interface{})
go func() {
for item := range s.source {
@@ -400,7 +376,7 @@ func (s Stream) Skip(n int64) Stream {
// Sort sorts the items from the underlying source.
func (s Stream) Sort(less LessFunc) Stream {
var items []any
var items []interface{}
for item := range s.source {
items = append(items, item)
}
@@ -418,9 +394,9 @@ func (s Stream) Split(n int) Stream {
panic("n should be greater than 0")
}
source := make(chan any)
source := make(chan interface{})
go func() {
var chunk []any
var chunk []interface{}
for item := range s.source {
chunk = append(chunk, item)
if len(chunk) == n {
@@ -443,7 +419,7 @@ func (s Stream) Tail(n int64) Stream {
panic("n should be greater than 0")
}
source := make(chan any)
source := make(chan interface{})
go func() {
ring := collection.NewRing(int(n))
@@ -470,7 +446,7 @@ func (s Stream) Walk(fn WalkFunc, opts ...Option) Stream {
}
func (s Stream) walkLimited(fn WalkFunc, option *rxOptions) Stream {
pipe := make(chan any, option.workers)
pipe := make(chan interface{}, option.workers)
go func() {
var wg sync.WaitGroup
@@ -501,7 +477,7 @@ func (s Stream) walkLimited(fn WalkFunc, option *rxOptions) Stream {
}
func (s Stream) walkUnlimited(fn WalkFunc, option *rxOptions) Stream {
pipe := make(chan any, option.workers)
pipe := make(chan interface{}, option.workers)
go func() {
var wg sync.WaitGroup
@@ -553,7 +529,7 @@ func buildOptions(opts ...Option) *rxOptions {
}
// drain drains the given channel.
func drain(channel <-chan any) {
func drain(channel <-chan interface{}) {
for range channel {
}
}

View File

@@ -23,7 +23,7 @@ func TestBuffer(t *testing.T) {
var count int32
var wait sync.WaitGroup
wait.Add(1)
From(func(source chan<- any) {
From(func(source chan<- interface{}) {
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
@@ -36,7 +36,7 @@ func TestBuffer(t *testing.T) {
return
}
}
}).Buffer(N).ForAll(func(pipe <-chan any) {
}).Buffer(N).ForAll(func(pipe <-chan interface{}) {
wait.Wait()
// why N+1, because take one more to wait for sending into the channel
assert.Equal(t, int32(N+1), atomic.LoadInt32(&count))
@@ -47,7 +47,7 @@ func TestBuffer(t *testing.T) {
func TestBufferNegative(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var result int
Just(1, 2, 3, 4).Buffer(-1).Reduce(func(pipe <-chan any) (any, error) {
Just(1, 2, 3, 4).Buffer(-1).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe {
result += item.(int)
}
@@ -61,22 +61,22 @@ func TestCount(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
tests := []struct {
name string
elements []any
elements []interface{}
}{
{
name: "no elements with nil",
},
{
name: "no elements",
elements: []any{},
elements: []interface{}{},
},
{
name: "1 element",
elements: []any{1},
elements: []interface{}{1},
},
{
name: "multiple elements",
elements: []any{1, 2, 3},
elements: []interface{}{1, 2, 3},
},
}
@@ -92,7 +92,7 @@ func TestCount(t *testing.T) {
func TestDone(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var count int32
Just(1, 2, 3).Walk(func(item any, pipe chan<- any) {
Just(1, 2, 3).Walk(func(item interface{}, pipe chan<- interface{}) {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, int32(item.(int)))
}).Done()
@@ -103,7 +103,7 @@ func TestDone(t *testing.T) {
func TestJust(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var result int
Just(1, 2, 3, 4).Reduce(func(pipe <-chan any) (any, error) {
Just(1, 2, 3, 4).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe {
result += item.(int)
}
@@ -116,9 +116,9 @@ func TestJust(t *testing.T) {
func TestDistinct(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var result int
Just(4, 1, 3, 2, 3, 4).Distinct(func(item any) any {
Just(4, 1, 3, 2, 3, 4).Distinct(func(item interface{}) interface{} {
return item
}).Reduce(func(pipe <-chan any) (any, error) {
}).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe {
result += item.(int)
}
@@ -131,9 +131,9 @@ func TestDistinct(t *testing.T) {
func TestFilter(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var result int
Just(1, 2, 3, 4).Filter(func(item any) bool {
Just(1, 2, 3, 4).Filter(func(item interface{}) bool {
return item.(int)%2 == 0
}).Reduce(func(pipe <-chan any) (any, error) {
}).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe {
result += item.(int)
}
@@ -154,9 +154,9 @@ func TestFirst(t *testing.T) {
func TestForAll(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var result int
Just(1, 2, 3, 4).Filter(func(item any) bool {
Just(1, 2, 3, 4).Filter(func(item interface{}) bool {
return item.(int)%2 == 0
}).ForAll(func(pipe <-chan any) {
}).ForAll(func(pipe <-chan interface{}) {
for item := range pipe {
result += item.(int)
}
@@ -168,11 +168,11 @@ func TestForAll(t *testing.T) {
func TestGroup(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var groups [][]int
Just(10, 11, 20, 21).Group(func(item any) any {
Just(10, 11, 20, 21).Group(func(item interface{}) interface{} {
v := item.(int)
return v / 10
}).ForEach(func(item any) {
v := item.([]any)
}).ForEach(func(item interface{}) {
v := item.([]interface{})
var group []int
for _, each := range v {
group = append(group, each.(int))
@@ -191,7 +191,7 @@ func TestGroup(t *testing.T) {
func TestHead(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var result int
Just(1, 2, 3, 4).Head(2).Reduce(func(pipe <-chan any) (any, error) {
Just(1, 2, 3, 4).Head(2).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe {
result += item.(int)
}
@@ -204,7 +204,7 @@ func TestHead(t *testing.T) {
func TestHeadZero(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
assert.Panics(t, func() {
Just(1, 2, 3, 4).Head(0).Reduce(func(pipe <-chan any) (any, error) {
Just(1, 2, 3, 4).Head(0).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
return nil, nil
})
})
@@ -214,7 +214,7 @@ func TestHeadZero(t *testing.T) {
func TestHeadMore(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var result int
Just(1, 2, 3, 4).Head(6).Reduce(func(pipe <-chan any) (any, error) {
Just(1, 2, 3, 4).Head(6).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe {
result += item.(int)
}
@@ -245,14 +245,14 @@ func TestMap(t *testing.T) {
expect int
}{
{
mapper: func(item any) any {
mapper: func(item interface{}) interface{} {
v := item.(int)
return v * v
},
expect: 30,
},
{
mapper: func(item any) any {
mapper: func(item interface{}) interface{} {
v := item.(int)
if v%2 == 0 {
return 0
@@ -262,7 +262,7 @@ func TestMap(t *testing.T) {
expect: 10,
},
{
mapper: func(item any) any {
mapper: func(item interface{}) interface{} {
v := item.(int)
if v%2 == 0 {
panic(v)
@@ -283,12 +283,12 @@ func TestMap(t *testing.T) {
} else {
workers = runtime.NumCPU()
}
From(func(source chan<- any) {
From(func(source chan<- interface{}) {
for i := 1; i < 5; i++ {
source <- i
}
}).Map(test.mapper, WithWorkers(workers)).Reduce(
func(pipe <-chan any) (any, error) {
func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe {
result += item.(int)
}
@@ -303,8 +303,8 @@ func TestMap(t *testing.T) {
func TestMerge(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
Just(1, 2, 3, 4).Merge().ForEach(func(item any) {
assert.ElementsMatch(t, []any{1, 2, 3, 4}, item.([]any))
Just(1, 2, 3, 4).Merge().ForEach(func(item interface{}) {
assert.ElementsMatch(t, []interface{}{1, 2, 3, 4}, item.([]interface{}))
})
})
}
@@ -312,7 +312,7 @@ func TestMerge(t *testing.T) {
func TestParallelJust(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var count int32
Just(1, 2, 3).Parallel(func(item any) {
Just(1, 2, 3).Parallel(func(item interface{}) {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, int32(item.(int)))
}, UnlimitedWorkers())
@@ -322,8 +322,8 @@ func TestParallelJust(t *testing.T) {
func TestReverse(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
Just(1, 2, 3, 4).Reverse().Merge().ForEach(func(item any) {
assert.ElementsMatch(t, []any{4, 3, 2, 1}, item.([]any))
Just(1, 2, 3, 4).Reverse().Merge().ForEach(func(item interface{}) {
assert.ElementsMatch(t, []interface{}{4, 3, 2, 1}, item.([]interface{}))
})
})
}
@@ -331,9 +331,9 @@ func TestReverse(t *testing.T) {
func TestSort(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var prev int
Just(5, 3, 7, 1, 9, 6, 4, 8, 2).Sort(func(a, b any) bool {
Just(5, 3, 7, 1, 9, 6, 4, 8, 2).Sort(func(a, b interface{}) bool {
return a.(int) < b.(int)
}).ForEach(func(item any) {
}).ForEach(func(item interface{}) {
next := item.(int)
assert.True(t, prev < next)
prev = next
@@ -346,12 +346,12 @@ func TestSplit(t *testing.T) {
assert.Panics(t, func() {
Just(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).Split(0).Done()
})
var chunks [][]any
Just(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).Split(4).ForEach(func(item any) {
chunk := item.([]any)
var chunks [][]interface{}
Just(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).Split(4).ForEach(func(item interface{}) {
chunk := item.([]interface{})
chunks = append(chunks, chunk)
})
assert.EqualValues(t, [][]any{
assert.EqualValues(t, [][]interface{}{
{1, 2, 3, 4},
{5, 6, 7, 8},
{9, 10},
@@ -362,7 +362,7 @@ func TestSplit(t *testing.T) {
func TestTail(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var result int
Just(1, 2, 3, 4).Tail(2).Reduce(func(pipe <-chan any) (any, error) {
Just(1, 2, 3, 4).Tail(2).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe {
result += item.(int)
}
@@ -375,7 +375,7 @@ func TestTail(t *testing.T) {
func TestTailZero(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
assert.Panics(t, func() {
Just(1, 2, 3, 4).Tail(0).Reduce(func(pipe <-chan any) (any, error) {
Just(1, 2, 3, 4).Tail(0).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
return nil, nil
})
})
@@ -385,11 +385,11 @@ func TestTailZero(t *testing.T) {
func TestWalk(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var result int
Just(1, 2, 3, 4, 5).Walk(func(item any, pipe chan<- any) {
Just(1, 2, 3, 4, 5).Walk(func(item interface{}, pipe chan<- interface{}) {
if item.(int)%2 != 0 {
pipe <- item
}
}, UnlimitedWorkers()).ForEach(func(item any) {
}, UnlimitedWorkers()).ForEach(func(item interface{}) {
result += item.(int)
})
assert.Equal(t, 9, result)
@@ -398,16 +398,16 @@ func TestWalk(t *testing.T) {
func TestStream_AnyMach(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
assetEqual(t, false, Just(1, 2, 3).AnyMatch(func(item any) bool {
assetEqual(t, false, Just(1, 2, 3).AnyMach(func(item interface{}) bool {
return item.(int) == 4
}))
assetEqual(t, false, Just(1, 2, 3).AnyMatch(func(item any) bool {
assetEqual(t, false, Just(1, 2, 3).AnyMach(func(item interface{}) bool {
return item.(int) == 0
}))
assetEqual(t, true, Just(1, 2, 3).AnyMatch(func(item any) bool {
assetEqual(t, true, Just(1, 2, 3).AnyMach(func(item interface{}) bool {
return item.(int) == 2
}))
assetEqual(t, true, Just(1, 2, 3).AnyMatch(func(item any) bool {
assetEqual(t, true, Just(1, 2, 3).AnyMach(func(item interface{}) bool {
return item.(int) == 2
}))
})
@@ -416,17 +416,17 @@ func TestStream_AnyMach(t *testing.T) {
func TestStream_AllMach(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
assetEqual(
t, true, Just(1, 2, 3).AllMatch(func(item any) bool {
t, true, Just(1, 2, 3).AllMach(func(item interface{}) bool {
return true
}),
)
assetEqual(
t, false, Just(1, 2, 3).AllMatch(func(item any) bool {
t, false, Just(1, 2, 3).AllMach(func(item interface{}) bool {
return false
}),
)
assetEqual(
t, false, Just(1, 2, 3).AllMatch(func(item any) bool {
t, false, Just(1, 2, 3).AllMach(func(item interface{}) bool {
return item.(int) == 1
}),
)
@@ -436,17 +436,17 @@ func TestStream_AllMach(t *testing.T) {
func TestStream_NoneMatch(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
assetEqual(
t, true, Just(1, 2, 3).NoneMatch(func(item any) bool {
t, true, Just(1, 2, 3).NoneMatch(func(item interface{}) bool {
return false
}),
)
assetEqual(
t, false, Just(1, 2, 3).NoneMatch(func(item any) bool {
t, false, Just(1, 2, 3).NoneMatch(func(item interface{}) bool {
return true
}),
)
assetEqual(
t, true, Just(1, 2, 3).NoneMatch(func(item any) bool {
t, true, Just(1, 2, 3).NoneMatch(func(item interface{}) bool {
return item.(int) == 4
}),
)
@@ -455,19 +455,19 @@ func TestStream_NoneMatch(t *testing.T) {
func TestConcat(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
a1 := []any{1, 2, 3}
a2 := []any{4, 5, 6}
a1 := []interface{}{1, 2, 3}
a2 := []interface{}{4, 5, 6}
s1 := Just(a1...)
s2 := Just(a2...)
stream := Concat(s1, s2)
var items []any
var items []interface{}
for item := range stream.source {
items = append(items, item)
}
sort.Slice(items, func(i, j int) bool {
return items[i].(int) < items[j].(int)
})
ints := make([]any, 0)
ints := make([]interface{}, 0)
ints = append(ints, a1...)
ints = append(ints, a2...)
assetEqual(t, ints, items)
@@ -479,7 +479,7 @@ func TestStream_Skip(t *testing.T) {
assetEqual(t, 3, Just(1, 2, 3, 4).Skip(1).Count())
assetEqual(t, 1, Just(1, 2, 3, 4).Skip(3).Count())
assetEqual(t, 4, Just(1, 2, 3, 4).Skip(0).Count())
equal(t, Just(1, 2, 3, 4).Skip(3), []any{4})
equal(t, Just(1, 2, 3, 4).Skip(3), []interface{}{4})
assert.Panics(t, func() {
Just(1, 2, 3, 4).Skip(-1)
})
@@ -489,104 +489,27 @@ func TestStream_Skip(t *testing.T) {
func TestStream_Concat(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
stream := Just(1).Concat(Just(2), Just(3))
var items []any
var items []interface{}
for item := range stream.source {
items = append(items, item)
}
sort.Slice(items, func(i, j int) bool {
return items[i].(int) < items[j].(int)
})
assetEqual(t, []any{1, 2, 3}, items)
assetEqual(t, []interface{}{1, 2, 3}, items)
just := Just(1)
equal(t, just.Concat(just), []any{1})
})
}
func TestStream_Max(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
tests := []struct {
name string
elements []any
max any
}{
{
name: "no elements with nil",
},
{
name: "no elements",
elements: []any{},
max: nil,
},
{
name: "1 element",
elements: []any{1},
max: 1,
},
{
name: "multiple elements",
elements: []any{1, 2, 9, 5, 8},
max: 9,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
val := Just(test.elements...).Max(func(a, b any) bool {
return a.(int) < b.(int)
})
assetEqual(t, test.max, val)
})
}
})
}
func TestStream_Min(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
tests := []struct {
name string
elements []any
min any
}{
{
name: "no elements with nil",
min: nil,
},
{
name: "no elements",
elements: []any{},
min: nil,
},
{
name: "1 element",
elements: []any{1},
min: 1,
},
{
name: "multiple elements",
elements: []any{-1, 1, 2, 9, 5, 8},
min: -1,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
val := Just(test.elements...).Min(func(a, b any) bool {
return a.(int) < b.(int)
})
assetEqual(t, test.min, val)
})
}
equal(t, just.Concat(just), []interface{}{1})
})
}
func BenchmarkParallelMapReduce(b *testing.B) {
b.ReportAllocs()
mapper := func(v any) any {
mapper := func(v interface{}) interface{} {
return v.(int64) * v.(int64)
}
reducer := func(input <-chan any) (any, error) {
reducer := func(input <-chan interface{}) (interface{}, error) {
var result int64
for v := range input {
result += v.(int64)
@@ -594,7 +517,7 @@ func BenchmarkParallelMapReduce(b *testing.B) {
return result, nil
}
b.ResetTimer()
From(func(input chan<- any) {
From(func(input chan<- interface{}) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
input <- int64(rand.Int())
@@ -606,10 +529,10 @@ func BenchmarkParallelMapReduce(b *testing.B) {
func BenchmarkMapReduce(b *testing.B) {
b.ReportAllocs()
mapper := func(v any) any {
mapper := func(v interface{}) interface{} {
return v.(int64) * v.(int64)
}
reducer := func(input <-chan any) (any, error) {
reducer := func(input <-chan interface{}) (interface{}, error) {
var result int64
for v := range input {
result += v.(int64)
@@ -617,21 +540,21 @@ func BenchmarkMapReduce(b *testing.B) {
return result, nil
}
b.ResetTimer()
From(func(input chan<- any) {
From(func(input chan<- interface{}) {
for i := 0; i < b.N; i++ {
input <- int64(rand.Int())
}
}).Map(mapper).Reduce(reducer)
}
func assetEqual(t *testing.T, except, data any) {
func assetEqual(t *testing.T, except, data interface{}) {
if !reflect.DeepEqual(except, data) {
t.Errorf(" %v, want %v", data, except)
}
}
func equal(t *testing.T, stream Stream, data []any) {
items := make([]any, 0)
func equal(t *testing.T, stream Stream, data []interface{}) {
items := make([]interface{}, 0)
for item := range stream.source {
items = append(items, item)
}

View File

@@ -29,7 +29,7 @@ func DoWithTimeout(fn func() error, timeout time.Duration, opts ...DoOption) err
// create channel with buffer size 1 to avoid goroutine leak
done := make(chan error, 1)
panicChan := make(chan any, 1)
panicChan := make(chan interface{}, 1)
go func() {
defer func() {
if p := recover(); p != nil {

View File

@@ -26,7 +26,7 @@ type (
hashFunc Func
replicas int
keys []uint64
ring map[uint64][]any
ring map[uint64][]interface{}
nodes map[string]lang.PlaceholderType
lock sync.RWMutex
}
@@ -50,21 +50,21 @@ func NewCustomConsistentHash(replicas int, fn Func) *ConsistentHash {
return &ConsistentHash{
hashFunc: fn,
replicas: replicas,
ring: make(map[uint64][]any),
ring: make(map[uint64][]interface{}),
nodes: make(map[string]lang.PlaceholderType),
}
}
// Add adds the node with the number of h.replicas,
// the later call will overwrite the replicas of the former calls.
func (h *ConsistentHash) Add(node any) {
func (h *ConsistentHash) Add(node interface{}) {
h.AddWithReplicas(node, h.replicas)
}
// AddWithReplicas adds the node with the number of replicas,
// replicas will be truncated to h.replicas if it's larger than h.replicas,
// the later call will overwrite the replicas of the former calls.
func (h *ConsistentHash) AddWithReplicas(node any, replicas int) {
func (h *ConsistentHash) AddWithReplicas(node interface{}, replicas int) {
h.Remove(node)
if replicas > h.replicas {
@@ -89,7 +89,7 @@ func (h *ConsistentHash) AddWithReplicas(node any, replicas int) {
// AddWithWeight adds the node with weight, the weight can be 1 to 100, indicates the percent,
// the later call will overwrite the replicas of the former calls.
func (h *ConsistentHash) AddWithWeight(node any, weight int) {
func (h *ConsistentHash) AddWithWeight(node interface{}, weight int) {
// don't need to make sure weight not larger than TopWeight,
// because AddWithReplicas makes sure replicas cannot be larger than h.replicas
replicas := h.replicas * weight / TopWeight
@@ -97,7 +97,7 @@ func (h *ConsistentHash) AddWithWeight(node any, weight int) {
}
// Get returns the corresponding node from h base on the given v.
func (h *ConsistentHash) Get(v any) (any, bool) {
func (h *ConsistentHash) Get(v interface{}) (interface{}, bool) {
h.lock.RLock()
defer h.lock.RUnlock()
@@ -124,7 +124,7 @@ func (h *ConsistentHash) Get(v any) (any, bool) {
}
// Remove removes the given node from h.
func (h *ConsistentHash) Remove(node any) {
func (h *ConsistentHash) Remove(node interface{}) {
nodeRepr := repr(node)
h.lock.Lock()
@@ -177,10 +177,10 @@ func (h *ConsistentHash) removeNode(nodeRepr string) {
delete(h.nodes, nodeRepr)
}
func innerRepr(node any) string {
func innerRepr(node interface{}) string {
return fmt.Sprintf("%d:%v", prime, node)
}
func repr(node any) string {
func repr(node interface{}) string {
return lang.Repr(node)
}

View File

@@ -42,7 +42,7 @@ func TestConsistentHash(t *testing.T) {
keys[key.(string)]++
}
mi := make(map[any]int, len(keys))
mi := make(map[interface{}]int, len(keys))
for k, v := range keys {
mi[k] = v
}

View File

@@ -16,7 +16,7 @@ func NewBufferPool(capability int) *BufferPool {
return &BufferPool{
capability: capability,
pool: &sync.Pool{
New: func() any {
New: func() interface{} {
return new(bytes.Buffer)
},
},
@@ -32,10 +32,6 @@ func (bp *BufferPool) Get() *bytes.Buffer {
// Put returns buf into bp.
func (bp *BufferPool) Put(buf *bytes.Buffer) {
if buf == nil {
return
}
if buf.Cap() < bp.capability {
bp.pool.Put(buf)
}

Some files were not shown because too many files have changed in this diff Show More