Compare commits

..

72 Commits

Author SHA1 Message Date
kevin
2ea0a843f8 chore: remove any keywords 2023-03-04 20:54:26 +08:00
Kevin Wan
9e0e01b2bc chore: add tests (#2960) 2023-03-04 20:38:50 +08:00
yangjinheng
af50a80d01 timeout writer add hijack 2023-03-04 20:38:45 +08:00
yangjinheng
703fb8d970 Update timeouthandler.go 2023-03-04 20:38:40 +08:00
MarkJoyMa
e964e530e1 x 2023-03-04 20:32:21 +08:00
MarkJoyMa
52265087d1 x 2023-03-04 20:32:16 +08:00
MarkJoyMa
b4c2677eb9 add ut 2023-03-04 20:32:10 +08:00
MarkJoyMa
30296fb1ca feat: conf add FillDefault func 2023-03-04 20:31:44 +08:00
zhoumingji
356c80defd Fix bug in dartgen: The property 'isEmpty' can't be unconditionally accessed because the receiver can be 'null' 2023-03-04 20:31:38 +08:00
zhoumingji
8c31525378 Fix bug in dartgen: Increase the processing logic when route.RequestType is empty 2023-03-04 20:31:30 +08:00
cui fliter
2cf09f3c36 fix functiom name
Signed-off-by: cui fliter <imcusg@gmail.com>
2023-03-04 20:31:20 +08:00
Kevin Wan
d41e542c92 feat: support grpc client keepalive config (#2950) 2023-03-04 20:30:31 +08:00
tanglihao
265a24ac6d fix code format style use const config.DefaultFormat 2023-03-04 20:30:21 +08:00
tanglihao
7d88fc39dc fix log name conflict 2023-03-04 20:30:16 +08:00
anqiansong
6957b6a344 format code 2023-03-04 20:30:10 +08:00
anqiansong
bca6a230c8 remove unused code 2023-03-04 20:30:04 +08:00
anqiansong
cc8413d683 remove unused code 2023-03-04 20:29:56 +08:00
anqiansong
3842283fa8 Fix #2879 2023-03-04 20:29:41 +08:00
qiying.wang
fe13a533f5 chore: remove redundant prefix of "error: " in error creation 2023-03-04 20:26:40 +08:00
qiying.wang
7a327ccda4 chore: add tests for logc debug 2023-03-04 20:25:52 +08:00
qiying.wang
06e4507406 feat: add debug log for logc 2023-03-04 20:25:27 +08:00
kevin
8794d5b753 chore: add comments 2023-03-04 20:25:21 +08:00
kevin
9bfa63d995 chore: add more tests 2023-03-04 20:25:15 +08:00
kevin
a432b121fb chore: add more tests 2023-03-04 20:25:07 +08:00
kevin
b61c94bb66 feat: check key overwritten 2023-03-04 20:24:33 +08:00
Kevin Wan
93fcf899dc fix: config map cannot handle case-insensitive keys. (#2932)
* fix: #2922

* chore: rename const

* feat: support anonymous map field

* feat: support anonymous map field
2023-03-04 20:23:53 +08:00
Kevin Wan
9f4b3bae92 fix: #2899, using autoscaling/v2beta2 instead of v2beta1 (#2900)
* fix: #2899, using autoscaling/v2 instead of v2beta1

* chore: change hpa definition

---------

Co-authored-by: kevin.wan <kevin.wan@yijinin.com>
2023-03-04 20:22:27 +08:00
Kevin Wan
805cb87d98 chore: refine rest validator (#2928)
* chore: refine rest validator

* chore: add more tests

* chore: reformat code

* chore: add comments
2023-03-04 20:22:10 +08:00
Qiying Wang
366131640e feat: add configurable validator for httpx.Parse (#2923)
Co-authored-by: qiying.wang <qiying.wang@highlight.mobi>
2023-03-04 20:22:05 +08:00
Kevin Wan
956884a3ff fix: timeout not working if greater than global rest timeout (#2926) 2023-03-04 20:21:59 +08:00
raymonder jin
f571cb8af2 del unnecessary blank 2023-03-04 20:21:54 +08:00
Kevin Wan
cc5acf3b90 chore: reformat code (#2925) 2023-03-04 20:21:49 +08:00
chenquan
e1aa665443 fix: fixed the bug that old trace instances may be fetched 2023-03-04 20:21:43 +08:00
xiandong
cd357d9484 rm parseErr when kindJaeger 2023-03-04 20:21:28 +08:00
xiandong
6d4d7cbd6b rm kindJaegerUdp 2023-03-04 20:21:18 +08:00
xiandong
c593b5b531 add parseEndpoint 2023-03-04 20:20:29 +08:00
xiandong
fd5b38b07c add parseEndpoint 2023-03-04 20:20:17 +08:00
xiandong
41efb48f55 add test for Endpoint of kindJaegerUdp 2023-03-04 20:19:40 +08:00
xiandong
0ef3626839 add test for Endpoint of kindJaegerUdp 2023-03-04 20:19:34 +08:00
xiandong
77a72b16e9 add kindJaegerUdp 2023-03-04 20:19:25 +08:00
Kevin Wan
21566f1b7a chore: reformat code (#2903) 2023-03-04 20:17:35 +08:00
anqiansong
b2646e228b feat: Add request.ts (#2901)
* Add request.ts

* Update comments

* Refactor request filename
2023-03-04 20:17:21 +08:00
cong
588b883710 refactor: simplify sqlx fail fast ping and simplify miniredis setup in test (#2897)
* chore(redistest): simplify miniredis setup in test

* refactor(sqlx): simplify sqlx fail fast ping

* chore: close connection if not available
2023-03-04 20:17:16 +08:00
Kevin Wan
033910bbd8 Update readme-cn.md 2023-03-04 20:17:11 +08:00
fondoger
530dd79e3f Fix bug in dart api gen: path parameter is not replaced 2023-03-04 20:17:05 +08:00
Kevin Wan
cd5263ac75 Update readme-cn.md 2023-03-04 20:16:58 +08:00
Kevin Wan
ea3302a468 fix: test failures (#2892)
Co-authored-by: kevin.wan <kevin.wan@yijinin.com>
2023-03-04 20:16:50 +08:00
fondoger
abf15b373c Fix Dart API generation bugs; Add ability to generate API for path parameters (#2887)
* Fix bug in dartgen: Import path should match the generated api filename

* Use Route.HandlerName as generated dart API function name

Reasons:
- There is bug when using url path name as function name, because it may have invalid characters such as ":"
- Switching to HandlerName aligns with other languages such as typescript generation

* [DartGen] Add ability to generate api for url path parameters such as /path/:param
2023-03-04 20:16:44 +08:00
Kevin Wan
a865e9ee29 refactor: simplify stringx.Replacer, and avoid potential infinite loops (#2877)
* simplify replace

* backup

* refactor: simplify stringx.Replacer

* chore: add comments and const

* chore: add more tests

* chore: rename variable
2023-03-04 20:16:37 +08:00
Kevin Wan
f8292198cf Update readme-cn.md 2023-03-04 20:15:38 +08:00
Kevin Wan
016d965f56 chore: refactor (#2875) 2023-03-04 20:15:30 +08:00
dahaihu
95d7c73409 fix Replacer suffix match, and add test case (#2867)
* fix: replace shoud replace the longest match

* feat: revert bytes.Buffer to strings.Builder

* fix: loop reset nextStart

* feat: add node longest match test

* feat: add replacer suffix match test case

* feat: multiple match

* fix: partial match ends

* fix: replace look back upon error

* feat: rm unnecessary branch

---------

Co-authored-by: hudahai <hscxrzs@gmail.com>
Co-authored-by: hushichang <hushichang@sensetime.com>
2023-03-04 20:15:25 +08:00
Kevin Wan
939ef2a181 chore: add more tests (#2873) 2023-03-04 20:15:18 +08:00
Kevin Wan
f0b8dd45fe fix: test failure (#2874) 2023-03-04 20:15:08 +08:00
Mikael
0ba9335b04 only unmashal public variables (#2872)
* only unmashal public variables

* only unmashal public variables

* only unmashal public variables

* only unmashal public variables
2023-03-04 20:15:01 +08:00
Kevin Wan
04f181f0b4 chore: add more tests (#2866)
* chore: add more tests

* chore: add more tests

* chore: fix test failure
2023-03-04 20:14:54 +08:00
hudahai
89f841c126 fix: loop reset nextStart 2023-03-04 20:14:48 +08:00
hudahai
d785c8c377 feat: revert bytes.Buffer to strings.Builder 2023-03-04 20:14:41 +08:00
hudahai
687a1d15da fix: replace shoud replace the longest match 2023-03-04 20:14:35 +08:00
Kevin Wan
aaa974e1ad Update readme-cn.md 2023-03-04 20:14:22 +08:00
Kevin Wan
2779568ccf fix: conf anonymous overlay problem (#2847) 2023-03-04 20:14:10 +08:00
Kevin Wan
f7d50ae626 Update readme-cn.md 2023-03-04 20:14:01 +08:00
Kevin Wan
33594ea350 Chore/rewire (#2836)
* fix: problem on name overlaping in config (#2820)

* chore: fix missing funcs on windows (#2825)

* chore: add more tests (#2812)

* chore: add more tests

* chore: add more tests

* chore: add more tests (#2814)

* chore: add more tests (#2815)

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* feat: upgrade go to v1.18 (#2817)

* feat: upgrade go to v1.18

* feat: upgrade go to v1.18

* chore: change interface{} to any (#2818)

* chore: change interface{} to any

* chore: update goctl version to 1.5.0

* chore: update goctl deps

* chore: update goctl interface{} to any (#2819)

* chore: update goctl interface{} to any

* chore: update goctl interface{} to any

* chore(deps): bump google.golang.org/grpc from 1.52.0 to 1.52.3 (#2823)

* support custom maxBytes in API file (#2822)

* feat: mapreduce generic version (#2827)

* feat: mapreduce generic version

* fix: gateway mr type issue

---------

Co-authored-by: kevin.wan <kevin.wan@yijinin.com>

* feat: add MustNewRedis (#2824)

* feat: add MustNewRedis

* feat: add MustNewRedis

* feat: add MustNewRedis

* x

* x

* fix ut

* x

* x

* x

* x

* x

* chore: improve codecov (#2828)

* feat: converge grpc interceptor processing (#2830)

* feat: converge grpc interceptor processing

* x

* x

* chore(deps): bump go.opentelemetry.io/otel/exporters/zipkin (#2831)

* chore(deps): bump go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp (#2833)

Bumps [go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp](https://github.com/open-telemetry/opentelemetry-go) from 1.11.2 to 1.12.0.
- [Release notes](https://github.com/open-telemetry/opentelemetry-go/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-go/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-go/compare/v1.11.2...v1.12.0)

---
updated-dependencies:
- dependency-name: go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* chore(deps): bump go.opentelemetry.io/otel/exporters/jaeger (#2832)

Bumps [go.opentelemetry.io/otel/exporters/jaeger](https://github.com/open-telemetry/opentelemetry-go) from 1.11.2 to 1.12.0.
- [Release notes](https://github.com/open-telemetry/opentelemetry-go/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-go/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-go/compare/v1.11.2...v1.12.0)

---
updated-dependencies:
- dependency-name: go.opentelemetry.io/otel/exporters/jaeger
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Xiaoju Jiang <44432198+jiang4869@users.noreply.github.com>
Co-authored-by: kevin.wan <kevin.wan@yijinin.com>
Co-authored-by: MarkJoyMa <64180138+MarkJoyMa@users.noreply.github.com>
2023-03-04 20:13:37 +08:00
MarkJoyMa
ee2ec974c4 feat: converge grpc interceptor processing (#2830)
* feat: converge grpc interceptor processing

* x

* x
2023-03-04 20:12:30 +08:00
Kevin Wan
fd2f2f0f54 chore: improve codecov (#2828) 2023-03-04 20:12:16 +08:00
MarkJoyMa
86a2429d7d feat: add MustNewRedis (#2824)
* feat: add MustNewRedis

* feat: add MustNewRedis

* feat: add MustNewRedis

* x

* x

* fix ut

* x

* x

* x

* x

* x
2023-03-04 20:12:05 +08:00
Xiaoju Jiang
e5fe5dcc50 support custom maxBytes in API file (#2822) 2023-03-04 20:11:55 +08:00
Kevin Wan
b510e7c242 chore: fix missing funcs on windows (#2825) 2023-03-04 20:11:46 +08:00
Kevin Wan
dfe92e709f fix: problem on name overlaping in config (#2820) 2023-03-04 20:11:18 +08:00
Kevin Wan
cb649cf627 chore: add more tests (#2815)
* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests
2023-03-04 20:11:03 +08:00
Kevin Wan
ce19a5ade6 chore: add more tests (#2814) 2023-03-04 20:10:57 +08:00
Kevin Wan
6dc56de714 chore: add more tests (#2812)
* chore: add more tests

* chore: add more tests
2023-03-04 20:09:03 +08:00
911 changed files with 15091 additions and 68414 deletions

6
.codecov.yml Normal file
View File

@@ -0,0 +1,6 @@
comment:
layout: "flags, files"
behavior: once
require_changes: true
ignore:
- "tools"

View File

@@ -1,7 +1 @@
**/.git
.dockerignore
Dockerfile
goctl
Makefile
readme.md
readme-cn.md

12
.github/FUNDING.yml vendored
View File

@@ -1,3 +1,13 @@
# These are supported funding model platforms
github: [zeromicro]
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
custom: # https://gitee.com/kevwan/static/raw/master/images/sponsor.jpg
ethereum: 0x5052b7f6B937B02563996D23feb69b38D06Ca150 | kevwan

View File

@@ -1,344 +0,0 @@
# GitHub Copilot Instructions for go-zero
This document provides guidelines for GitHub Copilot when assisting with development in the go-zero project.
## Project Overview
go-zero is a web and RPC framework with lots of built-in engineering practices designed to ensure the stability of busy services with resilience design. It has been serving sites with tens of millions of users for years.
### Key Architecture Components
- **REST API framework** (`rest/`) - HTTP service framework with middleware chain support
- **RPC framework** (`zrpc/`) - gRPC-based RPC framework with etcd service discovery and p2c_ewma load balancing
- **Gateway** (`gateway/`) - API gateway supporting both HTTP and gRPC upstreams with proto-based routing
- **MCP Server** (`mcp/`) - Model Context Protocol server for AI agent integration via SSE
- **Core utilities** (`core/`) - Production-grade components:
- Resilience: circuit breakers (`breaker/`), rate limiters (`limit/`), adaptive load shedding (`load/`)
- Storage: SQL with cache (`stores/sqlc/`), Redis (`stores/redis/`), MongoDB (`stores/mongo/`)
- Concurrency: MapReduce (`mr/`), worker pools (`executors/`), sync primitives (`syncx/`)
- Observability: metrics (`metric/`), tracing (`trace/`), structured logging (`logx/`)
- **Code generation tool** (`tools/goctl/`) - CLI tool for generating Go code from `.api` and `.proto` files
## Coding Standards and Conventions
### Code Style
1. **Follow Go conventions**: Use `gofmt` for formatting, follow effective Go practices
2. **Package naming**: Use lowercase, single-word package names when possible
3. **Error handling**: Always handle errors explicitly, use `errorx.BatchError` for multiple errors
4. **Context propagation**: Always pass `context.Context` as the first parameter for functions that may block
5. **Configuration structures**: Use struct tags with JSON annotations, defaults, and validation
**Pattern**: All service configs embed `service.ServiceConf` for common fields (Name, Log, Mode, Telemetry)
```go
type Config struct {
service.ServiceConf // Always embed for services
Host string `json:",default=0.0.0.0"`
Port int // Required field (no default)
Timeout int64 `json:",default=3000"` // Timeouts in milliseconds
Optional string `json:",optional"` // Optional field
Mode string `json:",default=pro,options=dev|test|rt|pre|pro"` // Validated options
}
```
**Service modes**: `dev`/`test`/`rt` disable load shedding and stats; `pre`/`pro` enable all resilience features
### Interface Design
1. **Small interfaces**: Follow Go's preference for small, focused interfaces
2. **Context methods**: Provide both context and non-context versions of methods
3. **Options pattern**: Use functional options for complex configuration
Example:
```go
func (c *Client) Get(key string, val any) error {
return c.GetCtx(context.Background(), key, val)
}
func (c *Client) GetCtx(ctx context.Context, key string, val any) error {
// implementation
}
```
### Testing Patterns
1. **Test file naming**: Use `*_test.go` suffix
2. **Test function naming**: Use `TestFunctionName` pattern
3. **Use testify/assert**: Prefer `assert` package for assertions
4. **Table-driven tests**: Use table-driven tests for multiple scenarios
5. **Mock interfaces**: Use `go.uber.org/mock` for mocking
6. **Test helpers**: Use `redistest`, `mongtest` helpers for database testing
Example test pattern:
```go
func TestSomething(t *testing.T) {
tests := []struct {
name string
input string
expected string
wantErr bool
}{
{"valid case", "input", "output", false},
{"error case", "bad", "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := SomeFunction(tt.input)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
```
## Framework-Specific Guidelines
### REST API Development
1. **API Definition**: Use `.api` files to define REST APIs with goctl codegen
2. **Handler pattern**: Separate business logic into logic packages (handlers call logic layer)
3. **Middleware chain**: Middlewares wrap via `chain.Chain` interface - use `Append()` or `Prepend()` to control order
- Built-in middlewares (all in `rest/handler/`): tracing, logging, metrics, recovery, breaker, shedding, timeout, maxconns, maxbytes, gunzip
- Custom middleware: `func(http.Handler) http.Handler` - call `next.ServeHTTP(w, r)` to continue chain
4. **Response handling**: Use `httpx.WriteJson(w, code, v)` for JSON responses
5. **Error handling**: Use `httpx.Error(w, err)` or `httpx.ErrorCtx(ctx, w, err)` for HTTP error responses
6. **Route registration**: Routes defined with `Method`, `Path`, and `Handler` - wildcards use `:param` syntax
### RPC Development
1. **Protocol Buffers**: Use protobuf for service definitions, generate code with goctl
2. **Service discovery**: Use etcd for dynamic service registration/discovery, or direct endpoints for static routing
3. **Load balancing**: Default is `p2c_ewma` (power of 2 choices with EWMA), configurable via `BalancerName`
4. **Client configuration**: Support `Etcd`, `Endpoints`, or `Target` - use `BuildTarget()` to construct connection string
5. **Interceptors**: Implement gRPC interceptors for cross-cutting concerns (auth, logging, metrics)
6. **Health checks**: gRPC health checks enabled by default (`Health: true`)
### Database Operations
1. **SQL operations**: Use `sqlx.SqlConn` interface - methods always end with `Ctx` for context support
2. **Caching pattern**: `stores/sqlc` provides `CachedConn` for automatic cache-aside pattern
- `QueryRowCtx`: Query with cache key, auto-populate on cache miss
- `ExecCtx`: Execute and delete cache keys
3. **Transactions**: Use `sqlx.SqlConn.TransactCtx()` to get transaction session
4. **Connection pooling**: Managed automatically (64 max idle/open, 1min lifetime)
5. **Test helpers**: Use `redistest.CreateRedis(t)` for Redis, SQL mocks for DB testing
Example cache pattern:
```go
err := c.QueryRowCtx(ctx, &dest, key, func(ctx context.Context, conn sqlx.SqlConn) error {
return conn.QueryRowCtx(ctx, &dest, query, args...)
})
```
### Configuration Management
1. **YAML configuration**: Use YAML for configuration files
2. **Environment variables**: Support environment variable overrides
3. **Validation**: Include proper validation for configuration parameters
4. **Sensible defaults**: Provide reasonable default values
## Error Handling Best Practices
1. **Wrap errors**: Use `fmt.Errorf` with `%w` verb to wrap errors
2. **Custom errors**: Define custom error types when needed
3. **Error logging**: Log errors appropriately with context
4. **Graceful degradation**: Implement fallback mechanisms
## Performance Considerations
1. **Resource pools**: Use connection pools and worker pools
2. **Circuit breakers**: Implement circuit breaker patterns for external calls
3. **Rate limiting**: Apply rate limiting to protect services
4. **Load shedding**: Implement adaptive load shedding
5. **Metrics**: Add appropriate metrics and monitoring
## Security Guidelines
1. **Input validation**: Validate all input parameters
2. **SQL injection prevention**: Use parameterized queries
3. **Authentication**: Implement proper JWT token handling
4. **HTTPS**: Support TLS/HTTPS configurations
5. **CORS**: Configure CORS appropriately for web APIs
## Documentation Standards
1. **Package documentation**: Include package-level documentation
2. **Function documentation**: Document exported functions with examples
3. **API documentation**: Maintain API documentation in sync
4. **README updates**: Update README for significant changes
## GitHub Issue Management
### Understanding and Categorizing Issues
When analyzing GitHub issues, consider these common categories:
1. **Bug Reports**: Stack traces, version info, reproduction steps
2. **Feature Requests**: Use case, proposed solution, alternatives
3. **Questions**: Usage, configuration, or architecture
4. **Documentation Issues**: Missing, unclear, or incorrect docs
5. **Performance Issues**: Benchmarks, profiling data, resource usage
### Issue Analysis Checklist
- Identify affected component (REST, RPC, Gateway, MCP, Core utilities, goctl)
- Check versions (go-zero, Go)
- Look for reproduction steps or code examples
- Review code snippets, logs, or stack traces
- Check if related to resilience features (breaker, load shedding, rate limiting)
- Determine production impact
### Responding to Issues
Be helpful and professional. Ask clarifying questions when needed. Reference relevant documentation and code files. Provide code examples following project conventions. Suggest workarounds when applicable.
### Chinese to English Translation
go-zero has an international user base. When encountering issues or comments written in Chinese, translate them to English to ensure all contributors can participate in discussions.
#### Translation Guidelines
1. **Update issue titles**: Edit the issue title to include English translation only
2. **Translate comments in place**: Add a comment with the English translation, followed by the original Chinese text
3. **Keep original Chinese**: After translating, include the original Chinese text in a blockquote for verification
4. **Encourage English communication**: Politely suggest users write in English for better collaboration
5. **Maintain technical accuracy**: Preserve technical terms, component names, and code exactly
6. **Translate naturally**: Avoid literal word-by-word translation; use idiomatic English
7. **Preserve formatting**: Keep markdown formatting, code blocks, and links intact
8. **Keep URLs unchanged**: Don't translate URLs or file paths
#### Common Technical Terms (Chinese → English)
- 框架 → **Framework** | 中间件 → **Middleware** | 负载均衡 → **Load Balancing**
- 熔断器 → **Circuit Breaker** | 限流 → **Rate Limiting** | 降载/过载保护 → **Load Shedding**
- 服务发现 → **Service Discovery** | 配置 → **Configuration** | 弹性/容错 → **Resilience** | 微服务 → **Microservices**
#### Translation Example
**Original Chinese Title:** `goctl 执行环境问题`
**Updated Title:** `goctl Execution Environment Issue`
**Original Chinese Comment:** `我在项目中遇到熔断器配置问题`
**Translation in Comment:**
```markdown
I encountered a circuit breaker configuration issue in my project.
> Original (原文): 我在项目中遇到熔断器配置问题
```
### Common Issue Patterns and Solutions
#### Configuration Issues
- Check `service.ServiceConf` embedding and struct tags
- Verify YAML syntax, defaults, and validation rules
- Reference: [rest/config.go](rest/config.go), [zrpc/config.go](zrpc/config.go)
#### Code Generation (goctl) Issues
- Verify `.api` or `.proto` file syntax and goctl version
- Reference: `tools/goctl/` directory
#### RPC Connection Issues
- Check etcd configuration, service discovery, and endpoints
- Verify load balancing settings (p2c_ewma)
#### Database/Cache Issues
- Verify `sqlx.SqlConn` usage with context
- Check cache key generation, invalidation, and connection pools
- Use test helpers (`redistest`, `mongtest`)
#### Performance Issues
- Check if load shedding is enabled (mode: `pre`/`pro`)
- Review circuit breaker thresholds, rate limiting, and context timeouts
### Referencing Codebase
When explaining issues, reference specific files and patterns:
- REST API: `rest/`, `rest/handler/`, `rest/httpx/`
- RPC: `zrpc/`, `zrpc/internal/`
- Core utilities: `core/breaker/`, `core/limit/`, `core/load/`, etc.
- Gateway: `gateway/`
- MCP: `mcp/`
- Code generation: `tools/goctl/`
- Examples: `adhoc/` directory contains various examples
### Encouraging Best Practices
When responding to issues, gently guide users toward:
- Proper error handling with context
- Using resilience features (breakers, rate limiters)
- Following testing patterns with table-driven tests
- Implementing proper resource cleanup
- Reading existing documentation in `docs/` and `readme.md`
## Common Patterns to Follow
### Service Configuration
```go
type ServiceConf struct {
Name string
Log logx.LogConf
Mode string `json:",default=pro,options=[dev,test,pre,pro]"`
// ... other common fields
}
```
### Middleware Implementation
```go
func SomeMiddleware() rest.Middleware {
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Pre-processing
next.ServeHTTP(w, r)
// Post-processing
}
}
}
```
### Resource Management
Always implement proper resource cleanup using defer and context cancellation.
## Build and Test Commands
- Build: `go build ./...`
- Test: `go test ./...`
- Test with race detection: `go test -race ./...`
- Format: `gofmt -w .`
- Code generation:
- REST API: `goctl api go -api *.api -dir .`
- RPC: `goctl rpc protoc *.proto --go_out=. --go-grpc_out=. --zrpc_out=.`
- Model from SQL: `goctl model mysql datasource -url="user:pass@tcp(host:port)/db" -table="*" -dir="./model"`
## Critical Architecture Patterns
### Resilience Design Philosophy
go-zero implements defense-in-depth with multiple protection layers:
1. **Circuit Breaker** (`core/breaker`): Google SRE breaker - tracks success/failure, opens on error threshold
2. **Adaptive Load Shedding** (`core/load`): CPU-based auto-rejection when system overloaded (disabled in dev/test/rt modes)
3. **Rate Limiting** (`core/limit`): Token bucket (Redis-based) and period limiters
4. **Timeout Control**: Cascading timeouts via context - set at multiple levels (client, server, handler)
### Middleware Chain Architecture
`rest/chain` provides middleware composition:
```go
// Middleware signature
type Middleware func(http.Handler) http.Handler
// Chain operations
chain := chain.New(m1, m2)
chain.Append(m3) // Adds to end: m1 -> m2 -> m3
chain.Prepend(m0) // Adds to start: m0 -> m1 -> m2 -> m3
handler := chain.Then(finalHandler)
```
### Concurrency Patterns
- **MapReduce** (`core/mr`): Parallel processing with worker pools - use for batch operations
- **Executors** (`core/executors`): Bulk/period executors for batching operations
- **SingleFlight** (`core/syncx`): Deduplicates concurrent identical requests
Remember to run tests and ensure all checks pass before submitting changes. The project emphasizes high quality, performance, and reliability, so these should be primary considerations in all development work.

View File

@@ -5,19 +5,7 @@
version: 2
updates:
- package-ecosystem: "docker" # Update image tags in Dockerfile
directory: "/"
schedule:
interval: "weekly"
- package-ecosystem: "github-actions" # Update GitHub Actions
directory: "/"
schedule:
interval: "weekly"
- package-ecosystem: "gomod" # See documentation for possible values
directory: "/" # Location of package manifests
schedule:
interval: "daily"
- package-ecosystem: "gomod" # See documentation for possible values
directory: "/tools/goctl" # Location of package manifests
schedule:
interval: "daily"

View File

@@ -35,11 +35,11 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v3
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v4
uses: github/codeql-action/init@v2
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
@@ -50,7 +50,7 @@ jobs:
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
uses: github/codeql-action/autobuild@v4
uses: github/codeql-action/autobuild@v2
# Command-line programs to run using the OS shell.
# 📚 https://git.io/JvXDl
@@ -64,4 +64,4 @@ jobs:
# make release
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v4
uses: github/codeql-action/analyze@v2

View File

@@ -12,12 +12,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check out code into the Go module directory
uses: actions/checkout@v6
uses: actions/checkout@v3
- name: Set up Go 1.x
uses: actions/setup-go@v6
uses: actions/setup-go@v3
with:
go-version-file: go.mod
go-version: ^1.16
check-latest: true
cache: true
id: go
@@ -29,36 +29,27 @@ jobs:
- name: Lint
run: |
go vet -stdmethods=false $(go list ./...)
go mod tidy
if ! test -z "$(git status --porcelain)"; then
echo "Please run 'go mod tidy'"
exit 1
fi
go install mvdan.cc/gofumpt@latest
test -z "$(gofumpt -l -extra .)" || echo "Please run 'gofumpt -l -w -extra .'"
- name: Test
run: go test -race -coverprofile=coverage.txt -covermode=atomic ./...
- name: Codecov
uses: codecov/codecov-action@v5
with:
files: ./coverage.txt
flags: unittests
name: codecov-umbrella
fail_ci_if_error: false
uses: codecov/codecov-action@v3
test-win:
name: Windows
runs-on: windows-latest
steps:
- name: Checkout codebase
uses: actions/checkout@v6
uses: actions/checkout@v3
- name: Set up Go 1.x
uses: actions/setup-go@v6
uses: actions/setup-go@v3
with:
# make sure Go version compatible with go-zero
go-version-file: go.mod
# use 1.16 to guarantee Go 1.16 compatibility
go-version: 1.16
check-latest: true
cache: true
@@ -66,5 +57,5 @@ jobs:
run: |
go mod verify
go mod download
go test ./...
go test -v -race ./...
cd tools/goctl && go build -v goctl.go

18
.github/workflows/issue-translator.yml vendored Normal file
View File

@@ -0,0 +1,18 @@
name: 'issue-translator'
on:
issue_comment:
types: [created]
issues:
types: [opened]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: usthe/issues-translate-action@v2.7
with:
IS_MODIFY_TITLE: true
# not require, default false, . Decide whether to modify the issue title
# if true, the robot account @Issues-translate-bot must have modification permissions, invite @Issues-translate-bot to your project or use your custom bot.
CUSTOM_BOT_NOTE: Bot detected the issue body's language is not English, translate it automatically. 👯👭🏻🧑‍🤝‍🧑👫🧑🏿‍🤝‍🧑🏻👩🏾‍🤝‍👨🏿👬🏿
# not require. Customize the translation robot prefix message.

View File

@@ -7,7 +7,7 @@ jobs:
close-issues:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v10
- uses: actions/stale@v6
with:
days-before-issue-stale: 365
days-before-issue-close: 90

View File

@@ -16,13 +16,13 @@ jobs:
- goarch: "386"
goos: darwin
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- uses: zeromicro/go-zero-release-action@master
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
goos: ${{ matrix.goos }}
goarch: ${{ matrix.goarch }}
goversion: "https://dl.google.com/go/go1.21.13.linux-amd64.tar.gz"
goversion: "https://dl.google.com/go/go1.17.5.linux-amd64.tar.gz"
project_path: "tools/goctl"
binary_name: "goctl"
extra_files: tools/goctl/readme.md tools/goctl/readme-cn.md
extra_files: tools/goctl/readme.md tools/goctl/readme-cn.md

View File

@@ -5,12 +5,7 @@ jobs:
name: runner / staticcheck
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version-file: go.mod
check-latest: true
cache: true
- uses: actions/checkout@v3
- uses: reviewdog/action-staticcheck@v1
with:
github_token: ${{ secrets.github_token }}
@@ -19,6 +14,6 @@ jobs:
# Report all results.
filter_mode: nofilter
# Exit with 1 when it find at least one finding.
fail_level: any
fail_on_error: true
# Set staticcheck flags
staticcheck_flags: -checks=inherit,-SA1019,-SA1029,-SA5008

View File

@@ -1,42 +0,0 @@
name: Release Version Check
on:
push:
tags:
- 'tools/goctl/v*'
workflow_dispatch:
jobs:
version-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- name: Set up Go
uses: actions/setup-go@v6
with:
go-version: '1.21'
- name: Extract tag version
id: get_version
run: |
# Extract version from tools/goctl/v* format
VERSION="${GITHUB_REF#refs/tags/tools/goctl/v}"
echo "VERSION=$VERSION" >> $GITHUB_ENV
echo "Extracted version: $VERSION"
- name: Check version in goctl source code
run: |
# Change to goctl directory
cd tools/goctl
# Check version in BuildVersion constant
VERSION_IN_CODE=$(grep -r "const BuildVersion =" . | grep -o '".*"' | tr -d '"')
echo "Version in code: $VERSION_IN_CODE"
echo "Expected version: $VERSION"
if [ "$VERSION_IN_CODE" != "$VERSION" ]; then
echo "Version mismatch: Version in code ($VERSION_IN_CODE) doesn't match tag version ($VERSION)"
exit 1
fi
echo "✅ Version check passed!"

7
.gitignore vendored
View File

@@ -11,15 +11,12 @@
!api
# ignore
**/.idea
**/.vscode
.idea
**/.DS_Store
**/logs
**/adhoc
**/coverage.txt
**/WARP.md
# for test purpose
**/adhoc
go.work
go.work.sum

View File

@@ -1,76 +1,102 @@
# 🚀 Contributing to go-zero
# Contributing
Welcome to the go-zero community! We're thrilled to have you here. Contributing to our project is a fantastic way to be a part of the go-zero journey. Let's make this guide exciting and fun!
Welcome to go-zero!
## 📜 Before You Dive In
- [Before you get started](#before-you-get-started)
- [Code of Conduct](#code-of-conduct)
- [Community Expectations](#community-expectations)
- [Getting started](#getting-started)
- [Your First Contribution](#your-first-contribution)
- [Find something to work on](#find-something-to-work-on)
- [Find a good first topic](#find-a-good-first-topic)
- [Work on an Issue](#work-on-an-issue)
- [File an Issue](#file-an-issue)
- [Contributor Workflow](#contributor-workflow)
- [Creating Pull Requests](#creating-pull-requests)
- [Code Review](#code-review)
- [Testing](#testing)
### 🤝 Code of Conduct
# Before you get started
Let's start on the right foot. Please take a moment to read and embrace our [Code of Conduct](/code-of-conduct.md). We're all about creating a welcoming and respectful environment.
## Code of Conduct
### 🌟 Community Expectations
Please make sure to read and observe our [Code of Conduct](/code-of-conduct.md).
At go-zero, we're like a close-knit family, and we believe in creating a healthy, friendly, and productive atmosphere. It's all about sharing knowledge and building amazing things together.
## Community Expectations
## 🚀 Getting Started
go-zero is a community project driven by its community which strives to promote a healthy, friendly and productive environment.
go-zero is a web and rpc framework written in Go. It's born to ensure the stability of the busy sites with resilient design. Builtin goctl greatly improves the development productivity.
Get your adventure rolling! Here's how to begin:
# Getting started
1. 🍴 **Fork the Repository**: Head over to the GitHub repository and fork it to your own space.
- Fork the repository on GitHub.
- Make your changes on your fork repository.
- Submit a PR.
2. 🛠️ **Make Your Magic**: Work your magic in your forked repository. Create new features, squash bugs, or improve documentation - it's your world to conquer!
3. 🚀 **Submit a PR (Pull Request)**: When you're ready to unveil your creation, submit a Pull Request. We can't wait to see your awesome work!
# Your First Contribution
## 🌟 Your First Contribution
We will help you to contribute in different areas like filing issues, developing features, fixing critical bugs and
getting your work reviewed and merged.
We're here to guide you on your quest to become a go-zero contributor. Whether you want to file issues, develop features, or tame some critical bugs, we've got you covered.
If you have questions about the development process,
feel free to [file an issue](https://github.com/zeromicro/go-zero/issues/new/choose).
If you have questions or need guidance at any stage, don't hesitate to [open an issue](https://github.com/zeromicro/go-zero/issues/new/choose).
## Find something to work on
## 🔍 Find Something to Work On
We are always in need of help, be it fixing documentation, reporting bugs or writing some code.
Look at places where you feel best coding practices aren't followed, code refactoring is needed or tests are missing.
Here is how you get started.
Ready to dive into the action? There are several ways to contribute:
### Find a good first topic
### 💼 Find a Good First Topic
[go-zero](https://github.com/zeromicro/go-zero) has beginner-friendly issues that provide a good first issue.
For example, [go-zero](https://github.com/zeromicro/go-zero) has
[help wanted](https://github.com/zeromicro/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) and
[good first issue](https://github.com/zeromicro/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)
labels for issues that should not need deep knowledge of the system.
We can help new contributors who wish to work on such issues.
Discover easy-entry issues labeled as [help wanted](https://github.com/zeromicro/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) or [good first issue](https://github.com/zeromicro/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). These issues are perfect for newcomers and don't require deep knowledge of the system. We're here to assist you with these tasks.
Another good way to contribute is to find a documentation improvement, such as a missing/broken link.
Please see [Contributing](#contributing) below for the workflow.
### 🪄 Work on an Issue
#### Work on an issue
Once you've picked an issue that excites you, let us know by commenting on it. Our maintainers will assign it to you, and you can embark on your mission!
When you are willing to take on an issue, just reply on the issue. The maintainer will assign it to you.
### 📢 File an Issue
### File an Issue
Reporting an issue is just as valuable as code contributions. If you discover a problem, don't hesitate to [open an issue](https://github.com/zeromicro/go-zero/issues/new/choose). Be sure to follow our guidelines when submitting an issue.
While we encourage everyone to contribute code, it is also appreciated when someone reports an issue.
## 🎯 Contributor Workflow
Please follow the prompted submission guidelines while opening an issue.
Here's a rough guide to your contributor journey:
# Contributor Workflow
1. 🌱 Create a New Branch: Start by creating a topic branch, usually based on the 'master' branch. This is where your contribution will grow.
Please do not ever hesitate to ask a question or send a pull request.
2. 💡 Make Commits: Commit your work in logical units. Each commit should tell a story.
This is a rough outline of what a contributor's workflow looks like:
3. 🚀 Push Changes: Push the changes in your topic branch to your personal fork of the repository.
- Create a topic branch from where to base the contribution. This is usually master.
- Make commits of logical units.
- Push changes in a topic branch to a personal fork of the repository.
- Submit a pull request to [go-zero](https://github.com/zeromicro/go-zero).
4. 📦 Submit a Pull Request: When your creation is complete, submit a Pull Request to the [go-zero repository](https://github.com/zeromicro/go-zero).
## Creating Pull Requests
## 🌠 Creating Pull Requests
Pull requests are often called simply "PR".
go-zero generally follows the standard [github pull request](https://help.github.com/articles/about-pull-requests/) process.
To submit a proposed change, please develop the code/fix and add new test cases.
After that, run these local verifications before submitting pull request to predict the pass or
fail of continuous integration.
Pull Requests (PRs) are your way of making a grand entrance with your contribution. Here's how to do it:
* Format the code with `gofmt`
* Run the test with data race enabled `go test -race ./...`
- 💼 Format Your Code: Ensure your code is beautifully formatted with `gofmt`.
- 🏃 Run Tests: Verify that your changes pass all the tests, including data race tests. Run `go test -race ./...` for the ultimate validation.
## Code Review
## 👁️‍🗨️ Code Review
To make it easier for your PR to receive reviews, consider the reviewers will need you to:
Getting your PR reviewed is the final step before your contribution becomes part of go-zero's magical world. To make the process smooth, keep these things in mind:
* follow [good coding guidelines](https://github.com/golang/go/wiki/CodeReviewComments).
* write [good commit messages](https://chris.beams.io/posts/git-commit/).
* break large changes into a logical series of smaller patches which individually make easily understandable changes, and in aggregate solve a broader issue.
- 🧙‍♀️ Follow Good Coding Practices: Stick to [good coding guidelines](https://github.com/golang/go/wiki/CodeReviewComments).
- 📝 Write Awesome Commit Messages: Craft [impressive commit messages](https://chris.beams.io/posts/git-commit/) - they're like spells in the wizard's book!
- 🔍 Break It Down: For larger changes, consider breaking them into a series of smaller, logical patches. Each patch should make an understandable and meaningful improvement.
Congratulations on your contribution journey! We're thrilled to have you as part of our go-zero community. Let's make amazing things together! 🌟
Now, go out there and start your adventure! If you have any more magical ideas to enhance this guide, please share them. 🔥

View File

@@ -1,16 +0,0 @@
# Security Policy
## Supported Versions
We publish releases monthly.
| Version | Supported |
| ------- | ------------------ |
| >= 1.4.4 | :white_check_mark: |
| < 1.4.4 | :x: |
## Reporting a Vulnerability
https://github.com/zeromicro/go-zero/security/advisories
Accepted vulnerabilities are expected to be fixed within a month.

View File

@@ -1,127 +1,76 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, caste, color, religion, or sexual
identity and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
Examples of behavior that contributes to creating a positive environment
include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the overall
community
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior include:
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery, and sexual attention or advances of
any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email address,
without their explicit permission
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
## Our Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
[INSERT CONTACT METHOD].
All complaints will be reviewed and investigated promptly and fairly.
reported by contacting the project team at [INSERT EMAIL ADDRESS]. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series of
actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or permanent
ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within the
community.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.1, available at
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
Community Impact Guidelines were inspired by
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
[https://www.contributor-covenant.org/translations][translations].
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq

View File

@@ -1,8 +1,6 @@
package bloom
import (
"context"
_ "embed"
"errors"
"strconv"
@@ -10,23 +8,28 @@ import (
"github.com/zeromicro/go-zero/core/stores/redis"
)
// for detailed error rate table, see http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html
// maps as k in the error rate table
const maps = 14
var (
// ErrTooLargeOffset indicates the offset is too large in bitset.
ErrTooLargeOffset = errors.New("too large offset")
//go:embed setscript.lua
setLuaScript string
setScript = redis.NewScript(setLuaScript)
//go:embed testscript.lua
testLuaScript string
testScript = redis.NewScript(testLuaScript)
const (
// for detailed error rate table, see http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html
// maps as k in the error rate table
maps = 14
setScript = `
for _, offset in ipairs(ARGV) do
redis.call("setbit", KEYS[1], offset, 1)
end
`
testScript = `
for _, offset in ipairs(ARGV) do
if tonumber(redis.call("getbit", KEYS[1], offset)) == 0 then
return false
end
end
return true
`
)
// ErrTooLargeOffset indicates the offset is too large in bitset.
var ErrTooLargeOffset = errors.New("too large offset")
type (
// A Filter is a bloom filter.
Filter struct {
@@ -35,12 +38,12 @@ type (
}
bitSetProvider interface {
check(ctx context.Context, offsets []uint) (bool, error)
set(ctx context.Context, offsets []uint) error
check([]uint) (bool, error)
set([]uint) error
}
)
// New creates a Filter, store is the backed redis, key is the key for the bloom filter,
// New create a Filter, store is the backed redis, key is the key for the bloom filter,
// bits is how many bits will be used, maps is how many hashes for each addition.
// best practices:
// elements - means how many actual elements
@@ -55,24 +58,14 @@ func New(store *redis.Redis, key string, bits uint) *Filter {
// Add adds data into f.
func (f *Filter) Add(data []byte) error {
return f.AddCtx(context.Background(), data)
}
// AddCtx adds data into f with context.
func (f *Filter) AddCtx(ctx context.Context, data []byte) error {
locations := f.getLocations(data)
return f.bitSet.set(ctx, locations)
return f.bitSet.set(locations)
}
// Exists checks if data is in f.
func (f *Filter) Exists(data []byte) (bool, error) {
return f.ExistsCtx(context.Background(), data)
}
// ExistsCtx checks if data is in f with context.
func (f *Filter) ExistsCtx(ctx context.Context, data []byte) (bool, error) {
locations := f.getLocations(data)
isSet, err := f.bitSet.check(ctx, locations)
isSet, err := f.bitSet.check(locations)
if err != nil {
return false, err
}
@@ -105,7 +98,7 @@ func newRedisBitSet(store *redis.Redis, key string, bits uint) *redisBitSet {
}
func (r *redisBitSet) buildOffsetArgs(offsets []uint) ([]string, error) {
args := make([]string, 0, len(offsets))
var args []string
for _, offset := range offsets {
if offset >= r.bits {
@@ -118,14 +111,14 @@ func (r *redisBitSet) buildOffsetArgs(offsets []uint) ([]string, error) {
return args, nil
}
func (r *redisBitSet) check(ctx context.Context, offsets []uint) (bool, error) {
func (r *redisBitSet) check(offsets []uint) (bool, error) {
args, err := r.buildOffsetArgs(offsets)
if err != nil {
return false, err
}
resp, err := r.store.ScriptRunCtx(ctx, testScript, []string{r.key}, args)
if errors.Is(err, redis.Nil) {
resp, err := r.store.Eval(testScript, []string{r.key}, args)
if err == redis.Nil {
return false, nil
} else if err != nil {
return false, err
@@ -139,25 +132,23 @@ func (r *redisBitSet) check(ctx context.Context, offsets []uint) (bool, error) {
return exists == 1, nil
}
// del only use for testing.
func (r *redisBitSet) del() error {
_, err := r.store.Del(r.key)
return err
}
// expire only use for testing.
func (r *redisBitSet) expire(seconds int) error {
return r.store.Expire(r.key, seconds)
}
func (r *redisBitSet) set(ctx context.Context, offsets []uint) error {
func (r *redisBitSet) set(offsets []uint) error {
args, err := r.buildOffsetArgs(offsets)
if err != nil {
return err
}
_, err = r.store.ScriptRunCtx(ctx, setScript, []string{r.key}, args)
if errors.Is(err, redis.Nil) {
_, err = r.store.Eval(setScript, []string{r.key}, args)
if err == redis.Nil {
return nil
}

View File

@@ -1,31 +1,30 @@
package bloom
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/redis/redistest"
)
func TestRedisBitSet_New_Set_Test(t *testing.T) {
store := redistest.CreateRedis(t)
ctx := context.Background()
store, clean, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean()
bitSet := newRedisBitSet(store, "test_key", 1024)
isSetBefore, err := bitSet.check(ctx, []uint{0})
isSetBefore, err := bitSet.check([]uint{0})
if err != nil {
t.Fatal(err)
}
if isSetBefore {
t.Fatal("Bit should not be set")
}
err = bitSet.set(ctx, []uint{512})
err = bitSet.set([]uint{512})
if err != nil {
t.Fatal(err)
}
isSetAfter, err := bitSet.check(ctx, []uint{512})
isSetAfter, err := bitSet.check([]uint{512})
if err != nil {
t.Fatal(err)
}
@@ -43,7 +42,9 @@ func TestRedisBitSet_New_Set_Test(t *testing.T) {
}
func TestRedisBitSet_Add(t *testing.T) {
store := redistest.CreateRedis(t)
store, clean, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean()
filter := New(store, "test_key", 64)
assert.Nil(t, filter.Add([]byte("hello")))
@@ -52,51 +53,3 @@ func TestRedisBitSet_Add(t *testing.T) {
assert.Nil(t, err)
assert.True(t, ok)
}
func TestFilter_Exists(t *testing.T) {
store, clean := redistest.CreateRedisWithClean(t)
rbs := New(store, "test", 64)
_, err := rbs.Exists([]byte{0, 1, 2})
assert.NoError(t, err)
clean()
rbs = New(store, "test", 64)
_, err = rbs.Exists([]byte{0, 1, 2})
assert.Error(t, err)
}
func TestRedisBitSet_check(t *testing.T) {
store, clean := redistest.CreateRedisWithClean(t)
ctx := context.Background()
rbs := newRedisBitSet(store, "test", 0)
assert.Error(t, rbs.set(ctx, []uint{0, 1, 2}))
_, err := rbs.check(ctx, []uint{0, 1, 2})
assert.Error(t, err)
rbs = newRedisBitSet(store, "test", 64)
_, err = rbs.check(ctx, []uint{0, 1, 2})
assert.NoError(t, err)
clean()
rbs = newRedisBitSet(store, "test", 64)
_, err = rbs.check(ctx, []uint{0, 1, 2})
assert.Error(t, err)
}
func TestRedisBitSet_set(t *testing.T) {
logx.Disable()
store, clean := redistest.CreateRedisWithClean(t)
ctx := context.Background()
rbs := newRedisBitSet(store, "test", 0)
assert.Error(t, rbs.set(ctx, []uint{0, 1, 2}))
rbs = newRedisBitSet(store, "test", 64)
assert.NoError(t, rbs.set(ctx, []uint{0, 1, 2}))
clean()
rbs = newRedisBitSet(store, "test", 64)
assert.Error(t, rbs.set(ctx, []uint{0, 1, 2}))
}

View File

@@ -1,3 +0,0 @@
for _, offset in ipairs(ARGV) do
redis.call("setbit", KEYS[1], offset, 1)
end

View File

@@ -1,6 +0,0 @@
for _, offset in ipairs(ARGV) do
if tonumber(redis.call("getbit", KEYS[1], offset)) == 0 then
return false
end
end
return true

View File

@@ -1,19 +1,22 @@
package breaker
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/core/stringx"
)
const numHistoryReasons = 5
const (
numHistoryReasons = 5
timeFormat = "15:04:05"
)
// ErrServiceUnavailable is returned when the Breaker state is open.
var ErrServiceUnavailable = errors.New("circuit breaker is open")
@@ -28,53 +31,38 @@ type (
Name() string
// Allow checks if the request is allowed.
// If allowed, a promise will be returned,
// otherwise ErrServiceUnavailable will be returned as the error.
// The caller needs to call promise.Accept() on success,
// or call promise.Reject() on failure.
// If allowed, a promise will be returned, the caller needs to call promise.Accept()
// on success, or call promise.Reject() on failure.
// If not allow, ErrServiceUnavailable will be returned.
Allow() (Promise, error)
// AllowCtx checks if the request is allowed when ctx isn't done.
AllowCtx(ctx context.Context) (Promise, error)
// Do runs the given request if the Breaker accepts it.
// Do returns an error instantly if the Breaker rejects the request.
// If a panic occurs in the request, the Breaker handles it as an error
// and causes the same panic again.
Do(req func() error) error
// DoCtx runs the given request if the Breaker accepts it when ctx isn't done.
DoCtx(ctx context.Context, req func() error) error
// DoWithAcceptable runs the given request if the Breaker accepts it.
// DoWithAcceptable returns an error instantly if the Breaker rejects the request.
// If a panic occurs in the request, the Breaker handles it as an error
// and causes the same panic again.
// acceptable checks if it's a successful call, even if the error is not nil.
// acceptable checks if it's a successful call, even if the err is not nil.
DoWithAcceptable(req func() error, acceptable Acceptable) error
// DoWithAcceptableCtx runs the given request if the Breaker accepts it when ctx isn't done.
DoWithAcceptableCtx(ctx context.Context, req func() error, acceptable Acceptable) error
// DoWithFallback runs the given request if the Breaker accepts it.
// DoWithFallback runs the fallback if the Breaker rejects the request.
// If a panic occurs in the request, the Breaker handles it as an error
// and causes the same panic again.
DoWithFallback(req func() error, fallback Fallback) error
// DoWithFallbackCtx runs the given request if the Breaker accepts it when ctx isn't done.
DoWithFallbackCtx(ctx context.Context, req func() error, fallback Fallback) error
DoWithFallback(req func() error, fallback func(err error) error) error
// DoWithFallbackAcceptable runs the given request if the Breaker accepts it.
// DoWithFallbackAcceptable runs the fallback if the Breaker rejects the request.
// If a panic occurs in the request, the Breaker handles it as an error
// and causes the same panic again.
// acceptable checks if it's a successful call, even if the error is not nil.
DoWithFallbackAcceptable(req func() error, fallback Fallback, acceptable Acceptable) error
// DoWithFallbackAcceptableCtx runs the given request if the Breaker accepts it when ctx isn't done.
DoWithFallbackAcceptableCtx(ctx context.Context, req func() error, fallback Fallback,
acceptable Acceptable) error
// acceptable checks if it's a successful call, even if the err is not nil.
DoWithFallbackAcceptable(req func() error, fallback func(err error) error, acceptable Acceptable) error
}
// Fallback is the func to be called if the request is rejected.
Fallback func(err error) error
// Option defines the method to customize a Breaker.
Option func(breaker *circuitBreaker)
@@ -98,12 +86,12 @@ type (
internalThrottle interface {
allow() (internalPromise, error)
doReq(req func() error, fallback Fallback, acceptable Acceptable) error
doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error
}
throttle interface {
allow() (Promise, error)
doReq(req func() error, fallback Fallback, acceptable Acceptable) error
doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error
}
)
@@ -126,71 +114,23 @@ func (cb *circuitBreaker) Allow() (Promise, error) {
return cb.throttle.allow()
}
func (cb *circuitBreaker) AllowCtx(ctx context.Context) (Promise, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
return cb.Allow()
}
}
func (cb *circuitBreaker) Do(req func() error) error {
return cb.throttle.doReq(req, nil, defaultAcceptable)
}
func (cb *circuitBreaker) DoCtx(ctx context.Context, req func() error) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
return cb.Do(req)
}
}
func (cb *circuitBreaker) DoWithAcceptable(req func() error, acceptable Acceptable) error {
return cb.throttle.doReq(req, nil, acceptable)
}
func (cb *circuitBreaker) DoWithAcceptableCtx(ctx context.Context, req func() error,
acceptable Acceptable) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
return cb.DoWithAcceptable(req, acceptable)
}
}
func (cb *circuitBreaker) DoWithFallback(req func() error, fallback Fallback) error {
func (cb *circuitBreaker) DoWithFallback(req func() error, fallback func(err error) error) error {
return cb.throttle.doReq(req, fallback, defaultAcceptable)
}
func (cb *circuitBreaker) DoWithFallbackCtx(ctx context.Context, req func() error,
fallback Fallback) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
return cb.DoWithFallback(req, fallback)
}
}
func (cb *circuitBreaker) DoWithFallbackAcceptable(req func() error, fallback Fallback,
func (cb *circuitBreaker) DoWithFallbackAcceptable(req func() error, fallback func(err error) error,
acceptable Acceptable) error {
return cb.throttle.doReq(req, fallback, acceptable)
}
func (cb *circuitBreaker) DoWithFallbackAcceptableCtx(ctx context.Context, req func() error,
fallback Fallback, acceptable Acceptable) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
return cb.DoWithFallbackAcceptable(req, fallback, acceptable)
}
}
func (cb *circuitBreaker) Name() string {
return cb.name
}
@@ -228,7 +168,7 @@ func (lt loggedThrottle) allow() (Promise, error) {
}, lt.logError(err)
}
func (lt loggedThrottle) doReq(req func() error, fallback Fallback, acceptable Acceptable) error {
func (lt loggedThrottle) doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error {
return lt.logError(lt.internalThrottle.doReq(req, fallback, func(err error) bool {
accept := acceptable(err)
if !accept && err != nil {
@@ -239,7 +179,7 @@ func (lt loggedThrottle) doReq(req func() error, fallback Fallback, acceptable A
}
func (lt loggedThrottle) logError(err error) error {
if errors.Is(err, ErrServiceUnavailable) {
if err == ErrServiceUnavailable {
// if circuit open, not possible to have empty error window
stat.Report(fmt.Sprintf(
"proc(%s/%d), callee: %s, breaker is open and requests dropped\nlast errors:\n%s",
@@ -258,14 +198,14 @@ type errorWindow struct {
func (ew *errorWindow) add(reason string) {
ew.lock.Lock()
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(time.TimeOnly), reason)
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(timeFormat), reason)
ew.index = (ew.index + 1) % numHistoryReasons
ew.count = min(ew.count+1, numHistoryReasons)
ew.count = mathx.MinInt(ew.count+1, numHistoryReasons)
ew.lock.Unlock()
}
func (ew *errorWindow) String() string {
reasons := make([]string, 0, ew.count)
var reasons []string
ew.lock.Lock()
// reverse order

View File

@@ -1,13 +1,11 @@
package breaker
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stat"
@@ -18,274 +16,10 @@ func init() {
}
func TestCircuitBreaker_Allow(t *testing.T) {
t.Run("allow", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
_, err := b.Allow()
assert.Nil(t, err)
})
t.Run("allow with ctx", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
_, err := b.AllowCtx(context.Background())
assert.Nil(t, err)
})
t.Run("allow with ctx timeout", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
defer cancel()
time.Sleep(time.Millisecond)
_, err := b.AllowCtx(ctx)
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
t.Run("allow with ctx cancel", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
for i := 0; i < 100; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
cancel()
_, err := b.AllowCtx(ctx)
assert.ErrorIs(t, err, context.Canceled)
}
_, err := b.AllowCtx(context.Background())
assert.NoError(t, err)
})
}
func TestCircuitBreaker_Do(t *testing.T) {
t.Run("do", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.Do(func() error {
return nil
})
assert.Nil(t, err)
})
t.Run("do with ctx", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoCtx(context.Background(), func() error {
return nil
})
assert.Nil(t, err)
})
t.Run("do with ctx timeout", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
defer cancel()
time.Sleep(time.Millisecond)
err := b.DoCtx(ctx, func() error {
return nil
})
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
t.Run("do with ctx cancel", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
for i := 0; i < 100; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
cancel()
err := b.DoCtx(ctx, func() error {
return nil
})
assert.ErrorIs(t, err, context.Canceled)
}
assert.NoError(t, b.DoCtx(context.Background(), func() error {
return nil
}))
})
}
func TestCircuitBreaker_DoWithAcceptable(t *testing.T) {
t.Run("doWithAcceptable", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoWithAcceptable(func() error {
return nil
}, func(err error) bool {
return true
})
assert.Nil(t, err)
})
t.Run("doWithAcceptable with ctx", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoWithAcceptableCtx(context.Background(), func() error {
return nil
}, func(err error) bool {
return true
})
assert.Nil(t, err)
})
t.Run("doWithAcceptable with ctx timeout", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
defer cancel()
time.Sleep(time.Millisecond)
err := b.DoWithAcceptableCtx(ctx, func() error {
return nil
}, func(err error) bool {
return true
})
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
t.Run("doWithAcceptable with ctx cancel", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
for i := 0; i < 100; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
cancel()
err := b.DoWithAcceptableCtx(ctx, func() error {
return nil
}, func(err error) bool {
return true
})
assert.ErrorIs(t, err, context.Canceled)
}
assert.NoError(t, b.DoWithAcceptableCtx(context.Background(), func() error {
return nil
}, func(err error) bool {
return true
}))
})
}
func TestCircuitBreaker_DoWithFallback(t *testing.T) {
t.Run("doWithFallback", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoWithFallback(func() error {
return nil
}, func(err error) error {
return err
})
assert.Nil(t, err)
})
t.Run("doWithFallback with ctx", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoWithFallbackCtx(context.Background(), func() error {
return nil
}, func(err error) error {
return err
})
assert.Nil(t, err)
})
t.Run("doWithFallback with ctx timeout", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
defer cancel()
time.Sleep(time.Millisecond)
err := b.DoWithFallbackCtx(ctx, func() error {
return nil
}, func(err error) error {
return err
})
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
t.Run("doWithFallback with ctx cancel", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
for i := 0; i < 100; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
cancel()
err := b.DoWithFallbackCtx(ctx, func() error {
return nil
}, func(err error) error {
return err
})
assert.ErrorIs(t, err, context.Canceled)
}
assert.NoError(t, b.DoWithFallbackCtx(context.Background(), func() error {
return nil
}, func(err error) error {
return err
}))
})
}
func TestCircuitBreaker_DoWithFallbackAcceptable(t *testing.T) {
t.Run("doWithFallbackAcceptable", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoWithFallbackAcceptable(func() error {
return nil
}, func(err error) error {
return err
}, func(err error) bool {
return true
})
assert.Nil(t, err)
})
t.Run("doWithFallbackAcceptable with ctx", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoWithFallbackAcceptableCtx(context.Background(), func() error {
return nil
}, func(err error) error {
return err
}, func(err error) bool {
return true
})
assert.Nil(t, err)
})
t.Run("doWithFallbackAcceptable with ctx timeout", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
defer cancel()
time.Sleep(time.Millisecond)
err := b.DoWithFallbackAcceptableCtx(ctx, func() error {
return nil
}, func(err error) error {
return err
}, func(err error) bool {
return true
})
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
t.Run("doWithFallbackAcceptable with ctx cancel", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
for i := 0; i < 100; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
cancel()
err := b.DoWithFallbackAcceptableCtx(ctx, func() error {
return nil
}, func(err error) error {
return err
}, func(err error) bool {
return true
})
assert.ErrorIs(t, err, context.Canceled)
}
assert.NoError(t, b.DoWithFallbackAcceptableCtx(context.Background(), func() error {
return nil
}, func(err error) error {
return err
}, func(err error) bool {
return true
}))
})
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
_, err := b.Allow()
assert.Nil(t, err)
}
func TestLogReason(t *testing.T) {

View File

@@ -1,9 +1,6 @@
package breaker
import (
"context"
"sync"
)
import "sync"
var (
lock sync.RWMutex
@@ -17,13 +14,6 @@ func Do(name string, req func() error) error {
})
}
// DoCtx calls Breaker.DoCtx on the Breaker with given name.
func DoCtx(ctx context.Context, name string, req func() error) error {
return do(name, func(b Breaker) error {
return b.DoCtx(ctx, req)
})
}
// DoWithAcceptable calls Breaker.DoWithAcceptable on the Breaker with given name.
func DoWithAcceptable(name string, req func() error, acceptable Acceptable) error {
return do(name, func(b Breaker) error {
@@ -31,44 +21,21 @@ func DoWithAcceptable(name string, req func() error, acceptable Acceptable) erro
})
}
// DoWithAcceptableCtx calls Breaker.DoWithAcceptableCtx on the Breaker with given name.
func DoWithAcceptableCtx(ctx context.Context, name string, req func() error,
acceptable Acceptable) error {
return do(name, func(b Breaker) error {
return b.DoWithAcceptableCtx(ctx, req, acceptable)
})
}
// DoWithFallback calls Breaker.DoWithFallback on the Breaker with given name.
func DoWithFallback(name string, req func() error, fallback Fallback) error {
func DoWithFallback(name string, req func() error, fallback func(err error) error) error {
return do(name, func(b Breaker) error {
return b.DoWithFallback(req, fallback)
})
}
// DoWithFallbackCtx calls Breaker.DoWithFallbackCtx on the Breaker with given name.
func DoWithFallbackCtx(ctx context.Context, name string, req func() error, fallback Fallback) error {
return do(name, func(b Breaker) error {
return b.DoWithFallbackCtx(ctx, req, fallback)
})
}
// DoWithFallbackAcceptable calls Breaker.DoWithFallbackAcceptable on the Breaker with given name.
func DoWithFallbackAcceptable(name string, req func() error, fallback Fallback,
func DoWithFallbackAcceptable(name string, req func() error, fallback func(err error) error,
acceptable Acceptable) error {
return do(name, func(b Breaker) error {
return b.DoWithFallbackAcceptable(req, fallback, acceptable)
})
}
// DoWithFallbackAcceptableCtx calls Breaker.DoWithFallbackAcceptableCtx on the Breaker with given name.
func DoWithFallbackAcceptableCtx(ctx context.Context, name string, req func() error,
fallback Fallback, acceptable Acceptable) error {
return do(name, func(b Breaker) error {
return b.DoWithFallbackAcceptableCtx(ctx, req, fallback, acceptable)
})
}
// GetBreaker returns the Breaker with the given name.
func GetBreaker(name string) Breaker {
lock.RLock()
@@ -92,7 +59,7 @@ func GetBreaker(name string) Breaker {
// NoBreakerFor disables the circuit breaker for the given name.
func NoBreakerFor(name string) {
lock.Lock()
breakers[name] = NopBreaker()
breakers[name] = newNoOpBreaker()
lock.Unlock()
}

View File

@@ -1,7 +1,6 @@
package breaker
import (
"context"
"errors"
"fmt"
"testing"
@@ -23,9 +22,6 @@ func TestBreakersDo(t *testing.T) {
assert.Equal(t, errDummy, Do("any", func() error {
return errDummy
}))
assert.Equal(t, errDummy, DoCtx(context.Background(), "any", func() error {
return errDummy
}))
}
func TestBreakersDoWithAcceptable(t *testing.T) {
@@ -34,7 +30,7 @@ func TestBreakersDoWithAcceptable(t *testing.T) {
assert.Equal(t, errDummy, GetBreaker("anyone").DoWithAcceptable(func() error {
return errDummy
}, func(err error) bool {
return err == nil || errors.Is(err, errDummy)
return err == nil || err == errDummy
}))
}
verify(t, func() bool {
@@ -42,13 +38,6 @@ func TestBreakersDoWithAcceptable(t *testing.T) {
return nil
}) == nil
})
verify(t, func() bool {
return DoWithAcceptableCtx(context.Background(), "anyone", func() error {
return nil
}, func(err error) bool {
return true
}) == nil
})
for i := 0; i < 10000; i++ {
err := DoWithAcceptable("another", func() error {
@@ -56,12 +45,12 @@ func TestBreakersDoWithAcceptable(t *testing.T) {
}, func(err error) bool {
return err == nil
})
assert.True(t, errors.Is(err, errDummy) || errors.Is(err, ErrServiceUnavailable))
assert.True(t, err == errDummy || err == ErrServiceUnavailable)
}
verify(t, func() bool {
return errors.Is(Do("another", func() error {
return ErrServiceUnavailable == Do("another", func() error {
return nil
}), ErrServiceUnavailable)
})
})
}
@@ -86,24 +75,18 @@ func TestBreakersFallback(t *testing.T) {
}, func(err error) error {
return nil
})
assert.True(t, err == nil || errors.Is(err, errDummy))
err = DoWithFallbackCtx(context.Background(), "fallback", func() error {
return errDummy
}, func(err error) error {
return nil
})
assert.True(t, err == nil || errors.Is(err, errDummy))
assert.True(t, err == nil || err == errDummy)
}
verify(t, func() bool {
return errors.Is(Do("fallback", func() error {
return ErrServiceUnavailable == Do("fallback", func() error {
return nil
}), ErrServiceUnavailable)
})
})
}
func TestBreakersAcceptableFallback(t *testing.T) {
errDummy := errors.New("any")
for i := 0; i < 5000; i++ {
for i := 0; i < 10000; i++ {
err := DoWithFallbackAcceptable("acceptablefallback", func() error {
return errDummy
}, func(err error) error {
@@ -111,20 +94,12 @@ func TestBreakersAcceptableFallback(t *testing.T) {
}, func(err error) bool {
return err == nil
})
assert.True(t, err == nil || errors.Is(err, errDummy))
err = DoWithFallbackAcceptableCtx(context.Background(), "acceptablefallback", func() error {
return errDummy
}, func(err error) error {
return nil
}, func(err error) bool {
return err == nil
})
assert.True(t, err == nil || errors.Is(err, errDummy))
assert.True(t, err == nil || err == errDummy)
}
verify(t, func() bool {
return errors.Is(Do("acceptablefallback", func() error {
return ErrServiceUnavailable == Do("acceptablefallback", func() error {
return nil
}), ErrServiceUnavailable)
})
})
}
@@ -135,5 +110,5 @@ func verify(t *testing.T, fn func() bool) {
count++
}
}
assert.True(t, count >= 75, fmt.Sprintf("should be greater than 75, actual %d", count))
assert.True(t, count >= 80, fmt.Sprintf("should be greater than 80, actual %d", count))
}

View File

@@ -1,48 +0,0 @@
package breaker
const (
success = iota
fail
drop
)
// bucket defines the bucket that holds sum and num of additions.
type bucket struct {
Sum int64
Success int64
Failure int64
Drop int64
}
func (b *bucket) Add(v int64) {
switch v {
case fail:
b.fail()
case drop:
b.drop()
default:
b.succeed()
}
}
func (b *bucket) Reset() {
b.Sum = 0
b.Success = 0
b.Failure = 0
b.Drop = 0
}
func (b *bucket) drop() {
b.Sum++
b.Drop++
}
func (b *bucket) fail() {
b.Sum++
b.Failure++
}
func (b *bucket) succeed() {
b.Sum++
b.Success++
}

View File

@@ -1,43 +0,0 @@
package breaker
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestBucketAdd(t *testing.T) {
b := &bucket{}
// Test succeed
b.Add(0) // Using 0 for success
assert.Equal(t, int64(1), b.Sum, "Sum should be incremented")
assert.Equal(t, int64(1), b.Success, "Success should be incremented")
assert.Equal(t, int64(0), b.Failure, "Failure should not be incremented")
assert.Equal(t, int64(0), b.Drop, "Drop should not be incremented")
// Test failure
b.Add(fail)
assert.Equal(t, int64(2), b.Sum, "Sum should be incremented")
assert.Equal(t, int64(1), b.Failure, "Failure should be incremented")
assert.Equal(t, int64(0), b.Drop, "Drop should not be incremented")
// Test drop
b.Add(drop)
assert.Equal(t, int64(3), b.Sum, "Sum should be incremented")
assert.Equal(t, int64(1), b.Drop, "Drop should be incremented")
}
func TestBucketReset(t *testing.T) {
b := &bucket{
Sum: 3,
Success: 1,
Failure: 1,
Drop: 1,
}
b.Reset()
assert.Equal(t, int64(0), b.Sum, "Sum should be reset to 0")
assert.Equal(t, int64(0), b.Success, "Success should be reset to 0")
assert.Equal(t, int64(0), b.Failure, "Failure should be reset to 0")
assert.Equal(t, int64(0), b.Drop, "Drop should be reset to 0")
}

View File

@@ -1,87 +1,57 @@
package breaker
import (
"math"
"time"
"github.com/zeromicro/go-zero/core/collection"
"github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/timex"
)
const (
// 250ms for bucket duration
window = time.Second * 10
buckets = 40
forcePassDuration = time.Second
k = 1.5
minK = 1.1
protection = 5
window = time.Second * 10
buckets = 40
k = 1.5
protection = 5
)
// googleBreaker is a netflixBreaker pattern from google.
// see Client-Side Throttling section in https://landing.google.com/sre/sre-book/chapters/handling-overload/
type (
googleBreaker struct {
k float64
stat *collection.RollingWindow[int64, *bucket]
proba *mathx.Proba
lastPass *syncx.AtomicDuration
}
windowResult struct {
accepts int64
total int64
failingBuckets int64
workingBuckets int64
}
)
type googleBreaker struct {
k float64
stat *collection.RollingWindow
proba *mathx.Proba
}
func newGoogleBreaker() *googleBreaker {
bucketDuration := time.Duration(int64(window) / int64(buckets))
st := collection.NewRollingWindow[int64, *bucket](func() *bucket {
return new(bucket)
}, buckets, bucketDuration)
st := collection.NewRollingWindow(buckets, bucketDuration)
return &googleBreaker{
stat: st,
k: k,
proba: mathx.NewProba(),
lastPass: syncx.NewAtomicDuration(),
stat: st,
k: k,
proba: mathx.NewProba(),
}
}
func (b *googleBreaker) accept() error {
var w float64
history := b.history()
w = b.k - (b.k-minK)*float64(history.failingBuckets)/buckets
weightedAccepts := mathx.AtLeast(w, minK) * float64(history.accepts)
accepts, total := b.history()
weightedAccepts := b.k * float64(accepts)
// https://landing.google.com/sre/sre-book/chapters/handling-overload/#eq2101
// for better performance, no need to care about the negative ratio
dropRatio := (float64(history.total-protection) - weightedAccepts) / float64(history.total+1)
dropRatio := math.Max(0, (float64(total-protection)-weightedAccepts)/float64(total+1))
if dropRatio <= 0 {
return nil
}
lastPass := b.lastPass.Load()
if lastPass > 0 && timex.Since(lastPass) > forcePassDuration {
b.lastPass.Set(timex.Now())
return nil
}
dropRatio *= float64(buckets-history.workingBuckets) / buckets
if b.proba.TrueOnProba(dropRatio) {
return ErrServiceUnavailable
}
b.lastPass.Set(timex.Now())
return nil
}
func (b *googleBreaker) allow() (internalPromise, error) {
if err := b.accept(); err != nil {
b.markDrop()
return nil, err
}
@@ -90,9 +60,8 @@ func (b *googleBreaker) allow() (internalPromise, error) {
}, nil
}
func (b *googleBreaker) doReq(req func() error, fallback Fallback, acceptable Acceptable) error {
func (b *googleBreaker) doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error {
if err := b.accept(); err != nil {
b.markDrop()
if fallback != nil {
return fallback(err)
}
@@ -100,55 +69,38 @@ func (b *googleBreaker) doReq(req func() error, fallback Fallback, acceptable Ac
return err
}
var succ bool
defer func() {
// if req() panic, success is false, mark as failure
if succ {
b.markSuccess()
} else {
if e := recover(); e != nil {
b.markFailure()
panic(e)
}
}()
err := req()
if acceptable(err) {
succ = true
b.markSuccess()
} else {
b.markFailure()
}
return err
}
func (b *googleBreaker) markDrop() {
b.stat.Add(drop)
func (b *googleBreaker) markSuccess() {
b.stat.Add(1)
}
func (b *googleBreaker) markFailure() {
b.stat.Add(fail)
b.stat.Add(0)
}
func (b *googleBreaker) markSuccess() {
b.stat.Add(success)
}
func (b *googleBreaker) history() windowResult {
var result windowResult
b.stat.Reduce(func(b *bucket) {
result.accepts += b.Success
result.total += b.Sum
if b.Failure > 0 {
result.workingBuckets = 0
} else if b.Success > 0 {
result.workingBuckets++
}
if b.Success > 0 {
result.failingBuckets = 0
} else if b.Failure > 0 {
result.failingBuckets++
}
func (b *googleBreaker) history() (accepts, total int64) {
b.stat.Reduce(func(b *collection.Bucket) {
accepts += int64(b.Sum)
total += b.Count
})
return result
return
}
type googlePromise struct {

View File

@@ -10,7 +10,6 @@ import (
"github.com/zeromicro/go-zero/core/collection"
"github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/core/syncx"
)
const (
@@ -23,14 +22,11 @@ func init() {
}
func getGoogleBreaker() *googleBreaker {
st := collection.NewRollingWindow[int64, *bucket](func() *bucket {
return new(bucket)
}, testBuckets, testInterval)
st := collection.NewRollingWindow(testBuckets, testInterval)
return &googleBreaker{
stat: st,
k: 5,
proba: mathx.NewProba(),
lastPass: syncx.NewAtomicDuration(),
stat: st,
k: 5,
proba: mathx.NewProba(),
}
}
@@ -67,33 +63,6 @@ func TestGoogleBreakerOpen(t *testing.T) {
})
}
func TestGoogleBreakerRecover(t *testing.T) {
st := collection.NewRollingWindow[int64, *bucket](func() *bucket {
return new(bucket)
}, testBuckets*2, testInterval)
b := &googleBreaker{
stat: st,
k: k,
proba: mathx.NewProba(),
lastPass: syncx.NewAtomicDuration(),
}
for i := 0; i < testBuckets; i++ {
for j := 0; j < 100; j++ {
b.stat.Add(1)
}
time.Sleep(testInterval)
}
for i := 0; i < testBuckets; i++ {
for j := 0; j < 100; j++ {
b.stat.Add(0)
}
time.Sleep(testInterval)
}
verify(t, func() bool {
return b.accept() == nil
})
}
func TestGoogleBreakerFallback(t *testing.T) {
b := getGoogleBreaker()
markSuccess(b, 1)
@@ -120,50 +89,13 @@ func TestGoogleBreakerReject(t *testing.T) {
}, nil, defaultAcceptable))
}
func TestGoogleBreakerMoreFallingBuckets(t *testing.T) {
t.Parallel()
t.Run("more falling buckets", func(t *testing.T) {
b := getGoogleBreaker()
func() {
stopChan := time.After(testInterval * 6)
for {
time.Sleep(time.Millisecond)
select {
case <-stopChan:
return
default:
assert.Error(t, b.doReq(func() error {
return errors.New("foo")
}, func(err error) error {
return err
}, func(err error) bool {
return err == nil
}))
}
}
}()
var count int
for i := 0; i < 100; i++ {
if errors.Is(b.doReq(func() error {
return ErrServiceUnavailable
}, nil, defaultAcceptable), ErrServiceUnavailable) {
count++
}
}
assert.True(t, count > 90)
})
}
func TestGoogleBreakerAcceptable(t *testing.T) {
b := getGoogleBreaker()
errAcceptable := errors.New("any")
assert.Equal(t, errAcceptable, b.doReq(func() error {
return errAcceptable
}, nil, func(err error) bool {
return errors.Is(err, errAcceptable)
return err == errAcceptable
}))
}
@@ -173,7 +105,7 @@ func TestGoogleBreakerNotAcceptable(t *testing.T) {
assert.Equal(t, errAcceptable, b.doReq(func() error {
return errAcceptable
}, nil, func(err error) bool {
return !errors.Is(err, errAcceptable)
return err != errAcceptable
}))
}
@@ -232,38 +164,41 @@ func TestGoogleBreakerSelfProtection(t *testing.T) {
}
func TestGoogleBreakerHistory(t *testing.T) {
var b *googleBreaker
var accepts, total int64
sleep := testInterval
t.Run("accepts == total", func(t *testing.T) {
b := getGoogleBreaker()
b = getGoogleBreaker()
markSuccessWithDuration(b, 10, sleep/2)
result := b.history()
assert.Equal(t, int64(10), result.accepts)
assert.Equal(t, int64(10), result.total)
accepts, total = b.history()
assert.Equal(t, int64(10), accepts)
assert.Equal(t, int64(10), total)
})
t.Run("fail == total", func(t *testing.T) {
b := getGoogleBreaker()
b = getGoogleBreaker()
markFailedWithDuration(b, 10, sleep/2)
result := b.history()
assert.Equal(t, int64(0), result.accepts)
assert.Equal(t, int64(10), result.total)
accepts, total = b.history()
assert.Equal(t, int64(0), accepts)
assert.Equal(t, int64(10), total)
})
t.Run("accepts = 1/2 * total, fail = 1/2 * total", func(t *testing.T) {
b := getGoogleBreaker()
b = getGoogleBreaker()
markFailedWithDuration(b, 5, sleep/2)
markSuccessWithDuration(b, 5, sleep/2)
result := b.history()
assert.Equal(t, int64(5), result.accepts)
assert.Equal(t, int64(10), result.total)
accepts, total = b.history()
assert.Equal(t, int64(5), accepts)
assert.Equal(t, int64(10), total)
})
t.Run("auto reset rolling counter", func(t *testing.T) {
b := getGoogleBreaker()
b = getGoogleBreaker()
time.Sleep(testInterval * testBuckets)
result := b.history()
assert.Equal(t, int64(0), result.accepts)
assert.Equal(t, int64(0), result.total)
accepts, total = b.history()
assert.Equal(t, int64(0), accepts)
assert.Equal(t, int64(0), total)
})
}
@@ -271,7 +206,7 @@ func BenchmarkGoogleBreakerAllow(b *testing.B) {
breaker := getGoogleBreaker()
b.ResetTimer()
for i := 0; i <= b.N; i++ {
_ = breaker.accept()
breaker.accept()
if i%2 == 0 {
breaker.markSuccess()
} else {
@@ -280,16 +215,6 @@ func BenchmarkGoogleBreakerAllow(b *testing.B) {
}
}
func BenchmarkGoogleBreakerDoReq(b *testing.B) {
breaker := getGoogleBreaker()
b.ResetTimer()
for i := 0; i <= b.N; i++ {
_ = breaker.doReq(func() error {
return nil
}, nil, defaultAcceptable)
}
}
func markSuccess(b *googleBreaker, count int) {
for i := 0; i < count; i++ {
p, err := b.allow()

View File

@@ -1,58 +1,35 @@
package breaker
import "context"
const noOpBreakerName = "nopBreaker"
const nopBreakerName = "nopBreaker"
type noOpBreaker struct{}
type nopBreaker struct{}
// NopBreaker returns a breaker that never trigger breaker circuit.
func NopBreaker() Breaker {
return nopBreaker{}
func newNoOpBreaker() Breaker {
return noOpBreaker{}
}
func (b nopBreaker) Name() string {
return nopBreakerName
func (b noOpBreaker) Name() string {
return noOpBreakerName
}
func (b nopBreaker) Allow() (Promise, error) {
func (b noOpBreaker) Allow() (Promise, error) {
return nopPromise{}, nil
}
func (b nopBreaker) AllowCtx(_ context.Context) (Promise, error) {
return nopPromise{}, nil
}
func (b nopBreaker) Do(req func() error) error {
func (b noOpBreaker) Do(req func() error) error {
return req()
}
func (b nopBreaker) DoCtx(_ context.Context, req func() error) error {
func (b noOpBreaker) DoWithAcceptable(req func() error, _ Acceptable) error {
return req()
}
func (b nopBreaker) DoWithAcceptable(req func() error, _ Acceptable) error {
func (b noOpBreaker) DoWithFallback(req func() error, _ func(err error) error) error {
return req()
}
func (b nopBreaker) DoWithAcceptableCtx(_ context.Context, req func() error, _ Acceptable) error {
return req()
}
func (b nopBreaker) DoWithFallback(req func() error, _ Fallback) error {
return req()
}
func (b nopBreaker) DoWithFallbackCtx(_ context.Context, req func() error, _ Fallback) error {
return req()
}
func (b nopBreaker) DoWithFallbackAcceptable(req func() error, _ Fallback, _ Acceptable) error {
return req()
}
func (b nopBreaker) DoWithFallbackAcceptableCtx(_ context.Context, req func() error,
_ Fallback, _ Acceptable) error {
func (b noOpBreaker) DoWithFallbackAcceptable(req func() error, _ func(err error) error,
_ Acceptable) error {
return req()
}

View File

@@ -1,7 +1,6 @@
package breaker
import (
"context"
"errors"
"testing"
@@ -9,11 +8,9 @@ import (
)
func TestNopBreaker(t *testing.T) {
b := NopBreaker()
assert.Equal(t, nopBreakerName, b.Name())
_, err := b.Allow()
assert.Nil(t, err)
p, err := b.AllowCtx(context.Background())
b := newNoOpBreaker()
assert.Equal(t, noOpBreakerName, b.Name())
p, err := b.Allow()
assert.Nil(t, err)
p.Accept()
for i := 0; i < 1000; i++ {
@@ -24,34 +21,18 @@ func TestNopBreaker(t *testing.T) {
assert.Nil(t, b.Do(func() error {
return nil
}))
assert.Nil(t, b.DoCtx(context.Background(), func() error {
return nil
}))
assert.Nil(t, b.DoWithAcceptable(func() error {
return nil
}, defaultAcceptable))
assert.Nil(t, b.DoWithAcceptableCtx(context.Background(), func() error {
return nil
}, defaultAcceptable))
errDummy := errors.New("any")
assert.Equal(t, errDummy, b.DoWithFallback(func() error {
return errDummy
}, func(err error) error {
return nil
}))
assert.Equal(t, errDummy, b.DoWithFallbackCtx(context.Background(), func() error {
return errDummy
}, func(err error) error {
return nil
}))
assert.Equal(t, errDummy, b.DoWithFallbackAcceptable(func() error {
return errDummy
}, func(err error) error {
return nil
}, defaultAcceptable))
assert.Equal(t, errDummy, b.DoWithFallbackAcceptableCtx(context.Background(), func() error {
return errDummy
}, func(err error) error {
return nil
}, defaultAcceptable))
}

View File

@@ -23,7 +23,7 @@ var (
zero = big.NewInt(0)
)
// DhKey defines the Diffie-Hellman key.
// DhKey defines the Diffie Hellman key.
type DhKey struct {
PriKey *big.Int
PubKey *big.Int
@@ -46,7 +46,7 @@ func ComputeKey(pubKey, priKey *big.Int) (*big.Int, error) {
return new(big.Int).Exp(pubKey, priKey, p), nil
}
// GenerateKey returns a Diffie-Hellman key.
// GenerateKey returns a Diffie Hellman key.
func GenerateKey() (*DhKey, error) {
var err error
var x *big.Int

View File

@@ -2,8 +2,6 @@ package codec
import (
"bytes"
"compress/gzip"
"errors"
"fmt"
"testing"
@@ -23,45 +21,3 @@ func TestGzip(t *testing.T) {
assert.True(t, len(bs) < buf.Len())
assert.Equal(t, buf.Bytes(), actual)
}
func TestGunzip(t *testing.T) {
tests := []struct {
name string
input []byte
expected []byte
expectedErr error
}{
{
name: "valid input",
input: func() []byte {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
gz.Write([]byte("hello"))
gz.Close()
return buf.Bytes()
}(),
expected: []byte("hello"),
expectedErr: nil,
},
{
name: "invalid input",
input: []byte("invalid input"),
expected: nil,
expectedErr: gzip.ErrHeader,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result, err := Gunzip(test.input)
if !bytes.Equal(result, test.expected) {
t.Errorf("unexpected result: %v", result)
}
if !errors.Is(err, test.expectedErr) {
t.Errorf("unexpected error: %v", err)
}
})
}
}

View File

@@ -2,7 +2,6 @@ package codec
import (
"encoding/base64"
"os"
"testing"
"github.com/stretchr/testify/assert"
@@ -42,7 +41,6 @@ func TestCryption(t *testing.T) {
file, err := fs.TempFilenameWithText(priKey)
assert.Nil(t, err)
defer os.Remove(file)
dec, err := NewRsaDecrypter(file)
assert.Nil(t, err)
actual, err := dec.Decrypt(ret)

View File

@@ -30,7 +30,7 @@ type (
Cache struct {
name string
lock sync.Mutex
data map[string]any
data map[string]interface{}
expire time.Duration
timingWheel *TimingWheel
lruCache lru
@@ -43,7 +43,7 @@ type (
// NewCache returns a Cache with given expire.
func NewCache(expire time.Duration, opts ...CacheOption) (*Cache, error) {
cache := &Cache{
data: make(map[string]any),
data: make(map[string]interface{}),
expire: expire,
lruCache: emptyLruCache,
barrier: syncx.NewSingleFlight(),
@@ -59,7 +59,7 @@ func NewCache(expire time.Duration, opts ...CacheOption) (*Cache, error) {
}
cache.stats = newCacheStat(cache.name, cache.size)
timingWheel, err := NewTimingWheel(time.Second, slots, func(k, v any) {
timingWheel, err := NewTimingWheel(time.Second, slots, func(k, v interface{}) {
key, ok := k.(string)
if !ok {
return
@@ -81,15 +81,11 @@ func (c *Cache) Del(key string) {
delete(c.data, key)
c.lruCache.remove(key)
c.lock.Unlock()
// RemoveTimer is called outside the lock to avoid performance impact from this
// potentially time-consuming operation. Data integrity is maintained by lruCache,
// which will eventually evict any remaining entries when capacity is exceeded.
c.timingWheel.RemoveTimer(key)
}
// Get returns the item with the given key from c.
func (c *Cache) Get(key string) (any, bool) {
func (c *Cache) Get(key string) (interface{}, bool) {
value, ok := c.doGet(key)
if ok {
c.stats.IncrementHit()
@@ -101,12 +97,12 @@ func (c *Cache) Get(key string) (any, bool) {
}
// Set sets value into c with key.
func (c *Cache) Set(key string, value any) {
func (c *Cache) Set(key string, value interface{}) {
c.SetWithExpire(key, value, c.expire)
}
// SetWithExpire sets value into c with key and expire with the given value.
func (c *Cache) SetWithExpire(key string, value any, expire time.Duration) {
func (c *Cache) SetWithExpire(key string, value interface{}, expire time.Duration) {
c.lock.Lock()
_, ok := c.data[key]
c.data[key] = value
@@ -124,16 +120,16 @@ func (c *Cache) SetWithExpire(key string, value any, expire time.Duration) {
// Take returns the item with the given key.
// If the item is in c, return it directly.
// If not, use fetch method to get the item, set into c and return it.
func (c *Cache) Take(key string, fetch func() (any, error)) (any, error) {
func (c *Cache) Take(key string, fetch func() (interface{}, error)) (interface{}, error) {
if val, ok := c.doGet(key); ok {
c.stats.IncrementHit()
return val, nil
}
var fresh bool
val, err := c.barrier.Do(key, func() (any, error) {
// because O(1) on map search in memory, and fetch is an IO query,
// so we do double-check, cache might be taken by another call
val, err := c.barrier.Do(key, func() (interface{}, error) {
// because O(1) on map search in memory, and fetch is an IO query
// so we do double check, cache might be taken by another call
if val, ok := c.doGet(key); ok {
return val, nil
}
@@ -161,7 +157,7 @@ func (c *Cache) Take(key string, fetch func() (any, error)) (any, error) {
return val, nil
}
func (c *Cache) doGet(key string) (any, bool) {
func (c *Cache) doGet(key string) (interface{}, bool) {
c.lock.Lock()
defer c.lock.Unlock()

View File

@@ -52,7 +52,7 @@ func TestCacheTake(t *testing.T) {
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
cache.Take("first", func() (any, error) {
cache.Take("first", func() (interface{}, error) {
atomic.AddInt32(&count, 1)
time.Sleep(time.Millisecond * 100)
return "first element", nil
@@ -76,7 +76,7 @@ func TestCacheTakeExists(t *testing.T) {
wg.Add(1)
go func() {
cache.Set("first", "first element")
cache.Take("first", func() (any, error) {
cache.Take("first", func() (interface{}, error) {
atomic.AddInt32(&count, 1)
time.Sleep(time.Millisecond * 100)
return "first element", nil
@@ -99,7 +99,7 @@ func TestCacheTakeError(t *testing.T) {
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
_, err := cache.Take("first", func() (any, error) {
_, err := cache.Take("first", func() (interface{}, error) {
atomic.AddInt32(&count, 1)
time.Sleep(time.Millisecond * 100)
return "", errDummy

View File

@@ -5,7 +5,7 @@ import "sync"
// A Queue is a FIFO queue.
type Queue struct {
lock sync.Mutex
elements []any
elements []interface{}
size int
head int
tail int
@@ -15,7 +15,7 @@ type Queue struct {
// NewQueue returns a Queue object.
func NewQueue(size int) *Queue {
return &Queue{
elements: make([]any, size),
elements: make([]interface{}, size),
size: size,
}
}
@@ -30,12 +30,12 @@ func (q *Queue) Empty() bool {
}
// Put puts element into q at the last position.
func (q *Queue) Put(element any) {
func (q *Queue) Put(element interface{}) {
q.lock.Lock()
defer q.lock.Unlock()
if q.head == q.tail && q.count > 0 {
nodes := make([]any, len(q.elements)+q.size)
nodes := make([]interface{}, len(q.elements)+q.size)
copy(nodes, q.elements[q.head:])
copy(nodes[len(q.elements)-q.head:], q.elements[:q.head])
q.head = 0
@@ -49,7 +49,7 @@ func (q *Queue) Put(element any) {
}
// Take takes the first element out of q if not empty.
func (q *Queue) Take() (any, bool) {
func (q *Queue) Take() (interface{}, bool) {
q.lock.Lock()
defer q.lock.Unlock()

View File

@@ -4,7 +4,7 @@ import "sync"
// A Ring can be used as fixed size ring.
type Ring struct {
elements []any
elements []interface{}
index int
lock sync.RWMutex
}
@@ -16,44 +16,36 @@ func NewRing(n int) *Ring {
}
return &Ring{
elements: make([]any, n),
elements: make([]interface{}, n),
}
}
// Add adds v into r.
func (r *Ring) Add(v any) {
func (r *Ring) Add(v interface{}) {
r.lock.Lock()
defer r.lock.Unlock()
rlen := len(r.elements)
r.elements[r.index%rlen] = v
r.elements[r.index%len(r.elements)] = v
r.index++
// prevent ring index overflow
if r.index >= rlen<<1 {
r.index -= rlen
}
}
// Take takes all items from r.
func (r *Ring) Take() []any {
func (r *Ring) Take() []interface{} {
r.lock.RLock()
defer r.lock.RUnlock()
var size int
var start int
rlen := len(r.elements)
if r.index > rlen {
size = rlen
start = r.index % rlen
if r.index > len(r.elements) {
size = len(r.elements)
start = r.index % len(r.elements)
} else {
size = r.index
}
elements := make([]any, size)
elements := make([]interface{}, size)
for i := 0; i < size; i++ {
elements[i] = r.elements[(start+i)%rlen]
elements[i] = r.elements[(start+i)%len(r.elements)]
}
return elements

View File

@@ -19,7 +19,7 @@ func TestRingLess(t *testing.T) {
ring.Add(i)
}
elements := ring.Take()
assert.ElementsMatch(t, []any{0, 1, 2}, elements)
assert.ElementsMatch(t, []interface{}{0, 1, 2}, elements)
}
func TestRingMore(t *testing.T) {
@@ -28,7 +28,7 @@ func TestRingMore(t *testing.T) {
ring.Add(i)
}
elements := ring.Take()
assert.ElementsMatch(t, []any{6, 7, 8, 9, 10}, elements)
assert.ElementsMatch(t, []interface{}{6, 7, 8, 9, 10}, elements)
}
func TestRingAdd(t *testing.T) {

View File

@@ -4,28 +4,18 @@ import (
"sync"
"time"
"github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/timex"
)
type (
// BucketInterface is the interface that defines the buckets.
BucketInterface[T Numerical] interface {
Add(v T)
Reset()
}
// Numerical is the interface that restricts the numerical type.
Numerical = mathx.Numerical
// RollingWindowOption let callers customize the RollingWindow.
RollingWindowOption[T Numerical, B BucketInterface[T]] func(rollingWindow *RollingWindow[T, B])
RollingWindowOption func(rollingWindow *RollingWindow)
// RollingWindow defines a rolling window to calculate the events in buckets with the time interval.
RollingWindow[T Numerical, B BucketInterface[T]] struct {
// RollingWindow defines a rolling window to calculate the events in buckets with time interval.
RollingWindow struct {
lock sync.RWMutex
size int
win *window[T, B]
win *window
interval time.Duration
offset int
ignoreCurrent bool
@@ -35,15 +25,14 @@ type (
// NewRollingWindow returns a RollingWindow that with size buckets and time interval,
// use opts to customize the RollingWindow.
func NewRollingWindow[T Numerical, B BucketInterface[T]](newBucket func() B, size int,
interval time.Duration, opts ...RollingWindowOption[T, B]) *RollingWindow[T, B] {
func NewRollingWindow(size int, interval time.Duration, opts ...RollingWindowOption) *RollingWindow {
if size < 1 {
panic("size must be greater than 0")
}
w := &RollingWindow[T, B]{
w := &RollingWindow{
size: size,
win: newWindow[T, B](newBucket, size),
win: newWindow(size),
interval: interval,
lastTime: timex.Now(),
}
@@ -54,7 +43,7 @@ func NewRollingWindow[T Numerical, B BucketInterface[T]](newBucket func() B, siz
}
// Add adds value to current bucket.
func (rw *RollingWindow[T, B]) Add(v T) {
func (rw *RollingWindow) Add(v float64) {
rw.lock.Lock()
defer rw.lock.Unlock()
rw.updateOffset()
@@ -62,13 +51,13 @@ func (rw *RollingWindow[T, B]) Add(v T) {
}
// Reduce runs fn on all buckets, ignore current bucket if ignoreCurrent was set.
func (rw *RollingWindow[T, B]) Reduce(fn func(b B)) {
func (rw *RollingWindow) Reduce(fn func(b *Bucket)) {
rw.lock.RLock()
defer rw.lock.RUnlock()
var diff int
span := rw.span()
// ignore the current bucket, because of partial data
// ignore current bucket, because of partial data
if span == 0 && rw.ignoreCurrent {
diff = rw.size - 1
} else {
@@ -80,7 +69,7 @@ func (rw *RollingWindow[T, B]) Reduce(fn func(b B)) {
}
}
func (rw *RollingWindow[T, B]) span() int {
func (rw *RollingWindow) span() int {
offset := int(timex.Since(rw.lastTime) / rw.interval)
if 0 <= offset && offset < rw.size {
return offset
@@ -89,7 +78,7 @@ func (rw *RollingWindow[T, B]) span() int {
return rw.size
}
func (rw *RollingWindow[T, B]) updateOffset() {
func (rw *RollingWindow) updateOffset() {
span := rw.span()
if span <= 0 {
return
@@ -108,54 +97,54 @@ func (rw *RollingWindow[T, B]) updateOffset() {
}
// Bucket defines the bucket that holds sum and num of additions.
type Bucket[T Numerical] struct {
Sum T
type Bucket struct {
Sum float64
Count int64
}
func (b *Bucket[T]) Add(v T) {
func (b *Bucket) add(v float64) {
b.Sum += v
b.Count++
}
func (b *Bucket[T]) Reset() {
func (b *Bucket) reset() {
b.Sum = 0
b.Count = 0
}
type window[T Numerical, B BucketInterface[T]] struct {
buckets []B
type window struct {
buckets []*Bucket
size int
}
func newWindow[T Numerical, B BucketInterface[T]](newBucket func() B, size int) *window[T, B] {
buckets := make([]B, size)
func newWindow(size int) *window {
buckets := make([]*Bucket, size)
for i := 0; i < size; i++ {
buckets[i] = newBucket()
buckets[i] = new(Bucket)
}
return &window[T, B]{
return &window{
buckets: buckets,
size: size,
}
}
func (w *window[T, B]) add(offset int, v T) {
w.buckets[offset%w.size].Add(v)
func (w *window) add(offset int, v float64) {
w.buckets[offset%w.size].add(v)
}
func (w *window[T, B]) reduce(start, count int, fn func(b B)) {
func (w *window) reduce(start, count int, fn func(b *Bucket)) {
for i := 0; i < count; i++ {
fn(w.buckets[(start+i)%w.size])
}
}
func (w *window[T, B]) resetBucket(offset int) {
w.buckets[offset%w.size].Reset()
func (w *window) resetBucket(offset int) {
w.buckets[offset%w.size].reset()
}
// IgnoreCurrentBucket lets the Reduce call ignore current bucket.
func IgnoreCurrentBucket[T Numerical, B BucketInterface[T]]() RollingWindowOption[T, B] {
return func(w *RollingWindow[T, B]) {
func IgnoreCurrentBucket() RollingWindowOption {
return func(w *RollingWindow) {
w.ignoreCurrent = true
}
}

View File

@@ -12,24 +12,18 @@ import (
const duration = time.Millisecond * 50
func TestNewRollingWindow(t *testing.T) {
assert.NotNil(t, NewRollingWindow[int64, *Bucket[int64]](func() *Bucket[int64] {
return new(Bucket[int64])
}, 10, time.Second))
assert.NotNil(t, NewRollingWindow(10, time.Second))
assert.Panics(t, func() {
NewRollingWindow[int64, *Bucket[int64]](func() *Bucket[int64] {
return new(Bucket[int64])
}, 0, time.Second)
NewRollingWindow(0, time.Second)
})
}
func TestRollingWindowAdd(t *testing.T) {
const size = 3
r := NewRollingWindow[float64, *Bucket[float64]](func() *Bucket[float64] {
return new(Bucket[float64])
}, size, duration)
r := NewRollingWindow(size, duration)
listBuckets := func() []float64 {
var buckets []float64
r.Reduce(func(b *Bucket[float64]) {
r.Reduce(func(b *Bucket) {
buckets = append(buckets, b.Sum)
})
return buckets
@@ -53,12 +47,10 @@ func TestRollingWindowAdd(t *testing.T) {
func TestRollingWindowReset(t *testing.T) {
const size = 3
r := NewRollingWindow[float64, *Bucket[float64]](func() *Bucket[float64] {
return new(Bucket[float64])
}, size, duration, IgnoreCurrentBucket[float64, *Bucket[float64]]())
r := NewRollingWindow(size, duration, IgnoreCurrentBucket())
listBuckets := func() []float64 {
var buckets []float64
r.Reduce(func(b *Bucket[float64]) {
r.Reduce(func(b *Bucket) {
buckets = append(buckets, b.Sum)
})
return buckets
@@ -80,19 +72,15 @@ func TestRollingWindowReset(t *testing.T) {
func TestRollingWindowReduce(t *testing.T) {
const size = 4
tests := []struct {
win *RollingWindow[float64, *Bucket[float64]]
win *RollingWindow
expect float64
}{
{
win: NewRollingWindow[float64, *Bucket[float64]](func() *Bucket[float64] {
return new(Bucket[float64])
}, size, duration),
win: NewRollingWindow(size, duration),
expect: 10,
},
{
win: NewRollingWindow[float64, *Bucket[float64]](func() *Bucket[float64] {
return new(Bucket[float64])
}, size, duration, IgnoreCurrentBucket[float64, *Bucket[float64]]()),
win: NewRollingWindow(size, duration, IgnoreCurrentBucket()),
expect: 4,
},
}
@@ -109,7 +97,7 @@ func TestRollingWindowReduce(t *testing.T) {
}
}
var result float64
r.Reduce(func(b *Bucket[float64]) {
r.Reduce(func(b *Bucket) {
result += b.Sum
})
assert.Equal(t, test.expect, result)
@@ -120,12 +108,10 @@ func TestRollingWindowReduce(t *testing.T) {
func TestRollingWindowBucketTimeBoundary(t *testing.T) {
const size = 3
interval := time.Millisecond * 30
r := NewRollingWindow[float64, *Bucket[float64]](func() *Bucket[float64] {
return new(Bucket[float64])
}, size, interval)
r := NewRollingWindow(size, interval)
listBuckets := func() []float64 {
var buckets []float64
r.Reduce(func(b *Bucket[float64]) {
r.Reduce(func(b *Bucket) {
buckets = append(buckets, b.Sum)
})
return buckets
@@ -152,9 +138,7 @@ func TestRollingWindowBucketTimeBoundary(t *testing.T) {
func TestRollingWindowDataRace(t *testing.T) {
const size = 3
r := NewRollingWindow[float64, *Bucket[float64]](func() *Bucket[float64] {
return new(Bucket[float64])
}, size, duration)
r := NewRollingWindow(size, duration)
stop := make(chan bool)
go func() {
for {
@@ -173,7 +157,7 @@ func TestRollingWindowDataRace(t *testing.T) {
case <-stop:
return
default:
r.Reduce(func(b *Bucket[float64]) {})
r.Reduce(func(b *Bucket) {})
}
}
}()

View File

@@ -14,23 +14,21 @@ type SafeMap struct {
lock sync.RWMutex
deletionOld int
deletionNew int
dirtyOld map[any]any
dirtyNew map[any]any
dirtyOld map[interface{}]interface{}
dirtyNew map[interface{}]interface{}
}
// NewSafeMap returns a SafeMap.
func NewSafeMap() *SafeMap {
return &SafeMap{
dirtyOld: make(map[any]any),
dirtyNew: make(map[any]any),
dirtyOld: make(map[interface{}]interface{}),
dirtyNew: make(map[interface{}]interface{}),
}
}
// Del deletes the value with the given key from m.
func (m *SafeMap) Del(key any) {
func (m *SafeMap) Del(key interface{}) {
m.lock.Lock()
defer m.lock.Unlock()
if _, ok := m.dirtyOld[key]; ok {
delete(m.dirtyOld, key)
m.deletionOld++
@@ -44,20 +42,21 @@ func (m *SafeMap) Del(key any) {
}
m.dirtyOld = m.dirtyNew
m.deletionOld = m.deletionNew
m.dirtyNew = make(map[any]any)
m.dirtyNew = make(map[interface{}]interface{})
m.deletionNew = 0
}
if m.deletionNew >= maxDeletion && len(m.dirtyNew) < copyThreshold {
for k, v := range m.dirtyNew {
m.dirtyOld[k] = v
}
m.dirtyNew = make(map[any]any)
m.dirtyNew = make(map[interface{}]interface{})
m.deletionNew = 0
}
m.lock.Unlock()
}
// Get gets the value with the given key from m.
func (m *SafeMap) Get(key any) (any, bool) {
func (m *SafeMap) Get(key interface{}) (interface{}, bool) {
m.lock.RLock()
defer m.lock.RUnlock()
@@ -71,7 +70,7 @@ func (m *SafeMap) Get(key any) (any, bool) {
// Range calls f sequentially for each key and value present in the map.
// If f returns false, range stops the iteration.
func (m *SafeMap) Range(f func(key, val any) bool) {
func (m *SafeMap) Range(f func(key, val interface{}) bool) {
m.lock.RLock()
defer m.lock.RUnlock()
@@ -88,10 +87,8 @@ func (m *SafeMap) Range(f func(key, val any) bool) {
}
// Set sets the value into m with the given key.
func (m *SafeMap) Set(key, value any) {
func (m *SafeMap) Set(key, value interface{}) {
m.lock.Lock()
defer m.lock.Unlock()
if m.deletionOld <= maxDeletion {
if _, ok := m.dirtyNew[key]; ok {
delete(m.dirtyNew, key)
@@ -105,6 +102,7 @@ func (m *SafeMap) Set(key, value any) {
}
m.dirtyNew[key] = value
}
m.lock.Unlock()
}
// Size returns the size of m.

View File

@@ -138,7 +138,7 @@ func TestSafeMap_Range(t *testing.T) {
}
var count int32
m.Range(func(k, v any) bool {
m.Range(func(k, v interface{}) bool {
atomic.AddInt32(&count, 1)
newMap.Set(k, v)
return true
@@ -147,65 +147,3 @@ func TestSafeMap_Range(t *testing.T) {
assert.Equal(t, m.dirtyNew, newMap.dirtyNew)
assert.Equal(t, m.dirtyOld, newMap.dirtyOld)
}
func TestSetManyTimes(t *testing.T) {
const iteration = maxDeletion * 2
m := NewSafeMap()
for i := 0; i < iteration; i++ {
m.Set(i, i)
if i%3 == 0 {
m.Del(i / 2)
}
}
var count int
m.Range(func(k, v any) bool {
count++
return count < maxDeletion/2
})
assert.Equal(t, maxDeletion/2, count)
for i := 0; i < iteration; i++ {
m.Set(i, i)
if i%3 == 0 {
m.Del(i / 2)
}
}
for i := 0; i < iteration; i++ {
m.Set(i, i)
if i%3 == 0 {
m.Del(i / 2)
}
}
for i := 0; i < iteration; i++ {
m.Set(i, i)
if i%3 == 0 {
m.Del(i / 2)
}
}
count = 0
m.Range(func(k, v any) bool {
count++
return count < maxDeletion
})
assert.Equal(t, maxDeletion, count)
}
func TestSetManyTimesNew(t *testing.T) {
m := NewSafeMap()
for i := 0; i < maxDeletion*3; i++ {
m.Set(i, i)
}
for i := 0; i < maxDeletion*2; i++ {
m.Del(i)
}
for i := 0; i < maxDeletion*3; i++ {
m.Set(i+maxDeletion*3, i+maxDeletion*3)
}
for i := 0; i < maxDeletion*2; i++ {
m.Del(i + maxDeletion*2)
}
for i := 0; i < maxDeletion-copyThreshold+1; i++ {
m.Del(i + maxDeletion*2)
}
assert.Equal(t, 0, len(m.dirtyNew))
}

View File

@@ -1,53 +1,235 @@
package collection
import "github.com/zeromicro/go-zero/core/lang"
import (
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logx"
)
// Set is a type-safe generic set collection.
// It's not thread-safe, use with synchronization for concurrent access.
type Set[T comparable] struct {
data map[T]lang.PlaceholderType
const (
unmanaged = iota
untyped
intType
int64Type
uintType
uint64Type
stringType
)
// Set is not thread-safe, for concurrent use, make sure to use it with synchronization.
type Set struct {
data map[interface{}]lang.PlaceholderType
tp int
}
// NewSet returns a new type-safe set.
func NewSet[T comparable]() *Set[T] {
return &Set[T]{
data: make(map[T]lang.PlaceholderType),
// NewSet returns a managed Set, can only put the values with the same type.
func NewSet() *Set {
return &Set{
data: make(map[interface{}]lang.PlaceholderType),
tp: untyped,
}
}
// Add adds items to the set. Duplicates are automatically ignored.
func (s *Set[T]) Add(items ...T) {
for _, item := range items {
s.data[item] = lang.Placeholder
// NewUnmanagedSet returns an unmanaged Set, which can put values with different types.
func NewUnmanagedSet() *Set {
return &Set{
data: make(map[interface{}]lang.PlaceholderType),
tp: unmanaged,
}
}
// Clear removes all items from the set.
func (s *Set[T]) Clear() {
clear(s.data)
// Add adds i into s.
func (s *Set) Add(i ...interface{}) {
for _, each := range i {
s.add(each)
}
}
// Contains checks if an item exists in the set.
func (s *Set[T]) Contains(item T) bool {
_, ok := s.data[item]
// AddInt adds int values ii into s.
func (s *Set) AddInt(ii ...int) {
for _, each := range ii {
s.add(each)
}
}
// AddInt64 adds int64 values ii into s.
func (s *Set) AddInt64(ii ...int64) {
for _, each := range ii {
s.add(each)
}
}
// AddUint adds uint values ii into s.
func (s *Set) AddUint(ii ...uint) {
for _, each := range ii {
s.add(each)
}
}
// AddUint64 adds uint64 values ii into s.
func (s *Set) AddUint64(ii ...uint64) {
for _, each := range ii {
s.add(each)
}
}
// AddStr adds string values ss into s.
func (s *Set) AddStr(ss ...string) {
for _, each := range ss {
s.add(each)
}
}
// Contains checks if i is in s.
func (s *Set) Contains(i interface{}) bool {
if len(s.data) == 0 {
return false
}
s.validate(i)
_, ok := s.data[i]
return ok
}
// Count returns the number of items in the set.
func (s *Set[T]) Count() int {
return len(s.data)
}
// Keys returns the keys in s.
func (s *Set) Keys() []interface{} {
var keys []interface{}
// Keys returns all elements in the set as a slice.
func (s *Set[T]) Keys() []T {
keys := make([]T, 0, len(s.data))
for key := range s.data {
keys = append(keys, key)
}
return keys
}
// Remove removes an item from the set.
func (s *Set[T]) Remove(item T) {
delete(s.data, item)
// KeysInt returns the int keys in s.
func (s *Set) KeysInt() []int {
var keys []int
for key := range s.data {
if intKey, ok := key.(int); ok {
keys = append(keys, intKey)
}
}
return keys
}
// KeysInt64 returns int64 keys in s.
func (s *Set) KeysInt64() []int64 {
var keys []int64
for key := range s.data {
if intKey, ok := key.(int64); ok {
keys = append(keys, intKey)
}
}
return keys
}
// KeysUint returns uint keys in s.
func (s *Set) KeysUint() []uint {
var keys []uint
for key := range s.data {
if intKey, ok := key.(uint); ok {
keys = append(keys, intKey)
}
}
return keys
}
// KeysUint64 returns uint64 keys in s.
func (s *Set) KeysUint64() []uint64 {
var keys []uint64
for key := range s.data {
if intKey, ok := key.(uint64); ok {
keys = append(keys, intKey)
}
}
return keys
}
// KeysStr returns string keys in s.
func (s *Set) KeysStr() []string {
var keys []string
for key := range s.data {
if strKey, ok := key.(string); ok {
keys = append(keys, strKey)
}
}
return keys
}
// Remove removes i from s.
func (s *Set) Remove(i interface{}) {
s.validate(i)
delete(s.data, i)
}
// Count returns the number of items in s.
func (s *Set) Count() int {
return len(s.data)
}
func (s *Set) add(i interface{}) {
switch s.tp {
case unmanaged:
// do nothing
case untyped:
s.setType(i)
default:
s.validate(i)
}
s.data[i] = lang.Placeholder
}
func (s *Set) setType(i interface{}) {
// s.tp can only be untyped here
switch i.(type) {
case int:
s.tp = intType
case int64:
s.tp = int64Type
case uint:
s.tp = uintType
case uint64:
s.tp = uint64Type
case string:
s.tp = stringType
}
}
func (s *Set) validate(i interface{}) {
if s.tp == unmanaged {
return
}
switch i.(type) {
case int:
if s.tp != intType {
logx.Errorf("element is int, but set contains elements with type %d", s.tp)
}
case int64:
if s.tp != int64Type {
logx.Errorf("element is int64, but set contains elements with type %d", s.tp)
}
case uint:
if s.tp != uintType {
logx.Errorf("element is uint, but set contains elements with type %d", s.tp)
}
case uint64:
if s.tp != uint64Type {
logx.Errorf("element is uint64, but set contains elements with type %d", s.tp)
}
case string:
if s.tp != stringType {
logx.Errorf("element is string, but set contains elements with type %d", s.tp)
}
}
}

View File

@@ -12,117 +12,34 @@ func init() {
logx.Disable()
}
// Set functionality tests
func TestTypedSetInt(t *testing.T) {
set := NewSet[int]()
values := []int{1, 2, 3, 2, 1} // Contains duplicates
// Test adding
set.Add(values...)
assert.Equal(t, 3, set.Count()) // Should only have 3 elements after deduplication
// Test contains
assert.True(t, set.Contains(1))
assert.True(t, set.Contains(2))
assert.True(t, set.Contains(3))
assert.False(t, set.Contains(4))
// Test getting all keys
keys := set.Keys()
sort.Ints(keys)
assert.EqualValues(t, []int{1, 2, 3}, keys)
// Test removal
set.Remove(2)
assert.False(t, set.Contains(2))
assert.Equal(t, 2, set.Count())
}
func TestTypedSetStringOps(t *testing.T) {
set := NewSet[string]()
values := []string{"a", "b", "c", "b", "a"}
set.Add(values...)
assert.Equal(t, 3, set.Count())
assert.True(t, set.Contains("a"))
assert.True(t, set.Contains("b"))
assert.True(t, set.Contains("c"))
assert.False(t, set.Contains("d"))
keys := set.Keys()
sort.Strings(keys)
assert.EqualValues(t, []string{"a", "b", "c"}, keys)
}
func TestTypedSetClear(t *testing.T) {
set := NewSet[int]()
set.Add(1, 2, 3)
assert.Equal(t, 3, set.Count())
set.Clear()
assert.Equal(t, 0, set.Count())
assert.False(t, set.Contains(1))
}
func TestTypedSetEmpty(t *testing.T) {
set := NewSet[int]()
assert.Equal(t, 0, set.Count())
assert.False(t, set.Contains(1))
assert.Empty(t, set.Keys())
}
func TestTypedSetMultipleTypes(t *testing.T) {
// Test different typed generic sets
intSet := NewSet[int]()
int64Set := NewSet[int64]()
uintSet := NewSet[uint]()
uint64Set := NewSet[uint64]()
stringSet := NewSet[string]()
intSet.Add(1, 2, 3)
int64Set.Add(1, 2, 3)
uintSet.Add(1, 2, 3)
uint64Set.Add(1, 2, 3)
stringSet.Add("1", "2", "3")
assert.Equal(t, 3, intSet.Count())
assert.Equal(t, 3, int64Set.Count())
assert.Equal(t, 3, uintSet.Count())
assert.Equal(t, 3, uint64Set.Count())
assert.Equal(t, 3, stringSet.Count())
}
// Set benchmarks
func BenchmarkTypedIntSet(b *testing.B) {
s := NewSet[int]()
for i := 0; i < b.N; i++ {
s.Add(i)
_ = s.Contains(i)
}
}
func BenchmarkTypedStringSet(b *testing.B) {
s := NewSet[string]()
for i := 0; i < b.N; i++ {
s.Add(string(rune(i)))
_ = s.Contains(string(rune(i)))
}
}
// Legacy tests remain unchanged for backward compatibility
func BenchmarkRawSet(b *testing.B) {
m := make(map[any]struct{})
m := make(map[interface{}]struct{})
for i := 0; i < b.N; i++ {
m[i] = struct{}{}
_ = m[i]
}
}
func BenchmarkUnmanagedSet(b *testing.B) {
s := NewUnmanagedSet()
for i := 0; i < b.N; i++ {
s.Add(i)
_ = s.Contains(i)
}
}
func BenchmarkSet(b *testing.B) {
s := NewSet()
for i := 0; i < b.N; i++ {
s.AddInt(i)
_ = s.Contains(i)
}
}
func TestAdd(t *testing.T) {
// given
set := NewSet[int]()
values := []int{1, 2, 3}
set := NewUnmanagedSet()
values := []interface{}{1, 2, 3}
// when
set.Add(values...)
@@ -134,74 +51,82 @@ func TestAdd(t *testing.T) {
func TestAddInt(t *testing.T) {
// given
set := NewSet[int]()
set := NewSet()
values := []int{1, 2, 3}
// when
set.Add(values...)
set.AddInt(values...)
// then
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
keys := set.Keys()
keys := set.KeysInt()
sort.Ints(keys)
assert.EqualValues(t, values, keys)
}
func TestAddInt64(t *testing.T) {
// given
set := NewSet[int64]()
set := NewSet()
values := []int64{1, 2, 3}
// when
set.Add(values...)
set.AddInt64(values...)
// then
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
assert.Equal(t, len(values), len(set.Keys()))
assert.True(t, set.Contains(int64(1)) && set.Contains(int64(2)) && set.Contains(int64(3)))
assert.Equal(t, len(values), len(set.KeysInt64()))
}
func TestAddUint(t *testing.T) {
// given
set := NewSet[uint]()
set := NewSet()
values := []uint{1, 2, 3}
// when
set.Add(values...)
set.AddUint(values...)
// then
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
assert.Equal(t, len(values), len(set.Keys()))
assert.True(t, set.Contains(uint(1)) && set.Contains(uint(2)) && set.Contains(uint(3)))
assert.Equal(t, len(values), len(set.KeysUint()))
}
func TestAddUint64(t *testing.T) {
// given
set := NewSet[uint64]()
set := NewSet()
values := []uint64{1, 2, 3}
// when
set.Add(values...)
set.AddUint64(values...)
// then
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
assert.Equal(t, len(values), len(set.Keys()))
assert.True(t, set.Contains(uint64(1)) && set.Contains(uint64(2)) && set.Contains(uint64(3)))
assert.Equal(t, len(values), len(set.KeysUint64()))
}
func TestAddStr(t *testing.T) {
// given
set := NewSet[string]()
set := NewSet()
values := []string{"1", "2", "3"}
// when
set.Add(values...)
set.AddStr(values...)
// then
assert.True(t, set.Contains("1") && set.Contains("2") && set.Contains("3"))
assert.Equal(t, len(values), len(set.Keys()))
assert.Equal(t, len(values), len(set.KeysStr()))
}
func TestContainsWithoutElements(t *testing.T) {
// given
set := NewSet[int]()
set := NewSet()
// then
assert.False(t, set.Contains(1))
}
func TestContainsUnmanagedWithoutElements(t *testing.T) {
// given
set := NewUnmanagedSet()
// then
assert.False(t, set.Contains(1))
@@ -209,8 +134,8 @@ func TestContainsWithoutElements(t *testing.T) {
func TestRemove(t *testing.T) {
// given
set := NewSet[int]()
set.Add([]int{1, 2, 3}...)
set := NewSet()
set.Add([]interface{}{1, 2, 3}...)
// when
set.Remove(2)
@@ -221,9 +146,57 @@ func TestRemove(t *testing.T) {
func TestCount(t *testing.T) {
// given
set := NewSet[int]()
set.Add([]int{1, 2, 3}...)
set := NewSet()
set.Add([]interface{}{1, 2, 3}...)
// then
assert.Equal(t, set.Count(), 3)
}
func TestKeysIntMismatch(t *testing.T) {
set := NewSet()
set.add(int64(1))
set.add(2)
vals := set.KeysInt()
assert.EqualValues(t, []int{2}, vals)
}
func TestKeysInt64Mismatch(t *testing.T) {
set := NewSet()
set.add(1)
set.add(int64(2))
vals := set.KeysInt64()
assert.EqualValues(t, []int64{2}, vals)
}
func TestKeysUintMismatch(t *testing.T) {
set := NewSet()
set.add(1)
set.add(uint(2))
vals := set.KeysUint()
assert.EqualValues(t, []uint{2}, vals)
}
func TestKeysUint64Mismatch(t *testing.T) {
set := NewSet()
set.add(1)
set.add(uint64(2))
vals := set.KeysUint64()
assert.EqualValues(t, []uint64{2}, vals)
}
func TestKeysStrMismatch(t *testing.T) {
set := NewSet()
set.add(1)
set.add("2")
vals := set.KeysStr()
assert.EqualValues(t, []string{"2"}, vals)
}
func TestSetType(t *testing.T) {
set := NewUnmanagedSet()
set.add(1)
set.add("2")
vals := set.Keys()
assert.ElementsMatch(t, []interface{}{1, "2"}, vals)
}

View File

@@ -20,7 +20,7 @@ var (
type (
// Execute defines the method to execute the task.
Execute func(key, value any)
Execute func(key, value interface{})
// A TimingWheel is a timing wheel object to schedule tasks.
TimingWheel struct {
@@ -33,14 +33,14 @@ type (
execute Execute
setChannel chan timingEntry
moveChannel chan baseEntry
removeChannel chan any
drainChannel chan func(key, value any)
removeChannel chan interface{}
drainChannel chan func(key, value interface{})
stopChannel chan lang.PlaceholderType
}
timingEntry struct {
baseEntry
value any
value interface{}
circle int
diff int
removed bool
@@ -48,7 +48,7 @@ type (
baseEntry struct {
delay time.Duration
key any
key interface{}
}
positionEntry struct {
@@ -57,8 +57,8 @@ type (
}
timingTask struct {
key any
value any
key interface{}
value interface{}
}
)
@@ -85,8 +85,8 @@ func NewTimingWheelWithTicker(interval time.Duration, numSlots int, execute Exec
numSlots: numSlots,
setChannel: make(chan timingEntry),
moveChannel: make(chan baseEntry),
removeChannel: make(chan any),
drainChannel: make(chan func(key, value any)),
removeChannel: make(chan interface{}),
drainChannel: make(chan func(key, value interface{})),
stopChannel: make(chan lang.PlaceholderType),
}
@@ -97,7 +97,7 @@ func NewTimingWheelWithTicker(interval time.Duration, numSlots int, execute Exec
}
// Drain drains all items and executes them.
func (tw *TimingWheel) Drain(fn func(key, value any)) error {
func (tw *TimingWheel) Drain(fn func(key, value interface{})) error {
select {
case tw.drainChannel <- fn:
return nil
@@ -107,7 +107,7 @@ func (tw *TimingWheel) Drain(fn func(key, value any)) error {
}
// MoveTimer moves the task with the given key to the given delay.
func (tw *TimingWheel) MoveTimer(key any, delay time.Duration) error {
func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) error {
if delay <= 0 || key == nil {
return ErrArgument
}
@@ -124,7 +124,7 @@ func (tw *TimingWheel) MoveTimer(key any, delay time.Duration) error {
}
// RemoveTimer removes the task with the given key.
func (tw *TimingWheel) RemoveTimer(key any) error {
func (tw *TimingWheel) RemoveTimer(key interface{}) error {
if key == nil {
return ErrArgument
}
@@ -138,7 +138,7 @@ func (tw *TimingWheel) RemoveTimer(key any) error {
}
// SetTimer sets the task value with the given key to the delay.
func (tw *TimingWheel) SetTimer(key, value any, delay time.Duration) error {
func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) error {
if delay <= 0 || key == nil {
return ErrArgument
}
@@ -162,9 +162,8 @@ func (tw *TimingWheel) Stop() {
close(tw.stopChannel)
}
func (tw *TimingWheel) drainAll(fn func(key, value any)) {
func (tw *TimingWheel) drainAll(fn func(key, value interface{})) {
runner := threading.NewTaskRunner(drainWorkers)
for _, slot := range tw.slots {
for e := slot.Front(); e != nil; {
task := e.Value.(*timingEntry)
@@ -178,8 +177,6 @@ func (tw *TimingWheel) drainAll(fn func(key, value any)) {
}
}
}
runner.Wait()
}
func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos, circle int) {
@@ -235,7 +232,7 @@ func (tw *TimingWheel) onTick() {
tw.scanAndRunTasks(l)
}
func (tw *TimingWheel) removeTask(key any) {
func (tw *TimingWheel) removeTask(key interface{}) {
val, ok := tw.timers.Get(key)
if !ok {
return

View File

@@ -20,13 +20,13 @@ const (
)
func TestNewTimingWheel(t *testing.T) {
_, err := NewTimingWheel(0, 10, func(key, value any) {})
_, err := NewTimingWheel(0, 10, func(key, value interface{}) {})
assert.NotNil(t, err)
}
func TestTimingWheel_Drain(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
}, ticker)
tw.SetTimer("first", 3, testStep*4)
tw.SetTimer("second", 5, testStep*7)
@@ -36,7 +36,7 @@ func TestTimingWheel_Drain(t *testing.T) {
var lock sync.Mutex
var wg sync.WaitGroup
wg.Add(3)
tw.Drain(func(key, value any) {
tw.Drain(func(key, value interface{}) {
lock.Lock()
defer lock.Unlock()
keys = append(keys, key.(string))
@@ -50,19 +50,19 @@ func TestTimingWheel_Drain(t *testing.T) {
assert.EqualValues(t, []string{"first", "second", "third"}, keys)
assert.EqualValues(t, []int{3, 5, 7}, vals)
var count int
tw.Drain(func(key, value any) {
tw.Drain(func(key, value interface{}) {
count++
})
time.Sleep(time.Millisecond * 100)
assert.Equal(t, 0, count)
tw.Stop()
assert.Equal(t, ErrClosed, tw.Drain(func(key, value any) {}))
assert.Equal(t, ErrClosed, tw.Drain(func(key, value interface{}) {}))
}
func TestTimingWheel_SetTimerSoon(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
@@ -78,7 +78,7 @@ func TestTimingWheel_SetTimerSoon(t *testing.T) {
func TestTimingWheel_SetTimerTwice(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 5, v.(int))
@@ -96,7 +96,7 @@ func TestTimingWheel_SetTimerTwice(t *testing.T) {
func TestTimingWheel_SetTimerWrongDelay(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker)
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
defer tw.Stop()
assert.NotPanics(t, func() {
tw.SetTimer("any", 3, -testStep)
@@ -105,7 +105,7 @@ func TestTimingWheel_SetTimerWrongDelay(t *testing.T) {
func TestTimingWheel_SetTimerAfterClose(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker)
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
tw.Stop()
assert.Equal(t, ErrClosed, tw.SetTimer("any", 3, testStep))
}
@@ -113,7 +113,7 @@ func TestTimingWheel_SetTimerAfterClose(t *testing.T) {
func TestTimingWheel_MoveTimer(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v any) {
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
@@ -139,7 +139,7 @@ func TestTimingWheel_MoveTimer(t *testing.T) {
func TestTimingWheel_MoveTimerSoon(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v any) {
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
@@ -155,7 +155,7 @@ func TestTimingWheel_MoveTimerSoon(t *testing.T) {
func TestTimingWheel_MoveTimerEarlier(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
@@ -173,7 +173,7 @@ func TestTimingWheel_MoveTimerEarlier(t *testing.T) {
func TestTimingWheel_RemoveTimer(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker)
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
tw.SetTimer("any", 3, testStep)
assert.NotPanics(t, func() {
tw.RemoveTimer("any")
@@ -236,7 +236,7 @@ func TestTimingWheel_SetTimer(t *testing.T) {
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) {
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
assert.Equal(t, 1, key.(int))
assert.Equal(t, 2, value.(int))
actual = atomic.LoadInt32(&count)
@@ -317,7 +317,7 @@ func TestTimingWheel_SetAndMoveThenStart(t *testing.T) {
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) {
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
@@ -405,7 +405,7 @@ func TestTimingWheel_SetAndMoveTwice(t *testing.T) {
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) {
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
@@ -486,7 +486,7 @@ func TestTimingWheel_ElapsedAndSet(t *testing.T) {
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) {
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
@@ -577,7 +577,7 @@ func TestTimingWheel_ElapsedAndSetThenMove(t *testing.T) {
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) {
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
@@ -612,7 +612,7 @@ func TestMoveAndRemoveTask(t *testing.T) {
}
}
var keys []int
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
keys = append(keys, v.(int))
@@ -629,161 +629,10 @@ func TestMoveAndRemoveTask(t *testing.T) {
assert.Equal(t, 0, len(keys))
}
// TestTimingWheel_DrainClosureBug tests the closure capture bug in drainAll
// Issue: https://github.com/zeromicro/go-zero/issues/5314
func TestTimingWheel_DrainClosureBug(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker)
defer tw.Stop()
// Set multiple timers with different values
for i := 0; i < 10; i++ {
tw.SetTimer(i, i*10, testStep*5)
}
// Give time for timers to be set
time.Sleep(time.Millisecond * 100)
var mu sync.Mutex
received := make(map[int]int)
var wg sync.WaitGroup
wg.Add(10)
tw.Drain(func(key, value any) {
mu.Lock()
defer mu.Unlock()
k := key.(int)
v := value.(int)
received[k] = v
wg.Done()
})
wg.Wait()
// Check if all values match their keys
for k, v := range received {
expected := k * 10
assert.Equal(t, expected, v, "key %d should have value %d, got %d", k, expected, v)
}
}
// TestTimingWheel_RunTasksClosureBug tests the closure capture bug in runTasks
// Issue: https://github.com/zeromicro/go-zero/issues/5314
func TestTimingWheel_RunTasksClosureBug(t *testing.T) {
ticker := timex.NewFakeTicker()
var mu sync.Mutex
executed := make(map[int]int)
var wg sync.WaitGroup
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
mu.Lock()
defer mu.Unlock()
key := k.(int)
val := v.(int)
executed[key] = val
wg.Done()
}, ticker)
defer tw.Stop()
// Set multiple timers that should fire in the same tick
count := 10
wg.Add(count)
for i := 0; i < count; i++ {
tw.SetTimer(i, i*10, testStep)
}
// Advance ticker to trigger tasks
ticker.Tick()
// Wait for execution with timeout
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
// Success
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for tasks to execute")
}
// Verify all tasks executed with correct values
assert.Equal(t, count, len(executed), "should have executed all tasks")
for k, v := range executed {
expected := k * 10
assert.Equal(t, expected, v, "key %d should have value %d, got %d", k, expected, v)
}
}
// TestTimingWheel_RunTasksRaceCondition tests for race conditions in runTasks
// This test specifically targets the loop variable capture bug
func TestTimingWheel_RunTasksRaceCondition(t *testing.T) {
// Run multiple times to increase likelihood of catching the bug
for attempt := 0; attempt < 10; attempt++ {
t.Run("", func(t *testing.T) {
ticker := timex.NewFakeTicker()
var mu sync.Mutex
keyValues := make(map[int][]int)
var wg sync.WaitGroup
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
// Add small delay to increase chance of race
time.Sleep(time.Microsecond)
mu.Lock()
defer mu.Unlock()
key := k.(int)
val := v.(int)
keyValues[key] = append(keyValues[key], val)
wg.Done()
}, ticker)
defer tw.Stop()
// Set many timers rapidly to increase chance of race
count := 50
wg.Add(count)
for i := 0; i < count; i++ {
tw.SetTimer(i, i*100, testStep)
}
ticker.Tick()
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for tasks")
}
// Check for duplicates or wrong values
wrongCount := 0
for key, values := range keyValues {
assert.Equal(t, 1, len(values), "key %d should only execute once, got %v", key, values)
if len(values) > 0 {
expected := key * 100
if values[0] != expected {
wrongCount++
t.Logf("BUG DETECTED: key %d should have value %d, got %d", key, expected, values[0])
}
}
}
if wrongCount > 0 {
t.Errorf("Found %d tasks with wrong values due to closure bug", wrongCount)
}
})
}
}
func BenchmarkTimingWheel(b *testing.B) {
b.ReportAllocs()
tw, _ := NewTimingWheel(time.Second, 100, func(k, v any) {})
tw, _ := NewTimingWheel(time.Second, 100, func(k, v interface{}) {})
for i := 0; i < b.N; i++ {
tw.SetTimer(i, i, time.Second)
tw.SetTimer(b.N+i, b.N+i, time.Second)

View File

@@ -13,14 +13,11 @@ import (
"github.com/zeromicro/go-zero/internal/encoding"
)
const (
jsonTagKey = "json"
jsonTagSep = ','
)
const jsonTagKey = "json"
var (
fillDefaultUnmarshaler = mapping.NewUnmarshaler(jsonTagKey, mapping.WithDefault())
loaders = map[string]func([]byte, any) error{
loaders = map[string]func([]byte, interface{}) error{
".json": LoadFromJsonBytes,
".toml": LoadFromTomlBytes,
".yaml": LoadFromYamlBytes,
@@ -37,12 +34,12 @@ type fieldInfo struct {
// FillDefault fills the default values for the given v,
// and the premise is that the value of v must be guaranteed to be empty.
func FillDefault(v any) error {
return fillDefaultUnmarshaler.Unmarshal(map[string]any{}, v)
func FillDefault(v interface{}) error {
return fillDefaultUnmarshaler.Unmarshal(map[string]interface{}{}, v)
}
// Load loads config into v from file, .json, .yaml and .yml are acceptable.
func Load(file string, v any, opts ...Option) error {
func Load(file string, v interface{}, opts ...Option) error {
content, err := os.ReadFile(file)
if err != nil {
return err
@@ -67,40 +64,35 @@ func Load(file string, v any, opts ...Option) error {
// LoadConfig loads config into v from file, .json, .yaml and .yml are acceptable.
// Deprecated: use Load instead.
func LoadConfig(file string, v any, opts ...Option) error {
func LoadConfig(file string, v interface{}, opts ...Option) error {
return Load(file, v, opts...)
}
// LoadFromJsonBytes loads config into v from content json bytes.
func LoadFromJsonBytes(content []byte, v any) error {
info, err := buildFieldsInfo(reflect.TypeOf(v), "")
func LoadFromJsonBytes(content []byte, v interface{}) error {
info, err := buildFieldsInfo(reflect.TypeOf(v))
if err != nil {
return err
}
var m map[string]any
if err = jsonx.Unmarshal(content, &m); err != nil {
var m map[string]interface{}
if err := jsonx.Unmarshal(content, &m); err != nil {
return err
}
lowerCaseKeyMap := toLowerCaseKeyMap(m, info)
if err = mapping.UnmarshalJsonMap(lowerCaseKeyMap, v,
mapping.WithCanonicalKeyFunc(toLowerCase)); err != nil {
return err
}
return validate(v)
return mapping.UnmarshalJsonMap(lowerCaseKeyMap, v, mapping.WithCanonicalKeyFunc(toLowerCase))
}
// LoadConfigFromJsonBytes loads config into v from content json bytes.
// Deprecated: use LoadFromJsonBytes instead.
func LoadConfigFromJsonBytes(content []byte, v any) error {
func LoadConfigFromJsonBytes(content []byte, v interface{}) error {
return LoadFromJsonBytes(content, v)
}
// LoadFromTomlBytes loads config into v from content toml bytes.
func LoadFromTomlBytes(content []byte, v any) error {
func LoadFromTomlBytes(content []byte, v interface{}) error {
b, err := encoding.TomlToJson(content)
if err != nil {
return err
@@ -110,7 +102,7 @@ func LoadFromTomlBytes(content []byte, v any) error {
}
// LoadFromYamlBytes loads config into v from content yaml bytes.
func LoadFromYamlBytes(content []byte, v any) error {
func LoadFromYamlBytes(content []byte, v interface{}) error {
b, err := encoding.YamlToJson(content)
if err != nil {
return err
@@ -121,24 +113,24 @@ func LoadFromYamlBytes(content []byte, v any) error {
// LoadConfigFromYamlBytes loads config into v from content yaml bytes.
// Deprecated: use LoadFromYamlBytes instead.
func LoadConfigFromYamlBytes(content []byte, v any) error {
func LoadConfigFromYamlBytes(content []byte, v interface{}) error {
return LoadFromYamlBytes(content, v)
}
// MustLoad loads config into v from path, exits on error.
func MustLoad(path string, v any, opts ...Option) {
func MustLoad(path string, v interface{}, opts ...Option) {
if err := Load(path, v, opts...); err != nil {
log.Fatalf("error: config file %s, %s", path, err.Error())
}
}
func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo, fullName string) error {
func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo) error {
if prev, ok := info.children[key]; ok {
if child.mapField != nil {
return newConflictKeyError(fullName)
return newDupKeyError(key)
}
if err := mergeFields(prev, child.children, fullName); err != nil {
if err := mergeFields(prev, key, child.children); err != nil {
return err
}
} else {
@@ -148,27 +140,27 @@ func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo, fullName st
return nil
}
func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type, fullName string) error {
func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
switch ft.Kind() {
case reflect.Struct:
fields, err := buildFieldsInfo(ft, fullName)
fields, err := buildFieldsInfo(ft)
if err != nil {
return err
}
for k, v := range fields.children {
if err = addOrMergeFields(info, k, v, fullName); err != nil {
if err = addOrMergeFields(info, k, v); err != nil {
return err
}
}
case reflect.Map:
elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()), fullName)
elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
if err != nil {
return err
}
if _, ok := info.children[lowerCaseName]; ok {
return newConflictKeyError(fullName)
return newDupKeyError(lowerCaseName)
}
info.children[lowerCaseName] = &fieldInfo{
@@ -177,7 +169,7 @@ func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.T
}
default:
if _, ok := info.children[lowerCaseName]; ok {
return newConflictKeyError(fullName)
return newDupKeyError(lowerCaseName)
}
info.children[lowerCaseName] = &fieldInfo{
@@ -188,16 +180,16 @@ func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.T
return nil
}
func buildFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) {
func buildFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
tp = mapping.Deref(tp)
switch tp.Kind() {
case reflect.Struct:
return buildStructFieldsInfo(tp, fullName)
case reflect.Array, reflect.Slice, reflect.Map:
return buildFieldsInfo(mapping.Deref(tp.Elem()), fullName)
return buildStructFieldsInfo(tp)
case reflect.Array, reflect.Slice:
return buildFieldsInfo(mapping.Deref(tp.Elem()))
case reflect.Chan, reflect.Func:
return nil, fmt.Errorf("unsupported type: %s, fullName: %s", tp.Kind(), fullName)
return nil, fmt.Errorf("unsupported type: %s", tp.Kind())
default:
return &fieldInfo{
children: make(map[string]*fieldInfo),
@@ -205,23 +197,23 @@ func buildFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) {
}
}
func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type, fullName string) error {
func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
var finfo *fieldInfo
var err error
switch ft.Kind() {
case reflect.Struct:
finfo, err = buildFieldsInfo(ft, fullName)
finfo, err = buildFieldsInfo(ft)
if err != nil {
return err
}
case reflect.Array, reflect.Slice:
finfo, err = buildFieldsInfo(ft.Elem(), fullName)
finfo, err = buildFieldsInfo(ft.Elem())
if err != nil {
return err
}
case reflect.Map:
elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()), fullName)
elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
if err != nil {
return err
}
@@ -231,37 +223,31 @@ func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type,
mapField: elemInfo,
}
default:
finfo, err = buildFieldsInfo(ft, fullName)
finfo, err = buildFieldsInfo(ft)
if err != nil {
return err
}
}
return addOrMergeFields(info, lowerCaseName, finfo, fullName)
return addOrMergeFields(info, lowerCaseName, finfo)
}
func buildStructFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) {
func buildStructFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
info := &fieldInfo{
children: make(map[string]*fieldInfo),
}
for i := 0; i < tp.NumField(); i++ {
field := tp.Field(i)
if !field.IsExported() {
continue
}
name := getTagName(field)
name := field.Name
lowerCaseName := toLowerCase(name)
ft := mapping.Deref(field.Type)
// flatten anonymous fields
if field.Anonymous {
if err := buildAnonymousFieldInfo(info, lowerCaseName, ft,
getFullName(fullName, lowerCaseName)); err != nil {
if err := buildAnonymousFieldInfo(info, lowerCaseName, ft); err != nil {
return nil, err
}
} else if err := buildNamedFieldInfo(info, lowerCaseName, ft,
getFullName(fullName, lowerCaseName)); err != nil {
} else if err := buildNamedFieldInfo(info, lowerCaseName, ft); err != nil {
return nil, err
}
}
@@ -269,32 +255,15 @@ func buildStructFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error)
return info, nil
}
// getTagName get the tag name of the given field, if no tag name, use file.Name.
// field.Name is returned on tags like `json:""` and `json:",optional"`.
func getTagName(field reflect.StructField) string {
if tag, ok := field.Tag.Lookup(jsonTagKey); ok {
if pos := strings.IndexByte(tag, jsonTagSep); pos >= 0 {
tag = tag[:pos]
}
tag = strings.TrimSpace(tag)
if len(tag) > 0 {
return tag
}
}
return field.Name
}
func mergeFields(prev *fieldInfo, children map[string]*fieldInfo, fullName string) error {
func mergeFields(prev *fieldInfo, key string, children map[string]*fieldInfo) error {
if len(prev.children) == 0 || len(children) == 0 {
return newConflictKeyError(fullName)
return newDupKeyError(key)
}
// merge fields
for k, v := range children {
if _, ok := prev.children[k]; ok {
return newConflictKeyError(fullName)
return newDupKeyError(k)
}
prev.children[k] = v
@@ -307,12 +276,12 @@ func toLowerCase(s string) string {
return strings.ToLower(s)
}
func toLowerCaseInterface(v any, info *fieldInfo) any {
func toLowerCaseInterface(v interface{}, info *fieldInfo) interface{} {
switch vv := v.(type) {
case map[string]any:
case map[string]interface{}:
return toLowerCaseKeyMap(vv, info)
case []any:
arr := make([]any, 0, len(vv))
case []interface{}:
var arr []interface{}
for _, vvv := range vv {
arr = append(arr, toLowerCaseInterface(vvv, info))
}
@@ -322,8 +291,8 @@ func toLowerCaseInterface(v any, info *fieldInfo) any {
}
}
func toLowerCaseKeyMap(m map[string]any, info *fieldInfo) map[string]any {
res := make(map[string]any)
func toLowerCaseKeyMap(m map[string]interface{}, info *fieldInfo) map[string]interface{} {
res := make(map[string]interface{})
for k, v := range m {
ti, ok := info.children[k]
@@ -337,8 +306,6 @@ func toLowerCaseKeyMap(m map[string]any, info *fieldInfo) map[string]any {
res[lk] = toLowerCaseInterface(v, ti)
} else if info.mapField != nil {
res[k] = toLowerCaseInterface(v, info.mapField)
} else if vv, ok := v.(map[string]any); ok {
res[k] = toLowerCaseKeyMap(vv, info)
} else {
res[k] = v
}
@@ -347,22 +314,14 @@ func toLowerCaseKeyMap(m map[string]any, info *fieldInfo) map[string]any {
return res
}
type conflictKeyError struct {
type dupKeyError struct {
key string
}
func newConflictKeyError(key string) conflictKeyError {
return conflictKeyError{key: key}
func newDupKeyError(key string) dupKeyError {
return dupKeyError{key: key}
}
func (e conflictKeyError) Error() string {
return fmt.Sprintf("conflict key %s, pay attention to anonymous fields", e.key)
}
func getFullName(parent, child string) string {
if len(parent) == 0 {
return child
}
return parent + "." + child
func (e dupKeyError) Error() string {
return fmt.Sprintf("duplicated key %s", e.key)
}

View File

@@ -1,9 +1,7 @@
package conf
import (
"errors"
"os"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
@@ -11,7 +9,7 @@ import (
"github.com/zeromicro/go-zero/core/hash"
)
var dupErr conflictKeyError
var dupErr dupKeyError
func TestLoadConfig_notExists(t *testing.T) {
assert.NotNil(t, Load("not_a_file", nil))
@@ -36,13 +34,14 @@ func TestConfigJson(t *testing.T) {
"c": "${FOO}",
"d": "abcd!@#$112"
}`
t.Setenv("FOO", "2")
for _, test := range tests {
test := test
t.Run(test, func(t *testing.T) {
tmpfile, err := createTempFile(t, test, text)
os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(test, text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
var val struct {
A string `json:"a"`
@@ -81,9 +80,11 @@ b = 1
c = "${FOO}"
d = "abcd!@#$112"
`
t.Setenv("FOO", "2")
tmpfile, err := createTempFile(t, ".toml", text)
os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(".toml", text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
var val struct {
A string `json:"a"`
@@ -104,8 +105,9 @@ b = 1
c = "FOO"
d = "abcd"
`
tmpfile, err := createTempFile(t, ".toml", text)
tmpfile, err := createTempFile(".toml", text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
var val struct {
A string `json:"a"`
@@ -121,23 +123,6 @@ d = "abcd"
}
}
func TestConfigWithLower(t *testing.T) {
text := `a = "foo"
b = 1
`
tmpfile, err := createTempFile(t, ".toml", text)
assert.Nil(t, err)
var val struct {
A string `json:"a"`
b int
}
if assert.NoError(t, Load(tmpfile, &val)) {
assert.Equal(t, "foo", val.A)
assert.Equal(t, 0, val.b)
}
}
func TestConfigJsonCanonical(t *testing.T) {
text := []byte(`{"a": "foo", "B": "bar"}`)
@@ -203,9 +188,11 @@ b = 1
c = "${FOO}"
d = "abcd!@#112"
`
t.Setenv("FOO", "2")
tmpfile, err := createTempFile(t, ".toml", text)
os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(".toml", text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
var val struct {
A string `json:"a"`
@@ -233,12 +220,14 @@ func TestConfigJsonEnv(t *testing.T) {
"c": "${FOO}",
"d": "abcd!@#$a12 3"
}`
t.Setenv("FOO", "2")
for _, test := range tests {
test := test
t.Run(test, func(t *testing.T) {
tmpfile, err := createTempFile(t, test, text)
os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(test, text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
var val struct {
A string `json:"a"`
@@ -643,7 +632,7 @@ func Test_FieldOverwrite(t *testing.T) {
Name2 *string
}
validate := func(val any) {
validate := func(val interface{}) {
input := []byte(`{"Name": "hello", "Name2": "world"}`)
assert.NoError(t, LoadFromJsonBytes(input, val))
}
@@ -679,11 +668,11 @@ func Test_FieldOverwrite(t *testing.T) {
Name *string
}
validate := func(val any) {
validate := func(val interface{}) {
input := []byte(`{"Name": "hello"}`)
err := LoadFromJsonBytes(input, val)
assert.ErrorAs(t, err, &dupErr)
assert.Equal(t, newConflictKeyError("name").Error(), err.Error())
assert.Equal(t, newDupKeyError("name").Error(), err.Error())
}
validate(&St1{})
@@ -722,11 +711,11 @@ func Test_FieldOverwrite(t *testing.T) {
Name *int
}
validate := func(val any) {
validate := func(val interface{}) {
input := []byte(`{"Name": "hello"}`)
err := LoadFromJsonBytes(input, val)
assert.ErrorAs(t, err, &dupErr)
assert.Error(t, err)
assert.Equal(t, newDupKeyError("name").Error(), err.Error())
}
validate(&St0{})
@@ -1033,22 +1022,22 @@ func TestLoadNamedFieldOverwritten(t *testing.T) {
})
}
func TestLoadLowerMemberShouldNotConflict(t *testing.T) {
type (
Redis struct {
db uint
}
func createTempFile(ext, text string) (string, error) {
tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
if err != nil {
return "", err
}
Config struct {
db uint
Redis
}
)
if err := os.WriteFile(tmpfile.Name(), []byte(text), os.ModeTemporary); err != nil {
return "", err
}
var c Config
assert.NoError(t, LoadFromJsonBytes([]byte(`{}`), &c))
assert.Zero(t, c.db)
assert.Zero(t, c.Redis.db)
filename := tmpfile.Name()
if err = tmpfile.Close(); err != nil {
return "", err
}
return filename, nil
}
func TestFillDefaultUnmarshal(t *testing.T) {
@@ -1090,7 +1079,7 @@ func TestFillDefaultUnmarshal(t *testing.T) {
assert.Equal(t, st.C, "c")
})
t.Run("has value", func(t *testing.T) {
t.Run("has vaue", func(t *testing.T) {
type St struct {
A string `json:",default=a"`
B string
@@ -1102,517 +1091,3 @@ func TestFillDefaultUnmarshal(t *testing.T) {
assert.Error(t, err)
})
}
func TestConfigWithJsonTag(t *testing.T) {
t.Run("map with value", func(t *testing.T) {
var input = []byte(`[Value]
[Value.first]
Email = "foo"
[Value.second]
Email = "bar"`)
type Value struct {
Email string
}
type Config struct {
ValueMap map[string]Value `json:"Value"`
}
var c Config
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
assert.Len(t, c.ValueMap, 2)
}
})
t.Run("map with ptr value", func(t *testing.T) {
var input = []byte(`[Value]
[Value.first]
Email = "foo"
[Value.second]
Email = "bar"`)
type Value struct {
Email string
}
type Config struct {
ValueMap map[string]*Value `json:"Value"`
}
var c Config
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
assert.Len(t, c.ValueMap, 2)
}
})
t.Run("map with optional", func(t *testing.T) {
var input = []byte(`[Value]
[Value.first]
Email = "foo"
[Value.second]
Email = "bar"`)
type Value struct {
Email string
}
type Config struct {
Value map[string]Value `json:",optional"`
}
var c Config
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
assert.Len(t, c.Value, 2)
}
})
t.Run("map with empty tag", func(t *testing.T) {
var input = []byte(`[Value]
[Value.first]
Email = "foo"
[Value.second]
Email = "bar"`)
type Value struct {
Email string
}
type Config struct {
Value map[string]Value `json:" "`
}
var c Config
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
assert.Len(t, c.Value, 2)
}
})
t.Run("multi layer map", func(t *testing.T) {
type Value struct {
User struct {
Name string
}
}
type Config struct {
Value map[string]map[string]Value
}
var input = []byte(`
[Value.first.User1.User]
Name = "foo"
[Value.second.User2.User]
Name = "bar"
`)
var c Config
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
assert.Len(t, c.Value, 2)
}
})
}
func Test_LoadBadConfig(t *testing.T) {
type Config struct {
Name string `json:"name,options=foo|bar"`
}
file, err := createTempFile(t, ".json", `{"name": "baz"}`)
assert.NoError(t, err)
var c Config
err = Load(file, &c)
assert.Error(t, err)
}
func Test_getFullName(t *testing.T) {
assert.Equal(t, "a.b", getFullName("a", "b"))
assert.Equal(t, "a", getFullName("", "a"))
}
func TestValidate(t *testing.T) {
t.Run("normal config", func(t *testing.T) {
var c mockConfig
err := LoadFromJsonBytes([]byte(`{"val": "hello", "number": 8}`), &c)
assert.NoError(t, err)
})
t.Run("error no int", func(t *testing.T) {
var c mockConfig
err := LoadFromJsonBytes([]byte(`{"val": "hello"}`), &c)
assert.Error(t, err)
})
t.Run("error no string", func(t *testing.T) {
var c mockConfig
err := LoadFromJsonBytes([]byte(`{"number": 8}`), &c)
assert.Error(t, err)
})
}
func Test_buildFieldsInfo(t *testing.T) {
type ParentSt struct {
Name string
M map[string]int
}
tests := []struct {
name string
t reflect.Type
ok bool
containsKey string
}{
{
name: "normal",
t: reflect.TypeOf(struct{ A string }{}),
ok: true,
},
{
name: "struct anonymous",
t: reflect.TypeOf(struct {
ParentSt
Name string
}{}),
ok: false,
containsKey: newConflictKeyError("name").Error(),
},
{
name: "struct ptr anonymous",
t: reflect.TypeOf(struct {
*ParentSt
Name string
}{}),
ok: false,
containsKey: newConflictKeyError("name").Error(),
},
{
name: "more struct anonymous",
t: reflect.TypeOf(struct {
Value struct {
ParentSt
Name string
}
}{}),
ok: false,
containsKey: newConflictKeyError("value.name").Error(),
},
{
name: "map anonymous",
t: reflect.TypeOf(struct {
ParentSt
M string
}{}),
ok: false,
containsKey: newConflictKeyError("m").Error(),
},
{
name: "map more anonymous",
t: reflect.TypeOf(struct {
Value struct {
ParentSt
M string
}
}{}),
ok: false,
containsKey: newConflictKeyError("value.m").Error(),
},
{
name: "struct slice anonymous",
t: reflect.TypeOf([]struct {
ParentSt
Name string
}{}),
ok: false,
containsKey: newConflictKeyError("name").Error(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := buildFieldsInfo(tt.t, "")
if tt.ok {
assert.NoError(t, err)
} else {
assert.Error(t, err)
assert.Equal(t, err.Error(), tt.containsKey)
}
})
}
}
func createTempFile(t *testing.T, ext, text string) (string, error) {
tmpFile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
if err != nil {
return "", err
}
if err = os.WriteFile(tmpFile.Name(), []byte(text), os.ModeTemporary); err != nil {
return "", err
}
filename := tmpFile.Name()
if err = tmpFile.Close(); err != nil {
return "", err
}
t.Cleanup(func() {
_ = os.Remove(filename)
})
return filename, nil
}
type mockConfig struct {
Val string
Number int
}
func (m mockConfig) Validate() error {
if len(m.Val) == 0 {
return errors.New("val is empty")
}
if m.Number == 0 {
return errors.New("number is zero")
}
return nil
}
func TestGetFullName(t *testing.T) {
tests := []struct {
parent string
child string
want string
}{
{"", "child", "child"},
{"parent", "child", "parent.child"},
{"a.b", "c", "a.b.c"},
{"root", "nested.field", "root.nested.field"},
}
for _, tt := range tests {
t.Run(tt.parent+"."+tt.child, func(t *testing.T) {
got := getFullName(tt.parent, tt.child)
assert.Equal(t, tt.want, got)
})
}
}
// validatorConfig is a test config that implements Validate() for testing validation behavior
type validatorConfig struct {
Value int `json:"value"`
}
func (v *validatorConfig) Validate() error {
if v.Value < 10 {
return errors.New("value must be >= 10")
}
return nil
}
// TestLoadValidation_WithoutEnv tests that validation is called correctly in normal loading path
func TestLoadValidation_WithoutEnv(t *testing.T) {
tests := []struct {
name string
extension string
content string
wantErr bool
errMsg string
}{
{
name: "json valid value",
extension: ".json",
content: `{"value": 15}`,
wantErr: false,
},
{
name: "json invalid value",
extension: ".json",
content: `{"value": 5}`,
wantErr: true,
errMsg: "value must be >= 10",
},
{
name: "yaml valid value",
extension: ".yaml",
content: "value: 20\n",
wantErr: false,
},
{
name: "yaml invalid value",
extension: ".yaml",
content: "value: 3\n",
wantErr: true,
errMsg: "value must be >= 10",
},
{
name: "toml valid value",
extension: ".toml",
content: "value = 100\n",
wantErr: false,
},
{
name: "toml invalid value",
extension: ".toml",
content: "value = 1\n",
wantErr: true,
errMsg: "value must be >= 10",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpfile, err := createTempFile(t, tt.extension, tt.content)
assert.Nil(t, err)
var cfg validatorConfig
err = Load(tmpfile, &cfg)
if tt.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
} else {
assert.NoError(t, err)
}
})
}
}
// TestLoadValidation_WithEnv tests that validation is called correctly with UseEnv() option
func TestLoadValidation_WithEnv(t *testing.T) {
tests := []struct {
name string
extension string
content string
envValue string
wantErr bool
errMsg string
}{
{
name: "json valid value with env",
extension: ".json",
content: `{"value": ${TEST_VALUE}}`,
envValue: "25",
wantErr: false,
},
{
name: "json invalid value with env",
extension: ".json",
content: `{"value": ${TEST_VALUE}}`,
envValue: "7",
wantErr: true,
errMsg: "value must be >= 10",
},
{
name: "yaml valid value with env",
extension: ".yaml",
content: "value: ${TEST_VALUE}\n",
envValue: "50",
wantErr: false,
},
{
name: "yaml invalid value with env",
extension: ".yaml",
content: "value: ${TEST_VALUE}\n",
envValue: "2",
wantErr: true,
errMsg: "value must be >= 10",
},
{
name: "toml valid value with env",
extension: ".toml",
content: "value = ${TEST_VALUE}\n",
envValue: "99",
wantErr: false,
},
{
name: "toml invalid value with env",
extension: ".toml",
content: "value = ${TEST_VALUE}\n",
envValue: "8",
wantErr: true,
errMsg: "value must be >= 10",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("TEST_VALUE", tt.envValue)
tmpfile, err := createTempFile(t, tt.extension, tt.content)
assert.Nil(t, err)
var cfg validatorConfig
err = Load(tmpfile, &cfg, UseEnv())
if tt.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
} else {
assert.NoError(t, err)
}
})
}
}
// TestLoadValidation_Consistency verifies validation behavior is consistent between paths
func TestLoadValidation_Consistency(t *testing.T) {
// Test that both paths (with and without UseEnv) produce the same validation results
const validValue = 15
formats := []struct {
ext string
invalid string
valid string
}{
{".json", `{"value": 5}`, `{"value": 15}`},
{".yaml", "value: 5\n", "value: 15\n"},
{".toml", "value = 5\n", "value = 15\n"},
}
for _, format := range formats {
t.Run("invalid_"+format.ext, func(t *testing.T) {
// Test without UseEnv()
tmpfile1, err := createTempFile(t, format.ext, format.invalid)
assert.Nil(t, err)
var cfg1 validatorConfig
err1 := Load(tmpfile1, &cfg1)
// Test with UseEnv()
tmpfile2, err := createTempFile(t, format.ext, format.invalid)
assert.Nil(t, err)
var cfg2 validatorConfig
err2 := Load(tmpfile2, &cfg2, UseEnv())
// Both should fail validation
assert.Error(t, err1, "validation should fail without UseEnv()")
assert.Error(t, err2, "validation should fail with UseEnv()")
assert.Contains(t, err1.Error(), "value must be >= 10")
assert.Contains(t, err2.Error(), "value must be >= 10")
})
t.Run("valid_"+format.ext, func(t *testing.T) {
// Test without UseEnv()
tmpfile1, err := createTempFile(t, format.ext, format.valid)
assert.Nil(t, err)
var cfg1 validatorConfig
err1 := Load(tmpfile1, &cfg1)
// Test with UseEnv()
tmpfile2, err := createTempFile(t, format.ext, format.valid)
assert.Nil(t, err)
var cfg2 validatorConfig
err2 := Load(tmpfile2, &cfg2, UseEnv())
// Both should pass validation
assert.NoError(t, err1, "validation should pass without UseEnv()")
assert.NoError(t, err2, "validation should pass with UseEnv()")
assert.Equal(t, validValue, cfg1.Value)
assert.Equal(t, validValue, cfg2.Value)
})
}
}

View File

@@ -45,7 +45,7 @@ func LoadProperties(filename string, opts ...Option) (Properties, error) {
raw := make(map[string]string)
for i := range lines {
pair := strings.SplitN(lines[i], "=", 2)
pair := strings.Split(lines[i], "=")
if len(pair) != 2 {
// invalid property format
return nil, &PropertyError{

View File

@@ -45,7 +45,8 @@ func TestPropertiesEnv(t *testing.T) {
assert.Nil(t, err)
defer os.Remove(tmpfile)
t.Setenv("FOO", "2")
os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
props, err := LoadProperties(tmpfile, UseEnv())
assert.Nil(t, err)
@@ -92,70 +93,3 @@ func TestLoadBadFile(t *testing.T) {
_, err := LoadProperties("nosuchfile")
assert.NotNil(t, err)
}
func TestProperties_valueWithEqualSymbols(t *testing.T) {
text := `# test with equal symbols in value
db.url=postgres://localhost:5432/db?param=value
math.equation=a=b=c
base64.data=SGVsbG8=World=Test=
url.with.params=http://example.com?foo=bar&baz=qux
empty.value=
key.with.space = value = with = equals`
tmpfile, err := fs.TempFilenameWithText(text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
props, err := LoadProperties(tmpfile)
assert.Nil(t, err)
assert.Equal(t, "postgres://localhost:5432/db?param=value", props.GetString("db.url"))
assert.Equal(t, "a=b=c", props.GetString("math.equation"))
assert.Equal(t, "SGVsbG8=World=Test=", props.GetString("base64.data"))
assert.Equal(t, "http://example.com?foo=bar&baz=qux", props.GetString("url.with.params"))
assert.Equal(t, "", props.GetString("empty.value"))
assert.Equal(t, "value = with = equals", props.GetString("key.with.space"))
}
func TestProperties_edgeCases(t *testing.T) {
tests := []struct {
name string
content string
wantErr bool
errMsg string
}{
{
name: "no equal sign",
content: "invalid line without equal",
wantErr: true,
},
{
name: "only equal sign",
content: "=",
wantErr: false, // "=" 会被解析为空 key 和空 valuelen(pair) == 2是合法的
},
{
name: "empty key",
content: "=value",
wantErr: false, // 空 key 也会被 trim但 len(pair) == 2 所以不会报错
},
{
name: "equal at end",
content: "key.name=",
wantErr: false, // 空 value 是合法的
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpfile, err := fs.TempFilenameWithText(tt.content)
assert.Nil(t, err)
defer os.Remove(tmpfile)
_, err = LoadProperties(tmpfile)
if tt.wantErr {
assert.NotNil(t, err, "expected error for case: %s", tt.name)
} else {
assert.Nil(t, err, "unexpected error for case: %s", tt.name)
}
})
}
}

View File

@@ -12,7 +12,7 @@ type RestfulConf struct {
MaxConns int `json:",default=10000"`
MaxBytes int64 `json:",default=1048576"`
Timeout time.Duration `json:",default=3s"`
CpuThreshold int64 `json:",default=900,range=[0:1000)"`
CpuThreshold int64 `json:",default=900,range=[0:1000]"`
}
```

View File

@@ -1,12 +0,0 @@
package conf
import "github.com/zeromicro/go-zero/core/validation"
// validate validates the value if it implements the Validator interface.
func validate(v any) error {
if val, ok := v.(validation.Validator); ok {
return val.Validate()
}
return nil
}

View File

@@ -1,81 +0,0 @@
package conf
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
type mockType int
func (m mockType) Validate() error {
if m < 10 {
return errors.New("invalid value")
}
return nil
}
type anotherMockType int
func Test_validate(t *testing.T) {
tests := []struct {
name string
v any
wantErr bool
}{
{
name: "invalid",
v: mockType(5),
wantErr: true,
},
{
name: "valid",
v: mockType(10),
wantErr: false,
},
{
name: "not validator",
v: anotherMockType(5),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validate(tt.v)
assert.Equal(t, tt.wantErr, err != nil)
})
}
}
type mockVal struct {
}
func (m mockVal) Validate() error {
return errors.New("invalid value")
}
func Test_validateValPtr(t *testing.T) {
tests := []struct {
name string
v any
wantErr bool
}{
{
name: "invalid",
v: mockVal{},
},
{
name: "invalid value",
v: &mockVal{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Error(t, validate(tt.v))
})
}
}

View File

@@ -1,200 +0,0 @@
package configurator
import (
"errors"
"fmt"
"reflect"
"strings"
"sync"
"sync/atomic"
"github.com/zeromicro/go-zero/core/configcenter/subscriber"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/mapping"
"github.com/zeromicro/go-zero/core/threading"
)
var (
errEmptyConfig = errors.New("empty config value")
errMissingUnmarshalerType = errors.New("missing unmarshaler type")
)
// Configurator is the interface for configuration center.
type Configurator[T any] interface {
// GetConfig returns the subscription value.
GetConfig() (T, error)
// AddListener adds a listener to the subscriber.
AddListener(listener func())
}
type (
// Config is the configuration for Configurator.
Config struct {
// Type is the value type, yaml, json or toml.
Type string `json:",default=yaml,options=[yaml,json,toml]"`
// Log is the flag to control logging.
Log bool `json:",default=true"`
}
configCenter[T any] struct {
conf Config
unmarshaler LoaderFn
subscriber subscriber.Subscriber
listeners []func()
lock sync.Mutex
snapshot atomic.Value
}
value[T any] struct {
data string
marshalData T
err error
}
)
// Configurator is the interface for configuration center.
var _ Configurator[any] = (*configCenter[any])(nil)
// MustNewConfigCenter returns a Configurator, exits on errors.
func MustNewConfigCenter[T any](c Config, subscriber subscriber.Subscriber) Configurator[T] {
cc, err := NewConfigCenter[T](c, subscriber)
logx.Must(err)
return cc
}
// NewConfigCenter returns a Configurator.
func NewConfigCenter[T any](c Config, subscriber subscriber.Subscriber) (Configurator[T], error) {
unmarshaler, ok := Unmarshaler(strings.ToLower(c.Type))
if !ok {
return nil, fmt.Errorf("unknown format: %s", c.Type)
}
cc := &configCenter[T]{
conf: c,
unmarshaler: unmarshaler,
subscriber: subscriber,
}
if err := cc.loadConfig(); err != nil {
return nil, err
}
if err := cc.subscriber.AddListener(cc.onChange); err != nil {
return nil, err
}
if _, err := cc.GetConfig(); err != nil {
return nil, err
}
return cc, nil
}
// AddListener adds listener to s.
func (c *configCenter[T]) AddListener(listener func()) {
c.lock.Lock()
defer c.lock.Unlock()
c.listeners = append(c.listeners, listener)
}
// GetConfig return structured config.
func (c *configCenter[T]) GetConfig() (T, error) {
v := c.value()
if v == nil || len(v.data) == 0 {
var empty T
return empty, errEmptyConfig
}
return v.marshalData, v.err
}
// Value returns the subscription value.
func (c *configCenter[T]) Value() string {
v := c.value()
if v == nil {
return ""
}
return v.data
}
func (c *configCenter[T]) loadConfig() error {
v, err := c.subscriber.Value()
if err != nil {
if c.conf.Log {
logx.Errorf("ConfigCenter loads changed configuration, error: %v", err)
}
return err
}
if c.conf.Log {
logx.Infof("ConfigCenter loads changed configuration, content [%s]", v)
}
c.snapshot.Store(c.genValue(v))
return nil
}
func (c *configCenter[T]) onChange() {
if err := c.loadConfig(); err != nil {
return
}
c.lock.Lock()
listeners := make([]func(), len(c.listeners))
copy(listeners, c.listeners)
c.lock.Unlock()
for _, l := range listeners {
threading.GoSafe(l)
}
}
func (c *configCenter[T]) value() *value[T] {
content := c.snapshot.Load()
if content == nil {
return nil
}
return content.(*value[T])
}
func (c *configCenter[T]) genValue(data string) *value[T] {
v := &value[T]{
data: data,
}
if len(data) == 0 {
return v
}
t := reflect.TypeOf(v.marshalData)
// if the type is nil, it means that the user has not set the type of the configuration.
if t == nil {
v.err = errMissingUnmarshalerType
return v
}
t = mapping.Deref(t)
switch t.Kind() {
case reflect.Struct, reflect.Array, reflect.Slice:
if err := c.unmarshaler([]byte(data), &v.marshalData); err != nil {
v.err = err
if c.conf.Log {
logx.Errorf("ConfigCenter unmarshal configuration failed, err: %+v, content [%s]",
err.Error(), data)
}
}
case reflect.String:
if str, ok := any(data).(T); ok {
v.marshalData = str
} else {
v.err = errMissingUnmarshalerType
}
default:
if c.conf.Log {
logx.Errorf("ConfigCenter unmarshal configuration missing unmarshaler for type: %s, content [%s]",
t.Kind(), data)
}
v.err = errMissingUnmarshalerType
}
return v
}

View File

@@ -1,233 +0,0 @@
package configurator
import (
"errors"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestNewConfigCenter(t *testing.T) {
_, err := NewConfigCenter[any](Config{
Log: true,
}, &mockSubscriber{})
assert.Error(t, err)
_, err = NewConfigCenter[any](Config{
Type: "json",
Log: true,
}, &mockSubscriber{})
assert.Error(t, err)
}
func TestConfigCenter_GetConfig(t *testing.T) {
mock := &mockSubscriber{}
type Data struct {
Name string `json:"name"`
}
mock.v = `{"name": "go-zero"}`
c1, err := NewConfigCenter[Data](Config{
Type: "json",
Log: true,
}, mock)
assert.NoError(t, err)
data, err := c1.GetConfig()
assert.NoError(t, err)
assert.Equal(t, "go-zero", data.Name)
mock.v = `{"name": "111"}`
c2, err := NewConfigCenter[Data](Config{Type: "json"}, mock)
assert.NoError(t, err)
mock.v = `{}`
c3, err := NewConfigCenter[string](Config{
Type: "json",
Log: true,
}, mock)
assert.NoError(t, err)
_, err = c3.GetConfig()
assert.NoError(t, err)
data, err = c2.GetConfig()
assert.NoError(t, err)
mock.lisErr = errors.New("mock error")
_, err = NewConfigCenter[Data](Config{
Type: "json",
Log: true,
}, mock)
assert.Error(t, err)
}
func TestConfigCenter_onChange(t *testing.T) {
mock := &mockSubscriber{}
type Data struct {
Name string `json:"name"`
}
mock.v = `{"name": "go-zero"}`
c1, err := NewConfigCenter[Data](Config{Type: "json", Log: true}, mock)
assert.NoError(t, err)
data, err := c1.GetConfig()
assert.NoError(t, err)
assert.Equal(t, "go-zero", data.Name)
mock.v = `{"name": "go-zero2"}`
mock.change()
data, err = c1.GetConfig()
assert.NoError(t, err)
assert.Equal(t, "go-zero2", data.Name)
mock.valErr = errors.New("mock error")
_, err = NewConfigCenter[Data](Config{Type: "json", Log: false}, mock)
assert.Error(t, err)
}
func TestConfigCenter_Value(t *testing.T) {
mock := &mockSubscriber{}
mock.v = "1234"
c, err := NewConfigCenter[string](Config{
Type: "json",
Log: true,
}, mock)
assert.NoError(t, err)
cc := c.(*configCenter[string])
assert.Equal(t, cc.Value(), "1234")
mock.valErr = errors.New("mock error")
_, err = NewConfigCenter[any](Config{
Type: "json",
Log: true,
}, mock)
assert.Error(t, err)
}
func TestConfigCenter_AddListener(t *testing.T) {
mock := &mockSubscriber{}
mock.v = "1234"
c, err := NewConfigCenter[string](Config{
Type: "json",
Log: true,
}, mock)
assert.NoError(t, err)
cc := c.(*configCenter[string])
var a, b int
var mutex sync.Mutex
cc.AddListener(func() {
mutex.Lock()
a = 1
mutex.Unlock()
})
cc.AddListener(func() {
mutex.Lock()
b = 2
mutex.Unlock()
})
assert.Equal(t, 2, len(cc.listeners))
mock.change()
time.Sleep(time.Millisecond * 100)
mutex.Lock()
assert.Equal(t, 1, a)
assert.Equal(t, 2, b)
mutex.Unlock()
}
func TestConfigCenter_genValue(t *testing.T) {
t.Run("data is empty", func(t *testing.T) {
c := &configCenter[string]{
unmarshaler: registry.unmarshalers["json"],
conf: Config{Log: true},
}
v := c.genValue("")
assert.Equal(t, "", v.data)
})
t.Run("invalid template type", func(t *testing.T) {
c := &configCenter[any]{
unmarshaler: registry.unmarshalers["json"],
conf: Config{Log: true},
}
v := c.genValue("xxxx")
assert.Equal(t, errMissingUnmarshalerType, v.err)
})
t.Run("unsupported template type", func(t *testing.T) {
c := &configCenter[int]{
unmarshaler: registry.unmarshalers["json"],
conf: Config{Log: true},
}
v := c.genValue("1")
assert.Equal(t, errMissingUnmarshalerType, v.err)
})
t.Run("supported template string type", func(t *testing.T) {
c := &configCenter[string]{
unmarshaler: registry.unmarshalers["json"],
conf: Config{Log: true},
}
v := c.genValue("12345")
assert.NoError(t, v.err)
assert.Equal(t, "12345", v.data)
})
t.Run("unmarshal fail", func(t *testing.T) {
c := &configCenter[struct {
Name string `json:"name"`
}]{
unmarshaler: registry.unmarshalers["json"],
conf: Config{Log: true},
}
v := c.genValue(`{"name":"new name}`)
assert.Equal(t, `{"name":"new name}`, v.data)
assert.Error(t, v.err)
})
t.Run("success", func(t *testing.T) {
c := &configCenter[struct {
Name string `json:"name"`
}]{
unmarshaler: registry.unmarshalers["json"],
conf: Config{Log: true},
}
v := c.genValue(`{"name":"new name"}`)
assert.Equal(t, `{"name":"new name"}`, v.data)
assert.Equal(t, "new name", v.marshalData.Name)
assert.NoError(t, v.err)
})
}
type mockSubscriber struct {
v string
lisErr, valErr error
listener func()
}
func (m *mockSubscriber) AddListener(listener func()) error {
m.listener = listener
return m.lisErr
}
func (m *mockSubscriber) Value() (string, error) {
return m.v, m.valErr
}
func (m *mockSubscriber) change() {
if m.listener != nil {
m.listener()
}
}

View File

@@ -1,115 +0,0 @@
package subscriber
import (
"sync"
"sync/atomic"
"github.com/zeromicro/go-zero/core/discov"
"github.com/zeromicro/go-zero/core/logx"
)
type (
// etcdSubscriber is a subscriber that subscribes to etcd.
etcdSubscriber struct {
*discov.Subscriber
}
// EtcdConf is the configuration for etcd.
EtcdConf = discov.EtcdConf
)
// MustNewEtcdSubscriber returns an etcd Subscriber, exits on errors.
func MustNewEtcdSubscriber(conf EtcdConf) Subscriber {
s, err := NewEtcdSubscriber(conf)
logx.Must(err)
return s
}
// NewEtcdSubscriber returns an etcd Subscriber.
func NewEtcdSubscriber(conf EtcdConf) (Subscriber, error) {
opts := buildSubOptions(conf)
s, err := discov.NewSubscriber(conf.Hosts, conf.Key, opts...)
if err != nil {
return nil, err
}
return &etcdSubscriber{Subscriber: s}, nil
}
// buildSubOptions constructs the options for creating a new etcd subscriber.
func buildSubOptions(conf EtcdConf) []discov.SubOption {
opts := []discov.SubOption{
discov.WithExactMatch(),
discov.WithContainer(newContainer()),
}
if len(conf.User) > 0 {
opts = append(opts, discov.WithSubEtcdAccount(conf.User, conf.Pass))
}
if len(conf.CertFile) > 0 || len(conf.CertKeyFile) > 0 || len(conf.CACertFile) > 0 {
opts = append(opts, discov.WithSubEtcdTLS(conf.CertFile, conf.CertKeyFile,
conf.CACertFile, conf.InsecureSkipVerify))
}
return opts
}
// AddListener adds a listener to the subscriber.
func (s *etcdSubscriber) AddListener(listener func()) error {
s.Subscriber.AddListener(listener)
return nil
}
// Value returns the value of the subscriber.
func (s *etcdSubscriber) Value() (string, error) {
vs := s.Subscriber.Values()
if len(vs) > 0 {
return vs[len(vs)-1], nil
}
return "", nil
}
type container struct {
value atomic.Value
listeners []func()
lock sync.Mutex
}
func newContainer() *container {
return &container{}
}
func (c *container) OnAdd(kv discov.KV) {
c.value.Store([]string{kv.Val})
c.notifyChange()
}
func (c *container) OnDelete(_ discov.KV) {
c.value.Store([]string(nil))
c.notifyChange()
}
func (c *container) AddListener(listener func()) {
c.lock.Lock()
c.listeners = append(c.listeners, listener)
c.lock.Unlock()
}
func (c *container) GetValues() []string {
if vals, ok := c.value.Load().([]string); ok {
return vals
}
return []string(nil)
}
func (c *container) notifyChange() {
c.lock.Lock()
listeners := append(([]func())(nil), c.listeners...)
c.lock.Unlock()
for _, listener := range listeners {
listener()
}
}

View File

@@ -1,186 +0,0 @@
package subscriber
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/discov"
)
const (
actionAdd = iota
actionDel
)
func TestConfigCenterContainer(t *testing.T) {
type action struct {
act int
key string
val string
}
tests := []struct {
name string
do []action
expect []string
}{
{
name: "add one",
do: []action{
{
act: actionAdd,
key: "first",
val: "a",
},
},
expect: []string{
"a",
},
},
{
name: "add two",
do: []action{
{
act: actionAdd,
key: "first",
val: "a",
},
{
act: actionAdd,
key: "second",
val: "b",
},
},
expect: []string{
"b",
},
},
{
name: "add two, delete one",
do: []action{
{
act: actionAdd,
key: "first",
val: "a",
},
{
act: actionAdd,
key: "second",
val: "b",
},
{
act: actionDel,
key: "first",
},
},
expect: []string(nil),
},
{
name: "add two, delete two",
do: []action{
{
act: actionAdd,
key: "first",
val: "a",
},
{
act: actionAdd,
key: "second",
val: "b",
},
{
act: actionDel,
key: "first",
},
{
act: actionDel,
key: "second",
},
},
expect: []string(nil),
},
{
name: "add two, dup values",
do: []action{
{
act: actionAdd,
key: "first",
val: "a",
},
{
act: actionAdd,
key: "second",
val: "b",
},
{
act: actionAdd,
key: "third",
val: "a",
},
},
expect: []string{"a"},
},
{
name: "add three, dup values, delete two, add one",
do: []action{
{
act: actionAdd,
key: "first",
val: "a",
},
{
act: actionAdd,
key: "second",
val: "b",
},
{
act: actionAdd,
key: "third",
val: "a",
},
{
act: actionDel,
key: "first",
},
{
act: actionDel,
key: "second",
},
{
act: actionAdd,
key: "forth",
val: "c",
},
},
expect: []string{"c"},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var changed bool
c := newContainer()
c.AddListener(func() {
changed = true
})
assert.Nil(t, c.GetValues())
assert.False(t, changed)
for _, order := range test.do {
if order.act == actionAdd {
c.OnAdd(discov.KV{
Key: order.key,
Val: order.val,
})
} else {
c.OnDelete(discov.KV{
Key: order.key,
Val: order.val,
})
}
}
assert.True(t, changed)
assert.ElementsMatch(t, test.expect, c.GetValues())
})
}
}

View File

@@ -1,9 +0,0 @@
package subscriber
// Subscriber is the interface for configcenter subscribers.
type Subscriber interface {
// AddListener adds a listener to the subscriber.
AddListener(listener func()) error
// Value returns the value of the subscriber.
Value() (string, error)
}

View File

@@ -1,41 +0,0 @@
package configurator
import (
"sync"
"github.com/zeromicro/go-zero/core/conf"
)
var registry = &unmarshalerRegistry{
unmarshalers: map[string]LoaderFn{
"json": conf.LoadFromJsonBytes,
"toml": conf.LoadFromTomlBytes,
"yaml": conf.LoadFromYamlBytes,
},
}
type (
// LoaderFn is the function type for loading configuration.
LoaderFn func([]byte, any) error
// unmarshalerRegistry is the registry for unmarshalers.
unmarshalerRegistry struct {
unmarshalers map[string]LoaderFn
mu sync.RWMutex
}
)
// RegisterUnmarshaler registers an unmarshaler.
func RegisterUnmarshaler(name string, fn LoaderFn) {
registry.mu.Lock()
defer registry.mu.Unlock()
registry.unmarshalers[name] = fn
}
// Unmarshaler returns the unmarshaler by name.
func Unmarshaler(name string) (LoaderFn, bool) {
registry.mu.RLock()
defer registry.mu.RUnlock()
fn, ok := registry.unmarshalers[name]
return fn, ok
}

View File

@@ -1,28 +0,0 @@
package configurator
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestRegisterUnmarshaler(t *testing.T) {
RegisterUnmarshaler("test", func(data []byte, v interface{}) error {
return nil
})
_, ok := Unmarshaler("test")
assert.True(t, ok)
_, ok = Unmarshaler("test2")
assert.False(t, ok)
_, ok = Unmarshaler("json")
assert.True(t, ok)
_, ok = Unmarshaler("toml")
assert.True(t, ok)
_, ok = Unmarshaler("yaml")
assert.True(t, ok)
}

View File

@@ -14,13 +14,13 @@ type contextValuer struct {
context.Context
}
func (cv contextValuer) Value(key string) (any, bool) {
func (cv contextValuer) Value(key string) (interface{}, bool) {
v := cv.Context.Value(key)
return v, v != nil
}
// For unmarshals ctx into v.
func For(ctx context.Context, v any) error {
func For(ctx context.Context, v interface{}) error {
return unmarshaler.UnmarshalValuer(contextValuer{
Context: ctx,
}, v)

View File

@@ -13,7 +13,6 @@ var (
type EtcdConf struct {
Hosts []string
Key string
ID int64 `json:",optional"`
User string `json:",optional"`
Pass string `json:",optional"`
CertFile string `json:",optional"`
@@ -27,11 +26,6 @@ func (c EtcdConf) HasAccount() bool {
return len(c.User) > 0 && len(c.Pass) > 0
}
// HasID returns if ID provided.
func (c EtcdConf) HasID() bool {
return c.ID > 0
}
// HasTLS returns if TLS CertFile/CertKeyFile/CACertFile are provided.
func (c EtcdConf) HasTLS() bool {
return len(c.CertFile) > 0 && len(c.CertKeyFile) > 0 && len(c.CACertFile) > 0

View File

@@ -80,90 +80,3 @@ func TestEtcdConf_HasAccount(t *testing.T) {
assert.Equal(t, test.hasAccount, test.EtcdConf.HasAccount())
}
}
func TestEtcdConf_HasID(t *testing.T) {
tests := []struct {
EtcdConf
hasServerID bool
}{
{
EtcdConf: EtcdConf{
Hosts: []string{"any"},
ID: -1,
},
hasServerID: false,
},
{
EtcdConf: EtcdConf{
Hosts: []string{"any"},
ID: 0,
},
hasServerID: false,
},
{
EtcdConf: EtcdConf{
Hosts: []string{"any"},
ID: 10000,
},
hasServerID: true,
},
}
for _, test := range tests {
assert.Equal(t, test.hasServerID, test.EtcdConf.HasID())
}
}
func TestEtcdConf_HasTLS(t *testing.T) {
tests := []struct {
name string
conf EtcdConf
want bool
}{
{
name: "empty config",
conf: EtcdConf{},
want: false,
},
{
name: "missing CertFile",
conf: EtcdConf{
CertKeyFile: "key",
CACertFile: "ca",
},
want: false,
},
{
name: "missing CertKeyFile",
conf: EtcdConf{
CertFile: "cert",
CACertFile: "ca",
},
want: false,
},
{
name: "missing CACertFile",
conf: EtcdConf{
CertFile: "cert",
CertKeyFile: "key",
},
want: false,
},
{
name: "valid config",
conf: EtcdConf{
CertFile: "cert",
CertKeyFile: "key",
CACertFile: "ca",
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.conf.HasTLS()
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -1,85 +1,12 @@
package internal
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stringx"
)
const (
certContent = `-----BEGIN CERTIFICATE-----
MIIDazCCAlOgAwIBAgIUEg9GVO2oaPn+YSmiqmFIuAo10WIwDQYJKoZIhvcNAQEM
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAzMTExMzIxMjNaGA8yMTIz
MDIxNTEzMjEyM1owRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
AQEBBQADggEPADCCAQoCggEBALplXlWsIf0O/IgnIplmiZHKGnxyfyufyE2FBRNk
OofRqbKuPH8GNqbkvZm7N29fwTDAQ+mViAggCkDht4hOzoWJMA7KYJt8JnTSWL48
M1lcrpc9DL2gszC/JF/FGvyANbBtLklkZPFBGdHUX14pjrT937wqPtm+SqUHSvRT
B7bmwmm2drRcmhpVm98LSlV7uQ2EgnJgsLjBPITKUejLmVLHfgX0RwQ2xIpX9pS4
FCe1BTacwl2gGp7Mje7y4Mfv3o0ArJW6Tuwbjx59ZXwb1KIP71b7bT04AVS8ZeYO
UMLKKuB5UR9x9Rn6cLXOTWBpcMVyzDgrAFLZjnE9LPUolZMCAwEAAaNRME8wHwYD
VR0jBBgwFoAUeW8w8pmhncbRgTsl48k4/7wnfx8wCQYDVR0TBAIwADALBgNVHQ8E
BAMCBPAwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBDAUAA4IBAQAI
y9xaoS88CLPBsX6mxfcTAFVfGNTRW9VN9Ng1cCnUR+YGoXGM/l+qP4f7p8ocdGwK
iYZErVTzXYIn+D27//wpY3klJk3gAnEUBT3QRkStBw7XnpbeZ2oPBK+cmDnCnZPS
BIF1wxPX7vIgaxs5Zsdqwk3qvZ4Djr2wP7LabNWTLSBKgQoUY45Liw6pffLwcGF9
UKlu54bvGze2SufISCR3ib+I+FLvqpvJhXToZWYb/pfI/HccuCL1oot1x8vx6DQy
U+TYxlZsKS5mdNxAX3dqEkEMsgEi+g/tzDPXJImfeCGGBhIOXLm8SRypiuGdEbc9
xkWYxRPegajuEZGvCqVs
-----END CERTIFICATE-----`
keyContent = `-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAumVeVawh/Q78iCcimWaJkcoafHJ/K5/ITYUFE2Q6h9Gpsq48
fwY2puS9mbs3b1/BMMBD6ZWICCAKQOG3iE7OhYkwDspgm3wmdNJYvjwzWVyulz0M
vaCzML8kX8Ua/IA1sG0uSWRk8UEZ0dRfXimOtP3fvCo+2b5KpQdK9FMHtubCabZ2
tFyaGlWb3wtKVXu5DYSCcmCwuME8hMpR6MuZUsd+BfRHBDbEilf2lLgUJ7UFNpzC
XaAansyN7vLgx+/ejQCslbpO7BuPHn1lfBvUog/vVvttPTgBVLxl5g5Qwsoq4HlR
H3H1Gfpwtc5NYGlwxXLMOCsAUtmOcT0s9SiVkwIDAQABAoIBAD5meTJNMgO55Kjg
ESExxpRcCIno+tHr5+6rvYtEXqPheOIsmmwb9Gfi4+Z3WpOaht5/Pz0Ppj6yGzyl
U//6AgGKb+BDuBvVcDpjwPnOxZIBCSHwejdxeQu0scSuA97MPS0XIAvJ5FEv7ijk
5Bht6SyGYURpECltHygoTNuGgGqmO+McCJRLE9L09lTBI6UQ/JQwWJqSr7wx6iPU
M1Ze/srIV+7cyEPu6i0DGjS1gSQKkX68Lqn1w6oE290O+OZvleO0gZ02fLDWCZke
aeD9+EU/Pw+rqm3H6o0szOFIpzhRp41FUdW9sybB3Yp3u7c/574E+04Z/e30LMKs
TCtE1QECgYEA3K7KIpw0NH2HXL5C3RHcLmr204xeBfS70riBQQuVUgYdmxak2ima
80RInskY8hRhSGTg0l+VYIH8cmjcUyqMSOELS5XfRH99r4QPiK8AguXg80T4VumY
W3Pf+zEC2ssgP/gYthV0g0Xj5m2QxktOF9tRw5nkg739ZR4dI9lm/iECgYEA2Dnf
uwEDGqHiQRF6/fh5BG/nGVMvrefkqx6WvTJQ3k/M/9WhxB+lr/8yH46TuS8N2b29
FoTf3Mr9T7pr/PWkOPzoY3P56nYbKU8xSwCim9xMzhBMzj8/N9ukJvXy27/VOz56
eQaKqnvdXNGtPJrIMDGHps2KKWlKLyAlapzjVTMCgYAA/W++tACv85g13EykfT4F
n0k4LbsGP9DP4zABQLIMyiY72eAncmRVjwrcW36XJ2xATOONTgx3gF3HjZzfaqNy
eD/6uNNllUTVEryXGmHgNHPL45VRnn6memCY2eFvZdXhM5W4y2PYaunY0MkDercA
+GTngbs6tBF88KOk04bYwQKBgFl68cRgsdkmnwwQYNaTKfmVGYzYaQXNzkqmWPko
xmCJo6tHzC7ubdG8iRCYHzfmahPuuj6EdGPZuSRyYFgJi5Ftz/nAN+84OxtIQ3zn
YWOgskQgaLh9YfsKsQ7Sf1NDOsnOnD5TX7UXl07fEpLe9vNCvAFiU8e5Y9LGudU5
4bYTAoGBAMdX3a3bXp4cZvXNBJ/QLVyxC6fP1Q4haCR1Od3m+T00Jth2IX2dk/fl
p6xiJT1av5JtYabv1dFKaXOS5s1kLGGuCCSKpkvFZm826aQ2AFm0XGqEQDLeei5b
A52Kpy/YJ+RkG4BTFtAooFq6DmA0cnoP6oPvG2h6XtDJwDTPInJb
-----END RSA PRIVATE KEY-----`
caContent = `-----BEGIN CERTIFICATE-----
MIIDbTCCAlWgAwIBAgIUBJvFoCowKich7MMfseJ+DYzzirowDQYJKoZIhvcNAQEM
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAzMTExMzIxMDNaGA8yMTIz
MDIxNTEzMjEwM1owRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
AQEBBQADggEPADCCAQoCggEBAO4to2YMYj0bxgr2FCiweSTSFuPx33zSw2x/s9Wf
OR41bm2DFsyYT5f3sOIKlXZEdLmOKty2e3ho3yC0EyNpVHdykkkHT3aDI17quZax
kYi/URqqtl1Z08A22txolc04hAZisg2BypGi3vql81UW1t3zyloGnJoIAeXR9uca
ljP6Bk3bwsxoVBLi1JtHrO0hHLQaeHmKhAyrys06X0LRdn7Px48yRZlt6FaLSa8X
YiRM0G44bVy/h6BkoQjMYGwVmCVk6zjJ9U7ZPFqdnDMNxAfR+hjDnYodqdLDMTTR
1NPVrnEnNwFx0AMLvgt/ba/45vZCEAmSZnFXFAJJcM7ai9ECAwEAAaNTMFEwHQYD
VR0OBBYEFHlvMPKZoZ3G0YE7JePJOP+8J38fMB8GA1UdIwQYMBaAFHlvMPKZoZ3G
0YE7JePJOP+8J38fMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQEMBQADggEB
AMX8dNulADOo9uQgBMyFb9TVra7iY0zZjzv4GY5XY7scd52n6CnfAPvYBBDnTr/O
BgNp5jaujb4+9u/2qhV3f9n+/3WOb2CmPehBgVSzlXqHeQ9lshmgwZPeem2T+8Tm
Nnc/xQnsUfCFszUDxpkr55+aLVM22j02RWqcZ4q7TAaVYL+kdFVMc8FoqG/0ro6A
BjE/Qn0Nn7ciX1VUjDt8l+k7ummPJTmzdi6i6E4AwO9dzrGNgGJ4aWL8cC6xYcIX
goVIRTFeONXSDno/oPjWHpIPt7L15heMpKBHNuzPkKx2YVqPHE5QZxWfS+Lzgx+Q
E2oTTM0rYKOZ8p6000mhvKI=
-----END CERTIFICATE-----`
)
func TestAccount(t *testing.T) {
endpoints := []string{
"192.168.0.2:2379",
@@ -105,34 +32,3 @@ func TestAccount(t *testing.T) {
assert.Equal(t, username, account.User)
assert.Equal(t, anotherPassword, account.Pass)
}
func TestTLSMethods(t *testing.T) {
certFile := createTempFile(t, []byte(certContent))
defer os.Remove(certFile)
keyFile := createTempFile(t, []byte(keyContent))
defer os.Remove(keyFile)
caFile := createTempFile(t, []byte(caContent))
defer os.Remove(caFile)
assert.NoError(t, AddTLS([]string{"foo"}, certFile, keyFile, caFile, false))
cfg, ok := GetTLS([]string{"foo"})
assert.True(t, ok)
assert.NotNil(t, cfg)
assert.Error(t, AddTLS([]string{"bar"}, "bad-file", keyFile, caFile, false))
assert.Error(t, AddTLS([]string{"bar"}, certFile, keyFile, "bad-file", false))
}
func createTempFile(t *testing.T, body []byte) string {
tmpFile, err := os.CreateTemp(os.TempDir(), "go-unit-*.tmp")
if err != nil {
t.Fatal(err)
}
tmpFile.Close()
if err = os.WriteFile(tmpFile.Name(), body, os.ModePerm); err != nil {
t.Fatal(err)
}
return tmpFile.Name()
}

View File

@@ -1,10 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: etcdclient.go
//
// Generated by this command:
//
// mockgen -package internal -destination etcdclient_mock.go -source etcdclient.go EtcdClient
//
// Package internal is a generated GoMock package.
package internal
@@ -13,36 +8,35 @@ import (
context "context"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
clientv3 "go.etcd.io/etcd/client/v3"
gomock "go.uber.org/mock/gomock"
grpc "google.golang.org/grpc"
)
// MockEtcdClient is a mock of EtcdClient interface.
// MockEtcdClient is a mock of EtcdClient interface
type MockEtcdClient struct {
ctrl *gomock.Controller
recorder *MockEtcdClientMockRecorder
isgomock struct{}
}
// MockEtcdClientMockRecorder is the mock recorder for MockEtcdClient.
// MockEtcdClientMockRecorder is the mock recorder for MockEtcdClient
type MockEtcdClientMockRecorder struct {
mock *MockEtcdClient
}
// NewMockEtcdClient creates a new mock instance.
// NewMockEtcdClient creates a new mock instance
func NewMockEtcdClient(ctrl *gomock.Controller) *MockEtcdClient {
mock := &MockEtcdClient{ctrl: ctrl}
mock.recorder = &MockEtcdClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockEtcdClient) EXPECT() *MockEtcdClientMockRecorder {
return m.recorder
}
// ActiveConnection mocks base method.
// ActiveConnection mocks base method
func (m *MockEtcdClient) ActiveConnection() *grpc.ClientConn {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ActiveConnection")
@@ -50,13 +44,13 @@ func (m *MockEtcdClient) ActiveConnection() *grpc.ClientConn {
return ret0
}
// ActiveConnection indicates an expected call of ActiveConnection.
// ActiveConnection indicates an expected call of ActiveConnection
func (mr *MockEtcdClientMockRecorder) ActiveConnection() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActiveConnection", reflect.TypeOf((*MockEtcdClient)(nil).ActiveConnection))
}
// Close mocks base method.
// Close mocks base method
func (m *MockEtcdClient) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
@@ -64,13 +58,13 @@ func (m *MockEtcdClient) Close() error {
return ret0
}
// Close indicates an expected call of Close.
// Close indicates an expected call of Close
func (mr *MockEtcdClientMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockEtcdClient)(nil).Close))
}
// Ctx mocks base method.
// Ctx mocks base method
func (m *MockEtcdClient) Ctx() context.Context {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Ctx")
@@ -78,16 +72,16 @@ func (m *MockEtcdClient) Ctx() context.Context {
return ret0
}
// Ctx indicates an expected call of Ctx.
// Ctx indicates an expected call of Ctx
func (mr *MockEtcdClientMockRecorder) Ctx() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ctx", reflect.TypeOf((*MockEtcdClient)(nil).Ctx))
}
// Get mocks base method.
// Get mocks base method
func (m *MockEtcdClient) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, key}
varargs := []interface{}{ctx, key}
for _, a := range opts {
varargs = append(varargs, a)
}
@@ -97,14 +91,14 @@ func (m *MockEtcdClient) Get(ctx context.Context, key string, opts ...clientv3.O
return ret0, ret1
}
// Get indicates an expected call of Get.
func (mr *MockEtcdClientMockRecorder) Get(ctx, key any, opts ...any) *gomock.Call {
// Get indicates an expected call of Get
func (mr *MockEtcdClientMockRecorder) Get(ctx, key interface{}, opts ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, key}, opts...)
varargs := append([]interface{}{ctx, key}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockEtcdClient)(nil).Get), varargs...)
}
// Grant mocks base method.
// Grant mocks base method
func (m *MockEtcdClient) Grant(ctx context.Context, ttl int64) (*clientv3.LeaseGrantResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Grant", ctx, ttl)
@@ -113,13 +107,13 @@ func (m *MockEtcdClient) Grant(ctx context.Context, ttl int64) (*clientv3.LeaseG
return ret0, ret1
}
// Grant indicates an expected call of Grant.
func (mr *MockEtcdClientMockRecorder) Grant(ctx, ttl any) *gomock.Call {
// Grant indicates an expected call of Grant
func (mr *MockEtcdClientMockRecorder) Grant(ctx, ttl interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Grant", reflect.TypeOf((*MockEtcdClient)(nil).Grant), ctx, ttl)
}
// KeepAlive mocks base method.
// KeepAlive mocks base method
func (m *MockEtcdClient) KeepAlive(ctx context.Context, id clientv3.LeaseID) (<-chan *clientv3.LeaseKeepAliveResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "KeepAlive", ctx, id)
@@ -128,16 +122,16 @@ func (m *MockEtcdClient) KeepAlive(ctx context.Context, id clientv3.LeaseID) (<-
return ret0, ret1
}
// KeepAlive indicates an expected call of KeepAlive.
func (mr *MockEtcdClientMockRecorder) KeepAlive(ctx, id any) *gomock.Call {
// KeepAlive indicates an expected call of KeepAlive
func (mr *MockEtcdClientMockRecorder) KeepAlive(ctx, id interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeepAlive", reflect.TypeOf((*MockEtcdClient)(nil).KeepAlive), ctx, id)
}
// Put mocks base method.
// Put mocks base method
func (m *MockEtcdClient) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, key, val}
varargs := []interface{}{ctx, key, val}
for _, a := range opts {
varargs = append(varargs, a)
}
@@ -147,14 +141,14 @@ func (m *MockEtcdClient) Put(ctx context.Context, key, val string, opts ...clien
return ret0, ret1
}
// Put indicates an expected call of Put.
func (mr *MockEtcdClientMockRecorder) Put(ctx, key, val any, opts ...any) *gomock.Call {
// Put indicates an expected call of Put
func (mr *MockEtcdClientMockRecorder) Put(ctx, key, val interface{}, opts ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, key, val}, opts...)
varargs := append([]interface{}{ctx, key, val}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockEtcdClient)(nil).Put), varargs...)
}
// Revoke mocks base method.
// Revoke mocks base method
func (m *MockEtcdClient) Revoke(ctx context.Context, id clientv3.LeaseID) (*clientv3.LeaseRevokeResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Revoke", ctx, id)
@@ -163,16 +157,16 @@ func (m *MockEtcdClient) Revoke(ctx context.Context, id clientv3.LeaseID) (*clie
return ret0, ret1
}
// Revoke indicates an expected call of Revoke.
func (mr *MockEtcdClientMockRecorder) Revoke(ctx, id any) *gomock.Call {
// Revoke indicates an expected call of Revoke
func (mr *MockEtcdClientMockRecorder) Revoke(ctx, id interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Revoke", reflect.TypeOf((*MockEtcdClient)(nil).Revoke), ctx, id)
}
// Watch mocks base method.
// Watch mocks base method
func (m *MockEtcdClient) Watch(ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan {
m.ctrl.T.Helper()
varargs := []any{ctx, key}
varargs := []interface{}{ctx, key}
for _, a := range opts {
varargs = append(varargs, a)
}
@@ -181,9 +175,9 @@ func (m *MockEtcdClient) Watch(ctx context.Context, key string, opts ...clientv3
return ret0
}
// Watch indicates an expected call of Watch.
func (mr *MockEtcdClientMockRecorder) Watch(ctx, key any, opts ...any) *gomock.Call {
// Watch indicates an expected call of Watch
func (mr *MockEtcdClientMockRecorder) Watch(ctx, key interface{}, opts ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, key}, opts...)
varargs := append([]interface{}{ctx, key}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockEtcdClient)(nil).Watch), varargs...)
}

View File

@@ -2,7 +2,6 @@ package internal
import (
"context"
"errors"
"fmt"
"io"
"sort"
@@ -10,30 +9,25 @@ import (
"sync"
"time"
"github.com/zeromicro/go-zero/core/contextx"
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logc"
"github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/threading"
"go.etcd.io/etcd/api/v3/v3rpc/rpctypes"
clientv3 "go.etcd.io/etcd/client/v3"
)
const coolDownDeviation = 0.05
var (
registry = Registry{
clusters: make(map[string]*cluster),
}
connManager = syncx.NewResourceManager()
coolDownUnstable = mathx.NewUnstable(coolDownDeviation)
errClosed = errors.New("etcd monitor chan has been closed")
connManager = syncx.NewResourceManager()
)
// A Registry is a registry that manages the etcd client connections.
type Registry struct {
clusters map[string]*cluster
lock sync.RWMutex
lock sync.Mutex
}
// GetRegistry returns a global Registry.
@@ -43,148 +37,60 @@ func GetRegistry() *Registry {
// GetConn returns an etcd client connection associated with given endpoints.
func (r *Registry) GetConn(endpoints []string) (EtcdClient, error) {
c, _ := r.getOrCreateCluster(endpoints)
c, _ := r.getCluster(endpoints)
return c.getClient()
}
// Monitor monitors the key on given etcd endpoints, notify with the given UpdateListener.
func (r *Registry) Monitor(endpoints []string, key string, exactMatch bool, l UpdateListener) error {
wkey := watchKey{
key: key,
exactMatch: exactMatch,
}
c, exists := r.getOrCreateCluster(endpoints)
func (r *Registry) Monitor(endpoints []string, key string, l UpdateListener) error {
c, exists := r.getCluster(endpoints)
// if exists, the existing values should be updated to the listener.
if exists {
c.lock.Lock()
watcher, ok := c.watchers[wkey]
if ok {
watcher.listeners = append(watcher.listeners, l)
}
c.lock.Unlock()
if ok {
kvs := c.getCurrent(wkey)
for _, kv := range kvs {
l.OnAdd(kv)
}
return nil
kvs := c.getCurrent(key)
for _, kv := range kvs {
l.OnAdd(kv)
}
}
return c.monitor(wkey, l)
return c.monitor(key, l)
}
func (r *Registry) Unmonitor(endpoints []string, key string, exactMatch bool, l UpdateListener) {
c, exists := r.getCluster(endpoints)
if !exists {
return
}
wkey := watchKey{
key: key,
exactMatch: exactMatch,
}
c.lock.Lock()
defer c.lock.Unlock()
watcher, ok := c.watchers[wkey]
if !ok {
return
}
for i, listener := range watcher.listeners {
if listener == l {
watcher.listeners = append(watcher.listeners[:i], watcher.listeners[i+1:]...)
break
}
}
if len(watcher.listeners) == 0 {
if watcher.cancel != nil {
watcher.cancel()
}
delete(c.watchers, wkey)
}
}
func (r *Registry) getCluster(endpoints []string) (*cluster, bool) {
func (r *Registry) getCluster(endpoints []string) (c *cluster, exists bool) {
clusterKey := getClusterKey(endpoints)
r.lock.RLock()
c, ok := r.clusters[clusterKey]
r.lock.RUnlock()
return c, ok
}
func (r *Registry) getOrCreateCluster(endpoints []string) (c *cluster, exists bool) {
c, exists = r.getCluster(endpoints)
r.lock.Lock()
defer r.lock.Unlock()
c, exists = r.clusters[clusterKey]
if !exists {
clusterKey := getClusterKey(endpoints)
r.lock.Lock()
defer r.lock.Unlock()
// double-check locking
c, exists = r.clusters[clusterKey]
if !exists {
c = newCluster(endpoints)
r.clusters[clusterKey] = c
}
c = newCluster(endpoints)
r.clusters[clusterKey] = c
}
return
}
type (
watchKey struct {
key string
exactMatch bool
}
watchValue struct {
listeners []UpdateListener
values map[string]string
cancel context.CancelFunc
}
cluster struct {
endpoints []string
key string
watchers map[watchKey]*watchValue
watchGroup *threading.RoutineGroup
done chan lang.PlaceholderType
lock sync.RWMutex
}
)
type cluster struct {
endpoints []string
key string
values map[string]map[string]string
listeners map[string][]UpdateListener
watchGroup *threading.RoutineGroup
done chan lang.PlaceholderType
lock sync.Mutex
}
func newCluster(endpoints []string) *cluster {
return &cluster{
endpoints: endpoints,
key: getClusterKey(endpoints),
watchers: make(map[watchKey]*watchValue),
values: make(map[string]map[string]string),
listeners: make(map[string][]UpdateListener),
watchGroup: threading.NewRoutineGroup(),
done: make(chan lang.PlaceholderType),
}
}
func (c *cluster) addListener(key watchKey, l UpdateListener) {
c.lock.Lock()
defer c.lock.Unlock()
watcher, ok := c.watchers[key]
if ok {
watcher.listeners = append(watcher.listeners, l)
return
}
val := newWatchValue()
val.listeners = []UpdateListener{l}
c.watchers[key] = val
func (c *cluster) context(cli EtcdClient) context.Context {
return contextx.ValueOnlyFrom(cli.Ctx())
}
func (c *cluster) getClient() (EtcdClient, error) {
@@ -198,17 +104,12 @@ func (c *cluster) getClient() (EtcdClient, error) {
return val.(EtcdClient), nil
}
func (c *cluster) getCurrent(key watchKey) []KV {
c.lock.RLock()
defer c.lock.RUnlock()
func (c *cluster) getCurrent(key string) []KV {
c.lock.Lock()
defer c.lock.Unlock()
watcher, ok := c.watchers[key]
if !ok {
return nil
}
kvs := make([]KV, 0, len(watcher.values))
for k, v := range watcher.values {
var kvs []KV
for k, v := range c.values[key] {
kvs = append(kvs, KV{
Key: k,
Val: v,
@@ -218,23 +119,42 @@ func (c *cluster) getCurrent(key watchKey) []KV {
return kvs
}
func (c *cluster) handleChanges(key watchKey, kvs []KV) {
func (c *cluster) handleChanges(key string, kvs []KV) {
var add []KV
var remove []KV
c.lock.Lock()
watcher, ok := c.watchers[key]
listeners := append([]UpdateListener(nil), c.listeners[key]...)
vals, ok := c.values[key]
if !ok {
c.lock.Unlock()
return
add = kvs
vals = make(map[string]string)
for _, kv := range kvs {
vals[kv.Key] = kv.Val
}
c.values[key] = vals
} else {
m := make(map[string]string)
for _, kv := range kvs {
m[kv.Key] = kv.Val
}
for k, v := range vals {
if val, ok := m[k]; !ok || v != val {
remove = append(remove, KV{
Key: k,
Val: v,
})
}
}
for k, v := range m {
if val, ok := vals[k]; !ok || v != val {
add = append(add, KV{
Key: k,
Val: v,
})
}
}
c.values[key] = m
}
listeners := append([]UpdateListener(nil), watcher.listeners...)
// watcher.values cannot be nil
vals := watcher.values
newVals := make(map[string]string, len(kvs)+len(vals))
for _, kv := range kvs {
newVals[kv.Key] = kv.Val
}
add, remove := calculateChanges(vals, newVals)
watcher.values = newVals
c.lock.Unlock()
for _, kv := range add {
@@ -249,22 +169,20 @@ func (c *cluster) handleChanges(key watchKey, kvs []KV) {
}
}
func (c *cluster) handleWatchEvents(ctx context.Context, key watchKey, events []*clientv3.Event) {
c.lock.RLock()
watcher, ok := c.watchers[key]
if !ok {
c.lock.RUnlock()
return
}
listeners := append([]UpdateListener(nil), watcher.listeners...)
c.lock.RUnlock()
func (c *cluster) handleWatchEvents(key string, events []*clientv3.Event) {
c.lock.Lock()
listeners := append([]UpdateListener(nil), c.listeners[key]...)
c.lock.Unlock()
for _, ev := range events {
switch ev.Type {
case clientv3.EventTypePut:
c.lock.Lock()
watcher.values[string(ev.Kv.Key)] = string(ev.Kv.Value)
if vals, ok := c.values[key]; ok {
vals[string(ev.Kv.Key)] = string(ev.Kv.Value)
} else {
c.values[key] = map[string]string{string(ev.Kv.Key): string(ev.Kv.Value)}
}
c.lock.Unlock()
for _, l := range listeners {
l.OnAdd(KV{
@@ -274,7 +192,9 @@ func (c *cluster) handleWatchEvents(ctx context.Context, key watchKey, events []
}
case clientv3.EventTypeDelete:
c.lock.Lock()
delete(watcher.values, string(ev.Kv.Key))
if vals, ok := c.values[key]; ok {
delete(vals, string(ev.Kv.Key))
}
c.lock.Unlock()
for _, l := range listeners {
l.OnDelete(KV{
@@ -283,32 +203,27 @@ func (c *cluster) handleWatchEvents(ctx context.Context, key watchKey, events []
})
}
default:
logc.Errorf(ctx, "Unknown event type: %v", ev.Type)
logx.Errorf("Unknown event type: %v", ev.Type)
}
}
}
func (c *cluster) load(cli EtcdClient, key watchKey) int64 {
func (c *cluster) load(cli EtcdClient, key string) int64 {
var resp *clientv3.GetResponse
for {
var err error
ctx, cancel := context.WithTimeout(cli.Ctx(), RequestTimeout)
if key.exactMatch {
resp, err = cli.Get(ctx, key.key)
} else {
resp, err = cli.Get(ctx, makeKeyPrefix(key.key), clientv3.WithPrefix())
}
ctx, cancel := context.WithTimeout(c.context(cli), RequestTimeout)
resp, err = cli.Get(ctx, makeKeyPrefix(key), clientv3.WithPrefix())
cancel()
if err == nil {
break
}
logc.Errorf(cli.Ctx(), "%s, key: %s, exactMatch: %t", err.Error(), key.key, key.exactMatch)
time.Sleep(coolDownUnstable.AroundDuration(coolDownInterval))
logx.Error(err)
time.Sleep(coolDownInterval)
}
kvs := make([]KV, 0, len(resp.Kvs))
var kvs []KV
for _, ev := range resp.Kvs {
kvs = append(kvs, KV{
Key: string(ev.Key),
@@ -321,13 +236,16 @@ func (c *cluster) load(cli EtcdClient, key watchKey) int64 {
return resp.Header.Revision
}
func (c *cluster) monitor(key watchKey, l UpdateListener) error {
func (c *cluster) monitor(key string, l UpdateListener) error {
c.lock.Lock()
c.listeners[key] = append(c.listeners[key], l)
c.lock.Unlock()
cli, err := c.getClient()
if err != nil {
return err
}
c.addListener(key, l)
rev := c.load(cli, key)
c.watchGroup.Run(func() {
c.watch(cli, key, rev)
@@ -349,22 +267,16 @@ func (c *cluster) newClient() (EtcdClient, error) {
func (c *cluster) reload(cli EtcdClient) {
c.lock.Lock()
// cancel the previous watches
close(c.done)
c.watchGroup.Wait()
keys := make([]watchKey, 0, len(c.watchers))
for wk, wval := range c.watchers {
keys = append(keys, wk)
if wval.cancel != nil {
wval.cancel()
}
}
c.done = make(chan lang.PlaceholderType)
c.watchGroup = threading.NewRoutineGroup()
var keys []string
for k := range c.listeners {
keys = append(keys, k)
}
c.lock.Unlock()
// start new watches
for _, key := range keys {
k := key
c.watchGroup.Run(func() {
@@ -374,81 +286,46 @@ func (c *cluster) reload(cli EtcdClient) {
}
}
func (c *cluster) watch(cli EtcdClient, key watchKey, rev int64) {
func (c *cluster) watch(cli EtcdClient, key string, rev int64) {
for {
err := c.watchStream(cli, key, rev)
if err == nil {
if c.watchStream(cli, key, rev) {
return
}
if rev != 0 && errors.Is(err, rpctypes.ErrCompacted) {
logc.Errorf(cli.Ctx(), "etcd watch stream has been compacted, try to reload, rev %d", rev)
rev = c.load(cli, key)
}
// log the error and retry with cooldown to prevent CPU/disk exhaustion
logc.Error(cli.Ctx(), err)
time.Sleep(coolDownUnstable.AroundDuration(coolDownInterval))
}
}
func (c *cluster) watchStream(cli EtcdClient, key watchKey, rev int64) error {
ctx, rch := c.setupWatch(cli, key, rev)
func (c *cluster) watchStream(cli EtcdClient, key string, rev int64) bool {
var rch clientv3.WatchChan
if rev != 0 {
rch = cli.Watch(clientv3.WithRequireLeader(c.context(cli)), makeKeyPrefix(key), clientv3.WithPrefix(),
clientv3.WithRev(rev+1))
} else {
rch = cli.Watch(clientv3.WithRequireLeader(c.context(cli)), makeKeyPrefix(key), clientv3.WithPrefix())
}
for {
select {
case wresp, ok := <-rch:
if !ok {
return errClosed
logx.Error("etcd monitor chan has been closed")
return false
}
if wresp.Canceled {
return fmt.Errorf("etcd monitor chan has been canceled, error: %w", wresp.Err())
logx.Errorf("etcd monitor chan has been canceled, error: %v", wresp.Err())
return false
}
if wresp.Err() != nil {
return fmt.Errorf("etcd monitor chan error: %w", wresp.Err())
logx.Error(fmt.Sprintf("etcd monitor chan error: %v", wresp.Err()))
return false
}
c.handleWatchEvents(ctx, key, wresp.Events)
case <-ctx.Done():
return nil
c.handleWatchEvents(key, wresp.Events)
case <-c.done:
return nil
return true
}
}
}
func (c *cluster) setupWatch(cli EtcdClient, key watchKey, rev int64) (context.Context, clientv3.WatchChan) {
var (
rch clientv3.WatchChan
ops []clientv3.OpOption
wkey = key.key
)
if !key.exactMatch {
wkey = makeKeyPrefix(key.key)
ops = append(ops, clientv3.WithPrefix())
}
if rev != 0 {
ops = append(ops, clientv3.WithRev(rev+1))
}
ctx, cancel := context.WithCancel(cli.Ctx())
c.lock.Lock()
if watcher, ok := c.watchers[key]; ok {
watcher.cancel = cancel
} else {
val := newWatchValue()
val.cancel = cancel
c.watchers[key] = val
}
c.lock.Unlock()
rch = cli.Watch(clientv3.WithRequireLeader(ctx), wkey, ops...)
return ctx, rch
}
func (c *cluster) watchConnState(cli EtcdClient) {
watcher := newStateWatcher()
watcher.addListener(func() {
@@ -460,11 +337,13 @@ func (c *cluster) watchConnState(cli EtcdClient) {
// DialClient dials an etcd cluster with given endpoints.
func DialClient(endpoints []string) (EtcdClient, error) {
cfg := clientv3.Config{
Endpoints: endpoints,
AutoSyncInterval: autoSyncInterval,
DialTimeout: DialTimeout,
RejectOldCluster: true,
PermitWithoutStream: true,
Endpoints: endpoints,
AutoSyncInterval: autoSyncInterval,
DialTimeout: DialTimeout,
DialKeepAliveTime: dialKeepAliveTime,
DialKeepAliveTimeout: DialTimeout,
RejectOldCluster: true,
PermitWithoutStream: true,
}
if account, ok := GetAccount(endpoints); ok {
cfg.Username = account.User
@@ -477,28 +356,6 @@ func DialClient(endpoints []string) (EtcdClient, error) {
return clientv3.New(cfg)
}
func calculateChanges(oldVals, newVals map[string]string) (add, remove []KV) {
for k, v := range newVals {
if val, ok := oldVals[k]; !ok || v != val {
add = append(add, KV{
Key: k,
Val: v,
})
}
}
for k, v := range oldVals {
if val, ok := newVals[k]; !ok || v != val {
remove = append(remove, KV{
Key: k,
Val: v,
})
}
}
return add, remove
}
func getClusterKey(endpoints []string) string {
sort.Strings(endpoints)
return strings.Join(endpoints, endpointsSeparator)
@@ -507,10 +364,3 @@ func getClusterKey(endpoints []string) string {
func makeKeyPrefix(key string) string {
return fmt.Sprintf("%s%c", key, Delimiter)
}
// NewClient returns a watchValue that make sure values are not nil.
func newWatchValue() *watchValue {
return &watchValue{
values: make(map[string]string),
}
}

View File

@@ -2,22 +2,18 @@ package internal
import (
"context"
"os"
"sync"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/contextx"
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stringx"
"github.com/zeromicro/go-zero/core/threading"
"go.etcd.io/etcd/api/v3/etcdserverpb"
"go.etcd.io/etcd/api/v3/mvccpb"
clientv3 "go.etcd.io/etcd/client/v3"
"go.etcd.io/etcd/client/v3/mock/mockserver"
"go.uber.org/mock/gomock"
)
var mockLock sync.Mutex
@@ -39,9 +35,9 @@ func setMockClient(cli EtcdClient) func() {
func TestGetCluster(t *testing.T) {
AddAccount([]string{"first"}, "foo", "bar")
c1, _ := GetRegistry().getOrCreateCluster([]string{"first"})
c2, _ := GetRegistry().getOrCreateCluster([]string{"second"})
c3, _ := GetRegistry().getOrCreateCluster([]string{"first"})
c1, _ := GetRegistry().getCluster([]string{"first"})
c2, _ := GetRegistry().getCluster([]string{"second"})
c3, _ := GetRegistry().getCluster([]string{"first"})
assert.Equal(t, c1, c3)
assert.NotEqual(t, c1, c2)
}
@@ -51,36 +47,6 @@ func TestGetClusterKey(t *testing.T) {
getClusterKey([]string{"remotehost:5678", "localhost:1234"}))
}
func TestUnmonitor(t *testing.T) {
t.Run("no listener", func(t *testing.T) {
reg := &Registry{
clusters: map[string]*cluster{},
}
assert.NotPanics(t, func() {
reg.Unmonitor([]string{"any"}, "any", false, nil)
})
})
t.Run("no value", func(t *testing.T) {
reg := &Registry{
clusters: map[string]*cluster{
"any": {
watchers: map[watchKey]*watchValue{
{
key: "any",
}: {
values: map[string]string{},
},
},
},
},
}
assert.NotPanics(t, func() {
reg.Unmonitor([]string{"any"}, "another", false, nil)
})
})
}
func TestCluster_HandleChanges(t *testing.T) {
ctrl := gomock.NewController(t)
l := NewMockUpdateListener(ctrl)
@@ -109,14 +75,8 @@ func TestCluster_HandleChanges(t *testing.T) {
Val: "4",
})
c := newCluster([]string{"any"})
key := watchKey{
key: "any",
exactMatch: false,
}
c.watchers[key] = &watchValue{
listeners: []UpdateListener{l},
}
c.handleChanges(key, []KV{
c.listeners["any"] = []UpdateListener{l}
c.handleChanges("any", []KV{
{
Key: "first",
Val: "1",
@@ -129,8 +89,8 @@ func TestCluster_HandleChanges(t *testing.T) {
assert.EqualValues(t, map[string]string{
"first": "1",
"second": "2",
}, c.watchers[key].values)
c.handleChanges(key, []KV{
}, c.values["any"])
c.handleChanges("any", []KV{
{
Key: "third",
Val: "3",
@@ -143,7 +103,7 @@ func TestCluster_HandleChanges(t *testing.T) {
assert.EqualValues(t, map[string]string{
"third": "3",
"fourth": "4",
}, c.watchers[key].values)
}, c.values["any"])
}
func TestCluster_Load(t *testing.T) {
@@ -163,11 +123,9 @@ func TestCluster_Load(t *testing.T) {
}, nil)
cli.EXPECT().Ctx().Return(context.Background())
c := &cluster{
watchers: make(map[watchKey]*watchValue),
values: make(map[string]map[string]string),
}
c.load(cli, watchKey{
key: "any",
})
c.load(cli, "any")
}
func TestCluster_Watch(t *testing.T) {
@@ -199,25 +157,20 @@ func TestCluster_Watch(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
c := &cluster{
watchers: make(map[watchKey]*watchValue),
}
key := watchKey{
key: "any",
listeners: make(map[string][]UpdateListener),
values: make(map[string]map[string]string),
}
listener := NewMockUpdateListener(ctrl)
c.watchers[key] = &watchValue{
listeners: []UpdateListener{listener},
values: make(map[string]string),
}
c.listeners["any"] = []UpdateListener{listener}
listener.EXPECT().OnAdd(gomock.Any()).Do(func(kv KV) {
assert.Equal(t, "hello", kv.Key)
assert.Equal(t, "world", kv.Val)
wg.Done()
}).MaxTimes(1)
listener.EXPECT().OnDelete(gomock.Any()).Do(func(_ any) {
listener.EXPECT().OnDelete(gomock.Any()).Do(func(_ interface{}) {
wg.Done()
}).MaxTimes(1)
go c.watch(cli, key, 0)
go c.watch(cli, "any", 0)
ch <- clientv3.WatchResponse{
Events: []*clientv3.Event{
{
@@ -255,111 +208,17 @@ func TestClusterWatch_RespFailures(t *testing.T) {
ch := make(chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes()
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
c := &cluster{
watchers: make(map[watchKey]*watchValue),
}
c := new(cluster)
c.done = make(chan lang.PlaceholderType)
go func() {
ch <- resp
close(c.done)
}()
key := watchKey{
key: "any",
}
c.watch(cli, key, 0)
c.watch(cli, "any", 0)
})
}
}
func TestCluster_getCurrent(t *testing.T) {
t.Run("no value", func(t *testing.T) {
c := &cluster{
watchers: map[watchKey]*watchValue{
{
key: "any",
}: {
values: map[string]string{},
},
},
}
assert.Nil(t, c.getCurrent(watchKey{
key: "another",
}))
})
}
func TestCluster_handleWatchEvents(t *testing.T) {
t.Run("no value", func(t *testing.T) {
c := &cluster{
watchers: map[watchKey]*watchValue{
{
key: "any",
}: {
values: map[string]string{},
},
},
}
assert.NotPanics(t, func() {
c.handleWatchEvents(context.Background(), watchKey{
key: "another",
}, nil)
})
})
}
func TestCluster_addListener(t *testing.T) {
t.Run("has listener", func(t *testing.T) {
c := &cluster{
watchers: map[watchKey]*watchValue{
{
key: "any",
}: {
listeners: make([]UpdateListener, 0),
},
},
}
assert.NotPanics(t, func() {
c.addListener(watchKey{
key: "any",
}, nil)
})
})
t.Run("no listener", func(t *testing.T) {
c := &cluster{
watchers: map[watchKey]*watchValue{
{
key: "any",
}: {
listeners: make([]UpdateListener, 0),
},
},
}
assert.NotPanics(t, func() {
c.addListener(watchKey{
key: "another",
}, nil)
})
})
}
func TestCluster_reload(t *testing.T) {
c := &cluster{
watchers: map[watchKey]*watchValue{},
watchGroup: threading.NewRoutineGroup(),
done: make(chan lang.PlaceholderType),
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
cli := NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
assert.NotPanics(t, func() {
c.reload(cli)
})
}
func TestClusterWatch_CloseChan(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
@@ -369,17 +228,13 @@ func TestClusterWatch_CloseChan(t *testing.T) {
ch := make(chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes()
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
c := &cluster{
watchers: make(map[watchKey]*watchValue),
}
c := new(cluster)
c.done = make(chan lang.PlaceholderType)
go func() {
close(ch)
close(c.done)
}()
c.watch(cli, watchKey{
key: "any",
}, 0)
c.watch(cli, "any", 0)
}
func TestValueOnlyContext(t *testing.T) {
@@ -387,167 +242,3 @@ func TestValueOnlyContext(t *testing.T) {
ctx.Done()
assert.Nil(t, ctx.Err())
}
func TestDialClient(t *testing.T) {
svr, err := mockserver.StartMockServers(1)
assert.NoError(t, err)
svr.StartAt(0)
certFile := createTempFile(t, []byte(certContent))
defer os.Remove(certFile)
keyFile := createTempFile(t, []byte(keyContent))
defer os.Remove(keyFile)
caFile := createTempFile(t, []byte(caContent))
defer os.Remove(caFile)
endpoints := []string{svr.Servers[0].Address}
AddAccount(endpoints, "foo", "bar")
assert.NoError(t, AddTLS(endpoints, certFile, keyFile, caFile, false))
old := DialTimeout
DialTimeout = time.Millisecond
defer func() {
DialTimeout = old
}()
_, err = DialClient(endpoints)
assert.Error(t, err)
}
func TestRegistry_Monitor(t *testing.T) {
svr, err := mockserver.StartMockServers(1)
assert.NoError(t, err)
svr.StartAt(0)
endpoints := []string{svr.Servers[0].Address}
GetRegistry().lock.Lock()
GetRegistry().clusters = map[string]*cluster{
getClusterKey(endpoints): {
watchers: map[watchKey]*watchValue{
{
key: "foo",
exactMatch: true,
}: {
values: map[string]string{
"bar": "baz",
},
},
},
},
}
GetRegistry().lock.Unlock()
assert.Error(t, GetRegistry().Monitor(endpoints, "foo", false, new(mockListener)))
}
func TestRegistry_Unmonitor(t *testing.T) {
svr, err := mockserver.StartMockServers(1)
assert.NoError(t, err)
svr.StartAt(0)
_, cancel := context.WithCancel(context.Background())
endpoints := []string{svr.Servers[0].Address}
GetRegistry().lock.Lock()
GetRegistry().clusters = map[string]*cluster{
getClusterKey(endpoints): {
watchers: map[watchKey]*watchValue{
{
key: "foo",
exactMatch: true,
}: {
values: map[string]string{
"bar": "baz",
},
cancel: cancel,
},
},
},
}
GetRegistry().lock.Unlock()
l := new(mockListener)
assert.NoError(t, GetRegistry().Monitor(endpoints, "foo", true, l))
watchVals := GetRegistry().clusters[getClusterKey(endpoints)].watchers[watchKey{
key: "foo",
exactMatch: true,
}]
assert.Equal(t, 1, len(watchVals.listeners))
GetRegistry().Unmonitor(endpoints, "foo", true, l)
watchVals = GetRegistry().clusters[getClusterKey(endpoints)].watchers[watchKey{
key: "foo",
exactMatch: true,
}]
assert.Nil(t, watchVals)
}
// TestCluster_ConcurrentMonitor tests the race condition fix in setupWatch
// This test specifically covers the scenario from issue #5394 where:
// - addListener() writes to the watchers map (with lock)
// - setupWatch() reads from the watchers map (now with lock after fix)
// Running with -race flag will detect any race conditions
func TestCluster_ConcurrentMonitor(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
cli := NewMockEtcdClient(ctrl)
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(make(chan clientv3.WatchResponse)).AnyTimes()
c := &cluster{
endpoints: []string{"localhost:2379"},
key: "test-cluster",
watchers: make(map[watchKey]*watchValue),
watchGroup: threading.NewRoutineGroup(),
done: make(chan lang.PlaceholderType),
lock: sync.RWMutex{},
}
// Spawn multiple concurrent operations that simulate the race condition:
// - Some goroutines call addListener (write to map)
// - Some goroutines call setupWatch (read from map)
var wg sync.WaitGroup
numGoroutines := 20
wg.Add(numGoroutines)
keys := []watchKey{
{key: "key-0", exactMatch: false},
{key: "key-1", exactMatch: false},
{key: "key-2", exactMatch: false},
}
for i := 0; i < numGoroutines; i++ {
idx := i
go func() {
defer wg.Done()
key := keys[idx%len(keys)]
if idx%2 == 0 {
// Half the goroutines add listeners (write operation)
c.addListener(key, &mockListener{})
} else {
// Half the goroutines setup watches (read operation)
_, _ = c.setupWatch(cli, key, 0)
}
}()
}
// Wait for all goroutines to complete
wg.Wait()
// Verify that watchers were correctly added
c.lock.RLock()
assert.True(t, len(c.watchers) > 0, "watchers should be added")
for _, watcher := range c.watchers {
assert.NotNil(t, watcher, "watcher should not be nil")
}
c.lock.RUnlock()
// Clean up
close(c.done)
}
type mockListener struct {
}
func (m *mockListener) OnAdd(_ KV) {
}
func (m *mockListener) OnDelete(_ KV) {
}

View File

@@ -1,10 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: statewatcher.go
//
// Generated by this command:
//
// mockgen -package internal -destination statewatcher_mock.go -source statewatcher.go etcdConn
//
// Package internal is a generated GoMock package.
package internal
@@ -13,35 +8,34 @@ import (
context "context"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
gomock "github.com/golang/mock/gomock"
connectivity "google.golang.org/grpc/connectivity"
)
// MocketcdConn is a mock of etcdConn interface.
// MocketcdConn is a mock of etcdConn interface
type MocketcdConn struct {
ctrl *gomock.Controller
recorder *MocketcdConnMockRecorder
isgomock struct{}
}
// MocketcdConnMockRecorder is the mock recorder for MocketcdConn.
// MocketcdConnMockRecorder is the mock recorder for MocketcdConn
type MocketcdConnMockRecorder struct {
mock *MocketcdConn
}
// NewMocketcdConn creates a new mock instance.
// NewMocketcdConn creates a new mock instance
func NewMocketcdConn(ctrl *gomock.Controller) *MocketcdConn {
mock := &MocketcdConn{ctrl: ctrl}
mock.recorder = &MocketcdConnMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MocketcdConn) EXPECT() *MocketcdConnMockRecorder {
return m.recorder
}
// GetState mocks base method.
// GetState mocks base method
func (m *MocketcdConn) GetState() connectivity.State {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetState")
@@ -49,13 +43,13 @@ func (m *MocketcdConn) GetState() connectivity.State {
return ret0
}
// GetState indicates an expected call of GetState.
// GetState indicates an expected call of GetState
func (mr *MocketcdConnMockRecorder) GetState() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetState", reflect.TypeOf((*MocketcdConn)(nil).GetState))
}
// WaitForStateChange mocks base method.
// WaitForStateChange mocks base method
func (m *MocketcdConn) WaitForStateChange(ctx context.Context, sourceState connectivity.State) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "WaitForStateChange", ctx, sourceState)
@@ -63,8 +57,8 @@ func (m *MocketcdConn) WaitForStateChange(ctx context.Context, sourceState conne
return ret0
}
// WaitForStateChange indicates an expected call of WaitForStateChange.
func (mr *MocketcdConnMockRecorder) WaitForStateChange(ctx, sourceState any) *gomock.Call {
// WaitForStateChange indicates an expected call of WaitForStateChange
func (mr *MocketcdConnMockRecorder) WaitForStateChange(ctx, sourceState interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WaitForStateChange", reflect.TypeOf((*MocketcdConn)(nil).WaitForStateChange), ctx, sourceState)
}

View File

@@ -4,7 +4,7 @@ import (
"sync"
"testing"
"go.uber.org/mock/gomock"
"github.com/golang/mock/gomock"
"google.golang.org/grpc/connectivity"
)

View File

@@ -10,7 +10,6 @@ type (
}
// UpdateListener wraps the OnAdd and OnDelete methods.
// The implementation should be thread-safe and idempotent.
UpdateListener interface {
OnAdd(kv KV)
OnDelete(kv KV)

View File

@@ -1,10 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: updatelistener.go
//
// Generated by this command:
//
// mockgen -package internal -destination updatelistener_mock.go -source updatelistener.go UpdateListener
//
// Package internal is a generated GoMock package.
package internal
@@ -12,53 +7,52 @@ package internal
import (
reflect "reflect"
gomock "go.uber.org/mock/gomock"
gomock "github.com/golang/mock/gomock"
)
// MockUpdateListener is a mock of UpdateListener interface.
// MockUpdateListener is a mock of UpdateListener interface
type MockUpdateListener struct {
ctrl *gomock.Controller
recorder *MockUpdateListenerMockRecorder
isgomock struct{}
}
// MockUpdateListenerMockRecorder is the mock recorder for MockUpdateListener.
// MockUpdateListenerMockRecorder is the mock recorder for MockUpdateListener
type MockUpdateListenerMockRecorder struct {
mock *MockUpdateListener
}
// NewMockUpdateListener creates a new mock instance.
// NewMockUpdateListener creates a new mock instance
func NewMockUpdateListener(ctrl *gomock.Controller) *MockUpdateListener {
mock := &MockUpdateListener{ctrl: ctrl}
mock.recorder = &MockUpdateListenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockUpdateListener) EXPECT() *MockUpdateListenerMockRecorder {
return m.recorder
}
// OnAdd mocks base method.
// OnAdd mocks base method
func (m *MockUpdateListener) OnAdd(kv KV) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnAdd", kv)
}
// OnAdd indicates an expected call of OnAdd.
func (mr *MockUpdateListenerMockRecorder) OnAdd(kv any) *gomock.Call {
// OnAdd indicates an expected call of OnAdd
func (mr *MockUpdateListenerMockRecorder) OnAdd(kv interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnAdd", reflect.TypeOf((*MockUpdateListener)(nil).OnAdd), kv)
}
// OnDelete mocks base method.
// OnDelete mocks base method
func (m *MockUpdateListener) OnDelete(kv KV) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnDelete", kv)
}
// OnDelete indicates an expected call of OnDelete.
func (mr *MockUpdateListenerMockRecorder) OnDelete(kv any) *gomock.Call {
// OnDelete indicates an expected call of OnDelete
func (mr *MockUpdateListenerMockRecorder) OnDelete(kv interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnDelete", reflect.TypeOf((*MockUpdateListener)(nil).OnDelete), kv)
}

View File

@@ -9,6 +9,7 @@ const (
autoSyncInterval = time.Minute
coolDownInterval = time.Second
dialTimeout = 5 * time.Second
dialKeepAliveTime = 5 * time.Second
requestTimeout = 3 * time.Second
endpointsSeparator = ","
)

View File

@@ -5,7 +5,6 @@ import (
"github.com/zeromicro/go-zero/core/discov/internal"
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logc"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/syncx"
@@ -92,12 +91,12 @@ func (p *Publisher) doKeepAlive() error {
default:
cli, err := p.doRegister()
if err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher doRegister: %v", err)
logx.Errorf("etcd publisher doRegister: %s", err.Error())
break
}
if err := p.keepAliveAsync(cli); err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher keepAliveAsync: %v", err)
logx.Errorf("etcd publisher keepAliveAsync: %s", err.Error())
break
}
@@ -125,48 +124,23 @@ func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error {
}
threading.GoSafe(func() {
wch := cli.Watch(cli.Ctx(), p.fullKey, clientv3.WithFilterPut())
for {
select {
case _, ok := <-ch:
if !ok {
p.revoke(cli)
if err := p.doKeepAlive(); err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %v", err)
logx.Errorf("etcd publisher KeepAlive: %s", err.Error())
}
return
}
case c := <-wch:
if c.Err() != nil {
logc.Errorf(cli.Ctx(), "etcd publisher watch: %v", c.Err())
if err := p.doKeepAlive(); err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %v", err)
}
return
}
for _, evt := range c.Events {
if evt.Type == clientv3.EventTypeDelete {
logc.Infof(cli.Ctx(), "etcd publisher watch: %s, event: %v",
evt.Kv.Key, evt.Type)
_, err := cli.Put(cli.Ctx(), p.fullKey, p.value, clientv3.WithLease(p.lease))
if err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher re-put key: %v", err)
} else {
logc.Infof(cli.Ctx(), "etcd publisher re-put key: %s, value: %s",
p.fullKey, p.value)
}
}
}
case <-p.pauseChan:
logc.Infof(cli.Ctx(), "paused etcd renew, key: %s, value: %s", p.key, p.value)
logx.Infof("paused etcd renew, key: %s, value: %s", p.key, p.value)
p.revoke(cli)
select {
case <-p.resumeChan:
if err := p.doKeepAlive(); err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %v", err)
logx.Errorf("etcd publisher KeepAlive: %s", err.Error())
}
return
case <-p.quit.Done():
@@ -201,7 +175,7 @@ func (p *Publisher) register(client internal.EtcdClient) (clientv3.LeaseID, erro
func (p *Publisher) revoke(cli internal.EtcdClient) {
if _, err := cli.Revoke(cli.Ctx(), p.lease); err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher revoke: %v", err)
logx.Errorf("etcd publisher revoke: %s", err.Error())
}
}

View File

@@ -1,99 +1,18 @@
package discov
import (
"context"
"errors"
"net"
"os"
"sync"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/discov/internal"
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stringx"
"go.etcd.io/etcd/api/v3/mvccpb"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/mock/gomock"
"golang.org/x/net/http2"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
)
const (
certContent = `-----BEGIN CERTIFICATE-----
MIIDazCCAlOgAwIBAgIUEg9GVO2oaPn+YSmiqmFIuAo10WIwDQYJKoZIhvcNAQEM
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAzMTExMzIxMjNaGA8yMTIz
MDIxNTEzMjEyM1owRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
AQEBBQADggEPADCCAQoCggEBALplXlWsIf0O/IgnIplmiZHKGnxyfyufyE2FBRNk
OofRqbKuPH8GNqbkvZm7N29fwTDAQ+mViAggCkDht4hOzoWJMA7KYJt8JnTSWL48
M1lcrpc9DL2gszC/JF/FGvyANbBtLklkZPFBGdHUX14pjrT937wqPtm+SqUHSvRT
B7bmwmm2drRcmhpVm98LSlV7uQ2EgnJgsLjBPITKUejLmVLHfgX0RwQ2xIpX9pS4
FCe1BTacwl2gGp7Mje7y4Mfv3o0ArJW6Tuwbjx59ZXwb1KIP71b7bT04AVS8ZeYO
UMLKKuB5UR9x9Rn6cLXOTWBpcMVyzDgrAFLZjnE9LPUolZMCAwEAAaNRME8wHwYD
VR0jBBgwFoAUeW8w8pmhncbRgTsl48k4/7wnfx8wCQYDVR0TBAIwADALBgNVHQ8E
BAMCBPAwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBDAUAA4IBAQAI
y9xaoS88CLPBsX6mxfcTAFVfGNTRW9VN9Ng1cCnUR+YGoXGM/l+qP4f7p8ocdGwK
iYZErVTzXYIn+D27//wpY3klJk3gAnEUBT3QRkStBw7XnpbeZ2oPBK+cmDnCnZPS
BIF1wxPX7vIgaxs5Zsdqwk3qvZ4Djr2wP7LabNWTLSBKgQoUY45Liw6pffLwcGF9
UKlu54bvGze2SufISCR3ib+I+FLvqpvJhXToZWYb/pfI/HccuCL1oot1x8vx6DQy
U+TYxlZsKS5mdNxAX3dqEkEMsgEi+g/tzDPXJImfeCGGBhIOXLm8SRypiuGdEbc9
xkWYxRPegajuEZGvCqVs
-----END CERTIFICATE-----`
keyContent = `-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAumVeVawh/Q78iCcimWaJkcoafHJ/K5/ITYUFE2Q6h9Gpsq48
fwY2puS9mbs3b1/BMMBD6ZWICCAKQOG3iE7OhYkwDspgm3wmdNJYvjwzWVyulz0M
vaCzML8kX8Ua/IA1sG0uSWRk8UEZ0dRfXimOtP3fvCo+2b5KpQdK9FMHtubCabZ2
tFyaGlWb3wtKVXu5DYSCcmCwuME8hMpR6MuZUsd+BfRHBDbEilf2lLgUJ7UFNpzC
XaAansyN7vLgx+/ejQCslbpO7BuPHn1lfBvUog/vVvttPTgBVLxl5g5Qwsoq4HlR
H3H1Gfpwtc5NYGlwxXLMOCsAUtmOcT0s9SiVkwIDAQABAoIBAD5meTJNMgO55Kjg
ESExxpRcCIno+tHr5+6rvYtEXqPheOIsmmwb9Gfi4+Z3WpOaht5/Pz0Ppj6yGzyl
U//6AgGKb+BDuBvVcDpjwPnOxZIBCSHwejdxeQu0scSuA97MPS0XIAvJ5FEv7ijk
5Bht6SyGYURpECltHygoTNuGgGqmO+McCJRLE9L09lTBI6UQ/JQwWJqSr7wx6iPU
M1Ze/srIV+7cyEPu6i0DGjS1gSQKkX68Lqn1w6oE290O+OZvleO0gZ02fLDWCZke
aeD9+EU/Pw+rqm3H6o0szOFIpzhRp41FUdW9sybB3Yp3u7c/574E+04Z/e30LMKs
TCtE1QECgYEA3K7KIpw0NH2HXL5C3RHcLmr204xeBfS70riBQQuVUgYdmxak2ima
80RInskY8hRhSGTg0l+VYIH8cmjcUyqMSOELS5XfRH99r4QPiK8AguXg80T4VumY
W3Pf+zEC2ssgP/gYthV0g0Xj5m2QxktOF9tRw5nkg739ZR4dI9lm/iECgYEA2Dnf
uwEDGqHiQRF6/fh5BG/nGVMvrefkqx6WvTJQ3k/M/9WhxB+lr/8yH46TuS8N2b29
FoTf3Mr9T7pr/PWkOPzoY3P56nYbKU8xSwCim9xMzhBMzj8/N9ukJvXy27/VOz56
eQaKqnvdXNGtPJrIMDGHps2KKWlKLyAlapzjVTMCgYAA/W++tACv85g13EykfT4F
n0k4LbsGP9DP4zABQLIMyiY72eAncmRVjwrcW36XJ2xATOONTgx3gF3HjZzfaqNy
eD/6uNNllUTVEryXGmHgNHPL45VRnn6memCY2eFvZdXhM5W4y2PYaunY0MkDercA
+GTngbs6tBF88KOk04bYwQKBgFl68cRgsdkmnwwQYNaTKfmVGYzYaQXNzkqmWPko
xmCJo6tHzC7ubdG8iRCYHzfmahPuuj6EdGPZuSRyYFgJi5Ftz/nAN+84OxtIQ3zn
YWOgskQgaLh9YfsKsQ7Sf1NDOsnOnD5TX7UXl07fEpLe9vNCvAFiU8e5Y9LGudU5
4bYTAoGBAMdX3a3bXp4cZvXNBJ/QLVyxC6fP1Q4haCR1Od3m+T00Jth2IX2dk/fl
p6xiJT1av5JtYabv1dFKaXOS5s1kLGGuCCSKpkvFZm826aQ2AFm0XGqEQDLeei5b
A52Kpy/YJ+RkG4BTFtAooFq6DmA0cnoP6oPvG2h6XtDJwDTPInJb
-----END RSA PRIVATE KEY-----`
caContent = `-----BEGIN CERTIFICATE-----
MIIDbTCCAlWgAwIBAgIUBJvFoCowKich7MMfseJ+DYzzirowDQYJKoZIhvcNAQEM
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAzMTExMzIxMDNaGA8yMTIz
MDIxNTEzMjEwM1owRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
AQEBBQADggEPADCCAQoCggEBAO4to2YMYj0bxgr2FCiweSTSFuPx33zSw2x/s9Wf
OR41bm2DFsyYT5f3sOIKlXZEdLmOKty2e3ho3yC0EyNpVHdykkkHT3aDI17quZax
kYi/URqqtl1Z08A22txolc04hAZisg2BypGi3vql81UW1t3zyloGnJoIAeXR9uca
ljP6Bk3bwsxoVBLi1JtHrO0hHLQaeHmKhAyrys06X0LRdn7Px48yRZlt6FaLSa8X
YiRM0G44bVy/h6BkoQjMYGwVmCVk6zjJ9U7ZPFqdnDMNxAfR+hjDnYodqdLDMTTR
1NPVrnEnNwFx0AMLvgt/ba/45vZCEAmSZnFXFAJJcM7ai9ECAwEAAaNTMFEwHQYD
VR0OBBYEFHlvMPKZoZ3G0YE7JePJOP+8J38fMB8GA1UdIwQYMBaAFHlvMPKZoZ3G
0YE7JePJOP+8J38fMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQEMBQADggEB
AMX8dNulADOo9uQgBMyFb9TVra7iY0zZjzv4GY5XY7scd52n6CnfAPvYBBDnTr/O
BgNp5jaujb4+9u/2qhV3f9n+/3WOb2CmPehBgVSzlXqHeQ9lshmgwZPeem2T+8Tm
Nnc/xQnsUfCFszUDxpkr55+aLVM22j02RWqcZ4q7TAaVYL+kdFVMc8FoqG/0ro6A
BjE/Qn0Nn7ciX1VUjDt8l+k7ummPJTmzdi6i6E4AwO9dzrGNgGJ4aWL8cC6xYcIX
goVIRTFeONXSDno/oPjWHpIPt7L15heMpKBHNuzPkKx2YVqPHE5QZxWfS+Lzgx+Q
E2oTTM0rYKOZ8p6000mhvKI=
-----END CERTIFICATE-----`
)
func init() {
@@ -118,7 +37,7 @@ func TestPublisher_register(t *testing.T) {
assert.Nil(t, err)
}
func TestPublisher_registerWithOptions(t *testing.T) {
func TestPublisher_registerWithId(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
const id = 2
@@ -130,15 +49,7 @@ func TestPublisher_registerWithOptions(t *testing.T) {
ID: 1,
}, nil)
cli.EXPECT().Put(gomock.Any(), makeEtcdKey("thekey", id), "thevalue", gomock.Any())
certFile := createTempFile(t, []byte(certContent))
defer os.Remove(certFile)
keyFile := createTempFile(t, []byte(keyContent))
defer os.Remove(keyFile)
caFile := createTempFile(t, []byte(caContent))
defer os.Remove(caFile)
pub := NewPublisher(nil, "thekey", "thevalue", WithId(id),
WithPubEtcdTLS(certFile, keyFile, caFile, true))
pub := NewPublisher(nil, "thekey", "thevalue", WithId(id))
_, err := pub.register(cli)
assert.Nil(t, err)
}
@@ -212,12 +123,9 @@ func TestPublisher_keepAliveAsyncQuit(t *testing.T) {
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().KeepAlive(gomock.Any(), id)
// Add Watch expectation for the new watch mechanism
watchChan := make(<-chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(watchChan)
var wg sync.WaitGroup
wg.Add(1)
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ interface{}) {
wg.Done()
})
pub := NewPublisher(nil, "thekey", "thevalue")
@@ -236,13 +144,10 @@ func TestPublisher_keepAliveAsyncPause(t *testing.T) {
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().KeepAlive(gomock.Any(), id)
// Add Watch expectation for the new watch mechanism
watchChan := make(<-chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(watchChan)
pub := NewPublisher(nil, "thekey", "thevalue")
var wg sync.WaitGroup
wg.Add(1)
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ interface{}) {
pub.Stop()
wg.Done()
})
@@ -252,112 +157,6 @@ func TestPublisher_keepAliveAsyncPause(t *testing.T) {
wg.Wait()
}
// Test case for key deletion and re-registration (covers lines 148-155)
func TestPublisher_keepAliveAsyncKeyDeletion(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
const id clientv3.LeaseID = 1
cli := internal.NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().KeepAlive(gomock.Any(), id)
// Create a watch channel that will send a delete event
watchChan := make(chan clientv3.WatchResponse, 1)
watchResp := clientv3.WatchResponse{
Events: []*clientv3.Event{{
Type: clientv3.EventTypeDelete,
Kv: &mvccpb.KeyValue{
Key: []byte("thekey"),
},
}},
}
watchChan <- watchResp
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return((<-chan clientv3.WatchResponse)(watchChan))
var wg sync.WaitGroup
wg.Add(1) // Only wait for Revoke call
// Use a channel to signal when Put has been called
putCalled := make(chan struct{})
// Expect the re-put operation when key is deleted
cli.EXPECT().Put(gomock.Any(), "thekey", "thevalue", gomock.Any()).Do(func(_, _, _, _ any) {
close(putCalled) // Signal that Put has been called
}).Return(nil, nil)
// Expect revoke when Stop is called
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
wg.Done()
})
pub := NewPublisher(nil, "thekey", "thevalue")
pub.lease = id
pub.fullKey = "thekey"
assert.Nil(t, pub.keepAliveAsync(cli))
// Wait for Put to be called, then stop
<-putCalled
pub.Stop()
wg.Wait()
}
// Test case for key deletion with re-put error (covers error branch in lines 151-152)
func TestPublisher_keepAliveAsyncKeyDeletionPutError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
const id clientv3.LeaseID = 1
cli := internal.NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().KeepAlive(gomock.Any(), id)
// Create a watch channel that will send a delete event
watchChan := make(chan clientv3.WatchResponse, 1)
watchResp := clientv3.WatchResponse{
Events: []*clientv3.Event{{
Type: clientv3.EventTypeDelete,
Kv: &mvccpb.KeyValue{
Key: []byte("thekey"),
},
}},
}
watchChan <- watchResp
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return((<-chan clientv3.WatchResponse)(watchChan))
var wg sync.WaitGroup
wg.Add(1) // Only wait for Revoke call
// Use a channel to signal when Put has been called
putCalled := make(chan struct{})
// Expect the re-put operation to fail
cli.EXPECT().Put(gomock.Any(), "thekey", "thevalue", gomock.Any()).Do(func(_, _, _, _ any) {
close(putCalled) // Signal that Put has been called
}).Return(nil, errors.New("put error"))
// Expect revoke when Stop is called
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
wg.Done()
})
pub := NewPublisher(nil, "thekey", "thevalue")
pub.lease = id
pub.fullKey = "thekey"
assert.Nil(t, pub.keepAliveAsync(cli))
// Wait for Put to be called, then stop
<-putCalled
pub.Stop()
wg.Wait()
}
func TestPublisher_Resume(t *testing.T) {
publisher := new(Publisher)
publisher.resumeChan = make(chan lang.PlaceholderType)
@@ -370,95 +169,3 @@ func TestPublisher_Resume(t *testing.T) {
}()
<-publisher.resumeChan
}
func TestPublisher_keepAliveAsync(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
const id clientv3.LeaseID = 1
conn := createMockConn(t)
defer conn.Close()
cli := internal.NewMockEtcdClient(ctrl)
cli.EXPECT().ActiveConnection().Return(conn).AnyTimes()
cli.EXPECT().Close()
defer cli.Close()
cli.ActiveConnection()
restore := setMockClient(cli)
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().KeepAlive(gomock.Any(), id)
// Add Watch expectation for the new watch mechanism
watchChan := make(<-chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(watchChan)
cli.EXPECT().Grant(gomock.Any(), timeToLive).Return(&clientv3.LeaseGrantResponse{
ID: 1,
}, nil)
cli.EXPECT().Put(gomock.Any(), makeEtcdKey("thekey", int64(id)), "thevalue", gomock.Any())
var wg sync.WaitGroup
wg.Add(1)
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
wg.Done()
})
pub := NewPublisher([]string{"the-endpoint"}, "thekey", "thevalue")
pub.lease = id
assert.Nil(t, pub.KeepAlive())
pub.Stop()
wg.Wait()
}
func createMockConn(t *testing.T) *grpc.ClientConn {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening. Err: %v", err)
}
defer lis.Close()
lisAddr := resolver.Address{Addr: lis.Addr().String()}
lisDone := make(chan struct{})
dialDone := make(chan struct{})
// 1st listener accepts the connection and then does nothing
go func() {
defer close(lisDone)
conn, err := lis.Accept()
if err != nil {
t.Errorf("Error while accepting. Err: %v", err)
return
}
framer := http2.NewFramer(conn, conn)
if err := framer.WriteSettings(http2.Setting{}); err != nil {
t.Errorf("Error while writing settings. Err: %v", err)
return
}
<-dialDone // Close conn only after dial returns.
}()
r := manual.NewBuilderWithScheme("whatever")
r.InitialState(resolver.State{Addresses: []resolver.Address{lisAddr}})
client, err := grpc.DialContext(context.Background(), r.Scheme()+":///test.server",
grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r))
close(dialDone)
if err != nil {
t.Fatalf("Dial failed. Err: %v", err)
}
timeout := time.After(1 * time.Second)
select {
case <-timeout:
t.Fatal("timed out waiting for server to finish")
case <-lisDone:
}
return client
}
func createTempFile(t *testing.T, body []byte) string {
tmpFile, err := os.CreateTemp(os.TempDir(), "go-unit-*.tmp")
if err != nil {
t.Fatal(err)
}
tmpFile.Close()
if err = os.WriteFile(tmpFile.Name(), body, os.ModePerm); err != nil {
t.Fatal(err)
}
return tmpFile.Name()
}

View File

@@ -15,13 +15,10 @@ type (
// A Subscriber is used to subscribe the given key on an etcd cluster.
Subscriber struct {
endpoints []string
exclusive bool
key string
exactMatch bool
items Container
endpoints []string
exclusive bool
items *container
}
KV = internal.KV
)
// NewSubscriber returns a Subscriber.
@@ -31,16 +28,13 @@ type (
func NewSubscriber(endpoints []string, key string, opts ...SubOption) (*Subscriber, error) {
sub := &Subscriber{
endpoints: endpoints,
key: key,
}
for _, opt := range opts {
opt(sub)
}
if sub.items == nil {
sub.items = newContainer(sub.exclusive)
}
sub.items = newContainer(sub.exclusive)
if err := internal.GetRegistry().Monitor(endpoints, key, sub.exactMatch, sub.items); err != nil {
if err := internal.GetRegistry().Monitor(endpoints, key, sub.items); err != nil {
return nil, err
}
@@ -49,17 +43,12 @@ func NewSubscriber(endpoints []string, key string, opts ...SubOption) (*Subscrib
// AddListener adds listener to s.
func (s *Subscriber) AddListener(listener func()) {
s.items.AddListener(listener)
}
// Close closes the subscriber.
func (s *Subscriber) Close() {
internal.GetRegistry().Unmonitor(s.endpoints, s.key, s.exactMatch, s.items)
s.items.addListener(listener)
}
// Values returns all the subscription values.
func (s *Subscriber) Values() []string {
return s.items.GetValues()
return s.items.getValues()
}
// Exclusive means that key value can only be 1:1,
@@ -70,13 +59,6 @@ func Exclusive() SubOption {
}
}
// WithExactMatch turn off querying using key prefixes.
func WithExactMatch() SubOption {
return func(sub *Subscriber) {
sub.exactMatch = true
}
}
// WithSubEtcdAccount provides the etcd username/password.
func WithSubEtcdAccount(user, pass string) SubOption {
return func(sub *Subscriber) {
@@ -91,32 +73,16 @@ func WithSubEtcdTLS(certFile, certKeyFile, caFile string, insecureSkipVerify boo
}
}
// WithContainer provides a custom container to the subscriber.
func WithContainer(container Container) SubOption {
return func(sub *Subscriber) {
sub.items = container
}
type container struct {
exclusive bool
values map[string][]string
mapping map[string]string
snapshot atomic.Value
dirty *syncx.AtomicBool
listeners []func()
lock sync.Mutex
}
type (
Container interface {
OnAdd(kv internal.KV)
OnDelete(kv internal.KV)
AddListener(listener func())
GetValues() []string
}
container struct {
exclusive bool
values map[string][]string
mapping map[string]string
snapshot atomic.Value
dirty *syncx.AtomicBool
listeners []func()
lock sync.Mutex
}
)
func newContainer(exclusive bool) *container {
return &container{
exclusive: exclusive,
@@ -160,7 +126,7 @@ func (c *container) addKv(key, value string) ([]string, bool) {
return nil, false
}
func (c *container) AddListener(listener func()) {
func (c *container) addListener(listener func()) {
c.lock.Lock()
c.listeners = append(c.listeners, listener)
c.lock.Unlock()
@@ -189,7 +155,7 @@ func (c *container) doRemoveKey(key string) {
}
}
func (c *container) GetValues() []string {
func (c *container) getValues() []string {
if !c.dirty.True() {
return c.snapshot.Load().([]string)
}

View File

@@ -171,10 +171,10 @@ func TestContainer(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
var changed bool
c := newContainer(exclusive)
c.AddListener(func() {
c.addListener(func() {
changed = true
})
assert.Nil(t, c.GetValues())
assert.Nil(t, c.getValues())
assert.False(t, changed)
for _, order := range test.do {
@@ -193,9 +193,9 @@ func TestContainer(t *testing.T) {
assert.True(t, changed)
assert.True(t, c.dirty.True())
assert.ElementsMatch(t, test.expect, c.GetValues())
assert.ElementsMatch(t, test.expect, c.getValues())
assert.False(t, c.dirty.True())
assert.ElementsMatch(t, test.expect, c.GetValues())
assert.ElementsMatch(t, test.expect, c.getValues())
})
}
}
@@ -204,14 +204,12 @@ func TestContainer(t *testing.T) {
func TestSubscriber(t *testing.T) {
sub := new(Subscriber)
Exclusive()(sub)
c := newContainer(sub.exclusive)
WithContainer(c)(sub)
sub.items = c
sub.items = newContainer(sub.exclusive)
var count int32
sub.AddListener(func() {
atomic.AddInt32(&count, 1)
})
c.notifyChange()
sub.items.notifyChange()
assert.Empty(t, sub.Values())
assert.Equal(t, int32(1), atomic.LoadInt32(&count))
}
@@ -227,29 +225,3 @@ func TestWithSubEtcdAccount(t *testing.T) {
assert.Equal(t, user, account.User)
assert.Equal(t, "bar", account.Pass)
}
func TestWithExactMatch(t *testing.T) {
sub := new(Subscriber)
WithExactMatch()(sub)
c := newContainer(sub.exclusive)
sub.items = c
var count int32
sub.AddListener(func() {
atomic.AddInt32(&count, 1)
})
c.notifyChange()
assert.Empty(t, sub.Values())
assert.Equal(t, int32(1), atomic.LoadInt32(&count))
}
func TestSubscriberClose(t *testing.T) {
l := newContainer(false)
sub := &Subscriber{
endpoints: []string{"localhost:12379"},
key: "foo",
items: l,
}
assert.NotPanics(t, func() {
sub.Close()
})
}

View File

@@ -1,21 +1,18 @@
package errorx
import (
"errors"
"sync"
import "bytes"
type (
// A BatchError is an error that can hold multiple errors.
BatchError struct {
errs errorArray
}
errorArray []error
)
// BatchError is an error that can hold multiple errors.
type BatchError struct {
errs []error
lock sync.RWMutex
}
// Add adds one or more non-nil errors to the BatchError instance.
// Add adds errs to be, nil errors are ignored.
func (be *BatchError) Add(errs ...error) {
be.lock.Lock()
defer be.lock.Unlock()
for _, err := range errs {
if err != nil {
be.errs = append(be.errs, err)
@@ -23,20 +20,33 @@ func (be *BatchError) Add(errs ...error) {
}
}
// Err returns an error that represents all accumulated errors.
// It returns nil if there are no errors.
// Err returns an error that represents all errors.
func (be *BatchError) Err() error {
be.lock.RLock()
defer be.lock.RUnlock()
// If there are no non-nil errors, errors.Join(...) returns nil.
return errors.Join(be.errs...)
switch len(be.errs) {
case 0:
return nil
case 1:
return be.errs[0]
default:
return be.errs
}
}
// NotNil checks if there is at least one error inside the BatchError.
// NotNil checks if any error inside.
func (be *BatchError) NotNil() bool {
be.lock.RLock()
defer be.lock.RUnlock()
return len(be.errs) > 0
}
// Error returns a string that represents inside errors.
func (ea errorArray) Error() string {
var buf bytes.Buffer
for i := range ea {
if i > 0 {
buf.WriteByte('\n')
}
buf.WriteString(ea[i].Error())
}
return buf.String()
}

View File

@@ -3,7 +3,6 @@ package errorx
import (
"errors"
"fmt"
"sync"
"testing"
"github.com/stretchr/testify/assert"
@@ -34,7 +33,7 @@ func TestBatchErrorNilFromFunc(t *testing.T) {
func TestBatchErrorOneError(t *testing.T) {
var batch BatchError
batch.Add(errors.New(err1))
assert.NotNil(t, batch.Err())
assert.NotNil(t, batch)
assert.Equal(t, err1, batch.Err().Error())
assert.True(t, batch.NotNil())
}
@@ -43,105 +42,7 @@ func TestBatchErrorWithErrors(t *testing.T) {
var batch BatchError
batch.Add(errors.New(err1))
batch.Add(errors.New(err2))
assert.NotNil(t, batch.Err())
assert.NotNil(t, batch)
assert.Equal(t, fmt.Sprintf("%s\n%s", err1, err2), batch.Err().Error())
assert.True(t, batch.NotNil())
}
func TestBatchErrorConcurrentAdd(t *testing.T) {
const count = 10000
var batch BatchError
var wg sync.WaitGroup
wg.Add(count)
for i := 0; i < count; i++ {
go func() {
defer wg.Done()
batch.Add(errors.New(err1))
}()
}
wg.Wait()
assert.NotNil(t, batch.Err())
assert.Equal(t, count, len(batch.errs))
assert.True(t, batch.NotNil())
}
func TestBatchError_Unwrap(t *testing.T) {
t.Run("nil", func(t *testing.T) {
var be BatchError
assert.Nil(t, be.Err())
assert.True(t, errors.Is(be.Err(), nil))
})
t.Run("one error", func(t *testing.T) {
var errFoo = errors.New("foo")
var errBar = errors.New("bar")
var be BatchError
be.Add(errFoo)
assert.True(t, errors.Is(be.Err(), errFoo))
assert.False(t, errors.Is(be.Err(), errBar))
})
t.Run("two errors", func(t *testing.T) {
var errFoo = errors.New("foo")
var errBar = errors.New("bar")
var errBaz = errors.New("baz")
var be BatchError
be.Add(errFoo)
be.Add(errBar)
assert.True(t, errors.Is(be.Err(), errFoo))
assert.True(t, errors.Is(be.Err(), errBar))
assert.False(t, errors.Is(be.Err(), errBaz))
})
}
func TestBatchError_Add(t *testing.T) {
var be BatchError
// Test adding nil errors
be.Add(nil, nil)
assert.False(t, be.NotNil(), "Expected BatchError to be empty after adding nil errors")
// Test adding non-nil errors
err1 := errors.New("error 1")
err2 := errors.New("error 2")
be.Add(err1, err2)
assert.True(t, be.NotNil(), "Expected BatchError to be non-empty after adding errors")
// Test adding a mix of nil and non-nil errors
err3 := errors.New("error 3")
be.Add(nil, err3, nil)
assert.True(t, be.NotNil(), "Expected BatchError to be non-empty after adding a mix of nil and non-nil errors")
}
func TestBatchError_Err(t *testing.T) {
var be BatchError
// Test Err() on empty BatchError
assert.Nil(t, be.Err(), "Expected nil error for empty BatchError")
// Test Err() with multiple errors
err1 := errors.New("error 1")
err2 := errors.New("error 2")
be.Add(err1, err2)
combinedErr := be.Err()
assert.NotNil(t, combinedErr, "Expected nil error for BatchError with multiple errors")
// Check if the combined error contains both error messages
errString := combinedErr.Error()
assert.Truef(t, errors.Is(combinedErr, err1), "Combined error doesn't contain first error: %s", errString)
assert.Truef(t, errors.Is(combinedErr, err2), "Combined error doesn't contain second error: %s", errString)
}
func TestBatchError_NotNil(t *testing.T) {
var be BatchError
// Test NotNil() on empty BatchError
assert.Nil(t, be.Err(), "Expected nil error for empty BatchError")
// Test NotNil() after adding an error
be.Add(errors.New("test error"))
assert.NotNil(t, be.Err(), "Expected non-nil error after adding an error")
}

View File

@@ -1,14 +0,0 @@
package errorx
import "errors"
// In checks if the given err is one of errs.
func In(err error, errs ...error) bool {
for _, each := range errs {
if errors.Is(err, each) {
return true
}
}
return false
}

View File

@@ -1,70 +0,0 @@
package errorx
import (
"errors"
"testing"
)
func TestIn(t *testing.T) {
err1 := errors.New("error 1")
err2 := errors.New("error 2")
err3 := errors.New("error 3")
tests := []struct {
name string
err error
errs []error
want bool
}{
{
name: "Error matches one of the errors in the list",
err: err1,
errs: []error{err1, err2},
want: true,
},
{
name: "Error does not match any errors in the list",
err: err3,
errs: []error{err1, err2},
want: false,
},
{
name: "Empty error list",
err: err1,
errs: []error{},
want: false,
},
{
name: "Nil error with non-nil list",
err: nil,
errs: []error{err1, err2},
want: false,
},
{
name: "Non-nil error with nil in list",
err: err1,
errs: []error{nil, err2},
want: false,
},
{
name: "Error matches nil error in the list",
err: nil,
errs: []error{nil, err2},
want: true,
},
{
name: "Nil error with empty list",
err: nil,
errs: []error{},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := In(tt.err, tt.errs...); got != tt.want {
t.Errorf("In() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -12,7 +12,7 @@ func Wrap(err error, message string) error {
}
// Wrapf returns an error that wraps err with given format and args.
func Wrapf(err error, format string, args ...any) error {
func Wrapf(err error, format string, args ...interface{}) error {
if err == nil {
return nil
}

View File

@@ -42,7 +42,7 @@ func NewBulkExecutor(execute Execute, opts ...BulkOption) *BulkExecutor {
}
// Add adds task into be.
func (be *BulkExecutor) Add(task any) error {
func (be *BulkExecutor) Add(task interface{}) error {
be.executor.Add(task)
return nil
}
@@ -79,22 +79,22 @@ func newBulkOptions() bulkOptions {
}
type bulkContainer struct {
tasks []any
tasks []interface{}
execute Execute
maxTasks int
}
func (bc *bulkContainer) AddTask(task any) bool {
func (bc *bulkContainer) AddTask(task interface{}) bool {
bc.tasks = append(bc.tasks, task)
return len(bc.tasks) >= bc.maxTasks
}
func (bc *bulkContainer) Execute(tasks any) {
vals := tasks.([]any)
func (bc *bulkContainer) Execute(tasks interface{}) {
vals := tasks.([]interface{})
bc.execute(vals)
}
func (bc *bulkContainer) RemoveAll() any {
func (bc *bulkContainer) RemoveAll() interface{} {
tasks := bc.tasks
bc.tasks = nil
return tasks

View File

@@ -12,7 +12,7 @@ func TestBulkExecutor(t *testing.T) {
var values []int
var lock sync.Mutex
executor := NewBulkExecutor(func(items []any) {
executor := NewBulkExecutor(func(items []interface{}) {
lock.Lock()
values = append(values, len(items))
lock.Unlock()
@@ -40,7 +40,7 @@ func TestBulkExecutorFlushInterval(t *testing.T) {
var wait sync.WaitGroup
wait.Add(1)
executor := NewBulkExecutor(func(items []any) {
executor := NewBulkExecutor(func(items []interface{}) {
assert.Equal(t, size, len(items))
wait.Done()
}, WithBulkTasks(caches), WithBulkInterval(time.Millisecond*100))
@@ -53,7 +53,7 @@ func TestBulkExecutorFlushInterval(t *testing.T) {
}
func TestBulkExecutorEmpty(t *testing.T) {
NewBulkExecutor(func(items []any) {
NewBulkExecutor(func(items []interface{}) {
assert.Fail(t, "should not called")
}, WithBulkTasks(10), WithBulkInterval(time.Millisecond))
time.Sleep(time.Millisecond * 100)
@@ -67,7 +67,7 @@ func TestBulkExecutorFlush(t *testing.T) {
var wait sync.WaitGroup
wait.Add(1)
be := NewBulkExecutor(func(items []any) {
be := NewBulkExecutor(func(items []interface{}) {
assert.Equal(t, tasks, len(items))
wait.Done()
}, WithBulkTasks(caches), WithBulkInterval(time.Minute))
@@ -78,11 +78,11 @@ func TestBulkExecutorFlush(t *testing.T) {
wait.Wait()
}
func TestBulkExecutorFlushSlowTasks(t *testing.T) {
func TestBuldExecutorFlushSlowTasks(t *testing.T) {
const total = 1500
lock := new(sync.Mutex)
result := make([]any, 0, 10000)
exec := NewBulkExecutor(func(tasks []any) {
result := make([]interface{}, 0, 10000)
exec := NewBulkExecutor(func(tasks []interface{}) {
time.Sleep(time.Millisecond * 100)
lock.Lock()
defer lock.Unlock()
@@ -100,7 +100,7 @@ func TestBulkExecutorFlushSlowTasks(t *testing.T) {
func BenchmarkBulkExecutor(b *testing.B) {
b.ReportAllocs()
be := NewBulkExecutor(func(tasks []any) {
be := NewBulkExecutor(func(tasks []interface{}) {
time.Sleep(time.Millisecond * time.Duration(len(tasks)))
})
for i := 0; i < b.N; i++ {

View File

@@ -42,7 +42,7 @@ func NewChunkExecutor(execute Execute, opts ...ChunkOption) *ChunkExecutor {
}
// Add adds task with given chunk size into ce.
func (ce *ChunkExecutor) Add(task any, size int) error {
func (ce *ChunkExecutor) Add(task interface{}, size int) error {
ce.executor.Add(chunk{
val: task,
size: size,
@@ -82,25 +82,25 @@ func newChunkOptions() chunkOptions {
}
type chunkContainer struct {
tasks []any
tasks []interface{}
execute Execute
size int
maxChunkSize int
}
func (bc *chunkContainer) AddTask(task any) bool {
func (bc *chunkContainer) AddTask(task interface{}) bool {
ck := task.(chunk)
bc.tasks = append(bc.tasks, ck.val)
bc.size += ck.size
return bc.size >= bc.maxChunkSize
}
func (bc *chunkContainer) Execute(tasks any) {
vals := tasks.([]any)
func (bc *chunkContainer) Execute(tasks interface{}) {
vals := tasks.([]interface{})
bc.execute(vals)
}
func (bc *chunkContainer) RemoveAll() any {
func (bc *chunkContainer) RemoveAll() interface{} {
tasks := bc.tasks
bc.tasks = nil
bc.size = 0
@@ -108,6 +108,6 @@ func (bc *chunkContainer) RemoveAll() any {
}
type chunk struct {
val any
val interface{}
size int
}

View File

@@ -12,7 +12,7 @@ func TestChunkExecutor(t *testing.T) {
var values []int
var lock sync.Mutex
executor := NewChunkExecutor(func(items []any) {
executor := NewChunkExecutor(func(items []interface{}) {
lock.Lock()
values = append(values, len(items))
lock.Unlock()
@@ -40,7 +40,7 @@ func TestChunkExecutorFlushInterval(t *testing.T) {
var wait sync.WaitGroup
wait.Add(1)
executor := NewChunkExecutor(func(items []any) {
executor := NewChunkExecutor(func(items []interface{}) {
assert.Equal(t, size, len(items))
wait.Done()
}, WithChunkBytes(caches), WithFlushInterval(time.Millisecond*100))
@@ -53,7 +53,7 @@ func TestChunkExecutorFlushInterval(t *testing.T) {
}
func TestChunkExecutorEmpty(t *testing.T) {
executor := NewChunkExecutor(func(items []any) {
executor := NewChunkExecutor(func(items []interface{}) {
assert.Fail(t, "should not called")
}, WithChunkBytes(10), WithFlushInterval(time.Millisecond))
time.Sleep(time.Millisecond * 100)
@@ -68,7 +68,7 @@ func TestChunkExecutorFlush(t *testing.T) {
var wait sync.WaitGroup
wait.Add(1)
be := NewChunkExecutor(func(items []any) {
be := NewChunkExecutor(func(items []interface{}) {
assert.Equal(t, tasks, len(items))
wait.Done()
}, WithChunkBytes(caches), WithFlushInterval(time.Minute))
@@ -82,7 +82,7 @@ func TestChunkExecutorFlush(t *testing.T) {
func BenchmarkChunkExecutor(b *testing.B) {
b.ReportAllocs()
be := NewChunkExecutor(func(tasks []any) {
be := NewChunkExecutor(func(tasks []interface{}) {
time.Sleep(time.Millisecond * time.Duration(len(tasks)))
})
for i := 0; i < b.N; i++ {

View File

@@ -21,16 +21,16 @@ type (
TaskContainer interface {
// AddTask adds the task into the container.
// Returns true if the container needs to be flushed after the addition.
AddTask(task any) bool
AddTask(task interface{}) bool
// Execute handles the collected tasks by the container when flushing.
Execute(tasks any)
Execute(tasks interface{})
// RemoveAll removes the contained tasks, and return them.
RemoveAll() any
RemoveAll() interface{}
}
// A PeriodicalExecutor is an executor that periodically execute tasks.
PeriodicalExecutor struct {
commander chan any
commander chan interface{}
interval time.Duration
container TaskContainer
waitGroup sync.WaitGroup
@@ -48,7 +48,7 @@ type (
func NewPeriodicalExecutor(interval time.Duration, container TaskContainer) *PeriodicalExecutor {
executor := &PeriodicalExecutor{
// buffer 1 to let the caller go quickly
commander: make(chan any, 1),
commander: make(chan interface{}, 1),
interval: interval,
container: container,
confirmChan: make(chan lang.PlaceholderType),
@@ -64,7 +64,7 @@ func NewPeriodicalExecutor(interval time.Duration, container TaskContainer) *Per
}
// Add adds tasks into pe.
func (pe *PeriodicalExecutor) Add(task any) {
func (pe *PeriodicalExecutor) Add(task interface{}) {
if vals, ok := pe.addAndCheck(task); ok {
pe.commander <- vals
<-pe.confirmChan
@@ -74,14 +74,14 @@ func (pe *PeriodicalExecutor) Add(task any) {
// Flush forces pe to execute tasks.
func (pe *PeriodicalExecutor) Flush() bool {
pe.enterExecution()
return pe.executeTasks(func() any {
return pe.executeTasks(func() interface{} {
pe.lock.Lock()
defer pe.lock.Unlock()
return pe.container.RemoveAll()
}())
}
// Sync lets caller run fn thread-safe with pe, especially for the underlying container.
// Sync lets caller to run fn thread-safe with pe, especially for the underlying container.
func (pe *PeriodicalExecutor) Sync(fn func()) {
pe.lock.Lock()
defer pe.lock.Unlock()
@@ -96,7 +96,7 @@ func (pe *PeriodicalExecutor) Wait() {
})
}
func (pe *PeriodicalExecutor) addAndCheck(task any) (any, bool) {
func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool) {
pe.lock.Lock()
defer func() {
if !pe.guarded {
@@ -116,7 +116,7 @@ func (pe *PeriodicalExecutor) addAndCheck(task any) (any, bool) {
}
func (pe *PeriodicalExecutor) backgroundFlush() {
go func() {
threading.GoSafe(func() {
// flush before quit goroutine to avoid missing tasks
defer pe.Flush()
@@ -144,7 +144,7 @@ func (pe *PeriodicalExecutor) backgroundFlush() {
}
}
}
}()
})
}
func (pe *PeriodicalExecutor) doneExecution() {
@@ -157,20 +157,18 @@ func (pe *PeriodicalExecutor) enterExecution() {
})
}
func (pe *PeriodicalExecutor) executeTasks(tasks any) bool {
func (pe *PeriodicalExecutor) executeTasks(tasks interface{}) bool {
defer pe.doneExecution()
ok := pe.hasTasks(tasks)
if ok {
threading.RunSafe(func() {
pe.container.Execute(tasks)
})
pe.container.Execute(tasks)
}
return ok
}
func (pe *PeriodicalExecutor) hasTasks(tasks any) bool {
func (pe *PeriodicalExecutor) hasTasks(tasks interface{}) bool {
if tasks == nil {
return false
}

View File

@@ -17,22 +17,22 @@ const threshold = 10
type container struct {
interval time.Duration
tasks []int
execute func(tasks any)
execute func(tasks interface{})
}
func newContainer(interval time.Duration, execute func(tasks any)) *container {
func newContainer(interval time.Duration, execute func(tasks interface{})) *container {
return &container{
interval: interval,
execute: execute,
}
}
func (c *container) AddTask(task any) bool {
func (c *container) AddTask(task interface{}) bool {
c.tasks = append(c.tasks, task.(int))
return len(c.tasks) > threshold
}
func (c *container) Execute(tasks any) {
func (c *container) Execute(tasks interface{}) {
if c.execute != nil {
c.execute(tasks)
} else {
@@ -40,7 +40,7 @@ func (c *container) Execute(tasks any) {
}
}
func (c *container) RemoveAll() any {
func (c *container) RemoveAll() interface{} {
tasks := c.tasks
c.tasks = nil
return tasks
@@ -76,7 +76,7 @@ func TestPeriodicalExecutor_Bulk(t *testing.T) {
var vals []int
// avoid data race
var lock sync.Mutex
exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks any) {
exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks interface{}) {
t := tasks.([]int)
for _, each := range t {
lock.Lock()
@@ -108,83 +108,25 @@ func TestPeriodicalExecutor_Bulk(t *testing.T) {
lock.Unlock()
}
func TestPeriodicalExecutor_Panic(t *testing.T) {
// avoid data race
var lock sync.Mutex
ticker := timex.NewFakeTicker()
var (
executedTasks []int
expected []int
)
executor := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks any) {
tt := tasks.([]int)
lock.Lock()
executedTasks = append(executedTasks, tt...)
lock.Unlock()
if tt[0] == 0 {
panic("test")
}
}))
executor.newTicker = func(duration time.Duration) timex.Ticker {
return ticker
}
for i := 0; i < 30; i++ {
executor.Add(i)
expected = append(expected, i)
}
ticker.Tick()
ticker.Tick()
time.Sleep(time.Millisecond)
lock.Lock()
assert.Equal(t, expected, executedTasks)
lock.Unlock()
}
func TestPeriodicalExecutor_FlushPanic(t *testing.T) {
var (
executedTasks []int
expected []int
lock sync.Mutex
)
executor := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks any) {
tt := tasks.([]int)
lock.Lock()
executedTasks = append(executedTasks, tt...)
lock.Unlock()
if tt[0] == 0 {
panic("flush panic")
}
}))
for i := 0; i < 8; i++ {
executor.Add(i)
expected = append(expected, i)
}
executor.Flush()
lock.Lock()
assert.Equal(t, expected, executedTasks)
lock.Unlock()
}
func TestPeriodicalExecutor_Wait(t *testing.T) {
var lock sync.Mutex
executor := NewBulkExecutor(func(tasks []any) {
executer := NewBulkExecutor(func(tasks []interface{}) {
lock.Lock()
defer lock.Unlock()
time.Sleep(10 * time.Millisecond)
}, WithBulkTasks(1), WithBulkInterval(time.Second))
for i := 0; i < 10; i++ {
executor.Add(1)
executer.Add(1)
}
executor.Flush()
executor.Wait()
executer.Flush()
executer.Wait()
}
func TestPeriodicalExecutor_WaitFast(t *testing.T) {
const total = 3
var cnt int
var lock sync.Mutex
executor := NewBulkExecutor(func(tasks []any) {
executer := NewBulkExecutor(func(tasks []interface{}) {
defer func() {
cnt++
}()
@@ -193,15 +135,15 @@ func TestPeriodicalExecutor_WaitFast(t *testing.T) {
time.Sleep(10 * time.Millisecond)
}, WithBulkTasks(1), WithBulkInterval(10*time.Millisecond))
for i := 0; i < total; i++ {
executor.Add(2)
executer.Add(2)
}
executor.Flush()
executor.Wait()
executer.Flush()
executer.Wait()
assert.Equal(t, total, cnt)
}
func TestPeriodicalExecutor_Deadlock(t *testing.T) {
executor := NewBulkExecutor(func(tasks []any) {
executor := NewBulkExecutor(func(tasks []interface{}) {
}, WithBulkTasks(1), WithBulkInterval(time.Millisecond))
for i := 0; i < 1e5; i++ {
executor.Add(1)
@@ -209,7 +151,13 @@ func TestPeriodicalExecutor_Deadlock(t *testing.T) {
}
func TestPeriodicalExecutor_hasTasks(t *testing.T) {
ticker := timex.NewFakeTicker()
defer ticker.Stop()
exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, nil))
exec.newTicker = func(d time.Duration) timex.Ticker {
return ticker
}
assert.False(t, exec.hasTasks(nil))
assert.True(t, exec.hasTasks(1))
}

View File

@@ -5,4 +5,4 @@ import "time"
const defaultFlushInterval = time.Second
// Execute defines the method to execute tasks.
type Execute func(tasks []any)
type Execute func(tasks []interface{})

View File

@@ -35,7 +35,6 @@ func firstLine(file *os.File) (string, error) {
for {
buf := make([]byte, bufSize)
n, err := file.ReadAt(buf, offset)
if err != nil && err != io.EOF {
return "", err
}
@@ -46,10 +45,6 @@ func firstLine(file *os.File) (string, error) {
}
}
if err == io.EOF {
return string(append(first, buf[:n]...)), nil
}
first = append(first, buf[:n]...)
offset += bufSize
}
@@ -62,42 +57,30 @@ func lastLine(filename string, file *os.File) (string, error) {
}
var last []byte
bufLen := int64(bufSize)
offset := info.Size()
for offset > 0 {
if offset < bufLen {
bufLen = offset
for {
offset -= bufSize
if offset < 0 {
offset = 0
} else {
offset -= bufLen
}
buf := make([]byte, bufLen)
buf := make([]byte, bufSize)
n, err := file.ReadAt(buf, offset)
if err != nil && err != io.EOF {
return "", err
}
if n == 0 {
break
}
if buf[n-1] == '\n' {
buf = buf[:n-1]
n--
} else {
buf = buf[:n]
}
for i := n - 1; i >= 0; i-- {
if buf[i] == '\n' {
return string(append(buf[i+1:], last...)), nil
for n--; n >= 0; n-- {
if buf[n] == '\n' {
return string(append(buf[n+1:], last...)), nil
}
}
last = append(buf, last...)
}
return string(last), nil
}

View File

@@ -52,7 +52,6 @@ last line`
second line
last line
`
emptyContent = ``
)
func TestFirstLine(t *testing.T) {
@@ -75,31 +74,6 @@ func TestFirstLineShort(t *testing.T) {
assert.Equal(t, "first line", val)
}
func TestFirstLineError(t *testing.T) {
_, err := FirstLine("/tmp/does-not-exist")
assert.Error(t, err)
}
func TestFirstLineEmptyFile(t *testing.T) {
filename, err := fs.TempFilenameWithText(emptyContent)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := FirstLine(filename)
assert.Nil(t, err)
assert.Equal(t, "", val)
}
func TestFirstLineWithoutNewline(t *testing.T) {
filename, err := fs.TempFilenameWithText(longLine)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := FirstLine(filename)
assert.Nil(t, err)
assert.Equal(t, longLine, val)
}
func TestLastLine(t *testing.T) {
filename, err := fs.TempFilenameWithText(text)
assert.Nil(t, err)
@@ -120,16 +94,6 @@ func TestLastLineWithLastNewline(t *testing.T) {
assert.Equal(t, longLine, val)
}
func TestLastLineWithoutLastNewline(t *testing.T) {
filename, err := fs.TempFilenameWithText(longLine)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := LastLine(filename)
assert.Nil(t, err)
assert.Equal(t, longLine, val)
}
func TestLastLineShort(t *testing.T) {
filename, err := fs.TempFilenameWithText(shortText)
assert.Nil(t, err)
@@ -149,72 +113,3 @@ func TestLastLineWithLastNewlineShort(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, "last line", val)
}
func TestLastLineError(t *testing.T) {
_, err := LastLine("/tmp/does-not-exist")
assert.Error(t, err)
}
func TestLastLineEmptyFile(t *testing.T) {
filename, err := fs.TempFilenameWithText(emptyContent)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := LastLine(filename)
assert.Nil(t, err)
assert.Equal(t, "", val)
}
func TestFirstLineExactlyBufSize(t *testing.T) {
content := make([]byte, bufSize)
for i := range content {
content[i] = 'a'
}
content[bufSize-1] = '\n' // Ensure there is a newline at the edge
filename, err := fs.TempFilenameWithText(string(content))
assert.Nil(t, err)
defer os.Remove(filename)
val, err := FirstLine(filename)
assert.Nil(t, err)
assert.Equal(t, string(content[:bufSize-1]), val)
}
func TestLastLineExactlyBufSize(t *testing.T) {
content := make([]byte, bufSize)
for i := range content {
content[i] = 'a'
}
content[bufSize-1] = '\n' // Ensure there is a newline at the edge
filename, err := fs.TempFilenameWithText(string(content))
assert.Nil(t, err)
defer os.Remove(filename)
val, err := LastLine(filename)
assert.Nil(t, err)
assert.Equal(t, string(content[:bufSize-1]), val)
}
func TestFirstLineLargeFile(t *testing.T) {
content := text + text + text + "\n" + "extra"
filename, err := fs.TempFilenameWithText(content)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := FirstLine(filename)
assert.Nil(t, err)
assert.Equal(t, "first line", val)
}
func TestLastLineLargeFile(t *testing.T) {
content := text + text + text + "\n" + "extra"
filename, err := fs.TempFilenameWithText(content)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := LastLine(filename)
assert.Nil(t, err)
assert.Equal(t, "extra", val)
}

View File

@@ -5,7 +5,7 @@ import "gopkg.in/cheggaaa/pb.v1"
type (
// A Scanner is used to read lines.
Scanner interface {
// Scan checks if it has remaining to read.
// Scan checks if has remaining to read.
Scan() bool
// Text returns next line.
Text() string

View File

@@ -1,4 +1,5 @@
//go:build windows
// +build windows
package fs

View File

@@ -1,4 +1,5 @@
//go:build linux || darwin || freebsd
//go:build linux || darwin
// +build linux darwin
package fs

View File

@@ -11,29 +11,29 @@ import (
// The file is kept as open, the caller should close the file handle,
// and remove the file by name.
func TempFileWithText(text string) (*os.File, error) {
tmpFile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text)))
tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text)))
if err != nil {
return nil, err
}
if err := os.WriteFile(tmpFile.Name(), []byte(text), os.ModeTemporary); err != nil {
if err := os.WriteFile(tmpfile.Name(), []byte(text), os.ModeTemporary); err != nil {
return nil, err
}
return tmpFile, nil
return tmpfile, nil
}
// TempFilenameWithText creates the file with the given content,
// and returns the filename (full path).
// The caller should remove the file after use.
func TempFilenameWithText(text string) (string, error) {
tmpFile, err := TempFileWithText(text)
tmpfile, err := TempFileWithText(text)
if err != nil {
return "", err
}
filename := tmpFile.Name()
if err = tmpFile.Close(); err != nil {
filename := tmpfile.Name()
if err = tmpfile.Close(); err != nil {
return "", err
}

View File

@@ -1,9 +1,6 @@
package fx
import (
"github.com/zeromicro/go-zero/core/errorx"
"github.com/zeromicro/go-zero/core/threading"
)
import "github.com/zeromicro/go-zero/core/threading"
// Parallel runs fns parallelly and waits for done.
func Parallel(fns ...func()) {
@@ -13,20 +10,3 @@ func Parallel(fns ...func()) {
}
group.Wait()
}
func ParallelErr(fns ...func() error) error {
var be errorx.BatchError
group := threading.NewRoutineGroup()
for _, fn := range fns {
f := fn
group.RunSafe(func() {
if err := f(); err != nil {
be.Add(err)
}
})
}
group.Wait()
return be.Err()
}

View File

@@ -1,7 +1,6 @@
package fx
import (
"errors"
"sync/atomic"
"testing"
"time"
@@ -23,54 +22,3 @@ func TestParallel(t *testing.T) {
})
assert.Equal(t, int32(6), count)
}
func TestParallelErr(t *testing.T) {
var count int32
err := ParallelErr(
func() error {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, 1)
return errors.New("failed to exec #1")
},
func() error {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, 2)
return errors.New("failed to exec #2")
},
func() error {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, 3)
return nil
},
)
assert.Equal(t, int32(6), count)
assert.Error(t, err)
assert.ErrorContains(t, err, "failed to exec #1", "failed to exec #2")
}
func TestParallelErrErrorNil(t *testing.T) {
var count int32
err := ParallelErr(
func() error {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, 1)
return nil
},
func() error {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, 2)
return nil
},
func() error {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, 3)
return nil
},
)
assert.Equal(t, int32(6), count)
assert.NoError(t, err)
}

View File

@@ -1,12 +1,6 @@
package fx
import (
"context"
"errors"
"time"
"github.com/zeromicro/go-zero/core/errorx"
)
import "github.com/zeromicro/go-zero/core/errorx"
const defaultRetryTimes = 3
@@ -15,110 +9,36 @@ type (
RetryOption func(*retryOptions)
retryOptions struct {
times int
interval time.Duration
timeout time.Duration
ignoreErrors []error
times int
}
)
// DoWithRetry runs fn, and retries if failed. Default to retry 3 times.
// Note that if the fn function accesses global variables outside the function
// and performs modification operations, it is best to lock them,
// otherwise there may be data race issues
func DoWithRetry(fn func() error, opts ...RetryOption) error {
return retry(context.Background(), func(errChan chan error, retryCount int) {
errChan <- fn()
}, opts...)
}
// DoWithRetryCtx runs fn, and retries if failed. Default to retry 3 times.
// fn retryCount indicates the current number of retries, starting from 0
// Note that if the fn function accesses global variables outside the function
// and performs modification operations, it is best to lock them,
// otherwise there may be data race issues
func DoWithRetryCtx(ctx context.Context, fn func(ctx context.Context, retryCount int) error,
opts ...RetryOption) error {
return retry(ctx, func(errChan chan error, retryCount int) {
errChan <- fn(ctx, retryCount)
}, opts...)
}
func retry(ctx context.Context, fn func(errChan chan error, retryCount int), opts ...RetryOption) error {
options := newRetryOptions()
for _, opt := range opts {
opt(options)
}
var berr errorx.BatchError
var cancelFunc context.CancelFunc
if options.timeout > 0 {
ctx, cancelFunc = context.WithTimeout(ctx, options.timeout)
defer cancelFunc()
}
errChan := make(chan error, 1)
for i := 0; i < options.times; i++ {
go fn(errChan, i)
select {
case err := <-errChan:
if err != nil {
for _, ignoreErr := range options.ignoreErrors {
if errors.Is(err, ignoreErr) {
return nil
}
}
berr.Add(err)
} else {
return nil
}
case <-ctx.Done():
berr.Add(ctx.Err())
return berr.Err()
}
if options.interval > 0 {
select {
case <-ctx.Done():
berr.Add(ctx.Err())
return berr.Err()
case <-time.After(options.interval):
}
if err := fn(); err != nil {
berr.Add(err)
} else {
return nil
}
}
return berr.Err()
}
// WithIgnoreErrors Ignore the specified errors
func WithIgnoreErrors(ignoreErrors []error) RetryOption {
return func(options *retryOptions) {
options.ignoreErrors = ignoreErrors
}
}
// WithInterval customizes a DoWithRetry call with given interval.
func WithInterval(interval time.Duration) RetryOption {
return func(options *retryOptions) {
options.interval = interval
}
}
// WithRetry customizes a DoWithRetry call with given retry times.
// WithRetry customize a DoWithRetry call with given retry times.
func WithRetry(times int) RetryOption {
return func(options *retryOptions) {
options.times = times
}
}
// WithTimeout customizes a DoWithRetry call with given timeout.
func WithTimeout(timeout time.Duration) RetryOption {
return func(options *retryOptions) {
options.timeout = timeout
}
}
func newRetryOptions() *retryOptions {
return &retryOptions{
times: defaultRetryTimes,

View File

@@ -1,10 +1,8 @@
package fx
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
@@ -14,153 +12,31 @@ func TestRetry(t *testing.T) {
return errors.New("any")
}))
times1 := 0
var times int
assert.Nil(t, DoWithRetry(func() error {
times1++
if times1 == defaultRetryTimes {
times++
if times == defaultRetryTimes {
return nil
}
return errors.New("any")
}))
times2 := 0
times = 0
assert.NotNil(t, DoWithRetry(func() error {
times2++
if times2 == defaultRetryTimes+1 {
times++
if times == defaultRetryTimes+1 {
return nil
}
return errors.New("any")
}))
total := 2 * defaultRetryTimes
times3 := 0
times = 0
assert.Nil(t, DoWithRetry(func() error {
times3++
if times3 == total {
times++
if times == total {
return nil
}
return errors.New("any")
}, WithRetry(total)))
}
func TestRetryWithTimeout(t *testing.T) {
assert.Nil(t, DoWithRetry(func() error {
return nil
}, WithTimeout(time.Millisecond*500)))
times1 := 0
assert.Nil(t, DoWithRetry(func() error {
times1++
if times1 == 1 {
return errors.New("any ")
}
time.Sleep(time.Millisecond * 150)
return nil
}, WithTimeout(time.Millisecond*250)))
total := defaultRetryTimes
times2 := 0
assert.Nil(t, DoWithRetry(func() error {
times2++
if times2 == total {
return nil
}
time.Sleep(time.Millisecond * 50)
return errors.New("any")
}, WithTimeout(time.Millisecond*50*(time.Duration(total)+2))))
assert.NotNil(t, DoWithRetry(func() error {
return errors.New("any")
}, WithTimeout(time.Millisecond*250)))
}
func TestRetryWithInterval(t *testing.T) {
times1 := 0
assert.NotNil(t, DoWithRetry(func() error {
times1++
if times1 == 1 {
return errors.New("any")
}
time.Sleep(time.Millisecond * 150)
return nil
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
times2 := 0
assert.NotNil(t, DoWithRetry(func() error {
times2++
if times2 == 2 {
return nil
}
time.Sleep(time.Millisecond * 150)
return errors.New("any ")
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
}
func TestRetryWithWithIgnoreErrors(t *testing.T) {
ignoreErr1 := errors.New("ignore error1")
ignoreErr2 := errors.New("ignore error2")
ignoreErrs := []error{ignoreErr1, ignoreErr2}
assert.Nil(t, DoWithRetry(func() error {
return ignoreErr1
}, WithIgnoreErrors(ignoreErrs)))
assert.Nil(t, DoWithRetry(func() error {
return ignoreErr2
}, WithIgnoreErrors(ignoreErrs)))
assert.NotNil(t, DoWithRetry(func() error {
return errors.New("any")
}))
}
func TestRetryCtx(t *testing.T) {
t.Run("with timeout", func(t *testing.T) {
assert.NotNil(t, DoWithRetryCtx(context.Background(), func(ctx context.Context, retryCount int) error {
if retryCount == 0 {
return errors.New("any")
}
time.Sleep(time.Millisecond * 150)
return nil
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
assert.NotNil(t, DoWithRetryCtx(context.Background(), func(ctx context.Context, retryCount int) error {
if retryCount == 1 {
return nil
}
time.Sleep(time.Millisecond * 150)
return errors.New("any ")
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
})
t.Run("with deadline exceeded", func(t *testing.T) {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond*250))
defer cancel()
var times int
assert.Error(t, DoWithRetryCtx(ctx, func(ctx context.Context, retryCount int) error {
times++
time.Sleep(time.Millisecond * 150)
return errors.New("any")
}, WithInterval(time.Millisecond*150)))
assert.Equal(t, 1, times)
})
t.Run("with deadline not exceeded", func(t *testing.T) {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond*250))
defer cancel()
var times int
assert.NoError(t, DoWithRetryCtx(ctx, func(ctx context.Context, retryCount int) error {
times++
if times == defaultRetryTimes {
return nil
}
time.Sleep(time.Millisecond * 50)
return errors.New("any")
}))
assert.Equal(t, defaultRetryTimes, times)
})
}

View File

@@ -21,31 +21,31 @@ type (
}
// FilterFunc defines the method to filter a Stream.
FilterFunc func(item any) bool
FilterFunc func(item interface{}) bool
// ForAllFunc defines the method to handle all elements in a Stream.
ForAllFunc func(pipe <-chan any)
ForAllFunc func(pipe <-chan interface{})
// ForEachFunc defines the method to handle each element in a Stream.
ForEachFunc func(item any)
ForEachFunc func(item interface{})
// GenerateFunc defines the method to send elements into a Stream.
GenerateFunc func(source chan<- any)
GenerateFunc func(source chan<- interface{})
// KeyFunc defines the method to generate keys for the elements in a Stream.
KeyFunc func(item any) any
KeyFunc func(item interface{}) interface{}
// LessFunc defines the method to compare the elements in a Stream.
LessFunc func(a, b any) bool
LessFunc func(a, b interface{}) bool
// MapFunc defines the method to map each element to another object in a Stream.
MapFunc func(item any) any
MapFunc func(item interface{}) interface{}
// Option defines the method to customize a Stream.
Option func(opts *rxOptions)
// ParallelFunc defines the method to handle elements parallelly.
ParallelFunc func(item any)
ParallelFunc func(item interface{})
// ReduceFunc defines the method to reduce all the elements in a Stream.
ReduceFunc func(pipe <-chan any) (any, error)
ReduceFunc func(pipe <-chan interface{}) (interface{}, error)
// WalkFunc defines the method to walk through all the elements in a Stream.
WalkFunc func(item any, pipe chan<- any)
WalkFunc func(item interface{}, pipe chan<- interface{})
// A Stream is a stream that can be used to do stream processing.
Stream struct {
source <-chan any
source <-chan interface{}
}
)
@@ -56,7 +56,7 @@ func Concat(s Stream, others ...Stream) Stream {
// From constructs a Stream from the given GenerateFunc.
func From(generate GenerateFunc) Stream {
source := make(chan any)
source := make(chan interface{})
threading.GoSafe(func() {
defer close(source)
@@ -67,8 +67,8 @@ func From(generate GenerateFunc) Stream {
}
// Just converts the given arbitrary items to a Stream.
func Just(items ...any) Stream {
source := make(chan any, len(items))
func Just(items ...interface{}) Stream {
source := make(chan interface{}, len(items))
for _, item := range items {
source <- item
}
@@ -78,16 +78,16 @@ func Just(items ...any) Stream {
}
// Range converts the given channel to a Stream.
func Range(source <-chan any) Stream {
func Range(source <-chan interface{}) Stream {
return Stream{
source: source,
}
}
// AllMatch returns whether all elements of this stream match the provided predicate.
// AllMach returns whether all elements of this stream match the provided predicate.
// May not evaluate the predicate on all elements if not necessary for determining the result.
// If the stream is empty then true is returned and the predicate is not evaluated.
func (s Stream) AllMatch(predicate func(item any) bool) bool {
func (s Stream) AllMach(predicate func(item interface{}) bool) bool {
for item := range s.source {
if !predicate(item) {
// make sure the former goroutine not block, and current func returns fast.
@@ -99,10 +99,10 @@ func (s Stream) AllMatch(predicate func(item any) bool) bool {
return true
}
// AnyMatch returns whether any elements of this stream match the provided predicate.
// AnyMach returns whether any elements of this stream match the provided predicate.
// May not evaluate the predicate on all elements if not necessary for determining the result.
// If the stream is empty then false is returned and the predicate is not evaluated.
func (s Stream) AnyMatch(predicate func(item any) bool) bool {
func (s Stream) AnyMach(predicate func(item interface{}) bool) bool {
for item := range s.source {
if predicate(item) {
// make sure the former goroutine not block, and current func returns fast.
@@ -121,7 +121,7 @@ func (s Stream) Buffer(n int) Stream {
n = 0
}
source := make(chan any, n)
source := make(chan interface{}, n)
go func() {
for item := range s.source {
source <- item
@@ -134,7 +134,7 @@ func (s Stream) Buffer(n int) Stream {
// Concat returns a Stream that concatenated other streams
func (s Stream) Concat(others ...Stream) Stream {
source := make(chan any)
source := make(chan interface{})
go func() {
group := threading.NewRoutineGroup()
@@ -168,14 +168,14 @@ func (s Stream) Count() (count int) {
return
}
// Distinct removes the duplicated items based on the given KeyFunc.
// Distinct removes the duplicated items base on the given KeyFunc.
func (s Stream) Distinct(fn KeyFunc) Stream {
source := make(chan any)
source := make(chan interface{})
threading.GoSafe(func() {
defer close(source)
keys := make(map[any]lang.PlaceholderType)
keys := make(map[interface{}]lang.PlaceholderType)
for item := range s.source {
key := fn(item)
if _, ok := keys[key]; !ok {
@@ -195,7 +195,7 @@ func (s Stream) Done() {
// Filter filters the items by the given FilterFunc.
func (s Stream) Filter(fn FilterFunc, opts ...Option) Stream {
return s.Walk(func(item any, pipe chan<- any) {
return s.Walk(func(item interface{}, pipe chan<- interface{}) {
if fn(item) {
pipe <- item
}
@@ -203,7 +203,7 @@ func (s Stream) Filter(fn FilterFunc, opts ...Option) Stream {
}
// First returns the first item, nil if no items.
func (s Stream) First() any {
func (s Stream) First() interface{} {
for item := range s.source {
// make sure the former goroutine not block, and current func returns fast.
go drain(s.source)
@@ -229,13 +229,13 @@ func (s Stream) ForEach(fn ForEachFunc) {
// Group groups the elements into different groups based on their keys.
func (s Stream) Group(fn KeyFunc) Stream {
groups := make(map[any][]any)
groups := make(map[interface{}][]interface{})
for item := range s.source {
key := fn(item)
groups[key] = append(groups[key], item)
}
source := make(chan any)
source := make(chan interface{})
go func() {
for _, group := range groups {
source <- group
@@ -252,7 +252,7 @@ func (s Stream) Head(n int64) Stream {
panic("n must be greater than 0")
}
source := make(chan any)
source := make(chan interface{})
go func() {
for item := range s.source {
@@ -279,7 +279,7 @@ func (s Stream) Head(n int64) Stream {
}
// Last returns the last item, or nil if no items.
func (s Stream) Last() (item any) {
func (s Stream) Last() (item interface{}) {
for item = range s.source {
}
return
@@ -287,53 +287,29 @@ func (s Stream) Last() (item any) {
// Map converts each item to another corresponding item, which means it's a 1:1 model.
func (s Stream) Map(fn MapFunc, opts ...Option) Stream {
return s.Walk(func(item any, pipe chan<- any) {
return s.Walk(func(item interface{}, pipe chan<- interface{}) {
pipe <- fn(item)
}, opts...)
}
// Max returns the maximum item from the underlying source.
func (s Stream) Max(less LessFunc) any {
var max any
for item := range s.source {
if max == nil || less(max, item) {
max = item
}
}
return max
}
// Merge merges all the items into a slice and generates a new stream.
func (s Stream) Merge() Stream {
var items []any
var items []interface{}
for item := range s.source {
items = append(items, item)
}
source := make(chan any, 1)
source := make(chan interface{}, 1)
source <- items
close(source)
return Range(source)
}
// Min returns the minimum item from the underlying source.
func (s Stream) Min(less LessFunc) any {
var min any
for item := range s.source {
if min == nil || less(item, min) {
min = item
}
}
return min
}
// NoneMatch returns whether all elements of this stream don't match the provided predicate.
// May not evaluate the predicate on all elements if not necessary for determining the result.
// If the stream is empty then true is returned and the predicate is not evaluated.
func (s Stream) NoneMatch(predicate func(item any) bool) bool {
func (s Stream) NoneMatch(predicate func(item interface{}) bool) bool {
for item := range s.source {
if predicate(item) {
// make sure the former goroutine not block, and current func returns fast.
@@ -347,19 +323,19 @@ func (s Stream) NoneMatch(predicate func(item any) bool) bool {
// Parallel applies the given ParallelFunc to each item concurrently with given number of workers.
func (s Stream) Parallel(fn ParallelFunc, opts ...Option) {
s.Walk(func(item any, pipe chan<- any) {
s.Walk(func(item interface{}, pipe chan<- interface{}) {
fn(item)
}, opts...).Done()
}
// Reduce is a utility method to let the caller deal with the underlying channel.
func (s Stream) Reduce(fn ReduceFunc) (any, error) {
// Reduce is an utility method to let the caller deal with the underlying channel.
func (s Stream) Reduce(fn ReduceFunc) (interface{}, error) {
return fn(s.source)
}
// Reverse reverses the elements in the stream.
func (s Stream) Reverse() Stream {
var items []any
var items []interface{}
for item := range s.source {
items = append(items, item)
}
@@ -381,7 +357,7 @@ func (s Stream) Skip(n int64) Stream {
return s
}
source := make(chan any)
source := make(chan interface{})
go func() {
for item := range s.source {
@@ -400,7 +376,7 @@ func (s Stream) Skip(n int64) Stream {
// Sort sorts the items from the underlying source.
func (s Stream) Sort(less LessFunc) Stream {
var items []any
var items []interface{}
for item := range s.source {
items = append(items, item)
}
@@ -418,9 +394,9 @@ func (s Stream) Split(n int) Stream {
panic("n should be greater than 0")
}
source := make(chan any)
source := make(chan interface{})
go func() {
var chunk []any
var chunk []interface{}
for item := range s.source {
chunk = append(chunk, item)
if len(chunk) == n {
@@ -443,7 +419,7 @@ func (s Stream) Tail(n int64) Stream {
panic("n should be greater than 0")
}
source := make(chan any)
source := make(chan interface{})
go func() {
ring := collection.NewRing(int(n))
@@ -459,7 +435,7 @@ func (s Stream) Tail(n int64) Stream {
return Range(source)
}
// Walk lets the callers handle each item, the caller may write zero, one or more items based on the given item.
// Walk lets the callers handle each item, the caller may write zero, one or more items base on the given item.
func (s Stream) Walk(fn WalkFunc, opts ...Option) Stream {
option := buildOptions(opts...)
if option.unlimitedWorkers {
@@ -470,7 +446,7 @@ func (s Stream) Walk(fn WalkFunc, opts ...Option) Stream {
}
func (s Stream) walkLimited(fn WalkFunc, option *rxOptions) Stream {
pipe := make(chan any, option.workers)
pipe := make(chan interface{}, option.workers)
go func() {
var wg sync.WaitGroup
@@ -501,7 +477,7 @@ func (s Stream) walkLimited(fn WalkFunc, option *rxOptions) Stream {
}
func (s Stream) walkUnlimited(fn WalkFunc, option *rxOptions) Stream {
pipe := make(chan any, option.workers)
pipe := make(chan interface{}, option.workers)
go func() {
var wg sync.WaitGroup
@@ -553,7 +529,7 @@ func buildOptions(opts ...Option) *rxOptions {
}
// drain drains the given channel.
func drain(channel <-chan any) {
func drain(channel <-chan interface{}) {
for range channel {
}
}

View File

@@ -1,6 +1,8 @@
package fx
import (
"io"
"log"
"math/rand"
"reflect"
"runtime"
@@ -11,7 +13,6 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx/logtest"
"github.com/zeromicro/go-zero/core/stringx"
"go.uber.org/goleak"
)
@@ -22,7 +23,7 @@ func TestBuffer(t *testing.T) {
var count int32
var wait sync.WaitGroup
wait.Add(1)
From(func(source chan<- any) {
From(func(source chan<- interface{}) {
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
@@ -35,7 +36,7 @@ func TestBuffer(t *testing.T) {
return
}
}
}).Buffer(N).ForAll(func(pipe <-chan any) {
}).Buffer(N).ForAll(func(pipe <-chan interface{}) {
wait.Wait()
// why N+1, because take one more to wait for sending into the channel
assert.Equal(t, int32(N+1), atomic.LoadInt32(&count))
@@ -46,7 +47,7 @@ func TestBuffer(t *testing.T) {
func TestBufferNegative(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var result int
Just(1, 2, 3, 4).Buffer(-1).Reduce(func(pipe <-chan any) (any, error) {
Just(1, 2, 3, 4).Buffer(-1).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe {
result += item.(int)
}
@@ -60,22 +61,22 @@ func TestCount(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
tests := []struct {
name string
elements []any
elements []interface{}
}{
{
name: "no elements with nil",
},
{
name: "no elements",
elements: []any{},
elements: []interface{}{},
},
{
name: "1 element",
elements: []any{1},
elements: []interface{}{1},
},
{
name: "multiple elements",
elements: []any{1, 2, 3},
elements: []interface{}{1, 2, 3},
},
}
@@ -91,7 +92,7 @@ func TestCount(t *testing.T) {
func TestDone(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var count int32
Just(1, 2, 3).Walk(func(item any, pipe chan<- any) {
Just(1, 2, 3).Walk(func(item interface{}, pipe chan<- interface{}) {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, int32(item.(int)))
}).Done()
@@ -102,7 +103,7 @@ func TestDone(t *testing.T) {
func TestJust(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var result int
Just(1, 2, 3, 4).Reduce(func(pipe <-chan any) (any, error) {
Just(1, 2, 3, 4).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe {
result += item.(int)
}
@@ -115,9 +116,9 @@ func TestJust(t *testing.T) {
func TestDistinct(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var result int
Just(4, 1, 3, 2, 3, 4).Distinct(func(item any) any {
Just(4, 1, 3, 2, 3, 4).Distinct(func(item interface{}) interface{} {
return item
}).Reduce(func(pipe <-chan any) (any, error) {
}).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe {
result += item.(int)
}
@@ -130,9 +131,9 @@ func TestDistinct(t *testing.T) {
func TestFilter(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var result int
Just(1, 2, 3, 4).Filter(func(item any) bool {
Just(1, 2, 3, 4).Filter(func(item interface{}) bool {
return item.(int)%2 == 0
}).Reduce(func(pipe <-chan any) (any, error) {
}).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe {
result += item.(int)
}
@@ -153,9 +154,9 @@ func TestFirst(t *testing.T) {
func TestForAll(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var result int
Just(1, 2, 3, 4).Filter(func(item any) bool {
Just(1, 2, 3, 4).Filter(func(item interface{}) bool {
return item.(int)%2 == 0
}).ForAll(func(pipe <-chan any) {
}).ForAll(func(pipe <-chan interface{}) {
for item := range pipe {
result += item.(int)
}
@@ -167,11 +168,11 @@ func TestForAll(t *testing.T) {
func TestGroup(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var groups [][]int
Just(10, 11, 20, 21).Group(func(item any) any {
Just(10, 11, 20, 21).Group(func(item interface{}) interface{} {
v := item.(int)
return v / 10
}).ForEach(func(item any) {
v := item.([]any)
}).ForEach(func(item interface{}) {
v := item.([]interface{})
var group []int
for _, each := range v {
group = append(group, each.(int))
@@ -190,7 +191,7 @@ func TestGroup(t *testing.T) {
func TestHead(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var result int
Just(1, 2, 3, 4).Head(2).Reduce(func(pipe <-chan any) (any, error) {
Just(1, 2, 3, 4).Head(2).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe {
result += item.(int)
}
@@ -203,7 +204,7 @@ func TestHead(t *testing.T) {
func TestHeadZero(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
assert.Panics(t, func() {
Just(1, 2, 3, 4).Head(0).Reduce(func(pipe <-chan any) (any, error) {
Just(1, 2, 3, 4).Head(0).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
return nil, nil
})
})
@@ -213,7 +214,7 @@ func TestHeadZero(t *testing.T) {
func TestHeadMore(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var result int
Just(1, 2, 3, 4).Head(6).Reduce(func(pipe <-chan any) (any, error) {
Just(1, 2, 3, 4).Head(6).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe {
result += item.(int)
}
@@ -237,21 +238,21 @@ func TestLast(t *testing.T) {
func TestMap(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
logtest.Discard(t)
log.SetOutput(io.Discard)
tests := []struct {
mapper MapFunc
expect int
}{
{
mapper: func(item any) any {
mapper: func(item interface{}) interface{} {
v := item.(int)
return v * v
},
expect: 30,
},
{
mapper: func(item any) any {
mapper: func(item interface{}) interface{} {
v := item.(int)
if v%2 == 0 {
return 0
@@ -261,7 +262,7 @@ func TestMap(t *testing.T) {
expect: 10,
},
{
mapper: func(item any) any {
mapper: func(item interface{}) interface{} {
v := item.(int)
if v%2 == 0 {
panic(v)
@@ -282,12 +283,12 @@ func TestMap(t *testing.T) {
} else {
workers = runtime.NumCPU()
}
From(func(source chan<- any) {
From(func(source chan<- interface{}) {
for i := 1; i < 5; i++ {
source <- i
}
}).Map(test.mapper, WithWorkers(workers)).Reduce(
func(pipe <-chan any) (any, error) {
func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe {
result += item.(int)
}
@@ -302,8 +303,8 @@ func TestMap(t *testing.T) {
func TestMerge(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
Just(1, 2, 3, 4).Merge().ForEach(func(item any) {
assert.ElementsMatch(t, []any{1, 2, 3, 4}, item.([]any))
Just(1, 2, 3, 4).Merge().ForEach(func(item interface{}) {
assert.ElementsMatch(t, []interface{}{1, 2, 3, 4}, item.([]interface{}))
})
})
}
@@ -311,7 +312,7 @@ func TestMerge(t *testing.T) {
func TestParallelJust(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var count int32
Just(1, 2, 3).Parallel(func(item any) {
Just(1, 2, 3).Parallel(func(item interface{}) {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, int32(item.(int)))
}, UnlimitedWorkers())
@@ -321,8 +322,8 @@ func TestParallelJust(t *testing.T) {
func TestReverse(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
Just(1, 2, 3, 4).Reverse().Merge().ForEach(func(item any) {
assert.ElementsMatch(t, []any{4, 3, 2, 1}, item.([]any))
Just(1, 2, 3, 4).Reverse().Merge().ForEach(func(item interface{}) {
assert.ElementsMatch(t, []interface{}{4, 3, 2, 1}, item.([]interface{}))
})
})
}
@@ -330,9 +331,9 @@ func TestReverse(t *testing.T) {
func TestSort(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var prev int
Just(5, 3, 7, 1, 9, 6, 4, 8, 2).Sort(func(a, b any) bool {
Just(5, 3, 7, 1, 9, 6, 4, 8, 2).Sort(func(a, b interface{}) bool {
return a.(int) < b.(int)
}).ForEach(func(item any) {
}).ForEach(func(item interface{}) {
next := item.(int)
assert.True(t, prev < next)
prev = next
@@ -345,12 +346,12 @@ func TestSplit(t *testing.T) {
assert.Panics(t, func() {
Just(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).Split(0).Done()
})
var chunks [][]any
Just(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).Split(4).ForEach(func(item any) {
chunk := item.([]any)
var chunks [][]interface{}
Just(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).Split(4).ForEach(func(item interface{}) {
chunk := item.([]interface{})
chunks = append(chunks, chunk)
})
assert.EqualValues(t, [][]any{
assert.EqualValues(t, [][]interface{}{
{1, 2, 3, 4},
{5, 6, 7, 8},
{9, 10},
@@ -361,7 +362,7 @@ func TestSplit(t *testing.T) {
func TestTail(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var result int
Just(1, 2, 3, 4).Tail(2).Reduce(func(pipe <-chan any) (any, error) {
Just(1, 2, 3, 4).Tail(2).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe {
result += item.(int)
}
@@ -374,7 +375,7 @@ func TestTail(t *testing.T) {
func TestTailZero(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
assert.Panics(t, func() {
Just(1, 2, 3, 4).Tail(0).Reduce(func(pipe <-chan any) (any, error) {
Just(1, 2, 3, 4).Tail(0).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
return nil, nil
})
})
@@ -384,11 +385,11 @@ func TestTailZero(t *testing.T) {
func TestWalk(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
var result int
Just(1, 2, 3, 4, 5).Walk(func(item any, pipe chan<- any) {
Just(1, 2, 3, 4, 5).Walk(func(item interface{}, pipe chan<- interface{}) {
if item.(int)%2 != 0 {
pipe <- item
}
}, UnlimitedWorkers()).ForEach(func(item any) {
}, UnlimitedWorkers()).ForEach(func(item interface{}) {
result += item.(int)
})
assert.Equal(t, 9, result)
@@ -397,16 +398,16 @@ func TestWalk(t *testing.T) {
func TestStream_AnyMach(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
assetEqual(t, false, Just(1, 2, 3).AnyMatch(func(item any) bool {
assetEqual(t, false, Just(1, 2, 3).AnyMach(func(item interface{}) bool {
return item.(int) == 4
}))
assetEqual(t, false, Just(1, 2, 3).AnyMatch(func(item any) bool {
assetEqual(t, false, Just(1, 2, 3).AnyMach(func(item interface{}) bool {
return item.(int) == 0
}))
assetEqual(t, true, Just(1, 2, 3).AnyMatch(func(item any) bool {
assetEqual(t, true, Just(1, 2, 3).AnyMach(func(item interface{}) bool {
return item.(int) == 2
}))
assetEqual(t, true, Just(1, 2, 3).AnyMatch(func(item any) bool {
assetEqual(t, true, Just(1, 2, 3).AnyMach(func(item interface{}) bool {
return item.(int) == 2
}))
})
@@ -415,17 +416,17 @@ func TestStream_AnyMach(t *testing.T) {
func TestStream_AllMach(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
assetEqual(
t, true, Just(1, 2, 3).AllMatch(func(item any) bool {
t, true, Just(1, 2, 3).AllMach(func(item interface{}) bool {
return true
}),
)
assetEqual(
t, false, Just(1, 2, 3).AllMatch(func(item any) bool {
t, false, Just(1, 2, 3).AllMach(func(item interface{}) bool {
return false
}),
)
assetEqual(
t, false, Just(1, 2, 3).AllMatch(func(item any) bool {
t, false, Just(1, 2, 3).AllMach(func(item interface{}) bool {
return item.(int) == 1
}),
)
@@ -435,17 +436,17 @@ func TestStream_AllMach(t *testing.T) {
func TestStream_NoneMatch(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
assetEqual(
t, true, Just(1, 2, 3).NoneMatch(func(item any) bool {
t, true, Just(1, 2, 3).NoneMatch(func(item interface{}) bool {
return false
}),
)
assetEqual(
t, false, Just(1, 2, 3).NoneMatch(func(item any) bool {
t, false, Just(1, 2, 3).NoneMatch(func(item interface{}) bool {
return true
}),
)
assetEqual(
t, true, Just(1, 2, 3).NoneMatch(func(item any) bool {
t, true, Just(1, 2, 3).NoneMatch(func(item interface{}) bool {
return item.(int) == 4
}),
)
@@ -454,19 +455,19 @@ func TestStream_NoneMatch(t *testing.T) {
func TestConcat(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
a1 := []any{1, 2, 3}
a2 := []any{4, 5, 6}
a1 := []interface{}{1, 2, 3}
a2 := []interface{}{4, 5, 6}
s1 := Just(a1...)
s2 := Just(a2...)
stream := Concat(s1, s2)
var items []any
var items []interface{}
for item := range stream.source {
items = append(items, item)
}
sort.Slice(items, func(i, j int) bool {
return items[i].(int) < items[j].(int)
})
ints := make([]any, 0)
ints := make([]interface{}, 0)
ints = append(ints, a1...)
ints = append(ints, a2...)
assetEqual(t, ints, items)
@@ -478,7 +479,7 @@ func TestStream_Skip(t *testing.T) {
assetEqual(t, 3, Just(1, 2, 3, 4).Skip(1).Count())
assetEqual(t, 1, Just(1, 2, 3, 4).Skip(3).Count())
assetEqual(t, 4, Just(1, 2, 3, 4).Skip(0).Count())
equal(t, Just(1, 2, 3, 4).Skip(3), []any{4})
equal(t, Just(1, 2, 3, 4).Skip(3), []interface{}{4})
assert.Panics(t, func() {
Just(1, 2, 3, 4).Skip(-1)
})
@@ -488,104 +489,27 @@ func TestStream_Skip(t *testing.T) {
func TestStream_Concat(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
stream := Just(1).Concat(Just(2), Just(3))
var items []any
var items []interface{}
for item := range stream.source {
items = append(items, item)
}
sort.Slice(items, func(i, j int) bool {
return items[i].(int) < items[j].(int)
})
assetEqual(t, []any{1, 2, 3}, items)
assetEqual(t, []interface{}{1, 2, 3}, items)
just := Just(1)
equal(t, just.Concat(just), []any{1})
})
}
func TestStream_Max(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
tests := []struct {
name string
elements []any
max any
}{
{
name: "no elements with nil",
},
{
name: "no elements",
elements: []any{},
max: nil,
},
{
name: "1 element",
elements: []any{1},
max: 1,
},
{
name: "multiple elements",
elements: []any{1, 2, 9, 5, 8},
max: 9,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
val := Just(test.elements...).Max(func(a, b any) bool {
return a.(int) < b.(int)
})
assetEqual(t, test.max, val)
})
}
})
}
func TestStream_Min(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
tests := []struct {
name string
elements []any
min any
}{
{
name: "no elements with nil",
min: nil,
},
{
name: "no elements",
elements: []any{},
min: nil,
},
{
name: "1 element",
elements: []any{1},
min: 1,
},
{
name: "multiple elements",
elements: []any{-1, 1, 2, 9, 5, 8},
min: -1,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
val := Just(test.elements...).Min(func(a, b any) bool {
return a.(int) < b.(int)
})
assetEqual(t, test.min, val)
})
}
equal(t, just.Concat(just), []interface{}{1})
})
}
func BenchmarkParallelMapReduce(b *testing.B) {
b.ReportAllocs()
mapper := func(v any) any {
mapper := func(v interface{}) interface{} {
return v.(int64) * v.(int64)
}
reducer := func(input <-chan any) (any, error) {
reducer := func(input <-chan interface{}) (interface{}, error) {
var result int64
for v := range input {
result += v.(int64)
@@ -593,7 +517,7 @@ func BenchmarkParallelMapReduce(b *testing.B) {
return result, nil
}
b.ResetTimer()
From(func(input chan<- any) {
From(func(input chan<- interface{}) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
input <- int64(rand.Int())
@@ -605,10 +529,10 @@ func BenchmarkParallelMapReduce(b *testing.B) {
func BenchmarkMapReduce(b *testing.B) {
b.ReportAllocs()
mapper := func(v any) any {
mapper := func(v interface{}) interface{} {
return v.(int64) * v.(int64)
}
reducer := func(input <-chan any) (any, error) {
reducer := func(input <-chan interface{}) (interface{}, error) {
var result int64
for v := range input {
result += v.(int64)
@@ -616,21 +540,21 @@ func BenchmarkMapReduce(b *testing.B) {
return result, nil
}
b.ResetTimer()
From(func(input chan<- any) {
From(func(input chan<- interface{}) {
for i := 0; i < b.N; i++ {
input <- int64(rand.Int())
}
}).Map(mapper).Reduce(reducer)
}
func assetEqual(t *testing.T, except, data any) {
func assetEqual(t *testing.T, except, data interface{}) {
if !reflect.DeepEqual(except, data) {
t.Errorf(" %v, want %v", data, except)
}
}
func equal(t *testing.T, stream Stream, data []any) {
items := make([]any, 0)
func equal(t *testing.T, stream Stream, data []interface{}) {
items := make([]interface{}, 0)
for item := range stream.source {
items = append(items, item)
}

Some files were not shown because too many files have changed in this diff Show More