Compare commits

...

211 Commits

Author SHA1 Message Date
dependabot[bot]
c38830377e chore(deps): bump the go_modules group across 1 directory with 6 updates
Bumps the go_modules group with 5 updates in the / directory:

| Package | From | To |
| --- | --- | --- |
| [github.com/jackc/pgx/v5](https://github.com/jackc/pgx) | `5.8.0` | `5.9.2` |
| [github.com/modelcontextprotocol/go-sdk](https://github.com/modelcontextprotocol/go-sdk) | `1.4.0` | `1.4.1` |
| [go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp](https://github.com/open-telemetry/opentelemetry-go) | `1.40.0` | `1.43.0` |
| [filippo.io/edwards25519](https://github.com/FiloSottile/edwards25519) | `1.1.0` | `1.1.1` |
| [github.com/go-jose/go-jose/v4](https://github.com/go-jose/go-jose) | `4.1.3` | `4.1.4` |



Updates `github.com/jackc/pgx/v5` from 5.8.0 to 5.9.2
- [Changelog](https://github.com/jackc/pgx/blob/master/CHANGELOG.md)
- [Commits](https://github.com/jackc/pgx/compare/v5.8.0...v5.9.2)

Updates `github.com/modelcontextprotocol/go-sdk` from 1.4.0 to 1.4.1
- [Release notes](https://github.com/modelcontextprotocol/go-sdk/releases)
- [Commits](https://github.com/modelcontextprotocol/go-sdk/compare/v1.4.0...v1.4.1)

Updates `go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp` from 1.40.0 to 1.43.0
- [Release notes](https://github.com/open-telemetry/opentelemetry-go/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-go/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-go/compare/v1.40.0...v1.43.0)

Updates `go.opentelemetry.io/otel/sdk` from 1.40.0 to 1.43.0
- [Release notes](https://github.com/open-telemetry/opentelemetry-go/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-go/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-go/compare/v1.40.0...v1.43.0)

Updates `filippo.io/edwards25519` from 1.1.0 to 1.1.1
- [Commits](https://github.com/FiloSottile/edwards25519/compare/v1.1.0...v1.1.1)

Updates `github.com/go-jose/go-jose/v4` from 4.1.3 to 4.1.4
- [Release notes](https://github.com/go-jose/go-jose/releases)
- [Commits](https://github.com/go-jose/go-jose/compare/v4.1.3...v4.1.4)

---
updated-dependencies:
- dependency-name: github.com/jackc/pgx/v5
  dependency-version: 5.9.2
  dependency-type: direct:production
  dependency-group: go_modules
- dependency-name: github.com/modelcontextprotocol/go-sdk
  dependency-version: 1.4.1
  dependency-type: direct:production
  dependency-group: go_modules
- dependency-name: go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp
  dependency-version: 1.43.0
  dependency-type: direct:production
  dependency-group: go_modules
- dependency-name: go.opentelemetry.io/otel/sdk
  dependency-version: 1.43.0
  dependency-type: direct:production
  dependency-group: go_modules
- dependency-name: filippo.io/edwards25519
  dependency-version: 1.1.1
  dependency-type: indirect
  dependency-group: go_modules
- dependency-name: github.com/go-jose/go-jose/v4
  dependency-version: 4.1.4
  dependency-type: indirect
  dependency-group: go_modules
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-05-11 15:57:30 +00:00
dependabot[bot]
3738be1945 chore(deps): bump go.mongodb.org/mongo-driver/v2 from 2.5.0 to 2.6.0 (#5558)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-11 23:17:45 +08:00
Kevin Wan
5b74b9ab7b feat(mcp): add opt-in request metadata bridge for tool handlers (#5550) 2026-04-25 17:11:04 +08:00
Kevin Wan
4a67261b7b fix(discov): move etcd hosts from URI authority to path for Go 1.26 compatibility (#5548) 2026-04-25 10:48:28 +08:00
dependabot[bot]
22bdae0787 chore(deps): bump codecov/codecov-action from 5 to 6 (#5521) 2026-04-11 15:41:58 +08:00
dependabot[bot]
e8675d6a9a chore(deps): bump google.golang.org/grpc from 1.79.3 to 1.80.0 (#5523) 2026-04-11 15:06:08 +08:00
dependabot[bot]
e441c44975 chore(deps): bump google.golang.org/grpc from 1.79.3 to 1.80.0 in /tools/goctl (#5524) 2026-04-11 10:46:11 +08:00
Kevin Wan
3f91a79a2b chore: update goctl version to v1.10.1 and bump go-zero dependency (#5518) 2026-03-28 23:01:48 +08:00
Name
8c47c01739 fix(rest/httpc): reject request body for HEAD method in buildRequest (#5457)
Co-authored-by: 1911860538 <alxps1911@gmail.com>
2026-03-28 14:16:53 +00:00
dependabot[bot]
f59a1cb0de chore(deps): bump github.com/grafana/pyroscope-go from 1.2.7 to 1.2.8 (#5513)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-28 08:49:57 +08:00
dependabot[bot]
d44ff6ddc8 chore(deps): bump github.com/pelletier/go-toml/v2 from 2.2.4 to 2.3.0 (#5512)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-28 08:35:12 +08:00
Kevin Wan
6ffa9cabec chore: reorder Eval/EvalCtx after Do/DoCtx in redis.go for consistent method ordering (#5502) 2026-03-22 20:39:30 +08:00
Ran丶
0069721586 feat(redis): add Do/DoCtx for generic command execution #5417 (#5442) 2026-03-22 12:26:53 +00:00
Kevin Wan
ba9c275853 chore: upgrade Go version to 1.24 and update dependencies (#5499) 2026-03-22 18:47:43 +08:00
fyyang
9a6447ab5c feat: goctl model Add a new method hasField (#5484) 2026-03-22 06:26:56 +00:00
kesonan
004995f06a feat(goctl/rpc): support external proto imports with cross-package ty… (#5472)
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-22 04:01:20 +00:00
Amshith Nair
c12c82b2f6 test(mathx,stringx): add missing edge case tests for CalcEntropy and … (#5471) 2026-03-22 03:24:02 +00:00
Name
85d770d340 perf(core/stringx): replace manual char filter with strings.Map (#5453)
Co-authored-by: 1911860538 <alxps1911@gmail.com>
2026-03-22 03:04:12 +00:00
Name
8cd7f7a2d8 refactor(core): replace TakeOne usage with cmp.Or (#5461)
Co-authored-by: 1911860538 <alxps1911@gmail.com>
2026-03-22 02:50:16 +00:00
Amshith Nair
db3101361b docs(mathx): add godoc comment to Numerical type constraint (#5470) 2026-03-21 15:25:37 +00:00
kesonan
eb2302b71e fix(swagger): add example field to path/form/header parameters (#5497)
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-21 15:18:55 +00:00
Amshith Nair
04ed637366 test(hash): add unit tests for Hash, Hash determinism, and Md5Hex edg… (#5469) 2026-03-15 15:02:57 +00:00
Kevin Wan
567087a715 test(goctl): add regression test for per-service type alias filtering (#5481) (#5483)
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-03-15 22:51:27 +08:00
kesonan
4d2e64a417 fix #5481 (#5482) 2026-03-15 14:14:12 +00:00
kesonan
b01831b4c5 (goctl)fix file copy permission missed (#5475) 2026-03-15 13:55:27 +00:00
Kevin Wan
d1a014955c fix: critical security fixes in core/codec (S0) (#5479) 2026-03-15 16:40:15 +08:00
Kevin Wan
ec802e25a6 feat: add JSON5 configuration support (#5433)
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-03-14 21:19:46 +08:00
dependabot[bot]
8a2e09dfd1 chore(deps): bump github.com/alicebob/miniredis/v2 from 2.36.1 to 2.37.0 (#5444)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-28 20:29:53 +08:00
dependabot[bot]
220d438fe7 chore(deps): bump github.com/modelcontextprotocol/go-sdk from 1.3.0 to 1.3.1 (#5435) 2026-02-21 13:04:51 +08:00
dependabot[bot]
2cd96146fa chore(deps): bump github.com/redis/go-redis/v9 from 9.17.3 to 9.18.0 (#5432)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-17 20:21:10 +08:00
Kevin Wan
7e96317fad chore: update goctl version (#5431) 2026-02-15 20:25:43 +08:00
Kevin Wan
70728ce2e2 chore: update go-zero version (#5430) 2026-02-15 19:29:30 +08:00
dependabot[bot]
6a72a735d4 chore(deps): bump github.com/modelcontextprotocol/go-sdk from 1.2.0 to 1.3.0 (#5413) 2026-02-12 22:42:08 +08:00
Kevin Wan
b139a82c2e fix: resolve data race in service discovery map access (#5408) 2026-02-06 23:16:05 +08:00
Kevin Wan
bdddf1f30c feat(gateway): export WithDialer option for custom gRPC client configuration (#5406) 2026-02-06 21:50:50 +08:00
dependabot[bot]
9b74b7e09e chore(deps): bump github.com/emicklei/proto from 1.14.2 to 1.14.3 in /tools/goctl (#5403)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-06 21:01:11 +08:00
RUGBEN
4d5ed2c45d fix(conf): support equal signs in property values (#5392)
Co-authored-by: liaogs <liaoguoshun@qq.com>
2026-02-01 04:29:16 +00:00
dependabot[bot]
a2310bf9d7 chore(deps): bump go.mongodb.org/mongo-driver/v2 from 2.4.2 to 2.5.0 (#5393) 2026-01-31 07:18:09 +08:00
dependabot[bot]
be846eba01 chore(deps): bump github.com/redis/go-redis/v9 from 9.17.2 to 9.17.3 (#5390) 2026-01-29 12:32:11 +08:00
Kevin Wan
b20f0e3d60 test(conf): add comprehensive validation tests for Load function (#5388) 2026-01-25 07:21:05 +08:00
RUGBEN
e2bb65d43c fix(conf): Remove redundant validation (#5372)
Co-authored-by: liaogs <liaoguoshun@qq.com>
2026-01-24 15:46:40 +00:00
mk0walsk
94e2f5bd12 Refactor routes and harden AddTool (#5375) 2026-01-24 12:13:35 +00:00
Kevin Wan
173f76acf9 feat: add cmdline argument to control whether generate package name from proto filename (#5387) 2026-01-24 19:47:14 +08:00
godLei6
6e1af75635 rpc service use proto.Package.Name by support multi proto file (#5378) 2026-01-24 08:44:44 +00:00
dependabot[bot]
84ff755e61 chore(deps): bump github.com/alicebob/miniredis/v2 from 2.36.0 to 2.36.1 (#5386)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-24 15:06:06 +08:00
dependabot[bot]
4b9d23aef5 chore(deps): bump github.com/alicebob/miniredis/v2 from 2.35.0 to 2.36.0 (#5381) 2026-01-23 21:55:24 +08:00
dependabot[bot]
97b9aebe99 chore(deps): bump go.mongodb.org/mongo-driver/v2 from 2.4.1 to 2.4.2 (#5385) 2026-01-23 07:44:19 +08:00
Kevin Wan
8e7e5695eb feat(mcp): migrate to official go-sdk with simplified API (#5362) 2025-12-26 00:21:45 +08:00
Kevin Wan
4b4751e76c chore: remove jaeger exporter due to official deprecation (#5361)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-12-25 23:03:14 +08:00
Kevin Wan
fcec494ea8 fix: ignore context cancel on triggering breaker of httpc (#5360)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-12-25 21:39:45 +08:00
Kevin Wan
42117c2dcc feat: upgrade go to version 1.23 (#5359) 2025-12-25 21:08:36 +08:00
dependabot[bot]
4b631f3785 chore(deps): bump github.com/zeromicro/go-zero from 1.9.3 to 1.9.4 in /tools/goctl (#5356)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-25 14:15:09 +08:00
Qiu shao
f29c8612e8 fix(zrpc): fix slow threshold priority in stat interceptor (#5310)
Co-authored-by: qiuwenhao <qiushaotest@qq.com>
2025-12-23 14:45:33 +00:00
Kevin Wan
35ba024103 chore: refactor code (#5352) 2025-12-23 22:29:52 +08:00
kesonan
52df1c532a Fix the issue of incorrect values notified in the configuration center (#5348) 2025-12-23 22:06:04 +08:00
Kevin Wan
39729f3756 fix(discov): add retry cooldown to prevent CPU/disk exhaustion on auth errors (#5347) 2025-12-20 22:04:38 +08:00
Kevin Wan
5c9ea81db2 docs: simplify README files while preserving structure (#5338) 2025-12-13 13:01:35 +08:00
Qiu shao
b284664de4 perf(mapping): use strings.EqualFold to optimize bool parsing (#5324)
Co-authored-by: qiuwenhao <qiushaotest@qq.com>
2025-12-12 15:24:10 +00:00
Ran丶
1b76885040 feat(redis): add redis command for getex (#5323) 2025-12-12 15:18:46 +00:00
Kevin Wan
eef217522b chore: simplify readme (#5334) 2025-12-12 22:32:45 +08:00
Kevin Wan
6bd0d169d5 docs: add AI-Native Development section to README (#5333) 2025-12-12 22:28:47 +08:00
soasurs
3d291328d8 feat(zrpc): migrate kube resolver from Endpoints to EndpointSlice API (#4987)
Signed-off-by: soasurs <soasurs@gmail.com>
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
2025-12-11 23:09:08 +08:00
dependabot[bot]
858f8ca82e chore(deps): bump go.mongodb.org/mongo-driver/v2 from 2.4.0 to 2.4.1 (#5329)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-11 20:33:49 +08:00
Qiu shao
4ff3975c5a perf(config): optimize getfullname (#5328)
Co-authored-by: qiuwenhao <qiushaotest@qq.com>
2025-12-10 14:46:26 +00:00
Kevin Wan
7b23f73268 fix(timingwheel): add missing Wait() call and improve code clarity (#5315)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-12-07 11:37:56 +08:00
Kevin Wan
918a7be698 docs: enhance copilot instructions with detailed architecture patterns (#5313)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-06 23:39:28 +08:00
dependabot[bot]
0a724447cd chore(deps): bump github.com/spf13/cobra from 1.10.1 to 1.10.2 in /tools/goctl (#5312) 2025-12-05 09:30:53 +08:00
Gregor Fischer
9e425893a7 Fix typos and grammar in comments (#5308) 2025-12-03 14:32:49 +00:00
dependabot[bot]
4de13b6cc8 chore(deps): bump github.com/redis/go-redis/v9 from 9.17.1 to 9.17.2 (#5307)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-03 22:25:56 +08:00
Qiu shao
c6f75532fa Fix/logx test log mismatchtime schedule num (#5305)
Co-authored-by: qiuwenhao <qiushaotest@qq.com>
2025-11-30 13:44:53 +00:00
dependabot[bot]
fdf4ccf057 chore(deps): bump github.com/zeromicro/go-zero from 1.9.2 to 1.9.3 in /tools/goctl (#5289) 2025-11-29 22:16:18 +08:00
dependabot[bot]
b333ed245b chore(deps): bump github.com/redis/go-redis/v9 from 9.17.0 to 9.17.1 (#5301) 2025-11-28 17:02:14 +08:00
dependabot[bot]
8f1576df36 chore(deps): bump actions/checkout from 5 to 6 (#5297) 2025-11-25 23:05:20 +08:00
Gregor Fischer
72dd970969 Fix Grammar and Typo in Comments (#5284) 2025-11-20 21:26:50 +08:00
dependabot[bot]
29b65e12c1 chore(deps): bump github.com/redis/go-redis/v9 from 9.16.0 to 9.17.0 (#5285)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-20 20:57:32 +08:00
Yuntsy
577a611dc3 fix(logx): add missing color for levelSevere in wrapLevelWithColor (#5281) 2025-11-19 14:46:28 +00:00
Kevin Wan
75941aedd4 refactor: simplify getValueInterface function (#5280) 2025-11-16 20:36:49 +08:00
lerity-yao
c7065171d7 fix(orm): properly handle zero value scanning for pointer destinations (#5270) 2025-11-16 11:59:13 +00:00
Gregor Fischer
052de3b552 chore: fix grammar and typos in comments (#5279) 2025-11-16 11:27:17 +00:00
Kevin Wan
866613af8c Update readme-cn.md (#5266) 2025-11-11 22:24:17 +08:00
Kevin Wan
3d4f6a5e16 Add company to the user list (#5264) 2025-11-01 22:00:35 +08:00
dependabot[bot]
d1d47d02d5 chore(deps): bump go.mongodb.org/mongo-driver/v2 from 2.3.1 to 2.4.0 (#5262) 2025-10-29 09:53:57 +08:00
Kevin Wan
d6c876860b feat(zrpc): change NonBlock default to true following gRPC best practices (#5259) 2025-10-26 12:56:34 +00:00
Kevin Wan
98423ca948 fix(goctl): use rest.Serverless for generated integration tests (#5258) 2025-10-25 23:14:54 +08:00
Kevin Wan
4e52d77ad8 fix(trace): use sync.Once to prevent multiple trace initialization (#5244)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-10-25 20:10:15 +08:00
Kevin Wan
1fc2cfb859 fix: gateway trace headers 5248 (#5256)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-25 12:16:31 +08:00
Gregor Fischer
942cdae41d Fix typos in comments and messages (#5254) 2025-10-25 01:28:40 +00:00
dependabot[bot]
e9c3607bc6 chore(deps): bump github.com/redis/go-redis/v9 from 9.14.1 to 9.16.0 (#5255)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-24 21:19:49 +08:00
dependabot[bot]
d1603e9166 chore(deps): bump github.com/redis/go-redis/v9 from 9.14.0 to 9.14.1 (#5251)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-22 21:01:51 +08:00
zhoushuguang
e30317e9c4 feat: consistent hash balancer support (#5246)
Co-authored-by: 周曙光 <zsg@zhoushuguangdeMacBook-Pro.local>
2025-10-19 14:10:30 +00:00
stemlaud
568f9ce007 chore: remove extra spaces in the comment (#5245)
Signed-off-by: stemlaud <stemlaud@outlook.com>
2025-10-19 13:42:10 +00:00
dependabot[bot]
dcb309065a chore(deps): bump github/codeql-action from 3 to 4 (#5243)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 23:09:38 +00:00
Kevin Wan
bf8e17a686 test: add unit tests for goctl docker command (PR #4343) (#5241) 2025-10-13 21:55:25 +08:00
Jack001
b2ebbfce62 fix: ensure Dockerfile includes etc directory and correct CMD based on config (#4343)
Co-authored-by: 白少杰macpro <harrellharris68491@gmail.com>
2025-10-13 13:42:15 +00:00
Kevin Wan
2b10a6a223 fix: support PUT, PATCH, DELETE methods for request body definitions in swagger (#5239) 2025-10-12 18:24:11 +08:00
Kevin Wan
80c320b46e chore: remove unused code (#5238) 2025-10-12 11:55:57 +08:00
Kevin Wan
bea9d150a1 fix(goctl): restore API summaries in swagger generation (#5237) 2025-10-12 11:38:58 +08:00
Kevin Wan
3f756a2cbf chore: update goctl version (#5236) 2025-10-11 18:01:18 +08:00
Kevin Wan
bbe5bbb0c0 chore: update go-redis for the retracted versions (#5235) 2025-10-11 16:20:23 +08:00
dependabot[bot]
5ad2278a69 chore(deps): bump go.mongodb.org/mongo-driver/v2 from 2.3.0 to 2.3.1 (#5230)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-10 21:32:04 +08:00
ZH1995
77763fe748 Fix api doc url in readme-cn.md (#5233) 2025-10-10 13:21:46 +00:00
Kevin Wan
538c4fb5c7 fix: issue #5154, not using sse template files (#5220) 2025-10-08 11:56:52 +00:00
Kevin Wan
315fb2fe0a Add company to the list in readme-cn.md (#5222) 2025-10-07 21:04:09 +08:00
Copilot
e382887eb8 docs: Add comprehensive documentation for blocking Redis operations (XReadGroup, Blpop) (#5221)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: kevwan <1918356+kevwan@users.noreply.github.com>
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
2025-10-07 16:46:10 +08:00
Kevin Wan
cf21cb2b0b chore: refactor to remove duplicated code (#5216) 2025-10-06 22:24:44 +08:00
Copilot
61e8894c31 Fix swagger generation: info block and server tags not included (#5215)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: kevwan <1918356+kevwan@users.noreply.github.com>
2025-10-06 22:02:42 +08:00
Copilot
7a6c3c8129 Fix swagger path generation: remove trailing slash for root routes with prefix (#5212) 2025-10-05 12:12:03 +08:00
Rizky Ikwan
875fec3e1a chore: fix typos (#5210) 2025-10-04 02:56:07 +00:00
Kevin Wan
60128c2100 chore: update goctl version (#5205) 2025-10-02 22:48:57 +08:00
Copilot
ce6d0e3ea7 fix(goctl/swagger): correct $ref placement in array definitions when useDefinitions is enabled (#5199)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: kevwan <1918356+kevwan@users.noreply.github.com>
2025-10-02 22:11:18 +08:00
Kevin Wan
fa85c84af3 chore: code refactoring (#5204) 2025-10-02 21:48:03 +08:00
Remember
440884105e feat(handler): add sseSlowThreshold (#5196) 2025-10-02 13:34:44 +00:00
Copilot
271f10598f Add complete test scaffolding support with --test flag for API projects (#5176)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: kevwan <1918356+kevwan@users.noreply.github.com>
2025-09-27 21:13:13 +08:00
Remember
cf55a88ce3 fix(rest): change SSE SetWriteDeadline error log to debug level (#5162) 2025-09-27 12:48:35 +00:00
dependabot[bot]
c1c786b14a chore(deps): bump github.com/redis/go-redis/v9 from 9.14.0 to 9.15.0 (#5193)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-27 11:51:15 +08:00
Remember
988fb9d9bf fix: SSE handler blocking (#5181)
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-09-26 13:53:42 +00:00
Copilot
d212c81bca Add GitHub Copilot instructions for go-zero project (#5178)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: kevwan <1918356+kevwan@users.noreply.github.com>
2025-09-20 13:40:43 +08:00
Kevin Wan
bc43df2641 optimize: mapreduce panic stacktrace (#5168) 2025-09-14 19:33:09 +08:00
dependabot[bot]
351b8cb37b chore(deps): bump github.com/redis/go-redis/v9 from 9.13.0 to 9.14.0 (#5169)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-11 22:07:13 +08:00
wanwu
0d681a2e29 opt: optimization of machine performance data reading (#5174)
Co-authored-by: sam.yang <sam.yang@yijinin.com>
2025-09-11 13:56:07 +00:00
dependabot[bot]
5ea027c5de chore(deps): bump actions/setup-go from 5 to 6 (#5156)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-09 21:07:44 +08:00
dependabot[bot]
5de6112dcd chore(deps): bump actions/stale from 9 to 10 (#5157)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-09 18:55:45 +08:00
Kevin Wan
4fb51723b7 Add company to the list (#5153) 2025-09-07 22:53:49 +08:00
me-cs
06502d1115 update:optimize slice find and Unquote func (#5108) 2025-09-07 00:41:45 +00:00
kesonan
3854d6dd00 fix array type generation error (#5142) 2025-09-04 13:41:15 +00:00
dependabot[bot]
895854913a chore(deps): bump github.com/spf13/pflag from 1.0.7 to 1.0.10 in /tools/goctl (#5141)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-04 21:35:50 +08:00
dependabot[bot]
ef753b8857 chore(deps): bump github.com/spf13/cobra from 1.9.1 to 1.10.1 in /tools/goctl (#5147)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-04 21:28:40 +08:00
dependabot[bot]
9c16fede73 chore(deps): bump github.com/redis/go-redis/v9 from 9.12.1 to 9.13.0 (#5149)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-04 21:17:33 +08:00
Kevin Wan
ce11adb5e4 feat: add code generation headers in safe to edit files (#5136) 2025-09-01 21:27:30 +08:00
Kevin Wan
894e8b1218 chore: update goctl version (#5138) 2025-08-31 23:37:00 +08:00
Kevin Wan
2ec7e432dd chore: refactor (#5137) 2025-08-31 17:35:52 +08:00
guonaihong
870e8352c1 fix:issue-5110 (#5113) 2025-08-31 09:17:34 +00:00
Qiying Wang
de42f27e03 feat: prefer json.Marshaler over fmt.Stringer for JSON log output whe… (#5117) 2025-08-31 09:06:25 +00:00
Kevin Wan
955b8016aa feat: support goctl --module to set go module (#5135) 2025-08-31 16:40:49 +08:00
dependabot[bot]
d728a3b2d9 chore(deps): bump github.com/stretchr/testify from 1.11.0 to 1.11.1 in /tools/goctl (#5124)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-08-31 10:51:06 +08:00
dependabot[bot]
0c205a71fc chore(deps): bump github.com/gookit/color from 1.5.4 to 1.6.0 in /tools/goctl (#5132)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-08-31 10:42:06 +08:00
dependabot[bot]
a8c0199d96 chore(deps): bump github.com/grafana/pyroscope-go from 1.2.4 to 1.2.7 (#5121) 2025-08-28 08:47:40 +08:00
dependabot[bot]
032a266ec4 chore(deps): bump github.com/stretchr/testify from 1.11.0 to 1.11.1 (#5125) 2025-08-28 08:46:29 +08:00
dependabot[bot]
40b75fbb9b chore(deps): bump github.com/stretchr/testify from 1.10.0 to 1.11.0 (#5120) 2025-08-27 00:28:34 +08:00
dependabot[bot]
afad55045b chore(deps): bump github.com/stretchr/testify from 1.10.0 to 1.11.0 in /tools/goctl (#5119) 2025-08-26 21:09:57 +08:00
Kevin Wan
5f54f06ee5 chore: refactor field keys in logx (#5104) 2025-08-20 20:48:47 +08:00
Qiying Wang
20f56ae1d0 feat: support customize of log keys (#5103) 2025-08-20 12:11:45 +00:00
geekeryy
73d6fcfccd feat: Support projectPkg template variables in config, handler, logic, main, and svc template files (#4939) 2025-08-19 12:29:41 +00:00
Kevin Wan
20d20ef861 fix: github release workflow (#5096) 2025-08-17 23:04:32 +08:00
Kevin Wan
a37422b504 fix: release workflows (#5095) 2025-08-17 17:50:38 +08:00
Kevin Wan
a81d898408 chore: update go-zero version (#5093)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-08-17 17:00:15 +08:00
kesonan
a5d42e20d5 Update goctl version to 1.9.0-alpha (#5090) 2025-08-15 12:08:02 +00:00
kesonan
4bdb07f225 (goctl)feature: supported sse generation (#5082)
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-08-15 01:49:50 +00:00
dependabot[bot]
3e6ec9b83d chore(deps): bump go.mongodb.org/mongo-driver/v2 from 2.2.3 to 2.3.0 (#5087) 2025-08-15 05:08:22 +08:00
me-cs
f0a3d213dc update:simplify slice lookup by using slices.Contains() (#5084)
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-08-14 15:33:25 +00:00
Kevin Wan
94562ded74 chore: fix compile errors (#5085)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-08-14 23:09:35 +08:00
guonaihong
d68cf4920c fix: 5076 (#5083)
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-08-14 22:32:09 +08:00
dependabot[bot]
31b749ab67 chore(deps): bump actions/checkout from 4 to 5 (#5080) 2025-08-12 15:15:25 +08:00
dependabot[bot]
3834319278 chore(deps): bump github.com/redis/go-redis/v9 from 9.12.0 to 9.12.1 (#5078) 2025-08-12 12:27:33 +08:00
Kevin Wan
1c9d339361 chore: use uber gomock instead of golang/mock (#5075) 2025-08-10 23:01:08 +08:00
Kevin Wan
b7f601c912 feat: support sse in api files (#5074)
Signed-off-by: Kevin Wan <wanjunfeng@gmail.com>
2025-08-10 22:17:08 +08:00
Kevin Wan
1ebbc6f0c7 chore: refactor mon/monc (#5073) 2025-08-09 23:51:44 +08:00
me-cs
b41b1b00df update:github.com/mongodb/mongo-go-driver v2.0 Migration (#4687) 2025-08-09 13:21:53 +00:00
hoshi
f36e5fed35 fear(model);add uuid:varchar to p2m (#5022)
Co-authored-by: hoshi <zheng.hao1@outlook.com>
2025-08-09 06:24:33 +00:00
Kevin Wan
2583673c8b chore: always ignore unknown fields for gateway requests (#5072)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-08-09 13:07:17 +08:00
guonaihong
00e67b9d20 feat: add option to ignore unknown fields in gateway request parsing (#5058) 2025-08-09 04:40:07 +00:00
Kevin Wan
9fd1f29845 chore: refactor (#5071) 2025-08-09 12:27:43 +08:00
Ioannis Pinakoulakis
130e1ba963 perf: pre-allocate all known length arrays to avoid re-scaling (#5029)
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
2025-08-08 16:03:25 +00:00
Kevin Wan
a2b98dbcf7 chore: refactor (#5068) 2025-08-08 23:30:03 +08:00
sunhao1296
b46d507a1d fix: resolve concurrent get may lead to empty result in ImmutableResource (#5065)
Co-authored-by: hsun <hsun@apac.freewheel.com>
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-08-08 23:00:49 +08:00
Kevin Wan
3152581d0d optimize: logging with fields (#5066) 2025-08-08 12:23:10 +00:00
dependabot[bot]
46e466f037 chore(deps): bump github.com/redis/go-redis/v9 from 9.11.0 to 9.12.0 (#5056)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-08-06 01:18:27 +00:00
Kevin Wan
151b3d1085 chore: fix codecov warning (#5055)
Signed-off-by: Kevin Wan <wanjunfeng@gmail.com>
2025-08-05 22:55:26 +08:00
Kevin Wan
ea53fe41de chore: fix codecov problem (#5053) 2025-08-05 21:37:22 +08:00
queryfast
d9df08b079 chore: fix some minor issues in comments (#5051)
Signed-off-by: queryfast <queryfast@outlook.com>
2025-08-05 04:13:44 +00:00
Kevin Wan
569c00ad09 chore: refactor redis stream (#5048)
Signed-off-by: Kevin Wan <wanjunfeng@gmail.com>
2025-08-03 14:58:09 +08:00
jk2K
9da76fbf04 feat: redis support consumer groups (#4912)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-08-02 13:36:08 +00:00
Kevin Wan
b69db5e09d chore: refactor etcd discov (#5046) 2025-08-01 23:12:49 +08:00
guonaihong
ee6b7cee79 fix: Issue with etcd key disappearing and unable to auto-re-register (#4960) 2025-08-01 22:54:07 +08:00
Kevin Wan
d150248c52 chore: refactor jsonx.Marshal (#5045)
Signed-off-by: Kevin Wan <wanjunfeng@gmail.com>
2025-08-01 22:05:13 +08:00
Kevin Wan
610a7345dc chore: refactor gateway middlewares (#5042) 2025-08-01 00:09:57 +08:00
Yi Deng
b0b31f3993 feat(gateway): add custom middleware support with onion model (#5035) 2025-07-31 13:33:36 +00:00
Kevin Wan
82a937d517 chore: refactor unit tests (#5041) 2025-07-31 20:10:45 +08:00
Joe Bird
93c11a7eb7 fix(httpx): Resolve HTML escaping issue during JSON serialization (#5032) 2025-07-31 11:53:40 +00:00
Kevin Wan
63ec989376 fix: large memory usage on detail logging post requests (#5039) 2025-07-31 19:09:32 +08:00
Kevin Wan
bf75027889 chore: add more tests (#5038) 2025-07-30 19:45:35 +08:00
Kevin Wan
d505fae979 fix: unmarshal problem on env vars for type env string (#5037)
Signed-off-by: kevin <wanjunfeng@gmail.com>
Signed-off-by: Kevin Wan <wanjunfeng@gmail.com>
2025-07-30 18:09:25 +08:00
Kevin Wan
25f37ca750 chore: add unit test for WithCodeResponseWriter (#5028)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-07-25 21:45:47 +08:00
csbzy
0be63c3625 Fix SSE timeout will affected by http.Server 's WriteTimeout (#5024)
Co-authored-by: csbzy <chenshaobo65@mail.com>
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-07-23 23:04:45 +08:00
dependabot[bot]
b011a072c7 chore(deps): bump github.com/spf13/pflag from 1.0.6 to 1.0.7 in /tools/goctl (#5021) 2025-07-23 09:53:42 +08:00
Kevin Wan
3c9b6335fb chore: refactor set in collection package (#5016) 2025-07-18 21:25:40 +08:00
Qiu shao
bf6ef5f033 feat: add generic TypedSet with 2x performance boost and compile-time (#4888)
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-07-18 12:57:43 +00:00
Kevin Wan
ff890628b0 feat: optimize ignore fields in orm (#5015) 2025-07-18 20:36:28 +08:00
wiki
cc79e3d842 feat(sqlx): add field tag (-) skip logic in unwrapFields (#5010)
Co-authored-by: wukun30 <wukun30@meituan.com>
2025-07-18 11:58:38 +00:00
Kevin Wan
f11b78ced9 feat: support masking sensitive data in logx (#5003)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-07-18 19:51:22 +08:00
dependabot[bot]
1d2b0d7ab8 chore(deps): bump github.com/grafana/pyroscope-go from 1.2.3 to 1.2.4 (#5014)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-07-18 10:41:26 +08:00
dependabot[bot]
da987e1270 chore(deps): bump github.com/grafana/pyroscope-go from 1.2.2 to 1.2.3 (#5004)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-07-15 11:51:22 +08:00
Kevin Wan
12e03c8843 chore: update goctl version (#5002) 2025-07-13 11:18:51 +08:00
Twilikiss
8cf4f95bd7 Fix POST JSON parameter determination logic (goctl api swagger) & Add some unit test. (#4997)
Co-authored-by: XiaobinChen <xiaobin.chen@corerain.com>
2025-07-13 03:07:35 +00:00
anlynn
ba0febf308 fix: Fix PostgreSQL numeric type mapping in goctl model generation (#4992)
Co-authored-by: 李安琳 <anlynn@gmail.com>
2025-07-13 02:40:16 +00:00
Kevin Wan
c9ff6a10d3 feat: support serverless in rest (#5001)
Signed-off-by: kevin <wanjunfeng@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-07-13 00:00:52 +08:00
Kevin Wan
a71e56de52 fix: context key error in sql read write mode (#5000) 2025-07-12 06:58:08 +08:00
Kevin Wan
bae8d4f4c8 chore: refactoring sql read write mode (#4990)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-07-11 01:05:55 +08:00
zhoushuguang
8c6266f338 sql read write support (#4976)
Co-authored-by: light.zhou <light.zhou@bkyo.io>
2025-07-09 16:04:56 +00:00
Kevin Wan
95d5b81f44 chore: optimize pr 4979 (#4988)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-07-09 23:55:24 +08:00
geekeryy
bca7bbc142 fix: correct duration type comparison in environment variable processing (#4979) 2025-07-09 15:22:27 +00:00
Kevin Wan
df9a52664b fix issue #4986 2025-07-08 13:58:48 +00:00
Kevin Wan
937cf0db96 Update readme-cn.md (#4983) 2025-07-04 11:02:49 +08:00
Kevin Wan
75cebb65f8 fix: timeout 0s not working (#4932)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-07-01 17:01:24 +08:00
dependabot[bot]
410f56e73a chore(deps): bump github.com/redis/go-redis/v9 from 9.10.0 to 9.11.0 (#4969)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-06-25 18:35:01 +08:00
dependabot[bot]
017909a3ab chore(deps): bump github.com/emicklei/proto from 1.14.1 to 1.14.2 in /tools/goctl (#4961)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-06-19 15:40:18 +08:00
kesonan
0d31e6c375 (goctl): fix #4943 (#4953) 2025-06-14 15:36:30 +00:00
Kevin Wan
0ba86b1849 chore: add more tests (#4949) 2025-06-13 22:10:08 +08:00
wanwu
4cacc4d9d3 fix: the time.Duration type panics due to numerical values (#4944)
Co-authored-by: sam.yang <sam.yang@yijinin.com>
2025-06-12 15:11:07 +00:00
Eric
a99c14da4a fix: typo of the logic of CpuThreshold in comments (#4942)
Co-authored-by: zhouyy <zhouyy@ickey.cn>
2025-06-12 08:28:44 +00:00
Kevin Wan
985582264a chore: fix warnings (#4940) 2025-06-12 00:04:29 +08:00
388 changed files with 21431 additions and 10611 deletions

View File

@@ -1,13 +0,0 @@
coverage:
status:
patch: true
project: false # disabled because project coverage is not stable
comment:
layout: "flags, files"
behavior: once
require_changes: true
ignore:
- "tools"
- "**/mock"
- "**/*_mock.go"
- "**/*test"

344
.github/copilot-instructions.md vendored Normal file
View File

@@ -0,0 +1,344 @@
# GitHub Copilot Instructions for go-zero
This document provides guidelines for GitHub Copilot when assisting with development in the go-zero project.
## Project Overview
go-zero is a web and RPC framework with lots of built-in engineering practices designed to ensure the stability of busy services with resilience design. It has been serving sites with tens of millions of users for years.
### Key Architecture Components
- **REST API framework** (`rest/`) - HTTP service framework with middleware chain support
- **RPC framework** (`zrpc/`) - gRPC-based RPC framework with etcd service discovery and p2c_ewma load balancing
- **Gateway** (`gateway/`) - API gateway supporting both HTTP and gRPC upstreams with proto-based routing
- **MCP Server** (`mcp/`) - Model Context Protocol server for AI agent integration via SSE
- **Core utilities** (`core/`) - Production-grade components:
- Resilience: circuit breakers (`breaker/`), rate limiters (`limit/`), adaptive load shedding (`load/`)
- Storage: SQL with cache (`stores/sqlc/`), Redis (`stores/redis/`), MongoDB (`stores/mongo/`)
- Concurrency: MapReduce (`mr/`), worker pools (`executors/`), sync primitives (`syncx/`)
- Observability: metrics (`metric/`), tracing (`trace/`), structured logging (`logx/`)
- **Code generation tool** (`tools/goctl/`) - CLI tool for generating Go code from `.api` and `.proto` files
## Coding Standards and Conventions
### Code Style
1. **Follow Go conventions**: Use `gofmt` for formatting, follow effective Go practices
2. **Package naming**: Use lowercase, single-word package names when possible
3. **Error handling**: Always handle errors explicitly, use `errorx.BatchError` for multiple errors
4. **Context propagation**: Always pass `context.Context` as the first parameter for functions that may block
5. **Configuration structures**: Use struct tags with JSON annotations, defaults, and validation
**Pattern**: All service configs embed `service.ServiceConf` for common fields (Name, Log, Mode, Telemetry)
```go
type Config struct {
service.ServiceConf // Always embed for services
Host string `json:",default=0.0.0.0"`
Port int // Required field (no default)
Timeout int64 `json:",default=3000"` // Timeouts in milliseconds
Optional string `json:",optional"` // Optional field
Mode string `json:",default=pro,options=dev|test|rt|pre|pro"` // Validated options
}
```
**Service modes**: `dev`/`test`/`rt` disable load shedding and stats; `pre`/`pro` enable all resilience features
### Interface Design
1. **Small interfaces**: Follow Go's preference for small, focused interfaces
2. **Context methods**: Provide both context and non-context versions of methods
3. **Options pattern**: Use functional options for complex configuration
Example:
```go
func (c *Client) Get(key string, val any) error {
return c.GetCtx(context.Background(), key, val)
}
func (c *Client) GetCtx(ctx context.Context, key string, val any) error {
// implementation
}
```
### Testing Patterns
1. **Test file naming**: Use `*_test.go` suffix
2. **Test function naming**: Use `TestFunctionName` pattern
3. **Use testify/assert**: Prefer `assert` package for assertions
4. **Table-driven tests**: Use table-driven tests for multiple scenarios
5. **Mock interfaces**: Use `go.uber.org/mock` for mocking
6. **Test helpers**: Use `redistest`, `mongtest` helpers for database testing
Example test pattern:
```go
func TestSomething(t *testing.T) {
tests := []struct {
name string
input string
expected string
wantErr bool
}{
{"valid case", "input", "output", false},
{"error case", "bad", "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := SomeFunction(tt.input)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
```
## Framework-Specific Guidelines
### REST API Development
1. **API Definition**: Use `.api` files to define REST APIs with goctl codegen
2. **Handler pattern**: Separate business logic into logic packages (handlers call logic layer)
3. **Middleware chain**: Middlewares wrap via `chain.Chain` interface - use `Append()` or `Prepend()` to control order
- Built-in middlewares (all in `rest/handler/`): tracing, logging, metrics, recovery, breaker, shedding, timeout, maxconns, maxbytes, gunzip
- Custom middleware: `func(http.Handler) http.Handler` - call `next.ServeHTTP(w, r)` to continue chain
4. **Response handling**: Use `httpx.WriteJson(w, code, v)` for JSON responses
5. **Error handling**: Use `httpx.Error(w, err)` or `httpx.ErrorCtx(ctx, w, err)` for HTTP error responses
6. **Route registration**: Routes defined with `Method`, `Path`, and `Handler` - wildcards use `:param` syntax
### RPC Development
1. **Protocol Buffers**: Use protobuf for service definitions, generate code with goctl
2. **Service discovery**: Use etcd for dynamic service registration/discovery, or direct endpoints for static routing
3. **Load balancing**: Default is `p2c_ewma` (power of 2 choices with EWMA), configurable via `BalancerName`
4. **Client configuration**: Support `Etcd`, `Endpoints`, or `Target` - use `BuildTarget()` to construct connection string
5. **Interceptors**: Implement gRPC interceptors for cross-cutting concerns (auth, logging, metrics)
6. **Health checks**: gRPC health checks enabled by default (`Health: true`)
### Database Operations
1. **SQL operations**: Use `sqlx.SqlConn` interface - methods always end with `Ctx` for context support
2. **Caching pattern**: `stores/sqlc` provides `CachedConn` for automatic cache-aside pattern
- `QueryRowCtx`: Query with cache key, auto-populate on cache miss
- `ExecCtx`: Execute and delete cache keys
3. **Transactions**: Use `sqlx.SqlConn.TransactCtx()` to get transaction session
4. **Connection pooling**: Managed automatically (64 max idle/open, 1min lifetime)
5. **Test helpers**: Use `redistest.CreateRedis(t)` for Redis, SQL mocks for DB testing
Example cache pattern:
```go
err := c.QueryRowCtx(ctx, &dest, key, func(ctx context.Context, conn sqlx.SqlConn) error {
return conn.QueryRowCtx(ctx, &dest, query, args...)
})
```
### Configuration Management
1. **YAML configuration**: Use YAML for configuration files
2. **Environment variables**: Support environment variable overrides
3. **Validation**: Include proper validation for configuration parameters
4. **Sensible defaults**: Provide reasonable default values
## Error Handling Best Practices
1. **Wrap errors**: Use `fmt.Errorf` with `%w` verb to wrap errors
2. **Custom errors**: Define custom error types when needed
3. **Error logging**: Log errors appropriately with context
4. **Graceful degradation**: Implement fallback mechanisms
## Performance Considerations
1. **Resource pools**: Use connection pools and worker pools
2. **Circuit breakers**: Implement circuit breaker patterns for external calls
3. **Rate limiting**: Apply rate limiting to protect services
4. **Load shedding**: Implement adaptive load shedding
5. **Metrics**: Add appropriate metrics and monitoring
## Security Guidelines
1. **Input validation**: Validate all input parameters
2. **SQL injection prevention**: Use parameterized queries
3. **Authentication**: Implement proper JWT token handling
4. **HTTPS**: Support TLS/HTTPS configurations
5. **CORS**: Configure CORS appropriately for web APIs
## Documentation Standards
1. **Package documentation**: Include package-level documentation
2. **Function documentation**: Document exported functions with examples
3. **API documentation**: Maintain API documentation in sync
4. **README updates**: Update README for significant changes
## GitHub Issue Management
### Understanding and Categorizing Issues
When analyzing GitHub issues, consider these common categories:
1. **Bug Reports**: Stack traces, version info, reproduction steps
2. **Feature Requests**: Use case, proposed solution, alternatives
3. **Questions**: Usage, configuration, or architecture
4. **Documentation Issues**: Missing, unclear, or incorrect docs
5. **Performance Issues**: Benchmarks, profiling data, resource usage
### Issue Analysis Checklist
- Identify affected component (REST, RPC, Gateway, MCP, Core utilities, goctl)
- Check versions (go-zero, Go)
- Look for reproduction steps or code examples
- Review code snippets, logs, or stack traces
- Check if related to resilience features (breaker, load shedding, rate limiting)
- Determine production impact
### Responding to Issues
Be helpful and professional. Ask clarifying questions when needed. Reference relevant documentation and code files. Provide code examples following project conventions. Suggest workarounds when applicable.
### Chinese to English Translation
go-zero has an international user base. When encountering issues or comments written in Chinese, translate them to English to ensure all contributors can participate in discussions.
#### Translation Guidelines
1. **Update issue titles**: Edit the issue title to include English translation only
2. **Translate comments in place**: Add a comment with the English translation, followed by the original Chinese text
3. **Keep original Chinese**: After translating, include the original Chinese text in a blockquote for verification
4. **Encourage English communication**: Politely suggest users write in English for better collaboration
5. **Maintain technical accuracy**: Preserve technical terms, component names, and code exactly
6. **Translate naturally**: Avoid literal word-by-word translation; use idiomatic English
7. **Preserve formatting**: Keep markdown formatting, code blocks, and links intact
8. **Keep URLs unchanged**: Don't translate URLs or file paths
#### Common Technical Terms (Chinese → English)
- 框架 → **Framework** | 中间件 → **Middleware** | 负载均衡 → **Load Balancing**
- 熔断器 → **Circuit Breaker** | 限流 → **Rate Limiting** | 降载/过载保护 → **Load Shedding**
- 服务发现 → **Service Discovery** | 配置 → **Configuration** | 弹性/容错 → **Resilience** | 微服务 → **Microservices**
#### Translation Example
**Original Chinese Title:** `goctl 执行环境问题`
**Updated Title:** `goctl Execution Environment Issue`
**Original Chinese Comment:** `我在项目中遇到熔断器配置问题`
**Translation in Comment:**
```markdown
I encountered a circuit breaker configuration issue in my project.
> Original (原文): 我在项目中遇到熔断器配置问题
```
### Common Issue Patterns and Solutions
#### Configuration Issues
- Check `service.ServiceConf` embedding and struct tags
- Verify YAML syntax, defaults, and validation rules
- Reference: [rest/config.go](rest/config.go), [zrpc/config.go](zrpc/config.go)
#### Code Generation (goctl) Issues
- Verify `.api` or `.proto` file syntax and goctl version
- Reference: `tools/goctl/` directory
#### RPC Connection Issues
- Check etcd configuration, service discovery, and endpoints
- Verify load balancing settings (p2c_ewma)
#### Database/Cache Issues
- Verify `sqlx.SqlConn` usage with context
- Check cache key generation, invalidation, and connection pools
- Use test helpers (`redistest`, `mongtest`)
#### Performance Issues
- Check if load shedding is enabled (mode: `pre`/`pro`)
- Review circuit breaker thresholds, rate limiting, and context timeouts
### Referencing Codebase
When explaining issues, reference specific files and patterns:
- REST API: `rest/`, `rest/handler/`, `rest/httpx/`
- RPC: `zrpc/`, `zrpc/internal/`
- Core utilities: `core/breaker/`, `core/limit/`, `core/load/`, etc.
- Gateway: `gateway/`
- MCP: `mcp/`
- Code generation: `tools/goctl/`
- Examples: `adhoc/` directory contains various examples
### Encouraging Best Practices
When responding to issues, gently guide users toward:
- Proper error handling with context
- Using resilience features (breakers, rate limiters)
- Following testing patterns with table-driven tests
- Implementing proper resource cleanup
- Reading existing documentation in `docs/` and `readme.md`
## Common Patterns to Follow
### Service Configuration
```go
type ServiceConf struct {
Name string
Log logx.LogConf
Mode string `json:",default=pro,options=[dev,test,pre,pro]"`
// ... other common fields
}
```
### Middleware Implementation
```go
func SomeMiddleware() rest.Middleware {
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Pre-processing
next.ServeHTTP(w, r)
// Post-processing
}
}
}
```
### Resource Management
Always implement proper resource cleanup using defer and context cancellation.
## Build and Test Commands
- Build: `go build ./...`
- Test: `go test ./...`
- Test with race detection: `go test -race ./...`
- Format: `gofmt -w .`
- Code generation:
- REST API: `goctl api go -api *.api -dir .`
- RPC: `goctl rpc protoc *.proto --go_out=. --go-grpc_out=. --zrpc_out=.`
- Model from SQL: `goctl model mysql datasource -url="user:pass@tcp(host:port)/db" -table="*" -dir="./model"`
## Critical Architecture Patterns
### Resilience Design Philosophy
go-zero implements defense-in-depth with multiple protection layers:
1. **Circuit Breaker** (`core/breaker`): Google SRE breaker - tracks success/failure, opens on error threshold
2. **Adaptive Load Shedding** (`core/load`): CPU-based auto-rejection when system overloaded (disabled in dev/test/rt modes)
3. **Rate Limiting** (`core/limit`): Token bucket (Redis-based) and period limiters
4. **Timeout Control**: Cascading timeouts via context - set at multiple levels (client, server, handler)
### Middleware Chain Architecture
`rest/chain` provides middleware composition:
```go
// Middleware signature
type Middleware func(http.Handler) http.Handler
// Chain operations
chain := chain.New(m1, m2)
chain.Append(m3) // Adds to end: m1 -> m2 -> m3
chain.Prepend(m0) // Adds to start: m0 -> m1 -> m2 -> m3
handler := chain.Then(finalHandler)
```
### Concurrency Patterns
- **MapReduce** (`core/mr`): Parallel processing with worker pools - use for batch operations
- **Executors** (`core/executors`): Bulk/period executors for batching operations
- **SingleFlight** (`core/syncx`): Deduplicates concurrent identical requests
Remember to run tests and ensure all checks pass before submitting changes. The project emphasizes high quality, performance, and reliability, so these should be primary considerations in all development work.

View File

@@ -35,11 +35,11 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v6
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
uses: github/codeql-action/init@v4
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
@@ -50,7 +50,7 @@ jobs:
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
uses: github/codeql-action/autobuild@v3
uses: github/codeql-action/autobuild@v4
# Command-line programs to run using the OS shell.
# 📚 https://git.io/JvXDl
@@ -64,4 +64,4 @@ jobs:
# make release
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3
uses: github/codeql-action/analyze@v4

View File

@@ -12,10 +12,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check out code into the Go module directory
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Set up Go 1.x
uses: actions/setup-go@v5
uses: actions/setup-go@v6
with:
go-version-file: go.mod
check-latest: true
@@ -40,17 +40,22 @@ jobs:
run: go test -race -coverprofile=coverage.txt -covermode=atomic ./...
- name: Codecov
uses: codecov/codecov-action@v5
uses: codecov/codecov-action@v6
with:
files: ./coverage.txt
flags: unittests
name: codecov-umbrella
fail_ci_if_error: false
test-win:
name: Windows
runs-on: windows-latest
steps:
- name: Checkout codebase
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Set up Go 1.x
uses: actions/setup-go@v5
uses: actions/setup-go@v6
with:
# make sure Go version compatible with go-zero
go-version-file: go.mod

View File

@@ -7,7 +7,7 @@ jobs:
close-issues:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v9
- uses: actions/stale@v10
with:
days-before-issue-stale: 365
days-before-issue-close: 90

View File

@@ -16,7 +16,7 @@ jobs:
- goarch: "386"
goos: darwin
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- uses: zeromicro/go-zero-release-action@master
with:
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -5,7 +5,12 @@ jobs:
name: runner / staticcheck
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- uses: actions/setup-go@v6
with:
go-version-file: go.mod
check-latest: true
cache: true
- uses: reviewdog/action-staticcheck@v1
with:
github_token: ${{ secrets.github_token }}

View File

@@ -10,10 +10,10 @@ jobs:
version-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- name: Set up Go
uses: actions/setup-go@v5
uses: actions/setup-go@v6
with:
go-version: '1.21'

1
.gitignore vendored
View File

@@ -17,6 +17,7 @@
**/logs
**/adhoc
**/coverage.txt
**/WARP.md
# for test purpose
go.work

View File

@@ -40,7 +40,7 @@ type (
}
)
// New create a Filter, store is the backed redis, key is the key for the bloom filter,
// New creates a Filter, store is the backed redis, key is the key for the bloom filter,
// bits is how many bits will be used, maps is how many hashes for each addition.
// best practices:
// elements - means how many actual elements

View File

@@ -6,8 +6,6 @@ import (
"crypto/cipher"
"encoding/base64"
"errors"
"github.com/zeromicro/go-zero/core/logx"
)
// ErrPaddingSize indicates bad padding size.
@@ -27,7 +25,8 @@ func newECB(b cipher.Block) *ecb {
type ecbEncrypter ecb
// NewECBEncrypter returns an ECB encrypter.
// Deprecated: NewECBEncrypter returns an ECB encrypter.
// ECB mode is insecure for multi-block data. Use AES-GCM instead.
func NewECBEncrypter(b cipher.Block) cipher.BlockMode {
return (*ecbEncrypter)(newECB(b))
}
@@ -39,12 +38,10 @@ func (x *ecbEncrypter) BlockSize() int { return x.blockSize }
// the block size. Dst and src must overlap entirely or not at all.
func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
if len(src)%x.blockSize != 0 {
logx.Error("crypto/cipher: input not full blocks")
return
panic("crypto/cipher: input not full blocks")
}
if len(dst) < len(src) {
logx.Error("crypto/cipher: output smaller than input")
return
panic("crypto/cipher: output smaller than input")
}
for len(src) > 0 {
@@ -56,7 +53,8 @@ func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
type ecbDecrypter ecb
// NewECBDecrypter returns an ECB decrypter.
// Deprecated: NewECBDecrypter returns an ECB decrypter.
// ECB mode is insecure for multi-block data. Use AES-GCM instead.
func NewECBDecrypter(b cipher.Block) cipher.BlockMode {
return (*ecbDecrypter)(newECB(b))
}
@@ -70,12 +68,10 @@ func (x *ecbDecrypter) BlockSize() int {
// the block size. Dst and src must overlap entirely or not at all.
func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
if len(src)%x.blockSize != 0 {
logx.Error("crypto/cipher: input not full blocks")
return
panic("crypto/cipher: input not full blocks")
}
if len(dst) < len(src) {
logx.Error("crypto/cipher: output smaller than input")
return
panic("crypto/cipher: output smaller than input")
}
for len(src) > 0 {
@@ -85,14 +81,18 @@ func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
}
}
// EcbDecrypt decrypts src with the given key.
// Deprecated: EcbDecrypt decrypts src with the given key.
// ECB mode is insecure for multi-block data. Use AES-GCM instead.
func EcbDecrypt(key, src []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
logx.Errorf("Decrypt key error: % x", key)
return nil, err
}
if len(src)%block.BlockSize() != 0 {
return nil, ErrPaddingSize
}
decrypter := NewECBDecrypter(block)
decrypted := make([]byte, len(src))
decrypter.CryptBlocks(decrypted, src)
@@ -100,8 +100,9 @@ func EcbDecrypt(key, src []byte) ([]byte, error) {
return pkcs5Unpadding(decrypted, decrypter.BlockSize())
}
// EcbDecryptBase64 decrypts base64 encoded src with the given base64 encoded key.
// Deprecated: EcbDecryptBase64 decrypts base64 encoded src with the given base64 encoded key.
// The returned string is also base64 encoded.
// ECB mode is insecure for multi-block data. Use AES-GCM instead.
func EcbDecryptBase64(key, src string) (string, error) {
keyBytes, err := getKeyBytes(key)
if err != nil {
@@ -121,11 +122,11 @@ func EcbDecryptBase64(key, src string) (string, error) {
return base64.StdEncoding.EncodeToString(decryptedBytes), nil
}
// EcbEncrypt encrypts src with the given key.
// Deprecated: EcbEncrypt encrypts src with the given key.
// ECB mode is insecure for multi-block data. Use AES-GCM instead.
func EcbEncrypt(key, src []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
logx.Errorf("Encrypt key error: % x", key)
return nil, err
}
@@ -137,8 +138,9 @@ func EcbEncrypt(key, src []byte) ([]byte, error) {
return crypted, nil
}
// EcbEncryptBase64 encrypts base64 encoded src with the given base64 encoded key.
// Deprecated: EcbEncryptBase64 encrypts base64 encoded src with the given base64 encoded key.
// The returned string is also base64 encoded.
// ECB mode is insecure for multi-block data. Use AES-GCM instead.
func EcbEncryptBase64(key, src string) (string, error) {
keyBytes, err := getKeyBytes(key)
if err != nil {
@@ -179,10 +181,20 @@ func pkcs5Padding(ciphertext []byte, blockSize int) []byte {
func pkcs5Unpadding(src []byte, blockSize int) ([]byte, error) {
length := len(src)
unpadding := int(src[length-1])
if unpadding >= length || unpadding > blockSize {
if length == 0 {
return nil, ErrPaddingSize
}
unpadding := int(src[length-1])
if unpadding < 1 || unpadding > blockSize || unpadding > length {
return nil, ErrPaddingSize
}
for _, b := range src[length-unpadding:] {
if int(b) != unpadding {
return nil, ErrPaddingSize
}
}
return src[:length-unpadding], nil
}

View File

@@ -28,8 +28,8 @@ func TestAesEcb(t *testing.T) {
_, err = EcbDecrypt(badKey2, dst)
assert.NotNil(t, err)
_, err = EcbDecrypt(key, val)
// not enough block, just nil
assert.Nil(t, err)
// not a multiple of block size
assert.NotNil(t, err)
src, err := EcbDecrypt(key, dst)
assert.Nil(t, err)
assert.Equal(t, val, src)
@@ -41,33 +41,28 @@ func TestAesEcb(t *testing.T) {
assert.Equal(t, 16, decrypter.BlockSize())
dst = make([]byte, 8)
encrypter.CryptBlocks(dst, val)
for _, b := range dst {
assert.Equal(t, byte(0), b)
}
assert.Panics(t, func() {
encrypter.CryptBlocks(dst, val)
})
dst = make([]byte, 8)
encrypter.CryptBlocks(dst, valLong)
for _, b := range dst {
assert.Equal(t, byte(0), b)
}
assert.Panics(t, func() {
encrypter.CryptBlocks(dst, valLong)
})
dst = make([]byte, 8)
decrypter.CryptBlocks(dst, val)
for _, b := range dst {
assert.Equal(t, byte(0), b)
}
assert.Panics(t, func() {
decrypter.CryptBlocks(dst, val)
})
dst = make([]byte, 8)
decrypter.CryptBlocks(dst, valLong)
for _, b := range dst {
assert.Equal(t, byte(0), b)
}
assert.Panics(t, func() {
decrypter.CryptBlocks(dst, valLong)
})
_, err = EcbEncryptBase64("cTR0N3dDKkYtSmFOZFJnVWpYbjJyNXU4eC9BP0QK", "aGVsbG93b3JsZGxvbmcuLgo=")
assert.Error(t, err)
}
func TestAesEcbBase64(t *testing.T) {
const (
val = "hello"
@@ -98,3 +93,44 @@ func TestAesEcbBase64(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, val, string(b))
}
func TestPkcs5UnpaddingEmptyInput(t *testing.T) {
_, err := pkcs5Unpadding([]byte{}, 16)
assert.Equal(t, ErrPaddingSize, err)
}
func TestPkcs5UnpaddingMalformedPadding(t *testing.T) {
// Valid PKCS5 padding of 3: last 3 bytes should all be 0x03
// Here we corrupt one padding byte
malformed := []byte{0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
0x41, 0x41, 0x41, 0x41, 0x41, 0x02, 0x03, 0x03}
_, err := pkcs5Unpadding(malformed, 16)
assert.Equal(t, ErrPaddingSize, err)
// All padding bytes correct
valid := []byte{0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
0x41, 0x41, 0x41, 0x41, 0x41, 0x03, 0x03, 0x03}
result, err := pkcs5Unpadding(valid, 16)
assert.NoError(t, err)
assert.Equal(t, valid[:13], result)
}
func TestPkcs5UnpaddingInvalidPaddingValue(t *testing.T) {
// padding value = 0 (< 1)
_, err := pkcs5Unpadding([]byte{0x41, 0x00}, 16)
assert.Equal(t, ErrPaddingSize, err)
// padding value > blockSize
_, err = pkcs5Unpadding([]byte{0x41, 0x41, 0x41, 0x41, 17}, 4)
assert.Equal(t, ErrPaddingSize, err)
// padding value > length
_, err = pkcs5Unpadding([]byte{0x41, 0x03}, 16)
assert.Equal(t, ErrPaddingSize, err)
}
func TestEcbDecryptEmptyInput(t *testing.T) {
key := []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
_, err := EcbDecrypt(key, []byte{})
assert.Equal(t, ErrPaddingSize, err)
}

View File

@@ -35,7 +35,7 @@ func ComputeKey(pubKey, priKey *big.Int) (*big.Int, error) {
return nil, ErrInvalidPubKey
}
if pubKey.Sign() <= 0 && p.Cmp(pubKey) <= 0 {
if pubKey.Sign() <= 0 || p.Cmp(pubKey) <= 0 {
return nil, ErrPubKeyOutOfBound
}

View File

@@ -94,3 +94,32 @@ func TestDHOnErrors(t *testing.T) {
assert.NotNil(t, NewPublicKey([]byte("")))
}
func TestDHPubKeyBoundary(t *testing.T) {
key, err := GenerateKey()
assert.Nil(t, err)
// pubKey = 0 should be rejected
_, err = ComputeKey(big.NewInt(0), key.PriKey)
assert.ErrorIs(t, err, ErrPubKeyOutOfBound)
// pubKey = -1 should be rejected
_, err = ComputeKey(big.NewInt(-1), key.PriKey)
assert.ErrorIs(t, err, ErrPubKeyOutOfBound)
// pubKey = p should be rejected
_, err = ComputeKey(new(big.Int).Set(p), key.PriKey)
assert.ErrorIs(t, err, ErrPubKeyOutOfBound)
// pubKey = p+1 should be rejected
_, err = ComputeKey(new(big.Int).Add(p, big.NewInt(1)), key.PriKey)
assert.ErrorIs(t, err, ErrPubKeyOutOfBound)
// pubKey = 1 should be accepted
_, err = ComputeKey(big.NewInt(1), key.PriKey)
assert.NoError(t, err)
// pubKey = p-1 should be accepted
_, err = ComputeKey(new(big.Int).Sub(p, big.NewInt(1)), key.PriKey)
assert.NoError(t, err)
}

View File

@@ -3,6 +3,7 @@ package codec
import (
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/pem"
@@ -46,7 +47,9 @@ type (
}
)
// NewRsaDecrypter returns a RsaDecrypter with the given file.
// Deprecated: NewRsaDecrypter returns a RsaDecrypter with the given file.
// PKCS#1 v1.5 padding is vulnerable to padding oracle attacks.
// Use NewRsaOAEPDecrypter instead.
func NewRsaDecrypter(file string) (RsaDecrypter, error) {
content, err := os.ReadFile(file)
if err != nil {
@@ -90,7 +93,9 @@ func (r *rsaDecrypter) DecryptBase64(input string) ([]byte, error) {
return r.Decrypt(base64Decoded)
}
// NewRsaEncrypter returns a RsaEncrypter with the given key.
// Deprecated: NewRsaEncrypter returns a RsaEncrypter with the given key.
// PKCS#1 v1.5 padding is vulnerable to padding oracle attacks.
// Use NewRsaOAEPEncrypter instead.
func NewRsaEncrypter(key []byte) (RsaEncrypter, error) {
block, _ := pem.Decode(key)
if block == nil {
@@ -154,3 +159,90 @@ func rsaDecryptBlock(privateKey *rsa.PrivateKey, block []byte) ([]byte, error) {
func rsaEncryptBlock(publicKey *rsa.PublicKey, msg []byte) ([]byte, error) {
return rsa.EncryptPKCS1v15(rand.Reader, publicKey, msg)
}
// NewRsaOAEPDecrypter returns a RsaDecrypter using OAEP with SHA-256.
func NewRsaOAEPDecrypter(file string) (RsaDecrypter, error) {
content, err := os.ReadFile(file)
if err != nil {
return nil, err
}
block, _ := pem.Decode(content)
if block == nil {
return nil, ErrPrivateKey
}
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
return &rsaOAEPDecrypter{
rsaBase: rsaBase{
bytesLimit: privateKey.N.BitLen() >> 3,
},
privateKey: privateKey,
}, nil
}
// NewRsaOAEPEncrypter returns a RsaEncrypter using OAEP with SHA-256.
func NewRsaOAEPEncrypter(key []byte) (RsaEncrypter, error) {
block, _ := pem.Decode(key)
if block == nil {
return nil, ErrPublicKey
}
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, err
}
switch pubKey := pub.(type) {
case *rsa.PublicKey:
// OAEP overhead: 2*hash_size + 2
hashSize := sha256.New().Size()
return &rsaOAEPEncrypter{
rsaBase: rsaBase{
bytesLimit: (pubKey.N.BitLen() >> 3) - 2*hashSize - 2,
},
publicKey: pubKey,
}, nil
default:
return nil, ErrNotRsaKey
}
}
type rsaOAEPDecrypter struct {
rsaBase
privateKey *rsa.PrivateKey
}
func (r *rsaOAEPDecrypter) Decrypt(input []byte) ([]byte, error) {
return r.crypt(input, func(block []byte) ([]byte, error) {
return rsa.DecryptOAEP(sha256.New(), rand.Reader, r.privateKey, block, nil)
})
}
func (r *rsaOAEPDecrypter) DecryptBase64(input string) ([]byte, error) {
if len(input) == 0 {
return nil, nil
}
base64Decoded, err := base64.StdEncoding.DecodeString(input)
if err != nil {
return nil, err
}
return r.Decrypt(base64Decoded)
}
type rsaOAEPEncrypter struct {
rsaBase
publicKey *rsa.PublicKey
}
func (r *rsaOAEPEncrypter) Encrypt(input []byte) ([]byte, error) {
return r.crypt(input, func(block []byte) ([]byte, error) {
return rsa.EncryptOAEP(sha256.New(), rand.Reader, r.publicKey, block, nil)
})
}

View File

@@ -1,7 +1,12 @@
package codec
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"os"
"testing"
@@ -58,3 +63,78 @@ func TestBadPubKey(t *testing.T) {
_, err := NewRsaEncrypter([]byte("foo"))
assert.Equal(t, ErrPublicKey, err)
}
func TestOAEPCryption(t *testing.T) {
enc, err := NewRsaOAEPEncrypter([]byte(pubKey))
assert.Nil(t, err)
ret, err := enc.Encrypt([]byte(testBody))
assert.Nil(t, err)
file, err := fs.TempFilenameWithText(priKey)
assert.Nil(t, err)
defer os.Remove(file)
dec, err := NewRsaOAEPDecrypter(file)
assert.Nil(t, err)
actual, err := dec.Decrypt(ret)
assert.Nil(t, err)
assert.Equal(t, testBody, string(actual))
actual, err = dec.DecryptBase64(base64.StdEncoding.EncodeToString(ret))
assert.Nil(t, err)
assert.Equal(t, testBody, string(actual))
// empty input
actual, err = dec.DecryptBase64("")
assert.Nil(t, err)
assert.Nil(t, actual)
}
func TestOAEPBadKeys(t *testing.T) {
_, err := NewRsaOAEPEncrypter([]byte("bad"))
assert.Equal(t, ErrPublicKey, err)
_, err = NewRsaOAEPDecrypter("nonexistent")
assert.Error(t, err)
// valid PEM but invalid private key content
badPem, err := fs.TempFilenameWithText("-----BEGIN RSA PRIVATE KEY-----\nYmFk\n-----END RSA PRIVATE KEY-----")
assert.Nil(t, err)
defer os.Remove(badPem)
_, err = NewRsaOAEPDecrypter(badPem)
assert.Error(t, err)
// not PEM content at all
notPem, err := fs.TempFilenameWithText("not a pem file")
assert.Nil(t, err)
defer os.Remove(notPem)
_, err = NewRsaOAEPDecrypter(notPem)
assert.Equal(t, ErrPrivateKey, err)
}
func TestOAEPEncrypterParseError(t *testing.T) {
// valid PEM block but invalid public key content
badPub := []byte("-----BEGIN PUBLIC KEY-----\nYmFk\n-----END PUBLIC KEY-----")
_, err := NewRsaOAEPEncrypter(badPub)
assert.Error(t, err)
}
func TestOAEPEncrypterNonRsaKey(t *testing.T) {
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.Nil(t, err)
derBytes, err := x509.MarshalPKIXPublicKey(&ecKey.PublicKey)
assert.Nil(t, err)
ecPem := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: derBytes})
_, err = NewRsaOAEPEncrypter(ecPem)
assert.Equal(t, ErrNotRsaKey, err)
}
func TestOAEPDecryptBase64Error(t *testing.T) {
file, err := fs.TempFilenameWithText(priKey)
assert.Nil(t, err)
defer os.Remove(file)
dec, err := NewRsaOAEPDecrypter(file)
assert.Nil(t, err)
_, err = dec.DecryptBase64("not-valid-base64!!!")
assert.Error(t, err)
}

View File

@@ -81,6 +81,10 @@ func (c *Cache) Del(key string) {
delete(c.data, key)
c.lruCache.remove(key)
c.lock.Unlock()
// RemoveTimer is called outside the lock to avoid performance impact from this
// potentially time-consuming operation. Data integrity is maintained by lruCache,
// which will eventually evict any remaining entries when capacity is exceeded.
c.timingWheel.RemoveTimer(key)
}

View File

@@ -1,235 +1,53 @@
package collection
import (
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logx"
)
import "github.com/zeromicro/go-zero/core/lang"
const (
unmanaged = iota
untyped
intType
int64Type
uintType
uint64Type
stringType
)
// Set is not thread-safe, for concurrent use, make sure to use it with synchronization.
type Set struct {
data map[any]lang.PlaceholderType
tp int
// Set is a type-safe generic set collection.
// It's not thread-safe, use with synchronization for concurrent access.
type Set[T comparable] struct {
data map[T]lang.PlaceholderType
}
// NewSet returns a managed Set, can only put the values with the same type.
func NewSet() *Set {
return &Set{
data: make(map[any]lang.PlaceholderType),
tp: untyped,
// NewSet returns a new type-safe set.
func NewSet[T comparable]() *Set[T] {
return &Set[T]{
data: make(map[T]lang.PlaceholderType),
}
}
// NewUnmanagedSet returns an unmanaged Set, which can put values with different types.
func NewUnmanagedSet() *Set {
return &Set{
data: make(map[any]lang.PlaceholderType),
tp: unmanaged,
// Add adds items to the set. Duplicates are automatically ignored.
func (s *Set[T]) Add(items ...T) {
for _, item := range items {
s.data[item] = lang.Placeholder
}
}
// Add adds i into s.
func (s *Set) Add(i ...any) {
for _, each := range i {
s.add(each)
}
// Clear removes all items from the set.
func (s *Set[T]) Clear() {
clear(s.data)
}
// AddInt adds int values ii into s.
func (s *Set) AddInt(ii ...int) {
for _, each := range ii {
s.add(each)
}
}
// AddInt64 adds int64 values ii into s.
func (s *Set) AddInt64(ii ...int64) {
for _, each := range ii {
s.add(each)
}
}
// AddUint adds uint values ii into s.
func (s *Set) AddUint(ii ...uint) {
for _, each := range ii {
s.add(each)
}
}
// AddUint64 adds uint64 values ii into s.
func (s *Set) AddUint64(ii ...uint64) {
for _, each := range ii {
s.add(each)
}
}
// AddStr adds string values ss into s.
func (s *Set) AddStr(ss ...string) {
for _, each := range ss {
s.add(each)
}
}
// Contains checks if i is in s.
func (s *Set) Contains(i any) bool {
if len(s.data) == 0 {
return false
}
s.validate(i)
_, ok := s.data[i]
// Contains checks if an item exists in the set.
func (s *Set[T]) Contains(item T) bool {
_, ok := s.data[item]
return ok
}
// Keys returns the keys in s.
func (s *Set) Keys() []any {
var keys []any
for key := range s.data {
keys = append(keys, key)
}
return keys
}
// KeysInt returns the int keys in s.
func (s *Set) KeysInt() []int {
var keys []int
for key := range s.data {
if intKey, ok := key.(int); ok {
keys = append(keys, intKey)
}
}
return keys
}
// KeysInt64 returns int64 keys in s.
func (s *Set) KeysInt64() []int64 {
var keys []int64
for key := range s.data {
if intKey, ok := key.(int64); ok {
keys = append(keys, intKey)
}
}
return keys
}
// KeysUint returns uint keys in s.
func (s *Set) KeysUint() []uint {
var keys []uint
for key := range s.data {
if intKey, ok := key.(uint); ok {
keys = append(keys, intKey)
}
}
return keys
}
// KeysUint64 returns uint64 keys in s.
func (s *Set) KeysUint64() []uint64 {
var keys []uint64
for key := range s.data {
if intKey, ok := key.(uint64); ok {
keys = append(keys, intKey)
}
}
return keys
}
// KeysStr returns string keys in s.
func (s *Set) KeysStr() []string {
var keys []string
for key := range s.data {
if strKey, ok := key.(string); ok {
keys = append(keys, strKey)
}
}
return keys
}
// Remove removes i from s.
func (s *Set) Remove(i any) {
s.validate(i)
delete(s.data, i)
}
// Count returns the number of items in s.
func (s *Set) Count() int {
// Count returns the number of items in the set.
func (s *Set[T]) Count() int {
return len(s.data)
}
func (s *Set) add(i any) {
switch s.tp {
case unmanaged:
// do nothing
case untyped:
s.setType(i)
default:
s.validate(i)
// Keys returns all elements in the set as a slice.
func (s *Set[T]) Keys() []T {
keys := make([]T, 0, len(s.data))
for key := range s.data {
keys = append(keys, key)
}
s.data[i] = lang.Placeholder
return keys
}
func (s *Set) setType(i any) {
// s.tp can only be untyped here
switch i.(type) {
case int:
s.tp = intType
case int64:
s.tp = int64Type
case uint:
s.tp = uintType
case uint64:
s.tp = uint64Type
case string:
s.tp = stringType
}
}
func (s *Set) validate(i any) {
if s.tp == unmanaged {
return
}
switch i.(type) {
case int:
if s.tp != intType {
logx.Errorf("element is int, but set contains elements with type %d", s.tp)
}
case int64:
if s.tp != int64Type {
logx.Errorf("element is int64, but set contains elements with type %d", s.tp)
}
case uint:
if s.tp != uintType {
logx.Errorf("element is uint, but set contains elements with type %d", s.tp)
}
case uint64:
if s.tp != uint64Type {
logx.Errorf("element is uint64, but set contains elements with type %d", s.tp)
}
case string:
if s.tp != stringType {
logx.Errorf("element is string, but set contains elements with type %d", s.tp)
}
}
// Remove removes an item from the set.
func (s *Set[T]) Remove(item T) {
delete(s.data, item)
}

View File

@@ -12,6 +12,105 @@ func init() {
logx.Disable()
}
// Set functionality tests
func TestTypedSetInt(t *testing.T) {
set := NewSet[int]()
values := []int{1, 2, 3, 2, 1} // Contains duplicates
// Test adding
set.Add(values...)
assert.Equal(t, 3, set.Count()) // Should only have 3 elements after deduplication
// Test contains
assert.True(t, set.Contains(1))
assert.True(t, set.Contains(2))
assert.True(t, set.Contains(3))
assert.False(t, set.Contains(4))
// Test getting all keys
keys := set.Keys()
sort.Ints(keys)
assert.EqualValues(t, []int{1, 2, 3}, keys)
// Test removal
set.Remove(2)
assert.False(t, set.Contains(2))
assert.Equal(t, 2, set.Count())
}
func TestTypedSetStringOps(t *testing.T) {
set := NewSet[string]()
values := []string{"a", "b", "c", "b", "a"}
set.Add(values...)
assert.Equal(t, 3, set.Count())
assert.True(t, set.Contains("a"))
assert.True(t, set.Contains("b"))
assert.True(t, set.Contains("c"))
assert.False(t, set.Contains("d"))
keys := set.Keys()
sort.Strings(keys)
assert.EqualValues(t, []string{"a", "b", "c"}, keys)
}
func TestTypedSetClear(t *testing.T) {
set := NewSet[int]()
set.Add(1, 2, 3)
assert.Equal(t, 3, set.Count())
set.Clear()
assert.Equal(t, 0, set.Count())
assert.False(t, set.Contains(1))
}
func TestTypedSetEmpty(t *testing.T) {
set := NewSet[int]()
assert.Equal(t, 0, set.Count())
assert.False(t, set.Contains(1))
assert.Empty(t, set.Keys())
}
func TestTypedSetMultipleTypes(t *testing.T) {
// Test different typed generic sets
intSet := NewSet[int]()
int64Set := NewSet[int64]()
uintSet := NewSet[uint]()
uint64Set := NewSet[uint64]()
stringSet := NewSet[string]()
intSet.Add(1, 2, 3)
int64Set.Add(1, 2, 3)
uintSet.Add(1, 2, 3)
uint64Set.Add(1, 2, 3)
stringSet.Add("1", "2", "3")
assert.Equal(t, 3, intSet.Count())
assert.Equal(t, 3, int64Set.Count())
assert.Equal(t, 3, uintSet.Count())
assert.Equal(t, 3, uint64Set.Count())
assert.Equal(t, 3, stringSet.Count())
}
// Set benchmarks
func BenchmarkTypedIntSet(b *testing.B) {
s := NewSet[int]()
for i := 0; i < b.N; i++ {
s.Add(i)
_ = s.Contains(i)
}
}
func BenchmarkTypedStringSet(b *testing.B) {
s := NewSet[string]()
for i := 0; i < b.N; i++ {
s.Add(string(rune(i)))
_ = s.Contains(string(rune(i)))
}
}
// Legacy tests remain unchanged for backward compatibility
func BenchmarkRawSet(b *testing.B) {
m := make(map[any]struct{})
for i := 0; i < b.N; i++ {
@@ -20,26 +119,10 @@ func BenchmarkRawSet(b *testing.B) {
}
}
func BenchmarkUnmanagedSet(b *testing.B) {
s := NewUnmanagedSet()
for i := 0; i < b.N; i++ {
s.Add(i)
_ = s.Contains(i)
}
}
func BenchmarkSet(b *testing.B) {
s := NewSet()
for i := 0; i < b.N; i++ {
s.AddInt(i)
_ = s.Contains(i)
}
}
func TestAdd(t *testing.T) {
// given
set := NewUnmanagedSet()
values := []any{1, 2, 3}
set := NewSet[int]()
values := []int{1, 2, 3}
// when
set.Add(values...)
@@ -51,82 +134,74 @@ func TestAdd(t *testing.T) {
func TestAddInt(t *testing.T) {
// given
set := NewSet()
set := NewSet[int]()
values := []int{1, 2, 3}
// when
set.AddInt(values...)
set.Add(values...)
// then
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
keys := set.KeysInt()
keys := set.Keys()
sort.Ints(keys)
assert.EqualValues(t, values, keys)
}
func TestAddInt64(t *testing.T) {
// given
set := NewSet()
set := NewSet[int64]()
values := []int64{1, 2, 3}
// when
set.AddInt64(values...)
set.Add(values...)
// then
assert.True(t, set.Contains(int64(1)) && set.Contains(int64(2)) && set.Contains(int64(3)))
assert.Equal(t, len(values), len(set.KeysInt64()))
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
assert.Equal(t, len(values), len(set.Keys()))
}
func TestAddUint(t *testing.T) {
// given
set := NewSet()
set := NewSet[uint]()
values := []uint{1, 2, 3}
// when
set.AddUint(values...)
set.Add(values...)
// then
assert.True(t, set.Contains(uint(1)) && set.Contains(uint(2)) && set.Contains(uint(3)))
assert.Equal(t, len(values), len(set.KeysUint()))
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
assert.Equal(t, len(values), len(set.Keys()))
}
func TestAddUint64(t *testing.T) {
// given
set := NewSet()
set := NewSet[uint64]()
values := []uint64{1, 2, 3}
// when
set.AddUint64(values...)
set.Add(values...)
// then
assert.True(t, set.Contains(uint64(1)) && set.Contains(uint64(2)) && set.Contains(uint64(3)))
assert.Equal(t, len(values), len(set.KeysUint64()))
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
assert.Equal(t, len(values), len(set.Keys()))
}
func TestAddStr(t *testing.T) {
// given
set := NewSet()
set := NewSet[string]()
values := []string{"1", "2", "3"}
// when
set.AddStr(values...)
set.Add(values...)
// then
assert.True(t, set.Contains("1") && set.Contains("2") && set.Contains("3"))
assert.Equal(t, len(values), len(set.KeysStr()))
assert.Equal(t, len(values), len(set.Keys()))
}
func TestContainsWithoutElements(t *testing.T) {
// given
set := NewSet()
// then
assert.False(t, set.Contains(1))
}
func TestContainsUnmanagedWithoutElements(t *testing.T) {
// given
set := NewUnmanagedSet()
set := NewSet[int]()
// then
assert.False(t, set.Contains(1))
@@ -134,8 +209,8 @@ func TestContainsUnmanagedWithoutElements(t *testing.T) {
func TestRemove(t *testing.T) {
// given
set := NewSet()
set.Add([]any{1, 2, 3}...)
set := NewSet[int]()
set.Add([]int{1, 2, 3}...)
// when
set.Remove(2)
@@ -146,57 +221,9 @@ func TestRemove(t *testing.T) {
func TestCount(t *testing.T) {
// given
set := NewSet()
set.Add([]any{1, 2, 3}...)
set := NewSet[int]()
set.Add([]int{1, 2, 3}...)
// then
assert.Equal(t, set.Count(), 3)
}
func TestKeysIntMismatch(t *testing.T) {
set := NewSet()
set.add(int64(1))
set.add(2)
vals := set.KeysInt()
assert.EqualValues(t, []int{2}, vals)
}
func TestKeysInt64Mismatch(t *testing.T) {
set := NewSet()
set.add(1)
set.add(int64(2))
vals := set.KeysInt64()
assert.EqualValues(t, []int64{2}, vals)
}
func TestKeysUintMismatch(t *testing.T) {
set := NewSet()
set.add(1)
set.add(uint(2))
vals := set.KeysUint()
assert.EqualValues(t, []uint{2}, vals)
}
func TestKeysUint64Mismatch(t *testing.T) {
set := NewSet()
set.add(1)
set.add(uint64(2))
vals := set.KeysUint64()
assert.EqualValues(t, []uint64{2}, vals)
}
func TestKeysStrMismatch(t *testing.T) {
set := NewSet()
set.add(1)
set.add("2")
vals := set.KeysStr()
assert.EqualValues(t, []string{"2"}, vals)
}
func TestSetType(t *testing.T) {
set := NewUnmanagedSet()
set.add(1)
set.add("2")
vals := set.Keys()
assert.ElementsMatch(t, []any{1, "2"}, vals)
}

View File

@@ -164,6 +164,7 @@ func (tw *TimingWheel) Stop() {
func (tw *TimingWheel) drainAll(fn func(key, value any)) {
runner := threading.NewTaskRunner(drainWorkers)
for _, slot := range tw.slots {
for e := slot.Front(); e != nil; {
task := e.Value.(*timingEntry)
@@ -177,6 +178,8 @@ func (tw *TimingWheel) drainAll(fn func(key, value any)) {
}
}
}
runner.Wait()
}
func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos, circle int) {

View File

@@ -629,6 +629,157 @@ func TestMoveAndRemoveTask(t *testing.T) {
assert.Equal(t, 0, len(keys))
}
// TestTimingWheel_DrainClosureBug tests the closure capture bug in drainAll
// Issue: https://github.com/zeromicro/go-zero/issues/5314
func TestTimingWheel_DrainClosureBug(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker)
defer tw.Stop()
// Set multiple timers with different values
for i := 0; i < 10; i++ {
tw.SetTimer(i, i*10, testStep*5)
}
// Give time for timers to be set
time.Sleep(time.Millisecond * 100)
var mu sync.Mutex
received := make(map[int]int)
var wg sync.WaitGroup
wg.Add(10)
tw.Drain(func(key, value any) {
mu.Lock()
defer mu.Unlock()
k := key.(int)
v := value.(int)
received[k] = v
wg.Done()
})
wg.Wait()
// Check if all values match their keys
for k, v := range received {
expected := k * 10
assert.Equal(t, expected, v, "key %d should have value %d, got %d", k, expected, v)
}
}
// TestTimingWheel_RunTasksClosureBug tests the closure capture bug in runTasks
// Issue: https://github.com/zeromicro/go-zero/issues/5314
func TestTimingWheel_RunTasksClosureBug(t *testing.T) {
ticker := timex.NewFakeTicker()
var mu sync.Mutex
executed := make(map[int]int)
var wg sync.WaitGroup
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
mu.Lock()
defer mu.Unlock()
key := k.(int)
val := v.(int)
executed[key] = val
wg.Done()
}, ticker)
defer tw.Stop()
// Set multiple timers that should fire in the same tick
count := 10
wg.Add(count)
for i := 0; i < count; i++ {
tw.SetTimer(i, i*10, testStep)
}
// Advance ticker to trigger tasks
ticker.Tick()
// Wait for execution with timeout
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
// Success
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for tasks to execute")
}
// Verify all tasks executed with correct values
assert.Equal(t, count, len(executed), "should have executed all tasks")
for k, v := range executed {
expected := k * 10
assert.Equal(t, expected, v, "key %d should have value %d, got %d", k, expected, v)
}
}
// TestTimingWheel_RunTasksRaceCondition tests for race conditions in runTasks
// This test specifically targets the loop variable capture bug
func TestTimingWheel_RunTasksRaceCondition(t *testing.T) {
// Run multiple times to increase likelihood of catching the bug
for attempt := 0; attempt < 10; attempt++ {
t.Run("", func(t *testing.T) {
ticker := timex.NewFakeTicker()
var mu sync.Mutex
keyValues := make(map[int][]int)
var wg sync.WaitGroup
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
// Add small delay to increase chance of race
time.Sleep(time.Microsecond)
mu.Lock()
defer mu.Unlock()
key := k.(int)
val := v.(int)
keyValues[key] = append(keyValues[key], val)
wg.Done()
}, ticker)
defer tw.Stop()
// Set many timers rapidly to increase chance of race
count := 50
wg.Add(count)
for i := 0; i < count; i++ {
tw.SetTimer(i, i*100, testStep)
}
ticker.Tick()
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for tasks")
}
// Check for duplicates or wrong values
wrongCount := 0
for key, values := range keyValues {
assert.Equal(t, 1, len(values), "key %d should only execute once, got %v", key, values)
if len(values) > 0 {
expected := key * 100
if values[0] != expected {
wrongCount++
t.Logf("BUG DETECTED: key %d should have value %d, got %d", key, expected, values[0])
}
}
}
if wrongCount > 0 {
t.Errorf("Found %d tasks with wrong values due to closure bug", wrongCount)
}
})
}
}
func BenchmarkTimingWheel(b *testing.B) {
b.ReportAllocs()

View File

@@ -21,10 +21,11 @@ const (
var (
fillDefaultUnmarshaler = mapping.NewUnmarshaler(jsonTagKey, mapping.WithDefault())
loaders = map[string]func([]byte, any) error{
".json": LoadFromJsonBytes,
".toml": LoadFromTomlBytes,
".yaml": LoadFromYamlBytes,
".yml": LoadFromYamlBytes,
".json": LoadFromJsonBytes,
".json5": LoadFromJson5Bytes,
".toml": LoadFromTomlBytes,
".yaml": LoadFromYamlBytes,
".yml": LoadFromYamlBytes,
}
)
@@ -41,7 +42,7 @@ func FillDefault(v any) error {
return fillDefaultUnmarshaler.Unmarshal(map[string]any{}, v)
}
// Load loads config into v from file, .json, .yaml and .yml are acceptable.
// Load loads config into v from file, .json, .json5, .toml, .yaml and .yml are acceptable.
func Load(file string, v any, opts ...Option) error {
content, err := os.ReadFile(file)
if err != nil {
@@ -62,14 +63,10 @@ func Load(file string, v any, opts ...Option) error {
return loader([]byte(os.ExpandEnv(string(content))), v)
}
if err = loader(content, v); err != nil {
return err
}
return validate(v)
return loader(content, v)
}
// LoadConfig loads config into v from file, .json, .yaml and .yml are acceptable.
// LoadConfig loads config into v from file, .json, .json5, .toml, .yaml and .yml are acceptable.
// Deprecated: use Load instead.
func LoadConfig(file string, v any, opts ...Option) error {
return Load(file, v, opts...)
@@ -123,6 +120,16 @@ func LoadFromYamlBytes(content []byte, v any) error {
return LoadFromJsonBytes(b, v)
}
// LoadFromJson5Bytes loads config into v from content json5 bytes.
func LoadFromJson5Bytes(content []byte, v any) error {
b, err := encoding.Json5ToJson(content)
if err != nil {
return err
}
return LoadFromJsonBytes(b, v)
}
// LoadConfigFromYamlBytes loads config into v from content yaml bytes.
// Deprecated: use LoadFromYamlBytes instead.
func LoadConfigFromYamlBytes(content []byte, v any) error {
@@ -316,7 +323,7 @@ func toLowerCaseInterface(v any, info *fieldInfo) any {
case map[string]any:
return toLowerCaseKeyMap(vv, info)
case []any:
var arr []any
arr := make([]any, 0, len(vv))
for _, vvv := range vv {
arr = append(arr, toLowerCaseInterface(vvv, info))
}
@@ -368,5 +375,5 @@ func getFullName(parent, child string) string {
return child
}
return strings.Join([]string{parent, child}, ".")
return parent + "." + child
}

View File

@@ -75,6 +75,160 @@ func TestLoadFromJsonBytesArray(t *testing.T) {
assert.EqualValues(t, []string{"foo", "bar"}, expect)
}
func TestConfigJson5(t *testing.T) {
// JSON5 with comments, trailing commas, and unquoted keys
text := `{
// This is a comment
a: 'foo', // single quotes
b: 1,
c: "${FOO}",
d: "abcd!@#$112", // trailing comma
}`
t.Setenv("FOO", "2")
tmpfile, err := createTempFile(t, ".json5", text)
assert.Nil(t, err)
var val struct {
A string `json:"a"`
B int `json:"b"`
C string `json:"c"`
D string `json:"d"`
}
MustLoad(tmpfile, &val)
assert.Equal(t, "foo", val.A)
assert.Equal(t, 1, val.B)
assert.Equal(t, "${FOO}", val.C)
assert.Equal(t, "abcd!@#$112", val.D)
}
func TestConfigJsonStandardParser(t *testing.T) {
// Standard JSON uses standard JSON parser (not JSON5) for backward compatibility
text := `{
"a": "foo",
"b": 1,
"c": "${FOO}",
"d": "abcd!@#$112"
}`
t.Setenv("FOO", "2")
tmpfile, err := createTempFile(t, ".json", text)
assert.Nil(t, err)
var val struct {
A string `json:"a"`
B int `json:"b"`
C string `json:"c"`
D string `json:"d"`
}
MustLoad(tmpfile, &val)
assert.Equal(t, "foo", val.A)
assert.Equal(t, 1, val.B)
assert.Equal(t, "${FOO}", val.C)
assert.Equal(t, "abcd!@#$112", val.D)
}
func TestConfigJsonLargeIntegers(t *testing.T) {
// Test that .json files preserve large integer precision (backward compatibility)
text := `{
"id": 1234567890123456789,
"timestamp": 9223372036854775807
}`
tmpfile, err := createTempFile(t, ".json", text)
assert.Nil(t, err)
var val struct {
ID int64 `json:"id"`
Timestamp int64 `json:"timestamp"`
}
MustLoad(tmpfile, &val)
assert.Equal(t, int64(1234567890123456789), val.ID)
assert.Equal(t, int64(9223372036854775807), val.Timestamp)
}
func TestConfigJson5Env(t *testing.T) {
text := `{
// Comment with env variable
a: "foo",
b: 1,
c: "${FOO}",
d: "abcd!@#$a12 3",
}`
t.Setenv("FOO", "2")
tmpfile, err := createTempFile(t, ".json5", text)
assert.Nil(t, err)
var val struct {
A string `json:"a"`
B int `json:"b"`
C string `json:"c"`
D string `json:"d"`
}
MustLoad(tmpfile, &val, UseEnv())
assert.Equal(t, "foo", val.A)
assert.Equal(t, 1, val.B)
assert.Equal(t, "2", val.C)
assert.Equal(t, "abcd!@# 3", val.D)
}
func TestLoadFromJson5Bytes(t *testing.T) {
// Test JSON5 features: comments, trailing commas, single quotes, unquoted keys
input := []byte(`{
// This is a comment
users: [
{name: 'foo'}, // trailing comma
{Name: "bar"},
],
}`)
var val struct {
Users []struct {
Name string
}
}
assert.NoError(t, LoadFromJson5Bytes(input, &val))
var expect []string
for _, user := range val.Users {
expect = append(expect, user.Name)
}
assert.EqualValues(t, []string{"foo", "bar"}, expect)
}
func TestLoadFromJson5BytesError(t *testing.T) {
// Invalid JSON5 syntax
input := []byte(`{a: foo}`) // unquoted string value (invalid)
var val struct {
A string
}
assert.Error(t, LoadFromJson5Bytes(input, &val))
}
func TestConfigJson5LargeIntegersLimitation(t *testing.T) {
// Document that JSON5 has precision limitations for large integers (>2^53)
// due to JavaScript number semantics. Users should use .json for configs with large IDs.
text := `{
// JSON5 converts numbers to float64, which loses precision for large integers
id: 1234567890123456789
}`
tmpfile, err := createTempFile(t, ".json5", text)
assert.Nil(t, err)
var val struct {
ID int64 `json:"id"`
}
// This will load; depending on the JSON5 implementation, large integers may lose precision.
// This test documents that behavior without requiring loss of precision as an invariant.
err = Load(tmpfile, &val)
assert.NoError(t, err)
t.Logf("loaded JSON5 large integer id=%d (original 1234567890123456789)", val.ID)
}
func TestConfigToml(t *testing.T) {
text := `a = "foo"
b = 1
@@ -1377,3 +1531,242 @@ func (m mockConfig) Validate() error {
return nil
}
func TestGetFullName(t *testing.T) {
tests := []struct {
parent string
child string
want string
}{
{"", "child", "child"},
{"parent", "child", "parent.child"},
{"a.b", "c", "a.b.c"},
{"root", "nested.field", "root.nested.field"},
}
for _, tt := range tests {
t.Run(tt.parent+"."+tt.child, func(t *testing.T) {
got := getFullName(tt.parent, tt.child)
assert.Equal(t, tt.want, got)
})
}
}
// validatorConfig is a test config that implements Validate() for testing validation behavior
type validatorConfig struct {
Value int `json:"value"`
}
func (v *validatorConfig) Validate() error {
if v.Value < 10 {
return errors.New("value must be >= 10")
}
return nil
}
// TestLoadValidation_WithoutEnv tests that validation is called correctly in normal loading path
func TestLoadValidation_WithoutEnv(t *testing.T) {
tests := []struct {
name string
extension string
content string
wantErr bool
errMsg string
}{
{
name: "json valid value",
extension: ".json",
content: `{"value": 15}`,
wantErr: false,
},
{
name: "json invalid value",
extension: ".json",
content: `{"value": 5}`,
wantErr: true,
errMsg: "value must be >= 10",
},
{
name: "yaml valid value",
extension: ".yaml",
content: "value: 20\n",
wantErr: false,
},
{
name: "yaml invalid value",
extension: ".yaml",
content: "value: 3\n",
wantErr: true,
errMsg: "value must be >= 10",
},
{
name: "toml valid value",
extension: ".toml",
content: "value = 100\n",
wantErr: false,
},
{
name: "toml invalid value",
extension: ".toml",
content: "value = 1\n",
wantErr: true,
errMsg: "value must be >= 10",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpfile, err := createTempFile(t, tt.extension, tt.content)
assert.Nil(t, err)
var cfg validatorConfig
err = Load(tmpfile, &cfg)
if tt.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
} else {
assert.NoError(t, err)
}
})
}
}
// TestLoadValidation_WithEnv tests that validation is called correctly with UseEnv() option
func TestLoadValidation_WithEnv(t *testing.T) {
tests := []struct {
name string
extension string
content string
envValue string
wantErr bool
errMsg string
}{
{
name: "json valid value with env",
extension: ".json",
content: `{"value": ${TEST_VALUE}}`,
envValue: "25",
wantErr: false,
},
{
name: "json invalid value with env",
extension: ".json",
content: `{"value": ${TEST_VALUE}}`,
envValue: "7",
wantErr: true,
errMsg: "value must be >= 10",
},
{
name: "yaml valid value with env",
extension: ".yaml",
content: "value: ${TEST_VALUE}\n",
envValue: "50",
wantErr: false,
},
{
name: "yaml invalid value with env",
extension: ".yaml",
content: "value: ${TEST_VALUE}\n",
envValue: "2",
wantErr: true,
errMsg: "value must be >= 10",
},
{
name: "toml valid value with env",
extension: ".toml",
content: "value = ${TEST_VALUE}\n",
envValue: "99",
wantErr: false,
},
{
name: "toml invalid value with env",
extension: ".toml",
content: "value = ${TEST_VALUE}\n",
envValue: "8",
wantErr: true,
errMsg: "value must be >= 10",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("TEST_VALUE", tt.envValue)
tmpfile, err := createTempFile(t, tt.extension, tt.content)
assert.Nil(t, err)
var cfg validatorConfig
err = Load(tmpfile, &cfg, UseEnv())
if tt.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
} else {
assert.NoError(t, err)
}
})
}
}
// TestLoadValidation_Consistency verifies validation behavior is consistent between paths
func TestLoadValidation_Consistency(t *testing.T) {
// Test that both paths (with and without UseEnv) produce the same validation results
const validValue = 15
formats := []struct {
ext string
invalid string
valid string
}{
{".json", `{"value": 5}`, `{"value": 15}`},
{".yaml", "value: 5\n", "value: 15\n"},
{".toml", "value = 5\n", "value = 15\n"},
}
for _, format := range formats {
t.Run("invalid_"+format.ext, func(t *testing.T) {
// Test without UseEnv()
tmpfile1, err := createTempFile(t, format.ext, format.invalid)
assert.Nil(t, err)
var cfg1 validatorConfig
err1 := Load(tmpfile1, &cfg1)
// Test with UseEnv()
tmpfile2, err := createTempFile(t, format.ext, format.invalid)
assert.Nil(t, err)
var cfg2 validatorConfig
err2 := Load(tmpfile2, &cfg2, UseEnv())
// Both should fail validation
assert.Error(t, err1, "validation should fail without UseEnv()")
assert.Error(t, err2, "validation should fail with UseEnv()")
assert.Contains(t, err1.Error(), "value must be >= 10")
assert.Contains(t, err2.Error(), "value must be >= 10")
})
t.Run("valid_"+format.ext, func(t *testing.T) {
// Test without UseEnv()
tmpfile1, err := createTempFile(t, format.ext, format.valid)
assert.Nil(t, err)
var cfg1 validatorConfig
err1 := Load(tmpfile1, &cfg1)
// Test with UseEnv()
tmpfile2, err := createTempFile(t, format.ext, format.valid)
assert.Nil(t, err)
var cfg2 validatorConfig
err2 := Load(tmpfile2, &cfg2, UseEnv())
// Both should pass validation
assert.NoError(t, err1, "validation should pass without UseEnv()")
assert.NoError(t, err2, "validation should pass with UseEnv()")
assert.Equal(t, validValue, cfg1.Value)
assert.Equal(t, validValue, cfg2.Value)
})
}
}

View File

@@ -45,7 +45,7 @@ func LoadProperties(filename string, opts ...Option) (Properties, error) {
raw := make(map[string]string)
for i := range lines {
pair := strings.Split(lines[i], "=")
pair := strings.SplitN(lines[i], "=", 2)
if len(pair) != 2 {
// invalid property format
return nil, &PropertyError{

View File

@@ -92,3 +92,70 @@ func TestLoadBadFile(t *testing.T) {
_, err := LoadProperties("nosuchfile")
assert.NotNil(t, err)
}
func TestProperties_valueWithEqualSymbols(t *testing.T) {
text := `# test with equal symbols in value
db.url=postgres://localhost:5432/db?param=value
math.equation=a=b=c
base64.data=SGVsbG8=World=Test=
url.with.params=http://example.com?foo=bar&baz=qux
empty.value=
key.with.space = value = with = equals`
tmpfile, err := fs.TempFilenameWithText(text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
props, err := LoadProperties(tmpfile)
assert.Nil(t, err)
assert.Equal(t, "postgres://localhost:5432/db?param=value", props.GetString("db.url"))
assert.Equal(t, "a=b=c", props.GetString("math.equation"))
assert.Equal(t, "SGVsbG8=World=Test=", props.GetString("base64.data"))
assert.Equal(t, "http://example.com?foo=bar&baz=qux", props.GetString("url.with.params"))
assert.Equal(t, "", props.GetString("empty.value"))
assert.Equal(t, "value = with = equals", props.GetString("key.with.space"))
}
func TestProperties_edgeCases(t *testing.T) {
tests := []struct {
name string
content string
wantErr bool
errMsg string
}{
{
name: "no equal sign",
content: "invalid line without equal",
wantErr: true,
},
{
name: "only equal sign",
content: "=",
wantErr: false, // "=" 会被解析为空 key 和空 valuelen(pair) == 2是合法的
},
{
name: "empty key",
content: "=value",
wantErr: false, // 空 key 也会被 trim但 len(pair) == 2 所以不会报错
},
{
name: "equal at end",
content: "key.name=",
wantErr: false, // 空 value 是合法的
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpfile, err := fs.TempFilenameWithText(tt.content)
assert.Nil(t, err)
defer os.Remove(tmpfile)
_, err = LoadProperties(tmpfile)
if tt.wantErr {
assert.NotNil(t, err, "expected error for case: %s", tt.name)
} else {
assert.Nil(t, err, "unexpected error for case: %s", tt.name)
}
})
}
}

View File

@@ -1,6 +1,9 @@
package subscriber
import (
"sync"
"sync/atomic"
"github.com/zeromicro/go-zero/core/discov"
"github.com/zeromicro/go-zero/core/logx"
)
@@ -37,6 +40,7 @@ func NewEtcdSubscriber(conf EtcdConf) (Subscriber, error) {
func buildSubOptions(conf EtcdConf) []discov.SubOption {
opts := []discov.SubOption{
discov.WithExactMatch(),
discov.WithContainer(newContainer()),
}
if len(conf.User) > 0 {
@@ -65,3 +69,47 @@ func (s *etcdSubscriber) Value() (string, error) {
return "", nil
}
type container struct {
value atomic.Value
listeners []func()
lock sync.Mutex
}
func newContainer() *container {
return &container{}
}
func (c *container) OnAdd(kv discov.KV) {
c.value.Store([]string{kv.Val})
c.notifyChange()
}
func (c *container) OnDelete(_ discov.KV) {
c.value.Store([]string(nil))
c.notifyChange()
}
func (c *container) AddListener(listener func()) {
c.lock.Lock()
c.listeners = append(c.listeners, listener)
c.lock.Unlock()
}
func (c *container) GetValues() []string {
if vals, ok := c.value.Load().([]string); ok {
return vals
}
return []string(nil)
}
func (c *container) notifyChange() {
c.lock.Lock()
listeners := append(([]func())(nil), c.listeners...)
c.lock.Unlock()
for _, listener := range listeners {
listener()
}
}

View File

@@ -0,0 +1,186 @@
package subscriber
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/discov"
)
const (
actionAdd = iota
actionDel
)
func TestConfigCenterContainer(t *testing.T) {
type action struct {
act int
key string
val string
}
tests := []struct {
name string
do []action
expect []string
}{
{
name: "add one",
do: []action{
{
act: actionAdd,
key: "first",
val: "a",
},
},
expect: []string{
"a",
},
},
{
name: "add two",
do: []action{
{
act: actionAdd,
key: "first",
val: "a",
},
{
act: actionAdd,
key: "second",
val: "b",
},
},
expect: []string{
"b",
},
},
{
name: "add two, delete one",
do: []action{
{
act: actionAdd,
key: "first",
val: "a",
},
{
act: actionAdd,
key: "second",
val: "b",
},
{
act: actionDel,
key: "first",
},
},
expect: []string(nil),
},
{
name: "add two, delete two",
do: []action{
{
act: actionAdd,
key: "first",
val: "a",
},
{
act: actionAdd,
key: "second",
val: "b",
},
{
act: actionDel,
key: "first",
},
{
act: actionDel,
key: "second",
},
},
expect: []string(nil),
},
{
name: "add two, dup values",
do: []action{
{
act: actionAdd,
key: "first",
val: "a",
},
{
act: actionAdd,
key: "second",
val: "b",
},
{
act: actionAdd,
key: "third",
val: "a",
},
},
expect: []string{"a"},
},
{
name: "add three, dup values, delete two, add one",
do: []action{
{
act: actionAdd,
key: "first",
val: "a",
},
{
act: actionAdd,
key: "second",
val: "b",
},
{
act: actionAdd,
key: "third",
val: "a",
},
{
act: actionDel,
key: "first",
},
{
act: actionDel,
key: "second",
},
{
act: actionAdd,
key: "forth",
val: "c",
},
},
expect: []string{"c"},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var changed bool
c := newContainer()
c.AddListener(func() {
changed = true
})
assert.Nil(t, c.GetValues())
assert.False(t, changed)
for _, order := range test.do {
if order.act == actionAdd {
c.OnAdd(discov.KV{
Key: order.key,
Val: order.val,
})
} else {
c.OnDelete(discov.KV{
Key: order.key,
Val: order.val,
})
}
}
assert.True(t, changed)
assert.ElementsMatch(t, test.expect, c.GetValues())
})
}
}

View File

@@ -1,5 +1,10 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: etcdclient.go
//
// Generated by this command:
//
// mockgen -package internal -destination etcdclient_mock.go -source etcdclient.go EtcdClient
//
// Package internal is a generated GoMock package.
package internal
@@ -8,35 +13,36 @@ import (
context "context"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
clientv3 "go.etcd.io/etcd/client/v3"
gomock "go.uber.org/mock/gomock"
grpc "google.golang.org/grpc"
)
// MockEtcdClient is a mock of EtcdClient interface
// MockEtcdClient is a mock of EtcdClient interface.
type MockEtcdClient struct {
ctrl *gomock.Controller
recorder *MockEtcdClientMockRecorder
isgomock struct{}
}
// MockEtcdClientMockRecorder is the mock recorder for MockEtcdClient
// MockEtcdClientMockRecorder is the mock recorder for MockEtcdClient.
type MockEtcdClientMockRecorder struct {
mock *MockEtcdClient
}
// NewMockEtcdClient creates a new mock instance
// NewMockEtcdClient creates a new mock instance.
func NewMockEtcdClient(ctrl *gomock.Controller) *MockEtcdClient {
mock := &MockEtcdClient{ctrl: ctrl}
mock.recorder = &MockEtcdClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockEtcdClient) EXPECT() *MockEtcdClientMockRecorder {
return m.recorder
}
// ActiveConnection mocks base method
// ActiveConnection mocks base method.
func (m *MockEtcdClient) ActiveConnection() *grpc.ClientConn {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ActiveConnection")
@@ -44,13 +50,13 @@ func (m *MockEtcdClient) ActiveConnection() *grpc.ClientConn {
return ret0
}
// ActiveConnection indicates an expected call of ActiveConnection
// ActiveConnection indicates an expected call of ActiveConnection.
func (mr *MockEtcdClientMockRecorder) ActiveConnection() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActiveConnection", reflect.TypeOf((*MockEtcdClient)(nil).ActiveConnection))
}
// Close mocks base method
// Close mocks base method.
func (m *MockEtcdClient) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
@@ -58,13 +64,13 @@ func (m *MockEtcdClient) Close() error {
return ret0
}
// Close indicates an expected call of Close
// Close indicates an expected call of Close.
func (mr *MockEtcdClientMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockEtcdClient)(nil).Close))
}
// Ctx mocks base method
// Ctx mocks base method.
func (m *MockEtcdClient) Ctx() context.Context {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Ctx")
@@ -72,13 +78,13 @@ func (m *MockEtcdClient) Ctx() context.Context {
return ret0
}
// Ctx indicates an expected call of Ctx
// Ctx indicates an expected call of Ctx.
func (mr *MockEtcdClientMockRecorder) Ctx() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ctx", reflect.TypeOf((*MockEtcdClient)(nil).Ctx))
}
// Get mocks base method
// Get mocks base method.
func (m *MockEtcdClient) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, key}
@@ -91,14 +97,14 @@ func (m *MockEtcdClient) Get(ctx context.Context, key string, opts ...clientv3.O
return ret0, ret1
}
// Get indicates an expected call of Get
// Get indicates an expected call of Get.
func (mr *MockEtcdClientMockRecorder) Get(ctx, key any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, key}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockEtcdClient)(nil).Get), varargs...)
}
// Grant mocks base method
// Grant mocks base method.
func (m *MockEtcdClient) Grant(ctx context.Context, ttl int64) (*clientv3.LeaseGrantResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Grant", ctx, ttl)
@@ -107,13 +113,13 @@ func (m *MockEtcdClient) Grant(ctx context.Context, ttl int64) (*clientv3.LeaseG
return ret0, ret1
}
// Grant indicates an expected call of Grant
// Grant indicates an expected call of Grant.
func (mr *MockEtcdClientMockRecorder) Grant(ctx, ttl any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Grant", reflect.TypeOf((*MockEtcdClient)(nil).Grant), ctx, ttl)
}
// KeepAlive mocks base method
// KeepAlive mocks base method.
func (m *MockEtcdClient) KeepAlive(ctx context.Context, id clientv3.LeaseID) (<-chan *clientv3.LeaseKeepAliveResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "KeepAlive", ctx, id)
@@ -122,13 +128,13 @@ func (m *MockEtcdClient) KeepAlive(ctx context.Context, id clientv3.LeaseID) (<-
return ret0, ret1
}
// KeepAlive indicates an expected call of KeepAlive
// KeepAlive indicates an expected call of KeepAlive.
func (mr *MockEtcdClientMockRecorder) KeepAlive(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeepAlive", reflect.TypeOf((*MockEtcdClient)(nil).KeepAlive), ctx, id)
}
// Put mocks base method
// Put mocks base method.
func (m *MockEtcdClient) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, key, val}
@@ -141,14 +147,14 @@ func (m *MockEtcdClient) Put(ctx context.Context, key, val string, opts ...clien
return ret0, ret1
}
// Put indicates an expected call of Put
// Put indicates an expected call of Put.
func (mr *MockEtcdClientMockRecorder) Put(ctx, key, val any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, key, val}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockEtcdClient)(nil).Put), varargs...)
}
// Revoke mocks base method
// Revoke mocks base method.
func (m *MockEtcdClient) Revoke(ctx context.Context, id clientv3.LeaseID) (*clientv3.LeaseRevokeResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Revoke", ctx, id)
@@ -157,13 +163,13 @@ func (m *MockEtcdClient) Revoke(ctx context.Context, id clientv3.LeaseID) (*clie
return ret0, ret1
}
// Revoke indicates an expected call of Revoke
// Revoke indicates an expected call of Revoke.
func (mr *MockEtcdClientMockRecorder) Revoke(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Revoke", reflect.TypeOf((*MockEtcdClient)(nil).Revoke), ctx, id)
}
// Watch mocks base method
// Watch mocks base method.
func (m *MockEtcdClient) Watch(ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan {
m.ctrl.T.Helper()
varargs := []any{ctx, key}
@@ -175,7 +181,7 @@ func (m *MockEtcdClient) Watch(ctx context.Context, key string, opts ...clientv3
return ret0
}
// Watch indicates an expected call of Watch
// Watch indicates an expected call of Watch.
func (mr *MockEtcdClientMockRecorder) Watch(ctx, key any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, key}, opts...)

View File

@@ -207,7 +207,7 @@ func (c *cluster) getCurrent(key watchKey) []KV {
return nil
}
var kvs []KV
kvs := make([]KV, 0, len(watcher.values))
for k, v := range watcher.values {
kvs = append(kvs, KV{
Key: k,
@@ -308,7 +308,7 @@ func (c *cluster) load(cli EtcdClient, key watchKey) int64 {
time.Sleep(coolDownUnstable.AroundDuration(coolDownInterval))
}
var kvs []KV
kvs := make([]KV, 0, len(resp.Kvs))
for _, ev := range resp.Kvs {
kvs = append(kvs, KV{
Key: string(ev.Key),
@@ -352,7 +352,7 @@ func (c *cluster) reload(cli EtcdClient) {
// cancel the previous watches
close(c.done)
c.watchGroup.Wait()
var keys []watchKey
keys := make([]watchKey, 0, len(c.watchers))
for wk, wval := range c.watchers {
keys = append(keys, wk)
if wval.cancel != nil {
@@ -386,8 +386,9 @@ func (c *cluster) watch(cli EtcdClient, key watchKey, rev int64) {
rev = c.load(cli, key)
}
// log the error and retry
// log the error and retry with cooldown to prevent CPU/disk exhaustion
logc.Error(cli.Ctx(), err)
time.Sleep(coolDownUnstable.AroundDuration(coolDownInterval))
}
}
@@ -432,16 +433,16 @@ func (c *cluster) setupWatch(cli EtcdClient, key watchKey, rev int64) (context.C
}
ctx, cancel := context.WithCancel(cli.Ctx())
c.lock.Lock()
if watcher, ok := c.watchers[key]; ok {
watcher.cancel = cancel
} else {
val := newWatchValue()
val.cancel = cancel
c.lock.Lock()
c.watchers[key] = val
c.lock.Unlock()
}
c.lock.Unlock()
rch = cli.Watch(clientv3.WithRequireLeader(ctx), wkey, ops...)

View File

@@ -7,7 +7,6 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/contextx"
"github.com/zeromicro/go-zero/core/lang"
@@ -18,6 +17,7 @@ import (
"go.etcd.io/etcd/api/v3/mvccpb"
clientv3 "go.etcd.io/etcd/client/v3"
"go.etcd.io/etcd/client/v3/mock/mockserver"
"go.uber.org/mock/gomock"
)
var mockLock sync.Mutex
@@ -423,7 +423,7 @@ func TestRegistry_Monitor(t *testing.T) {
GetRegistry().clusters = map[string]*cluster{
getClusterKey(endpoints): {
watchers: map[watchKey]*watchValue{
watchKey{
{
key: "foo",
exactMatch: true,
}: {
@@ -449,7 +449,7 @@ func TestRegistry_Unmonitor(t *testing.T) {
GetRegistry().clusters = map[string]*cluster{
getClusterKey(endpoints): {
watchers: map[watchKey]*watchValue{
watchKey{
{
key: "foo",
exactMatch: true,
}: {
@@ -477,6 +477,72 @@ func TestRegistry_Unmonitor(t *testing.T) {
assert.Nil(t, watchVals)
}
// TestCluster_ConcurrentMonitor tests the race condition fix in setupWatch
// This test specifically covers the scenario from issue #5394 where:
// - addListener() writes to the watchers map (with lock)
// - setupWatch() reads from the watchers map (now with lock after fix)
// Running with -race flag will detect any race conditions
func TestCluster_ConcurrentMonitor(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
cli := NewMockEtcdClient(ctrl)
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(make(chan clientv3.WatchResponse)).AnyTimes()
c := &cluster{
endpoints: []string{"localhost:2379"},
key: "test-cluster",
watchers: make(map[watchKey]*watchValue),
watchGroup: threading.NewRoutineGroup(),
done: make(chan lang.PlaceholderType),
lock: sync.RWMutex{},
}
// Spawn multiple concurrent operations that simulate the race condition:
// - Some goroutines call addListener (write to map)
// - Some goroutines call setupWatch (read from map)
var wg sync.WaitGroup
numGoroutines := 20
wg.Add(numGoroutines)
keys := []watchKey{
{key: "key-0", exactMatch: false},
{key: "key-1", exactMatch: false},
{key: "key-2", exactMatch: false},
}
for i := 0; i < numGoroutines; i++ {
idx := i
go func() {
defer wg.Done()
key := keys[idx%len(keys)]
if idx%2 == 0 {
// Half the goroutines add listeners (write operation)
c.addListener(key, &mockListener{})
} else {
// Half the goroutines setup watches (read operation)
_, _ = c.setupWatch(cli, key, 0)
}
}()
}
// Wait for all goroutines to complete
wg.Wait()
// Verify that watchers were correctly added
c.lock.RLock()
assert.True(t, len(c.watchers) > 0, "watchers should be added")
for _, watcher := range c.watchers {
assert.NotNil(t, watcher, "watcher should not be nil")
}
c.lock.RUnlock()
// Clean up
close(c.done)
}
type mockListener struct {
}

View File

@@ -1,5 +1,10 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: statewatcher.go
//
// Generated by this command:
//
// mockgen -package internal -destination statewatcher_mock.go -source statewatcher.go etcdConn
//
// Package internal is a generated GoMock package.
package internal
@@ -8,34 +13,35 @@ import (
context "context"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
gomock "go.uber.org/mock/gomock"
connectivity "google.golang.org/grpc/connectivity"
)
// MocketcdConn is a mock of etcdConn interface
// MocketcdConn is a mock of etcdConn interface.
type MocketcdConn struct {
ctrl *gomock.Controller
recorder *MocketcdConnMockRecorder
isgomock struct{}
}
// MocketcdConnMockRecorder is the mock recorder for MocketcdConn
// MocketcdConnMockRecorder is the mock recorder for MocketcdConn.
type MocketcdConnMockRecorder struct {
mock *MocketcdConn
}
// NewMocketcdConn creates a new mock instance
// NewMocketcdConn creates a new mock instance.
func NewMocketcdConn(ctrl *gomock.Controller) *MocketcdConn {
mock := &MocketcdConn{ctrl: ctrl}
mock.recorder = &MocketcdConnMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MocketcdConn) EXPECT() *MocketcdConnMockRecorder {
return m.recorder
}
// GetState mocks base method
// GetState mocks base method.
func (m *MocketcdConn) GetState() connectivity.State {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetState")
@@ -43,13 +49,13 @@ func (m *MocketcdConn) GetState() connectivity.State {
return ret0
}
// GetState indicates an expected call of GetState
// GetState indicates an expected call of GetState.
func (mr *MocketcdConnMockRecorder) GetState() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetState", reflect.TypeOf((*MocketcdConn)(nil).GetState))
}
// WaitForStateChange mocks base method
// WaitForStateChange mocks base method.
func (m *MocketcdConn) WaitForStateChange(ctx context.Context, sourceState connectivity.State) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "WaitForStateChange", ctx, sourceState)
@@ -57,7 +63,7 @@ func (m *MocketcdConn) WaitForStateChange(ctx context.Context, sourceState conne
return ret0
}
// WaitForStateChange indicates an expected call of WaitForStateChange
// WaitForStateChange indicates an expected call of WaitForStateChange.
func (mr *MocketcdConnMockRecorder) WaitForStateChange(ctx, sourceState any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WaitForStateChange", reflect.TypeOf((*MocketcdConn)(nil).WaitForStateChange), ctx, sourceState)

View File

@@ -4,7 +4,7 @@ import (
"sync"
"testing"
"github.com/golang/mock/gomock"
"go.uber.org/mock/gomock"
"google.golang.org/grpc/connectivity"
)

View File

@@ -1,5 +1,10 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: updatelistener.go
//
// Generated by this command:
//
// mockgen -package internal -destination updatelistener_mock.go -source updatelistener.go UpdateListener
//
// Package internal is a generated GoMock package.
package internal
@@ -7,51 +12,52 @@ package internal
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
gomock "go.uber.org/mock/gomock"
)
// MockUpdateListener is a mock of UpdateListener interface
// MockUpdateListener is a mock of UpdateListener interface.
type MockUpdateListener struct {
ctrl *gomock.Controller
recorder *MockUpdateListenerMockRecorder
isgomock struct{}
}
// MockUpdateListenerMockRecorder is the mock recorder for MockUpdateListener
// MockUpdateListenerMockRecorder is the mock recorder for MockUpdateListener.
type MockUpdateListenerMockRecorder struct {
mock *MockUpdateListener
}
// NewMockUpdateListener creates a new mock instance
// NewMockUpdateListener creates a new mock instance.
func NewMockUpdateListener(ctrl *gomock.Controller) *MockUpdateListener {
mock := &MockUpdateListener{ctrl: ctrl}
mock.recorder = &MockUpdateListenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockUpdateListener) EXPECT() *MockUpdateListenerMockRecorder {
return m.recorder
}
// OnAdd mocks base method
// OnAdd mocks base method.
func (m *MockUpdateListener) OnAdd(kv KV) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnAdd", kv)
}
// OnAdd indicates an expected call of OnAdd
// OnAdd indicates an expected call of OnAdd.
func (mr *MockUpdateListenerMockRecorder) OnAdd(kv any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnAdd", reflect.TypeOf((*MockUpdateListener)(nil).OnAdd), kv)
}
// OnDelete mocks base method
// OnDelete mocks base method.
func (m *MockUpdateListener) OnDelete(kv KV) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnDelete", kv)
}
// OnDelete indicates an expected call of OnDelete
// OnDelete indicates an expected call of OnDelete.
func (mr *MockUpdateListenerMockRecorder) OnDelete(kv any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnDelete", reflect.TypeOf((*MockUpdateListener)(nil).OnDelete), kv)

View File

@@ -92,12 +92,12 @@ func (p *Publisher) doKeepAlive() error {
default:
cli, err := p.doRegister()
if err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher doRegister: %s", err.Error())
logc.Errorf(cli.Ctx(), "etcd publisher doRegister: %v", err)
break
}
if err := p.keepAliveAsync(cli); err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher keepAliveAsync: %s", err.Error())
logc.Errorf(cli.Ctx(), "etcd publisher keepAliveAsync: %v", err)
break
}
@@ -125,23 +125,48 @@ func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error {
}
threading.GoSafe(func() {
wch := cli.Watch(cli.Ctx(), p.fullKey, clientv3.WithFilterPut())
for {
select {
case _, ok := <-ch:
if !ok {
p.revoke(cli)
if err := p.doKeepAlive(); err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %s", err.Error())
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %v", err)
}
return
}
case c := <-wch:
if c.Err() != nil {
logc.Errorf(cli.Ctx(), "etcd publisher watch: %v", c.Err())
if err := p.doKeepAlive(); err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %v", err)
}
return
}
for _, evt := range c.Events {
if evt.Type == clientv3.EventTypeDelete {
logc.Infof(cli.Ctx(), "etcd publisher watch: %s, event: %v",
evt.Kv.Key, evt.Type)
_, err := cli.Put(cli.Ctx(), p.fullKey, p.value, clientv3.WithLease(p.lease))
if err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher re-put key: %v", err)
} else {
logc.Infof(cli.Ctx(), "etcd publisher re-put key: %s, value: %s",
p.fullKey, p.value)
}
}
}
case <-p.pauseChan:
logc.Infof(cli.Ctx(), "paused etcd renew, key: %s, value: %s", p.key, p.value)
p.revoke(cli)
select {
case <-p.resumeChan:
if err := p.doKeepAlive(); err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %s", err.Error())
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %v", err)
}
return
case <-p.quit.Done():
@@ -176,7 +201,7 @@ func (p *Publisher) register(client internal.EtcdClient) (clientv3.LeaseID, erro
func (p *Publisher) revoke(cli internal.EtcdClient) {
if _, err := cli.Revoke(cli.Ctx(), p.lease); err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher revoke: %s", err.Error())
logc.Errorf(cli.Ctx(), "etcd publisher revoke: %v", err)
}
}

View File

@@ -9,13 +9,14 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/discov/internal"
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stringx"
"go.etcd.io/etcd/api/v3/mvccpb"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/mock/gomock"
"golang.org/x/net/http2"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
@@ -211,6 +212,9 @@ func TestPublisher_keepAliveAsyncQuit(t *testing.T) {
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().KeepAlive(gomock.Any(), id)
// Add Watch expectation for the new watch mechanism
watchChan := make(<-chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(watchChan)
var wg sync.WaitGroup
wg.Add(1)
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
@@ -232,6 +236,9 @@ func TestPublisher_keepAliveAsyncPause(t *testing.T) {
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().KeepAlive(gomock.Any(), id)
// Add Watch expectation for the new watch mechanism
watchChan := make(<-chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(watchChan)
pub := NewPublisher(nil, "thekey", "thevalue")
var wg sync.WaitGroup
wg.Add(1)
@@ -245,6 +252,112 @@ func TestPublisher_keepAliveAsyncPause(t *testing.T) {
wg.Wait()
}
// Test case for key deletion and re-registration (covers lines 148-155)
func TestPublisher_keepAliveAsyncKeyDeletion(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
const id clientv3.LeaseID = 1
cli := internal.NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().KeepAlive(gomock.Any(), id)
// Create a watch channel that will send a delete event
watchChan := make(chan clientv3.WatchResponse, 1)
watchResp := clientv3.WatchResponse{
Events: []*clientv3.Event{{
Type: clientv3.EventTypeDelete,
Kv: &mvccpb.KeyValue{
Key: []byte("thekey"),
},
}},
}
watchChan <- watchResp
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return((<-chan clientv3.WatchResponse)(watchChan))
var wg sync.WaitGroup
wg.Add(1) // Only wait for Revoke call
// Use a channel to signal when Put has been called
putCalled := make(chan struct{})
// Expect the re-put operation when key is deleted
cli.EXPECT().Put(gomock.Any(), "thekey", "thevalue", gomock.Any()).Do(func(_, _, _, _ any) {
close(putCalled) // Signal that Put has been called
}).Return(nil, nil)
// Expect revoke when Stop is called
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
wg.Done()
})
pub := NewPublisher(nil, "thekey", "thevalue")
pub.lease = id
pub.fullKey = "thekey"
assert.Nil(t, pub.keepAliveAsync(cli))
// Wait for Put to be called, then stop
<-putCalled
pub.Stop()
wg.Wait()
}
// Test case for key deletion with re-put error (covers error branch in lines 151-152)
func TestPublisher_keepAliveAsyncKeyDeletionPutError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
const id clientv3.LeaseID = 1
cli := internal.NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().KeepAlive(gomock.Any(), id)
// Create a watch channel that will send a delete event
watchChan := make(chan clientv3.WatchResponse, 1)
watchResp := clientv3.WatchResponse{
Events: []*clientv3.Event{{
Type: clientv3.EventTypeDelete,
Kv: &mvccpb.KeyValue{
Key: []byte("thekey"),
},
}},
}
watchChan <- watchResp
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return((<-chan clientv3.WatchResponse)(watchChan))
var wg sync.WaitGroup
wg.Add(1) // Only wait for Revoke call
// Use a channel to signal when Put has been called
putCalled := make(chan struct{})
// Expect the re-put operation to fail
cli.EXPECT().Put(gomock.Any(), "thekey", "thevalue", gomock.Any()).Do(func(_, _, _, _ any) {
close(putCalled) // Signal that Put has been called
}).Return(nil, errors.New("put error"))
// Expect revoke when Stop is called
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
wg.Done()
})
pub := NewPublisher(nil, "thekey", "thevalue")
pub.lease = id
pub.fullKey = "thekey"
assert.Nil(t, pub.keepAliveAsync(cli))
// Wait for Put to be called, then stop
<-putCalled
pub.Stop()
wg.Wait()
}
func TestPublisher_Resume(t *testing.T) {
publisher := new(Publisher)
publisher.resumeChan = make(chan lang.PlaceholderType)
@@ -273,6 +386,9 @@ func TestPublisher_keepAliveAsync(t *testing.T) {
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().KeepAlive(gomock.Any(), id)
// Add Watch expectation for the new watch mechanism
watchChan := make(<-chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(watchChan)
cli.EXPECT().Grant(gomock.Any(), timeToLive).Return(&clientv3.LeaseGrantResponse{
ID: 1,
}, nil)

View File

@@ -19,8 +19,9 @@ type (
exclusive bool
key string
exactMatch bool
items *container
items Container
}
KV = internal.KV
)
// NewSubscriber returns a Subscriber.
@@ -35,7 +36,9 @@ func NewSubscriber(endpoints []string, key string, opts ...SubOption) (*Subscrib
for _, opt := range opts {
opt(sub)
}
sub.items = newContainer(sub.exclusive)
if sub.items == nil {
sub.items = newContainer(sub.exclusive)
}
if err := internal.GetRegistry().Monitor(endpoints, key, sub.exactMatch, sub.items); err != nil {
return nil, err
@@ -46,7 +49,7 @@ func NewSubscriber(endpoints []string, key string, opts ...SubOption) (*Subscrib
// AddListener adds listener to s.
func (s *Subscriber) AddListener(listener func()) {
s.items.addListener(listener)
s.items.AddListener(listener)
}
// Close closes the subscriber.
@@ -56,7 +59,7 @@ func (s *Subscriber) Close() {
// Values returns all the subscription values.
func (s *Subscriber) Values() []string {
return s.items.getValues()
return s.items.GetValues()
}
// Exclusive means that key value can only be 1:1,
@@ -88,16 +91,32 @@ func WithSubEtcdTLS(certFile, certKeyFile, caFile string, insecureSkipVerify boo
}
}
type container struct {
exclusive bool
values map[string][]string
mapping map[string]string
snapshot atomic.Value
dirty *syncx.AtomicBool
listeners []func()
lock sync.Mutex
// WithContainer provides a custom container to the subscriber.
func WithContainer(container Container) SubOption {
return func(sub *Subscriber) {
sub.items = container
}
}
type (
Container interface {
OnAdd(kv internal.KV)
OnDelete(kv internal.KV)
AddListener(listener func())
GetValues() []string
}
container struct {
exclusive bool
values map[string][]string
mapping map[string]string
snapshot atomic.Value
dirty *syncx.AtomicBool
listeners []func()
lock sync.Mutex
}
)
func newContainer(exclusive bool) *container {
return &container{
exclusive: exclusive,
@@ -141,7 +160,7 @@ func (c *container) addKv(key, value string) ([]string, bool) {
return nil, false
}
func (c *container) addListener(listener func()) {
func (c *container) AddListener(listener func()) {
c.lock.Lock()
c.listeners = append(c.listeners, listener)
c.lock.Unlock()
@@ -170,7 +189,7 @@ func (c *container) doRemoveKey(key string) {
}
}
func (c *container) getValues() []string {
func (c *container) GetValues() []string {
if !c.dirty.True() {
return c.snapshot.Load().([]string)
}

View File

@@ -171,10 +171,10 @@ func TestContainer(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
var changed bool
c := newContainer(exclusive)
c.addListener(func() {
c.AddListener(func() {
changed = true
})
assert.Nil(t, c.getValues())
assert.Nil(t, c.GetValues())
assert.False(t, changed)
for _, order := range test.do {
@@ -193,9 +193,9 @@ func TestContainer(t *testing.T) {
assert.True(t, changed)
assert.True(t, c.dirty.True())
assert.ElementsMatch(t, test.expect, c.getValues())
assert.ElementsMatch(t, test.expect, c.GetValues())
assert.False(t, c.dirty.True())
assert.ElementsMatch(t, test.expect, c.getValues())
assert.ElementsMatch(t, test.expect, c.GetValues())
})
}
}
@@ -204,12 +204,14 @@ func TestContainer(t *testing.T) {
func TestSubscriber(t *testing.T) {
sub := new(Subscriber)
Exclusive()(sub)
sub.items = newContainer(sub.exclusive)
c := newContainer(sub.exclusive)
WithContainer(c)(sub)
sub.items = c
var count int32
sub.AddListener(func() {
atomic.AddInt32(&count, 1)
})
sub.items.notifyChange()
c.notifyChange()
assert.Empty(t, sub.Values())
assert.Equal(t, int32(1), atomic.LoadInt32(&count))
}
@@ -229,12 +231,13 @@ func TestWithSubEtcdAccount(t *testing.T) {
func TestWithExactMatch(t *testing.T) {
sub := new(Subscriber)
WithExactMatch()(sub)
sub.items = newContainer(sub.exclusive)
c := newContainer(sub.exclusive)
sub.items = c
var count int32
sub.AddListener(func() {
atomic.AddInt32(&count, 1)
})
sub.items.notifyChange()
c.notifyChange()
assert.Empty(t, sub.Values())
assert.Equal(t, int32(1), atomic.LoadInt32(&count))
}

View File

@@ -168,7 +168,7 @@ func (s Stream) Count() (count int) {
return
}
// Distinct removes the duplicated items base on the given KeyFunc.
// Distinct removes the duplicated items based on the given KeyFunc.
func (s Stream) Distinct(fn KeyFunc) Stream {
source := make(chan any)
@@ -459,7 +459,7 @@ func (s Stream) Tail(n int64) Stream {
return Range(source)
}
// Walk lets the callers handle each item, the caller may write zero, one or more items base on the given item.
// Walk lets the callers handle each item, the caller may write zero, one or more items based on the given item.
func (s Stream) Walk(fn WalkFunc, opts ...Option) Stream {
option := buildOptions(opts...)
if option.unlimitedWorkers {

View File

@@ -1,8 +1,6 @@
package fx
import (
"io"
"log"
"math/rand"
"reflect"
"runtime"
@@ -13,6 +11,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx/logtest"
"github.com/zeromicro/go-zero/core/stringx"
"go.uber.org/goleak"
)
@@ -238,7 +237,7 @@ func TestLast(t *testing.T) {
func TestMap(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
log.SetOutput(io.Discard)
logtest.Discard(t)
tests := []struct {
mapper MapFunc

View File

@@ -96,7 +96,7 @@ func (h *ConsistentHash) AddWithWeight(node any, weight int) {
h.AddWithReplicas(node, replicas)
}
// Get returns the corresponding node from h base on the given v.
// Get returns the corresponding node from h based on the given v.
func (h *ConsistentHash) Get(v any) (any, bool) {
h.lock.RLock()
defer h.lock.RUnlock()

View File

@@ -25,6 +25,29 @@ func TestMd5Hex(t *testing.T) {
assert.Equal(t, md5Digest, actual)
}
func TestHash(t *testing.T) {
result := Hash([]byte(text))
assert.NotEqual(t, uint64(0), result)
}
func TestHash_Deterministic(t *testing.T) {
data := []byte("consistent-hash-test")
first := Hash(data)
second := Hash(data)
assert.Equal(t, first, second)
}
func TestHash_Empty(t *testing.T) {
// Hash should not panic on empty input.
result := Hash([]byte{})
_ = result
}
func TestMd5Hex_Empty(t *testing.T) {
result := Md5Hex([]byte{})
assert.Equal(t, 32, len(result))
}
func BenchmarkHashFnv(b *testing.B) {
for i := 0; i < b.N; i++ {
h := fnv.New32()

View File

@@ -8,9 +8,25 @@ import (
"strings"
)
// Marshal marshals v into json bytes.
// Marshal marshals v into json bytes, without escaping HTML and removes the trailing newline.
func Marshal(v any) ([]byte, error) {
return json.Marshal(v)
// why not use json.Marshal? https://github.com/golang/go/issues/28453
// it changes the behavior of json.Marshal, like & -> \u0026, < -> \u003c, > -> \u003e
// which is not what we want in API responses
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.SetEscapeHTML(false)
if err := enc.Encode(v); err != nil {
return nil, err
}
bs := buf.Bytes()
// Remove trailing newline added by json.Encoder.Encode
if len(bs) > 0 && bs[len(bs)-1] == '\n' {
bs = bs[:len(bs)-1]
}
return bs, nil
}
// MarshalToString marshals v into a string.

View File

@@ -1,6 +1,7 @@
package jsonx
import (
"fmt"
"strings"
"testing"
@@ -101,3 +102,105 @@ func TestUnmarshalFromReaderError(t *testing.T) {
err := UnmarshalFromReader(strings.NewReader(s), &v)
assert.NotNil(t, err)
}
func Test_doMarshalJson(t *testing.T) {
type args struct {
v any
}
tests := []struct {
name string
args args
want []byte
wantErr assert.ErrorAssertionFunc
}{
{
name: "nil",
args: args{nil},
want: []byte("null"),
wantErr: assert.NoError,
},
{
name: "string",
args: args{"hello"},
want: []byte(`"hello"`),
wantErr: assert.NoError,
},
{
name: "int",
args: args{42},
want: []byte("42"),
wantErr: assert.NoError,
},
{
name: "bool",
args: args{true},
want: []byte("true"),
wantErr: assert.NoError,
},
{
name: "struct",
args: args{
struct {
Name string `json:"name"`
}{Name: "test"},
},
want: []byte(`{"name":"test"}`),
wantErr: assert.NoError,
},
{
name: "slice",
args: args{[]int{1, 2, 3}},
want: []byte("[1,2,3]"),
wantErr: assert.NoError,
},
{
name: "map",
args: args{map[string]int{"a": 1, "b": 2}},
want: []byte(`{"a":1,"b":2}`),
wantErr: assert.NoError,
},
{
name: "unmarshalable type",
args: args{complex(1, 2)},
want: nil,
wantErr: assert.Error,
},
{
name: "channel type",
args: args{make(chan int)},
want: nil,
wantErr: assert.Error,
},
{
name: "url with query params",
args: args{"https://example.com/api?name=test&age=25"},
want: []byte(`"https://example.com/api?name=test&age=25"`),
wantErr: assert.NoError,
},
{
name: "url with encoded query params",
args: args{"https://example.com/api?data=hello%20world&special=%26%3D"},
want: []byte(`"https://example.com/api?data=hello%20world&special=%26%3D"`),
wantErr: assert.NoError,
},
{
name: "url with multiple query params",
args: args{"http://localhost:8080/users?page=1&limit=10&sort=name&order=asc"},
want: []byte(`"http://localhost:8080/users?page=1&limit=10&sort=name&order=asc"`),
wantErr: assert.NoError,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
got, err := Marshal(tt.args.v)
if !tt.wantErr(t, err, fmt.Sprintf("Marshal(%v)", tt.args.v)) {
return
}
assert.Equalf(t, string(tt.want), string(got), "Marshal(%v)", tt.args.v)
})
}
}

View File

@@ -1,47 +1,70 @@
package logx
// A LogConf is a logging config.
type LogConf struct {
// ServiceName represents the service name.
ServiceName string `json:",optional"`
// Mode represents the logging mode, default is `console`.
// console: log to console.
// file: log to file.
// volume: used in k8s, prepend the hostname to the log file name.
Mode string `json:",default=console,options=[console,file,volume]"`
// Encoding represents the encoding type, default is `json`.
// json: json encoding.
// plain: plain text encoding, typically used in development.
Encoding string `json:",default=json,options=[json,plain]"`
// TimeFormat represents the time format, default is `2006-01-02T15:04:05.000Z07:00`.
TimeFormat string `json:",optional"`
// Path represents the log file path, default is `logs`.
Path string `json:",default=logs"`
// Level represents the log level, default is `info`.
Level string `json:",default=info,options=[debug,info,error,severe]"`
// MaxContentLength represents the max content bytes, default is no limit.
MaxContentLength uint32 `json:",optional"`
// Compress represents whether to compress the log file, default is `false`.
Compress bool `json:",optional"`
// Stat represents whether to log statistics, default is `true`.
Stat bool `json:",default=true"`
// KeepDays represents how many days the log files will be kept. Default to keep all files.
// Only take effect when Mode is `file` or `volume`, both work when Rotation is `daily` or `size`.
KeepDays int `json:",optional"`
// StackCooldownMillis represents the cooldown time for stack logging, default is 100ms.
StackCooldownMillis int `json:",default=100"`
// MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever.
// Only take effect when RotationRuleType is `size`.
// Even though `MaxBackups` sets 0, log files will still be removed
// if the `KeepDays` limitation is reached.
MaxBackups int `json:",default=0"`
// MaxSize represents how much space the writing log file takes up. 0 means no limit. The unit is `MB`.
// Only take effect when RotationRuleType is `size`
MaxSize int `json:",default=0"`
// Rotation represents the type of log rotation rule. Default is `daily`.
// daily: daily rotation.
// size: size limited rotation.
Rotation string `json:",default=daily,options=[daily,size]"`
// FileTimeFormat represents the time format for file name, default is `2006-01-02T15:04:05.000Z07:00`.
FileTimeFormat string `json:",optional"`
}
type (
// A LogConf is a logging config.
LogConf struct {
// ServiceName represents the service name.
ServiceName string `json:",optional"`
// Mode represents the logging mode, default is `console`.
// console: log to console.
// file: log to file.
// volume: used in k8s, prepend the hostname to the log file name.
Mode string `json:",default=console,options=[console,file,volume]"`
// Encoding represents the encoding type, default is `json`.
// json: json encoding.
// plain: plain text encoding, typically used in development.
Encoding string `json:",default=json,options=[json,plain]"`
// TimeFormat represents the time format, default is `2006-01-02T15:04:05.000Z07:00`.
TimeFormat string `json:",optional"`
// Path represents the log file path, default is `logs`.
Path string `json:",default=logs"`
// Level represents the log level, default is `info`.
Level string `json:",default=info,options=[debug,info,error,severe]"`
// MaxContentLength represents the max content bytes, default is no limit.
MaxContentLength uint32 `json:",optional"`
// Compress represents whether to compress the log file, default is `false`.
Compress bool `json:",optional"`
// Stat represents whether to log statistics, default is `true`.
Stat bool `json:",default=true"`
// KeepDays represents how many days the log files will be kept. Default to keep all files.
// Only take effect when Mode is `file` or `volume`, both work when Rotation is `daily` or `size`.
KeepDays int `json:",optional"`
// StackCooldownMillis represents the cooldown time for stack logging, default is 100ms.
StackCooldownMillis int `json:",default=100"`
// MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever.
// Only take effect when RotationRuleType is `size`.
// Even though `MaxBackups` sets 0, log files will still be removed
// if the `KeepDays` limitation is reached.
MaxBackups int `json:",default=0"`
// MaxSize represents how much space the writing log file takes up. 0 means no limit. The unit is `MB`.
// Only take effect when RotationRuleType is `size`
MaxSize int `json:",default=0"`
// Rotation represents the type of log rotation rule. Default is `daily`.
// daily: daily rotation.
// size: size limited rotation.
Rotation string `json:",default=daily,options=[daily,size]"`
// FileTimeFormat represents the time format for file name, default is `2006-01-02T15:04:05.000Z07:00`.
FileTimeFormat string `json:",optional"`
// FieldKeys represents the field keys.
FieldKeys fieldKeyConf `json:",optional"`
}
fieldKeyConf struct {
// CallerKey represents the caller key.
CallerKey string `json:",default=caller"`
// ContentKey represents the content key.
ContentKey string `json:",default=content"`
// DurationKey represents the duration key.
DurationKey string `json:",default=duration"`
// LevelKey represents the level key.
LevelKey string `json:",default=level"`
// SpanKey represents the span key.
SpanKey string `json:",default=span"`
// TimestampKey represents the timestamp key.
TimestampKey string `json:",default=@timestamp"`
// TraceKey represents the trace key.
TraceKey string `json:",default=trace"`
// TruncatedKey represents the truncated key.
TruncatedKey string `json:",default=truncated"`
}
)

View File

@@ -7,12 +7,11 @@ import (
)
var (
fieldsContextKey contextKey
globalFields atomic.Value
globalFieldsLock sync.Mutex
)
type contextKey struct{}
type fieldsKey struct{}
// AddGlobalFields adds global fields.
func AddGlobalFields(fields ...LogField) {
@@ -29,16 +28,16 @@ func AddGlobalFields(fields ...LogField) {
// ContextWithFields returns a new context with the given fields.
func ContextWithFields(ctx context.Context, fields ...LogField) context.Context {
if val := ctx.Value(fieldsContextKey); val != nil {
if val := ctx.Value(fieldsKey{}); val != nil {
if arr, ok := val.([]LogField); ok {
allFields := make([]LogField, 0, len(arr)+len(fields))
allFields = append(allFields, arr...)
allFields = append(allFields, fields...)
return context.WithValue(ctx, fieldsContextKey, allFields)
return context.WithValue(ctx, fieldsKey{}, allFields)
}
}
return context.WithValue(ctx, fieldsContextKey, fields)
return context.WithValue(ctx, fieldsKey{}, fields)
}
// WithFields returns a new logger with the given fields.

View File

@@ -34,7 +34,7 @@ func TestAddGlobalFields(t *testing.T) {
func TestContextWithFields(t *testing.T) {
ctx := ContextWithFields(context.Background(), Field("a", 1), Field("b", 2))
vals := ctx.Value(fieldsContextKey)
vals := ctx.Value(fieldsKey{})
assert.NotNil(t, vals)
fields, ok := vals.([]LogField)
assert.True(t, ok)
@@ -43,7 +43,7 @@ func TestContextWithFields(t *testing.T) {
func TestWithFields(t *testing.T) {
ctx := WithFields(context.Background(), Field("a", 1), Field("b", 2))
vals := ctx.Value(fieldsContextKey)
vals := ctx.Value(fieldsKey{})
assert.NotNil(t, vals)
fields, ok := vals.([]LogField)
assert.True(t, ok)
@@ -55,7 +55,7 @@ func TestWithFieldsAppend(t *testing.T) {
ctx := context.WithValue(context.Background(), dummyKey, "dummy")
ctx = ContextWithFields(ctx, Field("a", 1), Field("b", 2))
ctx = ContextWithFields(ctx, Field("c", 3), Field("d", 4))
vals := ctx.Value(fieldsContextKey)
vals := ctx.Value(fieldsKey{})
assert.NotNil(t, vals)
fields, ok := vals.([]LogField)
assert.True(t, ok)
@@ -80,8 +80,8 @@ func TestWithFieldsAppendCopy(t *testing.T) {
ctxa := ContextWithFields(ctx, af)
ctxb := ContextWithFields(ctx, bf)
assert.EqualValues(t, af, ctxa.Value(fieldsContextKey).([]LogField)[count])
assert.EqualValues(t, bf, ctxb.Value(fieldsContextKey).([]LogField)[count])
assert.EqualValues(t, af, ctxa.Value(fieldsKey{}).([]LogField)[count])
assert.EqualValues(t, bf, ctxb.Value(fieldsKey{}).([]LogField)[count])
}
func BenchmarkAtomicValue(b *testing.B) {

View File

@@ -10,7 +10,6 @@ import (
"runtime/debug"
"sync"
"sync/atomic"
"time"
"github.com/zeromicro/go-zero/core/sysx"
)
@@ -187,39 +186,9 @@ func Errorw(msg string, fields ...LogField) {
// Field returns a LogField for the given key and value.
func Field(key string, value any) LogField {
switch val := value.(type) {
case error:
return LogField{Key: key, Value: encodeError(val)}
case []error:
var errs []string
for _, err := range val {
errs = append(errs, encodeError(err))
}
return LogField{Key: key, Value: errs}
case time.Duration:
return LogField{Key: key, Value: fmt.Sprint(val)}
case []time.Duration:
var durs []string
for _, dur := range val {
durs = append(durs, fmt.Sprint(dur))
}
return LogField{Key: key, Value: durs}
case []time.Time:
var times []string
for _, t := range val {
times = append(times, fmt.Sprint(t))
}
return LogField{Key: key, Value: times}
case fmt.Stringer:
return LogField{Key: key, Value: encodeStringer(val)}
case []fmt.Stringer:
var strs []string
for _, str := range val {
strs = append(strs, encodeStringer(str))
}
return LogField{Key: key, Value: strs}
default:
return LogField{Key: key, Value: val}
return LogField{
Key: key,
Value: value,
}
}
@@ -307,7 +276,8 @@ func SetUp(c LogConf) (err error) {
// Because multiple services in one process might call SetUp respectively.
// Need to wait for the first caller to complete the execution.
setupOnce.Do(func() {
setupLogLevel(c)
setupLogLevel(c.Level)
setupFieldKeys(c.FieldKeys)
if !c.Stat {
DisableStat()
@@ -511,8 +481,35 @@ func handleOptions(opts []LogOption) {
}
}
func setupLogLevel(c LogConf) {
switch c.Level {
func setupFieldKeys(c fieldKeyConf) {
if len(c.CallerKey) > 0 {
callerKey = c.CallerKey
}
if len(c.ContentKey) > 0 {
contentKey = c.ContentKey
}
if len(c.DurationKey) > 0 {
durationKey = c.DurationKey
}
if len(c.LevelKey) > 0 {
levelKey = c.LevelKey
}
if len(c.SpanKey) > 0 {
spanKey = c.SpanKey
}
if len(c.TimestampKey) > 0 {
timestampKey = c.TimestampKey
}
if len(c.TraceKey) > 0 {
traceKey = c.TraceKey
}
if len(c.TruncatedKey) > 0 {
truncatedKey = c.TruncatedKey
}
}
func setupLogLevel(level string) {
switch level {
case levelDebug:
SetLevel(DebugLevel)
case levelInfo:

View File

@@ -1,6 +1,7 @@
package logx
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -16,6 +17,8 @@ import (
"time"
"github.com/stretchr/testify/assert"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/sdk/trace"
)
var (
@@ -244,7 +247,7 @@ func TestStructedLogDebugf(t *testing.T) {
defer writer.Store(old)
doTestStructedLog(t, levelDebug, w, func(v ...any) {
Debugf(fmt.Sprint(v...))
Debugf("%s", fmt.Sprint(v...))
})
}
@@ -556,7 +559,7 @@ func TestStructedLogSlowf(t *testing.T) {
defer writer.Store(old)
doTestStructedLog(t, levelSlow, w, func(v ...any) {
Slowf(fmt.Sprint(v...))
Slowf("%s", fmt.Sprint(v...))
})
}
@@ -622,7 +625,7 @@ func TestStructedLogStatf(t *testing.T) {
defer writer.Store(old)
doTestStructedLog(t, levelStat, w, func(v ...any) {
Statf(fmt.Sprint(v...))
Statf("%s", fmt.Sprint(v...))
})
}
@@ -642,7 +645,7 @@ func TestStructedLogSeveref(t *testing.T) {
defer writer.Store(old)
doTestStructedLog(t, levelSevere, w, func(v ...any) {
Severef(fmt.Sprint(v...))
Severef("%s", fmt.Sprint(v...))
})
}
@@ -776,15 +779,9 @@ func TestSetup(t *testing.T) {
MaxBackups: 3,
MaxSize: 1024 * 1024,
}))
setupLogLevel(LogConf{
Level: levelInfo,
})
setupLogLevel(LogConf{
Level: levelError,
})
setupLogLevel(LogConf{
Level: levelSevere,
})
setupLogLevel(levelInfo)
setupLogLevel(levelError)
setupLogLevel(levelSevere)
_, err := createOutput("")
assert.NotNil(t, err)
Disable()
@@ -856,6 +853,95 @@ func TestWithKeepDays(t *testing.T) {
assert.Equal(t, 1, opt.keepDays)
}
func TestWithField_LogLevel(t *testing.T) {
tests := []struct {
name string
level uint32
fn func(string, ...LogField)
count int32
}{
{
name: "debug/info",
level: DebugLevel,
fn: Infow,
count: 1,
},
{
name: "info/error",
level: InfoLevel,
fn: Errorw,
count: 1,
},
{
name: "info/info",
level: InfoLevel,
fn: Infow,
count: 1,
},
{
name: "info/severe",
level: InfoLevel,
fn: Errorw,
count: 1,
},
{
name: "error/info",
level: ErrorLevel,
fn: Infow,
count: 0,
},
{
name: "error/debug",
level: ErrorLevel,
fn: Debugw,
count: 0,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
olevel := atomic.LoadUint32(&logLevel)
SetLevel(tt.level)
defer SetLevel(olevel)
var val countingStringer
tt.fn("hello there", Field("foo", &val))
assert.Equal(t, tt.count, val.Count())
})
}
}
func TestWithField_LogLevelWithContext(t *testing.T) {
t.Run("context more than once with info/info", func(t *testing.T) {
olevel := atomic.LoadUint32(&logLevel)
SetLevel(InfoLevel)
defer SetLevel(olevel)
var val countingStringer
ctx := ContextWithFields(context.Background(), Field("foo", &val))
logger := WithContext(ctx)
logger.Info("hello there")
logger.Info("hello there")
logger.Info("hello there")
assert.True(t, val.Count() > 0)
})
t.Run("context more than once with error/info", func(t *testing.T) {
olevel := atomic.LoadUint32(&logLevel)
SetLevel(ErrorLevel)
defer SetLevel(olevel)
var val countingStringer
ctx := ContextWithFields(context.Background(), Field("foo", &val))
logger := WithContext(ctx)
logger.Info("hello there")
logger.Info("hello there")
logger.Info("hello there")
assert.Equal(t, int32(0), val.Count())
})
}
func BenchmarkCopyByteSliceAppend(b *testing.B) {
for i := 0; i < b.N; i++ {
var buf []byte
@@ -1054,3 +1140,79 @@ type panicStringer struct {
func (s panicStringer) String() string {
panic("panic")
}
type countingStringer struct {
count int32
}
func (s *countingStringer) Count() int32 {
return atomic.LoadInt32(&s.count)
}
func (s *countingStringer) String() string {
atomic.AddInt32(&s.count, 1)
return "countingStringer"
}
func TestLogKey(t *testing.T) {
setupOnce = sync.Once{}
MustSetup(LogConf{
ServiceName: "any",
Mode: "console",
Encoding: "json",
TimeFormat: timeFormat,
FieldKeys: fieldKeyConf{
CallerKey: "_caller",
ContentKey: "_content",
DurationKey: "_duration",
LevelKey: "_level",
SpanKey: "_span",
TimestampKey: "_timestamp",
TraceKey: "_trace",
TruncatedKey: "_truncated",
},
})
t.Cleanup(func() {
setupFieldKeys(fieldKeyConf{
CallerKey: defaultCallerKey,
ContentKey: defaultContentKey,
DurationKey: defaultDurationKey,
LevelKey: defaultLevelKey,
SpanKey: defaultSpanKey,
TimestampKey: defaultTimestampKey,
TraceKey: defaultTraceKey,
TruncatedKey: defaultTruncatedKey,
})
})
const message = "hello there"
w := new(mockWriter)
old := writer.Swap(w)
defer writer.Store(old)
otp := otel.GetTracerProvider()
tp := trace.NewTracerProvider(trace.WithSampler(trace.AlwaysSample()))
otel.SetTracerProvider(tp)
defer otel.SetTracerProvider(otp)
ctx, span := tp.Tracer("trace-id").Start(context.Background(), "span-id")
defer span.End()
WithContext(ctx).WithDuration(time.Second).Info(message)
now := time.Now()
var m map[string]string
if err := json.Unmarshal([]byte(w.String()), &m); err != nil {
t.Error(err)
}
assert.Equal(t, "info", m["_level"])
assert.Equal(t, message, m["_content"])
assert.Equal(t, "1000.0ms", m["_duration"])
assert.Regexp(t, `logx/logs_test.go:\d+`, m["_caller"])
assert.NotEmpty(t, m["_trace"])
assert.NotEmpty(t, m["_span"])
parsedTime, err := time.Parse(timeFormat, m["_timestamp"])
assert.True(t, err == nil)
assert.Equal(t, now.Minute(), parsedTime.Minute())
}

View File

@@ -224,7 +224,7 @@ func (l *richLogger) buildFields(fields ...LogField) []LogField {
fields = append(fields, Field(spanKey, spanID))
}
val := l.ctx.Value(fieldsContextKey)
val := l.ctx.Value(fieldsKey{})
if val != nil {
if arr, ok := val.([]LogField); ok {
fields = append(fields, arr...)

View File

@@ -423,3 +423,49 @@ type mockValue struct {
Foo string `json:"foo"`
Content any `json:"content"`
}
type testJson struct {
Name string `json:"name"`
Age int `json:"age"`
Score float64 `json:"score"`
}
func (t testJson) MarshalJSON() ([]byte, error) {
type testJsonImpl testJson
return json.Marshal(testJsonImpl(t))
}
func (t testJson) String() string {
return fmt.Sprintf("%s %d %f", t.Name, t.Age, t.Score)
}
func TestLogWithJson(t *testing.T) {
w := new(mockWriter)
old := writer.Swap(w)
writer.lock.RLock()
defer func() {
writer.lock.RUnlock()
writer.Store(old)
}()
l := WithContext(context.Background()).WithFields(Field("bar", testJson{
Name: "foo",
Age: 1,
Score: 1.0,
}))
l.Info(testlog)
type mockValue2 struct {
mockValue
Bar testJson `json:"bar"`
}
var val mockValue2
err := json.Unmarshal([]byte(w.String()), &val)
assert.NoError(t, err)
assert.Equal(t, testlog, val.Content)
assert.Equal(t, "foo", val.Bar.Name)
assert.Equal(t, 1, val.Bar.Age)
assert.Equal(t, 1.0, val.Bar.Score)
}

View File

@@ -66,7 +66,7 @@ type (
gzip bool
}
// SizeLimitRotateRule a rotation rule that make the log file rotated base on size
// SizeLimitRotateRule a rotation rule that makes the log file rotated based on size
SizeLimitRotateRule struct {
DailyRotateRule
maxSize int64
@@ -211,7 +211,7 @@ func (r *SizeLimitRotateRule) OutdatedFiles() []string {
}
}
var result []string
result := make([]string, 0, len(outdated))
for k := range outdated {
result = append(result, k)
}

21
core/logx/sensitive.go Normal file
View File

@@ -0,0 +1,21 @@
package logx
// Sensitive is an interface that defines a method for masking sensitive information in logs.
// It is typically implemented by types that contain sensitive data,
// such as passwords or personal information.
// Infov, Errorv, Debugv, and Slowv methods will call this method to mask sensitive data.
// The values in LogField will also be masked if they implement the Sensitive interface.
type Sensitive interface {
// MaskSensitive masks sensitive information in the log.
MaskSensitive() any
}
// maskSensitive returns the value returned by MaskSensitive method,
// if the value implements Sensitive interface.
func maskSensitive(v any) any {
if s, ok := v.(Sensitive); ok {
return s.MaskSensitive()
}
return v
}

View File

@@ -0,0 +1,50 @@
package logx
import (
"testing"
"github.com/stretchr/testify/assert"
)
const maskedContent = "******"
type User struct {
Name string
Pass string
}
func (u User) MaskSensitive() any {
return User{
Name: u.Name,
Pass: maskedContent,
}
}
type NonSensitiveUser struct {
Name string
Pass string
}
func TestMaskSensitive(t *testing.T) {
t.Run("sensitive", func(t *testing.T) {
user := User{
Name: "kevin",
Pass: "123",
}
mu := maskSensitive(user)
assert.Equal(t, user.Name, mu.(User).Name)
assert.Equal(t, maskedContent, mu.(User).Pass)
})
t.Run("non-sensitive", func(t *testing.T) {
user := NonSensitiveUser{
Name: "kevin",
Pass: "123",
}
mu := maskSensitive(user)
assert.Equal(t, user.Name, mu.(NonSensitiveUser).Name)
assert.Equal(t, user.Pass, mu.(NonSensitiveUser).Pass)
})
}

View File

@@ -53,14 +53,14 @@ const (
)
const (
callerKey = "caller"
contentKey = "content"
durationKey = "duration"
levelKey = "level"
spanKey = "span"
timestampKey = "@timestamp"
traceKey = "trace"
truncatedKey = "truncated"
defaultCallerKey = "caller"
defaultContentKey = "content"
defaultDurationKey = "duration"
defaultLevelKey = "level"
defaultSpanKey = "span"
defaultTimestampKey = "@timestamp"
defaultTraceKey = "trace"
defaultTruncatedKey = "truncated"
)
var (
@@ -73,3 +73,14 @@ var (
truncatedField = Field(truncatedKey, true)
)
var (
callerKey = defaultCallerKey
contentKey = defaultContentKey
durationKey = defaultDurationKey
levelKey = defaultLevelKey
spanKey = defaultSpanKey
timestampKey = defaultTimestampKey
traceKey = defaultTraceKey
truncatedKey = defaultTruncatedKey
)

View File

@@ -10,6 +10,7 @@ import (
"runtime/debug"
"sync"
"sync/atomic"
"time"
fatihcolor "github.com/fatih/color"
"github.com/zeromicro/go-zero/core/color"
@@ -211,7 +212,6 @@ func newFileWriter(c LogConf) (Writer, error) {
statFile := path.Join(c.Path, statFilename)
handleOptions(opts)
setupLogLevel(c)
if infoLog, err = createOutput(accessFile); err != nil {
return nil, err
@@ -365,19 +365,25 @@ func mergeGlobalFields(fields []LogField) []LogField {
}
func output(writer io.Writer, level string, val any, fields ...LogField) {
// only truncate string content, don't know how to truncate the values of other types.
if v, ok := val.(string); ok {
switch v := val.(type) {
case string:
// only truncate string content, don't know how to truncate the values of other types.
maxLen := atomic.LoadUint32(&maxContentLength)
if maxLen > 0 && len(v) > int(maxLen) {
val = v[:maxLen]
fields = append(fields, truncatedField)
}
case Sensitive:
val = v.MaskSensitive()
}
// +3 for timestamp, level and content
entry := make(logEntry, len(fields)+3)
for _, field := range fields {
entry[field.Key] = field.Value
// mask sensitive data before processing types,
// in case field.Value is a sensitive type and also implemented fmt.Stringer.
mval := maskSensitive(field.Value)
entry[field.Key] = processFieldValue(mval)
}
switch atomic.LoadUint32(&encoding) {
@@ -392,6 +398,45 @@ func output(writer io.Writer, level string, val any, fields ...LogField) {
}
}
func processFieldValue(value any) any {
switch val := value.(type) {
case error:
return encodeError(val)
case []error:
var errs []string
for _, err := range val {
errs = append(errs, encodeError(err))
}
return errs
case time.Duration:
return fmt.Sprint(val)
case []time.Duration:
var durs []string
for _, dur := range val {
durs = append(durs, fmt.Sprint(dur))
}
return durs
case []time.Time:
var times []string
for _, t := range val {
times = append(times, fmt.Sprint(t))
}
return times
case json.Marshaler:
return val
case fmt.Stringer:
return encodeStringer(val)
case []fmt.Stringer:
var strs []string
for _, str := range val {
strs = append(strs, encodeStringer(str))
}
return strs
default:
return val
}
}
func wrapLevelWithColor(level string) string {
var colour color.Color
switch level {
@@ -399,6 +444,8 @@ func wrapLevelWithColor(level string) string {
colour = color.FgRed
case levelError:
colour = color.FgRed
case levelSevere:
colour = color.FgRed
case levelFatal:
colour = color.FgRed
case levelInfo:

View File

@@ -225,6 +225,48 @@ func TestWritePlainDuplicate(t *testing.T) {
assert.Contains(t, buf.String(), "second=c")
}
func TestLogWithSensitive(t *testing.T) {
old := atomic.SwapUint32(&encoding, plainEncodingType)
t.Cleanup(func() {
atomic.StoreUint32(&encoding, old)
})
t.Run("sensitive", func(t *testing.T) {
var buf bytes.Buffer
output(&buf, levelInfo, User{
Name: "kevin",
Pass: "123",
}, LogField{
Key: "first",
Value: "a",
}, LogField{
Key: "first",
Value: "b",
})
assert.Contains(t, buf.String(), maskedContent)
assert.NotContains(t, buf.String(), "first=a")
assert.Contains(t, buf.String(), "first=b")
})
t.Run("sensitive fields", func(t *testing.T) {
var buf bytes.Buffer
output(&buf, levelInfo, "foo", LogField{
Key: "first",
Value: User{
Name: "kevin",
Pass: "123",
},
}, LogField{
Key: "second",
Value: "b",
})
assert.Contains(t, buf.String(), "foo")
assert.Contains(t, buf.String(), "first")
assert.Contains(t, buf.String(), maskedContent)
assert.Contains(t, buf.String(), "second=b")
})
}
func TestLogWithLimitContentLength(t *testing.T) {
maxLen := atomic.LoadUint32(&maxContentLength)
atomic.StoreUint32(&maxContentLength, 10)

View File

@@ -3,6 +3,7 @@ package mapping
import (
"fmt"
"reflect"
"slices"
"strings"
)
@@ -152,15 +153,8 @@ func validateOptional(field reflect.StructField, value reflect.Value) error {
}
func validateOptions(value reflect.Value, opt *fieldOptions) error {
var found bool
val := fmt.Sprint(value.Interface())
for i := range opt.Options {
if opt.Options[i] == val {
found = true
break
}
}
if !found {
if !slices.Contains(opt.Options, val) {
return fmt.Errorf("field %q not in options", val)
}

View File

@@ -622,9 +622,19 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
return u.fillSliceFromString(fieldType, value, mapValue, fullName)
case valueKind == reflect.String && derefedFieldType == durationType:
return fillDurationValue(fieldType, value, mapValue.(string))
v, err := convertToString(mapValue, fullName)
if err != nil {
return err
}
return fillDurationValue(fieldType, value, v)
case valueKind == reflect.String && typeKind == reflect.Struct && u.implementsUnmarshaler(fieldType):
return u.fillUnmarshalerStruct(fieldType, value, mapValue.(string))
v, err := convertToString(mapValue, fullName)
if err != nil {
return err
}
return u.fillUnmarshalerStruct(fieldType, value, v)
default:
return u.processFieldPrimitive(fieldType, value, mapValue, opts, fullName)
}
@@ -755,24 +765,26 @@ func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value ref
return err
}
fieldKind := fieldType.Kind()
switch fieldKind {
case reflect.Bool:
derefType := Deref(fieldType)
derefKind := derefType.Kind()
switch {
case derefKind == reflect.String:
SetValue(fieldType, value, toReflectValue(derefType, envVal))
return nil
case derefKind == reflect.Bool:
val, err := strconv.ParseBool(envVal)
if err != nil {
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
}
value.SetBool(val)
SetValue(fieldType, value, toReflectValue(derefType, val))
return nil
case durationType.Kind():
case derefType == durationType:
// time.Duration is a special case, its derefKind is reflect.Int64.
if err := fillDurationValue(fieldType, value, envVal); err != nil {
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
}
return nil
case reflect.String:
value.SetString(envVal)
return nil
default:
return u.processFieldPrimitiveWithJSONNumber(fieldType, value, json.Number(envVal), opts, fullName)

View File

@@ -203,6 +203,20 @@ func TestUnmarshalDuration(t *testing.T) {
}
}
func TestUnmarshalDurationUnexpectedError(t *testing.T) {
type inner struct {
Duration time.Duration `key:"duration"`
}
content := "{\"duration\": 1}"
var m = map[string]any{}
err := jsonx.Unmarshal([]byte(content), &m)
assert.NoError(t, err)
var in inner
err = UnmarshalKey(m, &in)
assert.Error(t, err)
assert.Contains(t, err.Error(), "expect string")
}
func TestUnmarshalDurationDefault(t *testing.T) {
type inner struct {
Int int `key:"int"`
@@ -4665,6 +4679,23 @@ func TestUnmarshal_EnvInt(t *testing.T) {
}
}
func TestUnmarshal_EnvInt64(t *testing.T) {
type Value struct {
Age int64 `key:"age,env=TEST_NAME_INT64"`
}
const (
envName = "TEST_NAME_INT64"
envVal = "88"
)
t.Setenv(envName, envVal)
var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
assert.Equal(t, int64(88), v.Age)
}
}
func TestUnmarshal_EnvIntOverwrite(t *testing.T) {
type Value struct {
Age int `key:"age,env=TEST_NAME_INT"`
@@ -4770,20 +4801,33 @@ func TestUnmarshal_EnvBoolBad(t *testing.T) {
}
func TestUnmarshal_EnvDuration(t *testing.T) {
type Value struct {
Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
}
const (
envName = "TEST_NAME_DURATION"
envVal = "1s"
)
t.Setenv(envName, envVal)
var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
assert.Equal(t, time.Second, v.Duration)
}
t.Run("valid duration", func(t *testing.T) {
type Value struct {
Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
}
var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
assert.Equal(t, time.Second, v.Duration)
}
})
t.Run("ptr of duration", func(t *testing.T) {
type Value struct {
Duration *time.Duration `key:"duration,env=TEST_NAME_DURATION"`
}
var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
assert.Equal(t, time.Second, *v.Duration)
}
})
}
func TestUnmarshal_EnvDurationBadValue(t *testing.T) {
@@ -5995,6 +6039,16 @@ func TestUnmarshal_Unmarshaler(t *testing.T) {
}, &v))
assert.Nil(t, v.Foo)
})
t.Run("json.Number", func(t *testing.T) {
v := struct {
Foo *mockUnmarshaler `json:"name"`
}{}
m := map[string]any{
"name": json.Number("123"),
}
assert.Error(t, UnmarshalJsonMap(m, &v))
})
}
func TestParseJsonStringValue(t *testing.T) {
@@ -6029,6 +6083,105 @@ func TestParseJsonStringValue(t *testing.T) {
})
}
// issue #5033, string type
func TestUnmarshalFromEnvString(t *testing.T) {
t.Setenv("STRING_ENV", "dev")
t.Run("by value", func(t *testing.T) {
type (
Env string
Config struct {
Env Env `json:",env=STRING_ENV,default=prod"`
}
)
var c Config
if assert.NoError(t, UnmarshalJsonMap(map[string]any{}, &c)) {
assert.Equal(t, Env("dev"), c.Env)
}
})
t.Run("by ptr", func(t *testing.T) {
type (
Env string
Config struct {
Env *Env `json:",env=STRING_ENV,default=prod"`
}
)
var c Config
if assert.NoError(t, UnmarshalJsonMap(map[string]any{}, &c)) {
assert.Equal(t, Env("dev"), *c.Env)
}
})
}
// issue #5033, bool type
func TestUnmarshalFromEnvBool(t *testing.T) {
t.Setenv("BOOL_ENV", "true")
t.Run("by value", func(t *testing.T) {
type (
Env bool
Config struct {
Env Env `json:",env=BOOL_ENV,default=false"`
}
)
var c Config
if assert.NoError(t, UnmarshalJsonMap(map[string]any{}, &c)) {
assert.Equal(t, Env(true), c.Env)
}
})
t.Run("by ptr", func(t *testing.T) {
type (
Env bool
Config struct {
Env *Env `json:",env=BOOL_ENV,default=false"`
}
)
var c Config
if assert.NoError(t, UnmarshalJsonMap(map[string]any{}, &c)) {
assert.Equal(t, Env(true), *c.Env)
}
})
}
// issue #5033, customized int type
func TestUnmarshalFromEnvInt(t *testing.T) {
t.Setenv("INT_ENV", "2")
t.Run("by value", func(t *testing.T) {
type (
Env int
Config struct {
Env Env `json:",env=INT_ENV,default=0"`
}
)
var c Config
if assert.NoError(t, UnmarshalJsonMap(map[string]any{}, &c)) {
assert.Equal(t, Env(2), c.Env)
}
})
t.Run("by ptr", func(t *testing.T) {
type (
Env int
Config struct {
Env *Env `json:",env=INT_ENV,default=0"`
}
)
var c Config
if assert.NoError(t, UnmarshalJsonMap(map[string]any{}, &c)) {
assert.Equal(t, Env(2), *c.Env)
}
})
}
func BenchmarkDefaultValue(b *testing.B) {
for i := 0; i < b.N; i++ {
var a struct {

View File

@@ -1,6 +1,7 @@
package mapping
import (
"cmp"
"encoding/json"
"errors"
"fmt"
@@ -12,7 +13,6 @@ import (
"sync"
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/stringx"
)
const (
@@ -92,17 +92,25 @@ func ValidatePtr(v reflect.Value) error {
return nil
}
func convertToString(val any, fullName string) (string, error) {
v, ok := val.(string)
if !ok {
return "", fmt.Errorf("expect string for field %s, but got type %T", fullName, val)
}
return v, nil
}
func convertTypeFromString(kind reflect.Kind, str string) (any, error) {
switch kind {
case reflect.Bool:
switch strings.ToLower(str) {
case "1", "true":
if str == "1" || strings.EqualFold(str, "true") {
return true, nil
case "0", "false":
return false, nil
default:
return false, errTypeMismatch
}
if str == "0" || strings.EqualFold(str, "false") {
return false, nil
}
return false, errTypeMismatch
case reflect.Int:
return strconv.ParseInt(str, 10, intSize)
case reflect.Int8:
@@ -270,7 +278,7 @@ func parseKeyAndOptions(tagName string, field reflect.StructField) (string, *fie
cache, ok := optionsCache[value]
cacheLock.RUnlock()
if ok {
return stringx.TakeOne(cache.key, field.Name), cache.options, cache.err
return cmp.Or(cache.key, field.Name), cache.options, cache.err
}
key, options, err := doParseKeyAndOptions(field, value)
@@ -282,7 +290,7 @@ func parseKeyAndOptions(tagName string, field reflect.StructField) (string, *fie
}
cacheLock.Unlock()
return stringx.TakeOne(key, field.Name), options, err
return cmp.Or(key, field.Name), options, err
}
// support below notations:
@@ -574,6 +582,10 @@ func toFloat64(v any) (float64, bool) {
}
}
func toReflectValue(tp reflect.Type, v any) reflect.Value {
return reflect.ValueOf(v).Convert(Deref(tp))
}
func usingDifferentKeys(key string, field reflect.StructField) bool {
if len(field.Tag) > 0 {
if _, ok := field.Tag.Lookup(key); !ok {

View File

@@ -334,3 +334,43 @@ func TestValidateValueRange(t *testing.T) {
func TestSetMatchedPrimitiveValue(t *testing.T) {
assert.Error(t, setMatchedPrimitiveValue(reflect.Func, reflect.ValueOf(2), "1"))
}
func TestConvertTypeFromString_Bool(t *testing.T) {
tests := []struct {
name string
input string
want bool
wantErr bool
}{
// true cases
{name: "1", input: "1", want: true, wantErr: false},
{name: "true lowercase", input: "true", want: true, wantErr: false},
{name: "True mixed", input: "True", want: true, wantErr: false},
{name: "TRUE uppercase", input: "TRUE", want: true, wantErr: false},
{name: "TrUe mixed", input: "TrUe", want: true, wantErr: false},
// false cases
{name: "0", input: "0", want: false, wantErr: false},
{name: "false lowercase", input: "false", want: false, wantErr: false},
{name: "False mixed", input: "False", want: false, wantErr: false},
{name: "FALSE uppercase", input: "FALSE", want: false, wantErr: false},
{name: "FaLsE mixed", input: "FaLsE", want: false, wantErr: false},
// error cases
{name: "invalid yes", input: "yes", want: false, wantErr: true},
{name: "invalid no", input: "no", want: false, wantErr: true},
{name: "invalid empty", input: "", want: false, wantErr: true},
{name: "invalid 2", input: "2", want: false, wantErr: true},
{name: "invalid truee", input: "truee", want: false, wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := convertTypeFromString(reflect.Bool, tt.input)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}

View File

@@ -29,3 +29,10 @@ func TestCalcDiffEntropy(t *testing.T) {
}
assert.True(t, CalcEntropy(m) < .99)
}
func TestCalcEntropySingleItem(t *testing.T) {
m := map[any]int{
"only": 42,
}
assert.Equal(t, float64(1), CalcEntropy(m))
}

View File

@@ -1,5 +1,6 @@
package mathx
// Numerical is a constraint that permits any numeric type.
type Numerical interface {
~int | ~int8 | ~int16 | ~int32 | ~int64 |
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |

View File

@@ -6,7 +6,7 @@ import (
"time"
)
// An Unstable is used to generate random value around the mean value base on given deviation.
// An Unstable is used to generate random value around the mean value based on given deviation.
type Unstable struct {
deviation float64
r *rand.Rand

View File

@@ -3,6 +3,9 @@ package mr
import (
"context"
"errors"
"fmt"
"runtime/debug"
"strings"
"sync"
"sync/atomic"
@@ -183,12 +186,16 @@ func buildOptions(opts ...Option) *mapReduceOptions {
return options
}
func buildPanicInfo(r any, stack []byte) string {
return fmt.Sprintf("%+v\n\n%s", r, strings.TrimSpace(string(stack)))
}
func buildSource[T any](generate GenerateFunc[T], panicChan *onceChan) chan T {
source := make(chan T)
go func() {
defer func() {
if r := recover(); r != nil {
panicChan.write(r)
panicChan.write(buildPanicInfo(r, debug.Stack()))
}
close(source)
}()
@@ -235,7 +242,7 @@ func executeMappers[T, U any](mCtx mapperContext[T, U]) {
defer func() {
if r := recover(); r != nil {
atomic.AddInt32(&failed, 1)
mCtx.panicChan.write(r)
mCtx.panicChan.write(buildPanicInfo(r, debug.Stack()))
}
wg.Done()
<-pool
@@ -289,7 +296,7 @@ func mapReduceWithPanicChan[T, U, V any](source <-chan T, panicChan *onceChan, m
defer func() {
drain(collector)
if r := recover(); r != nil {
panicChan.write(r)
panicChan.write(buildPanicInfo(r, debug.Stack()))
}
finish()
}()

View File

@@ -3,8 +3,7 @@ package mr
import (
"context"
"errors"
"io"
"log"
"fmt"
"runtime"
"sync/atomic"
"testing"
@@ -16,9 +15,6 @@ import (
var errDummy = errors.New("dummy")
func init() {
log.SetOutput(io.Discard)
}
func TestFinish(t *testing.T) {
defer goleak.VerifyNone(t)
@@ -39,6 +35,36 @@ func TestFinish(t *testing.T) {
assert.Nil(t, err)
}
func TestFinishWithPartialErrors(t *testing.T) {
defer goleak.VerifyNone(t)
errDummy := errors.New("dummy")
t.Run("one error", func(t *testing.T) {
err := Finish(func() error {
return errDummy
}, func() error {
return nil
}, func() error {
return nil
})
assert.Equal(t, errDummy, err)
})
t.Run("two errors", func(t *testing.T) {
err := Finish(func() error {
return errDummy
}, func() error {
return errDummy
}, func() error {
return nil
})
assert.Equal(t, errDummy, err)
})
}
func TestFinishNone(t *testing.T) {
defer goleak.VerifyNone(t)
@@ -118,11 +144,28 @@ func TestForEach(t *testing.T) {
assert.Equal(t, tasks/2, int(count))
})
}
t.Run("all", func(t *testing.T) {
defer goleak.VerifyNone(t)
func TestPanics(t *testing.T) {
defer goleak.VerifyNone(t)
const tasks = 1000
verify := func(t *testing.T, r any) {
panicStr := fmt.Sprintf("%v", r)
assert.Contains(t, panicStr, "foo")
assert.Contains(t, panicStr, "goroutine")
assert.Contains(t, panicStr, "runtime/debug.Stack")
panic(r)
}
t.Run("ForEach run panics", func(t *testing.T) {
assert.Panics(t, func() {
defer func() {
if r := recover(); r != nil {
verify(t, r)
}
}()
assert.PanicsWithValue(t, "foo", func() {
ForEach(func(source chan<- int) {
for i := 0; i < tasks; i++ {
source <- i
@@ -132,28 +175,31 @@ func TestForEach(t *testing.T) {
})
})
})
}
func TestGeneratePanic(t *testing.T) {
defer goleak.VerifyNone(t)
t.Run("ForEach generate panics", func(t *testing.T) {
assert.Panics(t, func() {
defer func() {
if r := recover(); r != nil {
verify(t, r)
}
}()
t.Run("all", func(t *testing.T) {
assert.PanicsWithValue(t, "foo", func() {
ForEach(func(source chan<- int) {
panic("foo")
}, func(item int) {
})
})
})
}
func TestMapperPanic(t *testing.T) {
defer goleak.VerifyNone(t)
const tasks = 1000
var run int32
t.Run("all", func(t *testing.T) {
assert.PanicsWithValue(t, "foo", func() {
t.Run("Mapper panics", func(t *testing.T) {
assert.Panics(t, func() {
defer func() {
if r := recover(); r != nil {
verify(t, r)
}
}()
_, _ = MapReduce(func(source chan<- int) {
for i := 0; i < tasks; i++ {
source <- i

View File

@@ -5,6 +5,8 @@ import (
"io"
"os"
"runtime"
"runtime/debug"
"runtime/metrics"
"time"
)
@@ -28,10 +30,29 @@ func displayStatsWithWriter(writer io.Writer, interval ...time.Duration) {
ticker := time.NewTicker(duration)
defer ticker.Stop()
for range ticker.C {
var m runtime.MemStats
runtime.ReadMemStats(&m)
var (
alloc, totalAlloc, sys uint64
samples = []metrics.Sample{
{Name: "/memory/classes/heap/objects:bytes"},
{Name: "/gc/heap/allocs:bytes"},
{Name: "/memory/classes/total:bytes"},
}
)
metrics.Read(samples)
if samples[0].Value.Kind() == metrics.KindUint64 {
alloc = samples[0].Value.Uint64()
}
if samples[1].Value.Kind() == metrics.KindUint64 {
totalAlloc = samples[1].Value.Uint64()
}
if samples[2].Value.Kind() == metrics.KindUint64 {
sys = samples[2].Value.Uint64()
}
var stats debug.GCStats
debug.ReadGCStats(&stats)
fmt.Fprintf(writer, "Goroutines: %d, Alloc: %vm, TotalAlloc: %vm, Sys: %vm, NumGC: %v\n",
runtime.NumGoroutine(), m.Alloc/mega, m.TotalAlloc/mega, m.Sys/mega, m.NumGC)
runtime.NumGoroutine(), alloc/mega, totalAlloc/mega, sys/mega, stats.NumGC)
}
}()
}

View File

@@ -1,7 +1,8 @@
package stat
import (
"runtime"
"runtime/debug"
"runtime/metrics"
"sync/atomic"
"time"
@@ -56,8 +57,28 @@ func bToMb(b uint64) float32 {
}
func printUsage() {
var m runtime.MemStats
runtime.ReadMemStats(&m)
var (
alloc, totalAlloc, sys uint64
samples = []metrics.Sample{
{Name: "/memory/classes/heap/objects:bytes"},
{Name: "/gc/heap/allocs:bytes"},
{Name: "/memory/classes/total:bytes"},
}
stats debug.GCStats
)
metrics.Read(samples)
if samples[0].Value.Kind() == metrics.KindUint64 {
alloc = samples[0].Value.Uint64()
}
if samples[1].Value.Kind() == metrics.KindUint64 {
totalAlloc = samples[1].Value.Uint64()
}
if samples[2].Value.Kind() == metrics.KindUint64 {
sys = samples[2].Value.Uint64()
}
debug.ReadGCStats(&stats)
logx.Statf("CPU: %dm, MEMORY: Alloc=%.1fMi, TotalAlloc=%.1fMi, Sys=%.1fMi, NumGC=%d",
CpuUsage(), bToMb(m.Alloc), bToMb(m.TotalAlloc), bToMb(m.Sys), m.NumGC)
CpuUsage(), bToMb(alloc), bToMb(totalAlloc), bToMb(sys), stats.NumGC)
}

View File

@@ -1,3 +1,4 @@
//go:generate mockgen -package mon -destination collectioninserter_mock.go -source bulkinserter.go collectionInserter
package mon
import (
@@ -6,7 +7,8 @@ import (
"github.com/zeromicro/go-zero/core/executors"
"github.com/zeromicro/go-zero/core/logx"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
)
const (
@@ -27,10 +29,7 @@ type (
// NewBulkInserter returns a BulkInserter.
func NewBulkInserter(coll Collection, interval ...time.Duration) (*BulkInserter, error) {
cloneColl, err := coll.Clone()
if err != nil {
return nil, err
}
cloneColl := coll.Clone()
inserter := &dbInserter{
collection: cloneColl,
@@ -64,8 +63,16 @@ func (bi *BulkInserter) SetResultHandler(handler ResultHandler) {
})
}
type collectionInserter interface {
InsertMany(
ctx context.Context,
documents interface{},
opts ...options.Lister[options.InsertManyOptions],
) (*mongo.InsertManyResult, error)
}
type dbInserter struct {
collection *mongo.Collection
collection collectionInserter
documents []any
resultHandler ResultHandler
}

View File

@@ -1,26 +1,131 @@
package mon
import (
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.uber.org/mock/gomock"
)
func TestBulkInserter(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "ok", Value: 1}}...))
bulk, err := NewBulkInserter(createModel(mt).Collection)
assert.Equal(t, err, nil)
bulk.SetResultHandler(func(result *mongo.InsertManyResult, err error) {
assert.Nil(t, err)
assert.Equal(t, 2, len(result.InsertedIDs))
})
bulk.Insert(bson.D{{Key: "foo", Value: "bar"}})
bulk.Insert(bson.D{{Key: "foo", Value: "baz"}})
bulk.Flush()
func TestBulkInserter_InsertAndFlush(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockCollection(ctrl)
mockCollection.EXPECT().Clone().Return(&mongo.Collection{})
bulkInserter, err := NewBulkInserter(mockCollection, time.Second)
assert.NoError(t, err)
bulkInserter.SetResultHandler(func(result *mongo.InsertManyResult, err error) {
assert.Nil(t, err)
assert.Equal(t, 2, len(result.InsertedIDs))
})
doc := map[string]interface{}{"name": "test"}
bulkInserter.Insert(doc)
bulkInserter.Flush()
}
func TestBulkInserter_SetResultHandler(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockCollection(ctrl)
mockCollection.EXPECT().Clone().Return(nil)
bulkInserter, err := NewBulkInserter(mockCollection)
assert.NoError(t, err)
mockHandler := func(result *mongo.InsertManyResult, err error) {}
bulkInserter.SetResultHandler(mockHandler)
}
func TestDbInserter_RemoveAll(t *testing.T) {
inserter := &dbInserter{}
inserter.documents = []interface{}{}
docs := inserter.RemoveAll()
assert.NotNil(t, docs)
assert.Empty(t, inserter.documents)
}
func Test_dbInserter_Execute(t *testing.T) {
type fields struct {
collection collectionInserter
documents []any
resultHandler ResultHandler
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockcollectionInserter(ctrl)
type args struct {
objs any
}
tests := []struct {
name string
fields fields
args args
mock func()
}{
{
name: "empty doc",
fields: fields{
collection: nil,
documents: nil,
resultHandler: nil,
},
args: args{
objs: make([]any, 0),
},
mock: func() {},
},
{
name: "result handler",
fields: fields{
collection: mockCollection,
resultHandler: func(result *mongo.InsertManyResult, err error) {
assert.NotNil(t, err)
},
},
args: args{
objs: make([]any, 1),
},
mock: func() {
mockCollection.EXPECT().InsertMany(gomock.Any(), gomock.Any()).Return(&mongo.InsertManyResult{}, errors.New("error"))
},
},
{
name: "normal error handler",
fields: fields{
collection: mockCollection,
resultHandler: nil,
},
args: args{
objs: make([]any, 1),
},
mock: func() {
mockCollection.EXPECT().InsertMany(gomock.Any(), gomock.Any()).Return(&mongo.InsertManyResult{}, errors.New("error"))
},
},
{
name: "no error",
fields: fields{
collection: mockCollection,
resultHandler: nil,
},
args: args{
objs: make([]any, 1),
},
mock: func() {
mockCollection.EXPECT().InsertMany(gomock.Any(), gomock.Any()).Return(&mongo.InsertManyResult{}, nil)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.mock()
in := &dbInserter{
collection: tt.fields.collection,
documents: tt.fields.documents,
resultHandler: tt.fields.resultHandler,
}
in.Execute(tt.args.objs)
})
}
}

View File

@@ -5,8 +5,8 @@ import (
"io"
"github.com/zeromicro/go-zero/core/syncx"
"go.mongodb.org/mongo-driver/mongo"
mopt "go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
)
var clientManager = syncx.NewResourceManager()
@@ -29,13 +29,13 @@ func Inject(key string, client *mongo.Client) {
func getClient(url string, opts ...Option) (*mongo.Client, error) {
val, err := clientManager.GetResource(url, func() (io.Closer, error) {
o := mopt.Client().ApplyURI(url)
o := options.Client().ApplyURI(url)
opts = append([]Option{defaultTimeoutOption()}, opts...)
for _, opt := range opts {
opt(o)
}
cli, err := mongo.Connect(context.Background(), o)
cli, err := mongo.Connect(o)
if err != nil {
return nil, err
}

View File

@@ -4,19 +4,13 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
"go.mongodb.org/mongo-driver/v2/mongo"
)
func init() {
_ = mtest.Setup()
}
func TestClientManger_getClient(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
Inject(mtest.ClusterURI(), mt.Client)
cli, err := getClient(mtest.ClusterURI())
assert.Nil(t, err)
assert.Equal(t, mt.Client, cli)
})
c := &mongo.Client{}
Inject("foo", c)
cli, err := getClient("foo")
assert.Nil(t, err)
assert.Equal(t, c, cli)
}

View File

@@ -1,3 +1,4 @@
//go:generate mockgen -package mon -destination collection_mock.go -source collection.go Collection,monCollection
package mon
import (
@@ -8,9 +9,9 @@ import (
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/errorx"
"github.com/zeromicro/go-zero/core/timex"
"go.mongodb.org/mongo-driver/mongo"
mopt "go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
)
const (
@@ -47,79 +48,79 @@ type (
// Collection defines a MongoDB collection.
Collection interface {
// Aggregate executes an aggregation pipeline.
Aggregate(ctx context.Context, pipeline any, opts ...*mopt.AggregateOptions) (
Aggregate(ctx context.Context, pipeline any, opts ...options.Lister[options.AggregateOptions]) (
*mongo.Cursor, error)
// BulkWrite performs a bulk write operation.
BulkWrite(ctx context.Context, models []mongo.WriteModel, opts ...*mopt.BulkWriteOptions) (
BulkWrite(ctx context.Context, models []mongo.WriteModel, opts ...options.Lister[options.BulkWriteOptions]) (
*mongo.BulkWriteResult, error)
// Clone creates a copy of this collection with the same settings.
Clone(opts ...*mopt.CollectionOptions) (*mongo.Collection, error)
Clone(opts ...options.Lister[options.CollectionOptions]) *mongo.Collection
// CountDocuments returns the number of documents in the collection that match the filter.
CountDocuments(ctx context.Context, filter any, opts ...*mopt.CountOptions) (int64, error)
CountDocuments(ctx context.Context, filter any, opts ...options.Lister[options.CountOptions]) (int64, error)
// Database returns the database that this collection is a part of.
Database() *mongo.Database
// DeleteMany deletes documents from the collection that match the filter.
DeleteMany(ctx context.Context, filter any, opts ...*mopt.DeleteOptions) (
DeleteMany(ctx context.Context, filter any, opts ...options.Lister[options.DeleteManyOptions]) (
*mongo.DeleteResult, error)
// DeleteOne deletes at most one document from the collection that matches the filter.
DeleteOne(ctx context.Context, filter any, opts ...*mopt.DeleteOptions) (
DeleteOne(ctx context.Context, filter any, opts ...options.Lister[options.DeleteOneOptions]) (
*mongo.DeleteResult, error)
// Distinct returns a list of distinct values for the given key across the collection.
Distinct(ctx context.Context, fieldName string, filter any,
opts ...*mopt.DistinctOptions) ([]any, error)
opts ...options.Lister[options.DistinctOptions]) (*mongo.DistinctResult, error)
// Drop drops this collection from database.
Drop(ctx context.Context) error
Drop(ctx context.Context, opts ...options.Lister[options.DropCollectionOptions]) error
// EstimatedDocumentCount returns an estimate of the count of documents in a collection
// using collection metadata.
EstimatedDocumentCount(ctx context.Context, opts ...*mopt.EstimatedDocumentCountOptions) (int64, error)
EstimatedDocumentCount(ctx context.Context, opts ...options.Lister[options.EstimatedDocumentCountOptions]) (int64, error)
// Find finds the documents matching the provided filter.
Find(ctx context.Context, filter any, opts ...*mopt.FindOptions) (*mongo.Cursor, error)
Find(ctx context.Context, filter any, opts ...options.Lister[options.FindOptions]) (*mongo.Cursor, error)
// FindOne returns up to one document that matches the provided filter.
FindOne(ctx context.Context, filter any, opts ...*mopt.FindOneOptions) (
FindOne(ctx context.Context, filter any, opts ...options.Lister[options.FindOneOptions]) (
*mongo.SingleResult, error)
// FindOneAndDelete returns at most one document that matches the filter. If the filter
// matches multiple documents, only the first document is deleted.
FindOneAndDelete(ctx context.Context, filter any, opts ...*mopt.FindOneAndDeleteOptions) (
FindOneAndDelete(ctx context.Context, filter any, opts ...options.Lister[options.FindOneAndDeleteOptions]) (
*mongo.SingleResult, error)
// FindOneAndReplace returns at most one document that matches the filter. If the filter
// matches multiple documents, FindOneAndReplace returns the first document in the
// collection that matches the filter.
FindOneAndReplace(ctx context.Context, filter, replacement any,
opts ...*mopt.FindOneAndReplaceOptions) (*mongo.SingleResult, error)
opts ...options.Lister[options.FindOneAndReplaceOptions]) (*mongo.SingleResult, error)
// FindOneAndUpdate returns at most one document that matches the filter. If the filter
// matches multiple documents, FindOneAndUpdate returns the first document in the
// collection that matches the filter.
FindOneAndUpdate(ctx context.Context, filter, update any,
opts ...*mopt.FindOneAndUpdateOptions) (*mongo.SingleResult, error)
opts ...options.Lister[options.FindOneAndUpdateOptions]) (*mongo.SingleResult, error)
// Indexes returns the index view for this collection.
Indexes() mongo.IndexView
// InsertMany inserts the provided documents.
InsertMany(ctx context.Context, documents []any, opts ...*mopt.InsertManyOptions) (
InsertMany(ctx context.Context, documents []any, opts ...options.Lister[options.InsertManyOptions]) (
*mongo.InsertManyResult, error)
// InsertOne inserts the provided document.
InsertOne(ctx context.Context, document any, opts ...*mopt.InsertOneOptions) (
InsertOne(ctx context.Context, document any, opts ...options.Lister[options.InsertOneOptions]) (
*mongo.InsertOneResult, error)
// ReplaceOne replaces at most one document that matches the filter.
ReplaceOne(ctx context.Context, filter, replacement any,
opts ...*mopt.ReplaceOptions) (*mongo.UpdateResult, error)
opts ...options.Lister[options.ReplaceOptions]) (*mongo.UpdateResult, error)
// UpdateByID updates a single document matching the provided filter.
UpdateByID(ctx context.Context, id, update any,
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error)
opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error)
// UpdateMany updates the provided documents.
UpdateMany(ctx context.Context, filter, update any,
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error)
opts ...options.Lister[options.UpdateManyOptions]) (*mongo.UpdateResult, error)
// UpdateOne updates a single document matching the provided filter.
UpdateOne(ctx context.Context, filter, update any,
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error)
opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error)
// Watch returns a change stream cursor used to receive notifications of changes to the collection.
Watch(ctx context.Context, pipeline any, opts ...*mopt.ChangeStreamOptions) (
Watch(ctx context.Context, pipeline any, opts ...options.Lister[options.ChangeStreamOptions]) (
*mongo.ChangeStream, error)
}
decoratedCollection struct {
*mongo.Collection
name string
brk breaker.Breaker
Collection monCollection
name string
brk breaker.Breaker
}
keepablePromise struct {
@@ -137,7 +138,7 @@ func newCollection(collection *mongo.Collection, brk breaker.Breaker) Collection
}
func (c *decoratedCollection) Aggregate(ctx context.Context, pipeline any,
opts ...*mopt.AggregateOptions) (cur *mongo.Cursor, err error) {
opts ...options.Lister[options.AggregateOptions]) (cur *mongo.Cursor, err error) {
ctx, span := startSpan(ctx, aggregate)
defer func() {
endSpan(span, err)
@@ -157,7 +158,7 @@ func (c *decoratedCollection) Aggregate(ctx context.Context, pipeline any,
}
func (c *decoratedCollection) BulkWrite(ctx context.Context, models []mongo.WriteModel,
opts ...*mopt.BulkWriteOptions) (res *mongo.BulkWriteResult, err error) {
opts ...options.Lister[options.BulkWriteOptions]) (res *mongo.BulkWriteResult, err error) {
ctx, span := startSpan(ctx, bulkWrite)
defer func() {
endSpan(span, err)
@@ -176,8 +177,12 @@ func (c *decoratedCollection) BulkWrite(ctx context.Context, models []mongo.Writ
return
}
func (c *decoratedCollection) Clone(opts ...options.Lister[options.CollectionOptions]) *mongo.Collection {
return c.Collection.Clone(opts...)
}
func (c *decoratedCollection) CountDocuments(ctx context.Context, filter any,
opts ...*mopt.CountOptions) (count int64, err error) {
opts ...options.Lister[options.CountOptions]) (count int64, err error) {
ctx, span := startSpan(ctx, countDocuments)
defer func() {
endSpan(span, err)
@@ -196,8 +201,12 @@ func (c *decoratedCollection) CountDocuments(ctx context.Context, filter any,
return
}
func (c *decoratedCollection) Database() *mongo.Database {
return c.Collection.Database()
}
func (c *decoratedCollection) DeleteMany(ctx context.Context, filter any,
opts ...*mopt.DeleteOptions) (res *mongo.DeleteResult, err error) {
opts ...options.Lister[options.DeleteManyOptions]) (res *mongo.DeleteResult, err error) {
ctx, span := startSpan(ctx, deleteMany)
defer func() {
endSpan(span, err)
@@ -217,7 +226,7 @@ func (c *decoratedCollection) DeleteMany(ctx context.Context, filter any,
}
func (c *decoratedCollection) DeleteOne(ctx context.Context, filter any,
opts ...*mopt.DeleteOptions) (res *mongo.DeleteResult, err error) {
opts ...options.Lister[options.DeleteOneOptions]) (res *mongo.DeleteResult, err error) {
ctx, span := startSpan(ctx, deleteOne)
defer func() {
endSpan(span, err)
@@ -237,7 +246,7 @@ func (c *decoratedCollection) DeleteOne(ctx context.Context, filter any,
}
func (c *decoratedCollection) Distinct(ctx context.Context, fieldName string, filter any,
opts ...*mopt.DistinctOptions) (val []any, err error) {
opts ...options.Lister[options.DistinctOptions]) (res *mongo.DistinctResult, err error) {
ctx, span := startSpan(ctx, distinct)
defer func() {
endSpan(span, err)
@@ -249,15 +258,20 @@ func (c *decoratedCollection) Distinct(ctx context.Context, fieldName string, fi
c.logDurationSimple(ctx, distinct, startTime, err)
}()
val, err = c.Collection.Distinct(ctx, fieldName, filter, opts...)
res = c.Collection.Distinct(ctx, fieldName, filter, opts...)
err = res.Err()
return err
}, acceptable)
return
}
func (c *decoratedCollection) Drop(ctx context.Context, opts ...options.Lister[options.DropCollectionOptions]) error {
return c.Collection.Drop(ctx, opts...)
}
func (c *decoratedCollection) EstimatedDocumentCount(ctx context.Context,
opts ...*mopt.EstimatedDocumentCountOptions) (val int64, err error) {
opts ...options.Lister[options.EstimatedDocumentCountOptions]) (val int64, err error) {
ctx, span := startSpan(ctx, estimatedDocumentCount)
defer func() {
endSpan(span, err)
@@ -277,7 +291,7 @@ func (c *decoratedCollection) EstimatedDocumentCount(ctx context.Context,
}
func (c *decoratedCollection) Find(ctx context.Context, filter any,
opts ...*mopt.FindOptions) (cur *mongo.Cursor, err error) {
opts ...options.Lister[options.FindOptions]) (cur *mongo.Cursor, err error) {
ctx, span := startSpan(ctx, find)
defer func() {
endSpan(span, err)
@@ -297,7 +311,7 @@ func (c *decoratedCollection) Find(ctx context.Context, filter any,
}
func (c *decoratedCollection) FindOne(ctx context.Context, filter any,
opts ...*mopt.FindOneOptions) (res *mongo.SingleResult, err error) {
opts ...options.Lister[options.FindOneOptions]) (res *mongo.SingleResult, err error) {
ctx, span := startSpan(ctx, findOne)
defer func() {
endSpan(span, err)
@@ -318,7 +332,7 @@ func (c *decoratedCollection) FindOne(ctx context.Context, filter any,
}
func (c *decoratedCollection) FindOneAndDelete(ctx context.Context, filter any,
opts ...*mopt.FindOneAndDeleteOptions) (res *mongo.SingleResult, err error) {
opts ...options.Lister[options.FindOneAndDeleteOptions]) (res *mongo.SingleResult, err error) {
ctx, span := startSpan(ctx, findOneAndDelete)
defer func() {
endSpan(span, err)
@@ -339,7 +353,7 @@ func (c *decoratedCollection) FindOneAndDelete(ctx context.Context, filter any,
}
func (c *decoratedCollection) FindOneAndReplace(ctx context.Context, filter any,
replacement any, opts ...*mopt.FindOneAndReplaceOptions) (
replacement any, opts ...options.Lister[options.FindOneAndReplaceOptions]) (
res *mongo.SingleResult, err error) {
ctx, span := startSpan(ctx, findOneAndReplace)
defer func() {
@@ -361,7 +375,7 @@ func (c *decoratedCollection) FindOneAndReplace(ctx context.Context, filter any,
}
func (c *decoratedCollection) FindOneAndUpdate(ctx context.Context, filter, update any,
opts ...*mopt.FindOneAndUpdateOptions) (res *mongo.SingleResult, err error) {
opts ...options.Lister[options.FindOneAndUpdateOptions]) (res *mongo.SingleResult, err error) {
ctx, span := startSpan(ctx, findOneAndUpdate)
defer func() {
endSpan(span, err)
@@ -381,8 +395,12 @@ func (c *decoratedCollection) FindOneAndUpdate(ctx context.Context, filter, upda
return
}
func (c *decoratedCollection) Indexes() mongo.IndexView {
return c.Collection.Indexes()
}
func (c *decoratedCollection) InsertMany(ctx context.Context, documents []any,
opts ...*mopt.InsertManyOptions) (res *mongo.InsertManyResult, err error) {
opts ...options.Lister[options.InsertManyOptions]) (res *mongo.InsertManyResult, err error) {
ctx, span := startSpan(ctx, insertMany)
defer func() {
endSpan(span, err)
@@ -402,7 +420,7 @@ func (c *decoratedCollection) InsertMany(ctx context.Context, documents []any,
}
func (c *decoratedCollection) InsertOne(ctx context.Context, document any,
opts ...*mopt.InsertOneOptions) (res *mongo.InsertOneResult, err error) {
opts ...options.Lister[options.InsertOneOptions]) (res *mongo.InsertOneResult, err error) {
ctx, span := startSpan(ctx, insertOne)
defer func() {
endSpan(span, err)
@@ -422,7 +440,7 @@ func (c *decoratedCollection) InsertOne(ctx context.Context, document any,
}
func (c *decoratedCollection) ReplaceOne(ctx context.Context, filter, replacement any,
opts ...*mopt.ReplaceOptions) (res *mongo.UpdateResult, err error) {
opts ...options.Lister[options.ReplaceOptions]) (res *mongo.UpdateResult, err error) {
ctx, span := startSpan(ctx, replaceOne)
defer func() {
endSpan(span, err)
@@ -442,7 +460,7 @@ func (c *decoratedCollection) ReplaceOne(ctx context.Context, filter, replacemen
}
func (c *decoratedCollection) UpdateByID(ctx context.Context, id, update any,
opts ...*mopt.UpdateOptions) (res *mongo.UpdateResult, err error) {
opts ...options.Lister[options.UpdateOneOptions]) (res *mongo.UpdateResult, err error) {
ctx, span := startSpan(ctx, updateByID)
defer func() {
endSpan(span, err)
@@ -462,7 +480,7 @@ func (c *decoratedCollection) UpdateByID(ctx context.Context, id, update any,
}
func (c *decoratedCollection) UpdateMany(ctx context.Context, filter, update any,
opts ...*mopt.UpdateOptions) (res *mongo.UpdateResult, err error) {
opts ...options.Lister[options.UpdateManyOptions]) (res *mongo.UpdateResult, err error) {
ctx, span := startSpan(ctx, updateMany)
defer func() {
endSpan(span, err)
@@ -482,7 +500,7 @@ func (c *decoratedCollection) UpdateMany(ctx context.Context, filter, update any
}
func (c *decoratedCollection) UpdateOne(ctx context.Context, filter, update any,
opts ...*mopt.UpdateOptions) (res *mongo.UpdateResult, err error) {
opts ...options.Lister[options.UpdateOneOptions]) (res *mongo.UpdateResult, err error) {
ctx, span := startSpan(ctx, updateOne)
defer func() {
endSpan(span, err)
@@ -501,6 +519,11 @@ func (c *decoratedCollection) UpdateOne(ctx context.Context, filter, update any,
return
}
func (c *decoratedCollection) Watch(ctx context.Context, pipeline any, opts ...options.Lister[options.ChangeStreamOptions]) (
*mongo.ChangeStream, error) {
return c.Collection.Watch(ctx, pipeline, opts...)
}
func (c *decoratedCollection) logDuration(ctx context.Context, method string,
startTime time.Duration, err error, docs ...any) {
logDurationWithDocs(ctx, c.name, method, startTime, err, docs...)
@@ -546,3 +569,71 @@ func isDupKeyError(err error) bool {
return e.HasErrorCode(duplicateKeyCode)
}
// monCollection defines a MongoDB collection, used for unit test
type monCollection interface {
// Aggregate executes an aggregation pipeline.
Aggregate(ctx context.Context, pipeline any, opts ...options.Lister[options.AggregateOptions]) (
*mongo.Cursor, error)
// BulkWrite performs a bulk write operation.
BulkWrite(ctx context.Context, models []mongo.WriteModel, opts ...options.Lister[options.BulkWriteOptions]) (
*mongo.BulkWriteResult, error)
// Clone creates a copy of this collection with the same settings.
Clone(opts ...options.Lister[options.CollectionOptions]) *mongo.Collection
// CountDocuments returns the number of documents in the collection that match the filter.
CountDocuments(ctx context.Context, filter any, opts ...options.Lister[options.CountOptions]) (int64, error)
// Database returns the database that this collection is a part of.
Database() *mongo.Database
// DeleteMany deletes documents from the collection that match the filter.
DeleteMany(ctx context.Context, filter any, opts ...options.Lister[options.DeleteManyOptions]) (
*mongo.DeleteResult, error)
// DeleteOne deletes at most one document from the collection that matches the filter.
DeleteOne(ctx context.Context, filter any, opts ...options.Lister[options.DeleteOneOptions]) (
*mongo.DeleteResult, error)
// Distinct returns a list of distinct values for the given key across the collection.
Distinct(ctx context.Context, fieldName string, filter any,
opts ...options.Lister[options.DistinctOptions]) *mongo.DistinctResult
// Drop drops this collection from database.
Drop(ctx context.Context, opts ...options.Lister[options.DropCollectionOptions]) error
// EstimatedDocumentCount returns an estimate of the count of documents in a collection
// using collection metadata.
EstimatedDocumentCount(ctx context.Context, opts ...options.Lister[options.EstimatedDocumentCountOptions]) (int64, error)
// Find finds the documents matching the provided filter.
Find(ctx context.Context, filter any, opts ...options.Lister[options.FindOptions]) (*mongo.Cursor, error)
// FindOne returns up to one document that matches the provided filter.
FindOne(ctx context.Context, filter any, opts ...options.Lister[options.FindOneOptions]) *mongo.SingleResult
// FindOneAndDelete returns at most one document that matches the filter. If the filter
// matches multiple documents, only the first document is deleted.
FindOneAndDelete(ctx context.Context, filter any, opts ...options.Lister[options.FindOneAndDeleteOptions]) *mongo.SingleResult
// FindOneAndReplace returns at most one document that matches the filter. If the filter
// matches multiple documents, FindOneAndReplace returns the first document in the
// collection that matches the filter.
FindOneAndReplace(ctx context.Context, filter, replacement any,
opts ...options.Lister[options.FindOneAndReplaceOptions]) *mongo.SingleResult
// FindOneAndUpdate returns at most one document that matches the filter. If the filter
// matches multiple documents, FindOneAndUpdate returns the first document in the
// collection that matches the filter.
FindOneAndUpdate(ctx context.Context, filter, update any,
opts ...options.Lister[options.FindOneAndUpdateOptions]) *mongo.SingleResult
// Indexes returns the index view for this collection.
Indexes() mongo.IndexView
// InsertMany inserts the provided documents.
InsertMany(ctx context.Context, documents interface{}, opts ...options.Lister[options.InsertManyOptions]) (*mongo.InsertManyResult, error)
// InsertOne inserts the provided document.
InsertOne(ctx context.Context, document any, opts ...options.Lister[options.InsertOneOptions]) (*mongo.InsertOneResult, error)
// ReplaceOne replaces at most one document that matches the filter.
ReplaceOne(ctx context.Context, filter, replacement any,
opts ...options.Lister[options.ReplaceOptions]) (*mongo.UpdateResult, error)
// UpdateByID updates a single document matching the provided filter.
UpdateByID(ctx context.Context, id, update any,
opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error)
// UpdateMany updates the provided documents.
UpdateMany(ctx context.Context, filter, update any,
opts ...options.Lister[options.UpdateManyOptions]) (*mongo.UpdateResult, error)
// UpdateOne updates a single document matching the provided filter.
UpdateOne(ctx context.Context, filter, update any,
opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error)
// Watch returns a change stream cursor used to receive notifications of changes to the collection.
Watch(ctx context.Context, pipeline any, opts ...options.Lister[options.ChangeStreamOptions]) (
*mongo.ChangeStream, error)
}

View File

@@ -0,0 +1,952 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: collection.go
//
// Generated by this command:
//
// mockgen -package mon -destination collection_mock.go -source collection.go Collection,monCollection
//
// Package mon is a generated GoMock package.
package mon
import (
context "context"
reflect "reflect"
mongo "go.mongodb.org/mongo-driver/v2/mongo"
options "go.mongodb.org/mongo-driver/v2/mongo/options"
gomock "go.uber.org/mock/gomock"
)
// MockCollection is a mock of Collection interface.
type MockCollection struct {
ctrl *gomock.Controller
recorder *MockCollectionMockRecorder
isgomock struct{}
}
// MockCollectionMockRecorder is the mock recorder for MockCollection.
type MockCollectionMockRecorder struct {
mock *MockCollection
}
// NewMockCollection creates a new mock instance.
func NewMockCollection(ctrl *gomock.Controller) *MockCollection {
mock := &MockCollection{ctrl: ctrl}
mock.recorder = &MockCollectionMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockCollection) EXPECT() *MockCollectionMockRecorder {
return m.recorder
}
// Aggregate mocks base method.
func (m *MockCollection) Aggregate(ctx context.Context, pipeline any, opts ...options.Lister[options.AggregateOptions]) (*mongo.Cursor, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, pipeline}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Aggregate", varargs...)
ret0, _ := ret[0].(*mongo.Cursor)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Aggregate indicates an expected call of Aggregate.
func (mr *MockCollectionMockRecorder) Aggregate(ctx, pipeline any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, pipeline}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Aggregate", reflect.TypeOf((*MockCollection)(nil).Aggregate), varargs...)
}
// BulkWrite mocks base method.
func (m *MockCollection) BulkWrite(ctx context.Context, models []mongo.WriteModel, opts ...options.Lister[options.BulkWriteOptions]) (*mongo.BulkWriteResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, models}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "BulkWrite", varargs...)
ret0, _ := ret[0].(*mongo.BulkWriteResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// BulkWrite indicates an expected call of BulkWrite.
func (mr *MockCollectionMockRecorder) BulkWrite(ctx, models any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, models}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BulkWrite", reflect.TypeOf((*MockCollection)(nil).BulkWrite), varargs...)
}
// Clone mocks base method.
func (m *MockCollection) Clone(opts ...options.Lister[options.CollectionOptions]) *mongo.Collection {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Clone", varargs...)
ret0, _ := ret[0].(*mongo.Collection)
return ret0
}
// Clone indicates an expected call of Clone.
func (mr *MockCollectionMockRecorder) Clone(opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clone", reflect.TypeOf((*MockCollection)(nil).Clone), opts...)
}
// CountDocuments mocks base method.
func (m *MockCollection) CountDocuments(ctx context.Context, filter any, opts ...options.Lister[options.CountOptions]) (int64, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, filter}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "CountDocuments", varargs...)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountDocuments indicates an expected call of CountDocuments.
func (mr *MockCollectionMockRecorder) CountDocuments(ctx, filter any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountDocuments", reflect.TypeOf((*MockCollection)(nil).CountDocuments), varargs...)
}
// Database mocks base method.
func (m *MockCollection) Database() *mongo.Database {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Database")
ret0, _ := ret[0].(*mongo.Database)
return ret0
}
// Database indicates an expected call of Database.
func (mr *MockCollectionMockRecorder) Database() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Database", reflect.TypeOf((*MockCollection)(nil).Database))
}
// DeleteMany mocks base method.
func (m *MockCollection) DeleteMany(ctx context.Context, filter any, opts ...options.Lister[options.DeleteManyOptions]) (*mongo.DeleteResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, filter}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "DeleteMany", varargs...)
ret0, _ := ret[0].(*mongo.DeleteResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteMany indicates an expected call of DeleteMany.
func (mr *MockCollectionMockRecorder) DeleteMany(ctx, filter any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMany", reflect.TypeOf((*MockCollection)(nil).DeleteMany), varargs...)
}
// DeleteOne mocks base method.
func (m *MockCollection) DeleteOne(ctx context.Context, filter any, opts ...options.Lister[options.DeleteOneOptions]) (*mongo.DeleteResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, filter}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "DeleteOne", varargs...)
ret0, _ := ret[0].(*mongo.DeleteResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteOne indicates an expected call of DeleteOne.
func (mr *MockCollectionMockRecorder) DeleteOne(ctx, filter any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOne", reflect.TypeOf((*MockCollection)(nil).DeleteOne), varargs...)
}
// Distinct mocks base method.
func (m *MockCollection) Distinct(ctx context.Context, fieldName string, filter any, opts ...options.Lister[options.DistinctOptions]) (*mongo.DistinctResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, fieldName, filter}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Distinct", varargs...)
ret0, _ := ret[0].(*mongo.DistinctResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Distinct indicates an expected call of Distinct.
func (mr *MockCollectionMockRecorder) Distinct(ctx, fieldName, filter any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, fieldName, filter}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Distinct", reflect.TypeOf((*MockCollection)(nil).Distinct), varargs...)
}
// Drop mocks base method.
func (m *MockCollection) Drop(ctx context.Context, opts ...options.Lister[options.DropCollectionOptions]) error {
m.ctrl.T.Helper()
varargs := []any{ctx}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Drop", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// Drop indicates an expected call of Drop.
func (mr *MockCollectionMockRecorder) Drop(ctx any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Drop", reflect.TypeOf((*MockCollection)(nil).Drop), varargs...)
}
// EstimatedDocumentCount mocks base method.
func (m *MockCollection) EstimatedDocumentCount(ctx context.Context, opts ...options.Lister[options.EstimatedDocumentCountOptions]) (int64, error) {
m.ctrl.T.Helper()
varargs := []any{ctx}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "EstimatedDocumentCount", varargs...)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// EstimatedDocumentCount indicates an expected call of EstimatedDocumentCount.
func (mr *MockCollectionMockRecorder) EstimatedDocumentCount(ctx any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EstimatedDocumentCount", reflect.TypeOf((*MockCollection)(nil).EstimatedDocumentCount), varargs...)
}
// Find mocks base method.
func (m *MockCollection) Find(ctx context.Context, filter any, opts ...options.Lister[options.FindOptions]) (*mongo.Cursor, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, filter}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Find", varargs...)
ret0, _ := ret[0].(*mongo.Cursor)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Find indicates an expected call of Find.
func (mr *MockCollectionMockRecorder) Find(ctx, filter any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Find", reflect.TypeOf((*MockCollection)(nil).Find), varargs...)
}
// FindOne mocks base method.
func (m *MockCollection) FindOne(ctx context.Context, filter any, opts ...options.Lister[options.FindOneOptions]) (*mongo.SingleResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, filter}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "FindOne", varargs...)
ret0, _ := ret[0].(*mongo.SingleResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FindOne indicates an expected call of FindOne.
func (mr *MockCollectionMockRecorder) FindOne(ctx, filter any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOne", reflect.TypeOf((*MockCollection)(nil).FindOne), varargs...)
}
// FindOneAndDelete mocks base method.
func (m *MockCollection) FindOneAndDelete(ctx context.Context, filter any, opts ...options.Lister[options.FindOneAndDeleteOptions]) (*mongo.SingleResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, filter}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "FindOneAndDelete", varargs...)
ret0, _ := ret[0].(*mongo.SingleResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FindOneAndDelete indicates an expected call of FindOneAndDelete.
func (mr *MockCollectionMockRecorder) FindOneAndDelete(ctx, filter any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOneAndDelete", reflect.TypeOf((*MockCollection)(nil).FindOneAndDelete), varargs...)
}
// FindOneAndReplace mocks base method.
func (m *MockCollection) FindOneAndReplace(ctx context.Context, filter, replacement any, opts ...options.Lister[options.FindOneAndReplaceOptions]) (*mongo.SingleResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, filter, replacement}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "FindOneAndReplace", varargs...)
ret0, _ := ret[0].(*mongo.SingleResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FindOneAndReplace indicates an expected call of FindOneAndReplace.
func (mr *MockCollectionMockRecorder) FindOneAndReplace(ctx, filter, replacement any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter, replacement}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOneAndReplace", reflect.TypeOf((*MockCollection)(nil).FindOneAndReplace), varargs...)
}
// FindOneAndUpdate mocks base method.
func (m *MockCollection) FindOneAndUpdate(ctx context.Context, filter, update any, opts ...options.Lister[options.FindOneAndUpdateOptions]) (*mongo.SingleResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, filter, update}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "FindOneAndUpdate", varargs...)
ret0, _ := ret[0].(*mongo.SingleResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FindOneAndUpdate indicates an expected call of FindOneAndUpdate.
func (mr *MockCollectionMockRecorder) FindOneAndUpdate(ctx, filter, update any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter, update}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOneAndUpdate", reflect.TypeOf((*MockCollection)(nil).FindOneAndUpdate), varargs...)
}
// Indexes mocks base method.
func (m *MockCollection) Indexes() mongo.IndexView {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Indexes")
ret0, _ := ret[0].(mongo.IndexView)
return ret0
}
// Indexes indicates an expected call of Indexes.
func (mr *MockCollectionMockRecorder) Indexes() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Indexes", reflect.TypeOf((*MockCollection)(nil).Indexes))
}
// InsertMany mocks base method.
func (m *MockCollection) InsertMany(ctx context.Context, documents []any, opts ...options.Lister[options.InsertManyOptions]) (*mongo.InsertManyResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, documents}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "InsertMany", varargs...)
ret0, _ := ret[0].(*mongo.InsertManyResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertMany indicates an expected call of InsertMany.
func (mr *MockCollectionMockRecorder) InsertMany(ctx, documents any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, documents}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertMany", reflect.TypeOf((*MockCollection)(nil).InsertMany), varargs...)
}
// InsertOne mocks base method.
func (m *MockCollection) InsertOne(ctx context.Context, document any, opts ...options.Lister[options.InsertOneOptions]) (*mongo.InsertOneResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, document}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "InsertOne", varargs...)
ret0, _ := ret[0].(*mongo.InsertOneResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertOne indicates an expected call of InsertOne.
func (mr *MockCollectionMockRecorder) InsertOne(ctx, document any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, document}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertOne", reflect.TypeOf((*MockCollection)(nil).InsertOne), varargs...)
}
// ReplaceOne mocks base method.
func (m *MockCollection) ReplaceOne(ctx context.Context, filter, replacement any, opts ...options.Lister[options.ReplaceOptions]) (*mongo.UpdateResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, filter, replacement}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "ReplaceOne", varargs...)
ret0, _ := ret[0].(*mongo.UpdateResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ReplaceOne indicates an expected call of ReplaceOne.
func (mr *MockCollectionMockRecorder) ReplaceOne(ctx, filter, replacement any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter, replacement}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceOne", reflect.TypeOf((*MockCollection)(nil).ReplaceOne), varargs...)
}
// UpdateByID mocks base method.
func (m *MockCollection) UpdateByID(ctx context.Context, id, update any, opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, id, update}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "UpdateByID", varargs...)
ret0, _ := ret[0].(*mongo.UpdateResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateByID indicates an expected call of UpdateByID.
func (mr *MockCollectionMockRecorder) UpdateByID(ctx, id, update any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, id, update}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateByID", reflect.TypeOf((*MockCollection)(nil).UpdateByID), varargs...)
}
// UpdateMany mocks base method.
func (m *MockCollection) UpdateMany(ctx context.Context, filter, update any, opts ...options.Lister[options.UpdateManyOptions]) (*mongo.UpdateResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, filter, update}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "UpdateMany", varargs...)
ret0, _ := ret[0].(*mongo.UpdateResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateMany indicates an expected call of UpdateMany.
func (mr *MockCollectionMockRecorder) UpdateMany(ctx, filter, update any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter, update}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMany", reflect.TypeOf((*MockCollection)(nil).UpdateMany), varargs...)
}
// UpdateOne mocks base method.
func (m *MockCollection) UpdateOne(ctx context.Context, filter, update any, opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, filter, update}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "UpdateOne", varargs...)
ret0, _ := ret[0].(*mongo.UpdateResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateOne indicates an expected call of UpdateOne.
func (mr *MockCollectionMockRecorder) UpdateOne(ctx, filter, update any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter, update}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateOne", reflect.TypeOf((*MockCollection)(nil).UpdateOne), varargs...)
}
// Watch mocks base method.
func (m *MockCollection) Watch(ctx context.Context, pipeline any, opts ...options.Lister[options.ChangeStreamOptions]) (*mongo.ChangeStream, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, pipeline}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Watch", varargs...)
ret0, _ := ret[0].(*mongo.ChangeStream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Watch indicates an expected call of Watch.
func (mr *MockCollectionMockRecorder) Watch(ctx, pipeline any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, pipeline}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockCollection)(nil).Watch), varargs...)
}
// MockmonCollection is a mock of monCollection interface.
type MockmonCollection struct {
ctrl *gomock.Controller
recorder *MockmonCollectionMockRecorder
isgomock struct{}
}
// MockmonCollectionMockRecorder is the mock recorder for MockmonCollection.
type MockmonCollectionMockRecorder struct {
mock *MockmonCollection
}
// NewMockmonCollection creates a new mock instance.
func NewMockmonCollection(ctrl *gomock.Controller) *MockmonCollection {
mock := &MockmonCollection{ctrl: ctrl}
mock.recorder = &MockmonCollectionMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockmonCollection) EXPECT() *MockmonCollectionMockRecorder {
return m.recorder
}
// Aggregate mocks base method.
func (m *MockmonCollection) Aggregate(ctx context.Context, pipeline any, opts ...options.Lister[options.AggregateOptions]) (*mongo.Cursor, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, pipeline}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Aggregate", varargs...)
ret0, _ := ret[0].(*mongo.Cursor)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Aggregate indicates an expected call of Aggregate.
func (mr *MockmonCollectionMockRecorder) Aggregate(ctx, pipeline any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, pipeline}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Aggregate", reflect.TypeOf((*MockmonCollection)(nil).Aggregate), varargs...)
}
// BulkWrite mocks base method.
func (m *MockmonCollection) BulkWrite(ctx context.Context, models []mongo.WriteModel, opts ...options.Lister[options.BulkWriteOptions]) (*mongo.BulkWriteResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, models}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "BulkWrite", varargs...)
ret0, _ := ret[0].(*mongo.BulkWriteResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// BulkWrite indicates an expected call of BulkWrite.
func (mr *MockmonCollectionMockRecorder) BulkWrite(ctx, models any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, models}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BulkWrite", reflect.TypeOf((*MockmonCollection)(nil).BulkWrite), varargs...)
}
// Clone mocks base method.
func (m *MockmonCollection) Clone(opts ...options.Lister[options.CollectionOptions]) *mongo.Collection {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Clone", varargs...)
ret0, _ := ret[0].(*mongo.Collection)
return ret0
}
// Clone indicates an expected call of Clone.
func (mr *MockmonCollectionMockRecorder) Clone(opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clone", reflect.TypeOf((*MockmonCollection)(nil).Clone), opts...)
}
// CountDocuments mocks base method.
func (m *MockmonCollection) CountDocuments(ctx context.Context, filter any, opts ...options.Lister[options.CountOptions]) (int64, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, filter}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "CountDocuments", varargs...)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountDocuments indicates an expected call of CountDocuments.
func (mr *MockmonCollectionMockRecorder) CountDocuments(ctx, filter any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountDocuments", reflect.TypeOf((*MockmonCollection)(nil).CountDocuments), varargs...)
}
// Database mocks base method.
func (m *MockmonCollection) Database() *mongo.Database {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Database")
ret0, _ := ret[0].(*mongo.Database)
return ret0
}
// Database indicates an expected call of Database.
func (mr *MockmonCollectionMockRecorder) Database() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Database", reflect.TypeOf((*MockmonCollection)(nil).Database))
}
// DeleteMany mocks base method.
func (m *MockmonCollection) DeleteMany(ctx context.Context, filter any, opts ...options.Lister[options.DeleteManyOptions]) (*mongo.DeleteResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, filter}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "DeleteMany", varargs...)
ret0, _ := ret[0].(*mongo.DeleteResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteMany indicates an expected call of DeleteMany.
func (mr *MockmonCollectionMockRecorder) DeleteMany(ctx, filter any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMany", reflect.TypeOf((*MockmonCollection)(nil).DeleteMany), varargs...)
}
// DeleteOne mocks base method.
func (m *MockmonCollection) DeleteOne(ctx context.Context, filter any, opts ...options.Lister[options.DeleteOneOptions]) (*mongo.DeleteResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, filter}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "DeleteOne", varargs...)
ret0, _ := ret[0].(*mongo.DeleteResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteOne indicates an expected call of DeleteOne.
func (mr *MockmonCollectionMockRecorder) DeleteOne(ctx, filter any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOne", reflect.TypeOf((*MockmonCollection)(nil).DeleteOne), varargs...)
}
// Distinct mocks base method.
func (m *MockmonCollection) Distinct(ctx context.Context, fieldName string, filter any, opts ...options.Lister[options.DistinctOptions]) *mongo.DistinctResult {
m.ctrl.T.Helper()
varargs := []any{ctx, fieldName, filter}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Distinct", varargs...)
ret0, _ := ret[0].(*mongo.DistinctResult)
return ret0
}
// Distinct indicates an expected call of Distinct.
func (mr *MockmonCollectionMockRecorder) Distinct(ctx, fieldName, filter any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, fieldName, filter}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Distinct", reflect.TypeOf((*MockmonCollection)(nil).Distinct), varargs...)
}
// Drop mocks base method.
func (m *MockmonCollection) Drop(ctx context.Context, opts ...options.Lister[options.DropCollectionOptions]) error {
m.ctrl.T.Helper()
varargs := []any{ctx}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Drop", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// Drop indicates an expected call of Drop.
func (mr *MockmonCollectionMockRecorder) Drop(ctx any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Drop", reflect.TypeOf((*MockmonCollection)(nil).Drop), varargs...)
}
// EstimatedDocumentCount mocks base method.
func (m *MockmonCollection) EstimatedDocumentCount(ctx context.Context, opts ...options.Lister[options.EstimatedDocumentCountOptions]) (int64, error) {
m.ctrl.T.Helper()
varargs := []any{ctx}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "EstimatedDocumentCount", varargs...)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// EstimatedDocumentCount indicates an expected call of EstimatedDocumentCount.
func (mr *MockmonCollectionMockRecorder) EstimatedDocumentCount(ctx any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EstimatedDocumentCount", reflect.TypeOf((*MockmonCollection)(nil).EstimatedDocumentCount), varargs...)
}
// Find mocks base method.
func (m *MockmonCollection) Find(ctx context.Context, filter any, opts ...options.Lister[options.FindOptions]) (*mongo.Cursor, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, filter}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Find", varargs...)
ret0, _ := ret[0].(*mongo.Cursor)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Find indicates an expected call of Find.
func (mr *MockmonCollectionMockRecorder) Find(ctx, filter any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Find", reflect.TypeOf((*MockmonCollection)(nil).Find), varargs...)
}
// FindOne mocks base method.
func (m *MockmonCollection) FindOne(ctx context.Context, filter any, opts ...options.Lister[options.FindOneOptions]) *mongo.SingleResult {
m.ctrl.T.Helper()
varargs := []any{ctx, filter}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "FindOne", varargs...)
ret0, _ := ret[0].(*mongo.SingleResult)
return ret0
}
// FindOne indicates an expected call of FindOne.
func (mr *MockmonCollectionMockRecorder) FindOne(ctx, filter any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOne", reflect.TypeOf((*MockmonCollection)(nil).FindOne), varargs...)
}
// FindOneAndDelete mocks base method.
func (m *MockmonCollection) FindOneAndDelete(ctx context.Context, filter any, opts ...options.Lister[options.FindOneAndDeleteOptions]) *mongo.SingleResult {
m.ctrl.T.Helper()
varargs := []any{ctx, filter}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "FindOneAndDelete", varargs...)
ret0, _ := ret[0].(*mongo.SingleResult)
return ret0
}
// FindOneAndDelete indicates an expected call of FindOneAndDelete.
func (mr *MockmonCollectionMockRecorder) FindOneAndDelete(ctx, filter any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOneAndDelete", reflect.TypeOf((*MockmonCollection)(nil).FindOneAndDelete), varargs...)
}
// FindOneAndReplace mocks base method.
func (m *MockmonCollection) FindOneAndReplace(ctx context.Context, filter, replacement any, opts ...options.Lister[options.FindOneAndReplaceOptions]) *mongo.SingleResult {
m.ctrl.T.Helper()
varargs := []any{ctx, filter, replacement}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "FindOneAndReplace", varargs...)
ret0, _ := ret[0].(*mongo.SingleResult)
return ret0
}
// FindOneAndReplace indicates an expected call of FindOneAndReplace.
func (mr *MockmonCollectionMockRecorder) FindOneAndReplace(ctx, filter, replacement any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter, replacement}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOneAndReplace", reflect.TypeOf((*MockmonCollection)(nil).FindOneAndReplace), varargs...)
}
// FindOneAndUpdate mocks base method.
func (m *MockmonCollection) FindOneAndUpdate(ctx context.Context, filter, update any, opts ...options.Lister[options.FindOneAndUpdateOptions]) *mongo.SingleResult {
m.ctrl.T.Helper()
varargs := []any{ctx, filter, update}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "FindOneAndUpdate", varargs...)
ret0, _ := ret[0].(*mongo.SingleResult)
return ret0
}
// FindOneAndUpdate indicates an expected call of FindOneAndUpdate.
func (mr *MockmonCollectionMockRecorder) FindOneAndUpdate(ctx, filter, update any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter, update}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOneAndUpdate", reflect.TypeOf((*MockmonCollection)(nil).FindOneAndUpdate), varargs...)
}
// Indexes mocks base method.
func (m *MockmonCollection) Indexes() mongo.IndexView {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Indexes")
ret0, _ := ret[0].(mongo.IndexView)
return ret0
}
// Indexes indicates an expected call of Indexes.
func (mr *MockmonCollectionMockRecorder) Indexes() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Indexes", reflect.TypeOf((*MockmonCollection)(nil).Indexes))
}
// InsertMany mocks base method.
func (m *MockmonCollection) InsertMany(ctx context.Context, documents any, opts ...options.Lister[options.InsertManyOptions]) (*mongo.InsertManyResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, documents}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "InsertMany", varargs...)
ret0, _ := ret[0].(*mongo.InsertManyResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertMany indicates an expected call of InsertMany.
func (mr *MockmonCollectionMockRecorder) InsertMany(ctx, documents any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, documents}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertMany", reflect.TypeOf((*MockmonCollection)(nil).InsertMany), varargs...)
}
// InsertOne mocks base method.
func (m *MockmonCollection) InsertOne(ctx context.Context, document any, opts ...options.Lister[options.InsertOneOptions]) (*mongo.InsertOneResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, document}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "InsertOne", varargs...)
ret0, _ := ret[0].(*mongo.InsertOneResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertOne indicates an expected call of InsertOne.
func (mr *MockmonCollectionMockRecorder) InsertOne(ctx, document any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, document}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertOne", reflect.TypeOf((*MockmonCollection)(nil).InsertOne), varargs...)
}
// ReplaceOne mocks base method.
func (m *MockmonCollection) ReplaceOne(ctx context.Context, filter, replacement any, opts ...options.Lister[options.ReplaceOptions]) (*mongo.UpdateResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, filter, replacement}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "ReplaceOne", varargs...)
ret0, _ := ret[0].(*mongo.UpdateResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ReplaceOne indicates an expected call of ReplaceOne.
func (mr *MockmonCollectionMockRecorder) ReplaceOne(ctx, filter, replacement any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter, replacement}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceOne", reflect.TypeOf((*MockmonCollection)(nil).ReplaceOne), varargs...)
}
// UpdateByID mocks base method.
func (m *MockmonCollection) UpdateByID(ctx context.Context, id, update any, opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, id, update}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "UpdateByID", varargs...)
ret0, _ := ret[0].(*mongo.UpdateResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateByID indicates an expected call of UpdateByID.
func (mr *MockmonCollectionMockRecorder) UpdateByID(ctx, id, update any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, id, update}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateByID", reflect.TypeOf((*MockmonCollection)(nil).UpdateByID), varargs...)
}
// UpdateMany mocks base method.
func (m *MockmonCollection) UpdateMany(ctx context.Context, filter, update any, opts ...options.Lister[options.UpdateManyOptions]) (*mongo.UpdateResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, filter, update}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "UpdateMany", varargs...)
ret0, _ := ret[0].(*mongo.UpdateResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateMany indicates an expected call of UpdateMany.
func (mr *MockmonCollectionMockRecorder) UpdateMany(ctx, filter, update any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter, update}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMany", reflect.TypeOf((*MockmonCollection)(nil).UpdateMany), varargs...)
}
// UpdateOne mocks base method.
func (m *MockmonCollection) UpdateOne(ctx context.Context, filter, update any, opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, filter, update}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "UpdateOne", varargs...)
ret0, _ := ret[0].(*mongo.UpdateResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateOne indicates an expected call of UpdateOne.
func (mr *MockmonCollectionMockRecorder) UpdateOne(ctx, filter, update any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, filter, update}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateOne", reflect.TypeOf((*MockmonCollection)(nil).UpdateOne), varargs...)
}
// Watch mocks base method.
func (m *MockmonCollection) Watch(ctx context.Context, pipeline any, opts ...options.Lister[options.ChangeStreamOptions]) (*mongo.ChangeStream, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, pipeline}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Watch", varargs...)
ret0, _ := ret[0].(*mongo.ChangeStream)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Watch indicates an expected call of Watch.
func (mr *MockmonCollectionMockRecorder) Watch(ctx, pipeline any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, pipeline}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockmonCollection)(nil).Watch), varargs...)
}

View File

@@ -10,12 +10,10 @@ import (
"github.com/zeromicro/go-zero/core/logx/logtest"
"github.com/zeromicro/go-zero/core/stringx"
"github.com/zeromicro/go-zero/core/timex"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
mopt "go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.uber.org/mock/gomock"
)
var errDummy = errors.New("dummy")
@@ -68,471 +66,345 @@ func TestKeepPromise_keep(t *testing.T) {
}
func TestNewCollection(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
coll := mt.Coll
assert.NotNil(t, coll)
col := newCollection(coll, breaker.GetBreaker("localhost"))
assert.Equal(t, t.Name()+"/test", col.(*decoratedCollection).name)
})
_ = newCollection(&mongo.Collection{}, breaker.GetBreaker("localhost"))
}
func TestCollection_Aggregate(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
coll := mt.Coll
assert.NotNil(t, coll)
col := newCollection(coll, breaker.GetBreaker("localhost"))
ns := mt.Coll.Database().Name() + "." + mt.Coll.Name()
aggRes := mtest.CreateCursorResponse(1, ns, mtest.FirstBatch)
mt.AddMockResponses(aggRes)
assert.Equal(t, t.Name()+"/test", col.(*decoratedCollection).name)
cursor, err := col.Aggregate(context.Background(), mongo.Pipeline{}, mopt.Aggregate())
assert.Nil(t, err)
cursor.Close(context.Background())
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().Aggregate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.Cursor{}, nil)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
_, err := c.Aggregate(context.Background(), []interface{}{}, options.Aggregate())
assert.Nil(t, err)
}
func TestCollection_BulkWrite(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "ok", Value: 1}}...))
res, err := c.BulkWrite(context.Background(), []mongo.WriteModel{
mongo.NewInsertOneModel().SetDocument(bson.D{{Key: "foo", Value: 1}}),
})
assert.Nil(t, err)
assert.NotNil(t, res)
c.brk = new(dropBreaker)
_, err = c.BulkWrite(context.Background(), []mongo.WriteModel{
mongo.NewInsertOneModel().SetDocument(bson.D{{Key: "foo", Value: 1}}),
})
assert.Equal(t, errDummy, err)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().BulkWrite(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.BulkWriteResult{}, nil)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
_, err := c.BulkWrite(context.Background(), []mongo.WriteModel{
mongo.NewInsertOneModel().SetDocument(bson.D{{Key: "foo", Value: 1}}),
})
assert.Nil(t, err)
c.brk = new(dropBreaker)
_, err = c.BulkWrite(context.Background(), []mongo.WriteModel{
mongo.NewInsertOneModel().SetDocument(bson.D{{Key: "foo", Value: 1}}),
})
assert.Equal(t, errDummy, err)
}
func TestCollection_CountDocuments(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.FirstBatch,
bson.D{
{Key: "n", Value: 1},
}))
res, err := c.CountDocuments(context.Background(), bson.D{})
assert.Nil(t, err)
assert.Equal(t, int64(1), res)
c.brk = new(dropBreaker)
_, err = c.CountDocuments(context.Background(), bson.D{{Key: "foo", Value: 1}})
assert.Equal(t, errDummy, err)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().CountDocuments(gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(0), nil)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
res, err := c.CountDocuments(context.Background(), bson.D{})
assert.Nil(t, err)
assert.Equal(t, int64(0), res)
c.brk = new(dropBreaker)
_, err = c.CountDocuments(context.Background(), bson.D{{Key: "foo", Value: 1}})
assert.Equal(t, errDummy, err)
}
func TestDecoratedCollection_DeleteMany(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
res, err := c.DeleteMany(context.Background(), bson.D{})
assert.Nil(t, err)
assert.Equal(t, int64(1), res.DeletedCount)
c.brk = new(dropBreaker)
_, err = c.DeleteMany(context.Background(), bson.D{{Key: "foo", Value: 1}})
assert.Equal(t, errDummy, err)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().DeleteMany(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.DeleteResult{}, nil)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
_, err := c.DeleteMany(context.Background(), bson.D{})
assert.Nil(t, err)
c.brk = new(dropBreaker)
_, err = c.DeleteMany(context.Background(), bson.D{{Key: "foo", Value: 1}})
assert.Equal(t, errDummy, err)
}
func TestCollection_Distinct(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(bson.D{{Key: "ok", Value: 1}, {Key: "values", Value: []int{1}}})
resp, err := c.Distinct(context.Background(), "foo", bson.D{})
assert.Nil(t, err)
assert.Equal(t, 1, len(resp))
c.brk = new(dropBreaker)
_, err = c.Distinct(context.Background(), "foo", bson.D{{Key: "foo", Value: 1}})
assert.Equal(t, errDummy, err)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().Distinct(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.DistinctResult{})
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
_, err := c.Distinct(context.Background(), "foo", bson.D{})
assert.Nil(t, err)
c.brk = new(dropBreaker)
_, err = c.Distinct(context.Background(), "foo", bson.D{{Key: "foo", Value: 1}})
assert.Equal(t, errDummy, err)
}
func TestCollection_EstimatedDocumentCount(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(bson.D{{Key: "ok", Value: 1}, {Key: "n", Value: 1}})
res, err := c.EstimatedDocumentCount(context.Background())
assert.Nil(t, err)
assert.Equal(t, int64(1), res)
c.brk = new(dropBreaker)
_, err = c.EstimatedDocumentCount(context.Background())
assert.Equal(t, errDummy, err)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().EstimatedDocumentCount(gomock.Any(), gomock.Any()).Return(int64(0), nil)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
_, err := c.EstimatedDocumentCount(context.Background())
assert.Nil(t, err)
c.brk = new(dropBreaker)
_, err = c.EstimatedDocumentCount(context.Background())
assert.Equal(t, errDummy, err)
}
func TestCollection_Find(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
find := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.FirstBatch,
bson.D{
{Key: "name", Value: "John"},
})
getMore := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.NextBatch,
bson.D{
{Key: "name", Value: "Mary"},
})
killCursors := mtest.CreateCursorResponse(
0,
"DBName.CollectionName",
mtest.NextBatch)
mt.AddMockResponses(find, getMore, killCursors)
filter := bson.D{{Key: "x", Value: 1}}
cursor, err := c.Find(context.Background(), filter, mopt.Find())
assert.Nil(t, err)
defer cursor.Close(context.Background())
var val []struct {
ID primitive.ObjectID `bson:"_id"`
Name string `bson:"name"`
}
assert.Nil(t, cursor.All(context.Background(), &val))
assert.Equal(t, 2, len(val))
assert.Equal(t, "John", val[0].Name)
assert.Equal(t, "Mary", val[1].Name)
c.brk = new(dropBreaker)
_, err = c.Find(context.Background(), filter, mopt.Find())
assert.Equal(t, errDummy, err)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().Find(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.Cursor{}, nil)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
filter := bson.D{{Key: "x", Value: 1}}
_, err := c.Find(context.Background(), filter, options.Find())
assert.Nil(t, err)
c.brk = new(dropBreaker)
_, err = c.Find(context.Background(), filter, options.Find())
assert.Equal(t, errDummy, err)
}
func TestCollection_FindOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
find := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.FirstBatch,
bson.D{
{Key: "name", Value: "John"},
})
getMore := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.NextBatch,
bson.D{
{Key: "name", Value: "Mary"},
})
killCursors := mtest.CreateCursorResponse(
0,
"DBName.CollectionName",
mtest.NextBatch)
mt.AddMockResponses(find, getMore, killCursors)
filter := bson.D{{Key: "x", Value: 1}}
resp, err := c.FindOne(context.Background(), filter)
assert.Nil(t, err)
var val struct {
ID primitive.ObjectID `bson:"_id"`
Name string `bson:"name"`
}
assert.Nil(t, resp.Decode(&val))
assert.Equal(t, "John", val.Name)
c.brk = new(dropBreaker)
_, err = c.FindOne(context.Background(), filter)
assert.Equal(t, errDummy, err)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().FindOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{})
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
filter := bson.D{{Key: "x", Value: 1}}
_, err := c.FindOne(context.Background(), filter)
assert.Equal(t, mongo.ErrNoDocuments, err)
c.brk = new(dropBreaker)
_, err = c.FindOne(context.Background(), filter)
assert.Equal(t, errDummy, err)
}
func TestCollection_FindOneAndDelete(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
filter := bson.D{}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{}...))
_, err := c.FindOneAndDelete(context.Background(), filter, mopt.FindOneAndDelete())
assert.Equal(t, mongo.ErrNoDocuments, err)
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "name", Value: "John"}}},
}...))
resp, err := c.FindOneAndDelete(context.Background(), filter, mopt.FindOneAndDelete())
assert.Nil(t, err)
var val struct {
Name string `bson:"name"`
}
assert.Nil(t, resp.Decode(&val))
assert.Equal(t, "John", val.Name)
c.brk = new(dropBreaker)
_, err = c.FindOneAndDelete(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Equal(t, errDummy, err)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().FindOneAndDelete(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{})
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
filter := bson.D{}
mockCollection.EXPECT().FindOneAndDelete(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{})
_, err := c.FindOneAndDelete(context.Background(), filter, options.FindOneAndDelete())
assert.Equal(t, mongo.ErrNoDocuments, err)
_, err = c.FindOneAndDelete(context.Background(), filter, options.FindOneAndDelete())
assert.Equal(t, mongo.ErrNoDocuments, err)
c.brk = new(dropBreaker)
_, err = c.FindOneAndDelete(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Equal(t, errDummy, err)
}
func TestCollection_FindOneAndReplace(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{}...))
filter := bson.D{{Key: "x", Value: 1}}
replacement := bson.D{{Key: "x", Value: 2}}
opts := mopt.FindOneAndReplace().SetUpsert(true)
_, err := c.FindOneAndReplace(context.Background(), filter, replacement, opts)
assert.Equal(t, mongo.ErrNoDocuments, err)
mt.AddMockResponses(bson.D{{Key: "ok", Value: 1}, {Key: "value", Value: bson.D{
{Key: "name", Value: "John"},
}}})
resp, err := c.FindOneAndReplace(context.Background(), filter, replacement, opts)
assert.Nil(t, err)
var val struct {
Name string `bson:"name"`
}
assert.Nil(t, resp.Decode(&val))
assert.Equal(t, "John", val.Name)
c.brk = new(dropBreaker)
_, err = c.FindOneAndReplace(context.Background(), filter, replacement, opts)
assert.Equal(t, errDummy, err)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().FindOneAndReplace(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{})
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
filter := bson.D{{Key: "x", Value: 1}}
replacement := bson.D{{Key: "x", Value: 2}}
opts := options.FindOneAndReplace().SetUpsert(true)
mockCollection.EXPECT().FindOneAndReplace(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{})
_, err := c.FindOneAndReplace(context.Background(), filter, replacement, opts)
assert.Equal(t, mongo.ErrNoDocuments, err)
_, err = c.FindOneAndReplace(context.Background(), filter, replacement, opts)
assert.Equal(t, mongo.ErrNoDocuments, err)
c.brk = new(dropBreaker)
_, err = c.FindOneAndReplace(context.Background(), filter, replacement, opts)
assert.Equal(t, errDummy, err)
}
func TestCollection_FindOneAndUpdate(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(bson.D{{Key: "ok", Value: 1}})
filter := bson.D{{Key: "x", Value: 1}}
update := bson.D{{Key: "$x", Value: 2}}
opts := mopt.FindOneAndUpdate().SetUpsert(true)
_, err := c.FindOneAndUpdate(context.Background(), filter, update, opts)
assert.Equal(t, mongo.ErrNoDocuments, err)
mt.AddMockResponses(bson.D{{Key: "ok", Value: 1}, {Key: "value", Value: bson.D{
{Key: "name", Value: "John"},
}}})
resp, err := c.FindOneAndUpdate(context.Background(), filter, update, opts)
assert.Nil(t, err)
var val struct {
Name string `bson:"name"`
}
assert.Nil(t, resp.Decode(&val))
assert.Equal(t, "John", val.Name)
c.brk = new(dropBreaker)
_, err = c.FindOneAndUpdate(context.Background(), filter, update, opts)
assert.Equal(t, errDummy, err)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().FindOneAndUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{})
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
filter := bson.D{{Key: "x", Value: 1}}
update := bson.D{{Key: "$x", Value: 2}}
mockCollection.EXPECT().FindOneAndUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{})
opts := options.FindOneAndUpdate().SetUpsert(true)
_, err := c.FindOneAndUpdate(context.Background(), filter, update, opts)
assert.Equal(t, mongo.ErrNoDocuments, err)
_, err = c.FindOneAndUpdate(context.Background(), filter, update, opts)
assert.Equal(t, mongo.ErrNoDocuments, err)
c.brk = new(dropBreaker)
_, err = c.FindOneAndUpdate(context.Background(), filter, update, opts)
assert.Equal(t, errDummy, err)
}
func TestCollection_InsertOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "ok", Value: 1}}...))
res, err := c.InsertOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Nil(t, err)
assert.NotNil(t, res)
c.brk = new(dropBreaker)
_, err = c.InsertOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Equal(t, errDummy, err)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().InsertOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.InsertOneResult{}, nil)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
res, err := c.InsertOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Nil(t, err)
assert.NotNil(t, res)
c.brk = new(dropBreaker)
_, err = c.InsertOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Equal(t, errDummy, err)
}
func TestCollection_InsertMany(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "ok", Value: 1}}...))
res, err := c.InsertMany(context.Background(), []any{
bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "foo", Value: "baz"}},
})
assert.Nil(t, err)
assert.NotNil(t, res)
assert.Equal(t, 2, len(res.InsertedIDs))
c.brk = new(dropBreaker)
_, err = c.InsertMany(context.Background(), []any{bson.D{{Key: "foo", Value: "bar"}}})
assert.Equal(t, errDummy, err)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().InsertMany(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.InsertManyResult{}, nil)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
_, err := c.InsertMany(context.Background(), []any{
bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "foo", Value: "baz"}},
})
assert.Nil(t, err)
c.brk = new(dropBreaker)
_, err = c.InsertMany(context.Background(), []any{bson.D{{Key: "foo", Value: "bar"}}})
assert.Equal(t, errDummy, err)
}
func TestCollection_DeleteOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
res, err := c.DeleteOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Nil(t, err)
assert.Equal(t, int64(1), res.DeletedCount)
c.brk = new(dropBreaker)
_, err = c.DeleteOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Equal(t, errDummy, err)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().DeleteOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.DeleteResult{}, nil)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
_, err := c.DeleteOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Nil(t, err)
c.brk = new(dropBreaker)
_, err = c.DeleteOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Equal(t, errDummy, err)
}
func TestCollection_DeleteMany(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
res, err := c.DeleteMany(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Nil(t, err)
assert.Equal(t, int64(1), res.DeletedCount)
c.brk = new(dropBreaker)
_, err = c.DeleteMany(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Equal(t, errDummy, err)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().DeleteMany(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.DeleteResult{}, nil)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
_, err := c.DeleteMany(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Nil(t, err)
c.brk = new(dropBreaker)
_, err = c.DeleteMany(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Equal(t, errDummy, err)
}
func TestCollection_ReplaceOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
res, err := c.ReplaceOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "foo", Value: "baz"}},
)
assert.Nil(t, err)
assert.Equal(t, int64(1), res.MatchedCount)
c.brk = new(dropBreaker)
_, err = c.ReplaceOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "foo", Value: "baz"}})
assert.Equal(t, errDummy, err)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().ReplaceOne(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
_, err := c.ReplaceOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "foo", Value: "baz"}},
)
assert.Nil(t, err)
c.brk = new(dropBreaker)
_, err = c.ReplaceOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "foo", Value: "baz"}})
assert.Equal(t, errDummy, err)
}
func TestCollection_UpdateOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
resp, err := c.UpdateOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
assert.Nil(t, err)
assert.Equal(t, int64(1), resp.MatchedCount)
c.brk = new(dropBreaker)
_, err = c.UpdateOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
assert.Equal(t, errDummy, err)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().UpdateOne(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
_, err := c.UpdateOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
assert.Nil(t, err)
c.brk = new(dropBreaker)
_, err = c.UpdateOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
assert.Equal(t, errDummy, err)
}
func TestCollection_UpdateByID(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
resp, err := c.UpdateByID(context.Background(), primitive.NewObjectID(),
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
assert.Nil(t, err)
assert.Equal(t, int64(1), resp.MatchedCount)
c.brk = new(dropBreaker)
_, err = c.UpdateByID(context.Background(), primitive.NewObjectID(),
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
assert.Equal(t, errDummy, err)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().UpdateByID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
_, err := c.UpdateByID(context.Background(), bson.NewObjectID(),
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
assert.Nil(t, err)
c.brk = new(dropBreaker)
_, err = c.UpdateByID(context.Background(), bson.NewObjectID(),
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
assert.Equal(t, errDummy, err)
}
func TestCollection_UpdateMany(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
resp, err := c.UpdateMany(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
assert.Nil(t, err)
assert.Equal(t, int64(1), resp.MatchedCount)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().UpdateMany(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
_, err := c.UpdateMany(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
assert.Nil(t, err)
c.brk = new(dropBreaker)
_, err = c.UpdateMany(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
assert.Equal(t, errDummy, err)
}
c.brk = new(dropBreaker)
_, err = c.UpdateMany(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
assert.Equal(t, errDummy, err)
})
func TestCollection_Watch(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.ChangeStream{}, nil)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
_, err := c.Watch(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Nil(t, err)
}
func TestCollection_Clone(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().Clone(gomock.Any()).Return(nil)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
cc := c.Clone()
assert.Nil(t, cc)
}
func TestCollection_Database(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().Database().Return(nil)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
db := c.Database()
assert.Nil(t, db)
}
func TestCollection_Drop(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().Drop(gomock.Any()).Return(nil)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
err := c.Drop(context.Background())
assert.Nil(t, err)
}
func TestCollection_Indexes(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
idx := mongo.IndexView{}
mockCollection.EXPECT().Indexes().Return(idx)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
index := c.Indexes()
assert.Equal(t, index, idx)
}
func TestDecoratedCollection_LogDuration(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := NewMockmonCollection(ctrl)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
buf := logtest.NewCollector(t)
@@ -585,14 +457,6 @@ func TestAcceptable(t *testing.T) {
{"NilDocument", mongo.ErrNilDocument, true},
{"NilCursor", mongo.ErrNilCursor, true},
{"EmptySlice", mongo.ErrEmptySlice, true},
{"SessionEnded", session.ErrSessionEnded, true},
{"NoTransactStarted", session.ErrNoTransactStarted, true},
{"TransactInProgress", session.ErrTransactInProgress, true},
{"AbortAfterCommit", session.ErrAbortAfterCommit, true},
{"AbortTwice", session.ErrAbortTwice, true},
{"CommitAfterAbort", session.ErrCommitAfterAbort, true},
{"UnackWCUnsupported", session.ErrUnackWCUnsupported, true},
{"SnapshotTransaction", session.ErrSnapshotTransaction, true},
{"DuplicateKeyError", mongo.WriteException{WriteErrors: []mongo.WriteError{{Code: duplicateKeyCode}}}, true},
{"OtherError", errors.New("other error"), false},
}
@@ -623,6 +487,14 @@ func TestIsDupKeyError(t *testing.T) {
}
}
func newTestCollection(collection monCollection, brk breaker.Breaker) *decoratedCollection {
return &decoratedCollection{
Collection: collection,
name: "test",
brk: brk,
}
}
type mockPromise struct {
accepted bool
reason string

View File

@@ -0,0 +1,63 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: bulkinserter.go
//
// Generated by this command:
//
// mockgen -package mon -destination collectioninserter_mock.go -source bulkinserter.go collectionInserter
//
// Package mon is a generated GoMock package.
package mon
import (
context "context"
reflect "reflect"
mongo "go.mongodb.org/mongo-driver/v2/mongo"
options "go.mongodb.org/mongo-driver/v2/mongo/options"
gomock "go.uber.org/mock/gomock"
)
// MockcollectionInserter is a mock of collectionInserter interface.
type MockcollectionInserter struct {
ctrl *gomock.Controller
recorder *MockcollectionInserterMockRecorder
isgomock struct{}
}
// MockcollectionInserterMockRecorder is the mock recorder for MockcollectionInserter.
type MockcollectionInserterMockRecorder struct {
mock *MockcollectionInserter
}
// NewMockcollectionInserter creates a new mock instance.
func NewMockcollectionInserter(ctrl *gomock.Controller) *MockcollectionInserter {
mock := &MockcollectionInserter{ctrl: ctrl}
mock.recorder = &MockcollectionInserterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockcollectionInserter) EXPECT() *MockcollectionInserterMockRecorder {
return m.recorder
}
// InsertMany mocks base method.
func (m *MockcollectionInserter) InsertMany(ctx context.Context, documents any, opts ...options.Lister[options.InsertManyOptions]) (*mongo.InsertManyResult, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, documents}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "InsertMany", varargs...)
ret0, _ := ret[0].(*mongo.InsertManyResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertMany indicates an expected call of InsertMany.
func (mr *MockcollectionInserterMockRecorder) InsertMany(ctx, documents any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, documents}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertMany", reflect.TypeOf((*MockcollectionInserter)(nil).InsertMany), varargs...)
}

View File

@@ -0,0 +1,19 @@
# Migrating from 1.x to 2.0
To upgrade imports of the Go Driver from v1 to v2, we recommend using [marwan-at-work/mod
](https://github.com/marwan-at-work/mod):
```
mod upgrade --mod-name=go.mongodb.org/mongo-driver
```
# Notice
After completing the mod upgrade, code changes are typically unnecessary in the vast majority of cases. However, if your project references packages including but not limited to those listed below, you'll need to manually replace them, as these libraries are no longer present in the v2 version.
```go
go.mongodb.org/mongo-driver/bson/bsonrw => go.mongodb.org/mongo-driver/v2/bson
go.mongodb.org/mongo-driver/bson/bsoncodec => go.mongodb.org/mongo-driver/v2/bson
go.mongodb.org/mongo-driver/bson/primitive => go.mongodb.org/mongo-driver/v2/bson
```
See the following resources to learn more about upgrading from version 1.x to 2.0.:
https://raw.githubusercontent.com/mongodb/mongo-go-driver/refs/heads/master/docs/migration-2.0.md

View File

@@ -1,3 +1,4 @@
//go:generate mockgen -package mon -destination model_mock.go -source model.go monClient monSession
package mon
import (
@@ -7,8 +8,8 @@ import (
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/timex"
"go.mongodb.org/mongo-driver/mongo"
mopt "go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
)
const (
@@ -24,15 +25,15 @@ type (
Model struct {
Collection
name string
cli *mongo.Client
cli monClient
brk breaker.Breaker
opts []Option
}
wrappedSession struct {
mongo.Session
name string
brk breaker.Breaker
Session struct {
session monSession
name string
brk breaker.Breaker
}
)
@@ -61,14 +62,14 @@ func newModel(name string, cli *mongo.Client, coll Collection, brk breaker.Break
return &Model{
name: name,
Collection: coll,
cli: cli,
cli: &wrappedMonClient{c: cli},
brk: brk,
opts: opts,
}
}
// StartSession starts a new session.
func (m *Model) StartSession(opts ...*mopt.SessionOptions) (sess mongo.Session, err error) {
func (m *Model) StartSession(opts ...options.Lister[options.SessionOptions]) (sess *Session, err error) {
starTime := timex.Now()
defer func() {
logDuration(context.Background(), m.name, startSession, starTime, err)
@@ -79,15 +80,16 @@ func (m *Model) StartSession(opts ...*mopt.SessionOptions) (sess mongo.Session,
return nil, sessionErr
}
return &wrappedSession{
Session: session,
return &Session{
session: session,
name: m.name,
brk: m.brk,
}, nil
}
// Aggregate executes an aggregation pipeline.
func (m *Model) Aggregate(ctx context.Context, v, pipeline any, opts ...*mopt.AggregateOptions) error {
func (m *Model) Aggregate(ctx context.Context, v, pipeline any,
opts ...options.Lister[options.AggregateOptions]) error {
cur, err := m.Collection.Aggregate(ctx, pipeline, opts...)
if err != nil {
return err
@@ -98,7 +100,8 @@ func (m *Model) Aggregate(ctx context.Context, v, pipeline any, opts ...*mopt.Ag
}
// DeleteMany deletes documents that match the filter.
func (m *Model) DeleteMany(ctx context.Context, filter any, opts ...*mopt.DeleteOptions) (int64, error) {
func (m *Model) DeleteMany(ctx context.Context, filter any,
opts ...options.Lister[options.DeleteManyOptions]) (int64, error) {
res, err := m.Collection.DeleteMany(ctx, filter, opts...)
if err != nil {
return 0, err
@@ -108,7 +111,8 @@ func (m *Model) DeleteMany(ctx context.Context, filter any, opts ...*mopt.Delete
}
// DeleteOne deletes the first document that matches the filter.
func (m *Model) DeleteOne(ctx context.Context, filter any, opts ...*mopt.DeleteOptions) (int64, error) {
func (m *Model) DeleteOne(ctx context.Context, filter any,
opts ...options.Lister[options.DeleteOneOptions]) (int64, error) {
res, err := m.Collection.DeleteOne(ctx, filter, opts...)
if err != nil {
return 0, err
@@ -118,7 +122,8 @@ func (m *Model) DeleteOne(ctx context.Context, filter any, opts ...*mopt.DeleteO
}
// Find finds documents that match the filter.
func (m *Model) Find(ctx context.Context, v, filter any, opts ...*mopt.FindOptions) error {
func (m *Model) Find(ctx context.Context, v, filter any,
opts ...options.Lister[options.FindOptions]) error {
cur, err := m.Collection.Find(ctx, filter, opts...)
if err != nil {
return err
@@ -129,7 +134,8 @@ func (m *Model) Find(ctx context.Context, v, filter any, opts ...*mopt.FindOptio
}
// FindOne finds the first document that matches the filter.
func (m *Model) FindOne(ctx context.Context, v, filter any, opts ...*mopt.FindOneOptions) error {
func (m *Model) FindOne(ctx context.Context, v, filter any,
opts ...options.Lister[options.FindOneOptions]) error {
res, err := m.Collection.FindOne(ctx, filter, opts...)
if err != nil {
return err
@@ -140,7 +146,7 @@ func (m *Model) FindOne(ctx context.Context, v, filter any, opts ...*mopt.FindOn
// FindOneAndDelete finds a single document and deletes it.
func (m *Model) FindOneAndDelete(ctx context.Context, v, filter any,
opts ...*mopt.FindOneAndDeleteOptions) error {
opts ...options.Lister[options.FindOneAndDeleteOptions]) error {
res, err := m.Collection.FindOneAndDelete(ctx, filter, opts...)
if err != nil {
return err
@@ -151,7 +157,7 @@ func (m *Model) FindOneAndDelete(ctx context.Context, v, filter any,
// FindOneAndReplace finds a single document and replaces it.
func (m *Model) FindOneAndReplace(ctx context.Context, v, filter, replacement any,
opts ...*mopt.FindOneAndReplaceOptions) error {
opts ...options.Lister[options.FindOneAndReplaceOptions]) error {
res, err := m.Collection.FindOneAndReplace(ctx, filter, replacement, opts...)
if err != nil {
return err
@@ -162,7 +168,7 @@ func (m *Model) FindOneAndReplace(ctx context.Context, v, filter, replacement an
// FindOneAndUpdate finds a single document and updates it.
func (m *Model) FindOneAndUpdate(ctx context.Context, v, filter, update any,
opts ...*mopt.FindOneAndUpdateOptions) error {
opts ...options.Lister[options.FindOneAndUpdateOptions]) error {
res, err := m.Collection.FindOneAndUpdate(ctx, filter, update, opts...)
if err != nil {
return err
@@ -171,8 +177,8 @@ func (m *Model) FindOneAndUpdate(ctx context.Context, v, filter, update any,
return res.Decode(v)
}
// AbortTransaction implements the mongo.Session interface.
func (w *wrappedSession) AbortTransaction(ctx context.Context) (err error) {
// AbortTransaction implements the mongo.session interface.
func (w *Session) AbortTransaction(ctx context.Context) (err error) {
ctx, span := startSpan(ctx, abortTransaction)
defer func() {
endSpan(span, err)
@@ -184,12 +190,12 @@ func (w *wrappedSession) AbortTransaction(ctx context.Context) (err error) {
logDuration(ctx, w.name, abortTransaction, starTime, err)
}()
return w.Session.AbortTransaction(ctx)
return w.session.AbortTransaction(ctx)
}, acceptable)
}
// CommitTransaction implements the mongo.Session interface.
func (w *wrappedSession) CommitTransaction(ctx context.Context) (err error) {
// CommitTransaction implements the mongo.session interface.
func (w *Session) CommitTransaction(ctx context.Context) (err error) {
ctx, span := startSpan(ctx, commitTransaction)
defer func() {
endSpan(span, err)
@@ -201,15 +207,15 @@ func (w *wrappedSession) CommitTransaction(ctx context.Context) (err error) {
logDuration(ctx, w.name, commitTransaction, starTime, err)
}()
return w.Session.CommitTransaction(ctx)
return w.session.CommitTransaction(ctx)
}, acceptable)
}
// WithTransaction implements the mongo.Session interface.
func (w *wrappedSession) WithTransaction(
// WithTransaction implements the mongo.session interface.
func (w *Session) WithTransaction(
ctx context.Context,
fn func(sessCtx mongo.SessionContext) (any, error),
opts ...*mopt.TransactionOptions,
fn func(sessCtx context.Context) (any, error),
opts ...options.Lister[options.TransactionOptions],
) (res any, err error) {
ctx, span := startSpan(ctx, withTransaction)
defer func() {
@@ -222,15 +228,15 @@ func (w *wrappedSession) WithTransaction(
logDuration(ctx, w.name, withTransaction, starTime, err)
}()
res, err = w.Session.WithTransaction(ctx, fn, opts...)
res, err = w.session.WithTransaction(ctx, fn, opts...)
return err
}, acceptable)
return
}
// EndSession implements the mongo.Session interface.
func (w *wrappedSession) EndSession(ctx context.Context) {
// EndSession implements the mongo.session interface.
func (w *Session) EndSession(ctx context.Context) {
var err error
ctx, span := startSpan(ctx, endSession)
defer func() {
@@ -243,7 +249,34 @@ func (w *wrappedSession) EndSession(ctx context.Context) {
logDuration(ctx, w.name, endSession, starTime, err)
}()
w.Session.EndSession(ctx)
w.session.EndSession(ctx)
return nil
}, acceptable)
}
type (
// for unit test
monClient interface {
StartSession(opts ...options.Lister[options.SessionOptions]) (monSession, error)
}
monSession interface {
AbortTransaction(ctx context.Context) error
CommitTransaction(ctx context.Context) error
EndSession(ctx context.Context)
WithTransaction(ctx context.Context, fn func(sessCtx context.Context) (any, error),
opts ...options.Lister[options.TransactionOptions]) (any, error)
}
)
type wrappedMonClient struct {
c *mongo.Client
}
// StartSession starts a new session using the underlying *mongo.Client.
// It implements the monClient interface.
// This is used to allow mocking in unit tests.
func (m *wrappedMonClient) StartSession(opts ...options.Lister[options.SessionOptions]) (
monSession, error) {
return m.c.StartSession(opts...)
}

View File

@@ -0,0 +1,145 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: model.go
//
// Generated by this command:
//
// mockgen -package mon -destination model_mock.go -source model.go monClient monSession
//
// Package mon is a generated GoMock package.
package mon
import (
context "context"
reflect "reflect"
options "go.mongodb.org/mongo-driver/v2/mongo/options"
gomock "go.uber.org/mock/gomock"
)
// MockmonClient is a mock of monClient interface.
type MockmonClient struct {
ctrl *gomock.Controller
recorder *MockmonClientMockRecorder
isgomock struct{}
}
// MockmonClientMockRecorder is the mock recorder for MockmonClient.
type MockmonClientMockRecorder struct {
mock *MockmonClient
}
// NewMockmonClient creates a new mock instance.
func NewMockmonClient(ctrl *gomock.Controller) *MockmonClient {
mock := &MockmonClient{ctrl: ctrl}
mock.recorder = &MockmonClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockmonClient) EXPECT() *MockmonClientMockRecorder {
return m.recorder
}
// StartSession mocks base method.
func (m *MockmonClient) StartSession(opts ...options.Lister[options.SessionOptions]) (monSession, error) {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "StartSession", varargs...)
ret0, _ := ret[0].(monSession)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// StartSession indicates an expected call of StartSession.
func (mr *MockmonClientMockRecorder) StartSession(opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartSession", reflect.TypeOf((*MockmonClient)(nil).StartSession), opts...)
}
// MockmonSession is a mock of monSession interface.
type MockmonSession struct {
ctrl *gomock.Controller
recorder *MockmonSessionMockRecorder
isgomock struct{}
}
// MockmonSessionMockRecorder is the mock recorder for MockmonSession.
type MockmonSessionMockRecorder struct {
mock *MockmonSession
}
// NewMockmonSession creates a new mock instance.
func NewMockmonSession(ctrl *gomock.Controller) *MockmonSession {
mock := &MockmonSession{ctrl: ctrl}
mock.recorder = &MockmonSessionMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockmonSession) EXPECT() *MockmonSessionMockRecorder {
return m.recorder
}
// AbortTransaction mocks base method.
func (m *MockmonSession) AbortTransaction(ctx context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AbortTransaction", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// AbortTransaction indicates an expected call of AbortTransaction.
func (mr *MockmonSessionMockRecorder) AbortTransaction(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AbortTransaction", reflect.TypeOf((*MockmonSession)(nil).AbortTransaction), ctx)
}
// CommitTransaction mocks base method.
func (m *MockmonSession) CommitTransaction(ctx context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CommitTransaction", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// CommitTransaction indicates an expected call of CommitTransaction.
func (mr *MockmonSessionMockRecorder) CommitTransaction(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CommitTransaction", reflect.TypeOf((*MockmonSession)(nil).CommitTransaction), ctx)
}
// EndSession mocks base method.
func (m *MockmonSession) EndSession(ctx context.Context) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "EndSession", ctx)
}
// EndSession indicates an expected call of EndSession.
func (mr *MockmonSessionMockRecorder) EndSession(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EndSession", reflect.TypeOf((*MockmonSession)(nil).EndSession), ctx)
}
// WithTransaction mocks base method.
func (m *MockmonSession) WithTransaction(ctx context.Context, fn func(context.Context) (any, error), opts ...options.Lister[options.TransactionOptions]) (any, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, fn}
for _, a := range opts {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "WithTransaction", varargs...)
ret0, _ := ret[0].(any)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// WithTransaction indicates an expected call of WithTransaction.
func (mr *MockmonSessionMockRecorder) WithTransaction(ctx, fn any, opts ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, fn}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTransaction", reflect.TypeOf((*MockmonSession)(nil).WithTransaction), varargs...)
}

View File

@@ -2,224 +2,242 @@ package mon
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
"github.com/zeromicro/go-zero/core/breaker"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/drivertest"
"go.uber.org/mock/gomock"
)
func TestModel_StartSession(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
m := createModel(mt)
sess, err := m.StartSession()
assert.Nil(t, err)
defer sess.EndSession(context.Background())
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockMonCollection := NewMockmonCollection(ctrl)
mockedMonClient := NewMockmonClient(ctrl)
mockMonSession := NewMockmonSession(ctrl)
warpSession := &Session{
session: mockMonSession,
name: "",
brk: breaker.GetBreaker("localhost"),
}
_, err = sess.WithTransaction(context.Background(), func(sessCtx mongo.SessionContext) (any, error) {
_ = sessCtx.StartTransaction()
sessCtx.Client().Database("1")
sessCtx.EndSession(context.Background())
return nil, nil
})
assert.Nil(t, err)
assert.NoError(t, sess.CommitTransaction(context.Background()))
assert.Error(t, sess.AbortTransaction(context.Background()))
m := newTestModel("foo", mockedMonClient, mockMonCollection, breaker.GetBreaker("test"))
mockedMonClient.EXPECT().StartSession(gomock.Any()).Return(warpSession, errors.New("error"))
_, err := m.StartSession()
assert.NotNil(t, err)
mockedMonClient.EXPECT().StartSession(gomock.Any()).Return(warpSession, nil)
sess, err := m.StartSession()
assert.Nil(t, err)
defer sess.EndSession(context.Background())
mockMonSession.EXPECT().WithTransaction(gomock.Any(), gomock.Any()).Return(nil, nil)
mockMonSession.EXPECT().CommitTransaction(gomock.Any()).Return(nil)
mockMonSession.EXPECT().AbortTransaction(gomock.Any()).Return(nil)
mockMonSession.EXPECT().EndSession(gomock.Any())
_, err = sess.WithTransaction(context.Background(), func(sessCtx context.Context) (any, error) {
// _ = sessCtx.StartTransaction()
// sessCtx.Client().Database("1")
// sessCtx.EndSession(context.Background())
return nil, nil
})
assert.Nil(t, err)
assert.NoError(t, sess.CommitTransaction(context.Background()))
assert.NoError(t, sess.AbortTransaction(context.Background()))
}
func TestModel_Aggregate(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
m := createModel(mt)
find := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.FirstBatch,
bson.D{
{Key: "name", Value: "John"},
})
getMore := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.NextBatch,
bson.D{
{Key: "name", Value: "Mary"},
})
killCursors := mtest.CreateCursorResponse(
0,
"DBName.CollectionName",
mtest.NextBatch)
mt.AddMockResponses(find, getMore, killCursors)
var result []any
err := m.Aggregate(context.Background(), &result, mongo.Pipeline{})
assert.Nil(t, err)
assert.Equal(t, 2, len(result))
assert.Equal(t, "John", result[0].(bson.D).Map()["name"])
assert.Equal(t, "Mary", result[1].(bson.D).Map()["name"])
triggerBreaker(m)
assert.Equal(t, errDummy, m.Aggregate(context.Background(), &result, mongo.Pipeline{}))
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockMonCollection := NewMockmonCollection(ctrl)
mockedMonClient := NewMockmonClient(ctrl)
cursor, err := mongo.NewCursorFromDocuments([]any{
bson.M{
"name": "John",
},
bson.M{
"name": "Mary",
},
}, nil, nil)
assert.NoError(t, err)
mockMonCollection.EXPECT().Aggregate(gomock.Any(), gomock.Any(), gomock.Any()).Return(cursor, nil)
m := newTestModel("foo", mockedMonClient, mockMonCollection, breaker.GetBreaker("test"))
var result []bson.M
err = m.Aggregate(context.Background(), &result, bson.D{})
assert.Nil(t, err)
assert.Equal(t, 2, len(result))
assert.Equal(t, "John", result[0]["name"])
assert.Equal(t, "Mary", result[1]["name"])
triggerBreaker(m)
assert.Equal(t, errDummy, m.Aggregate(context.Background(), &result, bson.D{}))
}
func TestModel_DeleteMany(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
m := createModel(mt)
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
val, err := m.DeleteMany(context.Background(), bson.D{})
assert.Nil(t, err)
assert.Equal(t, int64(1), val)
triggerBreaker(m)
_, err = m.DeleteMany(context.Background(), bson.D{})
assert.Equal(t, errDummy, err)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockMonCollection := NewMockmonCollection(ctrl)
mockedMonClient := NewMockmonClient(ctrl)
mockMonCollection.EXPECT().DeleteMany(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.DeleteResult{}, nil)
m := newTestModel("foo", mockedMonClient, mockMonCollection, breaker.GetBreaker("test"))
_, err := m.DeleteMany(context.Background(), bson.D{})
assert.Nil(t, err)
triggerBreaker(m)
_, err = m.DeleteMany(context.Background(), bson.D{})
assert.Equal(t, errDummy, err)
}
func TestModel_DeleteOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
m := createModel(mt)
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
val, err := m.DeleteOne(context.Background(), bson.D{})
assert.Nil(t, err)
assert.Equal(t, int64(1), val)
triggerBreaker(m)
_, err = m.DeleteOne(context.Background(), bson.D{})
assert.Equal(t, errDummy, err)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockMonCollection := NewMockmonCollection(ctrl)
mockedMonClient := NewMockmonClient(ctrl)
mockMonCollection.EXPECT().DeleteOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.DeleteResult{}, nil)
m := newTestModel("foo", mockedMonClient, mockMonCollection, breaker.GetBreaker("test"))
_, err := m.DeleteOne(context.Background(), bson.D{})
assert.Nil(t, err)
triggerBreaker(m)
_, err = m.DeleteOne(context.Background(), bson.D{})
assert.Equal(t, errDummy, err)
}
func TestModel_Find(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
m := createModel(mt)
find := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.FirstBatch,
bson.D{
{Key: "name", Value: "John"},
})
getMore := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.NextBatch,
bson.D{
{Key: "name", Value: "Mary"},
})
killCursors := mtest.CreateCursorResponse(
0,
"DBName.CollectionName",
mtest.NextBatch)
mt.AddMockResponses(find, getMore, killCursors)
var result []any
err := m.Find(context.Background(), &result, bson.D{})
assert.Nil(t, err)
assert.Equal(t, 2, len(result))
assert.Equal(t, "John", result[0].(bson.D).Map()["name"])
assert.Equal(t, "Mary", result[1].(bson.D).Map()["name"])
triggerBreaker(m)
assert.Equal(t, errDummy, m.Find(context.Background(), &result, bson.D{}))
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockMonCollection := NewMockmonCollection(ctrl)
mockedMonClient := NewMockmonClient(ctrl)
cursor, err := mongo.NewCursorFromDocuments([]any{
bson.M{
"name": "John",
},
bson.M{
"name": "Mary",
},
}, nil, nil)
assert.NoError(t, err)
mockMonCollection.EXPECT().Find(gomock.Any(), gomock.Any(), gomock.Any()).Return(cursor, nil)
m := newTestModel("foo", mockedMonClient, mockMonCollection, breaker.GetBreaker("test"))
var result []bson.M
err = m.Find(context.Background(), &result, bson.D{})
assert.Nil(t, err)
assert.Equal(t, 2, len(result))
assert.Equal(t, "John", result[0]["name"])
assert.Equal(t, "Mary", result[1]["name"])
triggerBreaker(m)
assert.Equal(t, errDummy, m.Find(context.Background(), &result, bson.D{}))
}
func TestModel_FindOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
m := createModel(mt)
find := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.FirstBatch,
bson.D{
{Key: "name", Value: "John"},
})
killCursors := mtest.CreateCursorResponse(
0,
"DBName.CollectionName",
mtest.NextBatch)
mt.AddMockResponses(find, killCursors)
var result bson.D
err := m.FindOne(context.Background(), &result, bson.D{})
assert.Nil(t, err)
assert.Equal(t, "John", result.Map()["name"])
triggerBreaker(m)
assert.Equal(t, errDummy, m.FindOne(context.Background(), &result, bson.D{}))
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockMonCollection := NewMockmonCollection(ctrl)
mockedMonClient := NewMockmonClient(ctrl)
mockMonCollection.EXPECT().FindOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(mongo.NewSingleResultFromDocument(bson.M{"name": "John"}, nil, nil))
m := newTestModel("foo", mockedMonClient, mockMonCollection, breaker.GetBreaker("test"))
var result bson.M
err := m.FindOne(context.Background(), &result, bson.D{})
assert.Nil(t, err)
assert.Equal(t, "John", result["name"])
triggerBreaker(m)
assert.Equal(t, errDummy, m.FindOne(context.Background(), &result, bson.D{}))
}
func TestModel_FindOneAndDelete(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
m := createModel(mt)
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "name", Value: "John"}}},
}...))
var result bson.D
err := m.FindOneAndDelete(context.Background(), &result, bson.D{})
assert.Nil(t, err)
assert.Equal(t, "John", result.Map()["name"])
triggerBreaker(m)
assert.Equal(t, errDummy, m.FindOneAndDelete(context.Background(), &result, bson.D{}))
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockMonCollection := NewMockmonCollection(ctrl)
mockedMonClient := NewMockmonClient(ctrl)
mockMonCollection.EXPECT().FindOneAndDelete(gomock.Any(), gomock.Any(), gomock.Any()).Return(mongo.NewSingleResultFromDocument(bson.M{"name": "John"}, nil, nil))
m := newTestModel("foo", mockedMonClient, mockMonCollection, breaker.GetBreaker("test"))
var result bson.M
err := m.FindOneAndDelete(context.Background(), &result, bson.M{})
assert.Nil(t, err)
assert.Equal(t, "John", result["name"])
triggerBreaker(m)
assert.Equal(t, errDummy, m.FindOneAndDelete(context.Background(), &result, bson.D{}))
}
func TestModel_FindOneAndReplace(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
m := createModel(mt)
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "name", Value: "John"}}},
}...))
var result bson.D
err := m.FindOneAndReplace(context.Background(), &result, bson.D{}, bson.D{
{Key: "name", Value: "Mary"},
})
assert.Nil(t, err)
assert.Equal(t, "John", result.Map()["name"])
triggerBreaker(m)
assert.Equal(t, errDummy, m.FindOneAndReplace(context.Background(), &result, bson.D{}, bson.D{
{Key: "name", Value: "Mary"},
}))
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockMonCollection := NewMockmonCollection(ctrl)
mockedMonClient := NewMockmonClient(ctrl)
mockMonCollection.EXPECT().FindOneAndReplace(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mongo.NewSingleResultFromDocument(bson.M{"name": "John"}, nil, nil))
m := newTestModel("foo", mockedMonClient, mockMonCollection, breaker.GetBreaker("test"))
var result bson.M
err := m.FindOneAndReplace(context.Background(), &result, bson.D{}, bson.D{
{Key: "name", Value: "Mary"},
})
assert.Nil(t, err)
assert.Equal(t, "John", result["name"])
triggerBreaker(m)
assert.Equal(t, errDummy, m.FindOneAndReplace(context.Background(), &result, bson.D{}, bson.D{
{Key: "name", Value: "Mary"},
}))
}
func TestModel_FindOneAndUpdate(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
m := createModel(mt)
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "name", Value: "John"}}},
}...))
var result bson.D
err := m.FindOneAndUpdate(context.Background(), &result, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
})
assert.Nil(t, err)
assert.Equal(t, "John", result.Map()["name"])
triggerBreaker(m)
assert.Equal(t, errDummy, m.FindOneAndUpdate(context.Background(), &result, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
}))
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockMonCollection := NewMockmonCollection(ctrl)
mockedMonClient := NewMockmonClient(ctrl)
mockMonCollection.EXPECT().FindOneAndUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mongo.NewSingleResultFromDocument(bson.M{"name": "John"}, nil, nil))
m := newTestModel("foo", mockedMonClient, mockMonCollection, breaker.GetBreaker("test"))
var result bson.M
err := m.FindOneAndUpdate(context.Background(), &result, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
})
}
assert.Nil(t, err)
assert.Equal(t, "John", result["name"])
triggerBreaker(m)
assert.Equal(t, errDummy, m.FindOneAndUpdate(context.Background(), &result, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
}))
func createModel(mt *mtest.T) *Model {
Inject(mt.Name(), mt.Client)
return MustNewModel(mt.Name(), mt.DB.Name(), mt.Coll.Name())
}
func triggerBreaker(m *Model) {
m.Collection.(*decoratedCollection).brk = new(dropBreaker)
}
func TestMustNewModel(t *testing.T) {
Inject("mongodb://localhost:27017", &mongo.Client{})
MustNewModel("mongodb://localhost:27017", "test", "test")
}
func TestNewModel(t *testing.T) {
NewModel("mongo://localhost:27018", "test", "test")
Inject("mongodb://localhost:27018", &mongo.Client{})
NewModel("mongodb://localhost:27018", "test", "test")
}
func Test_newModel(t *testing.T) {
Inject("mongodb://localhost:27019", &mongo.Client{})
newModel("mongodb://localhost:27019", nil, nil, nil)
}
func Test_mockMonClient_StartSession(t *testing.T) {
md := drivertest.NewMockDeployment()
opts := options.Client()
opts.Deployment = md
client, err := mongo.Connect(opts)
assert.Nil(t, err)
m := wrappedMonClient{
c: client,
}
_, err = m.StartSession()
assert.Nil(t, err)
}
func newTestModel(name string, cli monClient, coll monCollection, brk breaker.Breaker,
opts ...Option) *Model {
return &Model{
name: name,
Collection: newTestCollection(coll, breaker.GetBreaker("localhost")),
cli: cli,
brk: brk,
opts: opts,
}
}

View File

@@ -5,9 +5,8 @@ import (
"time"
"github.com/zeromicro/go-zero/core/syncx"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
mopt "go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo/options"
)
const defaultTimeout = time.Second * 3
@@ -20,16 +19,16 @@ var (
type (
// Option defines the method to customize a mongo model.
Option func(opts *options)
Option func(opts *clientOptions)
// TypeCodec is a struct that stores specific type Encoder/Decoder.
TypeCodec struct {
ValueType reflect.Type
Encoder bsoncodec.ValueEncoder
Decoder bsoncodec.ValueDecoder
Encoder bson.ValueEncoder
Decoder bson.ValueDecoder
}
options = mopt.ClientOptions
clientOptions = options.ClientOptions
)
// DisableLog disables logging of mongo commands, includes info and slow logs.
@@ -50,14 +49,14 @@ func SetSlowThreshold(threshold time.Duration) {
// WithTimeout set the mon client operation timeout.
func WithTimeout(timeout time.Duration) Option {
return func(opts *options) {
return func(opts *clientOptions) {
opts.SetTimeout(timeout)
}
}
// WithTypeCodec registers TypeCodecs to convert custom types.
func WithTypeCodec(typeCodecs ...TypeCodec) Option {
return func(opts *options) {
return func(opts *clientOptions) {
registry := bson.NewRegistry()
for _, v := range typeCodecs {
registry.RegisterTypeEncoder(v.ValueType, v.Encoder)
@@ -68,7 +67,7 @@ func WithTypeCodec(typeCodecs ...TypeCodec) Option {
}
func defaultTimeoutOption() Option {
return func(opts *options) {
return func(opts *clientOptions) {
opts.SetTimeout(defaultTimeout)
}
}

View File

@@ -7,9 +7,8 @@ import (
"time"
"github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
mopt "go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo/options"
)
func TestSetSlowThreshold(t *testing.T) {
@@ -19,13 +18,13 @@ func TestSetSlowThreshold(t *testing.T) {
}
func Test_defaultTimeoutOption(t *testing.T) {
opts := mopt.Client()
opts := options.Client()
defaultTimeoutOption()(opts)
assert.Equal(t, defaultTimeout, *opts.Timeout)
}
func TestWithTimeout(t *testing.T) {
opts := mopt.Client()
opts := options.Client()
WithTimeout(time.Second)(opts)
assert.Equal(t, time.Second, *opts.Timeout)
}
@@ -57,10 +56,11 @@ func TestDisableInfoLog(t *testing.T) {
}
func TestWithRegistryForTimestampRegisterType(t *testing.T) {
opts := mopt.Client()
opts := options.Client()
// mongoDateTimeEncoder allow user convert time.Time to primitive.DateTime.
var mongoDateTimeEncoder bsoncodec.ValueEncoderFunc = func(ect bsoncodec.EncodeContext, w bsonrw.ValueWriter, value reflect.Value) error {
var mongoDateTimeEncoder bson.ValueEncoderFunc = func(ect bson.EncodeContext,
w bson.ValueWriter, value reflect.Value) error {
// Use reflect, determine if it can be converted to time.Time.
dec, ok := value.Interface().(time.Time)
if !ok {
@@ -70,7 +70,8 @@ func TestWithRegistryForTimestampRegisterType(t *testing.T) {
}
// mongoDateTimeEncoder allow user convert primitive.DateTime to time.Time.
var mongoDateTimeDecoder bsoncodec.ValueDecoderFunc = func(ect bsoncodec.DecodeContext, r bsonrw.ValueReader, value reflect.Value) error {
var mongoDateTimeDecoder bson.ValueDecoderFunc = func(ect bson.DecodeContext,
r bson.ValueReader, value reflect.Value) error {
primTime, err := r.ReadDateTime()
if err != nil {
return fmt.Errorf("error reading primitive.DateTime from ValueReader: %v", err)

View File

@@ -5,7 +5,7 @@ import (
"github.com/zeromicro/go-zero/core/errorx"
"github.com/zeromicro/go-zero/core/trace"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
oteltrace "go.opentelemetry.io/otel/trace"

View File

@@ -8,8 +8,8 @@ import (
"github.com/zeromicro/go-zero/core/stores/mon"
"github.com/zeromicro/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/core/syncx"
"go.mongodb.org/mongo-driver/mongo"
mopt "go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
)
var (
@@ -78,7 +78,7 @@ func (mm *Model) DelCache(ctx context.Context, keys ...string) error {
// DeleteOne deletes the document with given filter, and remove it from cache.
func (mm *Model) DeleteOne(ctx context.Context, key string, filter any,
opts ...*mopt.DeleteOptions) (int64, error) {
opts ...options.Lister[options.DeleteOneOptions]) (int64, error) {
val, err := mm.Model.DeleteOne(ctx, filter, opts...)
if err != nil {
return 0, err
@@ -93,13 +93,13 @@ func (mm *Model) DeleteOne(ctx context.Context, key string, filter any,
// DeleteOneNoCache deletes the document with given filter.
func (mm *Model) DeleteOneNoCache(ctx context.Context, filter any,
opts ...*mopt.DeleteOptions) (int64, error) {
opts ...options.Lister[options.DeleteOneOptions]) (int64, error) {
return mm.Model.DeleteOne(ctx, filter, opts...)
}
// FindOne unmarshals a record into v with given key and query.
func (mm *Model) FindOne(ctx context.Context, key string, v, filter any,
opts ...*mopt.FindOneOptions) error {
opts ...options.Lister[options.FindOneOptions]) error {
return mm.cache.TakeCtx(ctx, v, key, func(v any) error {
return mm.Model.FindOne(ctx, v, filter, opts...)
})
@@ -107,13 +107,13 @@ func (mm *Model) FindOne(ctx context.Context, key string, v, filter any,
// FindOneNoCache unmarshals a record into v with query, without cache.
func (mm *Model) FindOneNoCache(ctx context.Context, v, filter any,
opts ...*mopt.FindOneOptions) error {
opts ...options.Lister[options.FindOneOptions]) error {
return mm.Model.FindOne(ctx, v, filter, opts...)
}
// FindOneAndDelete deletes the document with given filter, and unmarshals it into v.
func (mm *Model) FindOneAndDelete(ctx context.Context, key string, v, filter any,
opts ...*mopt.FindOneAndDeleteOptions) error {
opts ...options.Lister[options.FindOneAndDeleteOptions]) error {
if err := mm.Model.FindOneAndDelete(ctx, v, filter, opts...); err != nil {
return err
}
@@ -123,13 +123,13 @@ func (mm *Model) FindOneAndDelete(ctx context.Context, key string, v, filter any
// FindOneAndDeleteNoCache deletes the document with given filter, and unmarshals it into v.
func (mm *Model) FindOneAndDeleteNoCache(ctx context.Context, v, filter any,
opts ...*mopt.FindOneAndDeleteOptions) error {
opts ...options.Lister[options.FindOneAndDeleteOptions]) error {
return mm.Model.FindOneAndDelete(ctx, v, filter, opts...)
}
// FindOneAndReplace replaces the document with given filter with replacement, and unmarshals it into v.
func (mm *Model) FindOneAndReplace(ctx context.Context, key string, v, filter any,
replacement any, opts ...*mopt.FindOneAndReplaceOptions) error {
replacement any, opts ...options.Lister[options.FindOneAndReplaceOptions]) error {
if err := mm.Model.FindOneAndReplace(ctx, v, filter, replacement, opts...); err != nil {
return err
}
@@ -139,13 +139,13 @@ func (mm *Model) FindOneAndReplace(ctx context.Context, key string, v, filter an
// FindOneAndReplaceNoCache replaces the document with given filter with replacement, and unmarshals it into v.
func (mm *Model) FindOneAndReplaceNoCache(ctx context.Context, v, filter any,
replacement any, opts ...*mopt.FindOneAndReplaceOptions) error {
replacement any, opts ...options.Lister[options.FindOneAndReplaceOptions]) error {
return mm.Model.FindOneAndReplace(ctx, v, filter, replacement, opts...)
}
// FindOneAndUpdate updates the document with given filter with update, and unmarshals it into v.
func (mm *Model) FindOneAndUpdate(ctx context.Context, key string, v, filter any,
update any, opts ...*mopt.FindOneAndUpdateOptions) error {
update any, opts ...options.Lister[options.FindOneAndUpdateOptions]) error {
if err := mm.Model.FindOneAndUpdate(ctx, v, filter, update, opts...); err != nil {
return err
}
@@ -155,7 +155,7 @@ func (mm *Model) FindOneAndUpdate(ctx context.Context, key string, v, filter any
// FindOneAndUpdateNoCache updates the document with given filter with update, and unmarshals it into v.
func (mm *Model) FindOneAndUpdateNoCache(ctx context.Context, v, filter any,
update any, opts ...*mopt.FindOneAndUpdateOptions) error {
update any, opts ...options.Lister[options.FindOneAndUpdateOptions]) error {
return mm.Model.FindOneAndUpdate(ctx, v, filter, update, opts...)
}
@@ -166,7 +166,7 @@ func (mm *Model) GetCache(key string, v any) error {
// InsertOne inserts a single document into the collection, and remove the cache placeholder.
func (mm *Model) InsertOne(ctx context.Context, key string, document any,
opts ...*mopt.InsertOneOptions) (*mongo.InsertOneResult, error) {
opts ...options.Lister[options.InsertOneOptions]) (*mongo.InsertOneResult, error) {
res, err := mm.Model.InsertOne(ctx, document, opts...)
if err != nil {
return nil, err
@@ -181,13 +181,13 @@ func (mm *Model) InsertOne(ctx context.Context, key string, document any,
// InsertOneNoCache inserts a single document into the collection.
func (mm *Model) InsertOneNoCache(ctx context.Context, document any,
opts ...*mopt.InsertOneOptions) (*mongo.InsertOneResult, error) {
opts ...options.Lister[options.InsertOneOptions]) (*mongo.InsertOneResult, error) {
return mm.Model.InsertOne(ctx, document, opts...)
}
// ReplaceOne replaces a single document in the collection, and remove the cache.
func (mm *Model) ReplaceOne(ctx context.Context, key string, filter, replacement any,
opts ...*mopt.ReplaceOptions) (*mongo.UpdateResult, error) {
opts ...options.Lister[options.ReplaceOptions]) (*mongo.UpdateResult, error) {
res, err := mm.Model.ReplaceOne(ctx, filter, replacement, opts...)
if err != nil {
return nil, err
@@ -202,7 +202,7 @@ func (mm *Model) ReplaceOne(ctx context.Context, key string, filter, replacement
// ReplaceOneNoCache replaces a single document in the collection.
func (mm *Model) ReplaceOneNoCache(ctx context.Context, filter, replacement any,
opts ...*mopt.ReplaceOptions) (*mongo.UpdateResult, error) {
opts ...options.Lister[options.ReplaceOptions]) (*mongo.UpdateResult, error) {
return mm.Model.ReplaceOne(ctx, filter, replacement, opts...)
}
@@ -213,7 +213,7 @@ func (mm *Model) SetCache(key string, v any) error {
// UpdateByID updates the document with given id with update, and remove the cache.
func (mm *Model) UpdateByID(ctx context.Context, key string, id, update any,
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error) {
opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) {
res, err := mm.Model.UpdateByID(ctx, id, update, opts...)
if err != nil {
return nil, err
@@ -228,13 +228,13 @@ func (mm *Model) UpdateByID(ctx context.Context, key string, id, update any,
// UpdateByIDNoCache updates the document with given id with update.
func (mm *Model) UpdateByIDNoCache(ctx context.Context, id, update any,
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error) {
opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) {
return mm.Model.UpdateByID(ctx, id, update, opts...)
}
// UpdateMany updates the documents that match filter with update, and remove the cache.
func (mm *Model) UpdateMany(ctx context.Context, keys []string, filter, update any,
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error) {
opts ...options.Lister[options.UpdateManyOptions]) (*mongo.UpdateResult, error) {
res, err := mm.Model.UpdateMany(ctx, filter, update, opts...)
if err != nil {
return nil, err
@@ -249,13 +249,13 @@ func (mm *Model) UpdateMany(ctx context.Context, keys []string, filter, update a
// UpdateManyNoCache updates the documents that match filter with update.
func (mm *Model) UpdateManyNoCache(ctx context.Context, filter, update any,
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error) {
opts ...options.Lister[options.UpdateManyOptions]) (*mongo.UpdateResult, error) {
return mm.Model.UpdateMany(ctx, filter, update, opts...)
}
// UpdateOne updates the first document that matches filter with update, and remove the cache.
func (mm *Model) UpdateOne(ctx context.Context, key string, filter, update any,
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error) {
opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) {
res, err := mm.Model.UpdateOne(ctx, filter, update, opts...)
if err != nil {
return nil, err
@@ -270,6 +270,6 @@ func (mm *Model) UpdateOne(ctx context.Context, key string, filter, update any,
// UpdateOneNoCache updates the first document that matches filter with update.
func (mm *Model) UpdateOneNoCache(ctx context.Context, filter, update any,
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error) {
opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) {
return mm.Model.UpdateOne(ctx, filter, update, opts...)
}

View File

@@ -8,506 +8,519 @@ import (
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/mon"
"github.com/zeromicro/go-zero/core/stores/redis"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.uber.org/mock/gomock"
)
func TestNewModel(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
_, err := newModel("foo", mt.DB.Name(), mt.Coll.Name(), nil)
assert.NotNil(mt, err)
func TestMustNewModel(t *testing.T) {
s, err := miniredis.Run()
assert.Nil(t, err)
original := logx.ExitOnFatal.True()
logx.ExitOnFatal.Set(false)
defer logx.ExitOnFatal.Set(original)
assert.Panics(t, func() {
MustNewModel("foo", "db", "collectino", cache.CacheConf{
cache.NodeConf{
RedisConf: redis.RedisConf{
Host: s.Addr(),
Type: redis.NodeType,
},
Weight: 100,
}})
})
}
func TestMustNewNodeModel(t *testing.T) {
s, err := miniredis.Run()
assert.Nil(t, err)
original := logx.ExitOnFatal.True()
logx.ExitOnFatal.Set(false)
defer logx.ExitOnFatal.Set(original)
assert.Panics(t, func() {
MustNewNodeModel("foo", "db", "collectino", redis.New(s.Addr()))
})
}
func TestNewModel(t *testing.T) {
s, err := miniredis.Run()
assert.Nil(t, err)
_, err = NewModel("foo", "db", "coll", cache.CacheConf{
cache.NodeConf{
RedisConf: redis.RedisConf{
Host: s.Addr(),
Type: redis.NodeType,
},
Weight: 100,
},
})
assert.Error(t, err)
}
func TestNewNodeModel(t *testing.T) {
_, err := NewNodeModel("foo", "db", "coll", nil)
assert.NotNil(t, err)
}
func TestNewModelWithCache(t *testing.T) {
_, err := NewModelWithCache("foo", "db", "coll", nil)
assert.NotNil(t, err)
}
func Test_newModel(t *testing.T) {
mon.Inject("mongodb://localhost:27018", &mongo.Client{})
model, err := newModel("mongodb://localhost:27018", "db", "collection", nil)
assert.Nil(t, err)
assert.NotNil(t, model)
}
func TestModel_DelCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
assert.Nil(t, m.cache.Set("bar", "baz"))
assert.Nil(t, m.DelCache(context.Background(), "foo", "bar"))
var v string
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
assert.True(t, m.cache.IsNotFound(m.cache.Get("bar", &v)))
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
m := createModel(t, mockCollection)
assert.Nil(t, m.cache.Set("foo", "bar"))
assert.Nil(t, m.cache.Set("bar", "baz"))
assert.Nil(t, m.DelCache(context.Background(), "foo", "bar"))
var v string
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
assert.True(t, m.cache.IsNotFound(m.cache.Get("bar", &v)))
}
func TestModel_DeleteOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
val, err := m.DeleteOne(context.Background(), "foo", bson.D{{Key: "foo", Value: "bar"}})
assert.Nil(t, err)
assert.Equal(t, int64(1), val)
var v string
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
_, err = m.DeleteOne(context.Background(), "foo", bson.D{{Key: "foo", Value: "bar"}})
assert.NotNil(t, err)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
mockCollection.EXPECT().DeleteOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.DeleteResult{DeletedCount: 1}, nil)
m := createModel(t, mockCollection)
assert.Nil(t, m.cache.Set("foo", "bar"))
val, err := m.DeleteOne(context.Background(), "foo", bson.D{{Key: "foo", Value: "bar"}})
assert.Nil(t, err)
assert.Equal(t, int64(1), val)
var v string
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
mockCollection.EXPECT().DeleteOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.DeleteResult{}, errMocked)
_, err = m.DeleteOne(context.Background(), "foo", bson.D{{Key: "foo", Value: "bar"}})
assert.NotNil(t, err)
m.cache = mockedCache{m.cache}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
_, err = m.DeleteOne(context.Background(), "foo", bson.D{{Key: "foo", Value: "bar"}})
assert.Equal(t, errMocked, err)
})
m.cache = mockedCache{m.cache}
mockCollection.EXPECT().DeleteOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.DeleteResult{}, nil)
_, err = m.DeleteOne(context.Background(), "foo", bson.D{{Key: "foo", Value: "bar"}})
assert.Equal(t, errMocked, err)
}
func TestModel_DeleteOneNoCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
val, err := m.DeleteOneNoCache(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Nil(t, err)
assert.Equal(t, int64(1), val)
var v string
assert.Nil(t, m.cache.Get("foo", &v))
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
mockCollection.EXPECT().DeleteOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.DeleteResult{DeletedCount: 1}, nil)
m := createModel(t, mockCollection)
assert.Nil(t, m.cache.Set("foo", "bar"))
val, err := m.DeleteOneNoCache(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Nil(t, err)
assert.Equal(t, int64(1), val)
var v string
assert.Nil(t, m.cache.Get("foo", &v))
}
func TestModel_FindOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
resp := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.FirstBatch,
bson.D{
{Key: "foo", Value: "bar"},
})
mt.AddMockResponses(resp)
m := createModel(t, mt)
var v struct {
Foo string `bson:"foo"`
}
assert.Nil(t, m.FindOne(context.Background(), "foo", &v, bson.D{}))
assert.Equal(t, "bar", v.Foo)
assert.Nil(t, m.cache.Set("foo", "bar"))
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
mockCollection.EXPECT().FindOne(gomock.Any(), gomock.Any(), gomock.Any()).
Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, nil), nil)
m := createModel(t, mockCollection)
var v struct {
Foo string `bson:"foo"`
}
assert.Nil(t, m.FindOne(context.Background(), "foo", &v, bson.D{}))
assert.Equal(t, "bar", v.Foo)
assert.Nil(t, m.cache.Set("foo", "bar"))
}
func TestModel_FindOneNoCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
resp := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.FirstBatch,
bson.D{
{Key: "foo", Value: "bar"},
})
mt.AddMockResponses(resp)
m := createModel(t, mt)
v := struct {
Foo string `bson:"foo"`
}{}
assert.Nil(t, m.FindOneNoCache(context.Background(), &v, bson.D{}))
assert.Equal(t, "bar", v.Foo)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
mockCollection.EXPECT().FindOne(gomock.Any(), gomock.Any(), gomock.Any()).
Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, nil), nil)
m := createModel(t, mockCollection)
v := struct {
Foo string `bson:"foo"`
}{}
assert.Nil(t, m.FindOneNoCache(context.Background(), &v, bson.D{}))
assert.Equal(t, "bar", v.Foo)
}
func TestModel_FindOneAndDelete(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
v := struct {
Foo string `bson:"foo"`
}{}
assert.Nil(t, m.FindOneAndDelete(context.Background(), "foo", &v, bson.D{}))
assert.Equal(t, "bar", v.Foo)
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
assert.NotNil(t, m.FindOneAndDelete(context.Background(), "foo", &v, bson.D{}))
m.cache = mockedCache{m.cache}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
assert.Equal(t, errMocked, m.FindOneAndDelete(context.Background(), "foo", &v, bson.D{}))
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
mockCollection.EXPECT().FindOneAndDelete(gomock.Any(), gomock.Any(), gomock.Any()).
Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, bson.NewRegistry()), nil)
m := createModel(t, mockCollection)
assert.Nil(t, m.cache.Set("foo", "bar"))
v := struct {
Foo string `bson:"foo"`
}{}
assert.Nil(t, m.FindOneAndDelete(context.Background(), "foo", &v, bson.D{}))
assert.Equal(t, "bar", v.Foo)
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
mockCollection.EXPECT().FindOneAndDelete(gomock.Any(), gomock.Any(), gomock.Any()).Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, bson.NewRegistry()), errMocked)
assert.NotNil(t, m.FindOneAndDelete(context.Background(), "foo", &v, bson.D{}))
m.cache = mockedCache{m.cache}
mockCollection.EXPECT().FindOneAndDelete(gomock.Any(), gomock.Any(), gomock.Any()).Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, bson.NewRegistry()), nil)
assert.Equal(t, errMocked, m.FindOneAndDelete(context.Background(), "foo", &v, bson.D{}))
}
func TestModel_FindOneAndDeleteNoCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
v := struct {
Foo string `bson:"foo"`
}{}
assert.Nil(t, m.FindOneAndDeleteNoCache(context.Background(), &v, bson.D{}))
assert.Equal(t, "bar", v.Foo)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
mockCollection.EXPECT().FindOneAndDelete(gomock.Any(), gomock.Any(), gomock.Any()).
Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, nil), nil)
m := createModel(t, mockCollection)
v := struct {
Foo string `bson:"foo"`
}{}
assert.Nil(t, m.FindOneAndDeleteNoCache(context.Background(), &v, bson.D{}))
assert.Equal(t, "bar", v.Foo)
}
func TestModel_FindOneAndReplace(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
v := struct {
Foo string `bson:"foo"`
}{}
assert.Nil(t, m.FindOneAndReplace(context.Background(), "foo", &v, bson.D{}, bson.D{
{Key: "name", Value: "Mary"},
}))
assert.Equal(t, "bar", v.Foo)
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
assert.NotNil(t, m.FindOneAndReplace(context.Background(), "foo", &v, bson.D{}, bson.D{
{Key: "name", Value: "Mary"},
}))
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
m := createModel(t, mockCollection)
assert.Nil(t, m.cache.Set("foo", "bar"))
v := struct {
Foo string `bson:"foo"`
}{}
mockCollection.EXPECT().FindOneAndReplace(gomock.Any(), gomock.Any(), gomock.Any()).
Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, nil), nil)
assert.Nil(t, m.FindOneAndReplace(context.Background(), "foo", &v, bson.D{}, bson.D{
{Key: "name", Value: "Mary"},
}))
assert.Equal(t, "bar", v.Foo)
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
mockCollection.EXPECT().FindOneAndReplace(gomock.Any(), gomock.Any(), gomock.Any()).
Return(mongo.NewSingleResultFromDocument(bson.M{"name": "Mary"}, nil, nil), errMocked)
assert.NotNil(t, m.FindOneAndReplace(context.Background(), "foo", &v, bson.D{}, bson.D{
{Key: "name", Value: "Mary"},
}))
m.cache = mockedCache{m.cache}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
assert.Equal(t, errMocked, m.FindOneAndReplace(context.Background(), "foo", &v, bson.D{}, bson.D{
{Key: "name", Value: "Mary"},
}))
})
m.cache = mockedCache{m.cache}
mockCollection.EXPECT().FindOneAndReplace(gomock.Any(), gomock.Any(), gomock.Any()).
Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, nil), nil)
assert.Equal(t, errMocked, m.FindOneAndReplace(context.Background(), "foo", &v, bson.D{}, bson.D{
{Key: "name", Value: "Mary"},
}))
}
func TestModel_FindOneAndReplaceNoCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
v := struct {
Foo string `bson:"foo"`
}{}
assert.Nil(t, m.FindOneAndReplaceNoCache(context.Background(), &v, bson.D{}, bson.D{
{Key: "name", Value: "Mary"},
}))
assert.Equal(t, "bar", v.Foo)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
m := createModel(t, mockCollection)
v := struct {
Foo string `bson:"foo"`
}{}
mockCollection.EXPECT().FindOneAndReplace(gomock.Any(), gomock.Any(), gomock.Any()).Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, nil), nil)
assert.Nil(t, m.FindOneAndReplaceNoCache(context.Background(), &v, bson.D{}, bson.D{
{Key: "name", Value: "Mary"},
}))
assert.Equal(t, "bar", v.Foo)
}
func TestModel_FindOneAndUpdate(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
v := struct {
Foo string `bson:"foo"`
}{}
assert.Nil(t, m.FindOneAndUpdate(context.Background(), "foo", &v, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
}))
assert.Equal(t, "bar", v.Foo)
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
assert.NotNil(t, m.FindOneAndUpdate(context.Background(), "foo", &v, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
}))
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
m := createModel(t, mockCollection)
assert.Nil(t, m.cache.Set("foo", "bar"))
v := struct {
Foo string `bson:"foo"`
}{}
mockCollection.EXPECT().FindOneAndUpdate(gomock.Any(), gomock.Any(), gomock.Any()).
Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, nil), nil)
assert.Nil(t, m.FindOneAndUpdate(context.Background(), "foo", &v, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
}))
assert.Equal(t, "bar", v.Foo)
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
mockCollection.EXPECT().FindOneAndUpdate(gomock.Any(), gomock.Any(), gomock.Any()).
Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, nil), errMocked)
assert.NotNil(t, m.FindOneAndUpdate(context.Background(), "foo", &v, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
}))
m.cache = mockedCache{m.cache}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
assert.Equal(t, errMocked, m.FindOneAndUpdate(context.Background(), "foo", &v, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
}))
})
m.cache = mockedCache{m.cache}
mockCollection.EXPECT().FindOneAndUpdate(gomock.Any(), gomock.Any(), gomock.Any()).
Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, nil), nil)
assert.Equal(t, errMocked, m.FindOneAndUpdate(context.Background(), "foo", &v, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
}))
}
func TestModel_FindOneAndUpdateNoCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
v := struct {
Foo string `bson:"foo"`
}{}
assert.Nil(t, m.FindOneAndUpdateNoCache(context.Background(), &v, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
}))
assert.Equal(t, "bar", v.Foo)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
m := createModel(t, mockCollection)
v := struct {
Foo string `bson:"foo"`
}{}
mockCollection.EXPECT().FindOneAndUpdate(gomock.Any(), gomock.Any(), gomock.Any()).
Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, nil), nil)
assert.Nil(t, m.FindOneAndUpdateNoCache(context.Background(), &v, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
}))
assert.Equal(t, "bar", v.Foo)
}
func TestModel_GetCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
m := createModel(t, mt)
assert.NotNil(t, m.cache)
assert.Nil(t, m.cache.Set("foo", "bar"))
var s string
assert.Nil(t, m.cache.Get("foo", &s))
assert.Equal(t, "bar", s)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
m := createModel(t, mockCollection)
assert.NotNil(t, m.cache)
assert.Nil(t, m.cache.Set("foo", "bar"))
var s string
assert.Nil(t, m.cache.Get("foo", &s))
assert.Equal(t, "bar", s)
}
func TestModel_InsertOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
resp, err := m.InsertOne(context.Background(), "foo", bson.D{
{Key: "name", Value: "Mary"},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
var v string
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
_, err = m.InsertOne(context.Background(), "foo", bson.D{
{Key: "name", Value: "Mary"},
})
assert.NotNil(t, err)
m.cache = mockedCache{m.cache}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
_, err = m.InsertOne(context.Background(), "foo", bson.D{
{Key: "name", Value: "Mary"},
})
assert.Equal(t, errMocked, err)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
m := createModel(t, mockCollection)
assert.Nil(t, m.cache.Set("foo", "bar"))
mockCollection.EXPECT().InsertOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.InsertOneResult{}, nil)
resp, err := m.InsertOne(context.Background(), "foo", bson.D{
{Key: "name", Value: "Mary"},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
var v string
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
mockCollection.EXPECT().InsertOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.InsertOneResult{}, errMocked)
_, err = m.InsertOne(context.Background(), "foo", bson.D{
{Key: "name", Value: "Mary"},
})
assert.NotNil(t, err)
m.cache = mockedCache{m.cache}
mockCollection.EXPECT().InsertOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.InsertOneResult{}, nil)
_, err = m.InsertOne(context.Background(), "foo", bson.D{
{Key: "name", Value: "Mary"},
})
assert.Equal(t, errMocked, err)
}
func TestModel_InsertOneNoCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
resp, err := m.InsertOneNoCache(context.Background(), bson.D{
{Key: "name", Value: "Mary"},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
m := createModel(t, mockCollection)
mockCollection.EXPECT().InsertOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.InsertOneResult{}, nil)
resp, err := m.InsertOneNoCache(context.Background(), bson.D{
{Key: "name", Value: "Mary"},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
}
func TestModel_ReplaceOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
resp, err := m.ReplaceOne(context.Background(), "foo", bson.D{}, bson.D{
{Key: "foo", Value: "baz"},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
var v string
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
_, err = m.ReplaceOne(context.Background(), "foo", bson.D{}, bson.D{
{Key: "foo", Value: "baz"},
})
assert.NotNil(t, err)
m.cache = mockedCache{m.cache}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
_, err = m.ReplaceOne(context.Background(), "foo", bson.D{}, bson.D{
{Key: "foo", Value: "baz"},
})
assert.Equal(t, errMocked, err)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
m := createModel(t, mockCollection)
assert.Nil(t, m.cache.Set("foo", "bar"))
mockCollection.EXPECT().ReplaceOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
resp, err := m.ReplaceOne(context.Background(), "foo", bson.D{}, bson.D{
{Key: "foo", Value: "baz"},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
var v string
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
mockCollection.EXPECT().ReplaceOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, errMocked)
_, err = m.ReplaceOne(context.Background(), "foo", bson.D{}, bson.D{
{Key: "foo", Value: "baz"},
})
assert.NotNil(t, err)
m.cache = mockedCache{m.cache}
mockCollection.EXPECT().ReplaceOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
_, err = m.ReplaceOne(context.Background(), "foo", bson.D{}, bson.D{
{Key: "foo", Value: "baz"},
})
assert.Equal(t, errMocked, err)
}
func TestModel_ReplaceOneNoCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
resp, err := m.ReplaceOneNoCache(context.Background(), bson.D{}, bson.D{
{Key: "foo", Value: "baz"},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
m := createModel(t, mockCollection)
mockCollection.EXPECT().ReplaceOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
resp, err := m.ReplaceOneNoCache(context.Background(), bson.D{}, bson.D{
{Key: "foo", Value: "baz"},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
}
func TestModel_SetCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
m := createModel(t, mt)
assert.Nil(t, m.SetCache("foo", "bar"))
var v string
assert.Nil(t, m.GetCache("foo", &v))
assert.Equal(t, "bar", v)
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
m := createModel(t, mockCollection)
assert.Nil(t, m.SetCache("foo", "bar"))
var v string
assert.Nil(t, m.GetCache("foo", &v))
assert.Equal(t, "bar", v)
}
func TestModel_UpdateByID(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
resp, err := m.UpdateByID(context.Background(), "foo", bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
var v string
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
_, err = m.UpdateByID(context.Background(), "foo", bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.NotNil(t, err)
m.cache = mockedCache{m.cache}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
_, err = m.UpdateByID(context.Background(), "foo", bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Equal(t, errMocked, err)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
m := createModel(t, mockCollection)
assert.Nil(t, m.cache.Set("foo", "bar"))
mockCollection.EXPECT().UpdateByID(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
resp, err := m.UpdateByID(context.Background(), "foo", bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
var v string
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
mockCollection.EXPECT().UpdateByID(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, errMocked)
_, err = m.UpdateByID(context.Background(), "foo", bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.NotNil(t, err)
m.cache = mockedCache{m.cache}
mockCollection.EXPECT().UpdateByID(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
_, err = m.UpdateByID(context.Background(), "foo", bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Equal(t, errMocked, err)
}
func TestModel_UpdateByIDNoCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
resp, err := m.UpdateByIDNoCache(context.Background(), bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
m := createModel(t, mockCollection)
mockCollection.EXPECT().UpdateByID(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
resp, err := m.UpdateByIDNoCache(context.Background(), bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
}
func TestModel_UpdateMany(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
assert.Nil(t, m.cache.Set("bar", "baz"))
resp, err := m.UpdateMany(context.Background(), []string{"foo", "bar"}, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
var v string
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
assert.True(t, m.cache.IsNotFound(m.cache.Get("bar", &v)))
_, err = m.UpdateMany(context.Background(), []string{"foo", "bar"}, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.NotNil(t, err)
m.cache = mockedCache{m.cache}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
_, err = m.UpdateMany(context.Background(), []string{"foo", "bar"}, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Equal(t, errMocked, err)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
m := createModel(t, mockCollection)
assert.Nil(t, m.cache.Set("foo", "bar"))
assert.Nil(t, m.cache.Set("bar", "baz"))
mockCollection.EXPECT().UpdateMany(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
resp, err := m.UpdateMany(context.Background(), []string{"foo", "bar"}, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
var v string
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
assert.True(t, m.cache.IsNotFound(m.cache.Get("bar", &v)))
mockCollection.EXPECT().UpdateMany(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, errMocked)
_, err = m.UpdateMany(context.Background(), []string{"foo", "bar"}, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.NotNil(t, err)
m.cache = mockedCache{m.cache}
mockCollection.EXPECT().UpdateMany(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
_, err = m.UpdateMany(context.Background(), []string{"foo", "bar"}, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Equal(t, errMocked, err)
}
func TestModel_UpdateManyNoCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
resp, err := m.UpdateManyNoCache(context.Background(), bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
m := createModel(t, mockCollection)
mockCollection.EXPECT().UpdateMany(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
resp, err := m.UpdateManyNoCache(context.Background(), bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
}
func TestModel_UpdateOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
resp, err := m.UpdateOne(context.Background(), "foo", bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
var v string
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
_, err = m.UpdateOne(context.Background(), "foo", bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.NotNil(t, err)
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m.cache = mockedCache{m.cache}
_, err = m.UpdateOne(context.Background(), "foo", bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Equal(t, errMocked, err)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
m := createModel(t, mockCollection)
assert.Nil(t, m.cache.Set("foo", "bar"))
mockCollection.EXPECT().UpdateOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
resp, err := m.UpdateOne(context.Background(), "foo", bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
var v string
mockCollection.EXPECT().UpdateOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, errMocked)
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
_, err = m.UpdateOne(context.Background(), "foo", bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.NotNil(t, err)
mockCollection.EXPECT().UpdateOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
m.cache = mockedCache{m.cache}
_, err = m.UpdateOne(context.Background(), "foo", bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Equal(t, errMocked, err)
}
func TestModel_UpdateOneNoCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
resp, err := m.UpdateOneNoCache(context.Background(), bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCollection := mon.NewMockCollection(ctrl)
m := createModel(t, mockCollection)
mockCollection.EXPECT().UpdateOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
resp, err := m.UpdateOneNoCache(context.Background(), bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
}
func createModel(t *testing.T, mt *mtest.T) *Model {
func createModel(t *testing.T, coll mon.Collection) *Model {
s, err := miniredis.Run()
assert.Nil(t, err)
mon.Inject(mt.Name(), mt.Client)
if atomic.AddInt32(&index, 1)%2 == 0 {
return MustNewNodeModel(mt.Name(), mt.DB.Name(), mt.Coll.Name(), redis.New(s.Addr()))
return mustNewTestNodeModel(coll, redis.New(s.Addr()))
} else {
return MustNewModel(mt.Name(), mt.DB.Name(), mt.Coll.Name(), cache.CacheConf{
return mustNewTestModel(coll, cache.CacheConf{
cache.NodeConf{
RedisConf: redis.RedisConf{
Host: s.Addr(),
@@ -519,6 +532,27 @@ func createModel(t *testing.T, mt *mtest.T) *Model {
}
}
// mustNewTestModel returns a test Model with the given cache.
func mustNewTestModel(collection mon.Collection, c cache.CacheConf, opts ...cache.Option) *Model {
return &Model{
Model: &mon.Model{
Collection: collection,
},
cache: cache.New(c, singleFlight, stats, mongo.ErrNoDocuments, opts...),
}
}
// NewNodeModel returns a test Model with a cache node.
func mustNewTestNodeModel(collection mon.Collection, rds *redis.Redis, opts ...cache.Option) *Model {
c := cache.NewNode(rds, singleFlight, stats, mongo.ErrNoDocuments, opts...)
return &Model{
Model: &mon.Model{
Collection: collection,
},
cache: c,
}
}
var (
errMocked = errors.New("mocked error")
index int32

View File

@@ -0,0 +1,19 @@
# Migrating from 1.x to 2.0
To upgrade imports of the Go Driver from v1 to v2, we recommend using [marwan-at-work/mod
](https://github.com/marwan-at-work/mod):
```
mod upgrade --mod-name=go.mongodb.org/mongo-driver
```
# Notice
After completing the mod upgrade, code changes are typically unnecessary in the vast majority of cases. However, if your project references packages including but not limited to those listed below, you'll need to manually replace them, as these libraries are no longer present in the v2 version.
```go
go.mongodb.org/mongo-driver/bson/bsonrw => go.mongodb.org/mongo-driver/v2/bson
go.mongodb.org/mongo-driver/bson/bsoncodec => go.mongodb.org/mongo-driver/v2/bson
go.mongodb.org/mongo-driver/bson/primitive => go.mongodb.org/mongo-driver/v2/bson
```
See the following resources to learn more about upgrading from version 1.x to 2.0.:
https://raw.githubusercontent.com/mongodb/mongo-go-driver/refs/heads/master/docs/migration-2.0.md

View File

@@ -65,7 +65,7 @@ type (
// RedisNode interface represents a redis node.
RedisNode interface {
red.Cmdable
red.BitMapCmdable
Do(ctx context.Context, args ...any) *red.Cmd
}
// GeoLocation is used with GeoAdd to add geospatial location.
@@ -260,12 +260,34 @@ func (s *Redis) BitPosCtx(ctx context.Context, key string, bit, start, end int64
}
// Blpop uses passed in redis connection to execute blocking queries.
//
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode to avoid
// exhausting the connection pool. Blocking commands hold connections for extended periods and should
// not share the regular connection pool.
//
// Example usage:
//
// node, err := redis.CreateBlockingNode(rds)
// if err != nil {
// // handle error
// }
// defer node.Close()
//
// value, err := rds.Blpop(node, "mylist")
// if err != nil {
// // handle error
// }
//
// Doesn't benefit from pooling redis connections of blocking queries
func (s *Redis) Blpop(node RedisNode, key string) (string, error) {
return s.BlpopCtx(context.Background(), node, key)
}
// BlpopCtx uses passed in redis connection to execute blocking queries.
//
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode.
// See Blpop for usage examples.
//
// Doesn't benefit from pooling redis connections of blocking queries
func (s *Redis) BlpopCtx(ctx context.Context, node RedisNode, key string) (string, error) {
return s.BlpopWithTimeoutCtx(ctx, node, blockingQueryTimeout, key)
@@ -273,12 +295,18 @@ func (s *Redis) BlpopCtx(ctx context.Context, node RedisNode, key string) (strin
// BlpopEx uses passed in redis connection to execute blpop command.
// The difference against Blpop is that this method returns a bool to indicate success.
//
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode.
// See Blpop for usage examples.
func (s *Redis) BlpopEx(node RedisNode, key string) (string, bool, error) {
return s.BlpopExCtx(context.Background(), node, key)
}
// BlpopExCtx uses passed in redis connection to execute blpop command.
// The difference against Blpop is that this method returns a bool to indicate success.
//
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode.
// See Blpop for usage examples.
func (s *Redis) BlpopExCtx(ctx context.Context, node RedisNode, key string) (string, bool, error) {
if node == nil {
return "", false, ErrNilNode
@@ -298,12 +326,18 @@ func (s *Redis) BlpopExCtx(ctx context.Context, node RedisNode, key string) (str
// BlpopWithTimeout uses passed in redis connection to execute blpop command.
// Control blocking query timeout
//
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode.
// See Blpop for usage examples.
func (s *Redis) BlpopWithTimeout(node RedisNode, timeout time.Duration, key string) (string, error) {
return s.BlpopWithTimeoutCtx(context.Background(), node, timeout, key)
}
// BlpopWithTimeoutCtx uses passed in redis connection to execute blpop command.
// Control blocking query timeout
//
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode.
// See Blpop for usage examples.
func (s *Redis) BlpopWithTimeoutCtx(ctx context.Context, node RedisNode, timeout time.Duration,
key string) (string, error) {
if node == nil {
@@ -372,6 +406,25 @@ func (s *Redis) DelCtx(ctx context.Context, keys ...string) (int, error) {
return int(v), nil
}
// Do executes a generic redis command with given arguments.
func (s *Redis) Do(args ...any) (any, error) {
return s.DoCtx(context.Background(), args...)
}
// DoCtx executes a generic redis command with given arguments using the provided context.
func (s *Redis) DoCtx(ctx context.Context, args ...any) (any, error) {
if len(args) == 0 {
return nil, errors.New("missing redis command")
}
conn, err := getRedis(s)
if err != nil {
return nil, err
}
return conn.Do(ctx, args...).Result()
}
// Eval is the implementation of redis eval command.
func (s *Redis) Eval(script string, keys []string, args ...any) (any, error) {
return s.EvalCtx(context.Background(), script, keys, args...)
@@ -631,6 +684,28 @@ func (s *Redis) GetDelCtx(ctx context.Context, key string) (string, error) {
return val, err
}
// GetEx is the implementation of redis getex command.
// Available since: redis version 6.2.0
func (s *Redis) GetEx(key string, seconds int) (string, error) {
return s.GetExCtx(context.Background(), key, seconds)
}
// GetExCtx is the implementation of redis getex command.
// Available since: redis version 6.2.0
func (s *Redis) GetExCtx(ctx context.Context, key string, seconds int) (string, error) {
conn, err := getRedis(s)
if err != nil {
return "", err
}
val, err := conn.GetEx(ctx, key, time.Duration(seconds)*time.Second).Result()
if errors.Is(err, red.Nil) {
return "", nil
}
return val, err
}
// GetSet is the implementation of redis getset command.
func (s *Redis) GetSet(key, value string) (string, error) {
return s.GetSetCtx(context.Background(), key, value)
@@ -1285,10 +1360,12 @@ func (s *Redis) RpushCtx(ctx context.Context, key string, values ...any) (int, e
return int(v), nil
}
// RPopLPush atomically removes the last element from source list and prepends it to destination list.
func (s *Redis) RPopLPush(source string, destination string) (string, error) {
return s.RPopLPushCtx(context.Background(), source, destination)
}
// RPopLPushCtx is the context-aware version of RPopLPush.
func (s *Redis) RPopLPushCtx(ctx context.Context, source string, destination string) (string, error) {
conn, err := getRedis(s)
if err != nil {
@@ -1695,14 +1772,17 @@ func (s *Redis) TtlCtx(ctx context.Context, key string) (int, error) {
return int(duration), nil
}
// TxPipeline returns a Redis transaction pipeline for executing multiple commands atomically.
func (s *Redis) TxPipeline() (pipe Pipeliner, err error) {
conn, err := getRedis(s)
if err != nil {
return nil, err
}
return conn.TxPipeline(), nil
}
// Unlink is similar to Del but removes keys asynchronously in a separate thread.
func (s *Redis) Unlink(keys ...string) (int64, error) {
return s.UnlinkCtx(context.Background(), keys...)
}
@@ -1712,9 +1792,181 @@ func (s *Redis) UnlinkCtx(ctx context.Context, keys ...string) (int64, error) {
if err != nil {
return 0, err
}
return conn.Unlink(ctx, keys...).Result()
}
// XAck acknowledges one or more messages in a Redis stream consumer group.
// It marks the specified messages as successfully processed.
func (s *Redis) XAck(stream string, group string, ids ...string) (int64, error) {
return s.XAckCtx(context.Background(), stream, group, ids...)
}
// XAckCtx is the context-aware version of XAck.
func (s *Redis) XAckCtx(ctx context.Context, stream string, group string, ids ...string) (int64, error) {
conn, err := getRedis(s)
if err != nil {
return 0, err
}
return conn.XAck(ctx, stream, group, ids...).Result()
}
// XAdd adds a new entry to a Redis stream with the specified ID and field-value pairs.
// If noMkStream is true, the command will fail if the stream doesn't exist.
func (s *Redis) XAdd(stream string, noMkStream bool, id string, values any) (string, error) {
return s.XAddCtx(context.Background(), stream, noMkStream, id, values)
}
// XAddCtx is the context-aware version of XAdd.
func (s *Redis) XAddCtx(ctx context.Context, stream string, noMkStream bool, id string, values any) (
string, error) {
conn, err := getRedis(s)
if err != nil {
return "", err
}
return conn.XAdd(ctx, &red.XAddArgs{
Stream: stream,
ID: id,
Values: values,
NoMkStream: noMkStream,
}).Result()
}
// XGroupCreateMkStream creates a consumer group for a Redis stream.
// If the stream doesn't exist, it will be created automatically.
func (s *Redis) XGroupCreateMkStream(stream string, group string, start string) (string, error) {
return s.XGroupCreateMkStreamCtx(context.Background(), stream, group, start)
}
// XGroupCreateMkStreamCtx is the context-aware version of XGroupCreateMkStream.
func (s *Redis) XGroupCreateMkStreamCtx(ctx context.Context, stream string, group string,
start string) (string, error) {
conn, err := getRedis(s)
if err != nil {
return "", err
}
return conn.XGroupCreateMkStream(ctx, stream, group, start).Result()
}
// XGroupCreate creates a consumer group for a Redis stream.
// The stream must already exist, otherwise the command will fail.
func (s *Redis) XGroupCreate(stream string, group string, start string) (string, error) {
return s.XGroupCreateCtx(context.Background(), stream, group, start)
}
// XGroupCreateCtx is the context-aware version of XGroupCreate.
func (s *Redis) XGroupCreateCtx(ctx context.Context, stream string, group string, start string) (
string, error) {
conn, err := getRedis(s)
if err != nil {
return "", err
}
return conn.XGroupCreate(ctx, stream, group, start).Result()
}
// XInfoConsumers returns information about consumers in a Redis stream consumer group.
func (s *Redis) XInfoConsumers(stream string, group string) ([]red.XInfoConsumer, error) {
return s.XInfoConsumersCtx(context.Background(), stream, group)
}
// XInfoConsumersCtx is the context-aware version of XInfoConsumers.
func (s *Redis) XInfoConsumersCtx(ctx context.Context, stream string, group string) (
[]red.XInfoConsumer, error) {
conn, err := getRedis(s)
if err != nil {
return nil, err
}
return conn.XInfoConsumers(ctx, stream, group).Result()
}
// XInfoGroups returns information about consumer groups for a Redis stream.
func (s *Redis) XInfoGroups(stream string) ([]red.XInfoGroup, error) {
return s.XInfoGroupsCtx(context.Background(), stream)
}
// XInfoGroupsCtx is the context-aware version of XInfoGroups.
func (s *Redis) XInfoGroupsCtx(ctx context.Context, stream string) ([]red.XInfoGroup, error) {
conn, err := getRedis(s)
if err != nil {
return nil, err
}
return conn.XInfoGroups(ctx, stream).Result()
}
// XInfoStream returns general information about a Redis stream.
func (s *Redis) XInfoStream(stream string) (*red.XInfoStream, error) {
return s.XInfoStreamCtx(context.Background(), stream)
}
// XInfoStreamCtx is the context-aware version of XInfoStream.
func (s *Redis) XInfoStreamCtx(ctx context.Context, stream string) (*red.XInfoStream, error) {
conn, err := getRedis(s)
if err != nil {
return nil, err
}
return conn.XInfoStream(ctx, stream).Result()
}
// XReadGroup reads messages from Redis streams as part of a consumer group.
// It allows for distributed processing of stream messages with automatic message delivery semantics.
//
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode to avoid
// exhausting the connection pool. Blocking commands hold connections for extended periods and should
// not share the regular connection pool.
//
// Example usage:
//
// node, err := redis.CreateBlockingNode(rds)
// if err != nil {
// // handle error
// }
// defer node.Close()
//
// streams, err := rds.XReadGroup(
// node, // RedisNode created with CreateBlockingNode
// "mygroup", // consumer group name
// "consumer1", // consumer ID
// 10, // max number of messages to read
// 5*time.Second, // block duration
// false, // noAck flag
// "mystream", // stream name
// )
//
// Doesn't benefit from pooling redis connections of blocking queries.
func (s *Redis) XReadGroup(node RedisNode, group string, consumerId string, count int64,
block time.Duration, noAck bool, streams ...string) ([]red.XStream, error) {
return s.XReadGroupCtx(context.Background(), node, group, consumerId, count, block, noAck, streams...)
}
// XReadGroupCtx is the context-aware version of XReadGroup.
//
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode to avoid
// exhausting the connection pool. See XReadGroup for usage examples.
//
// Doesn't benefit from pooling redis connections of blocking queries.
func (s *Redis) XReadGroupCtx(ctx context.Context, node RedisNode, group string, consumerId string,
count int64, block time.Duration, noAck bool, streams ...string) ([]red.XStream, error) {
if node == nil {
return nil, ErrNilNode
}
return node.XReadGroup(ctx, &red.XReadGroupArgs{
Group: group,
Consumer: consumerId,
Count: count,
Block: block,
NoAck: noAck,
Streams: streams,
}).Result()
}
// Zadd is the implementation of redis zadd command.
func (s *Redis) Zadd(key string, score int64, value string) (bool, error) {
return s.ZaddCtx(context.Background(), key, score, value)
@@ -1795,7 +2047,7 @@ func (s *Redis) ZaddsCtx(ctx context.Context, key string, ps ...Pair) (int64, er
return 0, err
}
var zs []red.Z
zs := make([]red.Z, 0, len(ps))
for _, p := range ps {
z := red.Z{Score: float64(p.Score), Member: p.Key}
zs = append(zs, z)

View File

@@ -275,6 +275,36 @@ func TestRedis_Eval(t *testing.T) {
})
}
func TestRedis_Do(t *testing.T) {
runOnRedis(t, func(client *Redis) {
_, err := newRedis(client.Addr, badType()).Do("PING")
assert.NotNil(t, err)
pong, err := client.Do("PING")
assert.Nil(t, err)
assert.Equal(t, "PONG", pong)
ok, err := client.Do("SET", "key1", "value1")
assert.Nil(t, err)
assert.Equal(t, "OK", ok)
val, err := client.Do("GET", "key1")
assert.Nil(t, err)
assert.Equal(t, "value1", val)
_, err = client.Do("GET", "not_exist")
assert.Equal(t, Nil, err)
_, err = client.Do()
assert.NotNil(t, err)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = client.DoCtx(ctx, "PING")
assert.Equal(t, context.Canceled, err)
})
}
func TestRedis_ScriptRun(t *testing.T) {
runOnRedis(t, func(client *Redis) {
sc := NewScript(`redis.call("EXISTS", KEYS[1])`)
@@ -916,6 +946,11 @@ func TestRedis_Ping(t *testing.T) {
ok := client.Ping()
assert.True(t, ok)
})
runOnRedisWithError(t, func(client *Redis) {
ok := client.Ping()
assert.False(t, ok)
})
})
}
@@ -1099,6 +1134,45 @@ func TestRedis_GetDel(t *testing.T) {
})
}
func TestRedis_GetEx(t *testing.T) {
t.Run("get_ex", func(t *testing.T) {
runOnRedis(t, func(client *Redis) {
val, err := client.GetEx("getex_key", 10)
assert.Equal(t, "", val)
assert.Nil(t, err)
err = client.Set("getex_key", "getex_value")
assert.Nil(t, err)
val, err = client.GetEx("getex_key", 10)
assert.Nil(t, err)
assert.Equal(t, "getex_value", val)
val, err = client.Get("getex_key")
assert.Nil(t, err)
assert.Equal(t, "getex_value", val)
ttl, err := client.Ttl("getex_key")
assert.Nil(t, err)
assert.True(t, ttl > 0 && ttl <= 10)
val, err = client.GetEx("getex_key", 5)
assert.Nil(t, err)
assert.Equal(t, "getex_value", val)
ttl, err = client.Ttl("getex_key")
assert.Nil(t, err)
assert.True(t, ttl > 0 && ttl <= 5)
})
})
t.Run("get_ex_with_error", func(t *testing.T) {
runOnRedisWithError(t, func(client *Redis) {
_, err := newRedis(client.Addr, badType()).GetEx("hello", 10)
assert.Error(t, err)
})
})
}
func TestRedis_GetSet(t *testing.T) {
t.Run("set_get", func(t *testing.T) {
runOnRedis(t, func(client *Redis) {
@@ -2029,6 +2103,16 @@ func TestRedis_WithUserPass(t *testing.T) {
err := newRedis(client.Addr, WithUser("any"), WithPass("any")).Ping()
assert.NotNil(t, err)
})
runOnRedisWithAccount(t, "foo", "bar", func(client *Redis) {
err := client.Set("key1", "value1")
assert.Nil(t, err)
_, err = newRedis(client.Addr, badType()).Keys("*")
assert.NotNil(t, err)
keys, err := client.Keys("*")
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"key1"}, keys)
})
}
func TestRedis_checkConnection(t *testing.T) {
@@ -2057,6 +2141,19 @@ func runOnRedis(t *testing.T, fn func(client *Redis)) {
}))
}
func runOnRedisWithAccount(t *testing.T, user, pass string, fn func(client *Redis)) {
logx.Disable()
s := miniredis.RunT(t)
s.RequireUserAuth(user, pass)
fn(MustNewRedis(RedisConf{
Host: s.Addr(),
Type: NodeType,
User: user,
Pass: pass,
}))
}
func runOnRedisWithError(t *testing.T, fn func(client *Redis)) {
logx.Disable()
@@ -2175,3 +2272,115 @@ func TestRedisTxPipeline(t *testing.T) {
assert.Equal(t, hashValue, value)
})
}
func TestRedisXGroupCreate(t *testing.T) {
runOnRedis(t, func(client *Redis) {
_, err := newRedis(client.Addr, badType()).XGroupCreate("Source", "Destination", "0")
assert.NotNil(t, err)
redisCli := newRedis(client.Addr)
_, err = redisCli.XGroupCreate("aa", "bb", "0")
assert.NotNil(t, err)
_, err = newRedis(client.Addr, badType()).XGroupCreateMkStream("Source", "Destination", "0")
assert.NotNil(t, err)
_, err = redisCli.XGroupCreateMkStream("aa", "bb", "0")
assert.Nil(t, err)
_, err = redisCli.XGroupCreate("aa", "cc", "0")
assert.Nil(t, err)
})
}
func TestRedisXInfo(t *testing.T) {
runOnRedis(t, func(client *Redis) {
_, err := newRedis(client.Addr, badType()).XInfoStream("Source")
assert.NotNil(t, err)
_, err = newRedis(client.Addr, badType()).XInfoGroups("Source")
assert.NotNil(t, err)
redisCli := newRedis(client.Addr)
stream := "aa"
group := "bb"
_, err = redisCli.XGroupCreateMkStream(stream, group, "$")
assert.Nil(t, err)
_, err = redisCli.XAdd(stream, true, "*", []string{"key1", "value1", "key2", "value2"})
assert.Nil(t, err)
infoStream, err := redisCli.XInfoStream(stream)
assert.Nil(t, err)
assert.Equal(t, int64(1), infoStream.Length)
infoGroups, err := redisCli.XInfoGroups(stream)
assert.Nil(t, err)
assert.Equal(t, int64(1), infoGroups[0].Lag)
assert.Equal(t, group, infoGroups[0].Name)
node, err := getRedis(redisCli)
assert.NoError(t, err)
redisCli.XAdd(stream, true, "*", []string{"key1", "value1", "key2", "value2"})
streamRes, err := redisCli.XReadGroup(node, group, "consumer", 1, 2000, false, stream, ">")
assert.Nil(t, err)
assert.Equal(t, 1, len(streamRes))
assert.Equal(t, "value1", streamRes[0].Messages[0].Values["key1"])
infoConsumers, err := redisCli.XInfoConsumers(stream, group)
assert.Nil(t, err)
assert.Equal(t, 1, len(infoConsumers))
_, err = newRedis(client.Addr, badType()).XInfoConsumers(stream, group)
assert.NotNil(t, err)
})
}
func TestRedisXReadGroup(t *testing.T) {
runOnRedis(t, func(client *Redis) {
_, err := newRedis(client.Addr, badType()).XAdd("bb", true, "*", []string{"key1", "value1", "key2", "value2"})
assert.NotNil(t, err)
_, err = newRedis(client.Addr, badType()).XAck("bb", "aa", "123")
assert.NotNil(t, err)
redisCli := newRedis(client.Addr)
stream := "aa"
group := "bb"
_, err = redisCli.XGroupCreateMkStream(stream, group, "$")
assert.Nil(t, err)
_, err = redisCli.XAdd(stream, true, "*", []string{"key1", "value1", "key2", "value2"})
assert.Nil(t, err)
node, err := getRedis(redisCli)
assert.NoError(t, err)
redisCli.XAdd(stream, true, "*", []string{"key1", "value1", "key2", "value2"})
_, err = redisCli.XReadGroup(nil, group, "consumer", 1, 2000, false, stream, ">")
assert.Error(t, err)
streamRes, err := redisCli.XReadGroup(node, group, "consumer", 1, 2000, false, stream, ">")
assert.Nil(t, err)
assert.Equal(t, 1, len(streamRes))
assert.Equal(t, "value1", streamRes[0].Messages[0].Values["key1"])
_, err = redisCli.XReadGroup(nil, group, "consumer", 1, 2000, false, stream, "0")
assert.Error(t, err)
streamRes1, err := redisCli.XReadGroup(node, group, "consumer", 1, 2000, false, stream, "0")
assert.Nil(t, err)
assert.Equal(t, 1, len(streamRes1))
assert.Equal(t, "value1", streamRes1[0].Messages[0].Values["key1"])
_, err = redisCli.XAck(stream, group, streamRes[0].Messages[0].ID)
assert.Nil(t, err)
_, err = redisCli.XReadGroup(nil, group, "consumer", 1, 2000, false, stream, "0")
assert.Error(t, err)
streamRes2, err := redisCli.XReadGroup(node, group, "consumer", 1, 2000, false, stream, "0")
assert.Nil(t, err)
assert.Greater(t, len(streamRes2), 0, "streamRes2 is empty")
assert.Equal(t, 0, len(streamRes2[0].Messages))
})
}

View File

@@ -13,7 +13,37 @@ type ClosableNode interface {
Close()
}
// CreateBlockingNode returns a ClosableNode.
// CreateBlockingNode creates a dedicated RedisNode for blocking operations.
//
// Blocking Redis commands (like BLPOP, BRPOP, XREADGROUP with block parameter) hold connections
// for extended periods while waiting for data. Using them with the regular Redis connection pool
// can exhaust all available connections, causing other operations to fail or timeout.
//
// CreateBlockingNode creates a separate Redis client with a minimal connection pool (size 1) that
// is dedicated to blocking operations. This ensures blocking commands don't interfere with regular
// Redis operations.
//
// Example usage:
//
// rds := redis.MustNewRedis(redis.RedisConf{
// Host: "localhost:6379",
// Type: redis.NodeType,
// })
//
// // Create a dedicated node for blocking operations
// node, err := redis.CreateBlockingNode(rds)
// if err != nil {
// // handle error
// }
// defer node.Close() // Important: close the node when done
//
// // Use the node for blocking operations
// value, err := rds.Blpop(node, "mylist")
// if err != nil {
// // handle error
// }
//
// The returned ClosableNode must be closed when no longer needed to release resources.
func CreateBlockingNode(r *Redis) (ClosableNode, error) {
timeout := readWriteTimeout + blockingQueryTimeout

View File

@@ -25,8 +25,8 @@ type (
ResultHandler func(sql.Result, error)
// A BulkInserter is used to batch insert records.
// Postgresql is not supported yet, because of the sql is formated with symbol `$`.
// Oracle is not supported yet, because of the sql is formated with symbol `:`.
// Postgresql is not supported yet, because of the sql is formatted with symbol `$`.
// Oracle is not supported yet, because of the sql is formatted with symbol `:`.
BulkInserter struct {
executor *executors.PeriodicalExecutor
inserter *dbInserter

View File

@@ -0,0 +1,29 @@
package sqlx
import "errors"
var (
errEmptyDatasource = errors.New("empty datasource")
errEmptyDriverName = errors.New("empty driver name")
)
// SqlConf defines the configuration for sqlx.
type SqlConf struct {
DataSource string
DriverName string `json:",default=mysql"`
Replicas []string `json:",optional"`
Policy string `json:",default=round-robin,options=round-robin|random"`
}
// Validate validates the SqlxConf.
func (sc SqlConf) Validate() error {
if len(sc.DataSource) == 0 {
return errEmptyDatasource
}
if len(sc.DriverName) == 0 {
return errEmptyDriverName
}
return nil
}

View File

@@ -0,0 +1,29 @@
package sqlx
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/conf"
)
func TestValidate(t *testing.T) {
text := []byte(`DataSource: primary:password@tcp(127.0.0.1:3306)/primary_db
`)
var sc SqlConf
err := conf.LoadFromYamlBytes(text, &sc)
assert.Nil(t, err)
assert.Equal(t, "mysql", sc.DriverName)
assert.Equal(t, policyRoundRobin, sc.Policy)
assert.Nil(t, sc.Validate())
sc = SqlConf{}
assert.Equal(t, errEmptyDatasource, sc.Validate())
sc.DataSource = "primary:password@tcp(127.0.0.1:3306)/primary_db"
assert.Equal(t, errEmptyDriverName, sc.Validate())
sc.DriverName = "mysql"
assert.Nil(t, sc.Validate())
}

View File

@@ -9,7 +9,10 @@ import (
"github.com/zeromicro/go-zero/core/mapping"
)
const tagName = "db"
const (
tagIgnore = "-"
tagName = "db"
)
var (
// ErrNotMatchDestination is an error that indicates not matching destination to scan.
@@ -67,25 +70,16 @@ func getTaggedFieldValueMap(v reflect.Value) (map[string]any, error) {
}
func getValueInterface(value reflect.Value) (any, error) {
switch value.Kind() {
case reflect.Ptr:
if !value.CanInterface() {
return nil, ErrNotReadableValue
}
if value.IsNil() {
baseValueType := mapping.Deref(value.Type())
value.Set(reflect.New(baseValueType))
}
return value.Interface(), nil
default:
if !value.CanAddr() || !value.Addr().CanInterface() {
return nil, ErrNotReadableValue
}
return value.Addr().Interface(), nil
if !value.CanAddr() || !value.Addr().CanInterface() {
return nil, ErrNotReadableValue
}
if value.Kind() == reflect.Pointer && value.IsNil() {
baseValueType := mapping.Deref(value.Type())
value.Set(reflect.New(baseValueType))
}
return value.Addr().Interface(), nil
}
func isScanFailed(err error) bool {
@@ -269,13 +263,17 @@ func unwrapFields(v reflect.Value) []reflect.Value {
continue
}
childType := indirect.Type().Field(i)
if parseTagName(childType) == tagIgnore {
continue
}
if child.Kind() == reflect.Ptr && child.IsNil() {
baseValueType := mapping.Deref(child.Type())
child.Set(reflect.New(baseValueType))
}
child = reflect.Indirect(child)
childType := indirect.Type().Field(i)
if child.Kind() == reflect.Struct && childType.Anonymous {
fields = append(fields, unwrapFields(child)...)
} else {

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,65 @@
package sqlx
import "context"
const (
// policyRoundRobin round-robin policy for selecting replicas.
policyRoundRobin = "round-robin"
// policyRandom random policy for selecting replicas.
policyRandom = "random"
// readPrimaryMode indicates that the operation is a read,
// but should be performed on the primary database instance.
//
// This mode is used in scenarios where data freshness and consistency are critical,
// such as immediately after writes or where replication lag may cause stale reads.
readPrimaryMode readWriteMode = "read-primary"
// readReplicaMode indicates that the operation is a read from replicas.
// This is suitable for scenarios where eventual consistency is acceptable,
// and the goal is to offload traffic from the primary and improve read scalability.
readReplicaMode readWriteMode = "read-replica"
// writeMode indicates that the operation is a write operation (to primary).
writeMode readWriteMode = "write"
// notSpecifiedMode indicates that the read/write mode is not specified.
notSpecifiedMode readWriteMode = ""
)
type readWriteModeKey struct{}
// WithReadPrimary sets the context to read-primary mode.
func WithReadPrimary(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey{}, readPrimaryMode)
}
// WithReadReplica sets the context to read-replica mode.
func WithReadReplica(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey{}, readReplicaMode)
}
// WithWrite sets the context to write mode, indicating that the operation is a write operation.
func WithWrite(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey{}, writeMode)
}
type readWriteMode string
func (m readWriteMode) isValid() bool {
return m == readPrimaryMode || m == readReplicaMode || m == writeMode
}
func getReadWriteMode(ctx context.Context) readWriteMode {
if mode := ctx.Value(readWriteModeKey{}); mode != nil {
if v, ok := mode.(readWriteMode); ok && v.isValid() {
return v
}
}
return notSpecifiedMode
}
func usePrimary(ctx context.Context) bool {
return getReadWriteMode(ctx) != readReplicaMode
}

View File

@@ -0,0 +1,142 @@
package sqlx
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsValid(t *testing.T) {
testCases := []struct {
name string
mode readWriteMode
expected bool
}{
{
name: "valid read-primary mode",
mode: readPrimaryMode,
expected: true,
},
{
name: "valid read-replica mode",
mode: readReplicaMode,
expected: true,
},
{
name: "valid write mode",
mode: writeMode,
expected: true,
},
{
name: "not specified mode (empty)",
mode: notSpecifiedMode,
expected: false,
},
{
name: "invalid custom string",
mode: readWriteMode("delete"),
expected: false,
},
{
name: "case sensitive check",
mode: readWriteMode("READ"),
expected: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actual := tc.mode.isValid()
assert.Equal(t, tc.expected, actual)
})
}
}
func TestWithReadMode(t *testing.T) {
ctx := context.Background()
readPrimaryCtx := WithReadPrimary(ctx)
val := readPrimaryCtx.Value(readWriteModeKey{})
assert.Equal(t, readPrimaryMode, val)
readReplicaCtx := WithReadReplica(ctx)
val = readReplicaCtx.Value(readWriteModeKey{})
assert.Equal(t, readReplicaMode, val)
}
func TestWithWriteMode(t *testing.T) {
ctx := context.Background()
writeCtx := WithWrite(ctx)
val := writeCtx.Value(readWriteModeKey{})
assert.Equal(t, writeMode, val)
}
func TestGetReadWriteMode(t *testing.T) {
t.Run("valid read-primary mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readPrimaryMode)
assert.Equal(t, readPrimaryMode, getReadWriteMode(ctx))
})
t.Run("valid read-replica mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readReplicaMode)
assert.Equal(t, readReplicaMode, getReadWriteMode(ctx))
})
t.Run("valid write mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, writeMode)
assert.Equal(t, writeMode, getReadWriteMode(ctx))
})
t.Run("invalid mode value (wrong type)", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, "not-a-mode")
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
})
t.Run("invalid mode value (wrong value)", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readWriteMode("delete"))
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
})
t.Run("no mode set", func(t *testing.T) {
ctx := context.Background()
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
})
}
func TestUsePrimary(t *testing.T) {
t.Run("context with read-replica mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readReplicaMode)
assert.False(t, usePrimary(ctx))
})
t.Run("context with read-primary mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readPrimaryMode)
assert.True(t, usePrimary(ctx))
})
t.Run("context with write mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, writeMode)
assert.True(t, usePrimary(ctx))
})
t.Run("context with invalid mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readWriteMode("invalid"))
assert.True(t, usePrimary(ctx))
})
t.Run("context with no mode set", func(t *testing.T) {
ctx := context.Background()
assert.True(t, usePrimary(ctx))
})
}
func TestWithModeTwice(t *testing.T) {
ctx := context.Background()
ctx = WithReadPrimary(ctx)
writeCtx := WithWrite(ctx)
val := writeCtx.Value(readWriteModeKey{})
assert.Equal(t, writeMode, val)
}

View File

@@ -4,6 +4,9 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"math/rand"
"sync/atomic"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/errorx"
@@ -52,9 +55,10 @@ type (
beginTx beginnable
brk breaker.Breaker
accept breaker.Acceptable
index uint32
}
connProvider func() (*sql.DB, error)
connProvider func(ctx context.Context) (*sql.DB, error)
sessionConn interface {
Exec(query string, args ...any) (sql.Result, error)
@@ -64,10 +68,41 @@ type (
}
)
// MustNewConn returns a SqlConn with the given SqlConf.
func MustNewConn(c SqlConf, opts ...SqlOption) SqlConn {
conn, err := NewConn(c, opts...)
if err != nil {
logx.Must(err)
}
return conn
}
// NewConn returns a SqlConn with the given SqlConf.
func NewConn(c SqlConf, opts ...SqlOption) (SqlConn, error) {
if err := c.Validate(); err != nil {
return nil, err
}
conn := &commonSqlConn{
onError: func(ctx context.Context, err error) {
logInstanceError(ctx, c.DataSource, err)
},
beginTx: begin,
brk: breaker.NewBreaker(),
}
for _, opt := range opts {
opt(conn)
}
conn.connProv = getConnProvider(conn, c.DriverName, c.DataSource, c.Policy, c.Replicas)
return conn, nil
}
// NewSqlConn returns a SqlConn with given driver name and datasource.
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
conn := &commonSqlConn{
connProv: func() (*sql.DB, error) {
connProv: func(context.Context) (*sql.DB, error) {
return getSqlConn(driverName, datasource)
},
onError: func(ctx context.Context, err error) {
@@ -87,7 +122,7 @@ func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
// Use it with caution; it's provided for other ORM to interact with.
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
conn := &commonSqlConn{
connProv: func() (*sql.DB, error) {
connProv: func(ctx context.Context) (*sql.DB, error) {
return db, nil
},
onError: func(ctx context.Context, err error) {
@@ -123,7 +158,7 @@ func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
var conn *sql.DB
conn, err = db.connProv()
conn, err = db.connProv(ctx)
if err != nil {
db.onError(ctx, err)
return err
@@ -151,7 +186,7 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
var conn *sql.DB
conn, err = db.connProv()
conn, err = db.connProv(ctx)
if err != nil {
db.onError(ctx, err)
return err
@@ -242,7 +277,7 @@ func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any,
}
func (db *commonSqlConn) RawDB() (*sql.DB, error) {
return db.connProv()
return db.connProv(context.Background())
}
func (db *commonSqlConn) Transact(fn func(Session) error) error {
@@ -288,7 +323,7 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
q string, args ...any) (err error) {
var scanFailed bool
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
conn, err := db.connProv()
conn, err := db.connProv(ctx)
if err != nil {
db.onError(ctx, err)
return err
@@ -311,6 +346,38 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
return
}
func getConnProvider(sc *commonSqlConn, driverName, datasource, policy string, replicas []string) connProvider {
return func(ctx context.Context) (*sql.DB, error) {
replicaCount := len(replicas)
if replicaCount == 0 || usePrimary(ctx) {
return getSqlConn(driverName, datasource)
}
var dsn string
if replicaCount == 1 {
dsn = replicas[0]
} else {
if len(policy) == 0 {
policy = policyRoundRobin
}
switch policy {
case policyRandom:
dsn = replicas[rand.Intn(replicaCount)]
case policyRoundRobin:
index := atomic.AddUint32(&sc.index, 1) - 1
dsn = replicas[index%uint32(replicaCount)]
default:
return nil, fmt.Errorf("unknown policy: %s", policy)
}
}
return getSqlConn(driverName, dsn)
}
}
// WithAcceptable returns a SqlOption that setting the acceptable function.
// acceptable is the func to check if the error can be accepted.
func WithAcceptable(acceptable func(err error) bool) SqlOption {

View File

@@ -1,6 +1,7 @@
package sqlx
import (
"context"
"database/sql"
"errors"
"io"
@@ -98,7 +99,7 @@ func TestSqlConn_RawDB(t *testing.T) {
func TestSqlConn_Errors(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := NewSqlConnFromDB(db)
conn.(*commonSqlConn).connProv = func() (*sql.DB, error) {
conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) {
return nil, errors.New("error")
}
_, err := conn.Prepare("any")
@@ -138,6 +139,148 @@ func TestSqlConn_Errors(t *testing.T) {
})
}
func TestConfigSqlConn(t *testing.T) {
db, mock, err := sqlmock.New()
assert.NotNil(t, db)
assert.NotNil(t, mock)
assert.Nil(t, err)
connManager.Inject(mockedDatasource, db)
mock.ExpectExec("any")
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf, withMysqlAcceptable())
_, err = conn.Exec("any", "value")
assert.NotNil(t, err)
_, err = conn.Prepare("any")
assert.NotNil(t, err)
var val string
assert.NotNil(t, conn.QueryRow(&val, "any"))
assert.NotNil(t, conn.QueryRowPartial(&val, "any"))
assert.NotNil(t, conn.QueryRows(&val, "any"))
assert.NotNil(t, conn.QueryRowsPartial(&val, "any"))
}
func TestConfigSqlConnStatement(t *testing.T) {
db, mock, err := sqlmock.New()
assert.NotNil(t, db)
assert.NotNil(t, mock)
assert.Nil(t, err)
connManager.Inject(mockedDatasource, db)
mock.ExpectPrepare("any")
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
mock.ExpectPrepare("any")
row := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(row)
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf, withMysqlAcceptable())
stmt, err := conn.Prepare("any")
assert.NoError(t, err)
res, err := stmt.Exec()
assert.NoError(t, err)
lastInsertID, err := res.LastInsertId()
assert.NoError(t, err)
assert.Equal(t, int64(2), lastInsertID)
rowsAffected, err := res.RowsAffected()
assert.NoError(t, err)
assert.Equal(t, int64(3), rowsAffected)
stmt, err = conn.Prepare("any")
assert.NoError(t, err)
var val string
err = stmt.QueryRow(&val)
assert.NoError(t, err)
assert.Equal(t, "bar", val)
mock.ExpectPrepare("any")
rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
stmt, err = conn.Prepare("any")
assert.NoError(t, err)
var vals []string
assert.NoError(t, stmt.QueryRowsPartial(&vals))
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
}
func TestConfigSqlConnQuery(t *testing.T) {
db, mock, err := sqlmock.New()
assert.NotNil(t, db)
assert.NotNil(t, mock)
assert.Nil(t, err)
connManager.Inject(mockedDatasource, db)
t.Run("QueryRow", func(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf)
var val string
assert.NoError(t, conn.QueryRow(&val, "any"))
assert.Equal(t, "bar", val)
})
t.Run("QueryRowPartial", func(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf)
var val string
assert.NoError(t, conn.QueryRowPartial(&val, "any"))
assert.Equal(t, "bar", val)
})
t.Run("QueryRows", func(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf)
var vals []string
assert.NoError(t, conn.QueryRows(&vals, "any"))
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
})
t.Run("QueryRowsPartial", func(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf)
var vals []string
assert.NoError(t, conn.QueryRowsPartial(&vals, "any"))
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
})
}
func TestConfigSqlConnErr(t *testing.T) {
t.Run("panic on empty config", func(t *testing.T) {
original := logx.ExitOnFatal.True()
logx.ExitOnFatal.Set(false)
defer logx.ExitOnFatal.Set(original)
assert.Panics(t, func() {
MustNewConn(SqlConf{})
})
})
t.Run("on error", func(t *testing.T) {
db, mock, err := sqlmock.New()
assert.NotNil(t, db)
assert.NotNil(t, mock)
assert.Nil(t, err)
connManager.Inject(mockedDatasource, db)
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf)
conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) {
return nil, errors.New("error")
}
_, err = conn.Prepare("any")
assert.Error(t, err)
})
}
func TestStatement(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectPrepare("any").WillBeClosed()
@@ -303,6 +446,93 @@ func TestWithAcceptable(t *testing.T) {
assert.True(t, conn.accept(acceptableErr3))
}
func TestProvider(t *testing.T) {
defer func() {
_ = connManager.Close()
}()
primaryDSN := "primary:password@tcp(127.0.0.1:3306)/primary_db"
replicasDSN := []string{
"replica_one:pwd@tcp(localhost:3306)/replica_one",
"replica_two:pwd@tcp(localhost:3306)/replica_two",
"replica_three:pwd@tcp(localhost:3306)/replica_three",
}
primaryDB, err := connManager.GetResource(primaryDSN, func() (io.Closer, error) { return sql.Open(mysqlDriverName, primaryDSN) })
assert.Nil(t, err)
assert.NotNil(t, primaryDB)
replicaOneDB, err := connManager.GetResource(replicasDSN[0], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[0]) })
assert.Nil(t, err)
assert.NotNil(t, replicaOneDB)
replicaTwoDB, err := connManager.GetResource(replicasDSN[1], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[1]) })
assert.Nil(t, err)
assert.NotNil(t, replicaTwoDB)
replicaThreeDB, err := connManager.GetResource(replicasDSN[2], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[2]) })
assert.Nil(t, err)
assert.NotNil(t, replicaThreeDB)
sc := &commonSqlConn{}
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, nil)
ctx := context.Background()
db, err := sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, primaryDB, db)
ctx = WithWrite(ctx)
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, primaryDB, db)
ctx = WithReadPrimary(ctx)
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, primaryDB, db)
// no mode set, should return primary
ctx = context.Background()
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, replicasDSN)
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, primaryDB, db)
ctx = WithReadReplica(ctx)
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, []string{replicasDSN[0]})
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, replicaOneDB, db)
// default policy is round-robin
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, replicasDSN)
replicas := []io.Closer{replicaOneDB, replicaTwoDB, replicaThreeDB}
for i := 0; i < len(replicasDSN); i++ {
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, replicas[i], db)
}
// random policy
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRandom, replicasDSN)
for i := 0; i < len(replicasDSN); i++ {
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Contains(t, replicas, db)
}
// unknown policy
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, "unknown", replicasDSN)
_, err = sc.connProv(ctx)
assert.NotNil(t, err)
// empty policy transforms to round-robin
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, "", replicasDSN)
for i := 0; i < len(replicasDSN); i++ {
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, replicas[i], db)
}
}
func buildConn() (mock sqlmock.Sqlmock, err error) {
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
var db *sql.DB

View File

@@ -27,7 +27,7 @@ func getCachedSqlConn(driverName, server string) (*sql.DB, error) {
return nil, err
}
if driverName != mysqlDriverName {
if driverName == mysqlDriverName {
if cfg, e := mysql.ParseDSN(server); e != nil {
// if cannot parse, don't collect the metrics
logx.Error(e)

View File

@@ -156,7 +156,7 @@ func begin(db *sql.DB) (trans, error) {
func transact(ctx context.Context, db *commonSqlConn, b beginnable,
fn func(context.Context, Session) error) (err error) {
conn, err := db.connProv()
conn, err := db.connProv(ctx)
if err != nil {
db.onError(ctx, err)
return err

Some files were not shown because too many files have changed in this diff Show More