mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-07 15:10:01 +08:00
feat: mcp server sdk (#4794)
Signed-off-by: kevin <wanjunfeng@gmail.com>
This commit is contained in:
40
mcp/config.go
Normal file
40
mcp/config.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/rest"
|
||||
)
|
||||
|
||||
// McpConf defines the configuration for an MCP server.
|
||||
// It embeds rest.RestConf for HTTP server settings
|
||||
// and adds MCP-specific configuration options.
|
||||
type McpConf struct {
|
||||
rest.RestConf
|
||||
Mcp struct {
|
||||
// Name is the server name reported in initialize responses
|
||||
Name string `json:",optional"`
|
||||
|
||||
// Version is the server version reported in initialize responses
|
||||
Version string `json:",default=1.0.0"`
|
||||
|
||||
// ProtocolVersion is the MCP protocol version implemented
|
||||
ProtocolVersion string `json:",default=2024-11-05"`
|
||||
|
||||
// BaseUrl is the base URL for the server, used in SSE endpoint messages
|
||||
// If not set, defaults to http://localhost:{Port}
|
||||
BaseUrl string `json:",optional"`
|
||||
|
||||
// SseEndpoint is the path for Server-Sent Events connections
|
||||
SseEndpoint string `json:",default=/sse"`
|
||||
|
||||
// MessageEndpoint is the path for JSON-RPC requests
|
||||
MessageEndpoint string `json:",default=/message"`
|
||||
|
||||
// Cors contains allowed CORS origins
|
||||
Cors []string `json:",optional"`
|
||||
|
||||
// ToolTimeout is the maximum time allowed for tool execution
|
||||
ToolTimeout time.Duration `json:",default=30s"`
|
||||
}
|
||||
}
|
||||
63
mcp/config_test.go
Normal file
63
mcp/config_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/conf"
|
||||
)
|
||||
|
||||
func TestMcpConfDefaults(t *testing.T) {
|
||||
// Test default values are set correctly when unmarshalled from JSON
|
||||
jsonConfig := `name: test-service
|
||||
port: 8080
|
||||
mcp:
|
||||
name: test-mcp-server
|
||||
version: 1.0.0
|
||||
`
|
||||
|
||||
var c McpConf
|
||||
err := conf.LoadFromYamlBytes([]byte(jsonConfig), &c)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check default values
|
||||
assert.Equal(t, "test-mcp-server", c.Mcp.Name)
|
||||
assert.Equal(t, "1.0.0", c.Mcp.Version, "Default version should be 1.0.0")
|
||||
assert.Equal(t, "2024-11-05", c.Mcp.ProtocolVersion, "Default protocol version should be 2024-11-05")
|
||||
assert.Equal(t, "/sse", c.Mcp.SseEndpoint, "Default SSE endpoint should be /sse")
|
||||
assert.Equal(t, "/message", c.Mcp.MessageEndpoint, "Default message endpoint should be /message")
|
||||
assert.Equal(t, 30*time.Second, c.Mcp.ToolTimeout, "Default tool timeout should be 30s")
|
||||
}
|
||||
|
||||
func TestMcpConfCustomValues(t *testing.T) {
|
||||
// Test custom values can be set
|
||||
jsonConfig := `{
|
||||
"Name": "test-service",
|
||||
"Port": 8080,
|
||||
"Mcp": {
|
||||
"Name": "test-mcp-server",
|
||||
"Version": "2.0.0",
|
||||
"ProtocolVersion": "2025-01-01",
|
||||
"BaseUrl": "http://example.com",
|
||||
"SseEndpoint": "/custom-sse",
|
||||
"MessageEndpoint": "/custom-message",
|
||||
"Cors": ["http://localhost:3000", "http://example.com"],
|
||||
"ToolTimeout": "60s"
|
||||
}
|
||||
}`
|
||||
|
||||
var c McpConf
|
||||
err := conf.LoadFromJsonBytes([]byte(jsonConfig), &c)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check custom values
|
||||
assert.Equal(t, "test-mcp-server", c.Mcp.Name, "Name should be inherited from RestConf")
|
||||
assert.Equal(t, "2.0.0", c.Mcp.Version, "Version should be customizable")
|
||||
assert.Equal(t, "2025-01-01", c.Mcp.ProtocolVersion, "Protocol version should be customizable")
|
||||
assert.Equal(t, "http://example.com", c.Mcp.BaseUrl, "BaseUrl should be customizable")
|
||||
assert.Equal(t, "/custom-sse", c.Mcp.SseEndpoint, "SSE endpoint should be customizable")
|
||||
assert.Equal(t, "/custom-message", c.Mcp.MessageEndpoint, "Message endpoint should be customizable")
|
||||
assert.Equal(t, []string{"http://localhost:3000", "http://example.com"}, c.Mcp.Cors, "CORS settings should be customizable")
|
||||
assert.Equal(t, 60*time.Second, c.Mcp.ToolTimeout, "Tool timeout should be customizable")
|
||||
}
|
||||
443
mcp/integration_test.go
Normal file
443
mcp/integration_test.go
Normal file
@@ -0,0 +1,443 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// syncResponseRecorder is a thread-safe wrapper around httptest.ResponseRecorder
|
||||
type syncResponseRecorder struct {
|
||||
*httptest.ResponseRecorder
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Create a new synchronized response recorder
|
||||
func newSyncResponseRecorder() *syncResponseRecorder {
|
||||
return &syncResponseRecorder{
|
||||
ResponseRecorder: httptest.NewRecorder(),
|
||||
}
|
||||
}
|
||||
|
||||
// Override Write method to synchronize access
|
||||
func (srr *syncResponseRecorder) Write(p []byte) (int, error) {
|
||||
srr.mu.Lock()
|
||||
defer srr.mu.Unlock()
|
||||
return srr.ResponseRecorder.Write(p)
|
||||
}
|
||||
|
||||
// Override WriteHeader method to synchronize access
|
||||
func (srr *syncResponseRecorder) WriteHeader(statusCode int) {
|
||||
srr.mu.Lock()
|
||||
defer srr.mu.Unlock()
|
||||
srr.ResponseRecorder.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
// Override Result method to synchronize access
|
||||
func (srr *syncResponseRecorder) Result() *http.Response {
|
||||
srr.mu.Lock()
|
||||
defer srr.mu.Unlock()
|
||||
return srr.ResponseRecorder.Result()
|
||||
}
|
||||
|
||||
// TestHTTPHandlerIntegration tests the HTTP handlers with a real server instance
|
||||
func TestHTTPHandlerIntegration(t *testing.T) {
|
||||
// Skip in short test mode
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// Create a test configuration
|
||||
conf := McpConf{}
|
||||
conf.Mcp.Name = "test-integration"
|
||||
conf.Mcp.Version = "1.0.0-test"
|
||||
conf.Mcp.ToolTimeout = 1 * time.Second
|
||||
|
||||
// Create a mock server directly
|
||||
server := &sseMcpServer{
|
||||
conf: conf,
|
||||
clients: make(map[string]*mcpClient),
|
||||
tools: make(map[string]Tool),
|
||||
prompts: make(map[string]Prompt),
|
||||
resources: make(map[string]Resource),
|
||||
}
|
||||
|
||||
// Register a test tool
|
||||
err := server.RegisterTool(Tool{
|
||||
Name: "echo",
|
||||
Description: "Echo tool for testing",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"message": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Message to echo",
|
||||
},
|
||||
},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
if msg, ok := params["message"].(string); ok {
|
||||
return fmt.Sprintf("Echo: %s", msg), nil
|
||||
}
|
||||
return "Echo: no message provided", nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a test HTTP request to the SSE endpoint
|
||||
req := httptest.NewRequest("GET", "/sse", nil)
|
||||
w := newSyncResponseRecorder()
|
||||
|
||||
// Create a done channel to signal completion of test
|
||||
done := make(chan bool)
|
||||
|
||||
// Start the SSE handler in a goroutine
|
||||
go func() {
|
||||
// lock.Lock()
|
||||
server.handleSSE(w, req)
|
||||
// lock.Unlock()
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Allow time for the handler to process
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Expected - handler would normally block indefinitely
|
||||
case <-done:
|
||||
// This shouldn't happen immediately - the handler should block
|
||||
t.Error("SSE handler returned unexpectedly")
|
||||
}
|
||||
|
||||
// Check the initial headers
|
||||
resp := w.Result()
|
||||
assert.Equal(t, "chunked", resp.Header.Get("Transfer-Encoding"))
|
||||
resp.Body.Close()
|
||||
|
||||
// The handler creates a client and sends the endpoint message
|
||||
var sessionId string
|
||||
|
||||
// Give the handler time to set up the client
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Check that a client was created
|
||||
server.clientsLock.Lock()
|
||||
assert.Equal(t, 1, len(server.clients))
|
||||
for id := range server.clients {
|
||||
sessionId = id
|
||||
}
|
||||
server.clientsLock.Unlock()
|
||||
|
||||
require.NotEmpty(t, sessionId, "Expected a session ID to be created")
|
||||
|
||||
// Now that we have a session ID, we can test the message endpoint
|
||||
messageBody, _ := json.Marshal(Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 1,
|
||||
Method: methodInitialize,
|
||||
Params: json.RawMessage(`{}`),
|
||||
})
|
||||
|
||||
// Create a message request
|
||||
reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, sessionId)
|
||||
msgReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(messageBody))
|
||||
msgW := newSyncResponseRecorder()
|
||||
|
||||
// Process the message
|
||||
server.handleRequest(msgW, msgReq)
|
||||
|
||||
// Check the response
|
||||
msgResp := msgW.Result()
|
||||
assert.Equal(t, http.StatusAccepted, msgResp.StatusCode)
|
||||
msgResp.Body.Close() // Ensure response body is closed
|
||||
}
|
||||
|
||||
// TestHandlerResponseFlow tests the flow of a full request/response cycle
|
||||
func TestHandlerResponseFlow(t *testing.T) {
|
||||
// Create a mock server for testing
|
||||
server := &sseMcpServer{
|
||||
conf: McpConf{},
|
||||
clients: map[string]*mcpClient{
|
||||
"test-session": {
|
||||
id: "test-session",
|
||||
channel: make(chan string, 10),
|
||||
initialized: true,
|
||||
},
|
||||
},
|
||||
tools: make(map[string]Tool),
|
||||
prompts: make(map[string]Prompt),
|
||||
resources: make(map[string]Resource),
|
||||
}
|
||||
|
||||
// Register test resources
|
||||
server.RegisterTool(Tool{
|
||||
Name: "test.tool",
|
||||
Description: "Test tool",
|
||||
InputSchema: InputSchema{Type: "object"},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
return "tool result", nil
|
||||
},
|
||||
})
|
||||
|
||||
server.RegisterPrompt(Prompt{
|
||||
Name: "test.prompt",
|
||||
Description: "Test prompt",
|
||||
})
|
||||
|
||||
server.RegisterResource(Resource{
|
||||
Name: "test.resource",
|
||||
URI: "http://example.com",
|
||||
Description: "Test resource",
|
||||
})
|
||||
|
||||
// Create a request with session ID parameter
|
||||
reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, "test-session")
|
||||
|
||||
// Test tools/list request
|
||||
toolsListBody, _ := json.Marshal(Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 1,
|
||||
Method: methodToolsList,
|
||||
Params: json.RawMessage(`{}`),
|
||||
})
|
||||
|
||||
toolsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(toolsListBody))
|
||||
toolsW := newSyncResponseRecorder()
|
||||
|
||||
// Process the request
|
||||
server.handleRequest(toolsW, toolsReq)
|
||||
|
||||
// Check the response code
|
||||
toolsResp := toolsW.Result()
|
||||
assert.Equal(t, http.StatusAccepted, toolsResp.StatusCode)
|
||||
toolsResp.Body.Close()
|
||||
|
||||
// Check the channel message
|
||||
client := server.clients["test-session"]
|
||||
select {
|
||||
case message := <-client.channel:
|
||||
assert.Contains(t, message, `"tools":[{"name":"test.tool"`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for tools/list response")
|
||||
}
|
||||
|
||||
// Test prompts/list request
|
||||
promptsListBody, _ := json.Marshal(Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 2,
|
||||
Method: methodPromptsList,
|
||||
Params: json.RawMessage(`{}`),
|
||||
})
|
||||
|
||||
promptsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(promptsListBody))
|
||||
promptsW := newSyncResponseRecorder()
|
||||
|
||||
// Process the request
|
||||
server.handleRequest(promptsW, promptsReq)
|
||||
|
||||
// Check the response code
|
||||
promptsResp := promptsW.Result()
|
||||
assert.Equal(t, http.StatusAccepted, promptsResp.StatusCode)
|
||||
promptsResp.Body.Close()
|
||||
|
||||
// Check the channel message
|
||||
select {
|
||||
case message := <-client.channel:
|
||||
assert.Contains(t, message, `"prompts":[{"name":"test.prompt"`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for prompts/list response")
|
||||
}
|
||||
|
||||
// Test resources/list request
|
||||
resourcesListBody, _ := json.Marshal(Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 3,
|
||||
Method: methodResourcesList,
|
||||
Params: json.RawMessage(`{}`),
|
||||
})
|
||||
|
||||
resourcesReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(resourcesListBody))
|
||||
resourcesW := newSyncResponseRecorder()
|
||||
|
||||
// Process the request
|
||||
server.handleRequest(resourcesW, resourcesReq)
|
||||
|
||||
// Check the response code
|
||||
resourcesResp := resourcesW.Result()
|
||||
assert.Equal(t, http.StatusAccepted, resourcesResp.StatusCode)
|
||||
resourcesResp.Body.Close()
|
||||
|
||||
// Check the channel message
|
||||
select {
|
||||
case message := <-client.channel:
|
||||
assert.Contains(t, message, `"name":"test.resource"`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for resources/list response")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessListMethods tests the list processing methods with pagination
|
||||
func TestProcessListMethods(t *testing.T) {
|
||||
server := &sseMcpServer{
|
||||
tools: make(map[string]Tool),
|
||||
prompts: make(map[string]Prompt),
|
||||
resources: make(map[string]Resource),
|
||||
}
|
||||
|
||||
// Add some test data
|
||||
for i := 1; i <= 5; i++ {
|
||||
tool := Tool{
|
||||
Name: fmt.Sprintf("tool%d", i),
|
||||
Description: fmt.Sprintf("Tool %d", i),
|
||||
InputSchema: InputSchema{Type: "object"},
|
||||
}
|
||||
server.tools[tool.Name] = tool
|
||||
|
||||
prompt := Prompt{
|
||||
Name: fmt.Sprintf("prompt%d", i),
|
||||
Description: fmt.Sprintf("Prompt %d", i),
|
||||
}
|
||||
server.prompts[prompt.Name] = prompt
|
||||
|
||||
resource := Resource{
|
||||
Name: fmt.Sprintf("resource%d", i),
|
||||
URI: fmt.Sprintf("http://example.com/%d", i),
|
||||
Description: fmt.Sprintf("Resource %d", i),
|
||||
}
|
||||
server.resources[resource.Name] = resource
|
||||
}
|
||||
|
||||
// Create a test client
|
||||
client := &mcpClient{
|
||||
id: "test-client",
|
||||
channel: make(chan string, 10),
|
||||
initialized: true,
|
||||
}
|
||||
|
||||
// Test processListTools
|
||||
req := Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 1,
|
||||
Method: methodToolsList,
|
||||
Params: json.RawMessage(`{"cursor": "", "_meta": {"progressToken": "token1"}}`),
|
||||
}
|
||||
|
||||
server.processListTools(client, req)
|
||||
|
||||
// Read response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, `"tools":`)
|
||||
assert.Contains(t, response, `"progressToken":"token1"`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for tools/list response")
|
||||
}
|
||||
|
||||
// Test processListPrompts
|
||||
req.ID = 2
|
||||
req.Method = methodPromptsList
|
||||
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
||||
server.processListPrompts(client, req)
|
||||
|
||||
// Read response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, `"prompts":`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for prompts/list response")
|
||||
}
|
||||
|
||||
// Test processListResources
|
||||
req.ID = 3
|
||||
req.Method = methodResourcesList
|
||||
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
||||
server.processListResources(client, req)
|
||||
|
||||
// Read response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, `"resources":`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for resources/list response")
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorResponseHandling tests error handling in the server
|
||||
func TestErrorResponseHandling(t *testing.T) {
|
||||
server := &sseMcpServer{
|
||||
tools: make(map[string]Tool),
|
||||
prompts: make(map[string]Prompt),
|
||||
resources: make(map[string]Resource),
|
||||
}
|
||||
|
||||
// Create a test client
|
||||
client := &mcpClient{
|
||||
id: "test-client",
|
||||
channel: make(chan string, 10),
|
||||
initialized: true,
|
||||
}
|
||||
|
||||
// Test invalid method
|
||||
req := Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 1,
|
||||
Method: "invalid_method",
|
||||
Params: json.RawMessage(`{}`),
|
||||
}
|
||||
|
||||
// Mock handleRequest by directly calling error handler
|
||||
server.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, `"error":{"code":-32601,"message":"Method not found"}`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for error response")
|
||||
}
|
||||
|
||||
// Test invalid tool
|
||||
toolReq := Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 2,
|
||||
Method: methodToolsCall,
|
||||
Params: json.RawMessage(`{"name":"non_existent_tool"}`),
|
||||
}
|
||||
|
||||
// Call process method directly
|
||||
server.processToolCall(client, toolReq)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, `"error":{"code":-32602,"message":"Tool 'non_existent_tool' not found"}`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for error response")
|
||||
}
|
||||
|
||||
// Test invalid prompt
|
||||
promptReq := Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 3,
|
||||
Method: methodPromptsGet,
|
||||
Params: json.RawMessage(`{"name":"non_existent_prompt"}`),
|
||||
}
|
||||
|
||||
// Call process method directly
|
||||
server.processGetPrompt(client, promptReq)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, `"error":{"code":-32602,"message":"Prompt 'non_existent_prompt' not found"}`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for error response")
|
||||
}
|
||||
}
|
||||
62
mcp/readme.md
Normal file
62
mcp/readme.md
Normal file
@@ -0,0 +1,62 @@
|
||||
# Model Context Protocol (MCP) SDK Implementation
|
||||
|
||||
## Overview
|
||||
This package implements a Model Context Protocol (MCP) server in Go that facilitates real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation provides a framework for building AI-assisted applications with bidirectional communication capabilities.
|
||||
|
||||
## Core Components
|
||||
|
||||
### Server-Sent Events (SSE) Communication
|
||||
- **Real-time Communication**: Robust SSE-based communication system that maintains persistent connections with clients
|
||||
- **Connection Management**: Client registration, message broadcasting, and client cleanup mechanisms
|
||||
- **Event Handling**: Event types for tools, prompts, and resources changes
|
||||
|
||||
### JSON-RPC Implementation
|
||||
- **Request Processing**: Complete JSON-RPC request processor for handling MCP protocol methods
|
||||
- **Response Formatting**: Proper response formatting according to JSON-RPC specifications
|
||||
- **Error Handling**: Comprehensive error handling with appropriate error codes
|
||||
|
||||
### Tool Management
|
||||
- **Tool Registration**: System to register custom tools with handlers
|
||||
- **Tool Execution**: Mechanism to execute tool functions with proper timeout handling
|
||||
- **Result Handling**: Flexible result handling supporting various return types (string, JSON, images)
|
||||
|
||||
### Prompt System
|
||||
- **Prompt Registration**: System for registering both static and dynamic prompts
|
||||
- **Argument Validation**: Validation for required arguments and default values for optional ones
|
||||
- **Message Generation**: Handlers that generate properly formatted conversation messages
|
||||
|
||||
### Resource Management
|
||||
- **Resource Registration**: System for managing and accessing external resources
|
||||
- **Content Delivery**: Handlers for delivering resource content to clients on demand
|
||||
- **Resource Subscription**: Mechanisms for clients to subscribe to resource updates
|
||||
|
||||
### Protocol Features
|
||||
- **Initialization Sequence**: Proper handshaking with capability negotiation
|
||||
- **Notification Handling**: Support for both standard and client-specific notifications
|
||||
- **Message Routing**: Intelligent routing of requests to appropriate handlers
|
||||
|
||||
## Technical Highlights
|
||||
|
||||
### Configuration System
|
||||
- **Flexible Configuration**: Configuration system with sensible defaults and customization options
|
||||
- **CORS Support**: Configurable CORS settings for cross-origin requests
|
||||
- **Server Information**: Proper server identification and versioning
|
||||
|
||||
### Client Session Management
|
||||
- **Session Tracking**: Client session tracking with unique identifiers
|
||||
- **Connection Health**: Ping/pong mechanism to maintain connection health
|
||||
- **Initialization State**: Client initialization state tracking
|
||||
|
||||
### Content Handling
|
||||
- **Multi-format Content**: Support for text, code, and binary content
|
||||
- **MIME Type Support**: Proper MIME type identification for various content types
|
||||
- **Audience Annotations**: Content audience annotations for user/assistant targeting
|
||||
|
||||
## Usage
|
||||
|
||||
To create and use an MCP server, see the examples directory for practical implementation examples including:
|
||||
- Tool registration and execution
|
||||
- Static and dynamic prompt creation
|
||||
- Resource handling with proper URI identification
|
||||
- Embedded resources in prompt responses
|
||||
- Client connection management
|
||||
974
mcp/server.go
Normal file
974
mcp/server.go
Normal file
@@ -0,0 +1,974 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/rest"
|
||||
)
|
||||
|
||||
func NewMcpServer(c McpConf) McpServer {
|
||||
var server *rest.Server
|
||||
if len(c.Mcp.Cors) == 0 {
|
||||
server = rest.MustNewServer(c.RestConf)
|
||||
} else {
|
||||
server = rest.MustNewServer(c.RestConf, rest.WithCors(c.Mcp.Cors...))
|
||||
}
|
||||
|
||||
if len(c.Mcp.Name) == 0 {
|
||||
c.Mcp.Name = c.Name
|
||||
}
|
||||
if len(c.Mcp.BaseUrl) == 0 {
|
||||
c.Mcp.BaseUrl = fmt.Sprintf("http://localhost:%d", c.Port)
|
||||
}
|
||||
|
||||
s := &sseMcpServer{
|
||||
conf: c,
|
||||
server: server,
|
||||
clients: make(map[string]*mcpClient),
|
||||
tools: make(map[string]Tool),
|
||||
prompts: make(map[string]Prompt),
|
||||
resources: make(map[string]Resource),
|
||||
}
|
||||
|
||||
// SSE endpoint for real-time updates
|
||||
s.server.AddRoute(rest.Route{
|
||||
Method: http.MethodGet,
|
||||
Path: s.conf.Mcp.SseEndpoint,
|
||||
Handler: s.handleSSE,
|
||||
}, rest.WithSSE())
|
||||
|
||||
// JSON-RPC message endpoint for regular requests
|
||||
s.server.AddRoute(rest.Route{
|
||||
Method: http.MethodPost,
|
||||
Path: s.conf.Mcp.MessageEndpoint,
|
||||
Handler: s.handleRequest,
|
||||
})
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// RegisterPrompt registers a new prompt with the server
|
||||
func (s *sseMcpServer) RegisterPrompt(prompt Prompt) {
|
||||
s.promptsLock.Lock()
|
||||
s.prompts[prompt.Name] = prompt
|
||||
s.promptsLock.Unlock()
|
||||
// Notify clients about the new prompt
|
||||
s.broadcast(eventPromptsListChanged, map[string][]Prompt{keyPrompts: {prompt}})
|
||||
}
|
||||
|
||||
// RegisterResource registers a new resource with the server
|
||||
func (s *sseMcpServer) RegisterResource(resource Resource) {
|
||||
s.resourcesLock.Lock()
|
||||
s.resources[resource.URI] = resource
|
||||
s.resourcesLock.Unlock()
|
||||
// Notify clients about the new resource
|
||||
s.broadcast(eventResourcesListChanged, map[string][]Resource{keyResources: {resource}})
|
||||
}
|
||||
|
||||
// RegisterTool registers a new tool with the server
|
||||
func (s *sseMcpServer) RegisterTool(tool Tool) error {
|
||||
if tool.Handler == nil {
|
||||
return fmt.Errorf("tool '%s' has no handler function", tool.Name)
|
||||
}
|
||||
|
||||
s.toolsLock.Lock()
|
||||
s.tools[tool.Name] = tool
|
||||
s.toolsLock.Unlock()
|
||||
// Notify clients about the new tool
|
||||
s.broadcast(eventToolsListChanged, map[string][]Tool{keyTools: {tool}})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start implements McpServer.
|
||||
func (s *sseMcpServer) Start() {
|
||||
s.server.Start()
|
||||
}
|
||||
|
||||
func (s *sseMcpServer) Stop() {
|
||||
s.server.Stop()
|
||||
}
|
||||
|
||||
// broadcast sends a message to all connected clients
|
||||
// It uses Server-Sent Events (SSE) format for real-time communication
|
||||
func (s *sseMcpServer) broadcast(event string, data any) {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
logx.Errorf("Failed to marshal broadcast data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Lock only while reading the clients map
|
||||
s.clientsLock.Lock()
|
||||
clients := make([]*mcpClient, 0, len(s.clients))
|
||||
for _, client := range s.clients {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
s.clientsLock.Unlock()
|
||||
|
||||
clientCount := len(clients)
|
||||
if clientCount == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
logx.Infof("Broadcasting event '%s' to %d clients", event, clientCount)
|
||||
|
||||
// Use CRLF line endings as per SSE specification
|
||||
message := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(jsonData))
|
||||
|
||||
// Send messages without holding the lock
|
||||
for _, client := range clients {
|
||||
select {
|
||||
case client.channel <- message:
|
||||
// Message sent successfully
|
||||
default:
|
||||
// Channel buffer is full, log warning and continue
|
||||
logx.Errorf("Client channel buffer full, dropping message for client %s", client.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupClient removes a client from the active clients map
|
||||
func (s *sseMcpServer) cleanupClient(sessionId string) {
|
||||
s.clientsLock.Lock()
|
||||
defer s.clientsLock.Unlock()
|
||||
|
||||
if client, exists := s.clients[sessionId]; exists {
|
||||
// Close the channel to signal any goroutines waiting on it
|
||||
close(client.channel)
|
||||
// Remove from active clients
|
||||
delete(s.clients, sessionId)
|
||||
logx.Infof("Cleaned up client %s (remaining clients: %d)", sessionId, len(s.clients))
|
||||
}
|
||||
}
|
||||
|
||||
// handleRequest handles MCP JSON-RPC requests
|
||||
func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
// Extract sessionId from query parameters
|
||||
sessionId := r.URL.Query().Get(sessionIdKey)
|
||||
if len(sessionId) == 0 {
|
||||
http.Error(w, fmt.Sprintf("Missing %s", sessionIdKey), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the client with this sessionId exists
|
||||
s.clientsLock.Lock()
|
||||
client, exists := s.clients[sessionId]
|
||||
s.clientsLock.Unlock()
|
||||
|
||||
if !exists {
|
||||
http.Error(w, fmt.Sprintf("Invalid or expired %s", sessionIdKey), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var req Request
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
|
||||
// For notification methods (no ID), we don't send a response
|
||||
isNotification := req.ID == 0
|
||||
|
||||
// Special handling for initialization sequence
|
||||
// Always allow initialize and notifications/initialized regardless of client state
|
||||
if req.Method == methodInitialize {
|
||||
logx.Infof("Processing initialize request with ID: %d", req.ID)
|
||||
s.processInitialize(client, req)
|
||||
logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID)
|
||||
return
|
||||
} else if req.Method == methodNotificationsInitialized {
|
||||
// Handle initialized notification
|
||||
logx.Info("Received notifications/initialized notification")
|
||||
if !isNotification {
|
||||
s.sendErrorResponse(client, req.ID, "Method should be used as a notification", errCodeInvalidRequest)
|
||||
return
|
||||
}
|
||||
s.processNotificationInitialized(client)
|
||||
return
|
||||
} else if !client.initialized && req.Method != methodNotificationsCancelled {
|
||||
// Block most requests until client is initialized (except for cancellations)
|
||||
s.sendErrorResponse(client, req.ID, "Client not fully initialized, waiting for notifications/initialized",
|
||||
errCodeClientNotInitialized)
|
||||
return
|
||||
}
|
||||
|
||||
// Process normal requests only after initialization
|
||||
switch req.Method {
|
||||
case methodToolsCall:
|
||||
logx.Infof("Received tools call request with ID: %d", req.ID)
|
||||
s.processToolCall(client, req)
|
||||
logx.Infof("Sent tools call response for ID: %d", req.ID)
|
||||
case methodToolsList:
|
||||
logx.Infof("Processing tools/list request with ID: %d", req.ID)
|
||||
s.processListTools(client, req)
|
||||
logx.Infof("Sent tools/list response for ID: %d", req.ID)
|
||||
case methodPromptsList:
|
||||
logx.Infof("Processing prompts/list request with ID: %d", req.ID)
|
||||
s.processListPrompts(client, req)
|
||||
logx.Infof("Sent prompts/list response for ID: %d", req.ID)
|
||||
case methodPromptsGet:
|
||||
logx.Infof("Processing prompts/get request with ID: %d", req.ID)
|
||||
s.processGetPrompt(client, req)
|
||||
logx.Infof("Sent prompts/get response for ID: %d", req.ID)
|
||||
case methodResourcesList:
|
||||
logx.Infof("Processing resources/list request with ID: %d", req.ID)
|
||||
s.processListResources(client, req)
|
||||
logx.Infof("Sent resources/list response for ID: %d", req.ID)
|
||||
case methodResourcesRead:
|
||||
logx.Infof("Processing resources/read request with ID: %d", req.ID)
|
||||
s.processResourcesRead(client, req)
|
||||
logx.Infof("Sent resources/read response for ID: %d", req.ID)
|
||||
case methodResourcesSubscribe:
|
||||
logx.Infof("Processing resources/subscribe request with ID: %d", req.ID)
|
||||
s.processResourceSubscribe(client, req)
|
||||
logx.Infof("Sent resources/subscribe response for ID: %d", req.ID)
|
||||
case methodPing:
|
||||
logx.Infof("Processing ping request with ID: %d", req.ID)
|
||||
s.processPing(client, req)
|
||||
case methodNotificationsCancelled:
|
||||
logx.Infof("Received notifications/cancelled notification: %v", req.Params)
|
||||
s.processNotificationCancelled(client, req)
|
||||
default:
|
||||
logx.Infof("Unknown method: %s", req.Method)
|
||||
s.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
// handleSSE handles Server-Sent Events connections
|
||||
func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) {
|
||||
// Generate a unique session ID for this client
|
||||
sessionId := uuid.New().String()
|
||||
|
||||
// Create new client with buffered channel to prevent blocking
|
||||
client := &mcpClient{
|
||||
id: sessionId,
|
||||
channel: make(chan string, eventChanSize),
|
||||
}
|
||||
|
||||
// Add client to active clients map
|
||||
s.clientsLock.Lock()
|
||||
s.clients[sessionId] = client
|
||||
activeClients := len(s.clients)
|
||||
s.clientsLock.Unlock()
|
||||
|
||||
logx.Infof("New SSE connection established for client %s (active clients: %d)",
|
||||
sessionId, activeClients)
|
||||
|
||||
// Set proper SSE headers
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
|
||||
// Enable streaming
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
logx.Error("Streaming not supported by the underlying http.ResponseWriter")
|
||||
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Send the message endpoint URL to the client
|
||||
endpoint := fmt.Sprintf("%s%s?%s=%s",
|
||||
s.conf.Mcp.BaseUrl, s.conf.Mcp.MessageEndpoint, sessionIdKey, sessionId)
|
||||
|
||||
// Format and send the endpoint message
|
||||
endpointMsg := formatSSEMessage(eventEndpoint, []byte(endpoint))
|
||||
if _, err := fmt.Fprint(w, endpointMsg); err != nil {
|
||||
logx.Errorf("Failed to send endpoint message to client %s: %v", sessionId, err)
|
||||
s.cleanupClient(sessionId)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// Set up keep-alive ping and client cleanup
|
||||
ticker := time.NewTicker(pingInterval.Load())
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
s.cleanupClient(sessionId)
|
||||
logx.Infof("SSE connection closed for client %s", sessionId)
|
||||
}()
|
||||
|
||||
// Message processing loop
|
||||
for {
|
||||
select {
|
||||
case message, ok := <-client.channel:
|
||||
if !ok {
|
||||
// Channel was closed, end connection
|
||||
logx.Infof("Client channel was closed for %s", sessionId)
|
||||
return
|
||||
}
|
||||
|
||||
// Write message to the response
|
||||
if _, err := fmt.Fprint(w, message); err != nil {
|
||||
logx.Infof("Failed to write message to client %s: %v", sessionId, err)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
case <-ticker.C:
|
||||
// Send keep-alive ping to maintain connection
|
||||
ping := fmt.Sprintf(`{"type":"ping","timestamp":"%s"}`, time.Now().String())
|
||||
pingMsg := formatSSEMessage("ping", []byte(ping))
|
||||
if _, err := fmt.Fprint(w, pingMsg); err != nil {
|
||||
logx.Errorf("Failed to send ping to client %s, closing connection: %v", sessionId, err)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
case <-r.Context().Done():
|
||||
// Client disconnected or request was canceled
|
||||
logx.Infof("Client %s disconnected: context done", sessionId)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processInitialize processes the initialize request
|
||||
func (s *sseMcpServer) processInitialize(client *mcpClient, req Request) {
|
||||
// Create a proper JSON-RPC response that preserves the client's request ID
|
||||
result := initializationResponse{
|
||||
ProtocolVersion: s.conf.Mcp.ProtocolVersion,
|
||||
Capabilities: capabilities{
|
||||
Prompts: struct {
|
||||
ListChanged bool `json:"listChanged"`
|
||||
}{
|
||||
ListChanged: true,
|
||||
},
|
||||
Resources: struct {
|
||||
Subscribe bool `json:"subscribe"`
|
||||
ListChanged bool `json:"listChanged"`
|
||||
}{
|
||||
Subscribe: true,
|
||||
ListChanged: true,
|
||||
},
|
||||
Tools: struct {
|
||||
ListChanged bool `json:"listChanged"`
|
||||
}{
|
||||
ListChanged: true,
|
||||
},
|
||||
},
|
||||
ServerInfo: serverInfo{
|
||||
Name: s.conf.Mcp.Name,
|
||||
Version: s.conf.Mcp.Version,
|
||||
},
|
||||
}
|
||||
|
||||
// Mark client as initialized
|
||||
client.initialized = true
|
||||
|
||||
// Send response with client's original request ID
|
||||
s.sendResponse(client, req.ID, result)
|
||||
}
|
||||
|
||||
// processListTools processes the tools/list request
|
||||
func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
|
||||
// Extract pagination params if any
|
||||
var nextCursor string
|
||||
var progressToken any
|
||||
|
||||
// Extract meta data including progress token
|
||||
if req.Params != nil {
|
||||
var metaParams struct {
|
||||
Cursor string `json:"cursor"`
|
||||
Meta struct {
|
||||
ProgressToken any `json:"progressToken"`
|
||||
} `json:"_meta"`
|
||||
}
|
||||
if err := json.Unmarshal(req.Params, &metaParams); err == nil {
|
||||
if len(metaParams.Cursor) > 0 {
|
||||
nextCursor = metaParams.Cursor
|
||||
}
|
||||
progressToken = metaParams.Meta.ProgressToken
|
||||
}
|
||||
}
|
||||
|
||||
var toolsList []Tool
|
||||
s.toolsLock.Lock()
|
||||
for _, tool := range s.tools {
|
||||
toolsList = append(toolsList, tool)
|
||||
}
|
||||
s.toolsLock.Unlock()
|
||||
|
||||
result := ListToolsResult{
|
||||
PaginatedResult: PaginatedResult{
|
||||
Result: Result{},
|
||||
NextCursor: Cursor(nextCursor),
|
||||
},
|
||||
Tools: toolsList,
|
||||
}
|
||||
|
||||
// Add meta information if progress token was provided
|
||||
if progressToken != nil {
|
||||
result.Result.Meta = map[string]any{
|
||||
"progressToken": progressToken,
|
||||
}
|
||||
}
|
||||
|
||||
s.sendResponse(client, req.ID, result)
|
||||
}
|
||||
|
||||
// processListPrompts processes the prompts/list request
|
||||
func (s *sseMcpServer) processListPrompts(client *mcpClient, req Request) {
|
||||
// Extract pagination params if any
|
||||
var nextCursor string
|
||||
if req.Params != nil {
|
||||
var cursorParams struct {
|
||||
Cursor string `json:"cursor"`
|
||||
}
|
||||
if err := json.Unmarshal(req.Params, &cursorParams); err == nil && cursorParams.Cursor != "" {
|
||||
// If we have a valid cursor, we could use it for pagination
|
||||
// For now, we're not actually implementing pagination, so this is just
|
||||
// to show how it would be extracted from the request
|
||||
_ = cursorParams.Cursor
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare prompt list
|
||||
var promptsList []Prompt
|
||||
s.promptsLock.Lock()
|
||||
for _, prompt := range s.prompts {
|
||||
promptsList = append(promptsList, prompt)
|
||||
}
|
||||
s.promptsLock.Unlock()
|
||||
|
||||
// In a real implementation, you'd handle pagination here
|
||||
// For now, we'll return all prompts at once
|
||||
result := struct {
|
||||
Prompts []Prompt `json:"prompts"`
|
||||
NextCursor string `json:"nextCursor,omitempty"`
|
||||
Meta *struct{} `json:"_meta,omitempty"`
|
||||
}{
|
||||
Prompts: promptsList,
|
||||
NextCursor: nextCursor,
|
||||
}
|
||||
|
||||
s.sendResponse(client, req.ID, result)
|
||||
}
|
||||
|
||||
// processListResources processes the resources/list request
|
||||
func (s *sseMcpServer) processListResources(client *mcpClient, req Request) {
|
||||
// Extract pagination params if any
|
||||
var nextCursor string
|
||||
var progressToken any
|
||||
|
||||
// Extract meta information including progress token if available
|
||||
if req.Params != nil {
|
||||
var metaParams PaginatedParams
|
||||
if err := json.Unmarshal(req.Params, &metaParams); err == nil {
|
||||
if len(metaParams.Cursor) > 0 {
|
||||
nextCursor = metaParams.Cursor
|
||||
}
|
||||
progressToken = metaParams.Meta.ProgressToken
|
||||
}
|
||||
}
|
||||
|
||||
var resourcesList []Resource
|
||||
s.resourcesLock.Lock()
|
||||
for _, resource := range s.resources {
|
||||
// Create a copy without the handler function which shouldn't be sent to clients
|
||||
resourceCopy := Resource{
|
||||
URI: resource.URI,
|
||||
Name: resource.Name,
|
||||
Description: resource.Description,
|
||||
MimeType: resource.MimeType,
|
||||
}
|
||||
resourcesList = append(resourcesList, resourceCopy)
|
||||
}
|
||||
s.resourcesLock.Unlock()
|
||||
|
||||
// Create proper ResourcesListResult according to MCP specification
|
||||
result := ResourcesListResult{
|
||||
PaginatedResult: PaginatedResult{
|
||||
Result: Result{},
|
||||
NextCursor: Cursor(nextCursor),
|
||||
},
|
||||
Resources: resourcesList,
|
||||
}
|
||||
|
||||
// Add meta information if progress token was provided
|
||||
if progressToken != nil {
|
||||
result.Result.Meta = map[string]any{
|
||||
"progressToken": progressToken,
|
||||
}
|
||||
}
|
||||
|
||||
s.sendResponse(client, req.ID, result)
|
||||
}
|
||||
|
||||
// processGetPrompt processes the prompts/get request
|
||||
func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
type GetPromptParams struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
var params GetPromptParams
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if prompt exists
|
||||
s.promptsLock.Lock()
|
||||
prompt, exists := s.prompts[params.Name]
|
||||
s.promptsLock.Unlock()
|
||||
if !exists {
|
||||
message := fmt.Sprintf("Prompt '%s' not found", params.Name)
|
||||
s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
logx.Infof("Processing prompt request: %s with %d arguments", prompt.Name, len(params.Arguments))
|
||||
|
||||
// Validate required arguments
|
||||
missingArgs := validatePromptArguments(prompt, params.Arguments)
|
||||
if len(missingArgs) > 0 {
|
||||
message := fmt.Sprintf("Missing required arguments: %s", strings.Join(missingArgs, ", "))
|
||||
s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
// Apply default values for missing optional arguments
|
||||
args := applyDefaultArguments(prompt, params.Arguments)
|
||||
|
||||
// Generate messages using handler or static content
|
||||
var messages []PromptMessage
|
||||
var err error
|
||||
|
||||
if prompt.Handler != nil {
|
||||
// Use dynamic handler to generate messages
|
||||
logx.Info("Using prompt handler to generate content")
|
||||
messages, err = prompt.Handler(args)
|
||||
if err != nil {
|
||||
logx.Errorf("Error from prompt handler: %v", err)
|
||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// No handler, generate messages from static content
|
||||
var messageText string
|
||||
if prompt.Content != "" {
|
||||
messageText = prompt.Content
|
||||
|
||||
// Apply argument substitutions to static content
|
||||
for key, value := range args {
|
||||
placeholder := fmt.Sprintf("{{%s}}", key)
|
||||
messageText = strings.Replace(messageText, placeholder, value, -1)
|
||||
}
|
||||
} else {
|
||||
// No content, use a default fallback
|
||||
topic := "this topic"
|
||||
if t, ok := args["topic"]; ok && t != "" {
|
||||
topic = t
|
||||
}
|
||||
messageText = fmt.Sprintf("Tell me about %s", topic)
|
||||
}
|
||||
|
||||
// Create a single user message with the content
|
||||
messages = []PromptMessage{
|
||||
{
|
||||
Role: roleUser,
|
||||
Content: TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: messageText,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Construct the response according to MCP spec
|
||||
result := struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
Messages []PromptMessage `json:"messages"`
|
||||
}{
|
||||
Description: prompt.Description,
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
s.sendResponse(client, req.ID, result)
|
||||
}
|
||||
|
||||
// validatePromptArguments checks if all required arguments are provided
|
||||
// Returns a list of missing required arguments
|
||||
func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string {
|
||||
var missingArgs []string
|
||||
|
||||
for _, arg := range prompt.Arguments {
|
||||
if arg.Required {
|
||||
if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 {
|
||||
missingArgs = append(missingArgs, arg.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return missingArgs
|
||||
}
|
||||
|
||||
// applyDefaultArguments adds default values for missing optional arguments
|
||||
func applyDefaultArguments(prompt Prompt, providedArgs map[string]string) map[string]string {
|
||||
result := make(map[string]string)
|
||||
|
||||
// Copy all provided arguments
|
||||
for k, v := range providedArgs {
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
// Add defaults for missing arguments
|
||||
for _, arg := range prompt.Arguments {
|
||||
if _, exists := result[arg.Name]; !exists && arg.Default != "" {
|
||||
result[arg.Name] = arg.Default
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// processToolCall processes the tools/call request
|
||||
func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
var toolCallParams struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Meta struct {
|
||||
ProgressToken any `json:"progressToken"`
|
||||
} `json:"_meta,omitempty"`
|
||||
}
|
||||
|
||||
// Handle different types of req.Params
|
||||
// If it's a RawMessage (JSON), unmarshal it
|
||||
if err := json.Unmarshal(req.Params, &toolCallParams); err != nil {
|
||||
logx.Errorf("Failed to unmarshal tool call params: %v", err)
|
||||
s.sendErrorResponse(client, req.ID, "Invalid tool call parameters", errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract progress token if available
|
||||
progressToken := toolCallParams.Meta.ProgressToken
|
||||
|
||||
// Find the requested tool
|
||||
s.toolsLock.Lock()
|
||||
tool, exists := s.tools[toolCallParams.Name]
|
||||
s.toolsLock.Unlock()
|
||||
if !exists {
|
||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Tool '%s' not found",
|
||||
toolCallParams.Name), errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a context with the configured timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), s.conf.Mcp.ToolTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Log parameters before execution
|
||||
logx.Infof("Executing tool '%s' with arguments: %#v", toolCallParams.Name, toolCallParams.Arguments)
|
||||
|
||||
// Execute the tool handler with timeout handling
|
||||
var result any
|
||||
var err error
|
||||
|
||||
// Create a channel to receive the result
|
||||
resultCh := make(chan struct {
|
||||
result any
|
||||
err error
|
||||
}, 1)
|
||||
|
||||
// Execute the tool handler in a goroutine
|
||||
go func() {
|
||||
toolResult, toolErr := tool.Handler(toolCallParams.Arguments)
|
||||
resultCh <- struct {
|
||||
result any
|
||||
err error
|
||||
}{
|
||||
result: toolResult,
|
||||
err: toolErr,
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for either the result or a timeout
|
||||
select {
|
||||
case res := <-resultCh:
|
||||
result = res.result
|
||||
err = res.err
|
||||
case <-ctx.Done():
|
||||
// Handle timeout
|
||||
logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.ToolTimeout, toolCallParams.Name)
|
||||
s.sendErrorResponse(client, req.ID, "Tool execution timed out", errCodeTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
// Create the base result structure with metadata
|
||||
callToolResult := CallToolResult{
|
||||
Result: Result{},
|
||||
Content: []any{},
|
||||
IsError: false,
|
||||
}
|
||||
|
||||
// Add meta information if progress token was provided
|
||||
if progressToken != nil {
|
||||
callToolResult.Result.Meta = map[string]any{
|
||||
"progressToken": progressToken,
|
||||
}
|
||||
}
|
||||
|
||||
// Check if there was an error during tool execution
|
||||
if err != nil {
|
||||
// According to the spec, for tool-level errors (as opposed to protocol-level errors),
|
||||
// we should report them inside the result with isError=true
|
||||
logx.Errorf("Tool execution reported error: %v", err)
|
||||
|
||||
callToolResult.Content = []any{
|
||||
TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: fmt.Sprintf("Error: %v", err),
|
||||
},
|
||||
}
|
||||
callToolResult.IsError = true
|
||||
s.sendResponse(client, req.ID, callToolResult)
|
||||
return
|
||||
}
|
||||
|
||||
// Format the response according to the CallToolResult schema
|
||||
switch v := result.(type) {
|
||||
case string:
|
||||
// Simple string becomes text content
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: v,
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
},
|
||||
})
|
||||
case map[string]any:
|
||||
// JSON-like object becomes formatted JSON text
|
||||
jsonStr, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
jsonStr = []byte(err.Error())
|
||||
}
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: string(jsonStr),
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
},
|
||||
})
|
||||
case TextContent:
|
||||
// Direct TextContent object
|
||||
callToolResult.Content = append(callToolResult.Content, v)
|
||||
case ImageContent:
|
||||
// Direct ImageContent object
|
||||
callToolResult.Content = append(callToolResult.Content, v)
|
||||
case []any:
|
||||
// Array of content items
|
||||
callToolResult.Content = v
|
||||
case ToolResult:
|
||||
// Handle legacy ToolResult type
|
||||
switch v.Type {
|
||||
case contentTypeText:
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: fmt.Sprintf("%v", v.Content),
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
},
|
||||
})
|
||||
case contentTypeImage:
|
||||
if imgData, ok := v.Content.(map[string]any); ok {
|
||||
callToolResult.Content = append(callToolResult.Content, ImageContent{
|
||||
Type: contentTypeImage,
|
||||
Data: fmt.Sprintf("%v", imgData["data"]),
|
||||
MimeType: fmt.Sprintf("%v", imgData["mimeType"]),
|
||||
})
|
||||
}
|
||||
default:
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: fmt.Sprintf("%v", v.Content),
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
},
|
||||
})
|
||||
}
|
||||
default:
|
||||
// For any other type, convert to string
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: fmt.Sprintf("%v", v),
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
logx.Infof("Tool call result: %#v", callToolResult)
|
||||
s.sendResponse(client, req.ID, callToolResult)
|
||||
}
|
||||
|
||||
// processResourcesRead processes the resources/read request
|
||||
func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
||||
var params ResourceReadParams
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
// Find resource that matches the URI
|
||||
s.resourcesLock.Lock()
|
||||
resource, exists := s.resources[params.URI]
|
||||
s.resourcesLock.Unlock()
|
||||
|
||||
if !exists {
|
||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||
params.URI), errCodeResourceNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// If no handler is provided, return an empty content array
|
||||
if resource.Handler == nil {
|
||||
result := ResourceReadResult{
|
||||
Contents: []ResourceContent{
|
||||
{
|
||||
URI: params.URI,
|
||||
MimeType: resource.MimeType,
|
||||
Text: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
s.sendResponse(client, req.ID, result)
|
||||
return
|
||||
}
|
||||
|
||||
// Execute the resource handler
|
||||
content, err := resource.Handler()
|
||||
if err != nil {
|
||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error reading resource: %v", err),
|
||||
errCodeInternalError)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure the URI is set if not already provided by the handler
|
||||
if len(content.URI) == 0 {
|
||||
content.URI = params.URI
|
||||
}
|
||||
|
||||
// Ensure MimeType is set if available from the resource definition
|
||||
if len(content.MimeType) == 0 && resource.MimeType != "" {
|
||||
content.MimeType = resource.MimeType
|
||||
}
|
||||
|
||||
// Create response with contents from the handler
|
||||
// The MCP specification requires a contents array
|
||||
result := ResourceReadResult{
|
||||
Contents: []ResourceContent{content},
|
||||
}
|
||||
|
||||
s.sendResponse(client, req.ID, result)
|
||||
}
|
||||
|
||||
// processResourceSubscribe processes the resources/subscribe request
|
||||
func (s *sseMcpServer) processResourceSubscribe(client *mcpClient, req Request) {
|
||||
var params ResourceSubscribeParams
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the resource exists
|
||||
s.resourcesLock.Lock()
|
||||
_, exists := s.resources[params.URI]
|
||||
s.resourcesLock.Unlock()
|
||||
|
||||
if !exists {
|
||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||
params.URI), errCodeResourceNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Send success response for the subscription
|
||||
s.sendResponse(client, req.ID, struct{}{})
|
||||
|
||||
logx.Infof("Client %s subscribed to resource '%s'", client.id, params.URI)
|
||||
}
|
||||
|
||||
// processNotificationCancelled processes the notifications/cancelled notification
|
||||
func (s *sseMcpServer) processNotificationCancelled(client *mcpClient, req Request) {
|
||||
// Extract the requestId that was canceled
|
||||
type CancelParams struct {
|
||||
RequestId int64 `json:"requestId"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
var params CancelParams
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
logx.Errorf("Failed to parse cancellation params: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logx.Infof("Request %d was cancelled by client. Reason: %s", params.RequestId, params.Reason)
|
||||
}
|
||||
|
||||
// processNotificationInitialized processes the notifications/initialized notification
|
||||
func (s *sseMcpServer) processNotificationInitialized(client *mcpClient) {
|
||||
// Mark the client as properly initialized
|
||||
client.initialized = true
|
||||
logx.Infof("Client %s is now fully initialized and ready for normal operations", client.id)
|
||||
}
|
||||
|
||||
// processPing processes the ping request and responds immediately
|
||||
func (s *sseMcpServer) processPing(client *mcpClient, req Request) {
|
||||
// A ping request should simply respond with an empty result to confirm the server is alive
|
||||
logx.Infof("Received ping request with ID: %d", req.ID)
|
||||
|
||||
// Send an empty response with client's original request ID
|
||||
s.sendResponse(client, req.ID, struct{}{})
|
||||
|
||||
logx.Infof("Sent ping response for ID: %d", req.ID)
|
||||
}
|
||||
|
||||
// sendErrorResponse sends an error response via the SSE channel
|
||||
func (s *sseMcpServer) sendErrorResponse(client *mcpClient, id int64, message string, code int) {
|
||||
errorResponse := struct {
|
||||
JsonRpc string `json:"jsonrpc"`
|
||||
ID int64 `json:"id"`
|
||||
Error errorMessage `json:"error"`
|
||||
}{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: id,
|
||||
Error: errorMessage{
|
||||
Code: code,
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
|
||||
// all fields are primitive types, impossible to fail
|
||||
jsonData, _ := json.Marshal(errorResponse)
|
||||
// Use CRLF line endings as requested
|
||||
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||
logx.Infof("Sending error for ID %d: %s", id, sseMessage)
|
||||
|
||||
client.channel <- sseMessage
|
||||
}
|
||||
|
||||
// sendResponse sends a success response via the SSE channel
|
||||
func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) {
|
||||
response := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: id,
|
||||
Result: result,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
s.sendErrorResponse(client, id, "Failed to marshal response", errCodeInternalError)
|
||||
return
|
||||
}
|
||||
|
||||
// Use CRLF line endings as requested
|
||||
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||
logx.Infof("Sending response for ID %d: %s", id, sseMessage)
|
||||
|
||||
client.channel <- sseMessage
|
||||
}
|
||||
3418
mcp/server_test.go
Normal file
3418
mcp/server_test.go
Normal file
File diff suppressed because it is too large
Load Diff
294
mcp/types.go
Normal file
294
mcp/types.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
"github.com/zeromicro/go-zero/rest"
|
||||
)
|
||||
|
||||
// Cursor is an opaque token used for pagination
|
||||
type Cursor string
|
||||
|
||||
// Request represents a generic MCP request following JSON-RPC 2.0 specification
|
||||
type Request struct {
|
||||
SessionId string `form:"session_id"` // Session identifier for client tracking
|
||||
JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec
|
||||
ID int64 `json:"id"` // Request identifier for matching responses
|
||||
Method string `json:"method"` // Method name to invoke
|
||||
Params json.RawMessage `json:"params"` // Parameters for the method
|
||||
}
|
||||
|
||||
type PaginatedParams struct {
|
||||
Cursor string `json:"cursor"`
|
||||
Meta struct {
|
||||
ProgressToken any `json:"progressToken"`
|
||||
} `json:"_meta"`
|
||||
}
|
||||
|
||||
// Result is the base interface for all results
|
||||
type Result struct {
|
||||
Meta map[string]any `json:"_meta,omitempty"` // Optional metadata
|
||||
}
|
||||
|
||||
// PaginatedResult is a base for results that support pagination
|
||||
type PaginatedResult struct {
|
||||
Result
|
||||
NextCursor Cursor `json:"nextCursor,omitempty"` // Opaque token for fetching next page
|
||||
}
|
||||
|
||||
// ListToolsResult represents the response to a tools/list request
|
||||
type ListToolsResult struct {
|
||||
PaginatedResult
|
||||
Tools []Tool `json:"tools"` // List of available tools
|
||||
}
|
||||
|
||||
// Message Content Types
|
||||
|
||||
// roleType represents the sender or recipient of messages in a conversation
|
||||
type roleType string
|
||||
|
||||
// PromptArgument defines a single argument that can be passed to a prompt
|
||||
type PromptArgument struct {
|
||||
Name string `json:"name"` // Argument name
|
||||
Description string `json:"description,omitempty"` // Human-readable description
|
||||
Required bool `json:"required,omitempty"` // Whether this argument is required
|
||||
Default string `json:"default,omitempty"` // Default value if not provided
|
||||
}
|
||||
|
||||
// PromptHandler is a function that dynamically generates prompt content
|
||||
type PromptHandler func(args map[string]string) ([]PromptMessage, error)
|
||||
|
||||
// Prompt represents an MCP Prompt definition
|
||||
type Prompt struct {
|
||||
Name string `json:"name"` // Unique identifier for the prompt
|
||||
Description string `json:"description,omitempty"` // Human-readable description
|
||||
Arguments []PromptArgument `json:"arguments,omitempty"` // Arguments for customization
|
||||
Content string `json:"-"` // Static content (internal use only)
|
||||
Handler PromptHandler `json:"-"` // Handler for dynamic content generation
|
||||
}
|
||||
|
||||
// PromptMessage represents a message in a conversation
|
||||
type PromptMessage struct {
|
||||
Role roleType `json:"role"` // Message sender role
|
||||
Content any `json:"content"` // Message content (TextContent, ImageContent, etc.)
|
||||
}
|
||||
|
||||
// TextContent represents text content in a message
|
||||
type TextContent struct {
|
||||
Type string `json:"type"` // Always "text"
|
||||
Text string `json:"text"` // The text content
|
||||
Annotations *Annotations `json:"annotations,omitempty"` // Optional annotations
|
||||
}
|
||||
|
||||
// ImageContent represents image data in a message
|
||||
type ImageContent struct {
|
||||
Type string `json:"type"` // Always "image"
|
||||
Data string `json:"data"` // Base64-encoded image data
|
||||
MimeType string `json:"mimeType"` // MIME type (e.g., "image/png")
|
||||
}
|
||||
|
||||
// AudioContent represents audio data in a message
|
||||
type AudioContent struct {
|
||||
Type string `json:"type"` // Always "audio"
|
||||
Data string `json:"data"` // Base64-encoded audio data
|
||||
MimeType string `json:"mimeType"` // MIME type (e.g., "audio/mp3")
|
||||
}
|
||||
|
||||
// FileContent represents file content
|
||||
type FileContent struct {
|
||||
URI string `json:"uri"` // URI identifying the file
|
||||
MimeType string `json:"mimeType"` // MIME type of the file
|
||||
Text string `json:"text"` // File content as text
|
||||
}
|
||||
|
||||
// EmbeddedResource represents a resource embedded in a message
|
||||
type EmbeddedResource struct {
|
||||
Type string `json:"type"` // Always "resource"
|
||||
Resource struct {
|
||||
URI string `json:"uri"` // Resource URI
|
||||
MimeType string `json:"mimeType"` // MIME type of the resource
|
||||
Text string `json:"text,omitempty"` // Text content (if available)
|
||||
Blob string `json:"blob,omitempty"` // Base64 encoded blob data (if available)
|
||||
} `json:"resource"` // The resource data
|
||||
}
|
||||
|
||||
// Annotations provides additional metadata for content
|
||||
type Annotations struct {
|
||||
Audience []roleType `json:"audience,omitempty"` // Who should see this content
|
||||
Priority *float64 `json:"priority,omitempty"` // Optional priority (0-1)
|
||||
}
|
||||
|
||||
// Tool-related Types
|
||||
|
||||
// Tool Definition Types
|
||||
|
||||
// ToolHandler is a function that handles tool calls
|
||||
type ToolHandler func(params map[string]any) (any, error)
|
||||
|
||||
// Tool represents a Model Context Protocol Tool definition
|
||||
type Tool struct {
|
||||
Name string `json:"name"` // Unique identifier for the tool
|
||||
Description string `json:"description"` // Human-readable description
|
||||
InputSchema InputSchema `json:"inputSchema"` // JSON Schema for parameters
|
||||
Handler ToolHandler `json:"-"` // Not sent to clients
|
||||
}
|
||||
|
||||
// InputSchema represents tool's input schema in JSON Schema format
|
||||
type InputSchema struct {
|
||||
Type string `json:"type"` // Always "object" for tool inputs
|
||||
Properties map[string]any `json:"properties"` // Property definitions
|
||||
Required []string `json:"required,omitempty"` // List of required properties
|
||||
}
|
||||
|
||||
// CallToolResult represents a tool call result that conforms to the MCP schema
|
||||
type CallToolResult struct {
|
||||
Result
|
||||
Content []interface{} `json:"content"` // Content items (text, images, etc.)
|
||||
IsError bool `json:"isError,omitempty"` // True if tool execution failed
|
||||
}
|
||||
|
||||
// Resource represents a Model Context Protocol Resource definition
|
||||
type Resource struct {
|
||||
URI string `json:"uri"` // Unique resource identifier (RFC3986)
|
||||
Name string `json:"name"` // Human-readable name
|
||||
Description string `json:"description,omitempty"` // Optional description
|
||||
MimeType string `json:"mimeType,omitempty"` // Optional MIME type
|
||||
Handler ResourceHandler `json:"-"` // Internal handler not sent to clients
|
||||
}
|
||||
|
||||
// ResourceHandler is a function that handles resource read requests
|
||||
type ResourceHandler func() (ResourceContent, error)
|
||||
|
||||
// ResourceContent represents the content of a resource
|
||||
type ResourceContent struct {
|
||||
URI string `json:"uri"` // Resource URI (required)
|
||||
MimeType string `json:"mimeType,omitempty"` // MIME type of the resource
|
||||
Text string `json:"text,omitempty"` // Text content (if available)
|
||||
Blob string `json:"blob,omitempty"` // Base64 encoded blob data (if available)
|
||||
}
|
||||
|
||||
// ResourcesListResult represents the response to a resources/list request
|
||||
type ResourcesListResult struct {
|
||||
PaginatedResult
|
||||
Resources []Resource `json:"resources"` // List of available resources
|
||||
}
|
||||
|
||||
// ResourceReadParams contains parameters for a resources/read request
|
||||
type ResourceReadParams struct {
|
||||
URI string `json:"uri"` // URI of the resource to read
|
||||
}
|
||||
|
||||
// ResourceReadResult contains the result of a resources/read request
|
||||
type ResourceReadResult struct {
|
||||
Result
|
||||
Contents []ResourceContent `json:"contents"` // Array of resource content
|
||||
}
|
||||
|
||||
// ResourceSubscribeParams contains parameters for a resources/subscribe request
|
||||
type ResourceSubscribeParams struct {
|
||||
URI string `json:"uri"` // URI of the resource to subscribe to
|
||||
}
|
||||
|
||||
// ResourceUpdateNotification represents a notification about a resource update
|
||||
type ResourceUpdateNotification struct {
|
||||
URI string `json:"uri"` // URI of the updated resource
|
||||
Content ResourceContent `json:"content"` // New resource content
|
||||
}
|
||||
|
||||
// Client and Server Types
|
||||
|
||||
// mcpClient represents an SSE client connection
|
||||
type mcpClient struct {
|
||||
id string // Unique client identifier
|
||||
channel chan string // Channel for sending SSE messages
|
||||
initialized bool // Tracks if client has sent notifications/initialized
|
||||
}
|
||||
|
||||
// McpServer defines the interface for Model Context Protocol servers
|
||||
type McpServer interface {
|
||||
Start()
|
||||
Stop()
|
||||
RegisterTool(tool Tool) error
|
||||
RegisterPrompt(prompt Prompt)
|
||||
RegisterResource(resource Resource)
|
||||
}
|
||||
|
||||
// sseMcpServer implements the McpServer interface using SSE
|
||||
type sseMcpServer struct {
|
||||
conf McpConf
|
||||
server *rest.Server
|
||||
clients map[string]*mcpClient
|
||||
clientsLock sync.Mutex
|
||||
tools map[string]Tool
|
||||
toolsLock sync.Mutex
|
||||
prompts map[string]Prompt
|
||||
promptsLock sync.Mutex
|
||||
resources map[string]Resource
|
||||
resourcesLock sync.Mutex
|
||||
}
|
||||
|
||||
// Response Types
|
||||
|
||||
// errorObj represents a JSON-RPC error object
|
||||
type errorObj struct {
|
||||
Code int `json:"code"` // Error code
|
||||
Message string `json:"message"` // Error message
|
||||
}
|
||||
|
||||
// Response represents a JSON-RPC response
|
||||
type Response struct {
|
||||
JsonRpc string `json:"jsonrpc"` // Always "2.0"
|
||||
ID int64 `json:"id"` // Same as request ID
|
||||
Result any `json:"result"` // Result object (null if error)
|
||||
Error *errorObj `json:"error,omitempty"` // Error object (null if success)
|
||||
}
|
||||
|
||||
// Server Information Types
|
||||
|
||||
// serverInfo provides information about the server
|
||||
type serverInfo struct {
|
||||
Name string `json:"name"` // Server name
|
||||
Version string `json:"version"` // Server version
|
||||
}
|
||||
|
||||
// capabilities describes the server's capabilities
|
||||
type capabilities struct {
|
||||
Logging struct{} `json:"logging"`
|
||||
Prompts struct {
|
||||
ListChanged bool `json:"listChanged"` // Server will notify on prompt changes
|
||||
} `json:"prompts"`
|
||||
Resources struct {
|
||||
Subscribe bool `json:"subscribe"` // Server supports resource subscriptions
|
||||
ListChanged bool `json:"listChanged"` // Server will notify on resource changes
|
||||
} `json:"resources"`
|
||||
Tools struct {
|
||||
ListChanged bool `json:"listChanged"` // Server will notify on tool changes
|
||||
} `json:"tools"`
|
||||
}
|
||||
|
||||
// initializationResponse is sent in response to an initialize request
|
||||
type initializationResponse struct {
|
||||
ProtocolVersion string `json:"protocolVersion"` // Protocol version
|
||||
Capabilities capabilities `json:"capabilities"` // Server capabilities
|
||||
ServerInfo serverInfo `json:"serverInfo"` // Server information
|
||||
}
|
||||
|
||||
// ToolCallParams contains the parameters for a tool call
|
||||
type ToolCallParams struct {
|
||||
Name string `json:"name"` // Tool name
|
||||
Parameters map[string]any `json:"parameters"` // Tool parameters
|
||||
}
|
||||
|
||||
// ToolResult contains the result of a tool execution
|
||||
type ToolResult struct {
|
||||
Type string `json:"type"` // Content type (text, image, etc.)
|
||||
Content any `json:"content"` // Result content
|
||||
}
|
||||
|
||||
// errorMessage represents a detailed error message
|
||||
type errorMessage struct {
|
||||
Code int `json:"code"` // Error code
|
||||
Message string `json:"message"` // Error message
|
||||
Data any `json:",omitempty"` // Additional error data
|
||||
}
|
||||
212
mcp/types_test.go
Normal file
212
mcp/types_test.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestResponseMarshaling(t *testing.T) {
|
||||
// Test that the Response struct marshals correctly
|
||||
resp := Response{
|
||||
JsonRpc: "2.0",
|
||||
ID: 123,
|
||||
Result: map[string]string{
|
||||
"key": "value",
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(resp)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"jsonrpc":"2.0"`)
|
||||
assert.Contains(t, string(data), `"id":123`)
|
||||
assert.Contains(t, string(data), `"result":{"key":"value"}`)
|
||||
|
||||
// Test response with error
|
||||
respWithError := Response{
|
||||
JsonRpc: "2.0",
|
||||
ID: 456,
|
||||
Error: &errorObj{
|
||||
Code: errCodeInvalidRequest,
|
||||
Message: "Invalid Request",
|
||||
},
|
||||
}
|
||||
|
||||
data, err = json.Marshal(respWithError)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"jsonrpc":"2.0"`)
|
||||
assert.Contains(t, string(data), `"id":456`)
|
||||
assert.Contains(t, string(data), `"error":{"code":-32600,"message":"Invalid Request"}`)
|
||||
}
|
||||
|
||||
func TestRequestUnmarshaling(t *testing.T) {
|
||||
// Test that the Request struct unmarshals correctly
|
||||
jsonStr := `{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 789,
|
||||
"method": "test_method",
|
||||
"params": {"key": "value"}
|
||||
}`
|
||||
|
||||
var req Request
|
||||
err := json.Unmarshal([]byte(jsonStr), &req)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "2.0", req.JsonRpc)
|
||||
assert.Equal(t, int64(789), req.ID)
|
||||
assert.Equal(t, "test_method", req.Method)
|
||||
|
||||
// Check params unmarshaled correctly
|
||||
var params map[string]string
|
||||
err = json.Unmarshal(req.Params, ¶ms)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "value", params["key"])
|
||||
}
|
||||
|
||||
func TestToolStructs(t *testing.T) {
|
||||
// Test Tool struct
|
||||
tool := Tool{
|
||||
Name: "test.tool",
|
||||
Description: "A test tool",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"input": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Input parameter",
|
||||
},
|
||||
},
|
||||
Required: []string{"input"},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
return "result", nil
|
||||
},
|
||||
}
|
||||
|
||||
// Verify fields are correct
|
||||
assert.Equal(t, "test.tool", tool.Name)
|
||||
assert.Equal(t, "A test tool", tool.Description)
|
||||
assert.Equal(t, "object", tool.InputSchema.Type)
|
||||
assert.Contains(t, tool.InputSchema.Properties, "input")
|
||||
propMap, ok := tool.InputSchema.Properties["input"].(map[string]any)
|
||||
assert.True(t, ok, "Property should be a map")
|
||||
assert.Equal(t, "string", propMap["type"])
|
||||
assert.NotNil(t, tool.Handler)
|
||||
|
||||
// Verify JSON marshalling (which should exclude Handler function)
|
||||
data, err := json.Marshal(tool)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"name":"test.tool"`)
|
||||
assert.Contains(t, string(data), `"description":"A test tool"`)
|
||||
assert.Contains(t, string(data), `"inputSchema":`)
|
||||
assert.NotContains(t, string(data), `"Handler":`)
|
||||
}
|
||||
|
||||
func TestPromptStructs(t *testing.T) {
|
||||
// Test Prompt struct
|
||||
prompt := Prompt{
|
||||
Name: "test.prompt",
|
||||
Description: "A test prompt description",
|
||||
}
|
||||
|
||||
// Verify fields are correct
|
||||
assert.Equal(t, "test.prompt", prompt.Name)
|
||||
assert.Equal(t, "A test prompt description", prompt.Description)
|
||||
|
||||
// Verify JSON marshalling
|
||||
data, err := json.Marshal(prompt)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"name":"test.prompt"`)
|
||||
assert.Contains(t, string(data), `"description":"A test prompt description"`)
|
||||
}
|
||||
|
||||
func TestResourceStructs(t *testing.T) {
|
||||
// Test Resource struct
|
||||
resource := Resource{
|
||||
Name: "test.resource",
|
||||
URI: "http://example.com/resource",
|
||||
Description: "A test resource",
|
||||
}
|
||||
|
||||
// Verify fields are correct
|
||||
assert.Equal(t, "test.resource", resource.Name)
|
||||
assert.Equal(t, "http://example.com/resource", resource.URI)
|
||||
assert.Equal(t, "A test resource", resource.Description)
|
||||
|
||||
// Verify JSON marshalling
|
||||
data, err := json.Marshal(resource)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"name":"test.resource"`)
|
||||
assert.Contains(t, string(data), `"uri":"http://example.com/resource"`)
|
||||
assert.Contains(t, string(data), `"description":"A test resource"`)
|
||||
}
|
||||
|
||||
func TestContentTypes(t *testing.T) {
|
||||
// Test TextContent
|
||||
textContent := TextContent{
|
||||
Type: "text",
|
||||
Text: "Sample text",
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Priority: ptr(1.0),
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(textContent)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"type":"text"`)
|
||||
assert.Contains(t, string(data), `"text":"Sample text"`)
|
||||
assert.Contains(t, string(data), `"audience":["user","assistant"]`)
|
||||
assert.Contains(t, string(data), `"priority":1`)
|
||||
|
||||
// Test ImageContent
|
||||
imageContent := ImageContent{
|
||||
Type: "image",
|
||||
Data: "base64data",
|
||||
MimeType: "image/png",
|
||||
}
|
||||
|
||||
data, err = json.Marshal(imageContent)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"type":"image"`)
|
||||
assert.Contains(t, string(data), `"data":"base64data"`)
|
||||
assert.Contains(t, string(data), `"mimeType":"image/png"`)
|
||||
|
||||
// Test AudioContent
|
||||
audioContent := AudioContent{
|
||||
Type: "audio",
|
||||
Data: "base64audio",
|
||||
MimeType: "audio/mp3",
|
||||
}
|
||||
|
||||
data, err = json.Marshal(audioContent)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"type":"audio"`)
|
||||
assert.Contains(t, string(data), `"data":"base64audio"`)
|
||||
assert.Contains(t, string(data), `"mimeType":"audio/mp3"`)
|
||||
}
|
||||
|
||||
func TestCallToolResult(t *testing.T) {
|
||||
// Test CallToolResult
|
||||
result := CallToolResult{
|
||||
Result: Result{
|
||||
Meta: map[string]any{
|
||||
"progressToken": "token123",
|
||||
},
|
||||
},
|
||||
Content: []interface{}{
|
||||
TextContent{
|
||||
Type: "text",
|
||||
Text: "Sample result",
|
||||
},
|
||||
},
|
||||
IsError: false,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"_meta":{"progressToken":"token123"}`)
|
||||
assert.Contains(t, string(data), `"content":[{"type":"text","text":"Sample result"}]`)
|
||||
assert.NotContains(t, string(data), `"isError":`)
|
||||
}
|
||||
15
mcp/util.go
Normal file
15
mcp/util.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ptr is a helper function to get a pointer to a value
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings
|
||||
func formatSSEMessage(event string, data []byte) string {
|
||||
return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data))
|
||||
}
|
||||
63
mcp/util_test.go
Normal file
63
mcp/util_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPtr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
v interface{}
|
||||
}{
|
||||
{"string", "test"},
|
||||
{"int", 42},
|
||||
{"bool", true},
|
||||
{"float", 3.14},
|
||||
{"struct", struct{ Name string }{"test"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ptr(tt.v)
|
||||
assert.NotNil(t, got, "ptr() should not return nil")
|
||||
assert.Equal(t, tt.v, *got, "dereferenced pointer should equal input value")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type Event struct {
|
||||
Type string
|
||||
Data map[string]any
|
||||
}
|
||||
|
||||
func parseEvent(input string) (*Event, error) {
|
||||
var evt Event
|
||||
var dataStr string
|
||||
|
||||
scanner := bufio.NewScanner(strings.NewReader(input))
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "event:") {
|
||||
evt.Type = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
|
||||
} else if strings.HasPrefix(line, "data:") {
|
||||
dataStr = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(dataStr) > 0 {
|
||||
if err := json.Unmarshal([]byte(dataStr), &evt.Data); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse data: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &evt, nil
|
||||
}
|
||||
137
mcp/vars.go
Normal file
137
mcp/vars.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
)
|
||||
|
||||
// Protocol constants
|
||||
const (
|
||||
// JSON-RPC version as defined in the specification
|
||||
jsonRpcVersion = "2.0"
|
||||
|
||||
// Session identifier key used in request URLs
|
||||
sessionIdKey = "session_id"
|
||||
)
|
||||
|
||||
// Server-Sent Events (SSE) event types
|
||||
const (
|
||||
// Standard message event for JSON-RPC responses
|
||||
eventMessage = "message"
|
||||
|
||||
// Endpoint event for sending endpoint URL to clients
|
||||
eventEndpoint = "endpoint"
|
||||
)
|
||||
|
||||
// Content type identifiers
|
||||
const (
|
||||
// Text content type
|
||||
contentTypeText = "text"
|
||||
|
||||
// Image content type
|
||||
contentTypeImage = "image"
|
||||
)
|
||||
|
||||
// Collection keys for broadcast events
|
||||
const (
|
||||
// Key for prompts collection
|
||||
keyPrompts = "prompts"
|
||||
|
||||
// Key for resources collection
|
||||
keyResources = "resources"
|
||||
|
||||
// Key for tools collection
|
||||
keyTools = "tools"
|
||||
)
|
||||
|
||||
// JSON-RPC error codes
|
||||
// Standard error codes from JSON-RPC 2.0 spec
|
||||
const (
|
||||
// Invalid JSON was received by the server
|
||||
errCodeInvalidRequest = -32600
|
||||
|
||||
// The method does not exist / is not available
|
||||
errCodeMethodNotFound = -32601
|
||||
|
||||
// Invalid method parameter(s)
|
||||
errCodeInvalidParams = -32602
|
||||
|
||||
// Internal JSON-RPC error
|
||||
errCodeInternalError = -32603
|
||||
|
||||
// Tool execution timed out
|
||||
errCodeTimeout = -32001
|
||||
|
||||
// Resource not found error
|
||||
errCodeResourceNotFound = -32002
|
||||
|
||||
// Client hasn't completed initialization
|
||||
errCodeClientNotInitialized = -32800
|
||||
)
|
||||
|
||||
// User and assistant role definitions
|
||||
const (
|
||||
// The "user" role - the entity asking questions
|
||||
roleUser roleType = "user"
|
||||
|
||||
// The "assistant" role - the entity providing responses
|
||||
roleAssistant roleType = "assistant"
|
||||
)
|
||||
|
||||
// Method names as defined in the MCP specification
|
||||
const (
|
||||
// Initialize the connection between client and server
|
||||
methodInitialize = "initialize"
|
||||
|
||||
// List available tools
|
||||
methodToolsList = "tools/list"
|
||||
|
||||
// Call a specific tool
|
||||
methodToolsCall = "tools/call"
|
||||
|
||||
// List available prompts
|
||||
methodPromptsList = "prompts/list"
|
||||
|
||||
// Get a specific prompt
|
||||
methodPromptsGet = "prompts/get"
|
||||
|
||||
// List available resources
|
||||
methodResourcesList = "resources/list"
|
||||
|
||||
// Read a specific resource
|
||||
methodResourcesRead = "resources/read"
|
||||
|
||||
// Subscribe to resource updates
|
||||
methodResourcesSubscribe = "resources/subscribe"
|
||||
|
||||
// Simple ping to check server availability
|
||||
methodPing = "ping"
|
||||
|
||||
// Notification that client is fully initialized
|
||||
methodNotificationsInitialized = "notifications/initialized"
|
||||
|
||||
// Notification that a request was canceled
|
||||
methodNotificationsCancelled = "notifications/cancelled"
|
||||
)
|
||||
|
||||
// Event names for Server-Sent Events (SSE)
|
||||
const (
|
||||
// Notification of tool list changes
|
||||
eventToolsListChanged = "tools/list_changed"
|
||||
|
||||
// Notification of prompt list changes
|
||||
eventPromptsListChanged = "prompts/list_changed"
|
||||
|
||||
// Notification of resource list changes
|
||||
eventResourcesListChanged = "resources/list_changed"
|
||||
)
|
||||
|
||||
var (
|
||||
// Default channel size for events
|
||||
eventChanSize = 10
|
||||
|
||||
// Default ping interval for checking connection availability
|
||||
// use syncx.ForAtomicDuration to ensure atomicity in test race
|
||||
pingInterval = syncx.ForAtomicDuration(30 * time.Second)
|
||||
)
|
||||
214
mcp/vars_test.go
Normal file
214
mcp/vars_test.go
Normal file
@@ -0,0 +1,214 @@
|
||||
// filepath: /Users/kevin/Develop/go/opensource/go-zero/mcp/vars_test.go
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestErrorCodes ensures error codes are applied correctly in error responses
|
||||
func TestErrorCodes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
code int
|
||||
message string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "invalid request error",
|
||||
code: errCodeInvalidRequest,
|
||||
message: "Invalid request",
|
||||
expected: `"code":-32600`,
|
||||
},
|
||||
{
|
||||
name: "method not found error",
|
||||
code: errCodeMethodNotFound,
|
||||
message: "Method not found",
|
||||
expected: `"code":-32601`,
|
||||
},
|
||||
{
|
||||
name: "invalid params error",
|
||||
code: errCodeInvalidParams,
|
||||
message: "Invalid parameters",
|
||||
expected: `"code":-32602`,
|
||||
},
|
||||
{
|
||||
name: "internal error",
|
||||
code: errCodeInternalError,
|
||||
message: "Internal server error",
|
||||
expected: `"code":-32603`,
|
||||
},
|
||||
{
|
||||
name: "timeout error",
|
||||
code: errCodeTimeout,
|
||||
message: "Operation timed out",
|
||||
expected: `"code":-32001`,
|
||||
},
|
||||
{
|
||||
name: "resource not found error",
|
||||
code: errCodeResourceNotFound,
|
||||
message: "Resource not found",
|
||||
expected: `"code":-32002`,
|
||||
},
|
||||
{
|
||||
name: "client not initialized error",
|
||||
code: errCodeClientNotInitialized,
|
||||
message: "Client not initialized",
|
||||
expected: `"code":-32800`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: int64(1),
|
||||
Error: &errorObj{
|
||||
Code: tc.code,
|
||||
Message: tc.message,
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(resp)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), tc.expected, "Error code should match expected value")
|
||||
assert.Contains(t, string(data), tc.message, "Error message should be included")
|
||||
assert.Contains(t, string(data), jsonRpcVersion, "JSON-RPC version should be included")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJsonRpcVersion ensures the correct JSON-RPC version is used
|
||||
func TestJsonRpcVersion(t *testing.T) {
|
||||
assert.Equal(t, "2.0", jsonRpcVersion, "JSON-RPC version should be 2.0")
|
||||
|
||||
// Test that it's used in responses
|
||||
resp := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: int64(1),
|
||||
Result: "test",
|
||||
}
|
||||
data, err := json.Marshal(resp)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"jsonrpc":"2.0"`, "Response should use correct JSON-RPC version")
|
||||
|
||||
// Test that it's expected in requests
|
||||
reqStr := `{"jsonrpc":"2.0","id":1,"method":"test"}`
|
||||
var req Request
|
||||
err = json.Unmarshal([]byte(reqStr), &req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, jsonRpcVersion, req.JsonRpc, "Request should parse correct JSON-RPC version")
|
||||
}
|
||||
|
||||
// TestSessionIdKey ensures session ID extraction works correctly
|
||||
func TestSessionIdKey(t *testing.T) {
|
||||
// Create a mock server implementation
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
|
||||
// Verify the key constant
|
||||
assert.Equal(t, "session_id", sessionIdKey, "Session ID key should be 'session_id'")
|
||||
|
||||
// Test that session ID is extracted correctly
|
||||
mockR := httptest.NewRequest("GET", "/?"+sessionIdKey+"=test-session", nil)
|
||||
|
||||
// Since the mock server is using the same session key logic,
|
||||
// we can test this by accessing the request query parameters directly
|
||||
sessionID := mockR.URL.Query().Get(sessionIdKey)
|
||||
assert.Equal(t, "test-session", sessionID, "Session ID should be extracted correctly")
|
||||
}
|
||||
|
||||
// TestEventTypes ensures event types are set correctly in SSE responses
|
||||
func TestEventTypes(t *testing.T) {
|
||||
// Test message event
|
||||
assert.Equal(t, "message", eventMessage, "Message event should be 'message'")
|
||||
|
||||
// Test endpoint event
|
||||
assert.Equal(t, "endpoint", eventEndpoint, "Endpoint event should be 'endpoint'")
|
||||
|
||||
// Verify them in an actual SSE format string
|
||||
messageEvent := "event: " + eventMessage + "\ndata: test\n\n"
|
||||
assert.Contains(t, messageEvent, "event: message", "Message event should format correctly")
|
||||
|
||||
endpointEvent := "event: " + eventEndpoint + "\ndata: test\n\n"
|
||||
assert.Contains(t, endpointEvent, "event: endpoint", "Endpoint event should format correctly")
|
||||
}
|
||||
|
||||
// TestCollectionKeys checks that collection keys are used correctly
|
||||
func TestCollectionKeys(t *testing.T) {
|
||||
// Verify collection key constants
|
||||
assert.Equal(t, "prompts", keyPrompts, "Prompts key should be 'prompts'")
|
||||
assert.Equal(t, "resources", keyResources, "Resources key should be 'resources'")
|
||||
assert.Equal(t, "tools", keyTools, "Tools key should be 'tools'")
|
||||
}
|
||||
|
||||
// TestRoleTypes checks that role types are used correctly
|
||||
func TestRoleTypes(t *testing.T) {
|
||||
// Verify role type constants
|
||||
assert.Equal(t, "user", string(roleUser), "User role should be 'user'")
|
||||
assert.Equal(t, "assistant", string(roleAssistant), "Assistant role should be 'assistant'")
|
||||
|
||||
// Test in annotations
|
||||
annotations := Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
}
|
||||
data, err := json.Marshal(annotations)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"audience":["user","assistant"]`, "Role types should marshal correctly")
|
||||
}
|
||||
|
||||
// TestMethodNames checks that method names are used correctly
|
||||
func TestMethodNames(t *testing.T) {
|
||||
// Verify method name constants
|
||||
methods := map[string]string{
|
||||
"initialize": methodInitialize,
|
||||
"tools/list": methodToolsList,
|
||||
"tools/call": methodToolsCall,
|
||||
"prompts/list": methodPromptsList,
|
||||
"prompts/get": methodPromptsGet,
|
||||
"resources/list": methodResourcesList,
|
||||
"resources/read": methodResourcesRead,
|
||||
"resources/subscribe": methodResourcesSubscribe,
|
||||
"ping": methodPing,
|
||||
"notifications/initialized": methodNotificationsInitialized,
|
||||
"notifications/cancelled": methodNotificationsCancelled,
|
||||
}
|
||||
|
||||
for expected, actual := range methods {
|
||||
assert.Equal(t, expected, actual, "Method name should be "+expected)
|
||||
}
|
||||
|
||||
// Test in a request
|
||||
for methodName := range methods {
|
||||
req := Request{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: int64(1),
|
||||
Method: methodName,
|
||||
}
|
||||
data, err := json.Marshal(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"method":"`+methodName+`"`, "Method name should be used in requests")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEventNames checks that event names are used correctly
|
||||
func TestEventNames(t *testing.T) {
|
||||
// Verify event name constants
|
||||
events := map[string]string{
|
||||
"tools/list_changed": eventToolsListChanged,
|
||||
"prompts/list_changed": eventPromptsListChanged,
|
||||
"resources/list_changed": eventResourcesListChanged,
|
||||
}
|
||||
|
||||
for expected, actual := range events {
|
||||
assert.Equal(t, expected, actual, "Event name should be "+expected)
|
||||
}
|
||||
|
||||
// Test event names in SSE format
|
||||
for _, eventName := range events {
|
||||
sseEvent := "event: " + eventName + "\ndata: test\n\n"
|
||||
assert.Contains(t, sseEvent, "event: "+eventName, "Event name should format correctly in SSE")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user