mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-07 06:59:59 +08:00
feat(mcp): migrate to official go-sdk with simplified API (#5362)
This commit is contained in:
3
go.mod
3
go.mod
@@ -14,6 +14,7 @@ require (
|
||||
github.com/grafana/pyroscope-go v1.2.7
|
||||
github.com/jackc/pgx/v5 v5.7.4
|
||||
github.com/jhump/protoreflect v1.17.0
|
||||
github.com/modelcontextprotocol/go-sdk v1.2.0
|
||||
github.com/pelletier/go-toml/v2 v2.2.4
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
@@ -71,6 +72,7 @@ require (
|
||||
github.com/google/gnostic-models v0.6.8 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/gofuzz v1.2.0 // indirect
|
||||
github.com/google/jsonschema-go v0.3.0 // indirect
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.9 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
|
||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect
|
||||
@@ -98,6 +100,7 @@ require (
|
||||
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
||||
github.com/xdg-go/scram v1.1.2 // indirect
|
||||
github.com/xdg-go/stringprep v1.0.4 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
||||
github.com/yuin/gopher-lua v1.1.1 // indirect
|
||||
go.etcd.io/etcd/client/pkg/v3 v3.5.15 // indirect
|
||||
|
||||
8
go.sum
8
go.sum
@@ -62,6 +62,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
|
||||
@@ -74,6 +76,8 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
|
||||
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec=
|
||||
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
@@ -123,6 +127,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s=
|
||||
github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -182,6 +188,8 @@ github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY=
|
||||
github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4=
|
||||
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
|
||||
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
|
||||
166
mcp/MIGRATION.md
Normal file
166
mcp/MIGRATION.md
Normal file
@@ -0,0 +1,166 @@
|
||||
# Migration to Official MCP SDK
|
||||
|
||||
This document describes the migration from the custom MCP implementation to the official [go-sdk](https://github.com/modelcontextprotocol/go-sdk).
|
||||
|
||||
## Changes
|
||||
|
||||
### Dependencies
|
||||
|
||||
Added the official MCP SDK:
|
||||
```bash
|
||||
go get github.com/modelcontextprotocol/go-sdk@v1.2.0
|
||||
```
|
||||
|
||||
### Type System
|
||||
|
||||
All types are now re-exported from the official SDK:
|
||||
- `Tool` → `sdkmcp.Tool`
|
||||
- `CallToolRequest` → `sdkmcp.CallToolRequest`
|
||||
- `CallToolResult` → `sdkmcp.CallToolResult`
|
||||
- Content types (`TextContent`, `ImageContent`, etc.)
|
||||
- `Prompt`, `Resource`, `Server`, `ServerSession`
|
||||
|
||||
### Server Interface
|
||||
|
||||
The `McpServer` interface has been simplified:
|
||||
|
||||
```go
|
||||
type McpServer interface {
|
||||
Start()
|
||||
Stop()
|
||||
Server() *sdkmcp.Server // Returns underlying SDK server
|
||||
}
|
||||
```
|
||||
|
||||
**Important**: The `AddTool`, `AddPrompt`, and `AddResource` methods have been removed. Use the SDK directly:
|
||||
|
||||
```go
|
||||
// Old (no longer supported)
|
||||
server.AddTool(tool, handler)
|
||||
|
||||
// New (use SDK directly)
|
||||
sdkmcp.AddTool(server.Server(), tool, handler)
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
Updated configuration structure:
|
||||
- Removed: `ProtocolVersion`, `BaseUrl` (SDK manages these)
|
||||
- Added: `UseStreamable` (choose between SSE and Streamable HTTP transport)
|
||||
|
||||
```yaml
|
||||
mcp:
|
||||
name: my-server
|
||||
version: 1.0.0
|
||||
useStreamable: false # false = SSE (2024-11-05), true = Streamable HTTP (2025-03-26)
|
||||
sseEndpoint: /sse
|
||||
messageEndpoint: /message
|
||||
sseTimeout: 24h
|
||||
messageTimeout: 30s
|
||||
cors:
|
||||
- http://localhost:3000
|
||||
```
|
||||
|
||||
### Tool Registration
|
||||
|
||||
The SDK uses Go generics for type-safe tool registration:
|
||||
|
||||
```go
|
||||
import sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
type MyArgs struct {
|
||||
Value string `json:"value" jsonschema:"description=Input value"`
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{
|
||||
Name: "my_tool",
|
||||
Description: "Description",
|
||||
}
|
||||
|
||||
handler := func(ctx context.Context, req *mcp.CallToolRequest, args MyArgs) (*mcp.CallToolResult, any, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "Result"},
|
||||
},
|
||||
}, nil, nil
|
||||
}
|
||||
|
||||
// Register with explicit type parameters
|
||||
sdkmcp.AddTool(server.Server(), tool, handler)
|
||||
```
|
||||
|
||||
The SDK automatically generates JSON schemas from struct tags.
|
||||
|
||||
### Transport Support
|
||||
|
||||
Two transports are supported:
|
||||
|
||||
1. **SSE (Server-Sent Events)**: 2024-11-05 MCP spec
|
||||
- Default (`UseStreamable: false`)
|
||||
- Endpoint: `/sse` (configurable)
|
||||
- Bidirectional: client sends messages to `/message`
|
||||
|
||||
2. **Streamable HTTP**: 2025-03-26 MCP spec
|
||||
- Opt-in (`UseStreamable: true`)
|
||||
- Endpoint: `/sse` (configurable)
|
||||
- Newer protocol with improved streaming
|
||||
|
||||
### Example Migration
|
||||
|
||||
**Before:**
|
||||
```go
|
||||
server := mcp.NewMcpServer(c)
|
||||
|
||||
tool := &mcp.Tool{Name: "greet", Description: "Greet"}
|
||||
handler := func(ctx context.Context, req *mcp.CallToolRequest, args GreetArgs) (*mcp.CallToolResult, any, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{&mcp.TextContent{Text: "Hello"}},
|
||||
}, nil, nil
|
||||
}
|
||||
|
||||
if err := server.AddTool(tool, handler); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
|
||||
**After:**
|
||||
```go
|
||||
import sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
server := mcp.NewMcpServer(c)
|
||||
|
||||
tool := &mcp.Tool{Name: "greet", Description: "Greet"}
|
||||
handler := func(ctx context.Context, req *mcp.CallToolRequest, args GreetArgs) (*mcp.CallToolResult, any, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{&mcp.TextContent{Text: "Hello"}},
|
||||
}, nil, nil
|
||||
}
|
||||
|
||||
// Use SDK directly - no error return
|
||||
sdkmcp.AddTool(server.Server(), tool, handler)
|
||||
```
|
||||
|
||||
## Benefits
|
||||
|
||||
1. **Official SDK**: Uses the official Model Context Protocol SDK
|
||||
2. **Type Safety**: Go generics provide compile-time type checking
|
||||
3. **Auto Schema**: JSON schemas generated automatically from struct tags
|
||||
4. **Dual Transport**: Supports both SSE and Streamable HTTP transports
|
||||
5. **Maintained**: SDK is actively maintained by the MCP team
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
1. `server.AddTool()` removed → use `sdkmcp.AddTool(server.Server(), ...)`
|
||||
2. `server.AddPrompt()` removed (SDK v1.2.0 limitation)
|
||||
3. `server.AddResource()` removed (SDK v1.2.0 limitation)
|
||||
4. Config fields `ProtocolVersion` and `BaseUrl` removed
|
||||
5. All types now come from SDK (re-exported for convenience)
|
||||
|
||||
## Migration Checklist
|
||||
|
||||
- [ ] Update imports: add `sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"`
|
||||
- [ ] Replace `server.AddTool()` with `sdkmcp.AddTool(server.Server(), ...)`
|
||||
- [ ] Remove error handling for tool registration (SDK doesn't return errors)
|
||||
- [ ] Update config: remove `ProtocolVersion` and `BaseUrl`, add `UseStreamable`
|
||||
- [ ] Test with both SSE and Streamable transports
|
||||
- [ ] Update documentation/examples
|
||||
@@ -18,17 +18,16 @@ type McpConf struct {
|
||||
// Version is the server version reported in initialize responses
|
||||
Version string `json:",default=1.0.0"`
|
||||
|
||||
// ProtocolVersion is the MCP protocol version implemented
|
||||
ProtocolVersion string `json:",default=2024-11-05"`
|
||||
|
||||
// BaseUrl is the base URL for the server, used in SSE endpoint messages
|
||||
// If not set, defaults to http://localhost:{Port}
|
||||
BaseUrl string `json:",optional"`
|
||||
// UseStreamable when true uses Streamable HTTP transport (2025-03-26 spec),
|
||||
// otherwise uses SSE transport (2024-11-05 spec)
|
||||
UseStreamable bool `json:",default=false"`
|
||||
|
||||
// SseEndpoint is the path for Server-Sent Events connections
|
||||
// Used for SSE transport mode
|
||||
SseEndpoint string `json:",default=/sse"`
|
||||
|
||||
// MessageEndpoint is the path for JSON-RPC requests
|
||||
// Used for Streamable HTTP transport mode
|
||||
MessageEndpoint string `json:",default=/message"`
|
||||
|
||||
// Cors contains allowed CORS origins
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func TestMcpConfDefaults(t *testing.T) {
|
||||
// Test default values are set correctly when unmarshalled from JSON
|
||||
// Test default values are set correctly
|
||||
jsonConfig := `name: test-service
|
||||
port: 8080
|
||||
mcp:
|
||||
@@ -23,41 +23,8 @@ mcp:
|
||||
|
||||
// Check default values
|
||||
assert.Equal(t, "test-mcp-server", c.Mcp.Name)
|
||||
assert.Equal(t, "1.0.0", c.Mcp.Version, "Default version should be 1.0.0")
|
||||
assert.Equal(t, "2024-11-05", c.Mcp.ProtocolVersion, "Default protocol version should be 2024-11-05")
|
||||
assert.Equal(t, "/sse", c.Mcp.SseEndpoint, "Default SSE endpoint should be /sse")
|
||||
assert.Equal(t, "/message", c.Mcp.MessageEndpoint, "Default message endpoint should be /message")
|
||||
assert.Equal(t, 30*time.Second, c.Mcp.MessageTimeout, "Default message timeout should be 30s")
|
||||
}
|
||||
|
||||
func TestMcpConfCustomValues(t *testing.T) {
|
||||
// Test custom values can be set
|
||||
jsonConfig := `{
|
||||
"Name": "test-service",
|
||||
"Port": 8080,
|
||||
"Mcp": {
|
||||
"Name": "test-mcp-server",
|
||||
"Version": "2.0.0",
|
||||
"ProtocolVersion": "2025-01-01",
|
||||
"BaseUrl": "http://example.com",
|
||||
"SseEndpoint": "/custom-sse",
|
||||
"MessageEndpoint": "/custom-message",
|
||||
"Cors": ["http://localhost:3000", "http://example.com"],
|
||||
"MessageTimeout": "60s"
|
||||
}
|
||||
}`
|
||||
|
||||
var c McpConf
|
||||
err := conf.LoadFromJsonBytes([]byte(jsonConfig), &c)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check custom values
|
||||
assert.Equal(t, "test-mcp-server", c.Mcp.Name, "Name should be inherited from RestConf")
|
||||
assert.Equal(t, "2.0.0", c.Mcp.Version, "Version should be customizable")
|
||||
assert.Equal(t, "2025-01-01", c.Mcp.ProtocolVersion, "Protocol version should be customizable")
|
||||
assert.Equal(t, "http://example.com", c.Mcp.BaseUrl, "BaseUrl should be customizable")
|
||||
assert.Equal(t, "/custom-sse", c.Mcp.SseEndpoint, "SSE endpoint should be customizable")
|
||||
assert.Equal(t, "/custom-message", c.Mcp.MessageEndpoint, "Message endpoint should be customizable")
|
||||
assert.Equal(t, []string{"http://localhost:3000", "http://example.com"}, c.Mcp.Cors, "CORS settings should be customizable")
|
||||
assert.Equal(t, 60*time.Second, c.Mcp.MessageTimeout, "Tool timeout should be customizable")
|
||||
assert.Equal(t, "1.0.0", c.Mcp.Version)
|
||||
assert.Equal(t, "/sse", c.Mcp.SseEndpoint)
|
||||
assert.Equal(t, "/message", c.Mcp.MessageEndpoint)
|
||||
assert.Equal(t, 30*time.Second, c.Mcp.MessageTimeout)
|
||||
}
|
||||
|
||||
@@ -1,443 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// syncResponseRecorder is a thread-safe wrapper around httptest.ResponseRecorder
|
||||
type syncResponseRecorder struct {
|
||||
*httptest.ResponseRecorder
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Create a new synchronized response recorder
|
||||
func newSyncResponseRecorder() *syncResponseRecorder {
|
||||
return &syncResponseRecorder{
|
||||
ResponseRecorder: httptest.NewRecorder(),
|
||||
}
|
||||
}
|
||||
|
||||
// Override Write method to synchronize access
|
||||
func (srr *syncResponseRecorder) Write(p []byte) (int, error) {
|
||||
srr.mu.Lock()
|
||||
defer srr.mu.Unlock()
|
||||
return srr.ResponseRecorder.Write(p)
|
||||
}
|
||||
|
||||
// Override WriteHeader method to synchronize access
|
||||
func (srr *syncResponseRecorder) WriteHeader(statusCode int) {
|
||||
srr.mu.Lock()
|
||||
defer srr.mu.Unlock()
|
||||
srr.ResponseRecorder.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
// Override Result method to synchronize access
|
||||
func (srr *syncResponseRecorder) Result() *http.Response {
|
||||
srr.mu.Lock()
|
||||
defer srr.mu.Unlock()
|
||||
return srr.ResponseRecorder.Result()
|
||||
}
|
||||
|
||||
// TestHTTPHandlerIntegration tests the HTTP handlers with a real server instance
|
||||
func TestHTTPHandlerIntegration(t *testing.T) {
|
||||
// Skip in short test mode
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// Create a test configuration
|
||||
conf := McpConf{}
|
||||
conf.Mcp.Name = "test-integration"
|
||||
conf.Mcp.Version = "1.0.0-test"
|
||||
conf.Mcp.MessageTimeout = 1 * time.Second
|
||||
|
||||
// Create a mock server directly
|
||||
server := &sseMcpServer{
|
||||
conf: conf,
|
||||
clients: make(map[string]*mcpClient),
|
||||
tools: make(map[string]Tool),
|
||||
prompts: make(map[string]Prompt),
|
||||
resources: make(map[string]Resource),
|
||||
}
|
||||
|
||||
// Register a test tool
|
||||
err := server.RegisterTool(Tool{
|
||||
Name: "echo",
|
||||
Description: "Echo tool for testing",
|
||||
InputSchema: InputSchema{
|
||||
Properties: map[string]any{
|
||||
"message": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Message to echo",
|
||||
},
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
if msg, ok := params["message"].(string); ok {
|
||||
return fmt.Sprintf("Echo: %s", msg), nil
|
||||
}
|
||||
return "Echo: no message provided", nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a test HTTP request to the SSE endpoint
|
||||
req := httptest.NewRequest("GET", "/sse", nil)
|
||||
w := newSyncResponseRecorder()
|
||||
|
||||
// Create a done channel to signal completion of test
|
||||
done := make(chan bool)
|
||||
|
||||
// Start the SSE handler in a goroutine
|
||||
go func() {
|
||||
// lock.Lock()
|
||||
server.handleSSE(w, req)
|
||||
// lock.Unlock()
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Allow time for the handler to process
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Expected - handler would normally block indefinitely
|
||||
case <-done:
|
||||
// This shouldn't happen immediately - the handler should block
|
||||
t.Error("SSE handler returned unexpectedly")
|
||||
}
|
||||
|
||||
// Check the initial headers
|
||||
resp := w.Result()
|
||||
assert.Equal(t, "chunked", resp.Header.Get("Transfer-Encoding"))
|
||||
resp.Body.Close()
|
||||
|
||||
// The handler creates a client and sends the endpoint message
|
||||
var sessionId string
|
||||
|
||||
// Give the handler time to set up the client
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Check that a client was created
|
||||
server.clientsLock.Lock()
|
||||
assert.Equal(t, 1, len(server.clients))
|
||||
for id := range server.clients {
|
||||
sessionId = id
|
||||
}
|
||||
server.clientsLock.Unlock()
|
||||
|
||||
require.NotEmpty(t, sessionId, "Expected a session ID to be created")
|
||||
|
||||
// Now that we have a session ID, we can test the message endpoint
|
||||
messageBody, _ := json.Marshal(Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 1,
|
||||
Method: methodInitialize,
|
||||
Params: json.RawMessage(`{}`),
|
||||
})
|
||||
|
||||
// Create a message request
|
||||
reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, sessionId)
|
||||
msgReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(messageBody))
|
||||
msgW := newSyncResponseRecorder()
|
||||
|
||||
// Process the message
|
||||
server.handleRequest(msgW, msgReq)
|
||||
|
||||
// Check the response
|
||||
msgResp := msgW.Result()
|
||||
assert.Equal(t, http.StatusAccepted, msgResp.StatusCode)
|
||||
msgResp.Body.Close() // Ensure response body is closed
|
||||
}
|
||||
|
||||
// TestHandlerResponseFlow tests the flow of a full request/response cycle
|
||||
func TestHandlerResponseFlow(t *testing.T) {
|
||||
// Create a mock server for testing
|
||||
server := &sseMcpServer{
|
||||
conf: McpConf{},
|
||||
clients: map[string]*mcpClient{
|
||||
"test-session": {
|
||||
id: "test-session",
|
||||
channel: make(chan string, 10),
|
||||
initialized: true,
|
||||
},
|
||||
},
|
||||
tools: make(map[string]Tool),
|
||||
prompts: make(map[string]Prompt),
|
||||
resources: make(map[string]Resource),
|
||||
}
|
||||
|
||||
// Register test resources
|
||||
server.RegisterTool(Tool{
|
||||
Name: "test.tool",
|
||||
Description: "Test tool",
|
||||
InputSchema: InputSchema{Type: "object"},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return "tool result", nil
|
||||
},
|
||||
})
|
||||
|
||||
server.RegisterPrompt(Prompt{
|
||||
Name: "test.prompt",
|
||||
Description: "Test prompt",
|
||||
})
|
||||
|
||||
server.RegisterResource(Resource{
|
||||
Name: "test.resource",
|
||||
URI: "http://example.com",
|
||||
Description: "Test resource",
|
||||
})
|
||||
|
||||
// Create a request with session ID parameter
|
||||
reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, "test-session")
|
||||
|
||||
// Test tools/list request
|
||||
toolsListBody, _ := json.Marshal(Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 1,
|
||||
Method: methodToolsList,
|
||||
Params: json.RawMessage(`{}`),
|
||||
})
|
||||
|
||||
toolsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(toolsListBody))
|
||||
toolsW := newSyncResponseRecorder()
|
||||
|
||||
// Process the request
|
||||
server.handleRequest(toolsW, toolsReq)
|
||||
|
||||
// Check the response code
|
||||
toolsResp := toolsW.Result()
|
||||
assert.Equal(t, http.StatusAccepted, toolsResp.StatusCode)
|
||||
toolsResp.Body.Close()
|
||||
|
||||
// Check the channel message
|
||||
client := server.clients["test-session"]
|
||||
select {
|
||||
case message := <-client.channel:
|
||||
assert.Contains(t, message, `"tools":[{"name":"test.tool"`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for tools/list response")
|
||||
}
|
||||
|
||||
// Test prompts/list request
|
||||
promptsListBody, _ := json.Marshal(Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 2,
|
||||
Method: methodPromptsList,
|
||||
Params: json.RawMessage(`{}`),
|
||||
})
|
||||
|
||||
promptsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(promptsListBody))
|
||||
promptsW := newSyncResponseRecorder()
|
||||
|
||||
// Process the request
|
||||
server.handleRequest(promptsW, promptsReq)
|
||||
|
||||
// Check the response code
|
||||
promptsResp := promptsW.Result()
|
||||
assert.Equal(t, http.StatusAccepted, promptsResp.StatusCode)
|
||||
promptsResp.Body.Close()
|
||||
|
||||
// Check the channel message
|
||||
select {
|
||||
case message := <-client.channel:
|
||||
assert.Contains(t, message, `"prompts":[{"name":"test.prompt"`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for prompts/list response")
|
||||
}
|
||||
|
||||
// Test resources/list request
|
||||
resourcesListBody, _ := json.Marshal(Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 3,
|
||||
Method: methodResourcesList,
|
||||
Params: json.RawMessage(`{}`),
|
||||
})
|
||||
|
||||
resourcesReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(resourcesListBody))
|
||||
resourcesW := newSyncResponseRecorder()
|
||||
|
||||
// Process the request
|
||||
server.handleRequest(resourcesW, resourcesReq)
|
||||
|
||||
// Check the response code
|
||||
resourcesResp := resourcesW.Result()
|
||||
assert.Equal(t, http.StatusAccepted, resourcesResp.StatusCode)
|
||||
resourcesResp.Body.Close()
|
||||
|
||||
// Check the channel message
|
||||
select {
|
||||
case message := <-client.channel:
|
||||
assert.Contains(t, message, `"name":"test.resource"`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for resources/list response")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessListMethods tests the list processing methods with pagination
|
||||
func TestProcessListMethods(t *testing.T) {
|
||||
server := &sseMcpServer{
|
||||
tools: make(map[string]Tool),
|
||||
prompts: make(map[string]Prompt),
|
||||
resources: make(map[string]Resource),
|
||||
}
|
||||
|
||||
// Add some test data
|
||||
for i := 1; i <= 5; i++ {
|
||||
tool := Tool{
|
||||
Name: fmt.Sprintf("tool%d", i),
|
||||
Description: fmt.Sprintf("Tool %d", i),
|
||||
InputSchema: InputSchema{Type: "object"},
|
||||
}
|
||||
server.tools[tool.Name] = tool
|
||||
|
||||
prompt := Prompt{
|
||||
Name: fmt.Sprintf("prompt%d", i),
|
||||
Description: fmt.Sprintf("Prompt %d", i),
|
||||
}
|
||||
server.prompts[prompt.Name] = prompt
|
||||
|
||||
resource := Resource{
|
||||
Name: fmt.Sprintf("resource%d", i),
|
||||
URI: fmt.Sprintf("http://example.com/%d", i),
|
||||
Description: fmt.Sprintf("Resource %d", i),
|
||||
}
|
||||
server.resources[resource.Name] = resource
|
||||
}
|
||||
|
||||
// Create a test client
|
||||
client := &mcpClient{
|
||||
id: "test-client",
|
||||
channel: make(chan string, 10),
|
||||
initialized: true,
|
||||
}
|
||||
|
||||
// Test processListTools
|
||||
req := Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 1,
|
||||
Method: methodToolsList,
|
||||
Params: json.RawMessage(`{"cursor": "", "_meta": {"progressToken": "token1"}}`),
|
||||
}
|
||||
|
||||
server.processListTools(context.Background(), client, req)
|
||||
|
||||
// Read response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, `"tools":`)
|
||||
assert.Contains(t, response, `"progressToken":"token1"`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for tools/list response")
|
||||
}
|
||||
|
||||
// Test processListPrompts
|
||||
req.ID = 2
|
||||
req.Method = methodPromptsList
|
||||
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
||||
server.processListPrompts(context.Background(), client, req)
|
||||
|
||||
// Read response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, `"prompts":`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for prompts/list response")
|
||||
}
|
||||
|
||||
// Test processListResources
|
||||
req.ID = 3
|
||||
req.Method = methodResourcesList
|
||||
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
||||
server.processListResources(context.Background(), client, req)
|
||||
|
||||
// Read response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, `"resources":`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for resources/list response")
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorResponseHandling tests error handling in the server
|
||||
func TestErrorResponseHandling(t *testing.T) {
|
||||
server := &sseMcpServer{
|
||||
tools: make(map[string]Tool),
|
||||
prompts: make(map[string]Prompt),
|
||||
resources: make(map[string]Resource),
|
||||
}
|
||||
|
||||
// Create a test client
|
||||
client := &mcpClient{
|
||||
id: "test-client",
|
||||
channel: make(chan string, 10),
|
||||
initialized: true,
|
||||
}
|
||||
|
||||
// Test invalid method
|
||||
req := Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 1,
|
||||
Method: "invalid_method",
|
||||
Params: json.RawMessage(`{}`),
|
||||
}
|
||||
|
||||
// Mock handleRequest by directly calling error handler
|
||||
server.sendErrorResponse(context.Background(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, `"error":{"code":-32601,"message":"Method not found"}`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for error response")
|
||||
}
|
||||
|
||||
// Test invalid tool
|
||||
toolReq := Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 2,
|
||||
Method: methodToolsCall,
|
||||
Params: json.RawMessage(`{"name":"non_existent_tool"}`),
|
||||
}
|
||||
|
||||
// Call process method directly
|
||||
server.processToolCall(context.Background(), client, toolReq)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, `"error":{"code":-32602,"message":"Tool 'non_existent_tool' not found"}`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for error response")
|
||||
}
|
||||
|
||||
// Test invalid prompt
|
||||
promptReq := Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 3,
|
||||
Method: methodPromptsGet,
|
||||
Params: json.RawMessage(`{"name":"non_existent_prompt"}`),
|
||||
}
|
||||
|
||||
// Call process method directly
|
||||
server.processGetPrompt(context.Background(), client, promptReq)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, `"error":{"code":-32602,"message":"Prompt 'non_existent_prompt' not found"}`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for error response")
|
||||
}
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/mapping"
|
||||
)
|
||||
|
||||
// ParseArguments parses the arguments and populates the request object
|
||||
func ParseArguments(args any, req any) error {
|
||||
switch arguments := args.(type) {
|
||||
case map[string]string:
|
||||
m := make(map[string]any, len(arguments))
|
||||
for k, v := range arguments {
|
||||
m[k] = v
|
||||
}
|
||||
return mapping.UnmarshalJsonMap(m, req, mapping.WithStringValues())
|
||||
case map[string]any:
|
||||
return mapping.UnmarshalJsonMap(arguments, req)
|
||||
default:
|
||||
return fmt.Errorf("unsupported argument type: %T", arguments)
|
||||
}
|
||||
}
|
||||
@@ -1,139 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestParseArguments_MapStringString tests parsing map[string]string arguments
|
||||
func TestParseArguments_MapStringString(t *testing.T) {
|
||||
// Sample request struct to populate
|
||||
type requestStruct struct {
|
||||
Name string `json:"name"`
|
||||
Message string `json:"message"`
|
||||
Count int `json:"count"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// Create test arguments
|
||||
args := map[string]string{
|
||||
"name": "test-name",
|
||||
"message": "hello world",
|
||||
"count": "42",
|
||||
"enabled": "true",
|
||||
}
|
||||
|
||||
// Create a target object to populate
|
||||
var req requestStruct
|
||||
|
||||
// Parse the arguments
|
||||
err := ParseArguments(args, &req)
|
||||
|
||||
// Verify results
|
||||
assert.NoError(t, err, "Should parse map[string]string without error")
|
||||
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
|
||||
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
|
||||
assert.Equal(t, 42, req.Count, "Count should be correctly parsed to int")
|
||||
assert.True(t, req.Enabled, "Enabled should be correctly parsed to bool")
|
||||
}
|
||||
|
||||
// TestParseArguments_MapStringAny tests parsing map[string]any arguments
|
||||
func TestParseArguments_MapStringAny(t *testing.T) {
|
||||
// Sample request struct to populate
|
||||
type requestStruct struct {
|
||||
Name string `json:"name"`
|
||||
Message string `json:"message"`
|
||||
Count int `json:"count"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Tags []string `json:"tags"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
}
|
||||
|
||||
// Create test arguments with mixed types
|
||||
args := map[string]any{
|
||||
"name": "test-name",
|
||||
"message": "hello world",
|
||||
"count": 42, // note: this is already an int
|
||||
"enabled": true, // note: this is already a bool
|
||||
"tags": []string{"tag1", "tag2"},
|
||||
"metadata": map[string]string{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
},
|
||||
}
|
||||
|
||||
// Create a target object to populate
|
||||
var req requestStruct
|
||||
|
||||
// Parse the arguments
|
||||
err := ParseArguments(args, &req)
|
||||
|
||||
// Verify results
|
||||
assert.NoError(t, err, "Should parse map[string]any without error")
|
||||
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
|
||||
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
|
||||
assert.Equal(t, 42, req.Count, "Count should be correctly parsed")
|
||||
assert.True(t, req.Enabled, "Enabled should be correctly parsed")
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, req.Tags, "Tags should be correctly parsed")
|
||||
assert.Equal(t, map[string]string{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
}, req.Metadata, "Metadata should be correctly parsed")
|
||||
}
|
||||
|
||||
// TestParseArguments_UnsupportedType tests parsing with an unsupported type
|
||||
func TestParseArguments_UnsupportedType(t *testing.T) {
|
||||
// Sample request struct to populate
|
||||
type requestStruct struct {
|
||||
Name string `json:"name"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Use an unsupported argument type (slice)
|
||||
args := []string{"not", "a", "map"}
|
||||
|
||||
// Create a target object to populate
|
||||
var req requestStruct
|
||||
|
||||
// Parse the arguments
|
||||
err := ParseArguments(args, &req)
|
||||
|
||||
// Verify error is returned with correct message
|
||||
assert.Error(t, err, "Should return error for unsupported type")
|
||||
assert.Contains(t, err.Error(), "unsupported argument type", "Error should mention unsupported type")
|
||||
assert.Contains(t, err.Error(), "[]string", "Error should include the actual type")
|
||||
}
|
||||
|
||||
// TestParseArguments_EmptyMap tests parsing with empty maps
|
||||
func TestParseArguments_EmptyMap(t *testing.T) {
|
||||
// Sample request struct to populate
|
||||
type requestStruct struct {
|
||||
Name string `json:"name,optional"`
|
||||
Message string `json:"message,optional"`
|
||||
}
|
||||
|
||||
// Test empty map[string]string
|
||||
t.Run("EmptyMapStringString", func(t *testing.T) {
|
||||
args := map[string]string{}
|
||||
var req requestStruct
|
||||
|
||||
err := ParseArguments(args, &req)
|
||||
|
||||
assert.NoError(t, err, "Should parse empty map[string]string without error")
|
||||
assert.Empty(t, req.Name, "Name should be empty string")
|
||||
assert.Empty(t, req.Message, "Message should be empty string")
|
||||
})
|
||||
|
||||
// Test empty map[string]any
|
||||
t.Run("EmptyMapStringAny", func(t *testing.T) {
|
||||
args := map[string]any{}
|
||||
var req requestStruct
|
||||
|
||||
err := ParseArguments(args, &req)
|
||||
|
||||
assert.NoError(t, err, "Should parse empty map[string]any without error")
|
||||
assert.Empty(t, req.Name, "Name should be empty string")
|
||||
assert.Empty(t, req.Message, "Message should be empty string")
|
||||
})
|
||||
}
|
||||
1012
mcp/readme.md
1012
mcp/readme.md
File diff suppressed because it is too large
Load Diff
990
mcp/server.go
990
mcp/server.go
File diff suppressed because it is too large
Load Diff
3691
mcp/server_test.go
3691
mcp/server_test.go
File diff suppressed because it is too large
Load Diff
392
mcp/types.go
392
mcp/types.go
@@ -2,316 +2,96 @@ package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/zeromicro/go-zero/rest"
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// Cursor is an opaque token used for pagination
|
||||
type Cursor string
|
||||
// Re-export commonly used SDK types for convenience
|
||||
type (
|
||||
// Tool types
|
||||
Tool = sdkmcp.Tool
|
||||
CallToolParams = sdkmcp.CallToolParams
|
||||
CallToolResult = sdkmcp.CallToolResult
|
||||
CallToolRequest = sdkmcp.CallToolRequest
|
||||
|
||||
// Request represents a generic MCP request following JSON-RPC 2.0 specification
|
||||
type Request struct {
|
||||
SessionId string `form:"session_id"` // Session identifier for client tracking
|
||||
JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec
|
||||
ID any `json:"id"` // Request identifier for matching responses
|
||||
Method string `json:"method"` // Method name to invoke
|
||||
Params json.RawMessage `json:"params"` // Parameters for the method
|
||||
}
|
||||
// Content types
|
||||
Content = sdkmcp.Content
|
||||
TextContent = sdkmcp.TextContent
|
||||
ImageContent = sdkmcp.ImageContent
|
||||
AudioContent = sdkmcp.AudioContent
|
||||
|
||||
func (r Request) isNotification() (bool, error) {
|
||||
switch val := r.ID.(type) {
|
||||
case int:
|
||||
return val == 0, nil
|
||||
case int64:
|
||||
return val == 0, nil
|
||||
case float64:
|
||||
return val == 0.0, nil
|
||||
case string:
|
||||
return len(val) == 0, nil
|
||||
case nil:
|
||||
return true, nil
|
||||
default:
|
||||
return false, fmt.Errorf("invalid type %T", val)
|
||||
// Prompt types
|
||||
Prompt = sdkmcp.Prompt
|
||||
PromptMessage = sdkmcp.PromptMessage
|
||||
GetPromptParams = sdkmcp.GetPromptParams
|
||||
GetPromptResult = sdkmcp.GetPromptResult
|
||||
|
||||
// Resource types
|
||||
Resource = sdkmcp.Resource
|
||||
ResourceContents = sdkmcp.ResourceContents
|
||||
ReadResourceParams = sdkmcp.ReadResourceParams
|
||||
ReadResourceResult = sdkmcp.ReadResourceResult
|
||||
|
||||
// Session and server types
|
||||
Server = sdkmcp.Server
|
||||
ServerSession = sdkmcp.ServerSession
|
||||
ServerOptions = sdkmcp.ServerOptions
|
||||
Implementation = sdkmcp.Implementation
|
||||
|
||||
// Transport types
|
||||
SSEHandler = sdkmcp.SSEHandler
|
||||
StreamableHTTPHandler = sdkmcp.StreamableHTTPHandler
|
||||
)
|
||||
|
||||
// ToolHandler is a generic function signature for tool handlers.
|
||||
// Handlers should accept context, request, and typed arguments, and return
|
||||
// a result, metadata, and error.
|
||||
//
|
||||
// Deprecated: Use ToolHandlerFor directly from the SDK types.
|
||||
type ToolHandler[Args any, Meta any] func(
|
||||
ctx context.Context,
|
||||
req *CallToolRequest,
|
||||
args Args,
|
||||
) (*CallToolResult, Meta, error)
|
||||
|
||||
// PromptHandler is a function signature for prompt handlers.
|
||||
type PromptHandler func(
|
||||
ctx context.Context,
|
||||
req *sdkmcp.GetPromptRequest,
|
||||
args map[string]string,
|
||||
) (*GetPromptResult, error)
|
||||
|
||||
// ResourceHandler is a function signature for resource handlers.
|
||||
type ResourceHandler func(
|
||||
ctx context.Context,
|
||||
req *sdkmcp.ReadResourceRequest,
|
||||
uri string,
|
||||
) (*ReadResourceResult, error)
|
||||
|
||||
// AddTool registers a tool with the MCP server using type-safe generics.
|
||||
// The SDK automatically generates JSON schema from the Args struct tags.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type GreetArgs struct {
|
||||
// Name string `json:"name" jsonschema:"description=Name to greet"`
|
||||
// }
|
||||
//
|
||||
// tool := &mcp.Tool{
|
||||
// Name: "greet",
|
||||
// Description: "Greet someone",
|
||||
// }
|
||||
//
|
||||
// handler := func(ctx context.Context, req *mcp.CallToolRequest, args GreetArgs) (*mcp.CallToolResult, any, error) {
|
||||
// return &mcp.CallToolResult{
|
||||
// Content: []mcp.Content{&mcp.TextContent{Text: "Hello " + args.Name}},
|
||||
// }, nil, nil
|
||||
// }
|
||||
//
|
||||
// mcp.AddTool(server, tool, handler)
|
||||
func AddTool[In, Out any](server McpServer, tool *Tool, handler func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error)) {
|
||||
// Access internal server - only works with mcpServerImpl
|
||||
if impl, ok := server.(*mcpServerImpl); ok {
|
||||
sdkmcp.AddTool(impl.mcpServer, tool, handler)
|
||||
}
|
||||
}
|
||||
|
||||
type PaginatedParams struct {
|
||||
Cursor string `json:"cursor"`
|
||||
Meta struct {
|
||||
ProgressToken any `json:"progressToken"`
|
||||
} `json:"_meta"`
|
||||
}
|
||||
|
||||
// Result is the base interface for all results
|
||||
type Result struct {
|
||||
Meta map[string]any `json:"_meta,omitempty"` // Optional metadata
|
||||
}
|
||||
|
||||
// PaginatedResult is a base for results that support pagination
|
||||
type PaginatedResult struct {
|
||||
Result
|
||||
NextCursor Cursor `json:"nextCursor,omitempty"` // Opaque token for fetching next page
|
||||
}
|
||||
|
||||
// ListToolsResult represents the response to a tools/list request
|
||||
type ListToolsResult struct {
|
||||
PaginatedResult
|
||||
Tools []Tool `json:"tools"` // List of available tools
|
||||
}
|
||||
|
||||
// Message Content Types
|
||||
|
||||
// RoleType represents the sender or recipient of messages in a conversation
|
||||
type RoleType string
|
||||
|
||||
// PromptArgument defines a single argument that can be passed to a prompt
|
||||
type PromptArgument struct {
|
||||
Name string `json:"name"` // Argument name
|
||||
Description string `json:"description,omitempty"` // Human-readable description
|
||||
Required bool `json:"required,omitempty"` // Whether this argument is required
|
||||
}
|
||||
|
||||
// PromptHandler is a function that dynamically generates prompt content
|
||||
type PromptHandler func(ctx context.Context, args map[string]string) ([]PromptMessage, error)
|
||||
|
||||
// Prompt represents an MCP Prompt definition
|
||||
type Prompt struct {
|
||||
Name string `json:"name"` // Unique identifier for the prompt
|
||||
Description string `json:"description,omitempty"` // Human-readable description
|
||||
Arguments []PromptArgument `json:"arguments,omitempty"` // Arguments for customization
|
||||
Content string `json:"-"` // Static content (internal use only)
|
||||
Handler PromptHandler `json:"-"` // Handler for dynamic content generation
|
||||
}
|
||||
|
||||
// PromptMessage represents a message in a conversation
|
||||
type PromptMessage struct {
|
||||
Role RoleType `json:"role"` // Message sender role
|
||||
Content any `json:"content"` // Message content (TextContent, ImageContent, etc.)
|
||||
}
|
||||
|
||||
// TextContent represents text content in a message
|
||||
type TextContent struct {
|
||||
Text string `json:"text"` // The text content
|
||||
Annotations *Annotations `json:"annotations,omitempty"` // Optional annotations
|
||||
}
|
||||
|
||||
type typedTextContent struct {
|
||||
Type string `json:"type"`
|
||||
TextContent
|
||||
}
|
||||
|
||||
// ImageContent represents image data in a message
|
||||
type ImageContent struct {
|
||||
Data string `json:"data"` // Base64-encoded image data
|
||||
MimeType string `json:"mimeType"` // MIME type (e.g., "image/png")
|
||||
}
|
||||
|
||||
type typedImageContent struct {
|
||||
Type string `json:"type"`
|
||||
ImageContent
|
||||
}
|
||||
|
||||
// AudioContent represents audio data in a message
|
||||
type AudioContent struct {
|
||||
Data string `json:"data"` // Base64-encoded audio data
|
||||
MimeType string `json:"mimeType"` // MIME type (e.g., "audio/mp3")
|
||||
}
|
||||
|
||||
type typedAudioContent struct {
|
||||
Type string `json:"type"`
|
||||
AudioContent
|
||||
}
|
||||
|
||||
// FileContent represents file content
|
||||
type FileContent struct {
|
||||
URI string `json:"uri"` // URI identifying the file
|
||||
MimeType string `json:"mimeType"` // MIME type of the file
|
||||
Text string `json:"text"` // File content as text
|
||||
}
|
||||
|
||||
// EmbeddedResource represents a resource embedded in a message
|
||||
type EmbeddedResource struct {
|
||||
Type string `json:"type"` // Always "resource"
|
||||
Resource ResourceContent `json:"resource"` // The resource data
|
||||
}
|
||||
|
||||
// Annotations provides additional metadata for content
|
||||
type Annotations struct {
|
||||
Audience []RoleType `json:"audience,omitempty"` // Who should see this content
|
||||
Priority *float64 `json:"priority,omitempty"` // Optional priority (0-1)
|
||||
}
|
||||
|
||||
// Tool-related Types
|
||||
|
||||
// ToolHandler is a function that handles tool calls
|
||||
type ToolHandler func(ctx context.Context, params map[string]any) (any, error)
|
||||
|
||||
// Tool represents a Model Context Protocol Tool definition
|
||||
type Tool struct {
|
||||
Name string `json:"name"` // Unique identifier for the tool
|
||||
Description string `json:"description"` // Human-readable description
|
||||
InputSchema InputSchema `json:"inputSchema"` // JSON Schema for parameters
|
||||
Handler ToolHandler `json:"-"` // Not sent to clients
|
||||
}
|
||||
|
||||
// InputSchema represents tool's input schema in JSON Schema format
|
||||
type InputSchema struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]any `json:"properties"` // Property definitions
|
||||
Required []string `json:"required,omitempty"` // List of required properties
|
||||
}
|
||||
|
||||
// CallToolResult represents a tool call result that conforms to the MCP schema
|
||||
type CallToolResult struct {
|
||||
Result
|
||||
Content []any `json:"content"` // Content items (text, images, etc.)
|
||||
IsError bool `json:"isError,omitempty"` // True if tool execution failed
|
||||
}
|
||||
|
||||
// Resource represents a Model Context Protocol Resource definition
|
||||
type Resource struct {
|
||||
URI string `json:"uri"` // Unique resource identifier (RFC3986)
|
||||
Name string `json:"name"` // Human-readable name
|
||||
Description string `json:"description,omitempty"` // Optional description
|
||||
MimeType string `json:"mimeType,omitempty"` // Optional MIME type
|
||||
Handler ResourceHandler `json:"-"` // Internal handler not sent to clients
|
||||
}
|
||||
|
||||
// ResourceHandler is a function that handles resource read requests
|
||||
type ResourceHandler func(ctx context.Context) (ResourceContent, error)
|
||||
|
||||
// ResourceContent represents the content of a resource
|
||||
type ResourceContent struct {
|
||||
URI string `json:"uri"` // Resource URI (required)
|
||||
MimeType string `json:"mimeType,omitempty"` // MIME type of the resource
|
||||
Text string `json:"text,omitempty"` // Text content (if available)
|
||||
Blob string `json:"blob,omitempty"` // Base64 encoded blob data (if available)
|
||||
}
|
||||
|
||||
// ResourcesListResult represents the response to a resources/list request
|
||||
type ResourcesListResult struct {
|
||||
PaginatedResult
|
||||
Resources []Resource `json:"resources"` // List of available resources
|
||||
}
|
||||
|
||||
// ResourceReadParams contains parameters for a resources/read request
|
||||
type ResourceReadParams struct {
|
||||
URI string `json:"uri"` // URI of the resource to read
|
||||
}
|
||||
|
||||
// ResourceReadResult contains the result of a resources/read request
|
||||
type ResourceReadResult struct {
|
||||
Result
|
||||
Contents []ResourceContent `json:"contents"` // Array of resource content
|
||||
}
|
||||
|
||||
// ResourceSubscribeParams contains parameters for a resources/subscribe request
|
||||
type ResourceSubscribeParams struct {
|
||||
URI string `json:"uri"` // URI of the resource to subscribe to
|
||||
}
|
||||
|
||||
// ResourceUpdateNotification represents a notification about a resource update
|
||||
type ResourceUpdateNotification struct {
|
||||
URI string `json:"uri"` // URI of the updated resource
|
||||
Content ResourceContent `json:"content"` // New resource content
|
||||
}
|
||||
|
||||
// Client and Server Types
|
||||
|
||||
// mcpClient represents an SSE client connection
|
||||
type mcpClient struct {
|
||||
id string // Unique client identifier
|
||||
channel chan string // Channel for sending SSE messages
|
||||
initialized bool // Tracks if client has sent notifications/initialized
|
||||
}
|
||||
|
||||
// McpServer defines the interface for Model Context Protocol servers
|
||||
type McpServer interface {
|
||||
Start()
|
||||
Stop()
|
||||
RegisterTool(tool Tool) error
|
||||
RegisterPrompt(prompt Prompt)
|
||||
RegisterResource(resource Resource)
|
||||
}
|
||||
|
||||
// sseMcpServer implements the McpServer interface using SSE
|
||||
type sseMcpServer struct {
|
||||
conf McpConf
|
||||
server *rest.Server
|
||||
clients map[string]*mcpClient
|
||||
clientsLock sync.Mutex
|
||||
tools map[string]Tool
|
||||
toolsLock sync.Mutex
|
||||
prompts map[string]Prompt
|
||||
promptsLock sync.Mutex
|
||||
resources map[string]Resource
|
||||
resourcesLock sync.Mutex
|
||||
}
|
||||
|
||||
// Response Types
|
||||
|
||||
// errorObj represents a JSON-RPC error object
|
||||
type errorObj struct {
|
||||
Code int `json:"code"` // Error code
|
||||
Message string `json:"message"` // Error message
|
||||
}
|
||||
|
||||
// Response represents a JSON-RPC response
|
||||
type Response struct {
|
||||
JsonRpc string `json:"jsonrpc"` // Always "2.0"
|
||||
ID any `json:"id"` // Same as request ID
|
||||
Result any `json:"result"` // Result object (null if error)
|
||||
Error *errorObj `json:"error,omitempty"` // Error object (null if success)
|
||||
}
|
||||
|
||||
// Server Information Types
|
||||
|
||||
// serverInfo provides information about the server
|
||||
type serverInfo struct {
|
||||
Name string `json:"name"` // Server name
|
||||
Version string `json:"version"` // Server version
|
||||
}
|
||||
|
||||
// capabilities describes the server's capabilities
|
||||
type capabilities struct {
|
||||
Logging struct{} `json:"logging"`
|
||||
Prompts struct {
|
||||
ListChanged bool `json:"listChanged"` // Server will notify on prompt changes
|
||||
} `json:"prompts"`
|
||||
Resources struct {
|
||||
Subscribe bool `json:"subscribe"` // Server supports resource subscriptions
|
||||
ListChanged bool `json:"listChanged"` // Server will notify on resource changes
|
||||
} `json:"resources"`
|
||||
Tools struct {
|
||||
ListChanged bool `json:"listChanged"` // Server will notify on tool changes
|
||||
} `json:"tools"`
|
||||
}
|
||||
|
||||
// initializationResponse is sent in response to an initialize request
|
||||
type initializationResponse struct {
|
||||
ProtocolVersion string `json:"protocolVersion"` // Protocol version
|
||||
Capabilities capabilities `json:"capabilities"` // Server capabilities
|
||||
ServerInfo serverInfo `json:"serverInfo"` // Server information
|
||||
}
|
||||
|
||||
// ToolCallParams contains the parameters for a tool call
|
||||
type ToolCallParams struct {
|
||||
Name string `json:"name"` // Tool name
|
||||
Parameters map[string]any `json:"parameters"` // Tool parameters
|
||||
}
|
||||
|
||||
// ToolResult contains the result of a tool execution
|
||||
type ToolResult struct {
|
||||
Type string `json:"type"` // Content type (text, image, etc.)
|
||||
Content any `json:"content"` // Result content
|
||||
}
|
||||
|
||||
// errorMessage represents a detailed error message
|
||||
type errorMessage struct {
|
||||
Code int `json:"code"` // Error code
|
||||
Message string `json:"message"` // Error message
|
||||
Data any `json:",omitempty"` // Additional error data
|
||||
}
|
||||
|
||||
@@ -1,271 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestResponseMarshaling(t *testing.T) {
|
||||
// Test that the Response struct marshals correctly
|
||||
resp := Response{
|
||||
JsonRpc: "2.0",
|
||||
ID: 123,
|
||||
Result: map[string]string{
|
||||
"key": "value",
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(resp)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"jsonrpc":"2.0"`)
|
||||
assert.Contains(t, string(data), `"id":123`)
|
||||
assert.Contains(t, string(data), `"result":{"key":"value"}`)
|
||||
|
||||
// Test response with error
|
||||
respWithError := Response{
|
||||
JsonRpc: "2.0",
|
||||
ID: 456,
|
||||
Error: &errorObj{
|
||||
Code: errCodeInvalidRequest,
|
||||
Message: "Invalid Request",
|
||||
},
|
||||
}
|
||||
|
||||
data, err = json.Marshal(respWithError)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"jsonrpc":"2.0"`)
|
||||
assert.Contains(t, string(data), `"id":456`)
|
||||
assert.Contains(t, string(data), `"error":{"code":-32600,"message":"Invalid Request"}`)
|
||||
}
|
||||
|
||||
func TestRequestUnmarshaling(t *testing.T) {
|
||||
// Test that the Request struct unmarshals correctly
|
||||
jsonStr := `{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 789,
|
||||
"method": "test_method",
|
||||
"params": {"key": "value"}
|
||||
}`
|
||||
|
||||
var req Request
|
||||
err := json.Unmarshal([]byte(jsonStr), &req)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "2.0", req.JsonRpc)
|
||||
assert.Equal(t, float64(789), req.ID)
|
||||
assert.Equal(t, "test_method", req.Method)
|
||||
|
||||
// Check params unmarshaled correctly
|
||||
var params map[string]string
|
||||
err = json.Unmarshal(req.Params, ¶ms)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "value", params["key"])
|
||||
}
|
||||
|
||||
func TestToolStructs(t *testing.T) {
|
||||
// Test Tool struct
|
||||
tool := Tool{
|
||||
Name: "test.tool",
|
||||
Description: "A test tool",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"input": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Input parameter",
|
||||
},
|
||||
},
|
||||
Required: []string{"input"},
|
||||
},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return "result", nil
|
||||
},
|
||||
}
|
||||
|
||||
// Verify fields are correct
|
||||
assert.Equal(t, "test.tool", tool.Name)
|
||||
assert.Equal(t, "A test tool", tool.Description)
|
||||
assert.Equal(t, "object", tool.InputSchema.Type)
|
||||
assert.Contains(t, tool.InputSchema.Properties, "input")
|
||||
propMap, ok := tool.InputSchema.Properties["input"].(map[string]any)
|
||||
assert.True(t, ok, "Property should be a map")
|
||||
assert.Equal(t, "string", propMap["type"])
|
||||
assert.NotNil(t, tool.Handler)
|
||||
|
||||
// Verify JSON marshalling (which should exclude Handler function)
|
||||
data, err := json.Marshal(tool)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"name":"test.tool"`)
|
||||
assert.Contains(t, string(data), `"description":"A test tool"`)
|
||||
assert.Contains(t, string(data), `"inputSchema":`)
|
||||
assert.NotContains(t, string(data), `"Handler":`)
|
||||
}
|
||||
|
||||
func TestPromptStructs(t *testing.T) {
|
||||
// Test Prompt struct
|
||||
prompt := Prompt{
|
||||
Name: "test.prompt",
|
||||
Description: "A test prompt description",
|
||||
}
|
||||
|
||||
// Verify fields are correct
|
||||
assert.Equal(t, "test.prompt", prompt.Name)
|
||||
assert.Equal(t, "A test prompt description", prompt.Description)
|
||||
|
||||
// Verify JSON marshalling
|
||||
data, err := json.Marshal(prompt)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"name":"test.prompt"`)
|
||||
assert.Contains(t, string(data), `"description":"A test prompt description"`)
|
||||
}
|
||||
|
||||
func TestResourceStructs(t *testing.T) {
|
||||
// Test Resource struct
|
||||
resource := Resource{
|
||||
Name: "test.resource",
|
||||
URI: "http://example.com/resource",
|
||||
Description: "A test resource",
|
||||
}
|
||||
|
||||
// Verify fields are correct
|
||||
assert.Equal(t, "test.resource", resource.Name)
|
||||
assert.Equal(t, "http://example.com/resource", resource.URI)
|
||||
assert.Equal(t, "A test resource", resource.Description)
|
||||
|
||||
// Verify JSON marshalling
|
||||
data, err := json.Marshal(resource)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"name":"test.resource"`)
|
||||
assert.Contains(t, string(data), `"uri":"http://example.com/resource"`)
|
||||
assert.Contains(t, string(data), `"description":"A test resource"`)
|
||||
}
|
||||
|
||||
func TestContentTypes(t *testing.T) {
|
||||
// Test TextContent
|
||||
textContent := TextContent{
|
||||
Text: "Sample text",
|
||||
Annotations: &Annotations{
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
Priority: ptr(1.0),
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(textContent)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"text":"Sample text"`)
|
||||
assert.Contains(t, string(data), `"audience":["user","assistant"]`)
|
||||
assert.Contains(t, string(data), `"priority":1`)
|
||||
|
||||
// Test ImageContent
|
||||
imageContent := ImageContent{
|
||||
Data: "base64data",
|
||||
MimeType: "image/png",
|
||||
}
|
||||
|
||||
data, err = json.Marshal(imageContent)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"data":"base64data"`)
|
||||
assert.Contains(t, string(data), `"mimeType":"image/png"`)
|
||||
|
||||
// Test AudioContent
|
||||
audioContent := AudioContent{
|
||||
Data: "base64audio",
|
||||
MimeType: "audio/mp3",
|
||||
}
|
||||
|
||||
data, err = json.Marshal(audioContent)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"data":"base64audio"`)
|
||||
assert.Contains(t, string(data), `"mimeType":"audio/mp3"`)
|
||||
}
|
||||
|
||||
func TestCallToolResult(t *testing.T) {
|
||||
// Test CallToolResult
|
||||
result := CallToolResult{
|
||||
Result: Result{
|
||||
Meta: map[string]any{
|
||||
"progressToken": "token123",
|
||||
},
|
||||
},
|
||||
Content: []interface{}{
|
||||
TextContent{
|
||||
Text: "Sample result",
|
||||
},
|
||||
},
|
||||
IsError: false,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"_meta":{"progressToken":"token123"}`)
|
||||
assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`)
|
||||
assert.NotContains(t, string(data), `"isError":`)
|
||||
}
|
||||
|
||||
func TestRequest_isNotification(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id any
|
||||
want bool
|
||||
wantErr error
|
||||
}{
|
||||
// integer test cases
|
||||
{name: "int zero", id: 0, want: true, wantErr: nil},
|
||||
{name: "int non-zero", id: 1, want: false, wantErr: nil},
|
||||
{name: "int64 zero", id: int64(0), want: true, wantErr: nil},
|
||||
{name: "int64 max", id: int64(9223372036854775807), want: false, wantErr: nil},
|
||||
|
||||
// floating point number test cases
|
||||
{name: "float64 zero", id: float64(0.0), want: true, wantErr: nil},
|
||||
{name: "float64 positive", id: float64(0.000001), want: false, wantErr: nil},
|
||||
{name: "float64 negative", id: float64(-0.000001), want: false, wantErr: nil},
|
||||
{name: "float64 epsilon", id: float64(1e-300), want: false, wantErr: nil},
|
||||
|
||||
// string test cases
|
||||
{name: "empty string", id: "", want: true, wantErr: nil},
|
||||
{name: "non-empty string", id: "abc", want: false, wantErr: nil},
|
||||
{name: "space string", id: " ", want: false, wantErr: nil},
|
||||
{name: "unicode string", id: "こんにちは", want: false, wantErr: nil},
|
||||
|
||||
// special cases
|
||||
{name: "nil", id: nil, want: true, wantErr: nil},
|
||||
|
||||
// logical type test cases
|
||||
{name: "bool true", id: true, want: false, wantErr: errors.New("invalid type bool")},
|
||||
{name: "bool false", id: false, want: false, wantErr: errors.New("invalid type bool")},
|
||||
{name: "struct type", id: struct{}{}, want: false, wantErr: errors.New("invalid type struct {}")},
|
||||
{name: "slice type", id: []int{1, 2, 3}, want: false, wantErr: errors.New("invalid type []int")},
|
||||
{name: "map type", id: map[string]int{"a": 1}, want: false, wantErr: errors.New("invalid type map[string]int")},
|
||||
{name: "pointer type", id: new(int), want: false, wantErr: errors.New("invalid type *int")},
|
||||
{name: "func type", id: func() {}, want: false, wantErr: errors.New("invalid type func()")},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := Request{
|
||||
SessionId: "test-session",
|
||||
JsonRpc: "2.0",
|
||||
ID: tt.id,
|
||||
Method: "testMethod",
|
||||
Params: json.RawMessage(`{}`),
|
||||
}
|
||||
|
||||
got, err := req.isNotification()
|
||||
|
||||
if (err != nil) != (tt.wantErr != nil) {
|
||||
t.Fatalf("error presence mismatch: got error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if err != nil && tt.wantErr != nil && err.Error() != tt.wantErr.Error() {
|
||||
t.Fatalf("error message mismatch:\ngot %q\nwant %q", err.Error(), tt.wantErr.Error())
|
||||
}
|
||||
|
||||
if got != tt.want {
|
||||
t.Errorf("isNotification() = %v, want %v for ID %v (%T)", got, tt.want, tt.id, tt.id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
107
mcp/util.go
107
mcp/util.go
@@ -1,107 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import "fmt"
|
||||
|
||||
// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings
|
||||
func formatSSEMessage(event string, data []byte) string {
|
||||
return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data))
|
||||
}
|
||||
|
||||
// ptr is a helper function to get a pointer to a value
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
func toTypedContents(contents []any) []any {
|
||||
typedContents := make([]any, len(contents))
|
||||
|
||||
for i, content := range contents {
|
||||
switch v := content.(type) {
|
||||
case TextContent:
|
||||
typedContents[i] = typedTextContent{
|
||||
Type: ContentTypeText,
|
||||
TextContent: v,
|
||||
}
|
||||
case ImageContent:
|
||||
typedContents[i] = typedImageContent{
|
||||
Type: ContentTypeImage,
|
||||
ImageContent: v,
|
||||
}
|
||||
case AudioContent:
|
||||
typedContents[i] = typedAudioContent{
|
||||
Type: ContentTypeAudio,
|
||||
AudioContent: v,
|
||||
}
|
||||
default:
|
||||
typedContents[i] = typedTextContent{
|
||||
Type: ContentTypeText,
|
||||
TextContent: TextContent{
|
||||
Text: fmt.Sprintf("Unknown content type: %T", v),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return typedContents
|
||||
}
|
||||
|
||||
func toTypedPromptMessages(messages []PromptMessage) []PromptMessage {
|
||||
typedMessages := make([]PromptMessage, len(messages))
|
||||
|
||||
for i, msg := range messages {
|
||||
switch v := msg.Content.(type) {
|
||||
case TextContent:
|
||||
typedMessages[i] = PromptMessage{
|
||||
Role: msg.Role,
|
||||
Content: typedTextContent{
|
||||
Type: ContentTypeText,
|
||||
TextContent: v,
|
||||
},
|
||||
}
|
||||
case ImageContent:
|
||||
typedMessages[i] = PromptMessage{
|
||||
Role: msg.Role,
|
||||
Content: typedImageContent{
|
||||
Type: ContentTypeImage,
|
||||
ImageContent: v,
|
||||
},
|
||||
}
|
||||
case AudioContent:
|
||||
typedMessages[i] = PromptMessage{
|
||||
Role: msg.Role,
|
||||
Content: typedAudioContent{
|
||||
Type: ContentTypeAudio,
|
||||
AudioContent: v,
|
||||
},
|
||||
}
|
||||
default:
|
||||
typedMessages[i] = PromptMessage{
|
||||
Role: msg.Role,
|
||||
Content: typedTextContent{
|
||||
Type: ContentTypeText,
|
||||
TextContent: TextContent{
|
||||
Text: fmt.Sprintf("Unknown content type: %T", v),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return typedMessages
|
||||
}
|
||||
|
||||
// validatePromptArguments checks if all required arguments are provided
|
||||
// Returns a list of missing required arguments
|
||||
func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string {
|
||||
var missingArgs []string
|
||||
|
||||
for _, arg := range prompt.Arguments {
|
||||
if arg.Required {
|
||||
if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 {
|
||||
missingArgs = append(missingArgs, arg.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return missingArgs
|
||||
}
|
||||
274
mcp/util_test.go
274
mcp/util_test.go
@@ -1,274 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
Type string
|
||||
Data map[string]any
|
||||
}
|
||||
|
||||
func parseEvent(input string) (*Event, error) {
|
||||
var evt Event
|
||||
var dataStr string
|
||||
|
||||
scanner := bufio.NewScanner(strings.NewReader(input))
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "event:") {
|
||||
evt.Type = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
|
||||
} else if strings.HasPrefix(line, "data:") {
|
||||
dataStr = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(dataStr) > 0 {
|
||||
if err := json.Unmarshal([]byte(dataStr), &evt.Data); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse data: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &evt, nil
|
||||
}
|
||||
|
||||
// TestToTypedPromptMessages tests the toTypedPromptMessages function
|
||||
func TestToTypedPromptMessages(t *testing.T) {
|
||||
// Test with multiple message types in one test
|
||||
t.Run("MixedContentTypes", func(t *testing.T) {
|
||||
// Create test data with different content types
|
||||
messages := []PromptMessage{
|
||||
{
|
||||
Role: RoleUser,
|
||||
Content: TextContent{
|
||||
Text: "Hello, this is a text message",
|
||||
Annotations: &Annotations{
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
Priority: ptr(0.8),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: RoleAssistant,
|
||||
Content: ImageContent{
|
||||
Data: "base64ImageData",
|
||||
MimeType: "image/jpeg",
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: RoleUser,
|
||||
Content: AudioContent{
|
||||
Data: "base64AudioData",
|
||||
MimeType: "audio/mp3",
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "system",
|
||||
Content: "This is a simple string that should be handled as unknown type",
|
||||
},
|
||||
}
|
||||
|
||||
// Call the function
|
||||
result := toTypedPromptMessages(messages)
|
||||
|
||||
// Validate results
|
||||
require.Len(t, result, 4, "Should return the same number of messages")
|
||||
|
||||
// Validate first message (TextContent)
|
||||
msg := result[0]
|
||||
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
|
||||
|
||||
// Type assertion using reflection since Content is an interface
|
||||
typed, ok := msg.Content.(typedTextContent)
|
||||
require.True(t, ok, "Should be typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||
assert.Equal(t, "Hello, this is a text message", typed.Text, "Text content should be preserved")
|
||||
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
|
||||
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
|
||||
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
|
||||
assert.Equal(t, 0.8, *typed.Annotations.Priority, "Priority value should be preserved")
|
||||
|
||||
// Validate second message (ImageContent)
|
||||
msg = result[1]
|
||||
assert.Equal(t, RoleAssistant, msg.Role, "Role should be preserved")
|
||||
|
||||
// Type assertion for image content
|
||||
typedImg, ok := msg.Content.(typedImageContent)
|
||||
require.True(t, ok, "Should be typedImageContent")
|
||||
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
|
||||
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
|
||||
assert.Equal(t, "image/jpeg", typedImg.MimeType, "MimeType should be preserved")
|
||||
|
||||
// Validate third message (AudioContent)
|
||||
msg = result[2]
|
||||
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
|
||||
|
||||
// Type assertion for audio content
|
||||
typedAudio, ok := msg.Content.(typedAudioContent)
|
||||
require.True(t, ok, "Should be typedAudioContent")
|
||||
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
|
||||
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
|
||||
assert.Equal(t, "audio/mp3", typedAudio.MimeType, "MimeType should be preserved")
|
||||
|
||||
// Validate fourth message (unknown type converted to TextContent)
|
||||
msg = result[3]
|
||||
assert.Equal(t, RoleType("system"), msg.Role, "Role should be preserved")
|
||||
|
||||
// Should be converted to a typedTextContent with error message
|
||||
typedUnknown, ok := msg.Content.(typedTextContent)
|
||||
require.True(t, ok, "Unknown content should be converted to typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
|
||||
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
|
||||
})
|
||||
|
||||
// Test empty input
|
||||
t.Run("EmptyInput", func(t *testing.T) {
|
||||
messages := []PromptMessage{}
|
||||
result := toTypedPromptMessages(messages)
|
||||
assert.Empty(t, result, "Should return empty slice for empty input")
|
||||
})
|
||||
|
||||
// Test with nil annotations
|
||||
t.Run("NilAnnotations", func(t *testing.T) {
|
||||
messages := []PromptMessage{
|
||||
{
|
||||
Role: RoleUser,
|
||||
Content: TextContent{
|
||||
Text: "Text with nil annotations",
|
||||
Annotations: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := toTypedPromptMessages(messages)
|
||||
require.Len(t, result, 1, "Should return one message")
|
||||
|
||||
typed, ok := result[0].Content.(typedTextContent)
|
||||
require.True(t, ok, "Should be typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
|
||||
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
|
||||
})
|
||||
}
|
||||
|
||||
// TestToTypedContents tests the toTypedContents function
|
||||
func TestToTypedContents(t *testing.T) {
|
||||
// Test with multiple content types in one test
|
||||
t.Run("MixedContentTypes", func(t *testing.T) {
|
||||
// Create test data with different content types
|
||||
contents := []any{
|
||||
TextContent{
|
||||
Text: "Hello, this is a text content",
|
||||
Annotations: &Annotations{
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
Priority: ptr(0.7),
|
||||
},
|
||||
},
|
||||
ImageContent{
|
||||
Data: "base64ImageData",
|
||||
MimeType: "image/png",
|
||||
},
|
||||
AudioContent{
|
||||
Data: "base64AudioData",
|
||||
MimeType: "audio/wav",
|
||||
},
|
||||
"This is a simple string that should be handled as unknown type",
|
||||
}
|
||||
|
||||
// Call the function
|
||||
result := toTypedContents(contents)
|
||||
|
||||
// Validate results
|
||||
require.Len(t, result, 4, "Should return the same number of contents")
|
||||
|
||||
// Validate first content (TextContent)
|
||||
typed, ok := result[0].(typedTextContent)
|
||||
require.True(t, ok, "Should be typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||
assert.Equal(t, "Hello, this is a text content", typed.Text, "Text content should be preserved")
|
||||
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
|
||||
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
|
||||
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
|
||||
assert.Equal(t, 0.7, *typed.Annotations.Priority, "Priority value should be preserved")
|
||||
|
||||
// Validate second content (ImageContent)
|
||||
typedImg, ok := result[1].(typedImageContent)
|
||||
require.True(t, ok, "Should be typedImageContent")
|
||||
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
|
||||
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
|
||||
assert.Equal(t, "image/png", typedImg.MimeType, "MimeType should be preserved")
|
||||
|
||||
// Validate third content (AudioContent)
|
||||
typedAudio, ok := result[2].(typedAudioContent)
|
||||
require.True(t, ok, "Should be typedAudioContent")
|
||||
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
|
||||
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
|
||||
assert.Equal(t, "audio/wav", typedAudio.MimeType, "MimeType should be preserved")
|
||||
|
||||
// Validate fourth content (unknown type converted to TextContent)
|
||||
typedUnknown, ok := result[3].(typedTextContent)
|
||||
require.True(t, ok, "Unknown content should be converted to typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
|
||||
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
|
||||
})
|
||||
|
||||
// Test empty input
|
||||
t.Run("EmptyInput", func(t *testing.T) {
|
||||
contents := []any{}
|
||||
result := toTypedContents(contents)
|
||||
assert.Empty(t, result, "Should return empty slice for empty input")
|
||||
})
|
||||
|
||||
// Test with nil annotations
|
||||
t.Run("NilAnnotations", func(t *testing.T) {
|
||||
contents := []any{
|
||||
TextContent{
|
||||
Text: "Text with nil annotations",
|
||||
Annotations: nil,
|
||||
},
|
||||
}
|
||||
|
||||
result := toTypedContents(contents)
|
||||
require.Len(t, result, 1, "Should return one content")
|
||||
|
||||
typed, ok := result[0].(typedTextContent)
|
||||
require.True(t, ok, "Should be typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
|
||||
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
|
||||
})
|
||||
|
||||
// Test with custom struct (should be handled as unknown type)
|
||||
t.Run("CustomStruct", func(t *testing.T) {
|
||||
type CustomContent struct {
|
||||
Data string
|
||||
}
|
||||
|
||||
contents := []any{
|
||||
CustomContent{
|
||||
Data: "custom data",
|
||||
},
|
||||
}
|
||||
|
||||
result := toTypedContents(contents)
|
||||
require.Len(t, result, 1, "Should return one content")
|
||||
|
||||
typed, ok := result[0].(typedTextContent)
|
||||
require.True(t, ok, "Custom struct should be converted to typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||
assert.Contains(t, typed.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||
assert.Contains(t, typed.Text, "CustomContent", "Should mention the actual type")
|
||||
})
|
||||
}
|
||||
149
mcp/vars.go
149
mcp/vars.go
@@ -1,149 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
)
|
||||
|
||||
// Protocol constants
|
||||
const (
|
||||
// JSON-RPC version as defined in the specification
|
||||
jsonRpcVersion = "2.0"
|
||||
|
||||
// Session identifier key used in request URLs
|
||||
sessionIdKey = "session_id"
|
||||
|
||||
// progressTokenKey is used to track progress of long-running tasks
|
||||
progressTokenKey = "progressToken"
|
||||
)
|
||||
|
||||
// Server-Sent Events (SSE) event types
|
||||
const (
|
||||
// Standard message event for JSON-RPC responses
|
||||
eventMessage = "message"
|
||||
|
||||
// Endpoint event for sending endpoint URL to clients
|
||||
eventEndpoint = "endpoint"
|
||||
)
|
||||
|
||||
// Content type identifiers
|
||||
const (
|
||||
// ContentTypeObject is object content type
|
||||
ContentTypeObject = "object"
|
||||
|
||||
// ContentTypeText is text content type
|
||||
ContentTypeText = "text"
|
||||
|
||||
// ContentTypeImage is image content type
|
||||
ContentTypeImage = "image"
|
||||
|
||||
// ContentTypeAudio is audio content type
|
||||
ContentTypeAudio = "audio"
|
||||
|
||||
// ContentTypeResource is resource content type
|
||||
ContentTypeResource = "resource"
|
||||
)
|
||||
|
||||
// Collection keys for broadcast events
|
||||
const (
|
||||
// Key for prompts collection
|
||||
keyPrompts = "prompts"
|
||||
|
||||
// Key for resources collection
|
||||
keyResources = "resources"
|
||||
|
||||
// Key for tools collection
|
||||
keyTools = "tools"
|
||||
)
|
||||
|
||||
// JSON-RPC error codes
|
||||
// Standard error codes from JSON-RPC 2.0 spec
|
||||
const (
|
||||
// Invalid JSON was received by the server
|
||||
errCodeInvalidRequest = -32600
|
||||
|
||||
// The method does not exist / is not available
|
||||
errCodeMethodNotFound = -32601
|
||||
|
||||
// Invalid method parameter(s)
|
||||
errCodeInvalidParams = -32602
|
||||
|
||||
// Internal JSON-RPC error
|
||||
errCodeInternalError = -32603
|
||||
|
||||
// Tool execution timed out
|
||||
errCodeTimeout = -32001
|
||||
|
||||
// Resource not found error
|
||||
errCodeResourceNotFound = -32002
|
||||
|
||||
// Client hasn't completed initialization
|
||||
errCodeClientNotInitialized = -32800
|
||||
)
|
||||
|
||||
// User and assistant role definitions
|
||||
const (
|
||||
// RoleUser is the "user" role - the entity asking questions
|
||||
RoleUser RoleType = "user"
|
||||
|
||||
// RoleAssistant is the "assistant" role - the entity providing responses
|
||||
RoleAssistant RoleType = "assistant"
|
||||
)
|
||||
|
||||
// Method names as defined in the MCP specification
|
||||
const (
|
||||
// Initialize the connection between client and server
|
||||
methodInitialize = "initialize"
|
||||
|
||||
// List available tools
|
||||
methodToolsList = "tools/list"
|
||||
|
||||
// Call a specific tool
|
||||
methodToolsCall = "tools/call"
|
||||
|
||||
// List available prompts
|
||||
methodPromptsList = "prompts/list"
|
||||
|
||||
// Get a specific prompt
|
||||
methodPromptsGet = "prompts/get"
|
||||
|
||||
// List available resources
|
||||
methodResourcesList = "resources/list"
|
||||
|
||||
// Read a specific resource
|
||||
methodResourcesRead = "resources/read"
|
||||
|
||||
// Subscribe to resource updates
|
||||
methodResourcesSubscribe = "resources/subscribe"
|
||||
|
||||
// Simple ping to check server availability
|
||||
methodPing = "ping"
|
||||
|
||||
// Notification that client is fully initialized
|
||||
methodNotificationsInitialized = "notifications/initialized"
|
||||
|
||||
// Notification that a request was canceled
|
||||
methodNotificationsCancelled = "notifications/cancelled"
|
||||
)
|
||||
|
||||
// Event names for Server-Sent Events (SSE)
|
||||
const (
|
||||
// Notification of tool list changes
|
||||
eventToolsListChanged = "tools/list_changed"
|
||||
|
||||
// Notification of prompt list changes
|
||||
eventPromptsListChanged = "prompts/list_changed"
|
||||
|
||||
// Notification of resource list changes
|
||||
eventResourcesListChanged = "resources/list_changed"
|
||||
)
|
||||
|
||||
var (
|
||||
// Default channel size for events
|
||||
eventChanSize = 10
|
||||
|
||||
// Default ping interval for checking connection availability
|
||||
// use syncx.ForAtomicDuration to ensure atomicity in test race
|
||||
pingInterval = syncx.ForAtomicDuration(30 * time.Second)
|
||||
)
|
||||
210
mcp/vars_test.go
210
mcp/vars_test.go
@@ -1,210 +0,0 @@
|
||||
// filepath: /Users/kevin/Develop/go/opensource/go-zero/mcp/vars_test.go
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestErrorCodes ensures error codes are applied correctly in error responses
|
||||
func TestErrorCodes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
code int
|
||||
message string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "invalid request error",
|
||||
code: errCodeInvalidRequest,
|
||||
message: "Invalid request",
|
||||
expected: `"code":-32600`,
|
||||
},
|
||||
{
|
||||
name: "method not found error",
|
||||
code: errCodeMethodNotFound,
|
||||
message: "Method not found",
|
||||
expected: `"code":-32601`,
|
||||
},
|
||||
{
|
||||
name: "invalid params error",
|
||||
code: errCodeInvalidParams,
|
||||
message: "Invalid parameters",
|
||||
expected: `"code":-32602`,
|
||||
},
|
||||
{
|
||||
name: "internal error",
|
||||
code: errCodeInternalError,
|
||||
message: "Internal server error",
|
||||
expected: `"code":-32603`,
|
||||
},
|
||||
{
|
||||
name: "timeout error",
|
||||
code: errCodeTimeout,
|
||||
message: "Operation timed out",
|
||||
expected: `"code":-32001`,
|
||||
},
|
||||
{
|
||||
name: "resource not found error",
|
||||
code: errCodeResourceNotFound,
|
||||
message: "Resource not found",
|
||||
expected: `"code":-32002`,
|
||||
},
|
||||
{
|
||||
name: "client not initialized error",
|
||||
code: errCodeClientNotInitialized,
|
||||
message: "Client not initialized",
|
||||
expected: `"code":-32800`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: int64(1),
|
||||
Error: &errorObj{
|
||||
Code: tc.code,
|
||||
Message: tc.message,
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(resp)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), tc.expected, "Error code should match expected value")
|
||||
assert.Contains(t, string(data), tc.message, "Error message should be included")
|
||||
assert.Contains(t, string(data), jsonRpcVersion, "JSON-RPC version should be included")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJsonRpcVersion ensures the correct JSON-RPC version is used
|
||||
func TestJsonRpcVersion(t *testing.T) {
|
||||
assert.Equal(t, "2.0", jsonRpcVersion, "JSON-RPC version should be 2.0")
|
||||
|
||||
// Test that it's used in responses
|
||||
resp := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: int64(1),
|
||||
Result: "test",
|
||||
}
|
||||
data, err := json.Marshal(resp)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"jsonrpc":"2.0"`, "Response should use correct JSON-RPC version")
|
||||
|
||||
// Test that it's expected in requests
|
||||
reqStr := `{"jsonrpc":"2.0","id":1,"method":"test"}`
|
||||
var req Request
|
||||
err = json.Unmarshal([]byte(reqStr), &req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, jsonRpcVersion, req.JsonRpc, "Request should parse correct JSON-RPC version")
|
||||
}
|
||||
|
||||
// TestSessionIdKey ensures session ID extraction works correctly
|
||||
func TestSessionIdKey(t *testing.T) {
|
||||
// Create a mock server implementation
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
|
||||
// Verify the key constant
|
||||
assert.Equal(t, "session_id", sessionIdKey, "Session ID key should be 'session_id'")
|
||||
|
||||
// Test that session ID is extracted correctly
|
||||
mockR := httptest.NewRequest("GET", "/?"+sessionIdKey+"=test-session", nil)
|
||||
|
||||
// Since the mock server is using the same session key logic,
|
||||
// we can test this by accessing the request query parameters directly
|
||||
sessionID := mockR.URL.Query().Get(sessionIdKey)
|
||||
assert.Equal(t, "test-session", sessionID, "Session ID should be extracted correctly")
|
||||
}
|
||||
|
||||
// TestEventTypes ensures event types are set correctly in SSE responses
|
||||
func TestEventTypes(t *testing.T) {
|
||||
// Test message event
|
||||
assert.Equal(t, "message", eventMessage, "Message event should be 'message'")
|
||||
|
||||
// Test endpoint event
|
||||
assert.Equal(t, "endpoint", eventEndpoint, "Endpoint event should be 'endpoint'")
|
||||
|
||||
// Verify them in an actual SSE format string
|
||||
messageEvent := "event: " + eventMessage + "\ndata: test\n\n"
|
||||
assert.Contains(t, messageEvent, "event: message", "Message event should format correctly")
|
||||
|
||||
endpointEvent := "event: " + eventEndpoint + "\ndata: test\n\n"
|
||||
assert.Contains(t, endpointEvent, "event: endpoint", "Endpoint event should format correctly")
|
||||
}
|
||||
|
||||
// TestCollectionKeys checks that collection keys are used correctly
|
||||
func TestCollectionKeys(t *testing.T) {
|
||||
// Verify collection key constants
|
||||
assert.Equal(t, "prompts", keyPrompts, "Prompts key should be 'prompts'")
|
||||
assert.Equal(t, "resources", keyResources, "Resources key should be 'resources'")
|
||||
assert.Equal(t, "tools", keyTools, "Tools key should be 'tools'")
|
||||
}
|
||||
|
||||
// TestRoleTypes checks that role types are used correctly
|
||||
func TestRoleTypes(t *testing.T) {
|
||||
// Test in annotations
|
||||
annotations := Annotations{
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
}
|
||||
data, err := json.Marshal(annotations)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"audience":["user","assistant"]`, "Role types should marshal correctly")
|
||||
}
|
||||
|
||||
// TestMethodNames checks that method names are used correctly
|
||||
func TestMethodNames(t *testing.T) {
|
||||
// Verify method name constants
|
||||
methods := map[string]string{
|
||||
"initialize": methodInitialize,
|
||||
"tools/list": methodToolsList,
|
||||
"tools/call": methodToolsCall,
|
||||
"prompts/list": methodPromptsList,
|
||||
"prompts/get": methodPromptsGet,
|
||||
"resources/list": methodResourcesList,
|
||||
"resources/read": methodResourcesRead,
|
||||
"resources/subscribe": methodResourcesSubscribe,
|
||||
"ping": methodPing,
|
||||
"notifications/initialized": methodNotificationsInitialized,
|
||||
"notifications/cancelled": methodNotificationsCancelled,
|
||||
}
|
||||
|
||||
for expected, actual := range methods {
|
||||
assert.Equal(t, expected, actual, "Method name should be "+expected)
|
||||
}
|
||||
|
||||
// Test in a request
|
||||
for methodName := range methods {
|
||||
req := Request{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: int64(1),
|
||||
Method: methodName,
|
||||
}
|
||||
data, err := json.Marshal(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"method":"`+methodName+`"`, "Method name should be used in requests")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEventNames checks that event names are used correctly
|
||||
func TestEventNames(t *testing.T) {
|
||||
// Verify event name constants
|
||||
events := map[string]string{
|
||||
"tools/list_changed": eventToolsListChanged,
|
||||
"prompts/list_changed": eventPromptsListChanged,
|
||||
"resources/list_changed": eventResourcesListChanged,
|
||||
}
|
||||
|
||||
for expected, actual := range events {
|
||||
assert.Equal(t, expected, actual, "Event name should be "+expected)
|
||||
}
|
||||
|
||||
// Test event names in SSE format
|
||||
for _, eventName := range events {
|
||||
sseEvent := "event: " + eventName + "\ndata: test\n\n"
|
||||
assert.Contains(t, sseEvent, "event: "+eventName, "Event name should format correctly in SSE")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user