mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-11 00:40:00 +08:00
Compare commits
414 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5b74b9ab7b | ||
|
|
4a67261b7b | ||
|
|
22bdae0787 | ||
|
|
e8675d6a9a | ||
|
|
e441c44975 | ||
|
|
3f91a79a2b | ||
|
|
8c47c01739 | ||
|
|
f59a1cb0de | ||
|
|
d44ff6ddc8 | ||
|
|
6ffa9cabec | ||
|
|
0069721586 | ||
|
|
ba9c275853 | ||
|
|
9a6447ab5c | ||
|
|
004995f06a | ||
|
|
c12c82b2f6 | ||
|
|
85d770d340 | ||
|
|
8cd7f7a2d8 | ||
|
|
db3101361b | ||
|
|
eb2302b71e | ||
|
|
04ed637366 | ||
|
|
567087a715 | ||
|
|
4d2e64a417 | ||
|
|
b01831b4c5 | ||
|
|
d1a014955c | ||
|
|
ec802e25a6 | ||
|
|
8a2e09dfd1 | ||
|
|
220d438fe7 | ||
|
|
2cd96146fa | ||
|
|
7e96317fad | ||
|
|
70728ce2e2 | ||
|
|
6a72a735d4 | ||
|
|
b139a82c2e | ||
|
|
bdddf1f30c | ||
|
|
9b74b7e09e | ||
|
|
4d5ed2c45d | ||
|
|
a2310bf9d7 | ||
|
|
be846eba01 | ||
|
|
b20f0e3d60 | ||
|
|
e2bb65d43c | ||
|
|
94e2f5bd12 | ||
|
|
173f76acf9 | ||
|
|
6e1af75635 | ||
|
|
84ff755e61 | ||
|
|
4b9d23aef5 | ||
|
|
97b9aebe99 | ||
|
|
8e7e5695eb | ||
|
|
4b4751e76c | ||
|
|
fcec494ea8 | ||
|
|
42117c2dcc | ||
|
|
4b631f3785 | ||
|
|
f29c8612e8 | ||
|
|
35ba024103 | ||
|
|
52df1c532a | ||
|
|
39729f3756 | ||
|
|
5c9ea81db2 | ||
|
|
b284664de4 | ||
|
|
1b76885040 | ||
|
|
eef217522b | ||
|
|
6bd0d169d5 | ||
|
|
3d291328d8 | ||
|
|
858f8ca82e | ||
|
|
4ff3975c5a | ||
|
|
7b23f73268 | ||
|
|
918a7be698 | ||
|
|
0a724447cd | ||
|
|
9e425893a7 | ||
|
|
4de13b6cc8 | ||
|
|
c6f75532fa | ||
|
|
fdf4ccf057 | ||
|
|
b333ed245b | ||
|
|
8f1576df36 | ||
|
|
72dd970969 | ||
|
|
29b65e12c1 | ||
|
|
577a611dc3 | ||
|
|
75941aedd4 | ||
|
|
c7065171d7 | ||
|
|
052de3b552 | ||
|
|
866613af8c | ||
|
|
3d4f6a5e16 | ||
|
|
d1d47d02d5 | ||
|
|
d6c876860b | ||
|
|
98423ca948 | ||
|
|
4e52d77ad8 | ||
|
|
1fc2cfb859 | ||
|
|
942cdae41d | ||
|
|
e9c3607bc6 | ||
|
|
d1603e9166 | ||
|
|
e30317e9c4 | ||
|
|
568f9ce007 | ||
|
|
dcb309065a | ||
|
|
bf8e17a686 | ||
|
|
b2ebbfce62 | ||
|
|
2b10a6a223 | ||
|
|
80c320b46e | ||
|
|
bea9d150a1 | ||
|
|
3f756a2cbf | ||
|
|
bbe5bbb0c0 | ||
|
|
5ad2278a69 | ||
|
|
77763fe748 | ||
|
|
538c4fb5c7 | ||
|
|
315fb2fe0a | ||
|
|
e382887eb8 | ||
|
|
cf21cb2b0b | ||
|
|
61e8894c31 | ||
|
|
7a6c3c8129 | ||
|
|
875fec3e1a | ||
|
|
60128c2100 | ||
|
|
ce6d0e3ea7 | ||
|
|
fa85c84af3 | ||
|
|
440884105e | ||
|
|
271f10598f | ||
|
|
cf55a88ce3 | ||
|
|
c1c786b14a | ||
|
|
988fb9d9bf | ||
|
|
d212c81bca | ||
|
|
bc43df2641 | ||
|
|
351b8cb37b | ||
|
|
0d681a2e29 | ||
|
|
5ea027c5de | ||
|
|
5de6112dcd | ||
|
|
4fb51723b7 | ||
|
|
06502d1115 | ||
|
|
3854d6dd00 | ||
|
|
895854913a | ||
|
|
ef753b8857 | ||
|
|
9c16fede73 | ||
|
|
ce11adb5e4 | ||
|
|
894e8b1218 | ||
|
|
2ec7e432dd | ||
|
|
870e8352c1 | ||
|
|
de42f27e03 | ||
|
|
955b8016aa | ||
|
|
d728a3b2d9 | ||
|
|
0c205a71fc | ||
|
|
a8c0199d96 | ||
|
|
032a266ec4 | ||
|
|
40b75fbb9b | ||
|
|
afad55045b | ||
|
|
5f54f06ee5 | ||
|
|
20f56ae1d0 | ||
|
|
73d6fcfccd | ||
|
|
20d20ef861 | ||
|
|
a37422b504 | ||
|
|
a81d898408 | ||
|
|
a5d42e20d5 | ||
|
|
4bdb07f225 | ||
|
|
3e6ec9b83d | ||
|
|
f0a3d213dc | ||
|
|
94562ded74 | ||
|
|
d68cf4920c | ||
|
|
31b749ab67 | ||
|
|
3834319278 | ||
|
|
1c9d339361 | ||
|
|
b7f601c912 | ||
|
|
1ebbc6f0c7 | ||
|
|
b41b1b00df | ||
|
|
f36e5fed35 | ||
|
|
2583673c8b | ||
|
|
00e67b9d20 | ||
|
|
9fd1f29845 | ||
|
|
130e1ba963 | ||
|
|
a2b98dbcf7 | ||
|
|
b46d507a1d | ||
|
|
3152581d0d | ||
|
|
46e466f037 | ||
|
|
151b3d1085 | ||
|
|
ea53fe41de | ||
|
|
d9df08b079 | ||
|
|
569c00ad09 | ||
|
|
9da76fbf04 | ||
|
|
b69db5e09d | ||
|
|
ee6b7cee79 | ||
|
|
d150248c52 | ||
|
|
610a7345dc | ||
|
|
b0b31f3993 | ||
|
|
82a937d517 | ||
|
|
93c11a7eb7 | ||
|
|
63ec989376 | ||
|
|
bf75027889 | ||
|
|
d505fae979 | ||
|
|
25f37ca750 | ||
|
|
0be63c3625 | ||
|
|
b011a072c7 | ||
|
|
3c9b6335fb | ||
|
|
bf6ef5f033 | ||
|
|
ff890628b0 | ||
|
|
cc79e3d842 | ||
|
|
f11b78ced9 | ||
|
|
1d2b0d7ab8 | ||
|
|
da987e1270 | ||
|
|
12e03c8843 | ||
|
|
8cf4f95bd7 | ||
|
|
ba0febf308 | ||
|
|
c9ff6a10d3 | ||
|
|
a71e56de52 | ||
|
|
bae8d4f4c8 | ||
|
|
8c6266f338 | ||
|
|
95d5b81f44 | ||
|
|
bca7bbc142 | ||
|
|
df9a52664b | ||
|
|
937cf0db96 | ||
|
|
75cebb65f8 | ||
|
|
410f56e73a | ||
|
|
017909a3ab | ||
|
|
0d31e6c375 | ||
|
|
0ba86b1849 | ||
|
|
4cacc4d9d3 | ||
|
|
a99c14da4a | ||
|
|
985582264a | ||
|
|
8364e341e1 | ||
|
|
0f2b589d4d | ||
|
|
19fec36d24 | ||
|
|
f037bf344d | ||
|
|
d99cf35b07 | ||
|
|
f459f1b5ff | ||
|
|
0140fd417b | ||
|
|
7969e0ca38 | ||
|
|
91c885b5b0 | ||
|
|
d4cccca387 | ||
|
|
4b2095ed03 | ||
|
|
1229eeb2d2 | ||
|
|
9142b146c5 | ||
|
|
8a1b2d5aed | ||
|
|
da5d39e6ca | ||
|
|
68c5a17c67 | ||
|
|
b53f9f5f2d | ||
|
|
36d57626b6 | ||
|
|
4e36ba832f | ||
|
|
a44954a771 | ||
|
|
f3edd4b880 | ||
|
|
2de3e397ff | ||
|
|
a435eb56f2 | ||
|
|
d80761c147 | ||
|
|
e7bd0d8b60 | ||
|
|
b109b3ef4c | ||
|
|
e3c371ac89 | ||
|
|
15eb6f4f6d | ||
|
|
4d3681b71c | ||
|
|
a682bda0bb | ||
|
|
45b27ad93a | ||
|
|
292a8302a1 | ||
|
|
91ab1f6d2b | ||
|
|
5048c350ae | ||
|
|
94edc32f3e | ||
|
|
ec989b2e2a | ||
|
|
82fe802e81 | ||
|
|
072d68f897 | ||
|
|
2e91ba5811 | ||
|
|
5564c43197 | ||
|
|
e55158b0f7 | ||
|
|
69aa7fe346 | ||
|
|
c3820a95c1 | ||
|
|
493f3bad0f | ||
|
|
eb0d5ad3a4 | ||
|
|
14192050ae | ||
|
|
9193e771e3 | ||
|
|
808b4e496a | ||
|
|
e416d01f8d | ||
|
|
789c5de873 | ||
|
|
52078a0c14 | ||
|
|
7ef13116a0 | ||
|
|
6b8053410a | ||
|
|
81c6928445 | ||
|
|
761c2dd716 | ||
|
|
aeceb3cfbe | ||
|
|
15ea07aad1 | ||
|
|
98bebbc74f | ||
|
|
eafd11d949 | ||
|
|
b251ce346e | ||
|
|
812140ba36 | ||
|
|
44735e949c | ||
|
|
bf313c3c56 | ||
|
|
94e7753262 | ||
|
|
9c478626d2 | ||
|
|
801c283478 | ||
|
|
2a54faf997 | ||
|
|
ecd98f3653 | ||
|
|
61641581eb | ||
|
|
6f2730d5ae | ||
|
|
0eff777b62 | ||
|
|
cafbf535f7 | ||
|
|
6edfce63e3 | ||
|
|
cdb0098b18 | ||
|
|
620c7f9693 | ||
|
|
dba444a382 | ||
|
|
b24fb3ebf7 | ||
|
|
967f0926eb | ||
|
|
e68c683df9 | ||
|
|
247985a065 | ||
|
|
80573af0d8 | ||
|
|
c0394b631a | ||
|
|
68d1aba377 | ||
|
|
3315e60272 | ||
|
|
327ef73700 | ||
|
|
eb11521655 | ||
|
|
4c37545e55 | ||
|
|
2f47c1fba4 | ||
|
|
16d54d0ace | ||
|
|
9925bcbf99 | ||
|
|
38a5ecb796 | ||
|
|
af78fc7c5f | ||
|
|
790302b486 | ||
|
|
6a0672b801 | ||
|
|
560c61612c | ||
|
|
6a988dc4a9 | ||
|
|
15842c3c7a | ||
|
|
f2914a74df | ||
|
|
f113d512e8 | ||
|
|
7a4818da59 | ||
|
|
48d0709ca6 | ||
|
|
f747585518 | ||
|
|
507ff96546 | ||
|
|
651eabb4c6 | ||
|
|
e6b4372056 | ||
|
|
24073969a1 | ||
|
|
ca797ed22c | ||
|
|
e347d3f8f8 | ||
|
|
396393b336 | ||
|
|
1f0531b254 | ||
|
|
77fb271a06 | ||
|
|
af7cf79963 | ||
|
|
7926d396d7 | ||
|
|
080cd3df84 | ||
|
|
c4e1a6a2d8 | ||
|
|
4e71e95e44 | ||
|
|
84db9bcd15 | ||
|
|
b28f79ac11 | ||
|
|
e134e77b2b | ||
|
|
f669d84ce8 | ||
|
|
9213b8ac27 | ||
|
|
ae09d0e56d | ||
|
|
0bc4206d08 | ||
|
|
39ce17bfd2 | ||
|
|
d415ba39e2 | ||
|
|
c71829c8de | ||
|
|
a32f6d7642 | ||
|
|
64e8c94198 | ||
|
|
7d05a4bc93 | ||
|
|
44504e8df7 | ||
|
|
114311e51b | ||
|
|
4307ce45fc | ||
|
|
37b54d1fc7 | ||
|
|
00e0db5def | ||
|
|
cbcacf31c1 | ||
|
|
238c92aaa9 | ||
|
|
520d2a2075 | ||
|
|
1023800b02 | ||
|
|
030c859171 | ||
|
|
e6d1b47a43 | ||
|
|
6138f85470 | ||
|
|
bf883101d7 | ||
|
|
33011c7ed1 | ||
|
|
17d98f69e0 | ||
|
|
b650c8c425 | ||
|
|
3d931d7030 | ||
|
|
68da9ed51a | ||
|
|
b25c45b352 | ||
|
|
f05234a967 | ||
|
|
12071d17b4 | ||
|
|
11c47d23df | ||
|
|
024f285f86 | ||
|
|
fa4674611a | ||
|
|
730c3c5246 | ||
|
|
2c9310ac3a | ||
|
|
74ba0bcd50 | ||
|
|
5f4190b6c6 | ||
|
|
e1787b4ccb | ||
|
|
4ac8b492ef | ||
|
|
cdd068575c | ||
|
|
e89e2d8a75 | ||
|
|
acd2b94bd9 | ||
|
|
6a0c8047f4 | ||
|
|
cfe03ea9e1 | ||
|
|
48d21ef8ad | ||
|
|
28a001c5f9 | ||
|
|
22a41cacc7 | ||
|
|
fcc246933c | ||
|
|
5c3679ffe7 | ||
|
|
eaa01ccb9f | ||
|
|
b8206fb46a | ||
|
|
1c3876810e | ||
|
|
1d9159ea39 | ||
|
|
2159d112c3 | ||
|
|
f57874a51f | ||
|
|
8625864d43 | ||
|
|
8f9ba3ec11 | ||
|
|
a1d9bc08f0 | ||
|
|
b9d7f1cc77 | ||
|
|
6700910f64 | ||
|
|
9c4ed394a7 | ||
|
|
fd07a9c6e4 | ||
|
|
b8c239630c | ||
|
|
672ea55736 | ||
|
|
f7097866bf | ||
|
|
796b2bd1b0 | ||
|
|
e1e5fb2071 | ||
|
|
89ecb50005 | ||
|
|
dbed1ea042 | ||
|
|
ad291daf78 | ||
|
|
13746a3706 | ||
|
|
f03b13f632 | ||
|
|
f6f64b1286 | ||
|
|
300a415f5d | ||
|
|
c5de546f8a | ||
|
|
cad243905f | ||
|
|
7c8f41d577 | ||
|
|
cbd118d55f | ||
|
|
9d2a1b8b0a | ||
|
|
f6ada979aa | ||
|
|
53a74759a5 | ||
|
|
1940f7bd58 | ||
|
|
18cb3141ba | ||
|
|
f822c9a94f | ||
|
|
1a3dc75874 |
13
.codecov.yml
13
.codecov.yml
@@ -1,13 +0,0 @@
|
||||
coverage:
|
||||
status:
|
||||
patch: true
|
||||
project: false # disabled because project coverage is not stable
|
||||
comment:
|
||||
layout: "flags, files"
|
||||
behavior: once
|
||||
require_changes: true
|
||||
ignore:
|
||||
- "tools"
|
||||
- "**/mock"
|
||||
- "**/*_mock.go"
|
||||
- "**/*test"
|
||||
344
.github/copilot-instructions.md
vendored
Normal file
344
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,344 @@
|
||||
# GitHub Copilot Instructions for go-zero
|
||||
|
||||
This document provides guidelines for GitHub Copilot when assisting with development in the go-zero project.
|
||||
|
||||
## Project Overview
|
||||
|
||||
go-zero is a web and RPC framework with lots of built-in engineering practices designed to ensure the stability of busy services with resilience design. It has been serving sites with tens of millions of users for years.
|
||||
|
||||
### Key Architecture Components
|
||||
|
||||
- **REST API framework** (`rest/`) - HTTP service framework with middleware chain support
|
||||
- **RPC framework** (`zrpc/`) - gRPC-based RPC framework with etcd service discovery and p2c_ewma load balancing
|
||||
- **Gateway** (`gateway/`) - API gateway supporting both HTTP and gRPC upstreams with proto-based routing
|
||||
- **MCP Server** (`mcp/`) - Model Context Protocol server for AI agent integration via SSE
|
||||
- **Core utilities** (`core/`) - Production-grade components:
|
||||
- Resilience: circuit breakers (`breaker/`), rate limiters (`limit/`), adaptive load shedding (`load/`)
|
||||
- Storage: SQL with cache (`stores/sqlc/`), Redis (`stores/redis/`), MongoDB (`stores/mongo/`)
|
||||
- Concurrency: MapReduce (`mr/`), worker pools (`executors/`), sync primitives (`syncx/`)
|
||||
- Observability: metrics (`metric/`), tracing (`trace/`), structured logging (`logx/`)
|
||||
- **Code generation tool** (`tools/goctl/`) - CLI tool for generating Go code from `.api` and `.proto` files
|
||||
|
||||
## Coding Standards and Conventions
|
||||
|
||||
### Code Style
|
||||
|
||||
1. **Follow Go conventions**: Use `gofmt` for formatting, follow effective Go practices
|
||||
2. **Package naming**: Use lowercase, single-word package names when possible
|
||||
3. **Error handling**: Always handle errors explicitly, use `errorx.BatchError` for multiple errors
|
||||
4. **Context propagation**: Always pass `context.Context` as the first parameter for functions that may block
|
||||
5. **Configuration structures**: Use struct tags with JSON annotations, defaults, and validation
|
||||
|
||||
**Pattern**: All service configs embed `service.ServiceConf` for common fields (Name, Log, Mode, Telemetry)
|
||||
```go
|
||||
type Config struct {
|
||||
service.ServiceConf // Always embed for services
|
||||
Host string `json:",default=0.0.0.0"`
|
||||
Port int // Required field (no default)
|
||||
Timeout int64 `json:",default=3000"` // Timeouts in milliseconds
|
||||
Optional string `json:",optional"` // Optional field
|
||||
Mode string `json:",default=pro,options=dev|test|rt|pre|pro"` // Validated options
|
||||
}
|
||||
```
|
||||
|
||||
**Service modes**: `dev`/`test`/`rt` disable load shedding and stats; `pre`/`pro` enable all resilience features
|
||||
|
||||
### Interface Design
|
||||
|
||||
1. **Small interfaces**: Follow Go's preference for small, focused interfaces
|
||||
2. **Context methods**: Provide both context and non-context versions of methods
|
||||
3. **Options pattern**: Use functional options for complex configuration
|
||||
|
||||
Example:
|
||||
```go
|
||||
func (c *Client) Get(key string, val any) error {
|
||||
return c.GetCtx(context.Background(), key, val)
|
||||
}
|
||||
|
||||
func (c *Client) GetCtx(ctx context.Context, key string, val any) error {
|
||||
// implementation
|
||||
}
|
||||
```
|
||||
|
||||
### Testing Patterns
|
||||
|
||||
1. **Test file naming**: Use `*_test.go` suffix
|
||||
2. **Test function naming**: Use `TestFunctionName` pattern
|
||||
3. **Use testify/assert**: Prefer `assert` package for assertions
|
||||
4. **Table-driven tests**: Use table-driven tests for multiple scenarios
|
||||
5. **Mock interfaces**: Use `go.uber.org/mock` for mocking
|
||||
6. **Test helpers**: Use `redistest`, `mongtest` helpers for database testing
|
||||
|
||||
Example test pattern:
|
||||
```go
|
||||
func TestSomething(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
wantErr bool
|
||||
}{
|
||||
{"valid case", "input", "output", false},
|
||||
{"error case", "bad", "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := SomeFunction(tt.input)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Framework-Specific Guidelines
|
||||
|
||||
### REST API Development
|
||||
|
||||
1. **API Definition**: Use `.api` files to define REST APIs with goctl codegen
|
||||
2. **Handler pattern**: Separate business logic into logic packages (handlers call logic layer)
|
||||
3. **Middleware chain**: Middlewares wrap via `chain.Chain` interface - use `Append()` or `Prepend()` to control order
|
||||
- Built-in middlewares (all in `rest/handler/`): tracing, logging, metrics, recovery, breaker, shedding, timeout, maxconns, maxbytes, gunzip
|
||||
- Custom middleware: `func(http.Handler) http.Handler` - call `next.ServeHTTP(w, r)` to continue chain
|
||||
4. **Response handling**: Use `httpx.WriteJson(w, code, v)` for JSON responses
|
||||
5. **Error handling**: Use `httpx.Error(w, err)` or `httpx.ErrorCtx(ctx, w, err)` for HTTP error responses
|
||||
6. **Route registration**: Routes defined with `Method`, `Path`, and `Handler` - wildcards use `:param` syntax
|
||||
|
||||
### RPC Development
|
||||
|
||||
1. **Protocol Buffers**: Use protobuf for service definitions, generate code with goctl
|
||||
2. **Service discovery**: Use etcd for dynamic service registration/discovery, or direct endpoints for static routing
|
||||
3. **Load balancing**: Default is `p2c_ewma` (power of 2 choices with EWMA), configurable via `BalancerName`
|
||||
4. **Client configuration**: Support `Etcd`, `Endpoints`, or `Target` - use `BuildTarget()` to construct connection string
|
||||
5. **Interceptors**: Implement gRPC interceptors for cross-cutting concerns (auth, logging, metrics)
|
||||
6. **Health checks**: gRPC health checks enabled by default (`Health: true`)
|
||||
|
||||
### Database Operations
|
||||
|
||||
1. **SQL operations**: Use `sqlx.SqlConn` interface - methods always end with `Ctx` for context support
|
||||
2. **Caching pattern**: `stores/sqlc` provides `CachedConn` for automatic cache-aside pattern
|
||||
- `QueryRowCtx`: Query with cache key, auto-populate on cache miss
|
||||
- `ExecCtx`: Execute and delete cache keys
|
||||
3. **Transactions**: Use `sqlx.SqlConn.TransactCtx()` to get transaction session
|
||||
4. **Connection pooling**: Managed automatically (64 max idle/open, 1min lifetime)
|
||||
5. **Test helpers**: Use `redistest.CreateRedis(t)` for Redis, SQL mocks for DB testing
|
||||
|
||||
Example cache pattern:
|
||||
```go
|
||||
err := c.QueryRowCtx(ctx, &dest, key, func(ctx context.Context, conn sqlx.SqlConn) error {
|
||||
return conn.QueryRowCtx(ctx, &dest, query, args...)
|
||||
})
|
||||
```
|
||||
|
||||
### Configuration Management
|
||||
|
||||
1. **YAML configuration**: Use YAML for configuration files
|
||||
2. **Environment variables**: Support environment variable overrides
|
||||
3. **Validation**: Include proper validation for configuration parameters
|
||||
4. **Sensible defaults**: Provide reasonable default values
|
||||
|
||||
## Error Handling Best Practices
|
||||
|
||||
1. **Wrap errors**: Use `fmt.Errorf` with `%w` verb to wrap errors
|
||||
2. **Custom errors**: Define custom error types when needed
|
||||
3. **Error logging**: Log errors appropriately with context
|
||||
4. **Graceful degradation**: Implement fallback mechanisms
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
1. **Resource pools**: Use connection pools and worker pools
|
||||
2. **Circuit breakers**: Implement circuit breaker patterns for external calls
|
||||
3. **Rate limiting**: Apply rate limiting to protect services
|
||||
4. **Load shedding**: Implement adaptive load shedding
|
||||
5. **Metrics**: Add appropriate metrics and monitoring
|
||||
|
||||
## Security Guidelines
|
||||
|
||||
1. **Input validation**: Validate all input parameters
|
||||
2. **SQL injection prevention**: Use parameterized queries
|
||||
3. **Authentication**: Implement proper JWT token handling
|
||||
4. **HTTPS**: Support TLS/HTTPS configurations
|
||||
5. **CORS**: Configure CORS appropriately for web APIs
|
||||
|
||||
## Documentation Standards
|
||||
|
||||
1. **Package documentation**: Include package-level documentation
|
||||
2. **Function documentation**: Document exported functions with examples
|
||||
3. **API documentation**: Maintain API documentation in sync
|
||||
4. **README updates**: Update README for significant changes
|
||||
|
||||
## GitHub Issue Management
|
||||
|
||||
### Understanding and Categorizing Issues
|
||||
|
||||
When analyzing GitHub issues, consider these common categories:
|
||||
|
||||
1. **Bug Reports**: Stack traces, version info, reproduction steps
|
||||
2. **Feature Requests**: Use case, proposed solution, alternatives
|
||||
3. **Questions**: Usage, configuration, or architecture
|
||||
4. **Documentation Issues**: Missing, unclear, or incorrect docs
|
||||
5. **Performance Issues**: Benchmarks, profiling data, resource usage
|
||||
|
||||
### Issue Analysis Checklist
|
||||
|
||||
- Identify affected component (REST, RPC, Gateway, MCP, Core utilities, goctl)
|
||||
- Check versions (go-zero, Go)
|
||||
- Look for reproduction steps or code examples
|
||||
- Review code snippets, logs, or stack traces
|
||||
- Check if related to resilience features (breaker, load shedding, rate limiting)
|
||||
- Determine production impact
|
||||
|
||||
### Responding to Issues
|
||||
|
||||
Be helpful and professional. Ask clarifying questions when needed. Reference relevant documentation and code files. Provide code examples following project conventions. Suggest workarounds when applicable.
|
||||
|
||||
### Chinese to English Translation
|
||||
|
||||
go-zero has an international user base. When encountering issues or comments written in Chinese, translate them to English to ensure all contributors can participate in discussions.
|
||||
|
||||
#### Translation Guidelines
|
||||
|
||||
1. **Update issue titles**: Edit the issue title to include English translation only
|
||||
2. **Translate comments in place**: Add a comment with the English translation, followed by the original Chinese text
|
||||
3. **Keep original Chinese**: After translating, include the original Chinese text in a blockquote for verification
|
||||
4. **Encourage English communication**: Politely suggest users write in English for better collaboration
|
||||
5. **Maintain technical accuracy**: Preserve technical terms, component names, and code exactly
|
||||
6. **Translate naturally**: Avoid literal word-by-word translation; use idiomatic English
|
||||
7. **Preserve formatting**: Keep markdown formatting, code blocks, and links intact
|
||||
8. **Keep URLs unchanged**: Don't translate URLs or file paths
|
||||
|
||||
#### Common Technical Terms (Chinese → English)
|
||||
|
||||
- 框架 → **Framework** | 中间件 → **Middleware** | 负载均衡 → **Load Balancing**
|
||||
- 熔断器 → **Circuit Breaker** | 限流 → **Rate Limiting** | 降载/过载保护 → **Load Shedding**
|
||||
- 服务发现 → **Service Discovery** | 配置 → **Configuration** | 弹性/容错 → **Resilience** | 微服务 → **Microservices**
|
||||
|
||||
#### Translation Example
|
||||
|
||||
**Original Chinese Title:** `goctl 执行环境问题`
|
||||
**Updated Title:** `goctl Execution Environment Issue`
|
||||
|
||||
**Original Chinese Comment:** `我在项目中遇到熔断器配置问题`
|
||||
**Translation in Comment:**
|
||||
```markdown
|
||||
I encountered a circuit breaker configuration issue in my project.
|
||||
|
||||
> Original (原文): 我在项目中遇到熔断器配置问题
|
||||
```
|
||||
|
||||
### Common Issue Patterns and Solutions
|
||||
|
||||
#### Configuration Issues
|
||||
- Check `service.ServiceConf` embedding and struct tags
|
||||
- Verify YAML syntax, defaults, and validation rules
|
||||
- Reference: [rest/config.go](rest/config.go), [zrpc/config.go](zrpc/config.go)
|
||||
|
||||
#### Code Generation (goctl) Issues
|
||||
- Verify `.api` or `.proto` file syntax and goctl version
|
||||
- Reference: `tools/goctl/` directory
|
||||
|
||||
#### RPC Connection Issues
|
||||
- Check etcd configuration, service discovery, and endpoints
|
||||
- Verify load balancing settings (p2c_ewma)
|
||||
|
||||
#### Database/Cache Issues
|
||||
- Verify `sqlx.SqlConn` usage with context
|
||||
- Check cache key generation, invalidation, and connection pools
|
||||
- Use test helpers (`redistest`, `mongtest`)
|
||||
|
||||
#### Performance Issues
|
||||
- Check if load shedding is enabled (mode: `pre`/`pro`)
|
||||
- Review circuit breaker thresholds, rate limiting, and context timeouts
|
||||
|
||||
### Referencing Codebase
|
||||
|
||||
When explaining issues, reference specific files and patterns:
|
||||
- REST API: `rest/`, `rest/handler/`, `rest/httpx/`
|
||||
- RPC: `zrpc/`, `zrpc/internal/`
|
||||
- Core utilities: `core/breaker/`, `core/limit/`, `core/load/`, etc.
|
||||
- Gateway: `gateway/`
|
||||
- MCP: `mcp/`
|
||||
- Code generation: `tools/goctl/`
|
||||
- Examples: `adhoc/` directory contains various examples
|
||||
|
||||
### Encouraging Best Practices
|
||||
|
||||
When responding to issues, gently guide users toward:
|
||||
- Proper error handling with context
|
||||
- Using resilience features (breakers, rate limiters)
|
||||
- Following testing patterns with table-driven tests
|
||||
- Implementing proper resource cleanup
|
||||
- Reading existing documentation in `docs/` and `readme.md`
|
||||
|
||||
## Common Patterns to Follow
|
||||
|
||||
### Service Configuration
|
||||
```go
|
||||
type ServiceConf struct {
|
||||
Name string
|
||||
Log logx.LogConf
|
||||
Mode string `json:",default=pro,options=[dev,test,pre,pro]"`
|
||||
// ... other common fields
|
||||
}
|
||||
```
|
||||
|
||||
### Middleware Implementation
|
||||
```go
|
||||
func SomeMiddleware() rest.Middleware {
|
||||
return func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Pre-processing
|
||||
next.ServeHTTP(w, r)
|
||||
// Post-processing
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Resource Management
|
||||
Always implement proper resource cleanup using defer and context cancellation.
|
||||
|
||||
## Build and Test Commands
|
||||
|
||||
- Build: `go build ./...`
|
||||
- Test: `go test ./...`
|
||||
- Test with race detection: `go test -race ./...`
|
||||
- Format: `gofmt -w .`
|
||||
- Code generation:
|
||||
- REST API: `goctl api go -api *.api -dir .`
|
||||
- RPC: `goctl rpc protoc *.proto --go_out=. --go-grpc_out=. --zrpc_out=.`
|
||||
- Model from SQL: `goctl model mysql datasource -url="user:pass@tcp(host:port)/db" -table="*" -dir="./model"`
|
||||
|
||||
## Critical Architecture Patterns
|
||||
|
||||
### Resilience Design Philosophy
|
||||
go-zero implements defense-in-depth with multiple protection layers:
|
||||
1. **Circuit Breaker** (`core/breaker`): Google SRE breaker - tracks success/failure, opens on error threshold
|
||||
2. **Adaptive Load Shedding** (`core/load`): CPU-based auto-rejection when system overloaded (disabled in dev/test/rt modes)
|
||||
3. **Rate Limiting** (`core/limit`): Token bucket (Redis-based) and period limiters
|
||||
4. **Timeout Control**: Cascading timeouts via context - set at multiple levels (client, server, handler)
|
||||
|
||||
### Middleware Chain Architecture
|
||||
`rest/chain` provides middleware composition:
|
||||
```go
|
||||
// Middleware signature
|
||||
type Middleware func(http.Handler) http.Handler
|
||||
|
||||
// Chain operations
|
||||
chain := chain.New(m1, m2)
|
||||
chain.Append(m3) // Adds to end: m1 -> m2 -> m3
|
||||
chain.Prepend(m0) // Adds to start: m0 -> m1 -> m2 -> m3
|
||||
handler := chain.Then(finalHandler)
|
||||
```
|
||||
|
||||
### Concurrency Patterns
|
||||
- **MapReduce** (`core/mr`): Parallel processing with worker pools - use for batch operations
|
||||
- **Executors** (`core/executors`): Bulk/period executors for batching operations
|
||||
- **SingleFlight** (`core/syncx`): Deduplicates concurrent identical requests
|
||||
|
||||
Remember to run tests and ensure all checks pass before submitting changes. The project emphasizes high quality, performance, and reliability, so these should be primary considerations in all development work.
|
||||
8
.github/workflows/codeql-analysis.yml
vendored
8
.github/workflows/codeql-analysis.yml
vendored
@@ -35,11 +35,11 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v3
|
||||
uses: github/codeql-action/init@v4
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
# If you wish to specify custom queries, you can do so here or in a config file.
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
|
||||
# If this step fails, then you should remove it and run the build manually (see below)
|
||||
- name: Autobuild
|
||||
uses: github/codeql-action/autobuild@v3
|
||||
uses: github/codeql-action/autobuild@v4
|
||||
|
||||
# ℹ️ Command-line programs to run using the OS shell.
|
||||
# 📚 https://git.io/JvXDl
|
||||
@@ -64,4 +64,4 @@ jobs:
|
||||
# make release
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v3
|
||||
uses: github/codeql-action/analyze@v4
|
||||
|
||||
15
.github/workflows/go.yml
vendored
15
.github/workflows/go.yml
vendored
@@ -12,10 +12,10 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
check-latest: true
|
||||
@@ -40,17 +40,22 @@ jobs:
|
||||
run: go test -race -coverprofile=coverage.txt -covermode=atomic ./...
|
||||
|
||||
- name: Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
uses: codecov/codecov-action@v6
|
||||
with:
|
||||
files: ./coverage.txt
|
||||
flags: unittests
|
||||
name: codecov-umbrella
|
||||
fail_ci_if_error: false
|
||||
|
||||
test-win:
|
||||
name: Windows
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout codebase
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
# make sure Go version compatible with go-zero
|
||||
go-version-file: go.mod
|
||||
|
||||
18
.github/workflows/issue-translator.yml
vendored
18
.github/workflows/issue-translator.yml
vendored
@@ -1,18 +0,0 @@
|
||||
name: 'issue-translator'
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
issues:
|
||||
types: [opened]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: usthe/issues-translate-action@v2.7
|
||||
with:
|
||||
IS_MODIFY_TITLE: true
|
||||
# not require, default false, . Decide whether to modify the issue title
|
||||
# if true, the robot account @Issues-translate-bot must have modification permissions, invite @Issues-translate-bot to your project or use your custom bot.
|
||||
CUSTOM_BOT_NOTE: Bot detected the issue body's language is not English, translate it automatically. 👯👭🏻🧑🤝🧑👫🧑🏿🤝🧑🏻👩🏾🤝👨🏿👬🏿
|
||||
# not require. Customize the translation robot prefix message.
|
||||
2
.github/workflows/issues.yml
vendored
2
.github/workflows/issues.yml
vendored
@@ -7,7 +7,7 @@ jobs:
|
||||
close-issues:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
days-before-issue-stale: 365
|
||||
days-before-issue-close: 90
|
||||
|
||||
6
.github/workflows/release.yaml
vendored
6
.github/workflows/release.yaml
vendored
@@ -16,13 +16,13 @@ jobs:
|
||||
- goarch: "386"
|
||||
goos: darwin
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
- uses: zeromicro/go-zero-release-action@master
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
goos: ${{ matrix.goos }}
|
||||
goarch: ${{ matrix.goarch }}
|
||||
goversion: "https://dl.google.com/go/go1.20.14.linux-amd64.tar.gz"
|
||||
goversion: "https://dl.google.com/go/go1.21.13.linux-amd64.tar.gz"
|
||||
project_path: "tools/goctl"
|
||||
binary_name: "goctl"
|
||||
extra_files: tools/goctl/readme.md tools/goctl/readme-cn.md
|
||||
extra_files: tools/goctl/readme.md tools/goctl/readme-cn.md
|
||||
|
||||
9
.github/workflows/reviewdog.yml
vendored
9
.github/workflows/reviewdog.yml
vendored
@@ -5,7 +5,12 @@ jobs:
|
||||
name: runner / staticcheck
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
check-latest: true
|
||||
cache: true
|
||||
- uses: reviewdog/action-staticcheck@v1
|
||||
with:
|
||||
github_token: ${{ secrets.github_token }}
|
||||
@@ -14,6 +19,6 @@ jobs:
|
||||
# Report all results.
|
||||
filter_mode: nofilter
|
||||
# Exit with 1 when it find at least one finding.
|
||||
fail_on_error: true
|
||||
fail_level: any
|
||||
# Set staticcheck flags
|
||||
staticcheck_flags: -checks=inherit,-SA1019,-SA1029,-SA5008
|
||||
|
||||
42
.github/workflows/version-check.yml
vendored
Normal file
42
.github/workflows/version-check.yml
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
name: Release Version Check
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'tools/goctl/v*'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
version-check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.21'
|
||||
|
||||
- name: Extract tag version
|
||||
id: get_version
|
||||
run: |
|
||||
# Extract version from tools/goctl/v* format
|
||||
VERSION="${GITHUB_REF#refs/tags/tools/goctl/v}"
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
echo "Extracted version: $VERSION"
|
||||
|
||||
- name: Check version in goctl source code
|
||||
run: |
|
||||
# Change to goctl directory
|
||||
cd tools/goctl
|
||||
|
||||
# Check version in BuildVersion constant
|
||||
VERSION_IN_CODE=$(grep -r "const BuildVersion =" . | grep -o '".*"' | tr -d '"')
|
||||
echo "Version in code: $VERSION_IN_CODE"
|
||||
echo "Expected version: $VERSION"
|
||||
|
||||
if [ "$VERSION_IN_CODE" != "$VERSION" ]; then
|
||||
echo "Version mismatch: Version in code ($VERSION_IN_CODE) doesn't match tag version ($VERSION)"
|
||||
exit 1
|
||||
fi
|
||||
echo "✅ Version check passed!"
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -17,6 +17,7 @@
|
||||
**/logs
|
||||
**/adhoc
|
||||
**/coverage.txt
|
||||
**/WARP.md
|
||||
|
||||
# for test purpose
|
||||
go.work
|
||||
|
||||
@@ -40,7 +40,7 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
// New create a Filter, store is the backed redis, key is the key for the bloom filter,
|
||||
// New creates a Filter, store is the backed redis, key is the key for the bloom filter,
|
||||
// bits is how many bits will be used, maps is how many hashes for each addition.
|
||||
// best practices:
|
||||
// elements - means how many actual elements
|
||||
|
||||
@@ -8,16 +8,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/mathx"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
const (
|
||||
numHistoryReasons = 5
|
||||
timeFormat = "15:04:05"
|
||||
)
|
||||
const numHistoryReasons = 5
|
||||
|
||||
// ErrServiceUnavailable is returned when the Breaker state is open.
|
||||
var ErrServiceUnavailable = errors.New("circuit breaker is open")
|
||||
@@ -262,9 +258,9 @@ type errorWindow struct {
|
||||
|
||||
func (ew *errorWindow) add(reason string) {
|
||||
ew.lock.Lock()
|
||||
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(timeFormat), reason)
|
||||
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(time.TimeOnly), reason)
|
||||
ew.index = (ew.index + 1) % numHistoryReasons
|
||||
ew.count = mathx.MinInt(ew.count+1, numHistoryReasons)
|
||||
ew.count = min(ew.count+1, numHistoryReasons)
|
||||
ew.lock.Unlock()
|
||||
}
|
||||
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
func TestNopBreaker(t *testing.T) {
|
||||
b := NopBreaker()
|
||||
assert.Equal(t, nopBreakerName, b.Name())
|
||||
p, err := b.Allow()
|
||||
_, err := b.Allow()
|
||||
assert.Nil(t, err)
|
||||
p, err = b.AllowCtx(context.Background())
|
||||
p, err := b.AllowCtx(context.Background())
|
||||
assert.Nil(t, err)
|
||||
p.Accept()
|
||||
for i := 0; i < 1000; i++ {
|
||||
|
||||
@@ -6,8 +6,6 @@ import (
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
// ErrPaddingSize indicates bad padding size.
|
||||
@@ -27,7 +25,8 @@ func newECB(b cipher.Block) *ecb {
|
||||
|
||||
type ecbEncrypter ecb
|
||||
|
||||
// NewECBEncrypter returns an ECB encrypter.
|
||||
// Deprecated: NewECBEncrypter returns an ECB encrypter.
|
||||
// ECB mode is insecure for multi-block data. Use AES-GCM instead.
|
||||
func NewECBEncrypter(b cipher.Block) cipher.BlockMode {
|
||||
return (*ecbEncrypter)(newECB(b))
|
||||
}
|
||||
@@ -39,12 +38,10 @@ func (x *ecbEncrypter) BlockSize() int { return x.blockSize }
|
||||
// the block size. Dst and src must overlap entirely or not at all.
|
||||
func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
|
||||
if len(src)%x.blockSize != 0 {
|
||||
logx.Error("crypto/cipher: input not full blocks")
|
||||
return
|
||||
panic("crypto/cipher: input not full blocks")
|
||||
}
|
||||
if len(dst) < len(src) {
|
||||
logx.Error("crypto/cipher: output smaller than input")
|
||||
return
|
||||
panic("crypto/cipher: output smaller than input")
|
||||
}
|
||||
|
||||
for len(src) > 0 {
|
||||
@@ -56,7 +53,8 @@ func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
|
||||
|
||||
type ecbDecrypter ecb
|
||||
|
||||
// NewECBDecrypter returns an ECB decrypter.
|
||||
// Deprecated: NewECBDecrypter returns an ECB decrypter.
|
||||
// ECB mode is insecure for multi-block data. Use AES-GCM instead.
|
||||
func NewECBDecrypter(b cipher.Block) cipher.BlockMode {
|
||||
return (*ecbDecrypter)(newECB(b))
|
||||
}
|
||||
@@ -70,12 +68,10 @@ func (x *ecbDecrypter) BlockSize() int {
|
||||
// the block size. Dst and src must overlap entirely or not at all.
|
||||
func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
|
||||
if len(src)%x.blockSize != 0 {
|
||||
logx.Error("crypto/cipher: input not full blocks")
|
||||
return
|
||||
panic("crypto/cipher: input not full blocks")
|
||||
}
|
||||
if len(dst) < len(src) {
|
||||
logx.Error("crypto/cipher: output smaller than input")
|
||||
return
|
||||
panic("crypto/cipher: output smaller than input")
|
||||
}
|
||||
|
||||
for len(src) > 0 {
|
||||
@@ -85,14 +81,18 @@ func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
|
||||
}
|
||||
}
|
||||
|
||||
// EcbDecrypt decrypts src with the given key.
|
||||
// Deprecated: EcbDecrypt decrypts src with the given key.
|
||||
// ECB mode is insecure for multi-block data. Use AES-GCM instead.
|
||||
func EcbDecrypt(key, src []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
logx.Errorf("Decrypt key error: % x", key)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(src)%block.BlockSize() != 0 {
|
||||
return nil, ErrPaddingSize
|
||||
}
|
||||
|
||||
decrypter := NewECBDecrypter(block)
|
||||
decrypted := make([]byte, len(src))
|
||||
decrypter.CryptBlocks(decrypted, src)
|
||||
@@ -100,8 +100,9 @@ func EcbDecrypt(key, src []byte) ([]byte, error) {
|
||||
return pkcs5Unpadding(decrypted, decrypter.BlockSize())
|
||||
}
|
||||
|
||||
// EcbDecryptBase64 decrypts base64 encoded src with the given base64 encoded key.
|
||||
// Deprecated: EcbDecryptBase64 decrypts base64 encoded src with the given base64 encoded key.
|
||||
// The returned string is also base64 encoded.
|
||||
// ECB mode is insecure for multi-block data. Use AES-GCM instead.
|
||||
func EcbDecryptBase64(key, src string) (string, error) {
|
||||
keyBytes, err := getKeyBytes(key)
|
||||
if err != nil {
|
||||
@@ -121,11 +122,11 @@ func EcbDecryptBase64(key, src string) (string, error) {
|
||||
return base64.StdEncoding.EncodeToString(decryptedBytes), nil
|
||||
}
|
||||
|
||||
// EcbEncrypt encrypts src with the given key.
|
||||
// Deprecated: EcbEncrypt encrypts src with the given key.
|
||||
// ECB mode is insecure for multi-block data. Use AES-GCM instead.
|
||||
func EcbEncrypt(key, src []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
logx.Errorf("Encrypt key error: % x", key)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -137,8 +138,9 @@ func EcbEncrypt(key, src []byte) ([]byte, error) {
|
||||
return crypted, nil
|
||||
}
|
||||
|
||||
// EcbEncryptBase64 encrypts base64 encoded src with the given base64 encoded key.
|
||||
// Deprecated: EcbEncryptBase64 encrypts base64 encoded src with the given base64 encoded key.
|
||||
// The returned string is also base64 encoded.
|
||||
// ECB mode is insecure for multi-block data. Use AES-GCM instead.
|
||||
func EcbEncryptBase64(key, src string) (string, error) {
|
||||
keyBytes, err := getKeyBytes(key)
|
||||
if err != nil {
|
||||
@@ -179,10 +181,20 @@ func pkcs5Padding(ciphertext []byte, blockSize int) []byte {
|
||||
|
||||
func pkcs5Unpadding(src []byte, blockSize int) ([]byte, error) {
|
||||
length := len(src)
|
||||
unpadding := int(src[length-1])
|
||||
if unpadding >= length || unpadding > blockSize {
|
||||
if length == 0 {
|
||||
return nil, ErrPaddingSize
|
||||
}
|
||||
|
||||
unpadding := int(src[length-1])
|
||||
if unpadding < 1 || unpadding > blockSize || unpadding > length {
|
||||
return nil, ErrPaddingSize
|
||||
}
|
||||
|
||||
for _, b := range src[length-unpadding:] {
|
||||
if int(b) != unpadding {
|
||||
return nil, ErrPaddingSize
|
||||
}
|
||||
}
|
||||
|
||||
return src[:length-unpadding], nil
|
||||
}
|
||||
|
||||
@@ -28,8 +28,8 @@ func TestAesEcb(t *testing.T) {
|
||||
_, err = EcbDecrypt(badKey2, dst)
|
||||
assert.NotNil(t, err)
|
||||
_, err = EcbDecrypt(key, val)
|
||||
// not enough block, just nil
|
||||
assert.Nil(t, err)
|
||||
// not a multiple of block size
|
||||
assert.NotNil(t, err)
|
||||
src, err := EcbDecrypt(key, dst)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, val, src)
|
||||
@@ -41,33 +41,28 @@ func TestAesEcb(t *testing.T) {
|
||||
assert.Equal(t, 16, decrypter.BlockSize())
|
||||
|
||||
dst = make([]byte, 8)
|
||||
encrypter.CryptBlocks(dst, val)
|
||||
for _, b := range dst {
|
||||
assert.Equal(t, byte(0), b)
|
||||
}
|
||||
assert.Panics(t, func() {
|
||||
encrypter.CryptBlocks(dst, val)
|
||||
})
|
||||
|
||||
dst = make([]byte, 8)
|
||||
encrypter.CryptBlocks(dst, valLong)
|
||||
for _, b := range dst {
|
||||
assert.Equal(t, byte(0), b)
|
||||
}
|
||||
assert.Panics(t, func() {
|
||||
encrypter.CryptBlocks(dst, valLong)
|
||||
})
|
||||
|
||||
dst = make([]byte, 8)
|
||||
decrypter.CryptBlocks(dst, val)
|
||||
for _, b := range dst {
|
||||
assert.Equal(t, byte(0), b)
|
||||
}
|
||||
assert.Panics(t, func() {
|
||||
decrypter.CryptBlocks(dst, val)
|
||||
})
|
||||
|
||||
dst = make([]byte, 8)
|
||||
decrypter.CryptBlocks(dst, valLong)
|
||||
for _, b := range dst {
|
||||
assert.Equal(t, byte(0), b)
|
||||
}
|
||||
assert.Panics(t, func() {
|
||||
decrypter.CryptBlocks(dst, valLong)
|
||||
})
|
||||
|
||||
_, err = EcbEncryptBase64("cTR0N3dDKkYtSmFOZFJnVWpYbjJyNXU4eC9BP0QK", "aGVsbG93b3JsZGxvbmcuLgo=")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAesEcbBase64(t *testing.T) {
|
||||
const (
|
||||
val = "hello"
|
||||
@@ -98,3 +93,44 @@ func TestAesEcbBase64(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, val, string(b))
|
||||
}
|
||||
|
||||
func TestPkcs5UnpaddingEmptyInput(t *testing.T) {
|
||||
_, err := pkcs5Unpadding([]byte{}, 16)
|
||||
assert.Equal(t, ErrPaddingSize, err)
|
||||
}
|
||||
|
||||
func TestPkcs5UnpaddingMalformedPadding(t *testing.T) {
|
||||
// Valid PKCS5 padding of 3: last 3 bytes should all be 0x03
|
||||
// Here we corrupt one padding byte
|
||||
malformed := []byte{0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
||||
0x41, 0x41, 0x41, 0x41, 0x41, 0x02, 0x03, 0x03}
|
||||
_, err := pkcs5Unpadding(malformed, 16)
|
||||
assert.Equal(t, ErrPaddingSize, err)
|
||||
|
||||
// All padding bytes correct
|
||||
valid := []byte{0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
||||
0x41, 0x41, 0x41, 0x41, 0x41, 0x03, 0x03, 0x03}
|
||||
result, err := pkcs5Unpadding(valid, 16)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, valid[:13], result)
|
||||
}
|
||||
|
||||
func TestPkcs5UnpaddingInvalidPaddingValue(t *testing.T) {
|
||||
// padding value = 0 (< 1)
|
||||
_, err := pkcs5Unpadding([]byte{0x41, 0x00}, 16)
|
||||
assert.Equal(t, ErrPaddingSize, err)
|
||||
|
||||
// padding value > blockSize
|
||||
_, err = pkcs5Unpadding([]byte{0x41, 0x41, 0x41, 0x41, 17}, 4)
|
||||
assert.Equal(t, ErrPaddingSize, err)
|
||||
|
||||
// padding value > length
|
||||
_, err = pkcs5Unpadding([]byte{0x41, 0x03}, 16)
|
||||
assert.Equal(t, ErrPaddingSize, err)
|
||||
}
|
||||
|
||||
func TestEcbDecryptEmptyInput(t *testing.T) {
|
||||
key := []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
|
||||
_, err := EcbDecrypt(key, []byte{})
|
||||
assert.Equal(t, ErrPaddingSize, err)
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ func ComputeKey(pubKey, priKey *big.Int) (*big.Int, error) {
|
||||
return nil, ErrInvalidPubKey
|
||||
}
|
||||
|
||||
if pubKey.Sign() <= 0 && p.Cmp(pubKey) <= 0 {
|
||||
if pubKey.Sign() <= 0 || p.Cmp(pubKey) <= 0 {
|
||||
return nil, ErrPubKeyOutOfBound
|
||||
}
|
||||
|
||||
|
||||
@@ -94,3 +94,32 @@ func TestDHOnErrors(t *testing.T) {
|
||||
|
||||
assert.NotNil(t, NewPublicKey([]byte("")))
|
||||
}
|
||||
|
||||
func TestDHPubKeyBoundary(t *testing.T) {
|
||||
key, err := GenerateKey()
|
||||
assert.Nil(t, err)
|
||||
|
||||
// pubKey = 0 should be rejected
|
||||
_, err = ComputeKey(big.NewInt(0), key.PriKey)
|
||||
assert.ErrorIs(t, err, ErrPubKeyOutOfBound)
|
||||
|
||||
// pubKey = -1 should be rejected
|
||||
_, err = ComputeKey(big.NewInt(-1), key.PriKey)
|
||||
assert.ErrorIs(t, err, ErrPubKeyOutOfBound)
|
||||
|
||||
// pubKey = p should be rejected
|
||||
_, err = ComputeKey(new(big.Int).Set(p), key.PriKey)
|
||||
assert.ErrorIs(t, err, ErrPubKeyOutOfBound)
|
||||
|
||||
// pubKey = p+1 should be rejected
|
||||
_, err = ComputeKey(new(big.Int).Add(p, big.NewInt(1)), key.PriKey)
|
||||
assert.ErrorIs(t, err, ErrPubKeyOutOfBound)
|
||||
|
||||
// pubKey = 1 should be accepted
|
||||
_, err = ComputeKey(big.NewInt(1), key.PriKey)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// pubKey = p-1 should be accepted
|
||||
_, err = ComputeKey(new(big.Int).Sub(p, big.NewInt(1)), key.PriKey)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package codec
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
@@ -46,7 +47,9 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
// NewRsaDecrypter returns a RsaDecrypter with the given file.
|
||||
// Deprecated: NewRsaDecrypter returns a RsaDecrypter with the given file.
|
||||
// PKCS#1 v1.5 padding is vulnerable to padding oracle attacks.
|
||||
// Use NewRsaOAEPDecrypter instead.
|
||||
func NewRsaDecrypter(file string) (RsaDecrypter, error) {
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
@@ -90,7 +93,9 @@ func (r *rsaDecrypter) DecryptBase64(input string) ([]byte, error) {
|
||||
return r.Decrypt(base64Decoded)
|
||||
}
|
||||
|
||||
// NewRsaEncrypter returns a RsaEncrypter with the given key.
|
||||
// Deprecated: NewRsaEncrypter returns a RsaEncrypter with the given key.
|
||||
// PKCS#1 v1.5 padding is vulnerable to padding oracle attacks.
|
||||
// Use NewRsaOAEPEncrypter instead.
|
||||
func NewRsaEncrypter(key []byte) (RsaEncrypter, error) {
|
||||
block, _ := pem.Decode(key)
|
||||
if block == nil {
|
||||
@@ -154,3 +159,90 @@ func rsaDecryptBlock(privateKey *rsa.PrivateKey, block []byte) ([]byte, error) {
|
||||
func rsaEncryptBlock(publicKey *rsa.PublicKey, msg []byte) ([]byte, error) {
|
||||
return rsa.EncryptPKCS1v15(rand.Reader, publicKey, msg)
|
||||
}
|
||||
|
||||
// NewRsaOAEPDecrypter returns a RsaDecrypter using OAEP with SHA-256.
|
||||
func NewRsaOAEPDecrypter(file string) (RsaDecrypter, error) {
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(content)
|
||||
if block == nil {
|
||||
return nil, ErrPrivateKey
|
||||
}
|
||||
|
||||
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &rsaOAEPDecrypter{
|
||||
rsaBase: rsaBase{
|
||||
bytesLimit: privateKey.N.BitLen() >> 3,
|
||||
},
|
||||
privateKey: privateKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewRsaOAEPEncrypter returns a RsaEncrypter using OAEP with SHA-256.
|
||||
func NewRsaOAEPEncrypter(key []byte) (RsaEncrypter, error) {
|
||||
block, _ := pem.Decode(key)
|
||||
if block == nil {
|
||||
return nil, ErrPublicKey
|
||||
}
|
||||
|
||||
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch pubKey := pub.(type) {
|
||||
case *rsa.PublicKey:
|
||||
// OAEP overhead: 2*hash_size + 2
|
||||
hashSize := sha256.New().Size()
|
||||
return &rsaOAEPEncrypter{
|
||||
rsaBase: rsaBase{
|
||||
bytesLimit: (pubKey.N.BitLen() >> 3) - 2*hashSize - 2,
|
||||
},
|
||||
publicKey: pubKey,
|
||||
}, nil
|
||||
default:
|
||||
return nil, ErrNotRsaKey
|
||||
}
|
||||
}
|
||||
|
||||
type rsaOAEPDecrypter struct {
|
||||
rsaBase
|
||||
privateKey *rsa.PrivateKey
|
||||
}
|
||||
|
||||
func (r *rsaOAEPDecrypter) Decrypt(input []byte) ([]byte, error) {
|
||||
return r.crypt(input, func(block []byte) ([]byte, error) {
|
||||
return rsa.DecryptOAEP(sha256.New(), rand.Reader, r.privateKey, block, nil)
|
||||
})
|
||||
}
|
||||
|
||||
func (r *rsaOAEPDecrypter) DecryptBase64(input string) ([]byte, error) {
|
||||
if len(input) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
base64Decoded, err := base64.StdEncoding.DecodeString(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return r.Decrypt(base64Decoded)
|
||||
}
|
||||
|
||||
type rsaOAEPEncrypter struct {
|
||||
rsaBase
|
||||
publicKey *rsa.PublicKey
|
||||
}
|
||||
|
||||
func (r *rsaOAEPEncrypter) Encrypt(input []byte) ([]byte, error) {
|
||||
return r.crypt(input, func(block []byte) ([]byte, error) {
|
||||
return rsa.EncryptOAEP(sha256.New(), rand.Reader, r.publicKey, block, nil)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
@@ -58,3 +63,78 @@ func TestBadPubKey(t *testing.T) {
|
||||
_, err := NewRsaEncrypter([]byte("foo"))
|
||||
assert.Equal(t, ErrPublicKey, err)
|
||||
}
|
||||
|
||||
func TestOAEPCryption(t *testing.T) {
|
||||
enc, err := NewRsaOAEPEncrypter([]byte(pubKey))
|
||||
assert.Nil(t, err)
|
||||
ret, err := enc.Encrypt([]byte(testBody))
|
||||
assert.Nil(t, err)
|
||||
|
||||
file, err := fs.TempFilenameWithText(priKey)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(file)
|
||||
dec, err := NewRsaOAEPDecrypter(file)
|
||||
assert.Nil(t, err)
|
||||
actual, err := dec.Decrypt(ret)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, testBody, string(actual))
|
||||
|
||||
actual, err = dec.DecryptBase64(base64.StdEncoding.EncodeToString(ret))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, testBody, string(actual))
|
||||
|
||||
// empty input
|
||||
actual, err = dec.DecryptBase64("")
|
||||
assert.Nil(t, err)
|
||||
assert.Nil(t, actual)
|
||||
}
|
||||
|
||||
func TestOAEPBadKeys(t *testing.T) {
|
||||
_, err := NewRsaOAEPEncrypter([]byte("bad"))
|
||||
assert.Equal(t, ErrPublicKey, err)
|
||||
|
||||
_, err = NewRsaOAEPDecrypter("nonexistent")
|
||||
assert.Error(t, err)
|
||||
|
||||
// valid PEM but invalid private key content
|
||||
badPem, err := fs.TempFilenameWithText("-----BEGIN RSA PRIVATE KEY-----\nYmFk\n-----END RSA PRIVATE KEY-----")
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(badPem)
|
||||
_, err = NewRsaOAEPDecrypter(badPem)
|
||||
assert.Error(t, err)
|
||||
|
||||
// not PEM content at all
|
||||
notPem, err := fs.TempFilenameWithText("not a pem file")
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(notPem)
|
||||
_, err = NewRsaOAEPDecrypter(notPem)
|
||||
assert.Equal(t, ErrPrivateKey, err)
|
||||
}
|
||||
|
||||
func TestOAEPEncrypterParseError(t *testing.T) {
|
||||
// valid PEM block but invalid public key content
|
||||
badPub := []byte("-----BEGIN PUBLIC KEY-----\nYmFk\n-----END PUBLIC KEY-----")
|
||||
_, err := NewRsaOAEPEncrypter(badPub)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOAEPEncrypterNonRsaKey(t *testing.T) {
|
||||
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
assert.Nil(t, err)
|
||||
derBytes, err := x509.MarshalPKIXPublicKey(&ecKey.PublicKey)
|
||||
assert.Nil(t, err)
|
||||
ecPem := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: derBytes})
|
||||
_, err = NewRsaOAEPEncrypter(ecPem)
|
||||
assert.Equal(t, ErrNotRsaKey, err)
|
||||
}
|
||||
|
||||
func TestOAEPDecryptBase64Error(t *testing.T) {
|
||||
file, err := fs.TempFilenameWithText(priKey)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(file)
|
||||
dec, err := NewRsaOAEPDecrypter(file)
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, err = dec.DecryptBase64("not-valid-base64!!!")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -81,6 +81,10 @@ func (c *Cache) Del(key string) {
|
||||
delete(c.data, key)
|
||||
c.lruCache.remove(key)
|
||||
c.lock.Unlock()
|
||||
|
||||
// RemoveTimer is called outside the lock to avoid performance impact from this
|
||||
// potentially time-consuming operation. Data integrity is maintained by lruCache,
|
||||
// which will eventually evict any remaining entries when capacity is exceeded.
|
||||
c.timingWheel.RemoveTimer(key)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,235 +1,53 @@
|
||||
package collection
|
||||
|
||||
import (
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
import "github.com/zeromicro/go-zero/core/lang"
|
||||
|
||||
const (
|
||||
unmanaged = iota
|
||||
untyped
|
||||
intType
|
||||
int64Type
|
||||
uintType
|
||||
uint64Type
|
||||
stringType
|
||||
)
|
||||
|
||||
// Set is not thread-safe, for concurrent use, make sure to use it with synchronization.
|
||||
type Set struct {
|
||||
data map[any]lang.PlaceholderType
|
||||
tp int
|
||||
// Set is a type-safe generic set collection.
|
||||
// It's not thread-safe, use with synchronization for concurrent access.
|
||||
type Set[T comparable] struct {
|
||||
data map[T]lang.PlaceholderType
|
||||
}
|
||||
|
||||
// NewSet returns a managed Set, can only put the values with the same type.
|
||||
func NewSet() *Set {
|
||||
return &Set{
|
||||
data: make(map[any]lang.PlaceholderType),
|
||||
tp: untyped,
|
||||
// NewSet returns a new type-safe set.
|
||||
func NewSet[T comparable]() *Set[T] {
|
||||
return &Set[T]{
|
||||
data: make(map[T]lang.PlaceholderType),
|
||||
}
|
||||
}
|
||||
|
||||
// NewUnmanagedSet returns an unmanaged Set, which can put values with different types.
|
||||
func NewUnmanagedSet() *Set {
|
||||
return &Set{
|
||||
data: make(map[any]lang.PlaceholderType),
|
||||
tp: unmanaged,
|
||||
// Add adds items to the set. Duplicates are automatically ignored.
|
||||
func (s *Set[T]) Add(items ...T) {
|
||||
for _, item := range items {
|
||||
s.data[item] = lang.Placeholder
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds i into s.
|
||||
func (s *Set) Add(i ...any) {
|
||||
for _, each := range i {
|
||||
s.add(each)
|
||||
}
|
||||
// Clear removes all items from the set.
|
||||
func (s *Set[T]) Clear() {
|
||||
clear(s.data)
|
||||
}
|
||||
|
||||
// AddInt adds int values ii into s.
|
||||
func (s *Set) AddInt(ii ...int) {
|
||||
for _, each := range ii {
|
||||
s.add(each)
|
||||
}
|
||||
}
|
||||
|
||||
// AddInt64 adds int64 values ii into s.
|
||||
func (s *Set) AddInt64(ii ...int64) {
|
||||
for _, each := range ii {
|
||||
s.add(each)
|
||||
}
|
||||
}
|
||||
|
||||
// AddUint adds uint values ii into s.
|
||||
func (s *Set) AddUint(ii ...uint) {
|
||||
for _, each := range ii {
|
||||
s.add(each)
|
||||
}
|
||||
}
|
||||
|
||||
// AddUint64 adds uint64 values ii into s.
|
||||
func (s *Set) AddUint64(ii ...uint64) {
|
||||
for _, each := range ii {
|
||||
s.add(each)
|
||||
}
|
||||
}
|
||||
|
||||
// AddStr adds string values ss into s.
|
||||
func (s *Set) AddStr(ss ...string) {
|
||||
for _, each := range ss {
|
||||
s.add(each)
|
||||
}
|
||||
}
|
||||
|
||||
// Contains checks if i is in s.
|
||||
func (s *Set) Contains(i any) bool {
|
||||
if len(s.data) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
s.validate(i)
|
||||
_, ok := s.data[i]
|
||||
// Contains checks if an item exists in the set.
|
||||
func (s *Set[T]) Contains(item T) bool {
|
||||
_, ok := s.data[item]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Keys returns the keys in s.
|
||||
func (s *Set) Keys() []any {
|
||||
var keys []any
|
||||
|
||||
for key := range s.data {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// KeysInt returns the int keys in s.
|
||||
func (s *Set) KeysInt() []int {
|
||||
var keys []int
|
||||
|
||||
for key := range s.data {
|
||||
if intKey, ok := key.(int); ok {
|
||||
keys = append(keys, intKey)
|
||||
}
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// KeysInt64 returns int64 keys in s.
|
||||
func (s *Set) KeysInt64() []int64 {
|
||||
var keys []int64
|
||||
|
||||
for key := range s.data {
|
||||
if intKey, ok := key.(int64); ok {
|
||||
keys = append(keys, intKey)
|
||||
}
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// KeysUint returns uint keys in s.
|
||||
func (s *Set) KeysUint() []uint {
|
||||
var keys []uint
|
||||
|
||||
for key := range s.data {
|
||||
if intKey, ok := key.(uint); ok {
|
||||
keys = append(keys, intKey)
|
||||
}
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// KeysUint64 returns uint64 keys in s.
|
||||
func (s *Set) KeysUint64() []uint64 {
|
||||
var keys []uint64
|
||||
|
||||
for key := range s.data {
|
||||
if intKey, ok := key.(uint64); ok {
|
||||
keys = append(keys, intKey)
|
||||
}
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// KeysStr returns string keys in s.
|
||||
func (s *Set) KeysStr() []string {
|
||||
var keys []string
|
||||
|
||||
for key := range s.data {
|
||||
if strKey, ok := key.(string); ok {
|
||||
keys = append(keys, strKey)
|
||||
}
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// Remove removes i from s.
|
||||
func (s *Set) Remove(i any) {
|
||||
s.validate(i)
|
||||
delete(s.data, i)
|
||||
}
|
||||
|
||||
// Count returns the number of items in s.
|
||||
func (s *Set) Count() int {
|
||||
// Count returns the number of items in the set.
|
||||
func (s *Set[T]) Count() int {
|
||||
return len(s.data)
|
||||
}
|
||||
|
||||
func (s *Set) add(i any) {
|
||||
switch s.tp {
|
||||
case unmanaged:
|
||||
// do nothing
|
||||
case untyped:
|
||||
s.setType(i)
|
||||
default:
|
||||
s.validate(i)
|
||||
// Keys returns all elements in the set as a slice.
|
||||
func (s *Set[T]) Keys() []T {
|
||||
keys := make([]T, 0, len(s.data))
|
||||
for key := range s.data {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
s.data[i] = lang.Placeholder
|
||||
return keys
|
||||
}
|
||||
|
||||
func (s *Set) setType(i any) {
|
||||
// s.tp can only be untyped here
|
||||
switch i.(type) {
|
||||
case int:
|
||||
s.tp = intType
|
||||
case int64:
|
||||
s.tp = int64Type
|
||||
case uint:
|
||||
s.tp = uintType
|
||||
case uint64:
|
||||
s.tp = uint64Type
|
||||
case string:
|
||||
s.tp = stringType
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Set) validate(i any) {
|
||||
if s.tp == unmanaged {
|
||||
return
|
||||
}
|
||||
|
||||
switch i.(type) {
|
||||
case int:
|
||||
if s.tp != intType {
|
||||
logx.Errorf("element is int, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case int64:
|
||||
if s.tp != int64Type {
|
||||
logx.Errorf("element is int64, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case uint:
|
||||
if s.tp != uintType {
|
||||
logx.Errorf("element is uint, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case uint64:
|
||||
if s.tp != uint64Type {
|
||||
logx.Errorf("element is uint64, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case string:
|
||||
if s.tp != stringType {
|
||||
logx.Errorf("element is string, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
}
|
||||
// Remove removes an item from the set.
|
||||
func (s *Set[T]) Remove(item T) {
|
||||
delete(s.data, item)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,105 @@ func init() {
|
||||
logx.Disable()
|
||||
}
|
||||
|
||||
// Set functionality tests
|
||||
func TestTypedSetInt(t *testing.T) {
|
||||
set := NewSet[int]()
|
||||
values := []int{1, 2, 3, 2, 1} // Contains duplicates
|
||||
|
||||
// Test adding
|
||||
set.Add(values...)
|
||||
assert.Equal(t, 3, set.Count()) // Should only have 3 elements after deduplication
|
||||
|
||||
// Test contains
|
||||
assert.True(t, set.Contains(1))
|
||||
assert.True(t, set.Contains(2))
|
||||
assert.True(t, set.Contains(3))
|
||||
assert.False(t, set.Contains(4))
|
||||
|
||||
// Test getting all keys
|
||||
keys := set.Keys()
|
||||
sort.Ints(keys)
|
||||
assert.EqualValues(t, []int{1, 2, 3}, keys)
|
||||
|
||||
// Test removal
|
||||
set.Remove(2)
|
||||
assert.False(t, set.Contains(2))
|
||||
assert.Equal(t, 2, set.Count())
|
||||
}
|
||||
|
||||
func TestTypedSetStringOps(t *testing.T) {
|
||||
set := NewSet[string]()
|
||||
values := []string{"a", "b", "c", "b", "a"}
|
||||
|
||||
set.Add(values...)
|
||||
assert.Equal(t, 3, set.Count())
|
||||
|
||||
assert.True(t, set.Contains("a"))
|
||||
assert.True(t, set.Contains("b"))
|
||||
assert.True(t, set.Contains("c"))
|
||||
assert.False(t, set.Contains("d"))
|
||||
|
||||
keys := set.Keys()
|
||||
sort.Strings(keys)
|
||||
assert.EqualValues(t, []string{"a", "b", "c"}, keys)
|
||||
}
|
||||
|
||||
func TestTypedSetClear(t *testing.T) {
|
||||
set := NewSet[int]()
|
||||
set.Add(1, 2, 3)
|
||||
assert.Equal(t, 3, set.Count())
|
||||
|
||||
set.Clear()
|
||||
assert.Equal(t, 0, set.Count())
|
||||
assert.False(t, set.Contains(1))
|
||||
}
|
||||
|
||||
func TestTypedSetEmpty(t *testing.T) {
|
||||
set := NewSet[int]()
|
||||
assert.Equal(t, 0, set.Count())
|
||||
assert.False(t, set.Contains(1))
|
||||
assert.Empty(t, set.Keys())
|
||||
}
|
||||
|
||||
func TestTypedSetMultipleTypes(t *testing.T) {
|
||||
// Test different typed generic sets
|
||||
intSet := NewSet[int]()
|
||||
int64Set := NewSet[int64]()
|
||||
uintSet := NewSet[uint]()
|
||||
uint64Set := NewSet[uint64]()
|
||||
stringSet := NewSet[string]()
|
||||
|
||||
intSet.Add(1, 2, 3)
|
||||
int64Set.Add(1, 2, 3)
|
||||
uintSet.Add(1, 2, 3)
|
||||
uint64Set.Add(1, 2, 3)
|
||||
stringSet.Add("1", "2", "3")
|
||||
|
||||
assert.Equal(t, 3, intSet.Count())
|
||||
assert.Equal(t, 3, int64Set.Count())
|
||||
assert.Equal(t, 3, uintSet.Count())
|
||||
assert.Equal(t, 3, uint64Set.Count())
|
||||
assert.Equal(t, 3, stringSet.Count())
|
||||
}
|
||||
|
||||
// Set benchmarks
|
||||
func BenchmarkTypedIntSet(b *testing.B) {
|
||||
s := NewSet[int]()
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.Add(i)
|
||||
_ = s.Contains(i)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTypedStringSet(b *testing.B) {
|
||||
s := NewSet[string]()
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.Add(string(rune(i)))
|
||||
_ = s.Contains(string(rune(i)))
|
||||
}
|
||||
}
|
||||
|
||||
// Legacy tests remain unchanged for backward compatibility
|
||||
func BenchmarkRawSet(b *testing.B) {
|
||||
m := make(map[any]struct{})
|
||||
for i := 0; i < b.N; i++ {
|
||||
@@ -20,26 +119,10 @@ func BenchmarkRawSet(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUnmanagedSet(b *testing.B) {
|
||||
s := NewUnmanagedSet()
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.Add(i)
|
||||
_ = s.Contains(i)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSet(b *testing.B) {
|
||||
s := NewSet()
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.AddInt(i)
|
||||
_ = s.Contains(i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdd(t *testing.T) {
|
||||
// given
|
||||
set := NewUnmanagedSet()
|
||||
values := []any{1, 2, 3}
|
||||
set := NewSet[int]()
|
||||
values := []int{1, 2, 3}
|
||||
|
||||
// when
|
||||
set.Add(values...)
|
||||
@@ -51,82 +134,74 @@ func TestAdd(t *testing.T) {
|
||||
|
||||
func TestAddInt(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
set := NewSet[int]()
|
||||
values := []int{1, 2, 3}
|
||||
|
||||
// when
|
||||
set.AddInt(values...)
|
||||
set.Add(values...)
|
||||
|
||||
// then
|
||||
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
|
||||
keys := set.KeysInt()
|
||||
keys := set.Keys()
|
||||
sort.Ints(keys)
|
||||
assert.EqualValues(t, values, keys)
|
||||
}
|
||||
|
||||
func TestAddInt64(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
set := NewSet[int64]()
|
||||
values := []int64{1, 2, 3}
|
||||
|
||||
// when
|
||||
set.AddInt64(values...)
|
||||
set.Add(values...)
|
||||
|
||||
// then
|
||||
assert.True(t, set.Contains(int64(1)) && set.Contains(int64(2)) && set.Contains(int64(3)))
|
||||
assert.Equal(t, len(values), len(set.KeysInt64()))
|
||||
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
|
||||
assert.Equal(t, len(values), len(set.Keys()))
|
||||
}
|
||||
|
||||
func TestAddUint(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
set := NewSet[uint]()
|
||||
values := []uint{1, 2, 3}
|
||||
|
||||
// when
|
||||
set.AddUint(values...)
|
||||
set.Add(values...)
|
||||
|
||||
// then
|
||||
assert.True(t, set.Contains(uint(1)) && set.Contains(uint(2)) && set.Contains(uint(3)))
|
||||
assert.Equal(t, len(values), len(set.KeysUint()))
|
||||
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
|
||||
assert.Equal(t, len(values), len(set.Keys()))
|
||||
}
|
||||
|
||||
func TestAddUint64(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
set := NewSet[uint64]()
|
||||
values := []uint64{1, 2, 3}
|
||||
|
||||
// when
|
||||
set.AddUint64(values...)
|
||||
set.Add(values...)
|
||||
|
||||
// then
|
||||
assert.True(t, set.Contains(uint64(1)) && set.Contains(uint64(2)) && set.Contains(uint64(3)))
|
||||
assert.Equal(t, len(values), len(set.KeysUint64()))
|
||||
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
|
||||
assert.Equal(t, len(values), len(set.Keys()))
|
||||
}
|
||||
|
||||
func TestAddStr(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
set := NewSet[string]()
|
||||
values := []string{"1", "2", "3"}
|
||||
|
||||
// when
|
||||
set.AddStr(values...)
|
||||
set.Add(values...)
|
||||
|
||||
// then
|
||||
assert.True(t, set.Contains("1") && set.Contains("2") && set.Contains("3"))
|
||||
assert.Equal(t, len(values), len(set.KeysStr()))
|
||||
assert.Equal(t, len(values), len(set.Keys()))
|
||||
}
|
||||
|
||||
func TestContainsWithoutElements(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
|
||||
// then
|
||||
assert.False(t, set.Contains(1))
|
||||
}
|
||||
|
||||
func TestContainsUnmanagedWithoutElements(t *testing.T) {
|
||||
// given
|
||||
set := NewUnmanagedSet()
|
||||
set := NewSet[int]()
|
||||
|
||||
// then
|
||||
assert.False(t, set.Contains(1))
|
||||
@@ -134,8 +209,8 @@ func TestContainsUnmanagedWithoutElements(t *testing.T) {
|
||||
|
||||
func TestRemove(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
set.Add([]any{1, 2, 3}...)
|
||||
set := NewSet[int]()
|
||||
set.Add([]int{1, 2, 3}...)
|
||||
|
||||
// when
|
||||
set.Remove(2)
|
||||
@@ -146,57 +221,9 @@ func TestRemove(t *testing.T) {
|
||||
|
||||
func TestCount(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
set.Add([]any{1, 2, 3}...)
|
||||
set := NewSet[int]()
|
||||
set.Add([]int{1, 2, 3}...)
|
||||
|
||||
// then
|
||||
assert.Equal(t, set.Count(), 3)
|
||||
}
|
||||
|
||||
func TestKeysIntMismatch(t *testing.T) {
|
||||
set := NewSet()
|
||||
set.add(int64(1))
|
||||
set.add(2)
|
||||
vals := set.KeysInt()
|
||||
assert.EqualValues(t, []int{2}, vals)
|
||||
}
|
||||
|
||||
func TestKeysInt64Mismatch(t *testing.T) {
|
||||
set := NewSet()
|
||||
set.add(1)
|
||||
set.add(int64(2))
|
||||
vals := set.KeysInt64()
|
||||
assert.EqualValues(t, []int64{2}, vals)
|
||||
}
|
||||
|
||||
func TestKeysUintMismatch(t *testing.T) {
|
||||
set := NewSet()
|
||||
set.add(1)
|
||||
set.add(uint(2))
|
||||
vals := set.KeysUint()
|
||||
assert.EqualValues(t, []uint{2}, vals)
|
||||
}
|
||||
|
||||
func TestKeysUint64Mismatch(t *testing.T) {
|
||||
set := NewSet()
|
||||
set.add(1)
|
||||
set.add(uint64(2))
|
||||
vals := set.KeysUint64()
|
||||
assert.EqualValues(t, []uint64{2}, vals)
|
||||
}
|
||||
|
||||
func TestKeysStrMismatch(t *testing.T) {
|
||||
set := NewSet()
|
||||
set.add(1)
|
||||
set.add("2")
|
||||
vals := set.KeysStr()
|
||||
assert.EqualValues(t, []string{"2"}, vals)
|
||||
}
|
||||
|
||||
func TestSetType(t *testing.T) {
|
||||
set := NewUnmanagedSet()
|
||||
set.add(1)
|
||||
set.add("2")
|
||||
vals := set.Keys()
|
||||
assert.ElementsMatch(t, []any{1, "2"}, vals)
|
||||
}
|
||||
|
||||
@@ -164,6 +164,7 @@ func (tw *TimingWheel) Stop() {
|
||||
|
||||
func (tw *TimingWheel) drainAll(fn func(key, value any)) {
|
||||
runner := threading.NewTaskRunner(drainWorkers)
|
||||
|
||||
for _, slot := range tw.slots {
|
||||
for e := slot.Front(); e != nil; {
|
||||
task := e.Value.(*timingEntry)
|
||||
@@ -177,6 +178,8 @@ func (tw *TimingWheel) drainAll(fn func(key, value any)) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
runner.Wait()
|
||||
}
|
||||
|
||||
func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos, circle int) {
|
||||
|
||||
@@ -629,6 +629,157 @@ func TestMoveAndRemoveTask(t *testing.T) {
|
||||
assert.Equal(t, 0, len(keys))
|
||||
}
|
||||
|
||||
// TestTimingWheel_DrainClosureBug tests the closure capture bug in drainAll
|
||||
// Issue: https://github.com/zeromicro/go-zero/issues/5314
|
||||
func TestTimingWheel_DrainClosureBug(t *testing.T) {
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker)
|
||||
defer tw.Stop()
|
||||
|
||||
// Set multiple timers with different values
|
||||
for i := 0; i < 10; i++ {
|
||||
tw.SetTimer(i, i*10, testStep*5)
|
||||
}
|
||||
|
||||
// Give time for timers to be set
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
|
||||
var mu sync.Mutex
|
||||
received := make(map[int]int)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(10)
|
||||
|
||||
tw.Drain(func(key, value any) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
k := key.(int)
|
||||
v := value.(int)
|
||||
received[k] = v
|
||||
wg.Done()
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Check if all values match their keys
|
||||
for k, v := range received {
|
||||
expected := k * 10
|
||||
assert.Equal(t, expected, v, "key %d should have value %d, got %d", k, expected, v)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTimingWheel_RunTasksClosureBug tests the closure capture bug in runTasks
|
||||
// Issue: https://github.com/zeromicro/go-zero/issues/5314
|
||||
func TestTimingWheel_RunTasksClosureBug(t *testing.T) {
|
||||
ticker := timex.NewFakeTicker()
|
||||
var mu sync.Mutex
|
||||
executed := make(map[int]int)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
key := k.(int)
|
||||
val := v.(int)
|
||||
executed[key] = val
|
||||
wg.Done()
|
||||
}, ticker)
|
||||
defer tw.Stop()
|
||||
|
||||
// Set multiple timers that should fire in the same tick
|
||||
count := 10
|
||||
wg.Add(count)
|
||||
for i := 0; i < count; i++ {
|
||||
tw.SetTimer(i, i*10, testStep)
|
||||
}
|
||||
|
||||
// Advance ticker to trigger tasks
|
||||
ticker.Tick()
|
||||
|
||||
// Wait for execution with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Success
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for tasks to execute")
|
||||
}
|
||||
|
||||
// Verify all tasks executed with correct values
|
||||
assert.Equal(t, count, len(executed), "should have executed all tasks")
|
||||
for k, v := range executed {
|
||||
expected := k * 10
|
||||
assert.Equal(t, expected, v, "key %d should have value %d, got %d", k, expected, v)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTimingWheel_RunTasksRaceCondition tests for race conditions in runTasks
|
||||
// This test specifically targets the loop variable capture bug
|
||||
func TestTimingWheel_RunTasksRaceCondition(t *testing.T) {
|
||||
// Run multiple times to increase likelihood of catching the bug
|
||||
for attempt := 0; attempt < 10; attempt++ {
|
||||
t.Run("", func(t *testing.T) {
|
||||
ticker := timex.NewFakeTicker()
|
||||
var mu sync.Mutex
|
||||
keyValues := make(map[int][]int)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
|
||||
// Add small delay to increase chance of race
|
||||
time.Sleep(time.Microsecond)
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
key := k.(int)
|
||||
val := v.(int)
|
||||
keyValues[key] = append(keyValues[key], val)
|
||||
wg.Done()
|
||||
}, ticker)
|
||||
defer tw.Stop()
|
||||
|
||||
// Set many timers rapidly to increase chance of race
|
||||
count := 50
|
||||
wg.Add(count)
|
||||
for i := 0; i < count; i++ {
|
||||
tw.SetTimer(i, i*100, testStep)
|
||||
}
|
||||
|
||||
ticker.Tick()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for tasks")
|
||||
}
|
||||
|
||||
// Check for duplicates or wrong values
|
||||
wrongCount := 0
|
||||
for key, values := range keyValues {
|
||||
assert.Equal(t, 1, len(values), "key %d should only execute once, got %v", key, values)
|
||||
if len(values) > 0 {
|
||||
expected := key * 100
|
||||
if values[0] != expected {
|
||||
wrongCount++
|
||||
t.Logf("BUG DETECTED: key %d should have value %d, got %d", key, expected, values[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
if wrongCount > 0 {
|
||||
t.Errorf("Found %d tasks with wrong values due to closure bug", wrongCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTimingWheel(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
|
||||
@@ -21,10 +21,11 @@ const (
|
||||
var (
|
||||
fillDefaultUnmarshaler = mapping.NewUnmarshaler(jsonTagKey, mapping.WithDefault())
|
||||
loaders = map[string]func([]byte, any) error{
|
||||
".json": LoadFromJsonBytes,
|
||||
".toml": LoadFromTomlBytes,
|
||||
".yaml": LoadFromYamlBytes,
|
||||
".yml": LoadFromYamlBytes,
|
||||
".json": LoadFromJsonBytes,
|
||||
".json5": LoadFromJson5Bytes,
|
||||
".toml": LoadFromTomlBytes,
|
||||
".yaml": LoadFromYamlBytes,
|
||||
".yml": LoadFromYamlBytes,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -41,7 +42,7 @@ func FillDefault(v any) error {
|
||||
return fillDefaultUnmarshaler.Unmarshal(map[string]any{}, v)
|
||||
}
|
||||
|
||||
// Load loads config into v from file, .json, .yaml and .yml are acceptable.
|
||||
// Load loads config into v from file, .json, .json5, .toml, .yaml and .yml are acceptable.
|
||||
func Load(file string, v any, opts ...Option) error {
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
@@ -65,7 +66,7 @@ func Load(file string, v any, opts ...Option) error {
|
||||
return loader(content, v)
|
||||
}
|
||||
|
||||
// LoadConfig loads config into v from file, .json, .yaml and .yml are acceptable.
|
||||
// LoadConfig loads config into v from file, .json, .json5, .toml, .yaml and .yml are acceptable.
|
||||
// Deprecated: use Load instead.
|
||||
func LoadConfig(file string, v any, opts ...Option) error {
|
||||
return Load(file, v, opts...)
|
||||
@@ -85,7 +86,12 @@ func LoadFromJsonBytes(content []byte, v any) error {
|
||||
|
||||
lowerCaseKeyMap := toLowerCaseKeyMap(m, info)
|
||||
|
||||
return mapping.UnmarshalJsonMap(lowerCaseKeyMap, v, mapping.WithCanonicalKeyFunc(toLowerCase))
|
||||
if err = mapping.UnmarshalJsonMap(lowerCaseKeyMap, v,
|
||||
mapping.WithCanonicalKeyFunc(toLowerCase)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return validate(v)
|
||||
}
|
||||
|
||||
// LoadConfigFromJsonBytes loads config into v from content json bytes.
|
||||
@@ -114,6 +120,16 @@ func LoadFromYamlBytes(content []byte, v any) error {
|
||||
return LoadFromJsonBytes(b, v)
|
||||
}
|
||||
|
||||
// LoadFromJson5Bytes loads config into v from content json5 bytes.
|
||||
func LoadFromJson5Bytes(content []byte, v any) error {
|
||||
b, err := encoding.Json5ToJson(content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return LoadFromJsonBytes(b, v)
|
||||
}
|
||||
|
||||
// LoadConfigFromYamlBytes loads config into v from content yaml bytes.
|
||||
// Deprecated: use LoadFromYamlBytes instead.
|
||||
func LoadConfigFromYamlBytes(content []byte, v any) error {
|
||||
@@ -192,7 +208,7 @@ func buildFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) {
|
||||
case reflect.Array, reflect.Slice, reflect.Map:
|
||||
return buildFieldsInfo(mapping.Deref(tp.Elem()), fullName)
|
||||
case reflect.Chan, reflect.Func:
|
||||
return nil, fmt.Errorf("unsupported type: %s", tp.Kind())
|
||||
return nil, fmt.Errorf("unsupported type: %s, fullName: %s", tp.Kind(), fullName)
|
||||
default:
|
||||
return &fieldInfo{
|
||||
children: make(map[string]*fieldInfo),
|
||||
@@ -307,7 +323,7 @@ func toLowerCaseInterface(v any, info *fieldInfo) any {
|
||||
case map[string]any:
|
||||
return toLowerCaseKeyMap(vv, info)
|
||||
case []any:
|
||||
var arr []any
|
||||
arr := make([]any, 0, len(vv))
|
||||
for _, vvv := range vv {
|
||||
arr = append(arr, toLowerCaseInterface(vvv, info))
|
||||
}
|
||||
@@ -359,5 +375,5 @@ func getFullName(parent, child string) string {
|
||||
return child
|
||||
}
|
||||
|
||||
return strings.Join([]string{parent, child}, ".")
|
||||
return parent + "." + child
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package conf
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
@@ -40,9 +41,8 @@ func TestConfigJson(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test, func(t *testing.T) {
|
||||
tmpfile, err := createTempFile(test, text)
|
||||
tmpfile, err := createTempFile(t, test, text)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(tmpfile)
|
||||
|
||||
var val struct {
|
||||
A string `json:"a"`
|
||||
@@ -75,6 +75,160 @@ func TestLoadFromJsonBytesArray(t *testing.T) {
|
||||
assert.EqualValues(t, []string{"foo", "bar"}, expect)
|
||||
}
|
||||
|
||||
func TestConfigJson5(t *testing.T) {
|
||||
// JSON5 with comments, trailing commas, and unquoted keys
|
||||
text := `{
|
||||
// This is a comment
|
||||
a: 'foo', // single quotes
|
||||
b: 1,
|
||||
c: "${FOO}",
|
||||
d: "abcd!@#$112", // trailing comma
|
||||
}`
|
||||
t.Setenv("FOO", "2")
|
||||
|
||||
tmpfile, err := createTempFile(t, ".json5", text)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var val struct {
|
||||
A string `json:"a"`
|
||||
B int `json:"b"`
|
||||
C string `json:"c"`
|
||||
D string `json:"d"`
|
||||
}
|
||||
MustLoad(tmpfile, &val)
|
||||
assert.Equal(t, "foo", val.A)
|
||||
assert.Equal(t, 1, val.B)
|
||||
assert.Equal(t, "${FOO}", val.C)
|
||||
assert.Equal(t, "abcd!@#$112", val.D)
|
||||
}
|
||||
|
||||
func TestConfigJsonStandardParser(t *testing.T) {
|
||||
// Standard JSON uses standard JSON parser (not JSON5) for backward compatibility
|
||||
text := `{
|
||||
"a": "foo",
|
||||
"b": 1,
|
||||
"c": "${FOO}",
|
||||
"d": "abcd!@#$112"
|
||||
}`
|
||||
t.Setenv("FOO", "2")
|
||||
|
||||
tmpfile, err := createTempFile(t, ".json", text)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var val struct {
|
||||
A string `json:"a"`
|
||||
B int `json:"b"`
|
||||
C string `json:"c"`
|
||||
D string `json:"d"`
|
||||
}
|
||||
MustLoad(tmpfile, &val)
|
||||
assert.Equal(t, "foo", val.A)
|
||||
assert.Equal(t, 1, val.B)
|
||||
assert.Equal(t, "${FOO}", val.C)
|
||||
assert.Equal(t, "abcd!@#$112", val.D)
|
||||
}
|
||||
|
||||
func TestConfigJsonLargeIntegers(t *testing.T) {
|
||||
// Test that .json files preserve large integer precision (backward compatibility)
|
||||
text := `{
|
||||
"id": 1234567890123456789,
|
||||
"timestamp": 9223372036854775807
|
||||
}`
|
||||
|
||||
tmpfile, err := createTempFile(t, ".json", text)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var val struct {
|
||||
ID int64 `json:"id"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
MustLoad(tmpfile, &val)
|
||||
assert.Equal(t, int64(1234567890123456789), val.ID)
|
||||
assert.Equal(t, int64(9223372036854775807), val.Timestamp)
|
||||
}
|
||||
|
||||
func TestConfigJson5Env(t *testing.T) {
|
||||
text := `{
|
||||
// Comment with env variable
|
||||
a: "foo",
|
||||
b: 1,
|
||||
c: "${FOO}",
|
||||
d: "abcd!@#$a12 3",
|
||||
}`
|
||||
t.Setenv("FOO", "2")
|
||||
|
||||
tmpfile, err := createTempFile(t, ".json5", text)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var val struct {
|
||||
A string `json:"a"`
|
||||
B int `json:"b"`
|
||||
C string `json:"c"`
|
||||
D string `json:"d"`
|
||||
}
|
||||
MustLoad(tmpfile, &val, UseEnv())
|
||||
assert.Equal(t, "foo", val.A)
|
||||
assert.Equal(t, 1, val.B)
|
||||
assert.Equal(t, "2", val.C)
|
||||
assert.Equal(t, "abcd!@# 3", val.D)
|
||||
}
|
||||
|
||||
func TestLoadFromJson5Bytes(t *testing.T) {
|
||||
// Test JSON5 features: comments, trailing commas, single quotes, unquoted keys
|
||||
input := []byte(`{
|
||||
// This is a comment
|
||||
users: [
|
||||
{name: 'foo'}, // trailing comma
|
||||
{Name: "bar"},
|
||||
],
|
||||
}`)
|
||||
var val struct {
|
||||
Users []struct {
|
||||
Name string
|
||||
}
|
||||
}
|
||||
|
||||
assert.NoError(t, LoadFromJson5Bytes(input, &val))
|
||||
var expect []string
|
||||
for _, user := range val.Users {
|
||||
expect = append(expect, user.Name)
|
||||
}
|
||||
assert.EqualValues(t, []string{"foo", "bar"}, expect)
|
||||
}
|
||||
|
||||
func TestLoadFromJson5BytesError(t *testing.T) {
|
||||
// Invalid JSON5 syntax
|
||||
input := []byte(`{a: foo}`) // unquoted string value (invalid)
|
||||
var val struct {
|
||||
A string
|
||||
}
|
||||
|
||||
assert.Error(t, LoadFromJson5Bytes(input, &val))
|
||||
}
|
||||
|
||||
func TestConfigJson5LargeIntegersLimitation(t *testing.T) {
|
||||
// Document that JSON5 has precision limitations for large integers (>2^53)
|
||||
// due to JavaScript number semantics. Users should use .json for configs with large IDs.
|
||||
text := `{
|
||||
// JSON5 converts numbers to float64, which loses precision for large integers
|
||||
id: 1234567890123456789
|
||||
}`
|
||||
|
||||
tmpfile, err := createTempFile(t, ".json5", text)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var val struct {
|
||||
ID int64 `json:"id"`
|
||||
}
|
||||
|
||||
// This will load; depending on the JSON5 implementation, large integers may lose precision.
|
||||
// This test documents that behavior without requiring loss of precision as an invariant.
|
||||
err = Load(tmpfile, &val)
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Logf("loaded JSON5 large integer id=%d (original 1234567890123456789)", val.ID)
|
||||
}
|
||||
|
||||
func TestConfigToml(t *testing.T) {
|
||||
text := `a = "foo"
|
||||
b = 1
|
||||
@@ -82,9 +236,8 @@ c = "${FOO}"
|
||||
d = "abcd!@#$112"
|
||||
`
|
||||
t.Setenv("FOO", "2")
|
||||
tmpfile, err := createTempFile(".toml", text)
|
||||
tmpfile, err := createTempFile(t, ".toml", text)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(tmpfile)
|
||||
|
||||
var val struct {
|
||||
A string `json:"a"`
|
||||
@@ -105,9 +258,8 @@ b = 1
|
||||
c = "FOO"
|
||||
d = "abcd"
|
||||
`
|
||||
tmpfile, err := createTempFile(".toml", text)
|
||||
tmpfile, err := createTempFile(t, ".toml", text)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(tmpfile)
|
||||
|
||||
var val struct {
|
||||
A string `json:"a"`
|
||||
@@ -127,9 +279,8 @@ func TestConfigWithLower(t *testing.T) {
|
||||
text := `a = "foo"
|
||||
b = 1
|
||||
`
|
||||
tmpfile, err := createTempFile(".toml", text)
|
||||
tmpfile, err := createTempFile(t, ".toml", text)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(tmpfile)
|
||||
|
||||
var val struct {
|
||||
A string `json:"a"`
|
||||
@@ -207,9 +358,8 @@ c = "${FOO}"
|
||||
d = "abcd!@#112"
|
||||
`
|
||||
t.Setenv("FOO", "2")
|
||||
tmpfile, err := createTempFile(".toml", text)
|
||||
tmpfile, err := createTempFile(t, ".toml", text)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(tmpfile)
|
||||
|
||||
var val struct {
|
||||
A string `json:"a"`
|
||||
@@ -241,9 +391,8 @@ func TestConfigJsonEnv(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test, func(t *testing.T) {
|
||||
tmpfile, err := createTempFile(test, text)
|
||||
tmpfile, err := createTempFile(t, test, text)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(tmpfile)
|
||||
|
||||
var val struct {
|
||||
A string `json:"a"`
|
||||
@@ -1217,11 +1366,44 @@ Name = "bar"
|
||||
})
|
||||
}
|
||||
|
||||
func Test_LoadBadConfig(t *testing.T) {
|
||||
type Config struct {
|
||||
Name string `json:"name,options=foo|bar"`
|
||||
}
|
||||
|
||||
file, err := createTempFile(t, ".json", `{"name": "baz"}`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var c Config
|
||||
err = Load(file, &c)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func Test_getFullName(t *testing.T) {
|
||||
assert.Equal(t, "a.b", getFullName("a", "b"))
|
||||
assert.Equal(t, "a", getFullName("", "a"))
|
||||
}
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
t.Run("normal config", func(t *testing.T) {
|
||||
var c mockConfig
|
||||
err := LoadFromJsonBytes([]byte(`{"val": "hello", "number": 8}`), &c)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("error no int", func(t *testing.T) {
|
||||
var c mockConfig
|
||||
err := LoadFromJsonBytes([]byte(`{"val": "hello"}`), &c)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("error no string", func(t *testing.T) {
|
||||
var c mockConfig
|
||||
err := LoadFromJsonBytes([]byte(`{"number": 8}`), &c)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_buildFieldsInfo(t *testing.T) {
|
||||
type ParentSt struct {
|
||||
Name string
|
||||
@@ -1311,13 +1493,13 @@ func Test_buildFieldsInfo(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createTempFile(ext, text string) (string, error) {
|
||||
func createTempFile(t *testing.T, ext, text string) (string, error) {
|
||||
tmpFile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := os.WriteFile(tmpFile.Name(), []byte(text), os.ModeTemporary); err != nil {
|
||||
if err = os.WriteFile(tmpFile.Name(), []byte(text), os.ModeTemporary); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -1326,5 +1508,265 @@ func createTempFile(ext, text string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = os.Remove(filename)
|
||||
})
|
||||
|
||||
return filename, nil
|
||||
}
|
||||
|
||||
type mockConfig struct {
|
||||
Val string
|
||||
Number int
|
||||
}
|
||||
|
||||
func (m mockConfig) Validate() error {
|
||||
if len(m.Val) == 0 {
|
||||
return errors.New("val is empty")
|
||||
}
|
||||
|
||||
if m.Number == 0 {
|
||||
return errors.New("number is zero")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestGetFullName(t *testing.T) {
|
||||
tests := []struct {
|
||||
parent string
|
||||
child string
|
||||
want string
|
||||
}{
|
||||
{"", "child", "child"},
|
||||
{"parent", "child", "parent.child"},
|
||||
{"a.b", "c", "a.b.c"},
|
||||
{"root", "nested.field", "root.nested.field"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.parent+"."+tt.child, func(t *testing.T) {
|
||||
got := getFullName(tt.parent, tt.child)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// validatorConfig is a test config that implements Validate() for testing validation behavior
|
||||
type validatorConfig struct {
|
||||
Value int `json:"value"`
|
||||
}
|
||||
|
||||
func (v *validatorConfig) Validate() error {
|
||||
if v.Value < 10 {
|
||||
return errors.New("value must be >= 10")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestLoadValidation_WithoutEnv tests that validation is called correctly in normal loading path
|
||||
func TestLoadValidation_WithoutEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
extension string
|
||||
content string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "json valid value",
|
||||
extension: ".json",
|
||||
content: `{"value": 15}`,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "json invalid value",
|
||||
extension: ".json",
|
||||
content: `{"value": 5}`,
|
||||
wantErr: true,
|
||||
errMsg: "value must be >= 10",
|
||||
},
|
||||
{
|
||||
name: "yaml valid value",
|
||||
extension: ".yaml",
|
||||
content: "value: 20\n",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "yaml invalid value",
|
||||
extension: ".yaml",
|
||||
content: "value: 3\n",
|
||||
wantErr: true,
|
||||
errMsg: "value must be >= 10",
|
||||
},
|
||||
{
|
||||
name: "toml valid value",
|
||||
extension: ".toml",
|
||||
content: "value = 100\n",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "toml invalid value",
|
||||
extension: ".toml",
|
||||
content: "value = 1\n",
|
||||
wantErr: true,
|
||||
errMsg: "value must be >= 10",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpfile, err := createTempFile(t, tt.extension, tt.content)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var cfg validatorConfig
|
||||
err = Load(tmpfile, &cfg)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadValidation_WithEnv tests that validation is called correctly with UseEnv() option
|
||||
func TestLoadValidation_WithEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
extension string
|
||||
content string
|
||||
envValue string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "json valid value with env",
|
||||
extension: ".json",
|
||||
content: `{"value": ${TEST_VALUE}}`,
|
||||
envValue: "25",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "json invalid value with env",
|
||||
extension: ".json",
|
||||
content: `{"value": ${TEST_VALUE}}`,
|
||||
envValue: "7",
|
||||
wantErr: true,
|
||||
errMsg: "value must be >= 10",
|
||||
},
|
||||
{
|
||||
name: "yaml valid value with env",
|
||||
extension: ".yaml",
|
||||
content: "value: ${TEST_VALUE}\n",
|
||||
envValue: "50",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "yaml invalid value with env",
|
||||
extension: ".yaml",
|
||||
content: "value: ${TEST_VALUE}\n",
|
||||
envValue: "2",
|
||||
wantErr: true,
|
||||
errMsg: "value must be >= 10",
|
||||
},
|
||||
{
|
||||
name: "toml valid value with env",
|
||||
extension: ".toml",
|
||||
content: "value = ${TEST_VALUE}\n",
|
||||
envValue: "99",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "toml invalid value with env",
|
||||
extension: ".toml",
|
||||
content: "value = ${TEST_VALUE}\n",
|
||||
envValue: "8",
|
||||
wantErr: true,
|
||||
errMsg: "value must be >= 10",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Setenv("TEST_VALUE", tt.envValue)
|
||||
|
||||
tmpfile, err := createTempFile(t, tt.extension, tt.content)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var cfg validatorConfig
|
||||
err = Load(tmpfile, &cfg, UseEnv())
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadValidation_Consistency verifies validation behavior is consistent between paths
|
||||
func TestLoadValidation_Consistency(t *testing.T) {
|
||||
// Test that both paths (with and without UseEnv) produce the same validation results
|
||||
const validValue = 15
|
||||
|
||||
formats := []struct {
|
||||
ext string
|
||||
invalid string
|
||||
valid string
|
||||
}{
|
||||
{".json", `{"value": 5}`, `{"value": 15}`},
|
||||
{".yaml", "value: 5\n", "value: 15\n"},
|
||||
{".toml", "value = 5\n", "value = 15\n"},
|
||||
}
|
||||
|
||||
for _, format := range formats {
|
||||
t.Run("invalid_"+format.ext, func(t *testing.T) {
|
||||
// Test without UseEnv()
|
||||
tmpfile1, err := createTempFile(t, format.ext, format.invalid)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var cfg1 validatorConfig
|
||||
err1 := Load(tmpfile1, &cfg1)
|
||||
|
||||
// Test with UseEnv()
|
||||
tmpfile2, err := createTempFile(t, format.ext, format.invalid)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var cfg2 validatorConfig
|
||||
err2 := Load(tmpfile2, &cfg2, UseEnv())
|
||||
|
||||
// Both should fail validation
|
||||
assert.Error(t, err1, "validation should fail without UseEnv()")
|
||||
assert.Error(t, err2, "validation should fail with UseEnv()")
|
||||
assert.Contains(t, err1.Error(), "value must be >= 10")
|
||||
assert.Contains(t, err2.Error(), "value must be >= 10")
|
||||
})
|
||||
|
||||
t.Run("valid_"+format.ext, func(t *testing.T) {
|
||||
// Test without UseEnv()
|
||||
tmpfile1, err := createTempFile(t, format.ext, format.valid)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var cfg1 validatorConfig
|
||||
err1 := Load(tmpfile1, &cfg1)
|
||||
|
||||
// Test with UseEnv()
|
||||
tmpfile2, err := createTempFile(t, format.ext, format.valid)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var cfg2 validatorConfig
|
||||
err2 := Load(tmpfile2, &cfg2, UseEnv())
|
||||
|
||||
// Both should pass validation
|
||||
assert.NoError(t, err1, "validation should pass without UseEnv()")
|
||||
assert.NoError(t, err2, "validation should pass with UseEnv()")
|
||||
assert.Equal(t, validValue, cfg1.Value)
|
||||
assert.Equal(t, validValue, cfg2.Value)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ func LoadProperties(filename string, opts ...Option) (Properties, error) {
|
||||
|
||||
raw := make(map[string]string)
|
||||
for i := range lines {
|
||||
pair := strings.Split(lines[i], "=")
|
||||
pair := strings.SplitN(lines[i], "=", 2)
|
||||
if len(pair) != 2 {
|
||||
// invalid property format
|
||||
return nil, &PropertyError{
|
||||
|
||||
@@ -92,3 +92,70 @@ func TestLoadBadFile(t *testing.T) {
|
||||
_, err := LoadProperties("nosuchfile")
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestProperties_valueWithEqualSymbols(t *testing.T) {
|
||||
text := `# test with equal symbols in value
|
||||
db.url=postgres://localhost:5432/db?param=value
|
||||
math.equation=a=b=c
|
||||
base64.data=SGVsbG8=World=Test=
|
||||
url.with.params=http://example.com?foo=bar&baz=qux
|
||||
empty.value=
|
||||
key.with.space = value = with = equals`
|
||||
tmpfile, err := fs.TempFilenameWithText(text)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(tmpfile)
|
||||
|
||||
props, err := LoadProperties(tmpfile)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "postgres://localhost:5432/db?param=value", props.GetString("db.url"))
|
||||
assert.Equal(t, "a=b=c", props.GetString("math.equation"))
|
||||
assert.Equal(t, "SGVsbG8=World=Test=", props.GetString("base64.data"))
|
||||
assert.Equal(t, "http://example.com?foo=bar&baz=qux", props.GetString("url.with.params"))
|
||||
assert.Equal(t, "", props.GetString("empty.value"))
|
||||
assert.Equal(t, "value = with = equals", props.GetString("key.with.space"))
|
||||
}
|
||||
|
||||
func TestProperties_edgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "no equal sign",
|
||||
content: "invalid line without equal",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "only equal sign",
|
||||
content: "=",
|
||||
wantErr: false, // "=" 会被解析为空 key 和空 value,len(pair) == 2,是合法的
|
||||
},
|
||||
{
|
||||
name: "empty key",
|
||||
content: "=value",
|
||||
wantErr: false, // 空 key 也会被 trim,但 len(pair) == 2 所以不会报错
|
||||
},
|
||||
{
|
||||
name: "equal at end",
|
||||
content: "key.name=",
|
||||
wantErr: false, // 空 value 是合法的
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpfile, err := fs.TempFilenameWithText(tt.content)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(tmpfile)
|
||||
|
||||
_, err = LoadProperties(tmpfile)
|
||||
if tt.wantErr {
|
||||
assert.NotNil(t, err, "expected error for case: %s", tt.name)
|
||||
} else {
|
||||
assert.Nil(t, err, "unexpected error for case: %s", tt.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
12
core/conf/validate.go
Normal file
12
core/conf/validate.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package conf
|
||||
|
||||
import "github.com/zeromicro/go-zero/core/validation"
|
||||
|
||||
// validate validates the value if it implements the Validator interface.
|
||||
func validate(v any) error {
|
||||
if val, ok := v.(validation.Validator); ok {
|
||||
return val.Validate()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
81
core/conf/validate_test.go
Normal file
81
core/conf/validate_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package conf
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockType int
|
||||
|
||||
func (m mockType) Validate() error {
|
||||
if m < 10 {
|
||||
return errors.New("invalid value")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type anotherMockType int
|
||||
|
||||
func Test_validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
v any
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "invalid",
|
||||
v: mockType(5),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
v: mockType(10),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "not validator",
|
||||
v: anotherMockType(5),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validate(tt.v)
|
||||
assert.Equal(t, tt.wantErr, err != nil)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockVal struct {
|
||||
}
|
||||
|
||||
func (m mockVal) Validate() error {
|
||||
return errors.New("invalid value")
|
||||
}
|
||||
|
||||
func Test_validateValPtr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
v any
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "invalid",
|
||||
v: mockVal{},
|
||||
},
|
||||
{
|
||||
name: "invalid value",
|
||||
v: &mockVal{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Error(t, validate(tt.v))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
package subscriber
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/discov"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
@@ -37,6 +40,7 @@ func NewEtcdSubscriber(conf EtcdConf) (Subscriber, error) {
|
||||
func buildSubOptions(conf EtcdConf) []discov.SubOption {
|
||||
opts := []discov.SubOption{
|
||||
discov.WithExactMatch(),
|
||||
discov.WithContainer(newContainer()),
|
||||
}
|
||||
|
||||
if len(conf.User) > 0 {
|
||||
@@ -65,3 +69,47 @@ func (s *etcdSubscriber) Value() (string, error) {
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
type container struct {
|
||||
value atomic.Value
|
||||
listeners []func()
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func newContainer() *container {
|
||||
return &container{}
|
||||
}
|
||||
|
||||
func (c *container) OnAdd(kv discov.KV) {
|
||||
c.value.Store([]string{kv.Val})
|
||||
c.notifyChange()
|
||||
}
|
||||
|
||||
func (c *container) OnDelete(_ discov.KV) {
|
||||
c.value.Store([]string(nil))
|
||||
c.notifyChange()
|
||||
}
|
||||
|
||||
func (c *container) AddListener(listener func()) {
|
||||
c.lock.Lock()
|
||||
c.listeners = append(c.listeners, listener)
|
||||
c.lock.Unlock()
|
||||
}
|
||||
|
||||
func (c *container) GetValues() []string {
|
||||
if vals, ok := c.value.Load().([]string); ok {
|
||||
return vals
|
||||
}
|
||||
|
||||
return []string(nil)
|
||||
}
|
||||
|
||||
func (c *container) notifyChange() {
|
||||
c.lock.Lock()
|
||||
listeners := append(([]func())(nil), c.listeners...)
|
||||
c.lock.Unlock()
|
||||
|
||||
for _, listener := range listeners {
|
||||
listener()
|
||||
}
|
||||
}
|
||||
|
||||
186
core/configcenter/subscriber/etcd_test.go
Normal file
186
core/configcenter/subscriber/etcd_test.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package subscriber
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/discov"
|
||||
)
|
||||
|
||||
const (
|
||||
actionAdd = iota
|
||||
actionDel
|
||||
)
|
||||
|
||||
func TestConfigCenterContainer(t *testing.T) {
|
||||
type action struct {
|
||||
act int
|
||||
key string
|
||||
val string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
do []action
|
||||
expect []string
|
||||
}{
|
||||
{
|
||||
name: "add one",
|
||||
do: []action{
|
||||
{
|
||||
act: actionAdd,
|
||||
key: "first",
|
||||
val: "a",
|
||||
},
|
||||
},
|
||||
expect: []string{
|
||||
"a",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add two",
|
||||
do: []action{
|
||||
{
|
||||
act: actionAdd,
|
||||
key: "first",
|
||||
val: "a",
|
||||
},
|
||||
{
|
||||
act: actionAdd,
|
||||
key: "second",
|
||||
val: "b",
|
||||
},
|
||||
},
|
||||
expect: []string{
|
||||
"b",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add two, delete one",
|
||||
do: []action{
|
||||
{
|
||||
act: actionAdd,
|
||||
key: "first",
|
||||
val: "a",
|
||||
},
|
||||
{
|
||||
act: actionAdd,
|
||||
key: "second",
|
||||
val: "b",
|
||||
},
|
||||
{
|
||||
act: actionDel,
|
||||
key: "first",
|
||||
},
|
||||
},
|
||||
expect: []string(nil),
|
||||
},
|
||||
{
|
||||
name: "add two, delete two",
|
||||
do: []action{
|
||||
{
|
||||
act: actionAdd,
|
||||
key: "first",
|
||||
val: "a",
|
||||
},
|
||||
{
|
||||
act: actionAdd,
|
||||
key: "second",
|
||||
val: "b",
|
||||
},
|
||||
{
|
||||
act: actionDel,
|
||||
key: "first",
|
||||
},
|
||||
{
|
||||
act: actionDel,
|
||||
key: "second",
|
||||
},
|
||||
},
|
||||
expect: []string(nil),
|
||||
},
|
||||
{
|
||||
name: "add two, dup values",
|
||||
do: []action{
|
||||
{
|
||||
act: actionAdd,
|
||||
key: "first",
|
||||
val: "a",
|
||||
},
|
||||
{
|
||||
act: actionAdd,
|
||||
key: "second",
|
||||
val: "b",
|
||||
},
|
||||
{
|
||||
act: actionAdd,
|
||||
key: "third",
|
||||
val: "a",
|
||||
},
|
||||
},
|
||||
expect: []string{"a"},
|
||||
},
|
||||
{
|
||||
name: "add three, dup values, delete two, add one",
|
||||
do: []action{
|
||||
{
|
||||
act: actionAdd,
|
||||
key: "first",
|
||||
val: "a",
|
||||
},
|
||||
{
|
||||
act: actionAdd,
|
||||
key: "second",
|
||||
val: "b",
|
||||
},
|
||||
{
|
||||
act: actionAdd,
|
||||
key: "third",
|
||||
val: "a",
|
||||
},
|
||||
{
|
||||
act: actionDel,
|
||||
key: "first",
|
||||
},
|
||||
{
|
||||
act: actionDel,
|
||||
key: "second",
|
||||
},
|
||||
{
|
||||
act: actionAdd,
|
||||
key: "forth",
|
||||
val: "c",
|
||||
},
|
||||
},
|
||||
expect: []string{"c"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var changed bool
|
||||
c := newContainer()
|
||||
c.AddListener(func() {
|
||||
changed = true
|
||||
})
|
||||
assert.Nil(t, c.GetValues())
|
||||
assert.False(t, changed)
|
||||
|
||||
for _, order := range test.do {
|
||||
if order.act == actionAdd {
|
||||
c.OnAdd(discov.KV{
|
||||
Key: order.key,
|
||||
Val: order.val,
|
||||
})
|
||||
} else {
|
||||
c.OnDelete(discov.KV{
|
||||
Key: order.key,
|
||||
Val: order.val,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, changed)
|
||||
assert.ElementsMatch(t, test.expect, c.GetValues())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,10 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: etcdclient.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -package internal -destination etcdclient_mock.go -source etcdclient.go EtcdClient
|
||||
//
|
||||
|
||||
// Package internal is a generated GoMock package.
|
||||
package internal
|
||||
@@ -8,35 +13,36 @@ import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
grpc "google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// MockEtcdClient is a mock of EtcdClient interface
|
||||
// MockEtcdClient is a mock of EtcdClient interface.
|
||||
type MockEtcdClient struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockEtcdClientMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockEtcdClientMockRecorder is the mock recorder for MockEtcdClient
|
||||
// MockEtcdClientMockRecorder is the mock recorder for MockEtcdClient.
|
||||
type MockEtcdClientMockRecorder struct {
|
||||
mock *MockEtcdClient
|
||||
}
|
||||
|
||||
// NewMockEtcdClient creates a new mock instance
|
||||
// NewMockEtcdClient creates a new mock instance.
|
||||
func NewMockEtcdClient(ctrl *gomock.Controller) *MockEtcdClient {
|
||||
mock := &MockEtcdClient{ctrl: ctrl}
|
||||
mock.recorder = &MockEtcdClientMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockEtcdClient) EXPECT() *MockEtcdClientMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// ActiveConnection mocks base method
|
||||
// ActiveConnection mocks base method.
|
||||
func (m *MockEtcdClient) ActiveConnection() *grpc.ClientConn {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ActiveConnection")
|
||||
@@ -44,13 +50,13 @@ func (m *MockEtcdClient) ActiveConnection() *grpc.ClientConn {
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ActiveConnection indicates an expected call of ActiveConnection
|
||||
// ActiveConnection indicates an expected call of ActiveConnection.
|
||||
func (mr *MockEtcdClientMockRecorder) ActiveConnection() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActiveConnection", reflect.TypeOf((*MockEtcdClient)(nil).ActiveConnection))
|
||||
}
|
||||
|
||||
// Close mocks base method
|
||||
// Close mocks base method.
|
||||
func (m *MockEtcdClient) Close() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
@@ -58,13 +64,13 @@ func (m *MockEtcdClient) Close() error {
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockEtcdClientMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockEtcdClient)(nil).Close))
|
||||
}
|
||||
|
||||
// Ctx mocks base method
|
||||
// Ctx mocks base method.
|
||||
func (m *MockEtcdClient) Ctx() context.Context {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Ctx")
|
||||
@@ -72,13 +78,13 @@ func (m *MockEtcdClient) Ctx() context.Context {
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Ctx indicates an expected call of Ctx
|
||||
// Ctx indicates an expected call of Ctx.
|
||||
func (mr *MockEtcdClientMockRecorder) Ctx() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ctx", reflect.TypeOf((*MockEtcdClient)(nil).Ctx))
|
||||
}
|
||||
|
||||
// Get mocks base method
|
||||
// Get mocks base method.
|
||||
func (m *MockEtcdClient) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, key}
|
||||
@@ -91,14 +97,14 @@ func (m *MockEtcdClient) Get(ctx context.Context, key string, opts ...clientv3.O
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Get indicates an expected call of Get
|
||||
// Get indicates an expected call of Get.
|
||||
func (mr *MockEtcdClientMockRecorder) Get(ctx, key any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, key}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockEtcdClient)(nil).Get), varargs...)
|
||||
}
|
||||
|
||||
// Grant mocks base method
|
||||
// Grant mocks base method.
|
||||
func (m *MockEtcdClient) Grant(ctx context.Context, ttl int64) (*clientv3.LeaseGrantResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Grant", ctx, ttl)
|
||||
@@ -107,13 +113,13 @@ func (m *MockEtcdClient) Grant(ctx context.Context, ttl int64) (*clientv3.LeaseG
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Grant indicates an expected call of Grant
|
||||
// Grant indicates an expected call of Grant.
|
||||
func (mr *MockEtcdClientMockRecorder) Grant(ctx, ttl any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Grant", reflect.TypeOf((*MockEtcdClient)(nil).Grant), ctx, ttl)
|
||||
}
|
||||
|
||||
// KeepAlive mocks base method
|
||||
// KeepAlive mocks base method.
|
||||
func (m *MockEtcdClient) KeepAlive(ctx context.Context, id clientv3.LeaseID) (<-chan *clientv3.LeaseKeepAliveResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "KeepAlive", ctx, id)
|
||||
@@ -122,13 +128,13 @@ func (m *MockEtcdClient) KeepAlive(ctx context.Context, id clientv3.LeaseID) (<-
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// KeepAlive indicates an expected call of KeepAlive
|
||||
// KeepAlive indicates an expected call of KeepAlive.
|
||||
func (mr *MockEtcdClientMockRecorder) KeepAlive(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeepAlive", reflect.TypeOf((*MockEtcdClient)(nil).KeepAlive), ctx, id)
|
||||
}
|
||||
|
||||
// Put mocks base method
|
||||
// Put mocks base method.
|
||||
func (m *MockEtcdClient) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, key, val}
|
||||
@@ -141,14 +147,14 @@ func (m *MockEtcdClient) Put(ctx context.Context, key, val string, opts ...clien
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Put indicates an expected call of Put
|
||||
// Put indicates an expected call of Put.
|
||||
func (mr *MockEtcdClientMockRecorder) Put(ctx, key, val any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, key, val}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockEtcdClient)(nil).Put), varargs...)
|
||||
}
|
||||
|
||||
// Revoke mocks base method
|
||||
// Revoke mocks base method.
|
||||
func (m *MockEtcdClient) Revoke(ctx context.Context, id clientv3.LeaseID) (*clientv3.LeaseRevokeResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Revoke", ctx, id)
|
||||
@@ -157,13 +163,13 @@ func (m *MockEtcdClient) Revoke(ctx context.Context, id clientv3.LeaseID) (*clie
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Revoke indicates an expected call of Revoke
|
||||
// Revoke indicates an expected call of Revoke.
|
||||
func (mr *MockEtcdClientMockRecorder) Revoke(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Revoke", reflect.TypeOf((*MockEtcdClient)(nil).Revoke), ctx, id)
|
||||
}
|
||||
|
||||
// Watch mocks base method
|
||||
// Watch mocks base method.
|
||||
func (m *MockEtcdClient) Watch(ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, key}
|
||||
@@ -175,7 +181,7 @@ func (m *MockEtcdClient) Watch(ctx context.Context, key string, opts ...clientv3
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Watch indicates an expected call of Watch
|
||||
// Watch indicates an expected call of Watch.
|
||||
func (mr *MockEtcdClientMockRecorder) Watch(ctx, key any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, key}, opts...)
|
||||
|
||||
@@ -10,22 +10,24 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.etcd.io/etcd/api/v3/v3rpc/rpctypes"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/contextx"
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/logc"
|
||||
"github.com/zeromicro/go-zero/core/mathx"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
"github.com/zeromicro/go-zero/core/threading"
|
||||
"go.etcd.io/etcd/api/v3/v3rpc/rpctypes"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
)
|
||||
|
||||
const coolDownDeviation = 0.05
|
||||
|
||||
var (
|
||||
registry = Registry{
|
||||
clusters: make(map[string]*cluster),
|
||||
}
|
||||
connManager = syncx.NewResourceManager()
|
||||
errClosed = errors.New("etcd monitor chan has been closed")
|
||||
connManager = syncx.NewResourceManager()
|
||||
coolDownUnstable = mathx.NewUnstable(coolDownDeviation)
|
||||
errClosed = errors.New("etcd monitor chan has been closed")
|
||||
)
|
||||
|
||||
// A Registry is a registry that manages the etcd client connections.
|
||||
@@ -41,33 +43,92 @@ func GetRegistry() *Registry {
|
||||
|
||||
// GetConn returns an etcd client connection associated with given endpoints.
|
||||
func (r *Registry) GetConn(endpoints []string) (EtcdClient, error) {
|
||||
c, _ := r.getCluster(endpoints)
|
||||
c, _ := r.getOrCreateCluster(endpoints)
|
||||
return c.getClient()
|
||||
}
|
||||
|
||||
// Monitor monitors the key on given etcd endpoints, notify with the given UpdateListener.
|
||||
func (r *Registry) Monitor(endpoints []string, key string, l UpdateListener, exactMatch bool) error {
|
||||
c, exists := r.getCluster(endpoints)
|
||||
func (r *Registry) Monitor(endpoints []string, key string, exactMatch bool, l UpdateListener) error {
|
||||
wkey := watchKey{
|
||||
key: key,
|
||||
exactMatch: exactMatch,
|
||||
}
|
||||
|
||||
c, exists := r.getOrCreateCluster(endpoints)
|
||||
// if exists, the existing values should be updated to the listener.
|
||||
if exists {
|
||||
kvs := c.getCurrent(key)
|
||||
for _, kv := range kvs {
|
||||
l.OnAdd(kv)
|
||||
c.lock.Lock()
|
||||
watcher, ok := c.watchers[wkey]
|
||||
if ok {
|
||||
watcher.listeners = append(watcher.listeners, l)
|
||||
}
|
||||
c.lock.Unlock()
|
||||
|
||||
if ok {
|
||||
kvs := c.getCurrent(wkey)
|
||||
for _, kv := range kvs {
|
||||
l.OnAdd(kv)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return c.monitor(key, l, exactMatch)
|
||||
return c.monitor(wkey, l)
|
||||
}
|
||||
|
||||
func (r *Registry) getCluster(endpoints []string) (c *cluster, exists bool) {
|
||||
func (r *Registry) Unmonitor(endpoints []string, key string, exactMatch bool, l UpdateListener) {
|
||||
c, exists := r.getCluster(endpoints)
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
wkey := watchKey{
|
||||
key: key,
|
||||
exactMatch: exactMatch,
|
||||
}
|
||||
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
watcher, ok := c.watchers[wkey]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
for i, listener := range watcher.listeners {
|
||||
if listener == l {
|
||||
watcher.listeners = append(watcher.listeners[:i], watcher.listeners[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(watcher.listeners) == 0 {
|
||||
if watcher.cancel != nil {
|
||||
watcher.cancel()
|
||||
}
|
||||
delete(c.watchers, wkey)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Registry) getCluster(endpoints []string) (*cluster, bool) {
|
||||
clusterKey := getClusterKey(endpoints)
|
||||
|
||||
r.lock.RLock()
|
||||
c, exists = r.clusters[clusterKey]
|
||||
c, ok := r.clusters[clusterKey]
|
||||
r.lock.RUnlock()
|
||||
|
||||
return c, ok
|
||||
}
|
||||
|
||||
func (r *Registry) getOrCreateCluster(endpoints []string) (c *cluster, exists bool) {
|
||||
c, exists = r.getCluster(endpoints)
|
||||
if !exists {
|
||||
clusterKey := getClusterKey(endpoints)
|
||||
|
||||
r.lock.Lock()
|
||||
defer r.lock.Unlock()
|
||||
|
||||
// double-check locking
|
||||
c, exists = r.clusters[clusterKey]
|
||||
if !exists {
|
||||
@@ -79,30 +140,51 @@ func (r *Registry) getCluster(endpoints []string) (c *cluster, exists bool) {
|
||||
return
|
||||
}
|
||||
|
||||
type cluster struct {
|
||||
endpoints []string
|
||||
key string
|
||||
values map[string]map[string]string
|
||||
listeners map[string][]UpdateListener
|
||||
watchGroup *threading.RoutineGroup
|
||||
done chan lang.PlaceholderType
|
||||
lock sync.RWMutex
|
||||
exactMatch bool
|
||||
}
|
||||
type (
|
||||
watchKey struct {
|
||||
key string
|
||||
exactMatch bool
|
||||
}
|
||||
|
||||
watchValue struct {
|
||||
listeners []UpdateListener
|
||||
values map[string]string
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
cluster struct {
|
||||
endpoints []string
|
||||
key string
|
||||
watchers map[watchKey]*watchValue
|
||||
watchGroup *threading.RoutineGroup
|
||||
done chan lang.PlaceholderType
|
||||
lock sync.RWMutex
|
||||
}
|
||||
)
|
||||
|
||||
func newCluster(endpoints []string) *cluster {
|
||||
return &cluster{
|
||||
endpoints: endpoints,
|
||||
key: getClusterKey(endpoints),
|
||||
values: make(map[string]map[string]string),
|
||||
listeners: make(map[string][]UpdateListener),
|
||||
watchers: make(map[watchKey]*watchValue),
|
||||
watchGroup: threading.NewRoutineGroup(),
|
||||
done: make(chan lang.PlaceholderType),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cluster) context(cli EtcdClient) context.Context {
|
||||
return contextx.ValueOnlyFrom(cli.Ctx())
|
||||
func (c *cluster) addListener(key watchKey, l UpdateListener) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
watcher, ok := c.watchers[key]
|
||||
if ok {
|
||||
watcher.listeners = append(watcher.listeners, l)
|
||||
return
|
||||
}
|
||||
|
||||
val := newWatchValue()
|
||||
val.listeners = []UpdateListener{l}
|
||||
c.watchers[key] = val
|
||||
}
|
||||
|
||||
func (c *cluster) getClient() (EtcdClient, error) {
|
||||
@@ -116,12 +198,17 @@ func (c *cluster) getClient() (EtcdClient, error) {
|
||||
return val.(EtcdClient), nil
|
||||
}
|
||||
|
||||
func (c *cluster) getCurrent(key string) []KV {
|
||||
func (c *cluster) getCurrent(key watchKey) []KV {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
var kvs []KV
|
||||
for k, v := range c.values[key] {
|
||||
watcher, ok := c.watchers[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
kvs := make([]KV, 0, len(watcher.values))
|
||||
for k, v := range watcher.values {
|
||||
kvs = append(kvs, KV{
|
||||
Key: k,
|
||||
Val: v,
|
||||
@@ -131,43 +218,23 @@ func (c *cluster) getCurrent(key string) []KV {
|
||||
return kvs
|
||||
}
|
||||
|
||||
func (c *cluster) handleChanges(key string, kvs []KV) {
|
||||
var add []KV
|
||||
var remove []KV
|
||||
|
||||
func (c *cluster) handleChanges(key watchKey, kvs []KV) {
|
||||
c.lock.Lock()
|
||||
listeners := append([]UpdateListener(nil), c.listeners[key]...)
|
||||
vals, ok := c.values[key]
|
||||
watcher, ok := c.watchers[key]
|
||||
if !ok {
|
||||
add = kvs
|
||||
vals = make(map[string]string)
|
||||
for _, kv := range kvs {
|
||||
vals[kv.Key] = kv.Val
|
||||
}
|
||||
c.values[key] = vals
|
||||
} else {
|
||||
m := make(map[string]string)
|
||||
for _, kv := range kvs {
|
||||
m[kv.Key] = kv.Val
|
||||
}
|
||||
for k, v := range vals {
|
||||
if val, ok := m[k]; !ok || v != val {
|
||||
remove = append(remove, KV{
|
||||
Key: k,
|
||||
Val: v,
|
||||
})
|
||||
}
|
||||
}
|
||||
for k, v := range m {
|
||||
if val, ok := vals[k]; !ok || v != val {
|
||||
add = append(add, KV{
|
||||
Key: k,
|
||||
Val: v,
|
||||
})
|
||||
}
|
||||
}
|
||||
c.values[key] = m
|
||||
c.lock.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
listeners := append([]UpdateListener(nil), watcher.listeners...)
|
||||
// watcher.values cannot be nil
|
||||
vals := watcher.values
|
||||
newVals := make(map[string]string, len(kvs)+len(vals))
|
||||
for _, kv := range kvs {
|
||||
newVals[kv.Key] = kv.Val
|
||||
}
|
||||
add, remove := calculateChanges(vals, newVals)
|
||||
watcher.values = newVals
|
||||
c.lock.Unlock()
|
||||
|
||||
for _, kv := range add {
|
||||
@@ -182,20 +249,22 @@ func (c *cluster) handleChanges(key string, kvs []KV) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cluster) handleWatchEvents(key string, events []*clientv3.Event) {
|
||||
func (c *cluster) handleWatchEvents(ctx context.Context, key watchKey, events []*clientv3.Event) {
|
||||
c.lock.RLock()
|
||||
listeners := append([]UpdateListener(nil), c.listeners[key]...)
|
||||
watcher, ok := c.watchers[key]
|
||||
if !ok {
|
||||
c.lock.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
listeners := append([]UpdateListener(nil), watcher.listeners...)
|
||||
c.lock.RUnlock()
|
||||
|
||||
for _, ev := range events {
|
||||
switch ev.Type {
|
||||
case clientv3.EventTypePut:
|
||||
c.lock.Lock()
|
||||
if vals, ok := c.values[key]; ok {
|
||||
vals[string(ev.Kv.Key)] = string(ev.Kv.Value)
|
||||
} else {
|
||||
c.values[key] = map[string]string{string(ev.Kv.Key): string(ev.Kv.Value)}
|
||||
}
|
||||
watcher.values[string(ev.Kv.Key)] = string(ev.Kv.Value)
|
||||
c.lock.Unlock()
|
||||
for _, l := range listeners {
|
||||
l.OnAdd(KV{
|
||||
@@ -205,9 +274,7 @@ func (c *cluster) handleWatchEvents(key string, events []*clientv3.Event) {
|
||||
}
|
||||
case clientv3.EventTypeDelete:
|
||||
c.lock.Lock()
|
||||
if vals, ok := c.values[key]; ok {
|
||||
delete(vals, string(ev.Kv.Key))
|
||||
}
|
||||
delete(watcher.values, string(ev.Kv.Key))
|
||||
c.lock.Unlock()
|
||||
for _, l := range listeners {
|
||||
l.OnDelete(KV{
|
||||
@@ -216,20 +283,20 @@ func (c *cluster) handleWatchEvents(key string, events []*clientv3.Event) {
|
||||
})
|
||||
}
|
||||
default:
|
||||
logx.Errorf("Unknown event type: %v", ev.Type)
|
||||
logc.Errorf(ctx, "Unknown event type: %v", ev.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cluster) load(cli EtcdClient, key string) int64 {
|
||||
func (c *cluster) load(cli EtcdClient, key watchKey) int64 {
|
||||
var resp *clientv3.GetResponse
|
||||
for {
|
||||
var err error
|
||||
ctx, cancel := context.WithTimeout(c.context(cli), RequestTimeout)
|
||||
if c.exactMatch {
|
||||
resp, err = cli.Get(ctx, key)
|
||||
ctx, cancel := context.WithTimeout(cli.Ctx(), RequestTimeout)
|
||||
if key.exactMatch {
|
||||
resp, err = cli.Get(ctx, key.key)
|
||||
} else {
|
||||
resp, err = cli.Get(ctx, makeKeyPrefix(key), clientv3.WithPrefix())
|
||||
resp, err = cli.Get(ctx, makeKeyPrefix(key.key), clientv3.WithPrefix())
|
||||
}
|
||||
|
||||
cancel()
|
||||
@@ -237,11 +304,11 @@ func (c *cluster) load(cli EtcdClient, key string) int64 {
|
||||
break
|
||||
}
|
||||
|
||||
logx.Errorf("%s, key is %s", err.Error(), key)
|
||||
time.Sleep(coolDownInterval)
|
||||
logc.Errorf(cli.Ctx(), "%s, key: %s, exactMatch: %t", err.Error(), key.key, key.exactMatch)
|
||||
time.Sleep(coolDownUnstable.AroundDuration(coolDownInterval))
|
||||
}
|
||||
|
||||
var kvs []KV
|
||||
kvs := make([]KV, 0, len(resp.Kvs))
|
||||
for _, ev := range resp.Kvs {
|
||||
kvs = append(kvs, KV{
|
||||
Key: string(ev.Key),
|
||||
@@ -254,17 +321,13 @@ func (c *cluster) load(cli EtcdClient, key string) int64 {
|
||||
return resp.Header.Revision
|
||||
}
|
||||
|
||||
func (c *cluster) monitor(key string, l UpdateListener, exactMatch bool) error {
|
||||
c.lock.Lock()
|
||||
c.listeners[key] = append(c.listeners[key], l)
|
||||
c.exactMatch = exactMatch
|
||||
c.lock.Unlock()
|
||||
|
||||
func (c *cluster) monitor(key watchKey, l UpdateListener) error {
|
||||
cli, err := c.getClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.addListener(key, l)
|
||||
rev := c.load(cli, key)
|
||||
c.watchGroup.Run(func() {
|
||||
c.watch(cli, key, rev)
|
||||
@@ -286,16 +349,22 @@ func (c *cluster) newClient() (EtcdClient, error) {
|
||||
|
||||
func (c *cluster) reload(cli EtcdClient) {
|
||||
c.lock.Lock()
|
||||
// cancel the previous watches
|
||||
close(c.done)
|
||||
c.watchGroup.Wait()
|
||||
keys := make([]watchKey, 0, len(c.watchers))
|
||||
for wk, wval := range c.watchers {
|
||||
keys = append(keys, wk)
|
||||
if wval.cancel != nil {
|
||||
wval.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
c.done = make(chan lang.PlaceholderType)
|
||||
c.watchGroup = threading.NewRoutineGroup()
|
||||
var keys []string
|
||||
for k := range c.listeners {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
c.lock.Unlock()
|
||||
|
||||
// start new watches
|
||||
for _, key := range keys {
|
||||
k := key
|
||||
c.watchGroup.Run(func() {
|
||||
@@ -305,7 +374,7 @@ func (c *cluster) reload(cli EtcdClient) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cluster) watch(cli EtcdClient, key string, rev int64) {
|
||||
func (c *cluster) watch(cli EtcdClient, key watchKey, rev int64) {
|
||||
for {
|
||||
err := c.watchStream(cli, key, rev)
|
||||
if err == nil {
|
||||
@@ -313,30 +382,18 @@ func (c *cluster) watch(cli EtcdClient, key string, rev int64) {
|
||||
}
|
||||
|
||||
if rev != 0 && errors.Is(err, rpctypes.ErrCompacted) {
|
||||
logx.Errorf("etcd watch stream has been compacted, try to reload, rev %d", rev)
|
||||
logc.Errorf(cli.Ctx(), "etcd watch stream has been compacted, try to reload, rev %d", rev)
|
||||
rev = c.load(cli, key)
|
||||
}
|
||||
|
||||
// log the error and retry
|
||||
logx.Error(err)
|
||||
// log the error and retry with cooldown to prevent CPU/disk exhaustion
|
||||
logc.Error(cli.Ctx(), err)
|
||||
time.Sleep(coolDownUnstable.AroundDuration(coolDownInterval))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cluster) watchStream(cli EtcdClient, key string, rev int64) error {
|
||||
var (
|
||||
rch clientv3.WatchChan
|
||||
ops []clientv3.OpOption
|
||||
watchKey = key
|
||||
)
|
||||
if !c.exactMatch {
|
||||
watchKey = makeKeyPrefix(key)
|
||||
ops = append(ops, clientv3.WithPrefix())
|
||||
}
|
||||
if rev != 0 {
|
||||
ops = append(ops, clientv3.WithRev(rev+1))
|
||||
}
|
||||
|
||||
rch = cli.Watch(clientv3.WithRequireLeader(c.context(cli)), watchKey, ops...)
|
||||
func (c *cluster) watchStream(cli EtcdClient, key watchKey, rev int64) error {
|
||||
ctx, rch := c.setupWatch(cli, key, rev)
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -351,13 +408,47 @@ func (c *cluster) watchStream(cli EtcdClient, key string, rev int64) error {
|
||||
return fmt.Errorf("etcd monitor chan error: %w", wresp.Err())
|
||||
}
|
||||
|
||||
c.handleWatchEvents(key, wresp.Events)
|
||||
c.handleWatchEvents(ctx, key, wresp.Events)
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-c.done:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cluster) setupWatch(cli EtcdClient, key watchKey, rev int64) (context.Context, clientv3.WatchChan) {
|
||||
var (
|
||||
rch clientv3.WatchChan
|
||||
ops []clientv3.OpOption
|
||||
wkey = key.key
|
||||
)
|
||||
|
||||
if !key.exactMatch {
|
||||
wkey = makeKeyPrefix(key.key)
|
||||
ops = append(ops, clientv3.WithPrefix())
|
||||
}
|
||||
if rev != 0 {
|
||||
ops = append(ops, clientv3.WithRev(rev+1))
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cli.Ctx())
|
||||
|
||||
c.lock.Lock()
|
||||
if watcher, ok := c.watchers[key]; ok {
|
||||
watcher.cancel = cancel
|
||||
} else {
|
||||
val := newWatchValue()
|
||||
val.cancel = cancel
|
||||
c.watchers[key] = val
|
||||
}
|
||||
c.lock.Unlock()
|
||||
|
||||
rch = cli.Watch(clientv3.WithRequireLeader(ctx), wkey, ops...)
|
||||
|
||||
return ctx, rch
|
||||
}
|
||||
|
||||
func (c *cluster) watchConnState(cli EtcdClient) {
|
||||
watcher := newStateWatcher()
|
||||
watcher.addListener(func() {
|
||||
@@ -386,6 +477,28 @@ func DialClient(endpoints []string) (EtcdClient, error) {
|
||||
return clientv3.New(cfg)
|
||||
}
|
||||
|
||||
func calculateChanges(oldVals, newVals map[string]string) (add, remove []KV) {
|
||||
for k, v := range newVals {
|
||||
if val, ok := oldVals[k]; !ok || v != val {
|
||||
add = append(add, KV{
|
||||
Key: k,
|
||||
Val: v,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range oldVals {
|
||||
if val, ok := newVals[k]; !ok || v != val {
|
||||
remove = append(remove, KV{
|
||||
Key: k,
|
||||
Val: v,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return add, remove
|
||||
}
|
||||
|
||||
func getClusterKey(endpoints []string) string {
|
||||
sort.Strings(endpoints)
|
||||
return strings.Join(endpoints, endpointsSeparator)
|
||||
@@ -394,3 +507,10 @@ func getClusterKey(endpoints []string) string {
|
||||
func makeKeyPrefix(key string) string {
|
||||
return fmt.Sprintf("%s%c", key, Delimiter)
|
||||
}
|
||||
|
||||
// NewClient returns a watchValue that make sure values are not nil.
|
||||
func newWatchValue() *watchValue {
|
||||
return &watchValue{
|
||||
values: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,16 +7,17 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/contextx"
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
"github.com/zeromicro/go-zero/core/threading"
|
||||
"go.etcd.io/etcd/api/v3/etcdserverpb"
|
||||
"go.etcd.io/etcd/api/v3/mvccpb"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"go.etcd.io/etcd/client/v3/mock/mockserver"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
var mockLock sync.Mutex
|
||||
@@ -38,9 +39,9 @@ func setMockClient(cli EtcdClient) func() {
|
||||
|
||||
func TestGetCluster(t *testing.T) {
|
||||
AddAccount([]string{"first"}, "foo", "bar")
|
||||
c1, _ := GetRegistry().getCluster([]string{"first"})
|
||||
c2, _ := GetRegistry().getCluster([]string{"second"})
|
||||
c3, _ := GetRegistry().getCluster([]string{"first"})
|
||||
c1, _ := GetRegistry().getOrCreateCluster([]string{"first"})
|
||||
c2, _ := GetRegistry().getOrCreateCluster([]string{"second"})
|
||||
c3, _ := GetRegistry().getOrCreateCluster([]string{"first"})
|
||||
assert.Equal(t, c1, c3)
|
||||
assert.NotEqual(t, c1, c2)
|
||||
}
|
||||
@@ -50,6 +51,36 @@ func TestGetClusterKey(t *testing.T) {
|
||||
getClusterKey([]string{"remotehost:5678", "localhost:1234"}))
|
||||
}
|
||||
|
||||
func TestUnmonitor(t *testing.T) {
|
||||
t.Run("no listener", func(t *testing.T) {
|
||||
reg := &Registry{
|
||||
clusters: map[string]*cluster{},
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
reg.Unmonitor([]string{"any"}, "any", false, nil)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("no value", func(t *testing.T) {
|
||||
reg := &Registry{
|
||||
clusters: map[string]*cluster{
|
||||
"any": {
|
||||
watchers: map[watchKey]*watchValue{
|
||||
{
|
||||
key: "any",
|
||||
}: {
|
||||
values: map[string]string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
reg.Unmonitor([]string{"any"}, "another", false, nil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestCluster_HandleChanges(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
l := NewMockUpdateListener(ctrl)
|
||||
@@ -78,8 +109,14 @@ func TestCluster_HandleChanges(t *testing.T) {
|
||||
Val: "4",
|
||||
})
|
||||
c := newCluster([]string{"any"})
|
||||
c.listeners["any"] = []UpdateListener{l}
|
||||
c.handleChanges("any", []KV{
|
||||
key := watchKey{
|
||||
key: "any",
|
||||
exactMatch: false,
|
||||
}
|
||||
c.watchers[key] = &watchValue{
|
||||
listeners: []UpdateListener{l},
|
||||
}
|
||||
c.handleChanges(key, []KV{
|
||||
{
|
||||
Key: "first",
|
||||
Val: "1",
|
||||
@@ -92,8 +129,8 @@ func TestCluster_HandleChanges(t *testing.T) {
|
||||
assert.EqualValues(t, map[string]string{
|
||||
"first": "1",
|
||||
"second": "2",
|
||||
}, c.values["any"])
|
||||
c.handleChanges("any", []KV{
|
||||
}, c.watchers[key].values)
|
||||
c.handleChanges(key, []KV{
|
||||
{
|
||||
Key: "third",
|
||||
Val: "3",
|
||||
@@ -106,7 +143,7 @@ func TestCluster_HandleChanges(t *testing.T) {
|
||||
assert.EqualValues(t, map[string]string{
|
||||
"third": "3",
|
||||
"fourth": "4",
|
||||
}, c.values["any"])
|
||||
}, c.watchers[key].values)
|
||||
}
|
||||
|
||||
func TestCluster_Load(t *testing.T) {
|
||||
@@ -126,9 +163,11 @@ func TestCluster_Load(t *testing.T) {
|
||||
}, nil)
|
||||
cli.EXPECT().Ctx().Return(context.Background())
|
||||
c := &cluster{
|
||||
values: make(map[string]map[string]string),
|
||||
watchers: make(map[watchKey]*watchValue),
|
||||
}
|
||||
c.load(cli, "any")
|
||||
c.load(cli, watchKey{
|
||||
key: "any",
|
||||
})
|
||||
}
|
||||
|
||||
func TestCluster_Watch(t *testing.T) {
|
||||
@@ -160,11 +199,16 @@ func TestCluster_Watch(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
c := &cluster{
|
||||
listeners: make(map[string][]UpdateListener),
|
||||
values: make(map[string]map[string]string),
|
||||
watchers: make(map[watchKey]*watchValue),
|
||||
}
|
||||
key := watchKey{
|
||||
key: "any",
|
||||
}
|
||||
listener := NewMockUpdateListener(ctrl)
|
||||
c.listeners["any"] = []UpdateListener{listener}
|
||||
c.watchers[key] = &watchValue{
|
||||
listeners: []UpdateListener{listener},
|
||||
values: make(map[string]string),
|
||||
}
|
||||
listener.EXPECT().OnAdd(gomock.Any()).Do(func(kv KV) {
|
||||
assert.Equal(t, "hello", kv.Key)
|
||||
assert.Equal(t, "world", kv.Val)
|
||||
@@ -173,7 +217,7 @@ func TestCluster_Watch(t *testing.T) {
|
||||
listener.EXPECT().OnDelete(gomock.Any()).Do(func(_ any) {
|
||||
wg.Done()
|
||||
}).MaxTimes(1)
|
||||
go c.watch(cli, "any", 0)
|
||||
go c.watch(cli, key, 0)
|
||||
ch <- clientv3.WatchResponse{
|
||||
Events: []*clientv3.Event{
|
||||
{
|
||||
@@ -211,17 +255,111 @@ func TestClusterWatch_RespFailures(t *testing.T) {
|
||||
ch := make(chan clientv3.WatchResponse)
|
||||
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes()
|
||||
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
|
||||
c := new(cluster)
|
||||
c := &cluster{
|
||||
watchers: make(map[watchKey]*watchValue),
|
||||
}
|
||||
c.done = make(chan lang.PlaceholderType)
|
||||
go func() {
|
||||
ch <- resp
|
||||
close(c.done)
|
||||
}()
|
||||
c.watch(cli, "any", 0)
|
||||
key := watchKey{
|
||||
key: "any",
|
||||
}
|
||||
c.watch(cli, key, 0)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCluster_getCurrent(t *testing.T) {
|
||||
t.Run("no value", func(t *testing.T) {
|
||||
c := &cluster{
|
||||
watchers: map[watchKey]*watchValue{
|
||||
{
|
||||
key: "any",
|
||||
}: {
|
||||
values: map[string]string{},
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.Nil(t, c.getCurrent(watchKey{
|
||||
key: "another",
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
func TestCluster_handleWatchEvents(t *testing.T) {
|
||||
t.Run("no value", func(t *testing.T) {
|
||||
c := &cluster{
|
||||
watchers: map[watchKey]*watchValue{
|
||||
{
|
||||
key: "any",
|
||||
}: {
|
||||
values: map[string]string{},
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
c.handleWatchEvents(context.Background(), watchKey{
|
||||
key: "another",
|
||||
}, nil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestCluster_addListener(t *testing.T) {
|
||||
t.Run("has listener", func(t *testing.T) {
|
||||
c := &cluster{
|
||||
watchers: map[watchKey]*watchValue{
|
||||
{
|
||||
key: "any",
|
||||
}: {
|
||||
listeners: make([]UpdateListener, 0),
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
c.addListener(watchKey{
|
||||
key: "any",
|
||||
}, nil)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("no listener", func(t *testing.T) {
|
||||
c := &cluster{
|
||||
watchers: map[watchKey]*watchValue{
|
||||
{
|
||||
key: "any",
|
||||
}: {
|
||||
listeners: make([]UpdateListener, 0),
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
c.addListener(watchKey{
|
||||
key: "another",
|
||||
}, nil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestCluster_reload(t *testing.T) {
|
||||
c := &cluster{
|
||||
watchers: map[watchKey]*watchValue{},
|
||||
watchGroup: threading.NewRoutineGroup(),
|
||||
done: make(chan lang.PlaceholderType),
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
cli := NewMockEtcdClient(ctrl)
|
||||
restore := setMockClient(cli)
|
||||
defer restore()
|
||||
assert.NotPanics(t, func() {
|
||||
c.reload(cli)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClusterWatch_CloseChan(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
@@ -231,13 +369,17 @@ func TestClusterWatch_CloseChan(t *testing.T) {
|
||||
ch := make(chan clientv3.WatchResponse)
|
||||
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes()
|
||||
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
|
||||
c := new(cluster)
|
||||
c := &cluster{
|
||||
watchers: make(map[watchKey]*watchValue),
|
||||
}
|
||||
c.done = make(chan lang.PlaceholderType)
|
||||
go func() {
|
||||
close(ch)
|
||||
close(c.done)
|
||||
}()
|
||||
c.watch(cli, "any", 0)
|
||||
c.watch(cli, watchKey{
|
||||
key: "any",
|
||||
}, 0)
|
||||
}
|
||||
|
||||
func TestValueOnlyContext(t *testing.T) {
|
||||
@@ -280,16 +422,125 @@ func TestRegistry_Monitor(t *testing.T) {
|
||||
GetRegistry().lock.Lock()
|
||||
GetRegistry().clusters = map[string]*cluster{
|
||||
getClusterKey(endpoints): {
|
||||
listeners: map[string][]UpdateListener{},
|
||||
values: map[string]map[string]string{
|
||||
"foo": {
|
||||
"bar": "baz",
|
||||
watchers: map[watchKey]*watchValue{
|
||||
{
|
||||
key: "foo",
|
||||
exactMatch: true,
|
||||
}: {
|
||||
values: map[string]string{
|
||||
"bar": "baz",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
GetRegistry().lock.Unlock()
|
||||
assert.Error(t, GetRegistry().Monitor(endpoints, "foo", new(mockListener), false))
|
||||
assert.Error(t, GetRegistry().Monitor(endpoints, "foo", false, new(mockListener)))
|
||||
}
|
||||
|
||||
func TestRegistry_Unmonitor(t *testing.T) {
|
||||
svr, err := mockserver.StartMockServers(1)
|
||||
assert.NoError(t, err)
|
||||
svr.StartAt(0)
|
||||
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
endpoints := []string{svr.Servers[0].Address}
|
||||
GetRegistry().lock.Lock()
|
||||
GetRegistry().clusters = map[string]*cluster{
|
||||
getClusterKey(endpoints): {
|
||||
watchers: map[watchKey]*watchValue{
|
||||
{
|
||||
key: "foo",
|
||||
exactMatch: true,
|
||||
}: {
|
||||
values: map[string]string{
|
||||
"bar": "baz",
|
||||
},
|
||||
cancel: cancel,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
GetRegistry().lock.Unlock()
|
||||
l := new(mockListener)
|
||||
assert.NoError(t, GetRegistry().Monitor(endpoints, "foo", true, l))
|
||||
watchVals := GetRegistry().clusters[getClusterKey(endpoints)].watchers[watchKey{
|
||||
key: "foo",
|
||||
exactMatch: true,
|
||||
}]
|
||||
assert.Equal(t, 1, len(watchVals.listeners))
|
||||
GetRegistry().Unmonitor(endpoints, "foo", true, l)
|
||||
watchVals = GetRegistry().clusters[getClusterKey(endpoints)].watchers[watchKey{
|
||||
key: "foo",
|
||||
exactMatch: true,
|
||||
}]
|
||||
assert.Nil(t, watchVals)
|
||||
}
|
||||
|
||||
// TestCluster_ConcurrentMonitor tests the race condition fix in setupWatch
|
||||
// This test specifically covers the scenario from issue #5394 where:
|
||||
// - addListener() writes to the watchers map (with lock)
|
||||
// - setupWatch() reads from the watchers map (now with lock after fix)
|
||||
// Running with -race flag will detect any race conditions
|
||||
func TestCluster_ConcurrentMonitor(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
cli := NewMockEtcdClient(ctrl)
|
||||
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
|
||||
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(make(chan clientv3.WatchResponse)).AnyTimes()
|
||||
|
||||
c := &cluster{
|
||||
endpoints: []string{"localhost:2379"},
|
||||
key: "test-cluster",
|
||||
watchers: make(map[watchKey]*watchValue),
|
||||
watchGroup: threading.NewRoutineGroup(),
|
||||
done: make(chan lang.PlaceholderType),
|
||||
lock: sync.RWMutex{},
|
||||
}
|
||||
|
||||
// Spawn multiple concurrent operations that simulate the race condition:
|
||||
// - Some goroutines call addListener (write to map)
|
||||
// - Some goroutines call setupWatch (read from map)
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 20
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
keys := []watchKey{
|
||||
{key: "key-0", exactMatch: false},
|
||||
{key: "key-1", exactMatch: false},
|
||||
{key: "key-2", exactMatch: false},
|
||||
}
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
idx := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
key := keys[idx%len(keys)]
|
||||
|
||||
if idx%2 == 0 {
|
||||
// Half the goroutines add listeners (write operation)
|
||||
c.addListener(key, &mockListener{})
|
||||
} else {
|
||||
// Half the goroutines setup watches (read operation)
|
||||
_, _ = c.setupWatch(cli, key, 0)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
wg.Wait()
|
||||
|
||||
// Verify that watchers were correctly added
|
||||
c.lock.RLock()
|
||||
assert.True(t, len(c.watchers) > 0, "watchers should be added")
|
||||
for _, watcher := range c.watchers {
|
||||
assert.NotNil(t, watcher, "watcher should not be nil")
|
||||
}
|
||||
c.lock.RUnlock()
|
||||
|
||||
// Clean up
|
||||
close(c.done)
|
||||
}
|
||||
|
||||
type mockListener struct {
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: statewatcher.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -package internal -destination statewatcher_mock.go -source statewatcher.go etcdConn
|
||||
//
|
||||
|
||||
// Package internal is a generated GoMock package.
|
||||
package internal
|
||||
@@ -8,34 +13,35 @@ import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
connectivity "google.golang.org/grpc/connectivity"
|
||||
)
|
||||
|
||||
// MocketcdConn is a mock of etcdConn interface
|
||||
// MocketcdConn is a mock of etcdConn interface.
|
||||
type MocketcdConn struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MocketcdConnMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MocketcdConnMockRecorder is the mock recorder for MocketcdConn
|
||||
// MocketcdConnMockRecorder is the mock recorder for MocketcdConn.
|
||||
type MocketcdConnMockRecorder struct {
|
||||
mock *MocketcdConn
|
||||
}
|
||||
|
||||
// NewMocketcdConn creates a new mock instance
|
||||
// NewMocketcdConn creates a new mock instance.
|
||||
func NewMocketcdConn(ctrl *gomock.Controller) *MocketcdConn {
|
||||
mock := &MocketcdConn{ctrl: ctrl}
|
||||
mock.recorder = &MocketcdConnMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MocketcdConn) EXPECT() *MocketcdConnMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetState mocks base method
|
||||
// GetState mocks base method.
|
||||
func (m *MocketcdConn) GetState() connectivity.State {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetState")
|
||||
@@ -43,13 +49,13 @@ func (m *MocketcdConn) GetState() connectivity.State {
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetState indicates an expected call of GetState
|
||||
// GetState indicates an expected call of GetState.
|
||||
func (mr *MocketcdConnMockRecorder) GetState() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetState", reflect.TypeOf((*MocketcdConn)(nil).GetState))
|
||||
}
|
||||
|
||||
// WaitForStateChange mocks base method
|
||||
// WaitForStateChange mocks base method.
|
||||
func (m *MocketcdConn) WaitForStateChange(ctx context.Context, sourceState connectivity.State) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "WaitForStateChange", ctx, sourceState)
|
||||
@@ -57,7 +63,7 @@ func (m *MocketcdConn) WaitForStateChange(ctx context.Context, sourceState conne
|
||||
return ret0
|
||||
}
|
||||
|
||||
// WaitForStateChange indicates an expected call of WaitForStateChange
|
||||
// WaitForStateChange indicates an expected call of WaitForStateChange.
|
||||
func (mr *MocketcdConnMockRecorder) WaitForStateChange(ctx, sourceState any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WaitForStateChange", reflect.TypeOf((*MocketcdConn)(nil).WaitForStateChange), ctx, sourceState)
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"go.uber.org/mock/gomock"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ type (
|
||||
}
|
||||
|
||||
// UpdateListener wraps the OnAdd and OnDelete methods.
|
||||
// The implementation should be thread-safe and idempotent.
|
||||
UpdateListener interface {
|
||||
OnAdd(kv KV)
|
||||
OnDelete(kv KV)
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: updatelistener.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -package internal -destination updatelistener_mock.go -source updatelistener.go UpdateListener
|
||||
//
|
||||
|
||||
// Package internal is a generated GoMock package.
|
||||
package internal
|
||||
@@ -7,51 +12,52 @@ package internal
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockUpdateListener is a mock of UpdateListener interface
|
||||
// MockUpdateListener is a mock of UpdateListener interface.
|
||||
type MockUpdateListener struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockUpdateListenerMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockUpdateListenerMockRecorder is the mock recorder for MockUpdateListener
|
||||
// MockUpdateListenerMockRecorder is the mock recorder for MockUpdateListener.
|
||||
type MockUpdateListenerMockRecorder struct {
|
||||
mock *MockUpdateListener
|
||||
}
|
||||
|
||||
// NewMockUpdateListener creates a new mock instance
|
||||
// NewMockUpdateListener creates a new mock instance.
|
||||
func NewMockUpdateListener(ctrl *gomock.Controller) *MockUpdateListener {
|
||||
mock := &MockUpdateListener{ctrl: ctrl}
|
||||
mock.recorder = &MockUpdateListenerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockUpdateListener) EXPECT() *MockUpdateListenerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// OnAdd mocks base method
|
||||
// OnAdd mocks base method.
|
||||
func (m *MockUpdateListener) OnAdd(kv KV) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "OnAdd", kv)
|
||||
}
|
||||
|
||||
// OnAdd indicates an expected call of OnAdd
|
||||
// OnAdd indicates an expected call of OnAdd.
|
||||
func (mr *MockUpdateListenerMockRecorder) OnAdd(kv any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnAdd", reflect.TypeOf((*MockUpdateListener)(nil).OnAdd), kv)
|
||||
}
|
||||
|
||||
// OnDelete mocks base method
|
||||
// OnDelete mocks base method.
|
||||
func (m *MockUpdateListener) OnDelete(kv KV) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "OnDelete", kv)
|
||||
}
|
||||
|
||||
// OnDelete indicates an expected call of OnDelete
|
||||
// OnDelete indicates an expected call of OnDelete.
|
||||
func (mr *MockUpdateListenerMockRecorder) OnDelete(kv any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnDelete", reflect.TypeOf((*MockUpdateListener)(nil).OnDelete), kv)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/zeromicro/go-zero/core/discov/internal"
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
"github.com/zeromicro/go-zero/core/logc"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
@@ -91,12 +92,12 @@ func (p *Publisher) doKeepAlive() error {
|
||||
default:
|
||||
cli, err := p.doRegister()
|
||||
if err != nil {
|
||||
logx.Errorf("etcd publisher doRegister: %s", err.Error())
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher doRegister: %v", err)
|
||||
break
|
||||
}
|
||||
|
||||
if err := p.keepAliveAsync(cli); err != nil {
|
||||
logx.Errorf("etcd publisher keepAliveAsync: %s", err.Error())
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher keepAliveAsync: %v", err)
|
||||
break
|
||||
}
|
||||
|
||||
@@ -124,23 +125,48 @@ func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error {
|
||||
}
|
||||
|
||||
threading.GoSafe(func() {
|
||||
wch := cli.Watch(cli.Ctx(), p.fullKey, clientv3.WithFilterPut())
|
||||
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-ch:
|
||||
if !ok {
|
||||
p.revoke(cli)
|
||||
if err := p.doKeepAlive(); err != nil {
|
||||
logx.Errorf("etcd publisher KeepAlive: %s", err.Error())
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
case c := <-wch:
|
||||
if c.Err() != nil {
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher watch: %v", c.Err())
|
||||
if err := p.doKeepAlive(); err != nil {
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, evt := range c.Events {
|
||||
if evt.Type == clientv3.EventTypeDelete {
|
||||
logc.Infof(cli.Ctx(), "etcd publisher watch: %s, event: %v",
|
||||
evt.Kv.Key, evt.Type)
|
||||
_, err := cli.Put(cli.Ctx(), p.fullKey, p.value, clientv3.WithLease(p.lease))
|
||||
if err != nil {
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher re-put key: %v", err)
|
||||
} else {
|
||||
logc.Infof(cli.Ctx(), "etcd publisher re-put key: %s, value: %s",
|
||||
p.fullKey, p.value)
|
||||
}
|
||||
}
|
||||
}
|
||||
case <-p.pauseChan:
|
||||
logx.Infof("paused etcd renew, key: %s, value: %s", p.key, p.value)
|
||||
logc.Infof(cli.Ctx(), "paused etcd renew, key: %s, value: %s", p.key, p.value)
|
||||
p.revoke(cli)
|
||||
select {
|
||||
case <-p.resumeChan:
|
||||
if err := p.doKeepAlive(); err != nil {
|
||||
logx.Errorf("etcd publisher KeepAlive: %s", err.Error())
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %v", err)
|
||||
}
|
||||
return
|
||||
case <-p.quit.Done():
|
||||
@@ -175,7 +201,7 @@ func (p *Publisher) register(client internal.EtcdClient) (clientv3.LeaseID, erro
|
||||
|
||||
func (p *Publisher) revoke(cli internal.EtcdClient) {
|
||||
if _, err := cli.Revoke(cli.Ctx(), p.lease); err != nil {
|
||||
logx.Errorf("etcd publisher revoke: %s", err.Error())
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher revoke: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,13 +9,14 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/discov/internal"
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
"go.etcd.io/etcd/api/v3/mvccpb"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/net/http2"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
@@ -211,6 +212,9 @@ func TestPublisher_keepAliveAsyncQuit(t *testing.T) {
|
||||
defer restore()
|
||||
cli.EXPECT().Ctx().AnyTimes()
|
||||
cli.EXPECT().KeepAlive(gomock.Any(), id)
|
||||
// Add Watch expectation for the new watch mechanism
|
||||
watchChan := make(<-chan clientv3.WatchResponse)
|
||||
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(watchChan)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
|
||||
@@ -232,6 +236,9 @@ func TestPublisher_keepAliveAsyncPause(t *testing.T) {
|
||||
defer restore()
|
||||
cli.EXPECT().Ctx().AnyTimes()
|
||||
cli.EXPECT().KeepAlive(gomock.Any(), id)
|
||||
// Add Watch expectation for the new watch mechanism
|
||||
watchChan := make(<-chan clientv3.WatchResponse)
|
||||
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(watchChan)
|
||||
pub := NewPublisher(nil, "thekey", "thevalue")
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
@@ -245,6 +252,112 @@ func TestPublisher_keepAliveAsyncPause(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Test case for key deletion and re-registration (covers lines 148-155)
|
||||
func TestPublisher_keepAliveAsyncKeyDeletion(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
const id clientv3.LeaseID = 1
|
||||
cli := internal.NewMockEtcdClient(ctrl)
|
||||
restore := setMockClient(cli)
|
||||
defer restore()
|
||||
cli.EXPECT().Ctx().AnyTimes()
|
||||
cli.EXPECT().KeepAlive(gomock.Any(), id)
|
||||
|
||||
// Create a watch channel that will send a delete event
|
||||
watchChan := make(chan clientv3.WatchResponse, 1)
|
||||
watchResp := clientv3.WatchResponse{
|
||||
Events: []*clientv3.Event{{
|
||||
Type: clientv3.EventTypeDelete,
|
||||
Kv: &mvccpb.KeyValue{
|
||||
Key: []byte("thekey"),
|
||||
},
|
||||
}},
|
||||
}
|
||||
watchChan <- watchResp
|
||||
|
||||
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return((<-chan clientv3.WatchResponse)(watchChan))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1) // Only wait for Revoke call
|
||||
|
||||
// Use a channel to signal when Put has been called
|
||||
putCalled := make(chan struct{})
|
||||
|
||||
// Expect the re-put operation when key is deleted
|
||||
cli.EXPECT().Put(gomock.Any(), "thekey", "thevalue", gomock.Any()).Do(func(_, _, _, _ any) {
|
||||
close(putCalled) // Signal that Put has been called
|
||||
}).Return(nil, nil)
|
||||
|
||||
// Expect revoke when Stop is called
|
||||
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
|
||||
wg.Done()
|
||||
})
|
||||
|
||||
pub := NewPublisher(nil, "thekey", "thevalue")
|
||||
pub.lease = id
|
||||
pub.fullKey = "thekey"
|
||||
|
||||
assert.Nil(t, pub.keepAliveAsync(cli))
|
||||
|
||||
// Wait for Put to be called, then stop
|
||||
<-putCalled
|
||||
pub.Stop()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Test case for key deletion with re-put error (covers error branch in lines 151-152)
|
||||
func TestPublisher_keepAliveAsyncKeyDeletionPutError(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
const id clientv3.LeaseID = 1
|
||||
cli := internal.NewMockEtcdClient(ctrl)
|
||||
restore := setMockClient(cli)
|
||||
defer restore()
|
||||
cli.EXPECT().Ctx().AnyTimes()
|
||||
cli.EXPECT().KeepAlive(gomock.Any(), id)
|
||||
|
||||
// Create a watch channel that will send a delete event
|
||||
watchChan := make(chan clientv3.WatchResponse, 1)
|
||||
watchResp := clientv3.WatchResponse{
|
||||
Events: []*clientv3.Event{{
|
||||
Type: clientv3.EventTypeDelete,
|
||||
Kv: &mvccpb.KeyValue{
|
||||
Key: []byte("thekey"),
|
||||
},
|
||||
}},
|
||||
}
|
||||
watchChan <- watchResp
|
||||
|
||||
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return((<-chan clientv3.WatchResponse)(watchChan))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1) // Only wait for Revoke call
|
||||
|
||||
// Use a channel to signal when Put has been called
|
||||
putCalled := make(chan struct{})
|
||||
|
||||
// Expect the re-put operation to fail
|
||||
cli.EXPECT().Put(gomock.Any(), "thekey", "thevalue", gomock.Any()).Do(func(_, _, _, _ any) {
|
||||
close(putCalled) // Signal that Put has been called
|
||||
}).Return(nil, errors.New("put error"))
|
||||
|
||||
// Expect revoke when Stop is called
|
||||
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
|
||||
wg.Done()
|
||||
})
|
||||
|
||||
pub := NewPublisher(nil, "thekey", "thevalue")
|
||||
pub.lease = id
|
||||
pub.fullKey = "thekey"
|
||||
|
||||
assert.Nil(t, pub.keepAliveAsync(cli))
|
||||
|
||||
// Wait for Put to be called, then stop
|
||||
<-putCalled
|
||||
pub.Stop()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestPublisher_Resume(t *testing.T) {
|
||||
publisher := new(Publisher)
|
||||
publisher.resumeChan = make(chan lang.PlaceholderType)
|
||||
@@ -273,6 +386,9 @@ func TestPublisher_keepAliveAsync(t *testing.T) {
|
||||
defer restore()
|
||||
cli.EXPECT().Ctx().AnyTimes()
|
||||
cli.EXPECT().KeepAlive(gomock.Any(), id)
|
||||
// Add Watch expectation for the new watch mechanism
|
||||
watchChan := make(<-chan clientv3.WatchResponse)
|
||||
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(watchChan)
|
||||
cli.EXPECT().Grant(gomock.Any(), timeToLive).Return(&clientv3.LeaseGrantResponse{
|
||||
ID: 1,
|
||||
}, nil)
|
||||
|
||||
@@ -17,9 +17,11 @@ type (
|
||||
Subscriber struct {
|
||||
endpoints []string
|
||||
exclusive bool
|
||||
key string
|
||||
exactMatch bool
|
||||
items *container
|
||||
items Container
|
||||
}
|
||||
KV = internal.KV
|
||||
)
|
||||
|
||||
// NewSubscriber returns a Subscriber.
|
||||
@@ -29,13 +31,16 @@ type (
|
||||
func NewSubscriber(endpoints []string, key string, opts ...SubOption) (*Subscriber, error) {
|
||||
sub := &Subscriber{
|
||||
endpoints: endpoints,
|
||||
key: key,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(sub)
|
||||
}
|
||||
sub.items = newContainer(sub.exclusive)
|
||||
if sub.items == nil {
|
||||
sub.items = newContainer(sub.exclusive)
|
||||
}
|
||||
|
||||
if err := internal.GetRegistry().Monitor(endpoints, key, sub.items, sub.exactMatch); err != nil {
|
||||
if err := internal.GetRegistry().Monitor(endpoints, key, sub.exactMatch, sub.items); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -44,12 +49,17 @@ func NewSubscriber(endpoints []string, key string, opts ...SubOption) (*Subscrib
|
||||
|
||||
// AddListener adds listener to s.
|
||||
func (s *Subscriber) AddListener(listener func()) {
|
||||
s.items.addListener(listener)
|
||||
s.items.AddListener(listener)
|
||||
}
|
||||
|
||||
// Close closes the subscriber.
|
||||
func (s *Subscriber) Close() {
|
||||
internal.GetRegistry().Unmonitor(s.endpoints, s.key, s.exactMatch, s.items)
|
||||
}
|
||||
|
||||
// Values returns all the subscription values.
|
||||
func (s *Subscriber) Values() []string {
|
||||
return s.items.getValues()
|
||||
return s.items.GetValues()
|
||||
}
|
||||
|
||||
// Exclusive means that key value can only be 1:1,
|
||||
@@ -81,16 +91,32 @@ func WithSubEtcdTLS(certFile, certKeyFile, caFile string, insecureSkipVerify boo
|
||||
}
|
||||
}
|
||||
|
||||
type container struct {
|
||||
exclusive bool
|
||||
values map[string][]string
|
||||
mapping map[string]string
|
||||
snapshot atomic.Value
|
||||
dirty *syncx.AtomicBool
|
||||
listeners []func()
|
||||
lock sync.Mutex
|
||||
// WithContainer provides a custom container to the subscriber.
|
||||
func WithContainer(container Container) SubOption {
|
||||
return func(sub *Subscriber) {
|
||||
sub.items = container
|
||||
}
|
||||
}
|
||||
|
||||
type (
|
||||
Container interface {
|
||||
OnAdd(kv internal.KV)
|
||||
OnDelete(kv internal.KV)
|
||||
AddListener(listener func())
|
||||
GetValues() []string
|
||||
}
|
||||
|
||||
container struct {
|
||||
exclusive bool
|
||||
values map[string][]string
|
||||
mapping map[string]string
|
||||
snapshot atomic.Value
|
||||
dirty *syncx.AtomicBool
|
||||
listeners []func()
|
||||
lock sync.Mutex
|
||||
}
|
||||
)
|
||||
|
||||
func newContainer(exclusive bool) *container {
|
||||
return &container{
|
||||
exclusive: exclusive,
|
||||
@@ -134,7 +160,7 @@ func (c *container) addKv(key, value string) ([]string, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (c *container) addListener(listener func()) {
|
||||
func (c *container) AddListener(listener func()) {
|
||||
c.lock.Lock()
|
||||
c.listeners = append(c.listeners, listener)
|
||||
c.lock.Unlock()
|
||||
@@ -163,7 +189,7 @@ func (c *container) doRemoveKey(key string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *container) getValues() []string {
|
||||
func (c *container) GetValues() []string {
|
||||
if !c.dirty.True() {
|
||||
return c.snapshot.Load().([]string)
|
||||
}
|
||||
|
||||
@@ -171,10 +171,10 @@ func TestContainer(t *testing.T) {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var changed bool
|
||||
c := newContainer(exclusive)
|
||||
c.addListener(func() {
|
||||
c.AddListener(func() {
|
||||
changed = true
|
||||
})
|
||||
assert.Nil(t, c.getValues())
|
||||
assert.Nil(t, c.GetValues())
|
||||
assert.False(t, changed)
|
||||
|
||||
for _, order := range test.do {
|
||||
@@ -193,9 +193,9 @@ func TestContainer(t *testing.T) {
|
||||
|
||||
assert.True(t, changed)
|
||||
assert.True(t, c.dirty.True())
|
||||
assert.ElementsMatch(t, test.expect, c.getValues())
|
||||
assert.ElementsMatch(t, test.expect, c.GetValues())
|
||||
assert.False(t, c.dirty.True())
|
||||
assert.ElementsMatch(t, test.expect, c.getValues())
|
||||
assert.ElementsMatch(t, test.expect, c.GetValues())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -204,12 +204,14 @@ func TestContainer(t *testing.T) {
|
||||
func TestSubscriber(t *testing.T) {
|
||||
sub := new(Subscriber)
|
||||
Exclusive()(sub)
|
||||
sub.items = newContainer(sub.exclusive)
|
||||
c := newContainer(sub.exclusive)
|
||||
WithContainer(c)(sub)
|
||||
sub.items = c
|
||||
var count int32
|
||||
sub.AddListener(func() {
|
||||
atomic.AddInt32(&count, 1)
|
||||
})
|
||||
sub.items.notifyChange()
|
||||
c.notifyChange()
|
||||
assert.Empty(t, sub.Values())
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&count))
|
||||
}
|
||||
@@ -225,3 +227,29 @@ func TestWithSubEtcdAccount(t *testing.T) {
|
||||
assert.Equal(t, user, account.User)
|
||||
assert.Equal(t, "bar", account.Pass)
|
||||
}
|
||||
|
||||
func TestWithExactMatch(t *testing.T) {
|
||||
sub := new(Subscriber)
|
||||
WithExactMatch()(sub)
|
||||
c := newContainer(sub.exclusive)
|
||||
sub.items = c
|
||||
var count int32
|
||||
sub.AddListener(func() {
|
||||
atomic.AddInt32(&count, 1)
|
||||
})
|
||||
c.notifyChange()
|
||||
assert.Empty(t, sub.Values())
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&count))
|
||||
}
|
||||
|
||||
func TestSubscriberClose(t *testing.T) {
|
||||
l := newContainer(false)
|
||||
sub := &Subscriber{
|
||||
endpoints: []string{"localhost:12379"},
|
||||
key: "foo",
|
||||
items: l,
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
sub.Close()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux || darwin
|
||||
//go:build linux || darwin || freebsd
|
||||
|
||||
package fs
|
||||
|
||||
|
||||
@@ -168,7 +168,7 @@ func (s Stream) Count() (count int) {
|
||||
return
|
||||
}
|
||||
|
||||
// Distinct removes the duplicated items base on the given KeyFunc.
|
||||
// Distinct removes the duplicated items based on the given KeyFunc.
|
||||
func (s Stream) Distinct(fn KeyFunc) Stream {
|
||||
source := make(chan any)
|
||||
|
||||
@@ -459,7 +459,7 @@ func (s Stream) Tail(n int64) Stream {
|
||||
return Range(source)
|
||||
}
|
||||
|
||||
// Walk lets the callers handle each item, the caller may write zero, one or more items base on the given item.
|
||||
// Walk lets the callers handle each item, the caller may write zero, one or more items based on the given item.
|
||||
func (s Stream) Walk(fn WalkFunc, opts ...Option) Stream {
|
||||
option := buildOptions(opts...)
|
||||
if option.unlimitedWorkers {
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package fx
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"runtime"
|
||||
@@ -13,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
"go.uber.org/goleak"
|
||||
)
|
||||
@@ -238,7 +237,7 @@ func TestLast(t *testing.T) {
|
||||
|
||||
func TestMap(t *testing.T) {
|
||||
runCheckedTest(t, func(t *testing.T) {
|
||||
log.SetOutput(io.Discard)
|
||||
logtest.Discard(t)
|
||||
|
||||
tests := []struct {
|
||||
mapper MapFunc
|
||||
|
||||
@@ -96,7 +96,7 @@ func (h *ConsistentHash) AddWithWeight(node any, weight int) {
|
||||
h.AddWithReplicas(node, replicas)
|
||||
}
|
||||
|
||||
// Get returns the corresponding node from h base on the given v.
|
||||
// Get returns the corresponding node from h based on the given v.
|
||||
func (h *ConsistentHash) Get(v any) (any, bool) {
|
||||
h.lock.RLock()
|
||||
defer h.lock.RUnlock()
|
||||
|
||||
@@ -86,21 +86,16 @@ func TestConsistentHashIncrementalTransfer(t *testing.T) {
|
||||
|
||||
func TestConsistentHashTransferOnFailure(t *testing.T) {
|
||||
index := 41
|
||||
keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index)
|
||||
var transferred int
|
||||
for k, v := range newKeys {
|
||||
if v != keys[k] {
|
||||
transferred++
|
||||
}
|
||||
}
|
||||
|
||||
ratio := float32(transferred) / float32(requestSize)
|
||||
assert.True(t, ratio < 2.5/float32(keySize), fmt.Sprintf("%d: %f", index, ratio))
|
||||
ratioNotExists := getTransferRatioOnFailure(t, index)
|
||||
assert.True(t, ratioNotExists == 0, fmt.Sprintf("%d: %f", index, ratioNotExists))
|
||||
index = 13
|
||||
ratio := getTransferRatioOnFailure(t, index)
|
||||
assert.True(t, ratio < 2.5/keySize, fmt.Sprintf("%d: %f", index, ratio))
|
||||
}
|
||||
|
||||
func TestConsistentHashLeastTransferOnFailure(t *testing.T) {
|
||||
prefix := "localhost:"
|
||||
index := 41
|
||||
index := 13
|
||||
keys, newKeys := getKeysBeforeAndAfterFailure(t, prefix, index)
|
||||
for k, v := range keys {
|
||||
newV := newKeys[k]
|
||||
@@ -164,6 +159,17 @@ func getKeysBeforeAndAfterFailure(t *testing.T, prefix string, index int) (map[i
|
||||
return keys, newKeys
|
||||
}
|
||||
|
||||
func getTransferRatioOnFailure(t *testing.T, index int) float32 {
|
||||
keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index)
|
||||
var transferred int
|
||||
for k, v := range newKeys {
|
||||
if v != keys[k] {
|
||||
transferred++
|
||||
}
|
||||
}
|
||||
return float32(transferred) / float32(requestSize)
|
||||
}
|
||||
|
||||
type mockNode struct {
|
||||
addr string
|
||||
id int
|
||||
|
||||
@@ -2,7 +2,7 @@ package hash
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/spaolacci/murmur3"
|
||||
)
|
||||
@@ -20,6 +20,7 @@ func Md5(data []byte) []byte {
|
||||
}
|
||||
|
||||
// Md5Hex returns the md5 hex string of data.
|
||||
// This function is optimized for better performance than fmt.Sprintf.
|
||||
func Md5Hex(data []byte) string {
|
||||
return fmt.Sprintf("%x", Md5(data))
|
||||
return hex.EncodeToString(Md5(data))
|
||||
}
|
||||
|
||||
@@ -25,6 +25,29 @@ func TestMd5Hex(t *testing.T) {
|
||||
assert.Equal(t, md5Digest, actual)
|
||||
}
|
||||
|
||||
func TestHash(t *testing.T) {
|
||||
result := Hash([]byte(text))
|
||||
assert.NotEqual(t, uint64(0), result)
|
||||
}
|
||||
|
||||
func TestHash_Deterministic(t *testing.T) {
|
||||
data := []byte("consistent-hash-test")
|
||||
first := Hash(data)
|
||||
second := Hash(data)
|
||||
assert.Equal(t, first, second)
|
||||
}
|
||||
|
||||
func TestHash_Empty(t *testing.T) {
|
||||
// Hash should not panic on empty input.
|
||||
result := Hash([]byte{})
|
||||
_ = result
|
||||
}
|
||||
|
||||
func TestMd5Hex_Empty(t *testing.T) {
|
||||
result := Md5Hex([]byte{})
|
||||
assert.Equal(t, 32, len(result))
|
||||
}
|
||||
|
||||
func BenchmarkHashFnv(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
h := fnv.New32()
|
||||
|
||||
@@ -8,9 +8,25 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Marshal marshals v into json bytes.
|
||||
// Marshal marshals v into json bytes, without escaping HTML and removes the trailing newline.
|
||||
func Marshal(v any) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
// why not use json.Marshal? https://github.com/golang/go/issues/28453
|
||||
// it changes the behavior of json.Marshal, like & -> \u0026, < -> \u003c, > -> \u003e
|
||||
// which is not what we want in API responses
|
||||
var buf bytes.Buffer
|
||||
enc := json.NewEncoder(&buf)
|
||||
enc.SetEscapeHTML(false)
|
||||
if err := enc.Encode(v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bs := buf.Bytes()
|
||||
// Remove trailing newline added by json.Encoder.Encode
|
||||
if len(bs) > 0 && bs[len(bs)-1] == '\n' {
|
||||
bs = bs[:len(bs)-1]
|
||||
}
|
||||
|
||||
return bs, nil
|
||||
}
|
||||
|
||||
// MarshalToString marshals v into a string.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package jsonx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -101,3 +102,105 @@ func TestUnmarshalFromReaderError(t *testing.T) {
|
||||
err := UnmarshalFromReader(strings.NewReader(s), &v)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func Test_doMarshalJson(t *testing.T) {
|
||||
type args struct {
|
||||
v any
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []byte
|
||||
wantErr assert.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
args: args{nil},
|
||||
want: []byte("null"),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "string",
|
||||
args: args{"hello"},
|
||||
want: []byte(`"hello"`),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "int",
|
||||
args: args{42},
|
||||
want: []byte("42"),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "bool",
|
||||
args: args{true},
|
||||
want: []byte("true"),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "struct",
|
||||
args: args{
|
||||
struct {
|
||||
Name string `json:"name"`
|
||||
}{Name: "test"},
|
||||
},
|
||||
want: []byte(`{"name":"test"}`),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "slice",
|
||||
args: args{[]int{1, 2, 3}},
|
||||
want: []byte("[1,2,3]"),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "map",
|
||||
args: args{map[string]int{"a": 1, "b": 2}},
|
||||
want: []byte(`{"a":1,"b":2}`),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "unmarshalable type",
|
||||
args: args{complex(1, 2)},
|
||||
want: nil,
|
||||
wantErr: assert.Error,
|
||||
},
|
||||
{
|
||||
name: "channel type",
|
||||
args: args{make(chan int)},
|
||||
want: nil,
|
||||
wantErr: assert.Error,
|
||||
},
|
||||
{
|
||||
name: "url with query params",
|
||||
args: args{"https://example.com/api?name=test&age=25"},
|
||||
want: []byte(`"https://example.com/api?name=test&age=25"`),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "url with encoded query params",
|
||||
args: args{"https://example.com/api?data=hello%20world&special=%26%3D"},
|
||||
want: []byte(`"https://example.com/api?data=hello%20world&special=%26%3D"`),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "url with multiple query params",
|
||||
args: args{"http://localhost:8080/users?page=1&limit=10&sort=name&order=asc"},
|
||||
want: []byte(`"http://localhost:8080/users?page=1&limit=10&sort=name&order=asc"`),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Marshal(tt.args.v)
|
||||
if !tt.wantErr(t, err, fmt.Sprintf("Marshal(%v)", tt.args.v)) {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equalf(t, string(tt.want), string(got), "Marshal(%v)", tt.args.v)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,6 +37,13 @@ func Debugf(ctx context.Context, format string, v ...interface{}) {
|
||||
getLogger(ctx).Debugf(format, v...)
|
||||
}
|
||||
|
||||
// Debugfn writes fn result into access log.
|
||||
// This is useful when the function is expensive to compute,
|
||||
// and we want to log it only when necessary.
|
||||
func Debugfn(ctx context.Context, fn func() any) {
|
||||
getLogger(ctx).Debugfn(fn)
|
||||
}
|
||||
|
||||
// Debugv writes v into access log with json content.
|
||||
func Debugv(ctx context.Context, v interface{}) {
|
||||
getLogger(ctx).Debugv(v)
|
||||
@@ -57,6 +64,13 @@ func Errorf(ctx context.Context, format string, v ...any) {
|
||||
getLogger(ctx).Errorf(fmt.Errorf(format, v...).Error())
|
||||
}
|
||||
|
||||
// Errorfn writes fn result into error log.
|
||||
// This is useful when the function is expensive to compute,
|
||||
// and we want to log it only when necessary.
|
||||
func Errorfn(ctx context.Context, fn func() any) {
|
||||
getLogger(ctx).Errorfn(fn)
|
||||
}
|
||||
|
||||
// Errorv writes v into error log with json content.
|
||||
// No call stack attached, because not elegant to pack the messages.
|
||||
func Errorv(ctx context.Context, v any) {
|
||||
@@ -83,6 +97,13 @@ func Infof(ctx context.Context, format string, v ...any) {
|
||||
getLogger(ctx).Infof(format, v...)
|
||||
}
|
||||
|
||||
// Infofn writes fn result into access log.
|
||||
// This is useful when the function is expensive to compute,
|
||||
// and we want to log it only when necessary.
|
||||
func Infofn(ctx context.Context, fn func() any) {
|
||||
getLogger(ctx).Infofn(fn)
|
||||
}
|
||||
|
||||
// Infov writes v into access log with json content.
|
||||
func Infov(ctx context.Context, v any) {
|
||||
getLogger(ctx).Infov(v)
|
||||
@@ -127,6 +148,13 @@ func Slowf(ctx context.Context, format string, v ...any) {
|
||||
getLogger(ctx).Slowf(format, v...)
|
||||
}
|
||||
|
||||
// Slowfn writes fn result into slow log.
|
||||
// This is useful when the function is expensive to compute,
|
||||
// and we want to log it only when necessary.
|
||||
func Slowfn(ctx context.Context, fn func() any) {
|
||||
getLogger(ctx).Slowfn(fn)
|
||||
}
|
||||
|
||||
// Slowv writes v into slow log with json content.
|
||||
func Slowv(ctx context.Context, v any) {
|
||||
getLogger(ctx).Slowv(v)
|
||||
|
||||
@@ -49,6 +49,15 @@ func TestErrorf(t *testing.T) {
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestErrorfn(t *testing.T) {
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
Errorfn(context.Background(), func() any {
|
||||
return fmt.Sprintf("foo %s", "bar")
|
||||
})
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestErrorv(t *testing.T) {
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
@@ -77,6 +86,15 @@ func TestInfof(t *testing.T) {
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestInfofn(t *testing.T) {
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
Infofn(context.Background(), func() any {
|
||||
return fmt.Sprintf("foo %s", "bar")
|
||||
})
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestInfov(t *testing.T) {
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
@@ -105,6 +123,15 @@ func TestDebugf(t *testing.T) {
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestDebugfn(t *testing.T) {
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
Debugfn(context.Background(), func() any {
|
||||
return fmt.Sprintf("foo %s", "bar")
|
||||
})
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestDebugv(t *testing.T) {
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
@@ -148,6 +175,15 @@ func TestSlowf(t *testing.T) {
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
|
||||
}
|
||||
|
||||
func TestSlowfn(t *testing.T) {
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
Slowfn(context.Background(), func() any {
|
||||
return fmt.Sprintf("foo %s", "bar")
|
||||
})
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
|
||||
}
|
||||
|
||||
func TestSlowv(t *testing.T) {
|
||||
buf := logtest.NewCollector(t)
|
||||
file, line := getFileLine()
|
||||
|
||||
@@ -1,47 +1,70 @@
|
||||
package logx
|
||||
|
||||
// A LogConf is a logging config.
|
||||
type LogConf struct {
|
||||
// ServiceName represents the service name.
|
||||
ServiceName string `json:",optional"`
|
||||
// Mode represents the logging mode, default is `console`.
|
||||
// console: log to console.
|
||||
// file: log to file.
|
||||
// volume: used in k8s, prepend the hostname to the log file name.
|
||||
Mode string `json:",default=console,options=[console,file,volume]"`
|
||||
// Encoding represents the encoding type, default is `json`.
|
||||
// json: json encoding.
|
||||
// plain: plain text encoding, typically used in development.
|
||||
Encoding string `json:",default=json,options=[json,plain]"`
|
||||
// TimeFormat represents the time format, default is `2006-01-02T15:04:05.000Z07:00`.
|
||||
TimeFormat string `json:",optional"`
|
||||
// Path represents the log file path, default is `logs`.
|
||||
Path string `json:",default=logs"`
|
||||
// Level represents the log level, default is `info`.
|
||||
Level string `json:",default=info,options=[debug,info,error,severe]"`
|
||||
// MaxContentLength represents the max content bytes, default is no limit.
|
||||
MaxContentLength uint32 `json:",optional"`
|
||||
// Compress represents whether to compress the log file, default is `false`.
|
||||
Compress bool `json:",optional"`
|
||||
// Stat represents whether to log statistics, default is `true`.
|
||||
Stat bool `json:",default=true"`
|
||||
// KeepDays represents how many days the log files will be kept. Default to keep all files.
|
||||
// Only take effect when Mode is `file` or `volume`, both work when Rotation is `daily` or `size`.
|
||||
KeepDays int `json:",optional"`
|
||||
// StackCooldownMillis represents the cooldown time for stack logging, default is 100ms.
|
||||
StackCooldownMillis int `json:",default=100"`
|
||||
// MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever.
|
||||
// Only take effect when RotationRuleType is `size`.
|
||||
// Even though `MaxBackups` sets 0, log files will still be removed
|
||||
// if the `KeepDays` limitation is reached.
|
||||
MaxBackups int `json:",default=0"`
|
||||
// MaxSize represents how much space the writing log file takes up. 0 means no limit. The unit is `MB`.
|
||||
// Only take effect when RotationRuleType is `size`
|
||||
MaxSize int `json:",default=0"`
|
||||
// Rotation represents the type of log rotation rule. Default is `daily`.
|
||||
// daily: daily rotation.
|
||||
// size: size limited rotation.
|
||||
Rotation string `json:",default=daily,options=[daily,size]"`
|
||||
// FileTimeFormat represents the time format for file name, default is `2006-01-02T15:04:05.000Z07:00`.
|
||||
FileTimeFormat string `json:",optional"`
|
||||
}
|
||||
type (
|
||||
// A LogConf is a logging config.
|
||||
LogConf struct {
|
||||
// ServiceName represents the service name.
|
||||
ServiceName string `json:",optional"`
|
||||
// Mode represents the logging mode, default is `console`.
|
||||
// console: log to console.
|
||||
// file: log to file.
|
||||
// volume: used in k8s, prepend the hostname to the log file name.
|
||||
Mode string `json:",default=console,options=[console,file,volume]"`
|
||||
// Encoding represents the encoding type, default is `json`.
|
||||
// json: json encoding.
|
||||
// plain: plain text encoding, typically used in development.
|
||||
Encoding string `json:",default=json,options=[json,plain]"`
|
||||
// TimeFormat represents the time format, default is `2006-01-02T15:04:05.000Z07:00`.
|
||||
TimeFormat string `json:",optional"`
|
||||
// Path represents the log file path, default is `logs`.
|
||||
Path string `json:",default=logs"`
|
||||
// Level represents the log level, default is `info`.
|
||||
Level string `json:",default=info,options=[debug,info,error,severe]"`
|
||||
// MaxContentLength represents the max content bytes, default is no limit.
|
||||
MaxContentLength uint32 `json:",optional"`
|
||||
// Compress represents whether to compress the log file, default is `false`.
|
||||
Compress bool `json:",optional"`
|
||||
// Stat represents whether to log statistics, default is `true`.
|
||||
Stat bool `json:",default=true"`
|
||||
// KeepDays represents how many days the log files will be kept. Default to keep all files.
|
||||
// Only take effect when Mode is `file` or `volume`, both work when Rotation is `daily` or `size`.
|
||||
KeepDays int `json:",optional"`
|
||||
// StackCooldownMillis represents the cooldown time for stack logging, default is 100ms.
|
||||
StackCooldownMillis int `json:",default=100"`
|
||||
// MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever.
|
||||
// Only take effect when RotationRuleType is `size`.
|
||||
// Even though `MaxBackups` sets 0, log files will still be removed
|
||||
// if the `KeepDays` limitation is reached.
|
||||
MaxBackups int `json:",default=0"`
|
||||
// MaxSize represents how much space the writing log file takes up. 0 means no limit. The unit is `MB`.
|
||||
// Only take effect when RotationRuleType is `size`
|
||||
MaxSize int `json:",default=0"`
|
||||
// Rotation represents the type of log rotation rule. Default is `daily`.
|
||||
// daily: daily rotation.
|
||||
// size: size limited rotation.
|
||||
Rotation string `json:",default=daily,options=[daily,size]"`
|
||||
// FileTimeFormat represents the time format for file name, default is `2006-01-02T15:04:05.000Z07:00`.
|
||||
FileTimeFormat string `json:",optional"`
|
||||
// FieldKeys represents the field keys.
|
||||
FieldKeys fieldKeyConf `json:",optional"`
|
||||
}
|
||||
|
||||
fieldKeyConf struct {
|
||||
// CallerKey represents the caller key.
|
||||
CallerKey string `json:",default=caller"`
|
||||
// ContentKey represents the content key.
|
||||
ContentKey string `json:",default=content"`
|
||||
// DurationKey represents the duration key.
|
||||
DurationKey string `json:",default=duration"`
|
||||
// LevelKey represents the level key.
|
||||
LevelKey string `json:",default=level"`
|
||||
// SpanKey represents the span key.
|
||||
SpanKey string `json:",default=span"`
|
||||
// TimestampKey represents the timestamp key.
|
||||
TimestampKey string `json:",default=@timestamp"`
|
||||
// TraceKey represents the trace key.
|
||||
TraceKey string `json:",default=trace"`
|
||||
// TruncatedKey represents the truncated key.
|
||||
TruncatedKey string `json:",default=truncated"`
|
||||
}
|
||||
)
|
||||
|
||||
@@ -7,12 +7,11 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
fieldsContextKey contextKey
|
||||
globalFields atomic.Value
|
||||
globalFieldsLock sync.Mutex
|
||||
)
|
||||
|
||||
type contextKey struct{}
|
||||
type fieldsKey struct{}
|
||||
|
||||
// AddGlobalFields adds global fields.
|
||||
func AddGlobalFields(fields ...LogField) {
|
||||
@@ -29,16 +28,16 @@ func AddGlobalFields(fields ...LogField) {
|
||||
|
||||
// ContextWithFields returns a new context with the given fields.
|
||||
func ContextWithFields(ctx context.Context, fields ...LogField) context.Context {
|
||||
if val := ctx.Value(fieldsContextKey); val != nil {
|
||||
if val := ctx.Value(fieldsKey{}); val != nil {
|
||||
if arr, ok := val.([]LogField); ok {
|
||||
allFields := make([]LogField, 0, len(arr)+len(fields))
|
||||
allFields = append(allFields, arr...)
|
||||
allFields = append(allFields, fields...)
|
||||
return context.WithValue(ctx, fieldsContextKey, allFields)
|
||||
return context.WithValue(ctx, fieldsKey{}, allFields)
|
||||
}
|
||||
}
|
||||
|
||||
return context.WithValue(ctx, fieldsContextKey, fields)
|
||||
return context.WithValue(ctx, fieldsKey{}, fields)
|
||||
}
|
||||
|
||||
// WithFields returns a new logger with the given fields.
|
||||
|
||||
@@ -34,7 +34,7 @@ func TestAddGlobalFields(t *testing.T) {
|
||||
|
||||
func TestContextWithFields(t *testing.T) {
|
||||
ctx := ContextWithFields(context.Background(), Field("a", 1), Field("b", 2))
|
||||
vals := ctx.Value(fieldsContextKey)
|
||||
vals := ctx.Value(fieldsKey{})
|
||||
assert.NotNil(t, vals)
|
||||
fields, ok := vals.([]LogField)
|
||||
assert.True(t, ok)
|
||||
@@ -43,7 +43,7 @@ func TestContextWithFields(t *testing.T) {
|
||||
|
||||
func TestWithFields(t *testing.T) {
|
||||
ctx := WithFields(context.Background(), Field("a", 1), Field("b", 2))
|
||||
vals := ctx.Value(fieldsContextKey)
|
||||
vals := ctx.Value(fieldsKey{})
|
||||
assert.NotNil(t, vals)
|
||||
fields, ok := vals.([]LogField)
|
||||
assert.True(t, ok)
|
||||
@@ -55,7 +55,7 @@ func TestWithFieldsAppend(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), dummyKey, "dummy")
|
||||
ctx = ContextWithFields(ctx, Field("a", 1), Field("b", 2))
|
||||
ctx = ContextWithFields(ctx, Field("c", 3), Field("d", 4))
|
||||
vals := ctx.Value(fieldsContextKey)
|
||||
vals := ctx.Value(fieldsKey{})
|
||||
assert.NotNil(t, vals)
|
||||
fields, ok := vals.([]LogField)
|
||||
assert.True(t, ok)
|
||||
@@ -80,8 +80,8 @@ func TestWithFieldsAppendCopy(t *testing.T) {
|
||||
ctxa := ContextWithFields(ctx, af)
|
||||
ctxb := ContextWithFields(ctx, bf)
|
||||
|
||||
assert.EqualValues(t, af, ctxa.Value(fieldsContextKey).([]LogField)[count])
|
||||
assert.EqualValues(t, bf, ctxb.Value(fieldsContextKey).([]LogField)[count])
|
||||
assert.EqualValues(t, af, ctxa.Value(fieldsKey{}).([]LogField)[count])
|
||||
assert.EqualValues(t, bf, ctxb.Value(fieldsKey{}).([]LogField)[count])
|
||||
}
|
||||
|
||||
func BenchmarkAtomicValue(b *testing.B) {
|
||||
|
||||
@@ -11,6 +11,8 @@ type Logger interface {
|
||||
Debug(...any)
|
||||
// Debugf logs a message at debug level.
|
||||
Debugf(string, ...any)
|
||||
// Debugfn logs a message at debug level.
|
||||
Debugfn(func() any)
|
||||
// Debugv logs a message at debug level.
|
||||
Debugv(any)
|
||||
// Debugw logs a message at debug level.
|
||||
@@ -19,6 +21,8 @@ type Logger interface {
|
||||
Error(...any)
|
||||
// Errorf logs a message at error level.
|
||||
Errorf(string, ...any)
|
||||
// Errorfn logs a message at error level.
|
||||
Errorfn(func() any)
|
||||
// Errorv logs a message at error level.
|
||||
Errorv(any)
|
||||
// Errorw logs a message at error level.
|
||||
@@ -27,6 +31,8 @@ type Logger interface {
|
||||
Info(...any)
|
||||
// Infof logs a message at info level.
|
||||
Infof(string, ...any)
|
||||
// Infofn logs a message at info level.
|
||||
Infofn(func() any)
|
||||
// Infov logs a message at info level.
|
||||
Infov(any)
|
||||
// Infow logs a message at info level.
|
||||
@@ -35,6 +41,8 @@ type Logger interface {
|
||||
Slow(...any)
|
||||
// Slowf logs a message at slow level.
|
||||
Slowf(string, ...any)
|
||||
// Slowfn logs a message at slow level.
|
||||
Slowfn(func() any)
|
||||
// Slowv logs a message at slow level.
|
||||
Slowv(any)
|
||||
// Sloww logs a message at slow level.
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/sysx"
|
||||
)
|
||||
@@ -100,6 +99,14 @@ func Debugf(format string, v ...any) {
|
||||
}
|
||||
}
|
||||
|
||||
// Debugfn writes function result into access log if debug level enabled.
|
||||
// This is useful when the function is expensive to call and debug level disabled.
|
||||
func Debugfn(fn func() any) {
|
||||
if shallLog(DebugLevel) {
|
||||
writeDebug(fn())
|
||||
}
|
||||
}
|
||||
|
||||
// Debugv writes v into access log with json content.
|
||||
func Debugv(v any) {
|
||||
if shallLog(DebugLevel) {
|
||||
@@ -139,6 +146,13 @@ func Errorf(format string, v ...any) {
|
||||
}
|
||||
}
|
||||
|
||||
// Errorfn writes function result into error log.
|
||||
func Errorfn(fn func() any) {
|
||||
if shallLog(ErrorLevel) {
|
||||
writeError(fn())
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorStack writes v along with call stack into error log.
|
||||
func ErrorStack(v ...any) {
|
||||
if shallLog(ErrorLevel) {
|
||||
@@ -172,39 +186,9 @@ func Errorw(msg string, fields ...LogField) {
|
||||
|
||||
// Field returns a LogField for the given key and value.
|
||||
func Field(key string, value any) LogField {
|
||||
switch val := value.(type) {
|
||||
case error:
|
||||
return LogField{Key: key, Value: encodeError(val)}
|
||||
case []error:
|
||||
var errs []string
|
||||
for _, err := range val {
|
||||
errs = append(errs, encodeError(err))
|
||||
}
|
||||
return LogField{Key: key, Value: errs}
|
||||
case time.Duration:
|
||||
return LogField{Key: key, Value: fmt.Sprint(val)}
|
||||
case []time.Duration:
|
||||
var durs []string
|
||||
for _, dur := range val {
|
||||
durs = append(durs, fmt.Sprint(dur))
|
||||
}
|
||||
return LogField{Key: key, Value: durs}
|
||||
case []time.Time:
|
||||
var times []string
|
||||
for _, t := range val {
|
||||
times = append(times, fmt.Sprint(t))
|
||||
}
|
||||
return LogField{Key: key, Value: times}
|
||||
case fmt.Stringer:
|
||||
return LogField{Key: key, Value: encodeStringer(val)}
|
||||
case []fmt.Stringer:
|
||||
var strs []string
|
||||
for _, str := range val {
|
||||
strs = append(strs, encodeStringer(str))
|
||||
}
|
||||
return LogField{Key: key, Value: strs}
|
||||
default:
|
||||
return LogField{Key: key, Value: val}
|
||||
return LogField{
|
||||
Key: key,
|
||||
Value: value,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,6 +206,14 @@ func Infof(format string, v ...any) {
|
||||
}
|
||||
}
|
||||
|
||||
// Infofn writes function result into access log.
|
||||
// This is useful when the function is expensive to call and info level disabled.
|
||||
func Infofn(fn func() any) {
|
||||
if shallLog(InfoLevel) {
|
||||
writeInfo(fn())
|
||||
}
|
||||
}
|
||||
|
||||
// Infov writes v into access log with json content.
|
||||
func Infov(v any) {
|
||||
if shallLog(InfoLevel) {
|
||||
@@ -284,7 +276,8 @@ func SetUp(c LogConf) (err error) {
|
||||
// Because multiple services in one process might call SetUp respectively.
|
||||
// Need to wait for the first caller to complete the execution.
|
||||
setupOnce.Do(func() {
|
||||
setupLogLevel(c)
|
||||
setupLogLevel(c.Level)
|
||||
setupFieldKeys(c.FieldKeys)
|
||||
|
||||
if !c.Stat {
|
||||
DisableStat()
|
||||
@@ -348,6 +341,14 @@ func Slowf(format string, v ...any) {
|
||||
}
|
||||
}
|
||||
|
||||
// Slowfn writes function result into slow log.
|
||||
// This is useful when the function is expensive to call and slow level disabled.
|
||||
func Slowfn(fn func() any) {
|
||||
if shallLog(ErrorLevel) {
|
||||
writeSlow(fn())
|
||||
}
|
||||
}
|
||||
|
||||
// Slowv writes v into slow log with json content.
|
||||
func Slowv(v any) {
|
||||
if shallLog(ErrorLevel) {
|
||||
@@ -480,8 +481,35 @@ func handleOptions(opts []LogOption) {
|
||||
}
|
||||
}
|
||||
|
||||
func setupLogLevel(c LogConf) {
|
||||
switch c.Level {
|
||||
func setupFieldKeys(c fieldKeyConf) {
|
||||
if len(c.CallerKey) > 0 {
|
||||
callerKey = c.CallerKey
|
||||
}
|
||||
if len(c.ContentKey) > 0 {
|
||||
contentKey = c.ContentKey
|
||||
}
|
||||
if len(c.DurationKey) > 0 {
|
||||
durationKey = c.DurationKey
|
||||
}
|
||||
if len(c.LevelKey) > 0 {
|
||||
levelKey = c.LevelKey
|
||||
}
|
||||
if len(c.SpanKey) > 0 {
|
||||
spanKey = c.SpanKey
|
||||
}
|
||||
if len(c.TimestampKey) > 0 {
|
||||
timestampKey = c.TimestampKey
|
||||
}
|
||||
if len(c.TraceKey) > 0 {
|
||||
traceKey = c.TraceKey
|
||||
}
|
||||
if len(c.TruncatedKey) > 0 {
|
||||
truncatedKey = c.TruncatedKey
|
||||
}
|
||||
}
|
||||
|
||||
func setupLogLevel(level string) {
|
||||
switch level {
|
||||
case levelDebug:
|
||||
SetLevel(DebugLevel)
|
||||
case levelInfo:
|
||||
@@ -529,7 +557,7 @@ func shallLogStat() bool {
|
||||
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||
// The caller should check shallLog before calling this function.
|
||||
func writeDebug(val any, fields ...LogField) {
|
||||
getWriter().Debug(val, addCaller(fields...)...)
|
||||
getWriter().Debug(val, mergeGlobalFields(addCaller(fields...))...)
|
||||
}
|
||||
|
||||
// writeError writes v into the error log.
|
||||
@@ -537,7 +565,7 @@ func writeDebug(val any, fields ...LogField) {
|
||||
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||
// The caller should check shallLog before calling this function.
|
||||
func writeError(val any, fields ...LogField) {
|
||||
getWriter().Error(val, addCaller(fields...)...)
|
||||
getWriter().Error(val, mergeGlobalFields(addCaller(fields...))...)
|
||||
}
|
||||
|
||||
// writeInfo writes v into info log.
|
||||
@@ -545,7 +573,7 @@ func writeError(val any, fields ...LogField) {
|
||||
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||
// The caller should check shallLog before calling this function.
|
||||
func writeInfo(val any, fields ...LogField) {
|
||||
getWriter().Info(val, addCaller(fields...)...)
|
||||
getWriter().Info(val, mergeGlobalFields(addCaller(fields...))...)
|
||||
}
|
||||
|
||||
// writeSevere writes v into severe log.
|
||||
@@ -561,7 +589,7 @@ func writeSevere(msg string) {
|
||||
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||
// The caller should check shallLog before calling this function.
|
||||
func writeSlow(val any, fields ...LogField) {
|
||||
getWriter().Slow(val, addCaller(fields...)...)
|
||||
getWriter().Slow(val, mergeGlobalFields(addCaller(fields...))...)
|
||||
}
|
||||
|
||||
// writeStack writes v into stack log.
|
||||
@@ -577,5 +605,5 @@ func writeStack(msg string) {
|
||||
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||
// The caller should check shallLog before calling this function.
|
||||
func writeStat(msg string) {
|
||||
getWriter().Stat(msg, addCaller()...)
|
||||
getWriter().Stat(msg, mergeGlobalFields(addCaller())...)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package logx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -16,6 +17,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/sdk/trace"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -244,7 +247,33 @@ func TestStructedLogDebugf(t *testing.T) {
|
||||
defer writer.Store(old)
|
||||
|
||||
doTestStructedLog(t, levelDebug, w, func(v ...any) {
|
||||
Debugf(fmt.Sprint(v...))
|
||||
Debugf("%s", fmt.Sprint(v...))
|
||||
})
|
||||
}
|
||||
|
||||
func TestStructedLogDebugfn(t *testing.T) {
|
||||
t.Run("debugfn with output", func(t *testing.T) {
|
||||
w := new(mockWriter)
|
||||
old := writer.Swap(w)
|
||||
defer writer.Store(old)
|
||||
|
||||
doTestStructedLog(t, levelDebug, w, func(v ...any) {
|
||||
Debugfn(func() any {
|
||||
return fmt.Sprint(v...)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("debugfn without output", func(t *testing.T) {
|
||||
w := new(mockWriter)
|
||||
old := writer.Swap(w)
|
||||
defer writer.Store(old)
|
||||
|
||||
doTestStructedLogEmpty(t, w, InfoLevel, func(v ...any) {
|
||||
Debugfn(func() any {
|
||||
return fmt.Sprint(v...)
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -288,6 +317,32 @@ func TestStructedLogErrorf(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestStructedLogErrorfn(t *testing.T) {
|
||||
t.Run("errorfn with output", func(t *testing.T) {
|
||||
w := new(mockWriter)
|
||||
old := writer.Swap(w)
|
||||
defer writer.Store(old)
|
||||
|
||||
doTestStructedLog(t, levelError, w, func(v ...any) {
|
||||
Errorfn(func() any {
|
||||
return fmt.Sprint(v...)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("errorfn without output", func(t *testing.T) {
|
||||
w := new(mockWriter)
|
||||
old := writer.Swap(w)
|
||||
defer writer.Store(old)
|
||||
|
||||
doTestStructedLogEmpty(t, w, SevereLevel, func(v ...any) {
|
||||
Errorfn(func() any {
|
||||
return fmt.Sprint(v...)
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestStructedLogErrorv(t *testing.T) {
|
||||
w := new(mockWriter)
|
||||
old := writer.Swap(w)
|
||||
@@ -328,6 +383,32 @@ func TestStructedLogInfof(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestStructedInfofn(t *testing.T) {
|
||||
t.Run("infofn with output", func(t *testing.T) {
|
||||
w := new(mockWriter)
|
||||
old := writer.Swap(w)
|
||||
defer writer.Store(old)
|
||||
|
||||
doTestStructedLog(t, levelInfo, w, func(v ...any) {
|
||||
Infofn(func() any {
|
||||
return fmt.Sprint(v...)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("infofn without output", func(t *testing.T) {
|
||||
w := new(mockWriter)
|
||||
old := writer.Swap(w)
|
||||
defer writer.Store(old)
|
||||
|
||||
doTestStructedLogEmpty(t, w, ErrorLevel, func(v ...any) {
|
||||
Infofn(func() any {
|
||||
return fmt.Sprint(v...)
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestStructedLogInfov(t *testing.T) {
|
||||
w := new(mockWriter)
|
||||
old := writer.Swap(w)
|
||||
@@ -451,6 +532,17 @@ func TestStructedLogInfoConsoleText(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestInfofnWithErrorLevel(t *testing.T) {
|
||||
called := false
|
||||
SetLevel(ErrorLevel)
|
||||
defer SetLevel(DebugLevel)
|
||||
Infofn(func() any {
|
||||
called = true
|
||||
return "info log"
|
||||
})
|
||||
assert.False(t, called)
|
||||
}
|
||||
|
||||
func TestStructedLogSlow(t *testing.T) {
|
||||
w := new(mockWriter)
|
||||
old := writer.Swap(w)
|
||||
@@ -467,7 +559,33 @@ func TestStructedLogSlowf(t *testing.T) {
|
||||
defer writer.Store(old)
|
||||
|
||||
doTestStructedLog(t, levelSlow, w, func(v ...any) {
|
||||
Slowf(fmt.Sprint(v...))
|
||||
Slowf("%s", fmt.Sprint(v...))
|
||||
})
|
||||
}
|
||||
|
||||
func TestStructedLogSlowfn(t *testing.T) {
|
||||
t.Run("slowfn with output", func(t *testing.T) {
|
||||
w := new(mockWriter)
|
||||
old := writer.Swap(w)
|
||||
defer writer.Store(old)
|
||||
|
||||
doTestStructedLog(t, levelSlow, w, func(v ...any) {
|
||||
Slowfn(func() any {
|
||||
return fmt.Sprint(v...)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("slowfn without output", func(t *testing.T) {
|
||||
w := new(mockWriter)
|
||||
old := writer.Swap(w)
|
||||
defer writer.Store(old)
|
||||
|
||||
doTestStructedLogEmpty(t, w, SevereLevel, func(v ...any) {
|
||||
Slowfn(func() any {
|
||||
return fmt.Sprint(v...)
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -507,7 +625,7 @@ func TestStructedLogStatf(t *testing.T) {
|
||||
defer writer.Store(old)
|
||||
|
||||
doTestStructedLog(t, levelStat, w, func(v ...any) {
|
||||
Statf(fmt.Sprint(v...))
|
||||
Statf("%s", fmt.Sprint(v...))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -527,7 +645,7 @@ func TestStructedLogSeveref(t *testing.T) {
|
||||
defer writer.Store(old)
|
||||
|
||||
doTestStructedLog(t, levelSevere, w, func(v ...any) {
|
||||
Severef(fmt.Sprint(v...))
|
||||
Severef("%s", fmt.Sprint(v...))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -661,15 +779,9 @@ func TestSetup(t *testing.T) {
|
||||
MaxBackups: 3,
|
||||
MaxSize: 1024 * 1024,
|
||||
}))
|
||||
setupLogLevel(LogConf{
|
||||
Level: levelInfo,
|
||||
})
|
||||
setupLogLevel(LogConf{
|
||||
Level: levelError,
|
||||
})
|
||||
setupLogLevel(LogConf{
|
||||
Level: levelSevere,
|
||||
})
|
||||
setupLogLevel(levelInfo)
|
||||
setupLogLevel(levelError)
|
||||
setupLogLevel(levelSevere)
|
||||
_, err := createOutput("")
|
||||
assert.NotNil(t, err)
|
||||
Disable()
|
||||
@@ -741,6 +853,95 @@ func TestWithKeepDays(t *testing.T) {
|
||||
assert.Equal(t, 1, opt.keepDays)
|
||||
}
|
||||
|
||||
func TestWithField_LogLevel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
level uint32
|
||||
fn func(string, ...LogField)
|
||||
count int32
|
||||
}{
|
||||
{
|
||||
name: "debug/info",
|
||||
level: DebugLevel,
|
||||
fn: Infow,
|
||||
count: 1,
|
||||
},
|
||||
{
|
||||
name: "info/error",
|
||||
level: InfoLevel,
|
||||
fn: Errorw,
|
||||
count: 1,
|
||||
},
|
||||
{
|
||||
name: "info/info",
|
||||
level: InfoLevel,
|
||||
fn: Infow,
|
||||
count: 1,
|
||||
},
|
||||
{
|
||||
name: "info/severe",
|
||||
level: InfoLevel,
|
||||
fn: Errorw,
|
||||
count: 1,
|
||||
},
|
||||
{
|
||||
name: "error/info",
|
||||
level: ErrorLevel,
|
||||
fn: Infow,
|
||||
count: 0,
|
||||
},
|
||||
{
|
||||
name: "error/debug",
|
||||
level: ErrorLevel,
|
||||
fn: Debugw,
|
||||
count: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
olevel := atomic.LoadUint32(&logLevel)
|
||||
SetLevel(tt.level)
|
||||
defer SetLevel(olevel)
|
||||
|
||||
var val countingStringer
|
||||
tt.fn("hello there", Field("foo", &val))
|
||||
assert.Equal(t, tt.count, val.Count())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithField_LogLevelWithContext(t *testing.T) {
|
||||
t.Run("context more than once with info/info", func(t *testing.T) {
|
||||
olevel := atomic.LoadUint32(&logLevel)
|
||||
SetLevel(InfoLevel)
|
||||
defer SetLevel(olevel)
|
||||
|
||||
var val countingStringer
|
||||
ctx := ContextWithFields(context.Background(), Field("foo", &val))
|
||||
logger := WithContext(ctx)
|
||||
logger.Info("hello there")
|
||||
logger.Info("hello there")
|
||||
logger.Info("hello there")
|
||||
assert.True(t, val.Count() > 0)
|
||||
})
|
||||
|
||||
t.Run("context more than once with error/info", func(t *testing.T) {
|
||||
olevel := atomic.LoadUint32(&logLevel)
|
||||
SetLevel(ErrorLevel)
|
||||
defer SetLevel(olevel)
|
||||
|
||||
var val countingStringer
|
||||
ctx := ContextWithFields(context.Background(), Field("foo", &val))
|
||||
logger := WithContext(ctx)
|
||||
logger.Info("hello there")
|
||||
logger.Info("hello there")
|
||||
logger.Info("hello there")
|
||||
assert.Equal(t, int32(0), val.Count())
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCopyByteSliceAppend(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var buf []byte
|
||||
@@ -847,6 +1048,16 @@ func doTestStructedLogConsole(t *testing.T, w *mockWriter, write func(...any)) {
|
||||
assert.True(t, strings.Contains(w.String(), message))
|
||||
}
|
||||
|
||||
func doTestStructedLogEmpty(t *testing.T, w *mockWriter, level uint32, write func(...any)) {
|
||||
olevel := atomic.LoadUint32(&logLevel)
|
||||
SetLevel(level)
|
||||
defer SetLevel(olevel)
|
||||
|
||||
const message = "hello there"
|
||||
write(message)
|
||||
assert.Empty(t, w.String())
|
||||
}
|
||||
|
||||
func testSetLevelTwiceWithMode(t *testing.T, mode string, w *mockWriter) {
|
||||
writer.Store(nil)
|
||||
SetUp(LogConf{
|
||||
@@ -929,3 +1140,79 @@ type panicStringer struct {
|
||||
func (s panicStringer) String() string {
|
||||
panic("panic")
|
||||
}
|
||||
|
||||
type countingStringer struct {
|
||||
count int32
|
||||
}
|
||||
|
||||
func (s *countingStringer) Count() int32 {
|
||||
return atomic.LoadInt32(&s.count)
|
||||
}
|
||||
|
||||
func (s *countingStringer) String() string {
|
||||
atomic.AddInt32(&s.count, 1)
|
||||
return "countingStringer"
|
||||
}
|
||||
|
||||
func TestLogKey(t *testing.T) {
|
||||
setupOnce = sync.Once{}
|
||||
MustSetup(LogConf{
|
||||
ServiceName: "any",
|
||||
Mode: "console",
|
||||
Encoding: "json",
|
||||
TimeFormat: timeFormat,
|
||||
FieldKeys: fieldKeyConf{
|
||||
CallerKey: "_caller",
|
||||
ContentKey: "_content",
|
||||
DurationKey: "_duration",
|
||||
LevelKey: "_level",
|
||||
SpanKey: "_span",
|
||||
TimestampKey: "_timestamp",
|
||||
TraceKey: "_trace",
|
||||
TruncatedKey: "_truncated",
|
||||
},
|
||||
})
|
||||
|
||||
t.Cleanup(func() {
|
||||
setupFieldKeys(fieldKeyConf{
|
||||
CallerKey: defaultCallerKey,
|
||||
ContentKey: defaultContentKey,
|
||||
DurationKey: defaultDurationKey,
|
||||
LevelKey: defaultLevelKey,
|
||||
SpanKey: defaultSpanKey,
|
||||
TimestampKey: defaultTimestampKey,
|
||||
TraceKey: defaultTraceKey,
|
||||
TruncatedKey: defaultTruncatedKey,
|
||||
})
|
||||
})
|
||||
|
||||
const message = "hello there"
|
||||
w := new(mockWriter)
|
||||
old := writer.Swap(w)
|
||||
defer writer.Store(old)
|
||||
|
||||
otp := otel.GetTracerProvider()
|
||||
tp := trace.NewTracerProvider(trace.WithSampler(trace.AlwaysSample()))
|
||||
otel.SetTracerProvider(tp)
|
||||
defer otel.SetTracerProvider(otp)
|
||||
|
||||
ctx, span := tp.Tracer("trace-id").Start(context.Background(), "span-id")
|
||||
defer span.End()
|
||||
|
||||
WithContext(ctx).WithDuration(time.Second).Info(message)
|
||||
now := time.Now()
|
||||
|
||||
var m map[string]string
|
||||
if err := json.Unmarshal([]byte(w.String()), &m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.Equal(t, "info", m["_level"])
|
||||
assert.Equal(t, message, m["_content"])
|
||||
assert.Equal(t, "1000.0ms", m["_duration"])
|
||||
assert.Regexp(t, `logx/logs_test.go:\d+`, m["_caller"])
|
||||
assert.NotEmpty(t, m["_trace"])
|
||||
assert.NotEmpty(t, m["_span"])
|
||||
parsedTime, err := time.Parse(timeFormat, m["_timestamp"])
|
||||
assert.True(t, err == nil)
|
||||
assert.Equal(t, now.Minute(), parsedTime.Minute())
|
||||
}
|
||||
|
||||
@@ -52,6 +52,12 @@ func (l *richLogger) Debugf(format string, v ...any) {
|
||||
}
|
||||
}
|
||||
|
||||
func (l *richLogger) Debugfn(fn func() any) {
|
||||
if shallLog(DebugLevel) {
|
||||
l.debug(fn())
|
||||
}
|
||||
}
|
||||
|
||||
func (l *richLogger) Debugv(v any) {
|
||||
if shallLog(DebugLevel) {
|
||||
l.debug(v)
|
||||
@@ -76,6 +82,12 @@ func (l *richLogger) Errorf(format string, v ...any) {
|
||||
}
|
||||
}
|
||||
|
||||
func (l *richLogger) Errorfn(fn func() any) {
|
||||
if shallLog(ErrorLevel) {
|
||||
l.err(fn())
|
||||
}
|
||||
}
|
||||
|
||||
func (l *richLogger) Errorv(v any) {
|
||||
if shallLog(ErrorLevel) {
|
||||
l.err(v)
|
||||
@@ -100,6 +112,12 @@ func (l *richLogger) Infof(format string, v ...any) {
|
||||
}
|
||||
}
|
||||
|
||||
func (l *richLogger) Infofn(fn func() any) {
|
||||
if shallLog(InfoLevel) {
|
||||
l.info(fn())
|
||||
}
|
||||
}
|
||||
|
||||
func (l *richLogger) Infov(v any) {
|
||||
if shallLog(InfoLevel) {
|
||||
l.info(v)
|
||||
@@ -124,6 +142,12 @@ func (l *richLogger) Slowf(format string, v ...any) {
|
||||
}
|
||||
}
|
||||
|
||||
func (l *richLogger) Slowfn(fn func() any) {
|
||||
if shallLog(ErrorLevel) {
|
||||
l.slow(fn())
|
||||
}
|
||||
}
|
||||
|
||||
func (l *richLogger) Slowv(v any) {
|
||||
if shallLog(ErrorLevel) {
|
||||
l.slow(v)
|
||||
@@ -182,7 +206,9 @@ func (l *richLogger) WithFields(fields ...LogField) Logger {
|
||||
|
||||
func (l *richLogger) buildFields(fields ...LogField) []LogField {
|
||||
fields = append(l.fields, fields...)
|
||||
// caller field should always appear together with global fields
|
||||
fields = append(fields, Field(callerKey, getCaller(callerDepth+l.callerSkip)))
|
||||
fields = mergeGlobalFields(fields)
|
||||
|
||||
if l.ctx == nil {
|
||||
return fields
|
||||
@@ -198,7 +224,7 @@ func (l *richLogger) buildFields(fields ...LogField) []LogField {
|
||||
fields = append(fields, Field(spanKey, spanID))
|
||||
}
|
||||
|
||||
val := l.ctx.Value(fieldsContextKey)
|
||||
val := l.ctx.Value(fieldsKey{})
|
||||
if val != nil {
|
||||
if arr, ok := val.([]LogField); ok {
|
||||
fields = append(fields, arr...)
|
||||
|
||||
@@ -63,6 +63,11 @@ func TestTraceDebug(t *testing.T) {
|
||||
l.WithDuration(time.Second).Debugf(testlog)
|
||||
validate(t, w.String(), true, true)
|
||||
w.Reset()
|
||||
l.WithDuration(time.Second).Debugfn(func() any {
|
||||
return testlog
|
||||
})
|
||||
validate(t, w.String(), true, true)
|
||||
w.Reset()
|
||||
l.WithDuration(time.Second).Debugv(testlog)
|
||||
validate(t, w.String(), true, true)
|
||||
w.Reset()
|
||||
@@ -103,6 +108,11 @@ func TestTraceError(t *testing.T) {
|
||||
l.WithDuration(time.Second).Errorf(testlog)
|
||||
validate(t, w.String(), true, true)
|
||||
w.Reset()
|
||||
l.WithDuration(time.Second).Errorfn(func() any {
|
||||
return testlog
|
||||
})
|
||||
validate(t, w.String(), true, true)
|
||||
w.Reset()
|
||||
l.WithDuration(time.Second).Errorv(testlog)
|
||||
validate(t, w.String(), true, true)
|
||||
w.Reset()
|
||||
@@ -140,6 +150,11 @@ func TestTraceInfo(t *testing.T) {
|
||||
l.WithDuration(time.Second).Infof(testlog)
|
||||
validate(t, w.String(), true, true)
|
||||
w.Reset()
|
||||
l.WithDuration(time.Second).Infofn(func() any {
|
||||
return testlog
|
||||
})
|
||||
validate(t, w.String(), true, true)
|
||||
w.Reset()
|
||||
l.WithDuration(time.Second).Infov(testlog)
|
||||
validate(t, w.String(), true, true)
|
||||
w.Reset()
|
||||
@@ -213,6 +228,11 @@ func TestTraceSlow(t *testing.T) {
|
||||
l.WithDuration(time.Second).Slowf(testlog)
|
||||
validate(t, w.String(), true, true)
|
||||
w.Reset()
|
||||
l.WithDuration(time.Second).Slowfn(func() any {
|
||||
return testlog
|
||||
})
|
||||
validate(t, w.String(), true, true)
|
||||
w.Reset()
|
||||
l.WithDuration(time.Second).Slowv(testlog)
|
||||
validate(t, w.String(), true, true)
|
||||
w.Reset()
|
||||
@@ -403,3 +423,49 @@ type mockValue struct {
|
||||
Foo string `json:"foo"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
type testJson struct {
|
||||
Name string `json:"name"`
|
||||
Age int `json:"age"`
|
||||
Score float64 `json:"score"`
|
||||
}
|
||||
|
||||
func (t testJson) MarshalJSON() ([]byte, error) {
|
||||
type testJsonImpl testJson
|
||||
return json.Marshal(testJsonImpl(t))
|
||||
}
|
||||
|
||||
func (t testJson) String() string {
|
||||
return fmt.Sprintf("%s %d %f", t.Name, t.Age, t.Score)
|
||||
}
|
||||
|
||||
func TestLogWithJson(t *testing.T) {
|
||||
w := new(mockWriter)
|
||||
old := writer.Swap(w)
|
||||
writer.lock.RLock()
|
||||
defer func() {
|
||||
writer.lock.RUnlock()
|
||||
writer.Store(old)
|
||||
}()
|
||||
|
||||
l := WithContext(context.Background()).WithFields(Field("bar", testJson{
|
||||
Name: "foo",
|
||||
Age: 1,
|
||||
Score: 1.0,
|
||||
}))
|
||||
l.Info(testlog)
|
||||
|
||||
type mockValue2 struct {
|
||||
mockValue
|
||||
Bar testJson `json:"bar"`
|
||||
}
|
||||
|
||||
var val mockValue2
|
||||
err := json.Unmarshal([]byte(w.String()), &val)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, testlog, val.Content)
|
||||
assert.Equal(t, "foo", val.Bar.Name)
|
||||
assert.Equal(t, 1, val.Bar.Age)
|
||||
assert.Equal(t, 1.0, val.Bar.Score)
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
dateFormat = "2006-01-02"
|
||||
hoursPerDay = 24
|
||||
bufferSize = 100
|
||||
defaultDirMode = 0o755
|
||||
@@ -67,7 +66,7 @@ type (
|
||||
gzip bool
|
||||
}
|
||||
|
||||
// SizeLimitRotateRule a rotation rule that make the log file rotated base on size
|
||||
// SizeLimitRotateRule a rotation rule that makes the log file rotated based on size
|
||||
SizeLimitRotateRule struct {
|
||||
DailyRotateRule
|
||||
maxSize int64
|
||||
@@ -116,7 +115,7 @@ func (r *DailyRotateRule) OutdatedFiles() []string {
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(dateFormat)
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(time.DateOnly)
|
||||
buf.WriteString(r.filename)
|
||||
buf.WriteString(r.delimiter)
|
||||
buf.WriteString(boundary)
|
||||
@@ -212,7 +211,7 @@ func (r *SizeLimitRotateRule) OutdatedFiles() []string {
|
||||
}
|
||||
}
|
||||
|
||||
var result []string
|
||||
result := make([]string, 0, len(outdated))
|
||||
for k := range outdated {
|
||||
result = append(result, k)
|
||||
}
|
||||
@@ -425,7 +424,7 @@ func compressLogFile(file string) {
|
||||
}
|
||||
|
||||
func getNowDate() string {
|
||||
return time.Now().Format(dateFormat)
|
||||
return time.Now().Format(time.DateOnly)
|
||||
}
|
||||
|
||||
func getNowDateInRFC3339Format() string {
|
||||
|
||||
@@ -52,7 +52,7 @@ func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("temp files", func(t *testing.T) {
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
_ = f1.Close()
|
||||
@@ -73,7 +73,7 @@ func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
|
||||
|
||||
func TestDailyRotateRuleShallRotate(t *testing.T) {
|
||||
var rule DailyRotateRule
|
||||
rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(dateFormat)
|
||||
rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(time.DateOnly)
|
||||
assert.True(t, rule.ShallRotate(0))
|
||||
}
|
||||
|
||||
@@ -117,12 +117,12 @@ func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("temp files", func(t *testing.T) {
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
@@ -144,12 +144,12 @@ func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("no backups", func(t *testing.T) {
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
@@ -319,7 +319,7 @@ func TestRotateLoggerWrite(t *testing.T) {
|
||||
}
|
||||
// the following write calls cannot be changed to Write, because of DATA RACE.
|
||||
logger.write([]byte(`foo`))
|
||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat)
|
||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(time.DateOnly)
|
||||
logger.write([]byte(`bar`))
|
||||
logger.Close()
|
||||
logger.write([]byte(`baz`))
|
||||
@@ -447,7 +447,7 @@ func TestRotateLoggerWithSizeLimitRotateRuleWrite(t *testing.T) {
|
||||
}
|
||||
// the following write calls cannot be changed to Write, because of DATA RACE.
|
||||
logger.write([]byte(`foo`))
|
||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat)
|
||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(time.DateOnly)
|
||||
logger.write([]byte(`bar`))
|
||||
logger.Close()
|
||||
logger.write([]byte(`baz`))
|
||||
|
||||
21
core/logx/sensitive.go
Normal file
21
core/logx/sensitive.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package logx
|
||||
|
||||
// Sensitive is an interface that defines a method for masking sensitive information in logs.
|
||||
// It is typically implemented by types that contain sensitive data,
|
||||
// such as passwords or personal information.
|
||||
// Infov, Errorv, Debugv, and Slowv methods will call this method to mask sensitive data.
|
||||
// The values in LogField will also be masked if they implement the Sensitive interface.
|
||||
type Sensitive interface {
|
||||
// MaskSensitive masks sensitive information in the log.
|
||||
MaskSensitive() any
|
||||
}
|
||||
|
||||
// maskSensitive returns the value returned by MaskSensitive method,
|
||||
// if the value implements Sensitive interface.
|
||||
func maskSensitive(v any) any {
|
||||
if s, ok := v.(Sensitive); ok {
|
||||
return s.MaskSensitive()
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
50
core/logx/sensitive_test.go
Normal file
50
core/logx/sensitive_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package logx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const maskedContent = "******"
|
||||
|
||||
type User struct {
|
||||
Name string
|
||||
Pass string
|
||||
}
|
||||
|
||||
func (u User) MaskSensitive() any {
|
||||
return User{
|
||||
Name: u.Name,
|
||||
Pass: maskedContent,
|
||||
}
|
||||
}
|
||||
|
||||
type NonSensitiveUser struct {
|
||||
Name string
|
||||
Pass string
|
||||
}
|
||||
|
||||
func TestMaskSensitive(t *testing.T) {
|
||||
t.Run("sensitive", func(t *testing.T) {
|
||||
user := User{
|
||||
Name: "kevin",
|
||||
Pass: "123",
|
||||
}
|
||||
|
||||
mu := maskSensitive(user)
|
||||
assert.Equal(t, user.Name, mu.(User).Name)
|
||||
assert.Equal(t, maskedContent, mu.(User).Pass)
|
||||
})
|
||||
|
||||
t.Run("non-sensitive", func(t *testing.T) {
|
||||
user := NonSensitiveUser{
|
||||
Name: "kevin",
|
||||
Pass: "123",
|
||||
}
|
||||
|
||||
mu := maskSensitive(user)
|
||||
assert.Equal(t, user.Name, mu.(NonSensitiveUser).Name)
|
||||
assert.Equal(t, user.Pass, mu.(NonSensitiveUser).Pass)
|
||||
})
|
||||
}
|
||||
@@ -53,14 +53,14 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
callerKey = "caller"
|
||||
contentKey = "content"
|
||||
durationKey = "duration"
|
||||
levelKey = "level"
|
||||
spanKey = "span"
|
||||
timestampKey = "@timestamp"
|
||||
traceKey = "trace"
|
||||
truncatedKey = "truncated"
|
||||
defaultCallerKey = "caller"
|
||||
defaultContentKey = "content"
|
||||
defaultDurationKey = "duration"
|
||||
defaultLevelKey = "level"
|
||||
defaultSpanKey = "span"
|
||||
defaultTimestampKey = "@timestamp"
|
||||
defaultTraceKey = "trace"
|
||||
defaultTruncatedKey = "truncated"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -73,3 +73,14 @@ var (
|
||||
|
||||
truncatedField = Field(truncatedKey, true)
|
||||
)
|
||||
|
||||
var (
|
||||
callerKey = defaultCallerKey
|
||||
contentKey = defaultContentKey
|
||||
durationKey = defaultDurationKey
|
||||
levelKey = defaultLevelKey
|
||||
spanKey = defaultSpanKey
|
||||
timestampKey = defaultTimestampKey
|
||||
traceKey = defaultTraceKey
|
||||
truncatedKey = defaultTruncatedKey
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
fatihcolor "github.com/fatih/color"
|
||||
"github.com/zeromicro/go-zero/core/color"
|
||||
@@ -17,15 +18,27 @@ import (
|
||||
)
|
||||
|
||||
type (
|
||||
// Writer is the interface for writing logs.
|
||||
// It's designed to let users customize their own log writer,
|
||||
// such as writing logs to a kafka, a database, or using third-party loggers.
|
||||
Writer interface {
|
||||
// Alert sends an alert message, if your writer implemented alerting functionality.
|
||||
Alert(v any)
|
||||
// Close closes the writer.
|
||||
Close() error
|
||||
// Debug logs a message at debug level.
|
||||
Debug(v any, fields ...LogField)
|
||||
// Error logs a message at error level.
|
||||
Error(v any, fields ...LogField)
|
||||
// Info logs a message at info level.
|
||||
Info(v any, fields ...LogField)
|
||||
// Severe logs a message at severe level.
|
||||
Severe(v any)
|
||||
// Slow logs a message at slow level.
|
||||
Slow(v any, fields ...LogField)
|
||||
// Stack logs a message at error level.
|
||||
Stack(v any)
|
||||
// Stat logs a message at stat level.
|
||||
Stat(v any, fields ...LogField)
|
||||
}
|
||||
|
||||
@@ -199,7 +212,6 @@ func newFileWriter(c LogConf) (Writer, error) {
|
||||
statFile := path.Join(c.Path, statFilename)
|
||||
|
||||
handleOptions(opts)
|
||||
setupLogLevel(c)
|
||||
|
||||
if infoLog, err = createOutput(accessFile); err != nil {
|
||||
return nil, err
|
||||
@@ -324,20 +336,6 @@ func buildPlainFields(fields logEntry) []string {
|
||||
return items
|
||||
}
|
||||
|
||||
func combineGlobalFields(fields []LogField) []LogField {
|
||||
globals := globalFields.Load()
|
||||
if globals == nil {
|
||||
return fields
|
||||
}
|
||||
|
||||
gf := globals.([]LogField)
|
||||
ret := make([]LogField, 0, len(gf)+len(fields))
|
||||
ret = append(ret, gf...)
|
||||
ret = append(ret, fields...)
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func marshalJson(t interface{}) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
encoder := json.NewEncoder(&buf)
|
||||
@@ -352,21 +350,40 @@ func marshalJson(t interface{}) ([]byte, error) {
|
||||
return buf.Bytes(), err
|
||||
}
|
||||
|
||||
func mergeGlobalFields(fields []LogField) []LogField {
|
||||
globals := globalFields.Load()
|
||||
if globals == nil {
|
||||
return fields
|
||||
}
|
||||
|
||||
gf := globals.([]LogField)
|
||||
ret := make([]LogField, 0, len(gf)+len(fields))
|
||||
ret = append(ret, gf...)
|
||||
ret = append(ret, fields...)
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func output(writer io.Writer, level string, val any, fields ...LogField) {
|
||||
// only truncate string content, don't know how to truncate the values of other types.
|
||||
if v, ok := val.(string); ok {
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
// only truncate string content, don't know how to truncate the values of other types.
|
||||
maxLen := atomic.LoadUint32(&maxContentLength)
|
||||
if maxLen > 0 && len(v) > int(maxLen) {
|
||||
val = v[:maxLen]
|
||||
fields = append(fields, truncatedField)
|
||||
}
|
||||
case Sensitive:
|
||||
val = v.MaskSensitive()
|
||||
}
|
||||
|
||||
fields = combineGlobalFields(fields)
|
||||
// +3 for timestamp, level and content
|
||||
entry := make(logEntry, len(fields)+3)
|
||||
for _, field := range fields {
|
||||
entry[field.Key] = field.Value
|
||||
// mask sensitive data before processing types,
|
||||
// in case field.Value is a sensitive type and also implemented fmt.Stringer.
|
||||
mval := maskSensitive(field.Value)
|
||||
entry[field.Key] = processFieldValue(mval)
|
||||
}
|
||||
|
||||
switch atomic.LoadUint32(&encoding) {
|
||||
@@ -381,6 +398,45 @@ func output(writer io.Writer, level string, val any, fields ...LogField) {
|
||||
}
|
||||
}
|
||||
|
||||
func processFieldValue(value any) any {
|
||||
switch val := value.(type) {
|
||||
case error:
|
||||
return encodeError(val)
|
||||
case []error:
|
||||
var errs []string
|
||||
for _, err := range val {
|
||||
errs = append(errs, encodeError(err))
|
||||
}
|
||||
return errs
|
||||
case time.Duration:
|
||||
return fmt.Sprint(val)
|
||||
case []time.Duration:
|
||||
var durs []string
|
||||
for _, dur := range val {
|
||||
durs = append(durs, fmt.Sprint(dur))
|
||||
}
|
||||
return durs
|
||||
case []time.Time:
|
||||
var times []string
|
||||
for _, t := range val {
|
||||
times = append(times, fmt.Sprint(t))
|
||||
}
|
||||
return times
|
||||
case json.Marshaler:
|
||||
return val
|
||||
case fmt.Stringer:
|
||||
return encodeStringer(val)
|
||||
case []fmt.Stringer:
|
||||
var strs []string
|
||||
for _, str := range val {
|
||||
strs = append(strs, encodeStringer(str))
|
||||
}
|
||||
return strs
|
||||
default:
|
||||
return val
|
||||
}
|
||||
}
|
||||
|
||||
func wrapLevelWithColor(level string) string {
|
||||
var colour color.Color
|
||||
switch level {
|
||||
@@ -388,6 +444,8 @@ func wrapLevelWithColor(level string) string {
|
||||
colour = color.FgRed
|
||||
case levelError:
|
||||
colour = color.FgRed
|
||||
case levelSevere:
|
||||
colour = color.FgRed
|
||||
case levelFatal:
|
||||
colour = color.FgRed
|
||||
case levelInfo:
|
||||
|
||||
@@ -225,6 +225,48 @@ func TestWritePlainDuplicate(t *testing.T) {
|
||||
assert.Contains(t, buf.String(), "second=c")
|
||||
}
|
||||
|
||||
func TestLogWithSensitive(t *testing.T) {
|
||||
old := atomic.SwapUint32(&encoding, plainEncodingType)
|
||||
t.Cleanup(func() {
|
||||
atomic.StoreUint32(&encoding, old)
|
||||
})
|
||||
|
||||
t.Run("sensitive", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
output(&buf, levelInfo, User{
|
||||
Name: "kevin",
|
||||
Pass: "123",
|
||||
}, LogField{
|
||||
Key: "first",
|
||||
Value: "a",
|
||||
}, LogField{
|
||||
Key: "first",
|
||||
Value: "b",
|
||||
})
|
||||
assert.Contains(t, buf.String(), maskedContent)
|
||||
assert.NotContains(t, buf.String(), "first=a")
|
||||
assert.Contains(t, buf.String(), "first=b")
|
||||
})
|
||||
|
||||
t.Run("sensitive fields", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
output(&buf, levelInfo, "foo", LogField{
|
||||
Key: "first",
|
||||
Value: User{
|
||||
Name: "kevin",
|
||||
Pass: "123",
|
||||
},
|
||||
}, LogField{
|
||||
Key: "second",
|
||||
Value: "b",
|
||||
})
|
||||
assert.Contains(t, buf.String(), "foo")
|
||||
assert.Contains(t, buf.String(), "first")
|
||||
assert.Contains(t, buf.String(), maskedContent)
|
||||
assert.Contains(t, buf.String(), "second=b")
|
||||
})
|
||||
}
|
||||
|
||||
func TestLogWithLimitContentLength(t *testing.T) {
|
||||
maxLen := atomic.LoadUint32(&maxContentLength)
|
||||
atomic.StoreUint32(&maxContentLength, 10)
|
||||
|
||||
@@ -3,6 +3,7 @@ package mapping
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -13,6 +14,15 @@ const (
|
||||
|
||||
// Marshal marshals the given val and returns the map that contains the fields.
|
||||
// optional=another is not implemented, and it's hard to implement and not commonly used.
|
||||
// support anonymous field, e.g.:
|
||||
//
|
||||
// type Foo struct {
|
||||
// Token string `header:"token"`
|
||||
// }
|
||||
// type FooB struct {
|
||||
// Foo
|
||||
// Bar string `json:"bar"`
|
||||
// }
|
||||
func Marshal(val any) (map[string]map[string]any, error) {
|
||||
ret := make(map[string]map[string]any)
|
||||
tp := reflect.TypeOf(val)
|
||||
@@ -44,6 +54,16 @@ func getTag(field reflect.StructField) (string, bool) {
|
||||
return strings.TrimSpace(tag), false
|
||||
}
|
||||
|
||||
func insertValue(collector map[string]map[string]any, tag string, key string, val any) {
|
||||
if m, ok := collector[tag]; ok {
|
||||
m[key] = val
|
||||
} else {
|
||||
collector[tag] = map[string]any{
|
||||
key: val,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func processMember(field reflect.StructField, value reflect.Value,
|
||||
collector map[string]map[string]any) error {
|
||||
var key string
|
||||
@@ -69,15 +89,20 @@ func processMember(field reflect.StructField, value reflect.Value,
|
||||
val = fmt.Sprint(val)
|
||||
}
|
||||
|
||||
m, ok := collector[tag]
|
||||
if ok {
|
||||
m[key] = val
|
||||
} else {
|
||||
m = map[string]any{
|
||||
key: val,
|
||||
if field.Anonymous {
|
||||
anonCollector, err := Marshal(val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for anonTag, anonMap := range anonCollector {
|
||||
for anonKey, anonVal := range anonMap {
|
||||
insertValue(collector, anonTag, anonKey, anonVal)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
insertValue(collector, tag, key, val)
|
||||
}
|
||||
collector[tag] = m
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -118,7 +143,7 @@ func validateOptional(field reflect.StructField, value reflect.Value) error {
|
||||
if value.IsNil() {
|
||||
return fmt.Errorf("field %q is nil", field.Name)
|
||||
}
|
||||
case reflect.Array, reflect.Slice, reflect.Map:
|
||||
case reflect.Slice, reflect.Map:
|
||||
if value.IsNil() || value.Len() == 0 {
|
||||
return fmt.Errorf("field %q is empty", field.Name)
|
||||
}
|
||||
@@ -128,15 +153,8 @@ func validateOptional(field reflect.StructField, value reflect.Value) error {
|
||||
}
|
||||
|
||||
func validateOptions(value reflect.Value, opt *fieldOptions) error {
|
||||
var found bool
|
||||
val := fmt.Sprint(value.Interface())
|
||||
for i := range opt.Options {
|
||||
if opt.Options[i] == val {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
if !slices.Contains(opt.Options, val) {
|
||||
return fmt.Errorf("field %q not in options", val)
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,124 @@ func TestMarshal(t *testing.T) {
|
||||
assert.True(t, m[emptyTag]["Anonymous"].(bool))
|
||||
}
|
||||
|
||||
func TestMarshal_Anonymous(t *testing.T) {
|
||||
t.Run("anonymous", func(t *testing.T) {
|
||||
type BaseHeader struct {
|
||||
Token string `header:"token"`
|
||||
}
|
||||
v := struct {
|
||||
Name string `json:"name"`
|
||||
Address string `json:"address,options=[beijing,shanghai]"`
|
||||
Age int `json:"age"`
|
||||
BaseHeader
|
||||
}{
|
||||
Name: "kevin",
|
||||
Address: "shanghai",
|
||||
Age: 20,
|
||||
BaseHeader: BaseHeader{
|
||||
Token: "token_xxx",
|
||||
},
|
||||
}
|
||||
m, err := Marshal(v)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "kevin", m["json"]["name"])
|
||||
assert.Equal(t, "shanghai", m["json"]["address"])
|
||||
assert.Equal(t, 20, m["json"]["age"].(int))
|
||||
assert.Equal(t, "token_xxx", m["header"]["token"])
|
||||
|
||||
v1 := struct {
|
||||
Name string `json:"name"`
|
||||
Address string `json:"address,options=[beijing,shanghai]"`
|
||||
Age int `json:"age"`
|
||||
BaseHeader
|
||||
}{
|
||||
Name: "kevin",
|
||||
Address: "shanghai",
|
||||
Age: 20,
|
||||
}
|
||||
m1, err1 := Marshal(v1)
|
||||
assert.Nil(t, err1)
|
||||
assert.Equal(t, "kevin", m1["json"]["name"])
|
||||
assert.Equal(t, "shanghai", m1["json"]["address"])
|
||||
assert.Equal(t, 20, m1["json"]["age"].(int))
|
||||
|
||||
type AnotherHeader struct {
|
||||
Version string `header:"version"`
|
||||
}
|
||||
v2 := struct {
|
||||
Name string `json:"name"`
|
||||
Address string `json:"address,options=[beijing,shanghai]"`
|
||||
Age int `json:"age"`
|
||||
BaseHeader
|
||||
AnotherHeader
|
||||
}{
|
||||
Name: "kevin",
|
||||
Address: "shanghai",
|
||||
Age: 20,
|
||||
BaseHeader: BaseHeader{
|
||||
Token: "token_xxx",
|
||||
},
|
||||
AnotherHeader: AnotherHeader{
|
||||
Version: "v1.0",
|
||||
},
|
||||
}
|
||||
m2, err2 := Marshal(v2)
|
||||
assert.Nil(t, err2)
|
||||
assert.Equal(t, "kevin", m2["json"]["name"])
|
||||
assert.Equal(t, "shanghai", m2["json"]["address"])
|
||||
assert.Equal(t, 20, m2["json"]["age"].(int))
|
||||
assert.Equal(t, "token_xxx", m2["header"]["token"])
|
||||
assert.Equal(t, "v1.0", m2["header"]["version"])
|
||||
|
||||
type PointerHeader struct {
|
||||
Ref *string `header:"ref"`
|
||||
}
|
||||
ref := "reference"
|
||||
v3 := struct {
|
||||
Name string `json:"name"`
|
||||
Address string `json:"address,options=[beijing,shanghai]"`
|
||||
Age int `json:"age"`
|
||||
PointerHeader
|
||||
}{
|
||||
Name: "kevin",
|
||||
Address: "shanghai",
|
||||
Age: 20,
|
||||
PointerHeader: PointerHeader{
|
||||
Ref: &ref,
|
||||
},
|
||||
}
|
||||
m3, err3 := Marshal(v3)
|
||||
assert.Nil(t, err3)
|
||||
assert.Equal(t, "kevin", m3["json"]["name"])
|
||||
assert.Equal(t, "shanghai", m3["json"]["address"])
|
||||
assert.Equal(t, 20, m3["json"]["age"].(int))
|
||||
assert.Equal(t, "reference", *m3["header"]["ref"].(*string))
|
||||
})
|
||||
|
||||
t.Run("bad anonymous", func(t *testing.T) {
|
||||
type BaseHeader struct {
|
||||
Token string `json:"token,options=[a,b]"`
|
||||
}
|
||||
|
||||
v := struct {
|
||||
Name string `json:"name"`
|
||||
Address string `json:"address,options=[beijing,shanghai]"`
|
||||
Age int `json:"age"`
|
||||
BaseHeader
|
||||
}{
|
||||
Name: "kevin",
|
||||
Address: "shanghai",
|
||||
Age: 20,
|
||||
BaseHeader: BaseHeader{
|
||||
Token: "c",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := Marshal(v)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMarshal_Ptr(t *testing.T) {
|
||||
v := &struct {
|
||||
Name string `path:"name"`
|
||||
@@ -344,3 +462,15 @@ func TestMarshal_FromString(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10", m["json"]["age"].(string))
|
||||
}
|
||||
|
||||
func TestMarshal_Array(t *testing.T) {
|
||||
v := struct {
|
||||
H [1]int `json:"h,string"`
|
||||
}{
|
||||
H: [1]int{1},
|
||||
}
|
||||
|
||||
m, err := Marshal(v)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "[1]", m["json"]["h"].(string))
|
||||
}
|
||||
|
||||
@@ -2,10 +2,12 @@ package mapping
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -14,7 +16,6 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/jsonx"
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -50,6 +51,7 @@ type (
|
||||
|
||||
unmarshalOptions struct {
|
||||
fillDefault bool
|
||||
fromArray bool
|
||||
fromString bool
|
||||
opaqueKeys bool
|
||||
canonicalKey func(key string) string
|
||||
@@ -79,40 +81,11 @@ func (u *Unmarshaler) Unmarshal(i, v any) error {
|
||||
return u.unmarshal(i, v, "")
|
||||
}
|
||||
|
||||
func (u *Unmarshaler) unmarshal(i, v any, fullName string) error {
|
||||
valueType := reflect.TypeOf(v)
|
||||
if valueType.Kind() != reflect.Ptr {
|
||||
return errValueNotSettable
|
||||
}
|
||||
|
||||
elemType := Deref(valueType)
|
||||
switch iv := i.(type) {
|
||||
case map[string]any:
|
||||
if elemType.Kind() != reflect.Struct {
|
||||
return errTypeMismatch
|
||||
}
|
||||
|
||||
return u.unmarshalValuer(mapValuer(iv), v, fullName)
|
||||
case []any:
|
||||
if elemType.Kind() != reflect.Slice {
|
||||
return errTypeMismatch
|
||||
}
|
||||
|
||||
return u.fillSlice(elemType, reflect.ValueOf(v).Elem(), iv, fullName)
|
||||
default:
|
||||
return errUnsupportedType
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalValuer unmarshals m into v.
|
||||
func (u *Unmarshaler) UnmarshalValuer(m Valuer, v any) error {
|
||||
return u.unmarshalValuer(simpleValuer{current: m}, v, "")
|
||||
}
|
||||
|
||||
func (u *Unmarshaler) unmarshalValuer(m Valuer, v any, fullName string) error {
|
||||
return u.unmarshalWithFullName(simpleValuer{current: m}, v, fullName)
|
||||
}
|
||||
|
||||
func (u *Unmarshaler) fillMap(fieldType reflect.Type, value reflect.Value,
|
||||
mapValue any, fullName string) error {
|
||||
if !value.CanSet() {
|
||||
@@ -172,13 +145,14 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value,
|
||||
baseType := fieldType.Elem()
|
||||
dereffedBaseType := Deref(baseType)
|
||||
dereffedBaseKind := dereffedBaseType.Kind()
|
||||
conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap())
|
||||
if refValue.Len() == 0 {
|
||||
value.Set(conv)
|
||||
value.Set(reflect.MakeSlice(reflect.SliceOf(baseType), 0, 0))
|
||||
return nil
|
||||
}
|
||||
|
||||
var valid bool
|
||||
conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap())
|
||||
|
||||
for i := 0; i < refValue.Len(); i++ {
|
||||
ithValue := refValue.Index(i).Interface()
|
||||
if ithValue == nil {
|
||||
@@ -190,17 +164,9 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value,
|
||||
|
||||
switch dereffedBaseKind {
|
||||
case reflect.Struct:
|
||||
target := reflect.New(dereffedBaseType)
|
||||
val, ok := ithValue.(map[string]any)
|
||||
if !ok {
|
||||
return errTypeMismatch
|
||||
}
|
||||
|
||||
if err := u.unmarshal(val, target.Interface(), sliceFullName); err != nil {
|
||||
if err := u.fillStructElement(baseType, conv.Index(i), ithValue, sliceFullName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
SetValue(fieldType.Elem(), conv.Index(i), target.Elem())
|
||||
case reflect.Slice:
|
||||
if err := u.fillSlice(dereffedBaseType, conv.Index(i), ithValue, sliceFullName); err != nil {
|
||||
return err
|
||||
@@ -235,7 +201,7 @@ func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.
|
||||
return errUnsupportedType
|
||||
}
|
||||
|
||||
baseFieldType := Deref(fieldType.Elem())
|
||||
baseFieldType := fieldType.Elem()
|
||||
baseFieldKind := baseFieldType.Kind()
|
||||
conv := reflect.MakeSlice(reflect.SliceOf(baseFieldType), len(slice), cap(slice))
|
||||
|
||||
@@ -256,29 +222,39 @@ func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
|
||||
}
|
||||
|
||||
ithVal := slice.Index(index)
|
||||
ithValType := ithVal.Type()
|
||||
|
||||
switch v := value.(type) {
|
||||
case fmt.Stringer:
|
||||
return setValueFromString(baseKind, ithVal, v.String())
|
||||
case string:
|
||||
return setValueFromString(baseKind, ithVal, v)
|
||||
case map[string]any:
|
||||
return u.fillMap(ithVal.Type(), ithVal, value, fullName)
|
||||
// deref to handle both pointer and non-pointer types.
|
||||
switch Deref(ithValType).Kind() {
|
||||
case reflect.Struct:
|
||||
return u.fillStructElement(ithValType, ithVal, v, fullName)
|
||||
case reflect.Map:
|
||||
return u.fillMap(ithValType, ithVal, value, fullName)
|
||||
default:
|
||||
return errTypeMismatch
|
||||
}
|
||||
default:
|
||||
// don't need to consider the difference between int, int8, int16, int32, int64,
|
||||
// uint, uint8, uint16, uint32, uint64, because they're handled as json.Number.
|
||||
if ithVal.Kind() == reflect.Ptr {
|
||||
baseType := Deref(ithVal.Type())
|
||||
baseType := Deref(ithValType)
|
||||
if !reflect.TypeOf(value).AssignableTo(baseType) {
|
||||
return errTypeMismatch
|
||||
}
|
||||
|
||||
target := reflect.New(baseType).Elem()
|
||||
target.Set(reflect.ValueOf(value))
|
||||
SetValue(ithVal.Type(), ithVal, target)
|
||||
SetValue(ithValType, ithVal, target)
|
||||
return nil
|
||||
}
|
||||
|
||||
if !reflect.TypeOf(value).AssignableTo(ithVal.Type()) {
|
||||
if !reflect.TypeOf(value).AssignableTo(ithValType) {
|
||||
return errTypeMismatch
|
||||
}
|
||||
|
||||
@@ -309,6 +285,23 @@ func (u *Unmarshaler) fillSliceWithDefault(derefedType reflect.Type, value refle
|
||||
return u.fillSlice(derefedType, value, slice, fullName)
|
||||
}
|
||||
|
||||
func (u *Unmarshaler) fillStructElement(baseType reflect.Type, target reflect.Value,
|
||||
value any, fullName string) error {
|
||||
val, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
return errTypeMismatch
|
||||
}
|
||||
|
||||
// use Deref(baseType) to get the base type in case the type is a pointer type.
|
||||
ptr := reflect.New(Deref(baseType))
|
||||
if err := u.unmarshal(val, ptr.Interface(), fullName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
SetValue(baseType, target, ptr.Elem())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *Unmarshaler) fillUnmarshalerStruct(fieldType reflect.Type,
|
||||
value reflect.Value, targetValue string) error {
|
||||
if !value.CanSet() {
|
||||
@@ -611,11 +604,37 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
|
||||
case valueKind == reflect.String && typeKind == reflect.Map:
|
||||
return u.fillMapFromString(value, mapValue)
|
||||
case valueKind == reflect.String && typeKind == reflect.Slice:
|
||||
// try to find out if it's a byte slice,
|
||||
// more details https://pkg.go.dev/encoding/json#Marshal
|
||||
// array and slice values encode as JSON arrays,
|
||||
// except that []byte encodes as a base64-encoded string,
|
||||
// and a nil slice encoded as the null JSON value.
|
||||
// https://stackoverflow.com/questions/34089750/marshal-byte-to-json-giving-a-strange-string
|
||||
if fieldType.Elem().Kind() == reflect.Uint8 {
|
||||
// check whether string type, because the kind of some other types can be string
|
||||
if strVal, ok := mapValue.(string); ok {
|
||||
if decodedBytes, err := base64.StdEncoding.DecodeString(strVal); err == nil {
|
||||
value.Set(reflect.ValueOf(decodedBytes))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return u.fillSliceFromString(fieldType, value, mapValue, fullName)
|
||||
case valueKind == reflect.String && derefedFieldType == durationType:
|
||||
return fillDurationValue(fieldType, value, mapValue.(string))
|
||||
v, err := convertToString(mapValue, fullName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fillDurationValue(fieldType, value, v)
|
||||
case valueKind == reflect.String && typeKind == reflect.Struct && u.implementsUnmarshaler(fieldType):
|
||||
return u.fillUnmarshalerStruct(fieldType, value, mapValue.(string))
|
||||
v, err := convertToString(mapValue, fullName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return u.fillUnmarshalerStruct(fieldType, value, v)
|
||||
default:
|
||||
return u.processFieldPrimitive(fieldType, value, mapValue, opts, fullName)
|
||||
}
|
||||
@@ -746,24 +765,26 @@ func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value ref
|
||||
return err
|
||||
}
|
||||
|
||||
fieldKind := fieldType.Kind()
|
||||
switch fieldKind {
|
||||
case reflect.Bool:
|
||||
derefType := Deref(fieldType)
|
||||
derefKind := derefType.Kind()
|
||||
switch {
|
||||
case derefKind == reflect.String:
|
||||
SetValue(fieldType, value, toReflectValue(derefType, envVal))
|
||||
return nil
|
||||
case derefKind == reflect.Bool:
|
||||
val, err := strconv.ParseBool(envVal)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
|
||||
}
|
||||
|
||||
value.SetBool(val)
|
||||
SetValue(fieldType, value, toReflectValue(derefType, val))
|
||||
return nil
|
||||
case durationType.Kind():
|
||||
case derefType == durationType:
|
||||
// time.Duration is a special case, its derefKind is reflect.Int64.
|
||||
if err := fillDurationValue(fieldType, value, envVal); err != nil {
|
||||
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
case reflect.String:
|
||||
value.SetString(envVal)
|
||||
return nil
|
||||
default:
|
||||
return u.processFieldPrimitiveWithJSONNumber(fieldType, value, json.Number(envVal), opts, fullName)
|
||||
@@ -811,6 +832,19 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect
|
||||
return u.processNamedFieldWithoutValue(field.Type, value, opts, fullName)
|
||||
}
|
||||
|
||||
if u.opts.fromArray {
|
||||
fieldKind := field.Type.Kind()
|
||||
if fieldKind != reflect.Slice && fieldKind != reflect.Array {
|
||||
valueKind := reflect.TypeOf(mapValue).Kind()
|
||||
if valueKind == reflect.Slice || valueKind == reflect.Array {
|
||||
val := reflect.ValueOf(mapValue)
|
||||
if val.Len() > 0 {
|
||||
mapValue = val.Index(0).Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return u.processNamedFieldWithValue(field.Type, value, valueWithParent{
|
||||
value: mapValue,
|
||||
parent: valuer,
|
||||
@@ -872,7 +906,7 @@ func (u *Unmarshaler) processNamedFieldWithValueFromString(fieldType reflect.Typ
|
||||
valueKind.String())
|
||||
}
|
||||
|
||||
if !stringx.Contains(options, checkValue) {
|
||||
if !slices.Contains(options, checkValue) {
|
||||
return fmt.Errorf(`value "%s" for field %q is not defined in options "%v"`,
|
||||
mapValue, key, options)
|
||||
}
|
||||
@@ -938,6 +972,35 @@ func (u *Unmarshaler) processNamedFieldWithoutValue(fieldType reflect.Type, valu
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *Unmarshaler) unmarshal(i, v any, fullName string) error {
|
||||
valueType := reflect.TypeOf(v)
|
||||
if valueType.Kind() != reflect.Ptr {
|
||||
return errValueNotSettable
|
||||
}
|
||||
|
||||
elemType := Deref(valueType)
|
||||
switch iv := i.(type) {
|
||||
case map[string]any:
|
||||
if elemType.Kind() != reflect.Struct {
|
||||
return errTypeMismatch
|
||||
}
|
||||
|
||||
return u.unmarshalValuer(mapValuer(iv), v, fullName)
|
||||
case []any:
|
||||
if elemType.Kind() != reflect.Slice {
|
||||
return errTypeMismatch
|
||||
}
|
||||
|
||||
return u.fillSlice(elemType, reflect.ValueOf(v).Elem(), iv, fullName)
|
||||
default:
|
||||
return errUnsupportedType
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Unmarshaler) unmarshalValuer(m Valuer, v any, fullName string) error {
|
||||
return u.unmarshalWithFullName(simpleValuer{current: m}, v, fullName)
|
||||
}
|
||||
|
||||
func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v any, fullName string) error {
|
||||
rv := reflect.ValueOf(v)
|
||||
if err := ValidatePtr(rv); err != nil {
|
||||
@@ -990,6 +1053,16 @@ func WithDefault() UnmarshalOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithFromArray customizes an Unmarshaler with converting array values to non-array types.
|
||||
// For example, if the field type is []string, and the value is [hello],
|
||||
// the field type can be `string`, instead of `[]string`.
|
||||
// Typically, this option is used for unmarshaling from form values.
|
||||
func WithFromArray() UnmarshalOption {
|
||||
return func(opt *unmarshalOptions) {
|
||||
opt.fromArray = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithOpaqueKeys customizes an Unmarshaler with opaque keys.
|
||||
// Opaque keys are keys that are not processed by the unmarshaler.
|
||||
func WithOpaqueKeys() UnmarshalOption {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/jsonx"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
@@ -202,6 +203,20 @@ func TestUnmarshalDuration(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalDurationUnexpectedError(t *testing.T) {
|
||||
type inner struct {
|
||||
Duration time.Duration `key:"duration"`
|
||||
}
|
||||
content := "{\"duration\": 1}"
|
||||
var m = map[string]any{}
|
||||
err := jsonx.Unmarshal([]byte(content), &m)
|
||||
assert.NoError(t, err)
|
||||
var in inner
|
||||
err = UnmarshalKey(m, &in)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "expect string")
|
||||
}
|
||||
|
||||
func TestUnmarshalDurationDefault(t *testing.T) {
|
||||
type inner struct {
|
||||
Int int `key:"int"`
|
||||
@@ -351,7 +366,7 @@ func TestUnmarshalIntSliceOfPtr(t *testing.T) {
|
||||
assert.Error(t, UnmarshalKey(m, &in))
|
||||
})
|
||||
|
||||
t.Run("int slice with nil", func(t *testing.T) {
|
||||
t.Run("int slice with nil element", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Ints []int `key:"ints"`
|
||||
}
|
||||
@@ -365,6 +380,21 @@ func TestUnmarshalIntSliceOfPtr(t *testing.T) {
|
||||
assert.Empty(t, in.Ints)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("int slice with nil", func(t *testing.T) {
|
||||
type inner struct {
|
||||
Ints []int `key:"ints"`
|
||||
}
|
||||
|
||||
m := map[string]any{
|
||||
"ints": []any(nil),
|
||||
}
|
||||
|
||||
var in inner
|
||||
if assert.NoError(t, UnmarshalKey(m, &in)) {
|
||||
assert.Empty(t, in.Ints)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalIntWithDefault(t *testing.T) {
|
||||
@@ -1374,20 +1404,80 @@ func TestUnmarshalWithFloatPtr(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnmarshalIntSlice(t *testing.T) {
|
||||
var v struct {
|
||||
Ages []int `key:"ages"`
|
||||
Slice []int `key:"slice"`
|
||||
}
|
||||
m := map[string]any{
|
||||
"ages": []int{1, 2},
|
||||
"slice": []any{},
|
||||
}
|
||||
t.Run("int slice from int", func(t *testing.T) {
|
||||
var v struct {
|
||||
Ages []int `key:"ages"`
|
||||
Slice []int `key:"slice"`
|
||||
}
|
||||
m := map[string]any{
|
||||
"ages": []int{1, 2},
|
||||
"slice": []any{},
|
||||
}
|
||||
|
||||
ast := assert.New(t)
|
||||
if ast.NoError(UnmarshalKey(m, &v)) {
|
||||
ast.ElementsMatch([]int{1, 2}, v.Ages)
|
||||
ast.Equal([]int{}, v.Slice)
|
||||
}
|
||||
ast := assert.New(t)
|
||||
if ast.NoError(UnmarshalKey(m, &v)) {
|
||||
ast.ElementsMatch([]int{1, 2}, v.Ages)
|
||||
ast.Equal([]int{}, v.Slice)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("int slice from one int", func(t *testing.T) {
|
||||
var v struct {
|
||||
Ages []int `key:"ages"`
|
||||
}
|
||||
m := map[string]any{
|
||||
"ages": []int{2},
|
||||
}
|
||||
|
||||
ast := assert.New(t)
|
||||
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
|
||||
if ast.NoError(unmarshaler.Unmarshal(m, &v)) {
|
||||
ast.ElementsMatch([]int{2}, v.Ages)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("int slice from one int string", func(t *testing.T) {
|
||||
var v struct {
|
||||
Ages []int `key:"ages"`
|
||||
}
|
||||
m := map[string]any{
|
||||
"ages": []string{"2"},
|
||||
}
|
||||
|
||||
ast := assert.New(t)
|
||||
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
|
||||
if ast.NoError(unmarshaler.Unmarshal(m, &v)) {
|
||||
ast.ElementsMatch([]int{2}, v.Ages)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("int slice from one json.Number", func(t *testing.T) {
|
||||
var v struct {
|
||||
Ages []int `key:"ages"`
|
||||
}
|
||||
m := map[string]any{
|
||||
"ages": []json.Number{"2"},
|
||||
}
|
||||
|
||||
ast := assert.New(t)
|
||||
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
|
||||
if ast.NoError(unmarshaler.Unmarshal(m, &v)) {
|
||||
ast.ElementsMatch([]int{2}, v.Ages)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("int slice from one int strings", func(t *testing.T) {
|
||||
var v struct {
|
||||
Ages []int `key:"ages"`
|
||||
}
|
||||
m := map[string]any{
|
||||
"ages": []string{"1,2"},
|
||||
}
|
||||
|
||||
ast := assert.New(t)
|
||||
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
|
||||
ast.Error(unmarshaler.Unmarshal(m, &v))
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalString(t *testing.T) {
|
||||
@@ -1442,6 +1532,51 @@ func TestUnmarshalStringSliceFromString(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("slice from empty string", func(t *testing.T) {
|
||||
var v struct {
|
||||
Names []string `key:"names"`
|
||||
}
|
||||
m := map[string]any{
|
||||
"names": []string{""},
|
||||
}
|
||||
|
||||
ast := assert.New(t)
|
||||
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
|
||||
if ast.NoError(unmarshaler.Unmarshal(m, &v)) {
|
||||
ast.ElementsMatch([]string{""}, v.Names)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("slice from empty and valid string", func(t *testing.T) {
|
||||
var v struct {
|
||||
Names []string `key:"names"`
|
||||
}
|
||||
m := map[string]any{
|
||||
"names": []string{","},
|
||||
}
|
||||
|
||||
ast := assert.New(t)
|
||||
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
|
||||
if ast.NoError(unmarshaler.Unmarshal(m, &v)) {
|
||||
ast.ElementsMatch([]string{","}, v.Names)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("slice from valid strings with comma", func(t *testing.T) {
|
||||
var v struct {
|
||||
Names []string `key:"names"`
|
||||
}
|
||||
m := map[string]any{
|
||||
"names": []string{"aa,bb"},
|
||||
}
|
||||
|
||||
ast := assert.New(t)
|
||||
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
|
||||
if ast.NoError(unmarshaler.Unmarshal(m, &v)) {
|
||||
ast.ElementsMatch([]string{"aa,bb"}, v.Names)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("slice from string with slice error", func(t *testing.T) {
|
||||
var v struct {
|
||||
Names []int `key:"names"`
|
||||
@@ -4544,6 +4679,23 @@ func TestUnmarshal_EnvInt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshal_EnvInt64(t *testing.T) {
|
||||
type Value struct {
|
||||
Age int64 `key:"age,env=TEST_NAME_INT64"`
|
||||
}
|
||||
|
||||
const (
|
||||
envName = "TEST_NAME_INT64"
|
||||
envVal = "88"
|
||||
)
|
||||
t.Setenv(envName, envVal)
|
||||
|
||||
var v Value
|
||||
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
|
||||
assert.Equal(t, int64(88), v.Age)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshal_EnvIntOverwrite(t *testing.T) {
|
||||
type Value struct {
|
||||
Age int `key:"age,env=TEST_NAME_INT"`
|
||||
@@ -4649,20 +4801,33 @@ func TestUnmarshal_EnvBoolBad(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnmarshal_EnvDuration(t *testing.T) {
|
||||
type Value struct {
|
||||
Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
|
||||
}
|
||||
|
||||
const (
|
||||
envName = "TEST_NAME_DURATION"
|
||||
envVal = "1s"
|
||||
)
|
||||
t.Setenv(envName, envVal)
|
||||
|
||||
var v Value
|
||||
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
|
||||
assert.Equal(t, time.Second, v.Duration)
|
||||
}
|
||||
t.Run("valid duration", func(t *testing.T) {
|
||||
type Value struct {
|
||||
Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
|
||||
}
|
||||
|
||||
var v Value
|
||||
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
|
||||
assert.Equal(t, time.Second, v.Duration)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ptr of duration", func(t *testing.T) {
|
||||
type Value struct {
|
||||
Duration *time.Duration `key:"duration,env=TEST_NAME_DURATION"`
|
||||
}
|
||||
|
||||
var v Value
|
||||
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
|
||||
assert.Equal(t, time.Second, *v.Duration)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshal_EnvDurationBadValue(t *testing.T) {
|
||||
@@ -4761,14 +4926,28 @@ func TestUnmarshal_EnvWithOptionsWrongValueString(t *testing.T) {
|
||||
|
||||
func TestUnmarshalJsonReaderMultiArray(t *testing.T) {
|
||||
t.Run("reader multi array", func(t *testing.T) {
|
||||
var res struct {
|
||||
type testRes struct {
|
||||
A string `json:"a"`
|
||||
B [][]string `json:"b"`
|
||||
C []byte `json:"c"`
|
||||
}
|
||||
payload := `{"a": "133", "b": [["add", "cccd"], ["eeee"]]}`
|
||||
|
||||
var res testRes
|
||||
marshal := testRes{
|
||||
A: "133",
|
||||
B: [][]string{
|
||||
{"add", "cccd"},
|
||||
{"eeee"},
|
||||
},
|
||||
C: []byte("11122344wsss"),
|
||||
}
|
||||
bytes, err := jsonx.Marshal(marshal)
|
||||
assert.NoError(t, err)
|
||||
payload := string(bytes)
|
||||
reader := strings.NewReader(payload)
|
||||
if assert.NoError(t, UnmarshalJsonReader(reader, &res)) {
|
||||
assert.Equal(t, 2, len(res.B))
|
||||
assert.Equal(t, string(marshal.C), string(res.C))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -5639,6 +5818,62 @@ func TestUnmarshalFromStringSliceForTypeMismatch(t *testing.T) {
|
||||
}, &v))
|
||||
}
|
||||
|
||||
func TestUnmarshalWithFromArray(t *testing.T) {
|
||||
t.Run("array", func(t *testing.T) {
|
||||
var v struct {
|
||||
Value []string `key:"value"`
|
||||
}
|
||||
unmarshaler := NewUnmarshaler("key", WithFromArray())
|
||||
if assert.NoError(t, unmarshaler.Unmarshal(map[string]any{
|
||||
"value": []string{"foo", "bar"},
|
||||
}, &v)) {
|
||||
assert.ElementsMatch(t, []string{"foo", "bar"}, v.Value)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("not array", func(t *testing.T) {
|
||||
var v struct {
|
||||
Value string `key:"value"`
|
||||
}
|
||||
unmarshaler := NewUnmarshaler("key", WithFromArray())
|
||||
if assert.NoError(t, unmarshaler.Unmarshal(map[string]any{
|
||||
"value": []string{"foo"},
|
||||
}, &v)) {
|
||||
assert.Equal(t, "foo", v.Value)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("not array and empty", func(t *testing.T) {
|
||||
var v struct {
|
||||
Value string `key:"value"`
|
||||
}
|
||||
unmarshaler := NewUnmarshaler("key", WithFromArray())
|
||||
if assert.NoError(t, unmarshaler.Unmarshal(map[string]any{
|
||||
"value": []string{""},
|
||||
}, &v)) {
|
||||
assert.Empty(t, v.Value)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("not array and no value", func(t *testing.T) {
|
||||
var v struct {
|
||||
Value string `key:"value"`
|
||||
}
|
||||
unmarshaler := NewUnmarshaler("key", WithFromArray())
|
||||
assert.Error(t, unmarshaler.Unmarshal(map[string]any{}, &v))
|
||||
})
|
||||
|
||||
t.Run("not array and no value and optional", func(t *testing.T) {
|
||||
var v struct {
|
||||
Value string `key:"value,optional"`
|
||||
}
|
||||
unmarshaler := NewUnmarshaler("key", WithFromArray())
|
||||
if assert.NoError(t, unmarshaler.Unmarshal(map[string]any{}, &v)) {
|
||||
assert.Empty(t, v.Value)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalWithOpaqueKeys(t *testing.T) {
|
||||
var v struct {
|
||||
Opaque string `key:"opaque.key"`
|
||||
@@ -5804,6 +6039,147 @@ func TestUnmarshal_Unmarshaler(t *testing.T) {
|
||||
}, &v))
|
||||
assert.Nil(t, v.Foo)
|
||||
})
|
||||
|
||||
t.Run("json.Number", func(t *testing.T) {
|
||||
v := struct {
|
||||
Foo *mockUnmarshaler `json:"name"`
|
||||
}{}
|
||||
m := map[string]any{
|
||||
"name": json.Number("123"),
|
||||
}
|
||||
assert.Error(t, UnmarshalJsonMap(m, &v))
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseJsonStringValue(t *testing.T) {
|
||||
t.Run("string", func(t *testing.T) {
|
||||
type GoodsInfo struct {
|
||||
Sku int64 `json:"sku,optional"`
|
||||
}
|
||||
|
||||
type GetReq struct {
|
||||
GoodsList []*GoodsInfo `json:"goods_list"`
|
||||
}
|
||||
|
||||
input := map[string]any{"goods_list": "[{\"sku\":11},{\"sku\":22}]"}
|
||||
var v GetReq
|
||||
assert.NotPanics(t, func() {
|
||||
assert.NoError(t, UnmarshalJsonMap(input, &v))
|
||||
assert.Equal(t, 2, len(v.GoodsList))
|
||||
assert.ElementsMatch(t, []int64{11, 22}, []int64{v.GoodsList[0].Sku, v.GoodsList[1].Sku})
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("string with invalid type", func(t *testing.T) {
|
||||
type GetReq struct {
|
||||
GoodsList []*int `json:"goods_list"`
|
||||
}
|
||||
|
||||
input := map[string]any{"goods_list": "[{\"sku\":11},{\"sku\":22}]"}
|
||||
var v GetReq
|
||||
assert.NotPanics(t, func() {
|
||||
assert.Error(t, UnmarshalJsonMap(input, &v))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// issue #5033, string type
|
||||
func TestUnmarshalFromEnvString(t *testing.T) {
|
||||
t.Setenv("STRING_ENV", "dev")
|
||||
|
||||
t.Run("by value", func(t *testing.T) {
|
||||
type (
|
||||
Env string
|
||||
Config struct {
|
||||
Env Env `json:",env=STRING_ENV,default=prod"`
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
if assert.NoError(t, UnmarshalJsonMap(map[string]any{}, &c)) {
|
||||
assert.Equal(t, Env("dev"), c.Env)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("by ptr", func(t *testing.T) {
|
||||
type (
|
||||
Env string
|
||||
Config struct {
|
||||
Env *Env `json:",env=STRING_ENV,default=prod"`
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
if assert.NoError(t, UnmarshalJsonMap(map[string]any{}, &c)) {
|
||||
assert.Equal(t, Env("dev"), *c.Env)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// issue #5033, bool type
|
||||
func TestUnmarshalFromEnvBool(t *testing.T) {
|
||||
t.Setenv("BOOL_ENV", "true")
|
||||
|
||||
t.Run("by value", func(t *testing.T) {
|
||||
type (
|
||||
Env bool
|
||||
Config struct {
|
||||
Env Env `json:",env=BOOL_ENV,default=false"`
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
if assert.NoError(t, UnmarshalJsonMap(map[string]any{}, &c)) {
|
||||
assert.Equal(t, Env(true), c.Env)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("by ptr", func(t *testing.T) {
|
||||
type (
|
||||
Env bool
|
||||
Config struct {
|
||||
Env *Env `json:",env=BOOL_ENV,default=false"`
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
if assert.NoError(t, UnmarshalJsonMap(map[string]any{}, &c)) {
|
||||
assert.Equal(t, Env(true), *c.Env)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// issue #5033, customized int type
|
||||
func TestUnmarshalFromEnvInt(t *testing.T) {
|
||||
t.Setenv("INT_ENV", "2")
|
||||
|
||||
t.Run("by value", func(t *testing.T) {
|
||||
type (
|
||||
Env int
|
||||
Config struct {
|
||||
Env Env `json:",env=INT_ENV,default=0"`
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
if assert.NoError(t, UnmarshalJsonMap(map[string]any{}, &c)) {
|
||||
assert.Equal(t, Env(2), c.Env)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("by ptr", func(t *testing.T) {
|
||||
type (
|
||||
Env int
|
||||
Config struct {
|
||||
Env *Env `json:",env=INT_ENV,default=0"`
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
if assert.NoError(t, UnmarshalJsonMap(map[string]any{}, &c)) {
|
||||
assert.Equal(t, Env(2), *c.Env)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkDefaultValue(b *testing.B) {
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
package mapping
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -91,17 +92,25 @@ func ValidatePtr(v reflect.Value) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertToString(val any, fullName string) (string, error) {
|
||||
v, ok := val.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("expect string for field %s, but got type %T", fullName, val)
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func convertTypeFromString(kind reflect.Kind, str string) (any, error) {
|
||||
switch kind {
|
||||
case reflect.Bool:
|
||||
switch strings.ToLower(str) {
|
||||
case "1", "true":
|
||||
if str == "1" || strings.EqualFold(str, "true") {
|
||||
return true, nil
|
||||
case "0", "false":
|
||||
return false, nil
|
||||
default:
|
||||
return false, errTypeMismatch
|
||||
}
|
||||
if str == "0" || strings.EqualFold(str, "false") {
|
||||
return false, nil
|
||||
}
|
||||
return false, errTypeMismatch
|
||||
case reflect.Int:
|
||||
return strconv.ParseInt(str, 10, intSize)
|
||||
case reflect.Int8:
|
||||
@@ -269,7 +278,7 @@ func parseKeyAndOptions(tagName string, field reflect.StructField) (string, *fie
|
||||
cache, ok := optionsCache[value]
|
||||
cacheLock.RUnlock()
|
||||
if ok {
|
||||
return stringx.TakeOne(cache.key, field.Name), cache.options, cache.err
|
||||
return cmp.Or(cache.key, field.Name), cache.options, cache.err
|
||||
}
|
||||
|
||||
key, options, err := doParseKeyAndOptions(field, value)
|
||||
@@ -281,7 +290,7 @@ func parseKeyAndOptions(tagName string, field reflect.StructField) (string, *fie
|
||||
}
|
||||
cacheLock.Unlock()
|
||||
|
||||
return stringx.TakeOne(key, field.Name), options, err
|
||||
return cmp.Or(key, field.Name), options, err
|
||||
}
|
||||
|
||||
// support below notations:
|
||||
@@ -573,6 +582,10 @@ func toFloat64(v any) (float64, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func toReflectValue(tp reflect.Type, v any) reflect.Value {
|
||||
return reflect.ValueOf(v).Convert(Deref(tp))
|
||||
}
|
||||
|
||||
func usingDifferentKeys(key string, field reflect.StructField) bool {
|
||||
if len(field.Tag) > 0 {
|
||||
if _, ok := field.Tag.Lookup(key); !ok {
|
||||
@@ -634,11 +647,11 @@ func validateValueInOptions(val any, options []string) error {
|
||||
if len(options) > 0 {
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
if !stringx.Contains(options, v) {
|
||||
if !slices.Contains(options, v) {
|
||||
return fmt.Errorf(`error: value %q is not defined in options "%v"`, v, options)
|
||||
}
|
||||
default:
|
||||
if !stringx.Contains(options, Repr(v)) {
|
||||
if !slices.Contains(options, Repr(v)) {
|
||||
return fmt.Errorf(`error: value "%v" is not defined in options "%v"`, val, options)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -334,3 +334,43 @@ func TestValidateValueRange(t *testing.T) {
|
||||
func TestSetMatchedPrimitiveValue(t *testing.T) {
|
||||
assert.Error(t, setMatchedPrimitiveValue(reflect.Func, reflect.ValueOf(2), "1"))
|
||||
}
|
||||
|
||||
func TestConvertTypeFromString_Bool(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
// true cases
|
||||
{name: "1", input: "1", want: true, wantErr: false},
|
||||
{name: "true lowercase", input: "true", want: true, wantErr: false},
|
||||
{name: "True mixed", input: "True", want: true, wantErr: false},
|
||||
{name: "TRUE uppercase", input: "TRUE", want: true, wantErr: false},
|
||||
{name: "TrUe mixed", input: "TrUe", want: true, wantErr: false},
|
||||
// false cases
|
||||
{name: "0", input: "0", want: false, wantErr: false},
|
||||
{name: "false lowercase", input: "false", want: false, wantErr: false},
|
||||
{name: "False mixed", input: "False", want: false, wantErr: false},
|
||||
{name: "FALSE uppercase", input: "FALSE", want: false, wantErr: false},
|
||||
{name: "FaLsE mixed", input: "FaLsE", want: false, wantErr: false},
|
||||
// error cases
|
||||
{name: "invalid yes", input: "yes", want: false, wantErr: true},
|
||||
{name: "invalid no", input: "no", want: false, wantErr: true},
|
||||
{name: "invalid empty", input: "", want: false, wantErr: true},
|
||||
{name: "invalid 2", input: "2", want: false, wantErr: true},
|
||||
{name: "invalid truee", input: "truee", want: false, wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := convertTypeFromString(reflect.Bool, tt.input)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,3 +29,10 @@ func TestCalcDiffEntropy(t *testing.T) {
|
||||
}
|
||||
assert.True(t, CalcEntropy(m) < .99)
|
||||
}
|
||||
|
||||
func TestCalcEntropySingleItem(t *testing.T) {
|
||||
m := map[any]int{
|
||||
"only": 42,
|
||||
}
|
||||
assert.Equal(t, float64(1), CalcEntropy(m))
|
||||
}
|
||||
|
||||
@@ -1,19 +1,13 @@
|
||||
package mathx
|
||||
|
||||
// MaxInt returns the larger one of a and b.
|
||||
// Deprecated: use builtin max instead.
|
||||
func MaxInt(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
return max(a, b)
|
||||
}
|
||||
|
||||
// MinInt returns the smaller one of a and b.
|
||||
// Deprecated: use builtin min instead.
|
||||
func MinInt(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
return min(a, b)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package mathx
|
||||
|
||||
// Numerical is a constraint that permits any numeric type.
|
||||
type Numerical interface {
|
||||
~int | ~int8 | ~int16 | ~int32 | ~int64 |
|
||||
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// An Unstable is used to generate random value around the mean value base on given deviation.
|
||||
// An Unstable is used to generate random value around the mean value based on given deviation.
|
||||
type Unstable struct {
|
||||
deviation float64
|
||||
r *rand.Rand
|
||||
|
||||
@@ -3,6 +3,9 @@ package mr
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
@@ -142,89 +145,6 @@ func MapReduceChan[T, U, V any](source <-chan T, mapper MapperFunc[T, U], reduce
|
||||
return mapReduceWithPanicChan(source, panicChan, mapper, reducer, opts...)
|
||||
}
|
||||
|
||||
// mapReduceWithPanicChan maps all elements from source, and reduce the output elements with given reducer.
|
||||
func mapReduceWithPanicChan[T, U, V any](source <-chan T, panicChan *onceChan, mapper MapperFunc[T, U],
|
||||
reducer ReducerFunc[U, V], opts ...Option) (val V, err error) {
|
||||
options := buildOptions(opts...)
|
||||
// output is used to write the final result
|
||||
output := make(chan V)
|
||||
defer func() {
|
||||
// reducer can only write once, if more, panic
|
||||
for range output {
|
||||
panic("more than one element written in reducer")
|
||||
}
|
||||
}()
|
||||
|
||||
// collector is used to collect data from mapper, and consume in reducer
|
||||
collector := make(chan U, options.workers)
|
||||
// if done is closed, all mappers and reducer should stop processing
|
||||
done := make(chan struct{})
|
||||
writer := newGuardedWriter(options.ctx, output, done)
|
||||
var closeOnce sync.Once
|
||||
// use atomic type to avoid data race
|
||||
var retErr errorx.AtomicError
|
||||
finish := func() {
|
||||
closeOnce.Do(func() {
|
||||
close(done)
|
||||
close(output)
|
||||
})
|
||||
}
|
||||
cancel := once(func(err error) {
|
||||
if err != nil {
|
||||
retErr.Set(err)
|
||||
} else {
|
||||
retErr.Set(ErrCancelWithNil)
|
||||
}
|
||||
|
||||
drain(source)
|
||||
finish()
|
||||
})
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
drain(collector)
|
||||
if r := recover(); r != nil {
|
||||
panicChan.write(r)
|
||||
}
|
||||
finish()
|
||||
}()
|
||||
|
||||
reducer(collector, writer, cancel)
|
||||
}()
|
||||
|
||||
go executeMappers(mapperContext[T, U]{
|
||||
ctx: options.ctx,
|
||||
mapper: func(item T, w Writer[U]) {
|
||||
mapper(item, w, cancel)
|
||||
},
|
||||
source: source,
|
||||
panicChan: panicChan,
|
||||
collector: collector,
|
||||
doneChan: done,
|
||||
workers: options.workers,
|
||||
})
|
||||
|
||||
select {
|
||||
case <-options.ctx.Done():
|
||||
cancel(context.DeadlineExceeded)
|
||||
err = context.DeadlineExceeded
|
||||
case v := <-panicChan.channel:
|
||||
// drain output here, otherwise for loop panic in defer
|
||||
drain(output)
|
||||
panic(v)
|
||||
case v, ok := <-output:
|
||||
if e := retErr.Load(); e != nil {
|
||||
err = e
|
||||
} else if ok {
|
||||
val = v
|
||||
} else {
|
||||
err = ErrReduceNoOutput
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// MapReduceVoid maps all elements generated from given generate,
|
||||
// and reduce the output elements with given reducer.
|
||||
func MapReduceVoid[T, U any](generate GenerateFunc[T], mapper MapperFunc[T, U],
|
||||
@@ -266,12 +186,16 @@ func buildOptions(opts ...Option) *mapReduceOptions {
|
||||
return options
|
||||
}
|
||||
|
||||
func buildPanicInfo(r any, stack []byte) string {
|
||||
return fmt.Sprintf("%+v\n\n%s", r, strings.TrimSpace(string(stack)))
|
||||
}
|
||||
|
||||
func buildSource[T any](generate GenerateFunc[T], panicChan *onceChan) chan T {
|
||||
source := make(chan T)
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicChan.write(r)
|
||||
panicChan.write(buildPanicInfo(r, debug.Stack()))
|
||||
}
|
||||
close(source)
|
||||
}()
|
||||
@@ -318,7 +242,7 @@ func executeMappers[T, U any](mCtx mapperContext[T, U]) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
atomic.AddInt32(&failed, 1)
|
||||
mCtx.panicChan.write(r)
|
||||
mCtx.panicChan.write(buildPanicInfo(r, debug.Stack()))
|
||||
}
|
||||
wg.Done()
|
||||
<-pool
|
||||
@@ -330,6 +254,89 @@ func executeMappers[T, U any](mCtx mapperContext[T, U]) {
|
||||
}
|
||||
}
|
||||
|
||||
// mapReduceWithPanicChan maps all elements from source, and reduce the output elements with given reducer.
|
||||
func mapReduceWithPanicChan[T, U, V any](source <-chan T, panicChan *onceChan, mapper MapperFunc[T, U],
|
||||
reducer ReducerFunc[U, V], opts ...Option) (val V, err error) {
|
||||
options := buildOptions(opts...)
|
||||
// output is used to write the final result
|
||||
output := make(chan V)
|
||||
defer func() {
|
||||
// reducer can only write once, if more, panic
|
||||
for range output {
|
||||
panic("more than one element written in reducer")
|
||||
}
|
||||
}()
|
||||
|
||||
// collector is used to collect data from mapper, and consume in reducer
|
||||
collector := make(chan U, options.workers)
|
||||
// if done is closed, all mappers and reducer should stop processing
|
||||
done := make(chan struct{})
|
||||
writer := newGuardedWriter(options.ctx, output, done)
|
||||
var closeOnce sync.Once
|
||||
// use atomic type to avoid data race
|
||||
var retErr errorx.AtomicError
|
||||
finish := func() {
|
||||
closeOnce.Do(func() {
|
||||
close(done)
|
||||
close(output)
|
||||
})
|
||||
}
|
||||
cancel := once(func(err error) {
|
||||
if err != nil {
|
||||
retErr.Set(err)
|
||||
} else {
|
||||
retErr.Set(ErrCancelWithNil)
|
||||
}
|
||||
|
||||
drain(source)
|
||||
finish()
|
||||
})
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
drain(collector)
|
||||
if r := recover(); r != nil {
|
||||
panicChan.write(buildPanicInfo(r, debug.Stack()))
|
||||
}
|
||||
finish()
|
||||
}()
|
||||
|
||||
reducer(collector, writer, cancel)
|
||||
}()
|
||||
|
||||
go executeMappers(mapperContext[T, U]{
|
||||
ctx: options.ctx,
|
||||
mapper: func(item T, w Writer[U]) {
|
||||
mapper(item, w, cancel)
|
||||
},
|
||||
source: source,
|
||||
panicChan: panicChan,
|
||||
collector: collector,
|
||||
doneChan: done,
|
||||
workers: options.workers,
|
||||
})
|
||||
|
||||
select {
|
||||
case <-options.ctx.Done():
|
||||
cancel(context.DeadlineExceeded)
|
||||
err = context.DeadlineExceeded
|
||||
case v := <-panicChan.channel:
|
||||
// drain output here, otherwise for loop panic in defer
|
||||
drain(output)
|
||||
panic(v)
|
||||
case v, ok := <-output:
|
||||
if e := retErr.Load(); e != nil {
|
||||
err = e
|
||||
} else if ok {
|
||||
val = v
|
||||
} else {
|
||||
err = ErrReduceNoOutput
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func newOptions() *mapReduceOptions {
|
||||
return &mapReduceOptions{
|
||||
ctx: context.Background(),
|
||||
|
||||
@@ -3,8 +3,7 @@ package mr
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -16,9 +15,6 @@ import (
|
||||
|
||||
var errDummy = errors.New("dummy")
|
||||
|
||||
func init() {
|
||||
log.SetOutput(io.Discard)
|
||||
}
|
||||
|
||||
func TestFinish(t *testing.T) {
|
||||
defer goleak.VerifyNone(t)
|
||||
@@ -39,6 +35,36 @@ func TestFinish(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestFinishWithPartialErrors(t *testing.T) {
|
||||
defer goleak.VerifyNone(t)
|
||||
|
||||
errDummy := errors.New("dummy")
|
||||
|
||||
t.Run("one error", func(t *testing.T) {
|
||||
err := Finish(func() error {
|
||||
return errDummy
|
||||
}, func() error {
|
||||
return nil
|
||||
}, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
|
||||
t.Run("two errors", func(t *testing.T) {
|
||||
err := Finish(func() error {
|
||||
return errDummy
|
||||
}, func() error {
|
||||
return errDummy
|
||||
}, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFinishNone(t *testing.T) {
|
||||
defer goleak.VerifyNone(t)
|
||||
|
||||
@@ -118,11 +144,28 @@ func TestForEach(t *testing.T) {
|
||||
|
||||
assert.Equal(t, tasks/2, int(count))
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("all", func(t *testing.T) {
|
||||
defer goleak.VerifyNone(t)
|
||||
func TestPanics(t *testing.T) {
|
||||
defer goleak.VerifyNone(t)
|
||||
|
||||
const tasks = 1000
|
||||
verify := func(t *testing.T, r any) {
|
||||
panicStr := fmt.Sprintf("%v", r)
|
||||
assert.Contains(t, panicStr, "foo")
|
||||
assert.Contains(t, panicStr, "goroutine")
|
||||
assert.Contains(t, panicStr, "runtime/debug.Stack")
|
||||
panic(r)
|
||||
}
|
||||
|
||||
t.Run("ForEach run panics", func(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
verify(t, r)
|
||||
}
|
||||
}()
|
||||
|
||||
assert.PanicsWithValue(t, "foo", func() {
|
||||
ForEach(func(source chan<- int) {
|
||||
for i := 0; i < tasks; i++ {
|
||||
source <- i
|
||||
@@ -132,28 +175,31 @@ func TestForEach(t *testing.T) {
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGeneratePanic(t *testing.T) {
|
||||
defer goleak.VerifyNone(t)
|
||||
t.Run("ForEach generate panics", func(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
verify(t, r)
|
||||
}
|
||||
}()
|
||||
|
||||
t.Run("all", func(t *testing.T) {
|
||||
assert.PanicsWithValue(t, "foo", func() {
|
||||
ForEach(func(source chan<- int) {
|
||||
panic("foo")
|
||||
}, func(item int) {
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestMapperPanic(t *testing.T) {
|
||||
defer goleak.VerifyNone(t)
|
||||
|
||||
const tasks = 1000
|
||||
var run int32
|
||||
t.Run("all", func(t *testing.T) {
|
||||
assert.PanicsWithValue(t, "foo", func() {
|
||||
t.Run("Mapper panics", func(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
verify(t, r)
|
||||
}
|
||||
}()
|
||||
|
||||
_, _ = MapReduce(func(source chan<- int) {
|
||||
for i := 0; i < tasks; i++ {
|
||||
source <- i
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux || darwin
|
||||
//go:build linux || darwin || freebsd
|
||||
|
||||
package proc
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux || darwin
|
||||
//go:build linux || darwin || freebsd
|
||||
|
||||
package proc
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux || darwin
|
||||
//go:build linux || darwin || freebsd
|
||||
|
||||
package proc
|
||||
|
||||
|
||||
@@ -4,6 +4,9 @@ package proc
|
||||
|
||||
import "time"
|
||||
|
||||
// ShutdownConf is empty on windows.
|
||||
type ShutdownConf struct{}
|
||||
|
||||
// AddShutdownListener returns fn itself on windows, lets callers call fn on their own.
|
||||
func AddShutdownListener(fn func()) func() {
|
||||
return fn
|
||||
@@ -18,6 +21,10 @@ func AddWrapUpListener(fn func()) func() {
|
||||
func SetTimeToForceQuit(duration time.Duration) {
|
||||
}
|
||||
|
||||
// Setup does nothing on windows.
|
||||
func Setup(conf ShutdownConf) {
|
||||
}
|
||||
|
||||
// Shutdown does nothing on windows.
|
||||
func Shutdown() {
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux || darwin
|
||||
//go:build linux || darwin || freebsd
|
||||
|
||||
package proc
|
||||
|
||||
@@ -14,17 +14,29 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
wrapUpTime = time.Second
|
||||
// why we use 5500 milliseconds is because most of our queue are blocking mode with 5 seconds
|
||||
waitTime = 5500 * time.Millisecond
|
||||
// defaultWrapUpTime is the default time to wait before calling wrap up listeners.
|
||||
defaultWrapUpTime = time.Second
|
||||
// defaultWaitTime is the default time to wait before force quitting.
|
||||
// why we use 5500 milliseconds is because most of our queues are blocking mode with 5 seconds
|
||||
defaultWaitTime = 5500 * time.Millisecond
|
||||
)
|
||||
|
||||
var (
|
||||
wrapUpListeners = new(listenerManager)
|
||||
shutdownListeners = new(listenerManager)
|
||||
delayTimeBeforeForceQuit = waitTime
|
||||
wrapUpListeners = new(listenerManager)
|
||||
shutdownListeners = new(listenerManager)
|
||||
wrapUpTime = defaultWrapUpTime
|
||||
waitTime = defaultWaitTime
|
||||
shutdownLock sync.Mutex
|
||||
)
|
||||
|
||||
// ShutdownConf defines the shutdown configuration for the process.
|
||||
type ShutdownConf struct {
|
||||
// WrapUpTime is the time to wait before calling shutdown listeners.
|
||||
WrapUpTime time.Duration `json:",default=1s"`
|
||||
// WaitTime is the time to wait before force quitting.
|
||||
WaitTime time.Duration `json:",default=5.5s"`
|
||||
}
|
||||
|
||||
// AddShutdownListener adds fn as a shutdown listener.
|
||||
// The returned func can be used to wait for fn getting called.
|
||||
func AddShutdownListener(fn func()) (waitForCalled func()) {
|
||||
@@ -39,7 +51,21 @@ func AddWrapUpListener(fn func()) (waitForCalled func()) {
|
||||
|
||||
// SetTimeToForceQuit sets the waiting time before force quitting.
|
||||
func SetTimeToForceQuit(duration time.Duration) {
|
||||
delayTimeBeforeForceQuit = duration
|
||||
shutdownLock.Lock()
|
||||
defer shutdownLock.Unlock()
|
||||
waitTime = duration
|
||||
}
|
||||
|
||||
func Setup(conf ShutdownConf) {
|
||||
shutdownLock.Lock()
|
||||
defer shutdownLock.Unlock()
|
||||
|
||||
if conf.WrapUpTime > 0 {
|
||||
wrapUpTime = conf.WrapUpTime
|
||||
}
|
||||
if conf.WaitTime > 0 {
|
||||
waitTime = conf.WaitTime
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown calls the registered shutdown listeners, only for test purpose.
|
||||
@@ -61,8 +87,12 @@ func gracefulStop(signals chan os.Signal, sig syscall.Signal) {
|
||||
time.Sleep(wrapUpTime)
|
||||
go shutdownListeners.notifyListeners()
|
||||
|
||||
time.Sleep(delayTimeBeforeForceQuit - wrapUpTime)
|
||||
logx.Infof("Still alive after %v, going to force kill the process...", delayTimeBeforeForceQuit)
|
||||
shutdownLock.Lock()
|
||||
remainingTime := waitTime - wrapUpTime
|
||||
shutdownLock.Unlock()
|
||||
|
||||
time.Sleep(remainingTime)
|
||||
logx.Infof("Still alive after %v, going to force kill the process...", waitTime)
|
||||
_ = syscall.Kill(syscall.Getpid(), sig)
|
||||
}
|
||||
|
||||
@@ -82,6 +112,9 @@ func (lm *listenerManager) addListener(fn func()) (waitForCalled func()) {
|
||||
})
|
||||
lm.lock.Unlock()
|
||||
|
||||
// we can return lm.waitGroup.Wait directly,
|
||||
// but we want to make the returned func more readable.
|
||||
// creating an extra closure would be negligible in practice.
|
||||
return func() {
|
||||
lm.waitGroup.Wait()
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
//go:build linux || darwin
|
||||
//go:build linux || darwin || freebsd
|
||||
|
||||
package proc
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -10,8 +11,12 @@ import (
|
||||
)
|
||||
|
||||
func TestShutdown(t *testing.T) {
|
||||
t.Cleanup(restoreSettings)
|
||||
|
||||
SetTimeToForceQuit(time.Hour)
|
||||
assert.Equal(t, time.Hour, delayTimeBeforeForceQuit)
|
||||
shutdownLock.Lock()
|
||||
assert.Equal(t, time.Hour, waitTime)
|
||||
shutdownLock.Unlock()
|
||||
|
||||
var val int
|
||||
called := AddWrapUpListener(func() {
|
||||
@@ -29,7 +34,53 @@ func TestShutdown(t *testing.T) {
|
||||
assert.Equal(t, 3, val)
|
||||
}
|
||||
|
||||
func TestShutdownWithMultipleServices(t *testing.T) {
|
||||
t.Cleanup(restoreSettings)
|
||||
|
||||
SetTimeToForceQuit(time.Hour)
|
||||
shutdownLock.Lock()
|
||||
assert.Equal(t, time.Hour, waitTime)
|
||||
shutdownLock.Unlock()
|
||||
|
||||
var val int32
|
||||
called1 := AddShutdownListener(func() {
|
||||
atomic.AddInt32(&val, 1)
|
||||
})
|
||||
called2 := AddShutdownListener(func() {
|
||||
atomic.AddInt32(&val, 2)
|
||||
})
|
||||
Shutdown()
|
||||
called1()
|
||||
called2()
|
||||
|
||||
assert.Equal(t, int32(3), atomic.LoadInt32(&val))
|
||||
}
|
||||
|
||||
func TestWrapUpWithMultipleServices(t *testing.T) {
|
||||
t.Cleanup(restoreSettings)
|
||||
|
||||
SetTimeToForceQuit(time.Hour)
|
||||
shutdownLock.Lock()
|
||||
assert.Equal(t, time.Hour, waitTime)
|
||||
shutdownLock.Unlock()
|
||||
|
||||
var val int32
|
||||
called1 := AddWrapUpListener(func() {
|
||||
atomic.AddInt32(&val, 1)
|
||||
})
|
||||
called2 := AddWrapUpListener(func() {
|
||||
atomic.AddInt32(&val, 2)
|
||||
})
|
||||
WrapUp()
|
||||
called1()
|
||||
called2()
|
||||
|
||||
assert.Equal(t, int32(3), atomic.LoadInt32(&val))
|
||||
}
|
||||
|
||||
func TestNotifyMoreThanOnce(t *testing.T) {
|
||||
t.Cleanup(restoreSettings)
|
||||
|
||||
ch := make(chan struct{}, 1)
|
||||
|
||||
go func() {
|
||||
@@ -58,3 +109,38 @@ func TestNotifyMoreThanOnce(t *testing.T) {
|
||||
t.Fatal("timeout, check error logs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
t.Run("valid time", func(t *testing.T) {
|
||||
defer restoreSettings()
|
||||
|
||||
Setup(ShutdownConf{
|
||||
WrapUpTime: time.Second * 2,
|
||||
WaitTime: time.Second * 30,
|
||||
})
|
||||
|
||||
shutdownLock.Lock()
|
||||
assert.Equal(t, time.Second*2, wrapUpTime)
|
||||
assert.Equal(t, time.Second*30, waitTime)
|
||||
shutdownLock.Unlock()
|
||||
})
|
||||
|
||||
t.Run("valid time", func(t *testing.T) {
|
||||
defer restoreSettings()
|
||||
|
||||
Setup(ShutdownConf{})
|
||||
|
||||
shutdownLock.Lock()
|
||||
assert.Equal(t, defaultWrapUpTime, wrapUpTime)
|
||||
assert.Equal(t, defaultWaitTime, waitTime)
|
||||
shutdownLock.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
func restoreSettings() {
|
||||
shutdownLock.Lock()
|
||||
defer shutdownLock.Unlock()
|
||||
|
||||
wrapUpTime = defaultWrapUpTime
|
||||
waitTime = defaultWaitTime
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux || darwin
|
||||
//go:build linux || darwin || freebsd
|
||||
|
||||
package proc
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux || darwin
|
||||
//go:build linux || darwin || freebsd
|
||||
|
||||
package proc
|
||||
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
package prof
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strconv"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/threading"
|
||||
)
|
||||
@@ -28,46 +27,15 @@ type (
|
||||
|
||||
const flushInterval = 5 * time.Minute
|
||||
|
||||
var (
|
||||
pc = &profileCenter{
|
||||
slots: make(map[string]*profileSlot),
|
||||
}
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
func report(name string, duration time.Duration) {
|
||||
updated := func() bool {
|
||||
pc.lock.RLock()
|
||||
defer pc.lock.RUnlock()
|
||||
|
||||
slot, ok := pc.slots[name]
|
||||
if ok {
|
||||
atomic.AddInt64(&slot.lifecount, 1)
|
||||
atomic.AddInt64(&slot.lastcount, 1)
|
||||
atomic.AddInt64(&slot.lifecycle, int64(duration))
|
||||
atomic.AddInt64(&slot.lastcycle, int64(duration))
|
||||
}
|
||||
return ok
|
||||
}()
|
||||
|
||||
if !updated {
|
||||
func() {
|
||||
pc.lock.Lock()
|
||||
defer pc.lock.Unlock()
|
||||
|
||||
pc.slots[name] = &profileSlot{
|
||||
lifecount: 1,
|
||||
lastcount: 1,
|
||||
lifecycle: int64(duration),
|
||||
lastcycle: int64(duration),
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
once.Do(flushRepeatly)
|
||||
var pc = &profileCenter{
|
||||
slots: make(map[string]*profileSlot),
|
||||
}
|
||||
|
||||
func flushRepeatly() {
|
||||
func init() {
|
||||
flushRepeatedly()
|
||||
}
|
||||
|
||||
func flushRepeatedly() {
|
||||
threading.GoSafe(func() {
|
||||
for {
|
||||
time.Sleep(flushInterval)
|
||||
@@ -76,42 +44,64 @@ func flushRepeatly() {
|
||||
})
|
||||
}
|
||||
|
||||
func report(name string, duration time.Duration) {
|
||||
slot := loadOrStoreSlot(name, duration)
|
||||
|
||||
atomic.AddInt64(&slot.lifecount, 1)
|
||||
atomic.AddInt64(&slot.lastcount, 1)
|
||||
atomic.AddInt64(&slot.lifecycle, int64(duration))
|
||||
atomic.AddInt64(&slot.lastcycle, int64(duration))
|
||||
}
|
||||
|
||||
func loadOrStoreSlot(name string, duration time.Duration) *profileSlot {
|
||||
pc.lock.RLock()
|
||||
slot, ok := pc.slots[name]
|
||||
pc.lock.RUnlock()
|
||||
|
||||
if ok {
|
||||
return slot
|
||||
}
|
||||
|
||||
pc.lock.Lock()
|
||||
defer pc.lock.Unlock()
|
||||
|
||||
// double-check
|
||||
if slot, ok = pc.slots[name]; ok {
|
||||
return slot
|
||||
}
|
||||
|
||||
slot = &profileSlot{}
|
||||
pc.slots[name] = slot
|
||||
return slot
|
||||
}
|
||||
|
||||
func generateReport() string {
|
||||
var buffer bytes.Buffer
|
||||
buffer.WriteString("Profiling report\n")
|
||||
var data [][]string
|
||||
var builder strings.Builder
|
||||
builder.WriteString("Profiling report\n")
|
||||
builder.WriteString("QUEUE,LIFECOUNT,LIFECYCLE,LASTCOUNT,LASTCYCLE\n")
|
||||
|
||||
calcFn := func(total, count int64) string {
|
||||
if count == 0 {
|
||||
return "-"
|
||||
}
|
||||
|
||||
return (time.Duration(total) / time.Duration(count)).String()
|
||||
}
|
||||
|
||||
func() {
|
||||
pc.lock.Lock()
|
||||
defer pc.lock.Unlock()
|
||||
pc.lock.Lock()
|
||||
for key, slot := range pc.slots {
|
||||
builder.WriteString(fmt.Sprintf("%s,%d,%s,%d,%s\n",
|
||||
key,
|
||||
slot.lifecount,
|
||||
calcFn(slot.lifecycle, slot.lifecount),
|
||||
slot.lastcount,
|
||||
calcFn(slot.lastcycle, slot.lastcount),
|
||||
))
|
||||
|
||||
for key, slot := range pc.slots {
|
||||
data = append(data, []string{
|
||||
key,
|
||||
strconv.FormatInt(slot.lifecount, 10),
|
||||
calcFn(slot.lifecycle, slot.lifecount),
|
||||
strconv.FormatInt(slot.lastcount, 10),
|
||||
calcFn(slot.lastcycle, slot.lastcount),
|
||||
})
|
||||
// reset last cycle stats
|
||||
atomic.StoreInt64(&slot.lastcount, 0)
|
||||
atomic.StoreInt64(&slot.lastcycle, 0)
|
||||
}
|
||||
pc.lock.Unlock()
|
||||
|
||||
// reset the data for last cycle
|
||||
slot.lastcount = 0
|
||||
slot.lastcycle = 0
|
||||
}
|
||||
}()
|
||||
|
||||
table := tablewriter.NewWriter(&buffer)
|
||||
table.SetHeader([]string{"QUEUE", "LIFECOUNT", "LIFECYCLE", "LASTCOUNT", "LASTCYCLE"})
|
||||
table.SetBorder(false)
|
||||
table.AppendBulk(data)
|
||||
table.Render()
|
||||
|
||||
return buffer.String()
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
)
|
||||
|
||||
func TestReport(t *testing.T) {
|
||||
once.Do(func() {})
|
||||
assert.NotContains(t, generateReport(), "foo")
|
||||
report("foo", time.Second)
|
||||
assert.Contains(t, generateReport(), "foo")
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"runtime/metrics"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -28,10 +30,29 @@ func displayStatsWithWriter(writer io.Writer, interval ...time.Duration) {
|
||||
ticker := time.NewTicker(duration)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
var (
|
||||
alloc, totalAlloc, sys uint64
|
||||
samples = []metrics.Sample{
|
||||
{Name: "/memory/classes/heap/objects:bytes"},
|
||||
{Name: "/gc/heap/allocs:bytes"},
|
||||
{Name: "/memory/classes/total:bytes"},
|
||||
}
|
||||
)
|
||||
metrics.Read(samples)
|
||||
|
||||
if samples[0].Value.Kind() == metrics.KindUint64 {
|
||||
alloc = samples[0].Value.Uint64()
|
||||
}
|
||||
if samples[1].Value.Kind() == metrics.KindUint64 {
|
||||
totalAlloc = samples[1].Value.Uint64()
|
||||
}
|
||||
if samples[2].Value.Kind() == metrics.KindUint64 {
|
||||
sys = samples[2].Value.Uint64()
|
||||
}
|
||||
var stats debug.GCStats
|
||||
debug.ReadGCStats(&stats)
|
||||
fmt.Fprintf(writer, "Goroutines: %d, Alloc: %vm, TotalAlloc: %vm, Sys: %vm, NumGC: %v\n",
|
||||
runtime.NumGoroutine(), m.Alloc/mega, m.TotalAlloc/mega, m.Sys/mega, m.NumGC)
|
||||
runtime.NumGoroutine(), alloc/mega, totalAlloc/mega, sys/mega, stats.NumGC)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/core/trace"
|
||||
"github.com/zeromicro/go-zero/internal/devserver"
|
||||
"github.com/zeromicro/go-zero/internal/profiling"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -37,6 +38,9 @@ type (
|
||||
Prometheus prometheus.Config `json:",optional"`
|
||||
Telemetry trace.Config `json:",optional"`
|
||||
DevServer DevServerConfig `json:",optional"`
|
||||
Shutdown proc.ShutdownConf `json:",optional"`
|
||||
// Profiling is the configuration for continuous profiling.
|
||||
Profiling profiling.Config `json:",optional"`
|
||||
}
|
||||
)
|
||||
|
||||
@@ -61,6 +65,7 @@ func (sc ServiceConf) SetUp() error {
|
||||
sc.Telemetry.Name = sc.Name
|
||||
}
|
||||
trace.StartAgent(sc.Telemetry)
|
||||
proc.Setup(sc.Shutdown)
|
||||
proc.AddShutdownListener(func() {
|
||||
trace.StopAgent()
|
||||
})
|
||||
@@ -68,7 +73,9 @@ func (sc ServiceConf) SetUp() error {
|
||||
if len(sc.MetricsUrl) > 0 {
|
||||
stat.SetReportWriter(stat.NewRemoteWriter(sc.MetricsUrl))
|
||||
}
|
||||
|
||||
devserver.StartAgent(sc.DevServer)
|
||||
profiling.Start(sc.Profiling)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
"github.com/zeromicro/go-zero/core/threading"
|
||||
)
|
||||
|
||||
@@ -35,7 +36,7 @@ type (
|
||||
// NewServiceGroup returns a ServiceGroup.
|
||||
func NewServiceGroup() *ServiceGroup {
|
||||
sg := new(ServiceGroup)
|
||||
sg.stopOnce = syncx.Once(sg.doStop)
|
||||
sg.stopOnce = sync.OnceFunc(sg.doStop)
|
||||
return sg
|
||||
}
|
||||
|
||||
@@ -76,9 +77,14 @@ func (sg *ServiceGroup) doStart() {
|
||||
}
|
||||
|
||||
func (sg *ServiceGroup) doStop() {
|
||||
group := threading.NewRoutineGroup()
|
||||
for _, service := range sg.services {
|
||||
service.Stop()
|
||||
// new variable to avoid closure problems, can be removed after go 1.22
|
||||
// see https://golang.org/doc/faq#closures_and_goroutines
|
||||
service := service
|
||||
group.Run(service.Stop)
|
||||
}
|
||||
group.Wait()
|
||||
}
|
||||
|
||||
// WithStart wraps a start func as a Service.
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
const (
|
||||
clusterNameKey = "CLUSTER_NAME"
|
||||
testEnv = "test.v"
|
||||
timeFormat = "2006-01-02 15:04:05"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -45,7 +44,7 @@ func Report(msg string) {
|
||||
if fn != nil {
|
||||
reported := lessExecutor.DoOrDiscard(func() {
|
||||
var builder strings.Builder
|
||||
builder.WriteString(fmt.Sprintln(time.Now().Format(timeFormat)))
|
||||
builder.WriteString(fmt.Sprintln(time.Now().Format(time.DateTime)))
|
||||
if len(clusterName) > 0 {
|
||||
builder.WriteString(fmt.Sprintf("cluster: %s\n", clusterName))
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package stat
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"runtime/metrics"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -56,8 +57,28 @@ func bToMb(b uint64) float32 {
|
||||
}
|
||||
|
||||
func printUsage() {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
var (
|
||||
alloc, totalAlloc, sys uint64
|
||||
samples = []metrics.Sample{
|
||||
{Name: "/memory/classes/heap/objects:bytes"},
|
||||
{Name: "/gc/heap/allocs:bytes"},
|
||||
{Name: "/memory/classes/total:bytes"},
|
||||
}
|
||||
stats debug.GCStats
|
||||
)
|
||||
metrics.Read(samples)
|
||||
|
||||
if samples[0].Value.Kind() == metrics.KindUint64 {
|
||||
alloc = samples[0].Value.Uint64()
|
||||
}
|
||||
if samples[1].Value.Kind() == metrics.KindUint64 {
|
||||
totalAlloc = samples[1].Value.Uint64()
|
||||
}
|
||||
if samples[2].Value.Kind() == metrics.KindUint64 {
|
||||
sys = samples[2].Value.Uint64()
|
||||
}
|
||||
debug.ReadGCStats(&stats)
|
||||
|
||||
logx.Statf("CPU: %dm, MEMORY: Alloc=%.1fMi, TotalAlloc=%.1fMi, Sys=%.1fMi, NumGC=%d",
|
||||
CpuUsage(), bToMb(m.Alloc), bToMb(m.TotalAlloc), bToMb(m.Sys), m.NumGC)
|
||||
CpuUsage(), bToMb(alloc), bToMb(totalAlloc), bToMb(sys), stats.NumGC)
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
//go:generate mockgen -package mon -destination collectioninserter_mock.go -source bulkinserter.go collectionInserter
|
||||
package mon
|
||||
|
||||
import (
|
||||
@@ -6,7 +7,8 @@ import (
|
||||
|
||||
"github.com/zeromicro/go-zero/core/executors"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo/options"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -27,10 +29,7 @@ type (
|
||||
|
||||
// NewBulkInserter returns a BulkInserter.
|
||||
func NewBulkInserter(coll Collection, interval ...time.Duration) (*BulkInserter, error) {
|
||||
cloneColl, err := coll.Clone()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cloneColl := coll.Clone()
|
||||
|
||||
inserter := &dbInserter{
|
||||
collection: cloneColl,
|
||||
@@ -64,8 +63,16 @@ func (bi *BulkInserter) SetResultHandler(handler ResultHandler) {
|
||||
})
|
||||
}
|
||||
|
||||
type collectionInserter interface {
|
||||
InsertMany(
|
||||
ctx context.Context,
|
||||
documents interface{},
|
||||
opts ...options.Lister[options.InsertManyOptions],
|
||||
) (*mongo.InsertManyResult, error)
|
||||
}
|
||||
|
||||
type dbInserter struct {
|
||||
collection *mongo.Collection
|
||||
collection collectionInserter
|
||||
documents []any
|
||||
resultHandler ResultHandler
|
||||
}
|
||||
|
||||
@@ -1,26 +1,131 @@
|
||||
package mon
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
func TestBulkInserter(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "ok", Value: 1}}...))
|
||||
bulk, err := NewBulkInserter(createModel(mt).Collection)
|
||||
assert.Equal(t, err, nil)
|
||||
bulk.SetResultHandler(func(result *mongo.InsertManyResult, err error) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, len(result.InsertedIDs))
|
||||
})
|
||||
bulk.Insert(bson.D{{Key: "foo", Value: "bar"}})
|
||||
bulk.Insert(bson.D{{Key: "foo", Value: "baz"}})
|
||||
bulk.Flush()
|
||||
func TestBulkInserter_InsertAndFlush(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockCollection(ctrl)
|
||||
mockCollection.EXPECT().Clone().Return(&mongo.Collection{})
|
||||
bulkInserter, err := NewBulkInserter(mockCollection, time.Second)
|
||||
assert.NoError(t, err)
|
||||
bulkInserter.SetResultHandler(func(result *mongo.InsertManyResult, err error) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, len(result.InsertedIDs))
|
||||
})
|
||||
doc := map[string]interface{}{"name": "test"}
|
||||
bulkInserter.Insert(doc)
|
||||
bulkInserter.Flush()
|
||||
}
|
||||
|
||||
func TestBulkInserter_SetResultHandler(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockCollection(ctrl)
|
||||
mockCollection.EXPECT().Clone().Return(nil)
|
||||
bulkInserter, err := NewBulkInserter(mockCollection)
|
||||
assert.NoError(t, err)
|
||||
mockHandler := func(result *mongo.InsertManyResult, err error) {}
|
||||
bulkInserter.SetResultHandler(mockHandler)
|
||||
}
|
||||
|
||||
func TestDbInserter_RemoveAll(t *testing.T) {
|
||||
inserter := &dbInserter{}
|
||||
inserter.documents = []interface{}{}
|
||||
docs := inserter.RemoveAll()
|
||||
assert.NotNil(t, docs)
|
||||
assert.Empty(t, inserter.documents)
|
||||
}
|
||||
|
||||
func Test_dbInserter_Execute(t *testing.T) {
|
||||
type fields struct {
|
||||
collection collectionInserter
|
||||
documents []any
|
||||
resultHandler ResultHandler
|
||||
}
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockcollectionInserter(ctrl)
|
||||
type args struct {
|
||||
objs any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
mock func()
|
||||
}{
|
||||
{
|
||||
name: "empty doc",
|
||||
fields: fields{
|
||||
collection: nil,
|
||||
documents: nil,
|
||||
resultHandler: nil,
|
||||
},
|
||||
args: args{
|
||||
objs: make([]any, 0),
|
||||
},
|
||||
mock: func() {},
|
||||
},
|
||||
{
|
||||
name: "result handler",
|
||||
fields: fields{
|
||||
collection: mockCollection,
|
||||
resultHandler: func(result *mongo.InsertManyResult, err error) {
|
||||
assert.NotNil(t, err)
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
objs: make([]any, 1),
|
||||
},
|
||||
mock: func() {
|
||||
mockCollection.EXPECT().InsertMany(gomock.Any(), gomock.Any()).Return(&mongo.InsertManyResult{}, errors.New("error"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "normal error handler",
|
||||
fields: fields{
|
||||
collection: mockCollection,
|
||||
resultHandler: nil,
|
||||
},
|
||||
args: args{
|
||||
objs: make([]any, 1),
|
||||
},
|
||||
mock: func() {
|
||||
mockCollection.EXPECT().InsertMany(gomock.Any(), gomock.Any()).Return(&mongo.InsertManyResult{}, errors.New("error"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no error",
|
||||
fields: fields{
|
||||
collection: mockCollection,
|
||||
resultHandler: nil,
|
||||
},
|
||||
args: args{
|
||||
objs: make([]any, 1),
|
||||
},
|
||||
mock: func() {
|
||||
mockCollection.EXPECT().InsertMany(gomock.Any(), gomock.Any()).Return(&mongo.InsertManyResult{}, nil)
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.mock()
|
||||
in := &dbInserter{
|
||||
collection: tt.fields.collection,
|
||||
documents: tt.fields.documents,
|
||||
resultHandler: tt.fields.resultHandler,
|
||||
}
|
||||
in.Execute(tt.args.objs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"io"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
mopt "go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo/options"
|
||||
)
|
||||
|
||||
var clientManager = syncx.NewResourceManager()
|
||||
@@ -29,13 +29,13 @@ func Inject(key string, client *mongo.Client) {
|
||||
|
||||
func getClient(url string, opts ...Option) (*mongo.Client, error) {
|
||||
val, err := clientManager.GetResource(url, func() (io.Closer, error) {
|
||||
o := mopt.Client().ApplyURI(url)
|
||||
o := options.Client().ApplyURI(url)
|
||||
opts = append([]Option{defaultTimeoutOption()}, opts...)
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
|
||||
cli, err := mongo.Connect(context.Background(), o)
|
||||
cli, err := mongo.Connect(o)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -4,19 +4,13 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
)
|
||||
|
||||
func init() {
|
||||
_ = mtest.Setup()
|
||||
}
|
||||
|
||||
func TestClientManger_getClient(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
Inject(mtest.ClusterURI(), mt.Client)
|
||||
cli, err := getClient(mtest.ClusterURI())
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, mt.Client, cli)
|
||||
})
|
||||
c := &mongo.Client{}
|
||||
Inject("foo", c)
|
||||
cli, err := getClient("foo")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, c, cli)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user