mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-07 06:59:59 +08:00
feat(mcp): add opt-in request metadata bridge for tool handlers (#5550)
This commit is contained in:
33
mcp/options.go
Normal file
33
mcp/options.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// RequestMetadataExtractor extracts request metadata for downstream handlers.
|
||||||
|
type RequestMetadataExtractor func(*http.Request) RequestMetadata
|
||||||
|
|
||||||
|
// McpOption customizes MCP server construction.
|
||||||
|
type McpOption interface {
|
||||||
|
apply(*serverOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
type mcpOptionFunc func(*serverOptions)
|
||||||
|
|
||||||
|
func (f mcpOptionFunc) apply(opts *serverOptions) {
|
||||||
|
f(opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
type serverOptions struct {
|
||||||
|
requestMetadataExtractor RequestMetadataExtractor
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultServerOptions() serverOptions {
|
||||||
|
return serverOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRequestMetadataExtractor installs an extractor that runs for each incoming
|
||||||
|
// MCP HTTP request, and stores the extracted metadata into handler context.
|
||||||
|
func WithRequestMetadataExtractor(extractor RequestMetadataExtractor) McpOption {
|
||||||
|
return mcpOptionFunc(func(opts *serverOptions) {
|
||||||
|
opts.requestMetadataExtractor = extractor
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -15,6 +15,7 @@ This package provides a go-zero integration for the [Model Context Protocol (MCP
|
|||||||
- **CORS Support**: Configurable CORS settings for cross-origin requests
|
- **CORS Support**: Configurable CORS settings for cross-origin requests
|
||||||
- **Type-Safe Tool Handlers**: Generic tool handlers with automatic JSON schema generation
|
- **Type-Safe Tool Handlers**: Generic tool handlers with automatic JSON schema generation
|
||||||
- **Prompts and Resources**: Full support for MCP prompts and resources
|
- **Prompts and Resources**: Full support for MCP prompts and resources
|
||||||
|
- **Request Metadata Bridge**: Optional request metadata extraction into handler context
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
@@ -220,6 +221,35 @@ mcp:
|
|||||||
messageEndpoint: /message
|
messageEndpoint: /message
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Request Metadata Bridge
|
||||||
|
|
||||||
|
For multi-tenant or request-context-aware tools, you can extract selected HTTP request metadata once at the transport boundary and read it from `context.Context` in handlers.
|
||||||
|
|
||||||
|
```go
|
||||||
|
server := mcp.NewMcpServerWithOptions(c,
|
||||||
|
mcp.WithRequestMetadataExtractor(mcp.DefaultRequestMetadataExtractor),
|
||||||
|
)
|
||||||
|
|
||||||
|
handler := func(ctx context.Context, req *mcp.CallToolRequest, args SomeArgs) (*mcp.CallToolResult, any, error) {
|
||||||
|
tenant, _ := mcp.HeaderFromContext(ctx, "X-Tenant-Id")
|
||||||
|
traceID, _ := mcp.QueryFromContext(ctx, "trace")
|
||||||
|
scope, _ := mcp.PathFromContext(ctx, "scope")
|
||||||
|
|
||||||
|
_ = tenant
|
||||||
|
_ = traceID
|
||||||
|
_ = scope
|
||||||
|
|
||||||
|
return &mcp.CallToolResult{}, nil, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Available helpers:
|
||||||
|
|
||||||
|
- `RequestMetadataFromContext(ctx)`
|
||||||
|
- `HeaderFromContext(ctx, key)`
|
||||||
|
- `QueryFromContext(ctx, key)`
|
||||||
|
- `PathFromContext(ctx, key)`
|
||||||
|
|
||||||
## Configuration Options
|
## Configuration Options
|
||||||
|
|
||||||
| Field | Type | Default | Description |
|
| Field | Type | Default | Description |
|
||||||
|
|||||||
150
mcp/request_metadata.go
Normal file
150
mcp/request_metadata.go
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/rest/pathvar"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestMetadata carries selected request-scoped values into MCP handlers.
|
||||||
|
type RequestMetadata struct {
|
||||||
|
Headers map[string][]string
|
||||||
|
Query map[string][]string
|
||||||
|
Path map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestMetadataCtxKey struct{}
|
||||||
|
|
||||||
|
// RequestMetadataFromContext returns metadata extracted at the transport boundary.
|
||||||
|
func RequestMetadataFromContext(ctx context.Context) (RequestMetadata, bool) {
|
||||||
|
metadata, ok := requestMetadataFromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return RequestMetadata{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return normalizeRequestMetadata(metadata), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// HeaderFromContext returns the first header value for key.
|
||||||
|
func HeaderFromContext(ctx context.Context, key string) (string, bool) {
|
||||||
|
metadata, ok := requestMetadataFromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
vals := metadata.Headers[http.CanonicalHeaderKey(key)]
|
||||||
|
if len(vals) == 0 {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
return vals[0], true
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryFromContext returns the first query value for key.
|
||||||
|
func QueryFromContext(ctx context.Context, key string) (string, bool) {
|
||||||
|
metadata, ok := requestMetadataFromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
vals := metadata.Query[key]
|
||||||
|
if len(vals) == 0 {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
return vals[0], true
|
||||||
|
}
|
||||||
|
|
||||||
|
// PathFromContext returns the path variable value for key.
|
||||||
|
func PathFromContext(ctx context.Context, key string) (string, bool) {
|
||||||
|
metadata, ok := requestMetadataFromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
val, ok := metadata.Path[key]
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
return val, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestMetadataFromContext(ctx context.Context) (RequestMetadata, bool) {
|
||||||
|
metadata, ok := ctx.Value(requestMetadataCtxKey{}).(RequestMetadata)
|
||||||
|
if !ok {
|
||||||
|
return RequestMetadata{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return metadata, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultRequestMetadataExtractor extracts headers, query values, and path variables.
|
||||||
|
func DefaultRequestMetadataExtractor(r *http.Request) RequestMetadata {
|
||||||
|
metadata := RequestMetadata{
|
||||||
|
Headers: make(map[string][]string, len(r.Header)),
|
||||||
|
Query: make(map[string][]string),
|
||||||
|
Path: clonePathVars(pathvar.Vars(r)),
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, vals := range r.Header {
|
||||||
|
metadata.Headers[http.CanonicalHeaderKey(key)] = append([]string(nil), vals...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.URL != nil {
|
||||||
|
for key, vals := range r.URL.Query() {
|
||||||
|
metadata.Query[key] = append([]string(nil), vals...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeRequestMetadata(metadata RequestMetadata) RequestMetadata {
|
||||||
|
return RequestMetadata{
|
||||||
|
Headers: cloneCanonicalHeaderValues(metadata.Headers),
|
||||||
|
Query: cloneHeaderValues(metadata.Query),
|
||||||
|
Path: clonePathVars(metadata.Path),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneHeaderValues(values map[string][]string) map[string][]string {
|
||||||
|
if len(values) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cloned := make(map[string][]string, len(values))
|
||||||
|
for key, vals := range values {
|
||||||
|
cloned[key] = append([]string(nil), vals...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneCanonicalHeaderValues(values map[string][]string) map[string][]string {
|
||||||
|
if len(values) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cloned := make(map[string][]string, len(values))
|
||||||
|
for key, vals := range values {
|
||||||
|
canonical := http.CanonicalHeaderKey(key)
|
||||||
|
cloned[canonical] = append(cloned[canonical], vals...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func clonePathVars(values map[string]string) map[string]string {
|
||||||
|
if len(values) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cloned := make(map[string]string, len(values))
|
||||||
|
for key, val := range values {
|
||||||
|
cloned[key] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
185
mcp/request_metadata_test.go
Normal file
185
mcp/request_metadata_test.go
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/rest/pathvar"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDefaultRequestMetadataExtractor(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/sse?tenant=t1&trace=abc", nil)
|
||||||
|
req.Header.Add("X-Tenant-Id", "tenant-from-header")
|
||||||
|
req = pathvar.WithVars(req, map[string]string{"tool": "sum"})
|
||||||
|
|
||||||
|
metadata := DefaultRequestMetadataExtractor(req)
|
||||||
|
header, ok := metadata.Headers["X-Tenant-Id"]
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, []string{"tenant-from-header"}, header)
|
||||||
|
assert.Equal(t, []string{"t1"}, metadata.Query["tenant"])
|
||||||
|
assert.Equal(t, "sum", metadata.Path["tool"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestMetadataContextHelpers(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{
|
||||||
|
Headers: map[string][]string{"X-Trace-Id": {"trace-1"}},
|
||||||
|
Query: map[string][]string{"tenant": {"foo"}},
|
||||||
|
Path: map[string]string{"scope": "prod"},
|
||||||
|
})
|
||||||
|
|
||||||
|
metadata, ok := RequestMetadataFromContext(ctx)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, []string{"trace-1"}, metadata.Headers["X-Trace-Id"])
|
||||||
|
|
||||||
|
header, ok := HeaderFromContext(ctx, "x-trace-id")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "trace-1", header)
|
||||||
|
|
||||||
|
query, ok := QueryFromContext(ctx, "tenant")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "foo", query)
|
||||||
|
|
||||||
|
path, ok := PathFromContext(ctx, "scope")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "prod", path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestMetadataContextHelpersMissingKeys(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{
|
||||||
|
Headers: map[string][]string{"X-Trace-Id": {"trace-1"}},
|
||||||
|
Query: map[string][]string{"tenant": {"foo"}},
|
||||||
|
Path: map[string]string{"scope": "prod"},
|
||||||
|
})
|
||||||
|
|
||||||
|
_, ok := HeaderFromContext(ctx, "x-missing")
|
||||||
|
assert.False(t, ok)
|
||||||
|
|
||||||
|
_, ok = QueryFromContext(ctx, "missing")
|
||||||
|
assert.False(t, ok)
|
||||||
|
|
||||||
|
_, ok = PathFromContext(ctx, "missing")
|
||||||
|
assert.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestMetadataFromContextNotFound(t *testing.T) {
|
||||||
|
_, ok := RequestMetadataFromContext(context.Background())
|
||||||
|
assert.False(t, ok)
|
||||||
|
|
||||||
|
_, ok = HeaderFromContext(context.Background(), "x-test")
|
||||||
|
assert.False(t, ok)
|
||||||
|
|
||||||
|
_, ok = QueryFromContext(context.Background(), "tenant")
|
||||||
|
assert.False(t, ok)
|
||||||
|
|
||||||
|
_, ok = PathFromContext(context.Background(), "tenant")
|
||||||
|
assert.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapRequestMetadata(t *testing.T) {
|
||||||
|
s := &mcpServerImpl{
|
||||||
|
options: serverOptions{
|
||||||
|
requestMetadataExtractor: DefaultRequestMetadataExtractor,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
called := false
|
||||||
|
handler := s.wrapRequestMetadata(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||||
|
called = true
|
||||||
|
header, ok := HeaderFromContext(r.Context(), "x-tenant-id")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "tenant-1", header)
|
||||||
|
|
||||||
|
query, ok := QueryFromContext(r.Context(), "tenant")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "q-tenant", query)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/sse?tenant=q-tenant", nil)
|
||||||
|
req.Header.Set("X-Tenant-Id", "tenant-1")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
assert.True(t, called)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapRequestMetadataNoExtractor(t *testing.T) {
|
||||||
|
s := &mcpServerImpl{}
|
||||||
|
|
||||||
|
called := false
|
||||||
|
handler := s.wrapRequestMetadata(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||||
|
called = true
|
||||||
|
_, ok := RequestMetadataFromContext(r.Context())
|
||||||
|
assert.False(t, ok)
|
||||||
|
}))
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/sse", nil))
|
||||||
|
|
||||||
|
assert.True(t, called)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapRequestMetadataCanonicalizesCustomHeaders(t *testing.T) {
|
||||||
|
s := &mcpServerImpl{
|
||||||
|
options: serverOptions{
|
||||||
|
requestMetadataExtractor: func(*http.Request) RequestMetadata {
|
||||||
|
return RequestMetadata{
|
||||||
|
Headers: map[string][]string{
|
||||||
|
"x-tenant-id": {"tenant-lower"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
called := false
|
||||||
|
handler := s.wrapRequestMetadata(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||||
|
called = true
|
||||||
|
header, ok := HeaderFromContext(r.Context(), "X-Tenant-Id")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "tenant-lower", header)
|
||||||
|
}))
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/sse", nil))
|
||||||
|
|
||||||
|
assert.True(t, called)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestMetadataFromContextReturnsCopy(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{
|
||||||
|
Headers: map[string][]string{"X-Trace-Id": {"trace-1"}},
|
||||||
|
})
|
||||||
|
|
||||||
|
metadata, ok := RequestMetadataFromContext(ctx)
|
||||||
|
assert.True(t, ok)
|
||||||
|
metadata.Headers["X-Trace-Id"][0] = "mutated"
|
||||||
|
metadata.Headers["X-New"] = []string{"new"}
|
||||||
|
|
||||||
|
fresh, ok := RequestMetadataFromContext(ctx)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, []string{"trace-1"}, fresh.Headers["X-Trace-Id"])
|
||||||
|
assert.Nil(t, fresh.Headers["X-New"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestMetadataFromContextWithEmptyAndCanonicalizedHeaders(t *testing.T) {
|
||||||
|
emptyCtx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{})
|
||||||
|
empty, ok := RequestMetadataFromContext(emptyCtx)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Nil(t, empty.Headers)
|
||||||
|
assert.Nil(t, empty.Query)
|
||||||
|
assert.Nil(t, empty.Path)
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{
|
||||||
|
Headers: map[string][]string{
|
||||||
|
"x-tenant-id": {"a"},
|
||||||
|
"X-Tenant-Id": {"b"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
metadata, ok := RequestMetadataFromContext(ctx)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, []string{"a", "b"}, metadata.Headers["X-Tenant-Id"])
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package mcp
|
package mcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
@@ -20,10 +21,23 @@ type mcpServerImpl struct {
|
|||||||
conf McpConf
|
conf McpConf
|
||||||
httpServer *rest.Server
|
httpServer *rest.Server
|
||||||
mcpServer *sdkmcp.Server
|
mcpServer *sdkmcp.Server
|
||||||
|
options serverOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMcpServer creates a new MCP server using the official SDK
|
// NewMcpServer creates a new MCP server using the official SDK
|
||||||
func NewMcpServer(c McpConf) McpServer {
|
func NewMcpServer(c McpConf) McpServer {
|
||||||
|
return NewMcpServerWithOptions(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMcpServerWithOptions creates a new MCP server with optional customizations.
|
||||||
|
func NewMcpServerWithOptions(c McpConf, opts ...McpOption) McpServer {
|
||||||
|
serverOpts := defaultServerOptions()
|
||||||
|
for _, opt := range opts {
|
||||||
|
if opt != nil {
|
||||||
|
opt.apply(&serverOpts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create the underlying rest HTTP server
|
// Create the underlying rest HTTP server
|
||||||
var httpServer *rest.Server
|
var httpServer *rest.Server
|
||||||
if len(c.Mcp.Cors) == 0 {
|
if len(c.Mcp.Cors) == 0 {
|
||||||
@@ -52,6 +66,7 @@ func NewMcpServer(c McpConf) McpServer {
|
|||||||
conf: c,
|
conf: c,
|
||||||
httpServer: httpServer,
|
httpServer: httpServer,
|
||||||
mcpServer: mcpServer,
|
mcpServer: mcpServer,
|
||||||
|
options: serverOpts,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Choose transport based on configuration
|
// Choose transport based on configuration
|
||||||
@@ -85,7 +100,7 @@ func (s *mcpServerImpl) setupSSETransport() {
|
|||||||
return s.mcpServer
|
return s.mcpServer
|
||||||
}, nil)
|
}, nil)
|
||||||
|
|
||||||
s.registerRoutes(handler, s.conf.Mcp.SseEndpoint)
|
s.registerRoutes(s.wrapRequestMetadata(handler), s.conf.Mcp.SseEndpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupStreamableTransport configures the server to use Streamable HTTP transport (2025-03-26 spec)
|
// setupStreamableTransport configures the server to use Streamable HTTP transport (2025-03-26 spec)
|
||||||
@@ -96,7 +111,7 @@ func (s *mcpServerImpl) setupStreamableTransport() {
|
|||||||
return s.mcpServer
|
return s.mcpServer
|
||||||
}, nil)
|
}, nil)
|
||||||
|
|
||||||
s.registerRoutes(handler, s.conf.Mcp.MessageEndpoint)
|
s.registerRoutes(s.wrapRequestMetadata(handler), s.conf.Mcp.MessageEndpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mcpServerImpl) registerRoutes(handler http.Handler, endpoint string) {
|
func (s *mcpServerImpl) registerRoutes(handler http.Handler, endpoint string) {
|
||||||
@@ -113,3 +128,16 @@ func (s *mcpServerImpl) registerRoutes(handler http.Handler, endpoint string) {
|
|||||||
Handler: handler.ServeHTTP,
|
Handler: handler.ServeHTTP,
|
||||||
}, rest.WithTimeout(s.conf.Mcp.MessageTimeout))
|
}, rest.WithTimeout(s.conf.Mcp.MessageTimeout))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *mcpServerImpl) wrapRequestMetadata(next http.Handler) http.Handler {
|
||||||
|
extractor := s.options.requestMetadataExtractor
|
||||||
|
if extractor == nil {
|
||||||
|
return next
|
||||||
|
}
|
||||||
|
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
metadata := normalizeRequestMetadata(extractor(r))
|
||||||
|
ctx := context.WithValue(r.Context(), requestMetadataCtxKey{}, metadata)
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,11 +3,14 @@ package mcp
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/zeromicro/go-zero/core/conf"
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
)
|
)
|
||||||
@@ -391,3 +394,148 @@ func TestAddToolWithCustomServer(t *testing.T) {
|
|||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRequestMetadataIntegrationSSEToolCall(t *testing.T) {
|
||||||
|
port := getFreePort(t)
|
||||||
|
|
||||||
|
c := McpConf{}
|
||||||
|
c.Host = "127.0.0.1"
|
||||||
|
c.Port = port
|
||||||
|
c.Mcp.Name = "metadata-integration-test"
|
||||||
|
c.Mcp.UseStreamable = false
|
||||||
|
c.Mcp.SseEndpoint = "/sse/:scope"
|
||||||
|
c.Mcp.MessageTimeout = 2 * time.Second
|
||||||
|
c.Mcp.SseTimeout = 2 * time.Second
|
||||||
|
|
||||||
|
server := NewMcpServerWithOptions(c, WithRequestMetadataExtractor(DefaultRequestMetadataExtractor))
|
||||||
|
|
||||||
|
tool := &Tool{
|
||||||
|
Name: "inspect_metadata",
|
||||||
|
Description: "Inspect metadata in handler context",
|
||||||
|
}
|
||||||
|
|
||||||
|
type Args struct{}
|
||||||
|
|
||||||
|
AddTool(server, tool, func(ctx context.Context, req *CallToolRequest, args Args) (*CallToolResult, any, error) {
|
||||||
|
header, ok := HeaderFromContext(ctx, "x-tenant-id")
|
||||||
|
if !ok || header != "tenant-header" {
|
||||||
|
return nil, nil, fmt.Errorf("unexpected header from context: %q", header)
|
||||||
|
}
|
||||||
|
|
||||||
|
query, ok := QueryFromContext(ctx, "tenant")
|
||||||
|
if !ok || query != "tenant-query" {
|
||||||
|
return nil, nil, fmt.Errorf("unexpected query from context: %q", query)
|
||||||
|
}
|
||||||
|
|
||||||
|
scope, ok := PathFromContext(ctx, "scope")
|
||||||
|
if !ok || scope != "prod" {
|
||||||
|
return nil, nil, fmt.Errorf("unexpected path from context: %q", scope)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &CallToolResult{
|
||||||
|
Content: []Content{&TextContent{Text: "metadata-ok"}},
|
||||||
|
}, nil, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
go server.Start()
|
||||||
|
t.Cleanup(server.Stop)
|
||||||
|
|
||||||
|
baseURL := fmt.Sprintf("http://127.0.0.1:%d/sse/prod?tenant=tenant-query", port)
|
||||||
|
waitForServerReady(t, baseURL, 2*time.Second)
|
||||||
|
|
||||||
|
client := sdkmcp.NewClient(&sdkmcp.Implementation{
|
||||||
|
Name: "metadata-client",
|
||||||
|
Version: "1.0.0",
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Timeout: 2 * time.Second,
|
||||||
|
Transport: metadataHeaderRoundTripper{
|
||||||
|
next: http.DefaultTransport,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := &sdkmcp.SSEClientTransport{
|
||||||
|
Endpoint: baseURL,
|
||||||
|
HTTPClient: httpClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := client.Connect(context.Background(), transport, nil)
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = session.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
res, err := session.CallTool(context.Background(), &sdkmcp.CallToolParams{
|
||||||
|
Name: "inspect_metadata",
|
||||||
|
Arguments: map[string]any{},
|
||||||
|
})
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !assert.NotNil(t, res) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.False(t, res.IsError)
|
||||||
|
}
|
||||||
|
|
||||||
|
type metadataHeaderRoundTripper struct {
|
||||||
|
next http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r metadataHeaderRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
next := r.next
|
||||||
|
if next == nil {
|
||||||
|
next = http.DefaultTransport
|
||||||
|
}
|
||||||
|
|
||||||
|
clone := req.Clone(req.Context())
|
||||||
|
clone.Header.Set("X-Tenant-Id", "tenant-header")
|
||||||
|
return next.RoundTrip(clone)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFreePort(t *testing.T) int {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
addr, ok := listener.Addr().(*net.TCPAddr)
|
||||||
|
if !assert.True(t, ok) {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return addr.Port
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForServerReady(t *testing.T, endpoint string, timeout time.Duration) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: 200 * time.Millisecond}
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to build readiness request: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Accept", "text/event-stream")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err == nil {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if resp.StatusCode > 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Fatalf("server did not become ready for %s within %s", endpoint, timeout)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user