diff --git a/mcp/config.go b/mcp/config.go new file mode 100644 index 000000000..e9a6635ff --- /dev/null +++ b/mcp/config.go @@ -0,0 +1,40 @@ +package mcp + +import ( + "time" + + "github.com/zeromicro/go-zero/rest" +) + +// McpConf defines the configuration for an MCP server. +// It embeds rest.RestConf for HTTP server settings +// and adds MCP-specific configuration options. +type McpConf struct { + rest.RestConf + Mcp struct { + // Name is the server name reported in initialize responses + Name string `json:",optional"` + + // Version is the server version reported in initialize responses + Version string `json:",default=1.0.0"` + + // ProtocolVersion is the MCP protocol version implemented + ProtocolVersion string `json:",default=2024-11-05"` + + // BaseUrl is the base URL for the server, used in SSE endpoint messages + // If not set, defaults to http://localhost:{Port} + BaseUrl string `json:",optional"` + + // SseEndpoint is the path for Server-Sent Events connections + SseEndpoint string `json:",default=/sse"` + + // MessageEndpoint is the path for JSON-RPC requests + MessageEndpoint string `json:",default=/message"` + + // Cors contains allowed CORS origins + Cors []string `json:",optional"` + + // ToolTimeout is the maximum time allowed for tool execution + ToolTimeout time.Duration `json:",default=30s"` + } +} diff --git a/mcp/config_test.go b/mcp/config_test.go new file mode 100644 index 000000000..4156f0a86 --- /dev/null +++ b/mcp/config_test.go @@ -0,0 +1,63 @@ +package mcp + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/core/conf" +) + +func TestMcpConfDefaults(t *testing.T) { + // Test default values are set correctly when unmarshalled from JSON + jsonConfig := `name: test-service +port: 8080 +mcp: + name: test-mcp-server + version: 1.0.0 +` + + var c McpConf + err := conf.LoadFromYamlBytes([]byte(jsonConfig), &c) + assert.NoError(t, err) + + // Check default values + assert.Equal(t, "test-mcp-server", c.Mcp.Name) + assert.Equal(t, "1.0.0", c.Mcp.Version, "Default version should be 1.0.0") + assert.Equal(t, "2024-11-05", c.Mcp.ProtocolVersion, "Default protocol version should be 2024-11-05") + assert.Equal(t, "/sse", c.Mcp.SseEndpoint, "Default SSE endpoint should be /sse") + assert.Equal(t, "/message", c.Mcp.MessageEndpoint, "Default message endpoint should be /message") + assert.Equal(t, 30*time.Second, c.Mcp.ToolTimeout, "Default tool timeout should be 30s") +} + +func TestMcpConfCustomValues(t *testing.T) { + // Test custom values can be set + jsonConfig := `{ + "Name": "test-service", + "Port": 8080, + "Mcp": { + "Name": "test-mcp-server", + "Version": "2.0.0", + "ProtocolVersion": "2025-01-01", + "BaseUrl": "http://example.com", + "SseEndpoint": "/custom-sse", + "MessageEndpoint": "/custom-message", + "Cors": ["http://localhost:3000", "http://example.com"], + "ToolTimeout": "60s" + } + }` + + var c McpConf + err := conf.LoadFromJsonBytes([]byte(jsonConfig), &c) + assert.NoError(t, err) + + // Check custom values + assert.Equal(t, "test-mcp-server", c.Mcp.Name, "Name should be inherited from RestConf") + assert.Equal(t, "2.0.0", c.Mcp.Version, "Version should be customizable") + assert.Equal(t, "2025-01-01", c.Mcp.ProtocolVersion, "Protocol version should be customizable") + assert.Equal(t, "http://example.com", c.Mcp.BaseUrl, "BaseUrl should be customizable") + assert.Equal(t, "/custom-sse", c.Mcp.SseEndpoint, "SSE endpoint should be customizable") + assert.Equal(t, "/custom-message", c.Mcp.MessageEndpoint, "Message endpoint should be customizable") + assert.Equal(t, []string{"http://localhost:3000", "http://example.com"}, c.Mcp.Cors, "CORS settings should be customizable") + assert.Equal(t, 60*time.Second, c.Mcp.ToolTimeout, "Tool timeout should be customizable") +} diff --git a/mcp/integration_test.go b/mcp/integration_test.go new file mode 100644 index 000000000..5d7016fc5 --- /dev/null +++ b/mcp/integration_test.go @@ -0,0 +1,443 @@ +package mcp + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// syncResponseRecorder is a thread-safe wrapper around httptest.ResponseRecorder +type syncResponseRecorder struct { + *httptest.ResponseRecorder + mu sync.Mutex +} + +// Create a new synchronized response recorder +func newSyncResponseRecorder() *syncResponseRecorder { + return &syncResponseRecorder{ + ResponseRecorder: httptest.NewRecorder(), + } +} + +// Override Write method to synchronize access +func (srr *syncResponseRecorder) Write(p []byte) (int, error) { + srr.mu.Lock() + defer srr.mu.Unlock() + return srr.ResponseRecorder.Write(p) +} + +// Override WriteHeader method to synchronize access +func (srr *syncResponseRecorder) WriteHeader(statusCode int) { + srr.mu.Lock() + defer srr.mu.Unlock() + srr.ResponseRecorder.WriteHeader(statusCode) +} + +// Override Result method to synchronize access +func (srr *syncResponseRecorder) Result() *http.Response { + srr.mu.Lock() + defer srr.mu.Unlock() + return srr.ResponseRecorder.Result() +} + +// TestHTTPHandlerIntegration tests the HTTP handlers with a real server instance +func TestHTTPHandlerIntegration(t *testing.T) { + // Skip in short test mode + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + // Create a test configuration + conf := McpConf{} + conf.Mcp.Name = "test-integration" + conf.Mcp.Version = "1.0.0-test" + conf.Mcp.ToolTimeout = 1 * time.Second + + // Create a mock server directly + server := &sseMcpServer{ + conf: conf, + clients: make(map[string]*mcpClient), + tools: make(map[string]Tool), + prompts: make(map[string]Prompt), + resources: make(map[string]Resource), + } + + // Register a test tool + err := server.RegisterTool(Tool{ + Name: "echo", + Description: "Echo tool for testing", + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]any{ + "message": map[string]any{ + "type": "string", + "description": "Message to echo", + }, + }, + }, + Handler: func(params map[string]any) (any, error) { + if msg, ok := params["message"].(string); ok { + return fmt.Sprintf("Echo: %s", msg), nil + } + return "Echo: no message provided", nil + }, + }) + require.NoError(t, err) + + // Create a test HTTP request to the SSE endpoint + req := httptest.NewRequest("GET", "/sse", nil) + w := newSyncResponseRecorder() + + // Create a done channel to signal completion of test + done := make(chan bool) + + // Start the SSE handler in a goroutine + go func() { + // lock.Lock() + server.handleSSE(w, req) + // lock.Unlock() + done <- true + }() + + // Allow time for the handler to process + select { + case <-time.After(100 * time.Millisecond): + // Expected - handler would normally block indefinitely + case <-done: + // This shouldn't happen immediately - the handler should block + t.Error("SSE handler returned unexpectedly") + } + + // Check the initial headers + resp := w.Result() + assert.Equal(t, "chunked", resp.Header.Get("Transfer-Encoding")) + resp.Body.Close() + + // The handler creates a client and sends the endpoint message + var sessionId string + + // Give the handler time to set up the client + time.Sleep(50 * time.Millisecond) + + // Check that a client was created + server.clientsLock.Lock() + assert.Equal(t, 1, len(server.clients)) + for id := range server.clients { + sessionId = id + } + server.clientsLock.Unlock() + + require.NotEmpty(t, sessionId, "Expected a session ID to be created") + + // Now that we have a session ID, we can test the message endpoint + messageBody, _ := json.Marshal(Request{ + JsonRpc: "2.0", + ID: 1, + Method: methodInitialize, + Params: json.RawMessage(`{}`), + }) + + // Create a message request + reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, sessionId) + msgReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(messageBody)) + msgW := newSyncResponseRecorder() + + // Process the message + server.handleRequest(msgW, msgReq) + + // Check the response + msgResp := msgW.Result() + assert.Equal(t, http.StatusAccepted, msgResp.StatusCode) + msgResp.Body.Close() // Ensure response body is closed +} + +// TestHandlerResponseFlow tests the flow of a full request/response cycle +func TestHandlerResponseFlow(t *testing.T) { + // Create a mock server for testing + server := &sseMcpServer{ + conf: McpConf{}, + clients: map[string]*mcpClient{ + "test-session": { + id: "test-session", + channel: make(chan string, 10), + initialized: true, + }, + }, + tools: make(map[string]Tool), + prompts: make(map[string]Prompt), + resources: make(map[string]Resource), + } + + // Register test resources + server.RegisterTool(Tool{ + Name: "test.tool", + Description: "Test tool", + InputSchema: InputSchema{Type: "object"}, + Handler: func(params map[string]any) (any, error) { + return "tool result", nil + }, + }) + + server.RegisterPrompt(Prompt{ + Name: "test.prompt", + Description: "Test prompt", + }) + + server.RegisterResource(Resource{ + Name: "test.resource", + URI: "http://example.com", + Description: "Test resource", + }) + + // Create a request with session ID parameter + reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, "test-session") + + // Test tools/list request + toolsListBody, _ := json.Marshal(Request{ + JsonRpc: "2.0", + ID: 1, + Method: methodToolsList, + Params: json.RawMessage(`{}`), + }) + + toolsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(toolsListBody)) + toolsW := newSyncResponseRecorder() + + // Process the request + server.handleRequest(toolsW, toolsReq) + + // Check the response code + toolsResp := toolsW.Result() + assert.Equal(t, http.StatusAccepted, toolsResp.StatusCode) + toolsResp.Body.Close() + + // Check the channel message + client := server.clients["test-session"] + select { + case message := <-client.channel: + assert.Contains(t, message, `"tools":[{"name":"test.tool"`) + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for tools/list response") + } + + // Test prompts/list request + promptsListBody, _ := json.Marshal(Request{ + JsonRpc: "2.0", + ID: 2, + Method: methodPromptsList, + Params: json.RawMessage(`{}`), + }) + + promptsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(promptsListBody)) + promptsW := newSyncResponseRecorder() + + // Process the request + server.handleRequest(promptsW, promptsReq) + + // Check the response code + promptsResp := promptsW.Result() + assert.Equal(t, http.StatusAccepted, promptsResp.StatusCode) + promptsResp.Body.Close() + + // Check the channel message + select { + case message := <-client.channel: + assert.Contains(t, message, `"prompts":[{"name":"test.prompt"`) + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for prompts/list response") + } + + // Test resources/list request + resourcesListBody, _ := json.Marshal(Request{ + JsonRpc: "2.0", + ID: 3, + Method: methodResourcesList, + Params: json.RawMessage(`{}`), + }) + + resourcesReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(resourcesListBody)) + resourcesW := newSyncResponseRecorder() + + // Process the request + server.handleRequest(resourcesW, resourcesReq) + + // Check the response code + resourcesResp := resourcesW.Result() + assert.Equal(t, http.StatusAccepted, resourcesResp.StatusCode) + resourcesResp.Body.Close() + + // Check the channel message + select { + case message := <-client.channel: + assert.Contains(t, message, `"name":"test.resource"`) + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for resources/list response") + } +} + +// TestProcessListMethods tests the list processing methods with pagination +func TestProcessListMethods(t *testing.T) { + server := &sseMcpServer{ + tools: make(map[string]Tool), + prompts: make(map[string]Prompt), + resources: make(map[string]Resource), + } + + // Add some test data + for i := 1; i <= 5; i++ { + tool := Tool{ + Name: fmt.Sprintf("tool%d", i), + Description: fmt.Sprintf("Tool %d", i), + InputSchema: InputSchema{Type: "object"}, + } + server.tools[tool.Name] = tool + + prompt := Prompt{ + Name: fmt.Sprintf("prompt%d", i), + Description: fmt.Sprintf("Prompt %d", i), + } + server.prompts[prompt.Name] = prompt + + resource := Resource{ + Name: fmt.Sprintf("resource%d", i), + URI: fmt.Sprintf("http://example.com/%d", i), + Description: fmt.Sprintf("Resource %d", i), + } + server.resources[resource.Name] = resource + } + + // Create a test client + client := &mcpClient{ + id: "test-client", + channel: make(chan string, 10), + initialized: true, + } + + // Test processListTools + req := Request{ + JsonRpc: "2.0", + ID: 1, + Method: methodToolsList, + Params: json.RawMessage(`{"cursor": "", "_meta": {"progressToken": "token1"}}`), + } + + server.processListTools(client, req) + + // Read response + select { + case response := <-client.channel: + assert.Contains(t, response, `"tools":`) + assert.Contains(t, response, `"progressToken":"token1"`) + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for tools/list response") + } + + // Test processListPrompts + req.ID = 2 + req.Method = methodPromptsList + req.Params = json.RawMessage(`{"cursor": "next"}`) + server.processListPrompts(client, req) + + // Read response + select { + case response := <-client.channel: + assert.Contains(t, response, `"prompts":`) + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for prompts/list response") + } + + // Test processListResources + req.ID = 3 + req.Method = methodResourcesList + req.Params = json.RawMessage(`{"cursor": "next"}`) + server.processListResources(client, req) + + // Read response + select { + case response := <-client.channel: + assert.Contains(t, response, `"resources":`) + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for resources/list response") + } +} + +// TestErrorResponseHandling tests error handling in the server +func TestErrorResponseHandling(t *testing.T) { + server := &sseMcpServer{ + tools: make(map[string]Tool), + prompts: make(map[string]Prompt), + resources: make(map[string]Resource), + } + + // Create a test client + client := &mcpClient{ + id: "test-client", + channel: make(chan string, 10), + initialized: true, + } + + // Test invalid method + req := Request{ + JsonRpc: "2.0", + ID: 1, + Method: "invalid_method", + Params: json.RawMessage(`{}`), + } + + // Mock handleRequest by directly calling error handler + server.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound) + + // Check response + select { + case response := <-client.channel: + assert.Contains(t, response, `"error":{"code":-32601,"message":"Method not found"}`) + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error response") + } + + // Test invalid tool + toolReq := Request{ + JsonRpc: "2.0", + ID: 2, + Method: methodToolsCall, + Params: json.RawMessage(`{"name":"non_existent_tool"}`), + } + + // Call process method directly + server.processToolCall(client, toolReq) + + // Check response + select { + case response := <-client.channel: + assert.Contains(t, response, `"error":{"code":-32602,"message":"Tool 'non_existent_tool' not found"}`) + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error response") + } + + // Test invalid prompt + promptReq := Request{ + JsonRpc: "2.0", + ID: 3, + Method: methodPromptsGet, + Params: json.RawMessage(`{"name":"non_existent_prompt"}`), + } + + // Call process method directly + server.processGetPrompt(client, promptReq) + + // Check response + select { + case response := <-client.channel: + assert.Contains(t, response, `"error":{"code":-32602,"message":"Prompt 'non_existent_prompt' not found"}`) + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error response") + } +} diff --git a/mcp/readme.md b/mcp/readme.md new file mode 100644 index 000000000..98094aca2 --- /dev/null +++ b/mcp/readme.md @@ -0,0 +1,62 @@ +# Model Context Protocol (MCP) SDK Implementation + +## Overview +This package implements a Model Context Protocol (MCP) server in Go that facilitates real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation provides a framework for building AI-assisted applications with bidirectional communication capabilities. + +## Core Components + +### Server-Sent Events (SSE) Communication +- **Real-time Communication**: Robust SSE-based communication system that maintains persistent connections with clients +- **Connection Management**: Client registration, message broadcasting, and client cleanup mechanisms +- **Event Handling**: Event types for tools, prompts, and resources changes + +### JSON-RPC Implementation +- **Request Processing**: Complete JSON-RPC request processor for handling MCP protocol methods +- **Response Formatting**: Proper response formatting according to JSON-RPC specifications +- **Error Handling**: Comprehensive error handling with appropriate error codes + +### Tool Management +- **Tool Registration**: System to register custom tools with handlers +- **Tool Execution**: Mechanism to execute tool functions with proper timeout handling +- **Result Handling**: Flexible result handling supporting various return types (string, JSON, images) + +### Prompt System +- **Prompt Registration**: System for registering both static and dynamic prompts +- **Argument Validation**: Validation for required arguments and default values for optional ones +- **Message Generation**: Handlers that generate properly formatted conversation messages + +### Resource Management +- **Resource Registration**: System for managing and accessing external resources +- **Content Delivery**: Handlers for delivering resource content to clients on demand +- **Resource Subscription**: Mechanisms for clients to subscribe to resource updates + +### Protocol Features +- **Initialization Sequence**: Proper handshaking with capability negotiation +- **Notification Handling**: Support for both standard and client-specific notifications +- **Message Routing**: Intelligent routing of requests to appropriate handlers + +## Technical Highlights + +### Configuration System +- **Flexible Configuration**: Configuration system with sensible defaults and customization options +- **CORS Support**: Configurable CORS settings for cross-origin requests +- **Server Information**: Proper server identification and versioning + +### Client Session Management +- **Session Tracking**: Client session tracking with unique identifiers +- **Connection Health**: Ping/pong mechanism to maintain connection health +- **Initialization State**: Client initialization state tracking + +### Content Handling +- **Multi-format Content**: Support for text, code, and binary content +- **MIME Type Support**: Proper MIME type identification for various content types +- **Audience Annotations**: Content audience annotations for user/assistant targeting + +## Usage + +To create and use an MCP server, see the examples directory for practical implementation examples including: +- Tool registration and execution +- Static and dynamic prompt creation +- Resource handling with proper URI identification +- Embedded resources in prompt responses +- Client connection management \ No newline at end of file diff --git a/mcp/server.go b/mcp/server.go new file mode 100644 index 000000000..f2a2f8889 --- /dev/null +++ b/mcp/server.go @@ -0,0 +1,974 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "github.com/google/uuid" + "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/rest" +) + +func NewMcpServer(c McpConf) McpServer { + var server *rest.Server + if len(c.Mcp.Cors) == 0 { + server = rest.MustNewServer(c.RestConf) + } else { + server = rest.MustNewServer(c.RestConf, rest.WithCors(c.Mcp.Cors...)) + } + + if len(c.Mcp.Name) == 0 { + c.Mcp.Name = c.Name + } + if len(c.Mcp.BaseUrl) == 0 { + c.Mcp.BaseUrl = fmt.Sprintf("http://localhost:%d", c.Port) + } + + s := &sseMcpServer{ + conf: c, + server: server, + clients: make(map[string]*mcpClient), + tools: make(map[string]Tool), + prompts: make(map[string]Prompt), + resources: make(map[string]Resource), + } + + // SSE endpoint for real-time updates + s.server.AddRoute(rest.Route{ + Method: http.MethodGet, + Path: s.conf.Mcp.SseEndpoint, + Handler: s.handleSSE, + }, rest.WithSSE()) + + // JSON-RPC message endpoint for regular requests + s.server.AddRoute(rest.Route{ + Method: http.MethodPost, + Path: s.conf.Mcp.MessageEndpoint, + Handler: s.handleRequest, + }) + + return s +} + +// RegisterPrompt registers a new prompt with the server +func (s *sseMcpServer) RegisterPrompt(prompt Prompt) { + s.promptsLock.Lock() + s.prompts[prompt.Name] = prompt + s.promptsLock.Unlock() + // Notify clients about the new prompt + s.broadcast(eventPromptsListChanged, map[string][]Prompt{keyPrompts: {prompt}}) +} + +// RegisterResource registers a new resource with the server +func (s *sseMcpServer) RegisterResource(resource Resource) { + s.resourcesLock.Lock() + s.resources[resource.URI] = resource + s.resourcesLock.Unlock() + // Notify clients about the new resource + s.broadcast(eventResourcesListChanged, map[string][]Resource{keyResources: {resource}}) +} + +// RegisterTool registers a new tool with the server +func (s *sseMcpServer) RegisterTool(tool Tool) error { + if tool.Handler == nil { + return fmt.Errorf("tool '%s' has no handler function", tool.Name) + } + + s.toolsLock.Lock() + s.tools[tool.Name] = tool + s.toolsLock.Unlock() + // Notify clients about the new tool + s.broadcast(eventToolsListChanged, map[string][]Tool{keyTools: {tool}}) + return nil +} + +// Start implements McpServer. +func (s *sseMcpServer) Start() { + s.server.Start() +} + +func (s *sseMcpServer) Stop() { + s.server.Stop() +} + +// broadcast sends a message to all connected clients +// It uses Server-Sent Events (SSE) format for real-time communication +func (s *sseMcpServer) broadcast(event string, data any) { + jsonData, err := json.Marshal(data) + if err != nil { + logx.Errorf("Failed to marshal broadcast data: %v", err) + return + } + + // Lock only while reading the clients map + s.clientsLock.Lock() + clients := make([]*mcpClient, 0, len(s.clients)) + for _, client := range s.clients { + clients = append(clients, client) + } + s.clientsLock.Unlock() + + clientCount := len(clients) + if clientCount == 0 { + return + } + + logx.Infof("Broadcasting event '%s' to %d clients", event, clientCount) + + // Use CRLF line endings as per SSE specification + message := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(jsonData)) + + // Send messages without holding the lock + for _, client := range clients { + select { + case client.channel <- message: + // Message sent successfully + default: + // Channel buffer is full, log warning and continue + logx.Errorf("Client channel buffer full, dropping message for client %s", client.id) + } + } +} + +// cleanupClient removes a client from the active clients map +func (s *sseMcpServer) cleanupClient(sessionId string) { + s.clientsLock.Lock() + defer s.clientsLock.Unlock() + + if client, exists := s.clients[sessionId]; exists { + // Close the channel to signal any goroutines waiting on it + close(client.channel) + // Remove from active clients + delete(s.clients, sessionId) + logx.Infof("Cleaned up client %s (remaining clients: %d)", sessionId, len(s.clients)) + } +} + +// handleRequest handles MCP JSON-RPC requests +func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) { + // Extract sessionId from query parameters + sessionId := r.URL.Query().Get(sessionIdKey) + if len(sessionId) == 0 { + http.Error(w, fmt.Sprintf("Missing %s", sessionIdKey), http.StatusBadRequest) + return + } + + // Check if the client with this sessionId exists + s.clientsLock.Lock() + client, exists := s.clients[sessionId] + s.clientsLock.Unlock() + + if !exists { + http.Error(w, fmt.Sprintf("Invalid or expired %s", sessionIdKey), http.StatusBadRequest) + return + } + + var req Request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + w.WriteHeader(http.StatusAccepted) + + // For notification methods (no ID), we don't send a response + isNotification := req.ID == 0 + + // Special handling for initialization sequence + // Always allow initialize and notifications/initialized regardless of client state + if req.Method == methodInitialize { + logx.Infof("Processing initialize request with ID: %d", req.ID) + s.processInitialize(client, req) + logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID) + return + } else if req.Method == methodNotificationsInitialized { + // Handle initialized notification + logx.Info("Received notifications/initialized notification") + if !isNotification { + s.sendErrorResponse(client, req.ID, "Method should be used as a notification", errCodeInvalidRequest) + return + } + s.processNotificationInitialized(client) + return + } else if !client.initialized && req.Method != methodNotificationsCancelled { + // Block most requests until client is initialized (except for cancellations) + s.sendErrorResponse(client, req.ID, "Client not fully initialized, waiting for notifications/initialized", + errCodeClientNotInitialized) + return + } + + // Process normal requests only after initialization + switch req.Method { + case methodToolsCall: + logx.Infof("Received tools call request with ID: %d", req.ID) + s.processToolCall(client, req) + logx.Infof("Sent tools call response for ID: %d", req.ID) + case methodToolsList: + logx.Infof("Processing tools/list request with ID: %d", req.ID) + s.processListTools(client, req) + logx.Infof("Sent tools/list response for ID: %d", req.ID) + case methodPromptsList: + logx.Infof("Processing prompts/list request with ID: %d", req.ID) + s.processListPrompts(client, req) + logx.Infof("Sent prompts/list response for ID: %d", req.ID) + case methodPromptsGet: + logx.Infof("Processing prompts/get request with ID: %d", req.ID) + s.processGetPrompt(client, req) + logx.Infof("Sent prompts/get response for ID: %d", req.ID) + case methodResourcesList: + logx.Infof("Processing resources/list request with ID: %d", req.ID) + s.processListResources(client, req) + logx.Infof("Sent resources/list response for ID: %d", req.ID) + case methodResourcesRead: + logx.Infof("Processing resources/read request with ID: %d", req.ID) + s.processResourcesRead(client, req) + logx.Infof("Sent resources/read response for ID: %d", req.ID) + case methodResourcesSubscribe: + logx.Infof("Processing resources/subscribe request with ID: %d", req.ID) + s.processResourceSubscribe(client, req) + logx.Infof("Sent resources/subscribe response for ID: %d", req.ID) + case methodPing: + logx.Infof("Processing ping request with ID: %d", req.ID) + s.processPing(client, req) + case methodNotificationsCancelled: + logx.Infof("Received notifications/cancelled notification: %v", req.Params) + s.processNotificationCancelled(client, req) + default: + logx.Infof("Unknown method: %s", req.Method) + s.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound) + } +} + +// handleSSE handles Server-Sent Events connections +func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) { + // Generate a unique session ID for this client + sessionId := uuid.New().String() + + // Create new client with buffered channel to prevent blocking + client := &mcpClient{ + id: sessionId, + channel: make(chan string, eventChanSize), + } + + // Add client to active clients map + s.clientsLock.Lock() + s.clients[sessionId] = client + activeClients := len(s.clients) + s.clientsLock.Unlock() + + logx.Infof("New SSE connection established for client %s (active clients: %d)", + sessionId, activeClients) + + // Set proper SSE headers + w.Header().Set("Transfer-Encoding", "chunked") + + // Enable streaming + flusher, ok := w.(http.Flusher) + if !ok { + logx.Error("Streaming not supported by the underlying http.ResponseWriter") + http.Error(w, "Streaming not supported", http.StatusInternalServerError) + return + } + + // Send the message endpoint URL to the client + endpoint := fmt.Sprintf("%s%s?%s=%s", + s.conf.Mcp.BaseUrl, s.conf.Mcp.MessageEndpoint, sessionIdKey, sessionId) + + // Format and send the endpoint message + endpointMsg := formatSSEMessage(eventEndpoint, []byte(endpoint)) + if _, err := fmt.Fprint(w, endpointMsg); err != nil { + logx.Errorf("Failed to send endpoint message to client %s: %v", sessionId, err) + s.cleanupClient(sessionId) + return + } + flusher.Flush() + + // Set up keep-alive ping and client cleanup + ticker := time.NewTicker(pingInterval.Load()) + defer func() { + ticker.Stop() + s.cleanupClient(sessionId) + logx.Infof("SSE connection closed for client %s", sessionId) + }() + + // Message processing loop + for { + select { + case message, ok := <-client.channel: + if !ok { + // Channel was closed, end connection + logx.Infof("Client channel was closed for %s", sessionId) + return + } + + // Write message to the response + if _, err := fmt.Fprint(w, message); err != nil { + logx.Infof("Failed to write message to client %s: %v", sessionId, err) + return + } + flusher.Flush() + case <-ticker.C: + // Send keep-alive ping to maintain connection + ping := fmt.Sprintf(`{"type":"ping","timestamp":"%s"}`, time.Now().String()) + pingMsg := formatSSEMessage("ping", []byte(ping)) + if _, err := fmt.Fprint(w, pingMsg); err != nil { + logx.Errorf("Failed to send ping to client %s, closing connection: %v", sessionId, err) + return + } + flusher.Flush() + case <-r.Context().Done(): + // Client disconnected or request was canceled + logx.Infof("Client %s disconnected: context done", sessionId) + return + } + } +} + +// processInitialize processes the initialize request +func (s *sseMcpServer) processInitialize(client *mcpClient, req Request) { + // Create a proper JSON-RPC response that preserves the client's request ID + result := initializationResponse{ + ProtocolVersion: s.conf.Mcp.ProtocolVersion, + Capabilities: capabilities{ + Prompts: struct { + ListChanged bool `json:"listChanged"` + }{ + ListChanged: true, + }, + Resources: struct { + Subscribe bool `json:"subscribe"` + ListChanged bool `json:"listChanged"` + }{ + Subscribe: true, + ListChanged: true, + }, + Tools: struct { + ListChanged bool `json:"listChanged"` + }{ + ListChanged: true, + }, + }, + ServerInfo: serverInfo{ + Name: s.conf.Mcp.Name, + Version: s.conf.Mcp.Version, + }, + } + + // Mark client as initialized + client.initialized = true + + // Send response with client's original request ID + s.sendResponse(client, req.ID, result) +} + +// processListTools processes the tools/list request +func (s *sseMcpServer) processListTools(client *mcpClient, req Request) { + // Extract pagination params if any + var nextCursor string + var progressToken any + + // Extract meta data including progress token + if req.Params != nil { + var metaParams struct { + Cursor string `json:"cursor"` + Meta struct { + ProgressToken any `json:"progressToken"` + } `json:"_meta"` + } + if err := json.Unmarshal(req.Params, &metaParams); err == nil { + if len(metaParams.Cursor) > 0 { + nextCursor = metaParams.Cursor + } + progressToken = metaParams.Meta.ProgressToken + } + } + + var toolsList []Tool + s.toolsLock.Lock() + for _, tool := range s.tools { + toolsList = append(toolsList, tool) + } + s.toolsLock.Unlock() + + result := ListToolsResult{ + PaginatedResult: PaginatedResult{ + Result: Result{}, + NextCursor: Cursor(nextCursor), + }, + Tools: toolsList, + } + + // Add meta information if progress token was provided + if progressToken != nil { + result.Result.Meta = map[string]any{ + "progressToken": progressToken, + } + } + + s.sendResponse(client, req.ID, result) +} + +// processListPrompts processes the prompts/list request +func (s *sseMcpServer) processListPrompts(client *mcpClient, req Request) { + // Extract pagination params if any + var nextCursor string + if req.Params != nil { + var cursorParams struct { + Cursor string `json:"cursor"` + } + if err := json.Unmarshal(req.Params, &cursorParams); err == nil && cursorParams.Cursor != "" { + // If we have a valid cursor, we could use it for pagination + // For now, we're not actually implementing pagination, so this is just + // to show how it would be extracted from the request + _ = cursorParams.Cursor + } + } + + // Prepare prompt list + var promptsList []Prompt + s.promptsLock.Lock() + for _, prompt := range s.prompts { + promptsList = append(promptsList, prompt) + } + s.promptsLock.Unlock() + + // In a real implementation, you'd handle pagination here + // For now, we'll return all prompts at once + result := struct { + Prompts []Prompt `json:"prompts"` + NextCursor string `json:"nextCursor,omitempty"` + Meta *struct{} `json:"_meta,omitempty"` + }{ + Prompts: promptsList, + NextCursor: nextCursor, + } + + s.sendResponse(client, req.ID, result) +} + +// processListResources processes the resources/list request +func (s *sseMcpServer) processListResources(client *mcpClient, req Request) { + // Extract pagination params if any + var nextCursor string + var progressToken any + + // Extract meta information including progress token if available + if req.Params != nil { + var metaParams PaginatedParams + if err := json.Unmarshal(req.Params, &metaParams); err == nil { + if len(metaParams.Cursor) > 0 { + nextCursor = metaParams.Cursor + } + progressToken = metaParams.Meta.ProgressToken + } + } + + var resourcesList []Resource + s.resourcesLock.Lock() + for _, resource := range s.resources { + // Create a copy without the handler function which shouldn't be sent to clients + resourceCopy := Resource{ + URI: resource.URI, + Name: resource.Name, + Description: resource.Description, + MimeType: resource.MimeType, + } + resourcesList = append(resourcesList, resourceCopy) + } + s.resourcesLock.Unlock() + + // Create proper ResourcesListResult according to MCP specification + result := ResourcesListResult{ + PaginatedResult: PaginatedResult{ + Result: Result{}, + NextCursor: Cursor(nextCursor), + }, + Resources: resourcesList, + } + + // Add meta information if progress token was provided + if progressToken != nil { + result.Result.Meta = map[string]any{ + "progressToken": progressToken, + } + } + + s.sendResponse(client, req.ID, result) +} + +// processGetPrompt processes the prompts/get request +func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) { + type GetPromptParams struct { + Name string `json:"name"` + Arguments map[string]string `json:"arguments,omitempty"` + } + + var params GetPromptParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams) + return + } + + // Check if prompt exists + s.promptsLock.Lock() + prompt, exists := s.prompts[params.Name] + s.promptsLock.Unlock() + if !exists { + message := fmt.Sprintf("Prompt '%s' not found", params.Name) + s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams) + return + } + + logx.Infof("Processing prompt request: %s with %d arguments", prompt.Name, len(params.Arguments)) + + // Validate required arguments + missingArgs := validatePromptArguments(prompt, params.Arguments) + if len(missingArgs) > 0 { + message := fmt.Sprintf("Missing required arguments: %s", strings.Join(missingArgs, ", ")) + s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams) + return + } + + // Apply default values for missing optional arguments + args := applyDefaultArguments(prompt, params.Arguments) + + // Generate messages using handler or static content + var messages []PromptMessage + var err error + + if prompt.Handler != nil { + // Use dynamic handler to generate messages + logx.Info("Using prompt handler to generate content") + messages, err = prompt.Handler(args) + if err != nil { + logx.Errorf("Error from prompt handler: %v", err) + s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError) + return + } + } else { + // No handler, generate messages from static content + var messageText string + if prompt.Content != "" { + messageText = prompt.Content + + // Apply argument substitutions to static content + for key, value := range args { + placeholder := fmt.Sprintf("{{%s}}", key) + messageText = strings.Replace(messageText, placeholder, value, -1) + } + } else { + // No content, use a default fallback + topic := "this topic" + if t, ok := args["topic"]; ok && t != "" { + topic = t + } + messageText = fmt.Sprintf("Tell me about %s", topic) + } + + // Create a single user message with the content + messages = []PromptMessage{ + { + Role: roleUser, + Content: TextContent{ + Type: contentTypeText, + Text: messageText, + }, + }, + } + } + + // Construct the response according to MCP spec + result := struct { + Description string `json:"description,omitempty"` + Messages []PromptMessage `json:"messages"` + }{ + Description: prompt.Description, + Messages: messages, + } + + s.sendResponse(client, req.ID, result) +} + +// validatePromptArguments checks if all required arguments are provided +// Returns a list of missing required arguments +func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string { + var missingArgs []string + + for _, arg := range prompt.Arguments { + if arg.Required { + if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 { + missingArgs = append(missingArgs, arg.Name) + } + } + } + + return missingArgs +} + +// applyDefaultArguments adds default values for missing optional arguments +func applyDefaultArguments(prompt Prompt, providedArgs map[string]string) map[string]string { + result := make(map[string]string) + + // Copy all provided arguments + for k, v := range providedArgs { + result[k] = v + } + + // Add defaults for missing arguments + for _, arg := range prompt.Arguments { + if _, exists := result[arg.Name]; !exists && arg.Default != "" { + result[arg.Name] = arg.Default + } + } + + return result +} + +// processToolCall processes the tools/call request +func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) { + var toolCallParams struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments,omitempty"` + Meta struct { + ProgressToken any `json:"progressToken"` + } `json:"_meta,omitempty"` + } + + // Handle different types of req.Params + // If it's a RawMessage (JSON), unmarshal it + if err := json.Unmarshal(req.Params, &toolCallParams); err != nil { + logx.Errorf("Failed to unmarshal tool call params: %v", err) + s.sendErrorResponse(client, req.ID, "Invalid tool call parameters", errCodeInvalidParams) + return + } + + // Extract progress token if available + progressToken := toolCallParams.Meta.ProgressToken + + // Find the requested tool + s.toolsLock.Lock() + tool, exists := s.tools[toolCallParams.Name] + s.toolsLock.Unlock() + if !exists { + s.sendErrorResponse(client, req.ID, fmt.Sprintf("Tool '%s' not found", + toolCallParams.Name), errCodeInvalidParams) + return + } + + // Create a context with the configured timeout + ctx, cancel := context.WithTimeout(context.Background(), s.conf.Mcp.ToolTimeout) + defer cancel() + + // Log parameters before execution + logx.Infof("Executing tool '%s' with arguments: %#v", toolCallParams.Name, toolCallParams.Arguments) + + // Execute the tool handler with timeout handling + var result any + var err error + + // Create a channel to receive the result + resultCh := make(chan struct { + result any + err error + }, 1) + + // Execute the tool handler in a goroutine + go func() { + toolResult, toolErr := tool.Handler(toolCallParams.Arguments) + resultCh <- struct { + result any + err error + }{ + result: toolResult, + err: toolErr, + } + }() + + // Wait for either the result or a timeout + select { + case res := <-resultCh: + result = res.result + err = res.err + case <-ctx.Done(): + // Handle timeout + logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.ToolTimeout, toolCallParams.Name) + s.sendErrorResponse(client, req.ID, "Tool execution timed out", errCodeTimeout) + return + } + + // Create the base result structure with metadata + callToolResult := CallToolResult{ + Result: Result{}, + Content: []any{}, + IsError: false, + } + + // Add meta information if progress token was provided + if progressToken != nil { + callToolResult.Result.Meta = map[string]any{ + "progressToken": progressToken, + } + } + + // Check if there was an error during tool execution + if err != nil { + // According to the spec, for tool-level errors (as opposed to protocol-level errors), + // we should report them inside the result with isError=true + logx.Errorf("Tool execution reported error: %v", err) + + callToolResult.Content = []any{ + TextContent{ + Type: contentTypeText, + Text: fmt.Sprintf("Error: %v", err), + }, + } + callToolResult.IsError = true + s.sendResponse(client, req.ID, callToolResult) + return + } + + // Format the response according to the CallToolResult schema + switch v := result.(type) { + case string: + // Simple string becomes text content + callToolResult.Content = append(callToolResult.Content, TextContent{ + Type: contentTypeText, + Text: v, + Annotations: &Annotations{ + Audience: []roleType{roleUser, roleAssistant}, + }, + }) + case map[string]any: + // JSON-like object becomes formatted JSON text + jsonStr, err := json.Marshal(v) + if err != nil { + jsonStr = []byte(err.Error()) + } + callToolResult.Content = append(callToolResult.Content, TextContent{ + Type: contentTypeText, + Text: string(jsonStr), + Annotations: &Annotations{ + Audience: []roleType{roleUser, roleAssistant}, + }, + }) + case TextContent: + // Direct TextContent object + callToolResult.Content = append(callToolResult.Content, v) + case ImageContent: + // Direct ImageContent object + callToolResult.Content = append(callToolResult.Content, v) + case []any: + // Array of content items + callToolResult.Content = v + case ToolResult: + // Handle legacy ToolResult type + switch v.Type { + case contentTypeText: + callToolResult.Content = append(callToolResult.Content, TextContent{ + Type: contentTypeText, + Text: fmt.Sprintf("%v", v.Content), + Annotations: &Annotations{ + Audience: []roleType{roleUser, roleAssistant}, + }, + }) + case contentTypeImage: + if imgData, ok := v.Content.(map[string]any); ok { + callToolResult.Content = append(callToolResult.Content, ImageContent{ + Type: contentTypeImage, + Data: fmt.Sprintf("%v", imgData["data"]), + MimeType: fmt.Sprintf("%v", imgData["mimeType"]), + }) + } + default: + callToolResult.Content = append(callToolResult.Content, TextContent{ + Type: contentTypeText, + Text: fmt.Sprintf("%v", v.Content), + Annotations: &Annotations{ + Audience: []roleType{roleUser, roleAssistant}, + }, + }) + } + default: + // For any other type, convert to string + callToolResult.Content = append(callToolResult.Content, TextContent{ + Type: contentTypeText, + Text: fmt.Sprintf("%v", v), + Annotations: &Annotations{ + Audience: []roleType{roleUser, roleAssistant}, + }, + }) + } + + logx.Infof("Tool call result: %#v", callToolResult) + s.sendResponse(client, req.ID, callToolResult) +} + +// processResourcesRead processes the resources/read request +func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) { + var params ResourceReadParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams) + return + } + + // Find resource that matches the URI + s.resourcesLock.Lock() + resource, exists := s.resources[params.URI] + s.resourcesLock.Unlock() + + if !exists { + s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found", + params.URI), errCodeResourceNotFound) + return + } + + // If no handler is provided, return an empty content array + if resource.Handler == nil { + result := ResourceReadResult{ + Contents: []ResourceContent{ + { + URI: params.URI, + MimeType: resource.MimeType, + Text: "", + }, + }, + } + s.sendResponse(client, req.ID, result) + return + } + + // Execute the resource handler + content, err := resource.Handler() + if err != nil { + s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error reading resource: %v", err), + errCodeInternalError) + return + } + + // Ensure the URI is set if not already provided by the handler + if len(content.URI) == 0 { + content.URI = params.URI + } + + // Ensure MimeType is set if available from the resource definition + if len(content.MimeType) == 0 && resource.MimeType != "" { + content.MimeType = resource.MimeType + } + + // Create response with contents from the handler + // The MCP specification requires a contents array + result := ResourceReadResult{ + Contents: []ResourceContent{content}, + } + + s.sendResponse(client, req.ID, result) +} + +// processResourceSubscribe processes the resources/subscribe request +func (s *sseMcpServer) processResourceSubscribe(client *mcpClient, req Request) { + var params ResourceSubscribeParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams) + return + } + + // Check if the resource exists + s.resourcesLock.Lock() + _, exists := s.resources[params.URI] + s.resourcesLock.Unlock() + + if !exists { + s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found", + params.URI), errCodeResourceNotFound) + return + } + + // Send success response for the subscription + s.sendResponse(client, req.ID, struct{}{}) + + logx.Infof("Client %s subscribed to resource '%s'", client.id, params.URI) +} + +// processNotificationCancelled processes the notifications/cancelled notification +func (s *sseMcpServer) processNotificationCancelled(client *mcpClient, req Request) { + // Extract the requestId that was canceled + type CancelParams struct { + RequestId int64 `json:"requestId"` + Reason string `json:"reason"` + } + + var params CancelParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + logx.Errorf("Failed to parse cancellation params: %v", err) + return + } + + logx.Infof("Request %d was cancelled by client. Reason: %s", params.RequestId, params.Reason) +} + +// processNotificationInitialized processes the notifications/initialized notification +func (s *sseMcpServer) processNotificationInitialized(client *mcpClient) { + // Mark the client as properly initialized + client.initialized = true + logx.Infof("Client %s is now fully initialized and ready for normal operations", client.id) +} + +// processPing processes the ping request and responds immediately +func (s *sseMcpServer) processPing(client *mcpClient, req Request) { + // A ping request should simply respond with an empty result to confirm the server is alive + logx.Infof("Received ping request with ID: %d", req.ID) + + // Send an empty response with client's original request ID + s.sendResponse(client, req.ID, struct{}{}) + + logx.Infof("Sent ping response for ID: %d", req.ID) +} + +// sendErrorResponse sends an error response via the SSE channel +func (s *sseMcpServer) sendErrorResponse(client *mcpClient, id int64, message string, code int) { + errorResponse := struct { + JsonRpc string `json:"jsonrpc"` + ID int64 `json:"id"` + Error errorMessage `json:"error"` + }{ + JsonRpc: jsonRpcVersion, + ID: id, + Error: errorMessage{ + Code: code, + Message: message, + }, + } + + // all fields are primitive types, impossible to fail + jsonData, _ := json.Marshal(errorResponse) + // Use CRLF line endings as requested + sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData)) + logx.Infof("Sending error for ID %d: %s", id, sseMessage) + + client.channel <- sseMessage +} + +// sendResponse sends a success response via the SSE channel +func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) { + response := Response{ + JsonRpc: jsonRpcVersion, + ID: id, + Result: result, + } + + jsonData, err := json.Marshal(response) + if err != nil { + s.sendErrorResponse(client, id, "Failed to marshal response", errCodeInternalError) + return + } + + // Use CRLF line endings as requested + sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData)) + logx.Infof("Sending response for ID %d: %s", id, sseMessage) + + client.channel <- sseMessage +} diff --git a/mcp/server_test.go b/mcp/server_test.go new file mode 100644 index 000000000..979060e79 --- /dev/null +++ b/mcp/server_test.go @@ -0,0 +1,3418 @@ +package mcp + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zeromicro/go-zero/core/conf" + "github.com/zeromicro/go-zero/core/logx/logtest" +) + +// mockMcpServer is a helper for testing the MCP server +// It encapsulates the server and test server setup and teardown logic +type mockMcpServer struct { + server *sseMcpServer + testServer *httptest.Server + requestId int64 +} + +// newMockMcpServer initializes a mock MCP server for testing +func newMockMcpServer(t *testing.T) *mockMcpServer { + const yamlConf = `name: test-server +host: localhost +port: 8080 +mcp: + name: mcp-test-server + toolTimeout: 5s +` + + var c McpConf + assert.NoError(t, conf.LoadFromYamlBytes([]byte(yamlConf), &c)) + + server := NewMcpServer(c).(*sseMcpServer) + mux := http.NewServeMux() + mux.HandleFunc(c.Mcp.SseEndpoint, server.handleSSE) + mux.HandleFunc(c.Mcp.MessageEndpoint, server.handleRequest) + testServer := httptest.NewServer(mux) + server.conf.Mcp.BaseUrl = testServer.URL + + return &mockMcpServer{ + server: server, + testServer: testServer, + requestId: 1, + } +} + +// shutdown closes the test server +func (m *mockMcpServer) shutdown() { + m.testServer.Close() +} + +// registerExamplePrompt registers a test prompt +func (m *mockMcpServer) registerExamplePrompt() { + m.server.RegisterPrompt(Prompt{ + Name: "test.prompt", + Description: "A test prompt", + }) +} + +// registerExampleResource registers a test resource +func (m *mockMcpServer) registerExampleResource() { + m.server.RegisterResource(Resource{ + Name: "test.resource", + URI: "file:///test.file", + Description: "A test resource", + }) +} + +// registerExampleTool registers a test tool +func (m *mockMcpServer) registerExampleTool() { + _ = m.server.RegisterTool(Tool{ + Name: "test.tool", + Description: "A test tool", + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]any{ + "input": map[string]any{ + "type": "string", + "description": "Input parameter", + }, + }, + Required: []string{"input"}, + }, + Handler: func(params map[string]any) (any, error) { + input, ok := params["input"].(string) + if !ok { + return nil, fmt.Errorf("invalid input parameter") + } + return fmt.Sprintf("Processed: %s", input), nil + }, + }) +} + +// Helper function to create and add a test client +func addTestClient(server *sseMcpServer, clientID string, initialized bool) *mcpClient { + client := &mcpClient{ + id: clientID, + channel: make(chan string, eventChanSize), + initialized: initialized, + } + server.clientsLock.Lock() + server.clients[clientID] = client + server.clientsLock.Unlock() + return client +} + +func TestNewMcpServer(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + mock.registerExamplePrompt() + mock.registerExampleResource() + mock.registerExampleTool() + + require.NotNil(t, mock.server, "Server should be created") + assert.NotEmpty(t, mock.server.tools, "Tools map should be initialized") + assert.NotEmpty(t, mock.server.prompts, "Prompts map should be initialized") + assert.NotEmpty(t, mock.server.resources, "Resources map should be initialized") +} + +func TestNewMcpServer_WithCors(t *testing.T) { + const yamlConf = `name: test-server +host: localhost +port: 8080 +mcp: + cors: + - http://localhost:3000 + toolTimeout: 5s +` + + var c McpConf + assert.NoError(t, conf.LoadFromYamlBytes([]byte(yamlConf), &c)) + + server := NewMcpServer(c).(*sseMcpServer) + assert.Equal(t, "test-server", server.conf.Name, "Server name should be set") +} + +func TestHandleRequest_badRequest(t *testing.T) { + t.Run("empty session ID", func(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Create a request with an invalid session ID + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 1, + Method: methodToolsCall, + Params: []byte(`{"sessionId": "invalid-session"}`), + } + + jsonBody, _ := json.Marshal(req) + r := httptest.NewRequest(http.MethodPost, "/?session_id=", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + mock.server.handleRequest(w, r) + assert.Equal(t, http.StatusBadRequest, w.Code) + }) + + t.Run("bad body", func(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + addTestClient(mock.server, "test-session", true) + + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-session", bytes.NewReader([]byte(`{`))) + w := httptest.NewRecorder() + mock.server.handleRequest(w, r) + assert.Equal(t, http.StatusBadRequest, w.Code) + }) +} + +func TestRegisterTool(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + tool := Tool{ + Name: "example.tool", + Description: "An example tool", + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]any{ + "input": map[string]any{ + "type": "string", + "description": "Input parameter", + }, + }, + }, + Handler: func(params map[string]any) (any, error) { + return "result", nil + }, + } + + // Test with valid tool + err := mock.server.RegisterTool(tool) + assert.NoError(t, err, "Should not error with valid tool") + + // Check tool was registered + _, exists := mock.server.tools["example.tool"] + assert.True(t, exists, "Tool should be registered") + + // Test with missing handler + invalidTool := tool + invalidTool.Name = "invalid.tool" + invalidTool.Handler = nil + err = mock.server.RegisterTool(invalidTool) + assert.Error(t, err, "Should error with missing handler") +} + +func TestRegisterPrompt(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + prompt := Prompt{ + Name: "example.prompt", + Description: "An example prompt", + } + + // Test registering prompt + mock.server.RegisterPrompt(prompt) + + // Check prompt was registered + _, exists := mock.server.prompts["example.prompt"] + assert.True(t, exists, "Prompt should be registered") +} + +func TestRegisterResource(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + resource := Resource{ + Name: "example.resource", + URI: "http://example.com/resource", + Description: "An example resource", + } + + // Test registering resource + mock.server.RegisterResource(resource) + + // Check resource was registered + _, exists := mock.server.resources["http://example.com/resource"] + assert.True(t, exists, "Resource should be registered") +} + +// TestToolCallBasic tests the basic functionality of a tool call +func TestToolsList(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Register a test tool + mock.registerExampleTool() + + // Simulate a client to test tool call + client := addTestClient(mock.server, "test-client", true) + + // Create a tool call request + params := struct { + Cursor string `json:"cursor"` + Meta struct { + ProgressToken any `json:"progressToken"` + } `json:"_meta"` + }{ + Cursor: "my-cursor", + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 1, + Method: methodToolsList, + Params: paramBytes, + } + + // Process the tool call + mock.server.processListTools(client, req) + + // Get the response from the client's channel + select { + case response := <-client.channel: + evt, err := parseEvent(response) + assert.NoError(t, err) + + assert.Equal(t, eventMessage, evt.Type, "Event type should be message") + result, ok := evt.Data["result"].(map[string]any) + assert.True(t, ok) + assert.Equal(t, "my-cursor", result["nextCursor"], "Cursor should match") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for tool call response") + } +} + +// TestToolCallBasic tests the basic functionality of a tool call +func TestToolCallBasic(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Register a test tool + mock.registerExampleTool() + + // Simulate a client to test tool call + client := addTestClient(mock.server, "test-client", true) + + // Create a tool call request + params := struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + }{ + Name: "test.tool", + Arguments: map[string]any{ + "input": "test-input", + }, + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 1, + Method: methodToolsCall, + Params: paramBytes, + } + + // Process the tool call + mock.server.processToolCall(client, req) + + // Get the response from the client's channel + select { + case response := <-client.channel: + // Check response format + assert.Contains(t, response, "event: message", "Response should have message event") + assert.Contains(t, response, "data:", "Response should have data") + + // Extract JSON from the SSE response + jsonStart := strings.Index(response, "{") + jsonEnd := strings.LastIndex(response, "}") + jsonStr := response[jsonStart : jsonEnd+1] + + // Parse the JSON + var parsed struct { + Result struct { + Content []map[string]any `json:"content"` + IsError bool `json:"isError"` + } `json:"result"` + } + + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "Response should be valid JSON") + + // Verify the response content + assert.Len(t, parsed.Result.Content, 1, "Response should contain one content item") + assert.Equal(t, "text", parsed.Result.Content[0]["type"], "Content type should be text") + assert.Equal(t, "Processed: test-input", parsed.Result.Content[0]["text"], "Tool result incorrect") + assert.False(t, parsed.Result.IsError, "Response should not be an error") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for tool call response") + } +} + +// TestToolCallMapResult tests a tool that returns a map[string]any result +func TestToolCallMapResult(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Register a tool that returns a map + mapTool := Tool{ + Name: "map.tool", + Description: "A tool that returns a map result", + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]any{}, + }, + Handler: func(params map[string]any) (any, error) { + // Return a complex nested map structure + return map[string]any{ + "string": "value", + "number": 42, + "boolean": true, + "nested": map[string]any{ + "array": []string{"item1", "item2"}, + "obj": map[string]any{ + "key": "value", + }, + }, + "nullValue": nil, + }, nil + }, + } + + err := mock.server.RegisterTool(mapTool) + require.NoError(t, err) + + // Create a client + client := addTestClient(mock.server, "test-client", true) + + // Create a tool call request + params := struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + }{ + Name: "map.tool", + Arguments: map[string]any{}, + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 1, + Method: methodToolsCall, + Params: paramBytes, + } + + // Process the tool call + mock.server.processToolCall(client, req) + + // Get the response from the client's channel + select { + case response := <-client.channel: + // Parse the response + jsonStart := strings.Index(response, "{") + jsonEnd := strings.LastIndex(response, "}") + require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") + + jsonStr := response[jsonStart : jsonEnd+1] + var parsed map[string]any + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "Response should be valid JSON") + + // Get the result + result, ok := parsed["result"].(map[string]any) + require.True(t, ok, "Response should have a result object") + + // Get the content array + content, ok := result["content"].([]any) + require.True(t, ok, "Result should have a content array") + require.NotEmpty(t, content, "Content should not be empty") + + // The first content item should be our map result converted to JSON text + firstItem, ok := content[0].(map[string]any) + require.True(t, ok, "First content item should be an object") + + // Verify it's a text content + contentType, ok := firstItem["type"].(string) + require.True(t, ok, "Content should have a type") + assert.Equal(t, "text", contentType, "Content type should be text") + + // Get the text content which should be our JSON + text, ok := firstItem["text"].(string) + require.True(t, ok, "Content should have text") + + // Verify the text is valid JSON and contains our data + var mapResult map[string]any + err = json.Unmarshal([]byte(text), &mapResult) + require.NoError(t, err, "Text should be valid JSON") + + // Verify the content of our map + assert.Equal(t, "value", mapResult["string"], "String value should match") + assert.Equal(t, float64(42), mapResult["number"], "Number value should match") + assert.Equal(t, true, mapResult["boolean"], "Boolean value should match") + + // Check nested structure + nested, ok := mapResult["nested"].(map[string]any) + require.True(t, ok, "Should have nested map") + + array, ok := nested["array"].([]any) + require.True(t, ok, "Should have array in nested map") + assert.Len(t, array, 2, "Array should have 2 items") + assert.Equal(t, "item1", array[0], "First array item should match") + + obj, ok := nested["obj"].(map[string]any) + require.True(t, ok, "Should have obj in nested map") + assert.Equal(t, "value", obj["key"], "Nested object key should match") + + // Check null value + _, exists := mapResult["nullValue"] + assert.True(t, exists, "Null value key should exist") + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for tool call response") + } +} + +// TestToolCallArrayResult tests a tool that returns an array result +func TestToolCallArrayResult(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Register a tool that returns an array + arrayTool := Tool{ + Name: "array.tool", + Description: "A tool that returns an array result", + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]any{}, + }, + Handler: func(params map[string]any) (any, error) { + // Return an array of mixed content types + return []any{ + "string item", + 42, + true, + map[string]any{"key": "value"}, + []string{"nested", "array"}, + nil, + }, nil + }, + } + + err := mock.server.RegisterTool(arrayTool) + require.NoError(t, err) + + // Create a client + client := addTestClient(mock.server, "test-client", true) + + // Create a tool call request + params := struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + }{ + Name: "array.tool", + Arguments: map[string]any{}, + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 1, + Method: methodToolsCall, + Params: paramBytes, + } + + // Process the tool call + mock.server.processToolCall(client, req) + + // Get the response from the client's channel + select { + case response := <-client.channel: + // Parse the response + jsonStart := strings.Index(response, "{") + jsonEnd := strings.LastIndex(response, "}") + require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") + + jsonStr := response[jsonStart : jsonEnd+1] + var parsed map[string]any + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "Response should be valid JSON") + + // Get the result + result, ok := parsed["result"].(map[string]any) + require.True(t, ok, "Response should have a result object") + + // Get the content array + content, ok := result["content"].([]any) + require.True(t, ok, "Result should have a content array") + require.Equal(t, 6, len(content), "Content should have 6 items, one for each array item") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for tool call response") + } +} + +// TestToolCallTextContentResult tests a tool that returns a TextContent result +func TestToolCallTextContentResult(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Register a tool that returns a TextContent + textContentTool := Tool{ + Name: "text.content.tool", + Description: "A tool that returns a TextContent result", + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]any{}, + }, + Handler: func(params map[string]any) (any, error) { + // Return a TextContent object directly + return TextContent{ + Type: "text", + Text: "This is a direct TextContent result", + Annotations: &Annotations{ + Audience: []roleType{roleUser, roleAssistant}, + Priority: func() *float64 { p := 0.9; return &p }(), + }, + }, nil + }, + } + + err := mock.server.RegisterTool(textContentTool) + require.NoError(t, err) + + // Create a client + client := addTestClient(mock.server, "test-client", true) + + // Create a tool call request + params := struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + }{ + Name: "text.content.tool", + Arguments: map[string]any{}, + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 1, + Method: methodToolsCall, + Params: paramBytes, + } + + // Process the tool call + mock.server.processToolCall(client, req) + + // Get the response from the client's channel + select { + case response := <-client.channel: + // Parse the response + jsonStart := strings.Index(response, "{") + jsonEnd := strings.LastIndex(response, "}") + require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") + + jsonStr := response[jsonStart : jsonEnd+1] + var parsed map[string]any + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "Response should be valid JSON") + + // Get the result + result, ok := parsed["result"].(map[string]any) + require.True(t, ok, "Response should have a result object") + + // Get the content array + content, ok := result["content"].([]any) + require.True(t, ok, "Result should have a content array") + require.NotEmpty(t, content, "Content should not be empty") + + // The first content item should be our TextContent + firstItem, ok := content[0].(map[string]any) + require.True(t, ok, "First content item should be an object") + + // Verify it's a text content + contentType, ok := firstItem["type"].(string) + require.True(t, ok, "Content should have a type") + assert.Equal(t, "text", contentType, "Content type should be text") + + // Check text content + text, ok := firstItem["text"].(string) + require.True(t, ok, "Content should have text") + assert.Equal(t, "This is a direct TextContent result", text, "Text content should match") + + // Check annotations + annotations, ok := firstItem["annotations"].(map[string]any) + require.True(t, ok, "Should have annotations") + + audience, ok := annotations["audience"].([]any) + require.True(t, ok, "Should have audience in annotations") + assert.Len(t, audience, 2, "Audience should have 2 items") + + priority, ok := annotations["priority"].(float64) + require.True(t, ok, "Should have priority in annotations") + assert.Equal(t, 0.9, priority, "Priority should match") + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for tool call response") + } +} + +// TestToolCallImageContentResult tests a tool that returns an ImageContent result +func TestToolCallImageContentResult(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Register a tool that returns an ImageContent + imageContentTool := Tool{ + Name: "image.content.tool", + Description: "A tool that returns an ImageContent result", + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]any{}, + }, + Handler: func(params map[string]any) (any, error) { + // Return an ImageContent object directly + return ImageContent{ + Type: "image", + Data: "dGVzdCBpbWFnZSBkYXRhIChiYXNlNjQgZW5jb2RlZCk=", // "test image data (base64 encoded)" in base64 + MimeType: "image/png", + }, nil + }, + } + + err := mock.server.RegisterTool(imageContentTool) + require.NoError(t, err) + + // Create a client + client := addTestClient(mock.server, "test-client", true) + + // Create a tool call request + params := struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + }{ + Name: "image.content.tool", + Arguments: map[string]any{}, + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 1, + Method: methodToolsCall, + Params: paramBytes, + } + + // Process the tool call + mock.server.processToolCall(client, req) + + // Get the response from the client's channel + select { + case response := <-client.channel: + // Parse the response + jsonStart := strings.Index(response, "{") + jsonEnd := strings.LastIndex(response, "}") + require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") + + jsonStr := response[jsonStart : jsonEnd+1] + var parsed map[string]any + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "Response should be valid JSON") + + // Get the result + result, ok := parsed["result"].(map[string]any) + require.True(t, ok, "Response should have a result object") + + // Get the content array + content, ok := result["content"].([]any) + require.True(t, ok, "Result should have a content array") + require.NotEmpty(t, content, "Content should not be empty") + + // The first content item should be our ImageContent + firstItem, ok := content[0].(map[string]any) + require.True(t, ok, "First content item should be an object") + + // Verify it's an image content + contentType, ok := firstItem["type"].(string) + require.True(t, ok, "Content should have a type") + assert.Equal(t, "image", contentType, "Content type should be image") + + // Check image data + data, ok := firstItem["data"].(string) + require.True(t, ok, "Content should have data") + assert.Equal(t, "dGVzdCBpbWFnZSBkYXRhIChiYXNlNjQgZW5jb2RlZCk=", data, "Image data should match") + + // Check mime type + mimeType, ok := firstItem["mimeType"].(string) + require.True(t, ok, "Content should have mimeType") + assert.Equal(t, "image/png", mimeType, "MimeType should match") + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for tool call response") + } +} + +// TestToolCallToolResultType tests a tool that returns a ToolResult type +func TestToolCallToolResultType(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + toolResultTool := Tool{ + Name: "toolresult.tool", + Description: "A tool that returns a ToolResult object", + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]any{}, + }, + Handler: func(params map[string]any) (any, error) { + return ToolResult{ + Type: "text", + Content: "This is a ToolResult with text content type", + }, nil + }, + } + err := mock.server.RegisterTool(toolResultTool) + require.NoError(t, err) + + toolResultImageTool := Tool{ + Name: "toolresult.image.tool", + Description: "A tool that returns a ToolResult with image content", + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]any{}, + }, + Handler: func(params map[string]any) (any, error) { + return ToolResult{ + Type: "image", + Content: map[string]any{ + "data": "dGVzdCBpbWFnZSBkYXRhIGZvciB0b29sIHJlc3VsdA==", // "test image data for tool result" in base64 + "mimeType": "image/jpeg", + }, + }, nil + }, + } + err = mock.server.RegisterTool(toolResultImageTool) + require.NoError(t, err) + + toolResultAudioTool := Tool{ + Name: "toolresult.audio.tool", + Description: "A tool that returns a ToolResult with audio content", + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]any{}, + }, + Handler: func(params map[string]any) (any, error) { + // Test with image type + return ToolResult{ + Type: "audio", + Content: map[string]any{ + "data": "dGVzdCBpbWFnZSBkYXRhIGZvciB0b29sIHJlc3VsdA==", // "test image data for tool result" in base64 + "mimeType": "audio", + }, + }, nil + }, + } + err = mock.server.RegisterTool(toolResultAudioTool) + require.NoError(t, err) + + toolResultIntType := Tool{ + Name: "toolresult.int.tool", + Description: "A tool that returns a ToolResult with int content", + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]any{}, + }, + Handler: func(params map[string]any) (any, error) { + return 2, nil + }, + } + err = mock.server.RegisterTool(toolResultIntType) + require.NoError(t, err) + + toolResultBadType := Tool{ + Name: "toolresult.bad.tool", + Description: "A tool that returns a ToolResult with bad content", + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]any{}, + }, + Handler: func(params map[string]any) (any, error) { + return map[string]any{ + "type": "custom", + "data": make(chan int), + }, nil + }, + } + err = mock.server.RegisterTool(toolResultBadType) + require.NoError(t, err) + + // Test text ToolResult + t.Run("textToolResult", func(t *testing.T) { + // Create a client + client := addTestClient(mock.server, "test-client-text", true) + + // Create a tool call request + params := struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + }{ + Name: "toolresult.tool", + Arguments: map[string]any{}, + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 1, + Method: methodToolsCall, + Params: paramBytes, + } + + // Process the tool call + mock.server.processToolCall(client, req) + + // Get the response from the client's channel + select { + case response := <-client.channel: + // Parse the response + jsonStart := strings.Index(response, "{") + jsonEnd := strings.LastIndex(response, "}") + require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") + + jsonStr := response[jsonStart : jsonEnd+1] + var parsed map[string]any + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "Response should be valid JSON") + + // Get the result + result, ok := parsed["result"].(map[string]any) + require.True(t, ok, "Response should have a result object") + + // Get the content array + content, ok := result["content"].([]any) + require.True(t, ok, "Result should have a content array") + require.NotEmpty(t, content, "Content should not be empty") + + // The first content item should be converted from ToolResult to TextContent + firstItem, ok := content[0].(map[string]any) + require.True(t, ok, "First content item should be an object") + + // Verify it's a text content + contentType, ok := firstItem["type"].(string) + require.True(t, ok, "Content should have a type") + assert.Equal(t, "text", contentType, "Content type should be text") + + // Check text content + text, ok := firstItem["text"].(string) + require.True(t, ok, "Content should have text") + assert.Equal(t, "This is a ToolResult with text content type", text, "Text content should match") + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for tool call response") + } + }) + + // Test image ToolResult + t.Run("imageToolResult", func(t *testing.T) { + // Create a client + client := addTestClient(mock.server, "test-client-image", true) + + // Create a tool call request + params := struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + }{ + Name: "toolresult.image.tool", + Arguments: map[string]any{}, + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 2, + Method: methodToolsCall, + Params: paramBytes, + } + + // Process the tool call + mock.server.processToolCall(client, req) + + // Get the response from the client's channel + select { + case response := <-client.channel: + // Parse the response + jsonStart := strings.Index(response, "{") + jsonEnd := strings.LastIndex(response, "}") + require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") + + jsonStr := response[jsonStart : jsonEnd+1] + var parsed map[string]any + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "Response should be valid JSON") + + // Get the result + result, ok := parsed["result"].(map[string]any) + require.True(t, ok, "Response should have a result object") + + // Get the content array + content, ok := result["content"].([]any) + require.True(t, ok, "Result should have a content array") + require.NotEmpty(t, content, "Content should not be empty") + + // The first content item should be converted from ToolResult to ImageContent + firstItem, ok := content[0].(map[string]any) + require.True(t, ok, "First content item should be an object") + + // Verify it's an image content + contentType, ok := firstItem["type"].(string) + require.True(t, ok, "Content should have a type") + assert.Equal(t, "image", contentType, "Content type should be image") + + // Check image data and mime type + data, ok := firstItem["data"].(string) + require.True(t, ok, "Content should have data") + assert.Equal(t, "dGVzdCBpbWFnZSBkYXRhIGZvciB0b29sIHJlc3VsdA==", data, "Image data should match") + + mimeType, ok := firstItem["mimeType"].(string) + require.True(t, ok, "Content should have mimeType") + assert.Equal(t, "image/jpeg", mimeType, "MimeType should match") + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for tool call response") + } + }) + + // Test image ToolResult + t.Run("audioToolResult", func(t *testing.T) { + // Create a client + client := addTestClient(mock.server, "test-client-image", true) + + // Create a tool call request + params := struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + }{ + Name: "toolresult.audio.tool", + Arguments: map[string]any{}, + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 2, + Method: methodToolsCall, + Params: paramBytes, + } + + // Process the tool call + mock.server.processToolCall(client, req) + + // Get the response from the client's channel + select { + case response := <-client.channel: + // Parse the response + jsonStart := strings.Index(response, "{") + jsonEnd := strings.LastIndex(response, "}") + require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") + + jsonStr := response[jsonStart : jsonEnd+1] + var parsed map[string]any + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "Response should be valid JSON") + + // Get the result + result, ok := parsed["result"].(map[string]any) + require.True(t, ok, "Response should have a result object") + + // Get the content array + content, ok := result["content"].([]any) + require.True(t, ok, "Result should have a content array") + require.NotEmpty(t, content, "Content should not be empty") + + // The first content item should be converted from ToolResult to ImageContent + firstItem, ok := content[0].(map[string]any) + require.True(t, ok, "First content item should be an object") + + // Verify it's an image content + contentType, ok := firstItem["type"].(string) + require.True(t, ok, "Content should have a type") + assert.Equal(t, "text", contentType, "Content type should be image") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for tool call response") + } + }) + + // Test text ToolResult + t.Run("ToolResult with int type", func(t *testing.T) { + // Create a client + client := addTestClient(mock.server, "test-client-text", true) + + // Create a tool call request + params := struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + }{ + Name: "toolresult.int.tool", + Arguments: map[string]any{}, + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 1, + Method: methodToolsCall, + Params: paramBytes, + } + + // Process the tool call + mock.server.processToolCall(client, req) + + // Get the response from the client's channel + select { + case response := <-client.channel: + // Parse the response + jsonStart := strings.Index(response, "{") + jsonEnd := strings.LastIndex(response, "}") + require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") + + jsonStr := response[jsonStart : jsonEnd+1] + var parsed map[string]any + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "Response should be valid JSON") + + // Get the result + result, ok := parsed["result"].(map[string]any) + require.True(t, ok, "Response should have a result object") + + // Get the content array + content, ok := result["content"].([]any) + require.True(t, ok, "Result should have a content array") + require.NotEmpty(t, content, "Content should not be empty") + + // The first content item should be converted from ToolResult to ImageContent + firstItem, ok := content[0].(map[string]any) + require.True(t, ok, "First content item should be an object") + + // Verify it's an image content + contentType, ok := firstItem["type"].(string) + require.True(t, ok, "Content should have a type") + assert.Equal(t, "text", contentType, "Content type should be image") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for tool call response") + } + }) + + // Test text ToolResult + t.Run("ToolResult with bad type", func(t *testing.T) { + // Create a client + client := addTestClient(mock.server, "test-client-text", true) + + // Create a tool call request + params := struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + }{ + Name: "toolresult.bad.tool", + Arguments: map[string]any{}, + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 1, + Method: methodToolsCall, + Params: paramBytes, + } + + // Process the tool call + mock.server.processToolCall(client, req) + + // Get the response from the client's channel + select { + case response := <-client.channel: + assert.Contains(t, response, "json: unsupported type") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for tool call response") + } + }) +} + +// TestToolCallError tests that tool errors are properly handled +func TestToolCallError(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Register a tool that returns an error + err := mock.server.RegisterTool(Tool{ + Name: "error.tool", + Description: "A tool that returns an error", + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]any{}, + }, + Handler: func(params map[string]any) (any, error) { + return nil, fmt.Errorf("tool execution failed") + }, + }) + require.NoError(t, err) + + // Simulate a client + client := addTestClient(mock.server, "test-client", true) + + // Create a tool call request + params := struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + }{ + Name: "error.tool", + Arguments: map[string]any{}, + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 1, + Method: methodToolsCall, + Params: paramBytes, + } + + // Process the tool call + mock.server.processToolCall(client, req) + + // Check the response + select { + case response := <-client.channel: + assert.Contains(t, response, "event: message", "Response should have message event") + assert.Contains(t, response, "Error:", "Response should contain the error message") + assert.Contains(t, response, "isError", "Response should indicate it's an error") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for tool call response") + } +} + +// TestToolCallTimeout tests that tool timeouts are properly handled +func TestToolCallTimeout(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Set a very short timeout for testing + mock.server.conf.Mcp.ToolTimeout = 10 * time.Millisecond + + // Register a tool that times out + err := mock.server.RegisterTool(Tool{ + Name: "timeout.tool", + Description: "A tool that times out", + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]any{}, + }, + Handler: func(params map[string]any) (any, error) { + time.Sleep(50 * time.Millisecond) // Sleep longer than timeout + return "this should never be returned", nil + }, + }) + require.NoError(t, err) + + // Simulate a client + client := addTestClient(mock.server, "test-client", true) + + // Create a tool call request + params := struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + }{ + Name: "timeout.tool", + Arguments: map[string]any{}, + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 1, + Method: methodToolsCall, + Params: paramBytes, + } + + // Process the tool call + mock.server.processToolCall(client, req) + + // Check the response + select { + case response := <-client.channel: + assert.Contains(t, response, "event: message", "Response should have message event") + assert.Contains(t, response, `-32001`, "Response should contain a timeout error code") + case <-time.After(150 * time.Millisecond): + t.Fatal("Timed out waiting for tool call response") + } +} + +// TestInitializeAndNotifications tests the client initialization flow +func TestInitializeAndNotifications(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Create a test client + client := addTestClient(mock.server, "test-client", false) + + // Test initialize request + initReq := Request{ + JsonRpc: jsonRpcVersion, + ID: 1, + Method: "initialize", + Params: json.RawMessage(`{}`), + } + + mock.server.processInitialize(client, initReq) + + // Check that client is initialized after initialize request + assert.True(t, client.initialized, "Client should be marked as initialized after initialize request") + + // Check the response format + select { + case response := <-client.channel: + // Check response contains required initialization fields + assert.Contains(t, response, "protocolVersion", "Response should include protocol version") + assert.Contains(t, response, "capabilities", "Response should include capabilities") + assert.Contains(t, response, "serverInfo", "Response should include server info") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for initialize response") + } + + // Test notification initialized + mock.server.processNotificationInitialized(client) + assert.True(t, client.initialized, "Client should remain initialized after notification") +} + +// TestRequestHandlingWithoutInitialization tests that requests are properly rejected when client is not initialized +func TestRequestHandlingWithoutInitialization(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + mock.registerExampleTool() + + // Create an uninitialized test client + client := addTestClient(mock.server, "test-client", false) + + // Attempt a tool call before initialization + params := struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + }{ + Name: "test.tool", + Arguments: map[string]any{"input": "foo"}, + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 1, + Method: methodToolsCall, + Params: paramBytes, + } + jsonBody, _ := json.Marshal(req) + + // Handle the request - this should fail because client is not initialized + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", bytes.NewReader(jsonBody)) + mock.server.handleRequest(httptest.NewRecorder(), r) + + // Check error response + select { + case response := <-client.channel: + assert.Contains(t, strings.ToLower(response), "error", "Response should contain an error") + assert.Contains(t, strings.ToLower(response), "not fully initialized", + "Response should mention client not being initialized") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error response") + } +} + +// TestPing tests the ping request handling +func TestPing(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Create a test client + client := addTestClient(mock.server, "test-client", true) + + // Create a ping request + pingReq := Request{ + JsonRpc: jsonRpcVersion, + ID: 1, + Method: "ping", + Params: json.RawMessage(`{}`), + } + + jsonBody, _ := json.Marshal(pingReq) + + // Handle the request - this should fail because client is not initialized + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", bytes.NewReader(jsonBody)) + mock.server.handleRequest(httptest.NewRecorder(), r) + + // Check response + select { + case response := <-client.channel: + assert.Contains(t, response, `"result":`, "Response should contain a result field") + assert.Contains(t, response, `"id":1`, "Response should have the same ID as the request") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for ping response") + } +} + +// TestNotificationCancelled tests the notification cancelled handling +func TestNotificationCancelled(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Create a test client + client := addTestClient(mock.server, "test-client", true) + + // Create a cancellation request + paramBytes, _ := json.Marshal(map[string]any{ + "requestId": 123, + "reason": "user_cancelled", + }) + + cancelReq := Request{ + JsonRpc: jsonRpcVersion, + Method: "notifications/cancelled", + Params: paramBytes, + } + + jsonBody, _ := json.Marshal(cancelReq) + + // Handle the request - this should fail because client is not initialized + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", bytes.NewReader(jsonBody)) + mock.server.handleRequest(httptest.NewRecorder(), r) + + // No response expected for notifications + select { + case <-client.channel: + t.Fatal("No response expected for notifications") + case <-time.After(50 * time.Millisecond): + // This is the expected outcome - no response + } +} + +func TestNotificationCancelled_badParams(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + client := addTestClient(mock.server, "test-client", true) + + cancelReq := Request{ + JsonRpc: jsonRpcVersion, + Method: "notifications/cancelled", + Params: []byte(`invalid json`), + } + + buf := logtest.NewCollector(t) + mock.server.processNotificationCancelled(client, cancelReq) + + select { + case <-client.channel: + t.Fatal("No response expected for notifications") + case <-time.After(50 * time.Millisecond): + assert.Contains(t, buf.String(), "Failed to parse cancellation params") + } +} + +func TestUnknownRequest(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Create a test client + client := addTestClient(mock.server, "test-client", true) + req := Request{ + JsonRpc: jsonRpcVersion, + Method: "unknown", + } + + jsonBody, _ := json.Marshal(req) + + // Handle the request - this should fail because client is not initialized + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", bytes.NewReader(jsonBody)) + mock.server.handleRequest(httptest.NewRecorder(), r) + + // No response expected for notifications + select { + case message := <-client.channel: + evt, err := parseEvent(message) + require.NoError(t, err, "Should parse event without error") + errCode := evt.Data["error"].(map[string]any)["code"] + // because error code will be converted into float64 + assert.Equal(t, float64(errCodeMethodNotFound), errCode) + case <-time.After(50 * time.Millisecond): + // This is the expected outcome - no response + } +} + +func TestResponseWriter_notFlusher(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Handle the request - this should fail because client is not initialized + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", http.NoBody) + var w notFlusherResponseWriter + mock.server.handleSSE(&w, r) + assert.Equal(t, http.StatusInternalServerError, w.code) +} + +func TestResponseWriter_cantWrite(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Handle the request - this should fail because client is not initialized + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", http.NoBody) + var w cantWriteResponseWriter + mock.server.handleSSE(&w, r) + assert.Equal(t, 0, w.code) +} + +func TestHandleSSE_channelClosed(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Handle the request - this should fail because client is not initialized + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", http.NoBody) + w := httptest.NewRecorder() + var wg sync.WaitGroup + wg.Add(1) + go func() { + mock.server.handleSSE(w, r) + wg.Done() + }() + + buf := logtest.NewCollector(t) + for { + time.Sleep(time.Millisecond) + mock.server.clientsLock.Lock() + if len(mock.server.clients) > 0 { + for _, client := range mock.server.clients { + close(client.channel) + delete(mock.server.clients, client.id) + } + mock.server.clientsLock.Unlock() + break + } + mock.server.clientsLock.Unlock() + } + wg.Wait() + assert.Contains(t, "channel was closed", buf.Content(), "Should log channel closed error") +} + +func TestHandleSSE_badResponseWriter(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Handle the request - this should fail because client is not initialized + r := httptest.NewRequest(http.MethodPost, "/", http.NoBody) + var wg sync.WaitGroup + wg.Add(1) + go func() { + var w writeOnceResponseWriter + mock.server.handleSSE(&w, r) + wg.Done() + }() + + var session string + for { + time.Sleep(time.Millisecond) + mock.server.clientsLock.Lock() + if len(mock.server.clients) > 0 { + for _, client := range mock.server.clients { + session = client.id + } + mock.server.clientsLock.Unlock() + break + } + mock.server.clientsLock.Unlock() + } + + time.Sleep(100 * time.Millisecond) + // Create a ping request + pingReq := Request{ + JsonRpc: jsonRpcVersion, + ID: 1, + Method: "ping", + Params: json.RawMessage(`{}`), + } + + jsonBody, _ := json.Marshal(pingReq) + buf := logtest.NewCollector(t) + + // Handle the request - this should fail because client is not initialized + r = httptest.NewRequest(http.MethodPost, fmt.Sprintf("/?session_id=%s", session), + bytes.NewReader(jsonBody)) + mock.server.handleRequest(httptest.NewRecorder(), r) + + wg.Wait() + assert.Contains(t, "Failed to write", buf.Content()) +} + +// TestGetPrompt tests the prompts/get endpoint +func TestGetPrompt(t *testing.T) { + t.Run("test prompt", func(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Create a test client + client := addTestClient(mock.server, "test-client", true) + + // Register a test prompt + testPrompt := Prompt{ + Name: "test.prompt", + Description: "A test prompt", + } + mock.server.RegisterPrompt(testPrompt) + + // Create a get prompt request + paramBytes, _ := json.Marshal(map[string]any{ + "name": "test.prompt", + "arguments": map[string]string{ + "topic": "test topic", + }, + }) + + promptReq := Request{ + JsonRpc: jsonRpcVersion, + ID: 1, + Method: "prompts/get", + Params: paramBytes, + } + + // Process the request + mock.server.processGetPrompt(client, promptReq) + + // Check response + select { + case response := <-client.channel: + assert.Contains(t, response, "description", "Response should include prompt description") + assert.Contains(t, response, "prompts", "Response should include prompts array") + assert.Contains(t, response, "A test prompt", "Response should include the topic argument") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for prompt response") + } + }) + + t.Run("test prompt with invalid params", func(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Create a test client + client := addTestClient(mock.server, "test-client", true) + + paramBytes := []byte("invalid json") + promptReq := Request{ + JsonRpc: jsonRpcVersion, + ID: 1, + Method: "prompts/get", + Params: paramBytes, + } + + // Process the request + mock.server.processGetPrompt(client, promptReq) + + // Check response + select { + case response := <-client.channel: + evt, err := parseEvent(response) + assert.NoError(t, err, "Should be able to parse event") + errMsg, ok := evt.Data["error"].(map[string]any) + assert.True(t, ok, "Should have error in response") + assert.Equal(t, "Invalid parameters", errMsg["message"], "Error message should match") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for prompt response") + } + }) +} + +// TestBroadcast tests the broadcast functionality +func TestBroadcast(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Create two test clients + client1 := addTestClient(mock.server, "client-1", true) + client2 := addTestClient(mock.server, "client-2", true) + + // Broadcast a test message + testData := map[string]string{"key": "value"} + mock.server.broadcast("test_event", testData) + + // Check both clients received the broadcast + for i, client := range []*mcpClient{client1, client2} { + select { + case response := <-client.channel: + assert.Contains(t, response, `event: test_event`, "Response should have the correct event") + assert.Contains(t, response, `"key":"value"`, "Response should contain the broadcast data") + case <-time.After(100 * time.Millisecond): + t.Fatalf("Timed out waiting for broadcast on client %d", i+1) + } + } + + buf := logtest.NewCollector(t) + mock.server.broadcast("test_event", make(chan string)) + // Check that the broadcast was logged + content := buf.Content() + assert.Contains(t, content, "Failed", "Broadcast should be logged") + + for i := 0; i < eventChanSize; i++ { + mock.server.broadcast("test_event", "test") + } + + done := make(chan struct{}) + go func() { + mock.server.broadcast("test_event", "ignored") + close(done) + }() + + select { + case <-time.After(100 * time.Millisecond): + assert.Fail(t, "broadcast should not block") + case <-done: + } +} + +// TestHandleSSEPing tests the automatic ping functionality in the SSE handler +func TestHandleSSEPing(t *testing.T) { + originalInterval := pingInterval.Load() + pingInterval.Set(50 * time.Millisecond) + defer func() { + pingInterval.Set(originalInterval) + }() + + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Create a request context that can be cancelled + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Create a test ResponseRecorder and Request + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", mock.server.conf.Mcp.SseEndpoint, nil).WithContext(ctx) + + // Create a channel to coordinate the test + done := make(chan struct{}) + + // Set up a custom ResponseRecorder that captures writes and signals the test + customResponseWriter := &testResponseWriter{ + ResponseRecorder: w, + writes: make([]string, 0), + done: done, + pingDetected: false, + } + + // Start the SSE handler in a goroutine + go func() { + mock.server.handleSSE(customResponseWriter, r) + }() + + // Wait for ping or timeout + select { + case <-done: + // A ping was detected + assert.True(t, customResponseWriter.pingDetected, "Ping message should have been sent") + case <-time.After(pingInterval.Load() + 100*time.Millisecond): + t.Fatal("Timed out waiting for ping message") + } + + // Verify that the client was added and cleaned up + mock.server.clientsLock.Lock() + clientCount := len(mock.server.clients) + mock.server.clientsLock.Unlock() + + // Clean up by cancelling the context + cancel() + + // Wait for cleanup to complete + time.Sleep(50 * time.Millisecond) + + // Verify client was removed + mock.server.clientsLock.Lock() + finalClientCount := len(mock.server.clients) + mock.server.clientsLock.Unlock() + + assert.Equal(t, 1, clientCount, "One client should be added during the test") + assert.Equal(t, 0, finalClientCount, "Client should be cleaned up after context cancelation") +} + +// TestHandleSSEPing tests the automatic ping functionality in the SSE handler +func TestHandleSSEPing_writeOnce(t *testing.T) { + originalInterval := pingInterval.Load() + pingInterval.Set(50 * time.Millisecond) + defer func() { + pingInterval.Set(originalInterval) + }() + + buf := logtest.NewCollector(t) + var bufLock sync.Mutex + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Start the SSE handler in a goroutine + go func() { + var w writeOnceResponseWriter + r := httptest.NewRequest(http.MethodGet, mock.server.conf.Mcp.SseEndpoint, http.NoBody) + bufLock.Lock() + defer bufLock.Unlock() + mock.server.handleSSE(&w, r) + }() + + // Wait for ping or timeout + time.Sleep(100 * time.Millisecond) + bufLock.Lock() + assert.Contains(t, "Failed to send ping", buf.Content()) + bufLock.Unlock() +} + +func TestServerStartStop(t *testing.T) { + // Create a simple configuration for testing + const yamlConf = `name: test-server +host: localhost +port: 0 +timeout: 1000 +mcp: + name: mcp-test-server +` + var c McpConf + assert.NoError(t, conf.LoadFromYamlBytes([]byte(yamlConf), &c)) + + // Create the server + s := NewMcpServer(c) + + // Start and stop in goroutine to avoid blocking + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + go func() { + s.Start() + }() + + // Allow a brief moment for startup + time.Sleep(50 * time.Millisecond) + + // Stop the server + s.Stop() + + // Wait for context to ensure we properly stopped or timed out + <-ctx.Done() +} + +// TestNotificationInitialized tests the notifications/initialized handling in detail +func TestNotificationInitialized(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + t.Run("uninitializedClient", func(t *testing.T) { + // Create an uninitialized test client + client := addTestClient(mock.server, "test-client-uninitialized", false) + assert.False(t, client.initialized, "Client should start as uninitialized") + + // Create a notification request + req := Request{ + JsonRpc: jsonRpcVersion, + Method: methodNotificationsInitialized, + // No ID for notifications + Params: json.RawMessage(`{}`), // Empty params acceptable for this notification + } + + // Process through the request handler + jsonBody, _ := json.Marshal(req) + r := httptest.NewRequest(http.MethodPost, "/?session_id="+client.id, bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + mock.server.handleRequest(w, r) + + // Verify client is now initialized + assert.True(t, client.initialized, "Client should be marked as initialized after notifications/initialized") + + // Verify the response code is 202 Accepted + assert.Equal(t, http.StatusAccepted, w.Code, "Response status should be 202 Accepted") + + // No actual response body should be sent for notifications + select { + case <-client.channel: + t.Fatal("No response expected for notifications") + case <-time.After(50 * time.Millisecond): + // This is the expected outcome - no response + } + }) + + t.Run("initializedClient", func(t *testing.T) { + // Create an already initialized client + client := addTestClient(mock.server, "test-client-initialized", true) + assert.True(t, client.initialized, "Client should start as initialized") + + // Directly call processNotificationInitialized + mock.server.processNotificationInitialized(client) + + // Verify client remains initialized + assert.True(t, client.initialized, "Client should remain initialized after notifications/initialized") + + // No response expected + select { + case <-client.channel: + t.Fatal("No response expected for notifications") + case <-time.After(50 * time.Millisecond): + // This is the expected outcome - no response + } + }) + + t.Run("errorOnIncorrectUsage", func(t *testing.T) { + // Create a test client + client := addTestClient(mock.server, "test-client-error", false) + + // Create a request with ID (incorrect usage - should be a notification) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 123, // Adding ID makes this an incorrect usage - should be notification + Method: methodNotificationsInitialized, + Params: json.RawMessage(`{}`), + } + + // Process through the request handler + jsonBody, _ := json.Marshal(req) + r := httptest.NewRequest(http.MethodPost, "/?session_id="+client.id, bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + mock.server.handleRequest(w, r) + + // Should get an error response + select { + case response := <-client.channel: + assert.Contains(t, response, "error", "Response should contain an error") + assert.Contains(t, response, "Method should be used as a notification", "Response should explain notification usage") + assert.Contains(t, response, `"id":123`, "Response should include the original ID") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error response") + } + + // Client should not be initialized due to error + assert.False(t, client.initialized, "Client should not be initialized after error") + }) +} + +func TestSendBadResponse(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Create a test client + client := addTestClient(mock.server, "test-client", true) + + // Create a response + response := Response{ + JsonRpc: jsonRpcVersion, + ID: 1, + Result: make(chan int), + } + + // Send the response + mock.server.sendResponse(client, 1, response) + + // Check the response in the client's channel + select { + case res := <-client.channel: + evt, err := parseEvent(res) + require.NoError(t, err, "Should parse event without error") + errMsg, ok := evt.Data["error"].(map[string]any) + require.True(t, ok, "Should have error in response") + assert.Equal(t, float64(errCodeInternalError), errMsg["code"], "Error code should match") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for response") + } +} + +// TestMethodToolsCall tests the handling of tools/call method through handleRequest +func TestMethodToolsCall(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + t.Run("validToolCall", func(t *testing.T) { + // Register a test tool + mock.registerExampleTool() + + // Create an initialized client + client := addTestClient(mock.server, "test-client-valid", true) + + // Create a tools call request with progress token metadata + params := struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + Meta struct { + ProgressToken string `json:"progressToken"` + } `json:"_meta"` + }{ + Name: "test.tool", + Arguments: map[string]any{ + "input": "test-input", + }, + Meta: struct { + ProgressToken string `json:"progressToken"` + }{ + ProgressToken: "token123", + }, + } + + paramBytes, err := json.Marshal(params) + require.NoError(t, err, "Failed to marshal tool call parameters") + + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 42, // Specific ID to verify in response + Method: methodToolsCall, + Params: paramBytes, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-valid", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest (full path) + mock.server.handleRequest(w, r) + + // Verify the HTTP response + assert.Equal(t, http.StatusAccepted, w.Code, "HTTP status should be 202 Accepted") + + // Check the response in client's channel + select { + case response := <-client.channel: + // Verify it's a message event with the expected format + assert.Contains(t, response, "event: message", "Response should be a message event") + + // Parse the JSON part of the SSE message + jsonStart := strings.Index(response, "{") + jsonEnd := strings.LastIndex(response, "}") + require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") + jsonStr := response[jsonStart : jsonEnd+1] + + var parsed map[string]any + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "Should be able to parse response JSON") + + // Validate the structure + result, ok := parsed["result"].(map[string]any) + require.True(t, ok, "Response should have a result object") + + // Verify ID matches our request + id, ok := parsed["id"].(float64) + assert.True(t, ok, "Response should have an ID") + assert.Equal(t, float64(42), id, "Response ID should match request ID") + + // Verify content + content, ok := result["content"].([]any) + require.True(t, ok, "Result should have content array") + assert.NotEmpty(t, content, "Content should not be empty") + + // Check for progress token in metadata + meta, hasMeta := result["_meta"].(map[string]any) + assert.True(t, hasMeta, "Response should include _meta with progress token") + if hasMeta { + token, hasToken := meta["progressToken"].(string) + assert.True(t, hasToken, "Meta should include progress token") + assert.Equal(t, "token123", token, "Progress token should match") + } + + // Check actual result content + if len(content) > 0 { + firstItem, ok := content[0].(map[string]any) + if ok { + assert.Equal(t, "text", firstItem["type"], "Content type should be text") + assert.Contains(t, firstItem["text"], "Processed: test-input", "Content should include processed input") + } + } + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for tool call response") + } + }) + + t.Run("invalidToolName", func(t *testing.T) { + // Create an initialized client + client := addTestClient(mock.server, "test-client-invalid", true) + + // Create a tools call request with invalid tool name + params := struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + }{ + Name: "non-existent-tool", + Arguments: map[string]any{}, + } + + paramBytes, err := json.Marshal(params) + require.NoError(t, err, "Failed to marshal tool call parameters") + + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 43, + Method: methodToolsCall, + Params: paramBytes, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-invalid", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest + mock.server.handleRequest(w, r) + + // Verify response contains error about non-existent tool + select { + case response := <-client.channel: + assert.Contains(t, response, "error", "Response should contain error") + assert.Contains(t, response, "not found", "Error should mention tool not found") + assert.Contains(t, response, "non-existent-tool", "Error should mention the invalid tool name") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error response") + } + }) + + t.Run("clientNotInitialized", func(t *testing.T) { + // Register a tool + mock.registerExampleTool() + + // Create an uninitialized client + client := addTestClient(mock.server, "test-client-uninitialized", false) + + // Create a valid tools call request + params := struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + }{ + Name: "test.tool", + Arguments: map[string]any{ + "input": "test-input", + }, + } + + paramBytes, err := json.Marshal(params) + require.NoError(t, err, "Failed to marshal tool call parameters") + + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 44, + Method: methodToolsCall, + Params: paramBytes, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-uninitialized", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest + mock.server.handleRequest(w, r) + + // Verify response contains error about client not being initialized + select { + case response := <-client.channel: + assert.Contains(t, response, "error", "Response should contain error") + assert.Contains(t, response, "not fully initialized", "Error should mention client not initialized") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error response") + } + }) +} + +// TestMethodPromptsGet tests the handling of prompts/get method through handleRequest +func TestMethodPromptsGet(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + t.Run("staticPrompt", func(t *testing.T) { + // Register a test prompt with static content + testPrompt := Prompt{ + Name: "static-prompt", + Description: "A static test prompt with placeholders", + Arguments: []PromptArgument{ + { + Name: "name", + Description: "Name to use in greeting", + Required: true, + }, + { + Name: "topic", + Description: "Topic to discuss", + Default: "artificial intelligence", + }, + }, + Content: "Hello {{name}}! Let's talk about {{topic}}.", + } + mock.server.RegisterPrompt(testPrompt) + + // Create an initialized client + client := addTestClient(mock.server, "test-client-static", true) + + // Create a prompts/get request + params := struct { + Name string `json:"name"` + Arguments map[string]string `json:"arguments"` + }{ + Name: "static-prompt", + Arguments: map[string]string{ + "name": "Test User", + // Intentionally not providing "topic" to test default values + }, + } + + paramBytes, err := json.Marshal(params) + require.NoError(t, err, "Failed to marshal prompt get parameters") + + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 70, + Method: methodPromptsGet, + Params: paramBytes, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-static", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest (full path) + mock.server.handleRequest(w, r) + + // Verify the HTTP response + assert.Equal(t, http.StatusAccepted, w.Code, "HTTP status should be 202 Accepted") + + // Check the response in client's channel + select { + case response := <-client.channel: + // Verify it's a message event with the expected format + assert.Contains(t, response, "event: message", "Response should be a message event") + + // Parse the JSON part of the SSE message + jsonStart := strings.Index(response, "{") + jsonEnd := strings.LastIndex(response, "}") + require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") + jsonStr := response[jsonStart : jsonEnd+1] + + var parsed map[string]any + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "Should be able to parse response JSON") + + // Validate the structure + result, ok := parsed["result"].(map[string]any) + require.True(t, ok, "Response should have a result object") + + // Verify ID matches our request + id, ok := parsed["id"].(float64) + assert.True(t, ok, "Response should have an ID") + assert.Equal(t, float64(70), id, "Response ID should match request ID") + + // Verify description + description, ok := result["description"].(string) + assert.True(t, ok, "Response should include prompt description") + assert.Equal(t, "A static test prompt with placeholders", description, "Description should match") + + // Verify messages + messages, ok := result["messages"].([]any) + require.True(t, ok, "Result should have messages array") + assert.Len(t, messages, 1, "Should have 1 message") + + // Check message content - should have placeholder substitutions + if len(messages) > 0 { + message, ok := messages[0].(map[string]any) + require.True(t, ok, "Message should be an object") + assert.Equal(t, "user", message["role"], "Role should be 'user'") + + content, ok := message["content"].(map[string]any) + require.True(t, ok, "Should have content object") + assert.Equal(t, "text", content["type"], "Content type should be text") + assert.Contains(t, content["text"], "Hello Test User", "Content should include the name argument") + assert.Contains(t, content["text"], "about artificial intelligence", + "Content should include the default topic argument") + } + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for prompt get response") + } + }) + + t.Run("dynamicPrompt", func(t *testing.T) { + // Register a test prompt with a handler function + testPrompt := Prompt{ + Name: "dynamic-prompt", + Description: "A dynamic test prompt with a handler", + Arguments: []PromptArgument{ + { + Name: "username", + Description: "User's name", + Required: true, + }, + { + Name: "question", + Description: "User's question", + Default: "How does this work?", + }, + }, + Handler: func(args map[string]string) ([]PromptMessage, error) { + username := args["username"] + question := args["question"] + + // Create a system message + systemMessage := PromptMessage{ + Role: "system", + Content: TextContent{ + Type: "text", + Text: "You are a helpful assistant.", + }, + } + + // Create a user message + userMessage := PromptMessage{ + Role: "user", + Content: TextContent{ + Type: "text", + Text: fmt.Sprintf("Hi, I'm %s and I'm wondering: %s", username, question), + }, + } + + return []PromptMessage{systemMessage, userMessage}, nil + }, + } + mock.server.RegisterPrompt(testPrompt) + + // Create an initialized client + client := addTestClient(mock.server, "test-client-dynamic", true) + + // Create a prompts/get request + params := struct { + Name string `json:"name"` + Arguments map[string]string `json:"arguments"` + }{ + Name: "dynamic-prompt", + Arguments: map[string]string{ + "username": "Dynamic User", + "question": "How to test this?", + }, + } + + paramBytes, err := json.Marshal(params) + require.NoError(t, err, "Failed to marshal prompt get parameters") + + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 71, + Method: methodPromptsGet, + Params: paramBytes, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-dynamic", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest + mock.server.handleRequest(w, r) + + // Check the response + select { + case response := <-client.channel: + // Extract and parse JSON + jsonStart := strings.Index(response, "{") + jsonEnd := strings.LastIndex(response, "}") + jsonStr := response[jsonStart : jsonEnd+1] + + var parsed map[string]any + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "Should be able to parse response JSON") + + result, ok := parsed["result"].(map[string]any) + require.True(t, ok, "Response should have a result object") + + // Verify messages - should have 2 messages from handler + messages, ok := result["messages"].([]any) + require.True(t, ok, "Result should have messages array") + assert.Len(t, messages, 2, "Should have 2 messages") + + // Check message content + if len(messages) >= 2 { + // First message should be system + message1, _ := messages[0].(map[string]any) + assert.Equal(t, "system", message1["role"], "First role should be 'system'") + + content1, _ := message1["content"].(map[string]any) + assert.Contains(t, content1["text"], "helpful assistant", "System message should be correct") + + // Second message should be user + message2, _ := messages[1].(map[string]any) + assert.Equal(t, "user", message2["role"], "Second role should be 'user'") + + content2, _ := message2["content"].(map[string]any) + assert.Contains(t, content2["text"], "Dynamic User", "User message should contain username") + assert.Contains(t, content2["text"], "How to test this?", "User message should contain question") + } + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for prompt get response") + } + }) + + t.Run("missingRequiredArgument", func(t *testing.T) { + // Register a test prompt with a required argument + testPrompt := Prompt{ + Name: "required-arg-prompt", + Description: "A prompt with required arguments", + Arguments: []PromptArgument{ + { + Name: "required_arg", + Description: "This argument is required", + Required: true, + }, + }, + } + mock.server.RegisterPrompt(testPrompt) + + // Create an initialized client + client := addTestClient(mock.server, "test-client-missing-arg", true) + + // Create a prompts/get request with missing required argument + params := struct { + Name string `json:"name"` + Arguments map[string]string `json:"arguments"` + }{ + Name: "required-arg-prompt", + Arguments: map[string]string{}, // Empty arguments + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 72, + Method: methodPromptsGet, + Params: paramBytes, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-missing-arg", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest + mock.server.handleRequest(w, r) + + // Check for error response about missing required argument + select { + case response := <-client.channel: + assert.Contains(t, response, "error", "Response should contain error") + assert.Contains(t, response, "Missing required arguments", "Error should mention missing arguments") + assert.Contains(t, response, "required_arg", "Error should name the missing argument") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error response") + } + }) + + t.Run("promptNotFound", func(t *testing.T) { + // Create an initialized client + client := addTestClient(mock.server, "test-client-prompt-not-found", true) + + // Create a prompts/get request with non-existent prompt + params := struct { + Name string `json:"name"` + Arguments map[string]string `json:"arguments"` + }{ + Name: "non-existent-prompt", + Arguments: map[string]string{}, + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 73, + Method: methodPromptsGet, + Params: paramBytes, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-prompt-not-found", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest + mock.server.handleRequest(w, r) + + // Check for error response about non-existent prompt + select { + case response := <-client.channel: + assert.Contains(t, response, "error", "Response should contain error") + assert.Contains(t, response, "Prompt 'non-existent-prompt' not found", "Error should mention prompt not found") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error response") + } + }) + + t.Run("handlerError", func(t *testing.T) { + // Register a test prompt with a handler that returns an error + testPrompt := Prompt{ + Name: "error-handler-prompt", + Description: "A prompt with a handler that returns an error", + Arguments: []PromptArgument{}, + Handler: func(args map[string]string) ([]PromptMessage, error) { + return nil, fmt.Errorf("test handler error") + }, + } + mock.server.RegisterPrompt(testPrompt) + + // Create an initialized client + client := addTestClient(mock.server, "test-client-handler-error", true) + + // Create a prompts/get request + params := struct { + Name string `json:"name"` + Arguments map[string]string `json:"arguments"` + }{ + Name: "error-handler-prompt", + Arguments: map[string]string{}, + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 74, + Method: methodPromptsGet, + Params: paramBytes, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-handler-error", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest + mock.server.handleRequest(w, r) + + // Check for error response about handler error + select { + case response := <-client.channel: + assert.Contains(t, response, "error", "Response should contain error") + assert.Contains(t, response, "Error generating prompt content", "Error should mention generating content") + assert.Contains(t, response, "test handler error", "Error should include the handler error message") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error response") + } + }) + + t.Run("invalidParameters", func(t *testing.T) { + // Create an invalid JSON request + invalidJson := []byte(`{"not valid json`) + + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 75, + Method: methodPromptsGet, + Params: invalidJson, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-invalid-params", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest + mock.server.handleRequest(w, r) + assert.Equal(t, http.StatusBadRequest, w.Code) + }) + + t.Run("clientNotInitialized", func(t *testing.T) { + // Register a basic prompt + testPrompt := Prompt{ + Name: "basic-prompt", + Description: "A basic test prompt", + } + mock.server.RegisterPrompt(testPrompt) + + // Create an uninitialized client + client := addTestClient(mock.server, "test-client-uninit", false) + + // Create a valid prompts/get request + params := struct { + Name string `json:"name"` + Arguments map[string]string `json:"arguments"` + }{ + Name: "basic-prompt", + Arguments: map[string]string{}, + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 76, + Method: methodPromptsGet, + Params: paramBytes, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-uninit", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest + mock.server.handleRequest(w, r) + + // Verify response contains error about client not being initialized + select { + case response := <-client.channel: + assert.Contains(t, response, "error", "Response should contain error") + assert.Contains(t, response, "not fully initialized", "Error should mention client not initialized") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error response") + } + }) +} + +func TestMethodResourcesList(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + t.Run("validResourceWithHandler", func(t *testing.T) { + // Register a test resource with handler + testResource := Resource{ + Name: "test-resource", + URI: "file:///test/resource.txt", + Description: "A test resource with handler", + MimeType: "text/plain", + Handler: func() (ResourceContent, error) { + return ResourceContent{ + URI: "file:///test/resource.txt", + MimeType: "text/plain", + Text: "This is test resource content", + }, nil + }, + } + mock.server.RegisterResource(testResource) + + // Create an initialized client + client := addTestClient(mock.server, "test-client-resources", true) + + // Create a resources/read request + params := PaginatedParams{ + Cursor: "next-cursor", + Meta: struct { + ProgressToken any `json:"progressToken"` + }{ + ProgressToken: "token", + }, + } + + paramBytes, err := json.Marshal(params) + require.NoError(t, err, "Failed to marshal resource read parameters") + + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 50, + Method: methodResourcesList, + Params: paramBytes, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-resources", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest (full path) + mock.server.handleRequest(w, r) + + // Verify the HTTP response + assert.Equal(t, http.StatusAccepted, w.Code, "HTTP status should be 202 Accepted") + + // Check the response in client's channel + select { + case response := <-client.channel: + evt, err := parseEvent(response) + assert.NoError(t, err) + result, ok := evt.Data["result"].(map[string]any) + assert.True(t, ok) + assert.Equal(t, "next-cursor", result["nextCursor"]) + assert.Equal(t, "token", result["_meta"].(map[string]any)["progressToken"]) + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for resource read response") + } + }) +} + +// TestMethodResourcesRead tests the handling of resources/read method +func TestMethodResourcesRead(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + t.Run("validResourceWithHandler", func(t *testing.T) { + // Register a test resource with handler + testResource := Resource{ + Name: "test-resource", + URI: "file:///test/resource.txt", + Description: "A test resource with handler", + MimeType: "text/plain", + Handler: func() (ResourceContent, error) { + return ResourceContent{ + URI: "file:///test/resource.txt", + MimeType: "text/plain", + Text: "This is test resource content", + }, nil + }, + } + mock.server.RegisterResource(testResource) + + // Create an initialized client + client := addTestClient(mock.server, "test-client-resources", true) + + // Create a resources/read request + params := ResourceReadParams{ + URI: "file:///test/resource.txt", + } + + paramBytes, err := json.Marshal(params) + require.NoError(t, err, "Failed to marshal resource read parameters") + + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 50, + Method: methodResourcesRead, + Params: paramBytes, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-resources", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest (full path) + mock.server.handleRequest(w, r) + + // Verify the HTTP response + assert.Equal(t, http.StatusAccepted, w.Code, "HTTP status should be 202 Accepted") + + // Check the response in client's channel + select { + case response := <-client.channel: + // Verify it's a message event with the expected format + assert.Contains(t, response, "event: message", "Response should be a message event") + + // Parse the JSON part of the SSE message + jsonStart := strings.Index(response, "{") + jsonEnd := strings.LastIndex(response, "}") + require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") + jsonStr := response[jsonStart : jsonEnd+1] + + var parsed map[string]any + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "Should be able to parse response JSON") + + // Validate the structure + result, ok := parsed["result"].(map[string]any) + require.True(t, ok, "Response should have a result object") + + // Verify ID matches our request + id, ok := parsed["id"].(float64) + assert.True(t, ok, "Response should have an ID") + assert.Equal(t, float64(50), id, "Response ID should match request ID") + + // Verify contents + contents, ok := result["contents"].([]any) + require.True(t, ok, "Result should have contents array") + assert.Len(t, contents, 1, "Contents array should have 1 item") + + // Check content details + if len(contents) > 0 { + content, ok := contents[0].(map[string]any) + require.True(t, ok, "Content should be an object") + assert.Equal(t, "file:///test/resource.txt", content["uri"], "URI should match") + assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match") + assert.Equal(t, "This is test resource content", content["text"], "Text content should match") + } + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for resource read response") + } + }) + + t.Run("resourceWithoutHandler", func(t *testing.T) { + // Register a test resource without handler + testResource := Resource{ + Name: "no-handler-resource", + URI: "file:///test/no-handler.txt", + Description: "A test resource without handler", + MimeType: "text/plain", + // No handler provided + } + mock.server.RegisterResource(testResource) + + // Create an initialized client + client := addTestClient(mock.server, "test-client-no-handler", true) + + // Create a resources/read request + params := ResourceReadParams{ + URI: "file:///test/no-handler.txt", + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 51, + Method: methodResourcesRead, + Params: paramBytes, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-no-handler", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest + mock.server.handleRequest(w, r) + + // Check for response with empty content + select { + case response := <-client.channel: + assert.Contains(t, response, "event: message", "Response should be a message event") + + // Extract and parse JSON + jsonStart := strings.Index(response, "{") + jsonEnd := strings.LastIndex(response, "}") + jsonStr := response[jsonStart : jsonEnd+1] + + var parsed map[string]any + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "Should be able to parse response JSON") + + // Check contents exists but has empty text + result, ok := parsed["result"].(map[string]any) + require.True(t, ok, "Response should have a result object") + + contents, ok := result["contents"].([]any) + require.True(t, ok, "Result should have contents array") + assert.Len(t, contents, 1, "Contents array should have 1 item") + + // Check content details - should have URI and MimeType but empty text + if len(contents) > 0 { + content, ok := contents[0].(map[string]any) + require.True(t, ok, "Content should be an object") + assert.Equal(t, "file:///test/no-handler.txt", content["uri"], "URI should match") + assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match") + _, ok = content["text"] + assert.False(t, ok, "Text content should be empty string") + } + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for resource read response") + } + }) + + t.Run("resourceNotFound", func(t *testing.T) { + // Create an initialized client + client := addTestClient(mock.server, "test-client-not-found", true) + + // Create a resources/read request with non-existent URI + params := ResourceReadParams{ + URI: "file:///test/non-existent.txt", + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 52, + Method: methodResourcesRead, + Params: paramBytes, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-not-found", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest + mock.server.handleRequest(w, r) + + // Check for error response about resource not found + select { + case response := <-client.channel: + assert.Contains(t, response, "error", "Response should contain error") + assert.Contains(t, response, "Resource with URI", "Error should mention resource URI") + assert.Contains(t, response, "not found", "Error should indicate resource not found") + assert.Contains(t, response, "file:///test/non-existent.txt", "Error should mention the requested URI") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error response") + } + }) + + t.Run("invalidParameters", func(t *testing.T) { + // Create an invalid JSON request + invalidJson := []byte(`{"not valid json`) + + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 53, + Method: methodResourcesRead, + Params: invalidJson, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-invalid-params", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest + mock.server.handleRequest(w, r) + assert.Equal(t, http.StatusBadRequest, w.Code) + }) + + t.Run("invalidParameters direct", func(t *testing.T) { + // Create an invalid JSON request + invalidJson := []byte(`{"not valid json`) + + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 53, + Method: methodResourcesRead, + Params: invalidJson, + } + + // Create an initialized client + client := addTestClient(mock.server, "test-client-resources", true) + + // Process through handleRequest + mock.server.processResourcesRead(client, req) + + select { + case response := <-client.channel: + assert.Contains(t, response, "error", "Response should contain error") + assert.Contains(t, response, "Invalid parameters", "Error should mention invalid parameters") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error response") + } + }) + + t.Run("handlerError", func(t *testing.T) { + // Register a test resource with handler that returns an error + testResource := Resource{ + Name: "error-resource", + URI: "file:///test/error.txt", + Description: "A test resource with handler that returns error", + MimeType: "text/plain", + Handler: func() (ResourceContent, error) { + return ResourceContent{}, fmt.Errorf("test handler error") + }, + } + mock.server.RegisterResource(testResource) + + // Create an initialized client + client := addTestClient(mock.server, "test-client-handler-error", true) + + // Create a resources/read request + params := ResourceReadParams{ + URI: "file:///test/error.txt", + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 54, + Method: methodResourcesRead, + Params: paramBytes, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-handler-error", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest + mock.server.handleRequest(w, r) + + // Check for error response about handler error + select { + case response := <-client.channel: + assert.Contains(t, response, "error", "Response should contain error") + assert.Contains(t, response, "Error reading resource", "Error should mention reading resource") + assert.Contains(t, response, "test handler error", "Error should include handler error message") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error response") + } + }) + + t.Run("handlerMissingURIAndMimeType", func(t *testing.T) { + // Register a test resource with handler that returns content without URI and MimeType + testResource := Resource{ + Name: "missing-fields-resource", + URI: "file:///test/missing-fields.txt", + Description: "A test resource with handler that returns content missing fields", + MimeType: "text/plain", + Handler: func() (ResourceContent, error) { + // Return ResourceContent without URI and MimeType + return ResourceContent{ + Text: "Content with missing fields", + }, nil + }, + } + mock.server.RegisterResource(testResource) + + // Create an initialized client + client := addTestClient(mock.server, "test-client-missing-fields", true) + + // Create a resources/read request + params := ResourceReadParams{ + URI: "file:///test/missing-fields.txt", + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 55, + Method: methodResourcesRead, + Params: paramBytes, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-missing-fields", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest + mock.server.handleRequest(w, r) + + // Check response - server should fill in the missing fields + select { + case response := <-client.channel: + assert.Contains(t, response, "event: message", "Response should be a message event") + + // Extract and parse JSON + jsonStart := strings.Index(response, "{") + jsonEnd := strings.LastIndex(response, "}") + jsonStr := response[jsonStart : jsonEnd+1] + + var parsed map[string]any + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "Should be able to parse response JSON") + + result, ok := parsed["result"].(map[string]any) + require.True(t, ok, "Response should have a result object") + + contents, ok := result["contents"].([]any) + require.True(t, ok, "Result should have contents array") + assert.Len(t, contents, 1, "Contents array should have 1 item") + + // Check content details - server should fill in missing URI and MimeType + if len(contents) > 0 { + content, ok := contents[0].(map[string]any) + require.True(t, ok, "Content should be an object") + assert.Equal(t, "file:///test/missing-fields.txt", content["uri"], "URI should be filled from request") + assert.Equal(t, "text/plain", content["mimeType"], "MimeType should be filled from resource") + assert.Equal(t, "Content with missing fields", content["text"], "Text content should match") + } + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for resource read response") + } + }) +} + +// TestMethodResourcesSubscribe tests the handling of resources/subscribe method +func TestMethodResourcesSubscribe(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + t.Run("validSubscription", func(t *testing.T) { + // Register a test resource + testResource := Resource{ + Name: "subscribe-resource", + URI: "file:///test/subscribe.txt", + Description: "A test resource for subscription", + MimeType: "text/plain", + } + mock.server.RegisterResource(testResource) + + // Create an initialized client + client := addTestClient(mock.server, "test-client-subscribe", true) + + // Create a resources/subscribe request + params := ResourceSubscribeParams{ + URI: "file:///test/subscribe.txt", + } + + paramBytes, err := json.Marshal(params) + require.NoError(t, err, "Failed to marshal resource subscribe parameters") + + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 60, + Method: methodResourcesSubscribe, + Params: paramBytes, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-subscribe", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest (full path) + mock.server.handleRequest(w, r) + + // Verify the HTTP response + assert.Equal(t, http.StatusAccepted, w.Code, "HTTP status should be 202 Accepted") + + // Check the response in client's channel - should be an empty success response + select { + case response := <-client.channel: + // Verify it's a message event with the expected format + assert.Contains(t, response, "event: message", "Response should be a message event") + + // Parse the JSON part of the SSE message + jsonStart := strings.Index(response, "{") + jsonEnd := strings.LastIndex(response, "}") + require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") + jsonStr := response[jsonStart : jsonEnd+1] + + var parsed map[string]any + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "Should be able to parse response JSON") + + // Verify ID matches our request + id, ok := parsed["id"].(float64) + assert.True(t, ok, "Response should have an ID") + assert.Equal(t, float64(60), id, "Response ID should match request ID") + + // Verify the result exists and is an empty object + result, ok := parsed["result"].(map[string]any) + require.True(t, ok, "Response should have a result object") + assert.Empty(t, result, "Result should be an empty object for successful subscription") + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for subscription response") + } + }) + + t.Run("resourceNotFound", func(t *testing.T) { + // Create an initialized client + client := addTestClient(mock.server, "test-client-sub-not-found", true) + + // Create a resources/subscribe request with non-existent URI + params := ResourceSubscribeParams{ + URI: "file:///test/non-existent-subscription.txt", + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 61, + Method: methodResourcesSubscribe, + Params: paramBytes, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-sub-not-found", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest + mock.server.handleRequest(w, r) + + // Check for error response about resource not found + select { + case response := <-client.channel: + assert.Contains(t, response, "error", "Response should contain error") + assert.Contains(t, response, "Resource with URI", "Error should mention resource URI") + assert.Contains(t, response, "not found", "Error should indicate resource not found") + assert.Contains(t, response, "file:///test/non-existent-subscription.txt", "Error should mention the requested URI") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error response") + } + }) + + t.Run("invalidParameters", func(t *testing.T) { + // Create an invalid JSON request + invalidJson := []byte(`{"not valid json`) + + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 62, + Method: methodResourcesSubscribe, + Params: invalidJson, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-sub-invalid-params", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest + mock.server.handleRequest(w, r) + assert.Equal(t, http.StatusBadRequest, w.Code, "HTTP status should be 400 Bad Request") + }) + + t.Run("invalidParameters direct", func(t *testing.T) { + // Create an invalid JSON request + invalidJson := []byte(`{"not valid json`) + + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 62, + Method: methodResourcesSubscribe, + Params: invalidJson, + } + + client := addTestClient(mock.server, "test-client-sub-not-found", true) + mock.server.processResourceSubscribe(client, req) + + select { + case response := <-client.channel: + assert.Contains(t, response, "error", "Response should contain error") + assert.Contains(t, response, "Invalid parameters", "Error should mention invalid parameters") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error response") + } + }) + + t.Run("clientNotInitialized", func(t *testing.T) { + // Register a test resource + testResource := Resource{ + Name: "subscribe-resource-uninit", + URI: "file:///test/subscribe-uninit.txt", + Description: "A test resource for subscription with uninitialized client", + } + mock.server.RegisterResource(testResource) + + // Create an uninitialized client + client := addTestClient(mock.server, "test-client-sub-uninitialized", false) + + // Create a valid resources/subscribe request + params := ResourceSubscribeParams{ + URI: "file:///test/subscribe-uninit.txt", + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 63, + Method: methodResourcesSubscribe, + Params: paramBytes, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-sub-uninitialized", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest + mock.server.handleRequest(w, r) + + // Verify response contains error about client not being initialized + select { + case response := <-client.channel: + assert.Contains(t, response, "error", "Response should contain error") + assert.Contains(t, response, "not fully initialized", "Error should mention client not initialized") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error response") + } + }) + + t.Run("missingURIParameter", func(t *testing.T) { + // Create an initialized client + client := addTestClient(mock.server, "test-client-sub-missing-uri", true) + + // Create a subscription request with empty URI + params := ResourceSubscribeParams{ + URI: "", // Empty URI + } + + paramBytes, _ := json.Marshal(params) + req := Request{ + JsonRpc: jsonRpcVersion, + ID: 64, + Method: methodResourcesSubscribe, + Params: paramBytes, + } + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-sub-missing-uri", + bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest + mock.server.handleRequest(w, r) + + // Check for error response about resource not found (empty URI) + select { + case response := <-client.channel: + assert.Contains(t, response, "error", "Response should contain error") + assert.Contains(t, response, "Resource with URI", "Error should mention resource URI") + assert.Contains(t, response, "not found", "Error should indicate resource not found") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error response") + } + }) +} + +// TestToolCallUnmarshalError tests the error handling when unmarshaling invalid JSON in processToolCall +func TestToolCallUnmarshalError(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Create an initialized client + client := addTestClient(mock.server, "test-client-unmarshal-error", true) + + // Create a request with invalid JSON in Params + req := Request{ + JsonRpc: "2.0", + ID: 100, + Method: methodToolsCall, + Params: []byte(`{"name": "test.tool", "arguments": {"input": invalid_json}}`), // This is invalid JSON + } + + // Process the tool call directly + mock.server.processToolCall(client, req) + + // Check for error response about invalid JSON + select { + case response := <-client.channel: + assert.Contains(t, response, "error", "Response should contain an error") + assert.Contains(t, response, "Invalid tool call parameters", "Error should mention invalid parameters") + + // Extract error code from response + jsonStart := strings.Index(response, "{") + jsonEnd := strings.LastIndex(response, "}") + require.True(t, jsonStart >= 0 && jsonEnd > jsonStart, "Response should contain valid JSON") + jsonStr := response[jsonStart : jsonEnd+1] + + var parsed struct { + Error struct { + Code int `json:"code"` + } `json:"error"` + } + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "Should be able to parse response JSON") + + // Verify correct error code was returned + assert.Equal(t, errCodeInvalidParams, parsed.Error.Code, "Error code should be errCodeInvalidParams") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for error response") + } +} + +// TestToolCallWithInvalidParams tests the handling when calling handleRequest with invalid JSON params +func TestToolCallWithInvalidParams(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Register a tool to make sure it exists + mock.registerExampleTool() + + // Create a request with invalid JSON + req := Request{ + JsonRpc: "2.0", + ID: 101, + Method: methodToolsCall, + Params: []byte(`{"name": "test.tool", "arguments": {this_is_invalid_json}}`), + } + + jsonBody, _ := json.Marshal(req) + + // Create HTTP request + r := httptest.NewRequest("POST", "/?session_id=test-client-invalid-json", bytes.NewReader(jsonBody)) + w := httptest.NewRecorder() + + // Process through handleRequest + mock.server.handleRequest(w, r) + + // Verify HTTP status is Accepted (even for errors, we accept the request) + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +type mockResponseWriter struct { +} + +func (m *mockResponseWriter) Header() http.Header { + return http.Header{} +} + +func (m *mockResponseWriter) Write(i []byte) (int, error) { + return len(i), nil +} + +func (m *mockResponseWriter) WriteHeader(_ int) { +} + +type notFlusherResponseWriter struct { + mockResponseWriter + code int +} + +func (m *notFlusherResponseWriter) WriteHeader(code int) { + m.code = code +} + +type cantWriteResponseWriter struct { + mockResponseWriter + code int +} + +func (m *cantWriteResponseWriter) Flush() { +} + +func (m *cantWriteResponseWriter) Write(_ []byte) (int, error) { + return 0, fmt.Errorf("can't write") +} + +type writeOnceResponseWriter struct { + mockResponseWriter + times int32 +} + +func (m *writeOnceResponseWriter) Flush() { +} + +func (m *writeOnceResponseWriter) Write(i []byte) (int, error) { + if atomic.AddInt32(&m.times, 1) > 1 { + return 0, fmt.Errorf("write once") + } + return len(i), nil +} + +// testResponseWriter is a custom http.ResponseWriter that captures writes and detects ping messages +type testResponseWriter struct { + *httptest.ResponseRecorder + writes []string + mu sync.Mutex + pingDetected bool + done chan struct{} +} + +// Write overrides the ResponseRecorder's Write method to detect ping messages +func (w *testResponseWriter) Write(b []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + + written, err := w.ResponseRecorder.Write(b) + if err != nil { + return written, err + } + + content := string(b) + w.writes = append(w.writes, content) + + // Check if this is a ping message + if strings.Contains(content, "event: ping") { + w.pingDetected = true + // Signal that we've detected a ping + select { + case w.done <- struct{}{}: + default: + // Channel might be closed or already signaled + } + } + + return written, nil +} + +// Flush implements the http.Flusher interface +func (w *testResponseWriter) Flush() { + w.ResponseRecorder.Flush() +} diff --git a/mcp/types.go b/mcp/types.go new file mode 100644 index 000000000..00cb65582 --- /dev/null +++ b/mcp/types.go @@ -0,0 +1,294 @@ +package mcp + +import ( + "encoding/json" + "sync" + + "github.com/zeromicro/go-zero/rest" +) + +// Cursor is an opaque token used for pagination +type Cursor string + +// Request represents a generic MCP request following JSON-RPC 2.0 specification +type Request struct { + SessionId string `form:"session_id"` // Session identifier for client tracking + JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec + ID int64 `json:"id"` // Request identifier for matching responses + Method string `json:"method"` // Method name to invoke + Params json.RawMessage `json:"params"` // Parameters for the method +} + +type PaginatedParams struct { + Cursor string `json:"cursor"` + Meta struct { + ProgressToken any `json:"progressToken"` + } `json:"_meta"` +} + +// Result is the base interface for all results +type Result struct { + Meta map[string]any `json:"_meta,omitempty"` // Optional metadata +} + +// PaginatedResult is a base for results that support pagination +type PaginatedResult struct { + Result + NextCursor Cursor `json:"nextCursor,omitempty"` // Opaque token for fetching next page +} + +// ListToolsResult represents the response to a tools/list request +type ListToolsResult struct { + PaginatedResult + Tools []Tool `json:"tools"` // List of available tools +} + +// Message Content Types + +// roleType represents the sender or recipient of messages in a conversation +type roleType string + +// PromptArgument defines a single argument that can be passed to a prompt +type PromptArgument struct { + Name string `json:"name"` // Argument name + Description string `json:"description,omitempty"` // Human-readable description + Required bool `json:"required,omitempty"` // Whether this argument is required + Default string `json:"default,omitempty"` // Default value if not provided +} + +// PromptHandler is a function that dynamically generates prompt content +type PromptHandler func(args map[string]string) ([]PromptMessage, error) + +// Prompt represents an MCP Prompt definition +type Prompt struct { + Name string `json:"name"` // Unique identifier for the prompt + Description string `json:"description,omitempty"` // Human-readable description + Arguments []PromptArgument `json:"arguments,omitempty"` // Arguments for customization + Content string `json:"-"` // Static content (internal use only) + Handler PromptHandler `json:"-"` // Handler for dynamic content generation +} + +// PromptMessage represents a message in a conversation +type PromptMessage struct { + Role roleType `json:"role"` // Message sender role + Content any `json:"content"` // Message content (TextContent, ImageContent, etc.) +} + +// TextContent represents text content in a message +type TextContent struct { + Type string `json:"type"` // Always "text" + Text string `json:"text"` // The text content + Annotations *Annotations `json:"annotations,omitempty"` // Optional annotations +} + +// ImageContent represents image data in a message +type ImageContent struct { + Type string `json:"type"` // Always "image" + Data string `json:"data"` // Base64-encoded image data + MimeType string `json:"mimeType"` // MIME type (e.g., "image/png") +} + +// AudioContent represents audio data in a message +type AudioContent struct { + Type string `json:"type"` // Always "audio" + Data string `json:"data"` // Base64-encoded audio data + MimeType string `json:"mimeType"` // MIME type (e.g., "audio/mp3") +} + +// FileContent represents file content +type FileContent struct { + URI string `json:"uri"` // URI identifying the file + MimeType string `json:"mimeType"` // MIME type of the file + Text string `json:"text"` // File content as text +} + +// EmbeddedResource represents a resource embedded in a message +type EmbeddedResource struct { + Type string `json:"type"` // Always "resource" + Resource struct { + URI string `json:"uri"` // Resource URI + MimeType string `json:"mimeType"` // MIME type of the resource + Text string `json:"text,omitempty"` // Text content (if available) + Blob string `json:"blob,omitempty"` // Base64 encoded blob data (if available) + } `json:"resource"` // The resource data +} + +// Annotations provides additional metadata for content +type Annotations struct { + Audience []roleType `json:"audience,omitempty"` // Who should see this content + Priority *float64 `json:"priority,omitempty"` // Optional priority (0-1) +} + +// Tool-related Types + +// Tool Definition Types + +// ToolHandler is a function that handles tool calls +type ToolHandler func(params map[string]any) (any, error) + +// Tool represents a Model Context Protocol Tool definition +type Tool struct { + Name string `json:"name"` // Unique identifier for the tool + Description string `json:"description"` // Human-readable description + InputSchema InputSchema `json:"inputSchema"` // JSON Schema for parameters + Handler ToolHandler `json:"-"` // Not sent to clients +} + +// InputSchema represents tool's input schema in JSON Schema format +type InputSchema struct { + Type string `json:"type"` // Always "object" for tool inputs + Properties map[string]any `json:"properties"` // Property definitions + Required []string `json:"required,omitempty"` // List of required properties +} + +// CallToolResult represents a tool call result that conforms to the MCP schema +type CallToolResult struct { + Result + Content []interface{} `json:"content"` // Content items (text, images, etc.) + IsError bool `json:"isError,omitempty"` // True if tool execution failed +} + +// Resource represents a Model Context Protocol Resource definition +type Resource struct { + URI string `json:"uri"` // Unique resource identifier (RFC3986) + Name string `json:"name"` // Human-readable name + Description string `json:"description,omitempty"` // Optional description + MimeType string `json:"mimeType,omitempty"` // Optional MIME type + Handler ResourceHandler `json:"-"` // Internal handler not sent to clients +} + +// ResourceHandler is a function that handles resource read requests +type ResourceHandler func() (ResourceContent, error) + +// ResourceContent represents the content of a resource +type ResourceContent struct { + URI string `json:"uri"` // Resource URI (required) + MimeType string `json:"mimeType,omitempty"` // MIME type of the resource + Text string `json:"text,omitempty"` // Text content (if available) + Blob string `json:"blob,omitempty"` // Base64 encoded blob data (if available) +} + +// ResourcesListResult represents the response to a resources/list request +type ResourcesListResult struct { + PaginatedResult + Resources []Resource `json:"resources"` // List of available resources +} + +// ResourceReadParams contains parameters for a resources/read request +type ResourceReadParams struct { + URI string `json:"uri"` // URI of the resource to read +} + +// ResourceReadResult contains the result of a resources/read request +type ResourceReadResult struct { + Result + Contents []ResourceContent `json:"contents"` // Array of resource content +} + +// ResourceSubscribeParams contains parameters for a resources/subscribe request +type ResourceSubscribeParams struct { + URI string `json:"uri"` // URI of the resource to subscribe to +} + +// ResourceUpdateNotification represents a notification about a resource update +type ResourceUpdateNotification struct { + URI string `json:"uri"` // URI of the updated resource + Content ResourceContent `json:"content"` // New resource content +} + +// Client and Server Types + +// mcpClient represents an SSE client connection +type mcpClient struct { + id string // Unique client identifier + channel chan string // Channel for sending SSE messages + initialized bool // Tracks if client has sent notifications/initialized +} + +// McpServer defines the interface for Model Context Protocol servers +type McpServer interface { + Start() + Stop() + RegisterTool(tool Tool) error + RegisterPrompt(prompt Prompt) + RegisterResource(resource Resource) +} + +// sseMcpServer implements the McpServer interface using SSE +type sseMcpServer struct { + conf McpConf + server *rest.Server + clients map[string]*mcpClient + clientsLock sync.Mutex + tools map[string]Tool + toolsLock sync.Mutex + prompts map[string]Prompt + promptsLock sync.Mutex + resources map[string]Resource + resourcesLock sync.Mutex +} + +// Response Types + +// errorObj represents a JSON-RPC error object +type errorObj struct { + Code int `json:"code"` // Error code + Message string `json:"message"` // Error message +} + +// Response represents a JSON-RPC response +type Response struct { + JsonRpc string `json:"jsonrpc"` // Always "2.0" + ID int64 `json:"id"` // Same as request ID + Result any `json:"result"` // Result object (null if error) + Error *errorObj `json:"error,omitempty"` // Error object (null if success) +} + +// Server Information Types + +// serverInfo provides information about the server +type serverInfo struct { + Name string `json:"name"` // Server name + Version string `json:"version"` // Server version +} + +// capabilities describes the server's capabilities +type capabilities struct { + Logging struct{} `json:"logging"` + Prompts struct { + ListChanged bool `json:"listChanged"` // Server will notify on prompt changes + } `json:"prompts"` + Resources struct { + Subscribe bool `json:"subscribe"` // Server supports resource subscriptions + ListChanged bool `json:"listChanged"` // Server will notify on resource changes + } `json:"resources"` + Tools struct { + ListChanged bool `json:"listChanged"` // Server will notify on tool changes + } `json:"tools"` +} + +// initializationResponse is sent in response to an initialize request +type initializationResponse struct { + ProtocolVersion string `json:"protocolVersion"` // Protocol version + Capabilities capabilities `json:"capabilities"` // Server capabilities + ServerInfo serverInfo `json:"serverInfo"` // Server information +} + +// ToolCallParams contains the parameters for a tool call +type ToolCallParams struct { + Name string `json:"name"` // Tool name + Parameters map[string]any `json:"parameters"` // Tool parameters +} + +// ToolResult contains the result of a tool execution +type ToolResult struct { + Type string `json:"type"` // Content type (text, image, etc.) + Content any `json:"content"` // Result content +} + +// errorMessage represents a detailed error message +type errorMessage struct { + Code int `json:"code"` // Error code + Message string `json:"message"` // Error message + Data any `json:",omitempty"` // Additional error data +} diff --git a/mcp/types_test.go b/mcp/types_test.go new file mode 100644 index 000000000..45cea252a --- /dev/null +++ b/mcp/types_test.go @@ -0,0 +1,212 @@ +package mcp + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestResponseMarshaling(t *testing.T) { + // Test that the Response struct marshals correctly + resp := Response{ + JsonRpc: "2.0", + ID: 123, + Result: map[string]string{ + "key": "value", + }, + } + + data, err := json.Marshal(resp) + assert.NoError(t, err) + assert.Contains(t, string(data), `"jsonrpc":"2.0"`) + assert.Contains(t, string(data), `"id":123`) + assert.Contains(t, string(data), `"result":{"key":"value"}`) + + // Test response with error + respWithError := Response{ + JsonRpc: "2.0", + ID: 456, + Error: &errorObj{ + Code: errCodeInvalidRequest, + Message: "Invalid Request", + }, + } + + data, err = json.Marshal(respWithError) + assert.NoError(t, err) + assert.Contains(t, string(data), `"jsonrpc":"2.0"`) + assert.Contains(t, string(data), `"id":456`) + assert.Contains(t, string(data), `"error":{"code":-32600,"message":"Invalid Request"}`) +} + +func TestRequestUnmarshaling(t *testing.T) { + // Test that the Request struct unmarshals correctly + jsonStr := `{ + "jsonrpc": "2.0", + "id": 789, + "method": "test_method", + "params": {"key": "value"} + }` + + var req Request + err := json.Unmarshal([]byte(jsonStr), &req) + + assert.NoError(t, err) + assert.Equal(t, "2.0", req.JsonRpc) + assert.Equal(t, int64(789), req.ID) + assert.Equal(t, "test_method", req.Method) + + // Check params unmarshaled correctly + var params map[string]string + err = json.Unmarshal(req.Params, ¶ms) + assert.NoError(t, err) + assert.Equal(t, "value", params["key"]) +} + +func TestToolStructs(t *testing.T) { + // Test Tool struct + tool := Tool{ + Name: "test.tool", + Description: "A test tool", + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]any{ + "input": map[string]any{ + "type": "string", + "description": "Input parameter", + }, + }, + Required: []string{"input"}, + }, + Handler: func(params map[string]any) (any, error) { + return "result", nil + }, + } + + // Verify fields are correct + assert.Equal(t, "test.tool", tool.Name) + assert.Equal(t, "A test tool", tool.Description) + assert.Equal(t, "object", tool.InputSchema.Type) + assert.Contains(t, tool.InputSchema.Properties, "input") + propMap, ok := tool.InputSchema.Properties["input"].(map[string]any) + assert.True(t, ok, "Property should be a map") + assert.Equal(t, "string", propMap["type"]) + assert.NotNil(t, tool.Handler) + + // Verify JSON marshalling (which should exclude Handler function) + data, err := json.Marshal(tool) + assert.NoError(t, err) + assert.Contains(t, string(data), `"name":"test.tool"`) + assert.Contains(t, string(data), `"description":"A test tool"`) + assert.Contains(t, string(data), `"inputSchema":`) + assert.NotContains(t, string(data), `"Handler":`) +} + +func TestPromptStructs(t *testing.T) { + // Test Prompt struct + prompt := Prompt{ + Name: "test.prompt", + Description: "A test prompt description", + } + + // Verify fields are correct + assert.Equal(t, "test.prompt", prompt.Name) + assert.Equal(t, "A test prompt description", prompt.Description) + + // Verify JSON marshalling + data, err := json.Marshal(prompt) + assert.NoError(t, err) + assert.Contains(t, string(data), `"name":"test.prompt"`) + assert.Contains(t, string(data), `"description":"A test prompt description"`) +} + +func TestResourceStructs(t *testing.T) { + // Test Resource struct + resource := Resource{ + Name: "test.resource", + URI: "http://example.com/resource", + Description: "A test resource", + } + + // Verify fields are correct + assert.Equal(t, "test.resource", resource.Name) + assert.Equal(t, "http://example.com/resource", resource.URI) + assert.Equal(t, "A test resource", resource.Description) + + // Verify JSON marshalling + data, err := json.Marshal(resource) + assert.NoError(t, err) + assert.Contains(t, string(data), `"name":"test.resource"`) + assert.Contains(t, string(data), `"uri":"http://example.com/resource"`) + assert.Contains(t, string(data), `"description":"A test resource"`) +} + +func TestContentTypes(t *testing.T) { + // Test TextContent + textContent := TextContent{ + Type: "text", + Text: "Sample text", + Annotations: &Annotations{ + Audience: []roleType{roleUser, roleAssistant}, + Priority: ptr(1.0), + }, + } + + data, err := json.Marshal(textContent) + assert.NoError(t, err) + assert.Contains(t, string(data), `"type":"text"`) + assert.Contains(t, string(data), `"text":"Sample text"`) + assert.Contains(t, string(data), `"audience":["user","assistant"]`) + assert.Contains(t, string(data), `"priority":1`) + + // Test ImageContent + imageContent := ImageContent{ + Type: "image", + Data: "base64data", + MimeType: "image/png", + } + + data, err = json.Marshal(imageContent) + assert.NoError(t, err) + assert.Contains(t, string(data), `"type":"image"`) + assert.Contains(t, string(data), `"data":"base64data"`) + assert.Contains(t, string(data), `"mimeType":"image/png"`) + + // Test AudioContent + audioContent := AudioContent{ + Type: "audio", + Data: "base64audio", + MimeType: "audio/mp3", + } + + data, err = json.Marshal(audioContent) + assert.NoError(t, err) + assert.Contains(t, string(data), `"type":"audio"`) + assert.Contains(t, string(data), `"data":"base64audio"`) + assert.Contains(t, string(data), `"mimeType":"audio/mp3"`) +} + +func TestCallToolResult(t *testing.T) { + // Test CallToolResult + result := CallToolResult{ + Result: Result{ + Meta: map[string]any{ + "progressToken": "token123", + }, + }, + Content: []interface{}{ + TextContent{ + Type: "text", + Text: "Sample result", + }, + }, + IsError: false, + } + + data, err := json.Marshal(result) + assert.NoError(t, err) + assert.Contains(t, string(data), `"_meta":{"progressToken":"token123"}`) + assert.Contains(t, string(data), `"content":[{"type":"text","text":"Sample result"}]`) + assert.NotContains(t, string(data), `"isError":`) +} diff --git a/mcp/util.go b/mcp/util.go new file mode 100644 index 000000000..0963b7098 --- /dev/null +++ b/mcp/util.go @@ -0,0 +1,15 @@ +package mcp + +import ( + "fmt" +) + +// ptr is a helper function to get a pointer to a value +func ptr[T any](v T) *T { + return &v +} + +// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings +func formatSSEMessage(event string, data []byte) string { + return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data)) +} diff --git a/mcp/util_test.go b/mcp/util_test.go new file mode 100644 index 000000000..336cc3920 --- /dev/null +++ b/mcp/util_test.go @@ -0,0 +1,63 @@ +package mcp + +import ( + "bufio" + "encoding/json" + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPtr(t *testing.T) { + tests := []struct { + name string + v interface{} + }{ + {"string", "test"}, + {"int", 42}, + {"bool", true}, + {"float", 3.14}, + {"struct", struct{ Name string }{"test"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ptr(tt.v) + assert.NotNil(t, got, "ptr() should not return nil") + assert.Equal(t, tt.v, *got, "dereferenced pointer should equal input value") + }) + } +} + +type Event struct { + Type string + Data map[string]any +} + +func parseEvent(input string) (*Event, error) { + var evt Event + var dataStr string + + scanner := bufio.NewScanner(strings.NewReader(input)) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "event:") { + evt.Type = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + } else if strings.HasPrefix(line, "data:") { + dataStr = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + } + } + if err := scanner.Err(); err != nil { + return nil, err + } + + if len(dataStr) > 0 { + if err := json.Unmarshal([]byte(dataStr), &evt.Data); err != nil { + return nil, fmt.Errorf("failed to parse data: %w", err) + } + } + + return &evt, nil +} diff --git a/mcp/vars.go b/mcp/vars.go new file mode 100644 index 000000000..cc9a300da --- /dev/null +++ b/mcp/vars.go @@ -0,0 +1,137 @@ +package mcp + +import ( + "time" + + "github.com/zeromicro/go-zero/core/syncx" +) + +// Protocol constants +const ( + // JSON-RPC version as defined in the specification + jsonRpcVersion = "2.0" + + // Session identifier key used in request URLs + sessionIdKey = "session_id" +) + +// Server-Sent Events (SSE) event types +const ( + // Standard message event for JSON-RPC responses + eventMessage = "message" + + // Endpoint event for sending endpoint URL to clients + eventEndpoint = "endpoint" +) + +// Content type identifiers +const ( + // Text content type + contentTypeText = "text" + + // Image content type + contentTypeImage = "image" +) + +// Collection keys for broadcast events +const ( + // Key for prompts collection + keyPrompts = "prompts" + + // Key for resources collection + keyResources = "resources" + + // Key for tools collection + keyTools = "tools" +) + +// JSON-RPC error codes +// Standard error codes from JSON-RPC 2.0 spec +const ( + // Invalid JSON was received by the server + errCodeInvalidRequest = -32600 + + // The method does not exist / is not available + errCodeMethodNotFound = -32601 + + // Invalid method parameter(s) + errCodeInvalidParams = -32602 + + // Internal JSON-RPC error + errCodeInternalError = -32603 + + // Tool execution timed out + errCodeTimeout = -32001 + + // Resource not found error + errCodeResourceNotFound = -32002 + + // Client hasn't completed initialization + errCodeClientNotInitialized = -32800 +) + +// User and assistant role definitions +const ( + // The "user" role - the entity asking questions + roleUser roleType = "user" + + // The "assistant" role - the entity providing responses + roleAssistant roleType = "assistant" +) + +// Method names as defined in the MCP specification +const ( + // Initialize the connection between client and server + methodInitialize = "initialize" + + // List available tools + methodToolsList = "tools/list" + + // Call a specific tool + methodToolsCall = "tools/call" + + // List available prompts + methodPromptsList = "prompts/list" + + // Get a specific prompt + methodPromptsGet = "prompts/get" + + // List available resources + methodResourcesList = "resources/list" + + // Read a specific resource + methodResourcesRead = "resources/read" + + // Subscribe to resource updates + methodResourcesSubscribe = "resources/subscribe" + + // Simple ping to check server availability + methodPing = "ping" + + // Notification that client is fully initialized + methodNotificationsInitialized = "notifications/initialized" + + // Notification that a request was canceled + methodNotificationsCancelled = "notifications/cancelled" +) + +// Event names for Server-Sent Events (SSE) +const ( + // Notification of tool list changes + eventToolsListChanged = "tools/list_changed" + + // Notification of prompt list changes + eventPromptsListChanged = "prompts/list_changed" + + // Notification of resource list changes + eventResourcesListChanged = "resources/list_changed" +) + +var ( + // Default channel size for events + eventChanSize = 10 + + // Default ping interval for checking connection availability + // use syncx.ForAtomicDuration to ensure atomicity in test race + pingInterval = syncx.ForAtomicDuration(30 * time.Second) +) diff --git a/mcp/vars_test.go b/mcp/vars_test.go new file mode 100644 index 000000000..5cfdca195 --- /dev/null +++ b/mcp/vars_test.go @@ -0,0 +1,214 @@ +// filepath: /Users/kevin/Develop/go/opensource/go-zero/mcp/vars_test.go +package mcp + +import ( + "encoding/json" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestErrorCodes ensures error codes are applied correctly in error responses +func TestErrorCodes(t *testing.T) { + testCases := []struct { + name string + code int + message string + expected string + }{ + { + name: "invalid request error", + code: errCodeInvalidRequest, + message: "Invalid request", + expected: `"code":-32600`, + }, + { + name: "method not found error", + code: errCodeMethodNotFound, + message: "Method not found", + expected: `"code":-32601`, + }, + { + name: "invalid params error", + code: errCodeInvalidParams, + message: "Invalid parameters", + expected: `"code":-32602`, + }, + { + name: "internal error", + code: errCodeInternalError, + message: "Internal server error", + expected: `"code":-32603`, + }, + { + name: "timeout error", + code: errCodeTimeout, + message: "Operation timed out", + expected: `"code":-32001`, + }, + { + name: "resource not found error", + code: errCodeResourceNotFound, + message: "Resource not found", + expected: `"code":-32002`, + }, + { + name: "client not initialized error", + code: errCodeClientNotInitialized, + message: "Client not initialized", + expected: `"code":-32800`, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + resp := Response{ + JsonRpc: jsonRpcVersion, + ID: int64(1), + Error: &errorObj{ + Code: tc.code, + Message: tc.message, + }, + } + data, err := json.Marshal(resp) + assert.NoError(t, err) + assert.Contains(t, string(data), tc.expected, "Error code should match expected value") + assert.Contains(t, string(data), tc.message, "Error message should be included") + assert.Contains(t, string(data), jsonRpcVersion, "JSON-RPC version should be included") + }) + } +} + +// TestJsonRpcVersion ensures the correct JSON-RPC version is used +func TestJsonRpcVersion(t *testing.T) { + assert.Equal(t, "2.0", jsonRpcVersion, "JSON-RPC version should be 2.0") + + // Test that it's used in responses + resp := Response{ + JsonRpc: jsonRpcVersion, + ID: int64(1), + Result: "test", + } + data, err := json.Marshal(resp) + assert.NoError(t, err) + assert.Contains(t, string(data), `"jsonrpc":"2.0"`, "Response should use correct JSON-RPC version") + + // Test that it's expected in requests + reqStr := `{"jsonrpc":"2.0","id":1,"method":"test"}` + var req Request + err = json.Unmarshal([]byte(reqStr), &req) + assert.NoError(t, err) + assert.Equal(t, jsonRpcVersion, req.JsonRpc, "Request should parse correct JSON-RPC version") +} + +// TestSessionIdKey ensures session ID extraction works correctly +func TestSessionIdKey(t *testing.T) { + // Create a mock server implementation + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Verify the key constant + assert.Equal(t, "session_id", sessionIdKey, "Session ID key should be 'session_id'") + + // Test that session ID is extracted correctly + mockR := httptest.NewRequest("GET", "/?"+sessionIdKey+"=test-session", nil) + + // Since the mock server is using the same session key logic, + // we can test this by accessing the request query parameters directly + sessionID := mockR.URL.Query().Get(sessionIdKey) + assert.Equal(t, "test-session", sessionID, "Session ID should be extracted correctly") +} + +// TestEventTypes ensures event types are set correctly in SSE responses +func TestEventTypes(t *testing.T) { + // Test message event + assert.Equal(t, "message", eventMessage, "Message event should be 'message'") + + // Test endpoint event + assert.Equal(t, "endpoint", eventEndpoint, "Endpoint event should be 'endpoint'") + + // Verify them in an actual SSE format string + messageEvent := "event: " + eventMessage + "\ndata: test\n\n" + assert.Contains(t, messageEvent, "event: message", "Message event should format correctly") + + endpointEvent := "event: " + eventEndpoint + "\ndata: test\n\n" + assert.Contains(t, endpointEvent, "event: endpoint", "Endpoint event should format correctly") +} + +// TestCollectionKeys checks that collection keys are used correctly +func TestCollectionKeys(t *testing.T) { + // Verify collection key constants + assert.Equal(t, "prompts", keyPrompts, "Prompts key should be 'prompts'") + assert.Equal(t, "resources", keyResources, "Resources key should be 'resources'") + assert.Equal(t, "tools", keyTools, "Tools key should be 'tools'") +} + +// TestRoleTypes checks that role types are used correctly +func TestRoleTypes(t *testing.T) { + // Verify role type constants + assert.Equal(t, "user", string(roleUser), "User role should be 'user'") + assert.Equal(t, "assistant", string(roleAssistant), "Assistant role should be 'assistant'") + + // Test in annotations + annotations := Annotations{ + Audience: []roleType{roleUser, roleAssistant}, + } + data, err := json.Marshal(annotations) + assert.NoError(t, err) + assert.Contains(t, string(data), `"audience":["user","assistant"]`, "Role types should marshal correctly") +} + +// TestMethodNames checks that method names are used correctly +func TestMethodNames(t *testing.T) { + // Verify method name constants + methods := map[string]string{ + "initialize": methodInitialize, + "tools/list": methodToolsList, + "tools/call": methodToolsCall, + "prompts/list": methodPromptsList, + "prompts/get": methodPromptsGet, + "resources/list": methodResourcesList, + "resources/read": methodResourcesRead, + "resources/subscribe": methodResourcesSubscribe, + "ping": methodPing, + "notifications/initialized": methodNotificationsInitialized, + "notifications/cancelled": methodNotificationsCancelled, + } + + for expected, actual := range methods { + assert.Equal(t, expected, actual, "Method name should be "+expected) + } + + // Test in a request + for methodName := range methods { + req := Request{ + JsonRpc: jsonRpcVersion, + ID: int64(1), + Method: methodName, + } + data, err := json.Marshal(req) + assert.NoError(t, err) + assert.Contains(t, string(data), `"method":"`+methodName+`"`, "Method name should be used in requests") + } +} + +// TestEventNames checks that event names are used correctly +func TestEventNames(t *testing.T) { + // Verify event name constants + events := map[string]string{ + "tools/list_changed": eventToolsListChanged, + "prompts/list_changed": eventPromptsListChanged, + "resources/list_changed": eventResourcesListChanged, + } + + for expected, actual := range events { + assert.Equal(t, expected, actual, "Event name should be "+expected) + } + + // Test event names in SSE format + for _, eventName := range events { + sseEvent := "event: " + eventName + "\ndata: test\n\n" + assert.Contains(t, sseEvent, "event: "+eventName, "Event name should format correctly in SSE") + } +}