mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-13 09:50:00 +08:00
feat(mcp): add opt-in request metadata bridge for tool handlers (#5550)
This commit is contained in:
@@ -3,11 +3,14 @@ package mcp
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/conf"
|
||||
)
|
||||
@@ -391,3 +394,148 @@ func TestAddToolWithCustomServer(t *testing.T) {
|
||||
return nil, nil, nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequestMetadataIntegrationSSEToolCall(t *testing.T) {
|
||||
port := getFreePort(t)
|
||||
|
||||
c := McpConf{}
|
||||
c.Host = "127.0.0.1"
|
||||
c.Port = port
|
||||
c.Mcp.Name = "metadata-integration-test"
|
||||
c.Mcp.UseStreamable = false
|
||||
c.Mcp.SseEndpoint = "/sse/:scope"
|
||||
c.Mcp.MessageTimeout = 2 * time.Second
|
||||
c.Mcp.SseTimeout = 2 * time.Second
|
||||
|
||||
server := NewMcpServerWithOptions(c, WithRequestMetadataExtractor(DefaultRequestMetadataExtractor))
|
||||
|
||||
tool := &Tool{
|
||||
Name: "inspect_metadata",
|
||||
Description: "Inspect metadata in handler context",
|
||||
}
|
||||
|
||||
type Args struct{}
|
||||
|
||||
AddTool(server, tool, func(ctx context.Context, req *CallToolRequest, args Args) (*CallToolResult, any, error) {
|
||||
header, ok := HeaderFromContext(ctx, "x-tenant-id")
|
||||
if !ok || header != "tenant-header" {
|
||||
return nil, nil, fmt.Errorf("unexpected header from context: %q", header)
|
||||
}
|
||||
|
||||
query, ok := QueryFromContext(ctx, "tenant")
|
||||
if !ok || query != "tenant-query" {
|
||||
return nil, nil, fmt.Errorf("unexpected query from context: %q", query)
|
||||
}
|
||||
|
||||
scope, ok := PathFromContext(ctx, "scope")
|
||||
if !ok || scope != "prod" {
|
||||
return nil, nil, fmt.Errorf("unexpected path from context: %q", scope)
|
||||
}
|
||||
|
||||
return &CallToolResult{
|
||||
Content: []Content{&TextContent{Text: "metadata-ok"}},
|
||||
}, nil, nil
|
||||
})
|
||||
|
||||
go server.Start()
|
||||
t.Cleanup(server.Stop)
|
||||
|
||||
baseURL := fmt.Sprintf("http://127.0.0.1:%d/sse/prod?tenant=tenant-query", port)
|
||||
waitForServerReady(t, baseURL, 2*time.Second)
|
||||
|
||||
client := sdkmcp.NewClient(&sdkmcp.Implementation{
|
||||
Name: "metadata-client",
|
||||
Version: "1.0.0",
|
||||
}, nil)
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 2 * time.Second,
|
||||
Transport: metadataHeaderRoundTripper{
|
||||
next: http.DefaultTransport,
|
||||
},
|
||||
}
|
||||
|
||||
transport := &sdkmcp.SSEClientTransport{
|
||||
Endpoint: baseURL,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
|
||||
session, err := client.Connect(context.Background(), transport, nil)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = session.Close()
|
||||
})
|
||||
|
||||
res, err := session.CallTool(context.Background(), &sdkmcp.CallToolParams{
|
||||
Name: "inspect_metadata",
|
||||
Arguments: map[string]any{},
|
||||
})
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
if !assert.NotNil(t, res) {
|
||||
return
|
||||
}
|
||||
assert.False(t, res.IsError)
|
||||
}
|
||||
|
||||
type metadataHeaderRoundTripper struct {
|
||||
next http.RoundTripper
|
||||
}
|
||||
|
||||
func (r metadataHeaderRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
next := r.next
|
||||
if next == nil {
|
||||
next = http.DefaultTransport
|
||||
}
|
||||
|
||||
clone := req.Clone(req.Context())
|
||||
clone.Header.Set("X-Tenant-Id", "tenant-header")
|
||||
return next.RoundTrip(clone)
|
||||
}
|
||||
|
||||
func getFreePort(t *testing.T) int {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if !assert.NoError(t, err) {
|
||||
return 0
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
addr, ok := listener.Addr().(*net.TCPAddr)
|
||||
if !assert.True(t, ok) {
|
||||
return 0
|
||||
}
|
||||
|
||||
return addr.Port
|
||||
}
|
||||
|
||||
func waitForServerReady(t *testing.T, endpoint string, timeout time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
client := &http.Client{Timeout: 200 * time.Millisecond}
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
req, err := http.NewRequest(http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build readiness request: %v", err)
|
||||
}
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err == nil {
|
||||
_ = resp.Body.Close()
|
||||
if resp.StatusCode > 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Fatalf("server did not become ready for %s within %s", endpoint, timeout)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user