From 5b74b9ab7b970dcc43100cf69b15a4543c8c5853 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Sat, 25 Apr 2026 17:11:04 +0800 Subject: [PATCH] feat(mcp): add opt-in request metadata bridge for tool handlers (#5550) --- mcp/options.go | 33 +++++++ mcp/readme.md | 30 ++++++ mcp/request_metadata.go | 150 ++++++++++++++++++++++++++++ mcp/request_metadata_test.go | 185 +++++++++++++++++++++++++++++++++++ mcp/server.go | 32 +++++- mcp/server_test.go | 148 ++++++++++++++++++++++++++++ 6 files changed, 576 insertions(+), 2 deletions(-) create mode 100644 mcp/options.go create mode 100644 mcp/request_metadata.go create mode 100644 mcp/request_metadata_test.go diff --git a/mcp/options.go b/mcp/options.go new file mode 100644 index 000000000..09d70df75 --- /dev/null +++ b/mcp/options.go @@ -0,0 +1,33 @@ +package mcp + +import "net/http" + +// RequestMetadataExtractor extracts request metadata for downstream handlers. +type RequestMetadataExtractor func(*http.Request) RequestMetadata + +// McpOption customizes MCP server construction. +type McpOption interface { + apply(*serverOptions) +} + +type mcpOptionFunc func(*serverOptions) + +func (f mcpOptionFunc) apply(opts *serverOptions) { + f(opts) +} + +type serverOptions struct { + requestMetadataExtractor RequestMetadataExtractor +} + +func defaultServerOptions() serverOptions { + return serverOptions{} +} + +// WithRequestMetadataExtractor installs an extractor that runs for each incoming +// MCP HTTP request, and stores the extracted metadata into handler context. +func WithRequestMetadataExtractor(extractor RequestMetadataExtractor) McpOption { + return mcpOptionFunc(func(opts *serverOptions) { + opts.requestMetadataExtractor = extractor + }) +} diff --git a/mcp/readme.md b/mcp/readme.md index 1c0eec860..0f2664af5 100644 --- a/mcp/readme.md +++ b/mcp/readme.md @@ -15,6 +15,7 @@ This package provides a go-zero integration for the [Model Context Protocol (MCP - **CORS Support**: Configurable CORS settings for cross-origin requests - **Type-Safe Tool Handlers**: Generic tool handlers with automatic JSON schema generation - **Prompts and Resources**: Full support for MCP prompts and resources +- **Request Metadata Bridge**: Optional request metadata extraction into handler context ## Quick Start @@ -220,6 +221,35 @@ mcp: messageEndpoint: /message ``` +## Request Metadata Bridge + +For multi-tenant or request-context-aware tools, you can extract selected HTTP request metadata once at the transport boundary and read it from `context.Context` in handlers. + +```go +server := mcp.NewMcpServerWithOptions(c, + mcp.WithRequestMetadataExtractor(mcp.DefaultRequestMetadataExtractor), +) + +handler := func(ctx context.Context, req *mcp.CallToolRequest, args SomeArgs) (*mcp.CallToolResult, any, error) { + tenant, _ := mcp.HeaderFromContext(ctx, "X-Tenant-Id") + traceID, _ := mcp.QueryFromContext(ctx, "trace") + scope, _ := mcp.PathFromContext(ctx, "scope") + + _ = tenant + _ = traceID + _ = scope + + return &mcp.CallToolResult{}, nil, nil +} +``` + +Available helpers: + +- `RequestMetadataFromContext(ctx)` +- `HeaderFromContext(ctx, key)` +- `QueryFromContext(ctx, key)` +- `PathFromContext(ctx, key)` + ## Configuration Options | Field | Type | Default | Description | diff --git a/mcp/request_metadata.go b/mcp/request_metadata.go new file mode 100644 index 000000000..eeb1a48ca --- /dev/null +++ b/mcp/request_metadata.go @@ -0,0 +1,150 @@ +package mcp + +import ( + "context" + "net/http" + + "github.com/zeromicro/go-zero/rest/pathvar" +) + +// RequestMetadata carries selected request-scoped values into MCP handlers. +type RequestMetadata struct { + Headers map[string][]string + Query map[string][]string + Path map[string]string +} + +type requestMetadataCtxKey struct{} + +// RequestMetadataFromContext returns metadata extracted at the transport boundary. +func RequestMetadataFromContext(ctx context.Context) (RequestMetadata, bool) { + metadata, ok := requestMetadataFromContext(ctx) + if !ok { + return RequestMetadata{}, false + } + + return normalizeRequestMetadata(metadata), true +} + +// HeaderFromContext returns the first header value for key. +func HeaderFromContext(ctx context.Context, key string) (string, bool) { + metadata, ok := requestMetadataFromContext(ctx) + if !ok { + return "", false + } + + vals := metadata.Headers[http.CanonicalHeaderKey(key)] + if len(vals) == 0 { + return "", false + } + + return vals[0], true +} + +// QueryFromContext returns the first query value for key. +func QueryFromContext(ctx context.Context, key string) (string, bool) { + metadata, ok := requestMetadataFromContext(ctx) + if !ok { + return "", false + } + + vals := metadata.Query[key] + if len(vals) == 0 { + return "", false + } + + return vals[0], true +} + +// PathFromContext returns the path variable value for key. +func PathFromContext(ctx context.Context, key string) (string, bool) { + metadata, ok := requestMetadataFromContext(ctx) + if !ok { + return "", false + } + + val, ok := metadata.Path[key] + if !ok { + return "", false + } + + return val, true +} + +func requestMetadataFromContext(ctx context.Context) (RequestMetadata, bool) { + metadata, ok := ctx.Value(requestMetadataCtxKey{}).(RequestMetadata) + if !ok { + return RequestMetadata{}, false + } + + return metadata, true +} + +// DefaultRequestMetadataExtractor extracts headers, query values, and path variables. +func DefaultRequestMetadataExtractor(r *http.Request) RequestMetadata { + metadata := RequestMetadata{ + Headers: make(map[string][]string, len(r.Header)), + Query: make(map[string][]string), + Path: clonePathVars(pathvar.Vars(r)), + } + + for key, vals := range r.Header { + metadata.Headers[http.CanonicalHeaderKey(key)] = append([]string(nil), vals...) + } + + if r.URL != nil { + for key, vals := range r.URL.Query() { + metadata.Query[key] = append([]string(nil), vals...) + } + } + + return metadata +} + +func normalizeRequestMetadata(metadata RequestMetadata) RequestMetadata { + return RequestMetadata{ + Headers: cloneCanonicalHeaderValues(metadata.Headers), + Query: cloneHeaderValues(metadata.Query), + Path: clonePathVars(metadata.Path), + } +} + +func cloneHeaderValues(values map[string][]string) map[string][]string { + if len(values) == 0 { + return nil + } + + cloned := make(map[string][]string, len(values)) + for key, vals := range values { + cloned[key] = append([]string(nil), vals...) + } + + return cloned +} + +func cloneCanonicalHeaderValues(values map[string][]string) map[string][]string { + if len(values) == 0 { + return nil + } + + cloned := make(map[string][]string, len(values)) + for key, vals := range values { + canonical := http.CanonicalHeaderKey(key) + cloned[canonical] = append(cloned[canonical], vals...) + } + + return cloned +} + +func clonePathVars(values map[string]string) map[string]string { + if len(values) == 0 { + return nil + } + + cloned := make(map[string]string, len(values)) + for key, val := range values { + cloned[key] = val + } + + return cloned +} diff --git a/mcp/request_metadata_test.go b/mcp/request_metadata_test.go new file mode 100644 index 000000000..9e2433b48 --- /dev/null +++ b/mcp/request_metadata_test.go @@ -0,0 +1,185 @@ +package mcp + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/rest/pathvar" +) + +func TestDefaultRequestMetadataExtractor(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/sse?tenant=t1&trace=abc", nil) + req.Header.Add("X-Tenant-Id", "tenant-from-header") + req = pathvar.WithVars(req, map[string]string{"tool": "sum"}) + + metadata := DefaultRequestMetadataExtractor(req) + header, ok := metadata.Headers["X-Tenant-Id"] + assert.True(t, ok) + assert.Equal(t, []string{"tenant-from-header"}, header) + assert.Equal(t, []string{"t1"}, metadata.Query["tenant"]) + assert.Equal(t, "sum", metadata.Path["tool"]) +} + +func TestRequestMetadataContextHelpers(t *testing.T) { + ctx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{ + Headers: map[string][]string{"X-Trace-Id": {"trace-1"}}, + Query: map[string][]string{"tenant": {"foo"}}, + Path: map[string]string{"scope": "prod"}, + }) + + metadata, ok := RequestMetadataFromContext(ctx) + assert.True(t, ok) + assert.Equal(t, []string{"trace-1"}, metadata.Headers["X-Trace-Id"]) + + header, ok := HeaderFromContext(ctx, "x-trace-id") + assert.True(t, ok) + assert.Equal(t, "trace-1", header) + + query, ok := QueryFromContext(ctx, "tenant") + assert.True(t, ok) + assert.Equal(t, "foo", query) + + path, ok := PathFromContext(ctx, "scope") + assert.True(t, ok) + assert.Equal(t, "prod", path) +} + +func TestRequestMetadataContextHelpersMissingKeys(t *testing.T) { + ctx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{ + Headers: map[string][]string{"X-Trace-Id": {"trace-1"}}, + Query: map[string][]string{"tenant": {"foo"}}, + Path: map[string]string{"scope": "prod"}, + }) + + _, ok := HeaderFromContext(ctx, "x-missing") + assert.False(t, ok) + + _, ok = QueryFromContext(ctx, "missing") + assert.False(t, ok) + + _, ok = PathFromContext(ctx, "missing") + assert.False(t, ok) +} + +func TestRequestMetadataFromContextNotFound(t *testing.T) { + _, ok := RequestMetadataFromContext(context.Background()) + assert.False(t, ok) + + _, ok = HeaderFromContext(context.Background(), "x-test") + assert.False(t, ok) + + _, ok = QueryFromContext(context.Background(), "tenant") + assert.False(t, ok) + + _, ok = PathFromContext(context.Background(), "tenant") + assert.False(t, ok) +} + +func TestWrapRequestMetadata(t *testing.T) { + s := &mcpServerImpl{ + options: serverOptions{ + requestMetadataExtractor: DefaultRequestMetadataExtractor, + }, + } + + called := false + handler := s.wrapRequestMetadata(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + called = true + header, ok := HeaderFromContext(r.Context(), "x-tenant-id") + assert.True(t, ok) + assert.Equal(t, "tenant-1", header) + + query, ok := QueryFromContext(r.Context(), "tenant") + assert.True(t, ok) + assert.Equal(t, "q-tenant", query) + })) + + req := httptest.NewRequest(http.MethodGet, "/sse?tenant=q-tenant", nil) + req.Header.Set("X-Tenant-Id", "tenant-1") + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.True(t, called) +} + +func TestWrapRequestMetadataNoExtractor(t *testing.T) { + s := &mcpServerImpl{} + + called := false + handler := s.wrapRequestMetadata(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + called = true + _, ok := RequestMetadataFromContext(r.Context()) + assert.False(t, ok) + })) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/sse", nil)) + + assert.True(t, called) +} + +func TestWrapRequestMetadataCanonicalizesCustomHeaders(t *testing.T) { + s := &mcpServerImpl{ + options: serverOptions{ + requestMetadataExtractor: func(*http.Request) RequestMetadata { + return RequestMetadata{ + Headers: map[string][]string{ + "x-tenant-id": {"tenant-lower"}, + }, + } + }, + }, + } + + called := false + handler := s.wrapRequestMetadata(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + called = true + header, ok := HeaderFromContext(r.Context(), "X-Tenant-Id") + assert.True(t, ok) + assert.Equal(t, "tenant-lower", header) + })) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/sse", nil)) + + assert.True(t, called) +} + +func TestRequestMetadataFromContextReturnsCopy(t *testing.T) { + ctx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{ + Headers: map[string][]string{"X-Trace-Id": {"trace-1"}}, + }) + + metadata, ok := RequestMetadataFromContext(ctx) + assert.True(t, ok) + metadata.Headers["X-Trace-Id"][0] = "mutated" + metadata.Headers["X-New"] = []string{"new"} + + fresh, ok := RequestMetadataFromContext(ctx) + assert.True(t, ok) + assert.Equal(t, []string{"trace-1"}, fresh.Headers["X-Trace-Id"]) + assert.Nil(t, fresh.Headers["X-New"]) +} + +func TestRequestMetadataFromContextWithEmptyAndCanonicalizedHeaders(t *testing.T) { + emptyCtx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{}) + empty, ok := RequestMetadataFromContext(emptyCtx) + assert.True(t, ok) + assert.Nil(t, empty.Headers) + assert.Nil(t, empty.Query) + assert.Nil(t, empty.Path) + + ctx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{ + Headers: map[string][]string{ + "x-tenant-id": {"a"}, + "X-Tenant-Id": {"b"}, + }, + }) + + metadata, ok := RequestMetadataFromContext(ctx) + assert.True(t, ok) + assert.Equal(t, []string{"a", "b"}, metadata.Headers["X-Tenant-Id"]) +} diff --git a/mcp/server.go b/mcp/server.go index fe7e1e8c0..32a10ddc7 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1,6 +1,7 @@ package mcp import ( + "context" "net/http" sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" @@ -20,10 +21,23 @@ type mcpServerImpl struct { conf McpConf httpServer *rest.Server mcpServer *sdkmcp.Server + options serverOptions } // NewMcpServer creates a new MCP server using the official SDK func NewMcpServer(c McpConf) McpServer { + return NewMcpServerWithOptions(c) +} + +// NewMcpServerWithOptions creates a new MCP server with optional customizations. +func NewMcpServerWithOptions(c McpConf, opts ...McpOption) McpServer { + serverOpts := defaultServerOptions() + for _, opt := range opts { + if opt != nil { + opt.apply(&serverOpts) + } + } + // Create the underlying rest HTTP server var httpServer *rest.Server if len(c.Mcp.Cors) == 0 { @@ -52,6 +66,7 @@ func NewMcpServer(c McpConf) McpServer { conf: c, httpServer: httpServer, mcpServer: mcpServer, + options: serverOpts, } // Choose transport based on configuration @@ -85,7 +100,7 @@ func (s *mcpServerImpl) setupSSETransport() { return s.mcpServer }, nil) - s.registerRoutes(handler, s.conf.Mcp.SseEndpoint) + s.registerRoutes(s.wrapRequestMetadata(handler), s.conf.Mcp.SseEndpoint) } // setupStreamableTransport configures the server to use Streamable HTTP transport (2025-03-26 spec) @@ -96,7 +111,7 @@ func (s *mcpServerImpl) setupStreamableTransport() { return s.mcpServer }, nil) - s.registerRoutes(handler, s.conf.Mcp.MessageEndpoint) + s.registerRoutes(s.wrapRequestMetadata(handler), s.conf.Mcp.MessageEndpoint) } func (s *mcpServerImpl) registerRoutes(handler http.Handler, endpoint string) { @@ -113,3 +128,16 @@ func (s *mcpServerImpl) registerRoutes(handler http.Handler, endpoint string) { Handler: handler.ServeHTTP, }, rest.WithTimeout(s.conf.Mcp.MessageTimeout)) } + +func (s *mcpServerImpl) wrapRequestMetadata(next http.Handler) http.Handler { + extractor := s.options.requestMetadataExtractor + if extractor == nil { + return next + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + metadata := normalizeRequestMetadata(extractor(r)) + ctx := context.WithValue(r.Context(), requestMetadataCtxKey{}, metadata) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} diff --git a/mcp/server_test.go b/mcp/server_test.go index 171d07fe6..37870b4c8 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -3,11 +3,14 @@ package mcp import ( "bytes" "context" + "fmt" + "net" "net/http" "net/http/httptest" "testing" "time" + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/conf" ) @@ -391,3 +394,148 @@ func TestAddToolWithCustomServer(t *testing.T) { return nil, nil, nil }) } + +func TestRequestMetadataIntegrationSSEToolCall(t *testing.T) { + port := getFreePort(t) + + c := McpConf{} + c.Host = "127.0.0.1" + c.Port = port + c.Mcp.Name = "metadata-integration-test" + c.Mcp.UseStreamable = false + c.Mcp.SseEndpoint = "/sse/:scope" + c.Mcp.MessageTimeout = 2 * time.Second + c.Mcp.SseTimeout = 2 * time.Second + + server := NewMcpServerWithOptions(c, WithRequestMetadataExtractor(DefaultRequestMetadataExtractor)) + + tool := &Tool{ + Name: "inspect_metadata", + Description: "Inspect metadata in handler context", + } + + type Args struct{} + + AddTool(server, tool, func(ctx context.Context, req *CallToolRequest, args Args) (*CallToolResult, any, error) { + header, ok := HeaderFromContext(ctx, "x-tenant-id") + if !ok || header != "tenant-header" { + return nil, nil, fmt.Errorf("unexpected header from context: %q", header) + } + + query, ok := QueryFromContext(ctx, "tenant") + if !ok || query != "tenant-query" { + return nil, nil, fmt.Errorf("unexpected query from context: %q", query) + } + + scope, ok := PathFromContext(ctx, "scope") + if !ok || scope != "prod" { + return nil, nil, fmt.Errorf("unexpected path from context: %q", scope) + } + + return &CallToolResult{ + Content: []Content{&TextContent{Text: "metadata-ok"}}, + }, nil, nil + }) + + go server.Start() + t.Cleanup(server.Stop) + + baseURL := fmt.Sprintf("http://127.0.0.1:%d/sse/prod?tenant=tenant-query", port) + waitForServerReady(t, baseURL, 2*time.Second) + + client := sdkmcp.NewClient(&sdkmcp.Implementation{ + Name: "metadata-client", + Version: "1.0.0", + }, nil) + + httpClient := &http.Client{ + Timeout: 2 * time.Second, + Transport: metadataHeaderRoundTripper{ + next: http.DefaultTransport, + }, + } + + transport := &sdkmcp.SSEClientTransport{ + Endpoint: baseURL, + HTTPClient: httpClient, + } + + session, err := client.Connect(context.Background(), transport, nil) + if !assert.NoError(t, err) { + return + } + t.Cleanup(func() { + _ = session.Close() + }) + + res, err := session.CallTool(context.Background(), &sdkmcp.CallToolParams{ + Name: "inspect_metadata", + Arguments: map[string]any{}, + }) + if !assert.NoError(t, err) { + return + } + + if !assert.NotNil(t, res) { + return + } + assert.False(t, res.IsError) +} + +type metadataHeaderRoundTripper struct { + next http.RoundTripper +} + +func (r metadataHeaderRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + next := r.next + if next == nil { + next = http.DefaultTransport + } + + clone := req.Clone(req.Context()) + clone.Header.Set("X-Tenant-Id", "tenant-header") + return next.RoundTrip(clone) +} + +func getFreePort(t *testing.T) int { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if !assert.NoError(t, err) { + return 0 + } + defer listener.Close() + + addr, ok := listener.Addr().(*net.TCPAddr) + if !assert.True(t, ok) { + return 0 + } + + return addr.Port +} + +func waitForServerReady(t *testing.T, endpoint string, timeout time.Duration) { + t.Helper() + + client := &http.Client{Timeout: 200 * time.Millisecond} + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + req, err := http.NewRequest(http.MethodGet, endpoint, nil) + if err != nil { + t.Fatalf("failed to build readiness request: %v", err) + } + req.Header.Set("Accept", "text/event-stream") + + resp, err := client.Do(req) + if err == nil { + _ = resp.Body.Close() + if resp.StatusCode > 0 { + return + } + } + + time.Sleep(20 * time.Millisecond) + } + + t.Fatalf("server did not become ready for %s within %s", endpoint, timeout) +}