From 8e7e5695eb2095917864b5d6615dab4c90bde3ac Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Fri, 26 Dec 2025 00:21:45 +0800 Subject: [PATCH] feat(mcp): migrate to official go-sdk with simplified API (#5362) --- go.mod | 3 + go.sum | 8 + mcp/MIGRATION.md | 166 ++ mcp/config.go | 11 +- mcp/config_test.go | 43 +- mcp/integration_test.go | 443 ----- mcp/parser.go | 23 - mcp/parser_test.go | 139 -- mcp/readme.md | 1012 +++-------- mcp/server.go | 990 +---------- mcp/server_test.go | 3691 ++++----------------------------------- mcp/types.go | 392 +---- mcp/types_test.go | 271 --- mcp/util.go | 107 -- mcp/util_test.go | 274 --- mcp/vars.go | 149 -- mcp/vars_test.go | 210 --- 17 files changed, 864 insertions(+), 7068 deletions(-) create mode 100644 mcp/MIGRATION.md delete mode 100644 mcp/integration_test.go delete mode 100644 mcp/parser.go delete mode 100644 mcp/parser_test.go delete mode 100644 mcp/types_test.go delete mode 100644 mcp/util.go delete mode 100644 mcp/util_test.go delete mode 100644 mcp/vars.go delete mode 100644 mcp/vars_test.go diff --git a/go.mod b/go.mod index 21f591bfc..324fbc606 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/grafana/pyroscope-go v1.2.7 github.com/jackc/pgx/v5 v5.7.4 github.com/jhump/protoreflect v1.17.0 + github.com/modelcontextprotocol/go-sdk v1.2.0 github.com/pelletier/go-toml/v2 v2.2.4 github.com/prometheus/client_golang v1.23.2 github.com/redis/go-redis/v9 v9.17.2 @@ -71,6 +72,7 @@ require ( github.com/google/gnostic-models v0.6.8 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/gofuzz v1.2.0 // indirect + github.com/google/jsonschema-go v0.3.0 // indirect github.com/grafana/pyroscope-go/godeltaprof v0.1.9 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect @@ -98,6 +100,7 @@ require ( github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.1.2 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.15 // indirect diff --git a/go.sum b/go.sum index 94f3f7671..390a8f2e7 100644 --- a/go.sum +++ b/go.sum @@ -62,6 +62,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= @@ -74,6 +76,8 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= +github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -123,6 +127,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s= +github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -182,6 +188,8 @@ github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/mcp/MIGRATION.md b/mcp/MIGRATION.md new file mode 100644 index 000000000..bd860e209 --- /dev/null +++ b/mcp/MIGRATION.md @@ -0,0 +1,166 @@ +# Migration to Official MCP SDK + +This document describes the migration from the custom MCP implementation to the official [go-sdk](https://github.com/modelcontextprotocol/go-sdk). + +## Changes + +### Dependencies + +Added the official MCP SDK: +```bash +go get github.com/modelcontextprotocol/go-sdk@v1.2.0 +``` + +### Type System + +All types are now re-exported from the official SDK: +- `Tool` → `sdkmcp.Tool` +- `CallToolRequest` → `sdkmcp.CallToolRequest` +- `CallToolResult` → `sdkmcp.CallToolResult` +- Content types (`TextContent`, `ImageContent`, etc.) +- `Prompt`, `Resource`, `Server`, `ServerSession` + +### Server Interface + +The `McpServer` interface has been simplified: + +```go +type McpServer interface { + Start() + Stop() + Server() *sdkmcp.Server // Returns underlying SDK server +} +``` + +**Important**: The `AddTool`, `AddPrompt`, and `AddResource` methods have been removed. Use the SDK directly: + +```go +// Old (no longer supported) +server.AddTool(tool, handler) + +// New (use SDK directly) +sdkmcp.AddTool(server.Server(), tool, handler) +``` + +### Configuration + +Updated configuration structure: +- Removed: `ProtocolVersion`, `BaseUrl` (SDK manages these) +- Added: `UseStreamable` (choose between SSE and Streamable HTTP transport) + +```yaml +mcp: + name: my-server + version: 1.0.0 + useStreamable: false # false = SSE (2024-11-05), true = Streamable HTTP (2025-03-26) + sseEndpoint: /sse + messageEndpoint: /message + sseTimeout: 24h + messageTimeout: 30s + cors: + - http://localhost:3000 +``` + +### Tool Registration + +The SDK uses Go generics for type-safe tool registration: + +```go +import sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + +type MyArgs struct { + Value string `json:"value" jsonschema:"description=Input value"` +} + +tool := &mcp.Tool{ + Name: "my_tool", + Description: "Description", +} + +handler := func(ctx context.Context, req *mcp.CallToolRequest, args MyArgs) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Result"}, + }, + }, nil, nil +} + +// Register with explicit type parameters +sdkmcp.AddTool(server.Server(), tool, handler) +``` + +The SDK automatically generates JSON schemas from struct tags. + +### Transport Support + +Two transports are supported: + +1. **SSE (Server-Sent Events)**: 2024-11-05 MCP spec + - Default (`UseStreamable: false`) + - Endpoint: `/sse` (configurable) + - Bidirectional: client sends messages to `/message` + +2. **Streamable HTTP**: 2025-03-26 MCP spec + - Opt-in (`UseStreamable: true`) + - Endpoint: `/sse` (configurable) + - Newer protocol with improved streaming + +### Example Migration + +**Before:** +```go +server := mcp.NewMcpServer(c) + +tool := &mcp.Tool{Name: "greet", Description: "Greet"} +handler := func(ctx context.Context, req *mcp.CallToolRequest, args GreetArgs) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: "Hello"}}, + }, nil, nil +} + +if err := server.AddTool(tool, handler); err != nil { + log.Fatal(err) +} +``` + +**After:** +```go +import sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + +server := mcp.NewMcpServer(c) + +tool := &mcp.Tool{Name: "greet", Description: "Greet"} +handler := func(ctx context.Context, req *mcp.CallToolRequest, args GreetArgs) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: "Hello"}}, + }, nil, nil +} + +// Use SDK directly - no error return +sdkmcp.AddTool(server.Server(), tool, handler) +``` + +## Benefits + +1. **Official SDK**: Uses the official Model Context Protocol SDK +2. **Type Safety**: Go generics provide compile-time type checking +3. **Auto Schema**: JSON schemas generated automatically from struct tags +4. **Dual Transport**: Supports both SSE and Streamable HTTP transports +5. **Maintained**: SDK is actively maintained by the MCP team + +## Breaking Changes + +1. `server.AddTool()` removed → use `sdkmcp.AddTool(server.Server(), ...)` +2. `server.AddPrompt()` removed (SDK v1.2.0 limitation) +3. `server.AddResource()` removed (SDK v1.2.0 limitation) +4. Config fields `ProtocolVersion` and `BaseUrl` removed +5. All types now come from SDK (re-exported for convenience) + +## Migration Checklist + +- [ ] Update imports: add `sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"` +- [ ] Replace `server.AddTool()` with `sdkmcp.AddTool(server.Server(), ...)` +- [ ] Remove error handling for tool registration (SDK doesn't return errors) +- [ ] Update config: remove `ProtocolVersion` and `BaseUrl`, add `UseStreamable` +- [ ] Test with both SSE and Streamable transports +- [ ] Update documentation/examples diff --git a/mcp/config.go b/mcp/config.go index 4c404a14b..b796d9d50 100644 --- a/mcp/config.go +++ b/mcp/config.go @@ -18,17 +18,16 @@ type McpConf struct { // Version is the server version reported in initialize responses Version string `json:",default=1.0.0"` - // ProtocolVersion is the MCP protocol version implemented - ProtocolVersion string `json:",default=2024-11-05"` - - // BaseUrl is the base URL for the server, used in SSE endpoint messages - // If not set, defaults to http://localhost:{Port} - BaseUrl string `json:",optional"` + // UseStreamable when true uses Streamable HTTP transport (2025-03-26 spec), + // otherwise uses SSE transport (2024-11-05 spec) + UseStreamable bool `json:",default=false"` // SseEndpoint is the path for Server-Sent Events connections + // Used for SSE transport mode SseEndpoint string `json:",default=/sse"` // MessageEndpoint is the path for JSON-RPC requests + // Used for Streamable HTTP transport mode MessageEndpoint string `json:",default=/message"` // Cors contains allowed CORS origins diff --git a/mcp/config_test.go b/mcp/config_test.go index 5b9d13da3..fabc755bf 100644 --- a/mcp/config_test.go +++ b/mcp/config_test.go @@ -9,7 +9,7 @@ import ( ) func TestMcpConfDefaults(t *testing.T) { - // Test default values are set correctly when unmarshalled from JSON + // Test default values are set correctly jsonConfig := `name: test-service port: 8080 mcp: @@ -23,41 +23,8 @@ mcp: // Check default values assert.Equal(t, "test-mcp-server", c.Mcp.Name) - assert.Equal(t, "1.0.0", c.Mcp.Version, "Default version should be 1.0.0") - assert.Equal(t, "2024-11-05", c.Mcp.ProtocolVersion, "Default protocol version should be 2024-11-05") - assert.Equal(t, "/sse", c.Mcp.SseEndpoint, "Default SSE endpoint should be /sse") - assert.Equal(t, "/message", c.Mcp.MessageEndpoint, "Default message endpoint should be /message") - assert.Equal(t, 30*time.Second, c.Mcp.MessageTimeout, "Default message timeout should be 30s") -} - -func TestMcpConfCustomValues(t *testing.T) { - // Test custom values can be set - jsonConfig := `{ - "Name": "test-service", - "Port": 8080, - "Mcp": { - "Name": "test-mcp-server", - "Version": "2.0.0", - "ProtocolVersion": "2025-01-01", - "BaseUrl": "http://example.com", - "SseEndpoint": "/custom-sse", - "MessageEndpoint": "/custom-message", - "Cors": ["http://localhost:3000", "http://example.com"], - "MessageTimeout": "60s" - } - }` - - var c McpConf - err := conf.LoadFromJsonBytes([]byte(jsonConfig), &c) - assert.NoError(t, err) - - // Check custom values - assert.Equal(t, "test-mcp-server", c.Mcp.Name, "Name should be inherited from RestConf") - assert.Equal(t, "2.0.0", c.Mcp.Version, "Version should be customizable") - assert.Equal(t, "2025-01-01", c.Mcp.ProtocolVersion, "Protocol version should be customizable") - assert.Equal(t, "http://example.com", c.Mcp.BaseUrl, "BaseUrl should be customizable") - assert.Equal(t, "/custom-sse", c.Mcp.SseEndpoint, "SSE endpoint should be customizable") - assert.Equal(t, "/custom-message", c.Mcp.MessageEndpoint, "Message endpoint should be customizable") - assert.Equal(t, []string{"http://localhost:3000", "http://example.com"}, c.Mcp.Cors, "CORS settings should be customizable") - assert.Equal(t, 60*time.Second, c.Mcp.MessageTimeout, "Tool timeout should be customizable") + assert.Equal(t, "1.0.0", c.Mcp.Version) + assert.Equal(t, "/sse", c.Mcp.SseEndpoint) + assert.Equal(t, "/message", c.Mcp.MessageEndpoint) + assert.Equal(t, 30*time.Second, c.Mcp.MessageTimeout) } diff --git a/mcp/integration_test.go b/mcp/integration_test.go deleted file mode 100644 index ed28b72a3..000000000 --- a/mcp/integration_test.go +++ /dev/null @@ -1,443 +0,0 @@ -package mcp - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// syncResponseRecorder is a thread-safe wrapper around httptest.ResponseRecorder -type syncResponseRecorder struct { - *httptest.ResponseRecorder - mu sync.Mutex -} - -// Create a new synchronized response recorder -func newSyncResponseRecorder() *syncResponseRecorder { - return &syncResponseRecorder{ - ResponseRecorder: httptest.NewRecorder(), - } -} - -// Override Write method to synchronize access -func (srr *syncResponseRecorder) Write(p []byte) (int, error) { - srr.mu.Lock() - defer srr.mu.Unlock() - return srr.ResponseRecorder.Write(p) -} - -// Override WriteHeader method to synchronize access -func (srr *syncResponseRecorder) WriteHeader(statusCode int) { - srr.mu.Lock() - defer srr.mu.Unlock() - srr.ResponseRecorder.WriteHeader(statusCode) -} - -// Override Result method to synchronize access -func (srr *syncResponseRecorder) Result() *http.Response { - srr.mu.Lock() - defer srr.mu.Unlock() - return srr.ResponseRecorder.Result() -} - -// TestHTTPHandlerIntegration tests the HTTP handlers with a real server instance -func TestHTTPHandlerIntegration(t *testing.T) { - // Skip in short test mode - if testing.Short() { - t.Skip("Skipping integration test in short mode") - } - - // Create a test configuration - conf := McpConf{} - conf.Mcp.Name = "test-integration" - conf.Mcp.Version = "1.0.0-test" - conf.Mcp.MessageTimeout = 1 * time.Second - - // Create a mock server directly - server := &sseMcpServer{ - conf: conf, - clients: make(map[string]*mcpClient), - tools: make(map[string]Tool), - prompts: make(map[string]Prompt), - resources: make(map[string]Resource), - } - - // Register a test tool - err := server.RegisterTool(Tool{ - Name: "echo", - Description: "Echo tool for testing", - InputSchema: InputSchema{ - Properties: map[string]any{ - "message": map[string]any{ - "type": "string", - "description": "Message to echo", - }, - }, - }, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - if msg, ok := params["message"].(string); ok { - return fmt.Sprintf("Echo: %s", msg), nil - } - return "Echo: no message provided", nil - }, - }) - require.NoError(t, err) - - // Create a test HTTP request to the SSE endpoint - req := httptest.NewRequest("GET", "/sse", nil) - w := newSyncResponseRecorder() - - // Create a done channel to signal completion of test - done := make(chan bool) - - // Start the SSE handler in a goroutine - go func() { - // lock.Lock() - server.handleSSE(w, req) - // lock.Unlock() - done <- true - }() - - // Allow time for the handler to process - select { - case <-time.After(100 * time.Millisecond): - // Expected - handler would normally block indefinitely - case <-done: - // This shouldn't happen immediately - the handler should block - t.Error("SSE handler returned unexpectedly") - } - - // Check the initial headers - resp := w.Result() - assert.Equal(t, "chunked", resp.Header.Get("Transfer-Encoding")) - resp.Body.Close() - - // The handler creates a client and sends the endpoint message - var sessionId string - - // Give the handler time to set up the client - time.Sleep(50 * time.Millisecond) - - // Check that a client was created - server.clientsLock.Lock() - assert.Equal(t, 1, len(server.clients)) - for id := range server.clients { - sessionId = id - } - server.clientsLock.Unlock() - - require.NotEmpty(t, sessionId, "Expected a session ID to be created") - - // Now that we have a session ID, we can test the message endpoint - messageBody, _ := json.Marshal(Request{ - JsonRpc: "2.0", - ID: 1, - Method: methodInitialize, - Params: json.RawMessage(`{}`), - }) - - // Create a message request - reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, sessionId) - msgReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(messageBody)) - msgW := newSyncResponseRecorder() - - // Process the message - server.handleRequest(msgW, msgReq) - - // Check the response - msgResp := msgW.Result() - assert.Equal(t, http.StatusAccepted, msgResp.StatusCode) - msgResp.Body.Close() // Ensure response body is closed -} - -// TestHandlerResponseFlow tests the flow of a full request/response cycle -func TestHandlerResponseFlow(t *testing.T) { - // Create a mock server for testing - server := &sseMcpServer{ - conf: McpConf{}, - clients: map[string]*mcpClient{ - "test-session": { - id: "test-session", - channel: make(chan string, 10), - initialized: true, - }, - }, - tools: make(map[string]Tool), - prompts: make(map[string]Prompt), - resources: make(map[string]Resource), - } - - // Register test resources - server.RegisterTool(Tool{ - Name: "test.tool", - Description: "Test tool", - InputSchema: InputSchema{Type: "object"}, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - return "tool result", nil - }, - }) - - server.RegisterPrompt(Prompt{ - Name: "test.prompt", - Description: "Test prompt", - }) - - server.RegisterResource(Resource{ - Name: "test.resource", - URI: "http://example.com", - Description: "Test resource", - }) - - // Create a request with session ID parameter - reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, "test-session") - - // Test tools/list request - toolsListBody, _ := json.Marshal(Request{ - JsonRpc: "2.0", - ID: 1, - Method: methodToolsList, - Params: json.RawMessage(`{}`), - }) - - toolsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(toolsListBody)) - toolsW := newSyncResponseRecorder() - - // Process the request - server.handleRequest(toolsW, toolsReq) - - // Check the response code - toolsResp := toolsW.Result() - assert.Equal(t, http.StatusAccepted, toolsResp.StatusCode) - toolsResp.Body.Close() - - // Check the channel message - client := server.clients["test-session"] - select { - case message := <-client.channel: - assert.Contains(t, message, `"tools":[{"name":"test.tool"`) - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for tools/list response") - } - - // Test prompts/list request - promptsListBody, _ := json.Marshal(Request{ - JsonRpc: "2.0", - ID: 2, - Method: methodPromptsList, - Params: json.RawMessage(`{}`), - }) - - promptsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(promptsListBody)) - promptsW := newSyncResponseRecorder() - - // Process the request - server.handleRequest(promptsW, promptsReq) - - // Check the response code - promptsResp := promptsW.Result() - assert.Equal(t, http.StatusAccepted, promptsResp.StatusCode) - promptsResp.Body.Close() - - // Check the channel message - select { - case message := <-client.channel: - assert.Contains(t, message, `"prompts":[{"name":"test.prompt"`) - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for prompts/list response") - } - - // Test resources/list request - resourcesListBody, _ := json.Marshal(Request{ - JsonRpc: "2.0", - ID: 3, - Method: methodResourcesList, - Params: json.RawMessage(`{}`), - }) - - resourcesReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(resourcesListBody)) - resourcesW := newSyncResponseRecorder() - - // Process the request - server.handleRequest(resourcesW, resourcesReq) - - // Check the response code - resourcesResp := resourcesW.Result() - assert.Equal(t, http.StatusAccepted, resourcesResp.StatusCode) - resourcesResp.Body.Close() - - // Check the channel message - select { - case message := <-client.channel: - assert.Contains(t, message, `"name":"test.resource"`) - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for resources/list response") - } -} - -// TestProcessListMethods tests the list processing methods with pagination -func TestProcessListMethods(t *testing.T) { - server := &sseMcpServer{ - tools: make(map[string]Tool), - prompts: make(map[string]Prompt), - resources: make(map[string]Resource), - } - - // Add some test data - for i := 1; i <= 5; i++ { - tool := Tool{ - Name: fmt.Sprintf("tool%d", i), - Description: fmt.Sprintf("Tool %d", i), - InputSchema: InputSchema{Type: "object"}, - } - server.tools[tool.Name] = tool - - prompt := Prompt{ - Name: fmt.Sprintf("prompt%d", i), - Description: fmt.Sprintf("Prompt %d", i), - } - server.prompts[prompt.Name] = prompt - - resource := Resource{ - Name: fmt.Sprintf("resource%d", i), - URI: fmt.Sprintf("http://example.com/%d", i), - Description: fmt.Sprintf("Resource %d", i), - } - server.resources[resource.Name] = resource - } - - // Create a test client - client := &mcpClient{ - id: "test-client", - channel: make(chan string, 10), - initialized: true, - } - - // Test processListTools - req := Request{ - JsonRpc: "2.0", - ID: 1, - Method: methodToolsList, - Params: json.RawMessage(`{"cursor": "", "_meta": {"progressToken": "token1"}}`), - } - - server.processListTools(context.Background(), client, req) - - // Read response - select { - case response := <-client.channel: - assert.Contains(t, response, `"tools":`) - assert.Contains(t, response, `"progressToken":"token1"`) - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for tools/list response") - } - - // Test processListPrompts - req.ID = 2 - req.Method = methodPromptsList - req.Params = json.RawMessage(`{"cursor": "next"}`) - server.processListPrompts(context.Background(), client, req) - - // Read response - select { - case response := <-client.channel: - assert.Contains(t, response, `"prompts":`) - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for prompts/list response") - } - - // Test processListResources - req.ID = 3 - req.Method = methodResourcesList - req.Params = json.RawMessage(`{"cursor": "next"}`) - server.processListResources(context.Background(), client, req) - - // Read response - select { - case response := <-client.channel: - assert.Contains(t, response, `"resources":`) - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for resources/list response") - } -} - -// TestErrorResponseHandling tests error handling in the server -func TestErrorResponseHandling(t *testing.T) { - server := &sseMcpServer{ - tools: make(map[string]Tool), - prompts: make(map[string]Prompt), - resources: make(map[string]Resource), - } - - // Create a test client - client := &mcpClient{ - id: "test-client", - channel: make(chan string, 10), - initialized: true, - } - - // Test invalid method - req := Request{ - JsonRpc: "2.0", - ID: 1, - Method: "invalid_method", - Params: json.RawMessage(`{}`), - } - - // Mock handleRequest by directly calling error handler - server.sendErrorResponse(context.Background(), client, req.ID, "Method not found", errCodeMethodNotFound) - - // Check response - select { - case response := <-client.channel: - assert.Contains(t, response, `"error":{"code":-32601,"message":"Method not found"}`) - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for error response") - } - - // Test invalid tool - toolReq := Request{ - JsonRpc: "2.0", - ID: 2, - Method: methodToolsCall, - Params: json.RawMessage(`{"name":"non_existent_tool"}`), - } - - // Call process method directly - server.processToolCall(context.Background(), client, toolReq) - - // Check response - select { - case response := <-client.channel: - assert.Contains(t, response, `"error":{"code":-32602,"message":"Tool 'non_existent_tool' not found"}`) - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for error response") - } - - // Test invalid prompt - promptReq := Request{ - JsonRpc: "2.0", - ID: 3, - Method: methodPromptsGet, - Params: json.RawMessage(`{"name":"non_existent_prompt"}`), - } - - // Call process method directly - server.processGetPrompt(context.Background(), client, promptReq) - - // Check response - select { - case response := <-client.channel: - assert.Contains(t, response, `"error":{"code":-32602,"message":"Prompt 'non_existent_prompt' not found"}`) - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for error response") - } -} diff --git a/mcp/parser.go b/mcp/parser.go deleted file mode 100644 index 45584ce17..000000000 --- a/mcp/parser.go +++ /dev/null @@ -1,23 +0,0 @@ -package mcp - -import ( - "fmt" - - "github.com/zeromicro/go-zero/core/mapping" -) - -// ParseArguments parses the arguments and populates the request object -func ParseArguments(args any, req any) error { - switch arguments := args.(type) { - case map[string]string: - m := make(map[string]any, len(arguments)) - for k, v := range arguments { - m[k] = v - } - return mapping.UnmarshalJsonMap(m, req, mapping.WithStringValues()) - case map[string]any: - return mapping.UnmarshalJsonMap(arguments, req) - default: - return fmt.Errorf("unsupported argument type: %T", arguments) - } -} diff --git a/mcp/parser_test.go b/mcp/parser_test.go deleted file mode 100644 index 43071a579..000000000 --- a/mcp/parser_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package mcp - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -// TestParseArguments_MapStringString tests parsing map[string]string arguments -func TestParseArguments_MapStringString(t *testing.T) { - // Sample request struct to populate - type requestStruct struct { - Name string `json:"name"` - Message string `json:"message"` - Count int `json:"count"` - Enabled bool `json:"enabled"` - } - - // Create test arguments - args := map[string]string{ - "name": "test-name", - "message": "hello world", - "count": "42", - "enabled": "true", - } - - // Create a target object to populate - var req requestStruct - - // Parse the arguments - err := ParseArguments(args, &req) - - // Verify results - assert.NoError(t, err, "Should parse map[string]string without error") - assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed") - assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed") - assert.Equal(t, 42, req.Count, "Count should be correctly parsed to int") - assert.True(t, req.Enabled, "Enabled should be correctly parsed to bool") -} - -// TestParseArguments_MapStringAny tests parsing map[string]any arguments -func TestParseArguments_MapStringAny(t *testing.T) { - // Sample request struct to populate - type requestStruct struct { - Name string `json:"name"` - Message string `json:"message"` - Count int `json:"count"` - Enabled bool `json:"enabled"` - Tags []string `json:"tags"` - Metadata map[string]string `json:"metadata"` - } - - // Create test arguments with mixed types - args := map[string]any{ - "name": "test-name", - "message": "hello world", - "count": 42, // note: this is already an int - "enabled": true, // note: this is already a bool - "tags": []string{"tag1", "tag2"}, - "metadata": map[string]string{ - "key1": "value1", - "key2": "value2", - }, - } - - // Create a target object to populate - var req requestStruct - - // Parse the arguments - err := ParseArguments(args, &req) - - // Verify results - assert.NoError(t, err, "Should parse map[string]any without error") - assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed") - assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed") - assert.Equal(t, 42, req.Count, "Count should be correctly parsed") - assert.True(t, req.Enabled, "Enabled should be correctly parsed") - assert.Equal(t, []string{"tag1", "tag2"}, req.Tags, "Tags should be correctly parsed") - assert.Equal(t, map[string]string{ - "key1": "value1", - "key2": "value2", - }, req.Metadata, "Metadata should be correctly parsed") -} - -// TestParseArguments_UnsupportedType tests parsing with an unsupported type -func TestParseArguments_UnsupportedType(t *testing.T) { - // Sample request struct to populate - type requestStruct struct { - Name string `json:"name"` - Message string `json:"message"` - } - - // Use an unsupported argument type (slice) - args := []string{"not", "a", "map"} - - // Create a target object to populate - var req requestStruct - - // Parse the arguments - err := ParseArguments(args, &req) - - // Verify error is returned with correct message - assert.Error(t, err, "Should return error for unsupported type") - assert.Contains(t, err.Error(), "unsupported argument type", "Error should mention unsupported type") - assert.Contains(t, err.Error(), "[]string", "Error should include the actual type") -} - -// TestParseArguments_EmptyMap tests parsing with empty maps -func TestParseArguments_EmptyMap(t *testing.T) { - // Sample request struct to populate - type requestStruct struct { - Name string `json:"name,optional"` - Message string `json:"message,optional"` - } - - // Test empty map[string]string - t.Run("EmptyMapStringString", func(t *testing.T) { - args := map[string]string{} - var req requestStruct - - err := ParseArguments(args, &req) - - assert.NoError(t, err, "Should parse empty map[string]string without error") - assert.Empty(t, req.Name, "Name should be empty string") - assert.Empty(t, req.Message, "Message should be empty string") - }) - - // Test empty map[string]any - t.Run("EmptyMapStringAny", func(t *testing.T) { - args := map[string]any{} - var req requestStruct - - err := ParseArguments(args, &req) - - assert.NoError(t, err, "Should parse empty map[string]any without error") - assert.Empty(t, req.Name, "Name should be empty string") - assert.Empty(t, req.Message, "Message should be empty string") - }) -} diff --git a/mcp/readme.md b/mcp/readme.md index c4e9fd09e..1c0eec860 100644 --- a/mcp/readme.md +++ b/mcp/readme.md @@ -1,870 +1,252 @@ # Model Context Protocol (MCP) Implementation ## Overview -This package implements the Model Context Protocol (MCP) server specification in Go, providing a framework for real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation follows the standardized protocol for building AI-assisted applications with bidirectional communication capabilities. -## Core Components +This package provides a go-zero integration for the [Model Context Protocol (MCP)](https://modelcontextprotocol.io/) using the official [go-sdk](https://github.com/modelcontextprotocol/go-sdk). It wraps the official MCP SDK to provide a seamless integration with go-zero's REST server framework. -### Server-Sent Events (SSE) Communication -- **Real-time Communication**: Robust SSE-based communication system that maintains persistent connections with clients -- **Connection Management**: Client registration, message broadcasting, and client cleanup mechanisms -- **Event Handling**: Event types for tools, prompts, and resources changes +## Features -### JSON-RPC Implementation -- **Request Processing**: Complete JSON-RPC request processor for handling MCP protocol methods -- **Response Formatting**: Proper response formatting according to JSON-RPC specifications -- **Error Handling**: Comprehensive error handling with appropriate error codes - -### Tool Management -- **Tool Registration**: System to register custom tools with handlers -- **Tool Execution**: Mechanism to execute tool functions with proper timeout handling -- **Result Handling**: Flexible result handling supporting various return types (string, JSON, images) - -### Prompt System -- **Prompt Registration**: System for registering both static and dynamic prompts -- **Argument Validation**: Validation for required arguments and default values for optional ones -- **Message Generation**: Handlers that generate properly formatted conversation messages - -### Resource Management -- **Resource Registration**: System for managing and accessing external resources -- **Content Delivery**: Handlers for delivering resource content to clients on demand -- **Resource Subscription**: Mechanisms for clients to subscribe to resource updates - -### Protocol Features -- **Initialization Sequence**: Proper handshaking with capability negotiation -- **Notification Handling**: Support for both standard and client-specific notifications -- **Message Routing**: Intelligent routing of requests to appropriate handlers - -## Technical Highlights - -### Configuration System -- **Flexible Configuration**: Configuration system with sensible defaults and customization options +- **Official SDK Integration**: Built on top of the official Model Context Protocol Go SDK +- **No SDK Import Required**: Use `mcp.AddTool()` directly without importing the official SDK +- **go-zero Integration**: Seamlessly integrates with go-zero's REST server and configuration system +- **Dual Transport Support**: + - SSE (Server-Sent Events) transport for 2024-11-05 MCP spec + - Streamable HTTP transport for 2025-03-26 MCP spec - **CORS Support**: Configurable CORS settings for cross-origin requests -- **Server Information**: Proper server identification and versioning +- **Type-Safe Tool Handlers**: Generic tool handlers with automatic JSON schema generation +- **Prompts and Resources**: Full support for MCP prompts and resources -### Client Session Management -- **Session Tracking**: Client session tracking with unique identifiers -- **Connection Health**: Ping/pong mechanism to maintain connection health -- **Initialization State**: Client initialization state tracking +## Quick Start -### Content Handling -- **Multi-format Content**: Support for text, code, and binary content -- **MIME Type Support**: Proper MIME type identification for various content types -- **Audience Annotations**: Content audience annotations for user/assistant targeting +### 1. Installation -## Usage - -### Setting Up an MCP Server - -To create and start an MCP server: - -```go -package main - -import ( - "github.com/zeromicro/go-zero/core/conf" - "github.com/zeromicro/go-zero/core/logx" - "github.com/zeromicro/go-zero/mcp" -) - -func main() { - // Load configuration from YAML file - var c mcp.McpConf - conf.MustLoad("config.yaml", &c) - - // Optional: Disable stats logging - logx.DisableStat() - - // Create MCP server - server := mcp.NewMcpServer(c) - - // Register tools, prompts, and resources (examples below) - - // Start the server and ensure it's stopped on exit - defer server.Stop() - server.Start() -} +```bash +go get github.com/zeromicro/go-zero ``` -Sample configuration file (config.yaml): +**Note**: The official MCP SDK is a transitive dependency and will be installed automatically. You don't need to import it directly in your code. + +### 2. Configuration + +Create a configuration file `config.yaml`: ```yaml -name: mcp-server +name: my-mcp-server host: localhost port: 8080 mcp: name: my-mcp-server - messageTimeout: 30s # Timeout for tool calls + version: 1.0.0 + useStreamable: false # Use SSE transport (default), set to true for Streamable HTTP + sseEndpoint: /sse + messageEndpoint: /message + sseTimeout: 24h + messageTimeout: 30s cors: - - http://localhost:3000 # Optional CORS configuration + - http://localhost:3000 ``` -### Registering Tools - -Tools allow AI models to execute custom code through the MCP protocol. - -#### Basic Tool Example: - -```go -// Register a simple echo tool -echoTool := mcp.Tool{ - Name: "echo", - Description: "Echoes back the message provided by the user", - InputSchema: mcp.InputSchema{ - Properties: map[string]any{ - "message": map[string]any{ - "type": "string", - "description": "The message to echo back", - }, - "prefix": map[string]any{ - "type": "string", - "description": "Optional prefix to add to the echoed message", - "default": "Echo: ", - }, - }, - Required: []string{"message"}, - }, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - var req struct { - Message string `json:"message"` - Prefix string `json:"prefix,optional"` - } - - if err := mcp.ParseArguments(params, &req); err != nil { - return nil, fmt.Errorf("failed to parse params: %w", err) - } - - prefix := "Echo: " - if len(req.Prefix) > 0 { - prefix = req.Prefix - } - - return prefix + req.Message, nil - }, -} - -server.RegisterTool(echoTool) -``` - -#### Tool with Different Response Types: - -```go -// Tool returning JSON data -dataTool := mcp.Tool{ - Name: "data.generate", - Description: "Generates sample data in various formats", - InputSchema: mcp.InputSchema{ - Properties: map[string]any{ - "format": map[string]any{ - "type": "string", - "description": "Format of data (json, text)", - "enum": []string{"json", "text"}, - }, - }, - }, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - var req struct { - Format string `json:"format"` - } - - if err := mcp.ParseArguments(params, &req); err != nil { - return nil, fmt.Errorf("failed to parse params: %w", err) - } - - if req.Format == "json" { - // Return structured data - return map[string]any{ - "items": []map[string]any{ - {"id": 1, "name": "Item 1"}, - {"id": 2, "name": "Item 2"}, - }, - "count": 2, - }, nil - } - - // Default to text - return "Sample text data", nil - }, -} - -server.RegisterTool(dataTool) -``` - -#### Image Generation Tool Example: - -```go -// Tool returning image content -imageTool := mcp.Tool{ - Name: "image.generate", - Description: "Generates a simple image", - InputSchema: mcp.InputSchema{ - Properties: map[string]any{ - "type": map[string]any{ - "type": "string", - "description": "Type of image to generate", - "default": "placeholder", - }, - }, - }, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - // Return image content directly - return mcp.ImageContent{ - Data: "base64EncodedImageData...", // Base64 encoded image data - MimeType: "image/png", - }, nil - }, -} - -server.RegisterTool(imageTool) -``` - -#### Using ToolResult for Custom Outputs: - -```go -// Tool that returns a custom ToolResult type -customResultTool := mcp.Tool{ - Name: "custom.result", - Description: "Returns a custom formatted result", - InputSchema: mcp.InputSchema{ - Properties: map[string]any{ - "resultType": map[string]any{ - "type": "string", - "enum": []string{"text", "image"}, - }, - }, - }, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - var req struct { - ResultType string `json:"resultType"` - } - - if err := mcp.ParseArguments(params, &req); err != nil { - return nil, fmt.Errorf("failed to parse params: %w", err) - } - - if req.ResultType == "image" { - return mcp.ToolResult{ - Type: mcp.ContentTypeImage, - Content: map[string]any{ - "data": "base64EncodedImageData...", - "mimeType": "image/jpeg", - }, - }, nil - } - - // Default to text - return mcp.ToolResult{ - Type: mcp.ContentTypeText, - Content: "This is a text result from ToolResult", - }, nil - }, -} - -server.RegisterTool(customResultTool) -``` - -### Registering Prompts - -Prompts are reusable conversation templates for AI models. - -#### Static Prompt Example: - -```go -// Register a simple static prompt with placeholders -server.RegisterPrompt(mcp.Prompt{ - Name: "hello", - Description: "A simple hello prompt", - Arguments: []mcp.PromptArgument{ - { - Name: "name", - Description: "The name to greet", - Required: false, - }, - }, - Content: "Say hello to {{name}} and introduce yourself as an AI assistant.", -}) -``` - -#### Dynamic Prompt with Handler Function: - -```go -// Register a prompt with a dynamic handler function -server.RegisterPrompt(mcp.Prompt{ - Name: "dynamic-prompt", - Description: "A prompt that uses a handler to generate dynamic content", - Arguments: []mcp.PromptArgument{ - { - Name: "username", - Description: "User's name for personalized greeting", - Required: true, - }, - { - Name: "topic", - Description: "Topic of expertise", - Required: true, - }, - }, - Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) { - var req struct { - Username string `json:"username"` - Topic string `json:"topic"` - } - - if err := mcp.ParseArguments(args, &req); err != nil { - return nil, fmt.Errorf("failed to parse args: %w", err) - } - - // Create a user message - userMessage := mcp.PromptMessage{ - Role: mcp.RoleUser, - Content: mcp.TextContent{ - Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic), - }, - } - - // Create an assistant response with current time - currentTime := time.Now().Format(time.RFC1123) - assistantMessage := mcp.PromptMessage{ - Role: mcp.RoleAssistant, - Content: mcp.TextContent{ - Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.", - req.Username, req.Topic, currentTime), - }, - } - - // Return both messages as a conversation - return []mcp.PromptMessage{userMessage, assistantMessage}, nil - }, -}) -``` - -#### Multi-Message Prompt with Code Examples: - -```go -// Register a prompt that provides code examples in different programming languages -server.RegisterPrompt(mcp.Prompt{ - Name: "code-example", - Description: "Provides code examples in different programming languages", - Arguments: []mcp.PromptArgument{ - { - Name: "language", - Description: "Programming language for the example", - Required: true, - }, - { - Name: "complexity", - Description: "Complexity level (simple, medium, advanced)", - }, - }, - Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) { - var req struct { - Language string `json:"language"` - Complexity string `json:"complexity,optional"` - } - - if err := mcp.ParseArguments(args, &req); err != nil { - return nil, fmt.Errorf("failed to parse args: %w", err) - } - - // Validate language - supportedLanguages := map[string]bool{"go": true, "python": true, "javascript": true, "rust": true} - if !supportedLanguages[req.Language] { - return nil, fmt.Errorf("unsupported language: %s", req.Language) - } - - // Generate code example based on language and complexity - var codeExample string - - switch req.Language { - case "go": - if req.Complexity == "simple" { - codeExample = ` -package main - -import "fmt" - -func main() { - fmt.Println("Hello, World!") -}` - } else { - codeExample = ` -package main - -import ( - "fmt" - "time" -) - -func main() { - now := time.Now() - fmt.Printf("Hello, World! Current time is %s\n", now.Format(time.RFC3339)) -}` - } - case "python": - // Python example code - if req.Complexity == "simple" { - codeExample = ` -def greet(name): - return f"Hello, {name}!" - -print(greet("World"))` - } else { - codeExample = ` -import datetime - -def greet(name, include_time=False): - message = f"Hello, {name}!" - if include_time: - message += f" Current time is {datetime.datetime.now().isoformat()}" - return message - -print(greet("World", include_time=True))` - } - } - - // Create messages array according to MCP spec - messages := []mcp.PromptMessage{ - { - Role: mcp.RoleAssistant, - Content: mcp.TextContent{ - Text: fmt.Sprintf("You are a helpful coding assistant specialized in %s programming.", req.Language), - }, - }, - { - Role: mcp.RoleUser, - Content: mcp.TextContent{ - Text: fmt.Sprintf("Show me a %s example of a Hello World program in %s.", req.Complexity, req.Language), - }, - }, - { - Role: mcp.RoleAssistant, - Content: mcp.TextContent{ - Text: fmt.Sprintf("Here's a %s example in %s:\n\n```%s%s\n```\n\nHow can I help you implement this?", - req.Complexity, req.Language, req.Language, codeExample), - }, - }, - } - - return messages, nil - }, -}) -``` - -### Registering Resources - -Resources provide access to external content such as files or generated data. - -#### Basic Resource Example: - -```go -// Register a static resource -server.RegisterResource(mcp.Resource{ - Name: "example-document", - URI: "file:///example/document.txt", - Description: "An example document", - MimeType: "text/plain", - Handler: func(ctx context.Context) (mcp.ResourceContent, error) { - return mcp.ResourceContent{ - URI: "file:///example/document.txt", - MimeType: "text/plain", - Text: "This is an example document content.", - }, nil - }, -}) -``` - -#### Dynamic Resource with Code Example: - -```go -// Register a Go code resource with dynamic handler -server.RegisterResource(mcp.Resource{ - Name: "go-example", - URI: "file:///project/src/main.go", - Description: "A simple Go example with multiple files", - MimeType: "text/x-go", - Handler: func(ctx context.Context) (mcp.ResourceContent, error) { - // Return ResourceContent with all required fields - return mcp.ResourceContent{ - URI: "file:///project/src/main.go", - MimeType: "text/x-go", - Text: "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}", - }, nil - }, -}) - -// Register a companion file for the above example -server.RegisterResource(mcp.Resource{ - Name: "go-greeting", - URI: "file:///project/src/greeting/greeting.go", - Description: "A greeting package for the Go example", - MimeType: "text/x-go", - Handler: func(ctx context.Context) (mcp.ResourceContent, error) { - return mcp.ResourceContent{ - URI: "file:///project/src/greeting/greeting.go", - MimeType: "text/x-go", - Text: "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}", - }, nil - }, -}) -``` - -#### Binary Resource Example: - -```go -// Register a binary resource (like an image) -server.RegisterResource(mcp.Resource{ - Name: "example-image", - URI: "file:///example/image.png", - Description: "An example image", - MimeType: "image/png", - Handler: func(ctx context.Context) (mcp.ResourceContent, error) { - // Read image from file or generate it - imageData := "base64EncodedImageData..." // Base64 encoded image data - - return mcp.ResourceContent{ - URI: "file:///example/image.png", - MimeType: "image/png", - Blob: imageData, // For binary data - }, nil - }, -}) -``` - -### Using Resources in Prompts - -You can embed resources in prompt responses to create rich interactions with proper MCP-compliant structure: - -```go -// Register a prompt that embeds a resource -server.RegisterPrompt(mcp.Prompt{ - Name: "resource-example", - Description: "A prompt that embeds a resource", - Arguments: []mcp.PromptArgument{ - { - Name: "file_type", - Description: "Type of file to show (rust or go)", - Required: true, - }, - }, - Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) { - var req struct { - FileType string `json:"file_type"` - } - - if err := mcp.ParseArguments(args, &req); err != nil { - return nil, fmt.Errorf("failed to parse args: %w", err) - } - - var resourceURI, mimeType, fileContent string - if req.FileType == "rust" { - resourceURI = "file:///project/src/main.rs" - mimeType = "text/x-rust" - fileContent = "fn main() {\n println!(\"Hello world!\");\n}" - } else { - resourceURI = "file:///project/src/main.go" - mimeType = "text/x-go" - fileContent = "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello, world!\")\n}" - } - - // Create message with embedded resource using proper MCP format - return []mcp.PromptMessage{ - { - Role: mcp.RoleUser, - Content: mcp.TextContent{ - Text: fmt.Sprintf("Can you explain this %s code?", req.FileType), - }, - }, - { - Role: mcp.RoleAssistant, - Content: mcp.EmbeddedResource{ - Type: mcp.ContentTypeResource, - Resource: struct { - URI string `json:"uri"` - MimeType string `json:"mimeType"` - Text string `json:"text,omitempty"` - Blob string `json:"blob,omitempty"` - }{ - URI: resourceURI, - MimeType: mimeType, - Text: fileContent, - }, - }, - }, - { - Role: mcp.RoleAssistant, - Content: mcp.TextContent{ - Text: fmt.Sprintf("Above is a simple Hello World example in %s. Let me explain how it works.", req.FileType), - }, - }, - }, nil - }, -}) -``` - -### Multiple File Resources Example - -```go -// Register a prompt that demonstrates embedding multiple resource files -server.RegisterPrompt(mcp.Prompt{ - Name: "go-code-example", - Description: "A prompt that correctly embeds multiple resource files", - Arguments: []mcp.PromptArgument{ - { - Name: "format", - Description: "How to format the code display", - }, - }, - Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) { - var req struct { - Format string `json:"format,optional"` - } - - if err := mcp.ParseArguments(args, &req); err != nil { - return nil, fmt.Errorf("failed to parse args: %w", err) - } - - // Get the Go code for multiple files - var mainGoText string = "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}" - var greetingGoText string = "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}" - - // Create message with properly formatted embedded resource per MCP spec - messages := []mcp.PromptMessage{ - { - Role: mcp.RoleUser, - Content: mcp.TextContent{ - Text: "Show me a simple Go example with proper imports.", - }, - }, - { - Role: mcp.RoleAssistant, - Content: mcp.TextContent{ - Text: "Here's a simple Go example project:", - }, - }, - { - Role: mcp.RoleAssistant, - Content: mcp.EmbeddedResource{ - Type: mcp.ContentTypeResource, - Resource: struct { - URI string `json:"uri"` - MimeType string `json:"mimeType"` - Text string `json:"text,omitempty"` - Blob string `json:"blob,omitempty"` - }{ - URI: "file:///project/src/main.go", - MimeType: "text/x-go", - Text: mainGoText, - }, - }, - }, - } - - // Add explanation and additional file if requested - if req.Format == "with_explanation" { - messages = append(messages, mcp.PromptMessage{ - Role: mcp.RoleAssistant, - Content: mcp.TextContent{ - Text: "This example demonstrates a simple Go application with modular structure. The main.go file imports from a local 'greeting' package that provides the Hello function.", - }, - }) - - // Also show the greeting.go file with correct resource format - messages = append(messages, mcp.PromptMessage{ - Role: mcp.RoleAssistant, - Content: mcp.EmbeddedResource{ - Type: mcp.ContentTypeResource, - Resource: struct { - URI string `json:"uri"` - MimeType string `json:"mimeType"` - Text string `json:"text,omitempty"` - Blob string `json:"blob,omitempty"` - }{ - URI: "file:///project/src/greeting/greeting.go", - MimeType: "text/x-go", - Text: greetingGoText, - }, - }, - }) - } - - return messages, nil - }, -}) -``` - -### Complete Application Example - -Here's a complete example demonstrating all the components: +### 3. Create Your Server ```go package main import ( "context" - "fmt" "log" - "time" "github.com/zeromicro/go-zero/core/conf" - "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/mcp" ) +type GreetArgs struct { + Name string `json:"name" jsonschema:"description=Name of the person to greet"` +} + func main() { // Load configuration var c mcp.McpConf - if err := conf.Load("config.yaml", &c); err != nil { - log.Fatalf("Failed to load config: %v", err) - } - - // Set up logging - logx.DisableStat() + conf.MustLoad("config.yaml", &c) // Create MCP server server := mcp.NewMcpServer(c) - defer server.Stop() - // Register a simple echo tool - echoTool := mcp.Tool{ - Name: "echo", - Description: "Echoes back the message provided by the user", - InputSchema: mcp.InputSchema{ - Properties: map[string]any{ - "message": map[string]any{ - "type": "string", - "description": "The message to echo back", - }, - "prefix": map[string]any{ - "type": "string", - "description": "Optional prefix to add to the echoed message", - "default": "Echo: ", - }, - }, - Required: []string{"message"}, - }, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - var req struct { - Message string `json:"message"` - Prefix string `json:"prefix,optional"` - } - - if err := mcp.ParseArguments(params, &req); err != nil { - return nil, fmt.Errorf("failed to parse args: %w", err) - } - - prefix := "Echo: " - if len(req.Prefix) > 0 { - prefix = req.Prefix - } - - return prefix + req.Message, nil - }, + // Register a tool with automatic schema generation using the SDK directly + tool := &mcp.Tool{ + Name: "greet", + Description: "Greet someone by name", } - server.RegisterTool(echoTool) - // Register a static prompt - server.RegisterPrompt(mcp.Prompt{ - Name: "greeting", - Description: "A simple greeting prompt", - Arguments: []mcp.PromptArgument{ - { - Name: "name", - Description: "The name to greet", - Required: true, + handler := func(ctx context.Context, req *mcp.CallToolRequest, args GreetArgs) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Hello, " + args.Name + "!"}, }, - }, - Content: "Hello {{name}}! How can I assist you today?", - }) + }, nil, nil + } - // Register a dynamic prompt - server.RegisterPrompt(mcp.Prompt{ - Name: "dynamic-prompt", - Description: "A prompt that uses a handler to generate dynamic content", - Arguments: []mcp.PromptArgument{ - { - Name: "username", - Description: "User's name for personalized greeting", - Required: true, - }, - { - Name: "topic", - Description: "Topic of expertise", - Required: true, - }, - }, - Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) { - var req struct { - Username string `json:"username"` - Topic string `json:"topic"` - } + // Register tool with type-safe generics - no need to import official SDK + mcp.AddTool(server, tool, handler) - if err := mcp.ParseArguments(args, &req); err != nil { - return nil, fmt.Errorf("failed to parse args: %w", err) - } - - // Create messages with current time - currentTime := time.Now().Format(time.RFC1123) - return []mcp.PromptMessage{ - { - Role: mcp.RoleUser, - Content: mcp.TextContent{ - Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic), - }, - }, - { - Role: mcp.RoleAssistant, - Content: mcp.TextContent{ - Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.", - req.Username, req.Topic, currentTime), - }, - }, - }, nil - }, - }) - - // Register a resource - server.RegisterResource(mcp.Resource{ - Name: "example-doc", - URI: "file:///example/doc.txt", - Description: "An example document", - MimeType: "text/plain", - Handler: func(ctx context.Context) (mcp.ResourceContent, error) { - return mcp.ResourceContent{ - URI: "file:///example/doc.txt", - MimeType: "text/plain", - Text: "This is the content of the example document.", - }, nil - }, - }) - - // Start the server - fmt.Printf("Starting MCP server on %s:%d\n", c.Host, c.Port) + // Start server + defer server.Stop() server.Start() } ``` -## Error Handling +## Adding Tools -The MCP implementation provides comprehensive error handling: +Tools are functions that the MCP client can call. The SDK automatically generates JSON schemas from your struct tags. Use `sdkmcp.AddTool` with the server's underlying SDK server: -- Tool execution errors are properly reported back to clients -- Missing or invalid parameters are detected and reported with appropriate error codes -- Resource and prompt lookup failures are handled gracefully -- Timeout handling for long-running tool executions using context -- Panic recovery to prevent server crashes +```go +type CalculateArgs struct { + Operation string `json:"operation" jsonschema:"enum=add,enum=subtract,enum=multiply,enum=divide"` + A float64 `json:"a" jsonschema:"description=First number"` + B float64 `json:"b" jsonschema:"description=Second number"` +} -## Advanced Features +tool := &mcp.Tool{ + Name: "calculate", + Description: "Perform arithmetic operations", +} -- **Annotations**: Add audience and priority metadata to content -- **Content Types**: Support for text, images, audio, and other content formats -- **Embedded Resources**: Include file resources directly in prompt responses -- **Context Awareness**: All handlers receive context.Context for timeout and cancellation support -- **Progress Tokens**: Support for tracking progress of long-running operations -- **Customizable Timeouts**: Configure execution timeouts for tools and operations +handler := func(ctx context.Context, req *mcp.CallToolRequest, args CalculateArgs) (*mcp.CallToolResult, any, error) { + var result float64 + switch args.Operation { + case "add": + result = args.A + args.B + case "subtract": + result = args.A - args.B + case "multiply": + result = args.A * args.B + case "divide": + if args.B == 0 { + return &mcp.CallToolResult{IsError: true}, nil, fmt.Errorf("division by zero") + } + result = args.A / args.B + } -## Performance Considerations + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: fmt.Sprintf("Result: %v", result)}, + }, + }, result, nil +} -- Tool execution runs with configurable timeouts to prevent blocking -- Efficient client tracking and cleanup to prevent resource leaks -- Proper concurrency handling with mutex protection for shared resources -- Buffered message channels to prevent blocking on client message delivery +// Register tool +mcp.AddTool(server, tool, handler) +``` + +## Adding Prompts + +Prompts provide reusable message templates: + +```go +prompt := &mcp.Prompt{ + Name: "code-review", + Description: "Review code for best practices", +} + +handler := func(ctx context.Context, req *sdkmcp.GetPromptRequest, args map[string]string) (*mcp.GetPromptResult, error) { + code := args["code"] + language := args["language"] + + return &mcp.GetPromptResult{ + Messages: []mcp.PromptMessage{ + { + Role: "user", + Content: &mcp.TextContent{ + Text: fmt.Sprintf("Please review this %s code:\n\n%s", language, code), + }, + }, + }, + }, nil +} + +server.AddPrompt(prompt, handler) +``` + +## Adding Resources + +Resources provide access to data that the model can read: + +```go +resource := &mcp.Resource{ + URI: "file:///docs/readme.md", + Name: "README", + Description: "Project documentation", + MimeType: "text/markdown", +} + +handler := func(ctx context.Context, req *sdkmcp.ReadResourceRequest, uri string) (*mcp.ReadResourceResult, error) { + content, err := os.ReadFile("README.md") + if err != nil { + return nil, err + } + + return &mcp.ReadResourceResult{ + Contents: []mcp.ResourceContents{ + { + URI: uri, + MimeType: "text/markdown", + Text: string(content), + }, + }, + }, nil +} + +server.AddResource(resource, handler) +``` + +## Transport Options + +### SSE Transport (Default) + +The SSE (Server-Sent Events) transport is the original MCP transport from the 2024-11-05 specification: + +```yaml +mcp: + useStreamable: false + sseEndpoint: /sse +``` + +### Streamable HTTP Transport + +The newer Streamable HTTP transport from the 2025-03-26 specification provides better connection management: + +```yaml +mcp: + useStreamable: true + messageEndpoint: /message +``` + +## Configuration Options + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `name` | string | | Server name (required, from RestConf) | +| `host` | string | | Server host (required, from RestConf) | +| `port` | int | | Server port (required, from RestConf) | +| `mcp.name` | string | | MCP server name (defaults to `name`) | +| `mcp.version` | string | `1.0.0` | Server version | +| `mcp.useStreamable` | bool | `false` | Use Streamable HTTP transport instead of SSE | +| `mcp.sseEndpoint` | string | `/sse` | SSE endpoint path | +| `mcp.messageEndpoint` | string | `/message` | Message endpoint path | +| `mcp.sseTimeout` | duration | `24h` | SSE connection timeout | +| `mcp.messageTimeout` | duration | `30s` | Message processing timeout | +| `mcp.cors` | []string | | Allowed CORS origins | + +## Examples + +See the `adhoc/mcp` directory for a complete working example. + +## Official SDK Documentation + +For more details on the underlying MCP SDK, see: +- [Official Go SDK](https://github.com/modelcontextprotocol/go-sdk) +- [MCP Specification](https://modelcontextprotocol.io/) +- [SDK Documentation](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp) + +## License + +This implementation follows the go-zero project license (MIT). diff --git a/mcp/server.go b/mcp/server.go index 8897b4b91..3f97a87a4 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1,940 +1,124 @@ package mcp import ( - "context" - "encoding/json" - "fmt" "net/http" - "strings" - "time" - "github.com/google/uuid" + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/rest" ) +// McpServer defines the interface for Model Context Protocol servers using the official SDK +type McpServer interface { + // Start starts the HTTP server + Start() + // Stop stops the HTTP server + Stop() +} + +type mcpServerImpl struct { + conf McpConf + httpServer *rest.Server + mcpServer *sdkmcp.Server +} + +// NewMcpServer creates a new MCP server using the official SDK func NewMcpServer(c McpConf) McpServer { - var server *rest.Server + // Create the underlying rest HTTP server + var httpServer *rest.Server if len(c.Mcp.Cors) == 0 { - server = rest.MustNewServer(c.RestConf) + httpServer = rest.MustNewServer(c.RestConf) } else { - server = rest.MustNewServer(c.RestConf, rest.WithCors(c.Mcp.Cors...)) + httpServer = rest.MustNewServer(c.RestConf, rest.WithCors(c.Mcp.Cors...)) } + // Set defaults if len(c.Mcp.Name) == 0 { c.Mcp.Name = c.Name } - if len(c.Mcp.BaseUrl) == 0 { - c.Mcp.BaseUrl = fmt.Sprintf("http://localhost:%d", c.Port) + if len(c.Mcp.Version) == 0 { + c.Mcp.Version = "1.0.0" } - s := &sseMcpServer{ - conf: c, - server: server, - clients: make(map[string]*mcpClient), - tools: make(map[string]Tool), - prompts: make(map[string]Prompt), - resources: make(map[string]Resource), + // Create the MCP SDK server + impl := &sdkmcp.Implementation{ + Name: c.Mcp.Name, + Version: c.Mcp.Version, } - // SSE endpoint for real-time updates - s.server.AddRoute(rest.Route{ - Method: http.MethodGet, - Path: s.conf.Mcp.SseEndpoint, - Handler: s.handleSSE, - }, rest.WithSSE(), rest.WithTimeout(c.Mcp.SseTimeout)) + mcpServer := sdkmcp.NewServer(impl, nil) - // JSON-RPC message endpoint for regular requests - s.server.AddRoute(rest.Route{ - Method: http.MethodPost, - Path: s.conf.Mcp.MessageEndpoint, - Handler: s.handleRequest, - }, rest.WithTimeout(c.Mcp.MessageTimeout)) + s := &mcpServerImpl{ + conf: c, + httpServer: httpServer, + mcpServer: mcpServer, + } + + // Choose transport based on configuration + if c.Mcp.UseStreamable { + s.setupStreamableTransport() + } else { + s.setupSSETransport() + } return s } -// RegisterPrompt registers a new prompt with the server -func (s *sseMcpServer) RegisterPrompt(prompt Prompt) { - s.promptsLock.Lock() - s.prompts[prompt.Name] = prompt - s.promptsLock.Unlock() - // Notify clients about the new prompt - s.broadcast(eventPromptsListChanged, map[string][]Prompt{keyPrompts: {prompt}}) -} - -// RegisterResource registers a new resource with the server -func (s *sseMcpServer) RegisterResource(resource Resource) { - s.resourcesLock.Lock() - s.resources[resource.URI] = resource - s.resourcesLock.Unlock() - // Notify clients about the new resource - s.broadcast(eventResourcesListChanged, map[string][]Resource{keyResources: {resource}}) -} - -// RegisterTool registers a new tool with the server -func (s *sseMcpServer) RegisterTool(tool Tool) error { - if tool.Handler == nil { - return fmt.Errorf("tool '%s' has no handler function", tool.Name) - } - - s.toolsLock.Lock() - s.tools[tool.Name] = tool - s.toolsLock.Unlock() - // Notify clients about the new tool - s.broadcast(eventToolsListChanged, map[string][]Tool{keyTools: {tool}}) - return nil -} - // Start implements McpServer. -func (s *sseMcpServer) Start() { - s.server.Start() +func (s *mcpServerImpl) Start() { + logx.Infof("Starting MCP server %s v%s on %s:%d", + s.conf.Mcp.Name, s.conf.Mcp.Version, s.conf.Host, s.conf.Port) + s.httpServer.Start() } -func (s *sseMcpServer) Stop() { - s.server.Stop() +// Stop implements McpServer. +func (s *mcpServerImpl) Stop() { + logx.Info("Stopping MCP server") + s.httpServer.Stop() } -// broadcast sends a message to all connected clients -// It uses Server-Sent Events (SSE) format for real-time communication -func (s *sseMcpServer) broadcast(event string, data any) { - jsonData, err := json.Marshal(data) - if err != nil { - logx.Errorf("Failed to marshal broadcast data: %v", err) - return - } +// setupSSETransport configures the server to use SSE transport (2024-11-05 spec) +func (s *mcpServerImpl) setupSSETransport() { + // Create SSE handler that returns our MCP server for each connection + handler := sdkmcp.NewSSEHandler(func(r *http.Request) *sdkmcp.Server { + logx.Infof("New SSE connection from %s", r.RemoteAddr) + return s.mcpServer + }, nil) - // Lock only while reading the clients map - s.clientsLock.Lock() - clients := make([]*mcpClient, 0, len(s.clients)) - for _, client := range s.clients { - clients = append(clients, client) - } - s.clientsLock.Unlock() + // Register the SSE endpoint + s.httpServer.AddRoute(rest.Route{ + Method: http.MethodGet, + Path: s.conf.Mcp.SseEndpoint, + Handler: handler.ServeHTTP, + }, rest.WithSSE(), rest.WithTimeout(s.conf.Mcp.SseTimeout)) - clientCount := len(clients) - if clientCount == 0 { - return - } - - logx.Infof("Broadcasting event '%s' to %d clients", event, clientCount) - - // Use CRLF line endings as per SSE specification - message := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(jsonData)) - - // Send messages without holding the lock - for _, client := range clients { - select { - case client.channel <- message: - // Message sent successfully - default: - // Channel buffer is full, log warning and continue - logx.Errorf("Client channel buffer full, dropping message for client %s", client.id) - } - } + // The SSE handler also handles POST requests to message endpoints + // We need to route those as well + s.httpServer.AddRoute(rest.Route{ + Method: http.MethodPost, + Path: s.conf.Mcp.SseEndpoint, + Handler: handler.ServeHTTP, + }, rest.WithTimeout(s.conf.Mcp.MessageTimeout)) } -// cleanupClient removes a client from the active clients map -func (s *sseMcpServer) cleanupClient(sessionId string) { - s.clientsLock.Lock() - defer s.clientsLock.Unlock() +// setupStreamableTransport configures the server to use Streamable HTTP transport (2025-03-26 spec) +func (s *mcpServerImpl) setupStreamableTransport() { + // Create Streamable HTTP handler + handler := sdkmcp.NewStreamableHTTPHandler(func(r *http.Request) *sdkmcp.Server { + logx.Infof("New streamable connection from %s", r.RemoteAddr) + return s.mcpServer + }, nil) - if client, exists := s.clients[sessionId]; exists { - // Close the channel to signal any goroutines waiting on it - close(client.channel) - // Remove from active clients - delete(s.clients, sessionId) - logx.Infof("Cleaned up client %s (remaining clients: %d)", sessionId, len(s.clients)) - } -} - -// handleRequest handles MCP JSON-RPC requests -func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) { - // Extract sessionId from query parameters - sessionId := r.URL.Query().Get(sessionIdKey) - if len(sessionId) == 0 { - http.Error(w, fmt.Sprintf("Missing %s", sessionIdKey), http.StatusBadRequest) - return - } - - // Check if the client with this sessionId exists - s.clientsLock.Lock() - client, exists := s.clients[sessionId] - s.clientsLock.Unlock() - - if !exists { - http.Error(w, fmt.Sprintf("Invalid or expired %s", sessionIdKey), http.StatusBadRequest) - return - } - - var req Request - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Invalid request", http.StatusBadRequest) - return - } - - // For notification methods (no ID), we don't send a response - isNotification, err := req.isNotification() - if err != nil { - http.Error(w, "Invalid request.ID", http.StatusBadRequest) - } - - w.WriteHeader(http.StatusAccepted) - - // Special handling for initialization sequence - // Always allow initialize and notifications/initialized regardless of client state - if req.Method == methodInitialize { - logx.Infof("Processing initialize request with ID: %v", req.ID) - s.processInitialize(r.Context(), client, req) - logx.Infof("Sent initialize response for ID: %v, waiting for notifications/initialized", req.ID) - return - } else if req.Method == methodNotificationsInitialized { - // Handle initialized notification - logx.Info("Received notifications/initialized notification") - if !isNotification { - s.sendErrorResponse(r.Context(), client, req.ID, - "Method should be used as a notification", errCodeInvalidRequest) - return - } - s.processNotificationInitialized(client) - return - } else if !client.initialized && req.Method != methodNotificationsCancelled { - // Block most requests until client is initialized (except for cancellations) - s.sendErrorResponse(r.Context(), client, req.ID, - "Client not fully initialized, waiting for notifications/initialized", - errCodeClientNotInitialized) - return - } - - // Process normal requests only after initialization - switch req.Method { - case methodToolsCall: - logx.Infof("Received tools call request with ID: %v", req.ID) - s.processToolCall(r.Context(), client, req) - logx.Infof("Sent tools call response for ID: %v", req.ID) - case methodToolsList: - logx.Infof("Processing tools/list request with ID: %v", req.ID) - s.processListTools(r.Context(), client, req) - logx.Infof("Sent tools/list response for ID: %v", req.ID) - case methodPromptsList: - logx.Infof("Processing prompts/list request with ID: %v", req.ID) - s.processListPrompts(r.Context(), client, req) - logx.Infof("Sent prompts/list response for ID: %v", req.ID) - case methodPromptsGet: - logx.Infof("Processing prompts/get request with ID: %v", req.ID) - s.processGetPrompt(r.Context(), client, req) - logx.Infof("Sent prompts/get response for ID: %v", req.ID) - case methodResourcesList: - logx.Infof("Processing resources/list request with ID: %v", req.ID) - s.processListResources(r.Context(), client, req) - logx.Infof("Sent resources/list response for ID: %v", req.ID) - case methodResourcesRead: - logx.Infof("Processing resources/read request with ID: %v", req.ID) - s.processResourcesRead(r.Context(), client, req) - logx.Infof("Sent resources/read response for ID: %v", req.ID) - case methodResourcesSubscribe: - logx.Infof("Processing resources/subscribe request with ID: %v", req.ID) - s.processResourceSubscribe(r.Context(), client, req) - logx.Infof("Sent resources/subscribe response for ID: %v", req.ID) - case methodPing: - logx.Infof("Processing ping request with ID: %v", req.ID) - s.processPing(r.Context(), client, req) - case methodNotificationsCancelled: - logx.Infof("Received notifications/cancelled notification: %v", req.ID) - s.processNotificationCancelled(r.Context(), client, req) - default: - logx.Infof("Unknown method: %s from client: %v", req.Method, req.ID) - s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound) - } -} - -// handleSSE handles Server-Sent Events connections -func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) { - // Generate a unique session ID for this client - sessionId := uuid.New().String() - - // Create new client with buffered channel to prevent blocking - client := &mcpClient{ - id: sessionId, - channel: make(chan string, eventChanSize), - } - - // Add client to active clients map - s.clientsLock.Lock() - s.clients[sessionId] = client - activeClients := len(s.clients) - s.clientsLock.Unlock() - - logx.Infof("New SSE connection established for client %s (active clients: %d)", - sessionId, activeClients) - - // Set proper SSE headers - w.Header().Set("Transfer-Encoding", "chunked") - - // Enable streaming - flusher, ok := w.(http.Flusher) - if !ok { - logx.Error("Streaming not supported by the underlying http.ResponseWriter") - http.Error(w, "Streaming not supported", http.StatusInternalServerError) - return - } - - // Send the message endpoint URL to the client - endpoint := fmt.Sprintf("%s%s?%s=%s", - s.conf.Mcp.BaseUrl, s.conf.Mcp.MessageEndpoint, sessionIdKey, sessionId) - - // Format and send the endpoint message - endpointMsg := formatSSEMessage(eventEndpoint, []byte(endpoint)) - if _, err := fmt.Fprint(w, endpointMsg); err != nil { - logx.Errorf("Failed to send endpoint message to client %s: %v", sessionId, err) - s.cleanupClient(sessionId) - return - } - flusher.Flush() - - // Set up keep-alive ping and client cleanup - ticker := time.NewTicker(pingInterval.Load()) - defer func() { - ticker.Stop() - s.cleanupClient(sessionId) - logx.Infof("SSE connection closed for client %s", sessionId) - }() - - // Message processing loop - for { - select { - case message, ok := <-client.channel: - if !ok { - // Channel was closed, end connection - logx.Infof("Client channel was closed for %s", sessionId) - return - } - - // Write message to the response - if _, err := fmt.Fprint(w, message); err != nil { - logx.Infof("Failed to write message to client %s: %v", sessionId, err) - return - } - flusher.Flush() - case <-ticker.C: - // Send keep-alive ping to maintain connection - ping := fmt.Sprintf(`{"type":"ping","timestamp":"%s"}`, time.Now().String()) - pingMsg := formatSSEMessage("ping", []byte(ping)) - if _, err := fmt.Fprint(w, pingMsg); err != nil { - logx.Errorf("Failed to send ping to client %s, closing connection: %v", sessionId, err) - return - } - flusher.Flush() - case <-r.Context().Done(): - // Client disconnected or request was canceled or timed out - logx.Infof("Client %s disconnected: context done", sessionId) - return - } - } -} - -// processInitialize processes the initialize request -func (s *sseMcpServer) processInitialize(ctx context.Context, client *mcpClient, req Request) { - // Create a proper JSON-RPC response that preserves the client's request ID - result := initializationResponse{ - ProtocolVersion: s.conf.Mcp.ProtocolVersion, - Capabilities: capabilities{ - Prompts: struct { - ListChanged bool `json:"listChanged"` - }{ - ListChanged: true, - }, - Resources: struct { - Subscribe bool `json:"subscribe"` - ListChanged bool `json:"listChanged"` - }{ - Subscribe: true, - ListChanged: true, - }, - Tools: struct { - ListChanged bool `json:"listChanged"` - }{ - ListChanged: true, - }, - }, - ServerInfo: serverInfo{ - Name: s.conf.Mcp.Name, - Version: s.conf.Mcp.Version, - }, - } - - // Mark client as initialized - client.initialized = true - - // Send response with client's original request ID - s.sendResponse(ctx, client, req.ID, result) -} - -// processListTools processes the tools/list request -func (s *sseMcpServer) processListTools(ctx context.Context, client *mcpClient, req Request) { - // Extract pagination params if any - var nextCursor string - var progressToken any - - // Extract meta data including progress token - if req.Params != nil { - var metaParams struct { - Cursor string `json:"cursor"` - Meta struct { - ProgressToken any `json:"progressToken"` - } `json:"_meta"` - } - if err := json.Unmarshal(req.Params, &metaParams); err == nil { - if len(metaParams.Cursor) > 0 { - nextCursor = metaParams.Cursor - } - progressToken = metaParams.Meta.ProgressToken - } - } - - s.toolsLock.Lock() - toolsList := make([]Tool, 0, len(s.tools)) - for _, tool := range s.tools { - if len(tool.InputSchema.Type) == 0 { - tool.InputSchema.Type = ContentTypeObject - } - toolsList = append(toolsList, tool) - } - s.toolsLock.Unlock() - - result := ListToolsResult{ - PaginatedResult: PaginatedResult{ - Result: Result{}, - NextCursor: Cursor(nextCursor), - }, - Tools: toolsList, - } - - // Add meta information if progress token was provided - if progressToken != nil { - result.Result.Meta = map[string]any{ - progressTokenKey: progressToken, - } - } - - s.sendResponse(ctx, client, req.ID, result) -} - -// processListPrompts processes the prompts/list request -func (s *sseMcpServer) processListPrompts(ctx context.Context, client *mcpClient, req Request) { - // Extract pagination params if any - var nextCursor string - if req.Params != nil { - var cursorParams struct { - Cursor string `json:"cursor"` - } - if err := json.Unmarshal(req.Params, &cursorParams); err == nil && cursorParams.Cursor != "" { - // If we have a valid cursor, we could use it for pagination - // For now, we're not actually implementing pagination, so this is just - // to show how it would be extracted from the request - _ = cursorParams.Cursor - } - } - - // Prepare prompt list - s.promptsLock.Lock() - promptsList := make([]Prompt, 0, len(s.prompts)) - for _, prompt := range s.prompts { - promptsList = append(promptsList, prompt) - } - s.promptsLock.Unlock() - - // In a real implementation, you'd handle pagination here - // For now, we'll return all prompts at once - result := struct { - Prompts []Prompt `json:"prompts"` - NextCursor string `json:"nextCursor,omitempty"` - Meta *struct{} `json:"_meta,omitempty"` - }{ - Prompts: promptsList, - NextCursor: nextCursor, - } - - s.sendResponse(ctx, client, req.ID, result) -} - -// processListResources processes the resources/list request -func (s *sseMcpServer) processListResources(ctx context.Context, client *mcpClient, req Request) { - // Extract pagination params if any - var nextCursor string - var progressToken any - - // Extract meta information including progress token if available - if req.Params != nil { - var metaParams PaginatedParams - if err := json.Unmarshal(req.Params, &metaParams); err == nil { - if len(metaParams.Cursor) > 0 { - nextCursor = metaParams.Cursor - } - progressToken = metaParams.Meta.ProgressToken - } - } - - s.resourcesLock.Lock() - resourcesList := make([]Resource, 0, len(s.resources)) - for _, resource := range s.resources { - // Create a copy without the handler function which shouldn't be sent to clients - resourceCopy := Resource{ - URI: resource.URI, - Name: resource.Name, - Description: resource.Description, - MimeType: resource.MimeType, - } - resourcesList = append(resourcesList, resourceCopy) - } - s.resourcesLock.Unlock() - - // Create proper ResourcesListResult according to MCP specification - result := ResourcesListResult{ - PaginatedResult: PaginatedResult{ - Result: Result{}, - NextCursor: Cursor(nextCursor), - }, - Resources: resourcesList, - } - - // Add meta information if progress token was provided - if progressToken != nil { - result.Result.Meta = map[string]any{ - progressTokenKey: progressToken, - } - } - - s.sendResponse(ctx, client, req.ID, result) -} - -// processGetPrompt processes the prompts/get request -func (s *sseMcpServer) processGetPrompt(ctx context.Context, client *mcpClient, req Request) { - type GetPromptParams struct { - Name string `json:"name"` - Arguments map[string]string `json:"arguments,omitempty"` - } - - var params GetPromptParams - if err := json.Unmarshal(req.Params, ¶ms); err != nil { - s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams) - return - } - - // Check if prompt exists - s.promptsLock.Lock() - prompt, exists := s.prompts[params.Name] - s.promptsLock.Unlock() - if !exists { - message := fmt.Sprintf("Prompt '%s' not found", params.Name) - s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams) - return - } - - logx.Infof("Processing prompt request: %s with %d arguments", prompt.Name, len(params.Arguments)) - - // Validate required arguments - missingArgs := validatePromptArguments(prompt, params.Arguments) - if len(missingArgs) > 0 { - message := fmt.Sprintf("Missing required arguments: %s", strings.Join(missingArgs, ", ")) - s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams) - return - } - - // Ensure arguments are initialized to an empty map if nil - if params.Arguments == nil { - params.Arguments = make(map[string]string) - } - args := params.Arguments - - // Generate messages using handler or static content - var messages []PromptMessage - var err error - - if prompt.Handler != nil { - // Use dynamic handler to generate messages - messages, err = prompt.Handler(ctx, args) - if err != nil { - logx.Errorf("Error from prompt handler: %v", err) - s.sendErrorResponse(ctx, client, req.ID, - fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError) - return - } - } else { - // No handler, generate messages from static content - var messageText string - if len(prompt.Content) > 0 { - messageText = prompt.Content - - // Apply argument substitutions to static content - for key, value := range args { - placeholder := fmt.Sprintf("{{%s}}", key) - messageText = strings.Replace(messageText, placeholder, value, -1) - } - } - - // Create a single user message with the content - messages = []PromptMessage{ - { - Role: RoleUser, - Content: TextContent{ - Text: messageText, - }, - }, - } - } - - // Construct the response according to MCP spec - result := struct { - Description string `json:"description,omitempty"` - Messages []PromptMessage `json:"messages"` - }{ - Description: prompt.Description, - Messages: toTypedPromptMessages(messages), - } - - s.sendResponse(ctx, client, req.ID, result) -} - -// processToolCall processes the tools/call request -func (s *sseMcpServer) processToolCall(ctx context.Context, client *mcpClient, req Request) { - var toolCallParams struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments,omitempty"` - Meta struct { - ProgressToken any `json:"progressToken"` - } `json:"_meta,omitempty"` - } - - // Handle different types of req.Params - // If it's a RawMessage (JSON), unmarshal it - if err := json.Unmarshal(req.Params, &toolCallParams); err != nil { - logx.Errorf("Failed to unmarshal tool call params: %v", err) - s.sendErrorResponse(ctx, client, req.ID, "Invalid tool call parameters", errCodeInvalidParams) - return - } - - // Extract progress token if available - progressToken := toolCallParams.Meta.ProgressToken - - // Find the requested tool - s.toolsLock.Lock() - tool, exists := s.tools[toolCallParams.Name] - s.toolsLock.Unlock() - if !exists { - s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Tool '%s' not found", - toolCallParams.Name), errCodeInvalidParams) - return - } - - // Log parameters before execution - logx.Infof("Executing tool '%s' with arguments: %#v", toolCallParams.Name, toolCallParams.Arguments) - - // Execute the tool handler with timeout handling - var result any - var err error - - // Create a channel to receive the result - // make sure to have 1 size buffer to avoid channel leak if timeout - resultCh := make(chan struct { - result any - err error - }, 1) - - // Execute the tool handler in a goroutine - go func() { - toolResult, toolErr := tool.Handler(ctx, toolCallParams.Arguments) - resultCh <- struct { - result any - err error - }{ - result: toolResult, - err: toolErr, - } - }() - - // Wait for either the result or a timeout - select { - case res := <-resultCh: - result = res.result - err = res.err - case <-ctx.Done(): - // Handle request timeout - logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.MessageTimeout, toolCallParams.Name) - s.sendErrorResponse(ctx, client, req.ID, "Tool execution timed out", errCodeTimeout) - return - } - - // Create the base result structure with metadata - callToolResult := CallToolResult{ - Result: Result{}, - Content: []any{}, - IsError: false, - } - - // Add meta information if progress token was provided - if progressToken != nil { - callToolResult.Result.Meta = map[string]any{ - progressTokenKey: progressToken, - } - } - - // Check if there was an error during tool execution - if err != nil { - // According to the spec, for tool-level errors (as opposed to protocol-level errors), - // we should report them inside the result with isError=true - logx.Errorf("Tool execution reported error: %v", err) - - callToolResult.Content = []any{ - TextContent{ - Text: fmt.Sprintf("Error: %v", err), - }, - } - callToolResult.IsError = true - s.sendResponse(ctx, client, req.ID, callToolResult) - return - } - - // Format the response according to the CallToolResult schema - switch v := result.(type) { - case string: - // Simple string becomes text content - callToolResult.Content = append(callToolResult.Content, TextContent{ - Text: v, - Annotations: &Annotations{ - Audience: []RoleType{RoleUser, RoleAssistant}, - }, - }) - case map[string]any: - // JSON-like object becomes formatted JSON text - jsonStr, err := json.Marshal(v) - if err != nil { - jsonStr = []byte(err.Error()) - } - callToolResult.Content = append(callToolResult.Content, TextContent{ - Text: string(jsonStr), - Annotations: &Annotations{ - Audience: []RoleType{RoleUser, RoleAssistant}, - }, - }) - case TextContent: - callToolResult.Content = append(callToolResult.Content, v) - case ImageContent: - callToolResult.Content = append(callToolResult.Content, v) - case []any: - callToolResult.Content = v - case ToolResult: - // Handle legacy ToolResult type - switch v.Type { - case ContentTypeText: - callToolResult.Content = append(callToolResult.Content, TextContent{ - Text: fmt.Sprintf("%v", v.Content), - Annotations: &Annotations{ - Audience: []RoleType{RoleUser, RoleAssistant}, - }, - }) - case ContentTypeImage: - if imgData, ok := v.Content.(map[string]any); ok { - callToolResult.Content = append(callToolResult.Content, ImageContent{ - Data: fmt.Sprintf("%v", imgData["data"]), - MimeType: fmt.Sprintf("%v", imgData["mimeType"]), - }) - } - default: - callToolResult.Content = append(callToolResult.Content, TextContent{ - Text: fmt.Sprintf("%v", v.Content), - Annotations: &Annotations{ - Audience: []RoleType{RoleUser, RoleAssistant}, - }, - }) - } - default: - // For any other type, convert to string - callToolResult.Content = append(callToolResult.Content, TextContent{ - Text: fmt.Sprintf("%v", v), - Annotations: &Annotations{ - Audience: []RoleType{RoleUser, RoleAssistant}, - }, - }) - } - - callToolResult.Content = toTypedContents(callToolResult.Content) - logx.Infof("Tool call result: %#v", callToolResult) - - s.sendResponse(ctx, client, req.ID, callToolResult) -} - -// processResourcesRead processes the resources/read request -func (s *sseMcpServer) processResourcesRead(ctx context.Context, client *mcpClient, req Request) { - var params ResourceReadParams - if err := json.Unmarshal(req.Params, ¶ms); err != nil { - s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams) - return - } - - // Find resource that matches the URI - s.resourcesLock.Lock() - resource, exists := s.resources[params.URI] - s.resourcesLock.Unlock() - - if !exists { - s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found", - params.URI), errCodeResourceNotFound) - return - } - - // If no handler is provided, return an empty content array - if resource.Handler == nil { - result := ResourceReadResult{ - Contents: []ResourceContent{ - { - URI: params.URI, - MimeType: resource.MimeType, - Text: "", - }, - }, - } - s.sendResponse(ctx, client, req.ID, result) - return - } - - // Execute the resource handler - content, err := resource.Handler(ctx) - if err != nil { - s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Error reading resource: %v", err), - errCodeInternalError) - return - } - - // Ensure the URI is set if not already provided by the handler - if len(content.URI) == 0 { - content.URI = params.URI - } - - // Ensure MimeType is set if available from the resource definition - if len(content.MimeType) == 0 && len(resource.MimeType) > 0 { - content.MimeType = resource.MimeType - } - - // Create response with contents from the handler - // The MCP specification requires a contents array - result := ResourceReadResult{ - Contents: []ResourceContent{content}, - } - - s.sendResponse(ctx, client, req.ID, result) -} - -// processResourceSubscribe processes the resources/subscribe request -func (s *sseMcpServer) processResourceSubscribe(ctx context.Context, client *mcpClient, req Request) { - var params ResourceSubscribeParams - if err := json.Unmarshal(req.Params, ¶ms); err != nil { - s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams) - return - } - - // Check if the resource exists - s.resourcesLock.Lock() - _, exists := s.resources[params.URI] - s.resourcesLock.Unlock() - - if !exists { - s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found", - params.URI), errCodeResourceNotFound) - return - } - - // Send success response for the subscription - s.sendResponse(ctx, client, req.ID, struct{}{}) -} - -// processNotificationCancelled processes the notifications/cancelled notification -func (s *sseMcpServer) processNotificationCancelled(ctx context.Context, client *mcpClient, req Request) { - // Extract the requestId that was canceled - type CancelParams struct { - RequestId int64 `json:"requestId"` - Reason string `json:"reason"` - } - - var params CancelParams - if err := json.Unmarshal(req.Params, ¶ms); err != nil { - logx.Errorf("Failed to parse cancellation params: %v", err) - return - } - - logx.Infof("Request %d was cancelled by client. Reason: %s", params.RequestId, params.Reason) -} - -// processNotificationInitialized processes the notifications/initialized notification -func (s *sseMcpServer) processNotificationInitialized(client *mcpClient) { - // Mark the client as properly initialized - client.initialized = true - logx.Infof("Client %s is now fully initialized and ready for normal operations", client.id) -} - -// processPing processes the ping request and responds immediately -func (s *sseMcpServer) processPing(ctx context.Context, client *mcpClient, req Request) { - // A ping request should simply respond with an empty result to confirm the server is alive - logx.Infof("Received ping request with ID: %d", req.ID) - - // Send an empty response with client's original request ID - s.sendResponse(ctx, client, req.ID, struct{}{}) -} - -// sendErrorResponse sends an error response via the SSE channel -func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient, - id any, message string, code int) { - errorResponse := struct { - JsonRpc string `json:"jsonrpc"` - ID any `json:"id"` - Error errorMessage `json:"error"` - }{ - JsonRpc: jsonRpcVersion, - ID: id, - Error: errorMessage{ - Code: code, - Message: message, - }, - } - - // all fields are primitive types, impossible to fail - jsonData, _ := json.Marshal(errorResponse) - // Use CRLF line endings as requested - sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData)) - logx.Infof("Sending error for ID %v: %s", id, sseMessage) - - // cannot receive from ctx.Done() because we're sending to the channel for SSE messages - select { - case client.channel <- sseMessage: - default: - // Channel buffer is full, log warning and continue - logx.Infof("Client %s channel is full while sending error response with ID %d", client.id, id) - } -} - -// sendResponse sends a success response via the SSE channel -func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id any, result any) { - response := Response{ - JsonRpc: jsonRpcVersion, - ID: id, - Result: result, - } - - jsonData, err := json.Marshal(response) - if err != nil { - s.sendErrorResponse(ctx, client, id, "Failed to marshal response", errCodeInternalError) - return - } - - // Use CRLF line endings as requested - sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData)) - logx.Infof("Sending response for ID %v: %s", id, sseMessage) - - // cannot receive from ctx.Done() because we're sending to the channel for SSE messages - select { - case client.channel <- sseMessage: - default: - // Channel buffer is full, log warning and continue - logx.Infof("Client %s channel is full while sending response with ID %v", client.id, id) - } + // Register the message endpoint (handles both GET for SSE and POST for messages) + s.httpServer.AddRoute(rest.Route{ + Method: http.MethodGet, + Path: s.conf.Mcp.MessageEndpoint, + Handler: handler.ServeHTTP, + }, rest.WithSSE(), rest.WithTimeout(s.conf.Mcp.SseTimeout)) + + s.httpServer.AddRoute(rest.Route{ + Method: http.MethodPost, + Path: s.conf.Mcp.MessageEndpoint, + Handler: handler.ServeHTTP, + }, rest.WithTimeout(s.conf.Mcp.MessageTimeout)) } diff --git a/mcp/server_test.go b/mcp/server_test.go index ad683e5cc..833ca87d8 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -3,3449 +3,372 @@ package mcp import ( "bytes" "context" - "encoding/json" - "fmt" "net/http" "net/http/httptest" - "strings" - "sync" - "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "github.com/zeromicro/go-zero/core/conf" - "github.com/zeromicro/go-zero/core/logx/logtest" ) -// mockMcpServer is a helper for testing the MCP server -// It encapsulates the server and test server setup and teardown logic -type mockMcpServer struct { - server *sseMcpServer - testServer *httptest.Server - requestId int64 -} - -// newMockMcpServer initializes a mock MCP server for testing -func newMockMcpServer(t *testing.T) *mockMcpServer { - const yamlConf = `name: test-server -host: localhost -port: 8080 -mcp: - name: mcp-test-server - messageTimeout: 5s -` - - var c McpConf - assert.NoError(t, conf.LoadFromYamlBytes([]byte(yamlConf), &c)) - - server := NewMcpServer(c).(*sseMcpServer) - mux := http.NewServeMux() - mux.HandleFunc(c.Mcp.SseEndpoint, server.handleSSE) - mux.HandleFunc(c.Mcp.MessageEndpoint, server.handleRequest) - testServer := httptest.NewServer(mux) - server.conf.Mcp.BaseUrl = testServer.URL - - return &mockMcpServer{ - server: server, - testServer: testServer, - requestId: 1, - } -} - -// shutdown closes the test server -func (m *mockMcpServer) shutdown() { - m.testServer.Close() -} - -// registerExamplePrompt registers a test prompt -func (m *mockMcpServer) registerExamplePrompt() { - m.server.RegisterPrompt(Prompt{ - Name: "test.prompt", - Description: "A test prompt", - }) -} - -// registerExampleResource registers a test resource -func (m *mockMcpServer) registerExampleResource() { - m.server.RegisterResource(Resource{ - Name: "test.resource", - URI: "file:///test.file", - Description: "A test resource", - }) -} - -// registerExampleTool registers a test tool -func (m *mockMcpServer) registerExampleTool() { - _ = m.server.RegisterTool(Tool{ - Name: "test.tool", - Description: "A test tool", - InputSchema: InputSchema{ - Properties: map[string]any{ - "input": map[string]any{ - "type": "string", - "description": "Input parameter", - }, - }, - Required: []string{"input"}, - }, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - input, ok := params["input"].(string) - if !ok { - return nil, fmt.Errorf("invalid input parameter") - } - return fmt.Sprintf("Processed: %s", input), nil - }, - }) -} - -// Helper function to create and add a test client -func addTestClient(server *sseMcpServer, clientID string, initialized bool) *mcpClient { - client := &mcpClient{ - id: clientID, - channel: make(chan string, eventChanSize), - initialized: initialized, - } - server.clientsLock.Lock() - server.clients[clientID] = client - server.clientsLock.Unlock() - return client -} - func TestNewMcpServer(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() + c := McpConf{} + c.Host = "localhost" + c.Port = 8080 + c.Mcp.Name = "test-server" + c.Mcp.Version = "1.0.0" - mock.registerExamplePrompt() - mock.registerExampleResource() - mock.registerExampleTool() - - require.NotNil(t, mock.server, "Server should be created") - assert.NotEmpty(t, mock.server.tools, "Tools map should be initialized") - assert.NotEmpty(t, mock.server.prompts, "Prompts map should be initialized") - assert.NotEmpty(t, mock.server.resources, "Resources map should be initialized") + server := NewMcpServer(c) + assert.NotNil(t, server) } -func TestNewMcpServer_WithCors(t *testing.T) { - const yamlConf = `name: test-server -host: localhost -port: 8080 -mcp: - cors: - - http://localhost:3000 - messageTimeout: 5s -` +func TestNewMcpServerWithDefaults(t *testing.T) { + c := McpConf{} + c.Name = "default-server" + c.Host = "localhost" + c.Port = 8082 - var c McpConf - assert.NoError(t, conf.LoadFromYamlBytes([]byte(yamlConf), &c)) + server := NewMcpServer(c) + impl := server.(*mcpServerImpl) - server := NewMcpServer(c).(*sseMcpServer) - assert.Equal(t, "test-server", server.conf.Name, "Server name should be set") + // Check defaults are set + assert.Equal(t, "default-server", impl.conf.Mcp.Name) + assert.Equal(t, "1.0.0", impl.conf.Mcp.Version) } -func TestHandleRequest_badRequest(t *testing.T) { - t.Run("empty session ID", func(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() +func TestNewMcpServerWithCORS(t *testing.T) { + c := McpConf{} + c.Host = "localhost" + c.Port = 8083 + c.Mcp.Name = "cors-server" + c.Mcp.Cors = []string{"http://localhost:3000", "http://example.com"} - // Create a request with an invalid session ID - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 1, - Method: methodToolsCall, - Params: []byte(`{"sessionId": "invalid-session"}`), - } - - jsonBody, _ := json.Marshal(req) - r := httptest.NewRequest(http.MethodPost, "/?session_id=", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - mock.server.handleRequest(w, r) - assert.Equal(t, http.StatusBadRequest, w.Code) - }) - - t.Run("bad body", func(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - addTestClient(mock.server, "test-session", true) - - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-session", bytes.NewReader([]byte(`{`))) - w := httptest.NewRecorder() - mock.server.handleRequest(w, r) - assert.Equal(t, http.StatusBadRequest, w.Code) - }) - - t.Run("bad id", func(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - addTestClient(mock.server, "test-session", true) - - body := `{"jsonrpc": "2.0", "id": {}, "method": "tools.call", "params": {}}` - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-session", bytes.NewReader([]byte(body))) - w := httptest.NewRecorder() - mock.server.handleRequest(w, r) - assert.Equal(t, http.StatusBadRequest, w.Code) - assert.Contains(t, w.Body.String(), "Invalid request.ID") - }) + server := NewMcpServer(c) + assert.NotNil(t, server) } -func TestRegisterTool(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() +func TestSetupSSETransport(t *testing.T) { + c := McpConf{} + c.Host = "localhost" + c.Port = 8084 + c.Mcp.Name = "sse-server" + c.Mcp.UseStreamable = false + c.Mcp.SseEndpoint = "/sse" + c.Mcp.MessageTimeout = 30 * time.Second + c.Mcp.SseTimeout = 24 * time.Hour - tool := Tool{ - Name: "example.tool", - Description: "An example tool", - InputSchema: InputSchema{ - Properties: map[string]any{ - "input": map[string]any{ - "type": "string", - "description": "Input parameter", - }, + server := NewMcpServer(c) + impl := server.(*mcpServerImpl) + + assert.NotNil(t, impl.httpServer) + assert.False(t, impl.conf.Mcp.UseStreamable) +} + +func TestSetupStreamableTransport(t *testing.T) { + c := McpConf{} + c.Host = "localhost" + c.Port = 8085 + c.Mcp.Name = "streamable-server" + c.Mcp.UseStreamable = true + c.Mcp.MessageEndpoint = "/message" + c.Mcp.MessageTimeout = 30 * time.Second + c.Mcp.SseTimeout = 24 * time.Hour + + server := NewMcpServer(c) + impl := server.(*mcpServerImpl) + + assert.NotNil(t, impl.httpServer) + assert.True(t, impl.conf.Mcp.UseStreamable) +} + +func TestServerImplementsInterface(t *testing.T) { + c := McpConf{} + c.Host = "localhost" + c.Port = 8086 + c.Mcp.Name = "interface-test" + + var _ McpServer = NewMcpServer(c) +} + +func TestAddTool(t *testing.T) { + c := McpConf{} + c.Host = "localhost" + c.Port = 8081 + c.Mcp.Name = "test-server" + + server := NewMcpServer(c) + + type Args struct { + Name string `json:"name"` + } + + tool := &Tool{ + Name: "greet", + Description: "Say hello", + } + + handler := func(ctx context.Context, req *CallToolRequest, args Args) (*CallToolResult, any, error) { + return &CallToolResult{ + Content: []Content{ + &TextContent{Text: "Hello " + args.Name}, }, - }, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - return "result", nil - }, + }, nil, nil } - // Test with valid tool - err := mock.server.RegisterTool(tool) - assert.NoError(t, err, "Should not error with valid tool") - - // Check tool was registered - _, exists := mock.server.tools["example.tool"] - assert.True(t, exists, "Tool should be registered") - - // Test with missing handler - invalidTool := tool - invalidTool.Name = "invalid.tool" - invalidTool.Handler = nil - err = mock.server.RegisterTool(invalidTool) - assert.Error(t, err, "Should error with missing handler") + // Register the tool using mcp.AddTool + AddTool(server, tool, handler) } -func TestRegisterPrompt(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() +func TestAddToolWithStructuredOutput(t *testing.T) { + c := McpConf{} + c.Host = "localhost" + c.Port = 8087 + c.Mcp.Name = "structured-test" - prompt := Prompt{ - Name: "example.prompt", - Description: "An example prompt", + server := NewMcpServer(c) + + type CalculateArgs struct { + A int `json:"a"` + B int `json:"b"` } - // Test registering prompt - mock.server.RegisterPrompt(prompt) - - // Check prompt was registered - _, exists := mock.server.prompts["example.prompt"] - assert.True(t, exists, "Prompt should be registered") -} - -func TestRegisterResource(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - resource := Resource{ - Name: "example.resource", - URI: "http://example.com/resource", - Description: "An example resource", + type CalculateResult struct { + Sum int `json:"sum"` } - // Test registering resource - mock.server.RegisterResource(resource) - - // Check resource was registered - _, exists := mock.server.resources["http://example.com/resource"] - assert.True(t, exists, "Resource should be registered") -} - -// TestToolCallBasic tests the basic functionality of a tool call -func TestToolsList(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Register a test tool - mock.registerExampleTool() - - // Simulate a client to test tool call - client := addTestClient(mock.server, "test-client", true) - - // Create a tool call request - params := struct { - Cursor string `json:"cursor"` - Meta struct { - ProgressToken any `json:"progressToken"` - } `json:"_meta"` - }{ - Cursor: "my-cursor", + tool := &Tool{ + Name: "add", + Description: "Add two numbers", } - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 1, - Method: methodToolsList, - Params: paramBytes, - } - - // Process the tool call - mock.server.processListTools(context.Background(), client, req) - - // Get the response from the client's channel - select { - case response := <-client.channel: - evt, err := parseEvent(response) - assert.NoError(t, err) - - assert.Equal(t, eventMessage, evt.Type, "Event type should be message") - result, ok := evt.Data["result"].(map[string]any) - assert.True(t, ok) - assert.Equal(t, "my-cursor", result["nextCursor"], "Cursor should match") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for tool call response") - } -} - -// TestToolCallBasic tests the basic functionality of a tool call -func TestToolCallBasic(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Register a test tool - mock.registerExampleTool() - - // Simulate a client to test tool call - client := addTestClient(mock.server, "test-client", true) - - // Create a tool call request - params := struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments"` - }{ - Name: "test.tool", - Arguments: map[string]any{ - "input": "test-input", - }, - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 1, - Method: methodToolsCall, - Params: paramBytes, - } - - // Process the tool call - mock.server.processToolCall(context.Background(), client, req) - - // Get the response from the client's channel - select { - case response := <-client.channel: - // Check response format - assert.Contains(t, response, "event: message", "Response should have message event") - assert.Contains(t, response, "data:", "Response should have data") - - // Extract JSON from the SSE response - jsonStart := strings.Index(response, "{") - jsonEnd := strings.LastIndex(response, "}") - jsonStr := response[jsonStart : jsonEnd+1] - - // Parse the JSON - var parsed struct { - Result struct { - Content []map[string]any `json:"content"` - IsError bool `json:"isError"` - } `json:"result"` - } - - err := json.Unmarshal([]byte(jsonStr), &parsed) - require.NoError(t, err, "Response should be valid JSON") - - // Verify the response content - assert.Len(t, parsed.Result.Content, 1, "Response should contain one content item") - assert.Equal(t, "Processed: test-input", parsed.Result.Content[0][ContentTypeText], "Tool result incorrect") - assert.False(t, parsed.Result.IsError, "Response should not be an error") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for tool call response") - } -} - -// TestToolCallMapResult tests a tool that returns a map[string]any result -func TestToolCallMapResult(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Register a tool that returns a map - mapTool := Tool{ - Name: "map.tool", - Description: "A tool that returns a map result", - InputSchema: InputSchema{ - Properties: map[string]any{}, - }, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - // Return a complex nested map structure - return map[string]any{ - "string": "value", - "number": 42, - "boolean": true, - "nested": map[string]any{ - "array": []string{"item1", "item2"}, - "obj": map[string]any{ - "key": "value", - }, - }, - "nullValue": nil, - }, nil - }, - } - - err := mock.server.RegisterTool(mapTool) - require.NoError(t, err) - - // Create a client - client := addTestClient(mock.server, "test-client", true) - - // Create a tool call request - params := struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments"` - }{ - Name: "map.tool", - Arguments: map[string]any{}, - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 1, - Method: methodToolsCall, - Params: paramBytes, - } - - // Process the tool call - mock.server.processToolCall(context.Background(), client, req) - - // Get the response from the client's channel - select { - case response := <-client.channel: - // Parse the response - jsonStart := strings.Index(response, "{") - jsonEnd := strings.LastIndex(response, "}") - require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") - - jsonStr := response[jsonStart : jsonEnd+1] - var parsed map[string]any - err := json.Unmarshal([]byte(jsonStr), &parsed) - require.NoError(t, err, "Response should be valid JSON") - - // Get the result - result, ok := parsed["result"].(map[string]any) - require.True(t, ok, "Response should have a result object") - - // Get the content array - content, ok := result["content"].([]any) - require.True(t, ok, "Result should have a content array") - require.NotEmpty(t, content, "Content should not be empty") - - // The first content item should be our map result converted to JSON text - firstItem, ok := content[0].(map[string]any) - require.True(t, ok, "First content item should be an object") - - // Get the text content which should be our JSON - text, ok := firstItem[ContentTypeText].(string) - require.True(t, ok, "Content should have text") - - // Verify the text is valid JSON and contains our data - var mapResult map[string]any - err = json.Unmarshal([]byte(text), &mapResult) - require.NoError(t, err, "Text should be valid JSON") - - // Verify the content of our map - assert.Equal(t, "value", mapResult["string"], "String value should match") - assert.Equal(t, float64(42), mapResult["number"], "Number value should match") - assert.Equal(t, true, mapResult["boolean"], "Boolean value should match") - - // Check nested structure - nested, ok := mapResult["nested"].(map[string]any) - require.True(t, ok, "Should have nested map") - - array, ok := nested["array"].([]any) - require.True(t, ok, "Should have array in nested map") - assert.Len(t, array, 2, "Array should have 2 items") - assert.Equal(t, "item1", array[0], "First array item should match") - - obj, ok := nested["obj"].(map[string]any) - require.True(t, ok, "Should have obj in nested map") - assert.Equal(t, "value", obj["key"], "Nested object key should match") - - // Check null value - _, exists := mapResult["nullValue"] - assert.True(t, exists, "Null value key should exist") - - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for tool call response") - } -} - -// TestToolCallArrayResult tests a tool that returns an array result -func TestToolCallArrayResult(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Register a tool that returns an array - arrayTool := Tool{ - Name: "array.tool", - Description: "A tool that returns an array result", - InputSchema: InputSchema{ - Properties: map[string]any{}, - }, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - // Return an array of mixed content types - return []any{ - "string item", - 42, - true, - map[string]any{"key": "value"}, - []string{"nested", "array"}, - nil, - }, nil - }, - } - - err := mock.server.RegisterTool(arrayTool) - require.NoError(t, err) - - // Create a client - client := addTestClient(mock.server, "test-client", true) - - // Create a tool call request - params := struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments"` - }{ - Name: "array.tool", - Arguments: map[string]any{}, - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 1, - Method: methodToolsCall, - Params: paramBytes, - } - - // Process the tool call - mock.server.processToolCall(context.Background(), client, req) - - // Get the response from the client's channel - select { - case response := <-client.channel: - // Parse the response - jsonStart := strings.Index(response, "{") - jsonEnd := strings.LastIndex(response, "}") - require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") - - jsonStr := response[jsonStart : jsonEnd+1] - var parsed map[string]any - err := json.Unmarshal([]byte(jsonStr), &parsed) - require.NoError(t, err, "Response should be valid JSON") - - // Get the result - result, ok := parsed["result"].(map[string]any) - require.True(t, ok, "Response should have a result object") - - // Get the content array - content, ok := result["content"].([]any) - require.True(t, ok, "Result should have a content array") - require.Equal(t, 6, len(content), "Content should have 6 items, one for each array item") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for tool call response") - } -} - -// TestToolCallTextContentResult tests a tool that returns a TextContent result -func TestToolCallTextContentResult(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Register a tool that returns a TextContent - textContentTool := Tool{ - Name: "text.content.tool", - Description: "A tool that returns a TextContent result", - InputSchema: InputSchema{ - Properties: map[string]any{}, - }, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - // Return a TextContent object directly - return TextContent{ - Text: "This is a direct TextContent result", - Annotations: &Annotations{ - Audience: []RoleType{RoleUser, RoleAssistant}, - Priority: func() *float64 { p := 0.9; return &p }(), - }, - }, nil - }, - } - - err := mock.server.RegisterTool(textContentTool) - require.NoError(t, err) - - // Create a client - client := addTestClient(mock.server, "test-client", true) - - // Create a tool call request - params := struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments"` - }{ - Name: "text.content.tool", - Arguments: map[string]any{}, - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 1, - Method: methodToolsCall, - Params: paramBytes, - } - - // Process the tool call - mock.server.processToolCall(context.Background(), client, req) - - // Get the response from the client's channel - select { - case response := <-client.channel: - // Parse the response - jsonStart := strings.Index(response, "{") - jsonEnd := strings.LastIndex(response, "}") - require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") - - jsonStr := response[jsonStart : jsonEnd+1] - var parsed map[string]any - err := json.Unmarshal([]byte(jsonStr), &parsed) - require.NoError(t, err, "Response should be valid JSON") - - // Get the result - result, ok := parsed["result"].(map[string]any) - require.True(t, ok, "Response should have a result object") - - // Get the content array - content, ok := result["content"].([]any) - require.True(t, ok, "Result should have a content array") - require.NotEmpty(t, content, "Content should not be empty") - - // The first content item should be our TextContent - firstItem, ok := content[0].(map[string]any) - require.True(t, ok, "First content item should be an object") - - // Check annotations - annotations, ok := firstItem["annotations"].(map[string]any) - require.True(t, ok, "Should have annotations") - - audience, ok := annotations["audience"].([]any) - require.True(t, ok, "Should have audience in annotations") - assert.Len(t, audience, 2, "Audience should have 2 items") - - priority, ok := annotations["priority"].(float64) - require.True(t, ok, "Should have priority in annotations") - assert.Equal(t, 0.9, priority, "Priority should match") - - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for tool call response") - } -} - -// TestToolCallImageContentResult tests a tool that returns an ImageContent result -func TestToolCallImageContentResult(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Register a tool that returns an ImageContent - imageContentTool := Tool{ - Name: "image.content.tool", - Description: "A tool that returns an ImageContent result", - InputSchema: InputSchema{ - Properties: map[string]any{}, - }, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - // Return an ImageContent object directly - return ImageContent{ - Data: "dGVzdCBpbWFnZSBkYXRhIChiYXNlNjQgZW5jb2RlZCk=", // "test image data (base64 encoded)" in base64 - MimeType: "image/png", - }, nil - }, - } - - err := mock.server.RegisterTool(imageContentTool) - require.NoError(t, err) - - // Create a client - client := addTestClient(mock.server, "test-client", true) - - // Create a tool call request - params := struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments"` - }{ - Name: "image.content.tool", - Arguments: map[string]any{}, - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 1, - Method: methodToolsCall, - Params: paramBytes, - } - - // Process the tool call - mock.server.processToolCall(context.Background(), client, req) - - // Get the response from the client's channel - select { - case response := <-client.channel: - // Parse the response - jsonStart := strings.Index(response, "{") - jsonEnd := strings.LastIndex(response, "}") - require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") - - jsonStr := response[jsonStart : jsonEnd+1] - var parsed map[string]any - err := json.Unmarshal([]byte(jsonStr), &parsed) - require.NoError(t, err, "Response should be valid JSON") - - // Get the result - result, ok := parsed["result"].(map[string]any) - require.True(t, ok, "Response should have a result object") - - // Get the content array - content, ok := result["content"].([]any) - require.True(t, ok, "Result should have a content array") - require.NotEmpty(t, content, "Content should not be empty") - - // The first content item should be our ImageContent - firstItem, ok := content[0].(map[string]any) - require.True(t, ok, "First content item should be an object") - - // Check image data - data, ok := firstItem["data"].(string) - require.True(t, ok, "Content should have data") - assert.Equal(t, "dGVzdCBpbWFnZSBkYXRhIChiYXNlNjQgZW5jb2RlZCk=", data, "Image data should match") - - // Check mime type - mimeType, ok := firstItem["mimeType"].(string) - require.True(t, ok, "Content should have mimeType") - assert.Equal(t, "image/png", mimeType, "MimeType should match") - - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for tool call response") - } -} - -// TestToolCallToolResultType tests a tool that returns a ToolResult type -func TestToolCallToolResultType(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - toolResultTool := Tool{ - Name: "toolresult.tool", - Description: "A tool that returns a ToolResult object", - InputSchema: InputSchema{ - Type: ContentTypeObject, - Properties: map[string]any{}, - }, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - return ToolResult{ - Type: ContentTypeText, - Content: "This is a ToolResult with text content type", - }, nil - }, - } - err := mock.server.RegisterTool(toolResultTool) - require.NoError(t, err) - - toolResultImageTool := Tool{ - Name: "toolresult.image.tool", - Description: "A tool that returns a ToolResult with image content", - InputSchema: InputSchema{ - Type: ContentTypeObject, - Properties: map[string]any{}, - }, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - return ToolResult{ - Type: "image", - Content: map[string]any{ - "data": "dGVzdCBpbWFnZSBkYXRhIGZvciB0b29sIHJlc3VsdA==", // "test image data for tool result" in base64 - "mimeType": "image/jpeg", - }, - }, nil - }, - } - err = mock.server.RegisterTool(toolResultImageTool) - require.NoError(t, err) - - toolResultAudioTool := Tool{ - Name: "toolresult.audio.tool", - Description: "A tool that returns a ToolResult with audio content", - InputSchema: InputSchema{ - Properties: map[string]any{}, - }, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - // Test with image type - return ToolResult{ - Type: "audio", - Content: map[string]any{ - "data": "dGVzdCBpbWFnZSBkYXRhIGZvciB0b29sIHJlc3VsdA==", // "test image data for tool result" in base64 - "mimeType": "audio", - }, - }, nil - }, - } - err = mock.server.RegisterTool(toolResultAudioTool) - require.NoError(t, err) - - toolResultIntType := Tool{ - Name: "toolresult.int.tool", - Description: "A tool that returns a ToolResult with int content", - InputSchema: InputSchema{ - Properties: map[string]any{}, - }, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - return 2, nil - }, - } - err = mock.server.RegisterTool(toolResultIntType) - require.NoError(t, err) - - toolResultBadType := Tool{ - Name: "toolresult.bad.tool", - Description: "A tool that returns a ToolResult with bad content", - InputSchema: InputSchema{ - Properties: map[string]any{}, - }, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - return map[string]any{ - "type": "custom", - "data": make(chan int), - }, nil - }, - } - err = mock.server.RegisterTool(toolResultBadType) - require.NoError(t, err) - - // Test text ToolResult - t.Run("textToolResult", func(t *testing.T) { - // Create a client - client := addTestClient(mock.server, "test-client-text", true) - - // Create a tool call request - params := struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments"` - }{ - Name: "toolresult.tool", - Arguments: map[string]any{}, - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 1, - Method: methodToolsCall, - Params: paramBytes, - } - - // Process the tool call - mock.server.processToolCall(context.Background(), client, req) - - // Get the response from the client's channel - select { - case response := <-client.channel: - // Parse the response - jsonStart := strings.Index(response, "{") - jsonEnd := strings.LastIndex(response, "}") - require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") - - jsonStr := response[jsonStart : jsonEnd+1] - var parsed map[string]any - err := json.Unmarshal([]byte(jsonStr), &parsed) - require.NoError(t, err, "Response should be valid JSON") - - // Get the result - result, ok := parsed["result"].(map[string]any) - require.True(t, ok, "Response should have a result object") - - // Get the content array - content, ok := result["content"].([]any) - require.True(t, ok, "Result should have a content array") - require.NotEmpty(t, content, "Content should not be empty") - - // The first content item should be converted from ToolResult to TextContent - firstItem, ok := content[0].(map[string]any) - require.True(t, ok, "First content item should be an object") - - // Check text content - text, ok := firstItem[ContentTypeText].(string) - require.True(t, ok, "Content should have text") - assert.Equal(t, "This is a ToolResult with text content type", text, "Text content should match") - - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for tool call response") - } - }) - - // Test image ToolResult - t.Run("imageToolResult", func(t *testing.T) { - // Create a client - client := addTestClient(mock.server, "test-client-image", true) - - // Create a tool call request - params := struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments"` - }{ - Name: "toolresult.image.tool", - Arguments: map[string]any{}, - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 2, - Method: methodToolsCall, - Params: paramBytes, - } - - // Process the tool call - mock.server.processToolCall(context.Background(), client, req) - - // Get the response from the client's channel - select { - case response := <-client.channel: - // Parse the response - jsonStart := strings.Index(response, "{") - jsonEnd := strings.LastIndex(response, "}") - require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") - - jsonStr := response[jsonStart : jsonEnd+1] - var parsed map[string]any - err := json.Unmarshal([]byte(jsonStr), &parsed) - require.NoError(t, err, "Response should be valid JSON") - - // Get the result - result, ok := parsed["result"].(map[string]any) - require.True(t, ok, "Response should have a result object") - - // Get the content array - content, ok := result["content"].([]any) - require.True(t, ok, "Result should have a content array") - require.NotEmpty(t, content, "Content should not be empty") - - // The first content item should be converted from ToolResult to ImageContent - firstItem, ok := content[0].(map[string]any) - require.True(t, ok, "First content item should be an object") - - // Check image data and mime type - data, ok := firstItem["data"].(string) - require.True(t, ok, "Content should have data") - assert.Equal(t, "dGVzdCBpbWFnZSBkYXRhIGZvciB0b29sIHJlc3VsdA==", data, "Image data should match") - - mimeType, ok := firstItem["mimeType"].(string) - require.True(t, ok, "Content should have mimeType") - assert.Equal(t, "image/jpeg", mimeType, "MimeType should match") - - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for tool call response") - } - }) - - // Test image ToolResult - t.Run("audioToolResult", func(t *testing.T) { - // Create a client - client := addTestClient(mock.server, "test-client-image", true) - - // Create a tool call request - params := struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments"` - }{ - Name: "toolresult.audio.tool", - Arguments: map[string]any{}, - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 2, - Method: methodToolsCall, - Params: paramBytes, - } - - // Process the tool call - mock.server.processToolCall(context.Background(), client, req) - - // Get the response from the client's channel - select { - case response := <-client.channel: - // Parse the response - jsonStart := strings.Index(response, "{") - jsonEnd := strings.LastIndex(response, "}") - require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") - - jsonStr := response[jsonStart : jsonEnd+1] - var parsed map[string]any - err := json.Unmarshal([]byte(jsonStr), &parsed) - require.NoError(t, err, "Response should be valid JSON") - - // Get the result - result, ok := parsed["result"].(map[string]any) - require.True(t, ok, "Response should have a result object") - - // Get the content array - content, ok := result["content"].([]any) - require.True(t, ok, "Result should have a content array") - require.NotEmpty(t, content, "Content should not be empty") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for tool call response") - } - }) - - // Test text ToolResult - t.Run("ToolResult with int type", func(t *testing.T) { - // Create a client - client := addTestClient(mock.server, "test-client-text", true) - - // Create a tool call request - params := struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments"` - }{ - Name: "toolresult.int.tool", - Arguments: map[string]any{}, - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 1, - Method: methodToolsCall, - Params: paramBytes, - } - - // Process the tool call - mock.server.processToolCall(context.Background(), client, req) - - // Get the response from the client's channel - select { - case response := <-client.channel: - // Parse the response - jsonStart := strings.Index(response, "{") - jsonEnd := strings.LastIndex(response, "}") - require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") - - jsonStr := response[jsonStart : jsonEnd+1] - var parsed map[string]any - err := json.Unmarshal([]byte(jsonStr), &parsed) - require.NoError(t, err, "Response should be valid JSON") - - // Get the result - result, ok := parsed["result"].(map[string]any) - require.True(t, ok, "Response should have a result object") - - // Get the content array - content, ok := result["content"].([]any) - require.True(t, ok, "Result should have a content array") - require.NotEmpty(t, content, "Content should not be empty") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for tool call response") - } - }) - - // Test text ToolResult - t.Run("ToolResult with bad type", func(t *testing.T) { - // Create a client - client := addTestClient(mock.server, "test-client-text", true) - - // Create a tool call request - params := struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments"` - }{ - Name: "toolresult.bad.tool", - Arguments: map[string]any{}, - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 1, - Method: methodToolsCall, - Params: paramBytes, - } - - // Process the tool call - mock.server.processToolCall(context.Background(), client, req) - - // Get the response from the client's channel - select { - case response := <-client.channel: - assert.Contains(t, response, "json: unsupported type") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for tool call response") - } - }) -} - -// TestToolCallError tests that tool errors are properly handled -func TestToolCallError(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Register a tool that returns an error - err := mock.server.RegisterTool(Tool{ - Name: "error.tool", - Description: "A tool that returns an error", - InputSchema: InputSchema{ - Properties: map[string]any{}, - }, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - return nil, fmt.Errorf("tool execution failed") - }, - }) - require.NoError(t, err) - - // Simulate a client - client := addTestClient(mock.server, "test-client", true) - - // Create a tool call request - params := struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments"` - }{ - Name: "error.tool", - Arguments: map[string]any{}, - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 1, - Method: methodToolsCall, - Params: paramBytes, - } - - // Process the tool call - mock.server.processToolCall(context.Background(), client, req) - - // Check the response - select { - case response := <-client.channel: - assert.Contains(t, response, "event: message", "Response should have message event") - assert.Contains(t, response, "Error:", "Response should contain the error message") - assert.Contains(t, response, "isError", "Response should indicate it's an error") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for tool call response") - } -} - -// TestToolCallTimeout tests that tool timeouts are properly handled -func TestToolCallTimeout(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Register a tool that times out - err := mock.server.RegisterTool(Tool{ - Name: "timeout.tool", - Description: "A tool that times out", - InputSchema: InputSchema{ - Properties: map[string]any{}, - }, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - <-ctx.Done() - return nil, fmt.Errorf("tool execution timed out") - }, - }) - require.NoError(t, err) - - // Simulate a client - client := addTestClient(mock.server, "test-client", true) - - // Create a tool call request - params := struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments"` - }{ - Name: "timeout.tool", - Arguments: map[string]any{}, - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 1, - Method: methodToolsCall, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", bytes.NewReader(jsonBody)) - ctx, cancel := context.WithTimeout(r.Context(), 10*time.Millisecond) - defer cancel() - r = r.WithContext(ctx) - w := httptest.NewRecorder() - - // Process through handleRequest - go mock.server.handleRequest(w, r) - - // Check the response - select { - case response := <-client.channel: - assert.Contains(t, response, "event: message", "Response should have message event") - assert.Contains(t, response, `-32001`, "Response should contain a timeout error code") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for tool call response") - } -} - -// TestInitializeAndNotifications tests the client initialization flow -func TestInitializeAndNotifications(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Create a test client - client := addTestClient(mock.server, "test-client", false) - - // Test initialize request - initReq := Request{ - JsonRpc: jsonRpcVersion, - ID: 1, - Method: "initialize", - Params: json.RawMessage(`{}`), - } - - mock.server.processInitialize(context.Background(), client, initReq) - - // Check that client is initialized after initialize request - assert.True(t, client.initialized, "Client should be marked as initialized after initialize request") - - // Check the response format - select { - case response := <-client.channel: - // Check response contains required initialization fields - assert.Contains(t, response, "protocolVersion", "Response should include protocol version") - assert.Contains(t, response, "capabilities", "Response should include capabilities") - assert.Contains(t, response, "serverInfo", "Response should include server info") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for initialize response") - } - - // Test notification initialized - mock.server.processNotificationInitialized(client) - assert.True(t, client.initialized, "Client should remain initialized after notification") -} - -// TestRequestHandlingWithoutInitialization tests that requests are properly rejected when client is not initialized -func TestRequestHandlingWithoutInitialization(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - mock.registerExampleTool() - - // Create an uninitialized test client - client := addTestClient(mock.server, "test-client", false) - - // Attempt a tool call before initialization - params := struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments"` - }{ - Name: "test.tool", - Arguments: map[string]any{"input": "foo"}, - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 1, - Method: methodToolsCall, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Handle the request - this should fail because client is not initialized - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", bytes.NewReader(jsonBody)) - mock.server.handleRequest(httptest.NewRecorder(), r) - - // Check error response - select { - case response := <-client.channel: - assert.Contains(t, strings.ToLower(response), "error", "Response should contain an error") - assert.Contains(t, strings.ToLower(response), "not fully initialized", - "Response should mention client not being initialized") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for error response") - } -} - -// TestPing tests the ping request handling -func TestPing(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Create a test client - client := addTestClient(mock.server, "test-client", true) - - // Create a ping request - pingReq := Request{ - JsonRpc: jsonRpcVersion, - ID: 1, - Method: "ping", - Params: json.RawMessage(`{}`), - } - - jsonBody, _ := json.Marshal(pingReq) - - // Handle the request - this should fail because client is not initialized - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", bytes.NewReader(jsonBody)) - mock.server.handleRequest(httptest.NewRecorder(), r) - - // Check response - select { - case response := <-client.channel: - assert.Contains(t, response, `"result":`, "Response should contain a result field") - assert.Contains(t, response, `"id":1`, "Response should have the same ID as the request") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for ping response") - } -} - -// TestNotificationCancelled tests the notification cancelled handling -func TestNotificationCancelled(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Create a test client - client := addTestClient(mock.server, "test-client", true) - - // Create a cancellation request - paramBytes, _ := json.Marshal(map[string]any{ - "requestId": 123, - "reason": "user_cancelled", - }) - - cancelReq := Request{ - JsonRpc: jsonRpcVersion, - Method: "notifications/cancelled", - Params: paramBytes, - } - - jsonBody, _ := json.Marshal(cancelReq) - - // Handle the request - this should fail because client is not initialized - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", bytes.NewReader(jsonBody)) - mock.server.handleRequest(httptest.NewRecorder(), r) - - // No response expected for notifications - select { - case <-client.channel: - t.Fatal("No response expected for notifications") - case <-time.After(50 * time.Millisecond): - // This is the expected outcome - no response - } -} - -func TestNotificationCancelled_badParams(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - client := addTestClient(mock.server, "test-client", true) - - cancelReq := Request{ - JsonRpc: jsonRpcVersion, - Method: "notifications/cancelled", - Params: []byte(`invalid json`), - } - - buf := logtest.NewCollector(t) - mock.server.processNotificationCancelled(context.Background(), client, cancelReq) - - select { - case <-client.channel: - t.Fatal("No response expected for notifications") - case <-time.After(50 * time.Millisecond): - assert.Contains(t, buf.String(), "Failed to parse cancellation params") - } -} - -func TestUnknownRequest(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Create a test client - client := addTestClient(mock.server, "test-client", true) - req := Request{ - JsonRpc: jsonRpcVersion, - Method: "unknown", - } - - jsonBody, _ := json.Marshal(req) - - // Handle the request - this should fail because client is not initialized - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", bytes.NewReader(jsonBody)) - mock.server.handleRequest(httptest.NewRecorder(), r) - - // No response expected for notifications - select { - case message := <-client.channel: - evt, err := parseEvent(message) - require.NoError(t, err, "Should parse event without error") - errCode := evt.Data["error"].(map[string]any)["code"] - // because error code will be converted into float64 - assert.Equal(t, float64(errCodeMethodNotFound), errCode) - case <-time.After(50 * time.Millisecond): - // This is the expected outcome - no response - } -} - -func TestResponseWriter_notFlusher(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Handle the request - this should fail because client is not initialized - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", http.NoBody) - var w notFlusherResponseWriter - mock.server.handleSSE(&w, r) - assert.Equal(t, http.StatusInternalServerError, w.code) -} - -func TestResponseWriter_cantWrite(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Handle the request - this should fail because client is not initialized - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", http.NoBody) - var w cantWriteResponseWriter - mock.server.handleSSE(&w, r) - assert.Equal(t, 0, w.code) -} - -func TestHandleSSE_channelClosed(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Handle the request - this should fail because client is not initialized - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", http.NoBody) - w := httptest.NewRecorder() - var wg sync.WaitGroup - wg.Add(1) - go func() { - mock.server.handleSSE(w, r) - wg.Done() - }() - - buf := logtest.NewCollector(t) - for { - time.Sleep(time.Millisecond) - mock.server.clientsLock.Lock() - if len(mock.server.clients) > 0 { - for _, client := range mock.server.clients { - close(client.channel) - delete(mock.server.clients, client.id) - } - mock.server.clientsLock.Unlock() - break - } - mock.server.clientsLock.Unlock() - } - wg.Wait() - assert.Contains(t, "channel was closed", buf.Content(), "Should log channel closed error") -} - -func TestHandleSSE_badResponseWriter(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Handle the request - this should fail because client is not initialized - r := httptest.NewRequest(http.MethodPost, "/", http.NoBody) - var wg sync.WaitGroup - wg.Add(1) - go func() { - var w writeOnceResponseWriter - mock.server.handleSSE(&w, r) - wg.Done() - }() - - var session string - for { - time.Sleep(time.Millisecond) - mock.server.clientsLock.Lock() - if len(mock.server.clients) > 0 { - for _, client := range mock.server.clients { - session = client.id - } - mock.server.clientsLock.Unlock() - break - } - mock.server.clientsLock.Unlock() - } - - time.Sleep(100 * time.Millisecond) - // Create a ping request - pingReq := Request{ - JsonRpc: jsonRpcVersion, - ID: 1, - Method: "ping", - Params: json.RawMessage(`{}`), - } - - jsonBody, _ := json.Marshal(pingReq) - buf := logtest.NewCollector(t) - - // Handle the request - this should fail because client is not initialized - r = httptest.NewRequest(http.MethodPost, fmt.Sprintf("/?session_id=%s", session), - bytes.NewReader(jsonBody)) - mock.server.handleRequest(httptest.NewRecorder(), r) - - wg.Wait() - assert.Contains(t, "Failed to write", buf.Content()) -} - -// TestGetPrompt tests the prompts/get endpoint -func TestGetPrompt(t *testing.T) { - t.Run("test prompt", func(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Create a test client - client := addTestClient(mock.server, "test-client", true) - - // Register a test prompt - testPrompt := Prompt{ - Name: "test.prompt", - Description: "A test prompt", - } - mock.server.RegisterPrompt(testPrompt) - - // Create a get prompt request - paramBytes, _ := json.Marshal(map[string]any{ - "name": "test.prompt", - "arguments": map[string]string{ - "topic": "test topic", + handler := func(ctx context.Context, req *CallToolRequest, args CalculateArgs) (*CallToolResult, CalculateResult, error) { + result := CalculateResult{Sum: args.A + args.B} + return &CallToolResult{ + Content: []Content{ + &TextContent{Text: "Sum calculated"}, }, - }) + }, result, nil + } - promptReq := Request{ - JsonRpc: jsonRpcVersion, - ID: 1, - Method: "prompts/get", - Params: paramBytes, - } - - // Process the request - mock.server.processGetPrompt(context.Background(), client, promptReq) - - // Check response - select { - case response := <-client.channel: - assert.Contains(t, response, "description", "Response should include prompt description") - assert.Contains(t, response, "prompts", "Response should include prompts array") - assert.Contains(t, response, "A test prompt", "Response should include the topic argument") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for prompt response") - } - }) - - t.Run("test prompt with invalid params", func(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Create a test client - client := addTestClient(mock.server, "test-client", true) - - paramBytes := []byte("invalid json") - promptReq := Request{ - JsonRpc: jsonRpcVersion, - ID: 1, - Method: "prompts/get", - Params: paramBytes, - } - - // Process the request - mock.server.processGetPrompt(context.Background(), client, promptReq) - - // Check response - select { - case response := <-client.channel: - evt, err := parseEvent(response) - assert.NoError(t, err, "Should be able to parse event") - errMsg, ok := evt.Data["error"].(map[string]any) - assert.True(t, ok, "Should have error in response") - assert.Equal(t, "Invalid parameters", errMsg["message"], "Error message should match") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for prompt response") - } - }) - - t.Run("test prompt with nil params", func(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Create a test client - client := addTestClient(mock.server, "test-client", true) - - // Register a test prompt - testPrompt := Prompt{ - Name: "test.prompt", - Description: "A test prompt", - } - mock.server.RegisterPrompt(testPrompt) - - // Create a get prompt request - paramBytes, _ := json.Marshal(map[string]any{ - "name": "test.prompt", - }) - promptReq := Request{ - JsonRpc: jsonRpcVersion, - ID: 1, - Method: "prompts/get", - Params: paramBytes, - } - - // Process the request - mock.server.processGetPrompt(context.Background(), client, promptReq) - - // Check response - select { - case response := <-client.channel: - _, err := parseEvent(response) - assert.NoError(t, err, "Should be able to parse event") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for prompt response") - } - }) + AddTool(server, tool, handler) } -// TestBroadcast tests the broadcast functionality -func TestBroadcast(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() +func TestServerLifecycle(t *testing.T) { + c := McpConf{} + c.Host = "127.0.0.1" + c.Port = 0 // Use random port + c.Mcp.Name = "lifecycle-test" + c.Mcp.SseEndpoint = "/sse" + c.Mcp.MessageEndpoint = "/message" - // Create two test clients - client1 := addTestClient(mock.server, "client-1", true) - client2 := addTestClient(mock.server, "client-2", true) + server := NewMcpServer(c) - // Broadcast a test message - testData := map[string]string{"key": "value"} - mock.server.broadcast("test_event", testData) + // Test that Start and Stop can be called + // We don't actually start it to avoid port conflicts in tests + impl := server.(*mcpServerImpl) + assert.NotNil(t, impl.httpServer) - // Check both clients received the broadcast - for i, client := range []*mcpClient{client1, client2} { - select { - case response := <-client.channel: - assert.Contains(t, response, `event: test_event`, "Response should have the correct event") - assert.Contains(t, response, `"key":"value"`, "Response should contain the broadcast data") - case <-time.After(100 * time.Millisecond): - t.Fatalf("Timed out waiting for broadcast on client %d", i+1) - } - } - - buf := logtest.NewCollector(t) - mock.server.broadcast("test_event", make(chan string)) - // Check that the broadcast was logged - content := buf.Content() - assert.Contains(t, content, "Failed", "Broadcast should be logged") - - for i := 0; i < eventChanSize; i++ { - mock.server.broadcast("test_event", "test") - } - - done := make(chan struct{}) - go func() { - mock.server.broadcast("test_event", "ignored") - close(done) - }() - - select { - case <-time.After(100 * time.Millisecond): - assert.Fail(t, "broadcast should not block") - case <-done: - } -} - -// TestHandleSSEPing tests the automatic ping functionality in the SSE handler -func TestHandleSSEPing(t *testing.T) { - originalInterval := pingInterval.Load() - pingInterval.Set(50 * time.Millisecond) + // Just verify the methods exist and can be called + // Actual server start/stop is tested in integration tests defer func() { - pingInterval.Set(originalInterval) + if r := recover(); r == nil { + // If no panic, call stop + server.Stop() + } }() - - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Create a request context that can be cancelled - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Create a test ResponseRecorder and Request - w := httptest.NewRecorder() - r := httptest.NewRequest("GET", mock.server.conf.Mcp.SseEndpoint, nil).WithContext(ctx) - - // Create a channel to coordinate the test - done := make(chan struct{}) - - // Set up a custom ResponseRecorder that captures writes and signals the test - customResponseWriter := &testResponseWriter{ - ResponseRecorder: w, - writes: make([]string, 0), - done: done, - pingDetected: false, - } - - // Start the SSE handler in a goroutine - go func() { - mock.server.handleSSE(customResponseWriter, r) - }() - - // Wait for ping or timeout - select { - case <-done: - // A ping was detected - assert.True(t, customResponseWriter.pingDetected, "Ping message should have been sent") - case <-time.After(pingInterval.Load() + 100*time.Millisecond): - t.Fatal("Timed out waiting for ping message") - } - - // Verify that the client was added and cleaned up - mock.server.clientsLock.Lock() - clientCount := len(mock.server.clients) - mock.server.clientsLock.Unlock() - - // Clean up by cancelling the context - cancel() - - // Wait for cleanup to complete - time.Sleep(50 * time.Millisecond) - - // Verify client was removed - mock.server.clientsLock.Lock() - finalClientCount := len(mock.server.clients) - mock.server.clientsLock.Unlock() - - assert.Equal(t, 1, clientCount, "One client should be added during the test") - assert.Equal(t, 0, finalClientCount, "Client should be cleaned up after context cancellation") -} - -// TestHandleSSEPing tests the automatic ping functionality in the SSE handler -func TestHandleSSEPing_writeOnce(t *testing.T) { - originalInterval := pingInterval.Load() - pingInterval.Set(50 * time.Millisecond) - defer func() { - pingInterval.Set(originalInterval) - }() - - buf := logtest.NewCollector(t) - var bufLock sync.Mutex - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Start the SSE handler in a goroutine - go func() { - var w writeOnceResponseWriter - r := httptest.NewRequest(http.MethodGet, mock.server.conf.Mcp.SseEndpoint, http.NoBody) - bufLock.Lock() - defer bufLock.Unlock() - mock.server.handleSSE(&w, r) - }() - - // Wait for ping or timeout - time.Sleep(100 * time.Millisecond) - bufLock.Lock() - assert.Contains(t, "Failed to send ping", buf.Content()) - bufLock.Unlock() } func TestServerStartStop(t *testing.T) { - // Create a simple configuration for testing - const yamlConf = `name: test-server -host: localhost -port: 0 -timeout: 1000 -mcp: - name: mcp-test-server -` - var c McpConf - assert.NoError(t, conf.LoadFromYamlBytes([]byte(yamlConf), &c)) + // Create server with a unique port + c := McpConf{} + c.Host = "127.0.0.1" + c.Port = 18080 // Use high port to avoid conflicts + c.Mcp.Name = "start-stop-test" + c.Mcp.SseEndpoint = "/sse" + c.Mcp.MessageEndpoint = "/message" + c.Mcp.SseTimeout = 1 * time.Second + c.Mcp.MessageTimeout = 1 * time.Second - // Create the server - s := NewMcpServer(c) + server := NewMcpServer(c) - // Start and stop in goroutine to avoid blocking - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() + // Test that we can call Stop even without Start + // (This tests the Stop method coverage) + server.Stop() +} +func TestServerStartActual(t *testing.T) { + // Create server with specific port + c := McpConf{} + c.Host = "127.0.0.1" + c.Port = 19080 // Use specific high port + c.Mcp.Name = "actual-start-test" + c.Mcp.SseEndpoint = "/sse" + c.Mcp.MessageEndpoint = "/message" + c.Mcp.SseTimeout = 1 * time.Second + c.Mcp.MessageTimeout = 1 * time.Second + + server := NewMcpServer(c) + + // Start server in goroutine go func() { - s.Start() + server.Start() // This blocks until Stop() is called }() - // Allow a brief moment for startup - time.Sleep(50 * time.Millisecond) + // Give server time to start + time.Sleep(300 * time.Millisecond) + + // Make a test request to the SSE endpoint to trigger the handler callback + client := &http.Client{Timeout: 500 * time.Millisecond} + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:19080/sse", nil) + if err == nil { + req.Header.Set("Accept", "text/event-stream") + resp, err := client.Do(req) + if err == nil { + resp.Body.Close() + // Server is responding - this proves Start() worked + // and the SSE handler callback was called + assert.True(t, resp.StatusCode == http.StatusOK || resp.StatusCode > 0) + } + } // Stop the server - s.Stop() + server.Stop() - // Wait for context to ensure we properly stopped or timed out - <-ctx.Done() + // Give it time to shutdown + time.Sleep(100 * time.Millisecond) } -// TestNotificationInitialized tests the notifications/initialized handling in detail -func TestNotificationInitialized(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - t.Run("uninitializedClient", func(t *testing.T) { - // Create an uninitialized test client - client := addTestClient(mock.server, "test-client-uninitialized", false) - assert.False(t, client.initialized, "Client should start as uninitialized") - - // Create a notification request - req := Request{ - JsonRpc: jsonRpcVersion, - Method: methodNotificationsInitialized, - // No ID for notifications - Params: json.RawMessage(`{}`), // Empty params acceptable for this notification - } - - // Process through the request handler - jsonBody, _ := json.Marshal(req) - r := httptest.NewRequest(http.MethodPost, "/?session_id="+client.id, bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - mock.server.handleRequest(w, r) - - // Verify client is now initialized - assert.True(t, client.initialized, "Client should be marked as initialized after notifications/initialized") - - // Verify the response code is 202 Accepted - assert.Equal(t, http.StatusAccepted, w.Code, "Response status should be 202 Accepted") - - // No actual response body should be sent for notifications - select { - case <-client.channel: - t.Fatal("No response expected for notifications") - case <-time.After(50 * time.Millisecond): - // This is the expected outcome - no response - } - }) - - t.Run("initializedClient", func(t *testing.T) { - // Create an already initialized client - client := addTestClient(mock.server, "test-client-initialized", true) - assert.True(t, client.initialized, "Client should start as initialized") - - // Directly call processNotificationInitialized - mock.server.processNotificationInitialized(client) - - // Verify client remains initialized - assert.True(t, client.initialized, "Client should remain initialized after notifications/initialized") - - // No response expected - select { - case <-client.channel: - t.Fatal("No response expected for notifications") - case <-time.After(50 * time.Millisecond): - // This is the expected outcome - no response - } - }) - - t.Run("errorOnIncorrectUsage", func(t *testing.T) { - // Create a test client - client := addTestClient(mock.server, "test-client-error", false) - - // Create a request with ID (incorrect usage - should be a notification) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 123, // Adding ID makes this an incorrect usage - should be notification - Method: methodNotificationsInitialized, - Params: json.RawMessage(`{}`), - } - - // Process through the request handler - jsonBody, _ := json.Marshal(req) - r := httptest.NewRequest(http.MethodPost, "/?session_id="+client.id, bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - mock.server.handleRequest(w, r) - - // Should get an error response - select { - case response := <-client.channel: - assert.Contains(t, response, "error", "Response should contain an error") - assert.Contains(t, response, "Method should be used as a notification", "Response should explain notification usage") - assert.Contains(t, response, `"id":123`, "Response should include the original ID") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for error response") - } - - // Client should not be initialized due to error - assert.False(t, client.initialized, "Client should not be initialized after error") - }) -} - -func TestSendResponse(t *testing.T) { - t.Run("bad response", func(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Create a test client - client := addTestClient(mock.server, "test-client", true) - - // Create a response - response := Response{ - JsonRpc: jsonRpcVersion, - ID: 1, - Result: make(chan int), - } - - // Send the response - mock.server.sendResponse(context.Background(), client, 1, response) - - // Check the response in the client's channel - select { - case res := <-client.channel: - evt, err := parseEvent(res) - require.NoError(t, err, "Should parse event without error") - errMsg, ok := evt.Data["error"].(map[string]any) - require.True(t, ok, "Should have error in response") - assert.Equal(t, float64(errCodeInternalError), errMsg["code"], "Error code should match") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for response") - } - }) - - t.Run("channel full", func(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Create a test client - client := addTestClient(mock.server, "test-client", true) - for i := 0; i < eventChanSize; i++ { - client.channel <- "test" - } - - // Create a response - response := Response{ - JsonRpc: jsonRpcVersion, - ID: 1, - Result: "foo", - } - - buf := logtest.NewCollector(t) - // Send the response - mock.server.sendResponse(context.Background(), client, 1, response) - // Check the response in the client's channel - assert.Contains(t, buf.String(), "channel is full") - }) -} - -func TestSendErrorResponse(t *testing.T) { - t.Run("channel full", func(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Create a test client - client := addTestClient(mock.server, "test-client", true) - for i := 0; i < eventChanSize; i++ { - client.channel <- "test" - } - - buf := logtest.NewCollector(t) - // Send the response - mock.server.sendErrorResponse(context.Background(), client, 1, "foo", errCodeInternalError) - // Check the response in the client's channel - assert.Contains(t, buf.String(), "channel is full") - }) -} - -// TestMethodToolsCall tests the handling of tools/call method through handleRequest -func TestMethodToolsCall(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - t.Run("validToolCall", func(t *testing.T) { - // Register a test tool - mock.registerExampleTool() - - // Create an initialized client - client := addTestClient(mock.server, "test-client-valid", true) - - // Create a tools call request with progress token metadata - params := struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments"` - Meta struct { - ProgressToken string `json:"progressToken"` - } `json:"_meta"` - }{ - Name: "test.tool", - Arguments: map[string]any{ - "input": "test-input", - }, - Meta: struct { - ProgressToken string `json:"progressToken"` - }{ - ProgressToken: "token123", - }, - } - - paramBytes, err := json.Marshal(params) - require.NoError(t, err, "Failed to marshal tool call parameters") - - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 42, // Specific ID to verify in response - Method: methodToolsCall, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-valid", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest (full path) - mock.server.handleRequest(w, r) - - // Verify the HTTP response - assert.Equal(t, http.StatusAccepted, w.Code, "HTTP status should be 202 Accepted") - - // Check the response in client's channel - select { - case response := <-client.channel: - // Verify it's a message event with the expected format - assert.Contains(t, response, "event: message", "Response should be a message event") - - // Parse the JSON part of the SSE message - jsonStart := strings.Index(response, "{") - jsonEnd := strings.LastIndex(response, "}") - require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") - jsonStr := response[jsonStart : jsonEnd+1] - - var parsed map[string]any - err := json.Unmarshal([]byte(jsonStr), &parsed) - require.NoError(t, err, "Should be able to parse response JSON") - - // Validate the structure - result, ok := parsed["result"].(map[string]any) - require.True(t, ok, "Response should have a result object") - - // Verify ID matches our request - id, ok := parsed["id"].(float64) - assert.True(t, ok, "Response should have an ID") - assert.Equal(t, float64(42), id, "Response ID should match request ID") - - // Verify content - content, ok := result["content"].([]any) - require.True(t, ok, "Result should have content array") - assert.NotEmpty(t, content, "Content should not be empty") - - // Check for progress token in metadata - meta, hasMeta := result["_meta"].(map[string]any) - assert.True(t, hasMeta, "Response should include _meta with progress token") - if hasMeta { - token, hasToken := meta["progressToken"].(string) - assert.True(t, hasToken, "Meta should include progress token") - assert.Equal(t, "token123", token, "Progress token should match") - } - - // Check actual result content - if len(content) > 0 { - firstItem, ok := content[0].(map[string]any) - if ok { - assert.Contains(t, firstItem[ContentTypeText], "Processed: test-input", "Content should include processed input") - } - } - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for tool call response") - } - }) - - t.Run("invalidToolName", func(t *testing.T) { - // Create an initialized client - client := addTestClient(mock.server, "test-client-invalid", true) - - // Create a tools call request with invalid tool name - params := struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments"` - }{ - Name: "non-existent-tool", - Arguments: map[string]any{}, - } - - paramBytes, err := json.Marshal(params) - require.NoError(t, err, "Failed to marshal tool call parameters") - - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 43, - Method: methodToolsCall, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-invalid", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest - mock.server.handleRequest(w, r) - - // Verify response contains error about non-existent tool - select { - case response := <-client.channel: - assert.Contains(t, response, "error", "Response should contain error") - assert.Contains(t, response, "not found", "Error should mention tool not found") - assert.Contains(t, response, "non-existent-tool", "Error should mention the invalid tool name") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for error response") - } - }) - - t.Run("clientNotInitialized", func(t *testing.T) { - // Register a tool - mock.registerExampleTool() - - // Create an uninitialized client - client := addTestClient(mock.server, "test-client-uninitialized", false) - - // Create a valid tools call request - params := struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments"` - }{ - Name: "test.tool", - Arguments: map[string]any{ - "input": "test-input", - }, - } - - paramBytes, err := json.Marshal(params) - require.NoError(t, err, "Failed to marshal tool call parameters") - - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 44, - Method: methodToolsCall, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-uninitialized", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest - mock.server.handleRequest(w, r) - - // Verify response contains error about client not being initialized - select { - case response := <-client.channel: - assert.Contains(t, response, "error", "Response should contain error") - assert.Contains(t, response, "not fully initialized", "Error should mention client not initialized") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for error response") - } - }) -} - -// TestMethodPromptsGet tests the handling of prompts/get method through handleRequest -func TestMethodPromptsGet(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - t.Run("staticPrompt", func(t *testing.T) { - // Register a test prompt with static content - testPrompt := Prompt{ - Name: "static-prompt", - Description: "A static test prompt with placeholders", - Arguments: []PromptArgument{ - { - Name: "name", - Description: "Name to use in greeting", - Required: true, - }, - { - Name: "topic", - Description: "Topic to discuss", - }, - }, - Content: "Hello {{name}}! Let's talk about {{topic}}.", - } - mock.server.RegisterPrompt(testPrompt) - - // Create an initialized client - client := addTestClient(mock.server, "test-client-static", true) - - // Create a prompts/get request - params := struct { - Name string `json:"name"` - Arguments map[string]string `json:"arguments"` - }{ - Name: "static-prompt", - Arguments: map[string]string{ - "name": "Test User", - // Intentionally not providing "topic" to test default values - }, - } - - paramBytes, err := json.Marshal(params) - require.NoError(t, err, "Failed to marshal prompt get parameters") - - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 70, - Method: methodPromptsGet, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-static", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest (full path) - mock.server.handleRequest(w, r) - - // Verify the HTTP response - assert.Equal(t, http.StatusAccepted, w.Code, "HTTP status should be 202 Accepted") - - // Check the response in client's channel - select { - case response := <-client.channel: - // Verify it's a message event with the expected format - assert.Contains(t, response, "event: message", "Response should be a message event") - - // Parse the JSON part of the SSE message - jsonStart := strings.Index(response, "{") - jsonEnd := strings.LastIndex(response, "}") - require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") - jsonStr := response[jsonStart : jsonEnd+1] - - var parsed map[string]any - err := json.Unmarshal([]byte(jsonStr), &parsed) - require.NoError(t, err, "Should be able to parse response JSON") - - // Validate the structure - result, ok := parsed["result"].(map[string]any) - require.True(t, ok, "Response should have a result object") - - // Verify ID matches our request - id, ok := parsed["id"].(float64) - assert.True(t, ok, "Response should have an ID") - assert.Equal(t, float64(70), id, "Response ID should match request ID") - - // Verify description - description, ok := result["description"].(string) - assert.True(t, ok, "Response should include prompt description") - assert.Equal(t, "A static test prompt with placeholders", description, "Description should match") - - // Verify messages - messages, ok := result["messages"].([]any) - require.True(t, ok, "Result should have messages array") - assert.Len(t, messages, 1, "Should have 1 message") - - // Check message content - should have placeholder substitutions - if len(messages) > 0 { - message, ok := messages[0].(map[string]any) - require.True(t, ok, "Message should be an object") - assert.Equal(t, string(RoleUser), message["role"], "Role should be 'user'") - - content, ok := message["content"].(map[string]any) - require.True(t, ok, "Should have content object") - assert.Equal(t, ContentTypeText, content["type"], "Content type should be text") - assert.Contains(t, content[ContentTypeText], "Hello Test User", "Content should include the name argument") - } - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for prompt get response") - } - }) - - t.Run("dynamicPrompt", func(t *testing.T) { - // Register a test prompt with a handler function - testPrompt := Prompt{ - Name: "dynamic-prompt", - Description: "A dynamic test prompt with a handler", - Arguments: []PromptArgument{ - { - Name: "username", - Description: "User's name", - Required: true, - }, - { - Name: "question", - Description: "User's question", - }, - }, - Handler: func(ctx context.Context, args map[string]string) ([]PromptMessage, error) { - username := args["username"] - question := args["question"] - - // Create a system message - systemMessage := PromptMessage{ - Role: RoleAssistant, - Content: TextContent{ - Text: "You are a helpful assistant.", - }, - } - - // Create a user message - userMessage := PromptMessage{ - Role: RoleUser, - Content: TextContent{ - Text: fmt.Sprintf("Hi, I'm %s and I'm wondering: %s", username, question), - }, - } - - return []PromptMessage{systemMessage, userMessage}, nil - }, - } - mock.server.RegisterPrompt(testPrompt) - - // Create an initialized client - client := addTestClient(mock.server, "test-client-dynamic", true) - - // Create a prompts/get request - params := struct { - Name string `json:"name"` - Arguments map[string]string `json:"arguments"` - }{ - Name: "dynamic-prompt", - Arguments: map[string]string{ - "username": "Dynamic User", - "question": "How to test this?", - }, - } - - paramBytes, err := json.Marshal(params) - require.NoError(t, err, "Failed to marshal prompt get parameters") - - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 71, - Method: methodPromptsGet, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-dynamic", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest - mock.server.handleRequest(w, r) - - // Check the response - select { - case response := <-client.channel: - // Extract and parse JSON - jsonStart := strings.Index(response, "{") - jsonEnd := strings.LastIndex(response, "}") - jsonStr := response[jsonStart : jsonEnd+1] - - var parsed map[string]any - err := json.Unmarshal([]byte(jsonStr), &parsed) - require.NoError(t, err, "Should be able to parse response JSON") - - result, ok := parsed["result"].(map[string]any) - require.True(t, ok, "Response should have a result object") - - // Verify messages - should have 2 messages from handler - messages, ok := result["messages"].([]any) - require.True(t, ok, "Result should have messages array") - assert.Len(t, messages, 2, "Should have 2 messages") - - // Check message content - if len(messages) >= 2 { - // First message should be assistant - message1, _ := messages[0].(map[string]any) - assert.Equal(t, string(RoleAssistant), message1["role"], "First role should be 'system'") - - content1, _ := message1["content"].(map[string]any) - assert.Contains(t, content1[ContentTypeText], "helpful assistant", "System message should be correct") - - // Second message should be user - message2, _ := messages[1].(map[string]any) - assert.Equal(t, string(RoleUser), message2["role"], "Second role should be 'user'") - - content2, _ := message2["content"].(map[string]any) - assert.Contains(t, content2[ContentTypeText], "Dynamic User", "User message should contain username") - assert.Contains(t, content2[ContentTypeText], "How to test this?", "User message should contain question") - } - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for prompt get response") - } - }) - - t.Run("missingRequiredArgument", func(t *testing.T) { - // Register a test prompt with a required argument - testPrompt := Prompt{ - Name: "required-arg-prompt", - Description: "A prompt with required arguments", - Arguments: []PromptArgument{ - { - Name: "required_arg", - Description: "This argument is required", - Required: true, - }, - }, - } - mock.server.RegisterPrompt(testPrompt) - - // Create an initialized client - client := addTestClient(mock.server, "test-client-missing-arg", true) - - // Create a prompts/get request with missing required argument - params := struct { - Name string `json:"name"` - Arguments map[string]string `json:"arguments"` - }{ - Name: "required-arg-prompt", - Arguments: map[string]string{}, // Empty arguments - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 72, - Method: methodPromptsGet, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-missing-arg", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest - mock.server.handleRequest(w, r) - - // Check for error response about missing required argument - select { - case response := <-client.channel: - assert.Contains(t, response, "error", "Response should contain error") - assert.Contains(t, response, "Missing required arguments", "Error should mention missing arguments") - assert.Contains(t, response, "required_arg", "Error should name the missing argument") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for error response") - } - }) - - t.Run("promptNotFound", func(t *testing.T) { - // Create an initialized client - client := addTestClient(mock.server, "test-client-prompt-not-found", true) - - // Create a prompts/get request with non-existent prompt - params := struct { - Name string `json:"name"` - Arguments map[string]string `json:"arguments"` - }{ - Name: "non-existent-prompt", - Arguments: map[string]string{}, - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 73, - Method: methodPromptsGet, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-prompt-not-found", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest - mock.server.handleRequest(w, r) - - // Check for error response about non-existent prompt - select { - case response := <-client.channel: - assert.Contains(t, response, "error", "Response should contain error") - assert.Contains(t, response, "Prompt 'non-existent-prompt' not found", "Error should mention prompt not found") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for error response") - } - }) - - t.Run("handlerError", func(t *testing.T) { - // Register a test prompt with a handler that returns an error - testPrompt := Prompt{ - Name: "error-handler-prompt", - Description: "A prompt with a handler that returns an error", - Arguments: []PromptArgument{}, - Handler: func(ctx context.Context, args map[string]string) ([]PromptMessage, error) { - return nil, fmt.Errorf("test handler error") - }, - } - mock.server.RegisterPrompt(testPrompt) - - // Create an initialized client - client := addTestClient(mock.server, "test-client-handler-error", true) - - // Create a prompts/get request - params := struct { - Name string `json:"name"` - Arguments map[string]string `json:"arguments"` - }{ - Name: "error-handler-prompt", - Arguments: map[string]string{}, - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 74, - Method: methodPromptsGet, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-handler-error", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest - mock.server.handleRequest(w, r) - - // Check for error response about handler error - select { - case response := <-client.channel: - assert.Contains(t, response, "error", "Response should contain error") - assert.Contains(t, response, "Error generating prompt content", "Error should mention generating content") - assert.Contains(t, response, "test handler error", "Error should include the handler error message") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for error response") - } - }) - - t.Run("invalidParameters", func(t *testing.T) { - // Create an invalid JSON request - invalidJson := []byte(`{"not valid json`) - - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 75, - Method: methodPromptsGet, - Params: invalidJson, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-invalid-params", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest - mock.server.handleRequest(w, r) - assert.Equal(t, http.StatusBadRequest, w.Code) - }) - - t.Run("clientNotInitialized", func(t *testing.T) { - // Register a basic prompt - testPrompt := Prompt{ - Name: "basic-prompt", - Description: "A basic test prompt", - } - mock.server.RegisterPrompt(testPrompt) - - // Create an uninitialized client - client := addTestClient(mock.server, "test-client-uninit", false) - - // Create a valid prompts/get request - params := struct { - Name string `json:"name"` - Arguments map[string]string `json:"arguments"` - }{ - Name: "basic-prompt", - Arguments: map[string]string{}, - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 76, - Method: methodPromptsGet, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-uninit", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest - mock.server.handleRequest(w, r) - - // Verify response contains error about client not being initialized - select { - case response := <-client.channel: - assert.Contains(t, response, "error", "Response should contain error") - assert.Contains(t, response, "not fully initialized", "Error should mention client not initialized") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for error response") - } - }) -} - -func TestMethodResourcesList(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - t.Run("validResourceWithHandler", func(t *testing.T) { - // Register a test resource with handler - testResource := Resource{ - Name: "test-resource", - URI: "file:///test/resource.txt", - Description: "A test resource with handler", - MimeType: "text/plain", - Handler: func(ctx context.Context) (ResourceContent, error) { - return ResourceContent{ - URI: "file:///test/resource.txt", - MimeType: "text/plain", - Text: "This is test resource content", - }, nil - }, - } - mock.server.RegisterResource(testResource) - - // Create an initialized client - client := addTestClient(mock.server, "test-client-resources", true) - - // Create a resources/read request - params := PaginatedParams{ - Cursor: "next-cursor", - Meta: struct { - ProgressToken any `json:"progressToken"` - }{ - ProgressToken: "token", - }, - } - - paramBytes, err := json.Marshal(params) - require.NoError(t, err, "Failed to marshal resource read parameters") - - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 50, - Method: methodResourcesList, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-resources", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest (full path) - mock.server.handleRequest(w, r) - - // Verify the HTTP response - assert.Equal(t, http.StatusAccepted, w.Code, "HTTP status should be 202 Accepted") - - // Check the response in client's channel - select { - case response := <-client.channel: - evt, err := parseEvent(response) - assert.NoError(t, err) - result, ok := evt.Data["result"].(map[string]any) - assert.True(t, ok) - assert.Equal(t, "next-cursor", result["nextCursor"]) - assert.Equal(t, "token", result["_meta"].(map[string]any)["progressToken"]) - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for resource read response") - } - }) -} - -// TestMethodResourcesRead tests the handling of resources/read method -func TestMethodResourcesRead(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - t.Run("validResourceWithHandler", func(t *testing.T) { - // Register a test resource with handler - testResource := Resource{ - Name: "test-resource", - URI: "file:///test/resource.txt", - Description: "A test resource with handler", - MimeType: "text/plain", - Handler: func(ctx context.Context) (ResourceContent, error) { - return ResourceContent{ - URI: "file:///test/resource.txt", - MimeType: "text/plain", - Text: "This is test resource content", - }, nil - }, - } - mock.server.RegisterResource(testResource) - - // Create an initialized client - client := addTestClient(mock.server, "test-client-resources", true) - - // Create a resources/read request - params := ResourceReadParams{ - URI: "file:///test/resource.txt", - } - - paramBytes, err := json.Marshal(params) - require.NoError(t, err, "Failed to marshal resource read parameters") - - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 50, - Method: methodResourcesRead, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-resources", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest (full path) - mock.server.handleRequest(w, r) - - // Verify the HTTP response - assert.Equal(t, http.StatusAccepted, w.Code, "HTTP status should be 202 Accepted") - - // Check the response in client's channel - select { - case response := <-client.channel: - // Verify it's a message event with the expected format - assert.Contains(t, response, "event: message", "Response should be a message event") - - // Parse the JSON part of the SSE message - jsonStart := strings.Index(response, "{") - jsonEnd := strings.LastIndex(response, "}") - require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") - jsonStr := response[jsonStart : jsonEnd+1] - - var parsed map[string]any - err := json.Unmarshal([]byte(jsonStr), &parsed) - require.NoError(t, err, "Should be able to parse response JSON") - - // Validate the structure - result, ok := parsed["result"].(map[string]any) - require.True(t, ok, "Response should have a result object") - - // Verify ID matches our request - id, ok := parsed["id"].(float64) - assert.True(t, ok, "Response should have an ID") - assert.Equal(t, float64(50), id, "Response ID should match request ID") - - // Verify contents - contents, ok := result["contents"].([]any) - require.True(t, ok, "Result should have contents array") - assert.Len(t, contents, 1, "Contents array should have 1 item") - - // Check content details - if len(contents) > 0 { - content, ok := contents[0].(map[string]any) - require.True(t, ok, "Content should be an object") - assert.Equal(t, "file:///test/resource.txt", content["uri"], "URI should match") - assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match") - assert.Equal(t, "This is test resource content", content[ContentTypeText], "Text content should match") - } - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for resource read response") - } - }) - - t.Run("resourceWithoutHandler", func(t *testing.T) { - // Register a test resource without handler - testResource := Resource{ - Name: "no-handler-resource", - URI: "file:///test/no-handler.txt", - Description: "A test resource without handler", - MimeType: "text/plain", - // No handler provided - } - mock.server.RegisterResource(testResource) - - // Create an initialized client - client := addTestClient(mock.server, "test-client-no-handler", true) - - // Create a resources/read request - params := ResourceReadParams{ - URI: "file:///test/no-handler.txt", - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 51, - Method: methodResourcesRead, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-no-handler", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest - mock.server.handleRequest(w, r) - - // Check for response with empty content - select { - case response := <-client.channel: - assert.Contains(t, response, "event: message", "Response should be a message event") - - // Extract and parse JSON - jsonStart := strings.Index(response, "{") - jsonEnd := strings.LastIndex(response, "}") - jsonStr := response[jsonStart : jsonEnd+1] - - var parsed map[string]any - err := json.Unmarshal([]byte(jsonStr), &parsed) - require.NoError(t, err, "Should be able to parse response JSON") - - // Check contents exists but has empty text - result, ok := parsed["result"].(map[string]any) - require.True(t, ok, "Response should have a result object") - - contents, ok := result["contents"].([]any) - require.True(t, ok, "Result should have contents array") - assert.Len(t, contents, 1, "Contents array should have 1 item") - - // Check content details - should have URI and MimeType but empty text - if len(contents) > 0 { - content, ok := contents[0].(map[string]any) - require.True(t, ok, "Content should be an object") - assert.Equal(t, "file:///test/no-handler.txt", content["uri"], "URI should match") - assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match") - _, ok = content[ContentTypeText] - assert.False(t, ok, "Text content should be empty string") - } - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for resource read response") - } - }) - - t.Run("resourceNotFound", func(t *testing.T) { - // Create an initialized client - client := addTestClient(mock.server, "test-client-not-found", true) - - // Create a resources/read request with non-existent URI - params := ResourceReadParams{ - URI: "file:///test/non-existent.txt", - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 52, - Method: methodResourcesRead, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-not-found", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest - mock.server.handleRequest(w, r) - - // Check for error response about resource not found - select { - case response := <-client.channel: - assert.Contains(t, response, "error", "Response should contain error") - assert.Contains(t, response, "Resource with URI", "Error should mention resource URI") - assert.Contains(t, response, "not found", "Error should indicate resource not found") - assert.Contains(t, response, "file:///test/non-existent.txt", "Error should mention the requested URI") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for error response") - } - }) - - t.Run("invalidParameters", func(t *testing.T) { - // Create an invalid JSON request - invalidJson := []byte(`{"not valid json`) - - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 53, - Method: methodResourcesRead, - Params: invalidJson, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-invalid-params", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest - mock.server.handleRequest(w, r) - assert.Equal(t, http.StatusBadRequest, w.Code) - }) - - t.Run("invalidParameters direct", func(t *testing.T) { - // Create an invalid JSON request - invalidJson := []byte(`{"not valid json`) - - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 53, - Method: methodResourcesRead, - Params: invalidJson, - } - - // Create an initialized client - client := addTestClient(mock.server, "test-client-resources", true) - - // Process through handleRequest - mock.server.processResourcesRead(context.Background(), client, req) - - select { - case response := <-client.channel: - assert.Contains(t, response, "error", "Response should contain error") - assert.Contains(t, response, "Invalid parameters", "Error should mention invalid parameters") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for error response") - } - }) - - t.Run("handlerError", func(t *testing.T) { - // Register a test resource with handler that returns an error - testResource := Resource{ - Name: "error-resource", - URI: "file:///test/error.txt", - Description: "A test resource with handler that returns error", - MimeType: "text/plain", - Handler: func(ctx context.Context) (ResourceContent, error) { - return ResourceContent{}, fmt.Errorf("test handler error") - }, - } - mock.server.RegisterResource(testResource) - - // Create an initialized client - client := addTestClient(mock.server, "test-client-handler-error", true) - - // Create a resources/read request - params := ResourceReadParams{ - URI: "file:///test/error.txt", - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 54, - Method: methodResourcesRead, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-handler-error", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest - mock.server.handleRequest(w, r) - - // Check for error response about handler error - select { - case response := <-client.channel: - assert.Contains(t, response, "error", "Response should contain error") - assert.Contains(t, response, "Error reading resource", "Error should mention reading resource") - assert.Contains(t, response, "test handler error", "Error should include handler error message") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for error response") - } - }) - - t.Run("handlerMissingURIAndMimeType", func(t *testing.T) { - // Register a test resource with handler that returns content without URI and MimeType - testResource := Resource{ - Name: "missing-fields-resource", - URI: "file:///test/missing-fields.txt", - Description: "A test resource with handler that returns content missing fields", - MimeType: "text/plain", - Handler: func(ctx context.Context) (ResourceContent, error) { - // Return ResourceContent without URI and MimeType - return ResourceContent{ - Text: "Content with missing fields", - }, nil - }, - } - mock.server.RegisterResource(testResource) - - // Create an initialized client - client := addTestClient(mock.server, "test-client-missing-fields", true) - - // Create a resources/read request - params := ResourceReadParams{ - URI: "file:///test/missing-fields.txt", - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 55, - Method: methodResourcesRead, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-missing-fields", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest - mock.server.handleRequest(w, r) - - // Check response - server should fill in the missing fields - select { - case response := <-client.channel: - assert.Contains(t, response, "event: message", "Response should be a message event") - - // Extract and parse JSON - jsonStart := strings.Index(response, "{") - jsonEnd := strings.LastIndex(response, "}") - jsonStr := response[jsonStart : jsonEnd+1] - - var parsed map[string]any - err := json.Unmarshal([]byte(jsonStr), &parsed) - require.NoError(t, err, "Should be able to parse response JSON") - - result, ok := parsed["result"].(map[string]any) - require.True(t, ok, "Response should have a result object") - - contents, ok := result["contents"].([]any) - require.True(t, ok, "Result should have contents array") - assert.Len(t, contents, 1, "Contents array should have 1 item") - - // Check content details - server should fill in missing URI and MimeType - if len(contents) > 0 { - content, ok := contents[0].(map[string]any) - require.True(t, ok, "Content should be an object") - assert.Equal(t, "file:///test/missing-fields.txt", content["uri"], "URI should be filled from request") - assert.Equal(t, "text/plain", content["mimeType"], "MimeType should be filled from resource") - assert.Equal(t, "Content with missing fields", content[ContentTypeText], "Text content should match") - } - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for resource read response") - } - }) -} - -// TestMethodResourcesSubscribe tests the handling of resources/subscribe method -func TestMethodResourcesSubscribe(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - t.Run("validSubscription", func(t *testing.T) { - // Register a test resource - testResource := Resource{ - Name: "subscribe-resource", - URI: "file:///test/subscribe.txt", - Description: "A test resource for subscription", - MimeType: "text/plain", - } - mock.server.RegisterResource(testResource) - - // Create an initialized client - client := addTestClient(mock.server, "test-client-subscribe", true) - - // Create a resources/subscribe request - params := ResourceSubscribeParams{ - URI: "file:///test/subscribe.txt", - } - - paramBytes, err := json.Marshal(params) - require.NoError(t, err, "Failed to marshal resource subscribe parameters") - - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 60, - Method: methodResourcesSubscribe, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-subscribe", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest (full path) - mock.server.handleRequest(w, r) - - // Verify the HTTP response - assert.Equal(t, http.StatusAccepted, w.Code, "HTTP status should be 202 Accepted") - - // Check the response in client's channel - should be an empty success response - select { - case response := <-client.channel: - // Verify it's a message event with the expected format - assert.Contains(t, response, "event: message", "Response should be a message event") - - // Parse the JSON part of the SSE message - jsonStart := strings.Index(response, "{") - jsonEnd := strings.LastIndex(response, "}") - require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") - jsonStr := response[jsonStart : jsonEnd+1] - - var parsed map[string]any - err := json.Unmarshal([]byte(jsonStr), &parsed) - require.NoError(t, err, "Should be able to parse response JSON") - - // Verify ID matches our request - id, ok := parsed["id"].(float64) - assert.True(t, ok, "Response should have an ID") - assert.Equal(t, float64(60), id, "Response ID should match request ID") - - // Verify the result exists and is an empty object - result, ok := parsed["result"].(map[string]any) - require.True(t, ok, "Response should have a result object") - assert.Empty(t, result, "Result should be an empty object for successful subscription") - - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for subscription response") - } - }) - - t.Run("resourceNotFound", func(t *testing.T) { - // Create an initialized client - client := addTestClient(mock.server, "test-client-sub-not-found", true) - - // Create a resources/subscribe request with non-existent URI - params := ResourceSubscribeParams{ - URI: "file:///test/non-existent-subscription.txt", - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 61, - Method: methodResourcesSubscribe, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-sub-not-found", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest - mock.server.handleRequest(w, r) - - // Check for error response about resource not found - select { - case response := <-client.channel: - assert.Contains(t, response, "error", "Response should contain error") - assert.Contains(t, response, "Resource with URI", "Error should mention resource URI") - assert.Contains(t, response, "not found", "Error should indicate resource not found") - assert.Contains(t, response, "file:///test/non-existent-subscription.txt", "Error should mention the requested URI") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for error response") - } - }) - - t.Run("invalidParameters", func(t *testing.T) { - // Create an invalid JSON request - invalidJson := []byte(`{"not valid json`) - - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 62, - Method: methodResourcesSubscribe, - Params: invalidJson, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-sub-invalid-params", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest - mock.server.handleRequest(w, r) - assert.Equal(t, http.StatusBadRequest, w.Code, "HTTP status should be 400 Bad Request") - }) - - t.Run("invalidParameters direct", func(t *testing.T) { - // Create an invalid JSON request - invalidJson := []byte(`{"not valid json`) - - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 62, - Method: methodResourcesSubscribe, - Params: invalidJson, - } - - client := addTestClient(mock.server, "test-client-sub-not-found", true) - mock.server.processResourceSubscribe(context.Background(), client, req) - - select { - case response := <-client.channel: - assert.Contains(t, response, "error", "Response should contain error") - assert.Contains(t, response, "Invalid parameters", "Error should mention invalid parameters") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for error response") - } - }) - - t.Run("clientNotInitialized", func(t *testing.T) { - // Register a test resource - testResource := Resource{ - Name: "subscribe-resource-uninit", - URI: "file:///test/subscribe-uninit.txt", - Description: "A test resource for subscription with uninitialized client", - } - mock.server.RegisterResource(testResource) - - // Create an uninitialized client - client := addTestClient(mock.server, "test-client-sub-uninitialized", false) - - // Create a valid resources/subscribe request - params := ResourceSubscribeParams{ - URI: "file:///test/subscribe-uninit.txt", - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 63, - Method: methodResourcesSubscribe, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-sub-uninitialized", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest - mock.server.handleRequest(w, r) - - // Verify response contains error about client not being initialized - select { - case response := <-client.channel: - assert.Contains(t, response, "error", "Response should contain error") - assert.Contains(t, response, "not fully initialized", "Error should mention client not initialized") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for error response") - } - }) - - t.Run("missingURIParameter", func(t *testing.T) { - // Create an initialized client - client := addTestClient(mock.server, "test-client-sub-missing-uri", true) - - // Create a subscription request with empty URI - params := ResourceSubscribeParams{ - URI: "", // Empty URI - } - - paramBytes, _ := json.Marshal(params) - req := Request{ - JsonRpc: jsonRpcVersion, - ID: 64, - Method: methodResourcesSubscribe, - Params: paramBytes, - } - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-sub-missing-uri", - bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest - mock.server.handleRequest(w, r) - - // Check for error response about resource not found (empty URI) - select { - case response := <-client.channel: - assert.Contains(t, response, "error", "Response should contain error") - assert.Contains(t, response, "Resource with URI", "Error should mention resource URI") - assert.Contains(t, response, "not found", "Error should indicate resource not found") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for error response") - } - }) -} - -// TestToolCallUnmarshalError tests the error handling when unmarshaling invalid JSON in processToolCall -func TestToolCallUnmarshalError(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Create an initialized client - client := addTestClient(mock.server, "test-client-unmarshal-error", true) - - // Create a request with invalid JSON in Params - req := Request{ - JsonRpc: "2.0", - ID: 100, - Method: methodToolsCall, - Params: []byte(`{"name": "test.tool", "arguments": {"input": invalid_json}}`), // This is invalid JSON - } - - // Process the tool call directly - mock.server.processToolCall(context.Background(), client, req) - - // Check for error response about invalid JSON - select { - case response := <-client.channel: - assert.Contains(t, response, "error", "Response should contain an error") - assert.Contains(t, response, "Invalid tool call parameters", "Error should mention invalid parameters") - - // Extract error code from response - jsonStart := strings.Index(response, "{") - jsonEnd := strings.LastIndex(response, "}") - require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") - jsonStr := response[jsonStart : jsonEnd+1] - - var parsed struct { - Error struct { - Code int `json:"code"` - } `json:"error"` - } - err := json.Unmarshal([]byte(jsonStr), &parsed) - require.NoError(t, err, "Should be able to parse response JSON") - - // Verify correct error code was returned - assert.Equal(t, errCodeInvalidParams, parsed.Error.Code, "Error code should be errCodeInvalidParams") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for error response") - } -} - -// TestToolCallWithInvalidParams tests the handling when calling handleRequest with invalid JSON params -func TestToolCallWithInvalidParams(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Register a tool to make sure it exists - mock.registerExampleTool() - - // Create a request with invalid JSON - req := Request{ - JsonRpc: "2.0", - ID: 101, - Method: methodToolsCall, - Params: []byte(`{"name": "test.tool", "arguments": {this_is_invalid_json}}`), - } - - jsonBody, _ := json.Marshal(req) - - // Create HTTP request - r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-invalid-json", bytes.NewReader(jsonBody)) - w := httptest.NewRecorder() - - // Process through handleRequest - mock.server.handleRequest(w, r) - - // Verify HTTP status is Accepted (even for errors, we accept the request) - assert.Equal(t, http.StatusBadRequest, w.Code) -} - -type mockResponseWriter struct { -} - -func (m *mockResponseWriter) Header() http.Header { - return http.Header{} -} - -func (m *mockResponseWriter) Write(i []byte) (int, error) { - return len(i), nil -} - -func (m *mockResponseWriter) WriteHeader(_ int) { -} - -type notFlusherResponseWriter struct { - mockResponseWriter - code int -} - -func (m *notFlusherResponseWriter) WriteHeader(code int) { - m.code = code -} - -type cantWriteResponseWriter struct { - mockResponseWriter - code int -} - -func (m *cantWriteResponseWriter) Flush() { -} - -func (m *cantWriteResponseWriter) Write(_ []byte) (int, error) { - return 0, fmt.Errorf("can't write") -} - -type writeOnceResponseWriter struct { - mockResponseWriter - times int32 -} - -func (m *writeOnceResponseWriter) Flush() { -} - -func (m *writeOnceResponseWriter) Write(i []byte) (int, error) { - if atomic.AddInt32(&m.times, 1) > 1 { - return 0, fmt.Errorf("write once") - } - return len(i), nil -} - -// testResponseWriter is a custom http.ResponseWriter that captures writes and detects ping messages -type testResponseWriter struct { - *httptest.ResponseRecorder - writes []string - mu sync.Mutex - pingDetected bool - done chan struct{} -} - -// Write overrides the ResponseRecorder's Write method to detect ping messages -func (w *testResponseWriter) Write(b []byte) (int, error) { - w.mu.Lock() - defer w.mu.Unlock() - - written, err := w.ResponseRecorder.Write(b) - if err != nil { - return written, err - } - - content := string(b) - w.writes = append(w.writes, content) - - // Check if this is a ping message - if strings.Contains(content, "event: ping") { - w.pingDetected = true - // Signal that we've detected a ping - select { - case w.done <- struct{}{}: - default: - // Channel might be closed or already signaled +func TestServerStartStreamable(t *testing.T) { + // Test with Streamable transport + c := McpConf{} + c.Host = "127.0.0.1" + c.Port = 19081 + c.Mcp.Name = "streamable-start-test" + c.Mcp.UseStreamable = true + c.Mcp.SseEndpoint = "/sse" + c.Mcp.MessageEndpoint = "/message" + c.Mcp.SseTimeout = 1 * time.Second + c.Mcp.MessageTimeout = 1 * time.Second + + server := NewMcpServer(c) + + // Start server in goroutine + go func() { + server.Start() + }() + + // Give server time to start + time.Sleep(300 * time.Millisecond) + + // Make a GET request first (SSE connection) + client := &http.Client{Timeout: 500 * time.Millisecond} + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:19081/message", nil) + if err == nil { + req.Header.Set("Accept", "text/event-stream") + resp, err := client.Do(req) + if err == nil { + resp.Body.Close() + // GET request should work + assert.True(t, resp.StatusCode > 0) } } - return written, nil + // Also make a POST request (for message) + jsonData := []byte(`{"jsonrpc":"2.0","method":"ping","id":1}`) + req2, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:19081/message", bytes.NewBuffer(jsonData)) + if err == nil { + req2.Header.Set("Content-Type", "application/json") + resp, err := client.Do(req2) + if err == nil { + resp.Body.Close() + // POST request should also work + assert.True(t, resp.StatusCode > 0) + } + } + + // Stop the server + server.Stop() + + // Give it time to shutdown + time.Sleep(100 * time.Millisecond) } -// Flush implements the http.Flusher interface -func (w *testResponseWriter) Flush() { - w.ResponseRecorder.Flush() +func TestSSEHandlerCallback(t *testing.T) { + c := McpConf{} + c.Host = "127.0.0.1" + c.Port = 0 + c.Mcp.Name = "sse-handler-test" + c.Mcp.UseStreamable = false + c.Mcp.SseEndpoint = "/sse" + c.Mcp.MessageEndpoint = "/message" + + server := NewMcpServer(c) + impl := server.(*mcpServerImpl) + + // Verify the server is set up correctly + assert.NotNil(t, impl.mcpServer) + assert.False(t, impl.conf.Mcp.UseStreamable) +} + +func TestStreamableHandlerCallback(t *testing.T) { + c := McpConf{} + c.Host = "127.0.0.1" + c.Port = 0 + c.Mcp.Name = "streamable-handler-test" + c.Mcp.UseStreamable = true + c.Mcp.SseEndpoint = "/sse" + c.Mcp.MessageEndpoint = "/message" + + server := NewMcpServer(c) + impl := server.(*mcpServerImpl) + + // Verify the server is set up correctly + assert.NotNil(t, impl.mcpServer) + assert.True(t, impl.conf.Mcp.UseStreamable) +} + +func TestSSEEndpointAccess(t *testing.T) { + c := McpConf{} + c.Host = "127.0.0.1" + c.Port = 0 + c.Mcp.Name = "sse-endpoint-test" + c.Mcp.UseStreamable = false + c.Mcp.SseEndpoint = "/sse" + + server := NewMcpServer(c) + impl := server.(*mcpServerImpl) + + // Create a test request + req := httptest.NewRequest(http.MethodGet, "/sse", nil) + req.Header.Set("Accept", "text/event-stream") + + // The server should be configured with SSE endpoints + assert.NotNil(t, impl.httpServer) + assert.Equal(t, "/sse", impl.conf.Mcp.SseEndpoint) +} + +func TestStreamableEndpointAccess(t *testing.T) { + c := McpConf{} + c.Host = "127.0.0.1" + c.Port = 0 + c.Mcp.Name = "streamable-endpoint-test" + c.Mcp.UseStreamable = true + c.Mcp.MessageEndpoint = "/message" + + server := NewMcpServer(c) + impl := server.(*mcpServerImpl) + + // The server should be configured with streamable endpoints + assert.NotNil(t, impl.httpServer) + assert.Equal(t, "/message", impl.conf.Mcp.MessageEndpoint) +} + +func TestConfig(t *testing.T) { + var c McpConf + err := conf.FillDefault(&c) + assert.NoError(t, err) + assert.Equal(t, "1.0.0", c.Mcp.Version) + assert.Equal(t, "/sse", c.Mcp.SseEndpoint) + assert.Equal(t, "/message", c.Mcp.MessageEndpoint) } diff --git a/mcp/types.go b/mcp/types.go index 86119e1c6..7f4bff959 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -2,316 +2,96 @@ package mcp import ( "context" - "encoding/json" - "fmt" - "sync" - "github.com/zeromicro/go-zero/rest" + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" ) -// Cursor is an opaque token used for pagination -type Cursor string +// Re-export commonly used SDK types for convenience +type ( + // Tool types + Tool = sdkmcp.Tool + CallToolParams = sdkmcp.CallToolParams + CallToolResult = sdkmcp.CallToolResult + CallToolRequest = sdkmcp.CallToolRequest -// Request represents a generic MCP request following JSON-RPC 2.0 specification -type Request struct { - SessionId string `form:"session_id"` // Session identifier for client tracking - JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec - ID any `json:"id"` // Request identifier for matching responses - Method string `json:"method"` // Method name to invoke - Params json.RawMessage `json:"params"` // Parameters for the method -} + // Content types + Content = sdkmcp.Content + TextContent = sdkmcp.TextContent + ImageContent = sdkmcp.ImageContent + AudioContent = sdkmcp.AudioContent -func (r Request) isNotification() (bool, error) { - switch val := r.ID.(type) { - case int: - return val == 0, nil - case int64: - return val == 0, nil - case float64: - return val == 0.0, nil - case string: - return len(val) == 0, nil - case nil: - return true, nil - default: - return false, fmt.Errorf("invalid type %T", val) + // Prompt types + Prompt = sdkmcp.Prompt + PromptMessage = sdkmcp.PromptMessage + GetPromptParams = sdkmcp.GetPromptParams + GetPromptResult = sdkmcp.GetPromptResult + + // Resource types + Resource = sdkmcp.Resource + ResourceContents = sdkmcp.ResourceContents + ReadResourceParams = sdkmcp.ReadResourceParams + ReadResourceResult = sdkmcp.ReadResourceResult + + // Session and server types + Server = sdkmcp.Server + ServerSession = sdkmcp.ServerSession + ServerOptions = sdkmcp.ServerOptions + Implementation = sdkmcp.Implementation + + // Transport types + SSEHandler = sdkmcp.SSEHandler + StreamableHTTPHandler = sdkmcp.StreamableHTTPHandler +) + +// ToolHandler is a generic function signature for tool handlers. +// Handlers should accept context, request, and typed arguments, and return +// a result, metadata, and error. +// +// Deprecated: Use ToolHandlerFor directly from the SDK types. +type ToolHandler[Args any, Meta any] func( + ctx context.Context, + req *CallToolRequest, + args Args, +) (*CallToolResult, Meta, error) + +// PromptHandler is a function signature for prompt handlers. +type PromptHandler func( + ctx context.Context, + req *sdkmcp.GetPromptRequest, + args map[string]string, +) (*GetPromptResult, error) + +// ResourceHandler is a function signature for resource handlers. +type ResourceHandler func( + ctx context.Context, + req *sdkmcp.ReadResourceRequest, + uri string, +) (*ReadResourceResult, error) + +// AddTool registers a tool with the MCP server using type-safe generics. +// The SDK automatically generates JSON schema from the Args struct tags. +// +// Example: +// +// type GreetArgs struct { +// Name string `json:"name" jsonschema:"description=Name to greet"` +// } +// +// tool := &mcp.Tool{ +// Name: "greet", +// Description: "Greet someone", +// } +// +// handler := func(ctx context.Context, req *mcp.CallToolRequest, args GreetArgs) (*mcp.CallToolResult, any, error) { +// return &mcp.CallToolResult{ +// Content: []mcp.Content{&mcp.TextContent{Text: "Hello " + args.Name}}, +// }, nil, nil +// } +// +// mcp.AddTool(server, tool, handler) +func AddTool[In, Out any](server McpServer, tool *Tool, handler func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error)) { + // Access internal server - only works with mcpServerImpl + if impl, ok := server.(*mcpServerImpl); ok { + sdkmcp.AddTool(impl.mcpServer, tool, handler) } } - -type PaginatedParams struct { - Cursor string `json:"cursor"` - Meta struct { - ProgressToken any `json:"progressToken"` - } `json:"_meta"` -} - -// Result is the base interface for all results -type Result struct { - Meta map[string]any `json:"_meta,omitempty"` // Optional metadata -} - -// PaginatedResult is a base for results that support pagination -type PaginatedResult struct { - Result - NextCursor Cursor `json:"nextCursor,omitempty"` // Opaque token for fetching next page -} - -// ListToolsResult represents the response to a tools/list request -type ListToolsResult struct { - PaginatedResult - Tools []Tool `json:"tools"` // List of available tools -} - -// Message Content Types - -// RoleType represents the sender or recipient of messages in a conversation -type RoleType string - -// PromptArgument defines a single argument that can be passed to a prompt -type PromptArgument struct { - Name string `json:"name"` // Argument name - Description string `json:"description,omitempty"` // Human-readable description - Required bool `json:"required,omitempty"` // Whether this argument is required -} - -// PromptHandler is a function that dynamically generates prompt content -type PromptHandler func(ctx context.Context, args map[string]string) ([]PromptMessage, error) - -// Prompt represents an MCP Prompt definition -type Prompt struct { - Name string `json:"name"` // Unique identifier for the prompt - Description string `json:"description,omitempty"` // Human-readable description - Arguments []PromptArgument `json:"arguments,omitempty"` // Arguments for customization - Content string `json:"-"` // Static content (internal use only) - Handler PromptHandler `json:"-"` // Handler for dynamic content generation -} - -// PromptMessage represents a message in a conversation -type PromptMessage struct { - Role RoleType `json:"role"` // Message sender role - Content any `json:"content"` // Message content (TextContent, ImageContent, etc.) -} - -// TextContent represents text content in a message -type TextContent struct { - Text string `json:"text"` // The text content - Annotations *Annotations `json:"annotations,omitempty"` // Optional annotations -} - -type typedTextContent struct { - Type string `json:"type"` - TextContent -} - -// ImageContent represents image data in a message -type ImageContent struct { - Data string `json:"data"` // Base64-encoded image data - MimeType string `json:"mimeType"` // MIME type (e.g., "image/png") -} - -type typedImageContent struct { - Type string `json:"type"` - ImageContent -} - -// AudioContent represents audio data in a message -type AudioContent struct { - Data string `json:"data"` // Base64-encoded audio data - MimeType string `json:"mimeType"` // MIME type (e.g., "audio/mp3") -} - -type typedAudioContent struct { - Type string `json:"type"` - AudioContent -} - -// FileContent represents file content -type FileContent struct { - URI string `json:"uri"` // URI identifying the file - MimeType string `json:"mimeType"` // MIME type of the file - Text string `json:"text"` // File content as text -} - -// EmbeddedResource represents a resource embedded in a message -type EmbeddedResource struct { - Type string `json:"type"` // Always "resource" - Resource ResourceContent `json:"resource"` // The resource data -} - -// Annotations provides additional metadata for content -type Annotations struct { - Audience []RoleType `json:"audience,omitempty"` // Who should see this content - Priority *float64 `json:"priority,omitempty"` // Optional priority (0-1) -} - -// Tool-related Types - -// ToolHandler is a function that handles tool calls -type ToolHandler func(ctx context.Context, params map[string]any) (any, error) - -// Tool represents a Model Context Protocol Tool definition -type Tool struct { - Name string `json:"name"` // Unique identifier for the tool - Description string `json:"description"` // Human-readable description - InputSchema InputSchema `json:"inputSchema"` // JSON Schema for parameters - Handler ToolHandler `json:"-"` // Not sent to clients -} - -// InputSchema represents tool's input schema in JSON Schema format -type InputSchema struct { - Type string `json:"type"` - Properties map[string]any `json:"properties"` // Property definitions - Required []string `json:"required,omitempty"` // List of required properties -} - -// CallToolResult represents a tool call result that conforms to the MCP schema -type CallToolResult struct { - Result - Content []any `json:"content"` // Content items (text, images, etc.) - IsError bool `json:"isError,omitempty"` // True if tool execution failed -} - -// Resource represents a Model Context Protocol Resource definition -type Resource struct { - URI string `json:"uri"` // Unique resource identifier (RFC3986) - Name string `json:"name"` // Human-readable name - Description string `json:"description,omitempty"` // Optional description - MimeType string `json:"mimeType,omitempty"` // Optional MIME type - Handler ResourceHandler `json:"-"` // Internal handler not sent to clients -} - -// ResourceHandler is a function that handles resource read requests -type ResourceHandler func(ctx context.Context) (ResourceContent, error) - -// ResourceContent represents the content of a resource -type ResourceContent struct { - URI string `json:"uri"` // Resource URI (required) - MimeType string `json:"mimeType,omitempty"` // MIME type of the resource - Text string `json:"text,omitempty"` // Text content (if available) - Blob string `json:"blob,omitempty"` // Base64 encoded blob data (if available) -} - -// ResourcesListResult represents the response to a resources/list request -type ResourcesListResult struct { - PaginatedResult - Resources []Resource `json:"resources"` // List of available resources -} - -// ResourceReadParams contains parameters for a resources/read request -type ResourceReadParams struct { - URI string `json:"uri"` // URI of the resource to read -} - -// ResourceReadResult contains the result of a resources/read request -type ResourceReadResult struct { - Result - Contents []ResourceContent `json:"contents"` // Array of resource content -} - -// ResourceSubscribeParams contains parameters for a resources/subscribe request -type ResourceSubscribeParams struct { - URI string `json:"uri"` // URI of the resource to subscribe to -} - -// ResourceUpdateNotification represents a notification about a resource update -type ResourceUpdateNotification struct { - URI string `json:"uri"` // URI of the updated resource - Content ResourceContent `json:"content"` // New resource content -} - -// Client and Server Types - -// mcpClient represents an SSE client connection -type mcpClient struct { - id string // Unique client identifier - channel chan string // Channel for sending SSE messages - initialized bool // Tracks if client has sent notifications/initialized -} - -// McpServer defines the interface for Model Context Protocol servers -type McpServer interface { - Start() - Stop() - RegisterTool(tool Tool) error - RegisterPrompt(prompt Prompt) - RegisterResource(resource Resource) -} - -// sseMcpServer implements the McpServer interface using SSE -type sseMcpServer struct { - conf McpConf - server *rest.Server - clients map[string]*mcpClient - clientsLock sync.Mutex - tools map[string]Tool - toolsLock sync.Mutex - prompts map[string]Prompt - promptsLock sync.Mutex - resources map[string]Resource - resourcesLock sync.Mutex -} - -// Response Types - -// errorObj represents a JSON-RPC error object -type errorObj struct { - Code int `json:"code"` // Error code - Message string `json:"message"` // Error message -} - -// Response represents a JSON-RPC response -type Response struct { - JsonRpc string `json:"jsonrpc"` // Always "2.0" - ID any `json:"id"` // Same as request ID - Result any `json:"result"` // Result object (null if error) - Error *errorObj `json:"error,omitempty"` // Error object (null if success) -} - -// Server Information Types - -// serverInfo provides information about the server -type serverInfo struct { - Name string `json:"name"` // Server name - Version string `json:"version"` // Server version -} - -// capabilities describes the server's capabilities -type capabilities struct { - Logging struct{} `json:"logging"` - Prompts struct { - ListChanged bool `json:"listChanged"` // Server will notify on prompt changes - } `json:"prompts"` - Resources struct { - Subscribe bool `json:"subscribe"` // Server supports resource subscriptions - ListChanged bool `json:"listChanged"` // Server will notify on resource changes - } `json:"resources"` - Tools struct { - ListChanged bool `json:"listChanged"` // Server will notify on tool changes - } `json:"tools"` -} - -// initializationResponse is sent in response to an initialize request -type initializationResponse struct { - ProtocolVersion string `json:"protocolVersion"` // Protocol version - Capabilities capabilities `json:"capabilities"` // Server capabilities - ServerInfo serverInfo `json:"serverInfo"` // Server information -} - -// ToolCallParams contains the parameters for a tool call -type ToolCallParams struct { - Name string `json:"name"` // Tool name - Parameters map[string]any `json:"parameters"` // Tool parameters -} - -// ToolResult contains the result of a tool execution -type ToolResult struct { - Type string `json:"type"` // Content type (text, image, etc.) - Content any `json:"content"` // Result content -} - -// errorMessage represents a detailed error message -type errorMessage struct { - Code int `json:"code"` // Error code - Message string `json:"message"` // Error message - Data any `json:",omitempty"` // Additional error data -} diff --git a/mcp/types_test.go b/mcp/types_test.go deleted file mode 100644 index ba27100c9..000000000 --- a/mcp/types_test.go +++ /dev/null @@ -1,271 +0,0 @@ -package mcp - -import ( - "context" - "encoding/json" - "errors" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestResponseMarshaling(t *testing.T) { - // Test that the Response struct marshals correctly - resp := Response{ - JsonRpc: "2.0", - ID: 123, - Result: map[string]string{ - "key": "value", - }, - } - - data, err := json.Marshal(resp) - assert.NoError(t, err) - assert.Contains(t, string(data), `"jsonrpc":"2.0"`) - assert.Contains(t, string(data), `"id":123`) - assert.Contains(t, string(data), `"result":{"key":"value"}`) - - // Test response with error - respWithError := Response{ - JsonRpc: "2.0", - ID: 456, - Error: &errorObj{ - Code: errCodeInvalidRequest, - Message: "Invalid Request", - }, - } - - data, err = json.Marshal(respWithError) - assert.NoError(t, err) - assert.Contains(t, string(data), `"jsonrpc":"2.0"`) - assert.Contains(t, string(data), `"id":456`) - assert.Contains(t, string(data), `"error":{"code":-32600,"message":"Invalid Request"}`) -} - -func TestRequestUnmarshaling(t *testing.T) { - // Test that the Request struct unmarshals correctly - jsonStr := `{ - "jsonrpc": "2.0", - "id": 789, - "method": "test_method", - "params": {"key": "value"} - }` - - var req Request - err := json.Unmarshal([]byte(jsonStr), &req) - - assert.NoError(t, err) - assert.Equal(t, "2.0", req.JsonRpc) - assert.Equal(t, float64(789), req.ID) - assert.Equal(t, "test_method", req.Method) - - // Check params unmarshaled correctly - var params map[string]string - err = json.Unmarshal(req.Params, ¶ms) - assert.NoError(t, err) - assert.Equal(t, "value", params["key"]) -} - -func TestToolStructs(t *testing.T) { - // Test Tool struct - tool := Tool{ - Name: "test.tool", - Description: "A test tool", - InputSchema: InputSchema{ - Type: "object", - Properties: map[string]any{ - "input": map[string]any{ - "type": "string", - "description": "Input parameter", - }, - }, - Required: []string{"input"}, - }, - Handler: func(ctx context.Context, params map[string]any) (any, error) { - return "result", nil - }, - } - - // Verify fields are correct - assert.Equal(t, "test.tool", tool.Name) - assert.Equal(t, "A test tool", tool.Description) - assert.Equal(t, "object", tool.InputSchema.Type) - assert.Contains(t, tool.InputSchema.Properties, "input") - propMap, ok := tool.InputSchema.Properties["input"].(map[string]any) - assert.True(t, ok, "Property should be a map") - assert.Equal(t, "string", propMap["type"]) - assert.NotNil(t, tool.Handler) - - // Verify JSON marshalling (which should exclude Handler function) - data, err := json.Marshal(tool) - assert.NoError(t, err) - assert.Contains(t, string(data), `"name":"test.tool"`) - assert.Contains(t, string(data), `"description":"A test tool"`) - assert.Contains(t, string(data), `"inputSchema":`) - assert.NotContains(t, string(data), `"Handler":`) -} - -func TestPromptStructs(t *testing.T) { - // Test Prompt struct - prompt := Prompt{ - Name: "test.prompt", - Description: "A test prompt description", - } - - // Verify fields are correct - assert.Equal(t, "test.prompt", prompt.Name) - assert.Equal(t, "A test prompt description", prompt.Description) - - // Verify JSON marshalling - data, err := json.Marshal(prompt) - assert.NoError(t, err) - assert.Contains(t, string(data), `"name":"test.prompt"`) - assert.Contains(t, string(data), `"description":"A test prompt description"`) -} - -func TestResourceStructs(t *testing.T) { - // Test Resource struct - resource := Resource{ - Name: "test.resource", - URI: "http://example.com/resource", - Description: "A test resource", - } - - // Verify fields are correct - assert.Equal(t, "test.resource", resource.Name) - assert.Equal(t, "http://example.com/resource", resource.URI) - assert.Equal(t, "A test resource", resource.Description) - - // Verify JSON marshalling - data, err := json.Marshal(resource) - assert.NoError(t, err) - assert.Contains(t, string(data), `"name":"test.resource"`) - assert.Contains(t, string(data), `"uri":"http://example.com/resource"`) - assert.Contains(t, string(data), `"description":"A test resource"`) -} - -func TestContentTypes(t *testing.T) { - // Test TextContent - textContent := TextContent{ - Text: "Sample text", - Annotations: &Annotations{ - Audience: []RoleType{RoleUser, RoleAssistant}, - Priority: ptr(1.0), - }, - } - - data, err := json.Marshal(textContent) - assert.NoError(t, err) - assert.Contains(t, string(data), `"text":"Sample text"`) - assert.Contains(t, string(data), `"audience":["user","assistant"]`) - assert.Contains(t, string(data), `"priority":1`) - - // Test ImageContent - imageContent := ImageContent{ - Data: "base64data", - MimeType: "image/png", - } - - data, err = json.Marshal(imageContent) - assert.NoError(t, err) - assert.Contains(t, string(data), `"data":"base64data"`) - assert.Contains(t, string(data), `"mimeType":"image/png"`) - - // Test AudioContent - audioContent := AudioContent{ - Data: "base64audio", - MimeType: "audio/mp3", - } - - data, err = json.Marshal(audioContent) - assert.NoError(t, err) - assert.Contains(t, string(data), `"data":"base64audio"`) - assert.Contains(t, string(data), `"mimeType":"audio/mp3"`) -} - -func TestCallToolResult(t *testing.T) { - // Test CallToolResult - result := CallToolResult{ - Result: Result{ - Meta: map[string]any{ - "progressToken": "token123", - }, - }, - Content: []interface{}{ - TextContent{ - Text: "Sample result", - }, - }, - IsError: false, - } - - data, err := json.Marshal(result) - assert.NoError(t, err) - assert.Contains(t, string(data), `"_meta":{"progressToken":"token123"}`) - assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`) - assert.NotContains(t, string(data), `"isError":`) -} - -func TestRequest_isNotification(t *testing.T) { - tests := []struct { - name string - id any - want bool - wantErr error - }{ - // integer test cases - {name: "int zero", id: 0, want: true, wantErr: nil}, - {name: "int non-zero", id: 1, want: false, wantErr: nil}, - {name: "int64 zero", id: int64(0), want: true, wantErr: nil}, - {name: "int64 max", id: int64(9223372036854775807), want: false, wantErr: nil}, - - // floating point number test cases - {name: "float64 zero", id: float64(0.0), want: true, wantErr: nil}, - {name: "float64 positive", id: float64(0.000001), want: false, wantErr: nil}, - {name: "float64 negative", id: float64(-0.000001), want: false, wantErr: nil}, - {name: "float64 epsilon", id: float64(1e-300), want: false, wantErr: nil}, - - // string test cases - {name: "empty string", id: "", want: true, wantErr: nil}, - {name: "non-empty string", id: "abc", want: false, wantErr: nil}, - {name: "space string", id: " ", want: false, wantErr: nil}, - {name: "unicode string", id: "こんにちは", want: false, wantErr: nil}, - - // special cases - {name: "nil", id: nil, want: true, wantErr: nil}, - - // logical type test cases - {name: "bool true", id: true, want: false, wantErr: errors.New("invalid type bool")}, - {name: "bool false", id: false, want: false, wantErr: errors.New("invalid type bool")}, - {name: "struct type", id: struct{}{}, want: false, wantErr: errors.New("invalid type struct {}")}, - {name: "slice type", id: []int{1, 2, 3}, want: false, wantErr: errors.New("invalid type []int")}, - {name: "map type", id: map[string]int{"a": 1}, want: false, wantErr: errors.New("invalid type map[string]int")}, - {name: "pointer type", id: new(int), want: false, wantErr: errors.New("invalid type *int")}, - {name: "func type", id: func() {}, want: false, wantErr: errors.New("invalid type func()")}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := Request{ - SessionId: "test-session", - JsonRpc: "2.0", - ID: tt.id, - Method: "testMethod", - Params: json.RawMessage(`{}`), - } - - got, err := req.isNotification() - - if (err != nil) != (tt.wantErr != nil) { - t.Fatalf("error presence mismatch: got error = %v, wantErr %v", err, tt.wantErr) - } - if err != nil && tt.wantErr != nil && err.Error() != tt.wantErr.Error() { - t.Fatalf("error message mismatch:\ngot %q\nwant %q", err.Error(), tt.wantErr.Error()) - } - - if got != tt.want { - t.Errorf("isNotification() = %v, want %v for ID %v (%T)", got, tt.want, tt.id, tt.id) - } - }) - } -} diff --git a/mcp/util.go b/mcp/util.go deleted file mode 100644 index 282840ffd..000000000 --- a/mcp/util.go +++ /dev/null @@ -1,107 +0,0 @@ -package mcp - -import "fmt" - -// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings -func formatSSEMessage(event string, data []byte) string { - return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data)) -} - -// ptr is a helper function to get a pointer to a value -func ptr[T any](v T) *T { - return &v -} - -func toTypedContents(contents []any) []any { - typedContents := make([]any, len(contents)) - - for i, content := range contents { - switch v := content.(type) { - case TextContent: - typedContents[i] = typedTextContent{ - Type: ContentTypeText, - TextContent: v, - } - case ImageContent: - typedContents[i] = typedImageContent{ - Type: ContentTypeImage, - ImageContent: v, - } - case AudioContent: - typedContents[i] = typedAudioContent{ - Type: ContentTypeAudio, - AudioContent: v, - } - default: - typedContents[i] = typedTextContent{ - Type: ContentTypeText, - TextContent: TextContent{ - Text: fmt.Sprintf("Unknown content type: %T", v), - }, - } - } - } - - return typedContents -} - -func toTypedPromptMessages(messages []PromptMessage) []PromptMessage { - typedMessages := make([]PromptMessage, len(messages)) - - for i, msg := range messages { - switch v := msg.Content.(type) { - case TextContent: - typedMessages[i] = PromptMessage{ - Role: msg.Role, - Content: typedTextContent{ - Type: ContentTypeText, - TextContent: v, - }, - } - case ImageContent: - typedMessages[i] = PromptMessage{ - Role: msg.Role, - Content: typedImageContent{ - Type: ContentTypeImage, - ImageContent: v, - }, - } - case AudioContent: - typedMessages[i] = PromptMessage{ - Role: msg.Role, - Content: typedAudioContent{ - Type: ContentTypeAudio, - AudioContent: v, - }, - } - default: - typedMessages[i] = PromptMessage{ - Role: msg.Role, - Content: typedTextContent{ - Type: ContentTypeText, - TextContent: TextContent{ - Text: fmt.Sprintf("Unknown content type: %T", v), - }, - }, - } - } - } - - return typedMessages -} - -// validatePromptArguments checks if all required arguments are provided -// Returns a list of missing required arguments -func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string { - var missingArgs []string - - for _, arg := range prompt.Arguments { - if arg.Required { - if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 { - missingArgs = append(missingArgs, arg.Name) - } - } - } - - return missingArgs -} diff --git a/mcp/util_test.go b/mcp/util_test.go deleted file mode 100644 index 0014378b6..000000000 --- a/mcp/util_test.go +++ /dev/null @@ -1,274 +0,0 @@ -package mcp - -import ( - "bufio" - "encoding/json" - "fmt" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -type Event struct { - Type string - Data map[string]any -} - -func parseEvent(input string) (*Event, error) { - var evt Event - var dataStr string - - scanner := bufio.NewScanner(strings.NewReader(input)) - for scanner.Scan() { - line := scanner.Text() - if strings.HasPrefix(line, "event:") { - evt.Type = strings.TrimSpace(strings.TrimPrefix(line, "event:")) - } else if strings.HasPrefix(line, "data:") { - dataStr = strings.TrimSpace(strings.TrimPrefix(line, "data:")) - } - } - if err := scanner.Err(); err != nil { - return nil, err - } - - if len(dataStr) > 0 { - if err := json.Unmarshal([]byte(dataStr), &evt.Data); err != nil { - return nil, fmt.Errorf("failed to parse data: %w", err) - } - } - - return &evt, nil -} - -// TestToTypedPromptMessages tests the toTypedPromptMessages function -func TestToTypedPromptMessages(t *testing.T) { - // Test with multiple message types in one test - t.Run("MixedContentTypes", func(t *testing.T) { - // Create test data with different content types - messages := []PromptMessage{ - { - Role: RoleUser, - Content: TextContent{ - Text: "Hello, this is a text message", - Annotations: &Annotations{ - Audience: []RoleType{RoleUser, RoleAssistant}, - Priority: ptr(0.8), - }, - }, - }, - { - Role: RoleAssistant, - Content: ImageContent{ - Data: "base64ImageData", - MimeType: "image/jpeg", - }, - }, - { - Role: RoleUser, - Content: AudioContent{ - Data: "base64AudioData", - MimeType: "audio/mp3", - }, - }, - { - Role: "system", - Content: "This is a simple string that should be handled as unknown type", - }, - } - - // Call the function - result := toTypedPromptMessages(messages) - - // Validate results - require.Len(t, result, 4, "Should return the same number of messages") - - // Validate first message (TextContent) - msg := result[0] - assert.Equal(t, RoleUser, msg.Role, "Role should be preserved") - - // Type assertion using reflection since Content is an interface - typed, ok := msg.Content.(typedTextContent) - require.True(t, ok, "Should be typedTextContent") - assert.Equal(t, ContentTypeText, typed.Type, "Type should be text") - assert.Equal(t, "Hello, this is a text message", typed.Text, "Text content should be preserved") - require.NotNil(t, typed.Annotations, "Annotations should be preserved") - assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved") - require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved") - assert.Equal(t, 0.8, *typed.Annotations.Priority, "Priority value should be preserved") - - // Validate second message (ImageContent) - msg = result[1] - assert.Equal(t, RoleAssistant, msg.Role, "Role should be preserved") - - // Type assertion for image content - typedImg, ok := msg.Content.(typedImageContent) - require.True(t, ok, "Should be typedImageContent") - assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image") - assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved") - assert.Equal(t, "image/jpeg", typedImg.MimeType, "MimeType should be preserved") - - // Validate third message (AudioContent) - msg = result[2] - assert.Equal(t, RoleUser, msg.Role, "Role should be preserved") - - // Type assertion for audio content - typedAudio, ok := msg.Content.(typedAudioContent) - require.True(t, ok, "Should be typedAudioContent") - assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio") - assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved") - assert.Equal(t, "audio/mp3", typedAudio.MimeType, "MimeType should be preserved") - - // Validate fourth message (unknown type converted to TextContent) - msg = result[3] - assert.Equal(t, RoleType("system"), msg.Role, "Role should be preserved") - - // Should be converted to a typedTextContent with error message - typedUnknown, ok := msg.Content.(typedTextContent) - require.True(t, ok, "Unknown content should be converted to typedTextContent") - assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text") - assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type") - assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type") - }) - - // Test empty input - t.Run("EmptyInput", func(t *testing.T) { - messages := []PromptMessage{} - result := toTypedPromptMessages(messages) - assert.Empty(t, result, "Should return empty slice for empty input") - }) - - // Test with nil annotations - t.Run("NilAnnotations", func(t *testing.T) { - messages := []PromptMessage{ - { - Role: RoleUser, - Content: TextContent{ - Text: "Text with nil annotations", - Annotations: nil, - }, - }, - } - - result := toTypedPromptMessages(messages) - require.Len(t, result, 1, "Should return one message") - - typed, ok := result[0].Content.(typedTextContent) - require.True(t, ok, "Should be typedTextContent") - assert.Equal(t, ContentTypeText, typed.Type, "Type should be text") - assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved") - assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil") - }) -} - -// TestToTypedContents tests the toTypedContents function -func TestToTypedContents(t *testing.T) { - // Test with multiple content types in one test - t.Run("MixedContentTypes", func(t *testing.T) { - // Create test data with different content types - contents := []any{ - TextContent{ - Text: "Hello, this is a text content", - Annotations: &Annotations{ - Audience: []RoleType{RoleUser, RoleAssistant}, - Priority: ptr(0.7), - }, - }, - ImageContent{ - Data: "base64ImageData", - MimeType: "image/png", - }, - AudioContent{ - Data: "base64AudioData", - MimeType: "audio/wav", - }, - "This is a simple string that should be handled as unknown type", - } - - // Call the function - result := toTypedContents(contents) - - // Validate results - require.Len(t, result, 4, "Should return the same number of contents") - - // Validate first content (TextContent) - typed, ok := result[0].(typedTextContent) - require.True(t, ok, "Should be typedTextContent") - assert.Equal(t, ContentTypeText, typed.Type, "Type should be text") - assert.Equal(t, "Hello, this is a text content", typed.Text, "Text content should be preserved") - require.NotNil(t, typed.Annotations, "Annotations should be preserved") - assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved") - require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved") - assert.Equal(t, 0.7, *typed.Annotations.Priority, "Priority value should be preserved") - - // Validate second content (ImageContent) - typedImg, ok := result[1].(typedImageContent) - require.True(t, ok, "Should be typedImageContent") - assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image") - assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved") - assert.Equal(t, "image/png", typedImg.MimeType, "MimeType should be preserved") - - // Validate third content (AudioContent) - typedAudio, ok := result[2].(typedAudioContent) - require.True(t, ok, "Should be typedAudioContent") - assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio") - assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved") - assert.Equal(t, "audio/wav", typedAudio.MimeType, "MimeType should be preserved") - - // Validate fourth content (unknown type converted to TextContent) - typedUnknown, ok := result[3].(typedTextContent) - require.True(t, ok, "Unknown content should be converted to typedTextContent") - assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text") - assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type") - assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type") - }) - - // Test empty input - t.Run("EmptyInput", func(t *testing.T) { - contents := []any{} - result := toTypedContents(contents) - assert.Empty(t, result, "Should return empty slice for empty input") - }) - - // Test with nil annotations - t.Run("NilAnnotations", func(t *testing.T) { - contents := []any{ - TextContent{ - Text: "Text with nil annotations", - Annotations: nil, - }, - } - - result := toTypedContents(contents) - require.Len(t, result, 1, "Should return one content") - - typed, ok := result[0].(typedTextContent) - require.True(t, ok, "Should be typedTextContent") - assert.Equal(t, ContentTypeText, typed.Type, "Type should be text") - assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved") - assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil") - }) - - // Test with custom struct (should be handled as unknown type) - t.Run("CustomStruct", func(t *testing.T) { - type CustomContent struct { - Data string - } - - contents := []any{ - CustomContent{ - Data: "custom data", - }, - } - - result := toTypedContents(contents) - require.Len(t, result, 1, "Should return one content") - - typed, ok := result[0].(typedTextContent) - require.True(t, ok, "Custom struct should be converted to typedTextContent") - assert.Equal(t, ContentTypeText, typed.Type, "Type should be text") - assert.Contains(t, typed.Text, "Unknown content type:", "Should contain error about unknown type") - assert.Contains(t, typed.Text, "CustomContent", "Should mention the actual type") - }) -} diff --git a/mcp/vars.go b/mcp/vars.go deleted file mode 100644 index 7a5b6fb8e..000000000 --- a/mcp/vars.go +++ /dev/null @@ -1,149 +0,0 @@ -package mcp - -import ( - "time" - - "github.com/zeromicro/go-zero/core/syncx" -) - -// Protocol constants -const ( - // JSON-RPC version as defined in the specification - jsonRpcVersion = "2.0" - - // Session identifier key used in request URLs - sessionIdKey = "session_id" - - // progressTokenKey is used to track progress of long-running tasks - progressTokenKey = "progressToken" -) - -// Server-Sent Events (SSE) event types -const ( - // Standard message event for JSON-RPC responses - eventMessage = "message" - - // Endpoint event for sending endpoint URL to clients - eventEndpoint = "endpoint" -) - -// Content type identifiers -const ( - // ContentTypeObject is object content type - ContentTypeObject = "object" - - // ContentTypeText is text content type - ContentTypeText = "text" - - // ContentTypeImage is image content type - ContentTypeImage = "image" - - // ContentTypeAudio is audio content type - ContentTypeAudio = "audio" - - // ContentTypeResource is resource content type - ContentTypeResource = "resource" -) - -// Collection keys for broadcast events -const ( - // Key for prompts collection - keyPrompts = "prompts" - - // Key for resources collection - keyResources = "resources" - - // Key for tools collection - keyTools = "tools" -) - -// JSON-RPC error codes -// Standard error codes from JSON-RPC 2.0 spec -const ( - // Invalid JSON was received by the server - errCodeInvalidRequest = -32600 - - // The method does not exist / is not available - errCodeMethodNotFound = -32601 - - // Invalid method parameter(s) - errCodeInvalidParams = -32602 - - // Internal JSON-RPC error - errCodeInternalError = -32603 - - // Tool execution timed out - errCodeTimeout = -32001 - - // Resource not found error - errCodeResourceNotFound = -32002 - - // Client hasn't completed initialization - errCodeClientNotInitialized = -32800 -) - -// User and assistant role definitions -const ( - // RoleUser is the "user" role - the entity asking questions - RoleUser RoleType = "user" - - // RoleAssistant is the "assistant" role - the entity providing responses - RoleAssistant RoleType = "assistant" -) - -// Method names as defined in the MCP specification -const ( - // Initialize the connection between client and server - methodInitialize = "initialize" - - // List available tools - methodToolsList = "tools/list" - - // Call a specific tool - methodToolsCall = "tools/call" - - // List available prompts - methodPromptsList = "prompts/list" - - // Get a specific prompt - methodPromptsGet = "prompts/get" - - // List available resources - methodResourcesList = "resources/list" - - // Read a specific resource - methodResourcesRead = "resources/read" - - // Subscribe to resource updates - methodResourcesSubscribe = "resources/subscribe" - - // Simple ping to check server availability - methodPing = "ping" - - // Notification that client is fully initialized - methodNotificationsInitialized = "notifications/initialized" - - // Notification that a request was canceled - methodNotificationsCancelled = "notifications/cancelled" -) - -// Event names for Server-Sent Events (SSE) -const ( - // Notification of tool list changes - eventToolsListChanged = "tools/list_changed" - - // Notification of prompt list changes - eventPromptsListChanged = "prompts/list_changed" - - // Notification of resource list changes - eventResourcesListChanged = "resources/list_changed" -) - -var ( - // Default channel size for events - eventChanSize = 10 - - // Default ping interval for checking connection availability - // use syncx.ForAtomicDuration to ensure atomicity in test race - pingInterval = syncx.ForAtomicDuration(30 * time.Second) -) diff --git a/mcp/vars_test.go b/mcp/vars_test.go deleted file mode 100644 index 4a894a16f..000000000 --- a/mcp/vars_test.go +++ /dev/null @@ -1,210 +0,0 @@ -// filepath: /Users/kevin/Develop/go/opensource/go-zero/mcp/vars_test.go -package mcp - -import ( - "encoding/json" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" -) - -// TestErrorCodes ensures error codes are applied correctly in error responses -func TestErrorCodes(t *testing.T) { - testCases := []struct { - name string - code int - message string - expected string - }{ - { - name: "invalid request error", - code: errCodeInvalidRequest, - message: "Invalid request", - expected: `"code":-32600`, - }, - { - name: "method not found error", - code: errCodeMethodNotFound, - message: "Method not found", - expected: `"code":-32601`, - }, - { - name: "invalid params error", - code: errCodeInvalidParams, - message: "Invalid parameters", - expected: `"code":-32602`, - }, - { - name: "internal error", - code: errCodeInternalError, - message: "Internal server error", - expected: `"code":-32603`, - }, - { - name: "timeout error", - code: errCodeTimeout, - message: "Operation timed out", - expected: `"code":-32001`, - }, - { - name: "resource not found error", - code: errCodeResourceNotFound, - message: "Resource not found", - expected: `"code":-32002`, - }, - { - name: "client not initialized error", - code: errCodeClientNotInitialized, - message: "Client not initialized", - expected: `"code":-32800`, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - resp := Response{ - JsonRpc: jsonRpcVersion, - ID: int64(1), - Error: &errorObj{ - Code: tc.code, - Message: tc.message, - }, - } - data, err := json.Marshal(resp) - assert.NoError(t, err) - assert.Contains(t, string(data), tc.expected, "Error code should match expected value") - assert.Contains(t, string(data), tc.message, "Error message should be included") - assert.Contains(t, string(data), jsonRpcVersion, "JSON-RPC version should be included") - }) - } -} - -// TestJsonRpcVersion ensures the correct JSON-RPC version is used -func TestJsonRpcVersion(t *testing.T) { - assert.Equal(t, "2.0", jsonRpcVersion, "JSON-RPC version should be 2.0") - - // Test that it's used in responses - resp := Response{ - JsonRpc: jsonRpcVersion, - ID: int64(1), - Result: "test", - } - data, err := json.Marshal(resp) - assert.NoError(t, err) - assert.Contains(t, string(data), `"jsonrpc":"2.0"`, "Response should use correct JSON-RPC version") - - // Test that it's expected in requests - reqStr := `{"jsonrpc":"2.0","id":1,"method":"test"}` - var req Request - err = json.Unmarshal([]byte(reqStr), &req) - assert.NoError(t, err) - assert.Equal(t, jsonRpcVersion, req.JsonRpc, "Request should parse correct JSON-RPC version") -} - -// TestSessionIdKey ensures session ID extraction works correctly -func TestSessionIdKey(t *testing.T) { - // Create a mock server implementation - mock := newMockMcpServer(t) - defer mock.shutdown() - - // Verify the key constant - assert.Equal(t, "session_id", sessionIdKey, "Session ID key should be 'session_id'") - - // Test that session ID is extracted correctly - mockR := httptest.NewRequest("GET", "/?"+sessionIdKey+"=test-session", nil) - - // Since the mock server is using the same session key logic, - // we can test this by accessing the request query parameters directly - sessionID := mockR.URL.Query().Get(sessionIdKey) - assert.Equal(t, "test-session", sessionID, "Session ID should be extracted correctly") -} - -// TestEventTypes ensures event types are set correctly in SSE responses -func TestEventTypes(t *testing.T) { - // Test message event - assert.Equal(t, "message", eventMessage, "Message event should be 'message'") - - // Test endpoint event - assert.Equal(t, "endpoint", eventEndpoint, "Endpoint event should be 'endpoint'") - - // Verify them in an actual SSE format string - messageEvent := "event: " + eventMessage + "\ndata: test\n\n" - assert.Contains(t, messageEvent, "event: message", "Message event should format correctly") - - endpointEvent := "event: " + eventEndpoint + "\ndata: test\n\n" - assert.Contains(t, endpointEvent, "event: endpoint", "Endpoint event should format correctly") -} - -// TestCollectionKeys checks that collection keys are used correctly -func TestCollectionKeys(t *testing.T) { - // Verify collection key constants - assert.Equal(t, "prompts", keyPrompts, "Prompts key should be 'prompts'") - assert.Equal(t, "resources", keyResources, "Resources key should be 'resources'") - assert.Equal(t, "tools", keyTools, "Tools key should be 'tools'") -} - -// TestRoleTypes checks that role types are used correctly -func TestRoleTypes(t *testing.T) { - // Test in annotations - annotations := Annotations{ - Audience: []RoleType{RoleUser, RoleAssistant}, - } - data, err := json.Marshal(annotations) - assert.NoError(t, err) - assert.Contains(t, string(data), `"audience":["user","assistant"]`, "Role types should marshal correctly") -} - -// TestMethodNames checks that method names are used correctly -func TestMethodNames(t *testing.T) { - // Verify method name constants - methods := map[string]string{ - "initialize": methodInitialize, - "tools/list": methodToolsList, - "tools/call": methodToolsCall, - "prompts/list": methodPromptsList, - "prompts/get": methodPromptsGet, - "resources/list": methodResourcesList, - "resources/read": methodResourcesRead, - "resources/subscribe": methodResourcesSubscribe, - "ping": methodPing, - "notifications/initialized": methodNotificationsInitialized, - "notifications/cancelled": methodNotificationsCancelled, - } - - for expected, actual := range methods { - assert.Equal(t, expected, actual, "Method name should be "+expected) - } - - // Test in a request - for methodName := range methods { - req := Request{ - JsonRpc: jsonRpcVersion, - ID: int64(1), - Method: methodName, - } - data, err := json.Marshal(req) - assert.NoError(t, err) - assert.Contains(t, string(data), `"method":"`+methodName+`"`, "Method name should be used in requests") - } -} - -// TestEventNames checks that event names are used correctly -func TestEventNames(t *testing.T) { - // Verify event name constants - events := map[string]string{ - "tools/list_changed": eventToolsListChanged, - "prompts/list_changed": eventPromptsListChanged, - "resources/list_changed": eventResourcesListChanged, - } - - for expected, actual := range events { - assert.Equal(t, expected, actual, "Event name should be "+expected) - } - - // Test event names in SSE format - for _, eventName := range events { - sseEvent := "event: " + eventName + "\ndata: test\n\n" - assert.Contains(t, sseEvent, "event: "+eventName, "Event name should format correctly in SSE") - } -}