mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-07 23:20:00 +08:00
@@ -34,7 +34,10 @@ type McpConf struct {
|
||||
// Cors contains allowed CORS origins
|
||||
Cors []string `json:",optional"`
|
||||
|
||||
// ToolTimeout is the maximum time allowed for tool execution
|
||||
ToolTimeout time.Duration `json:",default=30s"`
|
||||
// SseTimeout is the maximum time allowed for SSE connections
|
||||
SseTimeout time.Duration `json:",default=24h"`
|
||||
|
||||
// MessageTimeout is the maximum time allowed for request execution
|
||||
MessageTimeout time.Duration `json:",default=30s"`
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ mcp:
|
||||
assert.Equal(t, "2024-11-05", c.Mcp.ProtocolVersion, "Default protocol version should be 2024-11-05")
|
||||
assert.Equal(t, "/sse", c.Mcp.SseEndpoint, "Default SSE endpoint should be /sse")
|
||||
assert.Equal(t, "/message", c.Mcp.MessageEndpoint, "Default message endpoint should be /message")
|
||||
assert.Equal(t, 30*time.Second, c.Mcp.ToolTimeout, "Default tool timeout should be 30s")
|
||||
assert.Equal(t, 30*time.Second, c.Mcp.MessageTimeout, "Default message timeout should be 30s")
|
||||
}
|
||||
|
||||
func TestMcpConfCustomValues(t *testing.T) {
|
||||
@@ -43,7 +43,7 @@ func TestMcpConfCustomValues(t *testing.T) {
|
||||
"SseEndpoint": "/custom-sse",
|
||||
"MessageEndpoint": "/custom-message",
|
||||
"Cors": ["http://localhost:3000", "http://example.com"],
|
||||
"ToolTimeout": "60s"
|
||||
"MessageTimeout": "60s"
|
||||
}
|
||||
}`
|
||||
|
||||
@@ -59,5 +59,5 @@ func TestMcpConfCustomValues(t *testing.T) {
|
||||
assert.Equal(t, "/custom-sse", c.Mcp.SseEndpoint, "SSE endpoint should be customizable")
|
||||
assert.Equal(t, "/custom-message", c.Mcp.MessageEndpoint, "Message endpoint should be customizable")
|
||||
assert.Equal(t, []string{"http://localhost:3000", "http://example.com"}, c.Mcp.Cors, "CORS settings should be customizable")
|
||||
assert.Equal(t, 60*time.Second, c.Mcp.ToolTimeout, "Tool timeout should be customizable")
|
||||
assert.Equal(t, 60*time.Second, c.Mcp.MessageTimeout, "Tool timeout should be customizable")
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package mcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -59,7 +60,7 @@ func TestHTTPHandlerIntegration(t *testing.T) {
|
||||
conf := McpConf{}
|
||||
conf.Mcp.Name = "test-integration"
|
||||
conf.Mcp.Version = "1.0.0-test"
|
||||
conf.Mcp.ToolTimeout = 1 * time.Second
|
||||
conf.Mcp.MessageTimeout = 1 * time.Second
|
||||
|
||||
// Create a mock server directly
|
||||
server := &sseMcpServer{
|
||||
@@ -75,7 +76,6 @@ func TestHTTPHandlerIntegration(t *testing.T) {
|
||||
Name: "echo",
|
||||
Description: "Echo tool for testing",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"message": map[string]any{
|
||||
"type": "string",
|
||||
@@ -83,7 +83,7 @@ func TestHTTPHandlerIntegration(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
if msg, ok := params["message"].(string); ok {
|
||||
return fmt.Sprintf("Echo: %s", msg), nil
|
||||
}
|
||||
@@ -181,7 +181,7 @@ func TestHandlerResponseFlow(t *testing.T) {
|
||||
Name: "test.tool",
|
||||
Description: "Test tool",
|
||||
InputSchema: InputSchema{Type: "object"},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return "tool result", nil
|
||||
},
|
||||
})
|
||||
@@ -329,7 +329,7 @@ func TestProcessListMethods(t *testing.T) {
|
||||
Params: json.RawMessage(`{"cursor": "", "_meta": {"progressToken": "token1"}}`),
|
||||
}
|
||||
|
||||
server.processListTools(client, req)
|
||||
server.processListTools(context.Background(), client, req)
|
||||
|
||||
// Read response
|
||||
select {
|
||||
@@ -344,7 +344,7 @@ func TestProcessListMethods(t *testing.T) {
|
||||
req.ID = 2
|
||||
req.Method = methodPromptsList
|
||||
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
||||
server.processListPrompts(client, req)
|
||||
server.processListPrompts(context.Background(), client, req)
|
||||
|
||||
// Read response
|
||||
select {
|
||||
@@ -358,7 +358,7 @@ func TestProcessListMethods(t *testing.T) {
|
||||
req.ID = 3
|
||||
req.Method = methodResourcesList
|
||||
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
||||
server.processListResources(client, req)
|
||||
server.processListResources(context.Background(), client, req)
|
||||
|
||||
// Read response
|
||||
select {
|
||||
@@ -393,7 +393,7 @@ func TestErrorResponseHandling(t *testing.T) {
|
||||
}
|
||||
|
||||
// Mock handleRequest by directly calling error handler
|
||||
server.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||
server.sendErrorResponse(context.Background(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
@@ -412,7 +412,7 @@ func TestErrorResponseHandling(t *testing.T) {
|
||||
}
|
||||
|
||||
// Call process method directly
|
||||
server.processToolCall(client, toolReq)
|
||||
server.processToolCall(context.Background(), client, toolReq)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
@@ -431,7 +431,7 @@ func TestErrorResponseHandling(t *testing.T) {
|
||||
}
|
||||
|
||||
// Call process method directly
|
||||
server.processGetPrompt(client, promptReq)
|
||||
server.processGetPrompt(context.Background(), client, promptReq)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
|
||||
23
mcp/parser.go
Normal file
23
mcp/parser.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/mapping"
|
||||
)
|
||||
|
||||
// ParseArguments parses the arguments and populates the request object
|
||||
func ParseArguments(args any, req any) error {
|
||||
switch arguments := args.(type) {
|
||||
case map[string]string:
|
||||
m := make(map[string]any, len(arguments))
|
||||
for k, v := range arguments {
|
||||
m[k] = v
|
||||
}
|
||||
return mapping.UnmarshalJsonMap(m, req, mapping.WithStringValues())
|
||||
case map[string]any:
|
||||
return mapping.UnmarshalJsonMap(arguments, req)
|
||||
default:
|
||||
return fmt.Errorf("unsupported argument type: %T", arguments)
|
||||
}
|
||||
}
|
||||
139
mcp/parser_test.go
Normal file
139
mcp/parser_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestParseArguments_MapStringString tests parsing map[string]string arguments
|
||||
func TestParseArguments_MapStringString(t *testing.T) {
|
||||
// Sample request struct to populate
|
||||
type requestStruct struct {
|
||||
Name string `json:"name"`
|
||||
Message string `json:"message"`
|
||||
Count int `json:"count"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// Create test arguments
|
||||
args := map[string]string{
|
||||
"name": "test-name",
|
||||
"message": "hello world",
|
||||
"count": "42",
|
||||
"enabled": "true",
|
||||
}
|
||||
|
||||
// Create a target object to populate
|
||||
var req requestStruct
|
||||
|
||||
// Parse the arguments
|
||||
err := ParseArguments(args, &req)
|
||||
|
||||
// Verify results
|
||||
assert.NoError(t, err, "Should parse map[string]string without error")
|
||||
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
|
||||
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
|
||||
assert.Equal(t, 42, req.Count, "Count should be correctly parsed to int")
|
||||
assert.True(t, req.Enabled, "Enabled should be correctly parsed to bool")
|
||||
}
|
||||
|
||||
// TestParseArguments_MapStringAny tests parsing map[string]any arguments
|
||||
func TestParseArguments_MapStringAny(t *testing.T) {
|
||||
// Sample request struct to populate
|
||||
type requestStruct struct {
|
||||
Name string `json:"name"`
|
||||
Message string `json:"message"`
|
||||
Count int `json:"count"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Tags []string `json:"tags"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
}
|
||||
|
||||
// Create test arguments with mixed types
|
||||
args := map[string]any{
|
||||
"name": "test-name",
|
||||
"message": "hello world",
|
||||
"count": 42, // note: this is already an int
|
||||
"enabled": true, // note: this is already a bool
|
||||
"tags": []string{"tag1", "tag2"},
|
||||
"metadata": map[string]string{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
},
|
||||
}
|
||||
|
||||
// Create a target object to populate
|
||||
var req requestStruct
|
||||
|
||||
// Parse the arguments
|
||||
err := ParseArguments(args, &req)
|
||||
|
||||
// Verify results
|
||||
assert.NoError(t, err, "Should parse map[string]any without error")
|
||||
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
|
||||
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
|
||||
assert.Equal(t, 42, req.Count, "Count should be correctly parsed")
|
||||
assert.True(t, req.Enabled, "Enabled should be correctly parsed")
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, req.Tags, "Tags should be correctly parsed")
|
||||
assert.Equal(t, map[string]string{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
}, req.Metadata, "Metadata should be correctly parsed")
|
||||
}
|
||||
|
||||
// TestParseArguments_UnsupportedType tests parsing with an unsupported type
|
||||
func TestParseArguments_UnsupportedType(t *testing.T) {
|
||||
// Sample request struct to populate
|
||||
type requestStruct struct {
|
||||
Name string `json:"name"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Use an unsupported argument type (slice)
|
||||
args := []string{"not", "a", "map"}
|
||||
|
||||
// Create a target object to populate
|
||||
var req requestStruct
|
||||
|
||||
// Parse the arguments
|
||||
err := ParseArguments(args, &req)
|
||||
|
||||
// Verify error is returned with correct message
|
||||
assert.Error(t, err, "Should return error for unsupported type")
|
||||
assert.Contains(t, err.Error(), "unsupported argument type", "Error should mention unsupported type")
|
||||
assert.Contains(t, err.Error(), "[]string", "Error should include the actual type")
|
||||
}
|
||||
|
||||
// TestParseArguments_EmptyMap tests parsing with empty maps
|
||||
func TestParseArguments_EmptyMap(t *testing.T) {
|
||||
// Sample request struct to populate
|
||||
type requestStruct struct {
|
||||
Name string `json:"name,optional"`
|
||||
Message string `json:"message,optional"`
|
||||
}
|
||||
|
||||
// Test empty map[string]string
|
||||
t.Run("EmptyMapStringString", func(t *testing.T) {
|
||||
args := map[string]string{}
|
||||
var req requestStruct
|
||||
|
||||
err := ParseArguments(args, &req)
|
||||
|
||||
assert.NoError(t, err, "Should parse empty map[string]string without error")
|
||||
assert.Empty(t, req.Name, "Name should be empty string")
|
||||
assert.Empty(t, req.Message, "Message should be empty string")
|
||||
})
|
||||
|
||||
// Test empty map[string]any
|
||||
t.Run("EmptyMapStringAny", func(t *testing.T) {
|
||||
args := map[string]any{}
|
||||
var req requestStruct
|
||||
|
||||
err := ParseArguments(args, &req)
|
||||
|
||||
assert.NoError(t, err, "Should parse empty map[string]any without error")
|
||||
assert.Empty(t, req.Name, "Name should be empty string")
|
||||
assert.Empty(t, req.Message, "Message should be empty string")
|
||||
})
|
||||
}
|
||||
824
mcp/readme.md
824
mcp/readme.md
@@ -1,7 +1,7 @@
|
||||
# Model Context Protocol (MCP) SDK Implementation
|
||||
# Model Context Protocol (MCP) Implementation
|
||||
|
||||
## Overview
|
||||
This package implements a Model Context Protocol (MCP) server in Go that facilitates real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation provides a framework for building AI-assisted applications with bidirectional communication capabilities.
|
||||
This package implements the Model Context Protocol (MCP) server specification in Go, providing a framework for real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation follows the standardized protocol for building AI-assisted applications with bidirectional communication capabilities.
|
||||
|
||||
## Core Components
|
||||
|
||||
@@ -54,9 +54,817 @@ This package implements a Model Context Protocol (MCP) server in Go that facilit
|
||||
|
||||
## Usage
|
||||
|
||||
To create and use an MCP server, see the examples directory for practical implementation examples including:
|
||||
- Tool registration and execution
|
||||
- Static and dynamic prompt creation
|
||||
- Resource handling with proper URI identification
|
||||
- Embedded resources in prompt responses
|
||||
- Client connection management
|
||||
### Setting Up an MCP Server
|
||||
|
||||
To create and start an MCP server:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/zeromicro/go-zero/core/conf"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/mcp"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Load configuration from YAML file
|
||||
var c mcp.McpConf
|
||||
conf.MustLoad("config.yaml", &c)
|
||||
|
||||
// Optional: Disable stats logging
|
||||
logx.DisableStat()
|
||||
|
||||
// Create MCP server
|
||||
server := mcp.NewMcpServer(c)
|
||||
|
||||
// Register tools, prompts, and resources (examples below)
|
||||
|
||||
// Start the server and ensure it's stopped on exit
|
||||
defer server.Stop()
|
||||
server.Start()
|
||||
}
|
||||
```
|
||||
|
||||
Sample configuration file (config.yaml):
|
||||
|
||||
```yaml
|
||||
name: mcp-server
|
||||
host: localhost
|
||||
port: 8080
|
||||
mcp:
|
||||
name: my-mcp-server
|
||||
messageTimeout: 30s # Timeout for tool calls
|
||||
cors:
|
||||
- http://localhost:3000 # Optional CORS configuration
|
||||
```
|
||||
|
||||
### Registering Tools
|
||||
|
||||
Tools allow AI models to execute custom code through the MCP protocol.
|
||||
|
||||
#### Basic Tool Example:
|
||||
|
||||
```go
|
||||
// Register a simple echo tool
|
||||
echoTool := mcp.Tool{
|
||||
Name: "echo",
|
||||
Description: "Echoes back the message provided by the user",
|
||||
InputSchema: mcp.InputSchema{
|
||||
Properties: map[string]any{
|
||||
"message": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The message to echo back",
|
||||
},
|
||||
"prefix": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional prefix to add to the echoed message",
|
||||
"default": "Echo: ",
|
||||
},
|
||||
},
|
||||
Required: []string{"message"},
|
||||
},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
var req struct {
|
||||
Message string `json:"message"`
|
||||
Prefix string `json:"prefix,optional"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||
}
|
||||
|
||||
prefix := "Echo: "
|
||||
if len(req.Prefix) > 0 {
|
||||
prefix = req.Prefix
|
||||
}
|
||||
|
||||
return prefix + req.Message, nil
|
||||
},
|
||||
}
|
||||
|
||||
server.RegisterTool(echoTool)
|
||||
```
|
||||
|
||||
#### Tool with Different Response Types:
|
||||
|
||||
```go
|
||||
// Tool returning JSON data
|
||||
dataTool := mcp.Tool{
|
||||
Name: "data.generate",
|
||||
Description: "Generates sample data in various formats",
|
||||
InputSchema: mcp.InputSchema{
|
||||
Properties: map[string]any{
|
||||
"format": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Format of data (json, text)",
|
||||
"enum": []string{"json", "text"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
var req struct {
|
||||
Format string `json:"format"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||
}
|
||||
|
||||
if req.Format == "json" {
|
||||
// Return structured data
|
||||
return map[string]any{
|
||||
"items": []map[string]any{
|
||||
{"id": 1, "name": "Item 1"},
|
||||
{"id": 2, "name": "Item 2"},
|
||||
},
|
||||
"count": 2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Default to text
|
||||
return "Sample text data", nil
|
||||
},
|
||||
}
|
||||
|
||||
server.RegisterTool(dataTool)
|
||||
```
|
||||
|
||||
#### Image Generation Tool Example:
|
||||
|
||||
```go
|
||||
// Tool returning image content
|
||||
imageTool := mcp.Tool{
|
||||
Name: "image.generate",
|
||||
Description: "Generates a simple image",
|
||||
InputSchema: mcp.InputSchema{
|
||||
Properties: map[string]any{
|
||||
"type": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Type of image to generate",
|
||||
"default": "placeholder",
|
||||
},
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
// Return image content directly
|
||||
return mcp.ImageContent{
|
||||
Data: "base64EncodedImageData...", // Base64 encoded image data
|
||||
MimeType: "image/png",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
server.RegisterTool(imageTool)
|
||||
```
|
||||
|
||||
#### Using ToolResult for Custom Outputs:
|
||||
|
||||
```go
|
||||
// Tool that returns a custom ToolResult type
|
||||
customResultTool := mcp.Tool{
|
||||
Name: "custom.result",
|
||||
Description: "Returns a custom formatted result",
|
||||
InputSchema: mcp.InputSchema{
|
||||
Properties: map[string]any{
|
||||
"resultType": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{"text", "image"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
var req struct {
|
||||
ResultType string `json:"resultType"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||
}
|
||||
|
||||
if req.ResultType == "image" {
|
||||
return mcp.ToolResult{
|
||||
Type: mcp.ContentTypeImage,
|
||||
Content: map[string]any{
|
||||
"data": "base64EncodedImageData...",
|
||||
"mimeType": "image/jpeg",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Default to text
|
||||
return mcp.ToolResult{
|
||||
Type: mcp.ContentTypeText,
|
||||
Content: "This is a text result from ToolResult",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
server.RegisterTool(customResultTool)
|
||||
```
|
||||
|
||||
### Registering Prompts
|
||||
|
||||
Prompts are reusable conversation templates for AI models.
|
||||
|
||||
#### Static Prompt Example:
|
||||
|
||||
```go
|
||||
// Register a simple static prompt with placeholders
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "hello",
|
||||
Description: "A simple hello prompt",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "name",
|
||||
Description: "The name to greet",
|
||||
Required: false,
|
||||
},
|
||||
},
|
||||
Content: "Say hello to {{name}} and introduce yourself as an AI assistant.",
|
||||
})
|
||||
```
|
||||
|
||||
#### Dynamic Prompt with Handler Function:
|
||||
|
||||
```go
|
||||
// Register a prompt with a dynamic handler function
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "dynamic-prompt",
|
||||
Description: "A prompt that uses a handler to generate dynamic content",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "username",
|
||||
Description: "User's name for personalized greeting",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "topic",
|
||||
Description: "Topic of expertise",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||
var req struct {
|
||||
Username string `json:"username"`
|
||||
Topic string `json:"topic"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||
}
|
||||
|
||||
// Create a user message
|
||||
userMessage := mcp.PromptMessage{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic),
|
||||
},
|
||||
}
|
||||
|
||||
// Create an assistant response with current time
|
||||
currentTime := time.Now().Format(time.RFC1123)
|
||||
assistantMessage := mcp.PromptMessage{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.",
|
||||
req.Username, req.Topic, currentTime),
|
||||
},
|
||||
}
|
||||
|
||||
// Return both messages as a conversation
|
||||
return []mcp.PromptMessage{userMessage, assistantMessage}, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
#### Multi-Message Prompt with Code Examples:
|
||||
|
||||
```go
|
||||
// Register a prompt that provides code examples in different programming languages
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "code-example",
|
||||
Description: "Provides code examples in different programming languages",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "language",
|
||||
Description: "Programming language for the example",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "complexity",
|
||||
Description: "Complexity level (simple, medium, advanced)",
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||
var req struct {
|
||||
Language string `json:"language"`
|
||||
Complexity string `json:"complexity,optional"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||
}
|
||||
|
||||
// Validate language
|
||||
supportedLanguages := map[string]bool{"go": true, "python": true, "javascript": true, "rust": true}
|
||||
if !supportedLanguages[req.Language] {
|
||||
return nil, fmt.Errorf("unsupported language: %s", req.Language)
|
||||
}
|
||||
|
||||
// Generate code example based on language and complexity
|
||||
var codeExample string
|
||||
|
||||
switch req.Language {
|
||||
case "go":
|
||||
if req.Complexity == "simple" {
|
||||
codeExample = `
|
||||
package main
|
||||
|
||||
import "fmt"
|
||||
|
||||
func main() {
|
||||
fmt.Println("Hello, World!")
|
||||
}`
|
||||
} else {
|
||||
codeExample = `
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
now := time.Now()
|
||||
fmt.Printf("Hello, World! Current time is %s\n", now.Format(time.RFC3339))
|
||||
}`
|
||||
}
|
||||
case "python":
|
||||
// Python example code
|
||||
if req.Complexity == "simple" {
|
||||
codeExample = `
|
||||
def greet(name):
|
||||
return f"Hello, {name}!"
|
||||
|
||||
print(greet("World"))`
|
||||
} else {
|
||||
codeExample = `
|
||||
import datetime
|
||||
|
||||
def greet(name, include_time=False):
|
||||
message = f"Hello, {name}!"
|
||||
if include_time:
|
||||
message += f" Current time is {datetime.datetime.now().isoformat()}"
|
||||
return message
|
||||
|
||||
print(greet("World", include_time=True))`
|
||||
}
|
||||
}
|
||||
|
||||
// Create messages array according to MCP spec
|
||||
messages := []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("You are a helpful coding assistant specialized in %s programming.", req.Language),
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Show me a %s example of a Hello World program in %s.", req.Complexity, req.Language),
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Here's a %s example in %s:\n\n```%s%s\n```\n\nHow can I help you implement this?",
|
||||
req.Complexity, req.Language, req.Language, codeExample),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
### Registering Resources
|
||||
|
||||
Resources provide access to external content such as files or generated data.
|
||||
|
||||
#### Basic Resource Example:
|
||||
|
||||
```go
|
||||
// Register a static resource
|
||||
server.RegisterResource(mcp.Resource{
|
||||
Name: "example-document",
|
||||
URI: "file:///example/document.txt",
|
||||
Description: "An example document",
|
||||
MimeType: "text/plain",
|
||||
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||
return mcp.ResourceContent{
|
||||
URI: "file:///example/document.txt",
|
||||
MimeType: "text/plain",
|
||||
Text: "This is an example document content.",
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
#### Dynamic Resource with Code Example:
|
||||
|
||||
```go
|
||||
// Register a Go code resource with dynamic handler
|
||||
server.RegisterResource(mcp.Resource{
|
||||
Name: "go-example",
|
||||
URI: "file:///project/src/main.go",
|
||||
Description: "A simple Go example with multiple files",
|
||||
MimeType: "text/x-go",
|
||||
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||
// Return ResourceContent with all required fields
|
||||
return mcp.ResourceContent{
|
||||
URI: "file:///project/src/main.go",
|
||||
MimeType: "text/x-go",
|
||||
Text: "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}",
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register a companion file for the above example
|
||||
server.RegisterResource(mcp.Resource{
|
||||
Name: "go-greeting",
|
||||
URI: "file:///project/src/greeting/greeting.go",
|
||||
Description: "A greeting package for the Go example",
|
||||
MimeType: "text/x-go",
|
||||
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||
return mcp.ResourceContent{
|
||||
URI: "file:///project/src/greeting/greeting.go",
|
||||
MimeType: "text/x-go",
|
||||
Text: "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}",
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
#### Binary Resource Example:
|
||||
|
||||
```go
|
||||
// Register a binary resource (like an image)
|
||||
server.RegisterResource(mcp.Resource{
|
||||
Name: "example-image",
|
||||
URI: "file:///example/image.png",
|
||||
Description: "An example image",
|
||||
MimeType: "image/png",
|
||||
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||
// Read image from file or generate it
|
||||
imageData := "base64EncodedImageData..." // Base64 encoded image data
|
||||
|
||||
return mcp.ResourceContent{
|
||||
URI: "file:///example/image.png",
|
||||
MimeType: "image/png",
|
||||
Blob: imageData, // For binary data
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
### Using Resources in Prompts
|
||||
|
||||
You can embed resources in prompt responses to create rich interactions with proper MCP-compliant structure:
|
||||
|
||||
```go
|
||||
// Register a prompt that embeds a resource
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "resource-example",
|
||||
Description: "A prompt that embeds a resource",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "file_type",
|
||||
Description: "Type of file to show (rust or go)",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||
var req struct {
|
||||
FileType string `json:"file_type"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||
}
|
||||
|
||||
var resourceURI, mimeType, fileContent string
|
||||
if req.FileType == "rust" {
|
||||
resourceURI = "file:///project/src/main.rs"
|
||||
mimeType = "text/x-rust"
|
||||
fileContent = "fn main() {\n println!(\"Hello world!\");\n}"
|
||||
} else {
|
||||
resourceURI = "file:///project/src/main.go"
|
||||
mimeType = "text/x-go"
|
||||
fileContent = "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello, world!\")\n}"
|
||||
}
|
||||
|
||||
// Create message with embedded resource using proper MCP format
|
||||
return []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Can you explain this %s code?", req.FileType),
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.EmbeddedResource{
|
||||
Type: mcp.ContentTypeResource,
|
||||
Resource: struct {
|
||||
URI string `json:"uri"`
|
||||
MimeType string `json:"mimeType"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Blob string `json:"blob,omitempty"`
|
||||
}{
|
||||
URI: resourceURI,
|
||||
MimeType: mimeType,
|
||||
Text: fileContent,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Above is a simple Hello World example in %s. Let me explain how it works.", req.FileType),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
### Multiple File Resources Example
|
||||
|
||||
```go
|
||||
// Register a prompt that demonstrates embedding multiple resource files
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "go-code-example",
|
||||
Description: "A prompt that correctly embeds multiple resource files",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "format",
|
||||
Description: "How to format the code display",
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||
var req struct {
|
||||
Format string `json:"format,optional"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||
}
|
||||
|
||||
// Get the Go code for multiple files
|
||||
var mainGoText string = "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}"
|
||||
var greetingGoText string = "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}"
|
||||
|
||||
// Create message with properly formatted embedded resource per MCP spec
|
||||
messages := []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Text: "Show me a simple Go example with proper imports.",
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: "Here's a simple Go example project:",
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.EmbeddedResource{
|
||||
Type: mcp.ContentTypeResource,
|
||||
Resource: struct {
|
||||
URI string `json:"uri"`
|
||||
MimeType string `json:"mimeType"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Blob string `json:"blob,omitempty"`
|
||||
}{
|
||||
URI: "file:///project/src/main.go",
|
||||
MimeType: "text/x-go",
|
||||
Text: mainGoText,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add explanation and additional file if requested
|
||||
if req.Format == "with_explanation" {
|
||||
messages = append(messages, mcp.PromptMessage{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: "This example demonstrates a simple Go application with modular structure. The main.go file imports from a local 'greeting' package that provides the Hello function.",
|
||||
},
|
||||
})
|
||||
|
||||
// Also show the greeting.go file with correct resource format
|
||||
messages = append(messages, mcp.PromptMessage{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.EmbeddedResource{
|
||||
Type: mcp.ContentTypeResource,
|
||||
Resource: struct {
|
||||
URI string `json:"uri"`
|
||||
MimeType string `json:"mimeType"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Blob string `json:"blob,omitempty"`
|
||||
}{
|
||||
URI: "file:///project/src/greeting/greeting.go",
|
||||
MimeType: "text/x-go",
|
||||
Text: greetingGoText,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
### Complete Application Example
|
||||
|
||||
Here's a complete example demonstrating all the components:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/conf"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/mcp"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Load configuration
|
||||
var c mcp.McpConf
|
||||
if err := conf.Load("config.yaml", &c); err != nil {
|
||||
log.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Set up logging
|
||||
logx.DisableStat()
|
||||
|
||||
// Create MCP server
|
||||
server := mcp.NewMcpServer(c)
|
||||
defer server.Stop()
|
||||
|
||||
// Register a simple echo tool
|
||||
echoTool := mcp.Tool{
|
||||
Name: "echo",
|
||||
Description: "Echoes back the message provided by the user",
|
||||
InputSchema: mcp.InputSchema{
|
||||
Properties: map[string]any{
|
||||
"message": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The message to echo back",
|
||||
},
|
||||
"prefix": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional prefix to add to the echoed message",
|
||||
"default": "Echo: ",
|
||||
},
|
||||
},
|
||||
Required: []string{"message"},
|
||||
},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
var req struct {
|
||||
Message string `json:"message"`
|
||||
Prefix string `json:"prefix,optional"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||
}
|
||||
|
||||
prefix := "Echo: "
|
||||
if len(req.Prefix) > 0 {
|
||||
prefix = req.Prefix
|
||||
}
|
||||
|
||||
return prefix + req.Message, nil
|
||||
},
|
||||
}
|
||||
server.RegisterTool(echoTool)
|
||||
|
||||
// Register a static prompt
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "greeting",
|
||||
Description: "A simple greeting prompt",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "name",
|
||||
Description: "The name to greet",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
Content: "Hello {{name}}! How can I assist you today?",
|
||||
})
|
||||
|
||||
// Register a dynamic prompt
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "dynamic-prompt",
|
||||
Description: "A prompt that uses a handler to generate dynamic content",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "username",
|
||||
Description: "User's name for personalized greeting",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "topic",
|
||||
Description: "Topic of expertise",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||
var req struct {
|
||||
Username string `json:"username"`
|
||||
Topic string `json:"topic"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||
}
|
||||
|
||||
// Create messages with current time
|
||||
currentTime := time.Now().Format(time.RFC1123)
|
||||
return []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic),
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.",
|
||||
req.Username, req.Topic, currentTime),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register a resource
|
||||
server.RegisterResource(mcp.Resource{
|
||||
Name: "example-doc",
|
||||
URI: "file:///example/doc.txt",
|
||||
Description: "An example document",
|
||||
MimeType: "text/plain",
|
||||
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||
return mcp.ResourceContent{
|
||||
URI: "file:///example/doc.txt",
|
||||
MimeType: "text/plain",
|
||||
Text: "This is the content of the example document.",
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
|
||||
// Start the server
|
||||
fmt.Printf("Starting MCP server on %s:%d\n", c.Host, c.Port)
|
||||
server.Start()
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The MCP implementation provides comprehensive error handling:
|
||||
|
||||
- Tool execution errors are properly reported back to clients
|
||||
- Missing or invalid parameters are detected and reported with appropriate error codes
|
||||
- Resource and prompt lookup failures are handled gracefully
|
||||
- Timeout handling for long-running tool executions using context
|
||||
- Panic recovery to prevent server crashes
|
||||
|
||||
## Advanced Features
|
||||
|
||||
- **Annotations**: Add audience and priority metadata to content
|
||||
- **Content Types**: Support for text, images, audio, and other content formats
|
||||
- **Embedded Resources**: Include file resources directly in prompt responses
|
||||
- **Context Awareness**: All handlers receive context.Context for timeout and cancellation support
|
||||
- **Progress Tokens**: Support for tracking progress of long-running operations
|
||||
- **Customizable Timeouts**: Configure execution timeouts for tools and operations
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
- Tool execution runs with configurable timeouts to prevent blocking
|
||||
- Efficient client tracking and cleanup to prevent resource leaks
|
||||
- Proper concurrency handling with mutex protection for shared resources
|
||||
- Buffered message channels to prevent blocking on client message delivery
|
||||
|
||||
239
mcp/server.go
239
mcp/server.go
@@ -42,14 +42,14 @@ func NewMcpServer(c McpConf) McpServer {
|
||||
Method: http.MethodGet,
|
||||
Path: s.conf.Mcp.SseEndpoint,
|
||||
Handler: s.handleSSE,
|
||||
}, rest.WithSSE())
|
||||
}, rest.WithSSE(), rest.WithTimeout(c.Mcp.SseTimeout))
|
||||
|
||||
// JSON-RPC message endpoint for regular requests
|
||||
s.server.AddRoute(rest.Route{
|
||||
Method: http.MethodPost,
|
||||
Path: s.conf.Mcp.MessageEndpoint,
|
||||
Handler: s.handleRequest,
|
||||
})
|
||||
}, rest.WithTimeout(c.Mcp.MessageTimeout))
|
||||
|
||||
return s
|
||||
}
|
||||
@@ -182,21 +182,23 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
// Always allow initialize and notifications/initialized regardless of client state
|
||||
if req.Method == methodInitialize {
|
||||
logx.Infof("Processing initialize request with ID: %d", req.ID)
|
||||
s.processInitialize(client, req)
|
||||
s.processInitialize(r.Context(), client, req)
|
||||
logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID)
|
||||
return
|
||||
} else if req.Method == methodNotificationsInitialized {
|
||||
// Handle initialized notification
|
||||
logx.Info("Received notifications/initialized notification")
|
||||
if !isNotification {
|
||||
s.sendErrorResponse(client, req.ID, "Method should be used as a notification", errCodeInvalidRequest)
|
||||
s.sendErrorResponse(r.Context(), client, req.ID,
|
||||
"Method should be used as a notification", errCodeInvalidRequest)
|
||||
return
|
||||
}
|
||||
s.processNotificationInitialized(client)
|
||||
return
|
||||
} else if !client.initialized && req.Method != methodNotificationsCancelled {
|
||||
// Block most requests until client is initialized (except for cancellations)
|
||||
s.sendErrorResponse(client, req.ID, "Client not fully initialized, waiting for notifications/initialized",
|
||||
s.sendErrorResponse(r.Context(), client, req.ID,
|
||||
"Client not fully initialized, waiting for notifications/initialized",
|
||||
errCodeClientNotInitialized)
|
||||
return
|
||||
}
|
||||
@@ -205,41 +207,41 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
switch req.Method {
|
||||
case methodToolsCall:
|
||||
logx.Infof("Received tools call request with ID: %d", req.ID)
|
||||
s.processToolCall(client, req)
|
||||
s.processToolCall(r.Context(), client, req)
|
||||
logx.Infof("Sent tools call response for ID: %d", req.ID)
|
||||
case methodToolsList:
|
||||
logx.Infof("Processing tools/list request with ID: %d", req.ID)
|
||||
s.processListTools(client, req)
|
||||
s.processListTools(r.Context(), client, req)
|
||||
logx.Infof("Sent tools/list response for ID: %d", req.ID)
|
||||
case methodPromptsList:
|
||||
logx.Infof("Processing prompts/list request with ID: %d", req.ID)
|
||||
s.processListPrompts(client, req)
|
||||
s.processListPrompts(r.Context(), client, req)
|
||||
logx.Infof("Sent prompts/list response for ID: %d", req.ID)
|
||||
case methodPromptsGet:
|
||||
logx.Infof("Processing prompts/get request with ID: %d", req.ID)
|
||||
s.processGetPrompt(client, req)
|
||||
s.processGetPrompt(r.Context(), client, req)
|
||||
logx.Infof("Sent prompts/get response for ID: %d", req.ID)
|
||||
case methodResourcesList:
|
||||
logx.Infof("Processing resources/list request with ID: %d", req.ID)
|
||||
s.processListResources(client, req)
|
||||
s.processListResources(r.Context(), client, req)
|
||||
logx.Infof("Sent resources/list response for ID: %d", req.ID)
|
||||
case methodResourcesRead:
|
||||
logx.Infof("Processing resources/read request with ID: %d", req.ID)
|
||||
s.processResourcesRead(client, req)
|
||||
s.processResourcesRead(r.Context(), client, req)
|
||||
logx.Infof("Sent resources/read response for ID: %d", req.ID)
|
||||
case methodResourcesSubscribe:
|
||||
logx.Infof("Processing resources/subscribe request with ID: %d", req.ID)
|
||||
s.processResourceSubscribe(client, req)
|
||||
s.processResourceSubscribe(r.Context(), client, req)
|
||||
logx.Infof("Sent resources/subscribe response for ID: %d", req.ID)
|
||||
case methodPing:
|
||||
logx.Infof("Processing ping request with ID: %d", req.ID)
|
||||
s.processPing(client, req)
|
||||
s.processPing(r.Context(), client, req)
|
||||
case methodNotificationsCancelled:
|
||||
logx.Infof("Received notifications/cancelled notification: %v", req.Params)
|
||||
s.processNotificationCancelled(client, req)
|
||||
logx.Infof("Received notifications/cancelled notification: %d", req.ID)
|
||||
s.processNotificationCancelled(r.Context(), client, req)
|
||||
default:
|
||||
logx.Infof("Unknown method: %s", req.Method)
|
||||
s.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||
logx.Infof("Unknown method: %s from client: %d", req.Method, req.ID)
|
||||
s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -321,7 +323,7 @@ func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
flusher.Flush()
|
||||
case <-r.Context().Done():
|
||||
// Client disconnected or request was canceled
|
||||
// Client disconnected or request was canceled or timed out
|
||||
logx.Infof("Client %s disconnected: context done", sessionId)
|
||||
return
|
||||
}
|
||||
@@ -329,7 +331,7 @@ func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// processInitialize processes the initialize request
|
||||
func (s *sseMcpServer) processInitialize(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processInitialize(ctx context.Context, client *mcpClient, req Request) {
|
||||
// Create a proper JSON-RPC response that preserves the client's request ID
|
||||
result := initializationResponse{
|
||||
ProtocolVersion: s.conf.Mcp.ProtocolVersion,
|
||||
@@ -362,11 +364,11 @@ func (s *sseMcpServer) processInitialize(client *mcpClient, req Request) {
|
||||
client.initialized = true
|
||||
|
||||
// Send response with client's original request ID
|
||||
s.sendResponse(client, req.ID, result)
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processListTools processes the tools/list request
|
||||
func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processListTools(ctx context.Context, client *mcpClient, req Request) {
|
||||
// Extract pagination params if any
|
||||
var nextCursor string
|
||||
var progressToken any
|
||||
@@ -390,6 +392,9 @@ func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
|
||||
var toolsList []Tool
|
||||
s.toolsLock.Lock()
|
||||
for _, tool := range s.tools {
|
||||
if len(tool.InputSchema.Type) == 0 {
|
||||
tool.InputSchema.Type = ContentTypeObject
|
||||
}
|
||||
toolsList = append(toolsList, tool)
|
||||
}
|
||||
s.toolsLock.Unlock()
|
||||
@@ -405,15 +410,15 @@ func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
|
||||
// Add meta information if progress token was provided
|
||||
if progressToken != nil {
|
||||
result.Result.Meta = map[string]any{
|
||||
"progressToken": progressToken,
|
||||
progressTokenKey: progressToken,
|
||||
}
|
||||
}
|
||||
|
||||
s.sendResponse(client, req.ID, result)
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processListPrompts processes the prompts/list request
|
||||
func (s *sseMcpServer) processListPrompts(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processListPrompts(ctx context.Context, client *mcpClient, req Request) {
|
||||
// Extract pagination params if any
|
||||
var nextCursor string
|
||||
if req.Params != nil {
|
||||
@@ -447,11 +452,11 @@ func (s *sseMcpServer) processListPrompts(client *mcpClient, req Request) {
|
||||
NextCursor: nextCursor,
|
||||
}
|
||||
|
||||
s.sendResponse(client, req.ID, result)
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processListResources processes the resources/list request
|
||||
func (s *sseMcpServer) processListResources(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processListResources(ctx context.Context, client *mcpClient, req Request) {
|
||||
// Extract pagination params if any
|
||||
var nextCursor string
|
||||
var progressToken any
|
||||
@@ -493,15 +498,15 @@ func (s *sseMcpServer) processListResources(client *mcpClient, req Request) {
|
||||
// Add meta information if progress token was provided
|
||||
if progressToken != nil {
|
||||
result.Result.Meta = map[string]any{
|
||||
"progressToken": progressToken,
|
||||
progressTokenKey: progressToken,
|
||||
}
|
||||
}
|
||||
|
||||
s.sendResponse(client, req.ID, result)
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processGetPrompt processes the prompts/get request
|
||||
func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processGetPrompt(ctx context.Context, client *mcpClient, req Request) {
|
||||
type GetPromptParams struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]string `json:"arguments,omitempty"`
|
||||
@@ -509,7 +514,7 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
|
||||
var params GetPromptParams
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -519,7 +524,7 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
s.promptsLock.Unlock()
|
||||
if !exists {
|
||||
message := fmt.Sprintf("Prompt '%s' not found", params.Name)
|
||||
s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams)
|
||||
s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -529,12 +534,15 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
missingArgs := validatePromptArguments(prompt, params.Arguments)
|
||||
if len(missingArgs) > 0 {
|
||||
message := fmt.Sprintf("Missing required arguments: %s", strings.Join(missingArgs, ", "))
|
||||
s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams)
|
||||
s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
// Apply default values for missing optional arguments
|
||||
args := applyDefaultArguments(prompt, params.Arguments)
|
||||
// Ensure arguments are initialized to an empty map if nil
|
||||
if params.Arguments == nil {
|
||||
params.Arguments = make(map[string]string)
|
||||
}
|
||||
args := params.Arguments
|
||||
|
||||
// Generate messages using handler or static content
|
||||
var messages []PromptMessage
|
||||
@@ -542,17 +550,17 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
|
||||
if prompt.Handler != nil {
|
||||
// Use dynamic handler to generate messages
|
||||
logx.Info("Using prompt handler to generate content")
|
||||
messages, err = prompt.Handler(args)
|
||||
messages, err = prompt.Handler(ctx, args)
|
||||
if err != nil {
|
||||
logx.Errorf("Error from prompt handler: %v", err)
|
||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError)
|
||||
s.sendErrorResponse(ctx, client, req.ID,
|
||||
fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// No handler, generate messages from static content
|
||||
var messageText string
|
||||
if prompt.Content != "" {
|
||||
if len(prompt.Content) > 0 {
|
||||
messageText = prompt.Content
|
||||
|
||||
// Apply argument substitutions to static content
|
||||
@@ -560,21 +568,13 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
placeholder := fmt.Sprintf("{{%s}}", key)
|
||||
messageText = strings.Replace(messageText, placeholder, value, -1)
|
||||
}
|
||||
} else {
|
||||
// No content, use a default fallback
|
||||
topic := "this topic"
|
||||
if t, ok := args["topic"]; ok && t != "" {
|
||||
topic = t
|
||||
}
|
||||
messageText = fmt.Sprintf("Tell me about %s", topic)
|
||||
}
|
||||
|
||||
// Create a single user message with the content
|
||||
messages = []PromptMessage{
|
||||
{
|
||||
Role: roleUser,
|
||||
Role: RoleUser,
|
||||
Content: TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: messageText,
|
||||
},
|
||||
},
|
||||
@@ -587,49 +587,14 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
Messages []PromptMessage `json:"messages"`
|
||||
}{
|
||||
Description: prompt.Description,
|
||||
Messages: messages,
|
||||
Messages: toTypedPromptMessages(messages),
|
||||
}
|
||||
|
||||
s.sendResponse(client, req.ID, result)
|
||||
}
|
||||
|
||||
// validatePromptArguments checks if all required arguments are provided
|
||||
// Returns a list of missing required arguments
|
||||
func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string {
|
||||
var missingArgs []string
|
||||
|
||||
for _, arg := range prompt.Arguments {
|
||||
if arg.Required {
|
||||
if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 {
|
||||
missingArgs = append(missingArgs, arg.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return missingArgs
|
||||
}
|
||||
|
||||
// applyDefaultArguments adds default values for missing optional arguments
|
||||
func applyDefaultArguments(prompt Prompt, providedArgs map[string]string) map[string]string {
|
||||
result := make(map[string]string)
|
||||
|
||||
// Copy all provided arguments
|
||||
for k, v := range providedArgs {
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
// Add defaults for missing arguments
|
||||
for _, arg := range prompt.Arguments {
|
||||
if _, exists := result[arg.Name]; !exists && arg.Default != "" {
|
||||
result[arg.Name] = arg.Default
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processToolCall processes the tools/call request
|
||||
func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processToolCall(ctx context.Context, client *mcpClient, req Request) {
|
||||
var toolCallParams struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
@@ -642,7 +607,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
// If it's a RawMessage (JSON), unmarshal it
|
||||
if err := json.Unmarshal(req.Params, &toolCallParams); err != nil {
|
||||
logx.Errorf("Failed to unmarshal tool call params: %v", err)
|
||||
s.sendErrorResponse(client, req.ID, "Invalid tool call parameters", errCodeInvalidParams)
|
||||
s.sendErrorResponse(ctx, client, req.ID, "Invalid tool call parameters", errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -654,15 +619,11 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
tool, exists := s.tools[toolCallParams.Name]
|
||||
s.toolsLock.Unlock()
|
||||
if !exists {
|
||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Tool '%s' not found",
|
||||
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Tool '%s' not found",
|
||||
toolCallParams.Name), errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a context with the configured timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), s.conf.Mcp.ToolTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Log parameters before execution
|
||||
logx.Infof("Executing tool '%s' with arguments: %#v", toolCallParams.Name, toolCallParams.Arguments)
|
||||
|
||||
@@ -671,6 +632,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
var err error
|
||||
|
||||
// Create a channel to receive the result
|
||||
// make sure to have 1 size buffer to avoid channel leak if timeout
|
||||
resultCh := make(chan struct {
|
||||
result any
|
||||
err error
|
||||
@@ -678,7 +640,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
|
||||
// Execute the tool handler in a goroutine
|
||||
go func() {
|
||||
toolResult, toolErr := tool.Handler(toolCallParams.Arguments)
|
||||
toolResult, toolErr := tool.Handler(ctx, toolCallParams.Arguments)
|
||||
resultCh <- struct {
|
||||
result any
|
||||
err error
|
||||
@@ -694,9 +656,9 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
result = res.result
|
||||
err = res.err
|
||||
case <-ctx.Done():
|
||||
// Handle timeout
|
||||
logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.ToolTimeout, toolCallParams.Name)
|
||||
s.sendErrorResponse(client, req.ID, "Tool execution timed out", errCodeTimeout)
|
||||
// Handle request timeout
|
||||
logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.MessageTimeout, toolCallParams.Name)
|
||||
s.sendErrorResponse(ctx, client, req.ID, "Tool execution timed out", errCodeTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -710,7 +672,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
// Add meta information if progress token was provided
|
||||
if progressToken != nil {
|
||||
callToolResult.Result.Meta = map[string]any{
|
||||
"progressToken": progressToken,
|
||||
progressTokenKey: progressToken,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -722,12 +684,11 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
|
||||
callToolResult.Content = []any{
|
||||
TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: fmt.Sprintf("Error: %v", err),
|
||||
},
|
||||
}
|
||||
callToolResult.IsError = true
|
||||
s.sendResponse(client, req.ID, callToolResult)
|
||||
s.sendResponse(ctx, client, req.ID, callToolResult)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -736,10 +697,9 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
case string:
|
||||
// Simple string becomes text content
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: v,
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
},
|
||||
})
|
||||
case map[string]any:
|
||||
@@ -749,69 +709,63 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
jsonStr = []byte(err.Error())
|
||||
}
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: string(jsonStr),
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
},
|
||||
})
|
||||
case TextContent:
|
||||
// Direct TextContent object
|
||||
callToolResult.Content = append(callToolResult.Content, v)
|
||||
case ImageContent:
|
||||
// Direct ImageContent object
|
||||
callToolResult.Content = append(callToolResult.Content, v)
|
||||
case []any:
|
||||
// Array of content items
|
||||
callToolResult.Content = v
|
||||
case ToolResult:
|
||||
// Handle legacy ToolResult type
|
||||
switch v.Type {
|
||||
case contentTypeText:
|
||||
case ContentTypeText:
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: fmt.Sprintf("%v", v.Content),
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
},
|
||||
})
|
||||
case contentTypeImage:
|
||||
case ContentTypeImage:
|
||||
if imgData, ok := v.Content.(map[string]any); ok {
|
||||
callToolResult.Content = append(callToolResult.Content, ImageContent{
|
||||
Type: contentTypeImage,
|
||||
Data: fmt.Sprintf("%v", imgData["data"]),
|
||||
MimeType: fmt.Sprintf("%v", imgData["mimeType"]),
|
||||
})
|
||||
}
|
||||
default:
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: fmt.Sprintf("%v", v.Content),
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
},
|
||||
})
|
||||
}
|
||||
default:
|
||||
// For any other type, convert to string
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: fmt.Sprintf("%v", v),
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
callToolResult.Content = toTypedContents(callToolResult.Content)
|
||||
logx.Infof("Tool call result: %#v", callToolResult)
|
||||
s.sendResponse(client, req.ID, callToolResult)
|
||||
|
||||
s.sendResponse(ctx, client, req.ID, callToolResult)
|
||||
}
|
||||
|
||||
// processResourcesRead processes the resources/read request
|
||||
func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processResourcesRead(ctx context.Context, client *mcpClient, req Request) {
|
||||
var params ResourceReadParams
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -821,7 +775,7 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
||||
s.resourcesLock.Unlock()
|
||||
|
||||
if !exists {
|
||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||
params.URI), errCodeResourceNotFound)
|
||||
return
|
||||
}
|
||||
@@ -837,14 +791,14 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
||||
},
|
||||
},
|
||||
}
|
||||
s.sendResponse(client, req.ID, result)
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
return
|
||||
}
|
||||
|
||||
// Execute the resource handler
|
||||
content, err := resource.Handler()
|
||||
content, err := resource.Handler(ctx)
|
||||
if err != nil {
|
||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error reading resource: %v", err),
|
||||
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Error reading resource: %v", err),
|
||||
errCodeInternalError)
|
||||
return
|
||||
}
|
||||
@@ -865,14 +819,14 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
||||
Contents: []ResourceContent{content},
|
||||
}
|
||||
|
||||
s.sendResponse(client, req.ID, result)
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processResourceSubscribe processes the resources/subscribe request
|
||||
func (s *sseMcpServer) processResourceSubscribe(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processResourceSubscribe(ctx context.Context, client *mcpClient, req Request) {
|
||||
var params ResourceSubscribeParams
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -882,19 +836,17 @@ func (s *sseMcpServer) processResourceSubscribe(client *mcpClient, req Request)
|
||||
s.resourcesLock.Unlock()
|
||||
|
||||
if !exists {
|
||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||
params.URI), errCodeResourceNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Send success response for the subscription
|
||||
s.sendResponse(client, req.ID, struct{}{})
|
||||
|
||||
logx.Infof("Client %s subscribed to resource '%s'", client.id, params.URI)
|
||||
s.sendResponse(ctx, client, req.ID, struct{}{})
|
||||
}
|
||||
|
||||
// processNotificationCancelled processes the notifications/cancelled notification
|
||||
func (s *sseMcpServer) processNotificationCancelled(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processNotificationCancelled(ctx context.Context, client *mcpClient, req Request) {
|
||||
// Extract the requestId that was canceled
|
||||
type CancelParams struct {
|
||||
RequestId int64 `json:"requestId"`
|
||||
@@ -918,18 +870,17 @@ func (s *sseMcpServer) processNotificationInitialized(client *mcpClient) {
|
||||
}
|
||||
|
||||
// processPing processes the ping request and responds immediately
|
||||
func (s *sseMcpServer) processPing(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processPing(ctx context.Context, client *mcpClient, req Request) {
|
||||
// A ping request should simply respond with an empty result to confirm the server is alive
|
||||
logx.Infof("Received ping request with ID: %d", req.ID)
|
||||
|
||||
// Send an empty response with client's original request ID
|
||||
s.sendResponse(client, req.ID, struct{}{})
|
||||
|
||||
logx.Infof("Sent ping response for ID: %d", req.ID)
|
||||
s.sendResponse(ctx, client, req.ID, struct{}{})
|
||||
}
|
||||
|
||||
// sendErrorResponse sends an error response via the SSE channel
|
||||
func (s *sseMcpServer) sendErrorResponse(client *mcpClient, id int64, message string, code int) {
|
||||
func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
|
||||
id int64, message string, code int) {
|
||||
errorResponse := struct {
|
||||
JsonRpc string `json:"jsonrpc"`
|
||||
ID int64 `json:"id"`
|
||||
@@ -949,11 +900,17 @@ func (s *sseMcpServer) sendErrorResponse(client *mcpClient, id int64, message st
|
||||
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||
logx.Infof("Sending error for ID %d: %s", id, sseMessage)
|
||||
|
||||
client.channel <- sseMessage
|
||||
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||
select {
|
||||
case client.channel <- sseMessage:
|
||||
default:
|
||||
// Channel buffer is full, log warning and continue
|
||||
logx.Infof("Client %s channel is full while sending error response with ID %d", client.id, id)
|
||||
}
|
||||
}
|
||||
|
||||
// sendResponse sends a success response via the SSE channel
|
||||
func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) {
|
||||
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id int64, result any) {
|
||||
response := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: id,
|
||||
@@ -962,7 +919,7 @@ func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) {
|
||||
|
||||
jsonData, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
s.sendErrorResponse(client, id, "Failed to marshal response", errCodeInternalError)
|
||||
s.sendErrorResponse(ctx, client, id, "Failed to marshal response", errCodeInternalError)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -970,5 +927,11 @@ func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) {
|
||||
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||
logx.Infof("Sending response for ID %d: %s", id, sseMessage)
|
||||
|
||||
client.channel <- sseMessage
|
||||
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||
select {
|
||||
case client.channel <- sseMessage:
|
||||
default:
|
||||
// Channel buffer is full, log warning and continue
|
||||
logx.Infof("Client %s channel is full while sending response with ID %d", client.id, id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ host: localhost
|
||||
port: 8080
|
||||
mcp:
|
||||
name: mcp-test-server
|
||||
toolTimeout: 5s
|
||||
messageTimeout: 5s
|
||||
`
|
||||
|
||||
var c McpConf
|
||||
@@ -82,7 +82,6 @@ func (m *mockMcpServer) registerExampleTool() {
|
||||
Name: "test.tool",
|
||||
Description: "A test tool",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"input": map[string]any{
|
||||
"type": "string",
|
||||
@@ -91,7 +90,7 @@ func (m *mockMcpServer) registerExampleTool() {
|
||||
},
|
||||
Required: []string{"input"},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
input, ok := params["input"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid input parameter")
|
||||
@@ -135,7 +134,7 @@ port: 8080
|
||||
mcp:
|
||||
cors:
|
||||
- http://localhost:3000
|
||||
toolTimeout: 5s
|
||||
messageTimeout: 5s
|
||||
`
|
||||
|
||||
var c McpConf
|
||||
@@ -186,7 +185,6 @@ func TestRegisterTool(t *testing.T) {
|
||||
Name: "example.tool",
|
||||
Description: "An example tool",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"input": map[string]any{
|
||||
"type": "string",
|
||||
@@ -194,7 +192,7 @@ func TestRegisterTool(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return "result", nil
|
||||
},
|
||||
}
|
||||
@@ -280,7 +278,7 @@ func TestToolsList(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processListTools(client, req)
|
||||
mock.server.processListTools(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -328,7 +326,7 @@ func TestToolCallBasic(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -355,8 +353,7 @@ func TestToolCallBasic(t *testing.T) {
|
||||
|
||||
// Verify the response content
|
||||
assert.Len(t, parsed.Result.Content, 1, "Response should contain one content item")
|
||||
assert.Equal(t, "text", parsed.Result.Content[0]["type"], "Content type should be text")
|
||||
assert.Equal(t, "Processed: test-input", parsed.Result.Content[0]["text"], "Tool result incorrect")
|
||||
assert.Equal(t, "Processed: test-input", parsed.Result.Content[0][ContentTypeText], "Tool result incorrect")
|
||||
assert.False(t, parsed.Result.IsError, "Response should not be an error")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for tool call response")
|
||||
@@ -373,10 +370,9 @@ func TestToolCallMapResult(t *testing.T) {
|
||||
Name: "map.tool",
|
||||
Description: "A tool that returns a map result",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
// Return a complex nested map structure
|
||||
return map[string]any{
|
||||
"string": "value",
|
||||
@@ -417,7 +413,7 @@ func TestToolCallMapResult(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -445,13 +441,8 @@ func TestToolCallMapResult(t *testing.T) {
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's a text content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "text", contentType, "Content type should be text")
|
||||
|
||||
// Get the text content which should be our JSON
|
||||
text, ok := firstItem["text"].(string)
|
||||
text, ok := firstItem[ContentTypeText].(string)
|
||||
require.True(t, ok, "Content should have text")
|
||||
|
||||
// Verify the text is valid JSON and contains our data
|
||||
@@ -496,10 +487,9 @@ func TestToolCallArrayResult(t *testing.T) {
|
||||
Name: "array.tool",
|
||||
Description: "A tool that returns an array result",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
// Return an array of mixed content types
|
||||
return []any{
|
||||
"string item",
|
||||
@@ -536,7 +526,7 @@ func TestToolCallArrayResult(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -574,16 +564,14 @@ func TestToolCallTextContentResult(t *testing.T) {
|
||||
Name: "text.content.tool",
|
||||
Description: "A tool that returns a TextContent result",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
// Return a TextContent object directly
|
||||
return TextContent{
|
||||
Type: "text",
|
||||
Text: "This is a direct TextContent result",
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
Priority: func() *float64 { p := 0.9; return &p }(),
|
||||
},
|
||||
}, nil
|
||||
@@ -614,7 +602,7 @@ func TestToolCallTextContentResult(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -642,16 +630,6 @@ func TestToolCallTextContentResult(t *testing.T) {
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's a text content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "text", contentType, "Content type should be text")
|
||||
|
||||
// Check text content
|
||||
text, ok := firstItem["text"].(string)
|
||||
require.True(t, ok, "Content should have text")
|
||||
assert.Equal(t, "This is a direct TextContent result", text, "Text content should match")
|
||||
|
||||
// Check annotations
|
||||
annotations, ok := firstItem["annotations"].(map[string]any)
|
||||
require.True(t, ok, "Should have annotations")
|
||||
@@ -679,13 +657,11 @@ func TestToolCallImageContentResult(t *testing.T) {
|
||||
Name: "image.content.tool",
|
||||
Description: "A tool that returns an ImageContent result",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
// Return an ImageContent object directly
|
||||
return ImageContent{
|
||||
Type: "image",
|
||||
Data: "dGVzdCBpbWFnZSBkYXRhIChiYXNlNjQgZW5jb2RlZCk=", // "test image data (base64 encoded)" in base64
|
||||
MimeType: "image/png",
|
||||
}, nil
|
||||
@@ -716,7 +692,7 @@ func TestToolCallImageContentResult(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -744,11 +720,6 @@ func TestToolCallImageContentResult(t *testing.T) {
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's an image content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "image", contentType, "Content type should be image")
|
||||
|
||||
// Check image data
|
||||
data, ok := firstItem["data"].(string)
|
||||
require.True(t, ok, "Content should have data")
|
||||
@@ -773,12 +744,12 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
Name: "toolresult.tool",
|
||||
Description: "A tool that returns a ToolResult object",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Type: ContentTypeObject,
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return ToolResult{
|
||||
Type: "text",
|
||||
Type: ContentTypeText,
|
||||
Content: "This is a ToolResult with text content type",
|
||||
}, nil
|
||||
},
|
||||
@@ -790,10 +761,10 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
Name: "toolresult.image.tool",
|
||||
Description: "A tool that returns a ToolResult with image content",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Type: ContentTypeObject,
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return ToolResult{
|
||||
Type: "image",
|
||||
Content: map[string]any{
|
||||
@@ -810,10 +781,9 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
Name: "toolresult.audio.tool",
|
||||
Description: "A tool that returns a ToolResult with audio content",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
// Test with image type
|
||||
return ToolResult{
|
||||
Type: "audio",
|
||||
@@ -831,10 +801,9 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
Name: "toolresult.int.tool",
|
||||
Description: "A tool that returns a ToolResult with int content",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return 2, nil
|
||||
},
|
||||
}
|
||||
@@ -845,10 +814,9 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
Name: "toolresult.bad.tool",
|
||||
Description: "A tool that returns a ToolResult with bad content",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return map[string]any{
|
||||
"type": "custom",
|
||||
"data": make(chan int),
|
||||
@@ -881,7 +849,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -909,13 +877,8 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's a text content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "text", contentType, "Content type should be text")
|
||||
|
||||
// Check text content
|
||||
text, ok := firstItem["text"].(string)
|
||||
text, ok := firstItem[ContentTypeText].(string)
|
||||
require.True(t, ok, "Content should have text")
|
||||
assert.Equal(t, "This is a ToolResult with text content type", text, "Text content should match")
|
||||
|
||||
@@ -947,7 +910,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -975,11 +938,6 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's an image content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "image", contentType, "Content type should be image")
|
||||
|
||||
// Check image data and mime type
|
||||
data, ok := firstItem["data"].(string)
|
||||
require.True(t, ok, "Content should have data")
|
||||
@@ -1017,7 +975,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -1040,15 +998,6 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
content, ok := result["content"].([]any)
|
||||
require.True(t, ok, "Result should have a content array")
|
||||
require.NotEmpty(t, content, "Content should not be empty")
|
||||
|
||||
// The first content item should be converted from ToolResult to ImageContent
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's an image content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "text", contentType, "Content type should be image")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for tool call response")
|
||||
}
|
||||
@@ -1077,7 +1026,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -1100,15 +1049,6 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
content, ok := result["content"].([]any)
|
||||
require.True(t, ok, "Result should have a content array")
|
||||
require.NotEmpty(t, content, "Content should not be empty")
|
||||
|
||||
// The first content item should be converted from ToolResult to ImageContent
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's an image content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "text", contentType, "Content type should be image")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for tool call response")
|
||||
}
|
||||
@@ -1137,7 +1077,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -1159,10 +1099,9 @@ func TestToolCallError(t *testing.T) {
|
||||
Name: "error.tool",
|
||||
Description: "A tool that returns an error",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return nil, fmt.Errorf("tool execution failed")
|
||||
},
|
||||
})
|
||||
@@ -1189,7 +1128,7 @@ func TestToolCallError(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Check the response
|
||||
select {
|
||||
@@ -1207,20 +1146,16 @@ func TestToolCallTimeout(t *testing.T) {
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
|
||||
// Set a very short timeout for testing
|
||||
mock.server.conf.Mcp.ToolTimeout = 10 * time.Millisecond
|
||||
|
||||
// Register a tool that times out
|
||||
err := mock.server.RegisterTool(Tool{
|
||||
Name: "timeout.tool",
|
||||
Description: "A tool that times out",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
time.Sleep(50 * time.Millisecond) // Sleep longer than timeout
|
||||
return "this should never be returned", nil
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
<-ctx.Done()
|
||||
return nil, fmt.Errorf("tool execution timed out")
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -1244,16 +1179,24 @@ func TestToolCallTimeout(t *testing.T) {
|
||||
Method: methodToolsCall,
|
||||
Params: paramBytes,
|
||||
}
|
||||
jsonBody, _ := json.Marshal(req)
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
// Create HTTP request
|
||||
r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", bytes.NewReader(jsonBody))
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
r = r.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Process through handleRequest
|
||||
go mock.server.handleRequest(w, r)
|
||||
|
||||
// Check the response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, "event: message", "Response should have message event")
|
||||
assert.Contains(t, response, `-32001`, "Response should contain a timeout error code")
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for tool call response")
|
||||
}
|
||||
}
|
||||
@@ -1274,7 +1217,7 @@ func TestInitializeAndNotifications(t *testing.T) {
|
||||
Params: json.RawMessage(`{}`),
|
||||
}
|
||||
|
||||
mock.server.processInitialize(client, initReq)
|
||||
mock.server.processInitialize(context.Background(), client, initReq)
|
||||
|
||||
// Check that client is initialized after initialize request
|
||||
assert.True(t, client.initialized, "Client should be marked as initialized after initialize request")
|
||||
@@ -1418,7 +1361,7 @@ func TestNotificationCancelled_badParams(t *testing.T) {
|
||||
}
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
mock.server.processNotificationCancelled(client, cancelReq)
|
||||
mock.server.processNotificationCancelled(context.Background(), client, cancelReq)
|
||||
|
||||
select {
|
||||
case <-client.channel:
|
||||
@@ -1593,7 +1536,7 @@ func TestGetPrompt(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the request
|
||||
mock.server.processGetPrompt(client, promptReq)
|
||||
mock.server.processGetPrompt(context.Background(), client, promptReq)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
@@ -1622,7 +1565,7 @@ func TestGetPrompt(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the request
|
||||
mock.server.processGetPrompt(client, promptReq)
|
||||
mock.server.processGetPrompt(context.Background(), client, promptReq)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
@@ -1636,6 +1579,44 @@ func TestGetPrompt(t *testing.T) {
|
||||
t.Fatal("Timed out waiting for prompt response")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test prompt with nil params", func(t *testing.T) {
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
|
||||
// Create a test client
|
||||
client := addTestClient(mock.server, "test-client", true)
|
||||
|
||||
// Register a test prompt
|
||||
testPrompt := Prompt{
|
||||
Name: "test.prompt",
|
||||
Description: "A test prompt",
|
||||
}
|
||||
mock.server.RegisterPrompt(testPrompt)
|
||||
|
||||
// Create a get prompt request
|
||||
paramBytes, _ := json.Marshal(map[string]any{
|
||||
"name": "test.prompt",
|
||||
})
|
||||
promptReq := Request{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: 1,
|
||||
Method: "prompts/get",
|
||||
Params: paramBytes,
|
||||
}
|
||||
|
||||
// Process the request
|
||||
mock.server.processGetPrompt(context.Background(), client, promptReq)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
_, err := parseEvent(response)
|
||||
assert.NoError(t, err, "Should be able to parse event")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for prompt response")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestBroadcast tests the broadcast functionality
|
||||
@@ -1903,34 +1884,79 @@ func TestNotificationInitialized(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSendBadResponse(t *testing.T) {
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
func TestSendResponse(t *testing.T) {
|
||||
t.Run("bad response", func(t *testing.T) {
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
|
||||
// Create a test client
|
||||
client := addTestClient(mock.server, "test-client", true)
|
||||
// Create a test client
|
||||
client := addTestClient(mock.server, "test-client", true)
|
||||
|
||||
// Create a response
|
||||
response := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: 1,
|
||||
Result: make(chan int),
|
||||
}
|
||||
// Create a response
|
||||
response := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: 1,
|
||||
Result: make(chan int),
|
||||
}
|
||||
|
||||
// Send the response
|
||||
mock.server.sendResponse(client, 1, response)
|
||||
// Send the response
|
||||
mock.server.sendResponse(context.Background(), client, 1, response)
|
||||
|
||||
// Check the response in the client's channel
|
||||
select {
|
||||
case res := <-client.channel:
|
||||
evt, err := parseEvent(res)
|
||||
require.NoError(t, err, "Should parse event without error")
|
||||
errMsg, ok := evt.Data["error"].(map[string]any)
|
||||
require.True(t, ok, "Should have error in response")
|
||||
assert.Equal(t, float64(errCodeInternalError), errMsg["code"], "Error code should match")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for response")
|
||||
}
|
||||
// Check the response in the client's channel
|
||||
select {
|
||||
case res := <-client.channel:
|
||||
evt, err := parseEvent(res)
|
||||
require.NoError(t, err, "Should parse event without error")
|
||||
errMsg, ok := evt.Data["error"].(map[string]any)
|
||||
require.True(t, ok, "Should have error in response")
|
||||
assert.Equal(t, float64(errCodeInternalError), errMsg["code"], "Error code should match")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for response")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("channel full", func(t *testing.T) {
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
|
||||
// Create a test client
|
||||
client := addTestClient(mock.server, "test-client", true)
|
||||
for i := 0; i < eventChanSize; i++ {
|
||||
client.channel <- "test"
|
||||
}
|
||||
|
||||
// Create a response
|
||||
response := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: 1,
|
||||
Result: "foo",
|
||||
}
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
// Send the response
|
||||
mock.server.sendResponse(context.Background(), client, 1, response)
|
||||
// Check the response in the client's channel
|
||||
assert.Contains(t, buf.String(), "channel is full")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSendErrorResponse(t *testing.T) {
|
||||
t.Run("channel full", func(t *testing.T) {
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
|
||||
// Create a test client
|
||||
client := addTestClient(mock.server, "test-client", true)
|
||||
for i := 0; i < eventChanSize; i++ {
|
||||
client.channel <- "test"
|
||||
}
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
// Send the response
|
||||
mock.server.sendErrorResponse(context.Background(), client, 1, "foo", errCodeInternalError)
|
||||
// Check the response in the client's channel
|
||||
assert.Contains(t, buf.String(), "channel is full")
|
||||
})
|
||||
}
|
||||
|
||||
// TestMethodToolsCall tests the handling of tools/call method through handleRequest
|
||||
@@ -2028,8 +2054,7 @@ func TestMethodToolsCall(t *testing.T) {
|
||||
if len(content) > 0 {
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
if ok {
|
||||
assert.Equal(t, "text", firstItem["type"], "Content type should be text")
|
||||
assert.Contains(t, firstItem["text"], "Processed: test-input", "Content should include processed input")
|
||||
assert.Contains(t, firstItem[ContentTypeText], "Processed: test-input", "Content should include processed input")
|
||||
}
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
@@ -2145,7 +2170,6 @@ func TestMethodPromptsGet(t *testing.T) {
|
||||
{
|
||||
Name: "topic",
|
||||
Description: "Topic to discuss",
|
||||
Default: "artificial intelligence",
|
||||
},
|
||||
},
|
||||
Content: "Hello {{name}}! Let's talk about {{topic}}.",
|
||||
@@ -2227,14 +2251,12 @@ func TestMethodPromptsGet(t *testing.T) {
|
||||
if len(messages) > 0 {
|
||||
message, ok := messages[0].(map[string]any)
|
||||
require.True(t, ok, "Message should be an object")
|
||||
assert.Equal(t, "user", message["role"], "Role should be 'user'")
|
||||
assert.Equal(t, string(RoleUser), message["role"], "Role should be 'user'")
|
||||
|
||||
content, ok := message["content"].(map[string]any)
|
||||
require.True(t, ok, "Should have content object")
|
||||
assert.Equal(t, "text", content["type"], "Content type should be text")
|
||||
assert.Contains(t, content["text"], "Hello Test User", "Content should include the name argument")
|
||||
assert.Contains(t, content["text"], "about artificial intelligence",
|
||||
"Content should include the default topic argument")
|
||||
assert.Equal(t, ContentTypeText, content["type"], "Content type should be text")
|
||||
assert.Contains(t, content[ContentTypeText], "Hello Test User", "Content should include the name argument")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for prompt get response")
|
||||
@@ -2255,27 +2277,24 @@ func TestMethodPromptsGet(t *testing.T) {
|
||||
{
|
||||
Name: "question",
|
||||
Description: "User's question",
|
||||
Default: "How does this work?",
|
||||
},
|
||||
},
|
||||
Handler: func(args map[string]string) ([]PromptMessage, error) {
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]PromptMessage, error) {
|
||||
username := args["username"]
|
||||
question := args["question"]
|
||||
|
||||
// Create a system message
|
||||
systemMessage := PromptMessage{
|
||||
Role: "system",
|
||||
Role: RoleAssistant,
|
||||
Content: TextContent{
|
||||
Type: "text",
|
||||
Text: "You are a helpful assistant.",
|
||||
},
|
||||
}
|
||||
|
||||
// Create a user message
|
||||
userMessage := PromptMessage{
|
||||
Role: "user",
|
||||
Role: RoleUser,
|
||||
Content: TextContent{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("Hi, I'm %s and I'm wondering: %s", username, question),
|
||||
},
|
||||
}
|
||||
@@ -2340,20 +2359,20 @@ func TestMethodPromptsGet(t *testing.T) {
|
||||
|
||||
// Check message content
|
||||
if len(messages) >= 2 {
|
||||
// First message should be system
|
||||
// First message should be assistant
|
||||
message1, _ := messages[0].(map[string]any)
|
||||
assert.Equal(t, "system", message1["role"], "First role should be 'system'")
|
||||
assert.Equal(t, string(RoleAssistant), message1["role"], "First role should be 'system'")
|
||||
|
||||
content1, _ := message1["content"].(map[string]any)
|
||||
assert.Contains(t, content1["text"], "helpful assistant", "System message should be correct")
|
||||
assert.Contains(t, content1[ContentTypeText], "helpful assistant", "System message should be correct")
|
||||
|
||||
// Second message should be user
|
||||
message2, _ := messages[1].(map[string]any)
|
||||
assert.Equal(t, "user", message2["role"], "Second role should be 'user'")
|
||||
assert.Equal(t, string(RoleUser), message2["role"], "Second role should be 'user'")
|
||||
|
||||
content2, _ := message2["content"].(map[string]any)
|
||||
assert.Contains(t, content2["text"], "Dynamic User", "User message should contain username")
|
||||
assert.Contains(t, content2["text"], "How to test this?", "User message should contain question")
|
||||
assert.Contains(t, content2[ContentTypeText], "Dynamic User", "User message should contain username")
|
||||
assert.Contains(t, content2[ContentTypeText], "How to test this?", "User message should contain question")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for prompt get response")
|
||||
@@ -2459,7 +2478,7 @@ func TestMethodPromptsGet(t *testing.T) {
|
||||
Name: "error-handler-prompt",
|
||||
Description: "A prompt with a handler that returns an error",
|
||||
Arguments: []PromptArgument{},
|
||||
Handler: func(args map[string]string) ([]PromptMessage, error) {
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]PromptMessage, error) {
|
||||
return nil, fmt.Errorf("test handler error")
|
||||
},
|
||||
}
|
||||
@@ -2583,7 +2602,7 @@ func TestMethodResourcesList(t *testing.T) {
|
||||
URI: "file:///test/resource.txt",
|
||||
Description: "A test resource with handler",
|
||||
MimeType: "text/plain",
|
||||
Handler: func() (ResourceContent, error) {
|
||||
Handler: func(ctx context.Context) (ResourceContent, error) {
|
||||
return ResourceContent{
|
||||
URI: "file:///test/resource.txt",
|
||||
MimeType: "text/plain",
|
||||
@@ -2654,7 +2673,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
URI: "file:///test/resource.txt",
|
||||
Description: "A test resource with handler",
|
||||
MimeType: "text/plain",
|
||||
Handler: func() (ResourceContent, error) {
|
||||
Handler: func(ctx context.Context) (ResourceContent, error) {
|
||||
return ResourceContent{
|
||||
URI: "file:///test/resource.txt",
|
||||
MimeType: "text/plain",
|
||||
@@ -2729,7 +2748,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
require.True(t, ok, "Content should be an object")
|
||||
assert.Equal(t, "file:///test/resource.txt", content["uri"], "URI should match")
|
||||
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match")
|
||||
assert.Equal(t, "This is test resource content", content["text"], "Text content should match")
|
||||
assert.Equal(t, "This is test resource content", content[ContentTypeText], "Text content should match")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for resource read response")
|
||||
@@ -2799,7 +2818,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
require.True(t, ok, "Content should be an object")
|
||||
assert.Equal(t, "file:///test/no-handler.txt", content["uri"], "URI should match")
|
||||
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match")
|
||||
_, ok = content["text"]
|
||||
_, ok = content[ContentTypeText]
|
||||
assert.False(t, ok, "Text content should be empty string")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
@@ -2880,7 +2899,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
client := addTestClient(mock.server, "test-client-resources", true)
|
||||
|
||||
// Process through handleRequest
|
||||
mock.server.processResourcesRead(client, req)
|
||||
mock.server.processResourcesRead(context.Background(), client, req)
|
||||
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
@@ -2898,7 +2917,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
URI: "file:///test/error.txt",
|
||||
Description: "A test resource with handler that returns error",
|
||||
MimeType: "text/plain",
|
||||
Handler: func() (ResourceContent, error) {
|
||||
Handler: func(ctx context.Context) (ResourceContent, error) {
|
||||
return ResourceContent{}, fmt.Errorf("test handler error")
|
||||
},
|
||||
}
|
||||
@@ -2946,7 +2965,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
URI: "file:///test/missing-fields.txt",
|
||||
Description: "A test resource with handler that returns content missing fields",
|
||||
MimeType: "text/plain",
|
||||
Handler: func() (ResourceContent, error) {
|
||||
Handler: func(ctx context.Context) (ResourceContent, error) {
|
||||
// Return ResourceContent without URI and MimeType
|
||||
return ResourceContent{
|
||||
Text: "Content with missing fields",
|
||||
@@ -3006,7 +3025,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
require.True(t, ok, "Content should be an object")
|
||||
assert.Equal(t, "file:///test/missing-fields.txt", content["uri"], "URI should be filled from request")
|
||||
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should be filled from resource")
|
||||
assert.Equal(t, "Content with missing fields", content["text"], "Text content should match")
|
||||
assert.Equal(t, "Content with missing fields", content[ContentTypeText], "Text content should match")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for resource read response")
|
||||
@@ -3159,7 +3178,7 @@ func TestMethodResourcesSubscribe(t *testing.T) {
|
||||
}
|
||||
|
||||
client := addTestClient(mock.server, "test-client-sub-not-found", true)
|
||||
mock.server.processResourceSubscribe(client, req)
|
||||
mock.server.processResourceSubscribe(context.Background(), client, req)
|
||||
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
@@ -3268,7 +3287,7 @@ func TestToolCallUnmarshalError(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call directly
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Check for error response about invalid JSON
|
||||
select {
|
||||
@@ -3316,7 +3335,7 @@ func TestToolCallWithInvalidParams(t *testing.T) {
|
||||
jsonBody, _ := json.Marshal(req)
|
||||
|
||||
// Create HTTP request
|
||||
r := httptest.NewRequest("POST", "/?session_id=test-client-invalid-json", bytes.NewReader(jsonBody))
|
||||
r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-invalid-json", bytes.NewReader(jsonBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Process through handleRequest
|
||||
|
||||
42
mcp/types.go
42
mcp/types.go
@@ -1,6 +1,7 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
@@ -45,19 +46,18 @@ type ListToolsResult struct {
|
||||
|
||||
// Message Content Types
|
||||
|
||||
// roleType represents the sender or recipient of messages in a conversation
|
||||
type roleType string
|
||||
// RoleType represents the sender or recipient of messages in a conversation
|
||||
type RoleType string
|
||||
|
||||
// PromptArgument defines a single argument that can be passed to a prompt
|
||||
type PromptArgument struct {
|
||||
Name string `json:"name"` // Argument name
|
||||
Description string `json:"description,omitempty"` // Human-readable description
|
||||
Required bool `json:"required,omitempty"` // Whether this argument is required
|
||||
Default string `json:"default,omitempty"` // Default value if not provided
|
||||
}
|
||||
|
||||
// PromptHandler is a function that dynamically generates prompt content
|
||||
type PromptHandler func(args map[string]string) ([]PromptMessage, error)
|
||||
type PromptHandler func(ctx context.Context, args map[string]string) ([]PromptMessage, error)
|
||||
|
||||
// Prompt represents an MCP Prompt definition
|
||||
type Prompt struct {
|
||||
@@ -70,31 +70,43 @@ type Prompt struct {
|
||||
|
||||
// PromptMessage represents a message in a conversation
|
||||
type PromptMessage struct {
|
||||
Role roleType `json:"role"` // Message sender role
|
||||
Role RoleType `json:"role"` // Message sender role
|
||||
Content any `json:"content"` // Message content (TextContent, ImageContent, etc.)
|
||||
}
|
||||
|
||||
// TextContent represents text content in a message
|
||||
type TextContent struct {
|
||||
Type string `json:"type"` // Always "text"
|
||||
Text string `json:"text"` // The text content
|
||||
Annotations *Annotations `json:"annotations,omitempty"` // Optional annotations
|
||||
}
|
||||
|
||||
type typedTextContent struct {
|
||||
Type string `json:"type"`
|
||||
TextContent
|
||||
}
|
||||
|
||||
// ImageContent represents image data in a message
|
||||
type ImageContent struct {
|
||||
Type string `json:"type"` // Always "image"
|
||||
Data string `json:"data"` // Base64-encoded image data
|
||||
MimeType string `json:"mimeType"` // MIME type (e.g., "image/png")
|
||||
}
|
||||
|
||||
type typedImageContent struct {
|
||||
Type string `json:"type"`
|
||||
ImageContent
|
||||
}
|
||||
|
||||
// AudioContent represents audio data in a message
|
||||
type AudioContent struct {
|
||||
Type string `json:"type"` // Always "audio"
|
||||
Data string `json:"data"` // Base64-encoded audio data
|
||||
MimeType string `json:"mimeType"` // MIME type (e.g., "audio/mp3")
|
||||
}
|
||||
|
||||
type typedAudioContent struct {
|
||||
Type string `json:"type"`
|
||||
AudioContent
|
||||
}
|
||||
|
||||
// FileContent represents file content
|
||||
type FileContent struct {
|
||||
URI string `json:"uri"` // URI identifying the file
|
||||
@@ -115,16 +127,14 @@ type EmbeddedResource struct {
|
||||
|
||||
// Annotations provides additional metadata for content
|
||||
type Annotations struct {
|
||||
Audience []roleType `json:"audience,omitempty"` // Who should see this content
|
||||
Audience []RoleType `json:"audience,omitempty"` // Who should see this content
|
||||
Priority *float64 `json:"priority,omitempty"` // Optional priority (0-1)
|
||||
}
|
||||
|
||||
// Tool-related Types
|
||||
|
||||
// Tool Definition Types
|
||||
|
||||
// ToolHandler is a function that handles tool calls
|
||||
type ToolHandler func(params map[string]any) (any, error)
|
||||
type ToolHandler func(ctx context.Context, params map[string]any) (any, error)
|
||||
|
||||
// Tool represents a Model Context Protocol Tool definition
|
||||
type Tool struct {
|
||||
@@ -136,7 +146,7 @@ type Tool struct {
|
||||
|
||||
// InputSchema represents tool's input schema in JSON Schema format
|
||||
type InputSchema struct {
|
||||
Type string `json:"type"` // Always "object" for tool inputs
|
||||
Type string `json:"type"`
|
||||
Properties map[string]any `json:"properties"` // Property definitions
|
||||
Required []string `json:"required,omitempty"` // List of required properties
|
||||
}
|
||||
@@ -144,8 +154,8 @@ type InputSchema struct {
|
||||
// CallToolResult represents a tool call result that conforms to the MCP schema
|
||||
type CallToolResult struct {
|
||||
Result
|
||||
Content []interface{} `json:"content"` // Content items (text, images, etc.)
|
||||
IsError bool `json:"isError,omitempty"` // True if tool execution failed
|
||||
Content []any `json:"content"` // Content items (text, images, etc.)
|
||||
IsError bool `json:"isError,omitempty"` // True if tool execution failed
|
||||
}
|
||||
|
||||
// Resource represents a Model Context Protocol Resource definition
|
||||
@@ -158,7 +168,7 @@ type Resource struct {
|
||||
}
|
||||
|
||||
// ResourceHandler is a function that handles resource read requests
|
||||
type ResourceHandler func() (ResourceContent, error)
|
||||
type ResourceHandler func(ctx context.Context) (ResourceContent, error)
|
||||
|
||||
// ResourceContent represents the content of a resource
|
||||
type ResourceContent struct {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
@@ -79,7 +80,7 @@ func TestToolStructs(t *testing.T) {
|
||||
},
|
||||
Required: []string{"input"},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return "result", nil
|
||||
},
|
||||
}
|
||||
@@ -145,44 +146,38 @@ func TestResourceStructs(t *testing.T) {
|
||||
func TestContentTypes(t *testing.T) {
|
||||
// Test TextContent
|
||||
textContent := TextContent{
|
||||
Type: "text",
|
||||
Text: "Sample text",
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
Priority: ptr(1.0),
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(textContent)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"type":"text"`)
|
||||
assert.Contains(t, string(data), `"text":"Sample text"`)
|
||||
assert.Contains(t, string(data), `"audience":["user","assistant"]`)
|
||||
assert.Contains(t, string(data), `"priority":1`)
|
||||
|
||||
// Test ImageContent
|
||||
imageContent := ImageContent{
|
||||
Type: "image",
|
||||
Data: "base64data",
|
||||
MimeType: "image/png",
|
||||
}
|
||||
|
||||
data, err = json.Marshal(imageContent)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"type":"image"`)
|
||||
assert.Contains(t, string(data), `"data":"base64data"`)
|
||||
assert.Contains(t, string(data), `"mimeType":"image/png"`)
|
||||
|
||||
// Test AudioContent
|
||||
audioContent := AudioContent{
|
||||
Type: "audio",
|
||||
Data: "base64audio",
|
||||
MimeType: "audio/mp3",
|
||||
}
|
||||
|
||||
data, err = json.Marshal(audioContent)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"type":"audio"`)
|
||||
assert.Contains(t, string(data), `"data":"base64audio"`)
|
||||
assert.Contains(t, string(data), `"mimeType":"audio/mp3"`)
|
||||
}
|
||||
@@ -197,7 +192,6 @@ func TestCallToolResult(t *testing.T) {
|
||||
},
|
||||
Content: []interface{}{
|
||||
TextContent{
|
||||
Type: "text",
|
||||
Text: "Sample result",
|
||||
},
|
||||
},
|
||||
@@ -207,6 +201,6 @@ func TestCallToolResult(t *testing.T) {
|
||||
data, err := json.Marshal(result)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"_meta":{"progressToken":"token123"}`)
|
||||
assert.Contains(t, string(data), `"content":[{"type":"text","text":"Sample result"}]`)
|
||||
assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`)
|
||||
assert.NotContains(t, string(data), `"isError":`)
|
||||
}
|
||||
|
||||
104
mcp/util.go
104
mcp/util.go
@@ -1,15 +1,107 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
import "fmt"
|
||||
|
||||
// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings
|
||||
func formatSSEMessage(event string, data []byte) string {
|
||||
return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data))
|
||||
}
|
||||
|
||||
// ptr is a helper function to get a pointer to a value
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings
|
||||
func formatSSEMessage(event string, data []byte) string {
|
||||
return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data))
|
||||
func toTypedContents(contents []any) []any {
|
||||
typedContents := make([]any, len(contents))
|
||||
|
||||
for i, content := range contents {
|
||||
switch v := content.(type) {
|
||||
case TextContent:
|
||||
typedContents[i] = typedTextContent{
|
||||
Type: ContentTypeText,
|
||||
TextContent: v,
|
||||
}
|
||||
case ImageContent:
|
||||
typedContents[i] = typedImageContent{
|
||||
Type: ContentTypeImage,
|
||||
ImageContent: v,
|
||||
}
|
||||
case AudioContent:
|
||||
typedContents[i] = typedAudioContent{
|
||||
Type: ContentTypeAudio,
|
||||
AudioContent: v,
|
||||
}
|
||||
default:
|
||||
typedContents[i] = typedTextContent{
|
||||
Type: ContentTypeText,
|
||||
TextContent: TextContent{
|
||||
Text: fmt.Sprintf("Unknown content type: %T", v),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return typedContents
|
||||
}
|
||||
|
||||
func toTypedPromptMessages(messages []PromptMessage) []PromptMessage {
|
||||
typedMessages := make([]PromptMessage, len(messages))
|
||||
|
||||
for i, msg := range messages {
|
||||
switch v := msg.Content.(type) {
|
||||
case TextContent:
|
||||
typedMessages[i] = PromptMessage{
|
||||
Role: msg.Role,
|
||||
Content: typedTextContent{
|
||||
Type: ContentTypeText,
|
||||
TextContent: v,
|
||||
},
|
||||
}
|
||||
case ImageContent:
|
||||
typedMessages[i] = PromptMessage{
|
||||
Role: msg.Role,
|
||||
Content: typedImageContent{
|
||||
Type: ContentTypeImage,
|
||||
ImageContent: v,
|
||||
},
|
||||
}
|
||||
case AudioContent:
|
||||
typedMessages[i] = PromptMessage{
|
||||
Role: msg.Role,
|
||||
Content: typedAudioContent{
|
||||
Type: ContentTypeAudio,
|
||||
AudioContent: v,
|
||||
},
|
||||
}
|
||||
default:
|
||||
typedMessages[i] = PromptMessage{
|
||||
Role: msg.Role,
|
||||
Content: typedTextContent{
|
||||
Type: ContentTypeText,
|
||||
TextContent: TextContent{
|
||||
Text: fmt.Sprintf("Unknown content type: %T", v),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return typedMessages
|
||||
}
|
||||
|
||||
// validatePromptArguments checks if all required arguments are provided
|
||||
// Returns a list of missing required arguments
|
||||
func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string {
|
||||
var missingArgs []string
|
||||
|
||||
for _, arg := range prompt.Arguments {
|
||||
if arg.Required {
|
||||
if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 {
|
||||
missingArgs = append(missingArgs, arg.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return missingArgs
|
||||
}
|
||||
|
||||
253
mcp/util_test.go
253
mcp/util_test.go
@@ -8,29 +8,9 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPtr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
v interface{}
|
||||
}{
|
||||
{"string", "test"},
|
||||
{"int", 42},
|
||||
{"bool", true},
|
||||
{"float", 3.14},
|
||||
{"struct", struct{ Name string }{"test"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ptr(tt.v)
|
||||
assert.NotNil(t, got, "ptr() should not return nil")
|
||||
assert.Equal(t, tt.v, *got, "dereferenced pointer should equal input value")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type Event struct {
|
||||
Type string
|
||||
Data map[string]any
|
||||
@@ -61,3 +41,234 @@ func parseEvent(input string) (*Event, error) {
|
||||
|
||||
return &evt, nil
|
||||
}
|
||||
|
||||
// TestToTypedPromptMessages tests the toTypedPromptMessages function
|
||||
func TestToTypedPromptMessages(t *testing.T) {
|
||||
// Test with multiple message types in one test
|
||||
t.Run("MixedContentTypes", func(t *testing.T) {
|
||||
// Create test data with different content types
|
||||
messages := []PromptMessage{
|
||||
{
|
||||
Role: RoleUser,
|
||||
Content: TextContent{
|
||||
Text: "Hello, this is a text message",
|
||||
Annotations: &Annotations{
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
Priority: ptr(0.8),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: RoleAssistant,
|
||||
Content: ImageContent{
|
||||
Data: "base64ImageData",
|
||||
MimeType: "image/jpeg",
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: RoleUser,
|
||||
Content: AudioContent{
|
||||
Data: "base64AudioData",
|
||||
MimeType: "audio/mp3",
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "system",
|
||||
Content: "This is a simple string that should be handled as unknown type",
|
||||
},
|
||||
}
|
||||
|
||||
// Call the function
|
||||
result := toTypedPromptMessages(messages)
|
||||
|
||||
// Validate results
|
||||
require.Len(t, result, 4, "Should return the same number of messages")
|
||||
|
||||
// Validate first message (TextContent)
|
||||
msg := result[0]
|
||||
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
|
||||
|
||||
// Type assertion using reflection since Content is an interface
|
||||
typed, ok := msg.Content.(typedTextContent)
|
||||
require.True(t, ok, "Should be typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||
assert.Equal(t, "Hello, this is a text message", typed.Text, "Text content should be preserved")
|
||||
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
|
||||
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
|
||||
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
|
||||
assert.Equal(t, 0.8, *typed.Annotations.Priority, "Priority value should be preserved")
|
||||
|
||||
// Validate second message (ImageContent)
|
||||
msg = result[1]
|
||||
assert.Equal(t, RoleAssistant, msg.Role, "Role should be preserved")
|
||||
|
||||
// Type assertion for image content
|
||||
typedImg, ok := msg.Content.(typedImageContent)
|
||||
require.True(t, ok, "Should be typedImageContent")
|
||||
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
|
||||
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
|
||||
assert.Equal(t, "image/jpeg", typedImg.MimeType, "MimeType should be preserved")
|
||||
|
||||
// Validate third message (AudioContent)
|
||||
msg = result[2]
|
||||
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
|
||||
|
||||
// Type assertion for audio content
|
||||
typedAudio, ok := msg.Content.(typedAudioContent)
|
||||
require.True(t, ok, "Should be typedAudioContent")
|
||||
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
|
||||
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
|
||||
assert.Equal(t, "audio/mp3", typedAudio.MimeType, "MimeType should be preserved")
|
||||
|
||||
// Validate fourth message (unknown type converted to TextContent)
|
||||
msg = result[3]
|
||||
assert.Equal(t, RoleType("system"), msg.Role, "Role should be preserved")
|
||||
|
||||
// Should be converted to a typedTextContent with error message
|
||||
typedUnknown, ok := msg.Content.(typedTextContent)
|
||||
require.True(t, ok, "Unknown content should be converted to typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
|
||||
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
|
||||
})
|
||||
|
||||
// Test empty input
|
||||
t.Run("EmptyInput", func(t *testing.T) {
|
||||
messages := []PromptMessage{}
|
||||
result := toTypedPromptMessages(messages)
|
||||
assert.Empty(t, result, "Should return empty slice for empty input")
|
||||
})
|
||||
|
||||
// Test with nil annotations
|
||||
t.Run("NilAnnotations", func(t *testing.T) {
|
||||
messages := []PromptMessage{
|
||||
{
|
||||
Role: RoleUser,
|
||||
Content: TextContent{
|
||||
Text: "Text with nil annotations",
|
||||
Annotations: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := toTypedPromptMessages(messages)
|
||||
require.Len(t, result, 1, "Should return one message")
|
||||
|
||||
typed, ok := result[0].Content.(typedTextContent)
|
||||
require.True(t, ok, "Should be typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
|
||||
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
|
||||
})
|
||||
}
|
||||
|
||||
// TestToTypedContents tests the toTypedContents function
|
||||
func TestToTypedContents(t *testing.T) {
|
||||
// Test with multiple content types in one test
|
||||
t.Run("MixedContentTypes", func(t *testing.T) {
|
||||
// Create test data with different content types
|
||||
contents := []any{
|
||||
TextContent{
|
||||
Text: "Hello, this is a text content",
|
||||
Annotations: &Annotations{
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
Priority: ptr(0.7),
|
||||
},
|
||||
},
|
||||
ImageContent{
|
||||
Data: "base64ImageData",
|
||||
MimeType: "image/png",
|
||||
},
|
||||
AudioContent{
|
||||
Data: "base64AudioData",
|
||||
MimeType: "audio/wav",
|
||||
},
|
||||
"This is a simple string that should be handled as unknown type",
|
||||
}
|
||||
|
||||
// Call the function
|
||||
result := toTypedContents(contents)
|
||||
|
||||
// Validate results
|
||||
require.Len(t, result, 4, "Should return the same number of contents")
|
||||
|
||||
// Validate first content (TextContent)
|
||||
typed, ok := result[0].(typedTextContent)
|
||||
require.True(t, ok, "Should be typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||
assert.Equal(t, "Hello, this is a text content", typed.Text, "Text content should be preserved")
|
||||
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
|
||||
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
|
||||
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
|
||||
assert.Equal(t, 0.7, *typed.Annotations.Priority, "Priority value should be preserved")
|
||||
|
||||
// Validate second content (ImageContent)
|
||||
typedImg, ok := result[1].(typedImageContent)
|
||||
require.True(t, ok, "Should be typedImageContent")
|
||||
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
|
||||
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
|
||||
assert.Equal(t, "image/png", typedImg.MimeType, "MimeType should be preserved")
|
||||
|
||||
// Validate third content (AudioContent)
|
||||
typedAudio, ok := result[2].(typedAudioContent)
|
||||
require.True(t, ok, "Should be typedAudioContent")
|
||||
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
|
||||
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
|
||||
assert.Equal(t, "audio/wav", typedAudio.MimeType, "MimeType should be preserved")
|
||||
|
||||
// Validate fourth content (unknown type converted to TextContent)
|
||||
typedUnknown, ok := result[3].(typedTextContent)
|
||||
require.True(t, ok, "Unknown content should be converted to typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
|
||||
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
|
||||
})
|
||||
|
||||
// Test empty input
|
||||
t.Run("EmptyInput", func(t *testing.T) {
|
||||
contents := []any{}
|
||||
result := toTypedContents(contents)
|
||||
assert.Empty(t, result, "Should return empty slice for empty input")
|
||||
})
|
||||
|
||||
// Test with nil annotations
|
||||
t.Run("NilAnnotations", func(t *testing.T) {
|
||||
contents := []any{
|
||||
TextContent{
|
||||
Text: "Text with nil annotations",
|
||||
Annotations: nil,
|
||||
},
|
||||
}
|
||||
|
||||
result := toTypedContents(contents)
|
||||
require.Len(t, result, 1, "Should return one content")
|
||||
|
||||
typed, ok := result[0].(typedTextContent)
|
||||
require.True(t, ok, "Should be typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
|
||||
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
|
||||
})
|
||||
|
||||
// Test with custom struct (should be handled as unknown type)
|
||||
t.Run("CustomStruct", func(t *testing.T) {
|
||||
type CustomContent struct {
|
||||
Data string
|
||||
}
|
||||
|
||||
contents := []any{
|
||||
CustomContent{
|
||||
Data: "custom data",
|
||||
},
|
||||
}
|
||||
|
||||
result := toTypedContents(contents)
|
||||
require.Len(t, result, 1, "Should return one content")
|
||||
|
||||
typed, ok := result[0].(typedTextContent)
|
||||
require.True(t, ok, "Custom struct should be converted to typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||
assert.Contains(t, typed.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||
assert.Contains(t, typed.Text, "CustomContent", "Should mention the actual type")
|
||||
})
|
||||
}
|
||||
|
||||
28
mcp/vars.go
28
mcp/vars.go
@@ -13,6 +13,9 @@ const (
|
||||
|
||||
// Session identifier key used in request URLs
|
||||
sessionIdKey = "session_id"
|
||||
|
||||
// progressTokenKey is used to track progress of long-running tasks
|
||||
progressTokenKey = "progressToken"
|
||||
)
|
||||
|
||||
// Server-Sent Events (SSE) event types
|
||||
@@ -26,11 +29,20 @@ const (
|
||||
|
||||
// Content type identifiers
|
||||
const (
|
||||
// Text content type
|
||||
contentTypeText = "text"
|
||||
// ContentTypeObject is object content type
|
||||
ContentTypeObject = "object"
|
||||
|
||||
// Image content type
|
||||
contentTypeImage = "image"
|
||||
// ContentTypeText is text content type
|
||||
ContentTypeText = "text"
|
||||
|
||||
// ContentTypeImage is image content type
|
||||
ContentTypeImage = "image"
|
||||
|
||||
// ContentTypeAudio is audio content type
|
||||
ContentTypeAudio = "audio"
|
||||
|
||||
// ContentTypeResource is resource content type
|
||||
ContentTypeResource = "resource"
|
||||
)
|
||||
|
||||
// Collection keys for broadcast events
|
||||
@@ -72,11 +84,11 @@ const (
|
||||
|
||||
// User and assistant role definitions
|
||||
const (
|
||||
// The "user" role - the entity asking questions
|
||||
roleUser roleType = "user"
|
||||
// RoleUser is the "user" role - the entity asking questions
|
||||
RoleUser RoleType = "user"
|
||||
|
||||
// The "assistant" role - the entity providing responses
|
||||
roleAssistant roleType = "assistant"
|
||||
// RoleAssistant is the "assistant" role - the entity providing responses
|
||||
RoleAssistant RoleType = "assistant"
|
||||
)
|
||||
|
||||
// Method names as defined in the MCP specification
|
||||
|
||||
@@ -146,13 +146,9 @@ func TestCollectionKeys(t *testing.T) {
|
||||
|
||||
// TestRoleTypes checks that role types are used correctly
|
||||
func TestRoleTypes(t *testing.T) {
|
||||
// Verify role type constants
|
||||
assert.Equal(t, "user", string(roleUser), "User role should be 'user'")
|
||||
assert.Equal(t, "assistant", string(roleAssistant), "Assistant role should be 'assistant'")
|
||||
|
||||
// Test in annotations
|
||||
annotations := Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
}
|
||||
data, err := json.Marshal(annotations)
|
||||
assert.NoError(t, err)
|
||||
|
||||
Reference in New Issue
Block a user