mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-13 09:50:00 +08:00
feat: mcp server sdk (#4794)
Signed-off-by: kevin <wanjunfeng@gmail.com>
This commit is contained in:
40
mcp/config.go
Normal file
40
mcp/config.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/rest"
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpConf defines the configuration for an MCP server.
|
||||||
|
// It embeds rest.RestConf for HTTP server settings
|
||||||
|
// and adds MCP-specific configuration options.
|
||||||
|
type McpConf struct {
|
||||||
|
rest.RestConf
|
||||||
|
Mcp struct {
|
||||||
|
// Name is the server name reported in initialize responses
|
||||||
|
Name string `json:",optional"`
|
||||||
|
|
||||||
|
// Version is the server version reported in initialize responses
|
||||||
|
Version string `json:",default=1.0.0"`
|
||||||
|
|
||||||
|
// ProtocolVersion is the MCP protocol version implemented
|
||||||
|
ProtocolVersion string `json:",default=2024-11-05"`
|
||||||
|
|
||||||
|
// BaseUrl is the base URL for the server, used in SSE endpoint messages
|
||||||
|
// If not set, defaults to http://localhost:{Port}
|
||||||
|
BaseUrl string `json:",optional"`
|
||||||
|
|
||||||
|
// SseEndpoint is the path for Server-Sent Events connections
|
||||||
|
SseEndpoint string `json:",default=/sse"`
|
||||||
|
|
||||||
|
// MessageEndpoint is the path for JSON-RPC requests
|
||||||
|
MessageEndpoint string `json:",default=/message"`
|
||||||
|
|
||||||
|
// Cors contains allowed CORS origins
|
||||||
|
Cors []string `json:",optional"`
|
||||||
|
|
||||||
|
// ToolTimeout is the maximum time allowed for tool execution
|
||||||
|
ToolTimeout time.Duration `json:",default=30s"`
|
||||||
|
}
|
||||||
|
}
|
||||||
63
mcp/config_test.go
Normal file
63
mcp/config_test.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMcpConfDefaults(t *testing.T) {
|
||||||
|
// Test default values are set correctly when unmarshalled from JSON
|
||||||
|
jsonConfig := `name: test-service
|
||||||
|
port: 8080
|
||||||
|
mcp:
|
||||||
|
name: test-mcp-server
|
||||||
|
version: 1.0.0
|
||||||
|
`
|
||||||
|
|
||||||
|
var c McpConf
|
||||||
|
err := conf.LoadFromYamlBytes([]byte(jsonConfig), &c)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Check default values
|
||||||
|
assert.Equal(t, "test-mcp-server", c.Mcp.Name)
|
||||||
|
assert.Equal(t, "1.0.0", c.Mcp.Version, "Default version should be 1.0.0")
|
||||||
|
assert.Equal(t, "2024-11-05", c.Mcp.ProtocolVersion, "Default protocol version should be 2024-11-05")
|
||||||
|
assert.Equal(t, "/sse", c.Mcp.SseEndpoint, "Default SSE endpoint should be /sse")
|
||||||
|
assert.Equal(t, "/message", c.Mcp.MessageEndpoint, "Default message endpoint should be /message")
|
||||||
|
assert.Equal(t, 30*time.Second, c.Mcp.ToolTimeout, "Default tool timeout should be 30s")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcpConfCustomValues(t *testing.T) {
|
||||||
|
// Test custom values can be set
|
||||||
|
jsonConfig := `{
|
||||||
|
"Name": "test-service",
|
||||||
|
"Port": 8080,
|
||||||
|
"Mcp": {
|
||||||
|
"Name": "test-mcp-server",
|
||||||
|
"Version": "2.0.0",
|
||||||
|
"ProtocolVersion": "2025-01-01",
|
||||||
|
"BaseUrl": "http://example.com",
|
||||||
|
"SseEndpoint": "/custom-sse",
|
||||||
|
"MessageEndpoint": "/custom-message",
|
||||||
|
"Cors": ["http://localhost:3000", "http://example.com"],
|
||||||
|
"ToolTimeout": "60s"
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
var c McpConf
|
||||||
|
err := conf.LoadFromJsonBytes([]byte(jsonConfig), &c)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Check custom values
|
||||||
|
assert.Equal(t, "test-mcp-server", c.Mcp.Name, "Name should be inherited from RestConf")
|
||||||
|
assert.Equal(t, "2.0.0", c.Mcp.Version, "Version should be customizable")
|
||||||
|
assert.Equal(t, "2025-01-01", c.Mcp.ProtocolVersion, "Protocol version should be customizable")
|
||||||
|
assert.Equal(t, "http://example.com", c.Mcp.BaseUrl, "BaseUrl should be customizable")
|
||||||
|
assert.Equal(t, "/custom-sse", c.Mcp.SseEndpoint, "SSE endpoint should be customizable")
|
||||||
|
assert.Equal(t, "/custom-message", c.Mcp.MessageEndpoint, "Message endpoint should be customizable")
|
||||||
|
assert.Equal(t, []string{"http://localhost:3000", "http://example.com"}, c.Mcp.Cors, "CORS settings should be customizable")
|
||||||
|
assert.Equal(t, 60*time.Second, c.Mcp.ToolTimeout, "Tool timeout should be customizable")
|
||||||
|
}
|
||||||
443
mcp/integration_test.go
Normal file
443
mcp/integration_test.go
Normal file
@@ -0,0 +1,443 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// syncResponseRecorder is a thread-safe wrapper around httptest.ResponseRecorder
|
||||||
|
type syncResponseRecorder struct {
|
||||||
|
*httptest.ResponseRecorder
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new synchronized response recorder
|
||||||
|
func newSyncResponseRecorder() *syncResponseRecorder {
|
||||||
|
return &syncResponseRecorder{
|
||||||
|
ResponseRecorder: httptest.NewRecorder(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override Write method to synchronize access
|
||||||
|
func (srr *syncResponseRecorder) Write(p []byte) (int, error) {
|
||||||
|
srr.mu.Lock()
|
||||||
|
defer srr.mu.Unlock()
|
||||||
|
return srr.ResponseRecorder.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override WriteHeader method to synchronize access
|
||||||
|
func (srr *syncResponseRecorder) WriteHeader(statusCode int) {
|
||||||
|
srr.mu.Lock()
|
||||||
|
defer srr.mu.Unlock()
|
||||||
|
srr.ResponseRecorder.WriteHeader(statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override Result method to synchronize access
|
||||||
|
func (srr *syncResponseRecorder) Result() *http.Response {
|
||||||
|
srr.mu.Lock()
|
||||||
|
defer srr.mu.Unlock()
|
||||||
|
return srr.ResponseRecorder.Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHTTPHandlerIntegration tests the HTTP handlers with a real server instance
|
||||||
|
func TestHTTPHandlerIntegration(t *testing.T) {
|
||||||
|
// Skip in short test mode
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a test configuration
|
||||||
|
conf := McpConf{}
|
||||||
|
conf.Mcp.Name = "test-integration"
|
||||||
|
conf.Mcp.Version = "1.0.0-test"
|
||||||
|
conf.Mcp.ToolTimeout = 1 * time.Second
|
||||||
|
|
||||||
|
// Create a mock server directly
|
||||||
|
server := &sseMcpServer{
|
||||||
|
conf: conf,
|
||||||
|
clients: make(map[string]*mcpClient),
|
||||||
|
tools: make(map[string]Tool),
|
||||||
|
prompts: make(map[string]Prompt),
|
||||||
|
resources: make(map[string]Resource),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register a test tool
|
||||||
|
err := server.RegisterTool(Tool{
|
||||||
|
Name: "echo",
|
||||||
|
Description: "Echo tool for testing",
|
||||||
|
InputSchema: InputSchema{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]any{
|
||||||
|
"message": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Message to echo",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(params map[string]any) (any, error) {
|
||||||
|
if msg, ok := params["message"].(string); ok {
|
||||||
|
return fmt.Sprintf("Echo: %s", msg), nil
|
||||||
|
}
|
||||||
|
return "Echo: no message provided", nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create a test HTTP request to the SSE endpoint
|
||||||
|
req := httptest.NewRequest("GET", "/sse", nil)
|
||||||
|
w := newSyncResponseRecorder()
|
||||||
|
|
||||||
|
// Create a done channel to signal completion of test
|
||||||
|
done := make(chan bool)
|
||||||
|
|
||||||
|
// Start the SSE handler in a goroutine
|
||||||
|
go func() {
|
||||||
|
// lock.Lock()
|
||||||
|
server.handleSSE(w, req)
|
||||||
|
// lock.Unlock()
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Allow time for the handler to process
|
||||||
|
select {
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
// Expected - handler would normally block indefinitely
|
||||||
|
case <-done:
|
||||||
|
// This shouldn't happen immediately - the handler should block
|
||||||
|
t.Error("SSE handler returned unexpectedly")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the initial headers
|
||||||
|
resp := w.Result()
|
||||||
|
assert.Equal(t, "chunked", resp.Header.Get("Transfer-Encoding"))
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
// The handler creates a client and sends the endpoint message
|
||||||
|
var sessionId string
|
||||||
|
|
||||||
|
// Give the handler time to set up the client
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Check that a client was created
|
||||||
|
server.clientsLock.Lock()
|
||||||
|
assert.Equal(t, 1, len(server.clients))
|
||||||
|
for id := range server.clients {
|
||||||
|
sessionId = id
|
||||||
|
}
|
||||||
|
server.clientsLock.Unlock()
|
||||||
|
|
||||||
|
require.NotEmpty(t, sessionId, "Expected a session ID to be created")
|
||||||
|
|
||||||
|
// Now that we have a session ID, we can test the message endpoint
|
||||||
|
messageBody, _ := json.Marshal(Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: methodInitialize,
|
||||||
|
Params: json.RawMessage(`{}`),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a message request
|
||||||
|
reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, sessionId)
|
||||||
|
msgReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(messageBody))
|
||||||
|
msgW := newSyncResponseRecorder()
|
||||||
|
|
||||||
|
// Process the message
|
||||||
|
server.handleRequest(msgW, msgReq)
|
||||||
|
|
||||||
|
// Check the response
|
||||||
|
msgResp := msgW.Result()
|
||||||
|
assert.Equal(t, http.StatusAccepted, msgResp.StatusCode)
|
||||||
|
msgResp.Body.Close() // Ensure response body is closed
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandlerResponseFlow tests the flow of a full request/response cycle
|
||||||
|
func TestHandlerResponseFlow(t *testing.T) {
|
||||||
|
// Create a mock server for testing
|
||||||
|
server := &sseMcpServer{
|
||||||
|
conf: McpConf{},
|
||||||
|
clients: map[string]*mcpClient{
|
||||||
|
"test-session": {
|
||||||
|
id: "test-session",
|
||||||
|
channel: make(chan string, 10),
|
||||||
|
initialized: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
tools: make(map[string]Tool),
|
||||||
|
prompts: make(map[string]Prompt),
|
||||||
|
resources: make(map[string]Resource),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register test resources
|
||||||
|
server.RegisterTool(Tool{
|
||||||
|
Name: "test.tool",
|
||||||
|
Description: "Test tool",
|
||||||
|
InputSchema: InputSchema{Type: "object"},
|
||||||
|
Handler: func(params map[string]any) (any, error) {
|
||||||
|
return "tool result", nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
server.RegisterPrompt(Prompt{
|
||||||
|
Name: "test.prompt",
|
||||||
|
Description: "Test prompt",
|
||||||
|
})
|
||||||
|
|
||||||
|
server.RegisterResource(Resource{
|
||||||
|
Name: "test.resource",
|
||||||
|
URI: "http://example.com",
|
||||||
|
Description: "Test resource",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a request with session ID parameter
|
||||||
|
reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, "test-session")
|
||||||
|
|
||||||
|
// Test tools/list request
|
||||||
|
toolsListBody, _ := json.Marshal(Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: methodToolsList,
|
||||||
|
Params: json.RawMessage(`{}`),
|
||||||
|
})
|
||||||
|
|
||||||
|
toolsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(toolsListBody))
|
||||||
|
toolsW := newSyncResponseRecorder()
|
||||||
|
|
||||||
|
// Process the request
|
||||||
|
server.handleRequest(toolsW, toolsReq)
|
||||||
|
|
||||||
|
// Check the response code
|
||||||
|
toolsResp := toolsW.Result()
|
||||||
|
assert.Equal(t, http.StatusAccepted, toolsResp.StatusCode)
|
||||||
|
toolsResp.Body.Close()
|
||||||
|
|
||||||
|
// Check the channel message
|
||||||
|
client := server.clients["test-session"]
|
||||||
|
select {
|
||||||
|
case message := <-client.channel:
|
||||||
|
assert.Contains(t, message, `"tools":[{"name":"test.tool"`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for tools/list response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test prompts/list request
|
||||||
|
promptsListBody, _ := json.Marshal(Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 2,
|
||||||
|
Method: methodPromptsList,
|
||||||
|
Params: json.RawMessage(`{}`),
|
||||||
|
})
|
||||||
|
|
||||||
|
promptsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(promptsListBody))
|
||||||
|
promptsW := newSyncResponseRecorder()
|
||||||
|
|
||||||
|
// Process the request
|
||||||
|
server.handleRequest(promptsW, promptsReq)
|
||||||
|
|
||||||
|
// Check the response code
|
||||||
|
promptsResp := promptsW.Result()
|
||||||
|
assert.Equal(t, http.StatusAccepted, promptsResp.StatusCode)
|
||||||
|
promptsResp.Body.Close()
|
||||||
|
|
||||||
|
// Check the channel message
|
||||||
|
select {
|
||||||
|
case message := <-client.channel:
|
||||||
|
assert.Contains(t, message, `"prompts":[{"name":"test.prompt"`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for prompts/list response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test resources/list request
|
||||||
|
resourcesListBody, _ := json.Marshal(Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 3,
|
||||||
|
Method: methodResourcesList,
|
||||||
|
Params: json.RawMessage(`{}`),
|
||||||
|
})
|
||||||
|
|
||||||
|
resourcesReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(resourcesListBody))
|
||||||
|
resourcesW := newSyncResponseRecorder()
|
||||||
|
|
||||||
|
// Process the request
|
||||||
|
server.handleRequest(resourcesW, resourcesReq)
|
||||||
|
|
||||||
|
// Check the response code
|
||||||
|
resourcesResp := resourcesW.Result()
|
||||||
|
assert.Equal(t, http.StatusAccepted, resourcesResp.StatusCode)
|
||||||
|
resourcesResp.Body.Close()
|
||||||
|
|
||||||
|
// Check the channel message
|
||||||
|
select {
|
||||||
|
case message := <-client.channel:
|
||||||
|
assert.Contains(t, message, `"name":"test.resource"`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for resources/list response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessListMethods tests the list processing methods with pagination
|
||||||
|
func TestProcessListMethods(t *testing.T) {
|
||||||
|
server := &sseMcpServer{
|
||||||
|
tools: make(map[string]Tool),
|
||||||
|
prompts: make(map[string]Prompt),
|
||||||
|
resources: make(map[string]Resource),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add some test data
|
||||||
|
for i := 1; i <= 5; i++ {
|
||||||
|
tool := Tool{
|
||||||
|
Name: fmt.Sprintf("tool%d", i),
|
||||||
|
Description: fmt.Sprintf("Tool %d", i),
|
||||||
|
InputSchema: InputSchema{Type: "object"},
|
||||||
|
}
|
||||||
|
server.tools[tool.Name] = tool
|
||||||
|
|
||||||
|
prompt := Prompt{
|
||||||
|
Name: fmt.Sprintf("prompt%d", i),
|
||||||
|
Description: fmt.Sprintf("Prompt %d", i),
|
||||||
|
}
|
||||||
|
server.prompts[prompt.Name] = prompt
|
||||||
|
|
||||||
|
resource := Resource{
|
||||||
|
Name: fmt.Sprintf("resource%d", i),
|
||||||
|
URI: fmt.Sprintf("http://example.com/%d", i),
|
||||||
|
Description: fmt.Sprintf("Resource %d", i),
|
||||||
|
}
|
||||||
|
server.resources[resource.Name] = resource
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a test client
|
||||||
|
client := &mcpClient{
|
||||||
|
id: "test-client",
|
||||||
|
channel: make(chan string, 10),
|
||||||
|
initialized: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test processListTools
|
||||||
|
req := Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: methodToolsList,
|
||||||
|
Params: json.RawMessage(`{"cursor": "", "_meta": {"progressToken": "token1"}}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
server.processListTools(client, req)
|
||||||
|
|
||||||
|
// Read response
|
||||||
|
select {
|
||||||
|
case response := <-client.channel:
|
||||||
|
assert.Contains(t, response, `"tools":`)
|
||||||
|
assert.Contains(t, response, `"progressToken":"token1"`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for tools/list response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test processListPrompts
|
||||||
|
req.ID = 2
|
||||||
|
req.Method = methodPromptsList
|
||||||
|
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
||||||
|
server.processListPrompts(client, req)
|
||||||
|
|
||||||
|
// Read response
|
||||||
|
select {
|
||||||
|
case response := <-client.channel:
|
||||||
|
assert.Contains(t, response, `"prompts":`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for prompts/list response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test processListResources
|
||||||
|
req.ID = 3
|
||||||
|
req.Method = methodResourcesList
|
||||||
|
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
||||||
|
server.processListResources(client, req)
|
||||||
|
|
||||||
|
// Read response
|
||||||
|
select {
|
||||||
|
case response := <-client.channel:
|
||||||
|
assert.Contains(t, response, `"resources":`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for resources/list response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestErrorResponseHandling tests error handling in the server
|
||||||
|
func TestErrorResponseHandling(t *testing.T) {
|
||||||
|
server := &sseMcpServer{
|
||||||
|
tools: make(map[string]Tool),
|
||||||
|
prompts: make(map[string]Prompt),
|
||||||
|
resources: make(map[string]Resource),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a test client
|
||||||
|
client := &mcpClient{
|
||||||
|
id: "test-client",
|
||||||
|
channel: make(chan string, 10),
|
||||||
|
initialized: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test invalid method
|
||||||
|
req := Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: "invalid_method",
|
||||||
|
Params: json.RawMessage(`{}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock handleRequest by directly calling error handler
|
||||||
|
server.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||||
|
|
||||||
|
// Check response
|
||||||
|
select {
|
||||||
|
case response := <-client.channel:
|
||||||
|
assert.Contains(t, response, `"error":{"code":-32601,"message":"Method not found"}`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for error response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test invalid tool
|
||||||
|
toolReq := Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 2,
|
||||||
|
Method: methodToolsCall,
|
||||||
|
Params: json.RawMessage(`{"name":"non_existent_tool"}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call process method directly
|
||||||
|
server.processToolCall(client, toolReq)
|
||||||
|
|
||||||
|
// Check response
|
||||||
|
select {
|
||||||
|
case response := <-client.channel:
|
||||||
|
assert.Contains(t, response, `"error":{"code":-32602,"message":"Tool 'non_existent_tool' not found"}`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for error response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test invalid prompt
|
||||||
|
promptReq := Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 3,
|
||||||
|
Method: methodPromptsGet,
|
||||||
|
Params: json.RawMessage(`{"name":"non_existent_prompt"}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call process method directly
|
||||||
|
server.processGetPrompt(client, promptReq)
|
||||||
|
|
||||||
|
// Check response
|
||||||
|
select {
|
||||||
|
case response := <-client.channel:
|
||||||
|
assert.Contains(t, response, `"error":{"code":-32602,"message":"Prompt 'non_existent_prompt' not found"}`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for error response")
|
||||||
|
}
|
||||||
|
}
|
||||||
62
mcp/readme.md
Normal file
62
mcp/readme.md
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
# Model Context Protocol (MCP) SDK Implementation
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
This package implements a Model Context Protocol (MCP) server in Go that facilitates real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation provides a framework for building AI-assisted applications with bidirectional communication capabilities.
|
||||||
|
|
||||||
|
## Core Components
|
||||||
|
|
||||||
|
### Server-Sent Events (SSE) Communication
|
||||||
|
- **Real-time Communication**: Robust SSE-based communication system that maintains persistent connections with clients
|
||||||
|
- **Connection Management**: Client registration, message broadcasting, and client cleanup mechanisms
|
||||||
|
- **Event Handling**: Event types for tools, prompts, and resources changes
|
||||||
|
|
||||||
|
### JSON-RPC Implementation
|
||||||
|
- **Request Processing**: Complete JSON-RPC request processor for handling MCP protocol methods
|
||||||
|
- **Response Formatting**: Proper response formatting according to JSON-RPC specifications
|
||||||
|
- **Error Handling**: Comprehensive error handling with appropriate error codes
|
||||||
|
|
||||||
|
### Tool Management
|
||||||
|
- **Tool Registration**: System to register custom tools with handlers
|
||||||
|
- **Tool Execution**: Mechanism to execute tool functions with proper timeout handling
|
||||||
|
- **Result Handling**: Flexible result handling supporting various return types (string, JSON, images)
|
||||||
|
|
||||||
|
### Prompt System
|
||||||
|
- **Prompt Registration**: System for registering both static and dynamic prompts
|
||||||
|
- **Argument Validation**: Validation for required arguments and default values for optional ones
|
||||||
|
- **Message Generation**: Handlers that generate properly formatted conversation messages
|
||||||
|
|
||||||
|
### Resource Management
|
||||||
|
- **Resource Registration**: System for managing and accessing external resources
|
||||||
|
- **Content Delivery**: Handlers for delivering resource content to clients on demand
|
||||||
|
- **Resource Subscription**: Mechanisms for clients to subscribe to resource updates
|
||||||
|
|
||||||
|
### Protocol Features
|
||||||
|
- **Initialization Sequence**: Proper handshaking with capability negotiation
|
||||||
|
- **Notification Handling**: Support for both standard and client-specific notifications
|
||||||
|
- **Message Routing**: Intelligent routing of requests to appropriate handlers
|
||||||
|
|
||||||
|
## Technical Highlights
|
||||||
|
|
||||||
|
### Configuration System
|
||||||
|
- **Flexible Configuration**: Configuration system with sensible defaults and customization options
|
||||||
|
- **CORS Support**: Configurable CORS settings for cross-origin requests
|
||||||
|
- **Server Information**: Proper server identification and versioning
|
||||||
|
|
||||||
|
### Client Session Management
|
||||||
|
- **Session Tracking**: Client session tracking with unique identifiers
|
||||||
|
- **Connection Health**: Ping/pong mechanism to maintain connection health
|
||||||
|
- **Initialization State**: Client initialization state tracking
|
||||||
|
|
||||||
|
### Content Handling
|
||||||
|
- **Multi-format Content**: Support for text, code, and binary content
|
||||||
|
- **MIME Type Support**: Proper MIME type identification for various content types
|
||||||
|
- **Audience Annotations**: Content audience annotations for user/assistant targeting
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To create and use an MCP server, see the examples directory for practical implementation examples including:
|
||||||
|
- Tool registration and execution
|
||||||
|
- Static and dynamic prompt creation
|
||||||
|
- Resource handling with proper URI identification
|
||||||
|
- Embedded resources in prompt responses
|
||||||
|
- Client connection management
|
||||||
974
mcp/server.go
Normal file
974
mcp/server.go
Normal file
@@ -0,0 +1,974 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"github.com/zeromicro/go-zero/rest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewMcpServer(c McpConf) McpServer {
|
||||||
|
var server *rest.Server
|
||||||
|
if len(c.Mcp.Cors) == 0 {
|
||||||
|
server = rest.MustNewServer(c.RestConf)
|
||||||
|
} else {
|
||||||
|
server = rest.MustNewServer(c.RestConf, rest.WithCors(c.Mcp.Cors...))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c.Mcp.Name) == 0 {
|
||||||
|
c.Mcp.Name = c.Name
|
||||||
|
}
|
||||||
|
if len(c.Mcp.BaseUrl) == 0 {
|
||||||
|
c.Mcp.BaseUrl = fmt.Sprintf("http://localhost:%d", c.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &sseMcpServer{
|
||||||
|
conf: c,
|
||||||
|
server: server,
|
||||||
|
clients: make(map[string]*mcpClient),
|
||||||
|
tools: make(map[string]Tool),
|
||||||
|
prompts: make(map[string]Prompt),
|
||||||
|
resources: make(map[string]Resource),
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE endpoint for real-time updates
|
||||||
|
s.server.AddRoute(rest.Route{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: s.conf.Mcp.SseEndpoint,
|
||||||
|
Handler: s.handleSSE,
|
||||||
|
}, rest.WithSSE())
|
||||||
|
|
||||||
|
// JSON-RPC message endpoint for regular requests
|
||||||
|
s.server.AddRoute(rest.Route{
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: s.conf.Mcp.MessageEndpoint,
|
||||||
|
Handler: s.handleRequest,
|
||||||
|
})
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterPrompt registers a new prompt with the server
|
||||||
|
func (s *sseMcpServer) RegisterPrompt(prompt Prompt) {
|
||||||
|
s.promptsLock.Lock()
|
||||||
|
s.prompts[prompt.Name] = prompt
|
||||||
|
s.promptsLock.Unlock()
|
||||||
|
// Notify clients about the new prompt
|
||||||
|
s.broadcast(eventPromptsListChanged, map[string][]Prompt{keyPrompts: {prompt}})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterResource registers a new resource with the server
|
||||||
|
func (s *sseMcpServer) RegisterResource(resource Resource) {
|
||||||
|
s.resourcesLock.Lock()
|
||||||
|
s.resources[resource.URI] = resource
|
||||||
|
s.resourcesLock.Unlock()
|
||||||
|
// Notify clients about the new resource
|
||||||
|
s.broadcast(eventResourcesListChanged, map[string][]Resource{keyResources: {resource}})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterTool registers a new tool with the server
|
||||||
|
func (s *sseMcpServer) RegisterTool(tool Tool) error {
|
||||||
|
if tool.Handler == nil {
|
||||||
|
return fmt.Errorf("tool '%s' has no handler function", tool.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.toolsLock.Lock()
|
||||||
|
s.tools[tool.Name] = tool
|
||||||
|
s.toolsLock.Unlock()
|
||||||
|
// Notify clients about the new tool
|
||||||
|
s.broadcast(eventToolsListChanged, map[string][]Tool{keyTools: {tool}})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start implements McpServer.
|
||||||
|
func (s *sseMcpServer) Start() {
|
||||||
|
s.server.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sseMcpServer) Stop() {
|
||||||
|
s.server.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// broadcast sends a message to all connected clients
|
||||||
|
// It uses Server-Sent Events (SSE) format for real-time communication
|
||||||
|
func (s *sseMcpServer) broadcast(event string, data any) {
|
||||||
|
jsonData, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
logx.Errorf("Failed to marshal broadcast data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lock only while reading the clients map
|
||||||
|
s.clientsLock.Lock()
|
||||||
|
clients := make([]*mcpClient, 0, len(s.clients))
|
||||||
|
for _, client := range s.clients {
|
||||||
|
clients = append(clients, client)
|
||||||
|
}
|
||||||
|
s.clientsLock.Unlock()
|
||||||
|
|
||||||
|
clientCount := len(clients)
|
||||||
|
if clientCount == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logx.Infof("Broadcasting event '%s' to %d clients", event, clientCount)
|
||||||
|
|
||||||
|
// Use CRLF line endings as per SSE specification
|
||||||
|
message := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(jsonData))
|
||||||
|
|
||||||
|
// Send messages without holding the lock
|
||||||
|
for _, client := range clients {
|
||||||
|
select {
|
||||||
|
case client.channel <- message:
|
||||||
|
// Message sent successfully
|
||||||
|
default:
|
||||||
|
// Channel buffer is full, log warning and continue
|
||||||
|
logx.Errorf("Client channel buffer full, dropping message for client %s", client.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupClient removes a client from the active clients map
|
||||||
|
func (s *sseMcpServer) cleanupClient(sessionId string) {
|
||||||
|
s.clientsLock.Lock()
|
||||||
|
defer s.clientsLock.Unlock()
|
||||||
|
|
||||||
|
if client, exists := s.clients[sessionId]; exists {
|
||||||
|
// Close the channel to signal any goroutines waiting on it
|
||||||
|
close(client.channel)
|
||||||
|
// Remove from active clients
|
||||||
|
delete(s.clients, sessionId)
|
||||||
|
logx.Infof("Cleaned up client %s (remaining clients: %d)", sessionId, len(s.clients))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleRequest handles MCP JSON-RPC requests
|
||||||
|
func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Extract sessionId from query parameters
|
||||||
|
sessionId := r.URL.Query().Get(sessionIdKey)
|
||||||
|
if len(sessionId) == 0 {
|
||||||
|
http.Error(w, fmt.Sprintf("Missing %s", sessionIdKey), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the client with this sessionId exists
|
||||||
|
s.clientsLock.Lock()
|
||||||
|
client, exists := s.clients[sessionId]
|
||||||
|
s.clientsLock.Unlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
http.Error(w, fmt.Sprintf("Invalid or expired %s", sessionIdKey), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req Request
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, "Invalid request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusAccepted)
|
||||||
|
|
||||||
|
// For notification methods (no ID), we don't send a response
|
||||||
|
isNotification := req.ID == 0
|
||||||
|
|
||||||
|
// Special handling for initialization sequence
|
||||||
|
// Always allow initialize and notifications/initialized regardless of client state
|
||||||
|
if req.Method == methodInitialize {
|
||||||
|
logx.Infof("Processing initialize request with ID: %d", req.ID)
|
||||||
|
s.processInitialize(client, req)
|
||||||
|
logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID)
|
||||||
|
return
|
||||||
|
} else if req.Method == methodNotificationsInitialized {
|
||||||
|
// Handle initialized notification
|
||||||
|
logx.Info("Received notifications/initialized notification")
|
||||||
|
if !isNotification {
|
||||||
|
s.sendErrorResponse(client, req.ID, "Method should be used as a notification", errCodeInvalidRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.processNotificationInitialized(client)
|
||||||
|
return
|
||||||
|
} else if !client.initialized && req.Method != methodNotificationsCancelled {
|
||||||
|
// Block most requests until client is initialized (except for cancellations)
|
||||||
|
s.sendErrorResponse(client, req.ID, "Client not fully initialized, waiting for notifications/initialized",
|
||||||
|
errCodeClientNotInitialized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process normal requests only after initialization
|
||||||
|
switch req.Method {
|
||||||
|
case methodToolsCall:
|
||||||
|
logx.Infof("Received tools call request with ID: %d", req.ID)
|
||||||
|
s.processToolCall(client, req)
|
||||||
|
logx.Infof("Sent tools call response for ID: %d", req.ID)
|
||||||
|
case methodToolsList:
|
||||||
|
logx.Infof("Processing tools/list request with ID: %d", req.ID)
|
||||||
|
s.processListTools(client, req)
|
||||||
|
logx.Infof("Sent tools/list response for ID: %d", req.ID)
|
||||||
|
case methodPromptsList:
|
||||||
|
logx.Infof("Processing prompts/list request with ID: %d", req.ID)
|
||||||
|
s.processListPrompts(client, req)
|
||||||
|
logx.Infof("Sent prompts/list response for ID: %d", req.ID)
|
||||||
|
case methodPromptsGet:
|
||||||
|
logx.Infof("Processing prompts/get request with ID: %d", req.ID)
|
||||||
|
s.processGetPrompt(client, req)
|
||||||
|
logx.Infof("Sent prompts/get response for ID: %d", req.ID)
|
||||||
|
case methodResourcesList:
|
||||||
|
logx.Infof("Processing resources/list request with ID: %d", req.ID)
|
||||||
|
s.processListResources(client, req)
|
||||||
|
logx.Infof("Sent resources/list response for ID: %d", req.ID)
|
||||||
|
case methodResourcesRead:
|
||||||
|
logx.Infof("Processing resources/read request with ID: %d", req.ID)
|
||||||
|
s.processResourcesRead(client, req)
|
||||||
|
logx.Infof("Sent resources/read response for ID: %d", req.ID)
|
||||||
|
case methodResourcesSubscribe:
|
||||||
|
logx.Infof("Processing resources/subscribe request with ID: %d", req.ID)
|
||||||
|
s.processResourceSubscribe(client, req)
|
||||||
|
logx.Infof("Sent resources/subscribe response for ID: %d", req.ID)
|
||||||
|
case methodPing:
|
||||||
|
logx.Infof("Processing ping request with ID: %d", req.ID)
|
||||||
|
s.processPing(client, req)
|
||||||
|
case methodNotificationsCancelled:
|
||||||
|
logx.Infof("Received notifications/cancelled notification: %v", req.Params)
|
||||||
|
s.processNotificationCancelled(client, req)
|
||||||
|
default:
|
||||||
|
logx.Infof("Unknown method: %s", req.Method)
|
||||||
|
s.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSSE handles Server-Sent Events connections
|
||||||
|
func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Generate a unique session ID for this client
|
||||||
|
sessionId := uuid.New().String()
|
||||||
|
|
||||||
|
// Create new client with buffered channel to prevent blocking
|
||||||
|
client := &mcpClient{
|
||||||
|
id: sessionId,
|
||||||
|
channel: make(chan string, eventChanSize),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add client to active clients map
|
||||||
|
s.clientsLock.Lock()
|
||||||
|
s.clients[sessionId] = client
|
||||||
|
activeClients := len(s.clients)
|
||||||
|
s.clientsLock.Unlock()
|
||||||
|
|
||||||
|
logx.Infof("New SSE connection established for client %s (active clients: %d)",
|
||||||
|
sessionId, activeClients)
|
||||||
|
|
||||||
|
// Set proper SSE headers
|
||||||
|
w.Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
|
||||||
|
// Enable streaming
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
logx.Error("Streaming not supported by the underlying http.ResponseWriter")
|
||||||
|
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the message endpoint URL to the client
|
||||||
|
endpoint := fmt.Sprintf("%s%s?%s=%s",
|
||||||
|
s.conf.Mcp.BaseUrl, s.conf.Mcp.MessageEndpoint, sessionIdKey, sessionId)
|
||||||
|
|
||||||
|
// Format and send the endpoint message
|
||||||
|
endpointMsg := formatSSEMessage(eventEndpoint, []byte(endpoint))
|
||||||
|
if _, err := fmt.Fprint(w, endpointMsg); err != nil {
|
||||||
|
logx.Errorf("Failed to send endpoint message to client %s: %v", sessionId, err)
|
||||||
|
s.cleanupClient(sessionId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
|
||||||
|
// Set up keep-alive ping and client cleanup
|
||||||
|
ticker := time.NewTicker(pingInterval.Load())
|
||||||
|
defer func() {
|
||||||
|
ticker.Stop()
|
||||||
|
s.cleanupClient(sessionId)
|
||||||
|
logx.Infof("SSE connection closed for client %s", sessionId)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Message processing loop
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case message, ok := <-client.channel:
|
||||||
|
if !ok {
|
||||||
|
// Channel was closed, end connection
|
||||||
|
logx.Infof("Client channel was closed for %s", sessionId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write message to the response
|
||||||
|
if _, err := fmt.Fprint(w, message); err != nil {
|
||||||
|
logx.Infof("Failed to write message to client %s: %v", sessionId, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
case <-ticker.C:
|
||||||
|
// Send keep-alive ping to maintain connection
|
||||||
|
ping := fmt.Sprintf(`{"type":"ping","timestamp":"%s"}`, time.Now().String())
|
||||||
|
pingMsg := formatSSEMessage("ping", []byte(ping))
|
||||||
|
if _, err := fmt.Fprint(w, pingMsg); err != nil {
|
||||||
|
logx.Errorf("Failed to send ping to client %s, closing connection: %v", sessionId, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
case <-r.Context().Done():
|
||||||
|
// Client disconnected or request was canceled
|
||||||
|
logx.Infof("Client %s disconnected: context done", sessionId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processInitialize processes the initialize request
|
||||||
|
func (s *sseMcpServer) processInitialize(client *mcpClient, req Request) {
|
||||||
|
// Create a proper JSON-RPC response that preserves the client's request ID
|
||||||
|
result := initializationResponse{
|
||||||
|
ProtocolVersion: s.conf.Mcp.ProtocolVersion,
|
||||||
|
Capabilities: capabilities{
|
||||||
|
Prompts: struct {
|
||||||
|
ListChanged bool `json:"listChanged"`
|
||||||
|
}{
|
||||||
|
ListChanged: true,
|
||||||
|
},
|
||||||
|
Resources: struct {
|
||||||
|
Subscribe bool `json:"subscribe"`
|
||||||
|
ListChanged bool `json:"listChanged"`
|
||||||
|
}{
|
||||||
|
Subscribe: true,
|
||||||
|
ListChanged: true,
|
||||||
|
},
|
||||||
|
Tools: struct {
|
||||||
|
ListChanged bool `json:"listChanged"`
|
||||||
|
}{
|
||||||
|
ListChanged: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ServerInfo: serverInfo{
|
||||||
|
Name: s.conf.Mcp.Name,
|
||||||
|
Version: s.conf.Mcp.Version,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark client as initialized
|
||||||
|
client.initialized = true
|
||||||
|
|
||||||
|
// Send response with client's original request ID
|
||||||
|
s.sendResponse(client, req.ID, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processListTools processes the tools/list request
|
||||||
|
func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
|
||||||
|
// Extract pagination params if any
|
||||||
|
var nextCursor string
|
||||||
|
var progressToken any
|
||||||
|
|
||||||
|
// Extract meta data including progress token
|
||||||
|
if req.Params != nil {
|
||||||
|
var metaParams struct {
|
||||||
|
Cursor string `json:"cursor"`
|
||||||
|
Meta struct {
|
||||||
|
ProgressToken any `json:"progressToken"`
|
||||||
|
} `json:"_meta"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(req.Params, &metaParams); err == nil {
|
||||||
|
if len(metaParams.Cursor) > 0 {
|
||||||
|
nextCursor = metaParams.Cursor
|
||||||
|
}
|
||||||
|
progressToken = metaParams.Meta.ProgressToken
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var toolsList []Tool
|
||||||
|
s.toolsLock.Lock()
|
||||||
|
for _, tool := range s.tools {
|
||||||
|
toolsList = append(toolsList, tool)
|
||||||
|
}
|
||||||
|
s.toolsLock.Unlock()
|
||||||
|
|
||||||
|
result := ListToolsResult{
|
||||||
|
PaginatedResult: PaginatedResult{
|
||||||
|
Result: Result{},
|
||||||
|
NextCursor: Cursor(nextCursor),
|
||||||
|
},
|
||||||
|
Tools: toolsList,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add meta information if progress token was provided
|
||||||
|
if progressToken != nil {
|
||||||
|
result.Result.Meta = map[string]any{
|
||||||
|
"progressToken": progressToken,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendResponse(client, req.ID, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processListPrompts processes the prompts/list request
|
||||||
|
func (s *sseMcpServer) processListPrompts(client *mcpClient, req Request) {
|
||||||
|
// Extract pagination params if any
|
||||||
|
var nextCursor string
|
||||||
|
if req.Params != nil {
|
||||||
|
var cursorParams struct {
|
||||||
|
Cursor string `json:"cursor"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(req.Params, &cursorParams); err == nil && cursorParams.Cursor != "" {
|
||||||
|
// If we have a valid cursor, we could use it for pagination
|
||||||
|
// For now, we're not actually implementing pagination, so this is just
|
||||||
|
// to show how it would be extracted from the request
|
||||||
|
_ = cursorParams.Cursor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare prompt list
|
||||||
|
var promptsList []Prompt
|
||||||
|
s.promptsLock.Lock()
|
||||||
|
for _, prompt := range s.prompts {
|
||||||
|
promptsList = append(promptsList, prompt)
|
||||||
|
}
|
||||||
|
s.promptsLock.Unlock()
|
||||||
|
|
||||||
|
// In a real implementation, you'd handle pagination here
|
||||||
|
// For now, we'll return all prompts at once
|
||||||
|
result := struct {
|
||||||
|
Prompts []Prompt `json:"prompts"`
|
||||||
|
NextCursor string `json:"nextCursor,omitempty"`
|
||||||
|
Meta *struct{} `json:"_meta,omitempty"`
|
||||||
|
}{
|
||||||
|
Prompts: promptsList,
|
||||||
|
NextCursor: nextCursor,
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendResponse(client, req.ID, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processListResources processes the resources/list request
|
||||||
|
func (s *sseMcpServer) processListResources(client *mcpClient, req Request) {
|
||||||
|
// Extract pagination params if any
|
||||||
|
var nextCursor string
|
||||||
|
var progressToken any
|
||||||
|
|
||||||
|
// Extract meta information including progress token if available
|
||||||
|
if req.Params != nil {
|
||||||
|
var metaParams PaginatedParams
|
||||||
|
if err := json.Unmarshal(req.Params, &metaParams); err == nil {
|
||||||
|
if len(metaParams.Cursor) > 0 {
|
||||||
|
nextCursor = metaParams.Cursor
|
||||||
|
}
|
||||||
|
progressToken = metaParams.Meta.ProgressToken
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var resourcesList []Resource
|
||||||
|
s.resourcesLock.Lock()
|
||||||
|
for _, resource := range s.resources {
|
||||||
|
// Create a copy without the handler function which shouldn't be sent to clients
|
||||||
|
resourceCopy := Resource{
|
||||||
|
URI: resource.URI,
|
||||||
|
Name: resource.Name,
|
||||||
|
Description: resource.Description,
|
||||||
|
MimeType: resource.MimeType,
|
||||||
|
}
|
||||||
|
resourcesList = append(resourcesList, resourceCopy)
|
||||||
|
}
|
||||||
|
s.resourcesLock.Unlock()
|
||||||
|
|
||||||
|
// Create proper ResourcesListResult according to MCP specification
|
||||||
|
result := ResourcesListResult{
|
||||||
|
PaginatedResult: PaginatedResult{
|
||||||
|
Result: Result{},
|
||||||
|
NextCursor: Cursor(nextCursor),
|
||||||
|
},
|
||||||
|
Resources: resourcesList,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add meta information if progress token was provided
|
||||||
|
if progressToken != nil {
|
||||||
|
result.Result.Meta = map[string]any{
|
||||||
|
"progressToken": progressToken,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendResponse(client, req.ID, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processGetPrompt processes the prompts/get request
|
||||||
|
func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||||
|
type GetPromptParams struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments map[string]string `json:"arguments,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var params GetPromptParams
|
||||||
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||||
|
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if prompt exists
|
||||||
|
s.promptsLock.Lock()
|
||||||
|
prompt, exists := s.prompts[params.Name]
|
||||||
|
s.promptsLock.Unlock()
|
||||||
|
if !exists {
|
||||||
|
message := fmt.Sprintf("Prompt '%s' not found", params.Name)
|
||||||
|
s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logx.Infof("Processing prompt request: %s with %d arguments", prompt.Name, len(params.Arguments))
|
||||||
|
|
||||||
|
// Validate required arguments
|
||||||
|
missingArgs := validatePromptArguments(prompt, params.Arguments)
|
||||||
|
if len(missingArgs) > 0 {
|
||||||
|
message := fmt.Sprintf("Missing required arguments: %s", strings.Join(missingArgs, ", "))
|
||||||
|
s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply default values for missing optional arguments
|
||||||
|
args := applyDefaultArguments(prompt, params.Arguments)
|
||||||
|
|
||||||
|
// Generate messages using handler or static content
|
||||||
|
var messages []PromptMessage
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if prompt.Handler != nil {
|
||||||
|
// Use dynamic handler to generate messages
|
||||||
|
logx.Info("Using prompt handler to generate content")
|
||||||
|
messages, err = prompt.Handler(args)
|
||||||
|
if err != nil {
|
||||||
|
logx.Errorf("Error from prompt handler: %v", err)
|
||||||
|
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No handler, generate messages from static content
|
||||||
|
var messageText string
|
||||||
|
if prompt.Content != "" {
|
||||||
|
messageText = prompt.Content
|
||||||
|
|
||||||
|
// Apply argument substitutions to static content
|
||||||
|
for key, value := range args {
|
||||||
|
placeholder := fmt.Sprintf("{{%s}}", key)
|
||||||
|
messageText = strings.Replace(messageText, placeholder, value, -1)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No content, use a default fallback
|
||||||
|
topic := "this topic"
|
||||||
|
if t, ok := args["topic"]; ok && t != "" {
|
||||||
|
topic = t
|
||||||
|
}
|
||||||
|
messageText = fmt.Sprintf("Tell me about %s", topic)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a single user message with the content
|
||||||
|
messages = []PromptMessage{
|
||||||
|
{
|
||||||
|
Role: roleUser,
|
||||||
|
Content: TextContent{
|
||||||
|
Type: contentTypeText,
|
||||||
|
Text: messageText,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct the response according to MCP spec
|
||||||
|
result := struct {
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Messages []PromptMessage `json:"messages"`
|
||||||
|
}{
|
||||||
|
Description: prompt.Description,
|
||||||
|
Messages: messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendResponse(client, req.ID, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validatePromptArguments checks if all required arguments are provided
|
||||||
|
// Returns a list of missing required arguments
|
||||||
|
func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string {
|
||||||
|
var missingArgs []string
|
||||||
|
|
||||||
|
for _, arg := range prompt.Arguments {
|
||||||
|
if arg.Required {
|
||||||
|
if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 {
|
||||||
|
missingArgs = append(missingArgs, arg.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return missingArgs
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyDefaultArguments adds default values for missing optional arguments
|
||||||
|
func applyDefaultArguments(prompt Prompt, providedArgs map[string]string) map[string]string {
|
||||||
|
result := make(map[string]string)
|
||||||
|
|
||||||
|
// Copy all provided arguments
|
||||||
|
for k, v := range providedArgs {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add defaults for missing arguments
|
||||||
|
for _, arg := range prompt.Arguments {
|
||||||
|
if _, exists := result[arg.Name]; !exists && arg.Default != "" {
|
||||||
|
result[arg.Name] = arg.Default
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// processToolCall processes the tools/call request
|
||||||
|
func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||||
|
var toolCallParams struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments map[string]any `json:"arguments,omitempty"`
|
||||||
|
Meta struct {
|
||||||
|
ProgressToken any `json:"progressToken"`
|
||||||
|
} `json:"_meta,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle different types of req.Params
|
||||||
|
// If it's a RawMessage (JSON), unmarshal it
|
||||||
|
if err := json.Unmarshal(req.Params, &toolCallParams); err != nil {
|
||||||
|
logx.Errorf("Failed to unmarshal tool call params: %v", err)
|
||||||
|
s.sendErrorResponse(client, req.ID, "Invalid tool call parameters", errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract progress token if available
|
||||||
|
progressToken := toolCallParams.Meta.ProgressToken
|
||||||
|
|
||||||
|
// Find the requested tool
|
||||||
|
s.toolsLock.Lock()
|
||||||
|
tool, exists := s.tools[toolCallParams.Name]
|
||||||
|
s.toolsLock.Unlock()
|
||||||
|
if !exists {
|
||||||
|
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Tool '%s' not found",
|
||||||
|
toolCallParams.Name), errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a context with the configured timeout
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), s.conf.Mcp.ToolTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Log parameters before execution
|
||||||
|
logx.Infof("Executing tool '%s' with arguments: %#v", toolCallParams.Name, toolCallParams.Arguments)
|
||||||
|
|
||||||
|
// Execute the tool handler with timeout handling
|
||||||
|
var result any
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Create a channel to receive the result
|
||||||
|
resultCh := make(chan struct {
|
||||||
|
result any
|
||||||
|
err error
|
||||||
|
}, 1)
|
||||||
|
|
||||||
|
// Execute the tool handler in a goroutine
|
||||||
|
go func() {
|
||||||
|
toolResult, toolErr := tool.Handler(toolCallParams.Arguments)
|
||||||
|
resultCh <- struct {
|
||||||
|
result any
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
result: toolResult,
|
||||||
|
err: toolErr,
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for either the result or a timeout
|
||||||
|
select {
|
||||||
|
case res := <-resultCh:
|
||||||
|
result = res.result
|
||||||
|
err = res.err
|
||||||
|
case <-ctx.Done():
|
||||||
|
// Handle timeout
|
||||||
|
logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.ToolTimeout, toolCallParams.Name)
|
||||||
|
s.sendErrorResponse(client, req.ID, "Tool execution timed out", errCodeTimeout)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the base result structure with metadata
|
||||||
|
callToolResult := CallToolResult{
|
||||||
|
Result: Result{},
|
||||||
|
Content: []any{},
|
||||||
|
IsError: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add meta information if progress token was provided
|
||||||
|
if progressToken != nil {
|
||||||
|
callToolResult.Result.Meta = map[string]any{
|
||||||
|
"progressToken": progressToken,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if there was an error during tool execution
|
||||||
|
if err != nil {
|
||||||
|
// According to the spec, for tool-level errors (as opposed to protocol-level errors),
|
||||||
|
// we should report them inside the result with isError=true
|
||||||
|
logx.Errorf("Tool execution reported error: %v", err)
|
||||||
|
|
||||||
|
callToolResult.Content = []any{
|
||||||
|
TextContent{
|
||||||
|
Type: contentTypeText,
|
||||||
|
Text: fmt.Sprintf("Error: %v", err),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
callToolResult.IsError = true
|
||||||
|
s.sendResponse(client, req.ID, callToolResult)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format the response according to the CallToolResult schema
|
||||||
|
switch v := result.(type) {
|
||||||
|
case string:
|
||||||
|
// Simple string becomes text content
|
||||||
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
|
Type: contentTypeText,
|
||||||
|
Text: v,
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []roleType{roleUser, roleAssistant},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
case map[string]any:
|
||||||
|
// JSON-like object becomes formatted JSON text
|
||||||
|
jsonStr, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
jsonStr = []byte(err.Error())
|
||||||
|
}
|
||||||
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
|
Type: contentTypeText,
|
||||||
|
Text: string(jsonStr),
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []roleType{roleUser, roleAssistant},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
case TextContent:
|
||||||
|
// Direct TextContent object
|
||||||
|
callToolResult.Content = append(callToolResult.Content, v)
|
||||||
|
case ImageContent:
|
||||||
|
// Direct ImageContent object
|
||||||
|
callToolResult.Content = append(callToolResult.Content, v)
|
||||||
|
case []any:
|
||||||
|
// Array of content items
|
||||||
|
callToolResult.Content = v
|
||||||
|
case ToolResult:
|
||||||
|
// Handle legacy ToolResult type
|
||||||
|
switch v.Type {
|
||||||
|
case contentTypeText:
|
||||||
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
|
Type: contentTypeText,
|
||||||
|
Text: fmt.Sprintf("%v", v.Content),
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []roleType{roleUser, roleAssistant},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
case contentTypeImage:
|
||||||
|
if imgData, ok := v.Content.(map[string]any); ok {
|
||||||
|
callToolResult.Content = append(callToolResult.Content, ImageContent{
|
||||||
|
Type: contentTypeImage,
|
||||||
|
Data: fmt.Sprintf("%v", imgData["data"]),
|
||||||
|
MimeType: fmt.Sprintf("%v", imgData["mimeType"]),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
|
Type: contentTypeText,
|
||||||
|
Text: fmt.Sprintf("%v", v.Content),
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []roleType{roleUser, roleAssistant},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// For any other type, convert to string
|
||||||
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
|
Type: contentTypeText,
|
||||||
|
Text: fmt.Sprintf("%v", v),
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []roleType{roleUser, roleAssistant},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
logx.Infof("Tool call result: %#v", callToolResult)
|
||||||
|
s.sendResponse(client, req.ID, callToolResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processResourcesRead processes the resources/read request
|
||||||
|
func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
||||||
|
var params ResourceReadParams
|
||||||
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||||
|
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find resource that matches the URI
|
||||||
|
s.resourcesLock.Lock()
|
||||||
|
resource, exists := s.resources[params.URI]
|
||||||
|
s.resourcesLock.Unlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||||
|
params.URI), errCodeResourceNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no handler is provided, return an empty content array
|
||||||
|
if resource.Handler == nil {
|
||||||
|
result := ResourceReadResult{
|
||||||
|
Contents: []ResourceContent{
|
||||||
|
{
|
||||||
|
URI: params.URI,
|
||||||
|
MimeType: resource.MimeType,
|
||||||
|
Text: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s.sendResponse(client, req.ID, result)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the resource handler
|
||||||
|
content, err := resource.Handler()
|
||||||
|
if err != nil {
|
||||||
|
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error reading resource: %v", err),
|
||||||
|
errCodeInternalError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the URI is set if not already provided by the handler
|
||||||
|
if len(content.URI) == 0 {
|
||||||
|
content.URI = params.URI
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure MimeType is set if available from the resource definition
|
||||||
|
if len(content.MimeType) == 0 && resource.MimeType != "" {
|
||||||
|
content.MimeType = resource.MimeType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create response with contents from the handler
|
||||||
|
// The MCP specification requires a contents array
|
||||||
|
result := ResourceReadResult{
|
||||||
|
Contents: []ResourceContent{content},
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendResponse(client, req.ID, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processResourceSubscribe processes the resources/subscribe request
|
||||||
|
func (s *sseMcpServer) processResourceSubscribe(client *mcpClient, req Request) {
|
||||||
|
var params ResourceSubscribeParams
|
||||||
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||||
|
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the resource exists
|
||||||
|
s.resourcesLock.Lock()
|
||||||
|
_, exists := s.resources[params.URI]
|
||||||
|
s.resourcesLock.Unlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||||
|
params.URI), errCodeResourceNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send success response for the subscription
|
||||||
|
s.sendResponse(client, req.ID, struct{}{})
|
||||||
|
|
||||||
|
logx.Infof("Client %s subscribed to resource '%s'", client.id, params.URI)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processNotificationCancelled processes the notifications/cancelled notification
|
||||||
|
func (s *sseMcpServer) processNotificationCancelled(client *mcpClient, req Request) {
|
||||||
|
// Extract the requestId that was canceled
|
||||||
|
type CancelParams struct {
|
||||||
|
RequestId int64 `json:"requestId"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var params CancelParams
|
||||||
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||||
|
logx.Errorf("Failed to parse cancellation params: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logx.Infof("Request %d was cancelled by client. Reason: %s", params.RequestId, params.Reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processNotificationInitialized processes the notifications/initialized notification
|
||||||
|
func (s *sseMcpServer) processNotificationInitialized(client *mcpClient) {
|
||||||
|
// Mark the client as properly initialized
|
||||||
|
client.initialized = true
|
||||||
|
logx.Infof("Client %s is now fully initialized and ready for normal operations", client.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processPing processes the ping request and responds immediately
|
||||||
|
func (s *sseMcpServer) processPing(client *mcpClient, req Request) {
|
||||||
|
// A ping request should simply respond with an empty result to confirm the server is alive
|
||||||
|
logx.Infof("Received ping request with ID: %d", req.ID)
|
||||||
|
|
||||||
|
// Send an empty response with client's original request ID
|
||||||
|
s.sendResponse(client, req.ID, struct{}{})
|
||||||
|
|
||||||
|
logx.Infof("Sent ping response for ID: %d", req.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendErrorResponse sends an error response via the SSE channel
|
||||||
|
func (s *sseMcpServer) sendErrorResponse(client *mcpClient, id int64, message string, code int) {
|
||||||
|
errorResponse := struct {
|
||||||
|
JsonRpc string `json:"jsonrpc"`
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Error errorMessage `json:"error"`
|
||||||
|
}{
|
||||||
|
JsonRpc: jsonRpcVersion,
|
||||||
|
ID: id,
|
||||||
|
Error: errorMessage{
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// all fields are primitive types, impossible to fail
|
||||||
|
jsonData, _ := json.Marshal(errorResponse)
|
||||||
|
// Use CRLF line endings as requested
|
||||||
|
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||||
|
logx.Infof("Sending error for ID %d: %s", id, sseMessage)
|
||||||
|
|
||||||
|
client.channel <- sseMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendResponse sends a success response via the SSE channel
|
||||||
|
func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) {
|
||||||
|
response := Response{
|
||||||
|
JsonRpc: jsonRpcVersion,
|
||||||
|
ID: id,
|
||||||
|
Result: result,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
s.sendErrorResponse(client, id, "Failed to marshal response", errCodeInternalError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use CRLF line endings as requested
|
||||||
|
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||||
|
logx.Infof("Sending response for ID %d: %s", id, sseMessage)
|
||||||
|
|
||||||
|
client.channel <- sseMessage
|
||||||
|
}
|
||||||
3418
mcp/server_test.go
Normal file
3418
mcp/server_test.go
Normal file
File diff suppressed because it is too large
Load Diff
294
mcp/types.go
Normal file
294
mcp/types.go
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/rest"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Cursor is an opaque token used for pagination
|
||||||
|
type Cursor string
|
||||||
|
|
||||||
|
// Request represents a generic MCP request following JSON-RPC 2.0 specification
|
||||||
|
type Request struct {
|
||||||
|
SessionId string `form:"session_id"` // Session identifier for client tracking
|
||||||
|
JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec
|
||||||
|
ID int64 `json:"id"` // Request identifier for matching responses
|
||||||
|
Method string `json:"method"` // Method name to invoke
|
||||||
|
Params json.RawMessage `json:"params"` // Parameters for the method
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaginatedParams struct {
|
||||||
|
Cursor string `json:"cursor"`
|
||||||
|
Meta struct {
|
||||||
|
ProgressToken any `json:"progressToken"`
|
||||||
|
} `json:"_meta"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Result is the base interface for all results
|
||||||
|
type Result struct {
|
||||||
|
Meta map[string]any `json:"_meta,omitempty"` // Optional metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
// PaginatedResult is a base for results that support pagination
|
||||||
|
type PaginatedResult struct {
|
||||||
|
Result
|
||||||
|
NextCursor Cursor `json:"nextCursor,omitempty"` // Opaque token for fetching next page
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListToolsResult represents the response to a tools/list request
|
||||||
|
type ListToolsResult struct {
|
||||||
|
PaginatedResult
|
||||||
|
Tools []Tool `json:"tools"` // List of available tools
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message Content Types
|
||||||
|
|
||||||
|
// roleType represents the sender or recipient of messages in a conversation
|
||||||
|
type roleType string
|
||||||
|
|
||||||
|
// PromptArgument defines a single argument that can be passed to a prompt
|
||||||
|
type PromptArgument struct {
|
||||||
|
Name string `json:"name"` // Argument name
|
||||||
|
Description string `json:"description,omitempty"` // Human-readable description
|
||||||
|
Required bool `json:"required,omitempty"` // Whether this argument is required
|
||||||
|
Default string `json:"default,omitempty"` // Default value if not provided
|
||||||
|
}
|
||||||
|
|
||||||
|
// PromptHandler is a function that dynamically generates prompt content
|
||||||
|
type PromptHandler func(args map[string]string) ([]PromptMessage, error)
|
||||||
|
|
||||||
|
// Prompt represents an MCP Prompt definition
|
||||||
|
type Prompt struct {
|
||||||
|
Name string `json:"name"` // Unique identifier for the prompt
|
||||||
|
Description string `json:"description,omitempty"` // Human-readable description
|
||||||
|
Arguments []PromptArgument `json:"arguments,omitempty"` // Arguments for customization
|
||||||
|
Content string `json:"-"` // Static content (internal use only)
|
||||||
|
Handler PromptHandler `json:"-"` // Handler for dynamic content generation
|
||||||
|
}
|
||||||
|
|
||||||
|
// PromptMessage represents a message in a conversation
|
||||||
|
type PromptMessage struct {
|
||||||
|
Role roleType `json:"role"` // Message sender role
|
||||||
|
Content any `json:"content"` // Message content (TextContent, ImageContent, etc.)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TextContent represents text content in a message
|
||||||
|
type TextContent struct {
|
||||||
|
Type string `json:"type"` // Always "text"
|
||||||
|
Text string `json:"text"` // The text content
|
||||||
|
Annotations *Annotations `json:"annotations,omitempty"` // Optional annotations
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageContent represents image data in a message
|
||||||
|
type ImageContent struct {
|
||||||
|
Type string `json:"type"` // Always "image"
|
||||||
|
Data string `json:"data"` // Base64-encoded image data
|
||||||
|
MimeType string `json:"mimeType"` // MIME type (e.g., "image/png")
|
||||||
|
}
|
||||||
|
|
||||||
|
// AudioContent represents audio data in a message
|
||||||
|
type AudioContent struct {
|
||||||
|
Type string `json:"type"` // Always "audio"
|
||||||
|
Data string `json:"data"` // Base64-encoded audio data
|
||||||
|
MimeType string `json:"mimeType"` // MIME type (e.g., "audio/mp3")
|
||||||
|
}
|
||||||
|
|
||||||
|
// FileContent represents file content
|
||||||
|
type FileContent struct {
|
||||||
|
URI string `json:"uri"` // URI identifying the file
|
||||||
|
MimeType string `json:"mimeType"` // MIME type of the file
|
||||||
|
Text string `json:"text"` // File content as text
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbeddedResource represents a resource embedded in a message
|
||||||
|
type EmbeddedResource struct {
|
||||||
|
Type string `json:"type"` // Always "resource"
|
||||||
|
Resource struct {
|
||||||
|
URI string `json:"uri"` // Resource URI
|
||||||
|
MimeType string `json:"mimeType"` // MIME type of the resource
|
||||||
|
Text string `json:"text,omitempty"` // Text content (if available)
|
||||||
|
Blob string `json:"blob,omitempty"` // Base64 encoded blob data (if available)
|
||||||
|
} `json:"resource"` // The resource data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Annotations provides additional metadata for content
|
||||||
|
type Annotations struct {
|
||||||
|
Audience []roleType `json:"audience,omitempty"` // Who should see this content
|
||||||
|
Priority *float64 `json:"priority,omitempty"` // Optional priority (0-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool-related Types
|
||||||
|
|
||||||
|
// Tool Definition Types
|
||||||
|
|
||||||
|
// ToolHandler is a function that handles tool calls
|
||||||
|
type ToolHandler func(params map[string]any) (any, error)
|
||||||
|
|
||||||
|
// Tool represents a Model Context Protocol Tool definition
|
||||||
|
type Tool struct {
|
||||||
|
Name string `json:"name"` // Unique identifier for the tool
|
||||||
|
Description string `json:"description"` // Human-readable description
|
||||||
|
InputSchema InputSchema `json:"inputSchema"` // JSON Schema for parameters
|
||||||
|
Handler ToolHandler `json:"-"` // Not sent to clients
|
||||||
|
}
|
||||||
|
|
||||||
|
// InputSchema represents tool's input schema in JSON Schema format
|
||||||
|
type InputSchema struct {
|
||||||
|
Type string `json:"type"` // Always "object" for tool inputs
|
||||||
|
Properties map[string]any `json:"properties"` // Property definitions
|
||||||
|
Required []string `json:"required,omitempty"` // List of required properties
|
||||||
|
}
|
||||||
|
|
||||||
|
// CallToolResult represents a tool call result that conforms to the MCP schema
|
||||||
|
type CallToolResult struct {
|
||||||
|
Result
|
||||||
|
Content []interface{} `json:"content"` // Content items (text, images, etc.)
|
||||||
|
IsError bool `json:"isError,omitempty"` // True if tool execution failed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resource represents a Model Context Protocol Resource definition
|
||||||
|
type Resource struct {
|
||||||
|
URI string `json:"uri"` // Unique resource identifier (RFC3986)
|
||||||
|
Name string `json:"name"` // Human-readable name
|
||||||
|
Description string `json:"description,omitempty"` // Optional description
|
||||||
|
MimeType string `json:"mimeType,omitempty"` // Optional MIME type
|
||||||
|
Handler ResourceHandler `json:"-"` // Internal handler not sent to clients
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourceHandler is a function that handles resource read requests
|
||||||
|
type ResourceHandler func() (ResourceContent, error)
|
||||||
|
|
||||||
|
// ResourceContent represents the content of a resource
|
||||||
|
type ResourceContent struct {
|
||||||
|
URI string `json:"uri"` // Resource URI (required)
|
||||||
|
MimeType string `json:"mimeType,omitempty"` // MIME type of the resource
|
||||||
|
Text string `json:"text,omitempty"` // Text content (if available)
|
||||||
|
Blob string `json:"blob,omitempty"` // Base64 encoded blob data (if available)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourcesListResult represents the response to a resources/list request
|
||||||
|
type ResourcesListResult struct {
|
||||||
|
PaginatedResult
|
||||||
|
Resources []Resource `json:"resources"` // List of available resources
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourceReadParams contains parameters for a resources/read request
|
||||||
|
type ResourceReadParams struct {
|
||||||
|
URI string `json:"uri"` // URI of the resource to read
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourceReadResult contains the result of a resources/read request
|
||||||
|
type ResourceReadResult struct {
|
||||||
|
Result
|
||||||
|
Contents []ResourceContent `json:"contents"` // Array of resource content
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourceSubscribeParams contains parameters for a resources/subscribe request
|
||||||
|
type ResourceSubscribeParams struct {
|
||||||
|
URI string `json:"uri"` // URI of the resource to subscribe to
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourceUpdateNotification represents a notification about a resource update
|
||||||
|
type ResourceUpdateNotification struct {
|
||||||
|
URI string `json:"uri"` // URI of the updated resource
|
||||||
|
Content ResourceContent `json:"content"` // New resource content
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client and Server Types
|
||||||
|
|
||||||
|
// mcpClient represents an SSE client connection
|
||||||
|
type mcpClient struct {
|
||||||
|
id string // Unique client identifier
|
||||||
|
channel chan string // Channel for sending SSE messages
|
||||||
|
initialized bool // Tracks if client has sent notifications/initialized
|
||||||
|
}
|
||||||
|
|
||||||
|
// McpServer defines the interface for Model Context Protocol servers
|
||||||
|
type McpServer interface {
|
||||||
|
Start()
|
||||||
|
Stop()
|
||||||
|
RegisterTool(tool Tool) error
|
||||||
|
RegisterPrompt(prompt Prompt)
|
||||||
|
RegisterResource(resource Resource)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sseMcpServer implements the McpServer interface using SSE
|
||||||
|
type sseMcpServer struct {
|
||||||
|
conf McpConf
|
||||||
|
server *rest.Server
|
||||||
|
clients map[string]*mcpClient
|
||||||
|
clientsLock sync.Mutex
|
||||||
|
tools map[string]Tool
|
||||||
|
toolsLock sync.Mutex
|
||||||
|
prompts map[string]Prompt
|
||||||
|
promptsLock sync.Mutex
|
||||||
|
resources map[string]Resource
|
||||||
|
resourcesLock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Response Types
|
||||||
|
|
||||||
|
// errorObj represents a JSON-RPC error object
|
||||||
|
type errorObj struct {
|
||||||
|
Code int `json:"code"` // Error code
|
||||||
|
Message string `json:"message"` // Error message
|
||||||
|
}
|
||||||
|
|
||||||
|
// Response represents a JSON-RPC response
|
||||||
|
type Response struct {
|
||||||
|
JsonRpc string `json:"jsonrpc"` // Always "2.0"
|
||||||
|
ID int64 `json:"id"` // Same as request ID
|
||||||
|
Result any `json:"result"` // Result object (null if error)
|
||||||
|
Error *errorObj `json:"error,omitempty"` // Error object (null if success)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server Information Types
|
||||||
|
|
||||||
|
// serverInfo provides information about the server
|
||||||
|
type serverInfo struct {
|
||||||
|
Name string `json:"name"` // Server name
|
||||||
|
Version string `json:"version"` // Server version
|
||||||
|
}
|
||||||
|
|
||||||
|
// capabilities describes the server's capabilities
|
||||||
|
type capabilities struct {
|
||||||
|
Logging struct{} `json:"logging"`
|
||||||
|
Prompts struct {
|
||||||
|
ListChanged bool `json:"listChanged"` // Server will notify on prompt changes
|
||||||
|
} `json:"prompts"`
|
||||||
|
Resources struct {
|
||||||
|
Subscribe bool `json:"subscribe"` // Server supports resource subscriptions
|
||||||
|
ListChanged bool `json:"listChanged"` // Server will notify on resource changes
|
||||||
|
} `json:"resources"`
|
||||||
|
Tools struct {
|
||||||
|
ListChanged bool `json:"listChanged"` // Server will notify on tool changes
|
||||||
|
} `json:"tools"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// initializationResponse is sent in response to an initialize request
|
||||||
|
type initializationResponse struct {
|
||||||
|
ProtocolVersion string `json:"protocolVersion"` // Protocol version
|
||||||
|
Capabilities capabilities `json:"capabilities"` // Server capabilities
|
||||||
|
ServerInfo serverInfo `json:"serverInfo"` // Server information
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolCallParams contains the parameters for a tool call
|
||||||
|
type ToolCallParams struct {
|
||||||
|
Name string `json:"name"` // Tool name
|
||||||
|
Parameters map[string]any `json:"parameters"` // Tool parameters
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolResult contains the result of a tool execution
|
||||||
|
type ToolResult struct {
|
||||||
|
Type string `json:"type"` // Content type (text, image, etc.)
|
||||||
|
Content any `json:"content"` // Result content
|
||||||
|
}
|
||||||
|
|
||||||
|
// errorMessage represents a detailed error message
|
||||||
|
type errorMessage struct {
|
||||||
|
Code int `json:"code"` // Error code
|
||||||
|
Message string `json:"message"` // Error message
|
||||||
|
Data any `json:",omitempty"` // Additional error data
|
||||||
|
}
|
||||||
212
mcp/types_test.go
Normal file
212
mcp/types_test.go
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResponseMarshaling(t *testing.T) {
|
||||||
|
// Test that the Response struct marshals correctly
|
||||||
|
resp := Response{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 123,
|
||||||
|
Result: map[string]string{
|
||||||
|
"key": "value",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(resp)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"jsonrpc":"2.0"`)
|
||||||
|
assert.Contains(t, string(data), `"id":123`)
|
||||||
|
assert.Contains(t, string(data), `"result":{"key":"value"}`)
|
||||||
|
|
||||||
|
// Test response with error
|
||||||
|
respWithError := Response{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 456,
|
||||||
|
Error: &errorObj{
|
||||||
|
Code: errCodeInvalidRequest,
|
||||||
|
Message: "Invalid Request",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err = json.Marshal(respWithError)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"jsonrpc":"2.0"`)
|
||||||
|
assert.Contains(t, string(data), `"id":456`)
|
||||||
|
assert.Contains(t, string(data), `"error":{"code":-32600,"message":"Invalid Request"}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestUnmarshaling(t *testing.T) {
|
||||||
|
// Test that the Request struct unmarshals correctly
|
||||||
|
jsonStr := `{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 789,
|
||||||
|
"method": "test_method",
|
||||||
|
"params": {"key": "value"}
|
||||||
|
}`
|
||||||
|
|
||||||
|
var req Request
|
||||||
|
err := json.Unmarshal([]byte(jsonStr), &req)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "2.0", req.JsonRpc)
|
||||||
|
assert.Equal(t, int64(789), req.ID)
|
||||||
|
assert.Equal(t, "test_method", req.Method)
|
||||||
|
|
||||||
|
// Check params unmarshaled correctly
|
||||||
|
var params map[string]string
|
||||||
|
err = json.Unmarshal(req.Params, ¶ms)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "value", params["key"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToolStructs(t *testing.T) {
|
||||||
|
// Test Tool struct
|
||||||
|
tool := Tool{
|
||||||
|
Name: "test.tool",
|
||||||
|
Description: "A test tool",
|
||||||
|
InputSchema: InputSchema{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]any{
|
||||||
|
"input": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Input parameter",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{"input"},
|
||||||
|
},
|
||||||
|
Handler: func(params map[string]any) (any, error) {
|
||||||
|
return "result", nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify fields are correct
|
||||||
|
assert.Equal(t, "test.tool", tool.Name)
|
||||||
|
assert.Equal(t, "A test tool", tool.Description)
|
||||||
|
assert.Equal(t, "object", tool.InputSchema.Type)
|
||||||
|
assert.Contains(t, tool.InputSchema.Properties, "input")
|
||||||
|
propMap, ok := tool.InputSchema.Properties["input"].(map[string]any)
|
||||||
|
assert.True(t, ok, "Property should be a map")
|
||||||
|
assert.Equal(t, "string", propMap["type"])
|
||||||
|
assert.NotNil(t, tool.Handler)
|
||||||
|
|
||||||
|
// Verify JSON marshalling (which should exclude Handler function)
|
||||||
|
data, err := json.Marshal(tool)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"name":"test.tool"`)
|
||||||
|
assert.Contains(t, string(data), `"description":"A test tool"`)
|
||||||
|
assert.Contains(t, string(data), `"inputSchema":`)
|
||||||
|
assert.NotContains(t, string(data), `"Handler":`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPromptStructs(t *testing.T) {
|
||||||
|
// Test Prompt struct
|
||||||
|
prompt := Prompt{
|
||||||
|
Name: "test.prompt",
|
||||||
|
Description: "A test prompt description",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify fields are correct
|
||||||
|
assert.Equal(t, "test.prompt", prompt.Name)
|
||||||
|
assert.Equal(t, "A test prompt description", prompt.Description)
|
||||||
|
|
||||||
|
// Verify JSON marshalling
|
||||||
|
data, err := json.Marshal(prompt)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"name":"test.prompt"`)
|
||||||
|
assert.Contains(t, string(data), `"description":"A test prompt description"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResourceStructs(t *testing.T) {
|
||||||
|
// Test Resource struct
|
||||||
|
resource := Resource{
|
||||||
|
Name: "test.resource",
|
||||||
|
URI: "http://example.com/resource",
|
||||||
|
Description: "A test resource",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify fields are correct
|
||||||
|
assert.Equal(t, "test.resource", resource.Name)
|
||||||
|
assert.Equal(t, "http://example.com/resource", resource.URI)
|
||||||
|
assert.Equal(t, "A test resource", resource.Description)
|
||||||
|
|
||||||
|
// Verify JSON marshalling
|
||||||
|
data, err := json.Marshal(resource)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"name":"test.resource"`)
|
||||||
|
assert.Contains(t, string(data), `"uri":"http://example.com/resource"`)
|
||||||
|
assert.Contains(t, string(data), `"description":"A test resource"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContentTypes(t *testing.T) {
|
||||||
|
// Test TextContent
|
||||||
|
textContent := TextContent{
|
||||||
|
Type: "text",
|
||||||
|
Text: "Sample text",
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []roleType{roleUser, roleAssistant},
|
||||||
|
Priority: ptr(1.0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(textContent)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"type":"text"`)
|
||||||
|
assert.Contains(t, string(data), `"text":"Sample text"`)
|
||||||
|
assert.Contains(t, string(data), `"audience":["user","assistant"]`)
|
||||||
|
assert.Contains(t, string(data), `"priority":1`)
|
||||||
|
|
||||||
|
// Test ImageContent
|
||||||
|
imageContent := ImageContent{
|
||||||
|
Type: "image",
|
||||||
|
Data: "base64data",
|
||||||
|
MimeType: "image/png",
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err = json.Marshal(imageContent)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"type":"image"`)
|
||||||
|
assert.Contains(t, string(data), `"data":"base64data"`)
|
||||||
|
assert.Contains(t, string(data), `"mimeType":"image/png"`)
|
||||||
|
|
||||||
|
// Test AudioContent
|
||||||
|
audioContent := AudioContent{
|
||||||
|
Type: "audio",
|
||||||
|
Data: "base64audio",
|
||||||
|
MimeType: "audio/mp3",
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err = json.Marshal(audioContent)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"type":"audio"`)
|
||||||
|
assert.Contains(t, string(data), `"data":"base64audio"`)
|
||||||
|
assert.Contains(t, string(data), `"mimeType":"audio/mp3"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallToolResult(t *testing.T) {
|
||||||
|
// Test CallToolResult
|
||||||
|
result := CallToolResult{
|
||||||
|
Result: Result{
|
||||||
|
Meta: map[string]any{
|
||||||
|
"progressToken": "token123",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Content: []interface{}{
|
||||||
|
TextContent{
|
||||||
|
Type: "text",
|
||||||
|
Text: "Sample result",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
IsError: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(result)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"_meta":{"progressToken":"token123"}`)
|
||||||
|
assert.Contains(t, string(data), `"content":[{"type":"text","text":"Sample result"}]`)
|
||||||
|
assert.NotContains(t, string(data), `"isError":`)
|
||||||
|
}
|
||||||
15
mcp/util.go
Normal file
15
mcp/util.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ptr is a helper function to get a pointer to a value
|
||||||
|
func ptr[T any](v T) *T {
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings
|
||||||
|
func formatSSEMessage(event string, data []byte) string {
|
||||||
|
return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data))
|
||||||
|
}
|
||||||
63
mcp/util_test.go
Normal file
63
mcp/util_test.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPtr(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
v interface{}
|
||||||
|
}{
|
||||||
|
{"string", "test"},
|
||||||
|
{"int", 42},
|
||||||
|
{"bool", true},
|
||||||
|
{"float", 3.14},
|
||||||
|
{"struct", struct{ Name string }{"test"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := ptr(tt.v)
|
||||||
|
assert.NotNil(t, got, "ptr() should not return nil")
|
||||||
|
assert.Equal(t, tt.v, *got, "dereferenced pointer should equal input value")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Event struct {
|
||||||
|
Type string
|
||||||
|
Data map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseEvent(input string) (*Event, error) {
|
||||||
|
var evt Event
|
||||||
|
var dataStr string
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(strings.NewReader(input))
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if strings.HasPrefix(line, "event:") {
|
||||||
|
evt.Type = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
|
||||||
|
} else if strings.HasPrefix(line, "data:") {
|
||||||
|
dataStr = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(dataStr) > 0 {
|
||||||
|
if err := json.Unmarshal([]byte(dataStr), &evt.Data); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse data: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &evt, nil
|
||||||
|
}
|
||||||
137
mcp/vars.go
Normal file
137
mcp/vars.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/core/syncx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Protocol constants
|
||||||
|
const (
|
||||||
|
// JSON-RPC version as defined in the specification
|
||||||
|
jsonRpcVersion = "2.0"
|
||||||
|
|
||||||
|
// Session identifier key used in request URLs
|
||||||
|
sessionIdKey = "session_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server-Sent Events (SSE) event types
|
||||||
|
const (
|
||||||
|
// Standard message event for JSON-RPC responses
|
||||||
|
eventMessage = "message"
|
||||||
|
|
||||||
|
// Endpoint event for sending endpoint URL to clients
|
||||||
|
eventEndpoint = "endpoint"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Content type identifiers
|
||||||
|
const (
|
||||||
|
// Text content type
|
||||||
|
contentTypeText = "text"
|
||||||
|
|
||||||
|
// Image content type
|
||||||
|
contentTypeImage = "image"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Collection keys for broadcast events
|
||||||
|
const (
|
||||||
|
// Key for prompts collection
|
||||||
|
keyPrompts = "prompts"
|
||||||
|
|
||||||
|
// Key for resources collection
|
||||||
|
keyResources = "resources"
|
||||||
|
|
||||||
|
// Key for tools collection
|
||||||
|
keyTools = "tools"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JSON-RPC error codes
|
||||||
|
// Standard error codes from JSON-RPC 2.0 spec
|
||||||
|
const (
|
||||||
|
// Invalid JSON was received by the server
|
||||||
|
errCodeInvalidRequest = -32600
|
||||||
|
|
||||||
|
// The method does not exist / is not available
|
||||||
|
errCodeMethodNotFound = -32601
|
||||||
|
|
||||||
|
// Invalid method parameter(s)
|
||||||
|
errCodeInvalidParams = -32602
|
||||||
|
|
||||||
|
// Internal JSON-RPC error
|
||||||
|
errCodeInternalError = -32603
|
||||||
|
|
||||||
|
// Tool execution timed out
|
||||||
|
errCodeTimeout = -32001
|
||||||
|
|
||||||
|
// Resource not found error
|
||||||
|
errCodeResourceNotFound = -32002
|
||||||
|
|
||||||
|
// Client hasn't completed initialization
|
||||||
|
errCodeClientNotInitialized = -32800
|
||||||
|
)
|
||||||
|
|
||||||
|
// User and assistant role definitions
|
||||||
|
const (
|
||||||
|
// The "user" role - the entity asking questions
|
||||||
|
roleUser roleType = "user"
|
||||||
|
|
||||||
|
// The "assistant" role - the entity providing responses
|
||||||
|
roleAssistant roleType = "assistant"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Method names as defined in the MCP specification
|
||||||
|
const (
|
||||||
|
// Initialize the connection between client and server
|
||||||
|
methodInitialize = "initialize"
|
||||||
|
|
||||||
|
// List available tools
|
||||||
|
methodToolsList = "tools/list"
|
||||||
|
|
||||||
|
// Call a specific tool
|
||||||
|
methodToolsCall = "tools/call"
|
||||||
|
|
||||||
|
// List available prompts
|
||||||
|
methodPromptsList = "prompts/list"
|
||||||
|
|
||||||
|
// Get a specific prompt
|
||||||
|
methodPromptsGet = "prompts/get"
|
||||||
|
|
||||||
|
// List available resources
|
||||||
|
methodResourcesList = "resources/list"
|
||||||
|
|
||||||
|
// Read a specific resource
|
||||||
|
methodResourcesRead = "resources/read"
|
||||||
|
|
||||||
|
// Subscribe to resource updates
|
||||||
|
methodResourcesSubscribe = "resources/subscribe"
|
||||||
|
|
||||||
|
// Simple ping to check server availability
|
||||||
|
methodPing = "ping"
|
||||||
|
|
||||||
|
// Notification that client is fully initialized
|
||||||
|
methodNotificationsInitialized = "notifications/initialized"
|
||||||
|
|
||||||
|
// Notification that a request was canceled
|
||||||
|
methodNotificationsCancelled = "notifications/cancelled"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Event names for Server-Sent Events (SSE)
|
||||||
|
const (
|
||||||
|
// Notification of tool list changes
|
||||||
|
eventToolsListChanged = "tools/list_changed"
|
||||||
|
|
||||||
|
// Notification of prompt list changes
|
||||||
|
eventPromptsListChanged = "prompts/list_changed"
|
||||||
|
|
||||||
|
// Notification of resource list changes
|
||||||
|
eventResourcesListChanged = "resources/list_changed"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// Default channel size for events
|
||||||
|
eventChanSize = 10
|
||||||
|
|
||||||
|
// Default ping interval for checking connection availability
|
||||||
|
// use syncx.ForAtomicDuration to ensure atomicity in test race
|
||||||
|
pingInterval = syncx.ForAtomicDuration(30 * time.Second)
|
||||||
|
)
|
||||||
214
mcp/vars_test.go
Normal file
214
mcp/vars_test.go
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
// filepath: /Users/kevin/Develop/go/opensource/go-zero/mcp/vars_test.go
|
||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestErrorCodes ensures error codes are applied correctly in error responses
|
||||||
|
func TestErrorCodes(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
code int
|
||||||
|
message string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "invalid request error",
|
||||||
|
code: errCodeInvalidRequest,
|
||||||
|
message: "Invalid request",
|
||||||
|
expected: `"code":-32600`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "method not found error",
|
||||||
|
code: errCodeMethodNotFound,
|
||||||
|
message: "Method not found",
|
||||||
|
expected: `"code":-32601`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid params error",
|
||||||
|
code: errCodeInvalidParams,
|
||||||
|
message: "Invalid parameters",
|
||||||
|
expected: `"code":-32602`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "internal error",
|
||||||
|
code: errCodeInternalError,
|
||||||
|
message: "Internal server error",
|
||||||
|
expected: `"code":-32603`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "timeout error",
|
||||||
|
code: errCodeTimeout,
|
||||||
|
message: "Operation timed out",
|
||||||
|
expected: `"code":-32001`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "resource not found error",
|
||||||
|
code: errCodeResourceNotFound,
|
||||||
|
message: "Resource not found",
|
||||||
|
expected: `"code":-32002`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "client not initialized error",
|
||||||
|
code: errCodeClientNotInitialized,
|
||||||
|
message: "Client not initialized",
|
||||||
|
expected: `"code":-32800`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
resp := Response{
|
||||||
|
JsonRpc: jsonRpcVersion,
|
||||||
|
ID: int64(1),
|
||||||
|
Error: &errorObj{
|
||||||
|
Code: tc.code,
|
||||||
|
Message: tc.message,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(resp)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), tc.expected, "Error code should match expected value")
|
||||||
|
assert.Contains(t, string(data), tc.message, "Error message should be included")
|
||||||
|
assert.Contains(t, string(data), jsonRpcVersion, "JSON-RPC version should be included")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestJsonRpcVersion ensures the correct JSON-RPC version is used
|
||||||
|
func TestJsonRpcVersion(t *testing.T) {
|
||||||
|
assert.Equal(t, "2.0", jsonRpcVersion, "JSON-RPC version should be 2.0")
|
||||||
|
|
||||||
|
// Test that it's used in responses
|
||||||
|
resp := Response{
|
||||||
|
JsonRpc: jsonRpcVersion,
|
||||||
|
ID: int64(1),
|
||||||
|
Result: "test",
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(resp)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"jsonrpc":"2.0"`, "Response should use correct JSON-RPC version")
|
||||||
|
|
||||||
|
// Test that it's expected in requests
|
||||||
|
reqStr := `{"jsonrpc":"2.0","id":1,"method":"test"}`
|
||||||
|
var req Request
|
||||||
|
err = json.Unmarshal([]byte(reqStr), &req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, jsonRpcVersion, req.JsonRpc, "Request should parse correct JSON-RPC version")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSessionIdKey ensures session ID extraction works correctly
|
||||||
|
func TestSessionIdKey(t *testing.T) {
|
||||||
|
// Create a mock server implementation
|
||||||
|
mock := newMockMcpServer(t)
|
||||||
|
defer mock.shutdown()
|
||||||
|
|
||||||
|
// Verify the key constant
|
||||||
|
assert.Equal(t, "session_id", sessionIdKey, "Session ID key should be 'session_id'")
|
||||||
|
|
||||||
|
// Test that session ID is extracted correctly
|
||||||
|
mockR := httptest.NewRequest("GET", "/?"+sessionIdKey+"=test-session", nil)
|
||||||
|
|
||||||
|
// Since the mock server is using the same session key logic,
|
||||||
|
// we can test this by accessing the request query parameters directly
|
||||||
|
sessionID := mockR.URL.Query().Get(sessionIdKey)
|
||||||
|
assert.Equal(t, "test-session", sessionID, "Session ID should be extracted correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEventTypes ensures event types are set correctly in SSE responses
|
||||||
|
func TestEventTypes(t *testing.T) {
|
||||||
|
// Test message event
|
||||||
|
assert.Equal(t, "message", eventMessage, "Message event should be 'message'")
|
||||||
|
|
||||||
|
// Test endpoint event
|
||||||
|
assert.Equal(t, "endpoint", eventEndpoint, "Endpoint event should be 'endpoint'")
|
||||||
|
|
||||||
|
// Verify them in an actual SSE format string
|
||||||
|
messageEvent := "event: " + eventMessage + "\ndata: test\n\n"
|
||||||
|
assert.Contains(t, messageEvent, "event: message", "Message event should format correctly")
|
||||||
|
|
||||||
|
endpointEvent := "event: " + eventEndpoint + "\ndata: test\n\n"
|
||||||
|
assert.Contains(t, endpointEvent, "event: endpoint", "Endpoint event should format correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCollectionKeys checks that collection keys are used correctly
|
||||||
|
func TestCollectionKeys(t *testing.T) {
|
||||||
|
// Verify collection key constants
|
||||||
|
assert.Equal(t, "prompts", keyPrompts, "Prompts key should be 'prompts'")
|
||||||
|
assert.Equal(t, "resources", keyResources, "Resources key should be 'resources'")
|
||||||
|
assert.Equal(t, "tools", keyTools, "Tools key should be 'tools'")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRoleTypes checks that role types are used correctly
|
||||||
|
func TestRoleTypes(t *testing.T) {
|
||||||
|
// Verify role type constants
|
||||||
|
assert.Equal(t, "user", string(roleUser), "User role should be 'user'")
|
||||||
|
assert.Equal(t, "assistant", string(roleAssistant), "Assistant role should be 'assistant'")
|
||||||
|
|
||||||
|
// Test in annotations
|
||||||
|
annotations := Annotations{
|
||||||
|
Audience: []roleType{roleUser, roleAssistant},
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(annotations)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"audience":["user","assistant"]`, "Role types should marshal correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMethodNames checks that method names are used correctly
|
||||||
|
func TestMethodNames(t *testing.T) {
|
||||||
|
// Verify method name constants
|
||||||
|
methods := map[string]string{
|
||||||
|
"initialize": methodInitialize,
|
||||||
|
"tools/list": methodToolsList,
|
||||||
|
"tools/call": methodToolsCall,
|
||||||
|
"prompts/list": methodPromptsList,
|
||||||
|
"prompts/get": methodPromptsGet,
|
||||||
|
"resources/list": methodResourcesList,
|
||||||
|
"resources/read": methodResourcesRead,
|
||||||
|
"resources/subscribe": methodResourcesSubscribe,
|
||||||
|
"ping": methodPing,
|
||||||
|
"notifications/initialized": methodNotificationsInitialized,
|
||||||
|
"notifications/cancelled": methodNotificationsCancelled,
|
||||||
|
}
|
||||||
|
|
||||||
|
for expected, actual := range methods {
|
||||||
|
assert.Equal(t, expected, actual, "Method name should be "+expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test in a request
|
||||||
|
for methodName := range methods {
|
||||||
|
req := Request{
|
||||||
|
JsonRpc: jsonRpcVersion,
|
||||||
|
ID: int64(1),
|
||||||
|
Method: methodName,
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"method":"`+methodName+`"`, "Method name should be used in requests")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEventNames checks that event names are used correctly
|
||||||
|
func TestEventNames(t *testing.T) {
|
||||||
|
// Verify event name constants
|
||||||
|
events := map[string]string{
|
||||||
|
"tools/list_changed": eventToolsListChanged,
|
||||||
|
"prompts/list_changed": eventPromptsListChanged,
|
||||||
|
"resources/list_changed": eventResourcesListChanged,
|
||||||
|
}
|
||||||
|
|
||||||
|
for expected, actual := range events {
|
||||||
|
assert.Equal(t, expected, actual, "Event name should be "+expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test event names in SSE format
|
||||||
|
for _, eventName := range events {
|
||||||
|
sseEvent := "event: " + eventName + "\ndata: test\n\n"
|
||||||
|
assert.Contains(t, sseEvent, "event: "+eventName, "Event name should format correctly in SSE")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user