feat: improve mcp (#4828)

Signed-off-by: kevin <wanjunfeng@gmail.com>
This commit is contained in:
Kevin Wan
2025-05-04 15:29:14 +08:00
committed by GitHub
parent c3820a95c1
commit 69aa7fe346
14 changed files with 1661 additions and 391 deletions

View File

@@ -34,7 +34,10 @@ type McpConf struct {
// Cors contains allowed CORS origins
Cors []string `json:",optional"`
// ToolTimeout is the maximum time allowed for tool execution
ToolTimeout time.Duration `json:",default=30s"`
// SseTimeout is the maximum time allowed for SSE connections
SseTimeout time.Duration `json:",default=24h"`
// MessageTimeout is the maximum time allowed for request execution
MessageTimeout time.Duration `json:",default=30s"`
}
}

View File

@@ -27,7 +27,7 @@ mcp:
assert.Equal(t, "2024-11-05", c.Mcp.ProtocolVersion, "Default protocol version should be 2024-11-05")
assert.Equal(t, "/sse", c.Mcp.SseEndpoint, "Default SSE endpoint should be /sse")
assert.Equal(t, "/message", c.Mcp.MessageEndpoint, "Default message endpoint should be /message")
assert.Equal(t, 30*time.Second, c.Mcp.ToolTimeout, "Default tool timeout should be 30s")
assert.Equal(t, 30*time.Second, c.Mcp.MessageTimeout, "Default message timeout should be 30s")
}
func TestMcpConfCustomValues(t *testing.T) {
@@ -43,7 +43,7 @@ func TestMcpConfCustomValues(t *testing.T) {
"SseEndpoint": "/custom-sse",
"MessageEndpoint": "/custom-message",
"Cors": ["http://localhost:3000", "http://example.com"],
"ToolTimeout": "60s"
"MessageTimeout": "60s"
}
}`
@@ -59,5 +59,5 @@ func TestMcpConfCustomValues(t *testing.T) {
assert.Equal(t, "/custom-sse", c.Mcp.SseEndpoint, "SSE endpoint should be customizable")
assert.Equal(t, "/custom-message", c.Mcp.MessageEndpoint, "Message endpoint should be customizable")
assert.Equal(t, []string{"http://localhost:3000", "http://example.com"}, c.Mcp.Cors, "CORS settings should be customizable")
assert.Equal(t, 60*time.Second, c.Mcp.ToolTimeout, "Tool timeout should be customizable")
assert.Equal(t, 60*time.Second, c.Mcp.MessageTimeout, "Tool timeout should be customizable")
}

View File

@@ -2,6 +2,7 @@ package mcp
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
@@ -59,7 +60,7 @@ func TestHTTPHandlerIntegration(t *testing.T) {
conf := McpConf{}
conf.Mcp.Name = "test-integration"
conf.Mcp.Version = "1.0.0-test"
conf.Mcp.ToolTimeout = 1 * time.Second
conf.Mcp.MessageTimeout = 1 * time.Second
// Create a mock server directly
server := &sseMcpServer{
@@ -75,7 +76,6 @@ func TestHTTPHandlerIntegration(t *testing.T) {
Name: "echo",
Description: "Echo tool for testing",
InputSchema: InputSchema{
Type: "object",
Properties: map[string]any{
"message": map[string]any{
"type": "string",
@@ -83,7 +83,7 @@ func TestHTTPHandlerIntegration(t *testing.T) {
},
},
},
Handler: func(params map[string]any) (any, error) {
Handler: func(ctx context.Context, params map[string]any) (any, error) {
if msg, ok := params["message"].(string); ok {
return fmt.Sprintf("Echo: %s", msg), nil
}
@@ -181,7 +181,7 @@ func TestHandlerResponseFlow(t *testing.T) {
Name: "test.tool",
Description: "Test tool",
InputSchema: InputSchema{Type: "object"},
Handler: func(params map[string]any) (any, error) {
Handler: func(ctx context.Context, params map[string]any) (any, error) {
return "tool result", nil
},
})
@@ -329,7 +329,7 @@ func TestProcessListMethods(t *testing.T) {
Params: json.RawMessage(`{"cursor": "", "_meta": {"progressToken": "token1"}}`),
}
server.processListTools(client, req)
server.processListTools(context.Background(), client, req)
// Read response
select {
@@ -344,7 +344,7 @@ func TestProcessListMethods(t *testing.T) {
req.ID = 2
req.Method = methodPromptsList
req.Params = json.RawMessage(`{"cursor": "next"}`)
server.processListPrompts(client, req)
server.processListPrompts(context.Background(), client, req)
// Read response
select {
@@ -358,7 +358,7 @@ func TestProcessListMethods(t *testing.T) {
req.ID = 3
req.Method = methodResourcesList
req.Params = json.RawMessage(`{"cursor": "next"}`)
server.processListResources(client, req)
server.processListResources(context.Background(), client, req)
// Read response
select {
@@ -393,7 +393,7 @@ func TestErrorResponseHandling(t *testing.T) {
}
// Mock handleRequest by directly calling error handler
server.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound)
server.sendErrorResponse(context.Background(), client, req.ID, "Method not found", errCodeMethodNotFound)
// Check response
select {
@@ -412,7 +412,7 @@ func TestErrorResponseHandling(t *testing.T) {
}
// Call process method directly
server.processToolCall(client, toolReq)
server.processToolCall(context.Background(), client, toolReq)
// Check response
select {
@@ -431,7 +431,7 @@ func TestErrorResponseHandling(t *testing.T) {
}
// Call process method directly
server.processGetPrompt(client, promptReq)
server.processGetPrompt(context.Background(), client, promptReq)
// Check response
select {

23
mcp/parser.go Normal file
View File

@@ -0,0 +1,23 @@
package mcp
import (
"fmt"
"github.com/zeromicro/go-zero/core/mapping"
)
// ParseArguments parses the arguments and populates the request object
func ParseArguments(args any, req any) error {
switch arguments := args.(type) {
case map[string]string:
m := make(map[string]any, len(arguments))
for k, v := range arguments {
m[k] = v
}
return mapping.UnmarshalJsonMap(m, req, mapping.WithStringValues())
case map[string]any:
return mapping.UnmarshalJsonMap(arguments, req)
default:
return fmt.Errorf("unsupported argument type: %T", arguments)
}
}

139
mcp/parser_test.go Normal file
View File

@@ -0,0 +1,139 @@
package mcp
import (
"testing"
"github.com/stretchr/testify/assert"
)
// TestParseArguments_MapStringString tests parsing map[string]string arguments
func TestParseArguments_MapStringString(t *testing.T) {
// Sample request struct to populate
type requestStruct struct {
Name string `json:"name"`
Message string `json:"message"`
Count int `json:"count"`
Enabled bool `json:"enabled"`
}
// Create test arguments
args := map[string]string{
"name": "test-name",
"message": "hello world",
"count": "42",
"enabled": "true",
}
// Create a target object to populate
var req requestStruct
// Parse the arguments
err := ParseArguments(args, &req)
// Verify results
assert.NoError(t, err, "Should parse map[string]string without error")
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
assert.Equal(t, 42, req.Count, "Count should be correctly parsed to int")
assert.True(t, req.Enabled, "Enabled should be correctly parsed to bool")
}
// TestParseArguments_MapStringAny tests parsing map[string]any arguments
func TestParseArguments_MapStringAny(t *testing.T) {
// Sample request struct to populate
type requestStruct struct {
Name string `json:"name"`
Message string `json:"message"`
Count int `json:"count"`
Enabled bool `json:"enabled"`
Tags []string `json:"tags"`
Metadata map[string]string `json:"metadata"`
}
// Create test arguments with mixed types
args := map[string]any{
"name": "test-name",
"message": "hello world",
"count": 42, // note: this is already an int
"enabled": true, // note: this is already a bool
"tags": []string{"tag1", "tag2"},
"metadata": map[string]string{
"key1": "value1",
"key2": "value2",
},
}
// Create a target object to populate
var req requestStruct
// Parse the arguments
err := ParseArguments(args, &req)
// Verify results
assert.NoError(t, err, "Should parse map[string]any without error")
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
assert.Equal(t, 42, req.Count, "Count should be correctly parsed")
assert.True(t, req.Enabled, "Enabled should be correctly parsed")
assert.Equal(t, []string{"tag1", "tag2"}, req.Tags, "Tags should be correctly parsed")
assert.Equal(t, map[string]string{
"key1": "value1",
"key2": "value2",
}, req.Metadata, "Metadata should be correctly parsed")
}
// TestParseArguments_UnsupportedType tests parsing with an unsupported type
func TestParseArguments_UnsupportedType(t *testing.T) {
// Sample request struct to populate
type requestStruct struct {
Name string `json:"name"`
Message string `json:"message"`
}
// Use an unsupported argument type (slice)
args := []string{"not", "a", "map"}
// Create a target object to populate
var req requestStruct
// Parse the arguments
err := ParseArguments(args, &req)
// Verify error is returned with correct message
assert.Error(t, err, "Should return error for unsupported type")
assert.Contains(t, err.Error(), "unsupported argument type", "Error should mention unsupported type")
assert.Contains(t, err.Error(), "[]string", "Error should include the actual type")
}
// TestParseArguments_EmptyMap tests parsing with empty maps
func TestParseArguments_EmptyMap(t *testing.T) {
// Sample request struct to populate
type requestStruct struct {
Name string `json:"name,optional"`
Message string `json:"message,optional"`
}
// Test empty map[string]string
t.Run("EmptyMapStringString", func(t *testing.T) {
args := map[string]string{}
var req requestStruct
err := ParseArguments(args, &req)
assert.NoError(t, err, "Should parse empty map[string]string without error")
assert.Empty(t, req.Name, "Name should be empty string")
assert.Empty(t, req.Message, "Message should be empty string")
})
// Test empty map[string]any
t.Run("EmptyMapStringAny", func(t *testing.T) {
args := map[string]any{}
var req requestStruct
err := ParseArguments(args, &req)
assert.NoError(t, err, "Should parse empty map[string]any without error")
assert.Empty(t, req.Name, "Name should be empty string")
assert.Empty(t, req.Message, "Message should be empty string")
})
}

View File

@@ -1,7 +1,7 @@
# Model Context Protocol (MCP) SDK Implementation
# Model Context Protocol (MCP) Implementation
## Overview
This package implements a Model Context Protocol (MCP) server in Go that facilitates real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation provides a framework for building AI-assisted applications with bidirectional communication capabilities.
This package implements the Model Context Protocol (MCP) server specification in Go, providing a framework for real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation follows the standardized protocol for building AI-assisted applications with bidirectional communication capabilities.
## Core Components
@@ -54,9 +54,817 @@ This package implements a Model Context Protocol (MCP) server in Go that facilit
## Usage
To create and use an MCP server, see the examples directory for practical implementation examples including:
- Tool registration and execution
- Static and dynamic prompt creation
- Resource handling with proper URI identification
- Embedded resources in prompt responses
- Client connection management
### Setting Up an MCP Server
To create and start an MCP server:
```go
package main
import (
"github.com/zeromicro/go-zero/core/conf"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/mcp"
)
func main() {
// Load configuration from YAML file
var c mcp.McpConf
conf.MustLoad("config.yaml", &c)
// Optional: Disable stats logging
logx.DisableStat()
// Create MCP server
server := mcp.NewMcpServer(c)
// Register tools, prompts, and resources (examples below)
// Start the server and ensure it's stopped on exit
defer server.Stop()
server.Start()
}
```
Sample configuration file (config.yaml):
```yaml
name: mcp-server
host: localhost
port: 8080
mcp:
name: my-mcp-server
messageTimeout: 30s # Timeout for tool calls
cors:
- http://localhost:3000 # Optional CORS configuration
```
### Registering Tools
Tools allow AI models to execute custom code through the MCP protocol.
#### Basic Tool Example:
```go
// Register a simple echo tool
echoTool := mcp.Tool{
Name: "echo",
Description: "Echoes back the message provided by the user",
InputSchema: mcp.InputSchema{
Properties: map[string]any{
"message": map[string]any{
"type": "string",
"description": "The message to echo back",
},
"prefix": map[string]any{
"type": "string",
"description": "Optional prefix to add to the echoed message",
"default": "Echo: ",
},
},
Required: []string{"message"},
},
Handler: func(ctx context.Context, params map[string]any) (any, error) {
var req struct {
Message string `json:"message"`
Prefix string `json:"prefix,optional"`
}
if err := mcp.ParseArguments(params, &req); err != nil {
return nil, fmt.Errorf("failed to parse params: %w", err)
}
prefix := "Echo: "
if len(req.Prefix) > 0 {
prefix = req.Prefix
}
return prefix + req.Message, nil
},
}
server.RegisterTool(echoTool)
```
#### Tool with Different Response Types:
```go
// Tool returning JSON data
dataTool := mcp.Tool{
Name: "data.generate",
Description: "Generates sample data in various formats",
InputSchema: mcp.InputSchema{
Properties: map[string]any{
"format": map[string]any{
"type": "string",
"description": "Format of data (json, text)",
"enum": []string{"json", "text"},
},
},
},
Handler: func(ctx context.Context, params map[string]any) (any, error) {
var req struct {
Format string `json:"format"`
}
if err := mcp.ParseArguments(params, &req); err != nil {
return nil, fmt.Errorf("failed to parse params: %w", err)
}
if req.Format == "json" {
// Return structured data
return map[string]any{
"items": []map[string]any{
{"id": 1, "name": "Item 1"},
{"id": 2, "name": "Item 2"},
},
"count": 2,
}, nil
}
// Default to text
return "Sample text data", nil
},
}
server.RegisterTool(dataTool)
```
#### Image Generation Tool Example:
```go
// Tool returning image content
imageTool := mcp.Tool{
Name: "image.generate",
Description: "Generates a simple image",
InputSchema: mcp.InputSchema{
Properties: map[string]any{
"type": map[string]any{
"type": "string",
"description": "Type of image to generate",
"default": "placeholder",
},
},
},
Handler: func(ctx context.Context, params map[string]any) (any, error) {
// Return image content directly
return mcp.ImageContent{
Data: "base64EncodedImageData...", // Base64 encoded image data
MimeType: "image/png",
}, nil
},
}
server.RegisterTool(imageTool)
```
#### Using ToolResult for Custom Outputs:
```go
// Tool that returns a custom ToolResult type
customResultTool := mcp.Tool{
Name: "custom.result",
Description: "Returns a custom formatted result",
InputSchema: mcp.InputSchema{
Properties: map[string]any{
"resultType": map[string]any{
"type": "string",
"enum": []string{"text", "image"},
},
},
},
Handler: func(ctx context.Context, params map[string]any) (any, error) {
var req struct {
ResultType string `json:"resultType"`
}
if err := mcp.ParseArguments(params, &req); err != nil {
return nil, fmt.Errorf("failed to parse params: %w", err)
}
if req.ResultType == "image" {
return mcp.ToolResult{
Type: mcp.ContentTypeImage,
Content: map[string]any{
"data": "base64EncodedImageData...",
"mimeType": "image/jpeg",
},
}, nil
}
// Default to text
return mcp.ToolResult{
Type: mcp.ContentTypeText,
Content: "This is a text result from ToolResult",
}, nil
},
}
server.RegisterTool(customResultTool)
```
### Registering Prompts
Prompts are reusable conversation templates for AI models.
#### Static Prompt Example:
```go
// Register a simple static prompt with placeholders
server.RegisterPrompt(mcp.Prompt{
Name: "hello",
Description: "A simple hello prompt",
Arguments: []mcp.PromptArgument{
{
Name: "name",
Description: "The name to greet",
Required: false,
},
},
Content: "Say hello to {{name}} and introduce yourself as an AI assistant.",
})
```
#### Dynamic Prompt with Handler Function:
```go
// Register a prompt with a dynamic handler function
server.RegisterPrompt(mcp.Prompt{
Name: "dynamic-prompt",
Description: "A prompt that uses a handler to generate dynamic content",
Arguments: []mcp.PromptArgument{
{
Name: "username",
Description: "User's name for personalized greeting",
Required: true,
},
{
Name: "topic",
Description: "Topic of expertise",
Required: true,
},
},
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
var req struct {
Username string `json:"username"`
Topic string `json:"topic"`
}
if err := mcp.ParseArguments(args, &req); err != nil {
return nil, fmt.Errorf("failed to parse args: %w", err)
}
// Create a user message
userMessage := mcp.PromptMessage{
Role: mcp.RoleUser,
Content: mcp.TextContent{
Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic),
},
}
// Create an assistant response with current time
currentTime := time.Now().Format(time.RFC1123)
assistantMessage := mcp.PromptMessage{
Role: mcp.RoleAssistant,
Content: mcp.TextContent{
Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.",
req.Username, req.Topic, currentTime),
},
}
// Return both messages as a conversation
return []mcp.PromptMessage{userMessage, assistantMessage}, nil
},
})
```
#### Multi-Message Prompt with Code Examples:
```go
// Register a prompt that provides code examples in different programming languages
server.RegisterPrompt(mcp.Prompt{
Name: "code-example",
Description: "Provides code examples in different programming languages",
Arguments: []mcp.PromptArgument{
{
Name: "language",
Description: "Programming language for the example",
Required: true,
},
{
Name: "complexity",
Description: "Complexity level (simple, medium, advanced)",
},
},
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
var req struct {
Language string `json:"language"`
Complexity string `json:"complexity,optional"`
}
if err := mcp.ParseArguments(args, &req); err != nil {
return nil, fmt.Errorf("failed to parse args: %w", err)
}
// Validate language
supportedLanguages := map[string]bool{"go": true, "python": true, "javascript": true, "rust": true}
if !supportedLanguages[req.Language] {
return nil, fmt.Errorf("unsupported language: %s", req.Language)
}
// Generate code example based on language and complexity
var codeExample string
switch req.Language {
case "go":
if req.Complexity == "simple" {
codeExample = `
package main
import "fmt"
func main() {
fmt.Println("Hello, World!")
}`
} else {
codeExample = `
package main
import (
"fmt"
"time"
)
func main() {
now := time.Now()
fmt.Printf("Hello, World! Current time is %s\n", now.Format(time.RFC3339))
}`
}
case "python":
// Python example code
if req.Complexity == "simple" {
codeExample = `
def greet(name):
return f"Hello, {name}!"
print(greet("World"))`
} else {
codeExample = `
import datetime
def greet(name, include_time=False):
message = f"Hello, {name}!"
if include_time:
message += f" Current time is {datetime.datetime.now().isoformat()}"
return message
print(greet("World", include_time=True))`
}
}
// Create messages array according to MCP spec
messages := []mcp.PromptMessage{
{
Role: mcp.RoleAssistant,
Content: mcp.TextContent{
Text: fmt.Sprintf("You are a helpful coding assistant specialized in %s programming.", req.Language),
},
},
{
Role: mcp.RoleUser,
Content: mcp.TextContent{
Text: fmt.Sprintf("Show me a %s example of a Hello World program in %s.", req.Complexity, req.Language),
},
},
{
Role: mcp.RoleAssistant,
Content: mcp.TextContent{
Text: fmt.Sprintf("Here's a %s example in %s:\n\n```%s%s\n```\n\nHow can I help you implement this?",
req.Complexity, req.Language, req.Language, codeExample),
},
},
}
return messages, nil
},
})
```
### Registering Resources
Resources provide access to external content such as files or generated data.
#### Basic Resource Example:
```go
// Register a static resource
server.RegisterResource(mcp.Resource{
Name: "example-document",
URI: "file:///example/document.txt",
Description: "An example document",
MimeType: "text/plain",
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
return mcp.ResourceContent{
URI: "file:///example/document.txt",
MimeType: "text/plain",
Text: "This is an example document content.",
}, nil
},
})
```
#### Dynamic Resource with Code Example:
```go
// Register a Go code resource with dynamic handler
server.RegisterResource(mcp.Resource{
Name: "go-example",
URI: "file:///project/src/main.go",
Description: "A simple Go example with multiple files",
MimeType: "text/x-go",
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
// Return ResourceContent with all required fields
return mcp.ResourceContent{
URI: "file:///project/src/main.go",
MimeType: "text/x-go",
Text: "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}",
}, nil
},
})
// Register a companion file for the above example
server.RegisterResource(mcp.Resource{
Name: "go-greeting",
URI: "file:///project/src/greeting/greeting.go",
Description: "A greeting package for the Go example",
MimeType: "text/x-go",
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
return mcp.ResourceContent{
URI: "file:///project/src/greeting/greeting.go",
MimeType: "text/x-go",
Text: "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}",
}, nil
},
})
```
#### Binary Resource Example:
```go
// Register a binary resource (like an image)
server.RegisterResource(mcp.Resource{
Name: "example-image",
URI: "file:///example/image.png",
Description: "An example image",
MimeType: "image/png",
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
// Read image from file or generate it
imageData := "base64EncodedImageData..." // Base64 encoded image data
return mcp.ResourceContent{
URI: "file:///example/image.png",
MimeType: "image/png",
Blob: imageData, // For binary data
}, nil
},
})
```
### Using Resources in Prompts
You can embed resources in prompt responses to create rich interactions with proper MCP-compliant structure:
```go
// Register a prompt that embeds a resource
server.RegisterPrompt(mcp.Prompt{
Name: "resource-example",
Description: "A prompt that embeds a resource",
Arguments: []mcp.PromptArgument{
{
Name: "file_type",
Description: "Type of file to show (rust or go)",
Required: true,
},
},
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
var req struct {
FileType string `json:"file_type"`
}
if err := mcp.ParseArguments(args, &req); err != nil {
return nil, fmt.Errorf("failed to parse args: %w", err)
}
var resourceURI, mimeType, fileContent string
if req.FileType == "rust" {
resourceURI = "file:///project/src/main.rs"
mimeType = "text/x-rust"
fileContent = "fn main() {\n println!(\"Hello world!\");\n}"
} else {
resourceURI = "file:///project/src/main.go"
mimeType = "text/x-go"
fileContent = "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello, world!\")\n}"
}
// Create message with embedded resource using proper MCP format
return []mcp.PromptMessage{
{
Role: mcp.RoleUser,
Content: mcp.TextContent{
Text: fmt.Sprintf("Can you explain this %s code?", req.FileType),
},
},
{
Role: mcp.RoleAssistant,
Content: mcp.EmbeddedResource{
Type: mcp.ContentTypeResource,
Resource: struct {
URI string `json:"uri"`
MimeType string `json:"mimeType"`
Text string `json:"text,omitempty"`
Blob string `json:"blob,omitempty"`
}{
URI: resourceURI,
MimeType: mimeType,
Text: fileContent,
},
},
},
{
Role: mcp.RoleAssistant,
Content: mcp.TextContent{
Text: fmt.Sprintf("Above is a simple Hello World example in %s. Let me explain how it works.", req.FileType),
},
},
}, nil
},
})
```
### Multiple File Resources Example
```go
// Register a prompt that demonstrates embedding multiple resource files
server.RegisterPrompt(mcp.Prompt{
Name: "go-code-example",
Description: "A prompt that correctly embeds multiple resource files",
Arguments: []mcp.PromptArgument{
{
Name: "format",
Description: "How to format the code display",
},
},
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
var req struct {
Format string `json:"format,optional"`
}
if err := mcp.ParseArguments(args, &req); err != nil {
return nil, fmt.Errorf("failed to parse args: %w", err)
}
// Get the Go code for multiple files
var mainGoText string = "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}"
var greetingGoText string = "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}"
// Create message with properly formatted embedded resource per MCP spec
messages := []mcp.PromptMessage{
{
Role: mcp.RoleUser,
Content: mcp.TextContent{
Text: "Show me a simple Go example with proper imports.",
},
},
{
Role: mcp.RoleAssistant,
Content: mcp.TextContent{
Text: "Here's a simple Go example project:",
},
},
{
Role: mcp.RoleAssistant,
Content: mcp.EmbeddedResource{
Type: mcp.ContentTypeResource,
Resource: struct {
URI string `json:"uri"`
MimeType string `json:"mimeType"`
Text string `json:"text,omitempty"`
Blob string `json:"blob,omitempty"`
}{
URI: "file:///project/src/main.go",
MimeType: "text/x-go",
Text: mainGoText,
},
},
},
}
// Add explanation and additional file if requested
if req.Format == "with_explanation" {
messages = append(messages, mcp.PromptMessage{
Role: mcp.RoleAssistant,
Content: mcp.TextContent{
Text: "This example demonstrates a simple Go application with modular structure. The main.go file imports from a local 'greeting' package that provides the Hello function.",
},
})
// Also show the greeting.go file with correct resource format
messages = append(messages, mcp.PromptMessage{
Role: mcp.RoleAssistant,
Content: mcp.EmbeddedResource{
Type: mcp.ContentTypeResource,
Resource: struct {
URI string `json:"uri"`
MimeType string `json:"mimeType"`
Text string `json:"text,omitempty"`
Blob string `json:"blob,omitempty"`
}{
URI: "file:///project/src/greeting/greeting.go",
MimeType: "text/x-go",
Text: greetingGoText,
},
},
})
}
return messages, nil
},
})
```
### Complete Application Example
Here's a complete example demonstrating all the components:
```go
package main
import (
"context"
"fmt"
"log"
"time"
"github.com/zeromicro/go-zero/core/conf"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/mcp"
)
func main() {
// Load configuration
var c mcp.McpConf
if err := conf.Load("config.yaml", &c); err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// Set up logging
logx.DisableStat()
// Create MCP server
server := mcp.NewMcpServer(c)
defer server.Stop()
// Register a simple echo tool
echoTool := mcp.Tool{
Name: "echo",
Description: "Echoes back the message provided by the user",
InputSchema: mcp.InputSchema{
Properties: map[string]any{
"message": map[string]any{
"type": "string",
"description": "The message to echo back",
},
"prefix": map[string]any{
"type": "string",
"description": "Optional prefix to add to the echoed message",
"default": "Echo: ",
},
},
Required: []string{"message"},
},
Handler: func(ctx context.Context, params map[string]any) (any, error) {
var req struct {
Message string `json:"message"`
Prefix string `json:"prefix,optional"`
}
if err := mcp.ParseArguments(params, &req); err != nil {
return nil, fmt.Errorf("failed to parse args: %w", err)
}
prefix := "Echo: "
if len(req.Prefix) > 0 {
prefix = req.Prefix
}
return prefix + req.Message, nil
},
}
server.RegisterTool(echoTool)
// Register a static prompt
server.RegisterPrompt(mcp.Prompt{
Name: "greeting",
Description: "A simple greeting prompt",
Arguments: []mcp.PromptArgument{
{
Name: "name",
Description: "The name to greet",
Required: true,
},
},
Content: "Hello {{name}}! How can I assist you today?",
})
// Register a dynamic prompt
server.RegisterPrompt(mcp.Prompt{
Name: "dynamic-prompt",
Description: "A prompt that uses a handler to generate dynamic content",
Arguments: []mcp.PromptArgument{
{
Name: "username",
Description: "User's name for personalized greeting",
Required: true,
},
{
Name: "topic",
Description: "Topic of expertise",
Required: true,
},
},
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
var req struct {
Username string `json:"username"`
Topic string `json:"topic"`
}
if err := mcp.ParseArguments(args, &req); err != nil {
return nil, fmt.Errorf("failed to parse args: %w", err)
}
// Create messages with current time
currentTime := time.Now().Format(time.RFC1123)
return []mcp.PromptMessage{
{
Role: mcp.RoleUser,
Content: mcp.TextContent{
Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic),
},
},
{
Role: mcp.RoleAssistant,
Content: mcp.TextContent{
Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.",
req.Username, req.Topic, currentTime),
},
},
}, nil
},
})
// Register a resource
server.RegisterResource(mcp.Resource{
Name: "example-doc",
URI: "file:///example/doc.txt",
Description: "An example document",
MimeType: "text/plain",
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
return mcp.ResourceContent{
URI: "file:///example/doc.txt",
MimeType: "text/plain",
Text: "This is the content of the example document.",
}, nil
},
})
// Start the server
fmt.Printf("Starting MCP server on %s:%d\n", c.Host, c.Port)
server.Start()
}
```
## Error Handling
The MCP implementation provides comprehensive error handling:
- Tool execution errors are properly reported back to clients
- Missing or invalid parameters are detected and reported with appropriate error codes
- Resource and prompt lookup failures are handled gracefully
- Timeout handling for long-running tool executions using context
- Panic recovery to prevent server crashes
## Advanced Features
- **Annotations**: Add audience and priority metadata to content
- **Content Types**: Support for text, images, audio, and other content formats
- **Embedded Resources**: Include file resources directly in prompt responses
- **Context Awareness**: All handlers receive context.Context for timeout and cancellation support
- **Progress Tokens**: Support for tracking progress of long-running operations
- **Customizable Timeouts**: Configure execution timeouts for tools and operations
## Performance Considerations
- Tool execution runs with configurable timeouts to prevent blocking
- Efficient client tracking and cleanup to prevent resource leaks
- Proper concurrency handling with mutex protection for shared resources
- Buffered message channels to prevent blocking on client message delivery

View File

@@ -42,14 +42,14 @@ func NewMcpServer(c McpConf) McpServer {
Method: http.MethodGet,
Path: s.conf.Mcp.SseEndpoint,
Handler: s.handleSSE,
}, rest.WithSSE())
}, rest.WithSSE(), rest.WithTimeout(c.Mcp.SseTimeout))
// JSON-RPC message endpoint for regular requests
s.server.AddRoute(rest.Route{
Method: http.MethodPost,
Path: s.conf.Mcp.MessageEndpoint,
Handler: s.handleRequest,
})
}, rest.WithTimeout(c.Mcp.MessageTimeout))
return s
}
@@ -182,21 +182,23 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
// Always allow initialize and notifications/initialized regardless of client state
if req.Method == methodInitialize {
logx.Infof("Processing initialize request with ID: %d", req.ID)
s.processInitialize(client, req)
s.processInitialize(r.Context(), client, req)
logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID)
return
} else if req.Method == methodNotificationsInitialized {
// Handle initialized notification
logx.Info("Received notifications/initialized notification")
if !isNotification {
s.sendErrorResponse(client, req.ID, "Method should be used as a notification", errCodeInvalidRequest)
s.sendErrorResponse(r.Context(), client, req.ID,
"Method should be used as a notification", errCodeInvalidRequest)
return
}
s.processNotificationInitialized(client)
return
} else if !client.initialized && req.Method != methodNotificationsCancelled {
// Block most requests until client is initialized (except for cancellations)
s.sendErrorResponse(client, req.ID, "Client not fully initialized, waiting for notifications/initialized",
s.sendErrorResponse(r.Context(), client, req.ID,
"Client not fully initialized, waiting for notifications/initialized",
errCodeClientNotInitialized)
return
}
@@ -205,41 +207,41 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
switch req.Method {
case methodToolsCall:
logx.Infof("Received tools call request with ID: %d", req.ID)
s.processToolCall(client, req)
s.processToolCall(r.Context(), client, req)
logx.Infof("Sent tools call response for ID: %d", req.ID)
case methodToolsList:
logx.Infof("Processing tools/list request with ID: %d", req.ID)
s.processListTools(client, req)
s.processListTools(r.Context(), client, req)
logx.Infof("Sent tools/list response for ID: %d", req.ID)
case methodPromptsList:
logx.Infof("Processing prompts/list request with ID: %d", req.ID)
s.processListPrompts(client, req)
s.processListPrompts(r.Context(), client, req)
logx.Infof("Sent prompts/list response for ID: %d", req.ID)
case methodPromptsGet:
logx.Infof("Processing prompts/get request with ID: %d", req.ID)
s.processGetPrompt(client, req)
s.processGetPrompt(r.Context(), client, req)
logx.Infof("Sent prompts/get response for ID: %d", req.ID)
case methodResourcesList:
logx.Infof("Processing resources/list request with ID: %d", req.ID)
s.processListResources(client, req)
s.processListResources(r.Context(), client, req)
logx.Infof("Sent resources/list response for ID: %d", req.ID)
case methodResourcesRead:
logx.Infof("Processing resources/read request with ID: %d", req.ID)
s.processResourcesRead(client, req)
s.processResourcesRead(r.Context(), client, req)
logx.Infof("Sent resources/read response for ID: %d", req.ID)
case methodResourcesSubscribe:
logx.Infof("Processing resources/subscribe request with ID: %d", req.ID)
s.processResourceSubscribe(client, req)
s.processResourceSubscribe(r.Context(), client, req)
logx.Infof("Sent resources/subscribe response for ID: %d", req.ID)
case methodPing:
logx.Infof("Processing ping request with ID: %d", req.ID)
s.processPing(client, req)
s.processPing(r.Context(), client, req)
case methodNotificationsCancelled:
logx.Infof("Received notifications/cancelled notification: %v", req.Params)
s.processNotificationCancelled(client, req)
logx.Infof("Received notifications/cancelled notification: %d", req.ID)
s.processNotificationCancelled(r.Context(), client, req)
default:
logx.Infof("Unknown method: %s", req.Method)
s.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound)
logx.Infof("Unknown method: %s from client: %d", req.Method, req.ID)
s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound)
}
}
@@ -321,7 +323,7 @@ func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) {
}
flusher.Flush()
case <-r.Context().Done():
// Client disconnected or request was canceled
// Client disconnected or request was canceled or timed out
logx.Infof("Client %s disconnected: context done", sessionId)
return
}
@@ -329,7 +331,7 @@ func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) {
}
// processInitialize processes the initialize request
func (s *sseMcpServer) processInitialize(client *mcpClient, req Request) {
func (s *sseMcpServer) processInitialize(ctx context.Context, client *mcpClient, req Request) {
// Create a proper JSON-RPC response that preserves the client's request ID
result := initializationResponse{
ProtocolVersion: s.conf.Mcp.ProtocolVersion,
@@ -362,11 +364,11 @@ func (s *sseMcpServer) processInitialize(client *mcpClient, req Request) {
client.initialized = true
// Send response with client's original request ID
s.sendResponse(client, req.ID, result)
s.sendResponse(ctx, client, req.ID, result)
}
// processListTools processes the tools/list request
func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
func (s *sseMcpServer) processListTools(ctx context.Context, client *mcpClient, req Request) {
// Extract pagination params if any
var nextCursor string
var progressToken any
@@ -390,6 +392,9 @@ func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
var toolsList []Tool
s.toolsLock.Lock()
for _, tool := range s.tools {
if len(tool.InputSchema.Type) == 0 {
tool.InputSchema.Type = ContentTypeObject
}
toolsList = append(toolsList, tool)
}
s.toolsLock.Unlock()
@@ -405,15 +410,15 @@ func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
// Add meta information if progress token was provided
if progressToken != nil {
result.Result.Meta = map[string]any{
"progressToken": progressToken,
progressTokenKey: progressToken,
}
}
s.sendResponse(client, req.ID, result)
s.sendResponse(ctx, client, req.ID, result)
}
// processListPrompts processes the prompts/list request
func (s *sseMcpServer) processListPrompts(client *mcpClient, req Request) {
func (s *sseMcpServer) processListPrompts(ctx context.Context, client *mcpClient, req Request) {
// Extract pagination params if any
var nextCursor string
if req.Params != nil {
@@ -447,11 +452,11 @@ func (s *sseMcpServer) processListPrompts(client *mcpClient, req Request) {
NextCursor: nextCursor,
}
s.sendResponse(client, req.ID, result)
s.sendResponse(ctx, client, req.ID, result)
}
// processListResources processes the resources/list request
func (s *sseMcpServer) processListResources(client *mcpClient, req Request) {
func (s *sseMcpServer) processListResources(ctx context.Context, client *mcpClient, req Request) {
// Extract pagination params if any
var nextCursor string
var progressToken any
@@ -493,15 +498,15 @@ func (s *sseMcpServer) processListResources(client *mcpClient, req Request) {
// Add meta information if progress token was provided
if progressToken != nil {
result.Result.Meta = map[string]any{
"progressToken": progressToken,
progressTokenKey: progressToken,
}
}
s.sendResponse(client, req.ID, result)
s.sendResponse(ctx, client, req.ID, result)
}
// processGetPrompt processes the prompts/get request
func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
func (s *sseMcpServer) processGetPrompt(ctx context.Context, client *mcpClient, req Request) {
type GetPromptParams struct {
Name string `json:"name"`
Arguments map[string]string `json:"arguments,omitempty"`
@@ -509,7 +514,7 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
var params GetPromptParams
if err := json.Unmarshal(req.Params, &params); err != nil {
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
return
}
@@ -519,7 +524,7 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
s.promptsLock.Unlock()
if !exists {
message := fmt.Sprintf("Prompt '%s' not found", params.Name)
s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams)
s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams)
return
}
@@ -529,12 +534,15 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
missingArgs := validatePromptArguments(prompt, params.Arguments)
if len(missingArgs) > 0 {
message := fmt.Sprintf("Missing required arguments: %s", strings.Join(missingArgs, ", "))
s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams)
s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams)
return
}
// Apply default values for missing optional arguments
args := applyDefaultArguments(prompt, params.Arguments)
// Ensure arguments are initialized to an empty map if nil
if params.Arguments == nil {
params.Arguments = make(map[string]string)
}
args := params.Arguments
// Generate messages using handler or static content
var messages []PromptMessage
@@ -542,17 +550,17 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
if prompt.Handler != nil {
// Use dynamic handler to generate messages
logx.Info("Using prompt handler to generate content")
messages, err = prompt.Handler(args)
messages, err = prompt.Handler(ctx, args)
if err != nil {
logx.Errorf("Error from prompt handler: %v", err)
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError)
s.sendErrorResponse(ctx, client, req.ID,
fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError)
return
}
} else {
// No handler, generate messages from static content
var messageText string
if prompt.Content != "" {
if len(prompt.Content) > 0 {
messageText = prompt.Content
// Apply argument substitutions to static content
@@ -560,21 +568,13 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
placeholder := fmt.Sprintf("{{%s}}", key)
messageText = strings.Replace(messageText, placeholder, value, -1)
}
} else {
// No content, use a default fallback
topic := "this topic"
if t, ok := args["topic"]; ok && t != "" {
topic = t
}
messageText = fmt.Sprintf("Tell me about %s", topic)
}
// Create a single user message with the content
messages = []PromptMessage{
{
Role: roleUser,
Role: RoleUser,
Content: TextContent{
Type: contentTypeText,
Text: messageText,
},
},
@@ -587,49 +587,14 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
Messages []PromptMessage `json:"messages"`
}{
Description: prompt.Description,
Messages: messages,
Messages: toTypedPromptMessages(messages),
}
s.sendResponse(client, req.ID, result)
}
// validatePromptArguments checks if all required arguments are provided
// Returns a list of missing required arguments
func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string {
var missingArgs []string
for _, arg := range prompt.Arguments {
if arg.Required {
if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 {
missingArgs = append(missingArgs, arg.Name)
}
}
}
return missingArgs
}
// applyDefaultArguments adds default values for missing optional arguments
func applyDefaultArguments(prompt Prompt, providedArgs map[string]string) map[string]string {
result := make(map[string]string)
// Copy all provided arguments
for k, v := range providedArgs {
result[k] = v
}
// Add defaults for missing arguments
for _, arg := range prompt.Arguments {
if _, exists := result[arg.Name]; !exists && arg.Default != "" {
result[arg.Name] = arg.Default
}
}
return result
s.sendResponse(ctx, client, req.ID, result)
}
// processToolCall processes the tools/call request
func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
func (s *sseMcpServer) processToolCall(ctx context.Context, client *mcpClient, req Request) {
var toolCallParams struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments,omitempty"`
@@ -642,7 +607,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
// If it's a RawMessage (JSON), unmarshal it
if err := json.Unmarshal(req.Params, &toolCallParams); err != nil {
logx.Errorf("Failed to unmarshal tool call params: %v", err)
s.sendErrorResponse(client, req.ID, "Invalid tool call parameters", errCodeInvalidParams)
s.sendErrorResponse(ctx, client, req.ID, "Invalid tool call parameters", errCodeInvalidParams)
return
}
@@ -654,15 +619,11 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
tool, exists := s.tools[toolCallParams.Name]
s.toolsLock.Unlock()
if !exists {
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Tool '%s' not found",
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Tool '%s' not found",
toolCallParams.Name), errCodeInvalidParams)
return
}
// Create a context with the configured timeout
ctx, cancel := context.WithTimeout(context.Background(), s.conf.Mcp.ToolTimeout)
defer cancel()
// Log parameters before execution
logx.Infof("Executing tool '%s' with arguments: %#v", toolCallParams.Name, toolCallParams.Arguments)
@@ -671,6 +632,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
var err error
// Create a channel to receive the result
// make sure to have 1 size buffer to avoid channel leak if timeout
resultCh := make(chan struct {
result any
err error
@@ -678,7 +640,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
// Execute the tool handler in a goroutine
go func() {
toolResult, toolErr := tool.Handler(toolCallParams.Arguments)
toolResult, toolErr := tool.Handler(ctx, toolCallParams.Arguments)
resultCh <- struct {
result any
err error
@@ -694,9 +656,9 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
result = res.result
err = res.err
case <-ctx.Done():
// Handle timeout
logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.ToolTimeout, toolCallParams.Name)
s.sendErrorResponse(client, req.ID, "Tool execution timed out", errCodeTimeout)
// Handle request timeout
logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.MessageTimeout, toolCallParams.Name)
s.sendErrorResponse(ctx, client, req.ID, "Tool execution timed out", errCodeTimeout)
return
}
@@ -710,7 +672,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
// Add meta information if progress token was provided
if progressToken != nil {
callToolResult.Result.Meta = map[string]any{
"progressToken": progressToken,
progressTokenKey: progressToken,
}
}
@@ -722,12 +684,11 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
callToolResult.Content = []any{
TextContent{
Type: contentTypeText,
Text: fmt.Sprintf("Error: %v", err),
},
}
callToolResult.IsError = true
s.sendResponse(client, req.ID, callToolResult)
s.sendResponse(ctx, client, req.ID, callToolResult)
return
}
@@ -736,10 +697,9 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
case string:
// Simple string becomes text content
callToolResult.Content = append(callToolResult.Content, TextContent{
Type: contentTypeText,
Text: v,
Annotations: &Annotations{
Audience: []roleType{roleUser, roleAssistant},
Audience: []RoleType{RoleUser, RoleAssistant},
},
})
case map[string]any:
@@ -749,69 +709,63 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
jsonStr = []byte(err.Error())
}
callToolResult.Content = append(callToolResult.Content, TextContent{
Type: contentTypeText,
Text: string(jsonStr),
Annotations: &Annotations{
Audience: []roleType{roleUser, roleAssistant},
Audience: []RoleType{RoleUser, RoleAssistant},
},
})
case TextContent:
// Direct TextContent object
callToolResult.Content = append(callToolResult.Content, v)
case ImageContent:
// Direct ImageContent object
callToolResult.Content = append(callToolResult.Content, v)
case []any:
// Array of content items
callToolResult.Content = v
case ToolResult:
// Handle legacy ToolResult type
switch v.Type {
case contentTypeText:
case ContentTypeText:
callToolResult.Content = append(callToolResult.Content, TextContent{
Type: contentTypeText,
Text: fmt.Sprintf("%v", v.Content),
Annotations: &Annotations{
Audience: []roleType{roleUser, roleAssistant},
Audience: []RoleType{RoleUser, RoleAssistant},
},
})
case contentTypeImage:
case ContentTypeImage:
if imgData, ok := v.Content.(map[string]any); ok {
callToolResult.Content = append(callToolResult.Content, ImageContent{
Type: contentTypeImage,
Data: fmt.Sprintf("%v", imgData["data"]),
MimeType: fmt.Sprintf("%v", imgData["mimeType"]),
})
}
default:
callToolResult.Content = append(callToolResult.Content, TextContent{
Type: contentTypeText,
Text: fmt.Sprintf("%v", v.Content),
Annotations: &Annotations{
Audience: []roleType{roleUser, roleAssistant},
Audience: []RoleType{RoleUser, RoleAssistant},
},
})
}
default:
// For any other type, convert to string
callToolResult.Content = append(callToolResult.Content, TextContent{
Type: contentTypeText,
Text: fmt.Sprintf("%v", v),
Annotations: &Annotations{
Audience: []roleType{roleUser, roleAssistant},
Audience: []RoleType{RoleUser, RoleAssistant},
},
})
}
callToolResult.Content = toTypedContents(callToolResult.Content)
logx.Infof("Tool call result: %#v", callToolResult)
s.sendResponse(client, req.ID, callToolResult)
s.sendResponse(ctx, client, req.ID, callToolResult)
}
// processResourcesRead processes the resources/read request
func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
func (s *sseMcpServer) processResourcesRead(ctx context.Context, client *mcpClient, req Request) {
var params ResourceReadParams
if err := json.Unmarshal(req.Params, &params); err != nil {
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
return
}
@@ -821,7 +775,7 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
s.resourcesLock.Unlock()
if !exists {
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
params.URI), errCodeResourceNotFound)
return
}
@@ -837,14 +791,14 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
},
},
}
s.sendResponse(client, req.ID, result)
s.sendResponse(ctx, client, req.ID, result)
return
}
// Execute the resource handler
content, err := resource.Handler()
content, err := resource.Handler(ctx)
if err != nil {
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error reading resource: %v", err),
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Error reading resource: %v", err),
errCodeInternalError)
return
}
@@ -865,14 +819,14 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
Contents: []ResourceContent{content},
}
s.sendResponse(client, req.ID, result)
s.sendResponse(ctx, client, req.ID, result)
}
// processResourceSubscribe processes the resources/subscribe request
func (s *sseMcpServer) processResourceSubscribe(client *mcpClient, req Request) {
func (s *sseMcpServer) processResourceSubscribe(ctx context.Context, client *mcpClient, req Request) {
var params ResourceSubscribeParams
if err := json.Unmarshal(req.Params, &params); err != nil {
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
return
}
@@ -882,19 +836,17 @@ func (s *sseMcpServer) processResourceSubscribe(client *mcpClient, req Request)
s.resourcesLock.Unlock()
if !exists {
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
params.URI), errCodeResourceNotFound)
return
}
// Send success response for the subscription
s.sendResponse(client, req.ID, struct{}{})
logx.Infof("Client %s subscribed to resource '%s'", client.id, params.URI)
s.sendResponse(ctx, client, req.ID, struct{}{})
}
// processNotificationCancelled processes the notifications/cancelled notification
func (s *sseMcpServer) processNotificationCancelled(client *mcpClient, req Request) {
func (s *sseMcpServer) processNotificationCancelled(ctx context.Context, client *mcpClient, req Request) {
// Extract the requestId that was canceled
type CancelParams struct {
RequestId int64 `json:"requestId"`
@@ -918,18 +870,17 @@ func (s *sseMcpServer) processNotificationInitialized(client *mcpClient) {
}
// processPing processes the ping request and responds immediately
func (s *sseMcpServer) processPing(client *mcpClient, req Request) {
func (s *sseMcpServer) processPing(ctx context.Context, client *mcpClient, req Request) {
// A ping request should simply respond with an empty result to confirm the server is alive
logx.Infof("Received ping request with ID: %d", req.ID)
// Send an empty response with client's original request ID
s.sendResponse(client, req.ID, struct{}{})
logx.Infof("Sent ping response for ID: %d", req.ID)
s.sendResponse(ctx, client, req.ID, struct{}{})
}
// sendErrorResponse sends an error response via the SSE channel
func (s *sseMcpServer) sendErrorResponse(client *mcpClient, id int64, message string, code int) {
func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
id int64, message string, code int) {
errorResponse := struct {
JsonRpc string `json:"jsonrpc"`
ID int64 `json:"id"`
@@ -949,11 +900,17 @@ func (s *sseMcpServer) sendErrorResponse(client *mcpClient, id int64, message st
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
logx.Infof("Sending error for ID %d: %s", id, sseMessage)
client.channel <- sseMessage
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
select {
case client.channel <- sseMessage:
default:
// Channel buffer is full, log warning and continue
logx.Infof("Client %s channel is full while sending error response with ID %d", client.id, id)
}
}
// sendResponse sends a success response via the SSE channel
func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) {
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id int64, result any) {
response := Response{
JsonRpc: jsonRpcVersion,
ID: id,
@@ -962,7 +919,7 @@ func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) {
jsonData, err := json.Marshal(response)
if err != nil {
s.sendErrorResponse(client, id, "Failed to marshal response", errCodeInternalError)
s.sendErrorResponse(ctx, client, id, "Failed to marshal response", errCodeInternalError)
return
}
@@ -970,5 +927,11 @@ func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) {
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
logx.Infof("Sending response for ID %d: %s", id, sseMessage)
client.channel <- sseMessage
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
select {
case client.channel <- sseMessage:
default:
// Channel buffer is full, log warning and continue
logx.Infof("Client %s channel is full while sending response with ID %d", client.id, id)
}
}

View File

@@ -34,7 +34,7 @@ host: localhost
port: 8080
mcp:
name: mcp-test-server
toolTimeout: 5s
messageTimeout: 5s
`
var c McpConf
@@ -82,7 +82,6 @@ func (m *mockMcpServer) registerExampleTool() {
Name: "test.tool",
Description: "A test tool",
InputSchema: InputSchema{
Type: "object",
Properties: map[string]any{
"input": map[string]any{
"type": "string",
@@ -91,7 +90,7 @@ func (m *mockMcpServer) registerExampleTool() {
},
Required: []string{"input"},
},
Handler: func(params map[string]any) (any, error) {
Handler: func(ctx context.Context, params map[string]any) (any, error) {
input, ok := params["input"].(string)
if !ok {
return nil, fmt.Errorf("invalid input parameter")
@@ -135,7 +134,7 @@ port: 8080
mcp:
cors:
- http://localhost:3000
toolTimeout: 5s
messageTimeout: 5s
`
var c McpConf
@@ -186,7 +185,6 @@ func TestRegisterTool(t *testing.T) {
Name: "example.tool",
Description: "An example tool",
InputSchema: InputSchema{
Type: "object",
Properties: map[string]any{
"input": map[string]any{
"type": "string",
@@ -194,7 +192,7 @@ func TestRegisterTool(t *testing.T) {
},
},
},
Handler: func(params map[string]any) (any, error) {
Handler: func(ctx context.Context, params map[string]any) (any, error) {
return "result", nil
},
}
@@ -280,7 +278,7 @@ func TestToolsList(t *testing.T) {
}
// Process the tool call
mock.server.processListTools(client, req)
mock.server.processListTools(context.Background(), client, req)
// Get the response from the client's channel
select {
@@ -328,7 +326,7 @@ func TestToolCallBasic(t *testing.T) {
}
// Process the tool call
mock.server.processToolCall(client, req)
mock.server.processToolCall(context.Background(), client, req)
// Get the response from the client's channel
select {
@@ -355,8 +353,7 @@ func TestToolCallBasic(t *testing.T) {
// Verify the response content
assert.Len(t, parsed.Result.Content, 1, "Response should contain one content item")
assert.Equal(t, "text", parsed.Result.Content[0]["type"], "Content type should be text")
assert.Equal(t, "Processed: test-input", parsed.Result.Content[0]["text"], "Tool result incorrect")
assert.Equal(t, "Processed: test-input", parsed.Result.Content[0][ContentTypeText], "Tool result incorrect")
assert.False(t, parsed.Result.IsError, "Response should not be an error")
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for tool call response")
@@ -373,10 +370,9 @@ func TestToolCallMapResult(t *testing.T) {
Name: "map.tool",
Description: "A tool that returns a map result",
InputSchema: InputSchema{
Type: "object",
Properties: map[string]any{},
},
Handler: func(params map[string]any) (any, error) {
Handler: func(ctx context.Context, params map[string]any) (any, error) {
// Return a complex nested map structure
return map[string]any{
"string": "value",
@@ -417,7 +413,7 @@ func TestToolCallMapResult(t *testing.T) {
}
// Process the tool call
mock.server.processToolCall(client, req)
mock.server.processToolCall(context.Background(), client, req)
// Get the response from the client's channel
select {
@@ -445,13 +441,8 @@ func TestToolCallMapResult(t *testing.T) {
firstItem, ok := content[0].(map[string]any)
require.True(t, ok, "First content item should be an object")
// Verify it's a text content
contentType, ok := firstItem["type"].(string)
require.True(t, ok, "Content should have a type")
assert.Equal(t, "text", contentType, "Content type should be text")
// Get the text content which should be our JSON
text, ok := firstItem["text"].(string)
text, ok := firstItem[ContentTypeText].(string)
require.True(t, ok, "Content should have text")
// Verify the text is valid JSON and contains our data
@@ -496,10 +487,9 @@ func TestToolCallArrayResult(t *testing.T) {
Name: "array.tool",
Description: "A tool that returns an array result",
InputSchema: InputSchema{
Type: "object",
Properties: map[string]any{},
},
Handler: func(params map[string]any) (any, error) {
Handler: func(ctx context.Context, params map[string]any) (any, error) {
// Return an array of mixed content types
return []any{
"string item",
@@ -536,7 +526,7 @@ func TestToolCallArrayResult(t *testing.T) {
}
// Process the tool call
mock.server.processToolCall(client, req)
mock.server.processToolCall(context.Background(), client, req)
// Get the response from the client's channel
select {
@@ -574,16 +564,14 @@ func TestToolCallTextContentResult(t *testing.T) {
Name: "text.content.tool",
Description: "A tool that returns a TextContent result",
InputSchema: InputSchema{
Type: "object",
Properties: map[string]any{},
},
Handler: func(params map[string]any) (any, error) {
Handler: func(ctx context.Context, params map[string]any) (any, error) {
// Return a TextContent object directly
return TextContent{
Type: "text",
Text: "This is a direct TextContent result",
Annotations: &Annotations{
Audience: []roleType{roleUser, roleAssistant},
Audience: []RoleType{RoleUser, RoleAssistant},
Priority: func() *float64 { p := 0.9; return &p }(),
},
}, nil
@@ -614,7 +602,7 @@ func TestToolCallTextContentResult(t *testing.T) {
}
// Process the tool call
mock.server.processToolCall(client, req)
mock.server.processToolCall(context.Background(), client, req)
// Get the response from the client's channel
select {
@@ -642,16 +630,6 @@ func TestToolCallTextContentResult(t *testing.T) {
firstItem, ok := content[0].(map[string]any)
require.True(t, ok, "First content item should be an object")
// Verify it's a text content
contentType, ok := firstItem["type"].(string)
require.True(t, ok, "Content should have a type")
assert.Equal(t, "text", contentType, "Content type should be text")
// Check text content
text, ok := firstItem["text"].(string)
require.True(t, ok, "Content should have text")
assert.Equal(t, "This is a direct TextContent result", text, "Text content should match")
// Check annotations
annotations, ok := firstItem["annotations"].(map[string]any)
require.True(t, ok, "Should have annotations")
@@ -679,13 +657,11 @@ func TestToolCallImageContentResult(t *testing.T) {
Name: "image.content.tool",
Description: "A tool that returns an ImageContent result",
InputSchema: InputSchema{
Type: "object",
Properties: map[string]any{},
},
Handler: func(params map[string]any) (any, error) {
Handler: func(ctx context.Context, params map[string]any) (any, error) {
// Return an ImageContent object directly
return ImageContent{
Type: "image",
Data: "dGVzdCBpbWFnZSBkYXRhIChiYXNlNjQgZW5jb2RlZCk=", // "test image data (base64 encoded)" in base64
MimeType: "image/png",
}, nil
@@ -716,7 +692,7 @@ func TestToolCallImageContentResult(t *testing.T) {
}
// Process the tool call
mock.server.processToolCall(client, req)
mock.server.processToolCall(context.Background(), client, req)
// Get the response from the client's channel
select {
@@ -744,11 +720,6 @@ func TestToolCallImageContentResult(t *testing.T) {
firstItem, ok := content[0].(map[string]any)
require.True(t, ok, "First content item should be an object")
// Verify it's an image content
contentType, ok := firstItem["type"].(string)
require.True(t, ok, "Content should have a type")
assert.Equal(t, "image", contentType, "Content type should be image")
// Check image data
data, ok := firstItem["data"].(string)
require.True(t, ok, "Content should have data")
@@ -773,12 +744,12 @@ func TestToolCallToolResultType(t *testing.T) {
Name: "toolresult.tool",
Description: "A tool that returns a ToolResult object",
InputSchema: InputSchema{
Type: "object",
Type: ContentTypeObject,
Properties: map[string]any{},
},
Handler: func(params map[string]any) (any, error) {
Handler: func(ctx context.Context, params map[string]any) (any, error) {
return ToolResult{
Type: "text",
Type: ContentTypeText,
Content: "This is a ToolResult with text content type",
}, nil
},
@@ -790,10 +761,10 @@ func TestToolCallToolResultType(t *testing.T) {
Name: "toolresult.image.tool",
Description: "A tool that returns a ToolResult with image content",
InputSchema: InputSchema{
Type: "object",
Type: ContentTypeObject,
Properties: map[string]any{},
},
Handler: func(params map[string]any) (any, error) {
Handler: func(ctx context.Context, params map[string]any) (any, error) {
return ToolResult{
Type: "image",
Content: map[string]any{
@@ -810,10 +781,9 @@ func TestToolCallToolResultType(t *testing.T) {
Name: "toolresult.audio.tool",
Description: "A tool that returns a ToolResult with audio content",
InputSchema: InputSchema{
Type: "object",
Properties: map[string]any{},
},
Handler: func(params map[string]any) (any, error) {
Handler: func(ctx context.Context, params map[string]any) (any, error) {
// Test with image type
return ToolResult{
Type: "audio",
@@ -831,10 +801,9 @@ func TestToolCallToolResultType(t *testing.T) {
Name: "toolresult.int.tool",
Description: "A tool that returns a ToolResult with int content",
InputSchema: InputSchema{
Type: "object",
Properties: map[string]any{},
},
Handler: func(params map[string]any) (any, error) {
Handler: func(ctx context.Context, params map[string]any) (any, error) {
return 2, nil
},
}
@@ -845,10 +814,9 @@ func TestToolCallToolResultType(t *testing.T) {
Name: "toolresult.bad.tool",
Description: "A tool that returns a ToolResult with bad content",
InputSchema: InputSchema{
Type: "object",
Properties: map[string]any{},
},
Handler: func(params map[string]any) (any, error) {
Handler: func(ctx context.Context, params map[string]any) (any, error) {
return map[string]any{
"type": "custom",
"data": make(chan int),
@@ -881,7 +849,7 @@ func TestToolCallToolResultType(t *testing.T) {
}
// Process the tool call
mock.server.processToolCall(client, req)
mock.server.processToolCall(context.Background(), client, req)
// Get the response from the client's channel
select {
@@ -909,13 +877,8 @@ func TestToolCallToolResultType(t *testing.T) {
firstItem, ok := content[0].(map[string]any)
require.True(t, ok, "First content item should be an object")
// Verify it's a text content
contentType, ok := firstItem["type"].(string)
require.True(t, ok, "Content should have a type")
assert.Equal(t, "text", contentType, "Content type should be text")
// Check text content
text, ok := firstItem["text"].(string)
text, ok := firstItem[ContentTypeText].(string)
require.True(t, ok, "Content should have text")
assert.Equal(t, "This is a ToolResult with text content type", text, "Text content should match")
@@ -947,7 +910,7 @@ func TestToolCallToolResultType(t *testing.T) {
}
// Process the tool call
mock.server.processToolCall(client, req)
mock.server.processToolCall(context.Background(), client, req)
// Get the response from the client's channel
select {
@@ -975,11 +938,6 @@ func TestToolCallToolResultType(t *testing.T) {
firstItem, ok := content[0].(map[string]any)
require.True(t, ok, "First content item should be an object")
// Verify it's an image content
contentType, ok := firstItem["type"].(string)
require.True(t, ok, "Content should have a type")
assert.Equal(t, "image", contentType, "Content type should be image")
// Check image data and mime type
data, ok := firstItem["data"].(string)
require.True(t, ok, "Content should have data")
@@ -1017,7 +975,7 @@ func TestToolCallToolResultType(t *testing.T) {
}
// Process the tool call
mock.server.processToolCall(client, req)
mock.server.processToolCall(context.Background(), client, req)
// Get the response from the client's channel
select {
@@ -1040,15 +998,6 @@ func TestToolCallToolResultType(t *testing.T) {
content, ok := result["content"].([]any)
require.True(t, ok, "Result should have a content array")
require.NotEmpty(t, content, "Content should not be empty")
// The first content item should be converted from ToolResult to ImageContent
firstItem, ok := content[0].(map[string]any)
require.True(t, ok, "First content item should be an object")
// Verify it's an image content
contentType, ok := firstItem["type"].(string)
require.True(t, ok, "Content should have a type")
assert.Equal(t, "text", contentType, "Content type should be image")
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for tool call response")
}
@@ -1077,7 +1026,7 @@ func TestToolCallToolResultType(t *testing.T) {
}
// Process the tool call
mock.server.processToolCall(client, req)
mock.server.processToolCall(context.Background(), client, req)
// Get the response from the client's channel
select {
@@ -1100,15 +1049,6 @@ func TestToolCallToolResultType(t *testing.T) {
content, ok := result["content"].([]any)
require.True(t, ok, "Result should have a content array")
require.NotEmpty(t, content, "Content should not be empty")
// The first content item should be converted from ToolResult to ImageContent
firstItem, ok := content[0].(map[string]any)
require.True(t, ok, "First content item should be an object")
// Verify it's an image content
contentType, ok := firstItem["type"].(string)
require.True(t, ok, "Content should have a type")
assert.Equal(t, "text", contentType, "Content type should be image")
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for tool call response")
}
@@ -1137,7 +1077,7 @@ func TestToolCallToolResultType(t *testing.T) {
}
// Process the tool call
mock.server.processToolCall(client, req)
mock.server.processToolCall(context.Background(), client, req)
// Get the response from the client's channel
select {
@@ -1159,10 +1099,9 @@ func TestToolCallError(t *testing.T) {
Name: "error.tool",
Description: "A tool that returns an error",
InputSchema: InputSchema{
Type: "object",
Properties: map[string]any{},
},
Handler: func(params map[string]any) (any, error) {
Handler: func(ctx context.Context, params map[string]any) (any, error) {
return nil, fmt.Errorf("tool execution failed")
},
})
@@ -1189,7 +1128,7 @@ func TestToolCallError(t *testing.T) {
}
// Process the tool call
mock.server.processToolCall(client, req)
mock.server.processToolCall(context.Background(), client, req)
// Check the response
select {
@@ -1207,20 +1146,16 @@ func TestToolCallTimeout(t *testing.T) {
mock := newMockMcpServer(t)
defer mock.shutdown()
// Set a very short timeout for testing
mock.server.conf.Mcp.ToolTimeout = 10 * time.Millisecond
// Register a tool that times out
err := mock.server.RegisterTool(Tool{
Name: "timeout.tool",
Description: "A tool that times out",
InputSchema: InputSchema{
Type: "object",
Properties: map[string]any{},
},
Handler: func(params map[string]any) (any, error) {
time.Sleep(50 * time.Millisecond) // Sleep longer than timeout
return "this should never be returned", nil
Handler: func(ctx context.Context, params map[string]any) (any, error) {
<-ctx.Done()
return nil, fmt.Errorf("tool execution timed out")
},
})
require.NoError(t, err)
@@ -1244,16 +1179,24 @@ func TestToolCallTimeout(t *testing.T) {
Method: methodToolsCall,
Params: paramBytes,
}
jsonBody, _ := json.Marshal(req)
// Process the tool call
mock.server.processToolCall(client, req)
// Create HTTP request
r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", bytes.NewReader(jsonBody))
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Millisecond)
defer cancel()
r = r.WithContext(ctx)
w := httptest.NewRecorder()
// Process through handleRequest
go mock.server.handleRequest(w, r)
// Check the response
select {
case response := <-client.channel:
assert.Contains(t, response, "event: message", "Response should have message event")
assert.Contains(t, response, `-32001`, "Response should contain a timeout error code")
case <-time.After(150 * time.Millisecond):
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for tool call response")
}
}
@@ -1274,7 +1217,7 @@ func TestInitializeAndNotifications(t *testing.T) {
Params: json.RawMessage(`{}`),
}
mock.server.processInitialize(client, initReq)
mock.server.processInitialize(context.Background(), client, initReq)
// Check that client is initialized after initialize request
assert.True(t, client.initialized, "Client should be marked as initialized after initialize request")
@@ -1418,7 +1361,7 @@ func TestNotificationCancelled_badParams(t *testing.T) {
}
buf := logtest.NewCollector(t)
mock.server.processNotificationCancelled(client, cancelReq)
mock.server.processNotificationCancelled(context.Background(), client, cancelReq)
select {
case <-client.channel:
@@ -1593,7 +1536,7 @@ func TestGetPrompt(t *testing.T) {
}
// Process the request
mock.server.processGetPrompt(client, promptReq)
mock.server.processGetPrompt(context.Background(), client, promptReq)
// Check response
select {
@@ -1622,7 +1565,7 @@ func TestGetPrompt(t *testing.T) {
}
// Process the request
mock.server.processGetPrompt(client, promptReq)
mock.server.processGetPrompt(context.Background(), client, promptReq)
// Check response
select {
@@ -1636,6 +1579,44 @@ func TestGetPrompt(t *testing.T) {
t.Fatal("Timed out waiting for prompt response")
}
})
t.Run("test prompt with nil params", func(t *testing.T) {
mock := newMockMcpServer(t)
defer mock.shutdown()
// Create a test client
client := addTestClient(mock.server, "test-client", true)
// Register a test prompt
testPrompt := Prompt{
Name: "test.prompt",
Description: "A test prompt",
}
mock.server.RegisterPrompt(testPrompt)
// Create a get prompt request
paramBytes, _ := json.Marshal(map[string]any{
"name": "test.prompt",
})
promptReq := Request{
JsonRpc: jsonRpcVersion,
ID: 1,
Method: "prompts/get",
Params: paramBytes,
}
// Process the request
mock.server.processGetPrompt(context.Background(), client, promptReq)
// Check response
select {
case response := <-client.channel:
_, err := parseEvent(response)
assert.NoError(t, err, "Should be able to parse event")
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for prompt response")
}
})
}
// TestBroadcast tests the broadcast functionality
@@ -1903,34 +1884,79 @@ func TestNotificationInitialized(t *testing.T) {
})
}
func TestSendBadResponse(t *testing.T) {
mock := newMockMcpServer(t)
defer mock.shutdown()
func TestSendResponse(t *testing.T) {
t.Run("bad response", func(t *testing.T) {
mock := newMockMcpServer(t)
defer mock.shutdown()
// Create a test client
client := addTestClient(mock.server, "test-client", true)
// Create a test client
client := addTestClient(mock.server, "test-client", true)
// Create a response
response := Response{
JsonRpc: jsonRpcVersion,
ID: 1,
Result: make(chan int),
}
// Create a response
response := Response{
JsonRpc: jsonRpcVersion,
ID: 1,
Result: make(chan int),
}
// Send the response
mock.server.sendResponse(client, 1, response)
// Send the response
mock.server.sendResponse(context.Background(), client, 1, response)
// Check the response in the client's channel
select {
case res := <-client.channel:
evt, err := parseEvent(res)
require.NoError(t, err, "Should parse event without error")
errMsg, ok := evt.Data["error"].(map[string]any)
require.True(t, ok, "Should have error in response")
assert.Equal(t, float64(errCodeInternalError), errMsg["code"], "Error code should match")
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for response")
}
// Check the response in the client's channel
select {
case res := <-client.channel:
evt, err := parseEvent(res)
require.NoError(t, err, "Should parse event without error")
errMsg, ok := evt.Data["error"].(map[string]any)
require.True(t, ok, "Should have error in response")
assert.Equal(t, float64(errCodeInternalError), errMsg["code"], "Error code should match")
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for response")
}
})
t.Run("channel full", func(t *testing.T) {
mock := newMockMcpServer(t)
defer mock.shutdown()
// Create a test client
client := addTestClient(mock.server, "test-client", true)
for i := 0; i < eventChanSize; i++ {
client.channel <- "test"
}
// Create a response
response := Response{
JsonRpc: jsonRpcVersion,
ID: 1,
Result: "foo",
}
buf := logtest.NewCollector(t)
// Send the response
mock.server.sendResponse(context.Background(), client, 1, response)
// Check the response in the client's channel
assert.Contains(t, buf.String(), "channel is full")
})
}
func TestSendErrorResponse(t *testing.T) {
t.Run("channel full", func(t *testing.T) {
mock := newMockMcpServer(t)
defer mock.shutdown()
// Create a test client
client := addTestClient(mock.server, "test-client", true)
for i := 0; i < eventChanSize; i++ {
client.channel <- "test"
}
buf := logtest.NewCollector(t)
// Send the response
mock.server.sendErrorResponse(context.Background(), client, 1, "foo", errCodeInternalError)
// Check the response in the client's channel
assert.Contains(t, buf.String(), "channel is full")
})
}
// TestMethodToolsCall tests the handling of tools/call method through handleRequest
@@ -2028,8 +2054,7 @@ func TestMethodToolsCall(t *testing.T) {
if len(content) > 0 {
firstItem, ok := content[0].(map[string]any)
if ok {
assert.Equal(t, "text", firstItem["type"], "Content type should be text")
assert.Contains(t, firstItem["text"], "Processed: test-input", "Content should include processed input")
assert.Contains(t, firstItem[ContentTypeText], "Processed: test-input", "Content should include processed input")
}
}
case <-time.After(100 * time.Millisecond):
@@ -2145,7 +2170,6 @@ func TestMethodPromptsGet(t *testing.T) {
{
Name: "topic",
Description: "Topic to discuss",
Default: "artificial intelligence",
},
},
Content: "Hello {{name}}! Let's talk about {{topic}}.",
@@ -2227,14 +2251,12 @@ func TestMethodPromptsGet(t *testing.T) {
if len(messages) > 0 {
message, ok := messages[0].(map[string]any)
require.True(t, ok, "Message should be an object")
assert.Equal(t, "user", message["role"], "Role should be 'user'")
assert.Equal(t, string(RoleUser), message["role"], "Role should be 'user'")
content, ok := message["content"].(map[string]any)
require.True(t, ok, "Should have content object")
assert.Equal(t, "text", content["type"], "Content type should be text")
assert.Contains(t, content["text"], "Hello Test User", "Content should include the name argument")
assert.Contains(t, content["text"], "about artificial intelligence",
"Content should include the default topic argument")
assert.Equal(t, ContentTypeText, content["type"], "Content type should be text")
assert.Contains(t, content[ContentTypeText], "Hello Test User", "Content should include the name argument")
}
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for prompt get response")
@@ -2255,27 +2277,24 @@ func TestMethodPromptsGet(t *testing.T) {
{
Name: "question",
Description: "User's question",
Default: "How does this work?",
},
},
Handler: func(args map[string]string) ([]PromptMessage, error) {
Handler: func(ctx context.Context, args map[string]string) ([]PromptMessage, error) {
username := args["username"]
question := args["question"]
// Create a system message
systemMessage := PromptMessage{
Role: "system",
Role: RoleAssistant,
Content: TextContent{
Type: "text",
Text: "You are a helpful assistant.",
},
}
// Create a user message
userMessage := PromptMessage{
Role: "user",
Role: RoleUser,
Content: TextContent{
Type: "text",
Text: fmt.Sprintf("Hi, I'm %s and I'm wondering: %s", username, question),
},
}
@@ -2340,20 +2359,20 @@ func TestMethodPromptsGet(t *testing.T) {
// Check message content
if len(messages) >= 2 {
// First message should be system
// First message should be assistant
message1, _ := messages[0].(map[string]any)
assert.Equal(t, "system", message1["role"], "First role should be 'system'")
assert.Equal(t, string(RoleAssistant), message1["role"], "First role should be 'system'")
content1, _ := message1["content"].(map[string]any)
assert.Contains(t, content1["text"], "helpful assistant", "System message should be correct")
assert.Contains(t, content1[ContentTypeText], "helpful assistant", "System message should be correct")
// Second message should be user
message2, _ := messages[1].(map[string]any)
assert.Equal(t, "user", message2["role"], "Second role should be 'user'")
assert.Equal(t, string(RoleUser), message2["role"], "Second role should be 'user'")
content2, _ := message2["content"].(map[string]any)
assert.Contains(t, content2["text"], "Dynamic User", "User message should contain username")
assert.Contains(t, content2["text"], "How to test this?", "User message should contain question")
assert.Contains(t, content2[ContentTypeText], "Dynamic User", "User message should contain username")
assert.Contains(t, content2[ContentTypeText], "How to test this?", "User message should contain question")
}
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for prompt get response")
@@ -2459,7 +2478,7 @@ func TestMethodPromptsGet(t *testing.T) {
Name: "error-handler-prompt",
Description: "A prompt with a handler that returns an error",
Arguments: []PromptArgument{},
Handler: func(args map[string]string) ([]PromptMessage, error) {
Handler: func(ctx context.Context, args map[string]string) ([]PromptMessage, error) {
return nil, fmt.Errorf("test handler error")
},
}
@@ -2583,7 +2602,7 @@ func TestMethodResourcesList(t *testing.T) {
URI: "file:///test/resource.txt",
Description: "A test resource with handler",
MimeType: "text/plain",
Handler: func() (ResourceContent, error) {
Handler: func(ctx context.Context) (ResourceContent, error) {
return ResourceContent{
URI: "file:///test/resource.txt",
MimeType: "text/plain",
@@ -2654,7 +2673,7 @@ func TestMethodResourcesRead(t *testing.T) {
URI: "file:///test/resource.txt",
Description: "A test resource with handler",
MimeType: "text/plain",
Handler: func() (ResourceContent, error) {
Handler: func(ctx context.Context) (ResourceContent, error) {
return ResourceContent{
URI: "file:///test/resource.txt",
MimeType: "text/plain",
@@ -2729,7 +2748,7 @@ func TestMethodResourcesRead(t *testing.T) {
require.True(t, ok, "Content should be an object")
assert.Equal(t, "file:///test/resource.txt", content["uri"], "URI should match")
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match")
assert.Equal(t, "This is test resource content", content["text"], "Text content should match")
assert.Equal(t, "This is test resource content", content[ContentTypeText], "Text content should match")
}
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for resource read response")
@@ -2799,7 +2818,7 @@ func TestMethodResourcesRead(t *testing.T) {
require.True(t, ok, "Content should be an object")
assert.Equal(t, "file:///test/no-handler.txt", content["uri"], "URI should match")
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match")
_, ok = content["text"]
_, ok = content[ContentTypeText]
assert.False(t, ok, "Text content should be empty string")
}
case <-time.After(100 * time.Millisecond):
@@ -2880,7 +2899,7 @@ func TestMethodResourcesRead(t *testing.T) {
client := addTestClient(mock.server, "test-client-resources", true)
// Process through handleRequest
mock.server.processResourcesRead(client, req)
mock.server.processResourcesRead(context.Background(), client, req)
select {
case response := <-client.channel:
@@ -2898,7 +2917,7 @@ func TestMethodResourcesRead(t *testing.T) {
URI: "file:///test/error.txt",
Description: "A test resource with handler that returns error",
MimeType: "text/plain",
Handler: func() (ResourceContent, error) {
Handler: func(ctx context.Context) (ResourceContent, error) {
return ResourceContent{}, fmt.Errorf("test handler error")
},
}
@@ -2946,7 +2965,7 @@ func TestMethodResourcesRead(t *testing.T) {
URI: "file:///test/missing-fields.txt",
Description: "A test resource with handler that returns content missing fields",
MimeType: "text/plain",
Handler: func() (ResourceContent, error) {
Handler: func(ctx context.Context) (ResourceContent, error) {
// Return ResourceContent without URI and MimeType
return ResourceContent{
Text: "Content with missing fields",
@@ -3006,7 +3025,7 @@ func TestMethodResourcesRead(t *testing.T) {
require.True(t, ok, "Content should be an object")
assert.Equal(t, "file:///test/missing-fields.txt", content["uri"], "URI should be filled from request")
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should be filled from resource")
assert.Equal(t, "Content with missing fields", content["text"], "Text content should match")
assert.Equal(t, "Content with missing fields", content[ContentTypeText], "Text content should match")
}
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for resource read response")
@@ -3159,7 +3178,7 @@ func TestMethodResourcesSubscribe(t *testing.T) {
}
client := addTestClient(mock.server, "test-client-sub-not-found", true)
mock.server.processResourceSubscribe(client, req)
mock.server.processResourceSubscribe(context.Background(), client, req)
select {
case response := <-client.channel:
@@ -3268,7 +3287,7 @@ func TestToolCallUnmarshalError(t *testing.T) {
}
// Process the tool call directly
mock.server.processToolCall(client, req)
mock.server.processToolCall(context.Background(), client, req)
// Check for error response about invalid JSON
select {
@@ -3316,7 +3335,7 @@ func TestToolCallWithInvalidParams(t *testing.T) {
jsonBody, _ := json.Marshal(req)
// Create HTTP request
r := httptest.NewRequest("POST", "/?session_id=test-client-invalid-json", bytes.NewReader(jsonBody))
r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-invalid-json", bytes.NewReader(jsonBody))
w := httptest.NewRecorder()
// Process through handleRequest

View File

@@ -1,6 +1,7 @@
package mcp
import (
"context"
"encoding/json"
"sync"
@@ -45,19 +46,18 @@ type ListToolsResult struct {
// Message Content Types
// roleType represents the sender or recipient of messages in a conversation
type roleType string
// RoleType represents the sender or recipient of messages in a conversation
type RoleType string
// PromptArgument defines a single argument that can be passed to a prompt
type PromptArgument struct {
Name string `json:"name"` // Argument name
Description string `json:"description,omitempty"` // Human-readable description
Required bool `json:"required,omitempty"` // Whether this argument is required
Default string `json:"default,omitempty"` // Default value if not provided
}
// PromptHandler is a function that dynamically generates prompt content
type PromptHandler func(args map[string]string) ([]PromptMessage, error)
type PromptHandler func(ctx context.Context, args map[string]string) ([]PromptMessage, error)
// Prompt represents an MCP Prompt definition
type Prompt struct {
@@ -70,31 +70,43 @@ type Prompt struct {
// PromptMessage represents a message in a conversation
type PromptMessage struct {
Role roleType `json:"role"` // Message sender role
Role RoleType `json:"role"` // Message sender role
Content any `json:"content"` // Message content (TextContent, ImageContent, etc.)
}
// TextContent represents text content in a message
type TextContent struct {
Type string `json:"type"` // Always "text"
Text string `json:"text"` // The text content
Annotations *Annotations `json:"annotations,omitempty"` // Optional annotations
}
type typedTextContent struct {
Type string `json:"type"`
TextContent
}
// ImageContent represents image data in a message
type ImageContent struct {
Type string `json:"type"` // Always "image"
Data string `json:"data"` // Base64-encoded image data
MimeType string `json:"mimeType"` // MIME type (e.g., "image/png")
}
type typedImageContent struct {
Type string `json:"type"`
ImageContent
}
// AudioContent represents audio data in a message
type AudioContent struct {
Type string `json:"type"` // Always "audio"
Data string `json:"data"` // Base64-encoded audio data
MimeType string `json:"mimeType"` // MIME type (e.g., "audio/mp3")
}
type typedAudioContent struct {
Type string `json:"type"`
AudioContent
}
// FileContent represents file content
type FileContent struct {
URI string `json:"uri"` // URI identifying the file
@@ -115,16 +127,14 @@ type EmbeddedResource struct {
// Annotations provides additional metadata for content
type Annotations struct {
Audience []roleType `json:"audience,omitempty"` // Who should see this content
Audience []RoleType `json:"audience,omitempty"` // Who should see this content
Priority *float64 `json:"priority,omitempty"` // Optional priority (0-1)
}
// Tool-related Types
// Tool Definition Types
// ToolHandler is a function that handles tool calls
type ToolHandler func(params map[string]any) (any, error)
type ToolHandler func(ctx context.Context, params map[string]any) (any, error)
// Tool represents a Model Context Protocol Tool definition
type Tool struct {
@@ -136,7 +146,7 @@ type Tool struct {
// InputSchema represents tool's input schema in JSON Schema format
type InputSchema struct {
Type string `json:"type"` // Always "object" for tool inputs
Type string `json:"type"`
Properties map[string]any `json:"properties"` // Property definitions
Required []string `json:"required,omitempty"` // List of required properties
}
@@ -144,8 +154,8 @@ type InputSchema struct {
// CallToolResult represents a tool call result that conforms to the MCP schema
type CallToolResult struct {
Result
Content []interface{} `json:"content"` // Content items (text, images, etc.)
IsError bool `json:"isError,omitempty"` // True if tool execution failed
Content []any `json:"content"` // Content items (text, images, etc.)
IsError bool `json:"isError,omitempty"` // True if tool execution failed
}
// Resource represents a Model Context Protocol Resource definition
@@ -158,7 +168,7 @@ type Resource struct {
}
// ResourceHandler is a function that handles resource read requests
type ResourceHandler func() (ResourceContent, error)
type ResourceHandler func(ctx context.Context) (ResourceContent, error)
// ResourceContent represents the content of a resource
type ResourceContent struct {

View File

@@ -1,6 +1,7 @@
package mcp
import (
"context"
"encoding/json"
"testing"
@@ -79,7 +80,7 @@ func TestToolStructs(t *testing.T) {
},
Required: []string{"input"},
},
Handler: func(params map[string]any) (any, error) {
Handler: func(ctx context.Context, params map[string]any) (any, error) {
return "result", nil
},
}
@@ -145,44 +146,38 @@ func TestResourceStructs(t *testing.T) {
func TestContentTypes(t *testing.T) {
// Test TextContent
textContent := TextContent{
Type: "text",
Text: "Sample text",
Annotations: &Annotations{
Audience: []roleType{roleUser, roleAssistant},
Audience: []RoleType{RoleUser, RoleAssistant},
Priority: ptr(1.0),
},
}
data, err := json.Marshal(textContent)
assert.NoError(t, err)
assert.Contains(t, string(data), `"type":"text"`)
assert.Contains(t, string(data), `"text":"Sample text"`)
assert.Contains(t, string(data), `"audience":["user","assistant"]`)
assert.Contains(t, string(data), `"priority":1`)
// Test ImageContent
imageContent := ImageContent{
Type: "image",
Data: "base64data",
MimeType: "image/png",
}
data, err = json.Marshal(imageContent)
assert.NoError(t, err)
assert.Contains(t, string(data), `"type":"image"`)
assert.Contains(t, string(data), `"data":"base64data"`)
assert.Contains(t, string(data), `"mimeType":"image/png"`)
// Test AudioContent
audioContent := AudioContent{
Type: "audio",
Data: "base64audio",
MimeType: "audio/mp3",
}
data, err = json.Marshal(audioContent)
assert.NoError(t, err)
assert.Contains(t, string(data), `"type":"audio"`)
assert.Contains(t, string(data), `"data":"base64audio"`)
assert.Contains(t, string(data), `"mimeType":"audio/mp3"`)
}
@@ -197,7 +192,6 @@ func TestCallToolResult(t *testing.T) {
},
Content: []interface{}{
TextContent{
Type: "text",
Text: "Sample result",
},
},
@@ -207,6 +201,6 @@ func TestCallToolResult(t *testing.T) {
data, err := json.Marshal(result)
assert.NoError(t, err)
assert.Contains(t, string(data), `"_meta":{"progressToken":"token123"}`)
assert.Contains(t, string(data), `"content":[{"type":"text","text":"Sample result"}]`)
assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`)
assert.NotContains(t, string(data), `"isError":`)
}

View File

@@ -1,15 +1,107 @@
package mcp
import (
"fmt"
)
import "fmt"
// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings
func formatSSEMessage(event string, data []byte) string {
return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data))
}
// ptr is a helper function to get a pointer to a value
func ptr[T any](v T) *T {
return &v
}
// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings
func formatSSEMessage(event string, data []byte) string {
return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data))
func toTypedContents(contents []any) []any {
typedContents := make([]any, len(contents))
for i, content := range contents {
switch v := content.(type) {
case TextContent:
typedContents[i] = typedTextContent{
Type: ContentTypeText,
TextContent: v,
}
case ImageContent:
typedContents[i] = typedImageContent{
Type: ContentTypeImage,
ImageContent: v,
}
case AudioContent:
typedContents[i] = typedAudioContent{
Type: ContentTypeAudio,
AudioContent: v,
}
default:
typedContents[i] = typedTextContent{
Type: ContentTypeText,
TextContent: TextContent{
Text: fmt.Sprintf("Unknown content type: %T", v),
},
}
}
}
return typedContents
}
func toTypedPromptMessages(messages []PromptMessage) []PromptMessage {
typedMessages := make([]PromptMessage, len(messages))
for i, msg := range messages {
switch v := msg.Content.(type) {
case TextContent:
typedMessages[i] = PromptMessage{
Role: msg.Role,
Content: typedTextContent{
Type: ContentTypeText,
TextContent: v,
},
}
case ImageContent:
typedMessages[i] = PromptMessage{
Role: msg.Role,
Content: typedImageContent{
Type: ContentTypeImage,
ImageContent: v,
},
}
case AudioContent:
typedMessages[i] = PromptMessage{
Role: msg.Role,
Content: typedAudioContent{
Type: ContentTypeAudio,
AudioContent: v,
},
}
default:
typedMessages[i] = PromptMessage{
Role: msg.Role,
Content: typedTextContent{
Type: ContentTypeText,
TextContent: TextContent{
Text: fmt.Sprintf("Unknown content type: %T", v),
},
},
}
}
}
return typedMessages
}
// validatePromptArguments checks if all required arguments are provided
// Returns a list of missing required arguments
func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string {
var missingArgs []string
for _, arg := range prompt.Arguments {
if arg.Required {
if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 {
missingArgs = append(missingArgs, arg.Name)
}
}
}
return missingArgs
}

View File

@@ -8,29 +8,9 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPtr(t *testing.T) {
tests := []struct {
name string
v interface{}
}{
{"string", "test"},
{"int", 42},
{"bool", true},
{"float", 3.14},
{"struct", struct{ Name string }{"test"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ptr(tt.v)
assert.NotNil(t, got, "ptr() should not return nil")
assert.Equal(t, tt.v, *got, "dereferenced pointer should equal input value")
})
}
}
type Event struct {
Type string
Data map[string]any
@@ -61,3 +41,234 @@ func parseEvent(input string) (*Event, error) {
return &evt, nil
}
// TestToTypedPromptMessages tests the toTypedPromptMessages function
func TestToTypedPromptMessages(t *testing.T) {
// Test with multiple message types in one test
t.Run("MixedContentTypes", func(t *testing.T) {
// Create test data with different content types
messages := []PromptMessage{
{
Role: RoleUser,
Content: TextContent{
Text: "Hello, this is a text message",
Annotations: &Annotations{
Audience: []RoleType{RoleUser, RoleAssistant},
Priority: ptr(0.8),
},
},
},
{
Role: RoleAssistant,
Content: ImageContent{
Data: "base64ImageData",
MimeType: "image/jpeg",
},
},
{
Role: RoleUser,
Content: AudioContent{
Data: "base64AudioData",
MimeType: "audio/mp3",
},
},
{
Role: "system",
Content: "This is a simple string that should be handled as unknown type",
},
}
// Call the function
result := toTypedPromptMessages(messages)
// Validate results
require.Len(t, result, 4, "Should return the same number of messages")
// Validate first message (TextContent)
msg := result[0]
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
// Type assertion using reflection since Content is an interface
typed, ok := msg.Content.(typedTextContent)
require.True(t, ok, "Should be typedTextContent")
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
assert.Equal(t, "Hello, this is a text message", typed.Text, "Text content should be preserved")
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
assert.Equal(t, 0.8, *typed.Annotations.Priority, "Priority value should be preserved")
// Validate second message (ImageContent)
msg = result[1]
assert.Equal(t, RoleAssistant, msg.Role, "Role should be preserved")
// Type assertion for image content
typedImg, ok := msg.Content.(typedImageContent)
require.True(t, ok, "Should be typedImageContent")
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
assert.Equal(t, "image/jpeg", typedImg.MimeType, "MimeType should be preserved")
// Validate third message (AudioContent)
msg = result[2]
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
// Type assertion for audio content
typedAudio, ok := msg.Content.(typedAudioContent)
require.True(t, ok, "Should be typedAudioContent")
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
assert.Equal(t, "audio/mp3", typedAudio.MimeType, "MimeType should be preserved")
// Validate fourth message (unknown type converted to TextContent)
msg = result[3]
assert.Equal(t, RoleType("system"), msg.Role, "Role should be preserved")
// Should be converted to a typedTextContent with error message
typedUnknown, ok := msg.Content.(typedTextContent)
require.True(t, ok, "Unknown content should be converted to typedTextContent")
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
})
// Test empty input
t.Run("EmptyInput", func(t *testing.T) {
messages := []PromptMessage{}
result := toTypedPromptMessages(messages)
assert.Empty(t, result, "Should return empty slice for empty input")
})
// Test with nil annotations
t.Run("NilAnnotations", func(t *testing.T) {
messages := []PromptMessage{
{
Role: RoleUser,
Content: TextContent{
Text: "Text with nil annotations",
Annotations: nil,
},
},
}
result := toTypedPromptMessages(messages)
require.Len(t, result, 1, "Should return one message")
typed, ok := result[0].Content.(typedTextContent)
require.True(t, ok, "Should be typedTextContent")
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
})
}
// TestToTypedContents tests the toTypedContents function
func TestToTypedContents(t *testing.T) {
// Test with multiple content types in one test
t.Run("MixedContentTypes", func(t *testing.T) {
// Create test data with different content types
contents := []any{
TextContent{
Text: "Hello, this is a text content",
Annotations: &Annotations{
Audience: []RoleType{RoleUser, RoleAssistant},
Priority: ptr(0.7),
},
},
ImageContent{
Data: "base64ImageData",
MimeType: "image/png",
},
AudioContent{
Data: "base64AudioData",
MimeType: "audio/wav",
},
"This is a simple string that should be handled as unknown type",
}
// Call the function
result := toTypedContents(contents)
// Validate results
require.Len(t, result, 4, "Should return the same number of contents")
// Validate first content (TextContent)
typed, ok := result[0].(typedTextContent)
require.True(t, ok, "Should be typedTextContent")
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
assert.Equal(t, "Hello, this is a text content", typed.Text, "Text content should be preserved")
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
assert.Equal(t, 0.7, *typed.Annotations.Priority, "Priority value should be preserved")
// Validate second content (ImageContent)
typedImg, ok := result[1].(typedImageContent)
require.True(t, ok, "Should be typedImageContent")
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
assert.Equal(t, "image/png", typedImg.MimeType, "MimeType should be preserved")
// Validate third content (AudioContent)
typedAudio, ok := result[2].(typedAudioContent)
require.True(t, ok, "Should be typedAudioContent")
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
assert.Equal(t, "audio/wav", typedAudio.MimeType, "MimeType should be preserved")
// Validate fourth content (unknown type converted to TextContent)
typedUnknown, ok := result[3].(typedTextContent)
require.True(t, ok, "Unknown content should be converted to typedTextContent")
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
})
// Test empty input
t.Run("EmptyInput", func(t *testing.T) {
contents := []any{}
result := toTypedContents(contents)
assert.Empty(t, result, "Should return empty slice for empty input")
})
// Test with nil annotations
t.Run("NilAnnotations", func(t *testing.T) {
contents := []any{
TextContent{
Text: "Text with nil annotations",
Annotations: nil,
},
}
result := toTypedContents(contents)
require.Len(t, result, 1, "Should return one content")
typed, ok := result[0].(typedTextContent)
require.True(t, ok, "Should be typedTextContent")
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
})
// Test with custom struct (should be handled as unknown type)
t.Run("CustomStruct", func(t *testing.T) {
type CustomContent struct {
Data string
}
contents := []any{
CustomContent{
Data: "custom data",
},
}
result := toTypedContents(contents)
require.Len(t, result, 1, "Should return one content")
typed, ok := result[0].(typedTextContent)
require.True(t, ok, "Custom struct should be converted to typedTextContent")
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
assert.Contains(t, typed.Text, "Unknown content type:", "Should contain error about unknown type")
assert.Contains(t, typed.Text, "CustomContent", "Should mention the actual type")
})
}

View File

@@ -13,6 +13,9 @@ const (
// Session identifier key used in request URLs
sessionIdKey = "session_id"
// progressTokenKey is used to track progress of long-running tasks
progressTokenKey = "progressToken"
)
// Server-Sent Events (SSE) event types
@@ -26,11 +29,20 @@ const (
// Content type identifiers
const (
// Text content type
contentTypeText = "text"
// ContentTypeObject is object content type
ContentTypeObject = "object"
// Image content type
contentTypeImage = "image"
// ContentTypeText is text content type
ContentTypeText = "text"
// ContentTypeImage is image content type
ContentTypeImage = "image"
// ContentTypeAudio is audio content type
ContentTypeAudio = "audio"
// ContentTypeResource is resource content type
ContentTypeResource = "resource"
)
// Collection keys for broadcast events
@@ -72,11 +84,11 @@ const (
// User and assistant role definitions
const (
// The "user" role - the entity asking questions
roleUser roleType = "user"
// RoleUser is the "user" role - the entity asking questions
RoleUser RoleType = "user"
// The "assistant" role - the entity providing responses
roleAssistant roleType = "assistant"
// RoleAssistant is the "assistant" role - the entity providing responses
RoleAssistant RoleType = "assistant"
)
// Method names as defined in the MCP specification

View File

@@ -146,13 +146,9 @@ func TestCollectionKeys(t *testing.T) {
// TestRoleTypes checks that role types are used correctly
func TestRoleTypes(t *testing.T) {
// Verify role type constants
assert.Equal(t, "user", string(roleUser), "User role should be 'user'")
assert.Equal(t, "assistant", string(roleAssistant), "Assistant role should be 'assistant'")
// Test in annotations
annotations := Annotations{
Audience: []roleType{roleUser, roleAssistant},
Audience: []RoleType{RoleUser, RoleAssistant},
}
data, err := json.Marshal(annotations)
assert.NoError(t, err)