diff --git a/mcp/config.go b/mcp/config.go index e9a6635ff..4c404a14b 100644 --- a/mcp/config.go +++ b/mcp/config.go @@ -34,7 +34,10 @@ type McpConf struct { // Cors contains allowed CORS origins Cors []string `json:",optional"` - // ToolTimeout is the maximum time allowed for tool execution - ToolTimeout time.Duration `json:",default=30s"` + // SseTimeout is the maximum time allowed for SSE connections + SseTimeout time.Duration `json:",default=24h"` + + // MessageTimeout is the maximum time allowed for request execution + MessageTimeout time.Duration `json:",default=30s"` } } diff --git a/mcp/config_test.go b/mcp/config_test.go index 4156f0a86..5b9d13da3 100644 --- a/mcp/config_test.go +++ b/mcp/config_test.go @@ -27,7 +27,7 @@ mcp: assert.Equal(t, "2024-11-05", c.Mcp.ProtocolVersion, "Default protocol version should be 2024-11-05") assert.Equal(t, "/sse", c.Mcp.SseEndpoint, "Default SSE endpoint should be /sse") assert.Equal(t, "/message", c.Mcp.MessageEndpoint, "Default message endpoint should be /message") - assert.Equal(t, 30*time.Second, c.Mcp.ToolTimeout, "Default tool timeout should be 30s") + assert.Equal(t, 30*time.Second, c.Mcp.MessageTimeout, "Default message timeout should be 30s") } func TestMcpConfCustomValues(t *testing.T) { @@ -43,7 +43,7 @@ func TestMcpConfCustomValues(t *testing.T) { "SseEndpoint": "/custom-sse", "MessageEndpoint": "/custom-message", "Cors": ["http://localhost:3000", "http://example.com"], - "ToolTimeout": "60s" + "MessageTimeout": "60s" } }` @@ -59,5 +59,5 @@ func TestMcpConfCustomValues(t *testing.T) { assert.Equal(t, "/custom-sse", c.Mcp.SseEndpoint, "SSE endpoint should be customizable") assert.Equal(t, "/custom-message", c.Mcp.MessageEndpoint, "Message endpoint should be customizable") assert.Equal(t, []string{"http://localhost:3000", "http://example.com"}, c.Mcp.Cors, "CORS settings should be customizable") - assert.Equal(t, 60*time.Second, c.Mcp.ToolTimeout, "Tool timeout should be customizable") + assert.Equal(t, 60*time.Second, c.Mcp.MessageTimeout, "Tool timeout should be customizable") } diff --git a/mcp/integration_test.go b/mcp/integration_test.go index 5d7016fc5..ed28b72a3 100644 --- a/mcp/integration_test.go +++ b/mcp/integration_test.go @@ -2,6 +2,7 @@ package mcp import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -59,7 +60,7 @@ func TestHTTPHandlerIntegration(t *testing.T) { conf := McpConf{} conf.Mcp.Name = "test-integration" conf.Mcp.Version = "1.0.0-test" - conf.Mcp.ToolTimeout = 1 * time.Second + conf.Mcp.MessageTimeout = 1 * time.Second // Create a mock server directly server := &sseMcpServer{ @@ -75,7 +76,6 @@ func TestHTTPHandlerIntegration(t *testing.T) { Name: "echo", Description: "Echo tool for testing", InputSchema: InputSchema{ - Type: "object", Properties: map[string]any{ "message": map[string]any{ "type": "string", @@ -83,7 +83,7 @@ func TestHTTPHandlerIntegration(t *testing.T) { }, }, }, - Handler: func(params map[string]any) (any, error) { + Handler: func(ctx context.Context, params map[string]any) (any, error) { if msg, ok := params["message"].(string); ok { return fmt.Sprintf("Echo: %s", msg), nil } @@ -181,7 +181,7 @@ func TestHandlerResponseFlow(t *testing.T) { Name: "test.tool", Description: "Test tool", InputSchema: InputSchema{Type: "object"}, - Handler: func(params map[string]any) (any, error) { + Handler: func(ctx context.Context, params map[string]any) (any, error) { return "tool result", nil }, }) @@ -329,7 +329,7 @@ func TestProcessListMethods(t *testing.T) { Params: json.RawMessage(`{"cursor": "", "_meta": {"progressToken": "token1"}}`), } - server.processListTools(client, req) + server.processListTools(context.Background(), client, req) // Read response select { @@ -344,7 +344,7 @@ func TestProcessListMethods(t *testing.T) { req.ID = 2 req.Method = methodPromptsList req.Params = json.RawMessage(`{"cursor": "next"}`) - server.processListPrompts(client, req) + server.processListPrompts(context.Background(), client, req) // Read response select { @@ -358,7 +358,7 @@ func TestProcessListMethods(t *testing.T) { req.ID = 3 req.Method = methodResourcesList req.Params = json.RawMessage(`{"cursor": "next"}`) - server.processListResources(client, req) + server.processListResources(context.Background(), client, req) // Read response select { @@ -393,7 +393,7 @@ func TestErrorResponseHandling(t *testing.T) { } // Mock handleRequest by directly calling error handler - server.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound) + server.sendErrorResponse(context.Background(), client, req.ID, "Method not found", errCodeMethodNotFound) // Check response select { @@ -412,7 +412,7 @@ func TestErrorResponseHandling(t *testing.T) { } // Call process method directly - server.processToolCall(client, toolReq) + server.processToolCall(context.Background(), client, toolReq) // Check response select { @@ -431,7 +431,7 @@ func TestErrorResponseHandling(t *testing.T) { } // Call process method directly - server.processGetPrompt(client, promptReq) + server.processGetPrompt(context.Background(), client, promptReq) // Check response select { diff --git a/mcp/parser.go b/mcp/parser.go new file mode 100644 index 000000000..45584ce17 --- /dev/null +++ b/mcp/parser.go @@ -0,0 +1,23 @@ +package mcp + +import ( + "fmt" + + "github.com/zeromicro/go-zero/core/mapping" +) + +// ParseArguments parses the arguments and populates the request object +func ParseArguments(args any, req any) error { + switch arguments := args.(type) { + case map[string]string: + m := make(map[string]any, len(arguments)) + for k, v := range arguments { + m[k] = v + } + return mapping.UnmarshalJsonMap(m, req, mapping.WithStringValues()) + case map[string]any: + return mapping.UnmarshalJsonMap(arguments, req) + default: + return fmt.Errorf("unsupported argument type: %T", arguments) + } +} diff --git a/mcp/parser_test.go b/mcp/parser_test.go new file mode 100644 index 000000000..43071a579 --- /dev/null +++ b/mcp/parser_test.go @@ -0,0 +1,139 @@ +package mcp + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestParseArguments_MapStringString tests parsing map[string]string arguments +func TestParseArguments_MapStringString(t *testing.T) { + // Sample request struct to populate + type requestStruct struct { + Name string `json:"name"` + Message string `json:"message"` + Count int `json:"count"` + Enabled bool `json:"enabled"` + } + + // Create test arguments + args := map[string]string{ + "name": "test-name", + "message": "hello world", + "count": "42", + "enabled": "true", + } + + // Create a target object to populate + var req requestStruct + + // Parse the arguments + err := ParseArguments(args, &req) + + // Verify results + assert.NoError(t, err, "Should parse map[string]string without error") + assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed") + assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed") + assert.Equal(t, 42, req.Count, "Count should be correctly parsed to int") + assert.True(t, req.Enabled, "Enabled should be correctly parsed to bool") +} + +// TestParseArguments_MapStringAny tests parsing map[string]any arguments +func TestParseArguments_MapStringAny(t *testing.T) { + // Sample request struct to populate + type requestStruct struct { + Name string `json:"name"` + Message string `json:"message"` + Count int `json:"count"` + Enabled bool `json:"enabled"` + Tags []string `json:"tags"` + Metadata map[string]string `json:"metadata"` + } + + // Create test arguments with mixed types + args := map[string]any{ + "name": "test-name", + "message": "hello world", + "count": 42, // note: this is already an int + "enabled": true, // note: this is already a bool + "tags": []string{"tag1", "tag2"}, + "metadata": map[string]string{ + "key1": "value1", + "key2": "value2", + }, + } + + // Create a target object to populate + var req requestStruct + + // Parse the arguments + err := ParseArguments(args, &req) + + // Verify results + assert.NoError(t, err, "Should parse map[string]any without error") + assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed") + assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed") + assert.Equal(t, 42, req.Count, "Count should be correctly parsed") + assert.True(t, req.Enabled, "Enabled should be correctly parsed") + assert.Equal(t, []string{"tag1", "tag2"}, req.Tags, "Tags should be correctly parsed") + assert.Equal(t, map[string]string{ + "key1": "value1", + "key2": "value2", + }, req.Metadata, "Metadata should be correctly parsed") +} + +// TestParseArguments_UnsupportedType tests parsing with an unsupported type +func TestParseArguments_UnsupportedType(t *testing.T) { + // Sample request struct to populate + type requestStruct struct { + Name string `json:"name"` + Message string `json:"message"` + } + + // Use an unsupported argument type (slice) + args := []string{"not", "a", "map"} + + // Create a target object to populate + var req requestStruct + + // Parse the arguments + err := ParseArguments(args, &req) + + // Verify error is returned with correct message + assert.Error(t, err, "Should return error for unsupported type") + assert.Contains(t, err.Error(), "unsupported argument type", "Error should mention unsupported type") + assert.Contains(t, err.Error(), "[]string", "Error should include the actual type") +} + +// TestParseArguments_EmptyMap tests parsing with empty maps +func TestParseArguments_EmptyMap(t *testing.T) { + // Sample request struct to populate + type requestStruct struct { + Name string `json:"name,optional"` + Message string `json:"message,optional"` + } + + // Test empty map[string]string + t.Run("EmptyMapStringString", func(t *testing.T) { + args := map[string]string{} + var req requestStruct + + err := ParseArguments(args, &req) + + assert.NoError(t, err, "Should parse empty map[string]string without error") + assert.Empty(t, req.Name, "Name should be empty string") + assert.Empty(t, req.Message, "Message should be empty string") + }) + + // Test empty map[string]any + t.Run("EmptyMapStringAny", func(t *testing.T) { + args := map[string]any{} + var req requestStruct + + err := ParseArguments(args, &req) + + assert.NoError(t, err, "Should parse empty map[string]any without error") + assert.Empty(t, req.Name, "Name should be empty string") + assert.Empty(t, req.Message, "Message should be empty string") + }) +} diff --git a/mcp/readme.md b/mcp/readme.md index 98094aca2..c4e9fd09e 100644 --- a/mcp/readme.md +++ b/mcp/readme.md @@ -1,7 +1,7 @@ -# Model Context Protocol (MCP) SDK Implementation +# Model Context Protocol (MCP) Implementation ## Overview -This package implements a Model Context Protocol (MCP) server in Go that facilitates real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation provides a framework for building AI-assisted applications with bidirectional communication capabilities. +This package implements the Model Context Protocol (MCP) server specification in Go, providing a framework for real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation follows the standardized protocol for building AI-assisted applications with bidirectional communication capabilities. ## Core Components @@ -54,9 +54,817 @@ This package implements a Model Context Protocol (MCP) server in Go that facilit ## Usage -To create and use an MCP server, see the examples directory for practical implementation examples including: -- Tool registration and execution -- Static and dynamic prompt creation -- Resource handling with proper URI identification -- Embedded resources in prompt responses -- Client connection management \ No newline at end of file +### Setting Up an MCP Server + +To create and start an MCP server: + +```go +package main + +import ( + "github.com/zeromicro/go-zero/core/conf" + "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/mcp" +) + +func main() { + // Load configuration from YAML file + var c mcp.McpConf + conf.MustLoad("config.yaml", &c) + + // Optional: Disable stats logging + logx.DisableStat() + + // Create MCP server + server := mcp.NewMcpServer(c) + + // Register tools, prompts, and resources (examples below) + + // Start the server and ensure it's stopped on exit + defer server.Stop() + server.Start() +} +``` + +Sample configuration file (config.yaml): + +```yaml +name: mcp-server +host: localhost +port: 8080 +mcp: + name: my-mcp-server + messageTimeout: 30s # Timeout for tool calls + cors: + - http://localhost:3000 # Optional CORS configuration +``` + +### Registering Tools + +Tools allow AI models to execute custom code through the MCP protocol. + +#### Basic Tool Example: + +```go +// Register a simple echo tool +echoTool := mcp.Tool{ + Name: "echo", + Description: "Echoes back the message provided by the user", + InputSchema: mcp.InputSchema{ + Properties: map[string]any{ + "message": map[string]any{ + "type": "string", + "description": "The message to echo back", + }, + "prefix": map[string]any{ + "type": "string", + "description": "Optional prefix to add to the echoed message", + "default": "Echo: ", + }, + }, + Required: []string{"message"}, + }, + Handler: func(ctx context.Context, params map[string]any) (any, error) { + var req struct { + Message string `json:"message"` + Prefix string `json:"prefix,optional"` + } + + if err := mcp.ParseArguments(params, &req); err != nil { + return nil, fmt.Errorf("failed to parse params: %w", err) + } + + prefix := "Echo: " + if len(req.Prefix) > 0 { + prefix = req.Prefix + } + + return prefix + req.Message, nil + }, +} + +server.RegisterTool(echoTool) +``` + +#### Tool with Different Response Types: + +```go +// Tool returning JSON data +dataTool := mcp.Tool{ + Name: "data.generate", + Description: "Generates sample data in various formats", + InputSchema: mcp.InputSchema{ + Properties: map[string]any{ + "format": map[string]any{ + "type": "string", + "description": "Format of data (json, text)", + "enum": []string{"json", "text"}, + }, + }, + }, + Handler: func(ctx context.Context, params map[string]any) (any, error) { + var req struct { + Format string `json:"format"` + } + + if err := mcp.ParseArguments(params, &req); err != nil { + return nil, fmt.Errorf("failed to parse params: %w", err) + } + + if req.Format == "json" { + // Return structured data + return map[string]any{ + "items": []map[string]any{ + {"id": 1, "name": "Item 1"}, + {"id": 2, "name": "Item 2"}, + }, + "count": 2, + }, nil + } + + // Default to text + return "Sample text data", nil + }, +} + +server.RegisterTool(dataTool) +``` + +#### Image Generation Tool Example: + +```go +// Tool returning image content +imageTool := mcp.Tool{ + Name: "image.generate", + Description: "Generates a simple image", + InputSchema: mcp.InputSchema{ + Properties: map[string]any{ + "type": map[string]any{ + "type": "string", + "description": "Type of image to generate", + "default": "placeholder", + }, + }, + }, + Handler: func(ctx context.Context, params map[string]any) (any, error) { + // Return image content directly + return mcp.ImageContent{ + Data: "base64EncodedImageData...", // Base64 encoded image data + MimeType: "image/png", + }, nil + }, +} + +server.RegisterTool(imageTool) +``` + +#### Using ToolResult for Custom Outputs: + +```go +// Tool that returns a custom ToolResult type +customResultTool := mcp.Tool{ + Name: "custom.result", + Description: "Returns a custom formatted result", + InputSchema: mcp.InputSchema{ + Properties: map[string]any{ + "resultType": map[string]any{ + "type": "string", + "enum": []string{"text", "image"}, + }, + }, + }, + Handler: func(ctx context.Context, params map[string]any) (any, error) { + var req struct { + ResultType string `json:"resultType"` + } + + if err := mcp.ParseArguments(params, &req); err != nil { + return nil, fmt.Errorf("failed to parse params: %w", err) + } + + if req.ResultType == "image" { + return mcp.ToolResult{ + Type: mcp.ContentTypeImage, + Content: map[string]any{ + "data": "base64EncodedImageData...", + "mimeType": "image/jpeg", + }, + }, nil + } + + // Default to text + return mcp.ToolResult{ + Type: mcp.ContentTypeText, + Content: "This is a text result from ToolResult", + }, nil + }, +} + +server.RegisterTool(customResultTool) +``` + +### Registering Prompts + +Prompts are reusable conversation templates for AI models. + +#### Static Prompt Example: + +```go +// Register a simple static prompt with placeholders +server.RegisterPrompt(mcp.Prompt{ + Name: "hello", + Description: "A simple hello prompt", + Arguments: []mcp.PromptArgument{ + { + Name: "name", + Description: "The name to greet", + Required: false, + }, + }, + Content: "Say hello to {{name}} and introduce yourself as an AI assistant.", +}) +``` + +#### Dynamic Prompt with Handler Function: + +```go +// Register a prompt with a dynamic handler function +server.RegisterPrompt(mcp.Prompt{ + Name: "dynamic-prompt", + Description: "A prompt that uses a handler to generate dynamic content", + Arguments: []mcp.PromptArgument{ + { + Name: "username", + Description: "User's name for personalized greeting", + Required: true, + }, + { + Name: "topic", + Description: "Topic of expertise", + Required: true, + }, + }, + Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) { + var req struct { + Username string `json:"username"` + Topic string `json:"topic"` + } + + if err := mcp.ParseArguments(args, &req); err != nil { + return nil, fmt.Errorf("failed to parse args: %w", err) + } + + // Create a user message + userMessage := mcp.PromptMessage{ + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic), + }, + } + + // Create an assistant response with current time + currentTime := time.Now().Format(time.RFC1123) + assistantMessage := mcp.PromptMessage{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.", + req.Username, req.Topic, currentTime), + }, + } + + // Return both messages as a conversation + return []mcp.PromptMessage{userMessage, assistantMessage}, nil + }, +}) +``` + +#### Multi-Message Prompt with Code Examples: + +```go +// Register a prompt that provides code examples in different programming languages +server.RegisterPrompt(mcp.Prompt{ + Name: "code-example", + Description: "Provides code examples in different programming languages", + Arguments: []mcp.PromptArgument{ + { + Name: "language", + Description: "Programming language for the example", + Required: true, + }, + { + Name: "complexity", + Description: "Complexity level (simple, medium, advanced)", + }, + }, + Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) { + var req struct { + Language string `json:"language"` + Complexity string `json:"complexity,optional"` + } + + if err := mcp.ParseArguments(args, &req); err != nil { + return nil, fmt.Errorf("failed to parse args: %w", err) + } + + // Validate language + supportedLanguages := map[string]bool{"go": true, "python": true, "javascript": true, "rust": true} + if !supportedLanguages[req.Language] { + return nil, fmt.Errorf("unsupported language: %s", req.Language) + } + + // Generate code example based on language and complexity + var codeExample string + + switch req.Language { + case "go": + if req.Complexity == "simple" { + codeExample = ` +package main + +import "fmt" + +func main() { + fmt.Println("Hello, World!") +}` + } else { + codeExample = ` +package main + +import ( + "fmt" + "time" +) + +func main() { + now := time.Now() + fmt.Printf("Hello, World! Current time is %s\n", now.Format(time.RFC3339)) +}` + } + case "python": + // Python example code + if req.Complexity == "simple" { + codeExample = ` +def greet(name): + return f"Hello, {name}!" + +print(greet("World"))` + } else { + codeExample = ` +import datetime + +def greet(name, include_time=False): + message = f"Hello, {name}!" + if include_time: + message += f" Current time is {datetime.datetime.now().isoformat()}" + return message + +print(greet("World", include_time=True))` + } + } + + // Create messages array according to MCP spec + messages := []mcp.PromptMessage{ + { + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Text: fmt.Sprintf("You are a helpful coding assistant specialized in %s programming.", req.Language), + }, + }, + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Text: fmt.Sprintf("Show me a %s example of a Hello World program in %s.", req.Complexity, req.Language), + }, + }, + { + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Text: fmt.Sprintf("Here's a %s example in %s:\n\n```%s%s\n```\n\nHow can I help you implement this?", + req.Complexity, req.Language, req.Language, codeExample), + }, + }, + } + + return messages, nil + }, +}) +``` + +### Registering Resources + +Resources provide access to external content such as files or generated data. + +#### Basic Resource Example: + +```go +// Register a static resource +server.RegisterResource(mcp.Resource{ + Name: "example-document", + URI: "file:///example/document.txt", + Description: "An example document", + MimeType: "text/plain", + Handler: func(ctx context.Context) (mcp.ResourceContent, error) { + return mcp.ResourceContent{ + URI: "file:///example/document.txt", + MimeType: "text/plain", + Text: "This is an example document content.", + }, nil + }, +}) +``` + +#### Dynamic Resource with Code Example: + +```go +// Register a Go code resource with dynamic handler +server.RegisterResource(mcp.Resource{ + Name: "go-example", + URI: "file:///project/src/main.go", + Description: "A simple Go example with multiple files", + MimeType: "text/x-go", + Handler: func(ctx context.Context) (mcp.ResourceContent, error) { + // Return ResourceContent with all required fields + return mcp.ResourceContent{ + URI: "file:///project/src/main.go", + MimeType: "text/x-go", + Text: "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}", + }, nil + }, +}) + +// Register a companion file for the above example +server.RegisterResource(mcp.Resource{ + Name: "go-greeting", + URI: "file:///project/src/greeting/greeting.go", + Description: "A greeting package for the Go example", + MimeType: "text/x-go", + Handler: func(ctx context.Context) (mcp.ResourceContent, error) { + return mcp.ResourceContent{ + URI: "file:///project/src/greeting/greeting.go", + MimeType: "text/x-go", + Text: "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}", + }, nil + }, +}) +``` + +#### Binary Resource Example: + +```go +// Register a binary resource (like an image) +server.RegisterResource(mcp.Resource{ + Name: "example-image", + URI: "file:///example/image.png", + Description: "An example image", + MimeType: "image/png", + Handler: func(ctx context.Context) (mcp.ResourceContent, error) { + // Read image from file or generate it + imageData := "base64EncodedImageData..." // Base64 encoded image data + + return mcp.ResourceContent{ + URI: "file:///example/image.png", + MimeType: "image/png", + Blob: imageData, // For binary data + }, nil + }, +}) +``` + +### Using Resources in Prompts + +You can embed resources in prompt responses to create rich interactions with proper MCP-compliant structure: + +```go +// Register a prompt that embeds a resource +server.RegisterPrompt(mcp.Prompt{ + Name: "resource-example", + Description: "A prompt that embeds a resource", + Arguments: []mcp.PromptArgument{ + { + Name: "file_type", + Description: "Type of file to show (rust or go)", + Required: true, + }, + }, + Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) { + var req struct { + FileType string `json:"file_type"` + } + + if err := mcp.ParseArguments(args, &req); err != nil { + return nil, fmt.Errorf("failed to parse args: %w", err) + } + + var resourceURI, mimeType, fileContent string + if req.FileType == "rust" { + resourceURI = "file:///project/src/main.rs" + mimeType = "text/x-rust" + fileContent = "fn main() {\n println!(\"Hello world!\");\n}" + } else { + resourceURI = "file:///project/src/main.go" + mimeType = "text/x-go" + fileContent = "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello, world!\")\n}" + } + + // Create message with embedded resource using proper MCP format + return []mcp.PromptMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Text: fmt.Sprintf("Can you explain this %s code?", req.FileType), + }, + }, + { + Role: mcp.RoleAssistant, + Content: mcp.EmbeddedResource{ + Type: mcp.ContentTypeResource, + Resource: struct { + URI string `json:"uri"` + MimeType string `json:"mimeType"` + Text string `json:"text,omitempty"` + Blob string `json:"blob,omitempty"` + }{ + URI: resourceURI, + MimeType: mimeType, + Text: fileContent, + }, + }, + }, + { + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Text: fmt.Sprintf("Above is a simple Hello World example in %s. Let me explain how it works.", req.FileType), + }, + }, + }, nil + }, +}) +``` + +### Multiple File Resources Example + +```go +// Register a prompt that demonstrates embedding multiple resource files +server.RegisterPrompt(mcp.Prompt{ + Name: "go-code-example", + Description: "A prompt that correctly embeds multiple resource files", + Arguments: []mcp.PromptArgument{ + { + Name: "format", + Description: "How to format the code display", + }, + }, + Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) { + var req struct { + Format string `json:"format,optional"` + } + + if err := mcp.ParseArguments(args, &req); err != nil { + return nil, fmt.Errorf("failed to parse args: %w", err) + } + + // Get the Go code for multiple files + var mainGoText string = "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}" + var greetingGoText string = "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}" + + // Create message with properly formatted embedded resource per MCP spec + messages := []mcp.PromptMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Text: "Show me a simple Go example with proper imports.", + }, + }, + { + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Text: "Here's a simple Go example project:", + }, + }, + { + Role: mcp.RoleAssistant, + Content: mcp.EmbeddedResource{ + Type: mcp.ContentTypeResource, + Resource: struct { + URI string `json:"uri"` + MimeType string `json:"mimeType"` + Text string `json:"text,omitempty"` + Blob string `json:"blob,omitempty"` + }{ + URI: "file:///project/src/main.go", + MimeType: "text/x-go", + Text: mainGoText, + }, + }, + }, + } + + // Add explanation and additional file if requested + if req.Format == "with_explanation" { + messages = append(messages, mcp.PromptMessage{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Text: "This example demonstrates a simple Go application with modular structure. The main.go file imports from a local 'greeting' package that provides the Hello function.", + }, + }) + + // Also show the greeting.go file with correct resource format + messages = append(messages, mcp.PromptMessage{ + Role: mcp.RoleAssistant, + Content: mcp.EmbeddedResource{ + Type: mcp.ContentTypeResource, + Resource: struct { + URI string `json:"uri"` + MimeType string `json:"mimeType"` + Text string `json:"text,omitempty"` + Blob string `json:"blob,omitempty"` + }{ + URI: "file:///project/src/greeting/greeting.go", + MimeType: "text/x-go", + Text: greetingGoText, + }, + }, + }) + } + + return messages, nil + }, +}) +``` + +### Complete Application Example + +Here's a complete example demonstrating all the components: + +```go +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/zeromicro/go-zero/core/conf" + "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/mcp" +) + +func main() { + // Load configuration + var c mcp.McpConf + if err := conf.Load("config.yaml", &c); err != nil { + log.Fatalf("Failed to load config: %v", err) + } + + // Set up logging + logx.DisableStat() + + // Create MCP server + server := mcp.NewMcpServer(c) + defer server.Stop() + + // Register a simple echo tool + echoTool := mcp.Tool{ + Name: "echo", + Description: "Echoes back the message provided by the user", + InputSchema: mcp.InputSchema{ + Properties: map[string]any{ + "message": map[string]any{ + "type": "string", + "description": "The message to echo back", + }, + "prefix": map[string]any{ + "type": "string", + "description": "Optional prefix to add to the echoed message", + "default": "Echo: ", + }, + }, + Required: []string{"message"}, + }, + Handler: func(ctx context.Context, params map[string]any) (any, error) { + var req struct { + Message string `json:"message"` + Prefix string `json:"prefix,optional"` + } + + if err := mcp.ParseArguments(params, &req); err != nil { + return nil, fmt.Errorf("failed to parse args: %w", err) + } + + prefix := "Echo: " + if len(req.Prefix) > 0 { + prefix = req.Prefix + } + + return prefix + req.Message, nil + }, + } + server.RegisterTool(echoTool) + + // Register a static prompt + server.RegisterPrompt(mcp.Prompt{ + Name: "greeting", + Description: "A simple greeting prompt", + Arguments: []mcp.PromptArgument{ + { + Name: "name", + Description: "The name to greet", + Required: true, + }, + }, + Content: "Hello {{name}}! How can I assist you today?", + }) + + // Register a dynamic prompt + server.RegisterPrompt(mcp.Prompt{ + Name: "dynamic-prompt", + Description: "A prompt that uses a handler to generate dynamic content", + Arguments: []mcp.PromptArgument{ + { + Name: "username", + Description: "User's name for personalized greeting", + Required: true, + }, + { + Name: "topic", + Description: "Topic of expertise", + Required: true, + }, + }, + Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) { + var req struct { + Username string `json:"username"` + Topic string `json:"topic"` + } + + if err := mcp.ParseArguments(args, &req); err != nil { + return nil, fmt.Errorf("failed to parse args: %w", err) + } + + // Create messages with current time + currentTime := time.Now().Format(time.RFC1123) + return []mcp.PromptMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic), + }, + }, + { + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.", + req.Username, req.Topic, currentTime), + }, + }, + }, nil + }, + }) + + // Register a resource + server.RegisterResource(mcp.Resource{ + Name: "example-doc", + URI: "file:///example/doc.txt", + Description: "An example document", + MimeType: "text/plain", + Handler: func(ctx context.Context) (mcp.ResourceContent, error) { + return mcp.ResourceContent{ + URI: "file:///example/doc.txt", + MimeType: "text/plain", + Text: "This is the content of the example document.", + }, nil + }, + }) + + // Start the server + fmt.Printf("Starting MCP server on %s:%d\n", c.Host, c.Port) + server.Start() +} +``` + +## Error Handling + +The MCP implementation provides comprehensive error handling: + +- Tool execution errors are properly reported back to clients +- Missing or invalid parameters are detected and reported with appropriate error codes +- Resource and prompt lookup failures are handled gracefully +- Timeout handling for long-running tool executions using context +- Panic recovery to prevent server crashes + +## Advanced Features + +- **Annotations**: Add audience and priority metadata to content +- **Content Types**: Support for text, images, audio, and other content formats +- **Embedded Resources**: Include file resources directly in prompt responses +- **Context Awareness**: All handlers receive context.Context for timeout and cancellation support +- **Progress Tokens**: Support for tracking progress of long-running operations +- **Customizable Timeouts**: Configure execution timeouts for tools and operations + +## Performance Considerations + +- Tool execution runs with configurable timeouts to prevent blocking +- Efficient client tracking and cleanup to prevent resource leaks +- Proper concurrency handling with mutex protection for shared resources +- Buffered message channels to prevent blocking on client message delivery diff --git a/mcp/server.go b/mcp/server.go index f2a2f8889..64534b830 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -42,14 +42,14 @@ func NewMcpServer(c McpConf) McpServer { Method: http.MethodGet, Path: s.conf.Mcp.SseEndpoint, Handler: s.handleSSE, - }, rest.WithSSE()) + }, rest.WithSSE(), rest.WithTimeout(c.Mcp.SseTimeout)) // JSON-RPC message endpoint for regular requests s.server.AddRoute(rest.Route{ Method: http.MethodPost, Path: s.conf.Mcp.MessageEndpoint, Handler: s.handleRequest, - }) + }, rest.WithTimeout(c.Mcp.MessageTimeout)) return s } @@ -182,21 +182,23 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) { // Always allow initialize and notifications/initialized regardless of client state if req.Method == methodInitialize { logx.Infof("Processing initialize request with ID: %d", req.ID) - s.processInitialize(client, req) + s.processInitialize(r.Context(), client, req) logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID) return } else if req.Method == methodNotificationsInitialized { // Handle initialized notification logx.Info("Received notifications/initialized notification") if !isNotification { - s.sendErrorResponse(client, req.ID, "Method should be used as a notification", errCodeInvalidRequest) + s.sendErrorResponse(r.Context(), client, req.ID, + "Method should be used as a notification", errCodeInvalidRequest) return } s.processNotificationInitialized(client) return } else if !client.initialized && req.Method != methodNotificationsCancelled { // Block most requests until client is initialized (except for cancellations) - s.sendErrorResponse(client, req.ID, "Client not fully initialized, waiting for notifications/initialized", + s.sendErrorResponse(r.Context(), client, req.ID, + "Client not fully initialized, waiting for notifications/initialized", errCodeClientNotInitialized) return } @@ -205,41 +207,41 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) { switch req.Method { case methodToolsCall: logx.Infof("Received tools call request with ID: %d", req.ID) - s.processToolCall(client, req) + s.processToolCall(r.Context(), client, req) logx.Infof("Sent tools call response for ID: %d", req.ID) case methodToolsList: logx.Infof("Processing tools/list request with ID: %d", req.ID) - s.processListTools(client, req) + s.processListTools(r.Context(), client, req) logx.Infof("Sent tools/list response for ID: %d", req.ID) case methodPromptsList: logx.Infof("Processing prompts/list request with ID: %d", req.ID) - s.processListPrompts(client, req) + s.processListPrompts(r.Context(), client, req) logx.Infof("Sent prompts/list response for ID: %d", req.ID) case methodPromptsGet: logx.Infof("Processing prompts/get request with ID: %d", req.ID) - s.processGetPrompt(client, req) + s.processGetPrompt(r.Context(), client, req) logx.Infof("Sent prompts/get response for ID: %d", req.ID) case methodResourcesList: logx.Infof("Processing resources/list request with ID: %d", req.ID) - s.processListResources(client, req) + s.processListResources(r.Context(), client, req) logx.Infof("Sent resources/list response for ID: %d", req.ID) case methodResourcesRead: logx.Infof("Processing resources/read request with ID: %d", req.ID) - s.processResourcesRead(client, req) + s.processResourcesRead(r.Context(), client, req) logx.Infof("Sent resources/read response for ID: %d", req.ID) case methodResourcesSubscribe: logx.Infof("Processing resources/subscribe request with ID: %d", req.ID) - s.processResourceSubscribe(client, req) + s.processResourceSubscribe(r.Context(), client, req) logx.Infof("Sent resources/subscribe response for ID: %d", req.ID) case methodPing: logx.Infof("Processing ping request with ID: %d", req.ID) - s.processPing(client, req) + s.processPing(r.Context(), client, req) case methodNotificationsCancelled: - logx.Infof("Received notifications/cancelled notification: %v", req.Params) - s.processNotificationCancelled(client, req) + logx.Infof("Received notifications/cancelled notification: %d", req.ID) + s.processNotificationCancelled(r.Context(), client, req) default: - logx.Infof("Unknown method: %s", req.Method) - s.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound) + logx.Infof("Unknown method: %s from client: %d", req.Method, req.ID) + s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound) } } @@ -321,7 +323,7 @@ func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) { } flusher.Flush() case <-r.Context().Done(): - // Client disconnected or request was canceled + // Client disconnected or request was canceled or timed out logx.Infof("Client %s disconnected: context done", sessionId) return } @@ -329,7 +331,7 @@ func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) { } // processInitialize processes the initialize request -func (s *sseMcpServer) processInitialize(client *mcpClient, req Request) { +func (s *sseMcpServer) processInitialize(ctx context.Context, client *mcpClient, req Request) { // Create a proper JSON-RPC response that preserves the client's request ID result := initializationResponse{ ProtocolVersion: s.conf.Mcp.ProtocolVersion, @@ -362,11 +364,11 @@ func (s *sseMcpServer) processInitialize(client *mcpClient, req Request) { client.initialized = true // Send response with client's original request ID - s.sendResponse(client, req.ID, result) + s.sendResponse(ctx, client, req.ID, result) } // processListTools processes the tools/list request -func (s *sseMcpServer) processListTools(client *mcpClient, req Request) { +func (s *sseMcpServer) processListTools(ctx context.Context, client *mcpClient, req Request) { // Extract pagination params if any var nextCursor string var progressToken any @@ -390,6 +392,9 @@ func (s *sseMcpServer) processListTools(client *mcpClient, req Request) { var toolsList []Tool s.toolsLock.Lock() for _, tool := range s.tools { + if len(tool.InputSchema.Type) == 0 { + tool.InputSchema.Type = ContentTypeObject + } toolsList = append(toolsList, tool) } s.toolsLock.Unlock() @@ -405,15 +410,15 @@ func (s *sseMcpServer) processListTools(client *mcpClient, req Request) { // Add meta information if progress token was provided if progressToken != nil { result.Result.Meta = map[string]any{ - "progressToken": progressToken, + progressTokenKey: progressToken, } } - s.sendResponse(client, req.ID, result) + s.sendResponse(ctx, client, req.ID, result) } // processListPrompts processes the prompts/list request -func (s *sseMcpServer) processListPrompts(client *mcpClient, req Request) { +func (s *sseMcpServer) processListPrompts(ctx context.Context, client *mcpClient, req Request) { // Extract pagination params if any var nextCursor string if req.Params != nil { @@ -447,11 +452,11 @@ func (s *sseMcpServer) processListPrompts(client *mcpClient, req Request) { NextCursor: nextCursor, } - s.sendResponse(client, req.ID, result) + s.sendResponse(ctx, client, req.ID, result) } // processListResources processes the resources/list request -func (s *sseMcpServer) processListResources(client *mcpClient, req Request) { +func (s *sseMcpServer) processListResources(ctx context.Context, client *mcpClient, req Request) { // Extract pagination params if any var nextCursor string var progressToken any @@ -493,15 +498,15 @@ func (s *sseMcpServer) processListResources(client *mcpClient, req Request) { // Add meta information if progress token was provided if progressToken != nil { result.Result.Meta = map[string]any{ - "progressToken": progressToken, + progressTokenKey: progressToken, } } - s.sendResponse(client, req.ID, result) + s.sendResponse(ctx, client, req.ID, result) } // processGetPrompt processes the prompts/get request -func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) { +func (s *sseMcpServer) processGetPrompt(ctx context.Context, client *mcpClient, req Request) { type GetPromptParams struct { Name string `json:"name"` Arguments map[string]string `json:"arguments,omitempty"` @@ -509,7 +514,7 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) { var params GetPromptParams if err := json.Unmarshal(req.Params, ¶ms); err != nil { - s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams) + s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams) return } @@ -519,7 +524,7 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) { s.promptsLock.Unlock() if !exists { message := fmt.Sprintf("Prompt '%s' not found", params.Name) - s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams) + s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams) return } @@ -529,12 +534,15 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) { missingArgs := validatePromptArguments(prompt, params.Arguments) if len(missingArgs) > 0 { message := fmt.Sprintf("Missing required arguments: %s", strings.Join(missingArgs, ", ")) - s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams) + s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams) return } - // Apply default values for missing optional arguments - args := applyDefaultArguments(prompt, params.Arguments) + // Ensure arguments are initialized to an empty map if nil + if params.Arguments == nil { + params.Arguments = make(map[string]string) + } + args := params.Arguments // Generate messages using handler or static content var messages []PromptMessage @@ -542,17 +550,17 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) { if prompt.Handler != nil { // Use dynamic handler to generate messages - logx.Info("Using prompt handler to generate content") - messages, err = prompt.Handler(args) + messages, err = prompt.Handler(ctx, args) if err != nil { logx.Errorf("Error from prompt handler: %v", err) - s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError) + s.sendErrorResponse(ctx, client, req.ID, + fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError) return } } else { // No handler, generate messages from static content var messageText string - if prompt.Content != "" { + if len(prompt.Content) > 0 { messageText = prompt.Content // Apply argument substitutions to static content @@ -560,21 +568,13 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) { placeholder := fmt.Sprintf("{{%s}}", key) messageText = strings.Replace(messageText, placeholder, value, -1) } - } else { - // No content, use a default fallback - topic := "this topic" - if t, ok := args["topic"]; ok && t != "" { - topic = t - } - messageText = fmt.Sprintf("Tell me about %s", topic) } // Create a single user message with the content messages = []PromptMessage{ { - Role: roleUser, + Role: RoleUser, Content: TextContent{ - Type: contentTypeText, Text: messageText, }, }, @@ -587,49 +587,14 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) { Messages []PromptMessage `json:"messages"` }{ Description: prompt.Description, - Messages: messages, + Messages: toTypedPromptMessages(messages), } - s.sendResponse(client, req.ID, result) -} - -// validatePromptArguments checks if all required arguments are provided -// Returns a list of missing required arguments -func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string { - var missingArgs []string - - for _, arg := range prompt.Arguments { - if arg.Required { - if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 { - missingArgs = append(missingArgs, arg.Name) - } - } - } - - return missingArgs -} - -// applyDefaultArguments adds default values for missing optional arguments -func applyDefaultArguments(prompt Prompt, providedArgs map[string]string) map[string]string { - result := make(map[string]string) - - // Copy all provided arguments - for k, v := range providedArgs { - result[k] = v - } - - // Add defaults for missing arguments - for _, arg := range prompt.Arguments { - if _, exists := result[arg.Name]; !exists && arg.Default != "" { - result[arg.Name] = arg.Default - } - } - - return result + s.sendResponse(ctx, client, req.ID, result) } // processToolCall processes the tools/call request -func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) { +func (s *sseMcpServer) processToolCall(ctx context.Context, client *mcpClient, req Request) { var toolCallParams struct { Name string `json:"name"` Arguments map[string]any `json:"arguments,omitempty"` @@ -642,7 +607,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) { // If it's a RawMessage (JSON), unmarshal it if err := json.Unmarshal(req.Params, &toolCallParams); err != nil { logx.Errorf("Failed to unmarshal tool call params: %v", err) - s.sendErrorResponse(client, req.ID, "Invalid tool call parameters", errCodeInvalidParams) + s.sendErrorResponse(ctx, client, req.ID, "Invalid tool call parameters", errCodeInvalidParams) return } @@ -654,15 +619,11 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) { tool, exists := s.tools[toolCallParams.Name] s.toolsLock.Unlock() if !exists { - s.sendErrorResponse(client, req.ID, fmt.Sprintf("Tool '%s' not found", + s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Tool '%s' not found", toolCallParams.Name), errCodeInvalidParams) return } - // Create a context with the configured timeout - ctx, cancel := context.WithTimeout(context.Background(), s.conf.Mcp.ToolTimeout) - defer cancel() - // Log parameters before execution logx.Infof("Executing tool '%s' with arguments: %#v", toolCallParams.Name, toolCallParams.Arguments) @@ -671,6 +632,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) { var err error // Create a channel to receive the result + // make sure to have 1 size buffer to avoid channel leak if timeout resultCh := make(chan struct { result any err error @@ -678,7 +640,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) { // Execute the tool handler in a goroutine go func() { - toolResult, toolErr := tool.Handler(toolCallParams.Arguments) + toolResult, toolErr := tool.Handler(ctx, toolCallParams.Arguments) resultCh <- struct { result any err error @@ -694,9 +656,9 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) { result = res.result err = res.err case <-ctx.Done(): - // Handle timeout - logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.ToolTimeout, toolCallParams.Name) - s.sendErrorResponse(client, req.ID, "Tool execution timed out", errCodeTimeout) + // Handle request timeout + logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.MessageTimeout, toolCallParams.Name) + s.sendErrorResponse(ctx, client, req.ID, "Tool execution timed out", errCodeTimeout) return } @@ -710,7 +672,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) { // Add meta information if progress token was provided if progressToken != nil { callToolResult.Result.Meta = map[string]any{ - "progressToken": progressToken, + progressTokenKey: progressToken, } } @@ -722,12 +684,11 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) { callToolResult.Content = []any{ TextContent{ - Type: contentTypeText, Text: fmt.Sprintf("Error: %v", err), }, } callToolResult.IsError = true - s.sendResponse(client, req.ID, callToolResult) + s.sendResponse(ctx, client, req.ID, callToolResult) return } @@ -736,10 +697,9 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) { case string: // Simple string becomes text content callToolResult.Content = append(callToolResult.Content, TextContent{ - Type: contentTypeText, Text: v, Annotations: &Annotations{ - Audience: []roleType{roleUser, roleAssistant}, + Audience: []RoleType{RoleUser, RoleAssistant}, }, }) case map[string]any: @@ -749,69 +709,63 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) { jsonStr = []byte(err.Error()) } callToolResult.Content = append(callToolResult.Content, TextContent{ - Type: contentTypeText, Text: string(jsonStr), Annotations: &Annotations{ - Audience: []roleType{roleUser, roleAssistant}, + Audience: []RoleType{RoleUser, RoleAssistant}, }, }) case TextContent: - // Direct TextContent object callToolResult.Content = append(callToolResult.Content, v) case ImageContent: - // Direct ImageContent object callToolResult.Content = append(callToolResult.Content, v) case []any: - // Array of content items callToolResult.Content = v case ToolResult: // Handle legacy ToolResult type switch v.Type { - case contentTypeText: + case ContentTypeText: callToolResult.Content = append(callToolResult.Content, TextContent{ - Type: contentTypeText, Text: fmt.Sprintf("%v", v.Content), Annotations: &Annotations{ - Audience: []roleType{roleUser, roleAssistant}, + Audience: []RoleType{RoleUser, RoleAssistant}, }, }) - case contentTypeImage: + case ContentTypeImage: if imgData, ok := v.Content.(map[string]any); ok { callToolResult.Content = append(callToolResult.Content, ImageContent{ - Type: contentTypeImage, Data: fmt.Sprintf("%v", imgData["data"]), MimeType: fmt.Sprintf("%v", imgData["mimeType"]), }) } default: callToolResult.Content = append(callToolResult.Content, TextContent{ - Type: contentTypeText, Text: fmt.Sprintf("%v", v.Content), Annotations: &Annotations{ - Audience: []roleType{roleUser, roleAssistant}, + Audience: []RoleType{RoleUser, RoleAssistant}, }, }) } default: // For any other type, convert to string callToolResult.Content = append(callToolResult.Content, TextContent{ - Type: contentTypeText, Text: fmt.Sprintf("%v", v), Annotations: &Annotations{ - Audience: []roleType{roleUser, roleAssistant}, + Audience: []RoleType{RoleUser, RoleAssistant}, }, }) } + callToolResult.Content = toTypedContents(callToolResult.Content) logx.Infof("Tool call result: %#v", callToolResult) - s.sendResponse(client, req.ID, callToolResult) + + s.sendResponse(ctx, client, req.ID, callToolResult) } // processResourcesRead processes the resources/read request -func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) { +func (s *sseMcpServer) processResourcesRead(ctx context.Context, client *mcpClient, req Request) { var params ResourceReadParams if err := json.Unmarshal(req.Params, ¶ms); err != nil { - s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams) + s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams) return } @@ -821,7 +775,7 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) { s.resourcesLock.Unlock() if !exists { - s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found", + s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found", params.URI), errCodeResourceNotFound) return } @@ -837,14 +791,14 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) { }, }, } - s.sendResponse(client, req.ID, result) + s.sendResponse(ctx, client, req.ID, result) return } // Execute the resource handler - content, err := resource.Handler() + content, err := resource.Handler(ctx) if err != nil { - s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error reading resource: %v", err), + s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Error reading resource: %v", err), errCodeInternalError) return } @@ -865,14 +819,14 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) { Contents: []ResourceContent{content}, } - s.sendResponse(client, req.ID, result) + s.sendResponse(ctx, client, req.ID, result) } // processResourceSubscribe processes the resources/subscribe request -func (s *sseMcpServer) processResourceSubscribe(client *mcpClient, req Request) { +func (s *sseMcpServer) processResourceSubscribe(ctx context.Context, client *mcpClient, req Request) { var params ResourceSubscribeParams if err := json.Unmarshal(req.Params, ¶ms); err != nil { - s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams) + s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams) return } @@ -882,19 +836,17 @@ func (s *sseMcpServer) processResourceSubscribe(client *mcpClient, req Request) s.resourcesLock.Unlock() if !exists { - s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found", + s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found", params.URI), errCodeResourceNotFound) return } // Send success response for the subscription - s.sendResponse(client, req.ID, struct{}{}) - - logx.Infof("Client %s subscribed to resource '%s'", client.id, params.URI) + s.sendResponse(ctx, client, req.ID, struct{}{}) } // processNotificationCancelled processes the notifications/cancelled notification -func (s *sseMcpServer) processNotificationCancelled(client *mcpClient, req Request) { +func (s *sseMcpServer) processNotificationCancelled(ctx context.Context, client *mcpClient, req Request) { // Extract the requestId that was canceled type CancelParams struct { RequestId int64 `json:"requestId"` @@ -918,18 +870,17 @@ func (s *sseMcpServer) processNotificationInitialized(client *mcpClient) { } // processPing processes the ping request and responds immediately -func (s *sseMcpServer) processPing(client *mcpClient, req Request) { +func (s *sseMcpServer) processPing(ctx context.Context, client *mcpClient, req Request) { // A ping request should simply respond with an empty result to confirm the server is alive logx.Infof("Received ping request with ID: %d", req.ID) // Send an empty response with client's original request ID - s.sendResponse(client, req.ID, struct{}{}) - - logx.Infof("Sent ping response for ID: %d", req.ID) + s.sendResponse(ctx, client, req.ID, struct{}{}) } // sendErrorResponse sends an error response via the SSE channel -func (s *sseMcpServer) sendErrorResponse(client *mcpClient, id int64, message string, code int) { +func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient, + id int64, message string, code int) { errorResponse := struct { JsonRpc string `json:"jsonrpc"` ID int64 `json:"id"` @@ -949,11 +900,17 @@ func (s *sseMcpServer) sendErrorResponse(client *mcpClient, id int64, message st sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData)) logx.Infof("Sending error for ID %d: %s", id, sseMessage) - client.channel <- sseMessage + // cannot receive from ctx.Done() because we're sending to the channel for SSE messages + select { + case client.channel <- sseMessage: + default: + // Channel buffer is full, log warning and continue + logx.Infof("Client %s channel is full while sending error response with ID %d", client.id, id) + } } // sendResponse sends a success response via the SSE channel -func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) { +func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id int64, result any) { response := Response{ JsonRpc: jsonRpcVersion, ID: id, @@ -962,7 +919,7 @@ func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) { jsonData, err := json.Marshal(response) if err != nil { - s.sendErrorResponse(client, id, "Failed to marshal response", errCodeInternalError) + s.sendErrorResponse(ctx, client, id, "Failed to marshal response", errCodeInternalError) return } @@ -970,5 +927,11 @@ func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) { sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData)) logx.Infof("Sending response for ID %d: %s", id, sseMessage) - client.channel <- sseMessage + // cannot receive from ctx.Done() because we're sending to the channel for SSE messages + select { + case client.channel <- sseMessage: + default: + // Channel buffer is full, log warning and continue + logx.Infof("Client %s channel is full while sending response with ID %d", client.id, id) + } } diff --git a/mcp/server_test.go b/mcp/server_test.go index 979060e79..1d6731a0c 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -34,7 +34,7 @@ host: localhost port: 8080 mcp: name: mcp-test-server - toolTimeout: 5s + messageTimeout: 5s ` var c McpConf @@ -82,7 +82,6 @@ func (m *mockMcpServer) registerExampleTool() { Name: "test.tool", Description: "A test tool", InputSchema: InputSchema{ - Type: "object", Properties: map[string]any{ "input": map[string]any{ "type": "string", @@ -91,7 +90,7 @@ func (m *mockMcpServer) registerExampleTool() { }, Required: []string{"input"}, }, - Handler: func(params map[string]any) (any, error) { + Handler: func(ctx context.Context, params map[string]any) (any, error) { input, ok := params["input"].(string) if !ok { return nil, fmt.Errorf("invalid input parameter") @@ -135,7 +134,7 @@ port: 8080 mcp: cors: - http://localhost:3000 - toolTimeout: 5s + messageTimeout: 5s ` var c McpConf @@ -186,7 +185,6 @@ func TestRegisterTool(t *testing.T) { Name: "example.tool", Description: "An example tool", InputSchema: InputSchema{ - Type: "object", Properties: map[string]any{ "input": map[string]any{ "type": "string", @@ -194,7 +192,7 @@ func TestRegisterTool(t *testing.T) { }, }, }, - Handler: func(params map[string]any) (any, error) { + Handler: func(ctx context.Context, params map[string]any) (any, error) { return "result", nil }, } @@ -280,7 +278,7 @@ func TestToolsList(t *testing.T) { } // Process the tool call - mock.server.processListTools(client, req) + mock.server.processListTools(context.Background(), client, req) // Get the response from the client's channel select { @@ -328,7 +326,7 @@ func TestToolCallBasic(t *testing.T) { } // Process the tool call - mock.server.processToolCall(client, req) + mock.server.processToolCall(context.Background(), client, req) // Get the response from the client's channel select { @@ -355,8 +353,7 @@ func TestToolCallBasic(t *testing.T) { // Verify the response content assert.Len(t, parsed.Result.Content, 1, "Response should contain one content item") - assert.Equal(t, "text", parsed.Result.Content[0]["type"], "Content type should be text") - assert.Equal(t, "Processed: test-input", parsed.Result.Content[0]["text"], "Tool result incorrect") + assert.Equal(t, "Processed: test-input", parsed.Result.Content[0][ContentTypeText], "Tool result incorrect") assert.False(t, parsed.Result.IsError, "Response should not be an error") case <-time.After(100 * time.Millisecond): t.Fatal("Timed out waiting for tool call response") @@ -373,10 +370,9 @@ func TestToolCallMapResult(t *testing.T) { Name: "map.tool", Description: "A tool that returns a map result", InputSchema: InputSchema{ - Type: "object", Properties: map[string]any{}, }, - Handler: func(params map[string]any) (any, error) { + Handler: func(ctx context.Context, params map[string]any) (any, error) { // Return a complex nested map structure return map[string]any{ "string": "value", @@ -417,7 +413,7 @@ func TestToolCallMapResult(t *testing.T) { } // Process the tool call - mock.server.processToolCall(client, req) + mock.server.processToolCall(context.Background(), client, req) // Get the response from the client's channel select { @@ -445,13 +441,8 @@ func TestToolCallMapResult(t *testing.T) { firstItem, ok := content[0].(map[string]any) require.True(t, ok, "First content item should be an object") - // Verify it's a text content - contentType, ok := firstItem["type"].(string) - require.True(t, ok, "Content should have a type") - assert.Equal(t, "text", contentType, "Content type should be text") - // Get the text content which should be our JSON - text, ok := firstItem["text"].(string) + text, ok := firstItem[ContentTypeText].(string) require.True(t, ok, "Content should have text") // Verify the text is valid JSON and contains our data @@ -496,10 +487,9 @@ func TestToolCallArrayResult(t *testing.T) { Name: "array.tool", Description: "A tool that returns an array result", InputSchema: InputSchema{ - Type: "object", Properties: map[string]any{}, }, - Handler: func(params map[string]any) (any, error) { + Handler: func(ctx context.Context, params map[string]any) (any, error) { // Return an array of mixed content types return []any{ "string item", @@ -536,7 +526,7 @@ func TestToolCallArrayResult(t *testing.T) { } // Process the tool call - mock.server.processToolCall(client, req) + mock.server.processToolCall(context.Background(), client, req) // Get the response from the client's channel select { @@ -574,16 +564,14 @@ func TestToolCallTextContentResult(t *testing.T) { Name: "text.content.tool", Description: "A tool that returns a TextContent result", InputSchema: InputSchema{ - Type: "object", Properties: map[string]any{}, }, - Handler: func(params map[string]any) (any, error) { + Handler: func(ctx context.Context, params map[string]any) (any, error) { // Return a TextContent object directly return TextContent{ - Type: "text", Text: "This is a direct TextContent result", Annotations: &Annotations{ - Audience: []roleType{roleUser, roleAssistant}, + Audience: []RoleType{RoleUser, RoleAssistant}, Priority: func() *float64 { p := 0.9; return &p }(), }, }, nil @@ -614,7 +602,7 @@ func TestToolCallTextContentResult(t *testing.T) { } // Process the tool call - mock.server.processToolCall(client, req) + mock.server.processToolCall(context.Background(), client, req) // Get the response from the client's channel select { @@ -642,16 +630,6 @@ func TestToolCallTextContentResult(t *testing.T) { firstItem, ok := content[0].(map[string]any) require.True(t, ok, "First content item should be an object") - // Verify it's a text content - contentType, ok := firstItem["type"].(string) - require.True(t, ok, "Content should have a type") - assert.Equal(t, "text", contentType, "Content type should be text") - - // Check text content - text, ok := firstItem["text"].(string) - require.True(t, ok, "Content should have text") - assert.Equal(t, "This is a direct TextContent result", text, "Text content should match") - // Check annotations annotations, ok := firstItem["annotations"].(map[string]any) require.True(t, ok, "Should have annotations") @@ -679,13 +657,11 @@ func TestToolCallImageContentResult(t *testing.T) { Name: "image.content.tool", Description: "A tool that returns an ImageContent result", InputSchema: InputSchema{ - Type: "object", Properties: map[string]any{}, }, - Handler: func(params map[string]any) (any, error) { + Handler: func(ctx context.Context, params map[string]any) (any, error) { // Return an ImageContent object directly return ImageContent{ - Type: "image", Data: "dGVzdCBpbWFnZSBkYXRhIChiYXNlNjQgZW5jb2RlZCk=", // "test image data (base64 encoded)" in base64 MimeType: "image/png", }, nil @@ -716,7 +692,7 @@ func TestToolCallImageContentResult(t *testing.T) { } // Process the tool call - mock.server.processToolCall(client, req) + mock.server.processToolCall(context.Background(), client, req) // Get the response from the client's channel select { @@ -744,11 +720,6 @@ func TestToolCallImageContentResult(t *testing.T) { firstItem, ok := content[0].(map[string]any) require.True(t, ok, "First content item should be an object") - // Verify it's an image content - contentType, ok := firstItem["type"].(string) - require.True(t, ok, "Content should have a type") - assert.Equal(t, "image", contentType, "Content type should be image") - // Check image data data, ok := firstItem["data"].(string) require.True(t, ok, "Content should have data") @@ -773,12 +744,12 @@ func TestToolCallToolResultType(t *testing.T) { Name: "toolresult.tool", Description: "A tool that returns a ToolResult object", InputSchema: InputSchema{ - Type: "object", + Type: ContentTypeObject, Properties: map[string]any{}, }, - Handler: func(params map[string]any) (any, error) { + Handler: func(ctx context.Context, params map[string]any) (any, error) { return ToolResult{ - Type: "text", + Type: ContentTypeText, Content: "This is a ToolResult with text content type", }, nil }, @@ -790,10 +761,10 @@ func TestToolCallToolResultType(t *testing.T) { Name: "toolresult.image.tool", Description: "A tool that returns a ToolResult with image content", InputSchema: InputSchema{ - Type: "object", + Type: ContentTypeObject, Properties: map[string]any{}, }, - Handler: func(params map[string]any) (any, error) { + Handler: func(ctx context.Context, params map[string]any) (any, error) { return ToolResult{ Type: "image", Content: map[string]any{ @@ -810,10 +781,9 @@ func TestToolCallToolResultType(t *testing.T) { Name: "toolresult.audio.tool", Description: "A tool that returns a ToolResult with audio content", InputSchema: InputSchema{ - Type: "object", Properties: map[string]any{}, }, - Handler: func(params map[string]any) (any, error) { + Handler: func(ctx context.Context, params map[string]any) (any, error) { // Test with image type return ToolResult{ Type: "audio", @@ -831,10 +801,9 @@ func TestToolCallToolResultType(t *testing.T) { Name: "toolresult.int.tool", Description: "A tool that returns a ToolResult with int content", InputSchema: InputSchema{ - Type: "object", Properties: map[string]any{}, }, - Handler: func(params map[string]any) (any, error) { + Handler: func(ctx context.Context, params map[string]any) (any, error) { return 2, nil }, } @@ -845,10 +814,9 @@ func TestToolCallToolResultType(t *testing.T) { Name: "toolresult.bad.tool", Description: "A tool that returns a ToolResult with bad content", InputSchema: InputSchema{ - Type: "object", Properties: map[string]any{}, }, - Handler: func(params map[string]any) (any, error) { + Handler: func(ctx context.Context, params map[string]any) (any, error) { return map[string]any{ "type": "custom", "data": make(chan int), @@ -881,7 +849,7 @@ func TestToolCallToolResultType(t *testing.T) { } // Process the tool call - mock.server.processToolCall(client, req) + mock.server.processToolCall(context.Background(), client, req) // Get the response from the client's channel select { @@ -909,13 +877,8 @@ func TestToolCallToolResultType(t *testing.T) { firstItem, ok := content[0].(map[string]any) require.True(t, ok, "First content item should be an object") - // Verify it's a text content - contentType, ok := firstItem["type"].(string) - require.True(t, ok, "Content should have a type") - assert.Equal(t, "text", contentType, "Content type should be text") - // Check text content - text, ok := firstItem["text"].(string) + text, ok := firstItem[ContentTypeText].(string) require.True(t, ok, "Content should have text") assert.Equal(t, "This is a ToolResult with text content type", text, "Text content should match") @@ -947,7 +910,7 @@ func TestToolCallToolResultType(t *testing.T) { } // Process the tool call - mock.server.processToolCall(client, req) + mock.server.processToolCall(context.Background(), client, req) // Get the response from the client's channel select { @@ -975,11 +938,6 @@ func TestToolCallToolResultType(t *testing.T) { firstItem, ok := content[0].(map[string]any) require.True(t, ok, "First content item should be an object") - // Verify it's an image content - contentType, ok := firstItem["type"].(string) - require.True(t, ok, "Content should have a type") - assert.Equal(t, "image", contentType, "Content type should be image") - // Check image data and mime type data, ok := firstItem["data"].(string) require.True(t, ok, "Content should have data") @@ -1017,7 +975,7 @@ func TestToolCallToolResultType(t *testing.T) { } // Process the tool call - mock.server.processToolCall(client, req) + mock.server.processToolCall(context.Background(), client, req) // Get the response from the client's channel select { @@ -1040,15 +998,6 @@ func TestToolCallToolResultType(t *testing.T) { content, ok := result["content"].([]any) require.True(t, ok, "Result should have a content array") require.NotEmpty(t, content, "Content should not be empty") - - // The first content item should be converted from ToolResult to ImageContent - firstItem, ok := content[0].(map[string]any) - require.True(t, ok, "First content item should be an object") - - // Verify it's an image content - contentType, ok := firstItem["type"].(string) - require.True(t, ok, "Content should have a type") - assert.Equal(t, "text", contentType, "Content type should be image") case <-time.After(100 * time.Millisecond): t.Fatal("Timed out waiting for tool call response") } @@ -1077,7 +1026,7 @@ func TestToolCallToolResultType(t *testing.T) { } // Process the tool call - mock.server.processToolCall(client, req) + mock.server.processToolCall(context.Background(), client, req) // Get the response from the client's channel select { @@ -1100,15 +1049,6 @@ func TestToolCallToolResultType(t *testing.T) { content, ok := result["content"].([]any) require.True(t, ok, "Result should have a content array") require.NotEmpty(t, content, "Content should not be empty") - - // The first content item should be converted from ToolResult to ImageContent - firstItem, ok := content[0].(map[string]any) - require.True(t, ok, "First content item should be an object") - - // Verify it's an image content - contentType, ok := firstItem["type"].(string) - require.True(t, ok, "Content should have a type") - assert.Equal(t, "text", contentType, "Content type should be image") case <-time.After(100 * time.Millisecond): t.Fatal("Timed out waiting for tool call response") } @@ -1137,7 +1077,7 @@ func TestToolCallToolResultType(t *testing.T) { } // Process the tool call - mock.server.processToolCall(client, req) + mock.server.processToolCall(context.Background(), client, req) // Get the response from the client's channel select { @@ -1159,10 +1099,9 @@ func TestToolCallError(t *testing.T) { Name: "error.tool", Description: "A tool that returns an error", InputSchema: InputSchema{ - Type: "object", Properties: map[string]any{}, }, - Handler: func(params map[string]any) (any, error) { + Handler: func(ctx context.Context, params map[string]any) (any, error) { return nil, fmt.Errorf("tool execution failed") }, }) @@ -1189,7 +1128,7 @@ func TestToolCallError(t *testing.T) { } // Process the tool call - mock.server.processToolCall(client, req) + mock.server.processToolCall(context.Background(), client, req) // Check the response select { @@ -1207,20 +1146,16 @@ func TestToolCallTimeout(t *testing.T) { mock := newMockMcpServer(t) defer mock.shutdown() - // Set a very short timeout for testing - mock.server.conf.Mcp.ToolTimeout = 10 * time.Millisecond - // Register a tool that times out err := mock.server.RegisterTool(Tool{ Name: "timeout.tool", Description: "A tool that times out", InputSchema: InputSchema{ - Type: "object", Properties: map[string]any{}, }, - Handler: func(params map[string]any) (any, error) { - time.Sleep(50 * time.Millisecond) // Sleep longer than timeout - return "this should never be returned", nil + Handler: func(ctx context.Context, params map[string]any) (any, error) { + <-ctx.Done() + return nil, fmt.Errorf("tool execution timed out") }, }) require.NoError(t, err) @@ -1244,16 +1179,24 @@ func TestToolCallTimeout(t *testing.T) { Method: methodToolsCall, Params: paramBytes, } + jsonBody, _ := json.Marshal(req) - // Process the tool call - mock.server.processToolCall(client, req) + // Create HTTP request + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", bytes.NewReader(jsonBody)) + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Millisecond) + defer cancel() + r = r.WithContext(ctx) + w := httptest.NewRecorder() + + // Process through handleRequest + go mock.server.handleRequest(w, r) // Check the response select { case response := <-client.channel: assert.Contains(t, response, "event: message", "Response should have message event") assert.Contains(t, response, `-32001`, "Response should contain a timeout error code") - case <-time.After(150 * time.Millisecond): + case <-time.After(100 * time.Millisecond): t.Fatal("Timed out waiting for tool call response") } } @@ -1274,7 +1217,7 @@ func TestInitializeAndNotifications(t *testing.T) { Params: json.RawMessage(`{}`), } - mock.server.processInitialize(client, initReq) + mock.server.processInitialize(context.Background(), client, initReq) // Check that client is initialized after initialize request assert.True(t, client.initialized, "Client should be marked as initialized after initialize request") @@ -1418,7 +1361,7 @@ func TestNotificationCancelled_badParams(t *testing.T) { } buf := logtest.NewCollector(t) - mock.server.processNotificationCancelled(client, cancelReq) + mock.server.processNotificationCancelled(context.Background(), client, cancelReq) select { case <-client.channel: @@ -1593,7 +1536,7 @@ func TestGetPrompt(t *testing.T) { } // Process the request - mock.server.processGetPrompt(client, promptReq) + mock.server.processGetPrompt(context.Background(), client, promptReq) // Check response select { @@ -1622,7 +1565,7 @@ func TestGetPrompt(t *testing.T) { } // Process the request - mock.server.processGetPrompt(client, promptReq) + mock.server.processGetPrompt(context.Background(), client, promptReq) // Check response select { @@ -1636,6 +1579,44 @@ func TestGetPrompt(t *testing.T) { t.Fatal("Timed out waiting for prompt response") } }) + + t.Run("test prompt with nil params", func(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Create a test client + client := addTestClient(mock.server, "test-client", true) + + // Register a test prompt + testPrompt := Prompt{ + Name: "test.prompt", + Description: "A test prompt", + } + mock.server.RegisterPrompt(testPrompt) + + // Create a get prompt request + paramBytes, _ := json.Marshal(map[string]any{ + "name": "test.prompt", + }) + promptReq := Request{ + JsonRpc: jsonRpcVersion, + ID: 1, + Method: "prompts/get", + Params: paramBytes, + } + + // Process the request + mock.server.processGetPrompt(context.Background(), client, promptReq) + + // Check response + select { + case response := <-client.channel: + _, err := parseEvent(response) + assert.NoError(t, err, "Should be able to parse event") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for prompt response") + } + }) } // TestBroadcast tests the broadcast functionality @@ -1903,34 +1884,79 @@ func TestNotificationInitialized(t *testing.T) { }) } -func TestSendBadResponse(t *testing.T) { - mock := newMockMcpServer(t) - defer mock.shutdown() +func TestSendResponse(t *testing.T) { + t.Run("bad response", func(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() - // Create a test client - client := addTestClient(mock.server, "test-client", true) + // Create a test client + client := addTestClient(mock.server, "test-client", true) - // Create a response - response := Response{ - JsonRpc: jsonRpcVersion, - ID: 1, - Result: make(chan int), - } + // Create a response + response := Response{ + JsonRpc: jsonRpcVersion, + ID: 1, + Result: make(chan int), + } - // Send the response - mock.server.sendResponse(client, 1, response) + // Send the response + mock.server.sendResponse(context.Background(), client, 1, response) - // Check the response in the client's channel - select { - case res := <-client.channel: - evt, err := parseEvent(res) - require.NoError(t, err, "Should parse event without error") - errMsg, ok := evt.Data["error"].(map[string]any) - require.True(t, ok, "Should have error in response") - assert.Equal(t, float64(errCodeInternalError), errMsg["code"], "Error code should match") - case <-time.After(100 * time.Millisecond): - t.Fatal("Timed out waiting for response") - } + // Check the response in the client's channel + select { + case res := <-client.channel: + evt, err := parseEvent(res) + require.NoError(t, err, "Should parse event without error") + errMsg, ok := evt.Data["error"].(map[string]any) + require.True(t, ok, "Should have error in response") + assert.Equal(t, float64(errCodeInternalError), errMsg["code"], "Error code should match") + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for response") + } + }) + + t.Run("channel full", func(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Create a test client + client := addTestClient(mock.server, "test-client", true) + for i := 0; i < eventChanSize; i++ { + client.channel <- "test" + } + + // Create a response + response := Response{ + JsonRpc: jsonRpcVersion, + ID: 1, + Result: "foo", + } + + buf := logtest.NewCollector(t) + // Send the response + mock.server.sendResponse(context.Background(), client, 1, response) + // Check the response in the client's channel + assert.Contains(t, buf.String(), "channel is full") + }) +} + +func TestSendErrorResponse(t *testing.T) { + t.Run("channel full", func(t *testing.T) { + mock := newMockMcpServer(t) + defer mock.shutdown() + + // Create a test client + client := addTestClient(mock.server, "test-client", true) + for i := 0; i < eventChanSize; i++ { + client.channel <- "test" + } + + buf := logtest.NewCollector(t) + // Send the response + mock.server.sendErrorResponse(context.Background(), client, 1, "foo", errCodeInternalError) + // Check the response in the client's channel + assert.Contains(t, buf.String(), "channel is full") + }) } // TestMethodToolsCall tests the handling of tools/call method through handleRequest @@ -2028,8 +2054,7 @@ func TestMethodToolsCall(t *testing.T) { if len(content) > 0 { firstItem, ok := content[0].(map[string]any) if ok { - assert.Equal(t, "text", firstItem["type"], "Content type should be text") - assert.Contains(t, firstItem["text"], "Processed: test-input", "Content should include processed input") + assert.Contains(t, firstItem[ContentTypeText], "Processed: test-input", "Content should include processed input") } } case <-time.After(100 * time.Millisecond): @@ -2145,7 +2170,6 @@ func TestMethodPromptsGet(t *testing.T) { { Name: "topic", Description: "Topic to discuss", - Default: "artificial intelligence", }, }, Content: "Hello {{name}}! Let's talk about {{topic}}.", @@ -2227,14 +2251,12 @@ func TestMethodPromptsGet(t *testing.T) { if len(messages) > 0 { message, ok := messages[0].(map[string]any) require.True(t, ok, "Message should be an object") - assert.Equal(t, "user", message["role"], "Role should be 'user'") + assert.Equal(t, string(RoleUser), message["role"], "Role should be 'user'") content, ok := message["content"].(map[string]any) require.True(t, ok, "Should have content object") - assert.Equal(t, "text", content["type"], "Content type should be text") - assert.Contains(t, content["text"], "Hello Test User", "Content should include the name argument") - assert.Contains(t, content["text"], "about artificial intelligence", - "Content should include the default topic argument") + assert.Equal(t, ContentTypeText, content["type"], "Content type should be text") + assert.Contains(t, content[ContentTypeText], "Hello Test User", "Content should include the name argument") } case <-time.After(100 * time.Millisecond): t.Fatal("Timed out waiting for prompt get response") @@ -2255,27 +2277,24 @@ func TestMethodPromptsGet(t *testing.T) { { Name: "question", Description: "User's question", - Default: "How does this work?", }, }, - Handler: func(args map[string]string) ([]PromptMessage, error) { + Handler: func(ctx context.Context, args map[string]string) ([]PromptMessage, error) { username := args["username"] question := args["question"] // Create a system message systemMessage := PromptMessage{ - Role: "system", + Role: RoleAssistant, Content: TextContent{ - Type: "text", Text: "You are a helpful assistant.", }, } // Create a user message userMessage := PromptMessage{ - Role: "user", + Role: RoleUser, Content: TextContent{ - Type: "text", Text: fmt.Sprintf("Hi, I'm %s and I'm wondering: %s", username, question), }, } @@ -2340,20 +2359,20 @@ func TestMethodPromptsGet(t *testing.T) { // Check message content if len(messages) >= 2 { - // First message should be system + // First message should be assistant message1, _ := messages[0].(map[string]any) - assert.Equal(t, "system", message1["role"], "First role should be 'system'") + assert.Equal(t, string(RoleAssistant), message1["role"], "First role should be 'system'") content1, _ := message1["content"].(map[string]any) - assert.Contains(t, content1["text"], "helpful assistant", "System message should be correct") + assert.Contains(t, content1[ContentTypeText], "helpful assistant", "System message should be correct") // Second message should be user message2, _ := messages[1].(map[string]any) - assert.Equal(t, "user", message2["role"], "Second role should be 'user'") + assert.Equal(t, string(RoleUser), message2["role"], "Second role should be 'user'") content2, _ := message2["content"].(map[string]any) - assert.Contains(t, content2["text"], "Dynamic User", "User message should contain username") - assert.Contains(t, content2["text"], "How to test this?", "User message should contain question") + assert.Contains(t, content2[ContentTypeText], "Dynamic User", "User message should contain username") + assert.Contains(t, content2[ContentTypeText], "How to test this?", "User message should contain question") } case <-time.After(100 * time.Millisecond): t.Fatal("Timed out waiting for prompt get response") @@ -2459,7 +2478,7 @@ func TestMethodPromptsGet(t *testing.T) { Name: "error-handler-prompt", Description: "A prompt with a handler that returns an error", Arguments: []PromptArgument{}, - Handler: func(args map[string]string) ([]PromptMessage, error) { + Handler: func(ctx context.Context, args map[string]string) ([]PromptMessage, error) { return nil, fmt.Errorf("test handler error") }, } @@ -2583,7 +2602,7 @@ func TestMethodResourcesList(t *testing.T) { URI: "file:///test/resource.txt", Description: "A test resource with handler", MimeType: "text/plain", - Handler: func() (ResourceContent, error) { + Handler: func(ctx context.Context) (ResourceContent, error) { return ResourceContent{ URI: "file:///test/resource.txt", MimeType: "text/plain", @@ -2654,7 +2673,7 @@ func TestMethodResourcesRead(t *testing.T) { URI: "file:///test/resource.txt", Description: "A test resource with handler", MimeType: "text/plain", - Handler: func() (ResourceContent, error) { + Handler: func(ctx context.Context) (ResourceContent, error) { return ResourceContent{ URI: "file:///test/resource.txt", MimeType: "text/plain", @@ -2729,7 +2748,7 @@ func TestMethodResourcesRead(t *testing.T) { require.True(t, ok, "Content should be an object") assert.Equal(t, "file:///test/resource.txt", content["uri"], "URI should match") assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match") - assert.Equal(t, "This is test resource content", content["text"], "Text content should match") + assert.Equal(t, "This is test resource content", content[ContentTypeText], "Text content should match") } case <-time.After(100 * time.Millisecond): t.Fatal("Timed out waiting for resource read response") @@ -2799,7 +2818,7 @@ func TestMethodResourcesRead(t *testing.T) { require.True(t, ok, "Content should be an object") assert.Equal(t, "file:///test/no-handler.txt", content["uri"], "URI should match") assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match") - _, ok = content["text"] + _, ok = content[ContentTypeText] assert.False(t, ok, "Text content should be empty string") } case <-time.After(100 * time.Millisecond): @@ -2880,7 +2899,7 @@ func TestMethodResourcesRead(t *testing.T) { client := addTestClient(mock.server, "test-client-resources", true) // Process through handleRequest - mock.server.processResourcesRead(client, req) + mock.server.processResourcesRead(context.Background(), client, req) select { case response := <-client.channel: @@ -2898,7 +2917,7 @@ func TestMethodResourcesRead(t *testing.T) { URI: "file:///test/error.txt", Description: "A test resource with handler that returns error", MimeType: "text/plain", - Handler: func() (ResourceContent, error) { + Handler: func(ctx context.Context) (ResourceContent, error) { return ResourceContent{}, fmt.Errorf("test handler error") }, } @@ -2946,7 +2965,7 @@ func TestMethodResourcesRead(t *testing.T) { URI: "file:///test/missing-fields.txt", Description: "A test resource with handler that returns content missing fields", MimeType: "text/plain", - Handler: func() (ResourceContent, error) { + Handler: func(ctx context.Context) (ResourceContent, error) { // Return ResourceContent without URI and MimeType return ResourceContent{ Text: "Content with missing fields", @@ -3006,7 +3025,7 @@ func TestMethodResourcesRead(t *testing.T) { require.True(t, ok, "Content should be an object") assert.Equal(t, "file:///test/missing-fields.txt", content["uri"], "URI should be filled from request") assert.Equal(t, "text/plain", content["mimeType"], "MimeType should be filled from resource") - assert.Equal(t, "Content with missing fields", content["text"], "Text content should match") + assert.Equal(t, "Content with missing fields", content[ContentTypeText], "Text content should match") } case <-time.After(100 * time.Millisecond): t.Fatal("Timed out waiting for resource read response") @@ -3159,7 +3178,7 @@ func TestMethodResourcesSubscribe(t *testing.T) { } client := addTestClient(mock.server, "test-client-sub-not-found", true) - mock.server.processResourceSubscribe(client, req) + mock.server.processResourceSubscribe(context.Background(), client, req) select { case response := <-client.channel: @@ -3268,7 +3287,7 @@ func TestToolCallUnmarshalError(t *testing.T) { } // Process the tool call directly - mock.server.processToolCall(client, req) + mock.server.processToolCall(context.Background(), client, req) // Check for error response about invalid JSON select { @@ -3316,7 +3335,7 @@ func TestToolCallWithInvalidParams(t *testing.T) { jsonBody, _ := json.Marshal(req) // Create HTTP request - r := httptest.NewRequest("POST", "/?session_id=test-client-invalid-json", bytes.NewReader(jsonBody)) + r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-invalid-json", bytes.NewReader(jsonBody)) w := httptest.NewRecorder() // Process through handleRequest diff --git a/mcp/types.go b/mcp/types.go index 00cb65582..9f0c397d8 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -1,6 +1,7 @@ package mcp import ( + "context" "encoding/json" "sync" @@ -45,19 +46,18 @@ type ListToolsResult struct { // Message Content Types -// roleType represents the sender or recipient of messages in a conversation -type roleType string +// RoleType represents the sender or recipient of messages in a conversation +type RoleType string // PromptArgument defines a single argument that can be passed to a prompt type PromptArgument struct { Name string `json:"name"` // Argument name Description string `json:"description,omitempty"` // Human-readable description Required bool `json:"required,omitempty"` // Whether this argument is required - Default string `json:"default,omitempty"` // Default value if not provided } // PromptHandler is a function that dynamically generates prompt content -type PromptHandler func(args map[string]string) ([]PromptMessage, error) +type PromptHandler func(ctx context.Context, args map[string]string) ([]PromptMessage, error) // Prompt represents an MCP Prompt definition type Prompt struct { @@ -70,31 +70,43 @@ type Prompt struct { // PromptMessage represents a message in a conversation type PromptMessage struct { - Role roleType `json:"role"` // Message sender role + Role RoleType `json:"role"` // Message sender role Content any `json:"content"` // Message content (TextContent, ImageContent, etc.) } // TextContent represents text content in a message type TextContent struct { - Type string `json:"type"` // Always "text" Text string `json:"text"` // The text content Annotations *Annotations `json:"annotations,omitempty"` // Optional annotations } +type typedTextContent struct { + Type string `json:"type"` + TextContent +} + // ImageContent represents image data in a message type ImageContent struct { - Type string `json:"type"` // Always "image" Data string `json:"data"` // Base64-encoded image data MimeType string `json:"mimeType"` // MIME type (e.g., "image/png") } +type typedImageContent struct { + Type string `json:"type"` + ImageContent +} + // AudioContent represents audio data in a message type AudioContent struct { - Type string `json:"type"` // Always "audio" Data string `json:"data"` // Base64-encoded audio data MimeType string `json:"mimeType"` // MIME type (e.g., "audio/mp3") } +type typedAudioContent struct { + Type string `json:"type"` + AudioContent +} + // FileContent represents file content type FileContent struct { URI string `json:"uri"` // URI identifying the file @@ -115,16 +127,14 @@ type EmbeddedResource struct { // Annotations provides additional metadata for content type Annotations struct { - Audience []roleType `json:"audience,omitempty"` // Who should see this content + Audience []RoleType `json:"audience,omitempty"` // Who should see this content Priority *float64 `json:"priority,omitempty"` // Optional priority (0-1) } // Tool-related Types -// Tool Definition Types - // ToolHandler is a function that handles tool calls -type ToolHandler func(params map[string]any) (any, error) +type ToolHandler func(ctx context.Context, params map[string]any) (any, error) // Tool represents a Model Context Protocol Tool definition type Tool struct { @@ -136,7 +146,7 @@ type Tool struct { // InputSchema represents tool's input schema in JSON Schema format type InputSchema struct { - Type string `json:"type"` // Always "object" for tool inputs + Type string `json:"type"` Properties map[string]any `json:"properties"` // Property definitions Required []string `json:"required,omitempty"` // List of required properties } @@ -144,8 +154,8 @@ type InputSchema struct { // CallToolResult represents a tool call result that conforms to the MCP schema type CallToolResult struct { Result - Content []interface{} `json:"content"` // Content items (text, images, etc.) - IsError bool `json:"isError,omitempty"` // True if tool execution failed + Content []any `json:"content"` // Content items (text, images, etc.) + IsError bool `json:"isError,omitempty"` // True if tool execution failed } // Resource represents a Model Context Protocol Resource definition @@ -158,7 +168,7 @@ type Resource struct { } // ResourceHandler is a function that handles resource read requests -type ResourceHandler func() (ResourceContent, error) +type ResourceHandler func(ctx context.Context) (ResourceContent, error) // ResourceContent represents the content of a resource type ResourceContent struct { diff --git a/mcp/types_test.go b/mcp/types_test.go index 45cea252a..e0b5c323d 100644 --- a/mcp/types_test.go +++ b/mcp/types_test.go @@ -1,6 +1,7 @@ package mcp import ( + "context" "encoding/json" "testing" @@ -79,7 +80,7 @@ func TestToolStructs(t *testing.T) { }, Required: []string{"input"}, }, - Handler: func(params map[string]any) (any, error) { + Handler: func(ctx context.Context, params map[string]any) (any, error) { return "result", nil }, } @@ -145,44 +146,38 @@ func TestResourceStructs(t *testing.T) { func TestContentTypes(t *testing.T) { // Test TextContent textContent := TextContent{ - Type: "text", Text: "Sample text", Annotations: &Annotations{ - Audience: []roleType{roleUser, roleAssistant}, + Audience: []RoleType{RoleUser, RoleAssistant}, Priority: ptr(1.0), }, } data, err := json.Marshal(textContent) assert.NoError(t, err) - assert.Contains(t, string(data), `"type":"text"`) assert.Contains(t, string(data), `"text":"Sample text"`) assert.Contains(t, string(data), `"audience":["user","assistant"]`) assert.Contains(t, string(data), `"priority":1`) // Test ImageContent imageContent := ImageContent{ - Type: "image", Data: "base64data", MimeType: "image/png", } data, err = json.Marshal(imageContent) assert.NoError(t, err) - assert.Contains(t, string(data), `"type":"image"`) assert.Contains(t, string(data), `"data":"base64data"`) assert.Contains(t, string(data), `"mimeType":"image/png"`) // Test AudioContent audioContent := AudioContent{ - Type: "audio", Data: "base64audio", MimeType: "audio/mp3", } data, err = json.Marshal(audioContent) assert.NoError(t, err) - assert.Contains(t, string(data), `"type":"audio"`) assert.Contains(t, string(data), `"data":"base64audio"`) assert.Contains(t, string(data), `"mimeType":"audio/mp3"`) } @@ -197,7 +192,6 @@ func TestCallToolResult(t *testing.T) { }, Content: []interface{}{ TextContent{ - Type: "text", Text: "Sample result", }, }, @@ -207,6 +201,6 @@ func TestCallToolResult(t *testing.T) { data, err := json.Marshal(result) assert.NoError(t, err) assert.Contains(t, string(data), `"_meta":{"progressToken":"token123"}`) - assert.Contains(t, string(data), `"content":[{"type":"text","text":"Sample result"}]`) + assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`) assert.NotContains(t, string(data), `"isError":`) } diff --git a/mcp/util.go b/mcp/util.go index 0963b7098..282840ffd 100644 --- a/mcp/util.go +++ b/mcp/util.go @@ -1,15 +1,107 @@ package mcp -import ( - "fmt" -) +import "fmt" + +// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings +func formatSSEMessage(event string, data []byte) string { + return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data)) +} // ptr is a helper function to get a pointer to a value func ptr[T any](v T) *T { return &v } -// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings -func formatSSEMessage(event string, data []byte) string { - return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data)) +func toTypedContents(contents []any) []any { + typedContents := make([]any, len(contents)) + + for i, content := range contents { + switch v := content.(type) { + case TextContent: + typedContents[i] = typedTextContent{ + Type: ContentTypeText, + TextContent: v, + } + case ImageContent: + typedContents[i] = typedImageContent{ + Type: ContentTypeImage, + ImageContent: v, + } + case AudioContent: + typedContents[i] = typedAudioContent{ + Type: ContentTypeAudio, + AudioContent: v, + } + default: + typedContents[i] = typedTextContent{ + Type: ContentTypeText, + TextContent: TextContent{ + Text: fmt.Sprintf("Unknown content type: %T", v), + }, + } + } + } + + return typedContents +} + +func toTypedPromptMessages(messages []PromptMessage) []PromptMessage { + typedMessages := make([]PromptMessage, len(messages)) + + for i, msg := range messages { + switch v := msg.Content.(type) { + case TextContent: + typedMessages[i] = PromptMessage{ + Role: msg.Role, + Content: typedTextContent{ + Type: ContentTypeText, + TextContent: v, + }, + } + case ImageContent: + typedMessages[i] = PromptMessage{ + Role: msg.Role, + Content: typedImageContent{ + Type: ContentTypeImage, + ImageContent: v, + }, + } + case AudioContent: + typedMessages[i] = PromptMessage{ + Role: msg.Role, + Content: typedAudioContent{ + Type: ContentTypeAudio, + AudioContent: v, + }, + } + default: + typedMessages[i] = PromptMessage{ + Role: msg.Role, + Content: typedTextContent{ + Type: ContentTypeText, + TextContent: TextContent{ + Text: fmt.Sprintf("Unknown content type: %T", v), + }, + }, + } + } + } + + return typedMessages +} + +// validatePromptArguments checks if all required arguments are provided +// Returns a list of missing required arguments +func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string { + var missingArgs []string + + for _, arg := range prompt.Arguments { + if arg.Required { + if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 { + missingArgs = append(missingArgs, arg.Name) + } + } + } + + return missingArgs } diff --git a/mcp/util_test.go b/mcp/util_test.go index 336cc3920..0014378b6 100644 --- a/mcp/util_test.go +++ b/mcp/util_test.go @@ -8,29 +8,9 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestPtr(t *testing.T) { - tests := []struct { - name string - v interface{} - }{ - {"string", "test"}, - {"int", 42}, - {"bool", true}, - {"float", 3.14}, - {"struct", struct{ Name string }{"test"}}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := ptr(tt.v) - assert.NotNil(t, got, "ptr() should not return nil") - assert.Equal(t, tt.v, *got, "dereferenced pointer should equal input value") - }) - } -} - type Event struct { Type string Data map[string]any @@ -61,3 +41,234 @@ func parseEvent(input string) (*Event, error) { return &evt, nil } + +// TestToTypedPromptMessages tests the toTypedPromptMessages function +func TestToTypedPromptMessages(t *testing.T) { + // Test with multiple message types in one test + t.Run("MixedContentTypes", func(t *testing.T) { + // Create test data with different content types + messages := []PromptMessage{ + { + Role: RoleUser, + Content: TextContent{ + Text: "Hello, this is a text message", + Annotations: &Annotations{ + Audience: []RoleType{RoleUser, RoleAssistant}, + Priority: ptr(0.8), + }, + }, + }, + { + Role: RoleAssistant, + Content: ImageContent{ + Data: "base64ImageData", + MimeType: "image/jpeg", + }, + }, + { + Role: RoleUser, + Content: AudioContent{ + Data: "base64AudioData", + MimeType: "audio/mp3", + }, + }, + { + Role: "system", + Content: "This is a simple string that should be handled as unknown type", + }, + } + + // Call the function + result := toTypedPromptMessages(messages) + + // Validate results + require.Len(t, result, 4, "Should return the same number of messages") + + // Validate first message (TextContent) + msg := result[0] + assert.Equal(t, RoleUser, msg.Role, "Role should be preserved") + + // Type assertion using reflection since Content is an interface + typed, ok := msg.Content.(typedTextContent) + require.True(t, ok, "Should be typedTextContent") + assert.Equal(t, ContentTypeText, typed.Type, "Type should be text") + assert.Equal(t, "Hello, this is a text message", typed.Text, "Text content should be preserved") + require.NotNil(t, typed.Annotations, "Annotations should be preserved") + assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved") + require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved") + assert.Equal(t, 0.8, *typed.Annotations.Priority, "Priority value should be preserved") + + // Validate second message (ImageContent) + msg = result[1] + assert.Equal(t, RoleAssistant, msg.Role, "Role should be preserved") + + // Type assertion for image content + typedImg, ok := msg.Content.(typedImageContent) + require.True(t, ok, "Should be typedImageContent") + assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image") + assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved") + assert.Equal(t, "image/jpeg", typedImg.MimeType, "MimeType should be preserved") + + // Validate third message (AudioContent) + msg = result[2] + assert.Equal(t, RoleUser, msg.Role, "Role should be preserved") + + // Type assertion for audio content + typedAudio, ok := msg.Content.(typedAudioContent) + require.True(t, ok, "Should be typedAudioContent") + assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio") + assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved") + assert.Equal(t, "audio/mp3", typedAudio.MimeType, "MimeType should be preserved") + + // Validate fourth message (unknown type converted to TextContent) + msg = result[3] + assert.Equal(t, RoleType("system"), msg.Role, "Role should be preserved") + + // Should be converted to a typedTextContent with error message + typedUnknown, ok := msg.Content.(typedTextContent) + require.True(t, ok, "Unknown content should be converted to typedTextContent") + assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text") + assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type") + assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type") + }) + + // Test empty input + t.Run("EmptyInput", func(t *testing.T) { + messages := []PromptMessage{} + result := toTypedPromptMessages(messages) + assert.Empty(t, result, "Should return empty slice for empty input") + }) + + // Test with nil annotations + t.Run("NilAnnotations", func(t *testing.T) { + messages := []PromptMessage{ + { + Role: RoleUser, + Content: TextContent{ + Text: "Text with nil annotations", + Annotations: nil, + }, + }, + } + + result := toTypedPromptMessages(messages) + require.Len(t, result, 1, "Should return one message") + + typed, ok := result[0].Content.(typedTextContent) + require.True(t, ok, "Should be typedTextContent") + assert.Equal(t, ContentTypeText, typed.Type, "Type should be text") + assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved") + assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil") + }) +} + +// TestToTypedContents tests the toTypedContents function +func TestToTypedContents(t *testing.T) { + // Test with multiple content types in one test + t.Run("MixedContentTypes", func(t *testing.T) { + // Create test data with different content types + contents := []any{ + TextContent{ + Text: "Hello, this is a text content", + Annotations: &Annotations{ + Audience: []RoleType{RoleUser, RoleAssistant}, + Priority: ptr(0.7), + }, + }, + ImageContent{ + Data: "base64ImageData", + MimeType: "image/png", + }, + AudioContent{ + Data: "base64AudioData", + MimeType: "audio/wav", + }, + "This is a simple string that should be handled as unknown type", + } + + // Call the function + result := toTypedContents(contents) + + // Validate results + require.Len(t, result, 4, "Should return the same number of contents") + + // Validate first content (TextContent) + typed, ok := result[0].(typedTextContent) + require.True(t, ok, "Should be typedTextContent") + assert.Equal(t, ContentTypeText, typed.Type, "Type should be text") + assert.Equal(t, "Hello, this is a text content", typed.Text, "Text content should be preserved") + require.NotNil(t, typed.Annotations, "Annotations should be preserved") + assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved") + require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved") + assert.Equal(t, 0.7, *typed.Annotations.Priority, "Priority value should be preserved") + + // Validate second content (ImageContent) + typedImg, ok := result[1].(typedImageContent) + require.True(t, ok, "Should be typedImageContent") + assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image") + assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved") + assert.Equal(t, "image/png", typedImg.MimeType, "MimeType should be preserved") + + // Validate third content (AudioContent) + typedAudio, ok := result[2].(typedAudioContent) + require.True(t, ok, "Should be typedAudioContent") + assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio") + assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved") + assert.Equal(t, "audio/wav", typedAudio.MimeType, "MimeType should be preserved") + + // Validate fourth content (unknown type converted to TextContent) + typedUnknown, ok := result[3].(typedTextContent) + require.True(t, ok, "Unknown content should be converted to typedTextContent") + assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text") + assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type") + assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type") + }) + + // Test empty input + t.Run("EmptyInput", func(t *testing.T) { + contents := []any{} + result := toTypedContents(contents) + assert.Empty(t, result, "Should return empty slice for empty input") + }) + + // Test with nil annotations + t.Run("NilAnnotations", func(t *testing.T) { + contents := []any{ + TextContent{ + Text: "Text with nil annotations", + Annotations: nil, + }, + } + + result := toTypedContents(contents) + require.Len(t, result, 1, "Should return one content") + + typed, ok := result[0].(typedTextContent) + require.True(t, ok, "Should be typedTextContent") + assert.Equal(t, ContentTypeText, typed.Type, "Type should be text") + assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved") + assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil") + }) + + // Test with custom struct (should be handled as unknown type) + t.Run("CustomStruct", func(t *testing.T) { + type CustomContent struct { + Data string + } + + contents := []any{ + CustomContent{ + Data: "custom data", + }, + } + + result := toTypedContents(contents) + require.Len(t, result, 1, "Should return one content") + + typed, ok := result[0].(typedTextContent) + require.True(t, ok, "Custom struct should be converted to typedTextContent") + assert.Equal(t, ContentTypeText, typed.Type, "Type should be text") + assert.Contains(t, typed.Text, "Unknown content type:", "Should contain error about unknown type") + assert.Contains(t, typed.Text, "CustomContent", "Should mention the actual type") + }) +} diff --git a/mcp/vars.go b/mcp/vars.go index cc9a300da..7a5b6fb8e 100644 --- a/mcp/vars.go +++ b/mcp/vars.go @@ -13,6 +13,9 @@ const ( // Session identifier key used in request URLs sessionIdKey = "session_id" + + // progressTokenKey is used to track progress of long-running tasks + progressTokenKey = "progressToken" ) // Server-Sent Events (SSE) event types @@ -26,11 +29,20 @@ const ( // Content type identifiers const ( - // Text content type - contentTypeText = "text" + // ContentTypeObject is object content type + ContentTypeObject = "object" - // Image content type - contentTypeImage = "image" + // ContentTypeText is text content type + ContentTypeText = "text" + + // ContentTypeImage is image content type + ContentTypeImage = "image" + + // ContentTypeAudio is audio content type + ContentTypeAudio = "audio" + + // ContentTypeResource is resource content type + ContentTypeResource = "resource" ) // Collection keys for broadcast events @@ -72,11 +84,11 @@ const ( // User and assistant role definitions const ( - // The "user" role - the entity asking questions - roleUser roleType = "user" + // RoleUser is the "user" role - the entity asking questions + RoleUser RoleType = "user" - // The "assistant" role - the entity providing responses - roleAssistant roleType = "assistant" + // RoleAssistant is the "assistant" role - the entity providing responses + RoleAssistant RoleType = "assistant" ) // Method names as defined in the MCP specification diff --git a/mcp/vars_test.go b/mcp/vars_test.go index 5cfdca195..4a894a16f 100644 --- a/mcp/vars_test.go +++ b/mcp/vars_test.go @@ -146,13 +146,9 @@ func TestCollectionKeys(t *testing.T) { // TestRoleTypes checks that role types are used correctly func TestRoleTypes(t *testing.T) { - // Verify role type constants - assert.Equal(t, "user", string(roleUser), "User role should be 'user'") - assert.Equal(t, "assistant", string(roleAssistant), "Assistant role should be 'assistant'") - // Test in annotations annotations := Annotations{ - Audience: []roleType{roleUser, roleAssistant}, + Audience: []RoleType{RoleUser, RoleAssistant}, } data, err := json.Marshal(annotations) assert.NoError(t, err)