Compare commits

...

137 Commits

Author SHA1 Message Date
Kevin Wan
c9ff6a10d3 feat: support serverless in rest (#5001)
Signed-off-by: kevin <wanjunfeng@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-07-13 00:00:52 +08:00
Kevin Wan
a71e56de52 fix: context key error in sql read write mode (#5000) 2025-07-12 06:58:08 +08:00
Kevin Wan
bae8d4f4c8 chore: refactoring sql read write mode (#4990)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-07-11 01:05:55 +08:00
zhoushuguang
8c6266f338 sql read write support (#4976)
Co-authored-by: light.zhou <light.zhou@bkyo.io>
2025-07-09 16:04:56 +00:00
Kevin Wan
95d5b81f44 chore: optimize pr 4979 (#4988)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-07-09 23:55:24 +08:00
geekeryy
bca7bbc142 fix: correct duration type comparison in environment variable processing (#4979) 2025-07-09 15:22:27 +00:00
Kevin Wan
df9a52664b fix issue #4986 2025-07-08 13:58:48 +00:00
Kevin Wan
937cf0db96 Update readme-cn.md (#4983) 2025-07-04 11:02:49 +08:00
Kevin Wan
75cebb65f8 fix: timeout 0s not working (#4932)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-07-01 17:01:24 +08:00
dependabot[bot]
410f56e73a chore(deps): bump github.com/redis/go-redis/v9 from 9.10.0 to 9.11.0 (#4969)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-06-25 18:35:01 +08:00
dependabot[bot]
017909a3ab chore(deps): bump github.com/emicklei/proto from 1.14.1 to 1.14.2 in /tools/goctl (#4961)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-06-19 15:40:18 +08:00
kesonan
0d31e6c375 (goctl): fix #4943 (#4953) 2025-06-14 15:36:30 +00:00
Kevin Wan
0ba86b1849 chore: add more tests (#4949) 2025-06-13 22:10:08 +08:00
wanwu
4cacc4d9d3 fix: the time.Duration type panics due to numerical values (#4944)
Co-authored-by: sam.yang <sam.yang@yijinin.com>
2025-06-12 15:11:07 +00:00
Eric
a99c14da4a fix: typo of the logic of CpuThreshold in comments (#4942)
Co-authored-by: zhouyy <zhouyy@ickey.cn>
2025-06-12 08:28:44 +00:00
Kevin Wan
985582264a chore: fix warnings (#4940) 2025-06-12 00:04:29 +08:00
Kevin Wan
8364e341e1 chore: update go-zero dep (#4933)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-06-09 18:08:20 +08:00
Kevin Wan
0f2b589d4d Revert "fix: api group set timeout: 0s not working." (#4931) 2025-06-08 23:14:38 +08:00
spectatorMrZ
19fec36d24 fix: api group set timeout: 0s not working. (#4785) 2025-06-08 14:50:21 +00:00
Kevin Wan
f037bf344d chore: add more tests (#4930) 2025-06-08 22:08:04 +08:00
MarkJoyMa
d99cf35b07 Feat/continue profiling (#4867)
Co-authored-by: aiden.ma <Aiden.ma@yijinin.com>
Co-authored-by: aiden.ma <aiden.ma@bkyo.io>
2025-06-07 21:12:31 +08:00
Kevin Wan
f459f1b5ff chore: update goctl version (#4929) 2025-06-07 21:01:35 +08:00
Haiwei Zhang
0140fd417b feat(goctl): generate mongo model with cache prefix (#4907) 2025-06-07 12:54:33 +00:00
jaron
7969e0ca38 fix(goctl): Fix getting swagger consume types (#4903) 2025-06-07 12:46:34 +00:00
Kevin Wan
91c885b5b0 chore: add more unit tests for mcp (#4928) 2025-06-07 20:41:57 +08:00
MarkJoyMa
d4cccca387 Fix the problem that mcp request id is not of int type (#4914) 2025-06-07 10:37:18 +08:00
dependabot[bot]
4b2095ed03 chore(deps): bump github.com/redis/go-redis/v9 from 9.9.0 to 9.10.0 (#4926) 2025-06-07 10:07:26 +08:00
dependabot[bot]
1229eeb2d2 chore(deps): bump go.mongodb.org/mongo-driver from 1.17.3 to 1.17.4 (#4924) 2025-06-06 19:45:26 +08:00
dependabot[bot]
9142b146c5 chore(deps): bump github.com/alicebob/miniredis/v2 from 2.34.0 to 2.35.0 (#4919)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-06-06 10:09:15 +08:00
Kevin Wan
8a1b2d5aed chore: fix typo (#4920) 2025-06-05 22:51:22 +08:00
Leon cap
da5d39e6ca fix: correct spelling of 'cancellation' in timeout handler comment (#4916)
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
2025-06-05 22:42:53 +08:00
Leon cap
68c5a17c67 fix: correct spelling of 'underlying' in Header method comment (#4918) 2025-06-05 10:36:21 +00:00
Leon cap
b53f9f5f2d fix: correct spelling of 'TimeoutHandler' in timeout handler comment (#4917) 2025-06-04 15:48:37 +00:00
dependabot[bot]
36d57626b6 chore(deps): bump github.com/redis/go-redis/v9 from 9.8.0 to 9.9.0 (#4905)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-05-28 11:32:57 +08:00
Kevin Wan
4e36ba832f Update readme.md (#4897) 2025-05-25 22:25:56 +08:00
Kevin Wan
a44954a771 fix: don't set read/write timeout if timeout middleware disabled (#4895)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-05-25 15:07:58 +08:00
kesonan
f3edd4b880 goctl: v1.8.4-beta (#4890) 2025-05-25 05:36:56 +00:00
Kevin Wan
2de3e397ff chore: revert go version to 1.21 (#4893) 2025-05-24 18:17:49 +08:00
Qiu shao
a435eb56f2 perf(hash): optimize Md5Hex encoding performance (#4891) 2025-05-24 08:41:21 +00:00
Kevin Wan
d80761c147 chore: refactor coding style (#4887) 2025-05-22 23:29:40 +08:00
me-cs
e7bd0d8b60 update:To standardize the time format, use the go standard library's own (#4875) 2025-05-22 15:26:53 +00:00
me-cs
b109b3ef4c update:use builtin cmp func (#4879) 2025-05-22 15:19:13 +00:00
Kevin Wan
e3c371ac89 chore: refactor 2025-05-20 12:59:35 +00:00
燕归来
15eb6f4f6d test(hash): modify TestConsistentHashTransferOnFailure to more reasonable test transfer ratio (#4874) 2025-05-20 12:51:50 +00:00
me-cs
4d3681b71c Optimize slicing operations (#4877) 2025-05-20 11:36:02 +00:00
dependabot[bot]
a682bda0bb chore(deps): bump github.com/jackc/pgx/v5 from 5.7.4 to 5.7.5 (#4871)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-05-20 10:19:06 +08:00
kesonan
45b27ad93a goctl: 1.8.4-beta (#4869) 2025-05-19 15:30:50 +00:00
Kevin Wan
292a8302a1 chore: optimize mcp (#4866) 2025-05-17 12:28:06 +08:00
kesonan
91ab1f6d2b goctl features of 1.8.4-alpha (#4849) 2025-05-15 13:59:48 +00:00
Kevin Wan
5048c350ae chore: fix test failure in profilecenter_test.go 2025-05-15 13:31:53 +00:00
Kevin Wan
94edc32f3e chore: optimize profile center and remove tablewriter dependency 2025-05-15 13:22:27 +00:00
Kevin Wan
ec989b2e2a chore: for backward compatibility (#4852)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-05-11 20:19:00 +08:00
me-cs
82fe802e81 update:Use the official sync.OnceFunc (#4840) 2025-05-11 12:08:43 +00:00
me-cs
072d68f897 update:Use the official slice operate func (#4841) 2025-05-11 11:48:54 +00:00
Kevin Wan
2e91ba5811 chore: refactor rest file server (#4851)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-05-11 12:44:43 +08:00
shaouai
5564c43197 feat: serve files using embed.FS (#4847) 2025-05-10 15:43:13 +00:00
Kevin Wan
e55158b0f7 chore: update deps in goctl (#4830) 2025-05-04 16:18:02 +08:00
Kevin Wan
69aa7fe346 feat: improve mcp (#4828)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-05-04 15:29:14 +08:00
kesonan
c3820a95c1 release goctl swagger (#4829) 2025-05-04 06:05:41 +00:00
Kevin Wan
493f3bad0f fix: allow special characters like periods in API route paths (#4827) 2025-05-04 11:07:43 +08:00
Kevin Wan
eb0d5ad3a4 Update readme-cn.md (#4826) 2025-05-02 17:31:49 +08:00
wwwfeng
14192050ae fix: goctl api tsgen (#4726) 2025-05-02 09:15:32 +00:00
spectatorMrZ
9193e771e3 fix: pg gen model missing cache prefix (#4788)
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
2025-05-02 08:52:55 +00:00
hoshi
808b4e496a feat:add cache prefix support to pg (#4741)
Co-authored-by: 郑好 <zheng.hao1@outlook.com>
2025-05-02 08:10:47 +00:00
kesonan
e416d01f8d api format from stdin (#4772) 2025-05-02 07:59:47 +00:00
Joe Bird
789c5de873 fix(marshaler): fix bug when marshal array (#4790) 2025-05-02 07:54:00 +00:00
Rankgice
52078a0c14 Fix the issue of generating swagger @doc "xxx" that fails, and use th… (#4816)
Co-authored-by: lxr <qiyuechuqi@an-idear.com>
2025-05-02 07:06:03 +00:00
Qiu shao
7ef13116a0 feat: add http to http method test case (#4820) 2025-05-02 06:59:36 +00:00
dependabot[bot]
6b8053410a chore(deps): bump github.com/redis/go-redis/v9 from 9.7.3 to 9.8.0 (#4821)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-05-02 14:48:30 +08:00
dependabot[bot]
81c6928445 chore(deps): bump github.com/emicklei/proto from 1.14.0 to 1.14.1 in /tools/goctl (#4817)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-05-02 14:29:42 +08:00
Kevin Wan
761c2dd716 chore: update goctl version (#4822) 2025-05-02 14:18:54 +08:00
kesonan
aeceb3cfbe Update goctl version to 1.8.3-beta (#4812) 2025-04-27 15:53:49 +00:00
kesonan
15ea07aad1 goctl: support custom swagger authentication (#4811) 2025-04-27 15:43:37 +00:00
shaouai
98bebbc74f feat(swagger): allow users to specify the generated swagger file name (#4809) 2025-04-27 15:34:52 +00:00
kesonan
eafd11d949 goctl: supported api types group for EXPERIMENTAL(实验性功能:支持 api type 结构体按照分组名称拆分文件) (#4810) 2025-04-27 15:18:33 +00:00
Kevin Wan
b251ce346e feat: mcp server sdk (#4794)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-04-27 23:06:37 +08:00
kesonan
812140ba36 fix: goctl swagger missing security definition and submit json body data error (#4808) 2025-04-25 14:58:45 +00:00
kesonan
44735e949c fix array schmea generation incorrect (#4801) 2025-04-23 23:59:01 +00:00
kesonan
bf313c3c56 fix: swagger separator incorrect in Windows OS (#4799) 2025-04-23 13:51:02 +00:00
kesonan
94e7753262 fix: the parameter "required" in the Swagger document generated for repair is incorrect (#4791) 2025-04-21 04:18:02 +00:00
kesonan
9c478626d2 feature/goctl-api-swagger (#4780) 2025-04-17 14:38:55 +00:00
Kevin Wan
801c283478 Delete issue-translator.yml 2025-04-10 12:01:21 +08:00
Kevin Wan
2a54faf997 chore: coding style (#4771) 2025-04-10 09:28:42 +08:00
Hanggang Z
ecd98f3653 chore: add more orm_test (#4766) 2025-04-09 13:49:00 +00:00
soasurs
61641581eb fix: form fields of request optional (#4755)
Signed-off-by: soasurs <soasurs@gmail.com>
2025-04-08 13:05:21 +00:00
Kevin Wan
6f2730d5ae chore: update goctl version (#4754) 2025-04-06 19:09:02 +08:00
dependabot[bot]
0eff777b62 chore(deps): bump github.com/jackc/pgx/v5 from 5.7.2 to 5.7.4 (#4737)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-26 22:41:55 +08:00
dependabot[bot]
cafbf535f7 chore(deps): bump github.com/golang-jwt/jwt/v4 from 4.5.1 to 4.5.2 (#4734) 2025-03-26 16:28:36 +08:00
Kevin Wan
6edfce63e3 feat: add rest.WithSSE to build SSE route easier (#4729) 2025-03-22 13:38:13 +08:00
dependabot[bot]
cdb0098b18 chore(deps): bump github.com/redis/go-redis/v9 from 9.7.1 to 9.7.3 (#4722) 2025-03-21 09:51:31 +08:00
Kevin Wan
620c7f9693 chore: add more tests (#4718) 2025-03-19 23:54:04 +08:00
Meng Ye
dba444a382 feat: support redis getdel command (#4709) 2025-03-19 23:40:14 +08:00
dependabot[bot]
b24fb3ebf7 chore(deps): bump github.com/fullstorydev/grpcurl from 1.9.2 to 1.9.3 (#4701) 2025-03-12 12:05:28 +08:00
POABOB
967f0926eb fix: fix the bug of the numeric/decimal data type in pg (#4686) 2025-03-07 11:12:02 +00:00
dependabot[bot]
e68c683df9 chore(deps): bump github.com/prometheus/client_golang from 1.21.0 to 1.21.1 (#4683)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-05 11:14:23 +08:00
Kevin Wan
247985a065 chore: refactoring & add more tests (#4677) 2025-03-02 20:11:10 +08:00
anlynn
80573af0d8 feat: Add support for serialization of anonymous fields in HTTP client(httpc) (#4676)
Co-authored-by: 李安琳 <anlynn@gmail.com>
2025-03-02 19:10:50 +08:00
Kevin Wan
c0394b631a chore: fix display problems in version-check workflow (#4675) 2025-03-01 22:18:49 +08:00
Kevin Wan
68d1aba377 chore: upgrade go-zero version in goctl (#4674) 2025-03-01 21:23:23 +08:00
Kevin Wan
3315e60272 chore: performance tunning for stable runner (#4670) 2025-02-26 19:19:24 +08:00
dependabot[bot]
327ef73700 chore(deps): bump go.mongodb.org/mongo-driver from 1.17.2 to 1.17.3 (#4669) 2025-02-26 10:03:43 +08:00
dependabot[bot]
eb11521655 chore(deps): bump github.com/redis/go-redis/v9 from 9.7.0 to 9.7.1 (#4665) 2025-02-22 08:19:28 +08:00
Kevin Wan
4c37545e55 Update readme-cn.md (#4664) 2025-02-20 21:13:33 +08:00
dependabot[bot]
2f47c1fba4 chore(deps): bump github.com/prometheus/client_golang from 1.20.5 to 1.21.0 (#4663) 2025-02-20 11:14:13 +08:00
dependabot[bot]
16d54d0ace chore(deps): bump github.com/go-sql-driver/mysql from 1.8.1 to 1.9.0 in /tools/goctl (#4662)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-02-19 11:38:41 +08:00
dependabot[bot]
9925bcbf99 chore(deps): bump github.com/go-sql-driver/mysql from 1.8.1 to 1.9.0 (#4661)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-02-19 11:13:41 +08:00
dependabot[bot]
38a5ecb796 chore(deps): bump github.com/spf13/cobra from 1.8.1 to 1.9.1 in /tools/goctl (#4660)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-02-18 10:47:42 +08:00
Kevin Wan
af78fc7c5f chore: add more tests (#4656) 2025-02-14 23:43:34 +08:00
Kevin Wan
790302b486 fix: should not ignore slowThreshold (#4655) 2025-02-14 23:14:57 +08:00
Nanosk07
6a0672b801 fix: SlowThreshold configuration not taking effect (#4654) 2025-02-14 14:56:25 +00:00
Kevin Wan
560c61612c chore: refactor (#4652) 2025-02-14 09:41:54 +08:00
Kevin Wan
6a988dc4a9 fix: make test purpose method private (#4649) 2025-02-14 00:31:53 +08:00
Kevin Wan
15842c3c7a Create version-check.yml (#4646) 2025-02-13 19:39:22 +08:00
Rui Chen
f2914a74df fix: update version to match with the release (#4645) 2025-02-13 01:11:09 +00:00
Kevin Wan
f113d512e8 chore: coding style (#4644) 2025-02-12 23:48:39 +08:00
kesonan
7a4818da59 Generate caches that support custom key prefix. (#4643) 2025-02-12 15:31:30 +00:00
dependabot[bot]
48d0709ca6 chore(deps): bump golang.org/x/net from 0.34.0 to 0.35.0 (#4640) 2025-02-11 09:25:16 +08:00
Kevin Wan
f747585518 chore: simplify http query array parsing (#4637) 2025-02-09 01:00:52 +08:00
xuerbujia
507ff96546 feat add tag switch to disable form array of split comma format (#4633)
Co-authored-by: wuhongyu <readboy@DESKTOP-T8INU17>
2025-02-09 00:34:41 +08:00
Kevin Wan
651eabb4c6 chore: refactor gateway http context (#4636) 2025-02-08 21:18:07 +08:00
#Suyghur
e6b4372056 fix(gateway): fixed http gateway context propagation error (#4634) 2025-02-08 09:50:26 +00:00
Kevin Wan
24073969a1 fix: redis username not working in redis v7 (#4632) 2025-02-08 12:21:35 +08:00
Kevin Wan
ca797ed22c chore: add trailing newlines (#4631) 2025-02-07 23:39:02 +08:00
youzipi
e347d3f8f8 fix(goctl): allow duplicate_path_expression under different prefix (#4626) 2025-02-07 08:37:11 +00:00
dependabot[bot]
396393b336 chore(deps): bump google.golang.org/protobuf from 1.36.4 to 1.36.5 (#4627) 2025-02-07 10:05:22 +08:00
dependabot[bot]
1f0531b254 chore(deps): bump google.golang.org/protobuf from 1.36.4 to 1.36.5 in /tools/goctl (#4628) 2025-02-07 08:25:54 +08:00
ningzi
77fb271a06 feat(goctl): support go work (#4332) (#4344) 2025-02-05 14:22:40 +00:00
dependabot[bot]
af7cf79963 chore(deps): bump golang.org/x/time from 0.9.0 to 0.10.0 (#4621)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-02-05 22:00:54 +08:00
dependabot[bot]
7926d396d7 chore(deps): bump golang.org/x/text from 0.21.0 to 0.22.0 in /tools/goctl (#4620)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-02-05 10:16:06 +08:00
dependabot[bot]
080cd3df84 chore(deps): bump golang.org/x/sys from 0.29.0 to 0.30.0 (#4622) 2025-02-05 09:27:03 +08:00
Kevin Wan
c4e1a6a2d8 chore: refactor mapreduce (#4619) 2025-02-01 00:12:37 +08:00
Kevin Wan
4e71e95e44 chore: add comments (#4618) 2025-01-31 22:42:46 +08:00
JiChen
84db9bcd15 fix: global fields apply to Third-party log module (#4400) 2025-01-31 13:51:20 +00:00
dependabot[bot]
b28f79ac11 chore(deps): bump github.com/spf13/pflag from 1.0.5 to 1.0.6 in /tools/goctl (#4615) 2025-01-30 16:49:47 +08:00
Kevin Wan
e134e77b2b chore: update go-zero to v1.7.6 for goctl (#4614) 2025-01-29 12:47:23 +08:00
Kevin Wan
f669d84ce8 chore: not using goproxy by default (#4613) 2025-01-29 12:28:47 +08:00
Kevin Wan
9213b8ac27 chore: update go-zero to v1.8.0 for goctl (#4611) 2025-01-29 10:21:40 +08:00
186 changed files with 24196 additions and 972 deletions

View File

@@ -1,18 +0,0 @@
name: 'issue-translator'
on:
issue_comment:
types: [created]
issues:
types: [opened]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: usthe/issues-translate-action@v2.7
with:
IS_MODIFY_TITLE: true
# not require, default false, . Decide whether to modify the issue title
# if true, the robot account @Issues-translate-bot must have modification permissions, invite @Issues-translate-bot to your project or use your custom bot.
CUSTOM_BOT_NOTE: Bot detected the issue body's language is not English, translate it automatically. 👯👭🏻🧑‍🤝‍🧑👫🧑🏿‍🤝‍🧑🏻👩🏾‍🤝‍👨🏿👬🏿
# not require. Customize the translation robot prefix message.

42
.github/workflows/version-check.yml vendored Normal file
View File

@@ -0,0 +1,42 @@
name: Release Version Check
on:
push:
tags:
- 'tools/goctl/v*'
workflow_dispatch:
jobs:
version-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: '1.21'
- name: Extract tag version
id: get_version
run: |
# Extract version from tools/goctl/v* format
VERSION="${GITHUB_REF#refs/tags/tools/goctl/v}"
echo "VERSION=$VERSION" >> $GITHUB_ENV
echo "Extracted version: $VERSION"
- name: Check version in goctl source code
run: |
# Change to goctl directory
cd tools/goctl
# Check version in BuildVersion constant
VERSION_IN_CODE=$(grep -r "const BuildVersion =" . | grep -o '".*"' | tr -d '"')
echo "Version in code: $VERSION_IN_CODE"
echo "Expected version: $VERSION"
if [ "$VERSION_IN_CODE" != "$VERSION" ]; then
echo "Version mismatch: Version in code ($VERSION_IN_CODE) doesn't match tag version ($VERSION)"
exit 1
fi
echo "✅ Version check passed!"

View File

@@ -8,16 +8,12 @@ import (
"sync"
"time"
"github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/core/stringx"
)
const (
numHistoryReasons = 5
timeFormat = "15:04:05"
)
const numHistoryReasons = 5
// ErrServiceUnavailable is returned when the Breaker state is open.
var ErrServiceUnavailable = errors.New("circuit breaker is open")
@@ -262,9 +258,9 @@ type errorWindow struct {
func (ew *errorWindow) add(reason string) {
ew.lock.Lock()
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(timeFormat), reason)
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(time.TimeOnly), reason)
ew.index = (ew.index + 1) % numHistoryReasons
ew.count = mathx.MinInt(ew.count+1, numHistoryReasons)
ew.count = min(ew.count+1, numHistoryReasons)
ew.lock.Unlock()
}

View File

@@ -423,7 +423,7 @@ func TestRegistry_Monitor(t *testing.T) {
GetRegistry().clusters = map[string]*cluster{
getClusterKey(endpoints): {
watchers: map[watchKey]*watchValue{
watchKey{
{
key: "foo",
exactMatch: true,
}: {
@@ -449,7 +449,7 @@ func TestRegistry_Unmonitor(t *testing.T) {
GetRegistry().clusters = map[string]*cluster{
getClusterKey(endpoints): {
watchers: map[watchKey]*watchValue{
watchKey{
{
key: "foo",
exactMatch: true,
}: {

View File

@@ -86,21 +86,16 @@ func TestConsistentHashIncrementalTransfer(t *testing.T) {
func TestConsistentHashTransferOnFailure(t *testing.T) {
index := 41
keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index)
var transferred int
for k, v := range newKeys {
if v != keys[k] {
transferred++
}
}
ratio := float32(transferred) / float32(requestSize)
assert.True(t, ratio < 2.5/float32(keySize), fmt.Sprintf("%d: %f", index, ratio))
ratioNotExists := getTransferRatioOnFailure(t, index)
assert.True(t, ratioNotExists == 0, fmt.Sprintf("%d: %f", index, ratioNotExists))
index = 13
ratio := getTransferRatioOnFailure(t, index)
assert.True(t, ratio < 2.5/keySize, fmt.Sprintf("%d: %f", index, ratio))
}
func TestConsistentHashLeastTransferOnFailure(t *testing.T) {
prefix := "localhost:"
index := 41
index := 13
keys, newKeys := getKeysBeforeAndAfterFailure(t, prefix, index)
for k, v := range keys {
newV := newKeys[k]
@@ -164,6 +159,17 @@ func getKeysBeforeAndAfterFailure(t *testing.T, prefix string, index int) (map[i
return keys, newKeys
}
func getTransferRatioOnFailure(t *testing.T, index int) float32 {
keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index)
var transferred int
for k, v := range newKeys {
if v != keys[k] {
transferred++
}
}
return float32(transferred) / float32(requestSize)
}
type mockNode struct {
addr string
id int

View File

@@ -2,7 +2,7 @@ package hash
import (
"crypto/md5"
"fmt"
"encoding/hex"
"github.com/spaolacci/murmur3"
)
@@ -20,6 +20,7 @@ func Md5(data []byte) []byte {
}
// Md5Hex returns the md5 hex string of data.
// This function is optimized for better performance than fmt.Sprintf.
func Md5Hex(data []byte) string {
return fmt.Sprintf("%x", Md5(data))
return hex.EncodeToString(Md5(data))
}

View File

@@ -7,12 +7,11 @@ import (
)
var (
fieldsContextKey contextKey
globalFields atomic.Value
globalFieldsLock sync.Mutex
)
type contextKey struct{}
type fieldsKey struct{}
// AddGlobalFields adds global fields.
func AddGlobalFields(fields ...LogField) {
@@ -29,16 +28,16 @@ func AddGlobalFields(fields ...LogField) {
// ContextWithFields returns a new context with the given fields.
func ContextWithFields(ctx context.Context, fields ...LogField) context.Context {
if val := ctx.Value(fieldsContextKey); val != nil {
if val := ctx.Value(fieldsKey{}); val != nil {
if arr, ok := val.([]LogField); ok {
allFields := make([]LogField, 0, len(arr)+len(fields))
allFields = append(allFields, arr...)
allFields = append(allFields, fields...)
return context.WithValue(ctx, fieldsContextKey, allFields)
return context.WithValue(ctx, fieldsKey{}, allFields)
}
}
return context.WithValue(ctx, fieldsContextKey, fields)
return context.WithValue(ctx, fieldsKey{}, fields)
}
// WithFields returns a new logger with the given fields.

View File

@@ -34,7 +34,7 @@ func TestAddGlobalFields(t *testing.T) {
func TestContextWithFields(t *testing.T) {
ctx := ContextWithFields(context.Background(), Field("a", 1), Field("b", 2))
vals := ctx.Value(fieldsContextKey)
vals := ctx.Value(fieldsKey{})
assert.NotNil(t, vals)
fields, ok := vals.([]LogField)
assert.True(t, ok)
@@ -43,7 +43,7 @@ func TestContextWithFields(t *testing.T) {
func TestWithFields(t *testing.T) {
ctx := WithFields(context.Background(), Field("a", 1), Field("b", 2))
vals := ctx.Value(fieldsContextKey)
vals := ctx.Value(fieldsKey{})
assert.NotNil(t, vals)
fields, ok := vals.([]LogField)
assert.True(t, ok)
@@ -55,7 +55,7 @@ func TestWithFieldsAppend(t *testing.T) {
ctx := context.WithValue(context.Background(), dummyKey, "dummy")
ctx = ContextWithFields(ctx, Field("a", 1), Field("b", 2))
ctx = ContextWithFields(ctx, Field("c", 3), Field("d", 4))
vals := ctx.Value(fieldsContextKey)
vals := ctx.Value(fieldsKey{})
assert.NotNil(t, vals)
fields, ok := vals.([]LogField)
assert.True(t, ok)
@@ -80,8 +80,8 @@ func TestWithFieldsAppendCopy(t *testing.T) {
ctxa := ContextWithFields(ctx, af)
ctxb := ContextWithFields(ctx, bf)
assert.EqualValues(t, af, ctxa.Value(fieldsContextKey).([]LogField)[count])
assert.EqualValues(t, bf, ctxb.Value(fieldsContextKey).([]LogField)[count])
assert.EqualValues(t, af, ctxa.Value(fieldsKey{}).([]LogField)[count])
assert.EqualValues(t, bf, ctxb.Value(fieldsKey{}).([]LogField)[count])
}
func BenchmarkAtomicValue(b *testing.B) {

View File

@@ -560,7 +560,7 @@ func shallLogStat() bool {
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
// The caller should check shallLog before calling this function.
func writeDebug(val any, fields ...LogField) {
getWriter().Debug(val, addCaller(fields...)...)
getWriter().Debug(val, mergeGlobalFields(addCaller(fields...))...)
}
// writeError writes v into the error log.
@@ -568,7 +568,7 @@ func writeDebug(val any, fields ...LogField) {
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
// The caller should check shallLog before calling this function.
func writeError(val any, fields ...LogField) {
getWriter().Error(val, addCaller(fields...)...)
getWriter().Error(val, mergeGlobalFields(addCaller(fields...))...)
}
// writeInfo writes v into info log.
@@ -576,7 +576,7 @@ func writeError(val any, fields ...LogField) {
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
// The caller should check shallLog before calling this function.
func writeInfo(val any, fields ...LogField) {
getWriter().Info(val, addCaller(fields...)...)
getWriter().Info(val, mergeGlobalFields(addCaller(fields...))...)
}
// writeSevere writes v into severe log.
@@ -592,7 +592,7 @@ func writeSevere(msg string) {
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
// The caller should check shallLog before calling this function.
func writeSlow(val any, fields ...LogField) {
getWriter().Slow(val, addCaller(fields...)...)
getWriter().Slow(val, mergeGlobalFields(addCaller(fields...))...)
}
// writeStack writes v into stack log.
@@ -608,5 +608,5 @@ func writeStack(msg string) {
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
// The caller should check shallLog before calling this function.
func writeStat(msg string) {
getWriter().Stat(msg, addCaller()...)
getWriter().Stat(msg, mergeGlobalFields(addCaller())...)
}

View File

@@ -206,7 +206,9 @@ func (l *richLogger) WithFields(fields ...LogField) Logger {
func (l *richLogger) buildFields(fields ...LogField) []LogField {
fields = append(l.fields, fields...)
// caller field should always appear together with global fields
fields = append(fields, Field(callerKey, getCaller(callerDepth+l.callerSkip)))
fields = mergeGlobalFields(fields)
if l.ctx == nil {
return fields
@@ -222,7 +224,7 @@ func (l *richLogger) buildFields(fields ...LogField) []LogField {
fields = append(fields, Field(spanKey, spanID))
}
val := l.ctx.Value(fieldsContextKey)
val := l.ctx.Value(fieldsKey{})
if val != nil {
if arr, ok := val.([]LogField); ok {
fields = append(fields, arr...)

View File

@@ -18,7 +18,6 @@ import (
)
const (
dateFormat = "2006-01-02"
hoursPerDay = 24
bufferSize = 100
defaultDirMode = 0o755
@@ -116,7 +115,7 @@ func (r *DailyRotateRule) OutdatedFiles() []string {
}
var buf strings.Builder
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(dateFormat)
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(time.DateOnly)
buf.WriteString(r.filename)
buf.WriteString(r.delimiter)
buf.WriteString(boundary)
@@ -425,7 +424,7 @@ func compressLogFile(file string) {
}
func getNowDate() string {
return time.Now().Format(dateFormat)
return time.Now().Format(time.DateOnly)
}
func getNowDateInRFC3339Format() string {

View File

@@ -52,7 +52,7 @@ func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
})
t.Run("temp files", func(t *testing.T) {
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
assert.NoError(t, err)
_ = f1.Close()
@@ -73,7 +73,7 @@ func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
func TestDailyRotateRuleShallRotate(t *testing.T) {
var rule DailyRotateRule
rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(dateFormat)
rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(time.DateOnly)
assert.True(t, rule.ShallRotate(0))
}
@@ -117,12 +117,12 @@ func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
})
t.Run("temp files", func(t *testing.T) {
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
assert.NoError(t, err)
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
assert.NoError(t, err)
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
assert.NoError(t, err)
t.Cleanup(func() {
@@ -144,12 +144,12 @@ func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
})
t.Run("no backups", func(t *testing.T) {
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
assert.NoError(t, err)
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
assert.NoError(t, err)
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
assert.NoError(t, err)
t.Cleanup(func() {
@@ -319,7 +319,7 @@ func TestRotateLoggerWrite(t *testing.T) {
}
// the following write calls cannot be changed to Write, because of DATA RACE.
logger.write([]byte(`foo`))
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat)
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(time.DateOnly)
logger.write([]byte(`bar`))
logger.Close()
logger.write([]byte(`baz`))
@@ -447,7 +447,7 @@ func TestRotateLoggerWithSizeLimitRotateRuleWrite(t *testing.T) {
}
// the following write calls cannot be changed to Write, because of DATA RACE.
logger.write([]byte(`foo`))
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat)
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(time.DateOnly)
logger.write([]byte(`bar`))
logger.Close()
logger.write([]byte(`baz`))

View File

@@ -17,15 +17,27 @@ import (
)
type (
// Writer is the interface for writing logs.
// It's designed to let users customize their own log writer,
// such as writing logs to a kafka, a database, or using third-party loggers.
Writer interface {
// Alert sends an alert message, if your writer implemented alerting functionality.
Alert(v any)
// Close closes the writer.
Close() error
// Debug logs a message at debug level.
Debug(v any, fields ...LogField)
// Error logs a message at error level.
Error(v any, fields ...LogField)
// Info logs a message at info level.
Info(v any, fields ...LogField)
// Severe logs a message at severe level.
Severe(v any)
// Slow logs a message at slow level.
Slow(v any, fields ...LogField)
// Stack logs a message at error level.
Stack(v any)
// Stat logs a message at stat level.
Stat(v any, fields ...LogField)
}
@@ -324,20 +336,6 @@ func buildPlainFields(fields logEntry) []string {
return items
}
func combineGlobalFields(fields []LogField) []LogField {
globals := globalFields.Load()
if globals == nil {
return fields
}
gf := globals.([]LogField)
ret := make([]LogField, 0, len(gf)+len(fields))
ret = append(ret, gf...)
ret = append(ret, fields...)
return ret
}
func marshalJson(t interface{}) ([]byte, error) {
var buf bytes.Buffer
encoder := json.NewEncoder(&buf)
@@ -352,6 +350,20 @@ func marshalJson(t interface{}) ([]byte, error) {
return buf.Bytes(), err
}
func mergeGlobalFields(fields []LogField) []LogField {
globals := globalFields.Load()
if globals == nil {
return fields
}
gf := globals.([]LogField)
ret := make([]LogField, 0, len(gf)+len(fields))
ret = append(ret, gf...)
ret = append(ret, fields...)
return ret
}
func output(writer io.Writer, level string, val any, fields ...LogField) {
// only truncate string content, don't know how to truncate the values of other types.
if v, ok := val.(string); ok {
@@ -362,7 +374,6 @@ func output(writer io.Writer, level string, val any, fields ...LogField) {
}
}
fields = combineGlobalFields(fields)
// +3 for timestamp, level and content
entry := make(logEntry, len(fields)+3)
for _, field := range fields {

View File

@@ -13,6 +13,15 @@ const (
// Marshal marshals the given val and returns the map that contains the fields.
// optional=another is not implemented, and it's hard to implement and not commonly used.
// support anonymous field, e.g.:
//
// type Foo struct {
// Token string `header:"token"`
// }
// type FooB struct {
// Foo
// Bar string `json:"bar"`
// }
func Marshal(val any) (map[string]map[string]any, error) {
ret := make(map[string]map[string]any)
tp := reflect.TypeOf(val)
@@ -44,6 +53,16 @@ func getTag(field reflect.StructField) (string, bool) {
return strings.TrimSpace(tag), false
}
func insertValue(collector map[string]map[string]any, tag string, key string, val any) {
if m, ok := collector[tag]; ok {
m[key] = val
} else {
collector[tag] = map[string]any{
key: val,
}
}
}
func processMember(field reflect.StructField, value reflect.Value,
collector map[string]map[string]any) error {
var key string
@@ -69,15 +88,20 @@ func processMember(field reflect.StructField, value reflect.Value,
val = fmt.Sprint(val)
}
m, ok := collector[tag]
if ok {
m[key] = val
} else {
m = map[string]any{
key: val,
if field.Anonymous {
anonCollector, err := Marshal(val)
if err != nil {
return err
}
for anonTag, anonMap := range anonCollector {
for anonKey, anonVal := range anonMap {
insertValue(collector, anonTag, anonKey, anonVal)
}
}
} else {
insertValue(collector, tag, key, val)
}
collector[tag] = m
return nil
}
@@ -118,7 +142,7 @@ func validateOptional(field reflect.StructField, value reflect.Value) error {
if value.IsNil() {
return fmt.Errorf("field %q is nil", field.Name)
}
case reflect.Array, reflect.Slice, reflect.Map:
case reflect.Slice, reflect.Map:
if value.IsNil() || value.Len() == 0 {
return fmt.Errorf("field %q is empty", field.Name)
}

View File

@@ -27,6 +27,124 @@ func TestMarshal(t *testing.T) {
assert.True(t, m[emptyTag]["Anonymous"].(bool))
}
func TestMarshal_Anonymous(t *testing.T) {
t.Run("anonymous", func(t *testing.T) {
type BaseHeader struct {
Token string `header:"token"`
}
v := struct {
Name string `json:"name"`
Address string `json:"address,options=[beijing,shanghai]"`
Age int `json:"age"`
BaseHeader
}{
Name: "kevin",
Address: "shanghai",
Age: 20,
BaseHeader: BaseHeader{
Token: "token_xxx",
},
}
m, err := Marshal(v)
assert.Nil(t, err)
assert.Equal(t, "kevin", m["json"]["name"])
assert.Equal(t, "shanghai", m["json"]["address"])
assert.Equal(t, 20, m["json"]["age"].(int))
assert.Equal(t, "token_xxx", m["header"]["token"])
v1 := struct {
Name string `json:"name"`
Address string `json:"address,options=[beijing,shanghai]"`
Age int `json:"age"`
BaseHeader
}{
Name: "kevin",
Address: "shanghai",
Age: 20,
}
m1, err1 := Marshal(v1)
assert.Nil(t, err1)
assert.Equal(t, "kevin", m1["json"]["name"])
assert.Equal(t, "shanghai", m1["json"]["address"])
assert.Equal(t, 20, m1["json"]["age"].(int))
type AnotherHeader struct {
Version string `header:"version"`
}
v2 := struct {
Name string `json:"name"`
Address string `json:"address,options=[beijing,shanghai]"`
Age int `json:"age"`
BaseHeader
AnotherHeader
}{
Name: "kevin",
Address: "shanghai",
Age: 20,
BaseHeader: BaseHeader{
Token: "token_xxx",
},
AnotherHeader: AnotherHeader{
Version: "v1.0",
},
}
m2, err2 := Marshal(v2)
assert.Nil(t, err2)
assert.Equal(t, "kevin", m2["json"]["name"])
assert.Equal(t, "shanghai", m2["json"]["address"])
assert.Equal(t, 20, m2["json"]["age"].(int))
assert.Equal(t, "token_xxx", m2["header"]["token"])
assert.Equal(t, "v1.0", m2["header"]["version"])
type PointerHeader struct {
Ref *string `header:"ref"`
}
ref := "reference"
v3 := struct {
Name string `json:"name"`
Address string `json:"address,options=[beijing,shanghai]"`
Age int `json:"age"`
PointerHeader
}{
Name: "kevin",
Address: "shanghai",
Age: 20,
PointerHeader: PointerHeader{
Ref: &ref,
},
}
m3, err3 := Marshal(v3)
assert.Nil(t, err3)
assert.Equal(t, "kevin", m3["json"]["name"])
assert.Equal(t, "shanghai", m3["json"]["address"])
assert.Equal(t, 20, m3["json"]["age"].(int))
assert.Equal(t, "reference", *m3["header"]["ref"].(*string))
})
t.Run("bad anonymous", func(t *testing.T) {
type BaseHeader struct {
Token string `json:"token,options=[a,b]"`
}
v := struct {
Name string `json:"name"`
Address string `json:"address,options=[beijing,shanghai]"`
Age int `json:"age"`
BaseHeader
}{
Name: "kevin",
Address: "shanghai",
Age: 20,
BaseHeader: BaseHeader{
Token: "c",
},
}
_, err := Marshal(v)
assert.NotNil(t, err)
})
}
func TestMarshal_Ptr(t *testing.T) {
v := &struct {
Name string `path:"name"`
@@ -344,3 +462,15 @@ func TestMarshal_FromString(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, "10", m["json"]["age"].(string))
}
func TestMarshal_Array(t *testing.T) {
v := struct {
H [1]int `json:"h,string"`
}{
H: [1]int{1},
}
m, err := Marshal(v)
assert.Nil(t, err)
assert.Equal(t, "[1]", m["json"]["h"].(string))
}

View File

@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"reflect"
"slices"
"strconv"
"strings"
"sync"
@@ -15,11 +16,9 @@ import (
"github.com/zeromicro/go-zero/core/jsonx"
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/stringx"
)
const (
comma = ","
defaultKeyName = "key"
delimiter = '.'
ignoreKey = "-"
@@ -31,14 +30,15 @@ var (
errValueNotSettable = errors.New("value is not settable")
errValueNotStruct = errors.New("value type is not struct")
keyUnmarshaler = NewUnmarshaler(defaultKeyName)
boolType = reflect.TypeOf(false)
durationType = reflect.TypeOf(time.Duration(0))
stringType = reflect.TypeOf("")
cacheKeys = make(map[string][]string)
cacheKeysLock sync.Mutex
defaultCache = make(map[string]any)
defaultCacheLock sync.Mutex
emptyMap = map[string]any{}
emptyValue = reflect.ValueOf(lang.Placeholder)
stringSliceType = reflect.TypeOf([]string{})
)
type (
@@ -152,10 +152,6 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value,
return nil
}
if u.opts.fromArray {
refValue = makeStringSlice(refValue)
}
var valid bool
conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap())
@@ -628,9 +624,19 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
return u.fillSliceFromString(fieldType, value, mapValue, fullName)
case valueKind == reflect.String && derefedFieldType == durationType:
return fillDurationValue(fieldType, value, mapValue.(string))
v, err := convertToString(mapValue, fullName)
if err != nil {
return err
}
return fillDurationValue(fieldType, value, v)
case valueKind == reflect.String && typeKind == reflect.Struct && u.implementsUnmarshaler(fieldType):
return u.fillUnmarshalerStruct(fieldType, value, mapValue.(string))
v, err := convertToString(mapValue, fullName)
if err != nil {
return err
}
return u.fillUnmarshalerStruct(fieldType, value, v)
default:
return u.processFieldPrimitive(fieldType, value, mapValue, opts, fullName)
}
@@ -761,24 +767,24 @@ func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value ref
return err
}
fieldKind := fieldType.Kind()
switch fieldKind {
case reflect.Bool:
derefType := Deref(fieldType)
switch derefType {
case boolType:
val, err := strconv.ParseBool(envVal)
if err != nil {
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
}
value.SetBool(val)
SetValue(fieldType, value, reflect.ValueOf(val))
return nil
case durationType.Kind():
case durationType:
if err := fillDurationValue(fieldType, value, envVal); err != nil {
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
}
return nil
case reflect.String:
value.SetString(envVal)
case stringType:
SetValue(fieldType, value, reflect.ValueOf(envVal))
return nil
default:
return u.processFieldPrimitiveWithJSONNumber(fieldType, value, json.Number(envVal), opts, fullName)
@@ -900,7 +906,7 @@ func (u *Unmarshaler) processNamedFieldWithValueFromString(fieldType reflect.Typ
valueKind.String())
}
if !stringx.Contains(options, checkValue) {
if !slices.Contains(options, checkValue) {
return fmt.Errorf(`value "%s" for field %q is not defined in options "%v"`,
mapValue, key, options)
}
@@ -1189,35 +1195,6 @@ func join(elem ...string) string {
return builder.String()
}
func makeStringSlice(refValue reflect.Value) reflect.Value {
if refValue.Len() != 1 {
return refValue
}
element := refValue.Index(0)
if element.Kind() != reflect.String {
return refValue
}
val, ok := element.Interface().(string)
if !ok {
return refValue
}
splits := strings.Split(val, comma)
if len(splits) <= 1 {
return refValue
}
slice := reflect.MakeSlice(stringSliceType, len(splits), len(splits))
for i, split := range splits {
// allow empty strings
slice.Index(i).Set(reflect.ValueOf(split))
}
return slice
}
func newInitError(name string) error {
return fmt.Errorf("field %q is not set", name)
}

View File

@@ -203,6 +203,20 @@ func TestUnmarshalDuration(t *testing.T) {
}
}
func TestUnmarshalDurationUnexpectedError(t *testing.T) {
type inner struct {
Duration time.Duration `key:"duration"`
}
content := "{\"duration\": 1}"
var m = map[string]any{}
err := jsonx.Unmarshal([]byte(content), &m)
assert.NoError(t, err)
var in inner
err = UnmarshalKey(m, &in)
assert.Error(t, err)
assert.Contains(t, err.Error(), "expect string")
}
func TestUnmarshalDurationDefault(t *testing.T) {
type inner struct {
Int int `key:"int"`
@@ -1462,9 +1476,7 @@ func TestUnmarshalIntSlice(t *testing.T) {
ast := assert.New(t)
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
if ast.NoError(unmarshaler.Unmarshal(m, &v)) {
ast.ElementsMatch([]int{1, 2}, v.Ages)
}
ast.Error(unmarshaler.Unmarshal(m, &v))
})
}
@@ -1546,7 +1558,22 @@ func TestUnmarshalStringSliceFromString(t *testing.T) {
ast := assert.New(t)
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
if ast.NoError(unmarshaler.Unmarshal(m, &v)) {
ast.ElementsMatch([]string{"", ""}, v.Names)
ast.ElementsMatch([]string{","}, v.Names)
}
})
t.Run("slice from valid strings with comma", func(t *testing.T) {
var v struct {
Names []string `key:"names"`
}
m := map[string]any{
"names": []string{"aa,bb"},
}
ast := assert.New(t)
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
if ast.NoError(unmarshaler.Unmarshal(m, &v)) {
ast.ElementsMatch([]string{"aa,bb"}, v.Names)
}
})
@@ -4652,6 +4679,23 @@ func TestUnmarshal_EnvInt(t *testing.T) {
}
}
func TestUnmarshal_EnvInt64(t *testing.T) {
type Value struct {
Age int64 `key:"age,env=TEST_NAME_INT64"`
}
const (
envName = "TEST_NAME_INT64"
envVal = "88"
)
t.Setenv(envName, envVal)
var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
assert.Equal(t, int64(88), v.Age)
}
}
func TestUnmarshal_EnvIntOverwrite(t *testing.T) {
type Value struct {
Age int `key:"age,env=TEST_NAME_INT"`
@@ -4757,20 +4801,33 @@ func TestUnmarshal_EnvBoolBad(t *testing.T) {
}
func TestUnmarshal_EnvDuration(t *testing.T) {
type Value struct {
Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
}
const (
envName = "TEST_NAME_DURATION"
envVal = "1s"
)
t.Setenv(envName, envVal)
var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
assert.Equal(t, time.Second, v.Duration)
}
t.Run("valid duration", func(t *testing.T) {
type Value struct {
Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
}
var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
assert.Equal(t, time.Second, v.Duration)
}
})
t.Run("ptr of duration", func(t *testing.T) {
type Value struct {
Duration *time.Duration `key:"duration,env=TEST_NAME_DURATION"`
}
var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
assert.Equal(t, time.Second, *v.Duration)
}
})
}
func TestUnmarshal_EnvDurationBadValue(t *testing.T) {
@@ -5982,6 +6039,16 @@ func TestUnmarshal_Unmarshaler(t *testing.T) {
}, &v))
assert.Nil(t, v.Foo)
})
t.Run("json.Number", func(t *testing.T) {
v := struct {
Foo *mockUnmarshaler `json:"name"`
}{}
m := map[string]any{
"name": json.Number("123"),
}
assert.Error(t, UnmarshalJsonMap(m, &v))
})
}
func TestParseJsonStringValue(t *testing.T) {

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"math"
"reflect"
"slices"
"strconv"
"strings"
"sync"
@@ -91,6 +92,15 @@ func ValidatePtr(v reflect.Value) error {
return nil
}
func convertToString(val any, fullName string) (string, error) {
v, ok := val.(string)
if !ok {
return "", fmt.Errorf("expect string for field %s, but got type %T", fullName, val)
}
return v, nil
}
func convertTypeFromString(kind reflect.Kind, str string) (any, error) {
switch kind {
case reflect.Bool:
@@ -634,11 +644,11 @@ func validateValueInOptions(val any, options []string) error {
if len(options) > 0 {
switch v := val.(type) {
case string:
if !stringx.Contains(options, v) {
if !slices.Contains(options, v) {
return fmt.Errorf(`error: value %q is not defined in options "%v"`, v, options)
}
default:
if !stringx.Contains(options, Repr(v)) {
if !slices.Contains(options, Repr(v)) {
return fmt.Errorf(`error: value "%v" is not defined in options "%v"`, val, options)
}
}

View File

@@ -1,19 +1,13 @@
package mathx
// MaxInt returns the larger one of a and b.
// Deprecated: use builtin max instead.
func MaxInt(a, b int) int {
if a > b {
return a
}
return b
return max(a, b)
}
// MinInt returns the smaller one of a and b.
// Deprecated: use builtin min instead.
func MinInt(a, b int) int {
if a < b {
return a
}
return b
return min(a, b)
}

View File

@@ -142,89 +142,6 @@ func MapReduceChan[T, U, V any](source <-chan T, mapper MapperFunc[T, U], reduce
return mapReduceWithPanicChan(source, panicChan, mapper, reducer, opts...)
}
// mapReduceWithPanicChan maps all elements from source, and reduce the output elements with given reducer.
func mapReduceWithPanicChan[T, U, V any](source <-chan T, panicChan *onceChan, mapper MapperFunc[T, U],
reducer ReducerFunc[U, V], opts ...Option) (val V, err error) {
options := buildOptions(opts...)
// output is used to write the final result
output := make(chan V)
defer func() {
// reducer can only write once, if more, panic
for range output {
panic("more than one element written in reducer")
}
}()
// collector is used to collect data from mapper, and consume in reducer
collector := make(chan U, options.workers)
// if done is closed, all mappers and reducer should stop processing
done := make(chan struct{})
writer := newGuardedWriter(options.ctx, output, done)
var closeOnce sync.Once
// use atomic type to avoid data race
var retErr errorx.AtomicError
finish := func() {
closeOnce.Do(func() {
close(done)
close(output)
})
}
cancel := once(func(err error) {
if err != nil {
retErr.Set(err)
} else {
retErr.Set(ErrCancelWithNil)
}
drain(source)
finish()
})
go func() {
defer func() {
drain(collector)
if r := recover(); r != nil {
panicChan.write(r)
}
finish()
}()
reducer(collector, writer, cancel)
}()
go executeMappers(mapperContext[T, U]{
ctx: options.ctx,
mapper: func(item T, w Writer[U]) {
mapper(item, w, cancel)
},
source: source,
panicChan: panicChan,
collector: collector,
doneChan: done,
workers: options.workers,
})
select {
case <-options.ctx.Done():
cancel(context.DeadlineExceeded)
err = context.DeadlineExceeded
case v := <-panicChan.channel:
// drain output here, otherwise for loop panic in defer
drain(output)
panic(v)
case v, ok := <-output:
if e := retErr.Load(); e != nil {
err = e
} else if ok {
val = v
} else {
err = ErrReduceNoOutput
}
}
return
}
// MapReduceVoid maps all elements generated from given generate,
// and reduce the output elements with given reducer.
func MapReduceVoid[T, U any](generate GenerateFunc[T], mapper MapperFunc[T, U],
@@ -330,6 +247,89 @@ func executeMappers[T, U any](mCtx mapperContext[T, U]) {
}
}
// mapReduceWithPanicChan maps all elements from source, and reduce the output elements with given reducer.
func mapReduceWithPanicChan[T, U, V any](source <-chan T, panicChan *onceChan, mapper MapperFunc[T, U],
reducer ReducerFunc[U, V], opts ...Option) (val V, err error) {
options := buildOptions(opts...)
// output is used to write the final result
output := make(chan V)
defer func() {
// reducer can only write once, if more, panic
for range output {
panic("more than one element written in reducer")
}
}()
// collector is used to collect data from mapper, and consume in reducer
collector := make(chan U, options.workers)
// if done is closed, all mappers and reducer should stop processing
done := make(chan struct{})
writer := newGuardedWriter(options.ctx, output, done)
var closeOnce sync.Once
// use atomic type to avoid data race
var retErr errorx.AtomicError
finish := func() {
closeOnce.Do(func() {
close(done)
close(output)
})
}
cancel := once(func(err error) {
if err != nil {
retErr.Set(err)
} else {
retErr.Set(ErrCancelWithNil)
}
drain(source)
finish()
})
go func() {
defer func() {
drain(collector)
if r := recover(); r != nil {
panicChan.write(r)
}
finish()
}()
reducer(collector, writer, cancel)
}()
go executeMappers(mapperContext[T, U]{
ctx: options.ctx,
mapper: func(item T, w Writer[U]) {
mapper(item, w, cancel)
},
source: source,
panicChan: panicChan,
collector: collector,
doneChan: done,
workers: options.workers,
})
select {
case <-options.ctx.Done():
cancel(context.DeadlineExceeded)
err = context.DeadlineExceeded
case v := <-panicChan.channel:
// drain output here, otherwise for loop panic in defer
drain(output)
panic(v)
case v, ok := <-output:
if e := retErr.Load(); e != nil {
err = e
} else if ok {
val = v
} else {
err = ErrReduceNoOutput
}
}
return
}
func newOptions() *mapReduceOptions {
return &mapReduceOptions{
ctx: context.Background(),

View File

@@ -1,13 +1,12 @@
package prof
import (
"bytes"
"strconv"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/olekukonko/tablewriter"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/threading"
)
@@ -28,46 +27,15 @@ type (
const flushInterval = 5 * time.Minute
var (
pc = &profileCenter{
slots: make(map[string]*profileSlot),
}
once sync.Once
)
func report(name string, duration time.Duration) {
updated := func() bool {
pc.lock.RLock()
defer pc.lock.RUnlock()
slot, ok := pc.slots[name]
if ok {
atomic.AddInt64(&slot.lifecount, 1)
atomic.AddInt64(&slot.lastcount, 1)
atomic.AddInt64(&slot.lifecycle, int64(duration))
atomic.AddInt64(&slot.lastcycle, int64(duration))
}
return ok
}()
if !updated {
func() {
pc.lock.Lock()
defer pc.lock.Unlock()
pc.slots[name] = &profileSlot{
lifecount: 1,
lastcount: 1,
lifecycle: int64(duration),
lastcycle: int64(duration),
}
}()
}
once.Do(flushRepeatly)
var pc = &profileCenter{
slots: make(map[string]*profileSlot),
}
func flushRepeatly() {
func init() {
flushRepeatedly()
}
func flushRepeatedly() {
threading.GoSafe(func() {
for {
time.Sleep(flushInterval)
@@ -76,42 +44,64 @@ func flushRepeatly() {
})
}
func report(name string, duration time.Duration) {
slot := loadOrStoreSlot(name, duration)
atomic.AddInt64(&slot.lifecount, 1)
atomic.AddInt64(&slot.lastcount, 1)
atomic.AddInt64(&slot.lifecycle, int64(duration))
atomic.AddInt64(&slot.lastcycle, int64(duration))
}
func loadOrStoreSlot(name string, duration time.Duration) *profileSlot {
pc.lock.RLock()
slot, ok := pc.slots[name]
pc.lock.RUnlock()
if ok {
return slot
}
pc.lock.Lock()
defer pc.lock.Unlock()
// double-check
if slot, ok = pc.slots[name]; ok {
return slot
}
slot = &profileSlot{}
pc.slots[name] = slot
return slot
}
func generateReport() string {
var buffer bytes.Buffer
buffer.WriteString("Profiling report\n")
var data [][]string
var builder strings.Builder
builder.WriteString("Profiling report\n")
builder.WriteString("QUEUE,LIFECOUNT,LIFECYCLE,LASTCOUNT,LASTCYCLE\n")
calcFn := func(total, count int64) string {
if count == 0 {
return "-"
}
return (time.Duration(total) / time.Duration(count)).String()
}
func() {
pc.lock.Lock()
defer pc.lock.Unlock()
pc.lock.Lock()
for key, slot := range pc.slots {
builder.WriteString(fmt.Sprintf("%s,%d,%s,%d,%s\n",
key,
slot.lifecount,
calcFn(slot.lifecycle, slot.lifecount),
slot.lastcount,
calcFn(slot.lastcycle, slot.lastcount),
))
for key, slot := range pc.slots {
data = append(data, []string{
key,
strconv.FormatInt(slot.lifecount, 10),
calcFn(slot.lifecycle, slot.lifecount),
strconv.FormatInt(slot.lastcount, 10),
calcFn(slot.lastcycle, slot.lastcount),
})
// reset last cycle stats
atomic.StoreInt64(&slot.lastcount, 0)
atomic.StoreInt64(&slot.lastcycle, 0)
}
pc.lock.Unlock()
// reset the data for last cycle
slot.lastcount = 0
slot.lastcycle = 0
}
}()
table := tablewriter.NewWriter(&buffer)
table.SetHeader([]string{"QUEUE", "LIFECOUNT", "LIFECYCLE", "LASTCOUNT", "LASTCYCLE"})
table.SetBorder(false)
table.AppendBulk(data)
table.Render()
return buffer.String()
return builder.String()
}

View File

@@ -8,7 +8,6 @@ import (
)
func TestReport(t *testing.T) {
once.Do(func() {})
assert.NotContains(t, generateReport(), "foo")
report("foo", time.Second)
assert.Contains(t, generateReport(), "foo")

View File

@@ -8,6 +8,7 @@ import (
"github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/core/trace"
"github.com/zeromicro/go-zero/internal/devserver"
"github.com/zeromicro/go-zero/internal/profiling"
)
const (
@@ -38,6 +39,8 @@ type (
Telemetry trace.Config `json:",optional"`
DevServer DevServerConfig `json:",optional"`
Shutdown proc.ShutdownConf `json:",optional"`
// Profiling is the configuration for continuous profiling.
Profiling profiling.Config `json:",optional"`
}
)
@@ -70,7 +73,9 @@ func (sc ServiceConf) SetUp() error {
if len(sc.MetricsUrl) > 0 {
stat.SetReportWriter(stat.NewRemoteWriter(sc.MetricsUrl))
}
devserver.StartAgent(sc.DevServer)
profiling.Start(sc.Profiling)
return nil
}

View File

@@ -1,9 +1,10 @@
package service
import (
"sync"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/threading"
)
@@ -35,7 +36,7 @@ type (
// NewServiceGroup returns a ServiceGroup.
func NewServiceGroup() *ServiceGroup {
sg := new(ServiceGroup)
sg.stopOnce = syncx.Once(sg.doStop)
sg.stopOnce = sync.OnceFunc(sg.doStop)
return sg
}

View File

@@ -19,7 +19,6 @@ import (
const (
clusterNameKey = "CLUSTER_NAME"
testEnv = "test.v"
timeFormat = "2006-01-02 15:04:05"
)
var (
@@ -45,7 +44,7 @@ func Report(msg string) {
if fn != nil {
reported := lessExecutor.DoOrDiscard(func() {
var builder strings.Builder
builder.WriteString(fmt.Sprintln(time.Now().Format(timeFormat)))
builder.WriteString(fmt.Sprintln(time.Now().Format(time.DateTime)))
if len(clusterName) > 0 {
builder.WriteString(fmt.Sprintf("cluster: %s\n", clusterName))
}

View File

@@ -609,6 +609,28 @@ func (s *Redis) GetBitCtx(ctx context.Context, key string, offset int64) (int, e
return int(v), nil
}
// GetDel is the implementation of redis getdel command.
// Available since: redis version 6.2.0
func (s *Redis) GetDel(key string) (string, error) {
return s.GetDelCtx(context.Background(), key)
}
// GetDelCtx is the implementation of redis getdel command.
// Available since: redis version 6.2.0
func (s *Redis) GetDelCtx(ctx context.Context, key string) (string, error) {
conn, err := getRedis(s)
if err != nil {
return "", err
}
val, err := conn.GetDel(ctx, key).Result()
if errors.Is(err, red.Nil) {
return "", nil
}
return val, err
}
// GetSet is the implementation of redis getset command.
func (s *Redis) GetSet(key, value string) (string, error) {
return s.GetSetCtx(context.Background(), key, value)

View File

@@ -1071,6 +1071,34 @@ func TestRedis_Set(t *testing.T) {
})
}
func TestRedis_GetDel(t *testing.T) {
t.Run("get_del", func(t *testing.T) {
runOnRedis(t, func(client *Redis) {
val, err := newRedis(client.Addr).GetDel("hello")
assert.Equal(t, "", val)
assert.Nil(t, err)
err = client.Set("hello", "world")
assert.Nil(t, err)
val, err = client.Get("hello")
assert.Nil(t, err)
assert.Equal(t, "world", val)
val, err = client.GetDel("hello")
assert.Nil(t, err)
assert.Equal(t, "world", val)
val, err = client.Get("hello")
assert.Nil(t, err)
assert.Equal(t, "", val)
})
})
t.Run("get_del_with_error", func(t *testing.T) {
runOnRedisWithError(t, func(client *Redis) {
_, err := newRedis(client.Addr, badType()).GetDel("hello")
assert.Error(t, err)
})
})
}
func TestRedis_GetSet(t *testing.T) {
t.Run("set_get", func(t *testing.T) {
runOnRedis(t, func(client *Redis) {

View File

@@ -21,6 +21,7 @@ func CreateBlockingNode(r *Redis) (ClosableNode, error) {
case NodeType:
client := red.NewClient(&red.Options{
Addr: r.Addr,
Username: r.User,
Password: r.Pass,
DB: defaultDatabase,
MaxRetries: maxRetries,
@@ -32,6 +33,7 @@ func CreateBlockingNode(r *Redis) (ClosableNode, error) {
case ClusterType:
client := red.NewClusterClient(&red.ClusterOptions{
Addrs: splitClusterAddrs(r.Addr),
Username: r.User,
Password: r.Pass,
MaxRetries: maxRetries,
PoolSize: 1,

View File

@@ -31,6 +31,7 @@ func getClient(r *Redis) (*red.Client, error) {
}
store := red.NewClient(&red.Options{
Addr: r.Addr,
Username: r.User,
Password: r.Pass,
DB: defaultDatabase,
MaxRetries: maxRetries,

View File

@@ -28,6 +28,7 @@ func getCluster(r *Redis) (*red.ClusterClient, error) {
}
store := red.NewClusterClient(&red.ClusterOptions{
Addrs: splitClusterAddrs(r.Addr),
Username: r.User,
Password: r.Pass,
MaxRetries: maxRetries,
MinIdleConns: idleConns,

View File

@@ -0,0 +1,29 @@
package sqlx
import "errors"
var (
errEmptyDatasource = errors.New("empty datasource")
errEmptyDriverName = errors.New("empty driver name")
)
// SqlConf defines the configuration for sqlx.
type SqlConf struct {
DataSource string
DriverName string `json:",default=mysql"`
Replicas []string `json:",optional"`
Policy string `json:",default=round-robin,options=round-robin|random"`
}
// Validate validates the SqlxConf.
func (sc SqlConf) Validate() error {
if len(sc.DataSource) == 0 {
return errEmptyDatasource
}
if len(sc.DriverName) == 0 {
return errEmptyDriverName
}
return nil
}

View File

@@ -0,0 +1,29 @@
package sqlx
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/conf"
)
func TestValidate(t *testing.T) {
text := []byte(`DataSource: primary:password@tcp(127.0.0.1:3306)/primary_db
`)
var sc SqlConf
err := conf.LoadFromYamlBytes(text, &sc)
assert.Nil(t, err)
assert.Equal(t, "mysql", sc.DriverName)
assert.Equal(t, policyRoundRobin, sc.Policy)
assert.Nil(t, sc.Validate())
sc = SqlConf{}
assert.Equal(t, errEmptyDatasource, sc.Validate())
sc.DataSource = "primary:password@tcp(127.0.0.1:3306)/primary_db"
assert.Equal(t, errEmptyDriverName, sc.Validate())
sc.DriverName = "mysql"
assert.Nil(t, sc.Validate())
}

View File

@@ -267,6 +267,20 @@ func TestUnmarshalRowStruct(t *testing.T) {
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
Name string
age int
})
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -310,6 +324,20 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
}, "select name, age from users where user=?", "anyone"), ErrNotReadableValue)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
age int `db:"age"`
Name string `db:"name"`
})
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"), ErrNotReadableValue)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var value struct {
Age *int `db:"age"`
@@ -1307,25 +1335,25 @@ func TestAnonymousStructPr(t *testing.T) {
}
func TestAnonymousStructPrError(t *testing.T) {
type Score struct {
Discipline string `db:"discipline"`
score uint `db:"score"`
}
type ClassType struct {
Grade sql.NullString `db:"grade"`
ClassName *string `db:"class_name"`
}
type Class struct {
*ClassType
Score
}
var value []*struct {
Age int64 `db:"age"`
Class
Name string `db:"name"`
}
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
type Score struct {
Discipline string `db:"discipline"`
score uint `db:"score"`
}
type ClassType struct {
Grade sql.NullString `db:"grade"`
ClassName *string `db:"class_name"`
}
type Class struct {
*ClassType
Score
}
var value []*struct {
Age int64 `db:"age"`
Class
Name string `db:"name"`
}
rs := sqlmock.NewRows([]string{
"name",
"age",
@@ -1338,10 +1366,50 @@ func TestAnonymousStructPrError(t *testing.T) {
AddRow("second", 3, "grade one", "chinese", "class three grade two", 99)
mock.ExpectQuery("select (.+) from users where user=?").
WithArgs("anyone").WillReturnRows(rs)
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age, grade, discipline, class_name, score from users where user=?",
"anyone"))
"anyone"), ErrNotReadableValue)
if len(value) > 0 {
assert.Equal(t, value[0].score, 0)
}
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
type Score struct {
Discipline string
score uint
}
type ClassType struct {
Grade sql.NullString
ClassName *string
}
type Class struct {
*ClassType
Score
}
var value []*struct {
Age int64
Class
Name string
}
rs := sqlmock.NewRows([]string{
"name",
"age",
"grade",
"discipline",
"class_name",
"score",
}).
AddRow("first", 2, nil, "math", "experimental class", 100).
AddRow("second", 3, "grade one", "chinese", "class three grade two", 99)
mock.ExpectQuery("select (.+) from users where user=?").
WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age, grade, discipline, class_name, score from users where user=?",
"anyone"), ErrNotMatchDestination)
if len(value) > 0 {
assert.Equal(t, value[0].score, 0)
}

View File

@@ -0,0 +1,65 @@
package sqlx
import "context"
const (
// policyRoundRobin round-robin policy for selecting replicas.
policyRoundRobin = "round-robin"
// policyRandom random policy for selecting replicas.
policyRandom = "random"
// readPrimaryMode indicates that the operation is a read,
// but should be performed on the primary database instance.
//
// This mode is used in scenarios where data freshness and consistency are critical,
// such as immediately after writes or where replication lag may cause stale reads.
readPrimaryMode readWriteMode = "read-primary"
// readReplicaMode indicates that the operation is a read from replicas.
// This is suitable for scenarios where eventual consistency is acceptable,
// and the goal is to offload traffic from the primary and improve read scalability.
readReplicaMode readWriteMode = "read-replica"
// writeMode indicates that the operation is a write operation (to primary).
writeMode readWriteMode = "write"
// notSpecifiedMode indicates that the read/write mode is not specified.
notSpecifiedMode readWriteMode = ""
)
type readWriteModeKey struct{}
// WithReadPrimary sets the context to read-primary mode.
func WithReadPrimary(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey{}, readPrimaryMode)
}
// WithReadReplica sets the context to read-replica mode.
func WithReadReplica(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey{}, readReplicaMode)
}
// WithWrite sets the context to write mode, indicating that the operation is a write operation.
func WithWrite(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey{}, writeMode)
}
type readWriteMode string
func (m readWriteMode) isValid() bool {
return m == readPrimaryMode || m == readReplicaMode || m == writeMode
}
func getReadWriteMode(ctx context.Context) readWriteMode {
if mode := ctx.Value(readWriteModeKey{}); mode != nil {
if v, ok := mode.(readWriteMode); ok && v.isValid() {
return v
}
}
return notSpecifiedMode
}
func usePrimary(ctx context.Context) bool {
return getReadWriteMode(ctx) != readReplicaMode
}

View File

@@ -0,0 +1,142 @@
package sqlx
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsValid(t *testing.T) {
testCases := []struct {
name string
mode readWriteMode
expected bool
}{
{
name: "valid read-primary mode",
mode: readPrimaryMode,
expected: true,
},
{
name: "valid read-replica mode",
mode: readReplicaMode,
expected: true,
},
{
name: "valid write mode",
mode: writeMode,
expected: true,
},
{
name: "not specified mode (empty)",
mode: notSpecifiedMode,
expected: false,
},
{
name: "invalid custom string",
mode: readWriteMode("delete"),
expected: false,
},
{
name: "case sensitive check",
mode: readWriteMode("READ"),
expected: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actual := tc.mode.isValid()
assert.Equal(t, tc.expected, actual)
})
}
}
func TestWithReadMode(t *testing.T) {
ctx := context.Background()
readPrimaryCtx := WithReadPrimary(ctx)
val := readPrimaryCtx.Value(readWriteModeKey{})
assert.Equal(t, readPrimaryMode, val)
readReplicaCtx := WithReadReplica(ctx)
val = readReplicaCtx.Value(readWriteModeKey{})
assert.Equal(t, readReplicaMode, val)
}
func TestWithWriteMode(t *testing.T) {
ctx := context.Background()
writeCtx := WithWrite(ctx)
val := writeCtx.Value(readWriteModeKey{})
assert.Equal(t, writeMode, val)
}
func TestGetReadWriteMode(t *testing.T) {
t.Run("valid read-primary mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readPrimaryMode)
assert.Equal(t, readPrimaryMode, getReadWriteMode(ctx))
})
t.Run("valid read-replica mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readReplicaMode)
assert.Equal(t, readReplicaMode, getReadWriteMode(ctx))
})
t.Run("valid write mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, writeMode)
assert.Equal(t, writeMode, getReadWriteMode(ctx))
})
t.Run("invalid mode value (wrong type)", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, "not-a-mode")
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
})
t.Run("invalid mode value (wrong value)", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readWriteMode("delete"))
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
})
t.Run("no mode set", func(t *testing.T) {
ctx := context.Background()
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
})
}
func TestUsePrimary(t *testing.T) {
t.Run("context with read-replica mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readReplicaMode)
assert.False(t, usePrimary(ctx))
})
t.Run("context with read-primary mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readPrimaryMode)
assert.True(t, usePrimary(ctx))
})
t.Run("context with write mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, writeMode)
assert.True(t, usePrimary(ctx))
})
t.Run("context with invalid mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readWriteMode("invalid"))
assert.True(t, usePrimary(ctx))
})
t.Run("context with no mode set", func(t *testing.T) {
ctx := context.Background()
assert.True(t, usePrimary(ctx))
})
}
func TestWithModeTwice(t *testing.T) {
ctx := context.Background()
ctx = WithReadPrimary(ctx)
writeCtx := WithWrite(ctx)
val := writeCtx.Value(readWriteModeKey{})
assert.Equal(t, writeMode, val)
}

View File

@@ -4,6 +4,9 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"math/rand"
"sync/atomic"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/errorx"
@@ -52,9 +55,10 @@ type (
beginTx beginnable
brk breaker.Breaker
accept breaker.Acceptable
index uint32
}
connProvider func() (*sql.DB, error)
connProvider func(ctx context.Context) (*sql.DB, error)
sessionConn interface {
Exec(query string, args ...any) (sql.Result, error)
@@ -64,10 +68,41 @@ type (
}
)
// MustNewConn returns a SqlConn with the given SqlConf.
func MustNewConn(c SqlConf, opts ...SqlOption) SqlConn {
conn, err := NewConn(c, opts...)
if err != nil {
logx.Must(err)
}
return conn
}
// NewConn returns a SqlConn with the given SqlConf.
func NewConn(c SqlConf, opts ...SqlOption) (SqlConn, error) {
if err := c.Validate(); err != nil {
return nil, err
}
conn := &commonSqlConn{
onError: func(ctx context.Context, err error) {
logInstanceError(ctx, c.DataSource, err)
},
beginTx: begin,
brk: breaker.NewBreaker(),
}
for _, opt := range opts {
opt(conn)
}
conn.connProv = getConnProvider(conn, c.DriverName, c.DataSource, c.Policy, c.Replicas)
return conn, nil
}
// NewSqlConn returns a SqlConn with given driver name and datasource.
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
conn := &commonSqlConn{
connProv: func() (*sql.DB, error) {
connProv: func(context.Context) (*sql.DB, error) {
return getSqlConn(driverName, datasource)
},
onError: func(ctx context.Context, err error) {
@@ -87,7 +122,7 @@ func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
// Use it with caution; it's provided for other ORM to interact with.
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
conn := &commonSqlConn{
connProv: func() (*sql.DB, error) {
connProv: func(ctx context.Context) (*sql.DB, error) {
return db, nil
},
onError: func(ctx context.Context, err error) {
@@ -123,7 +158,7 @@ func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
var conn *sql.DB
conn, err = db.connProv()
conn, err = db.connProv(ctx)
if err != nil {
db.onError(ctx, err)
return err
@@ -151,7 +186,7 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
var conn *sql.DB
conn, err = db.connProv()
conn, err = db.connProv(ctx)
if err != nil {
db.onError(ctx, err)
return err
@@ -242,7 +277,7 @@ func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any,
}
func (db *commonSqlConn) RawDB() (*sql.DB, error) {
return db.connProv()
return db.connProv(context.Background())
}
func (db *commonSqlConn) Transact(fn func(Session) error) error {
@@ -288,7 +323,7 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
q string, args ...any) (err error) {
var scanFailed bool
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
conn, err := db.connProv()
conn, err := db.connProv(ctx)
if err != nil {
db.onError(ctx, err)
return err
@@ -311,6 +346,38 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
return
}
func getConnProvider(sc *commonSqlConn, driverName, datasource, policy string, replicas []string) connProvider {
return func(ctx context.Context) (*sql.DB, error) {
replicaCount := len(replicas)
if replicaCount == 0 || usePrimary(ctx) {
return getSqlConn(driverName, datasource)
}
var dsn string
if replicaCount == 1 {
dsn = replicas[0]
} else {
if len(policy) == 0 {
policy = policyRoundRobin
}
switch policy {
case policyRandom:
dsn = replicas[rand.Intn(replicaCount)]
case policyRoundRobin:
index := atomic.AddUint32(&sc.index, 1) - 1
dsn = replicas[index%uint32(replicaCount)]
default:
return nil, fmt.Errorf("unknown policy: %s", policy)
}
}
return getSqlConn(driverName, dsn)
}
}
// WithAcceptable returns a SqlOption that setting the acceptable function.
// acceptable is the func to check if the error can be accepted.
func WithAcceptable(acceptable func(err error) bool) SqlOption {

View File

@@ -1,6 +1,7 @@
package sqlx
import (
"context"
"database/sql"
"errors"
"io"
@@ -98,7 +99,7 @@ func TestSqlConn_RawDB(t *testing.T) {
func TestSqlConn_Errors(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := NewSqlConnFromDB(db)
conn.(*commonSqlConn).connProv = func() (*sql.DB, error) {
conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) {
return nil, errors.New("error")
}
_, err := conn.Prepare("any")
@@ -138,6 +139,148 @@ func TestSqlConn_Errors(t *testing.T) {
})
}
func TestConfigSqlConn(t *testing.T) {
db, mock, err := sqlmock.New()
assert.NotNil(t, db)
assert.NotNil(t, mock)
assert.Nil(t, err)
connManager.Inject(mockedDatasource, db)
mock.ExpectExec("any")
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf, withMysqlAcceptable())
_, err = conn.Exec("any", "value")
assert.NotNil(t, err)
_, err = conn.Prepare("any")
assert.NotNil(t, err)
var val string
assert.NotNil(t, conn.QueryRow(&val, "any"))
assert.NotNil(t, conn.QueryRowPartial(&val, "any"))
assert.NotNil(t, conn.QueryRows(&val, "any"))
assert.NotNil(t, conn.QueryRowsPartial(&val, "any"))
}
func TestConfigSqlConnStatement(t *testing.T) {
db, mock, err := sqlmock.New()
assert.NotNil(t, db)
assert.NotNil(t, mock)
assert.Nil(t, err)
connManager.Inject(mockedDatasource, db)
mock.ExpectPrepare("any")
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
mock.ExpectPrepare("any")
row := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(row)
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf, withMysqlAcceptable())
stmt, err := conn.Prepare("any")
assert.NoError(t, err)
res, err := stmt.Exec()
assert.NoError(t, err)
lastInsertID, err := res.LastInsertId()
assert.NoError(t, err)
assert.Equal(t, int64(2), lastInsertID)
rowsAffected, err := res.RowsAffected()
assert.NoError(t, err)
assert.Equal(t, int64(3), rowsAffected)
stmt, err = conn.Prepare("any")
assert.NoError(t, err)
var val string
err = stmt.QueryRow(&val)
assert.NoError(t, err)
assert.Equal(t, "bar", val)
mock.ExpectPrepare("any")
rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
stmt, err = conn.Prepare("any")
assert.NoError(t, err)
var vals []string
assert.NoError(t, stmt.QueryRowsPartial(&vals))
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
}
func TestConfigSqlConnQuery(t *testing.T) {
db, mock, err := sqlmock.New()
assert.NotNil(t, db)
assert.NotNil(t, mock)
assert.Nil(t, err)
connManager.Inject(mockedDatasource, db)
t.Run("QueryRow", func(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf)
var val string
assert.NoError(t, conn.QueryRow(&val, "any"))
assert.Equal(t, "bar", val)
})
t.Run("QueryRowPartial", func(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf)
var val string
assert.NoError(t, conn.QueryRowPartial(&val, "any"))
assert.Equal(t, "bar", val)
})
t.Run("QueryRows", func(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf)
var vals []string
assert.NoError(t, conn.QueryRows(&vals, "any"))
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
})
t.Run("QueryRowsPartial", func(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf)
var vals []string
assert.NoError(t, conn.QueryRowsPartial(&vals, "any"))
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
})
}
func TestConfigSqlConnErr(t *testing.T) {
t.Run("panic on empty config", func(t *testing.T) {
original := logx.ExitOnFatal.True()
logx.ExitOnFatal.Set(false)
defer logx.ExitOnFatal.Set(original)
assert.Panics(t, func() {
MustNewConn(SqlConf{})
})
})
t.Run("on error", func(t *testing.T) {
db, mock, err := sqlmock.New()
assert.NotNil(t, db)
assert.NotNil(t, mock)
assert.Nil(t, err)
connManager.Inject(mockedDatasource, db)
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf)
conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) {
return nil, errors.New("error")
}
_, err = conn.Prepare("any")
assert.Error(t, err)
})
}
func TestStatement(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectPrepare("any").WillBeClosed()
@@ -303,6 +446,93 @@ func TestWithAcceptable(t *testing.T) {
assert.True(t, conn.accept(acceptableErr3))
}
func TestProvider(t *testing.T) {
defer func() {
_ = connManager.Close()
}()
primaryDSN := "primary:password@tcp(127.0.0.1:3306)/primary_db"
replicasDSN := []string{
"replica_one:pwd@tcp(localhost:3306)/replica_one",
"replica_two:pwd@tcp(localhost:3306)/replica_two",
"replica_three:pwd@tcp(localhost:3306)/replica_three",
}
primaryDB, err := connManager.GetResource(primaryDSN, func() (io.Closer, error) { return sql.Open(mysqlDriverName, primaryDSN) })
assert.Nil(t, err)
assert.NotNil(t, primaryDB)
replicaOneDB, err := connManager.GetResource(replicasDSN[0], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[0]) })
assert.Nil(t, err)
assert.NotNil(t, replicaOneDB)
replicaTwoDB, err := connManager.GetResource(replicasDSN[1], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[1]) })
assert.Nil(t, err)
assert.NotNil(t, replicaTwoDB)
replicaThreeDB, err := connManager.GetResource(replicasDSN[2], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[2]) })
assert.Nil(t, err)
assert.NotNil(t, replicaThreeDB)
sc := &commonSqlConn{}
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, nil)
ctx := context.Background()
db, err := sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, primaryDB, db)
ctx = WithWrite(ctx)
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, primaryDB, db)
ctx = WithReadPrimary(ctx)
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, primaryDB, db)
// no mode set, should return primary
ctx = context.Background()
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, replicasDSN)
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, primaryDB, db)
ctx = WithReadReplica(ctx)
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, []string{replicasDSN[0]})
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, replicaOneDB, db)
// default policy is round-robin
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, replicasDSN)
replicas := []io.Closer{replicaOneDB, replicaTwoDB, replicaThreeDB}
for i := 0; i < len(replicasDSN); i++ {
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, replicas[i], db)
}
// random policy
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRandom, replicasDSN)
for i := 0; i < len(replicasDSN); i++ {
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Contains(t, replicas, db)
}
// unknown policy
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, "unknown", replicasDSN)
_, err = sc.connProv(ctx)
assert.NotNil(t, err)
// empty policy transforms to round-robin
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, "", replicasDSN)
for i := 0; i < len(replicasDSN); i++ {
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, replicas[i], db)
}
}
func buildConn() (mock sqlmock.Sqlmock, err error) {
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
var db *sql.DB

View File

@@ -27,7 +27,7 @@ func getCachedSqlConn(driverName, server string) (*sql.DB, error) {
return nil, err
}
if driverName != mysqlDriverName {
if driverName == mysqlDriverName {
if cfg, e := mysql.ParseDSN(server); e != nil {
// if cannot parse, don't collect the metrics
logx.Error(e)

View File

@@ -156,7 +156,7 @@ func begin(db *sql.DB) (trans, error) {
func transact(ctx context.Context, db *commonSqlConn, b beginnable,
fn func(context.Context, Session) error) (err error) {
conn, err := db.connProv()
conn, err := db.connProv(ctx)
if err != nil {
db.onError(ctx, err)
return err

View File

@@ -117,7 +117,7 @@ func TestTxExceptions(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := &commonSqlConn{
connProv: func() (*sql.DB, error) {
connProv: func(ctx context.Context) (*sql.DB, error) {
return nil, errors.New("foo")
},
beginTx: begin,

View File

@@ -2,6 +2,7 @@ package stringx
import (
"errors"
"slices"
"unicode"
"github.com/zeromicro/go-zero/core/lang"
@@ -15,14 +16,9 @@ var (
)
// Contains checks if str is in list.
// Deprecated: use slices.Contains instead.
func Contains(list []string, str string) bool {
for _, each := range list {
if each == str {
return true
}
}
return false
return slices.Contains(list, str)
}
// Filter filters chars from s with given filter function.
@@ -123,11 +119,7 @@ func Remove(strings []string, strs ...string) []string {
// Reverse reverses s.
func Reverse(s string) string {
runes := []rune(s)
for from, to := 0, len(runes)-1; from < to; from, to = from+1, to-1 {
runes[from], runes[to] = runes[to], runes[from]
}
slices.Reverse(runes)
return string(runes)
}

View File

@@ -7,6 +7,28 @@ import (
"github.com/stretchr/testify/assert"
)
func TestContainsString(t *testing.T) {
cases := []struct {
slice []string
value string
expect bool
}{
{[]string{"1"}, "1", true},
{[]string{"1"}, "2", false},
{[]string{"1", "2"}, "1", true},
{[]string{"1", "2"}, "3", false},
{nil, "3", false},
{nil, "", false},
}
for _, each := range cases {
t.Run(path.Join(each.slice...), func(t *testing.T) {
actual := Contains(each.slice, each.value)
assert.Equal(t, each.expect, actual)
})
}
}
func TestNotEmpty(t *testing.T) {
cases := []struct {
args []string
@@ -41,28 +63,6 @@ func TestNotEmpty(t *testing.T) {
}
}
func TestContainsString(t *testing.T) {
cases := []struct {
slice []string
value string
expect bool
}{
{[]string{"1"}, "1", true},
{[]string{"1"}, "2", false},
{[]string{"1", "2"}, "1", true},
{[]string{"1", "2"}, "3", false},
{nil, "3", false},
{nil, "", false},
}
for _, each := range cases {
t.Run(path.Join(each.slice...), func(t *testing.T) {
actual := Contains(each.slice, each.value)
assert.Equal(t, each.expect, actual)
})
}
}
func TestFilter(t *testing.T) {
cases := []struct {
input string

View File

@@ -3,9 +3,7 @@ package syncx
import "sync"
// Once returns a func that guarantees fn can only called once.
// Deprecated: use sync.OnceFunc instead.
func Once(fn func()) func() {
once := new(sync.Once)
return func() {
once.Do(fn)
}
return sync.OnceFunc(fn)
}

View File

@@ -5,6 +5,7 @@ import (
"runtime"
"sync"
"sync/atomic"
"time"
)
const factor = 10
@@ -100,6 +101,6 @@ func (r *StableRunner[I, O]) Wait() {
close(r.done)
r.runner.Wait()
for atomic.LoadUint64(&r.consumedIndex) < atomic.LoadUint64(&r.writtenIndex) {
runtime.Gosched()
time.Sleep(time.Millisecond)
}
}

View File

@@ -1,10 +1,10 @@
package utils
import (
"cmp"
"strconv"
"strings"
"github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/stringx"
)
@@ -39,7 +39,7 @@ func compare(v1, v2 string) int {
fields1, fields2 := strings.Split(v1, "."), strings.Split(v2, ".")
ver1, ver2 := strsToInts(fields1), strsToInts(fields2)
ver1len, ver2len := len(ver1), len(ver2)
shorter := mathx.MinInt(ver1len, ver2len)
shorter := min(ver1len, ver2len)
for i := 0; i < shorter; i++ {
if ver1[i] == ver2[i] {
@@ -50,14 +50,7 @@ func compare(v1, v2 string) int {
return 1
}
}
if ver1len < ver2len {
return -1
} else if ver1len == ver2len {
return 0
} else {
return 1
}
return cmp.Compare(ver1len, ver2len)
}
func strsToInts(strs []string) []int64 {

View File

@@ -185,6 +185,8 @@ func (s *Server) buildHttpHandler(target *HttpClientConf) http.HandlerFunc {
return
}
// set the timeout if it's configured, take effect only if it's greater than 0
// and less than the deadline of the original request
if target.Timeout > 0 {
timeout := time.Duration(target.Timeout) * time.Millisecond
ctx, cancel := context.WithTimeout(r.Context(), timeout)
@@ -276,7 +278,7 @@ func buildRequestWithNewTarget(r *http.Request, target *HttpClientConf) (*http.R
}
}
return &http.Request{
newReq := &http.Request{
Method: r.Method,
URL: &u,
Header: r.Header.Clone(),
@@ -285,7 +287,10 @@ func buildRequestWithNewTarget(r *http.Request, target *HttpClientConf) (*http.R
ProtoMinor: r.ProtoMinor,
ContentLength: r.ContentLength,
Body: io.NopCloser(r.Body),
}, nil
}
// make sure the context is passed to the new request
return newReq.WithContext(r.Context()), nil
}
func createDescriptorSource(cli zrpc.Client, up Upstream) (grpcurl.DescriptorSource, error) {

View File

@@ -201,6 +201,13 @@ func TestHttpToHttp(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
})
t.Run("method not allowed", func(t *testing.T) {
resp, err := httpc.Do(context.Background(), http.MethodPost,
"http://localhost:18882/api/ping", nil)
assert.NoError(t, err)
assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
})
}
func TestHttpToHttpBadUpstream(t *testing.T) {

42
go.mod
View File

@@ -4,25 +4,25 @@ go 1.21
require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/alicebob/miniredis/v2 v2.34.0
github.com/alicebob/miniredis/v2 v2.35.0
github.com/fatih/color v1.18.0
github.com/fullstorydev/grpcurl v1.9.2
github.com/go-sql-driver/mysql v1.8.1
github.com/golang-jwt/jwt/v4 v4.5.1
github.com/fullstorydev/grpcurl v1.9.3
github.com/go-sql-driver/mysql v1.9.0
github.com/golang-jwt/jwt/v4 v4.5.2
github.com/golang/mock v1.6.0
github.com/golang/protobuf v1.5.4
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.7.2
github.com/grafana/pyroscope-go v1.2.2
github.com/jackc/pgx/v5 v5.7.4
github.com/jhump/protoreflect v1.17.0
github.com/olekukonko/tablewriter v0.0.5
github.com/pelletier/go-toml/v2 v2.2.2
github.com/prometheus/client_golang v1.20.5
github.com/redis/go-redis/v9 v9.7.0
github.com/prometheus/client_golang v1.21.1
github.com/redis/go-redis/v9 v9.11.0
github.com/spaolacci/murmur3 v1.1.0
github.com/stretchr/testify v1.10.0
go.etcd.io/etcd/api/v3 v3.5.15
go.etcd.io/etcd/client/v3 v3.5.15
go.mongodb.org/mongo-driver v1.17.2
go.mongodb.org/mongo-driver v1.17.4
go.opentelemetry.io/otel v1.24.0
go.opentelemetry.io/otel/exporters/jaeger v1.17.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0
@@ -33,12 +33,12 @@ require (
go.opentelemetry.io/otel/trace v1.24.0
go.uber.org/automaxprocs v1.6.0
go.uber.org/goleak v1.3.0
golang.org/x/net v0.34.0
golang.org/x/sys v0.29.0
golang.org/x/time v0.9.0
golang.org/x/net v0.35.0
golang.org/x/sys v0.30.0
golang.org/x/time v0.10.0
google.golang.org/genproto/googleapis/api v0.0.0-20240711142825-46eb208f015d
google.golang.org/grpc v1.65.0
google.golang.org/protobuf v1.36.4
google.golang.org/protobuf v1.36.5
gopkg.in/cheggaaa/pb.v1 v1.0.28
gopkg.in/h2non/gock.v1 v1.1.2
gopkg.in/yaml.v2 v2.4.0
@@ -50,7 +50,6 @@ require (
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/bufbuild/protocompile v0.14.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
@@ -73,6 +72,7 @@ require (
github.com/google/gnostic-models v0.6.8 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/gofuzz v1.2.0 // indirect
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
@@ -80,7 +80,7 @@ require (
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.17.9 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
@@ -93,7 +93,7 @@ require (
github.com/openzipkin/zipkin-go v0.4.3 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.55.0 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
@@ -109,11 +109,11 @@ require (
go.uber.org/atomic v1.10.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
go.uber.org/zap v1.24.0 // indirect
golang.org/x/crypto v0.32.0 // indirect
golang.org/x/oauth2 v0.21.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/term v0.28.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/crypto v0.33.0 // indirect
golang.org/x/oauth2 v0.24.0 // indirect
golang.org/x/sync v0.11.0 // indirect
golang.org/x/term v0.29.0 // indirect
golang.org/x/text v0.22.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

85
go.sum
View File

@@ -2,10 +2,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE=
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0=
github.com/alicebob/miniredis/v2 v2.34.0/go.mod h1:kWShP4b58T1CW0Y5dViCd5ztzrDqRWqM3nksiyXk5s8=
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
@@ -40,8 +38,8 @@ github.com/envoyproxy/protoc-gen-validate v1.0.4 h1:gVPz/FMfvh57HdSJQyvBtF00j8JU
github.com/envoyproxy/protoc-gen-validate v1.0.4/go.mod h1:qys6tmnRsYrQqIhm2bvKZH4Blx/1gTIZ2UKVY1M+Yew=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
github.com/fullstorydev/grpcurl v1.9.2 h1:ObqVQTZW7aFnhuqQoppUrvep2duMBanB0UYK2Mm8euo=
github.com/fullstorydev/grpcurl v1.9.2/go.mod h1:jLfcF55HAz6TYIJY9xFFWgsl0D7o2HlxA5Z4lUG0Tdo=
github.com/fullstorydev/grpcurl v1.9.3 h1:PC1Xi3w+JAvEE2Tg2Gf2RfVgPbf9+tbuQr1ZkyVU3jk=
github.com/fullstorydev/grpcurl v1.9.3/go.mod h1:/b4Wxe8bG6ndAjlfSUjwseQReUDUvBJiFEB7UllOlUE=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
@@ -55,15 +53,15 @@ github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En
github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14=
github.com/go-openapi/swag v0.22.4 h1:QLMzNJnMGPRNDCbySlcj1x01tzU8/9LTTL9hZZZogBU=
github.com/go-openapi/swag v0.22.4/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/go-sql-driver/mysql v1.9.0 h1:Y0zIbQXhQKmQgTp44Y1dp3wTXcn804QoTptLZT1vtvo=
github.com/go-sql-driver/mysql v1.9.0/go.mod h1:pDetrLJeA3oMujJuvXc8RJoasr589B6A9fwzD3QMrqw=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo=
github.com/golang-jwt/jwt/v4 v4.5.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
@@ -82,6 +80,10 @@ github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJY
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/grafana/pyroscope-go v1.2.2 h1:uvKCyZMD724RkaCEMrSTC38Yn7AnFe8S2wiAIYdDPCE=
github.com/grafana/pyroscope-go v1.2.2/go.mod h1:zzT9QXQAp2Iz2ZdS216UiV8y9uXJYQiGE1q8v1FyhqU=
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 h1:iwOtYXeeVSAeYefJNaxDytgjKtUuKQbJqgAIjlnicKg=
github.com/grafana/pyroscope-go/godeltaprof v0.1.8/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k=
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
@@ -90,8 +92,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI=
github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg=
github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jhump/protoreflect v1.17.0 h1:qOEr613fac2lOuTgWN4tPAtLL7fUSbuJL5X5XumQh94=
@@ -103,8 +105,8 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
@@ -121,7 +123,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -135,8 +136,6 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4=
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4=
github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o=
github.com/onsi/gomega v1.29.0 h1:KIA/t2t5UBzoirT4H9tsML45GEbo3ouUnBHsCfD2tVg=
@@ -151,16 +150,16 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y=
github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk=
github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc=
github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8=
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/redis/go-redis/v9 v9.7.0 h1:HhLSs+B6O021gwzl+locl0zEDnyNkxMtf/Z3NNBMa9E=
github.com/redis/go-redis/v9 v9.7.0/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw=
github.com/redis/go-redis/v9 v9.11.0 h1:E3S08Gl/nJNn5vkxd2i78wZxWAPNZgUNTp8WIJUAiIs=
github.com/redis/go-redis/v9 v9.11.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
@@ -203,8 +202,8 @@ go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5
go.etcd.io/etcd/client/pkg/v3 v3.5.15/go.mod h1:mXDI4NAOwEiszrHCb0aqfAYNCrZP4e9hRca3d1YK8EU=
go.etcd.io/etcd/client/v3 v3.5.15 h1:23M0eY4Fd/inNv1ZfU3AxrbbOdW79r9V9Rl62Nm6ip4=
go.etcd.io/etcd/client/v3 v3.5.15/go.mod h1:CLSJxrYjvLtHsrPKsy7LmZEE+DK2ktfd2bN4RhBMwlU=
go.mongodb.org/mongo-driver v1.17.2 h1:gvZyk8352qSfzyZ2UMWcpDpMSGEr1eqE4T793SqyhzM=
go.mongodb.org/mongo-driver v1.17.2/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
go.mongodb.org/mongo-driver v1.17.4 h1:jUorfmVzljjr0FLzYQsGP8cgN/qzzxlY9Vh0C9KFXVw=
go.mongodb.org/mongo-driver v1.17.4/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
go.opentelemetry.io/otel/exporters/jaeger v1.17.0 h1:D7UpUy2Xc2wsi1Ras6V40q806WM07rqoCWzXu7Sqy+4=
@@ -241,8 +240,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
@@ -254,17 +253,17 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs=
golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE=
golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -276,20 +275,20 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg=
golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek=
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
@@ -308,8 +307,8 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 h1:
google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY=
google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc=
google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ=
google.golang.org/protobuf v1.36.4 h1:6A3ZDJHn/eNqc1i+IdefRzy/9PokBTPvcqMySR7NNIM=
google.golang.org/protobuf v1.36.4/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

View File

@@ -0,0 +1,263 @@
package profiling
import (
"runtime"
"sync"
"time"
"github.com/grafana/pyroscope-go"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/core/threading"
)
const (
defaultCheckInterval = time.Second * 10
defaultProfilingDuration = time.Minute * 2
defaultUploadRate = time.Second * 15
)
type (
Config struct {
// Name is the name of the application.
Name string `json:",optional,inherit"`
// ServerAddr is the address of the profiling server.
ServerAddr string
// AuthUser is the username for basic authentication.
AuthUser string `json:",optional"`
// AuthPassword is the password for basic authentication.
AuthPassword string `json:",optional"`
// UploadRate is the duration for which profiling data is uploaded.
UploadRate time.Duration `json:",default=15s"`
// CheckInterval is the interval to check if profiling should start.
CheckInterval time.Duration `json:",default=10s"`
// ProfilingDuration is the duration for which profiling data is collected.
ProfilingDuration time.Duration `json:",default=2m"`
// CpuThreshold the collection is allowed only when the current service cpu > CpuThreshold
CpuThreshold int64 `json:",default=700,range=[0:1000)"`
// ProfileType is the type of profiling to be performed.
ProfileType ProfileType
}
ProfileType struct {
// Logger is a flag to enable or disable logging.
Logger bool `json:",default=false"`
// CPU is a flag to disable CPU profiling.
CPU bool `json:",default=true"`
// Goroutines is a flag to disable goroutine profiling.
Goroutines bool `json:",default=true"`
// Memory is a flag to disable memory profiling.
Memory bool `json:",default=true"`
// Mutex is a flag to disable mutex profiling.
Mutex bool `json:",default=false"`
// Block is a flag to disable block profiling.
Block bool `json:",default=false"`
}
profiler interface {
Start() error
Stop() error
}
pyroscopeProfiler struct {
c Config
profiler *pyroscope.Profiler
}
)
var (
once sync.Once
newProfiler = func(c Config) profiler {
return newPyroscopeProfiler(c)
}
)
// Start initializes the pyroscope profiler with the given configuration.
func Start(c Config) {
// check if the profiling is enabled
if len(c.ServerAddr) == 0 {
return
}
// set default values for the configuration
if c.ProfilingDuration <= 0 {
c.ProfilingDuration = defaultProfilingDuration
}
// set default values for the configuration
if c.CheckInterval <= 0 {
c.CheckInterval = defaultCheckInterval
}
if c.UploadRate <= 0 {
c.UploadRate = defaultUploadRate
}
once.Do(func() {
logx.Info("continuous profiling started")
threading.GoSafe(func() {
startPyroscope(c, proc.Done())
})
})
}
// startPyroscope starts the pyroscope profiler with the given configuration.
func startPyroscope(c Config, done <-chan struct{}) {
var (
pr profiler
err error
latestProfilingTime time.Time
intervalTicker = time.NewTicker(c.CheckInterval)
profilingTicker = time.NewTicker(c.ProfilingDuration)
)
defer profilingTicker.Stop()
defer intervalTicker.Stop()
for {
select {
case <-intervalTicker.C:
// Check if the machine is overloaded and if the profiler is not running
if pr == nil && isCpuOverloaded(c) {
pr = newProfiler(c)
if err := pr.Start(); err != nil {
logx.Errorf("failed to start profiler: %v", err)
continue
}
// record the latest profiling time
latestProfilingTime = time.Now()
logx.Infof("pyroscope profiler started.")
}
case <-profilingTicker.C:
// check if the profiling duration has passed
if !time.Now().After(latestProfilingTime.Add(c.ProfilingDuration)) {
continue
}
// check if the profiler is already running, if so, skip
if pr != nil {
if err = pr.Stop(); err != nil {
logx.Errorf("failed to stop profiler: %v", err)
}
logx.Infof("pyroscope profiler stopped.")
pr = nil
}
case <-done:
logx.Infof("continuous profiling stopped.")
return
}
}
}
// genPyroscopeConf generates the pyroscope configuration based on the given config.
func genPyroscopeConf(c Config) pyroscope.Config {
pConf := pyroscope.Config{
UploadRate: c.UploadRate,
ApplicationName: c.Name,
BasicAuthUser: c.AuthUser, // http basic auth user
BasicAuthPassword: c.AuthPassword, // http basic auth password
ServerAddress: c.ServerAddr,
Logger: nil,
HTTPHeaders: map[string]string{},
// you can provide static tags via a map:
Tags: map[string]string{
"name": c.Name,
},
}
if c.ProfileType.Logger {
pConf.Logger = logx.WithCallerSkip(0)
}
if c.ProfileType.CPU {
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileCPU)
}
if c.ProfileType.Goroutines {
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileGoroutines)
}
if c.ProfileType.Memory {
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileAllocObjects, pyroscope.ProfileAllocSpace,
pyroscope.ProfileInuseObjects, pyroscope.ProfileInuseSpace)
}
if c.ProfileType.Mutex {
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileMutexCount, pyroscope.ProfileMutexDuration)
}
if c.ProfileType.Block {
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileBlockCount, pyroscope.ProfileBlockDuration)
}
logx.Infof("applicationName: %s", pConf.ApplicationName)
return pConf
}
// isCpuOverloaded checks the machine performance based on the given configuration.
func isCpuOverloaded(c Config) bool {
currentValue := stat.CpuUsage()
if currentValue >= c.CpuThreshold {
logx.Infof("continuous profiling cpu overload, cpu: %d", currentValue)
return true
}
return false
}
func newPyroscopeProfiler(c Config) profiler {
return &pyroscopeProfiler{
c: c,
}
}
func (p *pyroscopeProfiler) Start() error {
pConf := genPyroscopeConf(p.c)
// set mutex and block profile rate
setFraction(p.c)
prof, err := pyroscope.Start(pConf)
if err != nil {
resetFraction(p.c)
return err
}
p.profiler = prof
return nil
}
func (p *pyroscopeProfiler) Stop() error {
if p.profiler == nil {
return nil
}
if err := p.profiler.Stop(); err != nil {
return err
}
resetFraction(p.c)
p.profiler = nil
return nil
}
func setFraction(c Config) {
// These 2 lines are only required if you're using mutex or block profiling
if c.ProfileType.Mutex {
runtime.SetMutexProfileFraction(10) // 10/seconds
}
if c.ProfileType.Block {
runtime.SetBlockProfileRate(1000 * 1000) // 1/millisecond
}
}
func resetFraction(c Config) {
// These 2 lines are only required if you're using mutex or block profiling
if c.ProfileType.Mutex {
runtime.SetMutexProfileFraction(0)
}
if c.ProfileType.Block {
runtime.SetBlockProfileRate(0)
}
}

View File

@@ -0,0 +1,177 @@
package profiling
import (
"sync"
"testing"
"time"
"github.com/grafana/pyroscope-go"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/conf"
"github.com/zeromicro/go-zero/core/syncx"
)
func TestStart(t *testing.T) {
t.Run("profiling", func(t *testing.T) {
var c Config
assert.NoError(t, conf.FillDefault(&c))
c.Name = "test"
p := newProfiler(c)
assert.NotNil(t, p)
assert.NoError(t, p.Start())
assert.NoError(t, p.Stop())
})
t.Run("invalid config", func(t *testing.T) {
mp := &mockProfiler{}
newProfiler = func(c Config) profiler {
return mp
}
Start(Config{})
Start(Config{
ServerAddr: "localhost:4040",
})
})
t.Run("test start profiler", func(t *testing.T) {
mp := &mockProfiler{}
newProfiler = func(c Config) profiler {
return mp
}
c := Config{
Name: "test",
ServerAddr: "localhost:4040",
CheckInterval: time.Millisecond,
ProfilingDuration: time.Millisecond * 10,
CpuThreshold: 0,
}
var done = make(chan struct{})
go startPyroscope(c, done)
time.Sleep(time.Millisecond * 50)
close(done)
assert.True(t, mp.started.True())
assert.True(t, mp.stopped.True())
})
t.Run("test start profiler with cpu overloaded", func(t *testing.T) {
mp := &mockProfiler{}
newProfiler = func(c Config) profiler {
return mp
}
c := Config{
Name: "test",
ServerAddr: "localhost:4040",
CheckInterval: time.Millisecond,
ProfilingDuration: time.Millisecond * 10,
CpuThreshold: 900,
}
var done = make(chan struct{})
go startPyroscope(c, done)
time.Sleep(time.Millisecond * 50)
close(done)
assert.False(t, mp.started.True())
})
t.Run("start/stop err", func(t *testing.T) {
mp := &mockProfiler{
err: assert.AnError,
}
newProfiler = func(c Config) profiler {
return mp
}
c := Config{
Name: "test",
ServerAddr: "localhost:4040",
CheckInterval: time.Millisecond,
ProfilingDuration: time.Millisecond * 10,
CpuThreshold: 0,
}
var done = make(chan struct{})
go startPyroscope(c, done)
time.Sleep(time.Millisecond * 50)
close(done)
assert.False(t, mp.started.True())
assert.False(t, mp.stopped.True())
})
}
func TestGenPyroscopeConf(t *testing.T) {
c := Config{
Name: "",
ServerAddr: "localhost:4040",
AuthUser: "user",
AuthPassword: "password",
ProfileType: ProfileType{
Logger: true,
CPU: true,
Goroutines: true,
Memory: true,
Mutex: true,
Block: true,
},
}
pyroscopeConf := genPyroscopeConf(c)
assert.Equal(t, c.ServerAddr, pyroscopeConf.ServerAddress)
assert.Equal(t, c.AuthUser, pyroscopeConf.BasicAuthUser)
assert.Equal(t, c.AuthPassword, pyroscopeConf.BasicAuthPassword)
assert.Equal(t, c.Name, pyroscopeConf.ApplicationName)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileCPU)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileGoroutines)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileAllocObjects)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileAllocSpace)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileInuseObjects)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileInuseSpace)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileMutexCount)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileMutexDuration)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileBlockCount)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileBlockDuration)
setFraction(c)
resetFraction(c)
newPyroscopeProfiler(c)
}
func TestNewPyroscopeProfiler(t *testing.T) {
p := newPyroscopeProfiler(Config{})
assert.Error(t, p.Start())
assert.NoError(t, p.Stop())
}
type mockProfiler struct {
mutex sync.Mutex
started syncx.AtomicBool
stopped syncx.AtomicBool
err error
}
func (m *mockProfiler) Start() error {
m.mutex.Lock()
if m.err == nil {
m.started.Set(true)
}
m.mutex.Unlock()
return m.err
}
func (m *mockProfiler) Stop() error {
m.mutex.Lock()
if m.err == nil {
m.stopped.Set(true)
}
m.mutex.Unlock()
return m.err
}

43
mcp/config.go Normal file
View File

@@ -0,0 +1,43 @@
package mcp
import (
"time"
"github.com/zeromicro/go-zero/rest"
)
// McpConf defines the configuration for an MCP server.
// It embeds rest.RestConf for HTTP server settings
// and adds MCP-specific configuration options.
type McpConf struct {
rest.RestConf
Mcp struct {
// Name is the server name reported in initialize responses
Name string `json:",optional"`
// Version is the server version reported in initialize responses
Version string `json:",default=1.0.0"`
// ProtocolVersion is the MCP protocol version implemented
ProtocolVersion string `json:",default=2024-11-05"`
// BaseUrl is the base URL for the server, used in SSE endpoint messages
// If not set, defaults to http://localhost:{Port}
BaseUrl string `json:",optional"`
// SseEndpoint is the path for Server-Sent Events connections
SseEndpoint string `json:",default=/sse"`
// MessageEndpoint is the path for JSON-RPC requests
MessageEndpoint string `json:",default=/message"`
// Cors contains allowed CORS origins
Cors []string `json:",optional"`
// SseTimeout is the maximum time allowed for SSE connections
SseTimeout time.Duration `json:",default=24h"`
// MessageTimeout is the maximum time allowed for request execution
MessageTimeout time.Duration `json:",default=30s"`
}
}

63
mcp/config_test.go Normal file
View File

@@ -0,0 +1,63 @@
package mcp
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/conf"
)
func TestMcpConfDefaults(t *testing.T) {
// Test default values are set correctly when unmarshalled from JSON
jsonConfig := `name: test-service
port: 8080
mcp:
name: test-mcp-server
version: 1.0.0
`
var c McpConf
err := conf.LoadFromYamlBytes([]byte(jsonConfig), &c)
assert.NoError(t, err)
// Check default values
assert.Equal(t, "test-mcp-server", c.Mcp.Name)
assert.Equal(t, "1.0.0", c.Mcp.Version, "Default version should be 1.0.0")
assert.Equal(t, "2024-11-05", c.Mcp.ProtocolVersion, "Default protocol version should be 2024-11-05")
assert.Equal(t, "/sse", c.Mcp.SseEndpoint, "Default SSE endpoint should be /sse")
assert.Equal(t, "/message", c.Mcp.MessageEndpoint, "Default message endpoint should be /message")
assert.Equal(t, 30*time.Second, c.Mcp.MessageTimeout, "Default message timeout should be 30s")
}
func TestMcpConfCustomValues(t *testing.T) {
// Test custom values can be set
jsonConfig := `{
"Name": "test-service",
"Port": 8080,
"Mcp": {
"Name": "test-mcp-server",
"Version": "2.0.0",
"ProtocolVersion": "2025-01-01",
"BaseUrl": "http://example.com",
"SseEndpoint": "/custom-sse",
"MessageEndpoint": "/custom-message",
"Cors": ["http://localhost:3000", "http://example.com"],
"MessageTimeout": "60s"
}
}`
var c McpConf
err := conf.LoadFromJsonBytes([]byte(jsonConfig), &c)
assert.NoError(t, err)
// Check custom values
assert.Equal(t, "test-mcp-server", c.Mcp.Name, "Name should be inherited from RestConf")
assert.Equal(t, "2.0.0", c.Mcp.Version, "Version should be customizable")
assert.Equal(t, "2025-01-01", c.Mcp.ProtocolVersion, "Protocol version should be customizable")
assert.Equal(t, "http://example.com", c.Mcp.BaseUrl, "BaseUrl should be customizable")
assert.Equal(t, "/custom-sse", c.Mcp.SseEndpoint, "SSE endpoint should be customizable")
assert.Equal(t, "/custom-message", c.Mcp.MessageEndpoint, "Message endpoint should be customizable")
assert.Equal(t, []string{"http://localhost:3000", "http://example.com"}, c.Mcp.Cors, "CORS settings should be customizable")
assert.Equal(t, 60*time.Second, c.Mcp.MessageTimeout, "Tool timeout should be customizable")
}

443
mcp/integration_test.go Normal file
View File

@@ -0,0 +1,443 @@
package mcp
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// syncResponseRecorder is a thread-safe wrapper around httptest.ResponseRecorder
type syncResponseRecorder struct {
*httptest.ResponseRecorder
mu sync.Mutex
}
// Create a new synchronized response recorder
func newSyncResponseRecorder() *syncResponseRecorder {
return &syncResponseRecorder{
ResponseRecorder: httptest.NewRecorder(),
}
}
// Override Write method to synchronize access
func (srr *syncResponseRecorder) Write(p []byte) (int, error) {
srr.mu.Lock()
defer srr.mu.Unlock()
return srr.ResponseRecorder.Write(p)
}
// Override WriteHeader method to synchronize access
func (srr *syncResponseRecorder) WriteHeader(statusCode int) {
srr.mu.Lock()
defer srr.mu.Unlock()
srr.ResponseRecorder.WriteHeader(statusCode)
}
// Override Result method to synchronize access
func (srr *syncResponseRecorder) Result() *http.Response {
srr.mu.Lock()
defer srr.mu.Unlock()
return srr.ResponseRecorder.Result()
}
// TestHTTPHandlerIntegration tests the HTTP handlers with a real server instance
func TestHTTPHandlerIntegration(t *testing.T) {
// Skip in short test mode
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
// Create a test configuration
conf := McpConf{}
conf.Mcp.Name = "test-integration"
conf.Mcp.Version = "1.0.0-test"
conf.Mcp.MessageTimeout = 1 * time.Second
// Create a mock server directly
server := &sseMcpServer{
conf: conf,
clients: make(map[string]*mcpClient),
tools: make(map[string]Tool),
prompts: make(map[string]Prompt),
resources: make(map[string]Resource),
}
// Register a test tool
err := server.RegisterTool(Tool{
Name: "echo",
Description: "Echo tool for testing",
InputSchema: InputSchema{
Properties: map[string]any{
"message": map[string]any{
"type": "string",
"description": "Message to echo",
},
},
},
Handler: func(ctx context.Context, params map[string]any) (any, error) {
if msg, ok := params["message"].(string); ok {
return fmt.Sprintf("Echo: %s", msg), nil
}
return "Echo: no message provided", nil
},
})
require.NoError(t, err)
// Create a test HTTP request to the SSE endpoint
req := httptest.NewRequest("GET", "/sse", nil)
w := newSyncResponseRecorder()
// Create a done channel to signal completion of test
done := make(chan bool)
// Start the SSE handler in a goroutine
go func() {
// lock.Lock()
server.handleSSE(w, req)
// lock.Unlock()
done <- true
}()
// Allow time for the handler to process
select {
case <-time.After(100 * time.Millisecond):
// Expected - handler would normally block indefinitely
case <-done:
// This shouldn't happen immediately - the handler should block
t.Error("SSE handler returned unexpectedly")
}
// Check the initial headers
resp := w.Result()
assert.Equal(t, "chunked", resp.Header.Get("Transfer-Encoding"))
resp.Body.Close()
// The handler creates a client and sends the endpoint message
var sessionId string
// Give the handler time to set up the client
time.Sleep(50 * time.Millisecond)
// Check that a client was created
server.clientsLock.Lock()
assert.Equal(t, 1, len(server.clients))
for id := range server.clients {
sessionId = id
}
server.clientsLock.Unlock()
require.NotEmpty(t, sessionId, "Expected a session ID to be created")
// Now that we have a session ID, we can test the message endpoint
messageBody, _ := json.Marshal(Request{
JsonRpc: "2.0",
ID: 1,
Method: methodInitialize,
Params: json.RawMessage(`{}`),
})
// Create a message request
reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, sessionId)
msgReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(messageBody))
msgW := newSyncResponseRecorder()
// Process the message
server.handleRequest(msgW, msgReq)
// Check the response
msgResp := msgW.Result()
assert.Equal(t, http.StatusAccepted, msgResp.StatusCode)
msgResp.Body.Close() // Ensure response body is closed
}
// TestHandlerResponseFlow tests the flow of a full request/response cycle
func TestHandlerResponseFlow(t *testing.T) {
// Create a mock server for testing
server := &sseMcpServer{
conf: McpConf{},
clients: map[string]*mcpClient{
"test-session": {
id: "test-session",
channel: make(chan string, 10),
initialized: true,
},
},
tools: make(map[string]Tool),
prompts: make(map[string]Prompt),
resources: make(map[string]Resource),
}
// Register test resources
server.RegisterTool(Tool{
Name: "test.tool",
Description: "Test tool",
InputSchema: InputSchema{Type: "object"},
Handler: func(ctx context.Context, params map[string]any) (any, error) {
return "tool result", nil
},
})
server.RegisterPrompt(Prompt{
Name: "test.prompt",
Description: "Test prompt",
})
server.RegisterResource(Resource{
Name: "test.resource",
URI: "http://example.com",
Description: "Test resource",
})
// Create a request with session ID parameter
reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, "test-session")
// Test tools/list request
toolsListBody, _ := json.Marshal(Request{
JsonRpc: "2.0",
ID: 1,
Method: methodToolsList,
Params: json.RawMessage(`{}`),
})
toolsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(toolsListBody))
toolsW := newSyncResponseRecorder()
// Process the request
server.handleRequest(toolsW, toolsReq)
// Check the response code
toolsResp := toolsW.Result()
assert.Equal(t, http.StatusAccepted, toolsResp.StatusCode)
toolsResp.Body.Close()
// Check the channel message
client := server.clients["test-session"]
select {
case message := <-client.channel:
assert.Contains(t, message, `"tools":[{"name":"test.tool"`)
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for tools/list response")
}
// Test prompts/list request
promptsListBody, _ := json.Marshal(Request{
JsonRpc: "2.0",
ID: 2,
Method: methodPromptsList,
Params: json.RawMessage(`{}`),
})
promptsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(promptsListBody))
promptsW := newSyncResponseRecorder()
// Process the request
server.handleRequest(promptsW, promptsReq)
// Check the response code
promptsResp := promptsW.Result()
assert.Equal(t, http.StatusAccepted, promptsResp.StatusCode)
promptsResp.Body.Close()
// Check the channel message
select {
case message := <-client.channel:
assert.Contains(t, message, `"prompts":[{"name":"test.prompt"`)
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for prompts/list response")
}
// Test resources/list request
resourcesListBody, _ := json.Marshal(Request{
JsonRpc: "2.0",
ID: 3,
Method: methodResourcesList,
Params: json.RawMessage(`{}`),
})
resourcesReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(resourcesListBody))
resourcesW := newSyncResponseRecorder()
// Process the request
server.handleRequest(resourcesW, resourcesReq)
// Check the response code
resourcesResp := resourcesW.Result()
assert.Equal(t, http.StatusAccepted, resourcesResp.StatusCode)
resourcesResp.Body.Close()
// Check the channel message
select {
case message := <-client.channel:
assert.Contains(t, message, `"name":"test.resource"`)
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for resources/list response")
}
}
// TestProcessListMethods tests the list processing methods with pagination
func TestProcessListMethods(t *testing.T) {
server := &sseMcpServer{
tools: make(map[string]Tool),
prompts: make(map[string]Prompt),
resources: make(map[string]Resource),
}
// Add some test data
for i := 1; i <= 5; i++ {
tool := Tool{
Name: fmt.Sprintf("tool%d", i),
Description: fmt.Sprintf("Tool %d", i),
InputSchema: InputSchema{Type: "object"},
}
server.tools[tool.Name] = tool
prompt := Prompt{
Name: fmt.Sprintf("prompt%d", i),
Description: fmt.Sprintf("Prompt %d", i),
}
server.prompts[prompt.Name] = prompt
resource := Resource{
Name: fmt.Sprintf("resource%d", i),
URI: fmt.Sprintf("http://example.com/%d", i),
Description: fmt.Sprintf("Resource %d", i),
}
server.resources[resource.Name] = resource
}
// Create a test client
client := &mcpClient{
id: "test-client",
channel: make(chan string, 10),
initialized: true,
}
// Test processListTools
req := Request{
JsonRpc: "2.0",
ID: 1,
Method: methodToolsList,
Params: json.RawMessage(`{"cursor": "", "_meta": {"progressToken": "token1"}}`),
}
server.processListTools(context.Background(), client, req)
// Read response
select {
case response := <-client.channel:
assert.Contains(t, response, `"tools":`)
assert.Contains(t, response, `"progressToken":"token1"`)
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for tools/list response")
}
// Test processListPrompts
req.ID = 2
req.Method = methodPromptsList
req.Params = json.RawMessage(`{"cursor": "next"}`)
server.processListPrompts(context.Background(), client, req)
// Read response
select {
case response := <-client.channel:
assert.Contains(t, response, `"prompts":`)
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for prompts/list response")
}
// Test processListResources
req.ID = 3
req.Method = methodResourcesList
req.Params = json.RawMessage(`{"cursor": "next"}`)
server.processListResources(context.Background(), client, req)
// Read response
select {
case response := <-client.channel:
assert.Contains(t, response, `"resources":`)
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for resources/list response")
}
}
// TestErrorResponseHandling tests error handling in the server
func TestErrorResponseHandling(t *testing.T) {
server := &sseMcpServer{
tools: make(map[string]Tool),
prompts: make(map[string]Prompt),
resources: make(map[string]Resource),
}
// Create a test client
client := &mcpClient{
id: "test-client",
channel: make(chan string, 10),
initialized: true,
}
// Test invalid method
req := Request{
JsonRpc: "2.0",
ID: 1,
Method: "invalid_method",
Params: json.RawMessage(`{}`),
}
// Mock handleRequest by directly calling error handler
server.sendErrorResponse(context.Background(), client, req.ID, "Method not found", errCodeMethodNotFound)
// Check response
select {
case response := <-client.channel:
assert.Contains(t, response, `"error":{"code":-32601,"message":"Method not found"}`)
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for error response")
}
// Test invalid tool
toolReq := Request{
JsonRpc: "2.0",
ID: 2,
Method: methodToolsCall,
Params: json.RawMessage(`{"name":"non_existent_tool"}`),
}
// Call process method directly
server.processToolCall(context.Background(), client, toolReq)
// Check response
select {
case response := <-client.channel:
assert.Contains(t, response, `"error":{"code":-32602,"message":"Tool 'non_existent_tool' not found"}`)
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for error response")
}
// Test invalid prompt
promptReq := Request{
JsonRpc: "2.0",
ID: 3,
Method: methodPromptsGet,
Params: json.RawMessage(`{"name":"non_existent_prompt"}`),
}
// Call process method directly
server.processGetPrompt(context.Background(), client, promptReq)
// Check response
select {
case response := <-client.channel:
assert.Contains(t, response, `"error":{"code":-32602,"message":"Prompt 'non_existent_prompt' not found"}`)
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for error response")
}
}

23
mcp/parser.go Normal file
View File

@@ -0,0 +1,23 @@
package mcp
import (
"fmt"
"github.com/zeromicro/go-zero/core/mapping"
)
// ParseArguments parses the arguments and populates the request object
func ParseArguments(args any, req any) error {
switch arguments := args.(type) {
case map[string]string:
m := make(map[string]any, len(arguments))
for k, v := range arguments {
m[k] = v
}
return mapping.UnmarshalJsonMap(m, req, mapping.WithStringValues())
case map[string]any:
return mapping.UnmarshalJsonMap(arguments, req)
default:
return fmt.Errorf("unsupported argument type: %T", arguments)
}
}

139
mcp/parser_test.go Normal file
View File

@@ -0,0 +1,139 @@
package mcp
import (
"testing"
"github.com/stretchr/testify/assert"
)
// TestParseArguments_MapStringString tests parsing map[string]string arguments
func TestParseArguments_MapStringString(t *testing.T) {
// Sample request struct to populate
type requestStruct struct {
Name string `json:"name"`
Message string `json:"message"`
Count int `json:"count"`
Enabled bool `json:"enabled"`
}
// Create test arguments
args := map[string]string{
"name": "test-name",
"message": "hello world",
"count": "42",
"enabled": "true",
}
// Create a target object to populate
var req requestStruct
// Parse the arguments
err := ParseArguments(args, &req)
// Verify results
assert.NoError(t, err, "Should parse map[string]string without error")
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
assert.Equal(t, 42, req.Count, "Count should be correctly parsed to int")
assert.True(t, req.Enabled, "Enabled should be correctly parsed to bool")
}
// TestParseArguments_MapStringAny tests parsing map[string]any arguments
func TestParseArguments_MapStringAny(t *testing.T) {
// Sample request struct to populate
type requestStruct struct {
Name string `json:"name"`
Message string `json:"message"`
Count int `json:"count"`
Enabled bool `json:"enabled"`
Tags []string `json:"tags"`
Metadata map[string]string `json:"metadata"`
}
// Create test arguments with mixed types
args := map[string]any{
"name": "test-name",
"message": "hello world",
"count": 42, // note: this is already an int
"enabled": true, // note: this is already a bool
"tags": []string{"tag1", "tag2"},
"metadata": map[string]string{
"key1": "value1",
"key2": "value2",
},
}
// Create a target object to populate
var req requestStruct
// Parse the arguments
err := ParseArguments(args, &req)
// Verify results
assert.NoError(t, err, "Should parse map[string]any without error")
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
assert.Equal(t, 42, req.Count, "Count should be correctly parsed")
assert.True(t, req.Enabled, "Enabled should be correctly parsed")
assert.Equal(t, []string{"tag1", "tag2"}, req.Tags, "Tags should be correctly parsed")
assert.Equal(t, map[string]string{
"key1": "value1",
"key2": "value2",
}, req.Metadata, "Metadata should be correctly parsed")
}
// TestParseArguments_UnsupportedType tests parsing with an unsupported type
func TestParseArguments_UnsupportedType(t *testing.T) {
// Sample request struct to populate
type requestStruct struct {
Name string `json:"name"`
Message string `json:"message"`
}
// Use an unsupported argument type (slice)
args := []string{"not", "a", "map"}
// Create a target object to populate
var req requestStruct
// Parse the arguments
err := ParseArguments(args, &req)
// Verify error is returned with correct message
assert.Error(t, err, "Should return error for unsupported type")
assert.Contains(t, err.Error(), "unsupported argument type", "Error should mention unsupported type")
assert.Contains(t, err.Error(), "[]string", "Error should include the actual type")
}
// TestParseArguments_EmptyMap tests parsing with empty maps
func TestParseArguments_EmptyMap(t *testing.T) {
// Sample request struct to populate
type requestStruct struct {
Name string `json:"name,optional"`
Message string `json:"message,optional"`
}
// Test empty map[string]string
t.Run("EmptyMapStringString", func(t *testing.T) {
args := map[string]string{}
var req requestStruct
err := ParseArguments(args, &req)
assert.NoError(t, err, "Should parse empty map[string]string without error")
assert.Empty(t, req.Name, "Name should be empty string")
assert.Empty(t, req.Message, "Message should be empty string")
})
// Test empty map[string]any
t.Run("EmptyMapStringAny", func(t *testing.T) {
args := map[string]any{}
var req requestStruct
err := ParseArguments(args, &req)
assert.NoError(t, err, "Should parse empty map[string]any without error")
assert.Empty(t, req.Name, "Name should be empty string")
assert.Empty(t, req.Message, "Message should be empty string")
})
}

870
mcp/readme.md Normal file
View File

@@ -0,0 +1,870 @@
# Model Context Protocol (MCP) Implementation
## Overview
This package implements the Model Context Protocol (MCP) server specification in Go, providing a framework for real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation follows the standardized protocol for building AI-assisted applications with bidirectional communication capabilities.
## Core Components
### Server-Sent Events (SSE) Communication
- **Real-time Communication**: Robust SSE-based communication system that maintains persistent connections with clients
- **Connection Management**: Client registration, message broadcasting, and client cleanup mechanisms
- **Event Handling**: Event types for tools, prompts, and resources changes
### JSON-RPC Implementation
- **Request Processing**: Complete JSON-RPC request processor for handling MCP protocol methods
- **Response Formatting**: Proper response formatting according to JSON-RPC specifications
- **Error Handling**: Comprehensive error handling with appropriate error codes
### Tool Management
- **Tool Registration**: System to register custom tools with handlers
- **Tool Execution**: Mechanism to execute tool functions with proper timeout handling
- **Result Handling**: Flexible result handling supporting various return types (string, JSON, images)
### Prompt System
- **Prompt Registration**: System for registering both static and dynamic prompts
- **Argument Validation**: Validation for required arguments and default values for optional ones
- **Message Generation**: Handlers that generate properly formatted conversation messages
### Resource Management
- **Resource Registration**: System for managing and accessing external resources
- **Content Delivery**: Handlers for delivering resource content to clients on demand
- **Resource Subscription**: Mechanisms for clients to subscribe to resource updates
### Protocol Features
- **Initialization Sequence**: Proper handshaking with capability negotiation
- **Notification Handling**: Support for both standard and client-specific notifications
- **Message Routing**: Intelligent routing of requests to appropriate handlers
## Technical Highlights
### Configuration System
- **Flexible Configuration**: Configuration system with sensible defaults and customization options
- **CORS Support**: Configurable CORS settings for cross-origin requests
- **Server Information**: Proper server identification and versioning
### Client Session Management
- **Session Tracking**: Client session tracking with unique identifiers
- **Connection Health**: Ping/pong mechanism to maintain connection health
- **Initialization State**: Client initialization state tracking
### Content Handling
- **Multi-format Content**: Support for text, code, and binary content
- **MIME Type Support**: Proper MIME type identification for various content types
- **Audience Annotations**: Content audience annotations for user/assistant targeting
## Usage
### Setting Up an MCP Server
To create and start an MCP server:
```go
package main
import (
"github.com/zeromicro/go-zero/core/conf"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/mcp"
)
func main() {
// Load configuration from YAML file
var c mcp.McpConf
conf.MustLoad("config.yaml", &c)
// Optional: Disable stats logging
logx.DisableStat()
// Create MCP server
server := mcp.NewMcpServer(c)
// Register tools, prompts, and resources (examples below)
// Start the server and ensure it's stopped on exit
defer server.Stop()
server.Start()
}
```
Sample configuration file (config.yaml):
```yaml
name: mcp-server
host: localhost
port: 8080
mcp:
name: my-mcp-server
messageTimeout: 30s # Timeout for tool calls
cors:
- http://localhost:3000 # Optional CORS configuration
```
### Registering Tools
Tools allow AI models to execute custom code through the MCP protocol.
#### Basic Tool Example:
```go
// Register a simple echo tool
echoTool := mcp.Tool{
Name: "echo",
Description: "Echoes back the message provided by the user",
InputSchema: mcp.InputSchema{
Properties: map[string]any{
"message": map[string]any{
"type": "string",
"description": "The message to echo back",
},
"prefix": map[string]any{
"type": "string",
"description": "Optional prefix to add to the echoed message",
"default": "Echo: ",
},
},
Required: []string{"message"},
},
Handler: func(ctx context.Context, params map[string]any) (any, error) {
var req struct {
Message string `json:"message"`
Prefix string `json:"prefix,optional"`
}
if err := mcp.ParseArguments(params, &req); err != nil {
return nil, fmt.Errorf("failed to parse params: %w", err)
}
prefix := "Echo: "
if len(req.Prefix) > 0 {
prefix = req.Prefix
}
return prefix + req.Message, nil
},
}
server.RegisterTool(echoTool)
```
#### Tool with Different Response Types:
```go
// Tool returning JSON data
dataTool := mcp.Tool{
Name: "data.generate",
Description: "Generates sample data in various formats",
InputSchema: mcp.InputSchema{
Properties: map[string]any{
"format": map[string]any{
"type": "string",
"description": "Format of data (json, text)",
"enum": []string{"json", "text"},
},
},
},
Handler: func(ctx context.Context, params map[string]any) (any, error) {
var req struct {
Format string `json:"format"`
}
if err := mcp.ParseArguments(params, &req); err != nil {
return nil, fmt.Errorf("failed to parse params: %w", err)
}
if req.Format == "json" {
// Return structured data
return map[string]any{
"items": []map[string]any{
{"id": 1, "name": "Item 1"},
{"id": 2, "name": "Item 2"},
},
"count": 2,
}, nil
}
// Default to text
return "Sample text data", nil
},
}
server.RegisterTool(dataTool)
```
#### Image Generation Tool Example:
```go
// Tool returning image content
imageTool := mcp.Tool{
Name: "image.generate",
Description: "Generates a simple image",
InputSchema: mcp.InputSchema{
Properties: map[string]any{
"type": map[string]any{
"type": "string",
"description": "Type of image to generate",
"default": "placeholder",
},
},
},
Handler: func(ctx context.Context, params map[string]any) (any, error) {
// Return image content directly
return mcp.ImageContent{
Data: "base64EncodedImageData...", // Base64 encoded image data
MimeType: "image/png",
}, nil
},
}
server.RegisterTool(imageTool)
```
#### Using ToolResult for Custom Outputs:
```go
// Tool that returns a custom ToolResult type
customResultTool := mcp.Tool{
Name: "custom.result",
Description: "Returns a custom formatted result",
InputSchema: mcp.InputSchema{
Properties: map[string]any{
"resultType": map[string]any{
"type": "string",
"enum": []string{"text", "image"},
},
},
},
Handler: func(ctx context.Context, params map[string]any) (any, error) {
var req struct {
ResultType string `json:"resultType"`
}
if err := mcp.ParseArguments(params, &req); err != nil {
return nil, fmt.Errorf("failed to parse params: %w", err)
}
if req.ResultType == "image" {
return mcp.ToolResult{
Type: mcp.ContentTypeImage,
Content: map[string]any{
"data": "base64EncodedImageData...",
"mimeType": "image/jpeg",
},
}, nil
}
// Default to text
return mcp.ToolResult{
Type: mcp.ContentTypeText,
Content: "This is a text result from ToolResult",
}, nil
},
}
server.RegisterTool(customResultTool)
```
### Registering Prompts
Prompts are reusable conversation templates for AI models.
#### Static Prompt Example:
```go
// Register a simple static prompt with placeholders
server.RegisterPrompt(mcp.Prompt{
Name: "hello",
Description: "A simple hello prompt",
Arguments: []mcp.PromptArgument{
{
Name: "name",
Description: "The name to greet",
Required: false,
},
},
Content: "Say hello to {{name}} and introduce yourself as an AI assistant.",
})
```
#### Dynamic Prompt with Handler Function:
```go
// Register a prompt with a dynamic handler function
server.RegisterPrompt(mcp.Prompt{
Name: "dynamic-prompt",
Description: "A prompt that uses a handler to generate dynamic content",
Arguments: []mcp.PromptArgument{
{
Name: "username",
Description: "User's name for personalized greeting",
Required: true,
},
{
Name: "topic",
Description: "Topic of expertise",
Required: true,
},
},
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
var req struct {
Username string `json:"username"`
Topic string `json:"topic"`
}
if err := mcp.ParseArguments(args, &req); err != nil {
return nil, fmt.Errorf("failed to parse args: %w", err)
}
// Create a user message
userMessage := mcp.PromptMessage{
Role: mcp.RoleUser,
Content: mcp.TextContent{
Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic),
},
}
// Create an assistant response with current time
currentTime := time.Now().Format(time.RFC1123)
assistantMessage := mcp.PromptMessage{
Role: mcp.RoleAssistant,
Content: mcp.TextContent{
Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.",
req.Username, req.Topic, currentTime),
},
}
// Return both messages as a conversation
return []mcp.PromptMessage{userMessage, assistantMessage}, nil
},
})
```
#### Multi-Message Prompt with Code Examples:
```go
// Register a prompt that provides code examples in different programming languages
server.RegisterPrompt(mcp.Prompt{
Name: "code-example",
Description: "Provides code examples in different programming languages",
Arguments: []mcp.PromptArgument{
{
Name: "language",
Description: "Programming language for the example",
Required: true,
},
{
Name: "complexity",
Description: "Complexity level (simple, medium, advanced)",
},
},
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
var req struct {
Language string `json:"language"`
Complexity string `json:"complexity,optional"`
}
if err := mcp.ParseArguments(args, &req); err != nil {
return nil, fmt.Errorf("failed to parse args: %w", err)
}
// Validate language
supportedLanguages := map[string]bool{"go": true, "python": true, "javascript": true, "rust": true}
if !supportedLanguages[req.Language] {
return nil, fmt.Errorf("unsupported language: %s", req.Language)
}
// Generate code example based on language and complexity
var codeExample string
switch req.Language {
case "go":
if req.Complexity == "simple" {
codeExample = `
package main
import "fmt"
func main() {
fmt.Println("Hello, World!")
}`
} else {
codeExample = `
package main
import (
"fmt"
"time"
)
func main() {
now := time.Now()
fmt.Printf("Hello, World! Current time is %s\n", now.Format(time.RFC3339))
}`
}
case "python":
// Python example code
if req.Complexity == "simple" {
codeExample = `
def greet(name):
return f"Hello, {name}!"
print(greet("World"))`
} else {
codeExample = `
import datetime
def greet(name, include_time=False):
message = f"Hello, {name}!"
if include_time:
message += f" Current time is {datetime.datetime.now().isoformat()}"
return message
print(greet("World", include_time=True))`
}
}
// Create messages array according to MCP spec
messages := []mcp.PromptMessage{
{
Role: mcp.RoleAssistant,
Content: mcp.TextContent{
Text: fmt.Sprintf("You are a helpful coding assistant specialized in %s programming.", req.Language),
},
},
{
Role: mcp.RoleUser,
Content: mcp.TextContent{
Text: fmt.Sprintf("Show me a %s example of a Hello World program in %s.", req.Complexity, req.Language),
},
},
{
Role: mcp.RoleAssistant,
Content: mcp.TextContent{
Text: fmt.Sprintf("Here's a %s example in %s:\n\n```%s%s\n```\n\nHow can I help you implement this?",
req.Complexity, req.Language, req.Language, codeExample),
},
},
}
return messages, nil
},
})
```
### Registering Resources
Resources provide access to external content such as files or generated data.
#### Basic Resource Example:
```go
// Register a static resource
server.RegisterResource(mcp.Resource{
Name: "example-document",
URI: "file:///example/document.txt",
Description: "An example document",
MimeType: "text/plain",
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
return mcp.ResourceContent{
URI: "file:///example/document.txt",
MimeType: "text/plain",
Text: "This is an example document content.",
}, nil
},
})
```
#### Dynamic Resource with Code Example:
```go
// Register a Go code resource with dynamic handler
server.RegisterResource(mcp.Resource{
Name: "go-example",
URI: "file:///project/src/main.go",
Description: "A simple Go example with multiple files",
MimeType: "text/x-go",
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
// Return ResourceContent with all required fields
return mcp.ResourceContent{
URI: "file:///project/src/main.go",
MimeType: "text/x-go",
Text: "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}",
}, nil
},
})
// Register a companion file for the above example
server.RegisterResource(mcp.Resource{
Name: "go-greeting",
URI: "file:///project/src/greeting/greeting.go",
Description: "A greeting package for the Go example",
MimeType: "text/x-go",
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
return mcp.ResourceContent{
URI: "file:///project/src/greeting/greeting.go",
MimeType: "text/x-go",
Text: "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}",
}, nil
},
})
```
#### Binary Resource Example:
```go
// Register a binary resource (like an image)
server.RegisterResource(mcp.Resource{
Name: "example-image",
URI: "file:///example/image.png",
Description: "An example image",
MimeType: "image/png",
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
// Read image from file or generate it
imageData := "base64EncodedImageData..." // Base64 encoded image data
return mcp.ResourceContent{
URI: "file:///example/image.png",
MimeType: "image/png",
Blob: imageData, // For binary data
}, nil
},
})
```
### Using Resources in Prompts
You can embed resources in prompt responses to create rich interactions with proper MCP-compliant structure:
```go
// Register a prompt that embeds a resource
server.RegisterPrompt(mcp.Prompt{
Name: "resource-example",
Description: "A prompt that embeds a resource",
Arguments: []mcp.PromptArgument{
{
Name: "file_type",
Description: "Type of file to show (rust or go)",
Required: true,
},
},
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
var req struct {
FileType string `json:"file_type"`
}
if err := mcp.ParseArguments(args, &req); err != nil {
return nil, fmt.Errorf("failed to parse args: %w", err)
}
var resourceURI, mimeType, fileContent string
if req.FileType == "rust" {
resourceURI = "file:///project/src/main.rs"
mimeType = "text/x-rust"
fileContent = "fn main() {\n println!(\"Hello world!\");\n}"
} else {
resourceURI = "file:///project/src/main.go"
mimeType = "text/x-go"
fileContent = "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello, world!\")\n}"
}
// Create message with embedded resource using proper MCP format
return []mcp.PromptMessage{
{
Role: mcp.RoleUser,
Content: mcp.TextContent{
Text: fmt.Sprintf("Can you explain this %s code?", req.FileType),
},
},
{
Role: mcp.RoleAssistant,
Content: mcp.EmbeddedResource{
Type: mcp.ContentTypeResource,
Resource: struct {
URI string `json:"uri"`
MimeType string `json:"mimeType"`
Text string `json:"text,omitempty"`
Blob string `json:"blob,omitempty"`
}{
URI: resourceURI,
MimeType: mimeType,
Text: fileContent,
},
},
},
{
Role: mcp.RoleAssistant,
Content: mcp.TextContent{
Text: fmt.Sprintf("Above is a simple Hello World example in %s. Let me explain how it works.", req.FileType),
},
},
}, nil
},
})
```
### Multiple File Resources Example
```go
// Register a prompt that demonstrates embedding multiple resource files
server.RegisterPrompt(mcp.Prompt{
Name: "go-code-example",
Description: "A prompt that correctly embeds multiple resource files",
Arguments: []mcp.PromptArgument{
{
Name: "format",
Description: "How to format the code display",
},
},
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
var req struct {
Format string `json:"format,optional"`
}
if err := mcp.ParseArguments(args, &req); err != nil {
return nil, fmt.Errorf("failed to parse args: %w", err)
}
// Get the Go code for multiple files
var mainGoText string = "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}"
var greetingGoText string = "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}"
// Create message with properly formatted embedded resource per MCP spec
messages := []mcp.PromptMessage{
{
Role: mcp.RoleUser,
Content: mcp.TextContent{
Text: "Show me a simple Go example with proper imports.",
},
},
{
Role: mcp.RoleAssistant,
Content: mcp.TextContent{
Text: "Here's a simple Go example project:",
},
},
{
Role: mcp.RoleAssistant,
Content: mcp.EmbeddedResource{
Type: mcp.ContentTypeResource,
Resource: struct {
URI string `json:"uri"`
MimeType string `json:"mimeType"`
Text string `json:"text,omitempty"`
Blob string `json:"blob,omitempty"`
}{
URI: "file:///project/src/main.go",
MimeType: "text/x-go",
Text: mainGoText,
},
},
},
}
// Add explanation and additional file if requested
if req.Format == "with_explanation" {
messages = append(messages, mcp.PromptMessage{
Role: mcp.RoleAssistant,
Content: mcp.TextContent{
Text: "This example demonstrates a simple Go application with modular structure. The main.go file imports from a local 'greeting' package that provides the Hello function.",
},
})
// Also show the greeting.go file with correct resource format
messages = append(messages, mcp.PromptMessage{
Role: mcp.RoleAssistant,
Content: mcp.EmbeddedResource{
Type: mcp.ContentTypeResource,
Resource: struct {
URI string `json:"uri"`
MimeType string `json:"mimeType"`
Text string `json:"text,omitempty"`
Blob string `json:"blob,omitempty"`
}{
URI: "file:///project/src/greeting/greeting.go",
MimeType: "text/x-go",
Text: greetingGoText,
},
},
})
}
return messages, nil
},
})
```
### Complete Application Example
Here's a complete example demonstrating all the components:
```go
package main
import (
"context"
"fmt"
"log"
"time"
"github.com/zeromicro/go-zero/core/conf"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/mcp"
)
func main() {
// Load configuration
var c mcp.McpConf
if err := conf.Load("config.yaml", &c); err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// Set up logging
logx.DisableStat()
// Create MCP server
server := mcp.NewMcpServer(c)
defer server.Stop()
// Register a simple echo tool
echoTool := mcp.Tool{
Name: "echo",
Description: "Echoes back the message provided by the user",
InputSchema: mcp.InputSchema{
Properties: map[string]any{
"message": map[string]any{
"type": "string",
"description": "The message to echo back",
},
"prefix": map[string]any{
"type": "string",
"description": "Optional prefix to add to the echoed message",
"default": "Echo: ",
},
},
Required: []string{"message"},
},
Handler: func(ctx context.Context, params map[string]any) (any, error) {
var req struct {
Message string `json:"message"`
Prefix string `json:"prefix,optional"`
}
if err := mcp.ParseArguments(params, &req); err != nil {
return nil, fmt.Errorf("failed to parse args: %w", err)
}
prefix := "Echo: "
if len(req.Prefix) > 0 {
prefix = req.Prefix
}
return prefix + req.Message, nil
},
}
server.RegisterTool(echoTool)
// Register a static prompt
server.RegisterPrompt(mcp.Prompt{
Name: "greeting",
Description: "A simple greeting prompt",
Arguments: []mcp.PromptArgument{
{
Name: "name",
Description: "The name to greet",
Required: true,
},
},
Content: "Hello {{name}}! How can I assist you today?",
})
// Register a dynamic prompt
server.RegisterPrompt(mcp.Prompt{
Name: "dynamic-prompt",
Description: "A prompt that uses a handler to generate dynamic content",
Arguments: []mcp.PromptArgument{
{
Name: "username",
Description: "User's name for personalized greeting",
Required: true,
},
{
Name: "topic",
Description: "Topic of expertise",
Required: true,
},
},
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
var req struct {
Username string `json:"username"`
Topic string `json:"topic"`
}
if err := mcp.ParseArguments(args, &req); err != nil {
return nil, fmt.Errorf("failed to parse args: %w", err)
}
// Create messages with current time
currentTime := time.Now().Format(time.RFC1123)
return []mcp.PromptMessage{
{
Role: mcp.RoleUser,
Content: mcp.TextContent{
Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic),
},
},
{
Role: mcp.RoleAssistant,
Content: mcp.TextContent{
Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.",
req.Username, req.Topic, currentTime),
},
},
}, nil
},
})
// Register a resource
server.RegisterResource(mcp.Resource{
Name: "example-doc",
URI: "file:///example/doc.txt",
Description: "An example document",
MimeType: "text/plain",
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
return mcp.ResourceContent{
URI: "file:///example/doc.txt",
MimeType: "text/plain",
Text: "This is the content of the example document.",
}, nil
},
})
// Start the server
fmt.Printf("Starting MCP server on %s:%d\n", c.Host, c.Port)
server.Start()
}
```
## Error Handling
The MCP implementation provides comprehensive error handling:
- Tool execution errors are properly reported back to clients
- Missing or invalid parameters are detected and reported with appropriate error codes
- Resource and prompt lookup failures are handled gracefully
- Timeout handling for long-running tool executions using context
- Panic recovery to prevent server crashes
## Advanced Features
- **Annotations**: Add audience and priority metadata to content
- **Content Types**: Support for text, images, audio, and other content formats
- **Embedded Resources**: Include file resources directly in prompt responses
- **Context Awareness**: All handlers receive context.Context for timeout and cancellation support
- **Progress Tokens**: Support for tracking progress of long-running operations
- **Customizable Timeouts**: Configure execution timeouts for tools and operations
## Performance Considerations
- Tool execution runs with configurable timeouts to prevent blocking
- Efficient client tracking and cleanup to prevent resource leaks
- Proper concurrency handling with mutex protection for shared resources
- Buffered message channels to prevent blocking on client message delivery

940
mcp/server.go Normal file
View File

@@ -0,0 +1,940 @@
package mcp
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/google/uuid"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/rest"
)
func NewMcpServer(c McpConf) McpServer {
var server *rest.Server
if len(c.Mcp.Cors) == 0 {
server = rest.MustNewServer(c.RestConf)
} else {
server = rest.MustNewServer(c.RestConf, rest.WithCors(c.Mcp.Cors...))
}
if len(c.Mcp.Name) == 0 {
c.Mcp.Name = c.Name
}
if len(c.Mcp.BaseUrl) == 0 {
c.Mcp.BaseUrl = fmt.Sprintf("http://localhost:%d", c.Port)
}
s := &sseMcpServer{
conf: c,
server: server,
clients: make(map[string]*mcpClient),
tools: make(map[string]Tool),
prompts: make(map[string]Prompt),
resources: make(map[string]Resource),
}
// SSE endpoint for real-time updates
s.server.AddRoute(rest.Route{
Method: http.MethodGet,
Path: s.conf.Mcp.SseEndpoint,
Handler: s.handleSSE,
}, rest.WithSSE(), rest.WithTimeout(c.Mcp.SseTimeout))
// JSON-RPC message endpoint for regular requests
s.server.AddRoute(rest.Route{
Method: http.MethodPost,
Path: s.conf.Mcp.MessageEndpoint,
Handler: s.handleRequest,
}, rest.WithTimeout(c.Mcp.MessageTimeout))
return s
}
// RegisterPrompt registers a new prompt with the server
func (s *sseMcpServer) RegisterPrompt(prompt Prompt) {
s.promptsLock.Lock()
s.prompts[prompt.Name] = prompt
s.promptsLock.Unlock()
// Notify clients about the new prompt
s.broadcast(eventPromptsListChanged, map[string][]Prompt{keyPrompts: {prompt}})
}
// RegisterResource registers a new resource with the server
func (s *sseMcpServer) RegisterResource(resource Resource) {
s.resourcesLock.Lock()
s.resources[resource.URI] = resource
s.resourcesLock.Unlock()
// Notify clients about the new resource
s.broadcast(eventResourcesListChanged, map[string][]Resource{keyResources: {resource}})
}
// RegisterTool registers a new tool with the server
func (s *sseMcpServer) RegisterTool(tool Tool) error {
if tool.Handler == nil {
return fmt.Errorf("tool '%s' has no handler function", tool.Name)
}
s.toolsLock.Lock()
s.tools[tool.Name] = tool
s.toolsLock.Unlock()
// Notify clients about the new tool
s.broadcast(eventToolsListChanged, map[string][]Tool{keyTools: {tool}})
return nil
}
// Start implements McpServer.
func (s *sseMcpServer) Start() {
s.server.Start()
}
func (s *sseMcpServer) Stop() {
s.server.Stop()
}
// broadcast sends a message to all connected clients
// It uses Server-Sent Events (SSE) format for real-time communication
func (s *sseMcpServer) broadcast(event string, data any) {
jsonData, err := json.Marshal(data)
if err != nil {
logx.Errorf("Failed to marshal broadcast data: %v", err)
return
}
// Lock only while reading the clients map
s.clientsLock.Lock()
clients := make([]*mcpClient, 0, len(s.clients))
for _, client := range s.clients {
clients = append(clients, client)
}
s.clientsLock.Unlock()
clientCount := len(clients)
if clientCount == 0 {
return
}
logx.Infof("Broadcasting event '%s' to %d clients", event, clientCount)
// Use CRLF line endings as per SSE specification
message := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(jsonData))
// Send messages without holding the lock
for _, client := range clients {
select {
case client.channel <- message:
// Message sent successfully
default:
// Channel buffer is full, log warning and continue
logx.Errorf("Client channel buffer full, dropping message for client %s", client.id)
}
}
}
// cleanupClient removes a client from the active clients map
func (s *sseMcpServer) cleanupClient(sessionId string) {
s.clientsLock.Lock()
defer s.clientsLock.Unlock()
if client, exists := s.clients[sessionId]; exists {
// Close the channel to signal any goroutines waiting on it
close(client.channel)
// Remove from active clients
delete(s.clients, sessionId)
logx.Infof("Cleaned up client %s (remaining clients: %d)", sessionId, len(s.clients))
}
}
// handleRequest handles MCP JSON-RPC requests
func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
// Extract sessionId from query parameters
sessionId := r.URL.Query().Get(sessionIdKey)
if len(sessionId) == 0 {
http.Error(w, fmt.Sprintf("Missing %s", sessionIdKey), http.StatusBadRequest)
return
}
// Check if the client with this sessionId exists
s.clientsLock.Lock()
client, exists := s.clients[sessionId]
s.clientsLock.Unlock()
if !exists {
http.Error(w, fmt.Sprintf("Invalid or expired %s", sessionIdKey), http.StatusBadRequest)
return
}
var req Request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
// For notification methods (no ID), we don't send a response
isNotification, err := req.isNotification()
if err != nil {
http.Error(w, "Invalid request.ID", http.StatusBadRequest)
}
w.WriteHeader(http.StatusAccepted)
// Special handling for initialization sequence
// Always allow initialize and notifications/initialized regardless of client state
if req.Method == methodInitialize {
logx.Infof("Processing initialize request with ID: %v", req.ID)
s.processInitialize(r.Context(), client, req)
logx.Infof("Sent initialize response for ID: %v, waiting for notifications/initialized", req.ID)
return
} else if req.Method == methodNotificationsInitialized {
// Handle initialized notification
logx.Info("Received notifications/initialized notification")
if !isNotification {
s.sendErrorResponse(r.Context(), client, req.ID,
"Method should be used as a notification", errCodeInvalidRequest)
return
}
s.processNotificationInitialized(client)
return
} else if !client.initialized && req.Method != methodNotificationsCancelled {
// Block most requests until client is initialized (except for cancellations)
s.sendErrorResponse(r.Context(), client, req.ID,
"Client not fully initialized, waiting for notifications/initialized",
errCodeClientNotInitialized)
return
}
// Process normal requests only after initialization
switch req.Method {
case methodToolsCall:
logx.Infof("Received tools call request with ID: %v", req.ID)
s.processToolCall(r.Context(), client, req)
logx.Infof("Sent tools call response for ID: %v", req.ID)
case methodToolsList:
logx.Infof("Processing tools/list request with ID: %v", req.ID)
s.processListTools(r.Context(), client, req)
logx.Infof("Sent tools/list response for ID: %v", req.ID)
case methodPromptsList:
logx.Infof("Processing prompts/list request with ID: %v", req.ID)
s.processListPrompts(r.Context(), client, req)
logx.Infof("Sent prompts/list response for ID: %v", req.ID)
case methodPromptsGet:
logx.Infof("Processing prompts/get request with ID: %v", req.ID)
s.processGetPrompt(r.Context(), client, req)
logx.Infof("Sent prompts/get response for ID: %v", req.ID)
case methodResourcesList:
logx.Infof("Processing resources/list request with ID: %v", req.ID)
s.processListResources(r.Context(), client, req)
logx.Infof("Sent resources/list response for ID: %v", req.ID)
case methodResourcesRead:
logx.Infof("Processing resources/read request with ID: %v", req.ID)
s.processResourcesRead(r.Context(), client, req)
logx.Infof("Sent resources/read response for ID: %v", req.ID)
case methodResourcesSubscribe:
logx.Infof("Processing resources/subscribe request with ID: %v", req.ID)
s.processResourceSubscribe(r.Context(), client, req)
logx.Infof("Sent resources/subscribe response for ID: %v", req.ID)
case methodPing:
logx.Infof("Processing ping request with ID: %v", req.ID)
s.processPing(r.Context(), client, req)
case methodNotificationsCancelled:
logx.Infof("Received notifications/cancelled notification: %v", req.ID)
s.processNotificationCancelled(r.Context(), client, req)
default:
logx.Infof("Unknown method: %s from client: %v", req.Method, req.ID)
s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound)
}
}
// handleSSE handles Server-Sent Events connections
func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) {
// Generate a unique session ID for this client
sessionId := uuid.New().String()
// Create new client with buffered channel to prevent blocking
client := &mcpClient{
id: sessionId,
channel: make(chan string, eventChanSize),
}
// Add client to active clients map
s.clientsLock.Lock()
s.clients[sessionId] = client
activeClients := len(s.clients)
s.clientsLock.Unlock()
logx.Infof("New SSE connection established for client %s (active clients: %d)",
sessionId, activeClients)
// Set proper SSE headers
w.Header().Set("Transfer-Encoding", "chunked")
// Enable streaming
flusher, ok := w.(http.Flusher)
if !ok {
logx.Error("Streaming not supported by the underlying http.ResponseWriter")
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
return
}
// Send the message endpoint URL to the client
endpoint := fmt.Sprintf("%s%s?%s=%s",
s.conf.Mcp.BaseUrl, s.conf.Mcp.MessageEndpoint, sessionIdKey, sessionId)
// Format and send the endpoint message
endpointMsg := formatSSEMessage(eventEndpoint, []byte(endpoint))
if _, err := fmt.Fprint(w, endpointMsg); err != nil {
logx.Errorf("Failed to send endpoint message to client %s: %v", sessionId, err)
s.cleanupClient(sessionId)
return
}
flusher.Flush()
// Set up keep-alive ping and client cleanup
ticker := time.NewTicker(pingInterval.Load())
defer func() {
ticker.Stop()
s.cleanupClient(sessionId)
logx.Infof("SSE connection closed for client %s", sessionId)
}()
// Message processing loop
for {
select {
case message, ok := <-client.channel:
if !ok {
// Channel was closed, end connection
logx.Infof("Client channel was closed for %s", sessionId)
return
}
// Write message to the response
if _, err := fmt.Fprint(w, message); err != nil {
logx.Infof("Failed to write message to client %s: %v", sessionId, err)
return
}
flusher.Flush()
case <-ticker.C:
// Send keep-alive ping to maintain connection
ping := fmt.Sprintf(`{"type":"ping","timestamp":"%s"}`, time.Now().String())
pingMsg := formatSSEMessage("ping", []byte(ping))
if _, err := fmt.Fprint(w, pingMsg); err != nil {
logx.Errorf("Failed to send ping to client %s, closing connection: %v", sessionId, err)
return
}
flusher.Flush()
case <-r.Context().Done():
// Client disconnected or request was canceled or timed out
logx.Infof("Client %s disconnected: context done", sessionId)
return
}
}
}
// processInitialize processes the initialize request
func (s *sseMcpServer) processInitialize(ctx context.Context, client *mcpClient, req Request) {
// Create a proper JSON-RPC response that preserves the client's request ID
result := initializationResponse{
ProtocolVersion: s.conf.Mcp.ProtocolVersion,
Capabilities: capabilities{
Prompts: struct {
ListChanged bool `json:"listChanged"`
}{
ListChanged: true,
},
Resources: struct {
Subscribe bool `json:"subscribe"`
ListChanged bool `json:"listChanged"`
}{
Subscribe: true,
ListChanged: true,
},
Tools: struct {
ListChanged bool `json:"listChanged"`
}{
ListChanged: true,
},
},
ServerInfo: serverInfo{
Name: s.conf.Mcp.Name,
Version: s.conf.Mcp.Version,
},
}
// Mark client as initialized
client.initialized = true
// Send response with client's original request ID
s.sendResponse(ctx, client, req.ID, result)
}
// processListTools processes the tools/list request
func (s *sseMcpServer) processListTools(ctx context.Context, client *mcpClient, req Request) {
// Extract pagination params if any
var nextCursor string
var progressToken any
// Extract meta data including progress token
if req.Params != nil {
var metaParams struct {
Cursor string `json:"cursor"`
Meta struct {
ProgressToken any `json:"progressToken"`
} `json:"_meta"`
}
if err := json.Unmarshal(req.Params, &metaParams); err == nil {
if len(metaParams.Cursor) > 0 {
nextCursor = metaParams.Cursor
}
progressToken = metaParams.Meta.ProgressToken
}
}
var toolsList []Tool
s.toolsLock.Lock()
for _, tool := range s.tools {
if len(tool.InputSchema.Type) == 0 {
tool.InputSchema.Type = ContentTypeObject
}
toolsList = append(toolsList, tool)
}
s.toolsLock.Unlock()
result := ListToolsResult{
PaginatedResult: PaginatedResult{
Result: Result{},
NextCursor: Cursor(nextCursor),
},
Tools: toolsList,
}
// Add meta information if progress token was provided
if progressToken != nil {
result.Result.Meta = map[string]any{
progressTokenKey: progressToken,
}
}
s.sendResponse(ctx, client, req.ID, result)
}
// processListPrompts processes the prompts/list request
func (s *sseMcpServer) processListPrompts(ctx context.Context, client *mcpClient, req Request) {
// Extract pagination params if any
var nextCursor string
if req.Params != nil {
var cursorParams struct {
Cursor string `json:"cursor"`
}
if err := json.Unmarshal(req.Params, &cursorParams); err == nil && cursorParams.Cursor != "" {
// If we have a valid cursor, we could use it for pagination
// For now, we're not actually implementing pagination, so this is just
// to show how it would be extracted from the request
_ = cursorParams.Cursor
}
}
// Prepare prompt list
var promptsList []Prompt
s.promptsLock.Lock()
for _, prompt := range s.prompts {
promptsList = append(promptsList, prompt)
}
s.promptsLock.Unlock()
// In a real implementation, you'd handle pagination here
// For now, we'll return all prompts at once
result := struct {
Prompts []Prompt `json:"prompts"`
NextCursor string `json:"nextCursor,omitempty"`
Meta *struct{} `json:"_meta,omitempty"`
}{
Prompts: promptsList,
NextCursor: nextCursor,
}
s.sendResponse(ctx, client, req.ID, result)
}
// processListResources processes the resources/list request
func (s *sseMcpServer) processListResources(ctx context.Context, client *mcpClient, req Request) {
// Extract pagination params if any
var nextCursor string
var progressToken any
// Extract meta information including progress token if available
if req.Params != nil {
var metaParams PaginatedParams
if err := json.Unmarshal(req.Params, &metaParams); err == nil {
if len(metaParams.Cursor) > 0 {
nextCursor = metaParams.Cursor
}
progressToken = metaParams.Meta.ProgressToken
}
}
var resourcesList []Resource
s.resourcesLock.Lock()
for _, resource := range s.resources {
// Create a copy without the handler function which shouldn't be sent to clients
resourceCopy := Resource{
URI: resource.URI,
Name: resource.Name,
Description: resource.Description,
MimeType: resource.MimeType,
}
resourcesList = append(resourcesList, resourceCopy)
}
s.resourcesLock.Unlock()
// Create proper ResourcesListResult according to MCP specification
result := ResourcesListResult{
PaginatedResult: PaginatedResult{
Result: Result{},
NextCursor: Cursor(nextCursor),
},
Resources: resourcesList,
}
// Add meta information if progress token was provided
if progressToken != nil {
result.Result.Meta = map[string]any{
progressTokenKey: progressToken,
}
}
s.sendResponse(ctx, client, req.ID, result)
}
// processGetPrompt processes the prompts/get request
func (s *sseMcpServer) processGetPrompt(ctx context.Context, client *mcpClient, req Request) {
type GetPromptParams struct {
Name string `json:"name"`
Arguments map[string]string `json:"arguments,omitempty"`
}
var params GetPromptParams
if err := json.Unmarshal(req.Params, &params); err != nil {
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
return
}
// Check if prompt exists
s.promptsLock.Lock()
prompt, exists := s.prompts[params.Name]
s.promptsLock.Unlock()
if !exists {
message := fmt.Sprintf("Prompt '%s' not found", params.Name)
s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams)
return
}
logx.Infof("Processing prompt request: %s with %d arguments", prompt.Name, len(params.Arguments))
// Validate required arguments
missingArgs := validatePromptArguments(prompt, params.Arguments)
if len(missingArgs) > 0 {
message := fmt.Sprintf("Missing required arguments: %s", strings.Join(missingArgs, ", "))
s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams)
return
}
// Ensure arguments are initialized to an empty map if nil
if params.Arguments == nil {
params.Arguments = make(map[string]string)
}
args := params.Arguments
// Generate messages using handler or static content
var messages []PromptMessage
var err error
if prompt.Handler != nil {
// Use dynamic handler to generate messages
messages, err = prompt.Handler(ctx, args)
if err != nil {
logx.Errorf("Error from prompt handler: %v", err)
s.sendErrorResponse(ctx, client, req.ID,
fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError)
return
}
} else {
// No handler, generate messages from static content
var messageText string
if len(prompt.Content) > 0 {
messageText = prompt.Content
// Apply argument substitutions to static content
for key, value := range args {
placeholder := fmt.Sprintf("{{%s}}", key)
messageText = strings.Replace(messageText, placeholder, value, -1)
}
}
// Create a single user message with the content
messages = []PromptMessage{
{
Role: RoleUser,
Content: TextContent{
Text: messageText,
},
},
}
}
// Construct the response according to MCP spec
result := struct {
Description string `json:"description,omitempty"`
Messages []PromptMessage `json:"messages"`
}{
Description: prompt.Description,
Messages: toTypedPromptMessages(messages),
}
s.sendResponse(ctx, client, req.ID, result)
}
// processToolCall processes the tools/call request
func (s *sseMcpServer) processToolCall(ctx context.Context, client *mcpClient, req Request) {
var toolCallParams struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments,omitempty"`
Meta struct {
ProgressToken any `json:"progressToken"`
} `json:"_meta,omitempty"`
}
// Handle different types of req.Params
// If it's a RawMessage (JSON), unmarshal it
if err := json.Unmarshal(req.Params, &toolCallParams); err != nil {
logx.Errorf("Failed to unmarshal tool call params: %v", err)
s.sendErrorResponse(ctx, client, req.ID, "Invalid tool call parameters", errCodeInvalidParams)
return
}
// Extract progress token if available
progressToken := toolCallParams.Meta.ProgressToken
// Find the requested tool
s.toolsLock.Lock()
tool, exists := s.tools[toolCallParams.Name]
s.toolsLock.Unlock()
if !exists {
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Tool '%s' not found",
toolCallParams.Name), errCodeInvalidParams)
return
}
// Log parameters before execution
logx.Infof("Executing tool '%s' with arguments: %#v", toolCallParams.Name, toolCallParams.Arguments)
// Execute the tool handler with timeout handling
var result any
var err error
// Create a channel to receive the result
// make sure to have 1 size buffer to avoid channel leak if timeout
resultCh := make(chan struct {
result any
err error
}, 1)
// Execute the tool handler in a goroutine
go func() {
toolResult, toolErr := tool.Handler(ctx, toolCallParams.Arguments)
resultCh <- struct {
result any
err error
}{
result: toolResult,
err: toolErr,
}
}()
// Wait for either the result or a timeout
select {
case res := <-resultCh:
result = res.result
err = res.err
case <-ctx.Done():
// Handle request timeout
logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.MessageTimeout, toolCallParams.Name)
s.sendErrorResponse(ctx, client, req.ID, "Tool execution timed out", errCodeTimeout)
return
}
// Create the base result structure with metadata
callToolResult := CallToolResult{
Result: Result{},
Content: []any{},
IsError: false,
}
// Add meta information if progress token was provided
if progressToken != nil {
callToolResult.Result.Meta = map[string]any{
progressTokenKey: progressToken,
}
}
// Check if there was an error during tool execution
if err != nil {
// According to the spec, for tool-level errors (as opposed to protocol-level errors),
// we should report them inside the result with isError=true
logx.Errorf("Tool execution reported error: %v", err)
callToolResult.Content = []any{
TextContent{
Text: fmt.Sprintf("Error: %v", err),
},
}
callToolResult.IsError = true
s.sendResponse(ctx, client, req.ID, callToolResult)
return
}
// Format the response according to the CallToolResult schema
switch v := result.(type) {
case string:
// Simple string becomes text content
callToolResult.Content = append(callToolResult.Content, TextContent{
Text: v,
Annotations: &Annotations{
Audience: []RoleType{RoleUser, RoleAssistant},
},
})
case map[string]any:
// JSON-like object becomes formatted JSON text
jsonStr, err := json.Marshal(v)
if err != nil {
jsonStr = []byte(err.Error())
}
callToolResult.Content = append(callToolResult.Content, TextContent{
Text: string(jsonStr),
Annotations: &Annotations{
Audience: []RoleType{RoleUser, RoleAssistant},
},
})
case TextContent:
callToolResult.Content = append(callToolResult.Content, v)
case ImageContent:
callToolResult.Content = append(callToolResult.Content, v)
case []any:
callToolResult.Content = v
case ToolResult:
// Handle legacy ToolResult type
switch v.Type {
case ContentTypeText:
callToolResult.Content = append(callToolResult.Content, TextContent{
Text: fmt.Sprintf("%v", v.Content),
Annotations: &Annotations{
Audience: []RoleType{RoleUser, RoleAssistant},
},
})
case ContentTypeImage:
if imgData, ok := v.Content.(map[string]any); ok {
callToolResult.Content = append(callToolResult.Content, ImageContent{
Data: fmt.Sprintf("%v", imgData["data"]),
MimeType: fmt.Sprintf("%v", imgData["mimeType"]),
})
}
default:
callToolResult.Content = append(callToolResult.Content, TextContent{
Text: fmt.Sprintf("%v", v.Content),
Annotations: &Annotations{
Audience: []RoleType{RoleUser, RoleAssistant},
},
})
}
default:
// For any other type, convert to string
callToolResult.Content = append(callToolResult.Content, TextContent{
Text: fmt.Sprintf("%v", v),
Annotations: &Annotations{
Audience: []RoleType{RoleUser, RoleAssistant},
},
})
}
callToolResult.Content = toTypedContents(callToolResult.Content)
logx.Infof("Tool call result: %#v", callToolResult)
s.sendResponse(ctx, client, req.ID, callToolResult)
}
// processResourcesRead processes the resources/read request
func (s *sseMcpServer) processResourcesRead(ctx context.Context, client *mcpClient, req Request) {
var params ResourceReadParams
if err := json.Unmarshal(req.Params, &params); err != nil {
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
return
}
// Find resource that matches the URI
s.resourcesLock.Lock()
resource, exists := s.resources[params.URI]
s.resourcesLock.Unlock()
if !exists {
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
params.URI), errCodeResourceNotFound)
return
}
// If no handler is provided, return an empty content array
if resource.Handler == nil {
result := ResourceReadResult{
Contents: []ResourceContent{
{
URI: params.URI,
MimeType: resource.MimeType,
Text: "",
},
},
}
s.sendResponse(ctx, client, req.ID, result)
return
}
// Execute the resource handler
content, err := resource.Handler(ctx)
if err != nil {
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Error reading resource: %v", err),
errCodeInternalError)
return
}
// Ensure the URI is set if not already provided by the handler
if len(content.URI) == 0 {
content.URI = params.URI
}
// Ensure MimeType is set if available from the resource definition
if len(content.MimeType) == 0 && len(resource.MimeType) > 0 {
content.MimeType = resource.MimeType
}
// Create response with contents from the handler
// The MCP specification requires a contents array
result := ResourceReadResult{
Contents: []ResourceContent{content},
}
s.sendResponse(ctx, client, req.ID, result)
}
// processResourceSubscribe processes the resources/subscribe request
func (s *sseMcpServer) processResourceSubscribe(ctx context.Context, client *mcpClient, req Request) {
var params ResourceSubscribeParams
if err := json.Unmarshal(req.Params, &params); err != nil {
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
return
}
// Check if the resource exists
s.resourcesLock.Lock()
_, exists := s.resources[params.URI]
s.resourcesLock.Unlock()
if !exists {
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
params.URI), errCodeResourceNotFound)
return
}
// Send success response for the subscription
s.sendResponse(ctx, client, req.ID, struct{}{})
}
// processNotificationCancelled processes the notifications/cancelled notification
func (s *sseMcpServer) processNotificationCancelled(ctx context.Context, client *mcpClient, req Request) {
// Extract the requestId that was canceled
type CancelParams struct {
RequestId int64 `json:"requestId"`
Reason string `json:"reason"`
}
var params CancelParams
if err := json.Unmarshal(req.Params, &params); err != nil {
logx.Errorf("Failed to parse cancellation params: %v", err)
return
}
logx.Infof("Request %d was cancelled by client. Reason: %s", params.RequestId, params.Reason)
}
// processNotificationInitialized processes the notifications/initialized notification
func (s *sseMcpServer) processNotificationInitialized(client *mcpClient) {
// Mark the client as properly initialized
client.initialized = true
logx.Infof("Client %s is now fully initialized and ready for normal operations", client.id)
}
// processPing processes the ping request and responds immediately
func (s *sseMcpServer) processPing(ctx context.Context, client *mcpClient, req Request) {
// A ping request should simply respond with an empty result to confirm the server is alive
logx.Infof("Received ping request with ID: %d", req.ID)
// Send an empty response with client's original request ID
s.sendResponse(ctx, client, req.ID, struct{}{})
}
// sendErrorResponse sends an error response via the SSE channel
func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
id any, message string, code int) {
errorResponse := struct {
JsonRpc string `json:"jsonrpc"`
ID any `json:"id"`
Error errorMessage `json:"error"`
}{
JsonRpc: jsonRpcVersion,
ID: id,
Error: errorMessage{
Code: code,
Message: message,
},
}
// all fields are primitive types, impossible to fail
jsonData, _ := json.Marshal(errorResponse)
// Use CRLF line endings as requested
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
logx.Infof("Sending error for ID %v: %s", id, sseMessage)
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
select {
case client.channel <- sseMessage:
default:
// Channel buffer is full, log warning and continue
logx.Infof("Client %s channel is full while sending error response with ID %d", client.id, id)
}
}
// sendResponse sends a success response via the SSE channel
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id any, result any) {
response := Response{
JsonRpc: jsonRpcVersion,
ID: id,
Result: result,
}
jsonData, err := json.Marshal(response)
if err != nil {
s.sendErrorResponse(ctx, client, id, "Failed to marshal response", errCodeInternalError)
return
}
// Use CRLF line endings as requested
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
logx.Infof("Sending response for ID %v: %s", id, sseMessage)
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
select {
case client.channel <- sseMessage:
default:
// Channel buffer is full, log warning and continue
logx.Infof("Client %s channel is full while sending response with ID %v", client.id, id)
}
}

3451
mcp/server_test.go Normal file

File diff suppressed because it is too large Load Diff

317
mcp/types.go Normal file
View File

@@ -0,0 +1,317 @@
package mcp
import (
"context"
"encoding/json"
"fmt"
"sync"
"github.com/zeromicro/go-zero/rest"
)
// Cursor is an opaque token used for pagination
type Cursor string
// Request represents a generic MCP request following JSON-RPC 2.0 specification
type Request struct {
SessionId string `form:"session_id"` // Session identifier for client tracking
JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec
ID any `json:"id"` // Request identifier for matching responses
Method string `json:"method"` // Method name to invoke
Params json.RawMessage `json:"params"` // Parameters for the method
}
func (r Request) isNotification() (bool, error) {
switch val := r.ID.(type) {
case int:
return val == 0, nil
case int64:
return val == 0, nil
case float64:
return val == 0.0, nil
case string:
return len(val) == 0, nil
case nil:
return true, nil
default:
return false, fmt.Errorf("invalid type %T", val)
}
}
type PaginatedParams struct {
Cursor string `json:"cursor"`
Meta struct {
ProgressToken any `json:"progressToken"`
} `json:"_meta"`
}
// Result is the base interface for all results
type Result struct {
Meta map[string]any `json:"_meta,omitempty"` // Optional metadata
}
// PaginatedResult is a base for results that support pagination
type PaginatedResult struct {
Result
NextCursor Cursor `json:"nextCursor,omitempty"` // Opaque token for fetching next page
}
// ListToolsResult represents the response to a tools/list request
type ListToolsResult struct {
PaginatedResult
Tools []Tool `json:"tools"` // List of available tools
}
// Message Content Types
// RoleType represents the sender or recipient of messages in a conversation
type RoleType string
// PromptArgument defines a single argument that can be passed to a prompt
type PromptArgument struct {
Name string `json:"name"` // Argument name
Description string `json:"description,omitempty"` // Human-readable description
Required bool `json:"required,omitempty"` // Whether this argument is required
}
// PromptHandler is a function that dynamically generates prompt content
type PromptHandler func(ctx context.Context, args map[string]string) ([]PromptMessage, error)
// Prompt represents an MCP Prompt definition
type Prompt struct {
Name string `json:"name"` // Unique identifier for the prompt
Description string `json:"description,omitempty"` // Human-readable description
Arguments []PromptArgument `json:"arguments,omitempty"` // Arguments for customization
Content string `json:"-"` // Static content (internal use only)
Handler PromptHandler `json:"-"` // Handler for dynamic content generation
}
// PromptMessage represents a message in a conversation
type PromptMessage struct {
Role RoleType `json:"role"` // Message sender role
Content any `json:"content"` // Message content (TextContent, ImageContent, etc.)
}
// TextContent represents text content in a message
type TextContent struct {
Text string `json:"text"` // The text content
Annotations *Annotations `json:"annotations,omitempty"` // Optional annotations
}
type typedTextContent struct {
Type string `json:"type"`
TextContent
}
// ImageContent represents image data in a message
type ImageContent struct {
Data string `json:"data"` // Base64-encoded image data
MimeType string `json:"mimeType"` // MIME type (e.g., "image/png")
}
type typedImageContent struct {
Type string `json:"type"`
ImageContent
}
// AudioContent represents audio data in a message
type AudioContent struct {
Data string `json:"data"` // Base64-encoded audio data
MimeType string `json:"mimeType"` // MIME type (e.g., "audio/mp3")
}
type typedAudioContent struct {
Type string `json:"type"`
AudioContent
}
// FileContent represents file content
type FileContent struct {
URI string `json:"uri"` // URI identifying the file
MimeType string `json:"mimeType"` // MIME type of the file
Text string `json:"text"` // File content as text
}
// EmbeddedResource represents a resource embedded in a message
type EmbeddedResource struct {
Type string `json:"type"` // Always "resource"
Resource ResourceContent `json:"resource"` // The resource data
}
// Annotations provides additional metadata for content
type Annotations struct {
Audience []RoleType `json:"audience,omitempty"` // Who should see this content
Priority *float64 `json:"priority,omitempty"` // Optional priority (0-1)
}
// Tool-related Types
// ToolHandler is a function that handles tool calls
type ToolHandler func(ctx context.Context, params map[string]any) (any, error)
// Tool represents a Model Context Protocol Tool definition
type Tool struct {
Name string `json:"name"` // Unique identifier for the tool
Description string `json:"description"` // Human-readable description
InputSchema InputSchema `json:"inputSchema"` // JSON Schema for parameters
Handler ToolHandler `json:"-"` // Not sent to clients
}
// InputSchema represents tool's input schema in JSON Schema format
type InputSchema struct {
Type string `json:"type"`
Properties map[string]any `json:"properties"` // Property definitions
Required []string `json:"required,omitempty"` // List of required properties
}
// CallToolResult represents a tool call result that conforms to the MCP schema
type CallToolResult struct {
Result
Content []any `json:"content"` // Content items (text, images, etc.)
IsError bool `json:"isError,omitempty"` // True if tool execution failed
}
// Resource represents a Model Context Protocol Resource definition
type Resource struct {
URI string `json:"uri"` // Unique resource identifier (RFC3986)
Name string `json:"name"` // Human-readable name
Description string `json:"description,omitempty"` // Optional description
MimeType string `json:"mimeType,omitempty"` // Optional MIME type
Handler ResourceHandler `json:"-"` // Internal handler not sent to clients
}
// ResourceHandler is a function that handles resource read requests
type ResourceHandler func(ctx context.Context) (ResourceContent, error)
// ResourceContent represents the content of a resource
type ResourceContent struct {
URI string `json:"uri"` // Resource URI (required)
MimeType string `json:"mimeType,omitempty"` // MIME type of the resource
Text string `json:"text,omitempty"` // Text content (if available)
Blob string `json:"blob,omitempty"` // Base64 encoded blob data (if available)
}
// ResourcesListResult represents the response to a resources/list request
type ResourcesListResult struct {
PaginatedResult
Resources []Resource `json:"resources"` // List of available resources
}
// ResourceReadParams contains parameters for a resources/read request
type ResourceReadParams struct {
URI string `json:"uri"` // URI of the resource to read
}
// ResourceReadResult contains the result of a resources/read request
type ResourceReadResult struct {
Result
Contents []ResourceContent `json:"contents"` // Array of resource content
}
// ResourceSubscribeParams contains parameters for a resources/subscribe request
type ResourceSubscribeParams struct {
URI string `json:"uri"` // URI of the resource to subscribe to
}
// ResourceUpdateNotification represents a notification about a resource update
type ResourceUpdateNotification struct {
URI string `json:"uri"` // URI of the updated resource
Content ResourceContent `json:"content"` // New resource content
}
// Client and Server Types
// mcpClient represents an SSE client connection
type mcpClient struct {
id string // Unique client identifier
channel chan string // Channel for sending SSE messages
initialized bool // Tracks if client has sent notifications/initialized
}
// McpServer defines the interface for Model Context Protocol servers
type McpServer interface {
Start()
Stop()
RegisterTool(tool Tool) error
RegisterPrompt(prompt Prompt)
RegisterResource(resource Resource)
}
// sseMcpServer implements the McpServer interface using SSE
type sseMcpServer struct {
conf McpConf
server *rest.Server
clients map[string]*mcpClient
clientsLock sync.Mutex
tools map[string]Tool
toolsLock sync.Mutex
prompts map[string]Prompt
promptsLock sync.Mutex
resources map[string]Resource
resourcesLock sync.Mutex
}
// Response Types
// errorObj represents a JSON-RPC error object
type errorObj struct {
Code int `json:"code"` // Error code
Message string `json:"message"` // Error message
}
// Response represents a JSON-RPC response
type Response struct {
JsonRpc string `json:"jsonrpc"` // Always "2.0"
ID any `json:"id"` // Same as request ID
Result any `json:"result"` // Result object (null if error)
Error *errorObj `json:"error,omitempty"` // Error object (null if success)
}
// Server Information Types
// serverInfo provides information about the server
type serverInfo struct {
Name string `json:"name"` // Server name
Version string `json:"version"` // Server version
}
// capabilities describes the server's capabilities
type capabilities struct {
Logging struct{} `json:"logging"`
Prompts struct {
ListChanged bool `json:"listChanged"` // Server will notify on prompt changes
} `json:"prompts"`
Resources struct {
Subscribe bool `json:"subscribe"` // Server supports resource subscriptions
ListChanged bool `json:"listChanged"` // Server will notify on resource changes
} `json:"resources"`
Tools struct {
ListChanged bool `json:"listChanged"` // Server will notify on tool changes
} `json:"tools"`
}
// initializationResponse is sent in response to an initialize request
type initializationResponse struct {
ProtocolVersion string `json:"protocolVersion"` // Protocol version
Capabilities capabilities `json:"capabilities"` // Server capabilities
ServerInfo serverInfo `json:"serverInfo"` // Server information
}
// ToolCallParams contains the parameters for a tool call
type ToolCallParams struct {
Name string `json:"name"` // Tool name
Parameters map[string]any `json:"parameters"` // Tool parameters
}
// ToolResult contains the result of a tool execution
type ToolResult struct {
Type string `json:"type"` // Content type (text, image, etc.)
Content any `json:"content"` // Result content
}
// errorMessage represents a detailed error message
type errorMessage struct {
Code int `json:"code"` // Error code
Message string `json:"message"` // Error message
Data any `json:",omitempty"` // Additional error data
}

271
mcp/types_test.go Normal file
View File

@@ -0,0 +1,271 @@
package mcp
import (
"context"
"encoding/json"
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
func TestResponseMarshaling(t *testing.T) {
// Test that the Response struct marshals correctly
resp := Response{
JsonRpc: "2.0",
ID: 123,
Result: map[string]string{
"key": "value",
},
}
data, err := json.Marshal(resp)
assert.NoError(t, err)
assert.Contains(t, string(data), `"jsonrpc":"2.0"`)
assert.Contains(t, string(data), `"id":123`)
assert.Contains(t, string(data), `"result":{"key":"value"}`)
// Test response with error
respWithError := Response{
JsonRpc: "2.0",
ID: 456,
Error: &errorObj{
Code: errCodeInvalidRequest,
Message: "Invalid Request",
},
}
data, err = json.Marshal(respWithError)
assert.NoError(t, err)
assert.Contains(t, string(data), `"jsonrpc":"2.0"`)
assert.Contains(t, string(data), `"id":456`)
assert.Contains(t, string(data), `"error":{"code":-32600,"message":"Invalid Request"}`)
}
func TestRequestUnmarshaling(t *testing.T) {
// Test that the Request struct unmarshals correctly
jsonStr := `{
"jsonrpc": "2.0",
"id": 789,
"method": "test_method",
"params": {"key": "value"}
}`
var req Request
err := json.Unmarshal([]byte(jsonStr), &req)
assert.NoError(t, err)
assert.Equal(t, "2.0", req.JsonRpc)
assert.Equal(t, float64(789), req.ID)
assert.Equal(t, "test_method", req.Method)
// Check params unmarshaled correctly
var params map[string]string
err = json.Unmarshal(req.Params, &params)
assert.NoError(t, err)
assert.Equal(t, "value", params["key"])
}
func TestToolStructs(t *testing.T) {
// Test Tool struct
tool := Tool{
Name: "test.tool",
Description: "A test tool",
InputSchema: InputSchema{
Type: "object",
Properties: map[string]any{
"input": map[string]any{
"type": "string",
"description": "Input parameter",
},
},
Required: []string{"input"},
},
Handler: func(ctx context.Context, params map[string]any) (any, error) {
return "result", nil
},
}
// Verify fields are correct
assert.Equal(t, "test.tool", tool.Name)
assert.Equal(t, "A test tool", tool.Description)
assert.Equal(t, "object", tool.InputSchema.Type)
assert.Contains(t, tool.InputSchema.Properties, "input")
propMap, ok := tool.InputSchema.Properties["input"].(map[string]any)
assert.True(t, ok, "Property should be a map")
assert.Equal(t, "string", propMap["type"])
assert.NotNil(t, tool.Handler)
// Verify JSON marshalling (which should exclude Handler function)
data, err := json.Marshal(tool)
assert.NoError(t, err)
assert.Contains(t, string(data), `"name":"test.tool"`)
assert.Contains(t, string(data), `"description":"A test tool"`)
assert.Contains(t, string(data), `"inputSchema":`)
assert.NotContains(t, string(data), `"Handler":`)
}
func TestPromptStructs(t *testing.T) {
// Test Prompt struct
prompt := Prompt{
Name: "test.prompt",
Description: "A test prompt description",
}
// Verify fields are correct
assert.Equal(t, "test.prompt", prompt.Name)
assert.Equal(t, "A test prompt description", prompt.Description)
// Verify JSON marshalling
data, err := json.Marshal(prompt)
assert.NoError(t, err)
assert.Contains(t, string(data), `"name":"test.prompt"`)
assert.Contains(t, string(data), `"description":"A test prompt description"`)
}
func TestResourceStructs(t *testing.T) {
// Test Resource struct
resource := Resource{
Name: "test.resource",
URI: "http://example.com/resource",
Description: "A test resource",
}
// Verify fields are correct
assert.Equal(t, "test.resource", resource.Name)
assert.Equal(t, "http://example.com/resource", resource.URI)
assert.Equal(t, "A test resource", resource.Description)
// Verify JSON marshalling
data, err := json.Marshal(resource)
assert.NoError(t, err)
assert.Contains(t, string(data), `"name":"test.resource"`)
assert.Contains(t, string(data), `"uri":"http://example.com/resource"`)
assert.Contains(t, string(data), `"description":"A test resource"`)
}
func TestContentTypes(t *testing.T) {
// Test TextContent
textContent := TextContent{
Text: "Sample text",
Annotations: &Annotations{
Audience: []RoleType{RoleUser, RoleAssistant},
Priority: ptr(1.0),
},
}
data, err := json.Marshal(textContent)
assert.NoError(t, err)
assert.Contains(t, string(data), `"text":"Sample text"`)
assert.Contains(t, string(data), `"audience":["user","assistant"]`)
assert.Contains(t, string(data), `"priority":1`)
// Test ImageContent
imageContent := ImageContent{
Data: "base64data",
MimeType: "image/png",
}
data, err = json.Marshal(imageContent)
assert.NoError(t, err)
assert.Contains(t, string(data), `"data":"base64data"`)
assert.Contains(t, string(data), `"mimeType":"image/png"`)
// Test AudioContent
audioContent := AudioContent{
Data: "base64audio",
MimeType: "audio/mp3",
}
data, err = json.Marshal(audioContent)
assert.NoError(t, err)
assert.Contains(t, string(data), `"data":"base64audio"`)
assert.Contains(t, string(data), `"mimeType":"audio/mp3"`)
}
func TestCallToolResult(t *testing.T) {
// Test CallToolResult
result := CallToolResult{
Result: Result{
Meta: map[string]any{
"progressToken": "token123",
},
},
Content: []interface{}{
TextContent{
Text: "Sample result",
},
},
IsError: false,
}
data, err := json.Marshal(result)
assert.NoError(t, err)
assert.Contains(t, string(data), `"_meta":{"progressToken":"token123"}`)
assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`)
assert.NotContains(t, string(data), `"isError":`)
}
func TestRequest_isNotification(t *testing.T) {
tests := []struct {
name string
id any
want bool
wantErr error
}{
// integer test cases
{name: "int zero", id: 0, want: true, wantErr: nil},
{name: "int non-zero", id: 1, want: false, wantErr: nil},
{name: "int64 zero", id: int64(0), want: true, wantErr: nil},
{name: "int64 max", id: int64(9223372036854775807), want: false, wantErr: nil},
// floating point number test cases
{name: "float64 zero", id: float64(0.0), want: true, wantErr: nil},
{name: "float64 positive", id: float64(0.000001), want: false, wantErr: nil},
{name: "float64 negative", id: float64(-0.000001), want: false, wantErr: nil},
{name: "float64 epsilon", id: float64(1e-300), want: false, wantErr: nil},
// string test cases
{name: "empty string", id: "", want: true, wantErr: nil},
{name: "non-empty string", id: "abc", want: false, wantErr: nil},
{name: "space string", id: " ", want: false, wantErr: nil},
{name: "unicode string", id: "こんにちは", want: false, wantErr: nil},
// special cases
{name: "nil", id: nil, want: true, wantErr: nil},
// logical type test cases
{name: "bool true", id: true, want: false, wantErr: errors.New("invalid type bool")},
{name: "bool false", id: false, want: false, wantErr: errors.New("invalid type bool")},
{name: "struct type", id: struct{}{}, want: false, wantErr: errors.New("invalid type struct {}")},
{name: "slice type", id: []int{1, 2, 3}, want: false, wantErr: errors.New("invalid type []int")},
{name: "map type", id: map[string]int{"a": 1}, want: false, wantErr: errors.New("invalid type map[string]int")},
{name: "pointer type", id: new(int), want: false, wantErr: errors.New("invalid type *int")},
{name: "func type", id: func() {}, want: false, wantErr: errors.New("invalid type func()")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := Request{
SessionId: "test-session",
JsonRpc: "2.0",
ID: tt.id,
Method: "testMethod",
Params: json.RawMessage(`{}`),
}
got, err := req.isNotification()
if (err != nil) != (tt.wantErr != nil) {
t.Fatalf("error presence mismatch: got error = %v, wantErr %v", err, tt.wantErr)
}
if err != nil && tt.wantErr != nil && err.Error() != tt.wantErr.Error() {
t.Fatalf("error message mismatch:\ngot %q\nwant %q", err.Error(), tt.wantErr.Error())
}
if got != tt.want {
t.Errorf("isNotification() = %v, want %v for ID %v (%T)", got, tt.want, tt.id, tt.id)
}
})
}
}

107
mcp/util.go Normal file
View File

@@ -0,0 +1,107 @@
package mcp
import "fmt"
// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings
func formatSSEMessage(event string, data []byte) string {
return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data))
}
// ptr is a helper function to get a pointer to a value
func ptr[T any](v T) *T {
return &v
}
func toTypedContents(contents []any) []any {
typedContents := make([]any, len(contents))
for i, content := range contents {
switch v := content.(type) {
case TextContent:
typedContents[i] = typedTextContent{
Type: ContentTypeText,
TextContent: v,
}
case ImageContent:
typedContents[i] = typedImageContent{
Type: ContentTypeImage,
ImageContent: v,
}
case AudioContent:
typedContents[i] = typedAudioContent{
Type: ContentTypeAudio,
AudioContent: v,
}
default:
typedContents[i] = typedTextContent{
Type: ContentTypeText,
TextContent: TextContent{
Text: fmt.Sprintf("Unknown content type: %T", v),
},
}
}
}
return typedContents
}
func toTypedPromptMessages(messages []PromptMessage) []PromptMessage {
typedMessages := make([]PromptMessage, len(messages))
for i, msg := range messages {
switch v := msg.Content.(type) {
case TextContent:
typedMessages[i] = PromptMessage{
Role: msg.Role,
Content: typedTextContent{
Type: ContentTypeText,
TextContent: v,
},
}
case ImageContent:
typedMessages[i] = PromptMessage{
Role: msg.Role,
Content: typedImageContent{
Type: ContentTypeImage,
ImageContent: v,
},
}
case AudioContent:
typedMessages[i] = PromptMessage{
Role: msg.Role,
Content: typedAudioContent{
Type: ContentTypeAudio,
AudioContent: v,
},
}
default:
typedMessages[i] = PromptMessage{
Role: msg.Role,
Content: typedTextContent{
Type: ContentTypeText,
TextContent: TextContent{
Text: fmt.Sprintf("Unknown content type: %T", v),
},
},
}
}
}
return typedMessages
}
// validatePromptArguments checks if all required arguments are provided
// Returns a list of missing required arguments
func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string {
var missingArgs []string
for _, arg := range prompt.Arguments {
if arg.Required {
if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 {
missingArgs = append(missingArgs, arg.Name)
}
}
}
return missingArgs
}

274
mcp/util_test.go Normal file
View File

@@ -0,0 +1,274 @@
package mcp
import (
"bufio"
"encoding/json"
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type Event struct {
Type string
Data map[string]any
}
func parseEvent(input string) (*Event, error) {
var evt Event
var dataStr string
scanner := bufio.NewScanner(strings.NewReader(input))
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "event:") {
evt.Type = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
} else if strings.HasPrefix(line, "data:") {
dataStr = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
}
}
if err := scanner.Err(); err != nil {
return nil, err
}
if len(dataStr) > 0 {
if err := json.Unmarshal([]byte(dataStr), &evt.Data); err != nil {
return nil, fmt.Errorf("failed to parse data: %w", err)
}
}
return &evt, nil
}
// TestToTypedPromptMessages tests the toTypedPromptMessages function
func TestToTypedPromptMessages(t *testing.T) {
// Test with multiple message types in one test
t.Run("MixedContentTypes", func(t *testing.T) {
// Create test data with different content types
messages := []PromptMessage{
{
Role: RoleUser,
Content: TextContent{
Text: "Hello, this is a text message",
Annotations: &Annotations{
Audience: []RoleType{RoleUser, RoleAssistant},
Priority: ptr(0.8),
},
},
},
{
Role: RoleAssistant,
Content: ImageContent{
Data: "base64ImageData",
MimeType: "image/jpeg",
},
},
{
Role: RoleUser,
Content: AudioContent{
Data: "base64AudioData",
MimeType: "audio/mp3",
},
},
{
Role: "system",
Content: "This is a simple string that should be handled as unknown type",
},
}
// Call the function
result := toTypedPromptMessages(messages)
// Validate results
require.Len(t, result, 4, "Should return the same number of messages")
// Validate first message (TextContent)
msg := result[0]
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
// Type assertion using reflection since Content is an interface
typed, ok := msg.Content.(typedTextContent)
require.True(t, ok, "Should be typedTextContent")
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
assert.Equal(t, "Hello, this is a text message", typed.Text, "Text content should be preserved")
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
assert.Equal(t, 0.8, *typed.Annotations.Priority, "Priority value should be preserved")
// Validate second message (ImageContent)
msg = result[1]
assert.Equal(t, RoleAssistant, msg.Role, "Role should be preserved")
// Type assertion for image content
typedImg, ok := msg.Content.(typedImageContent)
require.True(t, ok, "Should be typedImageContent")
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
assert.Equal(t, "image/jpeg", typedImg.MimeType, "MimeType should be preserved")
// Validate third message (AudioContent)
msg = result[2]
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
// Type assertion for audio content
typedAudio, ok := msg.Content.(typedAudioContent)
require.True(t, ok, "Should be typedAudioContent")
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
assert.Equal(t, "audio/mp3", typedAudio.MimeType, "MimeType should be preserved")
// Validate fourth message (unknown type converted to TextContent)
msg = result[3]
assert.Equal(t, RoleType("system"), msg.Role, "Role should be preserved")
// Should be converted to a typedTextContent with error message
typedUnknown, ok := msg.Content.(typedTextContent)
require.True(t, ok, "Unknown content should be converted to typedTextContent")
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
})
// Test empty input
t.Run("EmptyInput", func(t *testing.T) {
messages := []PromptMessage{}
result := toTypedPromptMessages(messages)
assert.Empty(t, result, "Should return empty slice for empty input")
})
// Test with nil annotations
t.Run("NilAnnotations", func(t *testing.T) {
messages := []PromptMessage{
{
Role: RoleUser,
Content: TextContent{
Text: "Text with nil annotations",
Annotations: nil,
},
},
}
result := toTypedPromptMessages(messages)
require.Len(t, result, 1, "Should return one message")
typed, ok := result[0].Content.(typedTextContent)
require.True(t, ok, "Should be typedTextContent")
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
})
}
// TestToTypedContents tests the toTypedContents function
func TestToTypedContents(t *testing.T) {
// Test with multiple content types in one test
t.Run("MixedContentTypes", func(t *testing.T) {
// Create test data with different content types
contents := []any{
TextContent{
Text: "Hello, this is a text content",
Annotations: &Annotations{
Audience: []RoleType{RoleUser, RoleAssistant},
Priority: ptr(0.7),
},
},
ImageContent{
Data: "base64ImageData",
MimeType: "image/png",
},
AudioContent{
Data: "base64AudioData",
MimeType: "audio/wav",
},
"This is a simple string that should be handled as unknown type",
}
// Call the function
result := toTypedContents(contents)
// Validate results
require.Len(t, result, 4, "Should return the same number of contents")
// Validate first content (TextContent)
typed, ok := result[0].(typedTextContent)
require.True(t, ok, "Should be typedTextContent")
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
assert.Equal(t, "Hello, this is a text content", typed.Text, "Text content should be preserved")
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
assert.Equal(t, 0.7, *typed.Annotations.Priority, "Priority value should be preserved")
// Validate second content (ImageContent)
typedImg, ok := result[1].(typedImageContent)
require.True(t, ok, "Should be typedImageContent")
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
assert.Equal(t, "image/png", typedImg.MimeType, "MimeType should be preserved")
// Validate third content (AudioContent)
typedAudio, ok := result[2].(typedAudioContent)
require.True(t, ok, "Should be typedAudioContent")
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
assert.Equal(t, "audio/wav", typedAudio.MimeType, "MimeType should be preserved")
// Validate fourth content (unknown type converted to TextContent)
typedUnknown, ok := result[3].(typedTextContent)
require.True(t, ok, "Unknown content should be converted to typedTextContent")
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
})
// Test empty input
t.Run("EmptyInput", func(t *testing.T) {
contents := []any{}
result := toTypedContents(contents)
assert.Empty(t, result, "Should return empty slice for empty input")
})
// Test with nil annotations
t.Run("NilAnnotations", func(t *testing.T) {
contents := []any{
TextContent{
Text: "Text with nil annotations",
Annotations: nil,
},
}
result := toTypedContents(contents)
require.Len(t, result, 1, "Should return one content")
typed, ok := result[0].(typedTextContent)
require.True(t, ok, "Should be typedTextContent")
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
})
// Test with custom struct (should be handled as unknown type)
t.Run("CustomStruct", func(t *testing.T) {
type CustomContent struct {
Data string
}
contents := []any{
CustomContent{
Data: "custom data",
},
}
result := toTypedContents(contents)
require.Len(t, result, 1, "Should return one content")
typed, ok := result[0].(typedTextContent)
require.True(t, ok, "Custom struct should be converted to typedTextContent")
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
assert.Contains(t, typed.Text, "Unknown content type:", "Should contain error about unknown type")
assert.Contains(t, typed.Text, "CustomContent", "Should mention the actual type")
})
}

149
mcp/vars.go Normal file
View File

@@ -0,0 +1,149 @@
package mcp
import (
"time"
"github.com/zeromicro/go-zero/core/syncx"
)
// Protocol constants
const (
// JSON-RPC version as defined in the specification
jsonRpcVersion = "2.0"
// Session identifier key used in request URLs
sessionIdKey = "session_id"
// progressTokenKey is used to track progress of long-running tasks
progressTokenKey = "progressToken"
)
// Server-Sent Events (SSE) event types
const (
// Standard message event for JSON-RPC responses
eventMessage = "message"
// Endpoint event for sending endpoint URL to clients
eventEndpoint = "endpoint"
)
// Content type identifiers
const (
// ContentTypeObject is object content type
ContentTypeObject = "object"
// ContentTypeText is text content type
ContentTypeText = "text"
// ContentTypeImage is image content type
ContentTypeImage = "image"
// ContentTypeAudio is audio content type
ContentTypeAudio = "audio"
// ContentTypeResource is resource content type
ContentTypeResource = "resource"
)
// Collection keys for broadcast events
const (
// Key for prompts collection
keyPrompts = "prompts"
// Key for resources collection
keyResources = "resources"
// Key for tools collection
keyTools = "tools"
)
// JSON-RPC error codes
// Standard error codes from JSON-RPC 2.0 spec
const (
// Invalid JSON was received by the server
errCodeInvalidRequest = -32600
// The method does not exist / is not available
errCodeMethodNotFound = -32601
// Invalid method parameter(s)
errCodeInvalidParams = -32602
// Internal JSON-RPC error
errCodeInternalError = -32603
// Tool execution timed out
errCodeTimeout = -32001
// Resource not found error
errCodeResourceNotFound = -32002
// Client hasn't completed initialization
errCodeClientNotInitialized = -32800
)
// User and assistant role definitions
const (
// RoleUser is the "user" role - the entity asking questions
RoleUser RoleType = "user"
// RoleAssistant is the "assistant" role - the entity providing responses
RoleAssistant RoleType = "assistant"
)
// Method names as defined in the MCP specification
const (
// Initialize the connection between client and server
methodInitialize = "initialize"
// List available tools
methodToolsList = "tools/list"
// Call a specific tool
methodToolsCall = "tools/call"
// List available prompts
methodPromptsList = "prompts/list"
// Get a specific prompt
methodPromptsGet = "prompts/get"
// List available resources
methodResourcesList = "resources/list"
// Read a specific resource
methodResourcesRead = "resources/read"
// Subscribe to resource updates
methodResourcesSubscribe = "resources/subscribe"
// Simple ping to check server availability
methodPing = "ping"
// Notification that client is fully initialized
methodNotificationsInitialized = "notifications/initialized"
// Notification that a request was canceled
methodNotificationsCancelled = "notifications/cancelled"
)
// Event names for Server-Sent Events (SSE)
const (
// Notification of tool list changes
eventToolsListChanged = "tools/list_changed"
// Notification of prompt list changes
eventPromptsListChanged = "prompts/list_changed"
// Notification of resource list changes
eventResourcesListChanged = "resources/list_changed"
)
var (
// Default channel size for events
eventChanSize = 10
// Default ping interval for checking connection availability
// use syncx.ForAtomicDuration to ensure atomicity in test race
pingInterval = syncx.ForAtomicDuration(30 * time.Second)
)

210
mcp/vars_test.go Normal file
View File

@@ -0,0 +1,210 @@
// filepath: /Users/kevin/Develop/go/opensource/go-zero/mcp/vars_test.go
package mcp
import (
"encoding/json"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
// TestErrorCodes ensures error codes are applied correctly in error responses
func TestErrorCodes(t *testing.T) {
testCases := []struct {
name string
code int
message string
expected string
}{
{
name: "invalid request error",
code: errCodeInvalidRequest,
message: "Invalid request",
expected: `"code":-32600`,
},
{
name: "method not found error",
code: errCodeMethodNotFound,
message: "Method not found",
expected: `"code":-32601`,
},
{
name: "invalid params error",
code: errCodeInvalidParams,
message: "Invalid parameters",
expected: `"code":-32602`,
},
{
name: "internal error",
code: errCodeInternalError,
message: "Internal server error",
expected: `"code":-32603`,
},
{
name: "timeout error",
code: errCodeTimeout,
message: "Operation timed out",
expected: `"code":-32001`,
},
{
name: "resource not found error",
code: errCodeResourceNotFound,
message: "Resource not found",
expected: `"code":-32002`,
},
{
name: "client not initialized error",
code: errCodeClientNotInitialized,
message: "Client not initialized",
expected: `"code":-32800`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
resp := Response{
JsonRpc: jsonRpcVersion,
ID: int64(1),
Error: &errorObj{
Code: tc.code,
Message: tc.message,
},
}
data, err := json.Marshal(resp)
assert.NoError(t, err)
assert.Contains(t, string(data), tc.expected, "Error code should match expected value")
assert.Contains(t, string(data), tc.message, "Error message should be included")
assert.Contains(t, string(data), jsonRpcVersion, "JSON-RPC version should be included")
})
}
}
// TestJsonRpcVersion ensures the correct JSON-RPC version is used
func TestJsonRpcVersion(t *testing.T) {
assert.Equal(t, "2.0", jsonRpcVersion, "JSON-RPC version should be 2.0")
// Test that it's used in responses
resp := Response{
JsonRpc: jsonRpcVersion,
ID: int64(1),
Result: "test",
}
data, err := json.Marshal(resp)
assert.NoError(t, err)
assert.Contains(t, string(data), `"jsonrpc":"2.0"`, "Response should use correct JSON-RPC version")
// Test that it's expected in requests
reqStr := `{"jsonrpc":"2.0","id":1,"method":"test"}`
var req Request
err = json.Unmarshal([]byte(reqStr), &req)
assert.NoError(t, err)
assert.Equal(t, jsonRpcVersion, req.JsonRpc, "Request should parse correct JSON-RPC version")
}
// TestSessionIdKey ensures session ID extraction works correctly
func TestSessionIdKey(t *testing.T) {
// Create a mock server implementation
mock := newMockMcpServer(t)
defer mock.shutdown()
// Verify the key constant
assert.Equal(t, "session_id", sessionIdKey, "Session ID key should be 'session_id'")
// Test that session ID is extracted correctly
mockR := httptest.NewRequest("GET", "/?"+sessionIdKey+"=test-session", nil)
// Since the mock server is using the same session key logic,
// we can test this by accessing the request query parameters directly
sessionID := mockR.URL.Query().Get(sessionIdKey)
assert.Equal(t, "test-session", sessionID, "Session ID should be extracted correctly")
}
// TestEventTypes ensures event types are set correctly in SSE responses
func TestEventTypes(t *testing.T) {
// Test message event
assert.Equal(t, "message", eventMessage, "Message event should be 'message'")
// Test endpoint event
assert.Equal(t, "endpoint", eventEndpoint, "Endpoint event should be 'endpoint'")
// Verify them in an actual SSE format string
messageEvent := "event: " + eventMessage + "\ndata: test\n\n"
assert.Contains(t, messageEvent, "event: message", "Message event should format correctly")
endpointEvent := "event: " + eventEndpoint + "\ndata: test\n\n"
assert.Contains(t, endpointEvent, "event: endpoint", "Endpoint event should format correctly")
}
// TestCollectionKeys checks that collection keys are used correctly
func TestCollectionKeys(t *testing.T) {
// Verify collection key constants
assert.Equal(t, "prompts", keyPrompts, "Prompts key should be 'prompts'")
assert.Equal(t, "resources", keyResources, "Resources key should be 'resources'")
assert.Equal(t, "tools", keyTools, "Tools key should be 'tools'")
}
// TestRoleTypes checks that role types are used correctly
func TestRoleTypes(t *testing.T) {
// Test in annotations
annotations := Annotations{
Audience: []RoleType{RoleUser, RoleAssistant},
}
data, err := json.Marshal(annotations)
assert.NoError(t, err)
assert.Contains(t, string(data), `"audience":["user","assistant"]`, "Role types should marshal correctly")
}
// TestMethodNames checks that method names are used correctly
func TestMethodNames(t *testing.T) {
// Verify method name constants
methods := map[string]string{
"initialize": methodInitialize,
"tools/list": methodToolsList,
"tools/call": methodToolsCall,
"prompts/list": methodPromptsList,
"prompts/get": methodPromptsGet,
"resources/list": methodResourcesList,
"resources/read": methodResourcesRead,
"resources/subscribe": methodResourcesSubscribe,
"ping": methodPing,
"notifications/initialized": methodNotificationsInitialized,
"notifications/cancelled": methodNotificationsCancelled,
}
for expected, actual := range methods {
assert.Equal(t, expected, actual, "Method name should be "+expected)
}
// Test in a request
for methodName := range methods {
req := Request{
JsonRpc: jsonRpcVersion,
ID: int64(1),
Method: methodName,
}
data, err := json.Marshal(req)
assert.NoError(t, err)
assert.Contains(t, string(data), `"method":"`+methodName+`"`, "Method name should be used in requests")
}
}
// TestEventNames checks that event names are used correctly
func TestEventNames(t *testing.T) {
// Verify event name constants
events := map[string]string{
"tools/list_changed": eventToolsListChanged,
"prompts/list_changed": eventPromptsListChanged,
"resources/list_changed": eventResourcesListChanged,
}
for expected, actual := range events {
assert.Equal(t, expected, actual, "Event name should be "+expected)
}
// Test event names in SSE format
for _, eventName := range events {
sseEvent := "event: " + eventName + "\ndata: test\n\n"
assert.Contains(t, sseEvent, "event: "+eventName, "Event name should format correctly in SSE")
}
}

View File

@@ -6,7 +6,6 @@
[English](readme.md) | 简体中文
[![Go](https://github.com/zeromicro/go-zero/workflows/Go/badge.svg?branch=master)](https://github.com/zeromicro/go-zero/actions)
[![Go Report Card](https://goreportcard.com/badge/github.com/zeromicro/go-zero)](https://goreportcard.com/report/github.com/zeromicro/go-zero)
[![goproxy](https://goproxy.cn/stats/github.com/zeromicro/go-zero/badges/download-count.svg)](https://goproxy.cn/stats/github.com/zeromicro/go-zero/badges/download-count.svg)
[![codecov](https://codecov.io/gh/zeromicro/go-zero/branch/master/graph/badge.svg)](https://codecov.io/gh/zeromicro/go-zero)
@@ -301,6 +300,10 @@ go-zero 已被许多公司用于生产部署,接入场景如在线教育、电
>102. 深圳市兴海物联科技有限公司
>103. 爱芯元智半导体股份有限公司
>104. 杭州升恒科技有限公司
>105. 昆仑万维科技股份有限公司
>106. 无锡盛算信息技术有限公司
>107. 深圳市聚货通信息科技有限公司
>108. 浙江银盾云科技有限公司
如果贵公司也已使用 go-zero欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。

View File

@@ -7,7 +7,6 @@ go-zero is a web and rpc framework with lots of builtin engineering practices. I
<div align=center>
[![Go](https://github.com/zeromicro/go-zero/workflows/Go/badge.svg?branch=master)](https://github.com/zeromicro/go-zero/actions)
[![codecov](https://codecov.io/gh/zeromicro/go-zero/branch/master/graph/badge.svg)](https://codecov.io/gh/zeromicro/go-zero)
[![Go Report Card](https://goreportcard.com/badge/github.com/zeromicro/go-zero)](https://goreportcard.com/report/github.com/zeromicro/go-zero)
[![Release](https://img.shields.io/github/v/release/zeromicro/go-zero.svg?style=flat-square)](https://github.com/zeromicro/go-zero)
@@ -251,7 +250,3 @@ go-zero enlisted in the [CNCF Cloud Native Landscape](https://landscape.cncf.io/
## Give a Star! ⭐
If you like this project or are using it to learn or start your own solution, give it a star to get updates on new releases. Your support matters!
## Buy me a coffee
<a href="https://www.buymeacoffee.com/kevwan" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 60px !important;width: 217px !important;" ></a>

View File

@@ -15,6 +15,7 @@ import (
"github.com/zeromicro/go-zero/rest/handler"
"github.com/zeromicro/go-zero/rest/httpx"
"github.com/zeromicro/go-zero/rest/internal"
"github.com/zeromicro/go-zero/rest/internal/header"
"github.com/zeromicro/go-zero/rest/internal/response"
)
@@ -27,7 +28,10 @@ var ErrSignatureConfig = errors.New("bad config for Signature")
type engine struct {
conf RestConf
routes []featuredRoutes
// timeout is the max timeout of all routes
// timeout is the max timeout of all routes,
// and is used to set http.Server.ReadTimeout and http.Server.WriteTimeout.
// this network timeout is used to avoid DoS attacks by sending data slowly
// or receiving data slowly with many connections to exhaust server resources.
timeout time.Duration
unauthorizedCallback handler.UnauthorizedCallback
unsignedCallback handler.UnsignedCallback
@@ -54,13 +58,26 @@ func newEngine(c RestConf) *engine {
}
func (ng *engine) addRoutes(r featuredRoutes) {
if r.sse {
r.routes = buildSSERoutes(r.routes)
}
ng.routes = append(ng.routes, r)
// need to guarantee the timeout is the max of all routes
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
if r.timeout > ng.timeout {
ng.timeout = r.timeout
ng.mightUpdateTimeout(r)
}
func buildSSERoutes(routes []Route) []Route {
for i, route := range routes {
h := route.Handler
routes[i].Handler = func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(header.ContentType, header.ContentTypeEventStream)
w.Header().Set(header.CacheControl, header.CacheControlNoCache)
w.Header().Set(header.Connection, header.ConnectionKeepAlive)
h(w, r)
}
}
return routes
}
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
@@ -174,11 +191,12 @@ func (ng *engine) checkedMaxBytes(bytes int64) int64 {
return ng.conf.MaxBytes
}
func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
if timeout > 0 {
return timeout
func (ng *engine) checkedTimeout(timeout *time.Duration) time.Duration {
if timeout != nil {
return *timeout
}
// if timeout not set in featured routes, use global timeout
return time.Duration(ng.conf.Timeout) * time.Millisecond
}
@@ -210,6 +228,32 @@ func (ng *engine) getShedder(priority bool) load.Shedder {
return ng.shedder
}
func (ng *engine) hasTimeout() bool {
return ng.conf.Middlewares.Timeout && ng.timeout > 0
}
// mightUpdateTimeout checks if the route timeout is greater than the current,
// and updates the engine's timeout accordingly.
func (ng *engine) mightUpdateTimeout(r featuredRoutes) {
// if global timeout is set to 0, it means no need to set read/write timeout
// if route timeout is nil, no need to update ng.timeout
if ng.timeout == 0 || r.timeout == nil {
return
}
// if route timeout is 0 (means no timeout), cannot set read/write timeout
if *r.timeout == 0 {
ng.timeout = 0
return
}
// need to guarantee the timeout is the max of all routes
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
if *r.timeout > ng.timeout {
ng.timeout = *r.timeout
}
}
// notFoundHandler returns a middleware that handles 404 not found requests.
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -311,7 +355,7 @@ func (ng *engine) start(router httpx.Router, opts ...StartOption) error {
}
// make sure user defined options overwrite default options
opts = append([]StartOption{ng.withTimeout()}, opts...)
opts = append([]StartOption{ng.withNetworkTimeout()}, opts...)
if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, opts...)
@@ -334,18 +378,19 @@ func (ng *engine) use(middleware Middleware) {
ng.middlewares = append(ng.middlewares, middleware)
}
func (ng *engine) withTimeout() internal.StartOption {
func (ng *engine) withNetworkTimeout() internal.StartOption {
return func(svr *http.Server) {
timeout := ng.timeout
if timeout > 0 {
// factor 0.8, to avoid clients send longer content-length than the actual content,
// without this timeout setting, the server will time out and respond 503 Service Unavailable,
// which triggers the circuit breaker.
svr.ReadTimeout = 4 * timeout / 5
// factor 1.1, to avoid servers don't have enough time to write responses.
// setting the factor less than 1.0 may lead clients not receiving the responses.
svr.WriteTimeout = 11 * timeout / 10
if !ng.hasTimeout() {
return
}
// factor 0.8, to avoid clients send longer content-length than the actual content,
// without this timeout setting, the server will time out and respond 503 Service Unavailable,
// which triggers the circuit breaker.
svr.ReadTimeout = 4 * ng.timeout / 5
// factor 1.1, to avoid servers don't have enough time to write responses.
// setting the factor less than 1.0 may lead clients not receiving the responses.
svr.WriteTimeout = 11 * ng.timeout / 10
}
}

View File

@@ -73,7 +73,17 @@ Verbose: true
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
timeout: time.Minute,
timeout: ptrOfDuration(time.Minute),
},
{
jwt: jwtSetting{},
signature: signatureSetting{},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
timeout: ptrOfDuration(0),
},
{
priority: true,
@@ -84,7 +94,7 @@ Verbose: true
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
timeout: time.Second,
timeout: ptrOfDuration(time.Second),
},
{
priority: true,
@@ -227,8 +237,12 @@ Verbose: true
}))
timeout := time.Second * 3
if route.timeout > timeout {
timeout = route.timeout
if route.timeout != nil {
if *route.timeout == 0 {
timeout = 0
} else if *route.timeout > timeout {
timeout = *route.timeout
}
}
assert.Equal(t, timeout, ng.timeout)
})
@@ -236,10 +250,69 @@ Verbose: true
}
}
func TestNewEngine_unsignedCallback(t *testing.T) {
priKeyfile, err := fs.TempFilenameWithText(priKey)
assert.Nil(t, err)
defer os.Remove(priKeyfile)
yaml := `Name: foo
Host: localhost
Port: 0
Middlewares:
Log: false
`
route := featuredRoutes{
priority: true,
jwt: jwtSetting{
enabled: true,
},
signature: signatureSetting{
enabled: true,
SignatureConf: SignatureConf{
Strict: true,
PrivateKeys: []PrivateKeyConf{
{
Fingerprint: "a",
KeyFile: priKeyfile,
},
},
},
},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
}
var index int32
t.Run(fmt.Sprintf("%s-%v", yaml, route.routes), func(t *testing.T) {
var cnf RestConf
assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf))
ng := newEngine(cnf)
if atomic.AddInt32(&index, 1)%2 == 0 {
ng.setUnsignedCallback(func(w http.ResponseWriter, r *http.Request,
next http.Handler, strict bool, code int) {
})
}
ng.addRoutes(route)
ng.use(func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
}
})
assert.NotNil(t, ng.start(mockedRouter{}, func(svr *http.Server) {
}))
assert.Equal(t, time.Duration(time.Second*3), ng.timeout)
})
}
func TestEngine_checkedTimeout(t *testing.T) {
tests := []struct {
name string
timeout time.Duration
timeout *time.Duration
expect time.Duration
}{
{
@@ -248,17 +321,17 @@ func TestEngine_checkedTimeout(t *testing.T) {
},
{
name: "less",
timeout: time.Millisecond * 500,
timeout: ptrOfDuration(time.Millisecond * 500),
expect: time.Millisecond * 500,
},
{
name: "equal",
timeout: time.Second,
timeout: ptrOfDuration(time.Second),
expect: time.Second,
},
{
name: "more",
timeout: time.Millisecond * 1500,
timeout: ptrOfDuration(time.Millisecond * 1500),
expect: time.Millisecond * 1500,
},
}
@@ -394,9 +467,14 @@ func TestEngine_withTimeout(t *testing.T) {
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
ng := newEngine(RestConf{Timeout: test.timeout})
ng := newEngine(RestConf{
Timeout: test.timeout,
Middlewares: MiddlewaresConf{
Timeout: true,
},
})
svr := &http.Server{}
ng.withTimeout()(svr)
ng.withNetworkTimeout()(svr)
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
@@ -406,6 +484,62 @@ func TestEngine_withTimeout(t *testing.T) {
}
}
func TestEngine_ReadWriteTimeout(t *testing.T) {
logx.Disable()
tests := []struct {
name string
timeout int64
middleware bool
}{
{
name: "0/false",
timeout: 0,
middleware: false,
},
{
name: "0/true",
timeout: 0,
middleware: true,
},
{
name: "set/false",
timeout: 1000,
middleware: false,
},
{
name: "both set",
timeout: 1000,
middleware: true,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
ng := newEngine(RestConf{
Timeout: test.timeout,
Middlewares: MiddlewaresConf{
Timeout: test.middleware,
},
})
svr := &http.Server{}
ng.withNetworkTimeout()(svr)
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
assert.Equal(t, time.Duration(0), svr.IdleTimeout)
if test.timeout > 0 && test.middleware {
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*11/10, svr.WriteTimeout)
} else {
assert.Equal(t, time.Duration(0), svr.ReadTimeout)
assert.Equal(t, time.Duration(0), svr.WriteTimeout)
}
})
}
}
func TestEngine_start(t *testing.T) {
logx.Disable()

View File

@@ -106,8 +106,8 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case <-ctx.Done():
tw.mu.Lock()
defer tw.mu.Unlock()
// there isn't any user-defined middleware before TimoutHandler,
// so we can guarantee that cancelation in biz related code won't come here.
// there isn't any user-defined middleware before TimeoutHandler,
// so we can guarantee that cancellation in biz related code won't come here.
httpx.ErrorCtx(r.Context(), w, ctx.Err(), func(w http.ResponseWriter, err error) {
if errors.Is(err, context.Canceled) {
w.WriteHeader(statusClientClosedRequest)
@@ -151,7 +151,7 @@ func (tw *timeoutWriter) Flush() {
flusher.Flush()
}
// Header returns the underline temporary http.Header.
// Header returns the underlying temporary http.Header.
func (tw *timeoutWriter) Header() http.Header {
return tw.h
}

View File

@@ -105,7 +105,7 @@ func buildRequest(ctx context.Context, method, url string, data any) (*http.Requ
req.URL.RawQuery = buildFormQuery(u, val[formKey])
fillHeader(req, val[headerKey])
if hasJsonBody {
req.Header.Set(header.ContentType, header.JsonContentType)
req.Header.Set(header.ContentType, header.ContentTypeJson)
}
return req, nil

View File

@@ -45,7 +45,7 @@ func TestDoRequest_NotFound(t *testing.T) {
defer svr.Close()
req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
assert.Nil(t, err)
req.Header.Set(header.ContentType, header.JsonContentType)
req.Header.Set(header.ContentType, header.ContentTypeJson)
resp, err := DoRequest(req)
assert.Nil(t, err)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)

View File

@@ -18,7 +18,7 @@ func TestParse(t *testing.T) {
}
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("foo", "bar")
w.Header().Set(header.ContentType, header.JsonContentType)
w.Header().Set(header.ContentType, header.ContentTypeJson)
w.Write([]byte(`{"name":"kevin","value":100}`))
}))
defer svr.Close()
@@ -38,7 +38,7 @@ func TestParseHeaderError(t *testing.T) {
}
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("foo", "bar")
w.Header().Set(header.ContentType, header.JsonContentType)
w.Header().Set(header.ContentType, header.ContentTypeJson)
}))
defer svr.Close()
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
@@ -54,7 +54,7 @@ func TestParseNoBody(t *testing.T) {
}
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("foo", "bar")
w.Header().Set(header.ContentType, header.JsonContentType)
w.Header().Set(header.ContentType, header.ContentTypeJson)
}))
defer svr.Close()
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
@@ -72,7 +72,7 @@ func TestParseWithZeroValue(t *testing.T) {
}
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("foo", "0")
w.Header().Set(header.ContentType, header.JsonContentType)
w.Header().Set(header.ContentType, header.ContentTypeJson)
w.Write([]byte(`{"bar":0}`))
}))
defer svr.Close()
@@ -90,7 +90,7 @@ func TestParseWithNegativeContentLength(t *testing.T) {
Bar int `json:"bar"`
}
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(header.ContentType, header.JsonContentType)
w.Header().Set(header.ContentType, header.ContentTypeJson)
w.Write([]byte(`{"bar":0}`))
}))
defer svr.Close()
@@ -124,7 +124,7 @@ func TestParseWithNegativeContentLength(t *testing.T) {
func TestParseWithNegativeContentLengthNoBody(t *testing.T) {
var val struct{}
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(header.ContentType, header.JsonContentType)
w.Header().Set(header.ContentType, header.ContentTypeJson)
}))
defer svr.Close()
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
@@ -156,7 +156,7 @@ func TestParseWithNegativeContentLengthNoBody(t *testing.T) {
func TestParseJsonBody_BodyError(t *testing.T) {
var val struct{}
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(header.ContentType, header.JsonContentType)
w.Header().Set(header.ContentType, header.ContentTypeJson)
}))
defer svr.Close()
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)

View File

@@ -44,7 +44,7 @@ func TestNamedService_DoRequestPost(t *testing.T) {
service := NewService("foo")
req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
assert.Nil(t, err)
req.Header.Set(header.ContentType, header.JsonContentType)
req.Header.Set(header.ContentType, header.ContentTypeJson)
resp, err := service.DoRequest(req)
assert.Nil(t, err)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)

View File

@@ -160,7 +160,7 @@ func TestParseFormArray(t *testing.T) {
http.NoBody)
assert.NoError(t, err)
if assert.NoError(t, Parse(r, &v)) {
assert.ElementsMatch(t, []string{"1", "2", "3"}, v.Names)
assert.ElementsMatch(t, []string{"1,2,3"}, v.Names)
}
})
@@ -189,9 +189,7 @@ func TestParseFormArray(t *testing.T) {
"/a?numbers=1,2,3",
http.NoBody)
assert.NoError(t, err)
if assert.NoError(t, Parse(r, &v)) {
assert.ElementsMatch(t, []int{1, 2, 3}, v.Numbers)
}
assert.Error(t, Parse(r, &v))
})
t.Run("slice with one value on array format brackets", func(t *testing.T) {
@@ -268,6 +266,36 @@ func TestParseFormArray(t *testing.T) {
assert.ElementsMatch(t, []float64{2}, v.Numbers)
}
})
t.Run("slice with one value", func(t *testing.T) {
var v struct {
Codes []string `form:"codes"`
}
r, err := http.NewRequest(
http.MethodGet,
"/a?codes=aaa,bbb,ccc",
http.NoBody)
assert.NoError(t, err)
if assert.NoError(t, Parse(r, &v)) {
assert.ElementsMatch(t, []string{"aaa,bbb,ccc"}, v.Codes)
}
})
t.Run("slice with multiple values", func(t *testing.T) {
var v struct {
Codes []string `form:"codes,arrayComma=false"`
}
r, err := http.NewRequest(
http.MethodGet,
"/a?codes=aaa,bbb,ccc&codes=ccc,ddd,eee",
http.NoBody)
assert.NoError(t, err)
if assert.NoError(t, Parse(r, &v)) {
assert.ElementsMatch(t, []string{"aaa,bbb,ccc", "ccc,ddd,eee"}, v.Codes)
}
})
}
func TestParseForm_Error(t *testing.T) {
@@ -448,7 +476,7 @@ func TestParseJsonBody(t *testing.T) {
body := `{"name":"kevin", "age": 18}`
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
r.Header.Set(ContentType, header.JsonContentType)
r.Header.Set(ContentType, header.ContentTypeJson)
if assert.NoError(t, Parse(r, &v)) {
assert.Equal(t, "kevin", v.Name)
@@ -464,7 +492,7 @@ func TestParseJsonBody(t *testing.T) {
body := `{"name":"kevin", "ag": 18}`
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
r.Header.Set(ContentType, header.JsonContentType)
r.Header.Set(ContentType, header.ContentTypeJson)
assert.Error(t, Parse(r, &v))
})
@@ -489,7 +517,7 @@ func TestParseJsonBody(t *testing.T) {
body := `[{"name":"kevin", "age": 18}]`
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
r.Header.Set(ContentType, header.JsonContentType)
r.Header.Set(ContentType, header.ContentTypeJson)
assert.NoError(t, Parse(r, &v))
assert.Equal(t, 1, len(v))
@@ -509,7 +537,7 @@ func TestParseJsonBody(t *testing.T) {
body := `[{"name":"apple", "age": 18}]`
r := httptest.NewRequest(http.MethodPost, "/a?product=tree", strings.NewReader(body))
r.Header.Set(ContentType, header.JsonContentType)
r.Header.Set(ContentType, header.ContentTypeJson)
assert.NoError(t, Parse(r, &v))
assert.Equal(t, 1, len(v))
@@ -527,7 +555,7 @@ func TestParseJsonBody(t *testing.T) {
body, _ := json.Marshal(v1)
t.Logf("body:%s", string(body))
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(body)))
r.Header.Set(ContentType, header.JsonContentType)
r.Header.Set(ContentType, header.ContentTypeJson)
var v2 v
err := ParseJsonBody(r, &v2)
if assert.NoError(t, err) {
@@ -581,7 +609,7 @@ func TestParseHeaders(t *testing.T) {
request.Header.Add("addrs", "addr2")
request.Header.Add("X-Forwarded-For", "10.0.10.11")
request.Header.Add("x-real-ip", "10.0.11.10")
request.Header.Add("Accept", header.JsonContentType)
request.Header.Add("Accept", header.ContentTypeJson)
err = ParseHeaders(request, &v)
if err != nil {
t.Fatal(err)
@@ -591,7 +619,7 @@ func TestParseHeaders(t *testing.T) {
assert.Equal(t, []string{"addr1", "addr2"}, v.Addrs)
assert.Equal(t, "10.0.10.11", v.XForwardedFor)
assert.Equal(t, "10.0.11.10", v.XRealIP)
assert.Equal(t, header.JsonContentType, v.Accept)
assert.Equal(t, header.ContentTypeJson, v.Accept)
}
func TestParseHeaders_Error(t *testing.T) {
@@ -683,7 +711,7 @@ func TestParseWithFloatPtr(t *testing.T) {
}
body := `{"weightFloat32": 3.2}`
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
r.Header.Set(ContentType, header.JsonContentType)
r.Header.Set(ContentType, header.ContentTypeJson)
if assert.NoError(t, Parse(r, &v)) {
assert.Equal(t, float32(3.2), *v.WeightFloat32)

View File

@@ -179,7 +179,7 @@ func doWriteJson(w http.ResponseWriter, code int, v any) error {
return fmt.Errorf("marshal json failed, error: %w", err)
}
w.Header().Set(ContentType, header.JsonContentType)
w.Header().Set(ContentType, header.ContentTypeJson)
w.WriteHeader(code)
if n, err := w.Write(bs); err != nil {

View File

@@ -10,7 +10,7 @@ const (
// ContentType means Content-Type.
ContentType = header.ContentType
// JsonContentType means application/json.
JsonContentType = header.JsonContentType
JsonContentType = header.ContentTypeJson
// KeyField means key.
KeyField = "key"
// SecretField means secret.

View File

@@ -2,15 +2,16 @@ package fileserver
import (
"net/http"
"path"
"strings"
"sync"
)
// Middleware returns a middleware that serves files from the given file system.
func Middleware(path string, fs http.FileSystem) func(http.HandlerFunc) http.HandlerFunc {
func Middleware(upath string, fs http.FileSystem) func(http.HandlerFunc) http.HandlerFunc {
fileServer := http.FileServer(fs)
pathWithoutTrailSlash := ensureNoTrailingSlash(path)
canServe := createServeChecker(path, fs)
pathWithoutTrailSlash := ensureNoTrailingSlash(upath)
canServe := createServeChecker(upath, fs)
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
@@ -28,9 +29,22 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
var lock sync.RWMutex
fileChecker := make(map[string]bool)
return func(path string) bool {
return func(upath string) bool {
// Emulate http.Dir.Opens path normalization for embed.FS.Open.
// http.FileServer redirects any request ending in "/index.html"
// to the same path without the final "index.html".
// So the path here may be empty or end with a "/".
// http.Dir.Open uses this logic to clean the path,
// correctly handling those two cases.
// embed.FS doesnt perform this normalization, so we apply the same logic here.
upath = path.Clean("/" + upath)[1:]
if len(upath) == 0 {
// if the path is empty, we use "." to open the current directory
upath = "."
}
lock.RLock()
exist, ok := fileChecker[path]
exist, ok := fileChecker[upath]
lock.RUnlock()
if ok {
return exist
@@ -39,9 +53,9 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
lock.Lock()
defer lock.Unlock()
file, err := fs.Open(path)
file, err := fs.Open(upath)
exist = err == nil
fileChecker[path] = exist
fileChecker[upath] = exist
if err != nil {
return false
}
@@ -51,8 +65,8 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
}
}
func createServeChecker(path string, fs http.FileSystem) func(r *http.Request) bool {
pathWithTrailSlash := ensureTrailingSlash(path)
func createServeChecker(upath string, fs http.FileSystem) func(r *http.Request) bool {
pathWithTrailSlash := ensureTrailingSlash(upath)
fileChecker := createFileChecker(fs)
return func(r *http.Request) bool {
@@ -62,18 +76,18 @@ func createServeChecker(path string, fs http.FileSystem) func(r *http.Request) b
}
}
func ensureTrailingSlash(path string) string {
if strings.HasSuffix(path, "/") {
return path
func ensureTrailingSlash(upath string) string {
if strings.HasSuffix(upath, "/") {
return upath
}
return path + "/"
return upath + "/"
}
func ensureNoTrailingSlash(path string) string {
if strings.HasSuffix(path, "/") {
return path[:len(path)-1]
func ensureNoTrailingSlash(upath string) string {
if strings.HasSuffix(upath, "/") {
return upath[:len(upath)-1]
}
return path
return upath
}

View File

@@ -1,6 +1,8 @@
package fileserver
import (
"embed"
"io/fs"
"net/http"
"net/http/httptest"
"testing"
@@ -61,6 +63,46 @@ func TestMiddleware(t *testing.T) {
requestPath: "/ws",
expectedStatus: http.StatusAlreadyReported,
},
// http.FileServer redirects any request ending in "/index.html"
// to the same path, without the final "index.html".
{
name: "Serve index.html",
path: "/static",
dir: "testdata",
requestPath: "/static/index.html",
expectedStatus: http.StatusMovedPermanently,
},
{
name: "Serve index.html with path with trailing slash",
path: "/static/",
dir: "testdata",
requestPath: "/static/index.html",
expectedStatus: http.StatusMovedPermanently,
},
{
name: "Serve index.html in a nested directory",
path: "/static",
dir: "testdata",
requestPath: "/static/nested/index.html",
expectedStatus: http.StatusMovedPermanently,
},
{
name: "Request index.html indirectly",
path: "/static",
dir: "testdata",
requestPath: "/static/",
expectedStatus: http.StatusOK,
expectedContent: "hello",
},
{
name: "Request index.html in a nested directory indirectly",
path: "/static",
dir: "testdata",
requestPath: "/static/nested/",
expectedStatus: http.StatusOK,
expectedContent: "hello",
},
}
for _, tt := range tests {
@@ -87,6 +129,128 @@ func TestMiddleware(t *testing.T) {
}
}
var (
//go:embed testdata
testdataFS embed.FS
)
func TestMiddleware_embedFS(t *testing.T) {
tests := []struct {
name string
path string
requestPath string
expectedStatus int
expectedContent string
}{
{
name: "Serve static file",
path: "/static",
requestPath: "/static/example.txt",
expectedStatus: http.StatusOK,
expectedContent: "1",
},
{
name: "Path with trailing slash",
path: "/static/",
requestPath: "/static/example.txt",
expectedStatus: http.StatusOK,
expectedContent: "1",
},
{
name: "Root path",
path: "/",
requestPath: "/example.txt",
expectedStatus: http.StatusOK,
expectedContent: "1",
},
{
name: "Pass through non-matching path",
path: "/static/",
requestPath: "/other/path",
expectedStatus: http.StatusAlreadyReported,
},
{
name: "Not exist file",
path: "/assets",
requestPath: "/assets/not-exist.txt",
expectedStatus: http.StatusAlreadyReported,
},
{
name: "Not exist file in root",
path: "/",
requestPath: "/not-exist.txt",
expectedStatus: http.StatusAlreadyReported,
},
{
name: "websocket request",
path: "/",
requestPath: "/ws",
expectedStatus: http.StatusAlreadyReported,
},
// http.FileServer redirects any request ending in "/index.html"
// to the same path, without the final "index.html".
{
name: "Serve index.html",
path: "/static",
requestPath: "/static/index.html",
expectedStatus: http.StatusMovedPermanently,
},
{
name: "Serve index.html with path with trailing slash",
path: "/static/",
requestPath: "/static/index.html",
expectedStatus: http.StatusMovedPermanently,
},
{
name: "Serve index.html in a nested directory",
path: "/static",
requestPath: "/static/nested/index.html",
expectedStatus: http.StatusMovedPermanently,
},
{
name: "Request index.html indirectly",
path: "/static",
requestPath: "/static/",
expectedStatus: http.StatusOK,
expectedContent: "hello",
},
{
name: "Request index.html in a nested directory indirectly",
path: "/static",
requestPath: "/static/nested/",
expectedStatus: http.StatusOK,
expectedContent: "hello",
},
}
subFS, err := fs.Sub(testdataFS, "testdata")
assert.Nil(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
middleware := Middleware(tt.path, http.FS(subFS))
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusAlreadyReported)
})
handlerToTest := middleware(nextHandler)
for i := 0; i < 2; i++ {
req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil)
rr := httptest.NewRecorder()
handlerToTest.ServeHTTP(rr, req)
assert.Equal(t, tt.expectedStatus, rr.Code)
if len(tt.expectedContent) > 0 {
assert.Equal(t, tt.expectedContent, rr.Body.String())
}
}
})
}
}
func TestEnsureTrailingSlash(t *testing.T) {
tests := []struct {
input string

View File

@@ -0,0 +1 @@
hello

View File

@@ -0,0 +1 @@
hello

View File

@@ -3,8 +3,18 @@ package header
const (
// ApplicationJson stands for application/json.
ApplicationJson = "application/json"
// CacheControl is the header key for Cache-Control.
CacheControl = "Cache-Control"
// CacheControlNoCache is the value for Cache-Control: no-cache.
CacheControlNoCache = "no-cache"
// Connection is the header key for Connection.
Connection = "Connection"
// ConnectionKeepAlive is the value for Connection: keep-alive.
ConnectionKeepAlive = "keep-alive"
// ContentType is the header key for Content-Type.
ContentType = "Content-Type"
// JsonContentType is the content type for JSON.
JsonContentType = "application/json; charset=utf-8"
// ContentTypeJson is the content type for JSON.
ContentTypeJson = "application/json; charset=utf-8"
// ContentTypeEventStream is the content type for event stream.
ContentTypeEventStream = "text/event-stream"
)

View File

@@ -628,7 +628,7 @@ func TestParseWrappedRequest(t *testing.T) {
func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017", bytes.NewReader(nil))
assert.Nil(t, err)
r.Header.Set(httpx.ContentType, header.JsonContentType)
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
type (
Request struct {
@@ -661,7 +661,7 @@ func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) {
func TestParseWrappedHeadRequestWithJsonHeader(t *testing.T) {
r, err := http.NewRequest(http.MethodHead, "http://hello.com/kevin/2017", bytes.NewReader(nil))
assert.Nil(t, err)
r.Header.Set(httpx.ContentType, header.JsonContentType)
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
type (
Request struct {
@@ -758,7 +758,7 @@ func TestParseWithAllUtf8(t *testing.T) {
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000",
bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`))
assert.Nil(t, err)
r.Header.Set(httpx.ContentType, header.JsonContentType)
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
@@ -948,7 +948,7 @@ func TestParseWithMissingAllPaths(t *testing.T) {
func TestParseGetWithContentLengthHeader(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil)
assert.Nil(t, err)
r.Header.Set(httpx.ContentType, header.JsonContentType)
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
r.Header.Set(contentLength, "1024")
router := NewRouter()
@@ -976,7 +976,7 @@ func TestParseJsonPostWithTypeMismatch(t *testing.T) {
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000",
bytes.NewBufferString(`{"time": "20170912"}`))
assert.Nil(t, err)
r.Header.Set(httpx.ContentType, header.JsonContentType)
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
@@ -1002,7 +1002,7 @@ func TestParseJsonPostWithInt2String(t *testing.T) {
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017",
bytes.NewBufferString(`{"time": 20170912}`))
assert.Nil(t, err)
r.Header.Set(httpx.ContentType, header.JsonContentType)
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(

View File

@@ -63,6 +63,11 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
return server, nil
}
// AddRoute adds given route into the Server.
func (s *Server) AddRoute(r Route, opts ...RouteOption) {
s.AddRoutes([]Route{r}, opts...)
}
// AddRoutes add given routes into the Server.
func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) {
r := featuredRoutes{
@@ -74,11 +79,6 @@ func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) {
s.ngin.addRoutes(r)
}
// AddRoute adds given route into the Server.
func (s *Server) AddRoute(r Route, opts ...RouteOption) {
s.AddRoutes([]Route{r}, opts...)
}
// PrintRoutes prints the added routes to stdout.
func (s *Server) PrintRoutes() {
s.ngin.print()
@@ -95,25 +95,6 @@ func (s *Server) Routes() []Route {
return routes
}
// ServeHTTP is for test purpose, allow developer to do a unit test with
// all defined router without starting an HTTP Server.
//
// For example:
//
// server := MustNewServer(...)
// server.addRoute(...) // router a
// server.addRoute(...) // router b
// server.addRoute(...) // router c
//
// r, _ := http.NewRequest(...)
// w := httptest.NewRecorder(...)
// server.ServeHTTP(w, r)
// // verify the response
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.ngin.bindRoutes(s.router)
s.router.ServeHTTP(w, r)
}
// Start starts the Server.
// Graceful shutdown is enabled by default.
// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
@@ -138,6 +119,16 @@ func (s *Server) Use(middleware Middleware) {
s.ngin.use(middleware)
}
// build builds the Server and binds the routes to the router.
func (s *Server) build() error {
return s.ngin.bindRoutes(s.router)
}
// serve serves the HTTP requests using the Server's router.
func (s *Server) serve(w http.ResponseWriter, r *http.Request) {
s.router.ServeHTTP(w, r)
}
// ToMiddleware converts the given handler to a Middleware.
func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
return func(handle http.HandlerFunc) http.HandlerFunc {
@@ -298,10 +289,18 @@ func WithSignature(signature SignatureConf) RouteOption {
}
}
// WithSSE returns a RouteOption to enable server-sent events.
func WithSSE() RouteOption {
return func(r *featuredRoutes) {
r.sse = true
r.timeout = ptrOfDuration(0)
}
}
// WithTimeout returns a RouteOption to set timeout with given value.
func WithTimeout(timeout time.Duration) RouteOption {
return func(r *featuredRoutes) {
r.timeout = timeout
r.timeout = &timeout
}
}
@@ -336,6 +335,10 @@ func handleError(err error) {
panic(err)
}
func ptrOfDuration(d time.Duration) *time.Duration {
return &d
}
func validateSecret(secret string) {
if len(secret) < 8 {
panic("secret's length can't be less than 8")

View File

@@ -20,6 +20,7 @@ import (
"github.com/zeromicro/go-zero/rest/chain"
"github.com/zeromicro/go-zero/rest/httpx"
"github.com/zeromicro/go-zero/rest/internal/cors"
"github.com/zeromicro/go-zero/rest/internal/header"
"github.com/zeromicro/go-zero/rest/router"
)
@@ -231,7 +232,7 @@ func TestWithFileServerMiddleware(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil)
rr := httptest.NewRecorder()
server.ServeHTTP(rr, req)
serve(server, rr, req)
assert.Equal(t, tt.expectedStatus, rr.Code)
if len(tt.expectedContent) > 0 {
@@ -344,7 +345,7 @@ func TestWithPriority(t *testing.T) {
func TestWithTimeout(t *testing.T) {
var fr featuredRoutes
WithTimeout(time.Hour)(&fr)
assert.Equal(t, time.Hour, fr.timeout)
assert.Equal(t, time.Hour, *fr.timeout)
}
func TestWithTLSConfig(t *testing.T) {
@@ -458,7 +459,7 @@ Port: 54321
// we would need to verify the behavior here. Since we don't have
// direct access to headers, we'll mock newCorsRouter to capture it.
w := httptest.NewRecorder()
svr.ServeHTTP(w, httptest.NewRequest(http.MethodOptions, "/", nil))
serve(svr, w, httptest.NewRequest(http.MethodOptions, "/", nil))
vals := w.Header().Values("Access-Control-Allow-Headers")
respHeaders := make(map[string]struct{})
@@ -748,12 +749,46 @@ Port: 54321
t.Run(test.name, func(t *testing.T) {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", test.path, nil)
svr.ServeHTTP(w, req)
serve(svr, w, req)
assert.Equal(t, test.code, w.Code)
})
}
}
func TestServerEventStream(t *testing.T) {
server := MustNewServer(RestConf{})
server.AddRoutes([]Route{
{
Method: http.MethodGet,
Path: "/foo",
Handler: func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("foo"))
},
},
{
Method: http.MethodGet,
Path: "/bar",
Handler: func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("bar"))
},
},
}, WithSSE())
check := func(val string) {
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%s", val), http.NoBody)
assert.Nil(t, err)
rr := httptest.NewRecorder()
serve(server, rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, header.ContentTypeEventStream, rr.Header().Get(header.ContentType))
assert.Equal(t, header.CacheControlNoCache, rr.Header().Get(header.CacheControl))
assert.Equal(t, header.ConnectionKeepAlive, rr.Header().Get(header.Connection))
assert.Equal(t, val, rr.Body.String())
}
check("foo")
check("bar")
}
//go:embed testdata
var content embed.FS
@@ -765,6 +800,25 @@ func TestServerEmbedFileSystem(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/assets/sample.txt", http.NoBody)
assert.Nil(t, err)
rr := httptest.NewRecorder()
server.ServeHTTP(rr, req)
serve(server, rr, req)
assert.Equal(t, sampleContent, rr.Body.String())
}
// serve is for test purpose, allow developer to do a unit test with
// all defined routes without starting an HTTP Server.
//
// For example:
//
// server := MustNewServer(...)
// server.addRoute(...) // router a
// server.addRoute(...) // router b
// server.addRoute(...) // router c
//
// r, _ := http.NewRequest(...)
// w := httptest.NewRecorder(...)
// serve(server, w, r)
// // verify the response
func serve(s *Server, w http.ResponseWriter, r *http.Request) {
_ = s.build()
s.serve(w, r)
}

27
rest/serverless.go Normal file
View File

@@ -0,0 +1,27 @@
package rest
import "net/http"
// Serverless is a wrapper around Server that allows it to be used in serverless environments.
type Serverless struct {
server *Server
}
// NewServerless creates a new Serverless instance from the provided Server.
func NewServerless(server *Server) (*Serverless, error) {
// Ensure the server is built before using it in a serverless context.
// Why not call server.build() when serving requests,
// is because we need to ensure fail fast behavior.
if err := server.build(); err != nil {
return nil, err
}
return &Serverless{
server: server,
}, nil
}
// Serve handles HTTP requests by delegating them to the underlying Server instance.
func (s *Serverless) Serve(w http.ResponseWriter, r *http.Request) {
s.server.serve(w, r)
}

67
rest/serverless_test.go Normal file
View File

@@ -0,0 +1,67 @@
package rest
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/conf"
"github.com/zeromicro/go-zero/core/logx/logtest"
)
func TestNewServerless(t *testing.T) {
logtest.Discard(t)
const configYaml = `
Name: foo
Host: localhost
Port: 0
`
var cnf RestConf
assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
svr, err := NewServer(cnf)
assert.NoError(t, err)
svr.AddRoute(Route{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello World"))
},
})
serverless, err := NewServerless(svr)
assert.NoError(t, err)
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
serverless.Serve(w, r)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "Hello World", w.Body.String())
}
func TestNewServerlessWithError(t *testing.T) {
logtest.Discard(t)
const configYaml = `
Name: foo
Host: localhost
Port: 0
`
var cnf RestConf
assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
svr, err := NewServer(cnf)
assert.NoError(t, err)
svr.AddRoute(Route{
Method: http.MethodGet,
Path: "notstartwith/",
Handler: nil,
})
_, err = NewServerless(svr)
assert.Error(t, err)
}

View File

@@ -31,10 +31,11 @@ type (
}
featuredRoutes struct {
timeout time.Duration
timeout *time.Duration
priority bool
jwt jwtSetting
signature signatureSetting
sse bool
routes []Route
maxBytes int64
}

1
tools/goctl/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
dist

View File

@@ -3,7 +3,8 @@ FROM golang:alpine AS builder
LABEL stage=gobuilder
ENV CGO_ENABLED=0
ENV GOPROXY=https://goproxy.cn,direct
# if you are in China, you can use the following command to speed up the download
# ENV GOPROXY=https://goproxy.cn,direct
RUN apk update --no-cache && apk add --no-cache tzdata
RUN go install google.golang.org/protobuf/cmd/protoc-gen-go@latest

View File

@@ -2,6 +2,7 @@ package api
import (
"github.com/spf13/cobra"
"github.com/zeromicro/go-zero/tools/goctl/api/apigen"
"github.com/zeromicro/go-zero/tools/goctl/api/dartgen"
"github.com/zeromicro/go-zero/tools/goctl/api/docgen"
@@ -10,6 +11,7 @@ import (
"github.com/zeromicro/go-zero/tools/goctl/api/javagen"
"github.com/zeromicro/go-zero/tools/goctl/api/ktgen"
"github.com/zeromicro/go-zero/tools/goctl/api/new"
"github.com/zeromicro/go-zero/tools/goctl/api/swagger"
"github.com/zeromicro/go-zero/tools/goctl/api/tsgen"
"github.com/zeromicro/go-zero/tools/goctl/api/validate"
"github.com/zeromicro/go-zero/tools/goctl/config"
@@ -31,6 +33,7 @@ var (
ktCmd = cobrax.NewCommand("kt", cobrax.WithRunE(ktgen.KtCommand))
pluginCmd = cobrax.NewCommand("plugin", cobrax.WithRunE(plugin.PluginCommand))
tsCmd = cobrax.NewCommand("ts", cobrax.WithRunE(tsgen.TsCommand))
swaggerCmd = cobrax.NewCommand("swagger", cobrax.WithRunE(swagger.Command))
)
func init() {
@@ -46,6 +49,7 @@ func init() {
pluginCmdFlags = pluginCmd.Flags()
tsCmdFlags = tsCmd.Flags()
validateCmdFlags = validateCmd.Flags()
swaggerCmdFlags = swaggerCmd.Flags()
)
apiCmdFlags.StringVar(&apigen.VarStringOutput, "o")
@@ -73,6 +77,7 @@ func init() {
goCmdFlags.StringVar(&gogen.VarStringRemote, "remote")
goCmdFlags.StringVar(&gogen.VarStringBranch, "branch")
goCmdFlags.BoolVar(&gogen.VarBoolWithTest, "test")
goCmdFlags.BoolVar(&gogen.VarBoolTypeGroup, "type-group")
goCmdFlags.StringVarWithDefaultValue(&gogen.VarStringStyle, "style", config.DefaultFormat)
javaCmdFlags.StringVar(&javagen.VarStringDir, "dir")
@@ -97,8 +102,13 @@ func init() {
tsCmdFlags.StringVar(&tsgen.VarStringCaller, "caller")
tsCmdFlags.BoolVar(&tsgen.VarBoolUnWrap, "unwrap")
swaggerCmdFlags.StringVar(&swagger.VarStringAPI, "api")
swaggerCmdFlags.StringVar(&swagger.VarStringDir, "dir")
swaggerCmdFlags.StringVar(&swagger.VarStringFilename, "filename")
swaggerCmdFlags.BoolVar(&swagger.VarBoolYaml, "yaml")
validateCmdFlags.StringVar(&validate.VarStringAPI, "api")
// Add sub-commands
Cmd.AddCommand(dartCmd, docCmd, formatCmd, goCmd, javaCmd, ktCmd, newCmd, pluginCmd, tsCmd, validateCmd)
Cmd.AddCommand(dartCmd, docCmd, formatCmd, goCmd, javaCmd, ktCmd, newCmd, pluginCmd, tsCmd, validateCmd, swaggerCmd)
}

View File

@@ -10,10 +10,10 @@ const (
import 'package:shared_preferences/shared_preferences.dart';
import '../data/tokens.dart';
/// 保存tokens到本地
/// store tokens to local
///
/// 传入null则删除本地tokens
/// 返回true设置成功 false:设置失败
/// pass null will clean local stored tokens
/// returns true if success, otherwise false
Future<bool> setTokens(Tokens tokens) async {
var sp = await SharedPreferences.getInstance();
if (tokens == null) {
@@ -23,9 +23,9 @@ Future<bool> setTokens(Tokens tokens) async {
return await sp.setString('tokens', jsonEncode(tokens.toJson()));
}
/// 获取本地存储的tokens
/// get local stored tokens
///
/// 如果没有,则返回null
/// if no, returns null
Future<Tokens> getTokens() async {
try {
var sp = await SharedPreferences.getInstance();
@@ -82,7 +82,8 @@ func genVars(dir string, isLegacy bool, scheme string, hostname string) error {
}
if !fileExists(dir + "vars.dart") {
err = os.WriteFile(dir+"vars.dart", []byte(fmt.Sprintf(`const serverHost='%s://%s';`, scheme, hostname)), 0o644)
err = os.WriteFile(dir+"vars.dart", []byte(fmt.Sprintf(`const serverHost='%s://%s';`,
scheme, hostname)), 0o644)
if err != nil {
return err
}

View File

@@ -42,8 +42,19 @@ var (
func GoFormatApi(_ *cobra.Command, _ []string) error {
var be errorx.BatchError
if VarBoolUseStdin {
if err := apiFormatReader(os.Stdin, VarStringDir, VarBoolSkipCheckDeclare); err != nil {
be.Add(err)
if env.UseExperimental() {
data, err := io.ReadAll(os.Stdin)
if err != nil {
be.Add(err)
} else {
if err := apiF.Source(data, os.Stdout); err != nil {
be.Add(err)
}
}
} else {
if err := apiFormatReader(os.Stdin, VarStringDir, VarBoolSkipCheckDeclare); err != nil {
be.Add(err)
}
}
} else {
if len(VarStringDir) == 0 {

View File

@@ -40,6 +40,8 @@ var (
// VarStringStyle describes the style of output files.
VarStringStyle string
VarBoolWithTest bool
// VarBoolTypeGroup describes whether to group types.
VarBoolTypeGroup bool
)
// GoCommand gen go project files from command line

View File

@@ -6,8 +6,10 @@ import (
"io"
"os"
"path"
"sort"
"strings"
"github.com/zeromicro/go-zero/core/collection"
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util"
"github.com/zeromicro/go-zero/tools/goctl/config"
@@ -39,20 +41,152 @@ func BuildTypes(types []spec.Type) (string, error) {
return builder.String(), nil
}
func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
val, err := BuildTypes(api.Types)
func getTypeName(tp spec.Type) string {
if tp == nil {
return ""
}
switch val := tp.(type) {
case spec.DefineStruct:
typeName := util.Title(tp.Name())
return typeName
case spec.PointerType:
return getTypeName(val.Type)
case spec.ArrayType:
return getTypeName(val.Value)
}
return ""
}
func genTypesWithGroup(dir string, cfg *config.Config, api *spec.ApiSpec) error {
groupTypes := make(map[string]map[string]spec.Type)
typesBelongToFiles := make(map[string]*collection.Set)
for _, v := range api.Service.Groups {
group := v.GetAnnotation(groupProperty)
if len(group) == 0 {
group = groupTypeDefault
}
// convert filepath to Identifier name spec.
group = strings.TrimPrefix(group, "/")
group = strings.TrimSuffix(group, "/")
group = util.SafeString(group)
for _, v := range v.Routes {
requestTypeName := getTypeName(v.RequestType)
responseTypeName := getTypeName(v.ResponseType)
requestTypeFileSet, ok := typesBelongToFiles[requestTypeName]
if !ok {
requestTypeFileSet = collection.NewSet()
}
if len(requestTypeName) > 0 {
requestTypeFileSet.AddStr(group)
typesBelongToFiles[requestTypeName] = requestTypeFileSet
}
responseTypeFileSet, ok := typesBelongToFiles[responseTypeName]
if !ok {
responseTypeFileSet = collection.NewSet()
}
if len(responseTypeName) > 0 {
responseTypeFileSet.AddStr(group)
typesBelongToFiles[responseTypeName] = responseTypeFileSet
}
}
}
typesInOneFile := make(map[string]*collection.Set)
for typeName, fileSet := range typesBelongToFiles {
count := fileSet.Count()
switch {
case count == 0: // it means there has no structure type or no request/response body
continue
case count == 1: // it means a structure type used in only one group.
groupName := fileSet.KeysStr()[0]
typeSet, ok := typesInOneFile[groupName]
if !ok {
typeSet = collection.NewSet()
}
typeSet.AddStr(typeName)
typesInOneFile[groupName] = typeSet
default: // it means this type is used in multiple groups.
continue
}
}
for _, v := range api.Types {
typeName := util.Title(v.Name())
groupSet, ok := typesBelongToFiles[typeName]
var typeCount int
if !ok {
typeCount = 0
} else {
typeCount = groupSet.Count()
}
if typeCount == 0 { // not belong to any group
types, ok := groupTypes[groupTypeDefault]
if !ok {
types = make(map[string]spec.Type)
}
types[typeName] = v
groupTypes[groupTypeDefault] = types
continue
}
if typeCount == 1 { // belong to one group
groupName := groupSet.KeysStr()[0]
types, ok := groupTypes[groupName]
if !ok {
types = make(map[string]spec.Type)
}
types[typeName] = v
groupTypes[groupName] = types
continue
}
// belong to multiple groups
types, ok := groupTypes[groupTypeDefault]
if !ok {
types = make(map[string]spec.Type)
}
types[typeName] = v
groupTypes[groupTypeDefault] = types
}
for group, typeGroup := range groupTypes {
var types []spec.Type
for _, v := range typeGroup {
types = append(types, v)
}
sort.Slice(types, func(i, j int) bool {
return types[i].Name() < types[j].Name()
})
if err := writeTypes(dir, group, cfg, types); err != nil {
return err
}
}
return nil
}
func writeTypes(dir, baseFilename string, cfg *config.Config, types []spec.Type) error {
if len(types) == 0 {
return nil
}
val, err := BuildTypes(types)
if err != nil {
return err
}
typeFilename, err := format.FileNamingFormat(cfg.NamingFormat, typesFile)
typeFilename, err := format.FileNamingFormat(cfg.NamingFormat, baseFilename)
if err != nil {
return err
}
typeFilename = typeFilename + ".go"
filename := path.Join(dir, typesDir, typeFilename)
os.Remove(filename)
_ = os.Remove(filename)
return genFile(fileGenConfig{
dir: dir,
@@ -70,6 +204,13 @@ func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
})
}
func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
if VarBoolTypeGroup {
return genTypesWithGroup(dir, cfg, api)
}
return writeTypes(dir, typesFile, cfg, api.Types)
}
func writeType(writer io.Writer, tp spec.Type) error {
structType, ok := tp.(spec.DefineStruct)
if !ok {

View File

@@ -10,4 +10,6 @@ const (
middlewareDir = internal + "middleware"
typesDir = internal + typesPacket
groupProperty = "group"
groupTypeDefault="types"
)

View File

@@ -8,10 +8,10 @@ import (
"fmt"
"io"
"path"
"slices"
"strings"
"text/template"
"github.com/zeromicro/go-zero/core/stringx"
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util"
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
@@ -96,13 +96,13 @@ func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type
for _, item := range c.responseTypes {
if item.Name() == defineStruct.Name() {
superClassName = "HttpResponseData"
if !stringx.Contains(c.imports, httpResponseData) {
if !slices.Contains(c.imports, httpResponseData) {
c.imports = append(c.imports, httpResponseData)
}
break
}
}
if superClassName == "HttpData" && !stringx.Contains(c.imports, httpData) {
if superClassName == "HttpData" && !slices.Contains(c.imports, httpData) {
c.imports = append(c.imports, httpData)
}
@@ -266,7 +266,7 @@ func (c *componentsContext) genGetSet(writer io.Writer, indent int) error {
tyString := javaType
decorator := ""
javaPrimitiveType := []string{"int", "long", "boolean", "float", "double", "short"}
if !stringx.Contains(javaPrimitiveType, javaType) {
if !slices.Contains(javaPrimitiveType, javaType) {
if member.IsOptional() || member.IsOmitEmpty() {
decorator = "@Nullable "
} else {

View File

@@ -3,9 +3,9 @@ package spec
import (
"errors"
"path"
"slices"
"strings"
"github.com/zeromicro/go-zero/core/stringx"
"github.com/zeromicro/go-zero/tools/goctl/util"
)
@@ -57,14 +57,14 @@ func (m Member) Tags() []*Tag {
// IsOptional returns true if tag is optional
func (m Member) IsOptional() bool {
if !m.IsBodyMember() {
if !m.IsBodyMember() && !m.IsFormMember() {
return false
}
tag := m.Tags()
for _, item := range tag {
if item.Key == bodyTagKey {
if stringx.Contains(item.Options, "optional") {
if item.Key == bodyTagKey || item.Key == formTagKey {
if slices.Contains(item.Options, "optional") {
return true
}
}
@@ -81,7 +81,7 @@ func (m Member) IsOmitEmpty() bool {
tag := m.Tags()
for _, item := range tag {
if item.Key == bodyTagKey {
if stringx.Contains(item.Options, "omitempty") {
if slices.Contains(item.Options, "omitempty") {
return true
}
}
@@ -93,7 +93,7 @@ func (m Member) IsOmitEmpty() bool {
func (m Member) GetPropertyName() (string, error) {
tags := m.Tags()
for _, tag := range tags {
if stringx.Contains(definedKeys, tag.Key) {
if slices.Contains(definedKeys, tag.Key) {
if tag.Name == "-" {
return util.Untitle(m.Name), nil
}

View File

@@ -21,7 +21,7 @@ type (
// ApiSpec describes an api file
ApiSpec struct {
Info Info // Deprecated: useless expression
Info Info
Syntax ApiSyntax // Deprecated: useless expression
Imports []Import // Deprecated: useless expression
Types []Type
@@ -59,11 +59,11 @@ type (
// Member describes the field of a structure
Member struct {
Name string
// 数据类型字面值,如:string、map[int]string、[]int64、[]*User
// data type, for example, string、map[int]string、[]int64、[]*User
Type Type
Tag string
Comment string
// 成员头顶注释说明
// document for the field
Docs Doc
IsInline bool
}

View File

@@ -49,6 +49,9 @@ func Parse(tag string) (*Tags, error) {
// Get gets tag value by specified key
func (t *Tags) Get(key string) (*Tag, error) {
if t == nil {
return nil, errTagNotExist
}
for _, tag := range t.tags {
if tag.Key == key {
return tag, nil
@@ -60,6 +63,9 @@ func (t *Tags) Get(key string) (*Tag, error) {
// Keys returns all keys in Tags
func (t *Tags) Keys() []string {
if t == nil {
return []string{}
}
var keys []string
for _, tag := range t.tags {
keys = append(keys, tag.Key)
@@ -69,5 +75,8 @@ func (t *Tags) Keys() []string {
// Tags returns all tags in Tags
func (t *Tags) Tags() []*Tag {
if t == nil {
return []*Tag{}
}
return t.tags
}

Some files were not shown because too many files have changed in this diff Show More