mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-11 00:40:00 +08:00
@@ -34,7 +34,10 @@ type McpConf struct {
|
|||||||
// Cors contains allowed CORS origins
|
// Cors contains allowed CORS origins
|
||||||
Cors []string `json:",optional"`
|
Cors []string `json:",optional"`
|
||||||
|
|
||||||
// ToolTimeout is the maximum time allowed for tool execution
|
// SseTimeout is the maximum time allowed for SSE connections
|
||||||
ToolTimeout time.Duration `json:",default=30s"`
|
SseTimeout time.Duration `json:",default=24h"`
|
||||||
|
|
||||||
|
// MessageTimeout is the maximum time allowed for request execution
|
||||||
|
MessageTimeout time.Duration `json:",default=30s"`
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ mcp:
|
|||||||
assert.Equal(t, "2024-11-05", c.Mcp.ProtocolVersion, "Default protocol version should be 2024-11-05")
|
assert.Equal(t, "2024-11-05", c.Mcp.ProtocolVersion, "Default protocol version should be 2024-11-05")
|
||||||
assert.Equal(t, "/sse", c.Mcp.SseEndpoint, "Default SSE endpoint should be /sse")
|
assert.Equal(t, "/sse", c.Mcp.SseEndpoint, "Default SSE endpoint should be /sse")
|
||||||
assert.Equal(t, "/message", c.Mcp.MessageEndpoint, "Default message endpoint should be /message")
|
assert.Equal(t, "/message", c.Mcp.MessageEndpoint, "Default message endpoint should be /message")
|
||||||
assert.Equal(t, 30*time.Second, c.Mcp.ToolTimeout, "Default tool timeout should be 30s")
|
assert.Equal(t, 30*time.Second, c.Mcp.MessageTimeout, "Default message timeout should be 30s")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMcpConfCustomValues(t *testing.T) {
|
func TestMcpConfCustomValues(t *testing.T) {
|
||||||
@@ -43,7 +43,7 @@ func TestMcpConfCustomValues(t *testing.T) {
|
|||||||
"SseEndpoint": "/custom-sse",
|
"SseEndpoint": "/custom-sse",
|
||||||
"MessageEndpoint": "/custom-message",
|
"MessageEndpoint": "/custom-message",
|
||||||
"Cors": ["http://localhost:3000", "http://example.com"],
|
"Cors": ["http://localhost:3000", "http://example.com"],
|
||||||
"ToolTimeout": "60s"
|
"MessageTimeout": "60s"
|
||||||
}
|
}
|
||||||
}`
|
}`
|
||||||
|
|
||||||
@@ -59,5 +59,5 @@ func TestMcpConfCustomValues(t *testing.T) {
|
|||||||
assert.Equal(t, "/custom-sse", c.Mcp.SseEndpoint, "SSE endpoint should be customizable")
|
assert.Equal(t, "/custom-sse", c.Mcp.SseEndpoint, "SSE endpoint should be customizable")
|
||||||
assert.Equal(t, "/custom-message", c.Mcp.MessageEndpoint, "Message endpoint should be customizable")
|
assert.Equal(t, "/custom-message", c.Mcp.MessageEndpoint, "Message endpoint should be customizable")
|
||||||
assert.Equal(t, []string{"http://localhost:3000", "http://example.com"}, c.Mcp.Cors, "CORS settings should be customizable")
|
assert.Equal(t, []string{"http://localhost:3000", "http://example.com"}, c.Mcp.Cors, "CORS settings should be customizable")
|
||||||
assert.Equal(t, 60*time.Second, c.Mcp.ToolTimeout, "Tool timeout should be customizable")
|
assert.Equal(t, 60*time.Second, c.Mcp.MessageTimeout, "Tool timeout should be customizable")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package mcp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -59,7 +60,7 @@ func TestHTTPHandlerIntegration(t *testing.T) {
|
|||||||
conf := McpConf{}
|
conf := McpConf{}
|
||||||
conf.Mcp.Name = "test-integration"
|
conf.Mcp.Name = "test-integration"
|
||||||
conf.Mcp.Version = "1.0.0-test"
|
conf.Mcp.Version = "1.0.0-test"
|
||||||
conf.Mcp.ToolTimeout = 1 * time.Second
|
conf.Mcp.MessageTimeout = 1 * time.Second
|
||||||
|
|
||||||
// Create a mock server directly
|
// Create a mock server directly
|
||||||
server := &sseMcpServer{
|
server := &sseMcpServer{
|
||||||
@@ -75,7 +76,6 @@ func TestHTTPHandlerIntegration(t *testing.T) {
|
|||||||
Name: "echo",
|
Name: "echo",
|
||||||
Description: "Echo tool for testing",
|
Description: "Echo tool for testing",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{
|
Properties: map[string]any{
|
||||||
"message": map[string]any{
|
"message": map[string]any{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -83,7 +83,7 @@ func TestHTTPHandlerIntegration(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
if msg, ok := params["message"].(string); ok {
|
if msg, ok := params["message"].(string); ok {
|
||||||
return fmt.Sprintf("Echo: %s", msg), nil
|
return fmt.Sprintf("Echo: %s", msg), nil
|
||||||
}
|
}
|
||||||
@@ -181,7 +181,7 @@ func TestHandlerResponseFlow(t *testing.T) {
|
|||||||
Name: "test.tool",
|
Name: "test.tool",
|
||||||
Description: "Test tool",
|
Description: "Test tool",
|
||||||
InputSchema: InputSchema{Type: "object"},
|
InputSchema: InputSchema{Type: "object"},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
return "tool result", nil
|
return "tool result", nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -329,7 +329,7 @@ func TestProcessListMethods(t *testing.T) {
|
|||||||
Params: json.RawMessage(`{"cursor": "", "_meta": {"progressToken": "token1"}}`),
|
Params: json.RawMessage(`{"cursor": "", "_meta": {"progressToken": "token1"}}`),
|
||||||
}
|
}
|
||||||
|
|
||||||
server.processListTools(client, req)
|
server.processListTools(context.Background(), client, req)
|
||||||
|
|
||||||
// Read response
|
// Read response
|
||||||
select {
|
select {
|
||||||
@@ -344,7 +344,7 @@ func TestProcessListMethods(t *testing.T) {
|
|||||||
req.ID = 2
|
req.ID = 2
|
||||||
req.Method = methodPromptsList
|
req.Method = methodPromptsList
|
||||||
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
||||||
server.processListPrompts(client, req)
|
server.processListPrompts(context.Background(), client, req)
|
||||||
|
|
||||||
// Read response
|
// Read response
|
||||||
select {
|
select {
|
||||||
@@ -358,7 +358,7 @@ func TestProcessListMethods(t *testing.T) {
|
|||||||
req.ID = 3
|
req.ID = 3
|
||||||
req.Method = methodResourcesList
|
req.Method = methodResourcesList
|
||||||
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
||||||
server.processListResources(client, req)
|
server.processListResources(context.Background(), client, req)
|
||||||
|
|
||||||
// Read response
|
// Read response
|
||||||
select {
|
select {
|
||||||
@@ -393,7 +393,7 @@ func TestErrorResponseHandling(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Mock handleRequest by directly calling error handler
|
// Mock handleRequest by directly calling error handler
|
||||||
server.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound)
|
server.sendErrorResponse(context.Background(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||||
|
|
||||||
// Check response
|
// Check response
|
||||||
select {
|
select {
|
||||||
@@ -412,7 +412,7 @@ func TestErrorResponseHandling(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Call process method directly
|
// Call process method directly
|
||||||
server.processToolCall(client, toolReq)
|
server.processToolCall(context.Background(), client, toolReq)
|
||||||
|
|
||||||
// Check response
|
// Check response
|
||||||
select {
|
select {
|
||||||
@@ -431,7 +431,7 @@ func TestErrorResponseHandling(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Call process method directly
|
// Call process method directly
|
||||||
server.processGetPrompt(client, promptReq)
|
server.processGetPrompt(context.Background(), client, promptReq)
|
||||||
|
|
||||||
// Check response
|
// Check response
|
||||||
select {
|
select {
|
||||||
|
|||||||
23
mcp/parser.go
Normal file
23
mcp/parser.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/core/mapping"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseArguments parses the arguments and populates the request object
|
||||||
|
func ParseArguments(args any, req any) error {
|
||||||
|
switch arguments := args.(type) {
|
||||||
|
case map[string]string:
|
||||||
|
m := make(map[string]any, len(arguments))
|
||||||
|
for k, v := range arguments {
|
||||||
|
m[k] = v
|
||||||
|
}
|
||||||
|
return mapping.UnmarshalJsonMap(m, req, mapping.WithStringValues())
|
||||||
|
case map[string]any:
|
||||||
|
return mapping.UnmarshalJsonMap(arguments, req)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported argument type: %T", arguments)
|
||||||
|
}
|
||||||
|
}
|
||||||
139
mcp/parser_test.go
Normal file
139
mcp/parser_test.go
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestParseArguments_MapStringString tests parsing map[string]string arguments
|
||||||
|
func TestParseArguments_MapStringString(t *testing.T) {
|
||||||
|
// Sample request struct to populate
|
||||||
|
type requestStruct struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Count int `json:"count"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test arguments
|
||||||
|
args := map[string]string{
|
||||||
|
"name": "test-name",
|
||||||
|
"message": "hello world",
|
||||||
|
"count": "42",
|
||||||
|
"enabled": "true",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a target object to populate
|
||||||
|
var req requestStruct
|
||||||
|
|
||||||
|
// Parse the arguments
|
||||||
|
err := ParseArguments(args, &req)
|
||||||
|
|
||||||
|
// Verify results
|
||||||
|
assert.NoError(t, err, "Should parse map[string]string without error")
|
||||||
|
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
|
||||||
|
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
|
||||||
|
assert.Equal(t, 42, req.Count, "Count should be correctly parsed to int")
|
||||||
|
assert.True(t, req.Enabled, "Enabled should be correctly parsed to bool")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParseArguments_MapStringAny tests parsing map[string]any arguments
|
||||||
|
func TestParseArguments_MapStringAny(t *testing.T) {
|
||||||
|
// Sample request struct to populate
|
||||||
|
type requestStruct struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Count int `json:"count"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Tags []string `json:"tags"`
|
||||||
|
Metadata map[string]string `json:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test arguments with mixed types
|
||||||
|
args := map[string]any{
|
||||||
|
"name": "test-name",
|
||||||
|
"message": "hello world",
|
||||||
|
"count": 42, // note: this is already an int
|
||||||
|
"enabled": true, // note: this is already a bool
|
||||||
|
"tags": []string{"tag1", "tag2"},
|
||||||
|
"metadata": map[string]string{
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": "value2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a target object to populate
|
||||||
|
var req requestStruct
|
||||||
|
|
||||||
|
// Parse the arguments
|
||||||
|
err := ParseArguments(args, &req)
|
||||||
|
|
||||||
|
// Verify results
|
||||||
|
assert.NoError(t, err, "Should parse map[string]any without error")
|
||||||
|
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
|
||||||
|
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
|
||||||
|
assert.Equal(t, 42, req.Count, "Count should be correctly parsed")
|
||||||
|
assert.True(t, req.Enabled, "Enabled should be correctly parsed")
|
||||||
|
assert.Equal(t, []string{"tag1", "tag2"}, req.Tags, "Tags should be correctly parsed")
|
||||||
|
assert.Equal(t, map[string]string{
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": "value2",
|
||||||
|
}, req.Metadata, "Metadata should be correctly parsed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParseArguments_UnsupportedType tests parsing with an unsupported type
|
||||||
|
func TestParseArguments_UnsupportedType(t *testing.T) {
|
||||||
|
// Sample request struct to populate
|
||||||
|
type requestStruct struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use an unsupported argument type (slice)
|
||||||
|
args := []string{"not", "a", "map"}
|
||||||
|
|
||||||
|
// Create a target object to populate
|
||||||
|
var req requestStruct
|
||||||
|
|
||||||
|
// Parse the arguments
|
||||||
|
err := ParseArguments(args, &req)
|
||||||
|
|
||||||
|
// Verify error is returned with correct message
|
||||||
|
assert.Error(t, err, "Should return error for unsupported type")
|
||||||
|
assert.Contains(t, err.Error(), "unsupported argument type", "Error should mention unsupported type")
|
||||||
|
assert.Contains(t, err.Error(), "[]string", "Error should include the actual type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParseArguments_EmptyMap tests parsing with empty maps
|
||||||
|
func TestParseArguments_EmptyMap(t *testing.T) {
|
||||||
|
// Sample request struct to populate
|
||||||
|
type requestStruct struct {
|
||||||
|
Name string `json:"name,optional"`
|
||||||
|
Message string `json:"message,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test empty map[string]string
|
||||||
|
t.Run("EmptyMapStringString", func(t *testing.T) {
|
||||||
|
args := map[string]string{}
|
||||||
|
var req requestStruct
|
||||||
|
|
||||||
|
err := ParseArguments(args, &req)
|
||||||
|
|
||||||
|
assert.NoError(t, err, "Should parse empty map[string]string without error")
|
||||||
|
assert.Empty(t, req.Name, "Name should be empty string")
|
||||||
|
assert.Empty(t, req.Message, "Message should be empty string")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test empty map[string]any
|
||||||
|
t.Run("EmptyMapStringAny", func(t *testing.T) {
|
||||||
|
args := map[string]any{}
|
||||||
|
var req requestStruct
|
||||||
|
|
||||||
|
err := ParseArguments(args, &req)
|
||||||
|
|
||||||
|
assert.NoError(t, err, "Should parse empty map[string]any without error")
|
||||||
|
assert.Empty(t, req.Name, "Name should be empty string")
|
||||||
|
assert.Empty(t, req.Message, "Message should be empty string")
|
||||||
|
})
|
||||||
|
}
|
||||||
824
mcp/readme.md
824
mcp/readme.md
@@ -1,7 +1,7 @@
|
|||||||
# Model Context Protocol (MCP) SDK Implementation
|
# Model Context Protocol (MCP) Implementation
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
This package implements a Model Context Protocol (MCP) server in Go that facilitates real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation provides a framework for building AI-assisted applications with bidirectional communication capabilities.
|
This package implements the Model Context Protocol (MCP) server specification in Go, providing a framework for real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation follows the standardized protocol for building AI-assisted applications with bidirectional communication capabilities.
|
||||||
|
|
||||||
## Core Components
|
## Core Components
|
||||||
|
|
||||||
@@ -54,9 +54,817 @@ This package implements a Model Context Protocol (MCP) server in Go that facilit
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
To create and use an MCP server, see the examples directory for practical implementation examples including:
|
### Setting Up an MCP Server
|
||||||
- Tool registration and execution
|
|
||||||
- Static and dynamic prompt creation
|
To create and start an MCP server:
|
||||||
- Resource handling with proper URI identification
|
|
||||||
- Embedded resources in prompt responses
|
```go
|
||||||
- Client connection management
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"github.com/zeromicro/go-zero/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Load configuration from YAML file
|
||||||
|
var c mcp.McpConf
|
||||||
|
conf.MustLoad("config.yaml", &c)
|
||||||
|
|
||||||
|
// Optional: Disable stats logging
|
||||||
|
logx.DisableStat()
|
||||||
|
|
||||||
|
// Create MCP server
|
||||||
|
server := mcp.NewMcpServer(c)
|
||||||
|
|
||||||
|
// Register tools, prompts, and resources (examples below)
|
||||||
|
|
||||||
|
// Start the server and ensure it's stopped on exit
|
||||||
|
defer server.Stop()
|
||||||
|
server.Start()
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Sample configuration file (config.yaml):
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
name: mcp-server
|
||||||
|
host: localhost
|
||||||
|
port: 8080
|
||||||
|
mcp:
|
||||||
|
name: my-mcp-server
|
||||||
|
messageTimeout: 30s # Timeout for tool calls
|
||||||
|
cors:
|
||||||
|
- http://localhost:3000 # Optional CORS configuration
|
||||||
|
```
|
||||||
|
|
||||||
|
### Registering Tools
|
||||||
|
|
||||||
|
Tools allow AI models to execute custom code through the MCP protocol.
|
||||||
|
|
||||||
|
#### Basic Tool Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a simple echo tool
|
||||||
|
echoTool := mcp.Tool{
|
||||||
|
Name: "echo",
|
||||||
|
Description: "Echoes back the message provided by the user",
|
||||||
|
InputSchema: mcp.InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"message": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "The message to echo back",
|
||||||
|
},
|
||||||
|
"prefix": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional prefix to add to the echoed message",
|
||||||
|
"default": "Echo: ",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{"message"},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
var req struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Prefix string `json:"prefix,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := "Echo: "
|
||||||
|
if len(req.Prefix) > 0 {
|
||||||
|
prefix = req.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
return prefix + req.Message, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterTool(echoTool)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Tool with Different Response Types:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Tool returning JSON data
|
||||||
|
dataTool := mcp.Tool{
|
||||||
|
Name: "data.generate",
|
||||||
|
Description: "Generates sample data in various formats",
|
||||||
|
InputSchema: mcp.InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"format": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Format of data (json, text)",
|
||||||
|
"enum": []string{"json", "text"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
var req struct {
|
||||||
|
Format string `json:"format"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Format == "json" {
|
||||||
|
// Return structured data
|
||||||
|
return map[string]any{
|
||||||
|
"items": []map[string]any{
|
||||||
|
{"id": 1, "name": "Item 1"},
|
||||||
|
{"id": 2, "name": "Item 2"},
|
||||||
|
},
|
||||||
|
"count": 2,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to text
|
||||||
|
return "Sample text data", nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterTool(dataTool)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Image Generation Tool Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Tool returning image content
|
||||||
|
imageTool := mcp.Tool{
|
||||||
|
Name: "image.generate",
|
||||||
|
Description: "Generates a simple image",
|
||||||
|
InputSchema: mcp.InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"type": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Type of image to generate",
|
||||||
|
"default": "placeholder",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
// Return image content directly
|
||||||
|
return mcp.ImageContent{
|
||||||
|
Data: "base64EncodedImageData...", // Base64 encoded image data
|
||||||
|
MimeType: "image/png",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterTool(imageTool)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Using ToolResult for Custom Outputs:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Tool that returns a custom ToolResult type
|
||||||
|
customResultTool := mcp.Tool{
|
||||||
|
Name: "custom.result",
|
||||||
|
Description: "Returns a custom formatted result",
|
||||||
|
InputSchema: mcp.InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"resultType": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"enum": []string{"text", "image"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
var req struct {
|
||||||
|
ResultType string `json:"resultType"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.ResultType == "image" {
|
||||||
|
return mcp.ToolResult{
|
||||||
|
Type: mcp.ContentTypeImage,
|
||||||
|
Content: map[string]any{
|
||||||
|
"data": "base64EncodedImageData...",
|
||||||
|
"mimeType": "image/jpeg",
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to text
|
||||||
|
return mcp.ToolResult{
|
||||||
|
Type: mcp.ContentTypeText,
|
||||||
|
Content: "This is a text result from ToolResult",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterTool(customResultTool)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Registering Prompts
|
||||||
|
|
||||||
|
Prompts are reusable conversation templates for AI models.
|
||||||
|
|
||||||
|
#### Static Prompt Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a simple static prompt with placeholders
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "hello",
|
||||||
|
Description: "A simple hello prompt",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "name",
|
||||||
|
Description: "The name to greet",
|
||||||
|
Required: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Content: "Say hello to {{name}} and introduce yourself as an AI assistant.",
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Dynamic Prompt with Handler Function:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a prompt with a dynamic handler function
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "dynamic-prompt",
|
||||||
|
Description: "A prompt that uses a handler to generate dynamic content",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "username",
|
||||||
|
Description: "User's name for personalized greeting",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "topic",
|
||||||
|
Description: "Topic of expertise",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||||
|
var req struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Topic string `json:"topic"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a user message
|
||||||
|
userMessage := mcp.PromptMessage{
|
||||||
|
Role: mcp.RoleUser,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an assistant response with current time
|
||||||
|
currentTime := time.Now().Format(time.RFC1123)
|
||||||
|
assistantMessage := mcp.PromptMessage{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.",
|
||||||
|
req.Username, req.Topic, currentTime),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return both messages as a conversation
|
||||||
|
return []mcp.PromptMessage{userMessage, assistantMessage}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Multi-Message Prompt with Code Examples:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a prompt that provides code examples in different programming languages
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "code-example",
|
||||||
|
Description: "Provides code examples in different programming languages",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "language",
|
||||||
|
Description: "Programming language for the example",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "complexity",
|
||||||
|
Description: "Complexity level (simple, medium, advanced)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||||
|
var req struct {
|
||||||
|
Language string `json:"language"`
|
||||||
|
Complexity string `json:"complexity,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate language
|
||||||
|
supportedLanguages := map[string]bool{"go": true, "python": true, "javascript": true, "rust": true}
|
||||||
|
if !supportedLanguages[req.Language] {
|
||||||
|
return nil, fmt.Errorf("unsupported language: %s", req.Language)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate code example based on language and complexity
|
||||||
|
var codeExample string
|
||||||
|
|
||||||
|
switch req.Language {
|
||||||
|
case "go":
|
||||||
|
if req.Complexity == "simple" {
|
||||||
|
codeExample = `
|
||||||
|
package main
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
fmt.Println("Hello, World!")
|
||||||
|
}`
|
||||||
|
} else {
|
||||||
|
codeExample = `
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
now := time.Now()
|
||||||
|
fmt.Printf("Hello, World! Current time is %s\n", now.Format(time.RFC3339))
|
||||||
|
}`
|
||||||
|
}
|
||||||
|
case "python":
|
||||||
|
// Python example code
|
||||||
|
if req.Complexity == "simple" {
|
||||||
|
codeExample = `
|
||||||
|
def greet(name):
|
||||||
|
return f"Hello, {name}!"
|
||||||
|
|
||||||
|
print(greet("World"))`
|
||||||
|
} else {
|
||||||
|
codeExample = `
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
def greet(name, include_time=False):
|
||||||
|
message = f"Hello, {name}!"
|
||||||
|
if include_time:
|
||||||
|
message += f" Current time is {datetime.datetime.now().isoformat()}"
|
||||||
|
return message
|
||||||
|
|
||||||
|
print(greet("World", include_time=True))`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create messages array according to MCP spec
|
||||||
|
messages := []mcp.PromptMessage{
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("You are a helpful coding assistant specialized in %s programming.", req.Language),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleUser,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Show me a %s example of a Hello World program in %s.", req.Complexity, req.Language),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Here's a %s example in %s:\n\n```%s%s\n```\n\nHow can I help you implement this?",
|
||||||
|
req.Complexity, req.Language, req.Language, codeExample),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Registering Resources
|
||||||
|
|
||||||
|
Resources provide access to external content such as files or generated data.
|
||||||
|
|
||||||
|
#### Basic Resource Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a static resource
|
||||||
|
server.RegisterResource(mcp.Resource{
|
||||||
|
Name: "example-document",
|
||||||
|
URI: "file:///example/document.txt",
|
||||||
|
Description: "An example document",
|
||||||
|
MimeType: "text/plain",
|
||||||
|
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||||
|
return mcp.ResourceContent{
|
||||||
|
URI: "file:///example/document.txt",
|
||||||
|
MimeType: "text/plain",
|
||||||
|
Text: "This is an example document content.",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Dynamic Resource with Code Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a Go code resource with dynamic handler
|
||||||
|
server.RegisterResource(mcp.Resource{
|
||||||
|
Name: "go-example",
|
||||||
|
URI: "file:///project/src/main.go",
|
||||||
|
Description: "A simple Go example with multiple files",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||||
|
// Return ResourceContent with all required fields
|
||||||
|
return mcp.ResourceContent{
|
||||||
|
URI: "file:///project/src/main.go",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Text: "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Register a companion file for the above example
|
||||||
|
server.RegisterResource(mcp.Resource{
|
||||||
|
Name: "go-greeting",
|
||||||
|
URI: "file:///project/src/greeting/greeting.go",
|
||||||
|
Description: "A greeting package for the Go example",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||||
|
return mcp.ResourceContent{
|
||||||
|
URI: "file:///project/src/greeting/greeting.go",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Text: "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Binary Resource Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a binary resource (like an image)
|
||||||
|
server.RegisterResource(mcp.Resource{
|
||||||
|
Name: "example-image",
|
||||||
|
URI: "file:///example/image.png",
|
||||||
|
Description: "An example image",
|
||||||
|
MimeType: "image/png",
|
||||||
|
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||||
|
// Read image from file or generate it
|
||||||
|
imageData := "base64EncodedImageData..." // Base64 encoded image data
|
||||||
|
|
||||||
|
return mcp.ResourceContent{
|
||||||
|
URI: "file:///example/image.png",
|
||||||
|
MimeType: "image/png",
|
||||||
|
Blob: imageData, // For binary data
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using Resources in Prompts
|
||||||
|
|
||||||
|
You can embed resources in prompt responses to create rich interactions with proper MCP-compliant structure:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a prompt that embeds a resource
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "resource-example",
|
||||||
|
Description: "A prompt that embeds a resource",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "file_type",
|
||||||
|
Description: "Type of file to show (rust or go)",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||||
|
var req struct {
|
||||||
|
FileType string `json:"file_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resourceURI, mimeType, fileContent string
|
||||||
|
if req.FileType == "rust" {
|
||||||
|
resourceURI = "file:///project/src/main.rs"
|
||||||
|
mimeType = "text/x-rust"
|
||||||
|
fileContent = "fn main() {\n println!(\"Hello world!\");\n}"
|
||||||
|
} else {
|
||||||
|
resourceURI = "file:///project/src/main.go"
|
||||||
|
mimeType = "text/x-go"
|
||||||
|
fileContent = "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello, world!\")\n}"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create message with embedded resource using proper MCP format
|
||||||
|
return []mcp.PromptMessage{
|
||||||
|
{
|
||||||
|
Role: mcp.RoleUser,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Can you explain this %s code?", req.FileType),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.EmbeddedResource{
|
||||||
|
Type: mcp.ContentTypeResource,
|
||||||
|
Resource: struct {
|
||||||
|
URI string `json:"uri"`
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
Blob string `json:"blob,omitempty"`
|
||||||
|
}{
|
||||||
|
URI: resourceURI,
|
||||||
|
MimeType: mimeType,
|
||||||
|
Text: fileContent,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Above is a simple Hello World example in %s. Let me explain how it works.", req.FileType),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multiple File Resources Example
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a prompt that demonstrates embedding multiple resource files
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "go-code-example",
|
||||||
|
Description: "A prompt that correctly embeds multiple resource files",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "format",
|
||||||
|
Description: "How to format the code display",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||||
|
var req struct {
|
||||||
|
Format string `json:"format,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the Go code for multiple files
|
||||||
|
var mainGoText string = "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}"
|
||||||
|
var greetingGoText string = "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}"
|
||||||
|
|
||||||
|
// Create message with properly formatted embedded resource per MCP spec
|
||||||
|
messages := []mcp.PromptMessage{
|
||||||
|
{
|
||||||
|
Role: mcp.RoleUser,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: "Show me a simple Go example with proper imports.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: "Here's a simple Go example project:",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.EmbeddedResource{
|
||||||
|
Type: mcp.ContentTypeResource,
|
||||||
|
Resource: struct {
|
||||||
|
URI string `json:"uri"`
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
Blob string `json:"blob,omitempty"`
|
||||||
|
}{
|
||||||
|
URI: "file:///project/src/main.go",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Text: mainGoText,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add explanation and additional file if requested
|
||||||
|
if req.Format == "with_explanation" {
|
||||||
|
messages = append(messages, mcp.PromptMessage{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: "This example demonstrates a simple Go application with modular structure. The main.go file imports from a local 'greeting' package that provides the Hello function.",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Also show the greeting.go file with correct resource format
|
||||||
|
messages = append(messages, mcp.PromptMessage{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.EmbeddedResource{
|
||||||
|
Type: mcp.ContentTypeResource,
|
||||||
|
Resource: struct {
|
||||||
|
URI string `json:"uri"`
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
Blob string `json:"blob,omitempty"`
|
||||||
|
}{
|
||||||
|
URI: "file:///project/src/greeting/greeting.go",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Text: greetingGoText,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Complete Application Example
|
||||||
|
|
||||||
|
Here's a complete example demonstrating all the components:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"github.com/zeromicro/go-zero/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Load configuration
|
||||||
|
var c mcp.McpConf
|
||||||
|
if err := conf.Load("config.yaml", &c); err != nil {
|
||||||
|
log.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up logging
|
||||||
|
logx.DisableStat()
|
||||||
|
|
||||||
|
// Create MCP server
|
||||||
|
server := mcp.NewMcpServer(c)
|
||||||
|
defer server.Stop()
|
||||||
|
|
||||||
|
// Register a simple echo tool
|
||||||
|
echoTool := mcp.Tool{
|
||||||
|
Name: "echo",
|
||||||
|
Description: "Echoes back the message provided by the user",
|
||||||
|
InputSchema: mcp.InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"message": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "The message to echo back",
|
||||||
|
},
|
||||||
|
"prefix": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional prefix to add to the echoed message",
|
||||||
|
"default": "Echo: ",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{"message"},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
var req struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Prefix string `json:"prefix,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := "Echo: "
|
||||||
|
if len(req.Prefix) > 0 {
|
||||||
|
prefix = req.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
return prefix + req.Message, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
server.RegisterTool(echoTool)
|
||||||
|
|
||||||
|
// Register a static prompt
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "greeting",
|
||||||
|
Description: "A simple greeting prompt",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "name",
|
||||||
|
Description: "The name to greet",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Content: "Hello {{name}}! How can I assist you today?",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Register a dynamic prompt
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "dynamic-prompt",
|
||||||
|
Description: "A prompt that uses a handler to generate dynamic content",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "username",
|
||||||
|
Description: "User's name for personalized greeting",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "topic",
|
||||||
|
Description: "Topic of expertise",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||||
|
var req struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Topic string `json:"topic"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create messages with current time
|
||||||
|
currentTime := time.Now().Format(time.RFC1123)
|
||||||
|
return []mcp.PromptMessage{
|
||||||
|
{
|
||||||
|
Role: mcp.RoleUser,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.",
|
||||||
|
req.Username, req.Topic, currentTime),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Register a resource
|
||||||
|
server.RegisterResource(mcp.Resource{
|
||||||
|
Name: "example-doc",
|
||||||
|
URI: "file:///example/doc.txt",
|
||||||
|
Description: "An example document",
|
||||||
|
MimeType: "text/plain",
|
||||||
|
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||||
|
return mcp.ResourceContent{
|
||||||
|
URI: "file:///example/doc.txt",
|
||||||
|
MimeType: "text/plain",
|
||||||
|
Text: "This is the content of the example document.",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Start the server
|
||||||
|
fmt.Printf("Starting MCP server on %s:%d\n", c.Host, c.Port)
|
||||||
|
server.Start()
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
The MCP implementation provides comprehensive error handling:
|
||||||
|
|
||||||
|
- Tool execution errors are properly reported back to clients
|
||||||
|
- Missing or invalid parameters are detected and reported with appropriate error codes
|
||||||
|
- Resource and prompt lookup failures are handled gracefully
|
||||||
|
- Timeout handling for long-running tool executions using context
|
||||||
|
- Panic recovery to prevent server crashes
|
||||||
|
|
||||||
|
## Advanced Features
|
||||||
|
|
||||||
|
- **Annotations**: Add audience and priority metadata to content
|
||||||
|
- **Content Types**: Support for text, images, audio, and other content formats
|
||||||
|
- **Embedded Resources**: Include file resources directly in prompt responses
|
||||||
|
- **Context Awareness**: All handlers receive context.Context for timeout and cancellation support
|
||||||
|
- **Progress Tokens**: Support for tracking progress of long-running operations
|
||||||
|
- **Customizable Timeouts**: Configure execution timeouts for tools and operations
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
- Tool execution runs with configurable timeouts to prevent blocking
|
||||||
|
- Efficient client tracking and cleanup to prevent resource leaks
|
||||||
|
- Proper concurrency handling with mutex protection for shared resources
|
||||||
|
- Buffered message channels to prevent blocking on client message delivery
|
||||||
|
|||||||
239
mcp/server.go
239
mcp/server.go
@@ -42,14 +42,14 @@ func NewMcpServer(c McpConf) McpServer {
|
|||||||
Method: http.MethodGet,
|
Method: http.MethodGet,
|
||||||
Path: s.conf.Mcp.SseEndpoint,
|
Path: s.conf.Mcp.SseEndpoint,
|
||||||
Handler: s.handleSSE,
|
Handler: s.handleSSE,
|
||||||
}, rest.WithSSE())
|
}, rest.WithSSE(), rest.WithTimeout(c.Mcp.SseTimeout))
|
||||||
|
|
||||||
// JSON-RPC message endpoint for regular requests
|
// JSON-RPC message endpoint for regular requests
|
||||||
s.server.AddRoute(rest.Route{
|
s.server.AddRoute(rest.Route{
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Path: s.conf.Mcp.MessageEndpoint,
|
Path: s.conf.Mcp.MessageEndpoint,
|
||||||
Handler: s.handleRequest,
|
Handler: s.handleRequest,
|
||||||
})
|
}, rest.WithTimeout(c.Mcp.MessageTimeout))
|
||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
@@ -182,21 +182,23 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Always allow initialize and notifications/initialized regardless of client state
|
// Always allow initialize and notifications/initialized regardless of client state
|
||||||
if req.Method == methodInitialize {
|
if req.Method == methodInitialize {
|
||||||
logx.Infof("Processing initialize request with ID: %d", req.ID)
|
logx.Infof("Processing initialize request with ID: %d", req.ID)
|
||||||
s.processInitialize(client, req)
|
s.processInitialize(r.Context(), client, req)
|
||||||
logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID)
|
logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID)
|
||||||
return
|
return
|
||||||
} else if req.Method == methodNotificationsInitialized {
|
} else if req.Method == methodNotificationsInitialized {
|
||||||
// Handle initialized notification
|
// Handle initialized notification
|
||||||
logx.Info("Received notifications/initialized notification")
|
logx.Info("Received notifications/initialized notification")
|
||||||
if !isNotification {
|
if !isNotification {
|
||||||
s.sendErrorResponse(client, req.ID, "Method should be used as a notification", errCodeInvalidRequest)
|
s.sendErrorResponse(r.Context(), client, req.ID,
|
||||||
|
"Method should be used as a notification", errCodeInvalidRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.processNotificationInitialized(client)
|
s.processNotificationInitialized(client)
|
||||||
return
|
return
|
||||||
} else if !client.initialized && req.Method != methodNotificationsCancelled {
|
} else if !client.initialized && req.Method != methodNotificationsCancelled {
|
||||||
// Block most requests until client is initialized (except for cancellations)
|
// Block most requests until client is initialized (except for cancellations)
|
||||||
s.sendErrorResponse(client, req.ID, "Client not fully initialized, waiting for notifications/initialized",
|
s.sendErrorResponse(r.Context(), client, req.ID,
|
||||||
|
"Client not fully initialized, waiting for notifications/initialized",
|
||||||
errCodeClientNotInitialized)
|
errCodeClientNotInitialized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -205,41 +207,41 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
switch req.Method {
|
switch req.Method {
|
||||||
case methodToolsCall:
|
case methodToolsCall:
|
||||||
logx.Infof("Received tools call request with ID: %d", req.ID)
|
logx.Infof("Received tools call request with ID: %d", req.ID)
|
||||||
s.processToolCall(client, req)
|
s.processToolCall(r.Context(), client, req)
|
||||||
logx.Infof("Sent tools call response for ID: %d", req.ID)
|
logx.Infof("Sent tools call response for ID: %d", req.ID)
|
||||||
case methodToolsList:
|
case methodToolsList:
|
||||||
logx.Infof("Processing tools/list request with ID: %d", req.ID)
|
logx.Infof("Processing tools/list request with ID: %d", req.ID)
|
||||||
s.processListTools(client, req)
|
s.processListTools(r.Context(), client, req)
|
||||||
logx.Infof("Sent tools/list response for ID: %d", req.ID)
|
logx.Infof("Sent tools/list response for ID: %d", req.ID)
|
||||||
case methodPromptsList:
|
case methodPromptsList:
|
||||||
logx.Infof("Processing prompts/list request with ID: %d", req.ID)
|
logx.Infof("Processing prompts/list request with ID: %d", req.ID)
|
||||||
s.processListPrompts(client, req)
|
s.processListPrompts(r.Context(), client, req)
|
||||||
logx.Infof("Sent prompts/list response for ID: %d", req.ID)
|
logx.Infof("Sent prompts/list response for ID: %d", req.ID)
|
||||||
case methodPromptsGet:
|
case methodPromptsGet:
|
||||||
logx.Infof("Processing prompts/get request with ID: %d", req.ID)
|
logx.Infof("Processing prompts/get request with ID: %d", req.ID)
|
||||||
s.processGetPrompt(client, req)
|
s.processGetPrompt(r.Context(), client, req)
|
||||||
logx.Infof("Sent prompts/get response for ID: %d", req.ID)
|
logx.Infof("Sent prompts/get response for ID: %d", req.ID)
|
||||||
case methodResourcesList:
|
case methodResourcesList:
|
||||||
logx.Infof("Processing resources/list request with ID: %d", req.ID)
|
logx.Infof("Processing resources/list request with ID: %d", req.ID)
|
||||||
s.processListResources(client, req)
|
s.processListResources(r.Context(), client, req)
|
||||||
logx.Infof("Sent resources/list response for ID: %d", req.ID)
|
logx.Infof("Sent resources/list response for ID: %d", req.ID)
|
||||||
case methodResourcesRead:
|
case methodResourcesRead:
|
||||||
logx.Infof("Processing resources/read request with ID: %d", req.ID)
|
logx.Infof("Processing resources/read request with ID: %d", req.ID)
|
||||||
s.processResourcesRead(client, req)
|
s.processResourcesRead(r.Context(), client, req)
|
||||||
logx.Infof("Sent resources/read response for ID: %d", req.ID)
|
logx.Infof("Sent resources/read response for ID: %d", req.ID)
|
||||||
case methodResourcesSubscribe:
|
case methodResourcesSubscribe:
|
||||||
logx.Infof("Processing resources/subscribe request with ID: %d", req.ID)
|
logx.Infof("Processing resources/subscribe request with ID: %d", req.ID)
|
||||||
s.processResourceSubscribe(client, req)
|
s.processResourceSubscribe(r.Context(), client, req)
|
||||||
logx.Infof("Sent resources/subscribe response for ID: %d", req.ID)
|
logx.Infof("Sent resources/subscribe response for ID: %d", req.ID)
|
||||||
case methodPing:
|
case methodPing:
|
||||||
logx.Infof("Processing ping request with ID: %d", req.ID)
|
logx.Infof("Processing ping request with ID: %d", req.ID)
|
||||||
s.processPing(client, req)
|
s.processPing(r.Context(), client, req)
|
||||||
case methodNotificationsCancelled:
|
case methodNotificationsCancelled:
|
||||||
logx.Infof("Received notifications/cancelled notification: %v", req.Params)
|
logx.Infof("Received notifications/cancelled notification: %d", req.ID)
|
||||||
s.processNotificationCancelled(client, req)
|
s.processNotificationCancelled(r.Context(), client, req)
|
||||||
default:
|
default:
|
||||||
logx.Infof("Unknown method: %s", req.Method)
|
logx.Infof("Unknown method: %s from client: %d", req.Method, req.ID)
|
||||||
s.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound)
|
s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -321,7 +323,7 @@ func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
case <-r.Context().Done():
|
case <-r.Context().Done():
|
||||||
// Client disconnected or request was canceled
|
// Client disconnected or request was canceled or timed out
|
||||||
logx.Infof("Client %s disconnected: context done", sessionId)
|
logx.Infof("Client %s disconnected: context done", sessionId)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -329,7 +331,7 @@ func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// processInitialize processes the initialize request
|
// processInitialize processes the initialize request
|
||||||
func (s *sseMcpServer) processInitialize(client *mcpClient, req Request) {
|
func (s *sseMcpServer) processInitialize(ctx context.Context, client *mcpClient, req Request) {
|
||||||
// Create a proper JSON-RPC response that preserves the client's request ID
|
// Create a proper JSON-RPC response that preserves the client's request ID
|
||||||
result := initializationResponse{
|
result := initializationResponse{
|
||||||
ProtocolVersion: s.conf.Mcp.ProtocolVersion,
|
ProtocolVersion: s.conf.Mcp.ProtocolVersion,
|
||||||
@@ -362,11 +364,11 @@ func (s *sseMcpServer) processInitialize(client *mcpClient, req Request) {
|
|||||||
client.initialized = true
|
client.initialized = true
|
||||||
|
|
||||||
// Send response with client's original request ID
|
// Send response with client's original request ID
|
||||||
s.sendResponse(client, req.ID, result)
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processListTools processes the tools/list request
|
// processListTools processes the tools/list request
|
||||||
func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
|
func (s *sseMcpServer) processListTools(ctx context.Context, client *mcpClient, req Request) {
|
||||||
// Extract pagination params if any
|
// Extract pagination params if any
|
||||||
var nextCursor string
|
var nextCursor string
|
||||||
var progressToken any
|
var progressToken any
|
||||||
@@ -390,6 +392,9 @@ func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
|
|||||||
var toolsList []Tool
|
var toolsList []Tool
|
||||||
s.toolsLock.Lock()
|
s.toolsLock.Lock()
|
||||||
for _, tool := range s.tools {
|
for _, tool := range s.tools {
|
||||||
|
if len(tool.InputSchema.Type) == 0 {
|
||||||
|
tool.InputSchema.Type = ContentTypeObject
|
||||||
|
}
|
||||||
toolsList = append(toolsList, tool)
|
toolsList = append(toolsList, tool)
|
||||||
}
|
}
|
||||||
s.toolsLock.Unlock()
|
s.toolsLock.Unlock()
|
||||||
@@ -405,15 +410,15 @@ func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
|
|||||||
// Add meta information if progress token was provided
|
// Add meta information if progress token was provided
|
||||||
if progressToken != nil {
|
if progressToken != nil {
|
||||||
result.Result.Meta = map[string]any{
|
result.Result.Meta = map[string]any{
|
||||||
"progressToken": progressToken,
|
progressTokenKey: progressToken,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.sendResponse(client, req.ID, result)
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processListPrompts processes the prompts/list request
|
// processListPrompts processes the prompts/list request
|
||||||
func (s *sseMcpServer) processListPrompts(client *mcpClient, req Request) {
|
func (s *sseMcpServer) processListPrompts(ctx context.Context, client *mcpClient, req Request) {
|
||||||
// Extract pagination params if any
|
// Extract pagination params if any
|
||||||
var nextCursor string
|
var nextCursor string
|
||||||
if req.Params != nil {
|
if req.Params != nil {
|
||||||
@@ -447,11 +452,11 @@ func (s *sseMcpServer) processListPrompts(client *mcpClient, req Request) {
|
|||||||
NextCursor: nextCursor,
|
NextCursor: nextCursor,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.sendResponse(client, req.ID, result)
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processListResources processes the resources/list request
|
// processListResources processes the resources/list request
|
||||||
func (s *sseMcpServer) processListResources(client *mcpClient, req Request) {
|
func (s *sseMcpServer) processListResources(ctx context.Context, client *mcpClient, req Request) {
|
||||||
// Extract pagination params if any
|
// Extract pagination params if any
|
||||||
var nextCursor string
|
var nextCursor string
|
||||||
var progressToken any
|
var progressToken any
|
||||||
@@ -493,15 +498,15 @@ func (s *sseMcpServer) processListResources(client *mcpClient, req Request) {
|
|||||||
// Add meta information if progress token was provided
|
// Add meta information if progress token was provided
|
||||||
if progressToken != nil {
|
if progressToken != nil {
|
||||||
result.Result.Meta = map[string]any{
|
result.Result.Meta = map[string]any{
|
||||||
"progressToken": progressToken,
|
progressTokenKey: progressToken,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.sendResponse(client, req.ID, result)
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processGetPrompt processes the prompts/get request
|
// processGetPrompt processes the prompts/get request
|
||||||
func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
func (s *sseMcpServer) processGetPrompt(ctx context.Context, client *mcpClient, req Request) {
|
||||||
type GetPromptParams struct {
|
type GetPromptParams struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Arguments map[string]string `json:"arguments,omitempty"`
|
Arguments map[string]string `json:"arguments,omitempty"`
|
||||||
@@ -509,7 +514,7 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
|||||||
|
|
||||||
var params GetPromptParams
|
var params GetPromptParams
|
||||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||||
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -519,7 +524,7 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
|||||||
s.promptsLock.Unlock()
|
s.promptsLock.Unlock()
|
||||||
if !exists {
|
if !exists {
|
||||||
message := fmt.Sprintf("Prompt '%s' not found", params.Name)
|
message := fmt.Sprintf("Prompt '%s' not found", params.Name)
|
||||||
s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams)
|
s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -529,12 +534,15 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
|||||||
missingArgs := validatePromptArguments(prompt, params.Arguments)
|
missingArgs := validatePromptArguments(prompt, params.Arguments)
|
||||||
if len(missingArgs) > 0 {
|
if len(missingArgs) > 0 {
|
||||||
message := fmt.Sprintf("Missing required arguments: %s", strings.Join(missingArgs, ", "))
|
message := fmt.Sprintf("Missing required arguments: %s", strings.Join(missingArgs, ", "))
|
||||||
s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams)
|
s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply default values for missing optional arguments
|
// Ensure arguments are initialized to an empty map if nil
|
||||||
args := applyDefaultArguments(prompt, params.Arguments)
|
if params.Arguments == nil {
|
||||||
|
params.Arguments = make(map[string]string)
|
||||||
|
}
|
||||||
|
args := params.Arguments
|
||||||
|
|
||||||
// Generate messages using handler or static content
|
// Generate messages using handler or static content
|
||||||
var messages []PromptMessage
|
var messages []PromptMessage
|
||||||
@@ -542,17 +550,17 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
|||||||
|
|
||||||
if prompt.Handler != nil {
|
if prompt.Handler != nil {
|
||||||
// Use dynamic handler to generate messages
|
// Use dynamic handler to generate messages
|
||||||
logx.Info("Using prompt handler to generate content")
|
messages, err = prompt.Handler(ctx, args)
|
||||||
messages, err = prompt.Handler(args)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logx.Errorf("Error from prompt handler: %v", err)
|
logx.Errorf("Error from prompt handler: %v", err)
|
||||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError)
|
s.sendErrorResponse(ctx, client, req.ID,
|
||||||
|
fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// No handler, generate messages from static content
|
// No handler, generate messages from static content
|
||||||
var messageText string
|
var messageText string
|
||||||
if prompt.Content != "" {
|
if len(prompt.Content) > 0 {
|
||||||
messageText = prompt.Content
|
messageText = prompt.Content
|
||||||
|
|
||||||
// Apply argument substitutions to static content
|
// Apply argument substitutions to static content
|
||||||
@@ -560,21 +568,13 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
|||||||
placeholder := fmt.Sprintf("{{%s}}", key)
|
placeholder := fmt.Sprintf("{{%s}}", key)
|
||||||
messageText = strings.Replace(messageText, placeholder, value, -1)
|
messageText = strings.Replace(messageText, placeholder, value, -1)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// No content, use a default fallback
|
|
||||||
topic := "this topic"
|
|
||||||
if t, ok := args["topic"]; ok && t != "" {
|
|
||||||
topic = t
|
|
||||||
}
|
|
||||||
messageText = fmt.Sprintf("Tell me about %s", topic)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a single user message with the content
|
// Create a single user message with the content
|
||||||
messages = []PromptMessage{
|
messages = []PromptMessage{
|
||||||
{
|
{
|
||||||
Role: roleUser,
|
Role: RoleUser,
|
||||||
Content: TextContent{
|
Content: TextContent{
|
||||||
Type: contentTypeText,
|
|
||||||
Text: messageText,
|
Text: messageText,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -587,49 +587,14 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
|||||||
Messages []PromptMessage `json:"messages"`
|
Messages []PromptMessage `json:"messages"`
|
||||||
}{
|
}{
|
||||||
Description: prompt.Description,
|
Description: prompt.Description,
|
||||||
Messages: messages,
|
Messages: toTypedPromptMessages(messages),
|
||||||
}
|
}
|
||||||
|
|
||||||
s.sendResponse(client, req.ID, result)
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
}
|
|
||||||
|
|
||||||
// validatePromptArguments checks if all required arguments are provided
|
|
||||||
// Returns a list of missing required arguments
|
|
||||||
func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string {
|
|
||||||
var missingArgs []string
|
|
||||||
|
|
||||||
for _, arg := range prompt.Arguments {
|
|
||||||
if arg.Required {
|
|
||||||
if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 {
|
|
||||||
missingArgs = append(missingArgs, arg.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return missingArgs
|
|
||||||
}
|
|
||||||
|
|
||||||
// applyDefaultArguments adds default values for missing optional arguments
|
|
||||||
func applyDefaultArguments(prompt Prompt, providedArgs map[string]string) map[string]string {
|
|
||||||
result := make(map[string]string)
|
|
||||||
|
|
||||||
// Copy all provided arguments
|
|
||||||
for k, v := range providedArgs {
|
|
||||||
result[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add defaults for missing arguments
|
|
||||||
for _, arg := range prompt.Arguments {
|
|
||||||
if _, exists := result[arg.Name]; !exists && arg.Default != "" {
|
|
||||||
result[arg.Name] = arg.Default
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// processToolCall processes the tools/call request
|
// processToolCall processes the tools/call request
|
||||||
func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
func (s *sseMcpServer) processToolCall(ctx context.Context, client *mcpClient, req Request) {
|
||||||
var toolCallParams struct {
|
var toolCallParams struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Arguments map[string]any `json:"arguments,omitempty"`
|
Arguments map[string]any `json:"arguments,omitempty"`
|
||||||
@@ -642,7 +607,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
|||||||
// If it's a RawMessage (JSON), unmarshal it
|
// If it's a RawMessage (JSON), unmarshal it
|
||||||
if err := json.Unmarshal(req.Params, &toolCallParams); err != nil {
|
if err := json.Unmarshal(req.Params, &toolCallParams); err != nil {
|
||||||
logx.Errorf("Failed to unmarshal tool call params: %v", err)
|
logx.Errorf("Failed to unmarshal tool call params: %v", err)
|
||||||
s.sendErrorResponse(client, req.ID, "Invalid tool call parameters", errCodeInvalidParams)
|
s.sendErrorResponse(ctx, client, req.ID, "Invalid tool call parameters", errCodeInvalidParams)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -654,15 +619,11 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
|||||||
tool, exists := s.tools[toolCallParams.Name]
|
tool, exists := s.tools[toolCallParams.Name]
|
||||||
s.toolsLock.Unlock()
|
s.toolsLock.Unlock()
|
||||||
if !exists {
|
if !exists {
|
||||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Tool '%s' not found",
|
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Tool '%s' not found",
|
||||||
toolCallParams.Name), errCodeInvalidParams)
|
toolCallParams.Name), errCodeInvalidParams)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a context with the configured timeout
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), s.conf.Mcp.ToolTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Log parameters before execution
|
// Log parameters before execution
|
||||||
logx.Infof("Executing tool '%s' with arguments: %#v", toolCallParams.Name, toolCallParams.Arguments)
|
logx.Infof("Executing tool '%s' with arguments: %#v", toolCallParams.Name, toolCallParams.Arguments)
|
||||||
|
|
||||||
@@ -671,6 +632,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
|||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Create a channel to receive the result
|
// Create a channel to receive the result
|
||||||
|
// make sure to have 1 size buffer to avoid channel leak if timeout
|
||||||
resultCh := make(chan struct {
|
resultCh := make(chan struct {
|
||||||
result any
|
result any
|
||||||
err error
|
err error
|
||||||
@@ -678,7 +640,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
|||||||
|
|
||||||
// Execute the tool handler in a goroutine
|
// Execute the tool handler in a goroutine
|
||||||
go func() {
|
go func() {
|
||||||
toolResult, toolErr := tool.Handler(toolCallParams.Arguments)
|
toolResult, toolErr := tool.Handler(ctx, toolCallParams.Arguments)
|
||||||
resultCh <- struct {
|
resultCh <- struct {
|
||||||
result any
|
result any
|
||||||
err error
|
err error
|
||||||
@@ -694,9 +656,9 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
|||||||
result = res.result
|
result = res.result
|
||||||
err = res.err
|
err = res.err
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
// Handle timeout
|
// Handle request timeout
|
||||||
logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.ToolTimeout, toolCallParams.Name)
|
logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.MessageTimeout, toolCallParams.Name)
|
||||||
s.sendErrorResponse(client, req.ID, "Tool execution timed out", errCodeTimeout)
|
s.sendErrorResponse(ctx, client, req.ID, "Tool execution timed out", errCodeTimeout)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -710,7 +672,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
|||||||
// Add meta information if progress token was provided
|
// Add meta information if progress token was provided
|
||||||
if progressToken != nil {
|
if progressToken != nil {
|
||||||
callToolResult.Result.Meta = map[string]any{
|
callToolResult.Result.Meta = map[string]any{
|
||||||
"progressToken": progressToken,
|
progressTokenKey: progressToken,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -722,12 +684,11 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
|||||||
|
|
||||||
callToolResult.Content = []any{
|
callToolResult.Content = []any{
|
||||||
TextContent{
|
TextContent{
|
||||||
Type: contentTypeText,
|
|
||||||
Text: fmt.Sprintf("Error: %v", err),
|
Text: fmt.Sprintf("Error: %v", err),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
callToolResult.IsError = true
|
callToolResult.IsError = true
|
||||||
s.sendResponse(client, req.ID, callToolResult)
|
s.sendResponse(ctx, client, req.ID, callToolResult)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -736,10 +697,9 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
|||||||
case string:
|
case string:
|
||||||
// Simple string becomes text content
|
// Simple string becomes text content
|
||||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
Type: contentTypeText,
|
|
||||||
Text: v,
|
Text: v,
|
||||||
Annotations: &Annotations{
|
Annotations: &Annotations{
|
||||||
Audience: []roleType{roleUser, roleAssistant},
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
case map[string]any:
|
case map[string]any:
|
||||||
@@ -749,69 +709,63 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
|||||||
jsonStr = []byte(err.Error())
|
jsonStr = []byte(err.Error())
|
||||||
}
|
}
|
||||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
Type: contentTypeText,
|
|
||||||
Text: string(jsonStr),
|
Text: string(jsonStr),
|
||||||
Annotations: &Annotations{
|
Annotations: &Annotations{
|
||||||
Audience: []roleType{roleUser, roleAssistant},
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
case TextContent:
|
case TextContent:
|
||||||
// Direct TextContent object
|
|
||||||
callToolResult.Content = append(callToolResult.Content, v)
|
callToolResult.Content = append(callToolResult.Content, v)
|
||||||
case ImageContent:
|
case ImageContent:
|
||||||
// Direct ImageContent object
|
|
||||||
callToolResult.Content = append(callToolResult.Content, v)
|
callToolResult.Content = append(callToolResult.Content, v)
|
||||||
case []any:
|
case []any:
|
||||||
// Array of content items
|
|
||||||
callToolResult.Content = v
|
callToolResult.Content = v
|
||||||
case ToolResult:
|
case ToolResult:
|
||||||
// Handle legacy ToolResult type
|
// Handle legacy ToolResult type
|
||||||
switch v.Type {
|
switch v.Type {
|
||||||
case contentTypeText:
|
case ContentTypeText:
|
||||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
Type: contentTypeText,
|
|
||||||
Text: fmt.Sprintf("%v", v.Content),
|
Text: fmt.Sprintf("%v", v.Content),
|
||||||
Annotations: &Annotations{
|
Annotations: &Annotations{
|
||||||
Audience: []roleType{roleUser, roleAssistant},
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
case contentTypeImage:
|
case ContentTypeImage:
|
||||||
if imgData, ok := v.Content.(map[string]any); ok {
|
if imgData, ok := v.Content.(map[string]any); ok {
|
||||||
callToolResult.Content = append(callToolResult.Content, ImageContent{
|
callToolResult.Content = append(callToolResult.Content, ImageContent{
|
||||||
Type: contentTypeImage,
|
|
||||||
Data: fmt.Sprintf("%v", imgData["data"]),
|
Data: fmt.Sprintf("%v", imgData["data"]),
|
||||||
MimeType: fmt.Sprintf("%v", imgData["mimeType"]),
|
MimeType: fmt.Sprintf("%v", imgData["mimeType"]),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
Type: contentTypeText,
|
|
||||||
Text: fmt.Sprintf("%v", v.Content),
|
Text: fmt.Sprintf("%v", v.Content),
|
||||||
Annotations: &Annotations{
|
Annotations: &Annotations{
|
||||||
Audience: []roleType{roleUser, roleAssistant},
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
// For any other type, convert to string
|
// For any other type, convert to string
|
||||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
Type: contentTypeText,
|
|
||||||
Text: fmt.Sprintf("%v", v),
|
Text: fmt.Sprintf("%v", v),
|
||||||
Annotations: &Annotations{
|
Annotations: &Annotations{
|
||||||
Audience: []roleType{roleUser, roleAssistant},
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
callToolResult.Content = toTypedContents(callToolResult.Content)
|
||||||
logx.Infof("Tool call result: %#v", callToolResult)
|
logx.Infof("Tool call result: %#v", callToolResult)
|
||||||
s.sendResponse(client, req.ID, callToolResult)
|
|
||||||
|
s.sendResponse(ctx, client, req.ID, callToolResult)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processResourcesRead processes the resources/read request
|
// processResourcesRead processes the resources/read request
|
||||||
func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
func (s *sseMcpServer) processResourcesRead(ctx context.Context, client *mcpClient, req Request) {
|
||||||
var params ResourceReadParams
|
var params ResourceReadParams
|
||||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||||
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -821,7 +775,7 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
|||||||
s.resourcesLock.Unlock()
|
s.resourcesLock.Unlock()
|
||||||
|
|
||||||
if !exists {
|
if !exists {
|
||||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||||
params.URI), errCodeResourceNotFound)
|
params.URI), errCodeResourceNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -837,14 +791,14 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
s.sendResponse(client, req.ID, result)
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute the resource handler
|
// Execute the resource handler
|
||||||
content, err := resource.Handler()
|
content, err := resource.Handler(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error reading resource: %v", err),
|
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Error reading resource: %v", err),
|
||||||
errCodeInternalError)
|
errCodeInternalError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -865,14 +819,14 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
|||||||
Contents: []ResourceContent{content},
|
Contents: []ResourceContent{content},
|
||||||
}
|
}
|
||||||
|
|
||||||
s.sendResponse(client, req.ID, result)
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processResourceSubscribe processes the resources/subscribe request
|
// processResourceSubscribe processes the resources/subscribe request
|
||||||
func (s *sseMcpServer) processResourceSubscribe(client *mcpClient, req Request) {
|
func (s *sseMcpServer) processResourceSubscribe(ctx context.Context, client *mcpClient, req Request) {
|
||||||
var params ResourceSubscribeParams
|
var params ResourceSubscribeParams
|
||||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||||
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -882,19 +836,17 @@ func (s *sseMcpServer) processResourceSubscribe(client *mcpClient, req Request)
|
|||||||
s.resourcesLock.Unlock()
|
s.resourcesLock.Unlock()
|
||||||
|
|
||||||
if !exists {
|
if !exists {
|
||||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||||
params.URI), errCodeResourceNotFound)
|
params.URI), errCodeResourceNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send success response for the subscription
|
// Send success response for the subscription
|
||||||
s.sendResponse(client, req.ID, struct{}{})
|
s.sendResponse(ctx, client, req.ID, struct{}{})
|
||||||
|
|
||||||
logx.Infof("Client %s subscribed to resource '%s'", client.id, params.URI)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// processNotificationCancelled processes the notifications/cancelled notification
|
// processNotificationCancelled processes the notifications/cancelled notification
|
||||||
func (s *sseMcpServer) processNotificationCancelled(client *mcpClient, req Request) {
|
func (s *sseMcpServer) processNotificationCancelled(ctx context.Context, client *mcpClient, req Request) {
|
||||||
// Extract the requestId that was canceled
|
// Extract the requestId that was canceled
|
||||||
type CancelParams struct {
|
type CancelParams struct {
|
||||||
RequestId int64 `json:"requestId"`
|
RequestId int64 `json:"requestId"`
|
||||||
@@ -918,18 +870,17 @@ func (s *sseMcpServer) processNotificationInitialized(client *mcpClient) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// processPing processes the ping request and responds immediately
|
// processPing processes the ping request and responds immediately
|
||||||
func (s *sseMcpServer) processPing(client *mcpClient, req Request) {
|
func (s *sseMcpServer) processPing(ctx context.Context, client *mcpClient, req Request) {
|
||||||
// A ping request should simply respond with an empty result to confirm the server is alive
|
// A ping request should simply respond with an empty result to confirm the server is alive
|
||||||
logx.Infof("Received ping request with ID: %d", req.ID)
|
logx.Infof("Received ping request with ID: %d", req.ID)
|
||||||
|
|
||||||
// Send an empty response with client's original request ID
|
// Send an empty response with client's original request ID
|
||||||
s.sendResponse(client, req.ID, struct{}{})
|
s.sendResponse(ctx, client, req.ID, struct{}{})
|
||||||
|
|
||||||
logx.Infof("Sent ping response for ID: %d", req.ID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendErrorResponse sends an error response via the SSE channel
|
// sendErrorResponse sends an error response via the SSE channel
|
||||||
func (s *sseMcpServer) sendErrorResponse(client *mcpClient, id int64, message string, code int) {
|
func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
|
||||||
|
id int64, message string, code int) {
|
||||||
errorResponse := struct {
|
errorResponse := struct {
|
||||||
JsonRpc string `json:"jsonrpc"`
|
JsonRpc string `json:"jsonrpc"`
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
@@ -949,11 +900,17 @@ func (s *sseMcpServer) sendErrorResponse(client *mcpClient, id int64, message st
|
|||||||
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||||
logx.Infof("Sending error for ID %d: %s", id, sseMessage)
|
logx.Infof("Sending error for ID %d: %s", id, sseMessage)
|
||||||
|
|
||||||
client.channel <- sseMessage
|
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||||
|
select {
|
||||||
|
case client.channel <- sseMessage:
|
||||||
|
default:
|
||||||
|
// Channel buffer is full, log warning and continue
|
||||||
|
logx.Infof("Client %s channel is full while sending error response with ID %d", client.id, id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendResponse sends a success response via the SSE channel
|
// sendResponse sends a success response via the SSE channel
|
||||||
func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) {
|
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id int64, result any) {
|
||||||
response := Response{
|
response := Response{
|
||||||
JsonRpc: jsonRpcVersion,
|
JsonRpc: jsonRpcVersion,
|
||||||
ID: id,
|
ID: id,
|
||||||
@@ -962,7 +919,7 @@ func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) {
|
|||||||
|
|
||||||
jsonData, err := json.Marshal(response)
|
jsonData, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.sendErrorResponse(client, id, "Failed to marshal response", errCodeInternalError)
|
s.sendErrorResponse(ctx, client, id, "Failed to marshal response", errCodeInternalError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -970,5 +927,11 @@ func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) {
|
|||||||
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||||
logx.Infof("Sending response for ID %d: %s", id, sseMessage)
|
logx.Infof("Sending response for ID %d: %s", id, sseMessage)
|
||||||
|
|
||||||
client.channel <- sseMessage
|
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||||
|
select {
|
||||||
|
case client.channel <- sseMessage:
|
||||||
|
default:
|
||||||
|
// Channel buffer is full, log warning and continue
|
||||||
|
logx.Infof("Client %s channel is full while sending response with ID %d", client.id, id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ host: localhost
|
|||||||
port: 8080
|
port: 8080
|
||||||
mcp:
|
mcp:
|
||||||
name: mcp-test-server
|
name: mcp-test-server
|
||||||
toolTimeout: 5s
|
messageTimeout: 5s
|
||||||
`
|
`
|
||||||
|
|
||||||
var c McpConf
|
var c McpConf
|
||||||
@@ -82,7 +82,6 @@ func (m *mockMcpServer) registerExampleTool() {
|
|||||||
Name: "test.tool",
|
Name: "test.tool",
|
||||||
Description: "A test tool",
|
Description: "A test tool",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{
|
Properties: map[string]any{
|
||||||
"input": map[string]any{
|
"input": map[string]any{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -91,7 +90,7 @@ func (m *mockMcpServer) registerExampleTool() {
|
|||||||
},
|
},
|
||||||
Required: []string{"input"},
|
Required: []string{"input"},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
input, ok := params["input"].(string)
|
input, ok := params["input"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid input parameter")
|
return nil, fmt.Errorf("invalid input parameter")
|
||||||
@@ -135,7 +134,7 @@ port: 8080
|
|||||||
mcp:
|
mcp:
|
||||||
cors:
|
cors:
|
||||||
- http://localhost:3000
|
- http://localhost:3000
|
||||||
toolTimeout: 5s
|
messageTimeout: 5s
|
||||||
`
|
`
|
||||||
|
|
||||||
var c McpConf
|
var c McpConf
|
||||||
@@ -186,7 +185,6 @@ func TestRegisterTool(t *testing.T) {
|
|||||||
Name: "example.tool",
|
Name: "example.tool",
|
||||||
Description: "An example tool",
|
Description: "An example tool",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{
|
Properties: map[string]any{
|
||||||
"input": map[string]any{
|
"input": map[string]any{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -194,7 +192,7 @@ func TestRegisterTool(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
return "result", nil
|
return "result", nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -280,7 +278,7 @@ func TestToolsList(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processListTools(client, req)
|
mock.server.processListTools(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -328,7 +326,7 @@ func TestToolCallBasic(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -355,8 +353,7 @@ func TestToolCallBasic(t *testing.T) {
|
|||||||
|
|
||||||
// Verify the response content
|
// Verify the response content
|
||||||
assert.Len(t, parsed.Result.Content, 1, "Response should contain one content item")
|
assert.Len(t, parsed.Result.Content, 1, "Response should contain one content item")
|
||||||
assert.Equal(t, "text", parsed.Result.Content[0]["type"], "Content type should be text")
|
assert.Equal(t, "Processed: test-input", parsed.Result.Content[0][ContentTypeText], "Tool result incorrect")
|
||||||
assert.Equal(t, "Processed: test-input", parsed.Result.Content[0]["text"], "Tool result incorrect")
|
|
||||||
assert.False(t, parsed.Result.IsError, "Response should not be an error")
|
assert.False(t, parsed.Result.IsError, "Response should not be an error")
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
t.Fatal("Timed out waiting for tool call response")
|
t.Fatal("Timed out waiting for tool call response")
|
||||||
@@ -373,10 +370,9 @@ func TestToolCallMapResult(t *testing.T) {
|
|||||||
Name: "map.tool",
|
Name: "map.tool",
|
||||||
Description: "A tool that returns a map result",
|
Description: "A tool that returns a map result",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
// Return a complex nested map structure
|
// Return a complex nested map structure
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"string": "value",
|
"string": "value",
|
||||||
@@ -417,7 +413,7 @@ func TestToolCallMapResult(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -445,13 +441,8 @@ func TestToolCallMapResult(t *testing.T) {
|
|||||||
firstItem, ok := content[0].(map[string]any)
|
firstItem, ok := content[0].(map[string]any)
|
||||||
require.True(t, ok, "First content item should be an object")
|
require.True(t, ok, "First content item should be an object")
|
||||||
|
|
||||||
// Verify it's a text content
|
|
||||||
contentType, ok := firstItem["type"].(string)
|
|
||||||
require.True(t, ok, "Content should have a type")
|
|
||||||
assert.Equal(t, "text", contentType, "Content type should be text")
|
|
||||||
|
|
||||||
// Get the text content which should be our JSON
|
// Get the text content which should be our JSON
|
||||||
text, ok := firstItem["text"].(string)
|
text, ok := firstItem[ContentTypeText].(string)
|
||||||
require.True(t, ok, "Content should have text")
|
require.True(t, ok, "Content should have text")
|
||||||
|
|
||||||
// Verify the text is valid JSON and contains our data
|
// Verify the text is valid JSON and contains our data
|
||||||
@@ -496,10 +487,9 @@ func TestToolCallArrayResult(t *testing.T) {
|
|||||||
Name: "array.tool",
|
Name: "array.tool",
|
||||||
Description: "A tool that returns an array result",
|
Description: "A tool that returns an array result",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
// Return an array of mixed content types
|
// Return an array of mixed content types
|
||||||
return []any{
|
return []any{
|
||||||
"string item",
|
"string item",
|
||||||
@@ -536,7 +526,7 @@ func TestToolCallArrayResult(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -574,16 +564,14 @@ func TestToolCallTextContentResult(t *testing.T) {
|
|||||||
Name: "text.content.tool",
|
Name: "text.content.tool",
|
||||||
Description: "A tool that returns a TextContent result",
|
Description: "A tool that returns a TextContent result",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
// Return a TextContent object directly
|
// Return a TextContent object directly
|
||||||
return TextContent{
|
return TextContent{
|
||||||
Type: "text",
|
|
||||||
Text: "This is a direct TextContent result",
|
Text: "This is a direct TextContent result",
|
||||||
Annotations: &Annotations{
|
Annotations: &Annotations{
|
||||||
Audience: []roleType{roleUser, roleAssistant},
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
Priority: func() *float64 { p := 0.9; return &p }(),
|
Priority: func() *float64 { p := 0.9; return &p }(),
|
||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
@@ -614,7 +602,7 @@ func TestToolCallTextContentResult(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -642,16 +630,6 @@ func TestToolCallTextContentResult(t *testing.T) {
|
|||||||
firstItem, ok := content[0].(map[string]any)
|
firstItem, ok := content[0].(map[string]any)
|
||||||
require.True(t, ok, "First content item should be an object")
|
require.True(t, ok, "First content item should be an object")
|
||||||
|
|
||||||
// Verify it's a text content
|
|
||||||
contentType, ok := firstItem["type"].(string)
|
|
||||||
require.True(t, ok, "Content should have a type")
|
|
||||||
assert.Equal(t, "text", contentType, "Content type should be text")
|
|
||||||
|
|
||||||
// Check text content
|
|
||||||
text, ok := firstItem["text"].(string)
|
|
||||||
require.True(t, ok, "Content should have text")
|
|
||||||
assert.Equal(t, "This is a direct TextContent result", text, "Text content should match")
|
|
||||||
|
|
||||||
// Check annotations
|
// Check annotations
|
||||||
annotations, ok := firstItem["annotations"].(map[string]any)
|
annotations, ok := firstItem["annotations"].(map[string]any)
|
||||||
require.True(t, ok, "Should have annotations")
|
require.True(t, ok, "Should have annotations")
|
||||||
@@ -679,13 +657,11 @@ func TestToolCallImageContentResult(t *testing.T) {
|
|||||||
Name: "image.content.tool",
|
Name: "image.content.tool",
|
||||||
Description: "A tool that returns an ImageContent result",
|
Description: "A tool that returns an ImageContent result",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
// Return an ImageContent object directly
|
// Return an ImageContent object directly
|
||||||
return ImageContent{
|
return ImageContent{
|
||||||
Type: "image",
|
|
||||||
Data: "dGVzdCBpbWFnZSBkYXRhIChiYXNlNjQgZW5jb2RlZCk=", // "test image data (base64 encoded)" in base64
|
Data: "dGVzdCBpbWFnZSBkYXRhIChiYXNlNjQgZW5jb2RlZCk=", // "test image data (base64 encoded)" in base64
|
||||||
MimeType: "image/png",
|
MimeType: "image/png",
|
||||||
}, nil
|
}, nil
|
||||||
@@ -716,7 +692,7 @@ func TestToolCallImageContentResult(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -744,11 +720,6 @@ func TestToolCallImageContentResult(t *testing.T) {
|
|||||||
firstItem, ok := content[0].(map[string]any)
|
firstItem, ok := content[0].(map[string]any)
|
||||||
require.True(t, ok, "First content item should be an object")
|
require.True(t, ok, "First content item should be an object")
|
||||||
|
|
||||||
// Verify it's an image content
|
|
||||||
contentType, ok := firstItem["type"].(string)
|
|
||||||
require.True(t, ok, "Content should have a type")
|
|
||||||
assert.Equal(t, "image", contentType, "Content type should be image")
|
|
||||||
|
|
||||||
// Check image data
|
// Check image data
|
||||||
data, ok := firstItem["data"].(string)
|
data, ok := firstItem["data"].(string)
|
||||||
require.True(t, ok, "Content should have data")
|
require.True(t, ok, "Content should have data")
|
||||||
@@ -773,12 +744,12 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
Name: "toolresult.tool",
|
Name: "toolresult.tool",
|
||||||
Description: "A tool that returns a ToolResult object",
|
Description: "A tool that returns a ToolResult object",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
Type: ContentTypeObject,
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
return ToolResult{
|
return ToolResult{
|
||||||
Type: "text",
|
Type: ContentTypeText,
|
||||||
Content: "This is a ToolResult with text content type",
|
Content: "This is a ToolResult with text content type",
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
@@ -790,10 +761,10 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
Name: "toolresult.image.tool",
|
Name: "toolresult.image.tool",
|
||||||
Description: "A tool that returns a ToolResult with image content",
|
Description: "A tool that returns a ToolResult with image content",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
Type: ContentTypeObject,
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
return ToolResult{
|
return ToolResult{
|
||||||
Type: "image",
|
Type: "image",
|
||||||
Content: map[string]any{
|
Content: map[string]any{
|
||||||
@@ -810,10 +781,9 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
Name: "toolresult.audio.tool",
|
Name: "toolresult.audio.tool",
|
||||||
Description: "A tool that returns a ToolResult with audio content",
|
Description: "A tool that returns a ToolResult with audio content",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
// Test with image type
|
// Test with image type
|
||||||
return ToolResult{
|
return ToolResult{
|
||||||
Type: "audio",
|
Type: "audio",
|
||||||
@@ -831,10 +801,9 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
Name: "toolresult.int.tool",
|
Name: "toolresult.int.tool",
|
||||||
Description: "A tool that returns a ToolResult with int content",
|
Description: "A tool that returns a ToolResult with int content",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
return 2, nil
|
return 2, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -845,10 +814,9 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
Name: "toolresult.bad.tool",
|
Name: "toolresult.bad.tool",
|
||||||
Description: "A tool that returns a ToolResult with bad content",
|
Description: "A tool that returns a ToolResult with bad content",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"type": "custom",
|
"type": "custom",
|
||||||
"data": make(chan int),
|
"data": make(chan int),
|
||||||
@@ -881,7 +849,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -909,13 +877,8 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
firstItem, ok := content[0].(map[string]any)
|
firstItem, ok := content[0].(map[string]any)
|
||||||
require.True(t, ok, "First content item should be an object")
|
require.True(t, ok, "First content item should be an object")
|
||||||
|
|
||||||
// Verify it's a text content
|
|
||||||
contentType, ok := firstItem["type"].(string)
|
|
||||||
require.True(t, ok, "Content should have a type")
|
|
||||||
assert.Equal(t, "text", contentType, "Content type should be text")
|
|
||||||
|
|
||||||
// Check text content
|
// Check text content
|
||||||
text, ok := firstItem["text"].(string)
|
text, ok := firstItem[ContentTypeText].(string)
|
||||||
require.True(t, ok, "Content should have text")
|
require.True(t, ok, "Content should have text")
|
||||||
assert.Equal(t, "This is a ToolResult with text content type", text, "Text content should match")
|
assert.Equal(t, "This is a ToolResult with text content type", text, "Text content should match")
|
||||||
|
|
||||||
@@ -947,7 +910,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -975,11 +938,6 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
firstItem, ok := content[0].(map[string]any)
|
firstItem, ok := content[0].(map[string]any)
|
||||||
require.True(t, ok, "First content item should be an object")
|
require.True(t, ok, "First content item should be an object")
|
||||||
|
|
||||||
// Verify it's an image content
|
|
||||||
contentType, ok := firstItem["type"].(string)
|
|
||||||
require.True(t, ok, "Content should have a type")
|
|
||||||
assert.Equal(t, "image", contentType, "Content type should be image")
|
|
||||||
|
|
||||||
// Check image data and mime type
|
// Check image data and mime type
|
||||||
data, ok := firstItem["data"].(string)
|
data, ok := firstItem["data"].(string)
|
||||||
require.True(t, ok, "Content should have data")
|
require.True(t, ok, "Content should have data")
|
||||||
@@ -1017,7 +975,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -1040,15 +998,6 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
content, ok := result["content"].([]any)
|
content, ok := result["content"].([]any)
|
||||||
require.True(t, ok, "Result should have a content array")
|
require.True(t, ok, "Result should have a content array")
|
||||||
require.NotEmpty(t, content, "Content should not be empty")
|
require.NotEmpty(t, content, "Content should not be empty")
|
||||||
|
|
||||||
// The first content item should be converted from ToolResult to ImageContent
|
|
||||||
firstItem, ok := content[0].(map[string]any)
|
|
||||||
require.True(t, ok, "First content item should be an object")
|
|
||||||
|
|
||||||
// Verify it's an image content
|
|
||||||
contentType, ok := firstItem["type"].(string)
|
|
||||||
require.True(t, ok, "Content should have a type")
|
|
||||||
assert.Equal(t, "text", contentType, "Content type should be image")
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
t.Fatal("Timed out waiting for tool call response")
|
t.Fatal("Timed out waiting for tool call response")
|
||||||
}
|
}
|
||||||
@@ -1077,7 +1026,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -1100,15 +1049,6 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
content, ok := result["content"].([]any)
|
content, ok := result["content"].([]any)
|
||||||
require.True(t, ok, "Result should have a content array")
|
require.True(t, ok, "Result should have a content array")
|
||||||
require.NotEmpty(t, content, "Content should not be empty")
|
require.NotEmpty(t, content, "Content should not be empty")
|
||||||
|
|
||||||
// The first content item should be converted from ToolResult to ImageContent
|
|
||||||
firstItem, ok := content[0].(map[string]any)
|
|
||||||
require.True(t, ok, "First content item should be an object")
|
|
||||||
|
|
||||||
// Verify it's an image content
|
|
||||||
contentType, ok := firstItem["type"].(string)
|
|
||||||
require.True(t, ok, "Content should have a type")
|
|
||||||
assert.Equal(t, "text", contentType, "Content type should be image")
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
t.Fatal("Timed out waiting for tool call response")
|
t.Fatal("Timed out waiting for tool call response")
|
||||||
}
|
}
|
||||||
@@ -1137,7 +1077,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -1159,10 +1099,9 @@ func TestToolCallError(t *testing.T) {
|
|||||||
Name: "error.tool",
|
Name: "error.tool",
|
||||||
Description: "A tool that returns an error",
|
Description: "A tool that returns an error",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
return nil, fmt.Errorf("tool execution failed")
|
return nil, fmt.Errorf("tool execution failed")
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -1189,7 +1128,7 @@ func TestToolCallError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Check the response
|
// Check the response
|
||||||
select {
|
select {
|
||||||
@@ -1207,20 +1146,16 @@ func TestToolCallTimeout(t *testing.T) {
|
|||||||
mock := newMockMcpServer(t)
|
mock := newMockMcpServer(t)
|
||||||
defer mock.shutdown()
|
defer mock.shutdown()
|
||||||
|
|
||||||
// Set a very short timeout for testing
|
|
||||||
mock.server.conf.Mcp.ToolTimeout = 10 * time.Millisecond
|
|
||||||
|
|
||||||
// Register a tool that times out
|
// Register a tool that times out
|
||||||
err := mock.server.RegisterTool(Tool{
|
err := mock.server.RegisterTool(Tool{
|
||||||
Name: "timeout.tool",
|
Name: "timeout.tool",
|
||||||
Description: "A tool that times out",
|
Description: "A tool that times out",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
time.Sleep(50 * time.Millisecond) // Sleep longer than timeout
|
<-ctx.Done()
|
||||||
return "this should never be returned", nil
|
return nil, fmt.Errorf("tool execution timed out")
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1244,16 +1179,24 @@ func TestToolCallTimeout(t *testing.T) {
|
|||||||
Method: methodToolsCall,
|
Method: methodToolsCall,
|
||||||
Params: paramBytes,
|
Params: paramBytes,
|
||||||
}
|
}
|
||||||
|
jsonBody, _ := json.Marshal(req)
|
||||||
|
|
||||||
// Process the tool call
|
// Create HTTP request
|
||||||
mock.server.processToolCall(client, req)
|
r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", bytes.NewReader(jsonBody))
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Process through handleRequest
|
||||||
|
go mock.server.handleRequest(w, r)
|
||||||
|
|
||||||
// Check the response
|
// Check the response
|
||||||
select {
|
select {
|
||||||
case response := <-client.channel:
|
case response := <-client.channel:
|
||||||
assert.Contains(t, response, "event: message", "Response should have message event")
|
assert.Contains(t, response, "event: message", "Response should have message event")
|
||||||
assert.Contains(t, response, `-32001`, "Response should contain a timeout error code")
|
assert.Contains(t, response, `-32001`, "Response should contain a timeout error code")
|
||||||
case <-time.After(150 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
t.Fatal("Timed out waiting for tool call response")
|
t.Fatal("Timed out waiting for tool call response")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1274,7 +1217,7 @@ func TestInitializeAndNotifications(t *testing.T) {
|
|||||||
Params: json.RawMessage(`{}`),
|
Params: json.RawMessage(`{}`),
|
||||||
}
|
}
|
||||||
|
|
||||||
mock.server.processInitialize(client, initReq)
|
mock.server.processInitialize(context.Background(), client, initReq)
|
||||||
|
|
||||||
// Check that client is initialized after initialize request
|
// Check that client is initialized after initialize request
|
||||||
assert.True(t, client.initialized, "Client should be marked as initialized after initialize request")
|
assert.True(t, client.initialized, "Client should be marked as initialized after initialize request")
|
||||||
@@ -1418,7 +1361,7 @@ func TestNotificationCancelled_badParams(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
buf := logtest.NewCollector(t)
|
buf := logtest.NewCollector(t)
|
||||||
mock.server.processNotificationCancelled(client, cancelReq)
|
mock.server.processNotificationCancelled(context.Background(), client, cancelReq)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-client.channel:
|
case <-client.channel:
|
||||||
@@ -1593,7 +1536,7 @@ func TestGetPrompt(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the request
|
// Process the request
|
||||||
mock.server.processGetPrompt(client, promptReq)
|
mock.server.processGetPrompt(context.Background(), client, promptReq)
|
||||||
|
|
||||||
// Check response
|
// Check response
|
||||||
select {
|
select {
|
||||||
@@ -1622,7 +1565,7 @@ func TestGetPrompt(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the request
|
// Process the request
|
||||||
mock.server.processGetPrompt(client, promptReq)
|
mock.server.processGetPrompt(context.Background(), client, promptReq)
|
||||||
|
|
||||||
// Check response
|
// Check response
|
||||||
select {
|
select {
|
||||||
@@ -1636,6 +1579,44 @@ func TestGetPrompt(t *testing.T) {
|
|||||||
t.Fatal("Timed out waiting for prompt response")
|
t.Fatal("Timed out waiting for prompt response")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("test prompt with nil params", func(t *testing.T) {
|
||||||
|
mock := newMockMcpServer(t)
|
||||||
|
defer mock.shutdown()
|
||||||
|
|
||||||
|
// Create a test client
|
||||||
|
client := addTestClient(mock.server, "test-client", true)
|
||||||
|
|
||||||
|
// Register a test prompt
|
||||||
|
testPrompt := Prompt{
|
||||||
|
Name: "test.prompt",
|
||||||
|
Description: "A test prompt",
|
||||||
|
}
|
||||||
|
mock.server.RegisterPrompt(testPrompt)
|
||||||
|
|
||||||
|
// Create a get prompt request
|
||||||
|
paramBytes, _ := json.Marshal(map[string]any{
|
||||||
|
"name": "test.prompt",
|
||||||
|
})
|
||||||
|
promptReq := Request{
|
||||||
|
JsonRpc: jsonRpcVersion,
|
||||||
|
ID: 1,
|
||||||
|
Method: "prompts/get",
|
||||||
|
Params: paramBytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the request
|
||||||
|
mock.server.processGetPrompt(context.Background(), client, promptReq)
|
||||||
|
|
||||||
|
// Check response
|
||||||
|
select {
|
||||||
|
case response := <-client.channel:
|
||||||
|
_, err := parseEvent(response)
|
||||||
|
assert.NoError(t, err, "Should be able to parse event")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for prompt response")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestBroadcast tests the broadcast functionality
|
// TestBroadcast tests the broadcast functionality
|
||||||
@@ -1903,34 +1884,79 @@ func TestNotificationInitialized(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSendBadResponse(t *testing.T) {
|
func TestSendResponse(t *testing.T) {
|
||||||
mock := newMockMcpServer(t)
|
t.Run("bad response", func(t *testing.T) {
|
||||||
defer mock.shutdown()
|
mock := newMockMcpServer(t)
|
||||||
|
defer mock.shutdown()
|
||||||
|
|
||||||
// Create a test client
|
// Create a test client
|
||||||
client := addTestClient(mock.server, "test-client", true)
|
client := addTestClient(mock.server, "test-client", true)
|
||||||
|
|
||||||
// Create a response
|
// Create a response
|
||||||
response := Response{
|
response := Response{
|
||||||
JsonRpc: jsonRpcVersion,
|
JsonRpc: jsonRpcVersion,
|
||||||
ID: 1,
|
ID: 1,
|
||||||
Result: make(chan int),
|
Result: make(chan int),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send the response
|
// Send the response
|
||||||
mock.server.sendResponse(client, 1, response)
|
mock.server.sendResponse(context.Background(), client, 1, response)
|
||||||
|
|
||||||
// Check the response in the client's channel
|
// Check the response in the client's channel
|
||||||
select {
|
select {
|
||||||
case res := <-client.channel:
|
case res := <-client.channel:
|
||||||
evt, err := parseEvent(res)
|
evt, err := parseEvent(res)
|
||||||
require.NoError(t, err, "Should parse event without error")
|
require.NoError(t, err, "Should parse event without error")
|
||||||
errMsg, ok := evt.Data["error"].(map[string]any)
|
errMsg, ok := evt.Data["error"].(map[string]any)
|
||||||
require.True(t, ok, "Should have error in response")
|
require.True(t, ok, "Should have error in response")
|
||||||
assert.Equal(t, float64(errCodeInternalError), errMsg["code"], "Error code should match")
|
assert.Equal(t, float64(errCodeInternalError), errMsg["code"], "Error code should match")
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
t.Fatal("Timed out waiting for response")
|
t.Fatal("Timed out waiting for response")
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("channel full", func(t *testing.T) {
|
||||||
|
mock := newMockMcpServer(t)
|
||||||
|
defer mock.shutdown()
|
||||||
|
|
||||||
|
// Create a test client
|
||||||
|
client := addTestClient(mock.server, "test-client", true)
|
||||||
|
for i := 0; i < eventChanSize; i++ {
|
||||||
|
client.channel <- "test"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a response
|
||||||
|
response := Response{
|
||||||
|
JsonRpc: jsonRpcVersion,
|
||||||
|
ID: 1,
|
||||||
|
Result: "foo",
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := logtest.NewCollector(t)
|
||||||
|
// Send the response
|
||||||
|
mock.server.sendResponse(context.Background(), client, 1, response)
|
||||||
|
// Check the response in the client's channel
|
||||||
|
assert.Contains(t, buf.String(), "channel is full")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendErrorResponse(t *testing.T) {
|
||||||
|
t.Run("channel full", func(t *testing.T) {
|
||||||
|
mock := newMockMcpServer(t)
|
||||||
|
defer mock.shutdown()
|
||||||
|
|
||||||
|
// Create a test client
|
||||||
|
client := addTestClient(mock.server, "test-client", true)
|
||||||
|
for i := 0; i < eventChanSize; i++ {
|
||||||
|
client.channel <- "test"
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := logtest.NewCollector(t)
|
||||||
|
// Send the response
|
||||||
|
mock.server.sendErrorResponse(context.Background(), client, 1, "foo", errCodeInternalError)
|
||||||
|
// Check the response in the client's channel
|
||||||
|
assert.Contains(t, buf.String(), "channel is full")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestMethodToolsCall tests the handling of tools/call method through handleRequest
|
// TestMethodToolsCall tests the handling of tools/call method through handleRequest
|
||||||
@@ -2028,8 +2054,7 @@ func TestMethodToolsCall(t *testing.T) {
|
|||||||
if len(content) > 0 {
|
if len(content) > 0 {
|
||||||
firstItem, ok := content[0].(map[string]any)
|
firstItem, ok := content[0].(map[string]any)
|
||||||
if ok {
|
if ok {
|
||||||
assert.Equal(t, "text", firstItem["type"], "Content type should be text")
|
assert.Contains(t, firstItem[ContentTypeText], "Processed: test-input", "Content should include processed input")
|
||||||
assert.Contains(t, firstItem["text"], "Processed: test-input", "Content should include processed input")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
@@ -2145,7 +2170,6 @@ func TestMethodPromptsGet(t *testing.T) {
|
|||||||
{
|
{
|
||||||
Name: "topic",
|
Name: "topic",
|
||||||
Description: "Topic to discuss",
|
Description: "Topic to discuss",
|
||||||
Default: "artificial intelligence",
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Content: "Hello {{name}}! Let's talk about {{topic}}.",
|
Content: "Hello {{name}}! Let's talk about {{topic}}.",
|
||||||
@@ -2227,14 +2251,12 @@ func TestMethodPromptsGet(t *testing.T) {
|
|||||||
if len(messages) > 0 {
|
if len(messages) > 0 {
|
||||||
message, ok := messages[0].(map[string]any)
|
message, ok := messages[0].(map[string]any)
|
||||||
require.True(t, ok, "Message should be an object")
|
require.True(t, ok, "Message should be an object")
|
||||||
assert.Equal(t, "user", message["role"], "Role should be 'user'")
|
assert.Equal(t, string(RoleUser), message["role"], "Role should be 'user'")
|
||||||
|
|
||||||
content, ok := message["content"].(map[string]any)
|
content, ok := message["content"].(map[string]any)
|
||||||
require.True(t, ok, "Should have content object")
|
require.True(t, ok, "Should have content object")
|
||||||
assert.Equal(t, "text", content["type"], "Content type should be text")
|
assert.Equal(t, ContentTypeText, content["type"], "Content type should be text")
|
||||||
assert.Contains(t, content["text"], "Hello Test User", "Content should include the name argument")
|
assert.Contains(t, content[ContentTypeText], "Hello Test User", "Content should include the name argument")
|
||||||
assert.Contains(t, content["text"], "about artificial intelligence",
|
|
||||||
"Content should include the default topic argument")
|
|
||||||
}
|
}
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
t.Fatal("Timed out waiting for prompt get response")
|
t.Fatal("Timed out waiting for prompt get response")
|
||||||
@@ -2255,27 +2277,24 @@ func TestMethodPromptsGet(t *testing.T) {
|
|||||||
{
|
{
|
||||||
Name: "question",
|
Name: "question",
|
||||||
Description: "User's question",
|
Description: "User's question",
|
||||||
Default: "How does this work?",
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Handler: func(args map[string]string) ([]PromptMessage, error) {
|
Handler: func(ctx context.Context, args map[string]string) ([]PromptMessage, error) {
|
||||||
username := args["username"]
|
username := args["username"]
|
||||||
question := args["question"]
|
question := args["question"]
|
||||||
|
|
||||||
// Create a system message
|
// Create a system message
|
||||||
systemMessage := PromptMessage{
|
systemMessage := PromptMessage{
|
||||||
Role: "system",
|
Role: RoleAssistant,
|
||||||
Content: TextContent{
|
Content: TextContent{
|
||||||
Type: "text",
|
|
||||||
Text: "You are a helpful assistant.",
|
Text: "You are a helpful assistant.",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a user message
|
// Create a user message
|
||||||
userMessage := PromptMessage{
|
userMessage := PromptMessage{
|
||||||
Role: "user",
|
Role: RoleUser,
|
||||||
Content: TextContent{
|
Content: TextContent{
|
||||||
Type: "text",
|
|
||||||
Text: fmt.Sprintf("Hi, I'm %s and I'm wondering: %s", username, question),
|
Text: fmt.Sprintf("Hi, I'm %s and I'm wondering: %s", username, question),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -2340,20 +2359,20 @@ func TestMethodPromptsGet(t *testing.T) {
|
|||||||
|
|
||||||
// Check message content
|
// Check message content
|
||||||
if len(messages) >= 2 {
|
if len(messages) >= 2 {
|
||||||
// First message should be system
|
// First message should be assistant
|
||||||
message1, _ := messages[0].(map[string]any)
|
message1, _ := messages[0].(map[string]any)
|
||||||
assert.Equal(t, "system", message1["role"], "First role should be 'system'")
|
assert.Equal(t, string(RoleAssistant), message1["role"], "First role should be 'system'")
|
||||||
|
|
||||||
content1, _ := message1["content"].(map[string]any)
|
content1, _ := message1["content"].(map[string]any)
|
||||||
assert.Contains(t, content1["text"], "helpful assistant", "System message should be correct")
|
assert.Contains(t, content1[ContentTypeText], "helpful assistant", "System message should be correct")
|
||||||
|
|
||||||
// Second message should be user
|
// Second message should be user
|
||||||
message2, _ := messages[1].(map[string]any)
|
message2, _ := messages[1].(map[string]any)
|
||||||
assert.Equal(t, "user", message2["role"], "Second role should be 'user'")
|
assert.Equal(t, string(RoleUser), message2["role"], "Second role should be 'user'")
|
||||||
|
|
||||||
content2, _ := message2["content"].(map[string]any)
|
content2, _ := message2["content"].(map[string]any)
|
||||||
assert.Contains(t, content2["text"], "Dynamic User", "User message should contain username")
|
assert.Contains(t, content2[ContentTypeText], "Dynamic User", "User message should contain username")
|
||||||
assert.Contains(t, content2["text"], "How to test this?", "User message should contain question")
|
assert.Contains(t, content2[ContentTypeText], "How to test this?", "User message should contain question")
|
||||||
}
|
}
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
t.Fatal("Timed out waiting for prompt get response")
|
t.Fatal("Timed out waiting for prompt get response")
|
||||||
@@ -2459,7 +2478,7 @@ func TestMethodPromptsGet(t *testing.T) {
|
|||||||
Name: "error-handler-prompt",
|
Name: "error-handler-prompt",
|
||||||
Description: "A prompt with a handler that returns an error",
|
Description: "A prompt with a handler that returns an error",
|
||||||
Arguments: []PromptArgument{},
|
Arguments: []PromptArgument{},
|
||||||
Handler: func(args map[string]string) ([]PromptMessage, error) {
|
Handler: func(ctx context.Context, args map[string]string) ([]PromptMessage, error) {
|
||||||
return nil, fmt.Errorf("test handler error")
|
return nil, fmt.Errorf("test handler error")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -2583,7 +2602,7 @@ func TestMethodResourcesList(t *testing.T) {
|
|||||||
URI: "file:///test/resource.txt",
|
URI: "file:///test/resource.txt",
|
||||||
Description: "A test resource with handler",
|
Description: "A test resource with handler",
|
||||||
MimeType: "text/plain",
|
MimeType: "text/plain",
|
||||||
Handler: func() (ResourceContent, error) {
|
Handler: func(ctx context.Context) (ResourceContent, error) {
|
||||||
return ResourceContent{
|
return ResourceContent{
|
||||||
URI: "file:///test/resource.txt",
|
URI: "file:///test/resource.txt",
|
||||||
MimeType: "text/plain",
|
MimeType: "text/plain",
|
||||||
@@ -2654,7 +2673,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
|||||||
URI: "file:///test/resource.txt",
|
URI: "file:///test/resource.txt",
|
||||||
Description: "A test resource with handler",
|
Description: "A test resource with handler",
|
||||||
MimeType: "text/plain",
|
MimeType: "text/plain",
|
||||||
Handler: func() (ResourceContent, error) {
|
Handler: func(ctx context.Context) (ResourceContent, error) {
|
||||||
return ResourceContent{
|
return ResourceContent{
|
||||||
URI: "file:///test/resource.txt",
|
URI: "file:///test/resource.txt",
|
||||||
MimeType: "text/plain",
|
MimeType: "text/plain",
|
||||||
@@ -2729,7 +2748,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
|||||||
require.True(t, ok, "Content should be an object")
|
require.True(t, ok, "Content should be an object")
|
||||||
assert.Equal(t, "file:///test/resource.txt", content["uri"], "URI should match")
|
assert.Equal(t, "file:///test/resource.txt", content["uri"], "URI should match")
|
||||||
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match")
|
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match")
|
||||||
assert.Equal(t, "This is test resource content", content["text"], "Text content should match")
|
assert.Equal(t, "This is test resource content", content[ContentTypeText], "Text content should match")
|
||||||
}
|
}
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
t.Fatal("Timed out waiting for resource read response")
|
t.Fatal("Timed out waiting for resource read response")
|
||||||
@@ -2799,7 +2818,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
|||||||
require.True(t, ok, "Content should be an object")
|
require.True(t, ok, "Content should be an object")
|
||||||
assert.Equal(t, "file:///test/no-handler.txt", content["uri"], "URI should match")
|
assert.Equal(t, "file:///test/no-handler.txt", content["uri"], "URI should match")
|
||||||
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match")
|
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match")
|
||||||
_, ok = content["text"]
|
_, ok = content[ContentTypeText]
|
||||||
assert.False(t, ok, "Text content should be empty string")
|
assert.False(t, ok, "Text content should be empty string")
|
||||||
}
|
}
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
@@ -2880,7 +2899,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
|||||||
client := addTestClient(mock.server, "test-client-resources", true)
|
client := addTestClient(mock.server, "test-client-resources", true)
|
||||||
|
|
||||||
// Process through handleRequest
|
// Process through handleRequest
|
||||||
mock.server.processResourcesRead(client, req)
|
mock.server.processResourcesRead(context.Background(), client, req)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case response := <-client.channel:
|
case response := <-client.channel:
|
||||||
@@ -2898,7 +2917,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
|||||||
URI: "file:///test/error.txt",
|
URI: "file:///test/error.txt",
|
||||||
Description: "A test resource with handler that returns error",
|
Description: "A test resource with handler that returns error",
|
||||||
MimeType: "text/plain",
|
MimeType: "text/plain",
|
||||||
Handler: func() (ResourceContent, error) {
|
Handler: func(ctx context.Context) (ResourceContent, error) {
|
||||||
return ResourceContent{}, fmt.Errorf("test handler error")
|
return ResourceContent{}, fmt.Errorf("test handler error")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -2946,7 +2965,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
|||||||
URI: "file:///test/missing-fields.txt",
|
URI: "file:///test/missing-fields.txt",
|
||||||
Description: "A test resource with handler that returns content missing fields",
|
Description: "A test resource with handler that returns content missing fields",
|
||||||
MimeType: "text/plain",
|
MimeType: "text/plain",
|
||||||
Handler: func() (ResourceContent, error) {
|
Handler: func(ctx context.Context) (ResourceContent, error) {
|
||||||
// Return ResourceContent without URI and MimeType
|
// Return ResourceContent without URI and MimeType
|
||||||
return ResourceContent{
|
return ResourceContent{
|
||||||
Text: "Content with missing fields",
|
Text: "Content with missing fields",
|
||||||
@@ -3006,7 +3025,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
|||||||
require.True(t, ok, "Content should be an object")
|
require.True(t, ok, "Content should be an object")
|
||||||
assert.Equal(t, "file:///test/missing-fields.txt", content["uri"], "URI should be filled from request")
|
assert.Equal(t, "file:///test/missing-fields.txt", content["uri"], "URI should be filled from request")
|
||||||
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should be filled from resource")
|
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should be filled from resource")
|
||||||
assert.Equal(t, "Content with missing fields", content["text"], "Text content should match")
|
assert.Equal(t, "Content with missing fields", content[ContentTypeText], "Text content should match")
|
||||||
}
|
}
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
t.Fatal("Timed out waiting for resource read response")
|
t.Fatal("Timed out waiting for resource read response")
|
||||||
@@ -3159,7 +3178,7 @@ func TestMethodResourcesSubscribe(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
client := addTestClient(mock.server, "test-client-sub-not-found", true)
|
client := addTestClient(mock.server, "test-client-sub-not-found", true)
|
||||||
mock.server.processResourceSubscribe(client, req)
|
mock.server.processResourceSubscribe(context.Background(), client, req)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case response := <-client.channel:
|
case response := <-client.channel:
|
||||||
@@ -3268,7 +3287,7 @@ func TestToolCallUnmarshalError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call directly
|
// Process the tool call directly
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Check for error response about invalid JSON
|
// Check for error response about invalid JSON
|
||||||
select {
|
select {
|
||||||
@@ -3316,7 +3335,7 @@ func TestToolCallWithInvalidParams(t *testing.T) {
|
|||||||
jsonBody, _ := json.Marshal(req)
|
jsonBody, _ := json.Marshal(req)
|
||||||
|
|
||||||
// Create HTTP request
|
// Create HTTP request
|
||||||
r := httptest.NewRequest("POST", "/?session_id=test-client-invalid-json", bytes.NewReader(jsonBody))
|
r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-invalid-json", bytes.NewReader(jsonBody))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
// Process through handleRequest
|
// Process through handleRequest
|
||||||
|
|||||||
42
mcp/types.go
42
mcp/types.go
@@ -1,6 +1,7 @@
|
|||||||
package mcp
|
package mcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -45,19 +46,18 @@ type ListToolsResult struct {
|
|||||||
|
|
||||||
// Message Content Types
|
// Message Content Types
|
||||||
|
|
||||||
// roleType represents the sender or recipient of messages in a conversation
|
// RoleType represents the sender or recipient of messages in a conversation
|
||||||
type roleType string
|
type RoleType string
|
||||||
|
|
||||||
// PromptArgument defines a single argument that can be passed to a prompt
|
// PromptArgument defines a single argument that can be passed to a prompt
|
||||||
type PromptArgument struct {
|
type PromptArgument struct {
|
||||||
Name string `json:"name"` // Argument name
|
Name string `json:"name"` // Argument name
|
||||||
Description string `json:"description,omitempty"` // Human-readable description
|
Description string `json:"description,omitempty"` // Human-readable description
|
||||||
Required bool `json:"required,omitempty"` // Whether this argument is required
|
Required bool `json:"required,omitempty"` // Whether this argument is required
|
||||||
Default string `json:"default,omitempty"` // Default value if not provided
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// PromptHandler is a function that dynamically generates prompt content
|
// PromptHandler is a function that dynamically generates prompt content
|
||||||
type PromptHandler func(args map[string]string) ([]PromptMessage, error)
|
type PromptHandler func(ctx context.Context, args map[string]string) ([]PromptMessage, error)
|
||||||
|
|
||||||
// Prompt represents an MCP Prompt definition
|
// Prompt represents an MCP Prompt definition
|
||||||
type Prompt struct {
|
type Prompt struct {
|
||||||
@@ -70,31 +70,43 @@ type Prompt struct {
|
|||||||
|
|
||||||
// PromptMessage represents a message in a conversation
|
// PromptMessage represents a message in a conversation
|
||||||
type PromptMessage struct {
|
type PromptMessage struct {
|
||||||
Role roleType `json:"role"` // Message sender role
|
Role RoleType `json:"role"` // Message sender role
|
||||||
Content any `json:"content"` // Message content (TextContent, ImageContent, etc.)
|
Content any `json:"content"` // Message content (TextContent, ImageContent, etc.)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TextContent represents text content in a message
|
// TextContent represents text content in a message
|
||||||
type TextContent struct {
|
type TextContent struct {
|
||||||
Type string `json:"type"` // Always "text"
|
|
||||||
Text string `json:"text"` // The text content
|
Text string `json:"text"` // The text content
|
||||||
Annotations *Annotations `json:"annotations,omitempty"` // Optional annotations
|
Annotations *Annotations `json:"annotations,omitempty"` // Optional annotations
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type typedTextContent struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
TextContent
|
||||||
|
}
|
||||||
|
|
||||||
// ImageContent represents image data in a message
|
// ImageContent represents image data in a message
|
||||||
type ImageContent struct {
|
type ImageContent struct {
|
||||||
Type string `json:"type"` // Always "image"
|
|
||||||
Data string `json:"data"` // Base64-encoded image data
|
Data string `json:"data"` // Base64-encoded image data
|
||||||
MimeType string `json:"mimeType"` // MIME type (e.g., "image/png")
|
MimeType string `json:"mimeType"` // MIME type (e.g., "image/png")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type typedImageContent struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
ImageContent
|
||||||
|
}
|
||||||
|
|
||||||
// AudioContent represents audio data in a message
|
// AudioContent represents audio data in a message
|
||||||
type AudioContent struct {
|
type AudioContent struct {
|
||||||
Type string `json:"type"` // Always "audio"
|
|
||||||
Data string `json:"data"` // Base64-encoded audio data
|
Data string `json:"data"` // Base64-encoded audio data
|
||||||
MimeType string `json:"mimeType"` // MIME type (e.g., "audio/mp3")
|
MimeType string `json:"mimeType"` // MIME type (e.g., "audio/mp3")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type typedAudioContent struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
AudioContent
|
||||||
|
}
|
||||||
|
|
||||||
// FileContent represents file content
|
// FileContent represents file content
|
||||||
type FileContent struct {
|
type FileContent struct {
|
||||||
URI string `json:"uri"` // URI identifying the file
|
URI string `json:"uri"` // URI identifying the file
|
||||||
@@ -115,16 +127,14 @@ type EmbeddedResource struct {
|
|||||||
|
|
||||||
// Annotations provides additional metadata for content
|
// Annotations provides additional metadata for content
|
||||||
type Annotations struct {
|
type Annotations struct {
|
||||||
Audience []roleType `json:"audience,omitempty"` // Who should see this content
|
Audience []RoleType `json:"audience,omitempty"` // Who should see this content
|
||||||
Priority *float64 `json:"priority,omitempty"` // Optional priority (0-1)
|
Priority *float64 `json:"priority,omitempty"` // Optional priority (0-1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tool-related Types
|
// Tool-related Types
|
||||||
|
|
||||||
// Tool Definition Types
|
|
||||||
|
|
||||||
// ToolHandler is a function that handles tool calls
|
// ToolHandler is a function that handles tool calls
|
||||||
type ToolHandler func(params map[string]any) (any, error)
|
type ToolHandler func(ctx context.Context, params map[string]any) (any, error)
|
||||||
|
|
||||||
// Tool represents a Model Context Protocol Tool definition
|
// Tool represents a Model Context Protocol Tool definition
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
@@ -136,7 +146,7 @@ type Tool struct {
|
|||||||
|
|
||||||
// InputSchema represents tool's input schema in JSON Schema format
|
// InputSchema represents tool's input schema in JSON Schema format
|
||||||
type InputSchema struct {
|
type InputSchema struct {
|
||||||
Type string `json:"type"` // Always "object" for tool inputs
|
Type string `json:"type"`
|
||||||
Properties map[string]any `json:"properties"` // Property definitions
|
Properties map[string]any `json:"properties"` // Property definitions
|
||||||
Required []string `json:"required,omitempty"` // List of required properties
|
Required []string `json:"required,omitempty"` // List of required properties
|
||||||
}
|
}
|
||||||
@@ -144,8 +154,8 @@ type InputSchema struct {
|
|||||||
// CallToolResult represents a tool call result that conforms to the MCP schema
|
// CallToolResult represents a tool call result that conforms to the MCP schema
|
||||||
type CallToolResult struct {
|
type CallToolResult struct {
|
||||||
Result
|
Result
|
||||||
Content []interface{} `json:"content"` // Content items (text, images, etc.)
|
Content []any `json:"content"` // Content items (text, images, etc.)
|
||||||
IsError bool `json:"isError,omitempty"` // True if tool execution failed
|
IsError bool `json:"isError,omitempty"` // True if tool execution failed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resource represents a Model Context Protocol Resource definition
|
// Resource represents a Model Context Protocol Resource definition
|
||||||
@@ -158,7 +168,7 @@ type Resource struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ResourceHandler is a function that handles resource read requests
|
// ResourceHandler is a function that handles resource read requests
|
||||||
type ResourceHandler func() (ResourceContent, error)
|
type ResourceHandler func(ctx context.Context) (ResourceContent, error)
|
||||||
|
|
||||||
// ResourceContent represents the content of a resource
|
// ResourceContent represents the content of a resource
|
||||||
type ResourceContent struct {
|
type ResourceContent struct {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package mcp
|
package mcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -79,7 +80,7 @@ func TestToolStructs(t *testing.T) {
|
|||||||
},
|
},
|
||||||
Required: []string{"input"},
|
Required: []string{"input"},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
return "result", nil
|
return "result", nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -145,44 +146,38 @@ func TestResourceStructs(t *testing.T) {
|
|||||||
func TestContentTypes(t *testing.T) {
|
func TestContentTypes(t *testing.T) {
|
||||||
// Test TextContent
|
// Test TextContent
|
||||||
textContent := TextContent{
|
textContent := TextContent{
|
||||||
Type: "text",
|
|
||||||
Text: "Sample text",
|
Text: "Sample text",
|
||||||
Annotations: &Annotations{
|
Annotations: &Annotations{
|
||||||
Audience: []roleType{roleUser, roleAssistant},
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
Priority: ptr(1.0),
|
Priority: ptr(1.0),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(textContent)
|
data, err := json.Marshal(textContent)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Contains(t, string(data), `"type":"text"`)
|
|
||||||
assert.Contains(t, string(data), `"text":"Sample text"`)
|
assert.Contains(t, string(data), `"text":"Sample text"`)
|
||||||
assert.Contains(t, string(data), `"audience":["user","assistant"]`)
|
assert.Contains(t, string(data), `"audience":["user","assistant"]`)
|
||||||
assert.Contains(t, string(data), `"priority":1`)
|
assert.Contains(t, string(data), `"priority":1`)
|
||||||
|
|
||||||
// Test ImageContent
|
// Test ImageContent
|
||||||
imageContent := ImageContent{
|
imageContent := ImageContent{
|
||||||
Type: "image",
|
|
||||||
Data: "base64data",
|
Data: "base64data",
|
||||||
MimeType: "image/png",
|
MimeType: "image/png",
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err = json.Marshal(imageContent)
|
data, err = json.Marshal(imageContent)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Contains(t, string(data), `"type":"image"`)
|
|
||||||
assert.Contains(t, string(data), `"data":"base64data"`)
|
assert.Contains(t, string(data), `"data":"base64data"`)
|
||||||
assert.Contains(t, string(data), `"mimeType":"image/png"`)
|
assert.Contains(t, string(data), `"mimeType":"image/png"`)
|
||||||
|
|
||||||
// Test AudioContent
|
// Test AudioContent
|
||||||
audioContent := AudioContent{
|
audioContent := AudioContent{
|
||||||
Type: "audio",
|
|
||||||
Data: "base64audio",
|
Data: "base64audio",
|
||||||
MimeType: "audio/mp3",
|
MimeType: "audio/mp3",
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err = json.Marshal(audioContent)
|
data, err = json.Marshal(audioContent)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Contains(t, string(data), `"type":"audio"`)
|
|
||||||
assert.Contains(t, string(data), `"data":"base64audio"`)
|
assert.Contains(t, string(data), `"data":"base64audio"`)
|
||||||
assert.Contains(t, string(data), `"mimeType":"audio/mp3"`)
|
assert.Contains(t, string(data), `"mimeType":"audio/mp3"`)
|
||||||
}
|
}
|
||||||
@@ -197,7 +192,6 @@ func TestCallToolResult(t *testing.T) {
|
|||||||
},
|
},
|
||||||
Content: []interface{}{
|
Content: []interface{}{
|
||||||
TextContent{
|
TextContent{
|
||||||
Type: "text",
|
|
||||||
Text: "Sample result",
|
Text: "Sample result",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -207,6 +201,6 @@ func TestCallToolResult(t *testing.T) {
|
|||||||
data, err := json.Marshal(result)
|
data, err := json.Marshal(result)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Contains(t, string(data), `"_meta":{"progressToken":"token123"}`)
|
assert.Contains(t, string(data), `"_meta":{"progressToken":"token123"}`)
|
||||||
assert.Contains(t, string(data), `"content":[{"type":"text","text":"Sample result"}]`)
|
assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`)
|
||||||
assert.NotContains(t, string(data), `"isError":`)
|
assert.NotContains(t, string(data), `"isError":`)
|
||||||
}
|
}
|
||||||
|
|||||||
104
mcp/util.go
104
mcp/util.go
@@ -1,15 +1,107 @@
|
|||||||
package mcp
|
package mcp
|
||||||
|
|
||||||
import (
|
import "fmt"
|
||||||
"fmt"
|
|
||||||
)
|
// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings
|
||||||
|
func formatSSEMessage(event string, data []byte) string {
|
||||||
|
return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
// ptr is a helper function to get a pointer to a value
|
// ptr is a helper function to get a pointer to a value
|
||||||
func ptr[T any](v T) *T {
|
func ptr[T any](v T) *T {
|
||||||
return &v
|
return &v
|
||||||
}
|
}
|
||||||
|
|
||||||
// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings
|
func toTypedContents(contents []any) []any {
|
||||||
func formatSSEMessage(event string, data []byte) string {
|
typedContents := make([]any, len(contents))
|
||||||
return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data))
|
|
||||||
|
for i, content := range contents {
|
||||||
|
switch v := content.(type) {
|
||||||
|
case TextContent:
|
||||||
|
typedContents[i] = typedTextContent{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
TextContent: v,
|
||||||
|
}
|
||||||
|
case ImageContent:
|
||||||
|
typedContents[i] = typedImageContent{
|
||||||
|
Type: ContentTypeImage,
|
||||||
|
ImageContent: v,
|
||||||
|
}
|
||||||
|
case AudioContent:
|
||||||
|
typedContents[i] = typedAudioContent{
|
||||||
|
Type: ContentTypeAudio,
|
||||||
|
AudioContent: v,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
typedContents[i] = typedTextContent{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
TextContent: TextContent{
|
||||||
|
Text: fmt.Sprintf("Unknown content type: %T", v),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return typedContents
|
||||||
|
}
|
||||||
|
|
||||||
|
func toTypedPromptMessages(messages []PromptMessage) []PromptMessage {
|
||||||
|
typedMessages := make([]PromptMessage, len(messages))
|
||||||
|
|
||||||
|
for i, msg := range messages {
|
||||||
|
switch v := msg.Content.(type) {
|
||||||
|
case TextContent:
|
||||||
|
typedMessages[i] = PromptMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: typedTextContent{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
TextContent: v,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
case ImageContent:
|
||||||
|
typedMessages[i] = PromptMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: typedImageContent{
|
||||||
|
Type: ContentTypeImage,
|
||||||
|
ImageContent: v,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
case AudioContent:
|
||||||
|
typedMessages[i] = PromptMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: typedAudioContent{
|
||||||
|
Type: ContentTypeAudio,
|
||||||
|
AudioContent: v,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
typedMessages[i] = PromptMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: typedTextContent{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
TextContent: TextContent{
|
||||||
|
Text: fmt.Sprintf("Unknown content type: %T", v),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return typedMessages
|
||||||
|
}
|
||||||
|
|
||||||
|
// validatePromptArguments checks if all required arguments are provided
|
||||||
|
// Returns a list of missing required arguments
|
||||||
|
func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string {
|
||||||
|
var missingArgs []string
|
||||||
|
|
||||||
|
for _, arg := range prompt.Arguments {
|
||||||
|
if arg.Required {
|
||||||
|
if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 {
|
||||||
|
missingArgs = append(missingArgs, arg.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return missingArgs
|
||||||
}
|
}
|
||||||
|
|||||||
253
mcp/util_test.go
253
mcp/util_test.go
@@ -8,29 +8,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPtr(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
v interface{}
|
|
||||||
}{
|
|
||||||
{"string", "test"},
|
|
||||||
{"int", 42},
|
|
||||||
{"bool", true},
|
|
||||||
{"float", 3.14},
|
|
||||||
{"struct", struct{ Name string }{"test"}},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := ptr(tt.v)
|
|
||||||
assert.NotNil(t, got, "ptr() should not return nil")
|
|
||||||
assert.Equal(t, tt.v, *got, "dereferenced pointer should equal input value")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Event struct {
|
type Event struct {
|
||||||
Type string
|
Type string
|
||||||
Data map[string]any
|
Data map[string]any
|
||||||
@@ -61,3 +41,234 @@ func parseEvent(input string) (*Event, error) {
|
|||||||
|
|
||||||
return &evt, nil
|
return &evt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestToTypedPromptMessages tests the toTypedPromptMessages function
|
||||||
|
func TestToTypedPromptMessages(t *testing.T) {
|
||||||
|
// Test with multiple message types in one test
|
||||||
|
t.Run("MixedContentTypes", func(t *testing.T) {
|
||||||
|
// Create test data with different content types
|
||||||
|
messages := []PromptMessage{
|
||||||
|
{
|
||||||
|
Role: RoleUser,
|
||||||
|
Content: TextContent{
|
||||||
|
Text: "Hello, this is a text message",
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
Priority: ptr(0.8),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: RoleAssistant,
|
||||||
|
Content: ImageContent{
|
||||||
|
Data: "base64ImageData",
|
||||||
|
MimeType: "image/jpeg",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: RoleUser,
|
||||||
|
Content: AudioContent{
|
||||||
|
Data: "base64AudioData",
|
||||||
|
MimeType: "audio/mp3",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "system",
|
||||||
|
Content: "This is a simple string that should be handled as unknown type",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the function
|
||||||
|
result := toTypedPromptMessages(messages)
|
||||||
|
|
||||||
|
// Validate results
|
||||||
|
require.Len(t, result, 4, "Should return the same number of messages")
|
||||||
|
|
||||||
|
// Validate first message (TextContent)
|
||||||
|
msg := result[0]
|
||||||
|
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
|
||||||
|
|
||||||
|
// Type assertion using reflection since Content is an interface
|
||||||
|
typed, ok := msg.Content.(typedTextContent)
|
||||||
|
require.True(t, ok, "Should be typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||||
|
assert.Equal(t, "Hello, this is a text message", typed.Text, "Text content should be preserved")
|
||||||
|
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
|
||||||
|
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
|
||||||
|
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
|
||||||
|
assert.Equal(t, 0.8, *typed.Annotations.Priority, "Priority value should be preserved")
|
||||||
|
|
||||||
|
// Validate second message (ImageContent)
|
||||||
|
msg = result[1]
|
||||||
|
assert.Equal(t, RoleAssistant, msg.Role, "Role should be preserved")
|
||||||
|
|
||||||
|
// Type assertion for image content
|
||||||
|
typedImg, ok := msg.Content.(typedImageContent)
|
||||||
|
require.True(t, ok, "Should be typedImageContent")
|
||||||
|
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
|
||||||
|
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
|
||||||
|
assert.Equal(t, "image/jpeg", typedImg.MimeType, "MimeType should be preserved")
|
||||||
|
|
||||||
|
// Validate third message (AudioContent)
|
||||||
|
msg = result[2]
|
||||||
|
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
|
||||||
|
|
||||||
|
// Type assertion for audio content
|
||||||
|
typedAudio, ok := msg.Content.(typedAudioContent)
|
||||||
|
require.True(t, ok, "Should be typedAudioContent")
|
||||||
|
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
|
||||||
|
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
|
||||||
|
assert.Equal(t, "audio/mp3", typedAudio.MimeType, "MimeType should be preserved")
|
||||||
|
|
||||||
|
// Validate fourth message (unknown type converted to TextContent)
|
||||||
|
msg = result[3]
|
||||||
|
assert.Equal(t, RoleType("system"), msg.Role, "Role should be preserved")
|
||||||
|
|
||||||
|
// Should be converted to a typedTextContent with error message
|
||||||
|
typedUnknown, ok := msg.Content.(typedTextContent)
|
||||||
|
require.True(t, ok, "Unknown content should be converted to typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
|
||||||
|
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||||
|
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test empty input
|
||||||
|
t.Run("EmptyInput", func(t *testing.T) {
|
||||||
|
messages := []PromptMessage{}
|
||||||
|
result := toTypedPromptMessages(messages)
|
||||||
|
assert.Empty(t, result, "Should return empty slice for empty input")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test with nil annotations
|
||||||
|
t.Run("NilAnnotations", func(t *testing.T) {
|
||||||
|
messages := []PromptMessage{
|
||||||
|
{
|
||||||
|
Role: RoleUser,
|
||||||
|
Content: TextContent{
|
||||||
|
Text: "Text with nil annotations",
|
||||||
|
Annotations: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := toTypedPromptMessages(messages)
|
||||||
|
require.Len(t, result, 1, "Should return one message")
|
||||||
|
|
||||||
|
typed, ok := result[0].Content.(typedTextContent)
|
||||||
|
require.True(t, ok, "Should be typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||||
|
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
|
||||||
|
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToTypedContents tests the toTypedContents function
|
||||||
|
func TestToTypedContents(t *testing.T) {
|
||||||
|
// Test with multiple content types in one test
|
||||||
|
t.Run("MixedContentTypes", func(t *testing.T) {
|
||||||
|
// Create test data with different content types
|
||||||
|
contents := []any{
|
||||||
|
TextContent{
|
||||||
|
Text: "Hello, this is a text content",
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
Priority: ptr(0.7),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ImageContent{
|
||||||
|
Data: "base64ImageData",
|
||||||
|
MimeType: "image/png",
|
||||||
|
},
|
||||||
|
AudioContent{
|
||||||
|
Data: "base64AudioData",
|
||||||
|
MimeType: "audio/wav",
|
||||||
|
},
|
||||||
|
"This is a simple string that should be handled as unknown type",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the function
|
||||||
|
result := toTypedContents(contents)
|
||||||
|
|
||||||
|
// Validate results
|
||||||
|
require.Len(t, result, 4, "Should return the same number of contents")
|
||||||
|
|
||||||
|
// Validate first content (TextContent)
|
||||||
|
typed, ok := result[0].(typedTextContent)
|
||||||
|
require.True(t, ok, "Should be typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||||
|
assert.Equal(t, "Hello, this is a text content", typed.Text, "Text content should be preserved")
|
||||||
|
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
|
||||||
|
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
|
||||||
|
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
|
||||||
|
assert.Equal(t, 0.7, *typed.Annotations.Priority, "Priority value should be preserved")
|
||||||
|
|
||||||
|
// Validate second content (ImageContent)
|
||||||
|
typedImg, ok := result[1].(typedImageContent)
|
||||||
|
require.True(t, ok, "Should be typedImageContent")
|
||||||
|
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
|
||||||
|
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
|
||||||
|
assert.Equal(t, "image/png", typedImg.MimeType, "MimeType should be preserved")
|
||||||
|
|
||||||
|
// Validate third content (AudioContent)
|
||||||
|
typedAudio, ok := result[2].(typedAudioContent)
|
||||||
|
require.True(t, ok, "Should be typedAudioContent")
|
||||||
|
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
|
||||||
|
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
|
||||||
|
assert.Equal(t, "audio/wav", typedAudio.MimeType, "MimeType should be preserved")
|
||||||
|
|
||||||
|
// Validate fourth content (unknown type converted to TextContent)
|
||||||
|
typedUnknown, ok := result[3].(typedTextContent)
|
||||||
|
require.True(t, ok, "Unknown content should be converted to typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
|
||||||
|
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||||
|
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test empty input
|
||||||
|
t.Run("EmptyInput", func(t *testing.T) {
|
||||||
|
contents := []any{}
|
||||||
|
result := toTypedContents(contents)
|
||||||
|
assert.Empty(t, result, "Should return empty slice for empty input")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test with nil annotations
|
||||||
|
t.Run("NilAnnotations", func(t *testing.T) {
|
||||||
|
contents := []any{
|
||||||
|
TextContent{
|
||||||
|
Text: "Text with nil annotations",
|
||||||
|
Annotations: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := toTypedContents(contents)
|
||||||
|
require.Len(t, result, 1, "Should return one content")
|
||||||
|
|
||||||
|
typed, ok := result[0].(typedTextContent)
|
||||||
|
require.True(t, ok, "Should be typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||||
|
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
|
||||||
|
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test with custom struct (should be handled as unknown type)
|
||||||
|
t.Run("CustomStruct", func(t *testing.T) {
|
||||||
|
type CustomContent struct {
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
|
||||||
|
contents := []any{
|
||||||
|
CustomContent{
|
||||||
|
Data: "custom data",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := toTypedContents(contents)
|
||||||
|
require.Len(t, result, 1, "Should return one content")
|
||||||
|
|
||||||
|
typed, ok := result[0].(typedTextContent)
|
||||||
|
require.True(t, ok, "Custom struct should be converted to typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||||
|
assert.Contains(t, typed.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||||
|
assert.Contains(t, typed.Text, "CustomContent", "Should mention the actual type")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
28
mcp/vars.go
28
mcp/vars.go
@@ -13,6 +13,9 @@ const (
|
|||||||
|
|
||||||
// Session identifier key used in request URLs
|
// Session identifier key used in request URLs
|
||||||
sessionIdKey = "session_id"
|
sessionIdKey = "session_id"
|
||||||
|
|
||||||
|
// progressTokenKey is used to track progress of long-running tasks
|
||||||
|
progressTokenKey = "progressToken"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server-Sent Events (SSE) event types
|
// Server-Sent Events (SSE) event types
|
||||||
@@ -26,11 +29,20 @@ const (
|
|||||||
|
|
||||||
// Content type identifiers
|
// Content type identifiers
|
||||||
const (
|
const (
|
||||||
// Text content type
|
// ContentTypeObject is object content type
|
||||||
contentTypeText = "text"
|
ContentTypeObject = "object"
|
||||||
|
|
||||||
// Image content type
|
// ContentTypeText is text content type
|
||||||
contentTypeImage = "image"
|
ContentTypeText = "text"
|
||||||
|
|
||||||
|
// ContentTypeImage is image content type
|
||||||
|
ContentTypeImage = "image"
|
||||||
|
|
||||||
|
// ContentTypeAudio is audio content type
|
||||||
|
ContentTypeAudio = "audio"
|
||||||
|
|
||||||
|
// ContentTypeResource is resource content type
|
||||||
|
ContentTypeResource = "resource"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Collection keys for broadcast events
|
// Collection keys for broadcast events
|
||||||
@@ -72,11 +84,11 @@ const (
|
|||||||
|
|
||||||
// User and assistant role definitions
|
// User and assistant role definitions
|
||||||
const (
|
const (
|
||||||
// The "user" role - the entity asking questions
|
// RoleUser is the "user" role - the entity asking questions
|
||||||
roleUser roleType = "user"
|
RoleUser RoleType = "user"
|
||||||
|
|
||||||
// The "assistant" role - the entity providing responses
|
// RoleAssistant is the "assistant" role - the entity providing responses
|
||||||
roleAssistant roleType = "assistant"
|
RoleAssistant RoleType = "assistant"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Method names as defined in the MCP specification
|
// Method names as defined in the MCP specification
|
||||||
|
|||||||
@@ -146,13 +146,9 @@ func TestCollectionKeys(t *testing.T) {
|
|||||||
|
|
||||||
// TestRoleTypes checks that role types are used correctly
|
// TestRoleTypes checks that role types are used correctly
|
||||||
func TestRoleTypes(t *testing.T) {
|
func TestRoleTypes(t *testing.T) {
|
||||||
// Verify role type constants
|
|
||||||
assert.Equal(t, "user", string(roleUser), "User role should be 'user'")
|
|
||||||
assert.Equal(t, "assistant", string(roleAssistant), "Assistant role should be 'assistant'")
|
|
||||||
|
|
||||||
// Test in annotations
|
// Test in annotations
|
||||||
annotations := Annotations{
|
annotations := Annotations{
|
||||||
Audience: []roleType{roleUser, roleAssistant},
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
}
|
}
|
||||||
data, err := json.Marshal(annotations)
|
data, err := json.Marshal(annotations)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|||||||
Reference in New Issue
Block a user