mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-07 06:59:59 +08:00
feat(mcp): migrate to official go-sdk with simplified API (#5362)
This commit is contained in:
3
go.mod
3
go.mod
@@ -14,6 +14,7 @@ require (
|
|||||||
github.com/grafana/pyroscope-go v1.2.7
|
github.com/grafana/pyroscope-go v1.2.7
|
||||||
github.com/jackc/pgx/v5 v5.7.4
|
github.com/jackc/pgx/v5 v5.7.4
|
||||||
github.com/jhump/protoreflect v1.17.0
|
github.com/jhump/protoreflect v1.17.0
|
||||||
|
github.com/modelcontextprotocol/go-sdk v1.2.0
|
||||||
github.com/pelletier/go-toml/v2 v2.2.4
|
github.com/pelletier/go-toml/v2 v2.2.4
|
||||||
github.com/prometheus/client_golang v1.23.2
|
github.com/prometheus/client_golang v1.23.2
|
||||||
github.com/redis/go-redis/v9 v9.17.2
|
github.com/redis/go-redis/v9 v9.17.2
|
||||||
@@ -71,6 +72,7 @@ require (
|
|||||||
github.com/google/gnostic-models v0.6.8 // indirect
|
github.com/google/gnostic-models v0.6.8 // indirect
|
||||||
github.com/google/go-cmp v0.7.0 // indirect
|
github.com/google/go-cmp v0.7.0 // indirect
|
||||||
github.com/google/gofuzz v1.2.0 // indirect
|
github.com/google/gofuzz v1.2.0 // indirect
|
||||||
|
github.com/google/jsonschema-go v0.3.0 // indirect
|
||||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.9 // indirect
|
github.com/grafana/pyroscope-go/godeltaprof v0.1.9 // indirect
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
|
||||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect
|
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect
|
||||||
@@ -98,6 +100,7 @@ require (
|
|||||||
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
||||||
github.com/xdg-go/scram v1.1.2 // indirect
|
github.com/xdg-go/scram v1.1.2 // indirect
|
||||||
github.com/xdg-go/stringprep v1.0.4 // indirect
|
github.com/xdg-go/stringprep v1.0.4 // indirect
|
||||||
|
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
||||||
github.com/yuin/gopher-lua v1.1.1 // indirect
|
github.com/yuin/gopher-lua v1.1.1 // indirect
|
||||||
go.etcd.io/etcd/client/pkg/v3 v3.5.15 // indirect
|
go.etcd.io/etcd/client/pkg/v3 v3.5.15 // indirect
|
||||||
|
|||||||
8
go.sum
8
go.sum
@@ -62,6 +62,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
|||||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
|
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||||
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
|
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
|
||||||
@@ -74,6 +76,8 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX
|
|||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||||
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
|
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
|
||||||
|
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||||
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec=
|
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec=
|
||||||
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
@@ -123,6 +127,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
|
|||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||||
|
github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s=
|
||||||
|
github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
@@ -182,6 +188,8 @@ github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY=
|
|||||||
github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4=
|
github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4=
|
||||||
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
|
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
|
||||||
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
|
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
|
||||||
|
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||||
|
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
|
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
|
||||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
|
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
|
||||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
|
|||||||
166
mcp/MIGRATION.md
Normal file
166
mcp/MIGRATION.md
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
# Migration to Official MCP SDK
|
||||||
|
|
||||||
|
This document describes the migration from the custom MCP implementation to the official [go-sdk](https://github.com/modelcontextprotocol/go-sdk).
|
||||||
|
|
||||||
|
## Changes
|
||||||
|
|
||||||
|
### Dependencies
|
||||||
|
|
||||||
|
Added the official MCP SDK:
|
||||||
|
```bash
|
||||||
|
go get github.com/modelcontextprotocol/go-sdk@v1.2.0
|
||||||
|
```
|
||||||
|
|
||||||
|
### Type System
|
||||||
|
|
||||||
|
All types are now re-exported from the official SDK:
|
||||||
|
- `Tool` → `sdkmcp.Tool`
|
||||||
|
- `CallToolRequest` → `sdkmcp.CallToolRequest`
|
||||||
|
- `CallToolResult` → `sdkmcp.CallToolResult`
|
||||||
|
- Content types (`TextContent`, `ImageContent`, etc.)
|
||||||
|
- `Prompt`, `Resource`, `Server`, `ServerSession`
|
||||||
|
|
||||||
|
### Server Interface
|
||||||
|
|
||||||
|
The `McpServer` interface has been simplified:
|
||||||
|
|
||||||
|
```go
|
||||||
|
type McpServer interface {
|
||||||
|
Start()
|
||||||
|
Stop()
|
||||||
|
Server() *sdkmcp.Server // Returns underlying SDK server
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Important**: The `AddTool`, `AddPrompt`, and `AddResource` methods have been removed. Use the SDK directly:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Old (no longer supported)
|
||||||
|
server.AddTool(tool, handler)
|
||||||
|
|
||||||
|
// New (use SDK directly)
|
||||||
|
sdkmcp.AddTool(server.Server(), tool, handler)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
|
||||||
|
Updated configuration structure:
|
||||||
|
- Removed: `ProtocolVersion`, `BaseUrl` (SDK manages these)
|
||||||
|
- Added: `UseStreamable` (choose between SSE and Streamable HTTP transport)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
mcp:
|
||||||
|
name: my-server
|
||||||
|
version: 1.0.0
|
||||||
|
useStreamable: false # false = SSE (2024-11-05), true = Streamable HTTP (2025-03-26)
|
||||||
|
sseEndpoint: /sse
|
||||||
|
messageEndpoint: /message
|
||||||
|
sseTimeout: 24h
|
||||||
|
messageTimeout: 30s
|
||||||
|
cors:
|
||||||
|
- http://localhost:3000
|
||||||
|
```
|
||||||
|
|
||||||
|
### Tool Registration
|
||||||
|
|
||||||
|
The SDK uses Go generics for type-safe tool registration:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
|
|
||||||
|
type MyArgs struct {
|
||||||
|
Value string `json:"value" jsonschema:"description=Input value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
tool := &mcp.Tool{
|
||||||
|
Name: "my_tool",
|
||||||
|
Description: "Description",
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := func(ctx context.Context, req *mcp.CallToolRequest, args MyArgs) (*mcp.CallToolResult, any, error) {
|
||||||
|
return &mcp.CallToolResult{
|
||||||
|
Content: []mcp.Content{
|
||||||
|
&mcp.TextContent{Text: "Result"},
|
||||||
|
},
|
||||||
|
}, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register with explicit type parameters
|
||||||
|
sdkmcp.AddTool(server.Server(), tool, handler)
|
||||||
|
```
|
||||||
|
|
||||||
|
The SDK automatically generates JSON schemas from struct tags.
|
||||||
|
|
||||||
|
### Transport Support
|
||||||
|
|
||||||
|
Two transports are supported:
|
||||||
|
|
||||||
|
1. **SSE (Server-Sent Events)**: 2024-11-05 MCP spec
|
||||||
|
- Default (`UseStreamable: false`)
|
||||||
|
- Endpoint: `/sse` (configurable)
|
||||||
|
- Bidirectional: client sends messages to `/message`
|
||||||
|
|
||||||
|
2. **Streamable HTTP**: 2025-03-26 MCP spec
|
||||||
|
- Opt-in (`UseStreamable: true`)
|
||||||
|
- Endpoint: `/sse` (configurable)
|
||||||
|
- Newer protocol with improved streaming
|
||||||
|
|
||||||
|
### Example Migration
|
||||||
|
|
||||||
|
**Before:**
|
||||||
|
```go
|
||||||
|
server := mcp.NewMcpServer(c)
|
||||||
|
|
||||||
|
tool := &mcp.Tool{Name: "greet", Description: "Greet"}
|
||||||
|
handler := func(ctx context.Context, req *mcp.CallToolRequest, args GreetArgs) (*mcp.CallToolResult, any, error) {
|
||||||
|
return &mcp.CallToolResult{
|
||||||
|
Content: []mcp.Content{&mcp.TextContent{Text: "Hello"}},
|
||||||
|
}, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := server.AddTool(tool, handler); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**After:**
|
||||||
|
```go
|
||||||
|
import sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
|
|
||||||
|
server := mcp.NewMcpServer(c)
|
||||||
|
|
||||||
|
tool := &mcp.Tool{Name: "greet", Description: "Greet"}
|
||||||
|
handler := func(ctx context.Context, req *mcp.CallToolRequest, args GreetArgs) (*mcp.CallToolResult, any, error) {
|
||||||
|
return &mcp.CallToolResult{
|
||||||
|
Content: []mcp.Content{&mcp.TextContent{Text: "Hello"}},
|
||||||
|
}, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use SDK directly - no error return
|
||||||
|
sdkmcp.AddTool(server.Server(), tool, handler)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Benefits
|
||||||
|
|
||||||
|
1. **Official SDK**: Uses the official Model Context Protocol SDK
|
||||||
|
2. **Type Safety**: Go generics provide compile-time type checking
|
||||||
|
3. **Auto Schema**: JSON schemas generated automatically from struct tags
|
||||||
|
4. **Dual Transport**: Supports both SSE and Streamable HTTP transports
|
||||||
|
5. **Maintained**: SDK is actively maintained by the MCP team
|
||||||
|
|
||||||
|
## Breaking Changes
|
||||||
|
|
||||||
|
1. `server.AddTool()` removed → use `sdkmcp.AddTool(server.Server(), ...)`
|
||||||
|
2. `server.AddPrompt()` removed (SDK v1.2.0 limitation)
|
||||||
|
3. `server.AddResource()` removed (SDK v1.2.0 limitation)
|
||||||
|
4. Config fields `ProtocolVersion` and `BaseUrl` removed
|
||||||
|
5. All types now come from SDK (re-exported for convenience)
|
||||||
|
|
||||||
|
## Migration Checklist
|
||||||
|
|
||||||
|
- [ ] Update imports: add `sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"`
|
||||||
|
- [ ] Replace `server.AddTool()` with `sdkmcp.AddTool(server.Server(), ...)`
|
||||||
|
- [ ] Remove error handling for tool registration (SDK doesn't return errors)
|
||||||
|
- [ ] Update config: remove `ProtocolVersion` and `BaseUrl`, add `UseStreamable`
|
||||||
|
- [ ] Test with both SSE and Streamable transports
|
||||||
|
- [ ] Update documentation/examples
|
||||||
@@ -18,17 +18,16 @@ type McpConf struct {
|
|||||||
// Version is the server version reported in initialize responses
|
// Version is the server version reported in initialize responses
|
||||||
Version string `json:",default=1.0.0"`
|
Version string `json:",default=1.0.0"`
|
||||||
|
|
||||||
// ProtocolVersion is the MCP protocol version implemented
|
// UseStreamable when true uses Streamable HTTP transport (2025-03-26 spec),
|
||||||
ProtocolVersion string `json:",default=2024-11-05"`
|
// otherwise uses SSE transport (2024-11-05 spec)
|
||||||
|
UseStreamable bool `json:",default=false"`
|
||||||
// BaseUrl is the base URL for the server, used in SSE endpoint messages
|
|
||||||
// If not set, defaults to http://localhost:{Port}
|
|
||||||
BaseUrl string `json:",optional"`
|
|
||||||
|
|
||||||
// SseEndpoint is the path for Server-Sent Events connections
|
// SseEndpoint is the path for Server-Sent Events connections
|
||||||
|
// Used for SSE transport mode
|
||||||
SseEndpoint string `json:",default=/sse"`
|
SseEndpoint string `json:",default=/sse"`
|
||||||
|
|
||||||
// MessageEndpoint is the path for JSON-RPC requests
|
// MessageEndpoint is the path for JSON-RPC requests
|
||||||
|
// Used for Streamable HTTP transport mode
|
||||||
MessageEndpoint string `json:",default=/message"`
|
MessageEndpoint string `json:",default=/message"`
|
||||||
|
|
||||||
// Cors contains allowed CORS origins
|
// Cors contains allowed CORS origins
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestMcpConfDefaults(t *testing.T) {
|
func TestMcpConfDefaults(t *testing.T) {
|
||||||
// Test default values are set correctly when unmarshalled from JSON
|
// Test default values are set correctly
|
||||||
jsonConfig := `name: test-service
|
jsonConfig := `name: test-service
|
||||||
port: 8080
|
port: 8080
|
||||||
mcp:
|
mcp:
|
||||||
@@ -23,41 +23,8 @@ mcp:
|
|||||||
|
|
||||||
// Check default values
|
// Check default values
|
||||||
assert.Equal(t, "test-mcp-server", c.Mcp.Name)
|
assert.Equal(t, "test-mcp-server", c.Mcp.Name)
|
||||||
assert.Equal(t, "1.0.0", c.Mcp.Version, "Default version should be 1.0.0")
|
assert.Equal(t, "1.0.0", c.Mcp.Version)
|
||||||
assert.Equal(t, "2024-11-05", c.Mcp.ProtocolVersion, "Default protocol version should be 2024-11-05")
|
assert.Equal(t, "/sse", c.Mcp.SseEndpoint)
|
||||||
assert.Equal(t, "/sse", c.Mcp.SseEndpoint, "Default SSE endpoint should be /sse")
|
assert.Equal(t, "/message", c.Mcp.MessageEndpoint)
|
||||||
assert.Equal(t, "/message", c.Mcp.MessageEndpoint, "Default message endpoint should be /message")
|
assert.Equal(t, 30*time.Second, c.Mcp.MessageTimeout)
|
||||||
assert.Equal(t, 30*time.Second, c.Mcp.MessageTimeout, "Default message timeout should be 30s")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMcpConfCustomValues(t *testing.T) {
|
|
||||||
// Test custom values can be set
|
|
||||||
jsonConfig := `{
|
|
||||||
"Name": "test-service",
|
|
||||||
"Port": 8080,
|
|
||||||
"Mcp": {
|
|
||||||
"Name": "test-mcp-server",
|
|
||||||
"Version": "2.0.0",
|
|
||||||
"ProtocolVersion": "2025-01-01",
|
|
||||||
"BaseUrl": "http://example.com",
|
|
||||||
"SseEndpoint": "/custom-sse",
|
|
||||||
"MessageEndpoint": "/custom-message",
|
|
||||||
"Cors": ["http://localhost:3000", "http://example.com"],
|
|
||||||
"MessageTimeout": "60s"
|
|
||||||
}
|
|
||||||
}`
|
|
||||||
|
|
||||||
var c McpConf
|
|
||||||
err := conf.LoadFromJsonBytes([]byte(jsonConfig), &c)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
// Check custom values
|
|
||||||
assert.Equal(t, "test-mcp-server", c.Mcp.Name, "Name should be inherited from RestConf")
|
|
||||||
assert.Equal(t, "2.0.0", c.Mcp.Version, "Version should be customizable")
|
|
||||||
assert.Equal(t, "2025-01-01", c.Mcp.ProtocolVersion, "Protocol version should be customizable")
|
|
||||||
assert.Equal(t, "http://example.com", c.Mcp.BaseUrl, "BaseUrl should be customizable")
|
|
||||||
assert.Equal(t, "/custom-sse", c.Mcp.SseEndpoint, "SSE endpoint should be customizable")
|
|
||||||
assert.Equal(t, "/custom-message", c.Mcp.MessageEndpoint, "Message endpoint should be customizable")
|
|
||||||
assert.Equal(t, []string{"http://localhost:3000", "http://example.com"}, c.Mcp.Cors, "CORS settings should be customizable")
|
|
||||||
assert.Equal(t, 60*time.Second, c.Mcp.MessageTimeout, "Tool timeout should be customizable")
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,443 +0,0 @@
|
|||||||
package mcp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
// syncResponseRecorder is a thread-safe wrapper around httptest.ResponseRecorder
|
|
||||||
type syncResponseRecorder struct {
|
|
||||||
*httptest.ResponseRecorder
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new synchronized response recorder
|
|
||||||
func newSyncResponseRecorder() *syncResponseRecorder {
|
|
||||||
return &syncResponseRecorder{
|
|
||||||
ResponseRecorder: httptest.NewRecorder(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Override Write method to synchronize access
|
|
||||||
func (srr *syncResponseRecorder) Write(p []byte) (int, error) {
|
|
||||||
srr.mu.Lock()
|
|
||||||
defer srr.mu.Unlock()
|
|
||||||
return srr.ResponseRecorder.Write(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Override WriteHeader method to synchronize access
|
|
||||||
func (srr *syncResponseRecorder) WriteHeader(statusCode int) {
|
|
||||||
srr.mu.Lock()
|
|
||||||
defer srr.mu.Unlock()
|
|
||||||
srr.ResponseRecorder.WriteHeader(statusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Override Result method to synchronize access
|
|
||||||
func (srr *syncResponseRecorder) Result() *http.Response {
|
|
||||||
srr.mu.Lock()
|
|
||||||
defer srr.mu.Unlock()
|
|
||||||
return srr.ResponseRecorder.Result()
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestHTTPHandlerIntegration tests the HTTP handlers with a real server instance
|
|
||||||
func TestHTTPHandlerIntegration(t *testing.T) {
|
|
||||||
// Skip in short test mode
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("Skipping integration test in short mode")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a test configuration
|
|
||||||
conf := McpConf{}
|
|
||||||
conf.Mcp.Name = "test-integration"
|
|
||||||
conf.Mcp.Version = "1.0.0-test"
|
|
||||||
conf.Mcp.MessageTimeout = 1 * time.Second
|
|
||||||
|
|
||||||
// Create a mock server directly
|
|
||||||
server := &sseMcpServer{
|
|
||||||
conf: conf,
|
|
||||||
clients: make(map[string]*mcpClient),
|
|
||||||
tools: make(map[string]Tool),
|
|
||||||
prompts: make(map[string]Prompt),
|
|
||||||
resources: make(map[string]Resource),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register a test tool
|
|
||||||
err := server.RegisterTool(Tool{
|
|
||||||
Name: "echo",
|
|
||||||
Description: "Echo tool for testing",
|
|
||||||
InputSchema: InputSchema{
|
|
||||||
Properties: map[string]any{
|
|
||||||
"message": map[string]any{
|
|
||||||
"type": "string",
|
|
||||||
"description": "Message to echo",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
|
||||||
if msg, ok := params["message"].(string); ok {
|
|
||||||
return fmt.Sprintf("Echo: %s", msg), nil
|
|
||||||
}
|
|
||||||
return "Echo: no message provided", nil
|
|
||||||
},
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Create a test HTTP request to the SSE endpoint
|
|
||||||
req := httptest.NewRequest("GET", "/sse", nil)
|
|
||||||
w := newSyncResponseRecorder()
|
|
||||||
|
|
||||||
// Create a done channel to signal completion of test
|
|
||||||
done := make(chan bool)
|
|
||||||
|
|
||||||
// Start the SSE handler in a goroutine
|
|
||||||
go func() {
|
|
||||||
// lock.Lock()
|
|
||||||
server.handleSSE(w, req)
|
|
||||||
// lock.Unlock()
|
|
||||||
done <- true
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Allow time for the handler to process
|
|
||||||
select {
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
// Expected - handler would normally block indefinitely
|
|
||||||
case <-done:
|
|
||||||
// This shouldn't happen immediately - the handler should block
|
|
||||||
t.Error("SSE handler returned unexpectedly")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check the initial headers
|
|
||||||
resp := w.Result()
|
|
||||||
assert.Equal(t, "chunked", resp.Header.Get("Transfer-Encoding"))
|
|
||||||
resp.Body.Close()
|
|
||||||
|
|
||||||
// The handler creates a client and sends the endpoint message
|
|
||||||
var sessionId string
|
|
||||||
|
|
||||||
// Give the handler time to set up the client
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
|
|
||||||
// Check that a client was created
|
|
||||||
server.clientsLock.Lock()
|
|
||||||
assert.Equal(t, 1, len(server.clients))
|
|
||||||
for id := range server.clients {
|
|
||||||
sessionId = id
|
|
||||||
}
|
|
||||||
server.clientsLock.Unlock()
|
|
||||||
|
|
||||||
require.NotEmpty(t, sessionId, "Expected a session ID to be created")
|
|
||||||
|
|
||||||
// Now that we have a session ID, we can test the message endpoint
|
|
||||||
messageBody, _ := json.Marshal(Request{
|
|
||||||
JsonRpc: "2.0",
|
|
||||||
ID: 1,
|
|
||||||
Method: methodInitialize,
|
|
||||||
Params: json.RawMessage(`{}`),
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create a message request
|
|
||||||
reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, sessionId)
|
|
||||||
msgReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(messageBody))
|
|
||||||
msgW := newSyncResponseRecorder()
|
|
||||||
|
|
||||||
// Process the message
|
|
||||||
server.handleRequest(msgW, msgReq)
|
|
||||||
|
|
||||||
// Check the response
|
|
||||||
msgResp := msgW.Result()
|
|
||||||
assert.Equal(t, http.StatusAccepted, msgResp.StatusCode)
|
|
||||||
msgResp.Body.Close() // Ensure response body is closed
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestHandlerResponseFlow tests the flow of a full request/response cycle
|
|
||||||
func TestHandlerResponseFlow(t *testing.T) {
|
|
||||||
// Create a mock server for testing
|
|
||||||
server := &sseMcpServer{
|
|
||||||
conf: McpConf{},
|
|
||||||
clients: map[string]*mcpClient{
|
|
||||||
"test-session": {
|
|
||||||
id: "test-session",
|
|
||||||
channel: make(chan string, 10),
|
|
||||||
initialized: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
tools: make(map[string]Tool),
|
|
||||||
prompts: make(map[string]Prompt),
|
|
||||||
resources: make(map[string]Resource),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register test resources
|
|
||||||
server.RegisterTool(Tool{
|
|
||||||
Name: "test.tool",
|
|
||||||
Description: "Test tool",
|
|
||||||
InputSchema: InputSchema{Type: "object"},
|
|
||||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
|
||||||
return "tool result", nil
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
server.RegisterPrompt(Prompt{
|
|
||||||
Name: "test.prompt",
|
|
||||||
Description: "Test prompt",
|
|
||||||
})
|
|
||||||
|
|
||||||
server.RegisterResource(Resource{
|
|
||||||
Name: "test.resource",
|
|
||||||
URI: "http://example.com",
|
|
||||||
Description: "Test resource",
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create a request with session ID parameter
|
|
||||||
reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, "test-session")
|
|
||||||
|
|
||||||
// Test tools/list request
|
|
||||||
toolsListBody, _ := json.Marshal(Request{
|
|
||||||
JsonRpc: "2.0",
|
|
||||||
ID: 1,
|
|
||||||
Method: methodToolsList,
|
|
||||||
Params: json.RawMessage(`{}`),
|
|
||||||
})
|
|
||||||
|
|
||||||
toolsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(toolsListBody))
|
|
||||||
toolsW := newSyncResponseRecorder()
|
|
||||||
|
|
||||||
// Process the request
|
|
||||||
server.handleRequest(toolsW, toolsReq)
|
|
||||||
|
|
||||||
// Check the response code
|
|
||||||
toolsResp := toolsW.Result()
|
|
||||||
assert.Equal(t, http.StatusAccepted, toolsResp.StatusCode)
|
|
||||||
toolsResp.Body.Close()
|
|
||||||
|
|
||||||
// Check the channel message
|
|
||||||
client := server.clients["test-session"]
|
|
||||||
select {
|
|
||||||
case message := <-client.channel:
|
|
||||||
assert.Contains(t, message, `"tools":[{"name":"test.tool"`)
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Fatal("Timed out waiting for tools/list response")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test prompts/list request
|
|
||||||
promptsListBody, _ := json.Marshal(Request{
|
|
||||||
JsonRpc: "2.0",
|
|
||||||
ID: 2,
|
|
||||||
Method: methodPromptsList,
|
|
||||||
Params: json.RawMessage(`{}`),
|
|
||||||
})
|
|
||||||
|
|
||||||
promptsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(promptsListBody))
|
|
||||||
promptsW := newSyncResponseRecorder()
|
|
||||||
|
|
||||||
// Process the request
|
|
||||||
server.handleRequest(promptsW, promptsReq)
|
|
||||||
|
|
||||||
// Check the response code
|
|
||||||
promptsResp := promptsW.Result()
|
|
||||||
assert.Equal(t, http.StatusAccepted, promptsResp.StatusCode)
|
|
||||||
promptsResp.Body.Close()
|
|
||||||
|
|
||||||
// Check the channel message
|
|
||||||
select {
|
|
||||||
case message := <-client.channel:
|
|
||||||
assert.Contains(t, message, `"prompts":[{"name":"test.prompt"`)
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Fatal("Timed out waiting for prompts/list response")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test resources/list request
|
|
||||||
resourcesListBody, _ := json.Marshal(Request{
|
|
||||||
JsonRpc: "2.0",
|
|
||||||
ID: 3,
|
|
||||||
Method: methodResourcesList,
|
|
||||||
Params: json.RawMessage(`{}`),
|
|
||||||
})
|
|
||||||
|
|
||||||
resourcesReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(resourcesListBody))
|
|
||||||
resourcesW := newSyncResponseRecorder()
|
|
||||||
|
|
||||||
// Process the request
|
|
||||||
server.handleRequest(resourcesW, resourcesReq)
|
|
||||||
|
|
||||||
// Check the response code
|
|
||||||
resourcesResp := resourcesW.Result()
|
|
||||||
assert.Equal(t, http.StatusAccepted, resourcesResp.StatusCode)
|
|
||||||
resourcesResp.Body.Close()
|
|
||||||
|
|
||||||
// Check the channel message
|
|
||||||
select {
|
|
||||||
case message := <-client.channel:
|
|
||||||
assert.Contains(t, message, `"name":"test.resource"`)
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Fatal("Timed out waiting for resources/list response")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestProcessListMethods tests the list processing methods with pagination
|
|
||||||
func TestProcessListMethods(t *testing.T) {
|
|
||||||
server := &sseMcpServer{
|
|
||||||
tools: make(map[string]Tool),
|
|
||||||
prompts: make(map[string]Prompt),
|
|
||||||
resources: make(map[string]Resource),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add some test data
|
|
||||||
for i := 1; i <= 5; i++ {
|
|
||||||
tool := Tool{
|
|
||||||
Name: fmt.Sprintf("tool%d", i),
|
|
||||||
Description: fmt.Sprintf("Tool %d", i),
|
|
||||||
InputSchema: InputSchema{Type: "object"},
|
|
||||||
}
|
|
||||||
server.tools[tool.Name] = tool
|
|
||||||
|
|
||||||
prompt := Prompt{
|
|
||||||
Name: fmt.Sprintf("prompt%d", i),
|
|
||||||
Description: fmt.Sprintf("Prompt %d", i),
|
|
||||||
}
|
|
||||||
server.prompts[prompt.Name] = prompt
|
|
||||||
|
|
||||||
resource := Resource{
|
|
||||||
Name: fmt.Sprintf("resource%d", i),
|
|
||||||
URI: fmt.Sprintf("http://example.com/%d", i),
|
|
||||||
Description: fmt.Sprintf("Resource %d", i),
|
|
||||||
}
|
|
||||||
server.resources[resource.Name] = resource
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a test client
|
|
||||||
client := &mcpClient{
|
|
||||||
id: "test-client",
|
|
||||||
channel: make(chan string, 10),
|
|
||||||
initialized: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test processListTools
|
|
||||||
req := Request{
|
|
||||||
JsonRpc: "2.0",
|
|
||||||
ID: 1,
|
|
||||||
Method: methodToolsList,
|
|
||||||
Params: json.RawMessage(`{"cursor": "", "_meta": {"progressToken": "token1"}}`),
|
|
||||||
}
|
|
||||||
|
|
||||||
server.processListTools(context.Background(), client, req)
|
|
||||||
|
|
||||||
// Read response
|
|
||||||
select {
|
|
||||||
case response := <-client.channel:
|
|
||||||
assert.Contains(t, response, `"tools":`)
|
|
||||||
assert.Contains(t, response, `"progressToken":"token1"`)
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Fatal("Timed out waiting for tools/list response")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test processListPrompts
|
|
||||||
req.ID = 2
|
|
||||||
req.Method = methodPromptsList
|
|
||||||
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
|
||||||
server.processListPrompts(context.Background(), client, req)
|
|
||||||
|
|
||||||
// Read response
|
|
||||||
select {
|
|
||||||
case response := <-client.channel:
|
|
||||||
assert.Contains(t, response, `"prompts":`)
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Fatal("Timed out waiting for prompts/list response")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test processListResources
|
|
||||||
req.ID = 3
|
|
||||||
req.Method = methodResourcesList
|
|
||||||
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
|
||||||
server.processListResources(context.Background(), client, req)
|
|
||||||
|
|
||||||
// Read response
|
|
||||||
select {
|
|
||||||
case response := <-client.channel:
|
|
||||||
assert.Contains(t, response, `"resources":`)
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Fatal("Timed out waiting for resources/list response")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestErrorResponseHandling tests error handling in the server
|
|
||||||
func TestErrorResponseHandling(t *testing.T) {
|
|
||||||
server := &sseMcpServer{
|
|
||||||
tools: make(map[string]Tool),
|
|
||||||
prompts: make(map[string]Prompt),
|
|
||||||
resources: make(map[string]Resource),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a test client
|
|
||||||
client := &mcpClient{
|
|
||||||
id: "test-client",
|
|
||||||
channel: make(chan string, 10),
|
|
||||||
initialized: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test invalid method
|
|
||||||
req := Request{
|
|
||||||
JsonRpc: "2.0",
|
|
||||||
ID: 1,
|
|
||||||
Method: "invalid_method",
|
|
||||||
Params: json.RawMessage(`{}`),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mock handleRequest by directly calling error handler
|
|
||||||
server.sendErrorResponse(context.Background(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
|
||||||
|
|
||||||
// Check response
|
|
||||||
select {
|
|
||||||
case response := <-client.channel:
|
|
||||||
assert.Contains(t, response, `"error":{"code":-32601,"message":"Method not found"}`)
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Fatal("Timed out waiting for error response")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test invalid tool
|
|
||||||
toolReq := Request{
|
|
||||||
JsonRpc: "2.0",
|
|
||||||
ID: 2,
|
|
||||||
Method: methodToolsCall,
|
|
||||||
Params: json.RawMessage(`{"name":"non_existent_tool"}`),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Call process method directly
|
|
||||||
server.processToolCall(context.Background(), client, toolReq)
|
|
||||||
|
|
||||||
// Check response
|
|
||||||
select {
|
|
||||||
case response := <-client.channel:
|
|
||||||
assert.Contains(t, response, `"error":{"code":-32602,"message":"Tool 'non_existent_tool' not found"}`)
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Fatal("Timed out waiting for error response")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test invalid prompt
|
|
||||||
promptReq := Request{
|
|
||||||
JsonRpc: "2.0",
|
|
||||||
ID: 3,
|
|
||||||
Method: methodPromptsGet,
|
|
||||||
Params: json.RawMessage(`{"name":"non_existent_prompt"}`),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Call process method directly
|
|
||||||
server.processGetPrompt(context.Background(), client, promptReq)
|
|
||||||
|
|
||||||
// Check response
|
|
||||||
select {
|
|
||||||
case response := <-client.channel:
|
|
||||||
assert.Contains(t, response, `"error":{"code":-32602,"message":"Prompt 'non_existent_prompt' not found"}`)
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Fatal("Timed out waiting for error response")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
package mcp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/mapping"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ParseArguments parses the arguments and populates the request object
|
|
||||||
func ParseArguments(args any, req any) error {
|
|
||||||
switch arguments := args.(type) {
|
|
||||||
case map[string]string:
|
|
||||||
m := make(map[string]any, len(arguments))
|
|
||||||
for k, v := range arguments {
|
|
||||||
m[k] = v
|
|
||||||
}
|
|
||||||
return mapping.UnmarshalJsonMap(m, req, mapping.WithStringValues())
|
|
||||||
case map[string]any:
|
|
||||||
return mapping.UnmarshalJsonMap(arguments, req)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unsupported argument type: %T", arguments)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,139 +0,0 @@
|
|||||||
package mcp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestParseArguments_MapStringString tests parsing map[string]string arguments
|
|
||||||
func TestParseArguments_MapStringString(t *testing.T) {
|
|
||||||
// Sample request struct to populate
|
|
||||||
type requestStruct struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Count int `json:"count"`
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create test arguments
|
|
||||||
args := map[string]string{
|
|
||||||
"name": "test-name",
|
|
||||||
"message": "hello world",
|
|
||||||
"count": "42",
|
|
||||||
"enabled": "true",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a target object to populate
|
|
||||||
var req requestStruct
|
|
||||||
|
|
||||||
// Parse the arguments
|
|
||||||
err := ParseArguments(args, &req)
|
|
||||||
|
|
||||||
// Verify results
|
|
||||||
assert.NoError(t, err, "Should parse map[string]string without error")
|
|
||||||
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
|
|
||||||
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
|
|
||||||
assert.Equal(t, 42, req.Count, "Count should be correctly parsed to int")
|
|
||||||
assert.True(t, req.Enabled, "Enabled should be correctly parsed to bool")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestParseArguments_MapStringAny tests parsing map[string]any arguments
|
|
||||||
func TestParseArguments_MapStringAny(t *testing.T) {
|
|
||||||
// Sample request struct to populate
|
|
||||||
type requestStruct struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Count int `json:"count"`
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
Tags []string `json:"tags"`
|
|
||||||
Metadata map[string]string `json:"metadata"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create test arguments with mixed types
|
|
||||||
args := map[string]any{
|
|
||||||
"name": "test-name",
|
|
||||||
"message": "hello world",
|
|
||||||
"count": 42, // note: this is already an int
|
|
||||||
"enabled": true, // note: this is already a bool
|
|
||||||
"tags": []string{"tag1", "tag2"},
|
|
||||||
"metadata": map[string]string{
|
|
||||||
"key1": "value1",
|
|
||||||
"key2": "value2",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a target object to populate
|
|
||||||
var req requestStruct
|
|
||||||
|
|
||||||
// Parse the arguments
|
|
||||||
err := ParseArguments(args, &req)
|
|
||||||
|
|
||||||
// Verify results
|
|
||||||
assert.NoError(t, err, "Should parse map[string]any without error")
|
|
||||||
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
|
|
||||||
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
|
|
||||||
assert.Equal(t, 42, req.Count, "Count should be correctly parsed")
|
|
||||||
assert.True(t, req.Enabled, "Enabled should be correctly parsed")
|
|
||||||
assert.Equal(t, []string{"tag1", "tag2"}, req.Tags, "Tags should be correctly parsed")
|
|
||||||
assert.Equal(t, map[string]string{
|
|
||||||
"key1": "value1",
|
|
||||||
"key2": "value2",
|
|
||||||
}, req.Metadata, "Metadata should be correctly parsed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestParseArguments_UnsupportedType tests parsing with an unsupported type
|
|
||||||
func TestParseArguments_UnsupportedType(t *testing.T) {
|
|
||||||
// Sample request struct to populate
|
|
||||||
type requestStruct struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use an unsupported argument type (slice)
|
|
||||||
args := []string{"not", "a", "map"}
|
|
||||||
|
|
||||||
// Create a target object to populate
|
|
||||||
var req requestStruct
|
|
||||||
|
|
||||||
// Parse the arguments
|
|
||||||
err := ParseArguments(args, &req)
|
|
||||||
|
|
||||||
// Verify error is returned with correct message
|
|
||||||
assert.Error(t, err, "Should return error for unsupported type")
|
|
||||||
assert.Contains(t, err.Error(), "unsupported argument type", "Error should mention unsupported type")
|
|
||||||
assert.Contains(t, err.Error(), "[]string", "Error should include the actual type")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestParseArguments_EmptyMap tests parsing with empty maps
|
|
||||||
func TestParseArguments_EmptyMap(t *testing.T) {
|
|
||||||
// Sample request struct to populate
|
|
||||||
type requestStruct struct {
|
|
||||||
Name string `json:"name,optional"`
|
|
||||||
Message string `json:"message,optional"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test empty map[string]string
|
|
||||||
t.Run("EmptyMapStringString", func(t *testing.T) {
|
|
||||||
args := map[string]string{}
|
|
||||||
var req requestStruct
|
|
||||||
|
|
||||||
err := ParseArguments(args, &req)
|
|
||||||
|
|
||||||
assert.NoError(t, err, "Should parse empty map[string]string without error")
|
|
||||||
assert.Empty(t, req.Name, "Name should be empty string")
|
|
||||||
assert.Empty(t, req.Message, "Message should be empty string")
|
|
||||||
})
|
|
||||||
|
|
||||||
// Test empty map[string]any
|
|
||||||
t.Run("EmptyMapStringAny", func(t *testing.T) {
|
|
||||||
args := map[string]any{}
|
|
||||||
var req requestStruct
|
|
||||||
|
|
||||||
err := ParseArguments(args, &req)
|
|
||||||
|
|
||||||
assert.NoError(t, err, "Should parse empty map[string]any without error")
|
|
||||||
assert.Empty(t, req.Name, "Name should be empty string")
|
|
||||||
assert.Empty(t, req.Message, "Message should be empty string")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
1012
mcp/readme.md
1012
mcp/readme.md
File diff suppressed because it is too large
Load Diff
990
mcp/server.go
990
mcp/server.go
File diff suppressed because it is too large
Load Diff
3691
mcp/server_test.go
3691
mcp/server_test.go
File diff suppressed because it is too large
Load Diff
392
mcp/types.go
392
mcp/types.go
@@ -2,316 +2,96 @@ package mcp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/rest"
|
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Cursor is an opaque token used for pagination
|
// Re-export commonly used SDK types for convenience
|
||||||
type Cursor string
|
type (
|
||||||
|
// Tool types
|
||||||
|
Tool = sdkmcp.Tool
|
||||||
|
CallToolParams = sdkmcp.CallToolParams
|
||||||
|
CallToolResult = sdkmcp.CallToolResult
|
||||||
|
CallToolRequest = sdkmcp.CallToolRequest
|
||||||
|
|
||||||
// Request represents a generic MCP request following JSON-RPC 2.0 specification
|
// Content types
|
||||||
type Request struct {
|
Content = sdkmcp.Content
|
||||||
SessionId string `form:"session_id"` // Session identifier for client tracking
|
TextContent = sdkmcp.TextContent
|
||||||
JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec
|
ImageContent = sdkmcp.ImageContent
|
||||||
ID any `json:"id"` // Request identifier for matching responses
|
AudioContent = sdkmcp.AudioContent
|
||||||
Method string `json:"method"` // Method name to invoke
|
|
||||||
Params json.RawMessage `json:"params"` // Parameters for the method
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r Request) isNotification() (bool, error) {
|
// Prompt types
|
||||||
switch val := r.ID.(type) {
|
Prompt = sdkmcp.Prompt
|
||||||
case int:
|
PromptMessage = sdkmcp.PromptMessage
|
||||||
return val == 0, nil
|
GetPromptParams = sdkmcp.GetPromptParams
|
||||||
case int64:
|
GetPromptResult = sdkmcp.GetPromptResult
|
||||||
return val == 0, nil
|
|
||||||
case float64:
|
// Resource types
|
||||||
return val == 0.0, nil
|
Resource = sdkmcp.Resource
|
||||||
case string:
|
ResourceContents = sdkmcp.ResourceContents
|
||||||
return len(val) == 0, nil
|
ReadResourceParams = sdkmcp.ReadResourceParams
|
||||||
case nil:
|
ReadResourceResult = sdkmcp.ReadResourceResult
|
||||||
return true, nil
|
|
||||||
default:
|
// Session and server types
|
||||||
return false, fmt.Errorf("invalid type %T", val)
|
Server = sdkmcp.Server
|
||||||
|
ServerSession = sdkmcp.ServerSession
|
||||||
|
ServerOptions = sdkmcp.ServerOptions
|
||||||
|
Implementation = sdkmcp.Implementation
|
||||||
|
|
||||||
|
// Transport types
|
||||||
|
SSEHandler = sdkmcp.SSEHandler
|
||||||
|
StreamableHTTPHandler = sdkmcp.StreamableHTTPHandler
|
||||||
|
)
|
||||||
|
|
||||||
|
// ToolHandler is a generic function signature for tool handlers.
|
||||||
|
// Handlers should accept context, request, and typed arguments, and return
|
||||||
|
// a result, metadata, and error.
|
||||||
|
//
|
||||||
|
// Deprecated: Use ToolHandlerFor directly from the SDK types.
|
||||||
|
type ToolHandler[Args any, Meta any] func(
|
||||||
|
ctx context.Context,
|
||||||
|
req *CallToolRequest,
|
||||||
|
args Args,
|
||||||
|
) (*CallToolResult, Meta, error)
|
||||||
|
|
||||||
|
// PromptHandler is a function signature for prompt handlers.
|
||||||
|
type PromptHandler func(
|
||||||
|
ctx context.Context,
|
||||||
|
req *sdkmcp.GetPromptRequest,
|
||||||
|
args map[string]string,
|
||||||
|
) (*GetPromptResult, error)
|
||||||
|
|
||||||
|
// ResourceHandler is a function signature for resource handlers.
|
||||||
|
type ResourceHandler func(
|
||||||
|
ctx context.Context,
|
||||||
|
req *sdkmcp.ReadResourceRequest,
|
||||||
|
uri string,
|
||||||
|
) (*ReadResourceResult, error)
|
||||||
|
|
||||||
|
// AddTool registers a tool with the MCP server using type-safe generics.
|
||||||
|
// The SDK automatically generates JSON schema from the Args struct tags.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// type GreetArgs struct {
|
||||||
|
// Name string `json:"name" jsonschema:"description=Name to greet"`
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// tool := &mcp.Tool{
|
||||||
|
// Name: "greet",
|
||||||
|
// Description: "Greet someone",
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// handler := func(ctx context.Context, req *mcp.CallToolRequest, args GreetArgs) (*mcp.CallToolResult, any, error) {
|
||||||
|
// return &mcp.CallToolResult{
|
||||||
|
// Content: []mcp.Content{&mcp.TextContent{Text: "Hello " + args.Name}},
|
||||||
|
// }, nil, nil
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// mcp.AddTool(server, tool, handler)
|
||||||
|
func AddTool[In, Out any](server McpServer, tool *Tool, handler func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error)) {
|
||||||
|
// Access internal server - only works with mcpServerImpl
|
||||||
|
if impl, ok := server.(*mcpServerImpl); ok {
|
||||||
|
sdkmcp.AddTool(impl.mcpServer, tool, handler)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type PaginatedParams struct {
|
|
||||||
Cursor string `json:"cursor"`
|
|
||||||
Meta struct {
|
|
||||||
ProgressToken any `json:"progressToken"`
|
|
||||||
} `json:"_meta"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Result is the base interface for all results
|
|
||||||
type Result struct {
|
|
||||||
Meta map[string]any `json:"_meta,omitempty"` // Optional metadata
|
|
||||||
}
|
|
||||||
|
|
||||||
// PaginatedResult is a base for results that support pagination
|
|
||||||
type PaginatedResult struct {
|
|
||||||
Result
|
|
||||||
NextCursor Cursor `json:"nextCursor,omitempty"` // Opaque token for fetching next page
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListToolsResult represents the response to a tools/list request
|
|
||||||
type ListToolsResult struct {
|
|
||||||
PaginatedResult
|
|
||||||
Tools []Tool `json:"tools"` // List of available tools
|
|
||||||
}
|
|
||||||
|
|
||||||
// Message Content Types
|
|
||||||
|
|
||||||
// RoleType represents the sender or recipient of messages in a conversation
|
|
||||||
type RoleType string
|
|
||||||
|
|
||||||
// PromptArgument defines a single argument that can be passed to a prompt
|
|
||||||
type PromptArgument struct {
|
|
||||||
Name string `json:"name"` // Argument name
|
|
||||||
Description string `json:"description,omitempty"` // Human-readable description
|
|
||||||
Required bool `json:"required,omitempty"` // Whether this argument is required
|
|
||||||
}
|
|
||||||
|
|
||||||
// PromptHandler is a function that dynamically generates prompt content
|
|
||||||
type PromptHandler func(ctx context.Context, args map[string]string) ([]PromptMessage, error)
|
|
||||||
|
|
||||||
// Prompt represents an MCP Prompt definition
|
|
||||||
type Prompt struct {
|
|
||||||
Name string `json:"name"` // Unique identifier for the prompt
|
|
||||||
Description string `json:"description,omitempty"` // Human-readable description
|
|
||||||
Arguments []PromptArgument `json:"arguments,omitempty"` // Arguments for customization
|
|
||||||
Content string `json:"-"` // Static content (internal use only)
|
|
||||||
Handler PromptHandler `json:"-"` // Handler for dynamic content generation
|
|
||||||
}
|
|
||||||
|
|
||||||
// PromptMessage represents a message in a conversation
|
|
||||||
type PromptMessage struct {
|
|
||||||
Role RoleType `json:"role"` // Message sender role
|
|
||||||
Content any `json:"content"` // Message content (TextContent, ImageContent, etc.)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TextContent represents text content in a message
|
|
||||||
type TextContent struct {
|
|
||||||
Text string `json:"text"` // The text content
|
|
||||||
Annotations *Annotations `json:"annotations,omitempty"` // Optional annotations
|
|
||||||
}
|
|
||||||
|
|
||||||
type typedTextContent struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
TextContent
|
|
||||||
}
|
|
||||||
|
|
||||||
// ImageContent represents image data in a message
|
|
||||||
type ImageContent struct {
|
|
||||||
Data string `json:"data"` // Base64-encoded image data
|
|
||||||
MimeType string `json:"mimeType"` // MIME type (e.g., "image/png")
|
|
||||||
}
|
|
||||||
|
|
||||||
type typedImageContent struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
ImageContent
|
|
||||||
}
|
|
||||||
|
|
||||||
// AudioContent represents audio data in a message
|
|
||||||
type AudioContent struct {
|
|
||||||
Data string `json:"data"` // Base64-encoded audio data
|
|
||||||
MimeType string `json:"mimeType"` // MIME type (e.g., "audio/mp3")
|
|
||||||
}
|
|
||||||
|
|
||||||
type typedAudioContent struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
AudioContent
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileContent represents file content
|
|
||||||
type FileContent struct {
|
|
||||||
URI string `json:"uri"` // URI identifying the file
|
|
||||||
MimeType string `json:"mimeType"` // MIME type of the file
|
|
||||||
Text string `json:"text"` // File content as text
|
|
||||||
}
|
|
||||||
|
|
||||||
// EmbeddedResource represents a resource embedded in a message
|
|
||||||
type EmbeddedResource struct {
|
|
||||||
Type string `json:"type"` // Always "resource"
|
|
||||||
Resource ResourceContent `json:"resource"` // The resource data
|
|
||||||
}
|
|
||||||
|
|
||||||
// Annotations provides additional metadata for content
|
|
||||||
type Annotations struct {
|
|
||||||
Audience []RoleType `json:"audience,omitempty"` // Who should see this content
|
|
||||||
Priority *float64 `json:"priority,omitempty"` // Optional priority (0-1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tool-related Types
|
|
||||||
|
|
||||||
// ToolHandler is a function that handles tool calls
|
|
||||||
type ToolHandler func(ctx context.Context, params map[string]any) (any, error)
|
|
||||||
|
|
||||||
// Tool represents a Model Context Protocol Tool definition
|
|
||||||
type Tool struct {
|
|
||||||
Name string `json:"name"` // Unique identifier for the tool
|
|
||||||
Description string `json:"description"` // Human-readable description
|
|
||||||
InputSchema InputSchema `json:"inputSchema"` // JSON Schema for parameters
|
|
||||||
Handler ToolHandler `json:"-"` // Not sent to clients
|
|
||||||
}
|
|
||||||
|
|
||||||
// InputSchema represents tool's input schema in JSON Schema format
|
|
||||||
type InputSchema struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Properties map[string]any `json:"properties"` // Property definitions
|
|
||||||
Required []string `json:"required,omitempty"` // List of required properties
|
|
||||||
}
|
|
||||||
|
|
||||||
// CallToolResult represents a tool call result that conforms to the MCP schema
|
|
||||||
type CallToolResult struct {
|
|
||||||
Result
|
|
||||||
Content []any `json:"content"` // Content items (text, images, etc.)
|
|
||||||
IsError bool `json:"isError,omitempty"` // True if tool execution failed
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resource represents a Model Context Protocol Resource definition
|
|
||||||
type Resource struct {
|
|
||||||
URI string `json:"uri"` // Unique resource identifier (RFC3986)
|
|
||||||
Name string `json:"name"` // Human-readable name
|
|
||||||
Description string `json:"description,omitempty"` // Optional description
|
|
||||||
MimeType string `json:"mimeType,omitempty"` // Optional MIME type
|
|
||||||
Handler ResourceHandler `json:"-"` // Internal handler not sent to clients
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResourceHandler is a function that handles resource read requests
|
|
||||||
type ResourceHandler func(ctx context.Context) (ResourceContent, error)
|
|
||||||
|
|
||||||
// ResourceContent represents the content of a resource
|
|
||||||
type ResourceContent struct {
|
|
||||||
URI string `json:"uri"` // Resource URI (required)
|
|
||||||
MimeType string `json:"mimeType,omitempty"` // MIME type of the resource
|
|
||||||
Text string `json:"text,omitempty"` // Text content (if available)
|
|
||||||
Blob string `json:"blob,omitempty"` // Base64 encoded blob data (if available)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResourcesListResult represents the response to a resources/list request
|
|
||||||
type ResourcesListResult struct {
|
|
||||||
PaginatedResult
|
|
||||||
Resources []Resource `json:"resources"` // List of available resources
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResourceReadParams contains parameters for a resources/read request
|
|
||||||
type ResourceReadParams struct {
|
|
||||||
URI string `json:"uri"` // URI of the resource to read
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResourceReadResult contains the result of a resources/read request
|
|
||||||
type ResourceReadResult struct {
|
|
||||||
Result
|
|
||||||
Contents []ResourceContent `json:"contents"` // Array of resource content
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResourceSubscribeParams contains parameters for a resources/subscribe request
|
|
||||||
type ResourceSubscribeParams struct {
|
|
||||||
URI string `json:"uri"` // URI of the resource to subscribe to
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResourceUpdateNotification represents a notification about a resource update
|
|
||||||
type ResourceUpdateNotification struct {
|
|
||||||
URI string `json:"uri"` // URI of the updated resource
|
|
||||||
Content ResourceContent `json:"content"` // New resource content
|
|
||||||
}
|
|
||||||
|
|
||||||
// Client and Server Types
|
|
||||||
|
|
||||||
// mcpClient represents an SSE client connection
|
|
||||||
type mcpClient struct {
|
|
||||||
id string // Unique client identifier
|
|
||||||
channel chan string // Channel for sending SSE messages
|
|
||||||
initialized bool // Tracks if client has sent notifications/initialized
|
|
||||||
}
|
|
||||||
|
|
||||||
// McpServer defines the interface for Model Context Protocol servers
|
|
||||||
type McpServer interface {
|
|
||||||
Start()
|
|
||||||
Stop()
|
|
||||||
RegisterTool(tool Tool) error
|
|
||||||
RegisterPrompt(prompt Prompt)
|
|
||||||
RegisterResource(resource Resource)
|
|
||||||
}
|
|
||||||
|
|
||||||
// sseMcpServer implements the McpServer interface using SSE
|
|
||||||
type sseMcpServer struct {
|
|
||||||
conf McpConf
|
|
||||||
server *rest.Server
|
|
||||||
clients map[string]*mcpClient
|
|
||||||
clientsLock sync.Mutex
|
|
||||||
tools map[string]Tool
|
|
||||||
toolsLock sync.Mutex
|
|
||||||
prompts map[string]Prompt
|
|
||||||
promptsLock sync.Mutex
|
|
||||||
resources map[string]Resource
|
|
||||||
resourcesLock sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// Response Types
|
|
||||||
|
|
||||||
// errorObj represents a JSON-RPC error object
|
|
||||||
type errorObj struct {
|
|
||||||
Code int `json:"code"` // Error code
|
|
||||||
Message string `json:"message"` // Error message
|
|
||||||
}
|
|
||||||
|
|
||||||
// Response represents a JSON-RPC response
|
|
||||||
type Response struct {
|
|
||||||
JsonRpc string `json:"jsonrpc"` // Always "2.0"
|
|
||||||
ID any `json:"id"` // Same as request ID
|
|
||||||
Result any `json:"result"` // Result object (null if error)
|
|
||||||
Error *errorObj `json:"error,omitempty"` // Error object (null if success)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Server Information Types
|
|
||||||
|
|
||||||
// serverInfo provides information about the server
|
|
||||||
type serverInfo struct {
|
|
||||||
Name string `json:"name"` // Server name
|
|
||||||
Version string `json:"version"` // Server version
|
|
||||||
}
|
|
||||||
|
|
||||||
// capabilities describes the server's capabilities
|
|
||||||
type capabilities struct {
|
|
||||||
Logging struct{} `json:"logging"`
|
|
||||||
Prompts struct {
|
|
||||||
ListChanged bool `json:"listChanged"` // Server will notify on prompt changes
|
|
||||||
} `json:"prompts"`
|
|
||||||
Resources struct {
|
|
||||||
Subscribe bool `json:"subscribe"` // Server supports resource subscriptions
|
|
||||||
ListChanged bool `json:"listChanged"` // Server will notify on resource changes
|
|
||||||
} `json:"resources"`
|
|
||||||
Tools struct {
|
|
||||||
ListChanged bool `json:"listChanged"` // Server will notify on tool changes
|
|
||||||
} `json:"tools"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// initializationResponse is sent in response to an initialize request
|
|
||||||
type initializationResponse struct {
|
|
||||||
ProtocolVersion string `json:"protocolVersion"` // Protocol version
|
|
||||||
Capabilities capabilities `json:"capabilities"` // Server capabilities
|
|
||||||
ServerInfo serverInfo `json:"serverInfo"` // Server information
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToolCallParams contains the parameters for a tool call
|
|
||||||
type ToolCallParams struct {
|
|
||||||
Name string `json:"name"` // Tool name
|
|
||||||
Parameters map[string]any `json:"parameters"` // Tool parameters
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToolResult contains the result of a tool execution
|
|
||||||
type ToolResult struct {
|
|
||||||
Type string `json:"type"` // Content type (text, image, etc.)
|
|
||||||
Content any `json:"content"` // Result content
|
|
||||||
}
|
|
||||||
|
|
||||||
// errorMessage represents a detailed error message
|
|
||||||
type errorMessage struct {
|
|
||||||
Code int `json:"code"` // Error code
|
|
||||||
Message string `json:"message"` // Error message
|
|
||||||
Data any `json:",omitempty"` // Additional error data
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,271 +0,0 @@
|
|||||||
package mcp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestResponseMarshaling(t *testing.T) {
|
|
||||||
// Test that the Response struct marshals correctly
|
|
||||||
resp := Response{
|
|
||||||
JsonRpc: "2.0",
|
|
||||||
ID: 123,
|
|
||||||
Result: map[string]string{
|
|
||||||
"key": "value",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := json.Marshal(resp)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Contains(t, string(data), `"jsonrpc":"2.0"`)
|
|
||||||
assert.Contains(t, string(data), `"id":123`)
|
|
||||||
assert.Contains(t, string(data), `"result":{"key":"value"}`)
|
|
||||||
|
|
||||||
// Test response with error
|
|
||||||
respWithError := Response{
|
|
||||||
JsonRpc: "2.0",
|
|
||||||
ID: 456,
|
|
||||||
Error: &errorObj{
|
|
||||||
Code: errCodeInvalidRequest,
|
|
||||||
Message: "Invalid Request",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err = json.Marshal(respWithError)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Contains(t, string(data), `"jsonrpc":"2.0"`)
|
|
||||||
assert.Contains(t, string(data), `"id":456`)
|
|
||||||
assert.Contains(t, string(data), `"error":{"code":-32600,"message":"Invalid Request"}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRequestUnmarshaling(t *testing.T) {
|
|
||||||
// Test that the Request struct unmarshals correctly
|
|
||||||
jsonStr := `{
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": 789,
|
|
||||||
"method": "test_method",
|
|
||||||
"params": {"key": "value"}
|
|
||||||
}`
|
|
||||||
|
|
||||||
var req Request
|
|
||||||
err := json.Unmarshal([]byte(jsonStr), &req)
|
|
||||||
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, "2.0", req.JsonRpc)
|
|
||||||
assert.Equal(t, float64(789), req.ID)
|
|
||||||
assert.Equal(t, "test_method", req.Method)
|
|
||||||
|
|
||||||
// Check params unmarshaled correctly
|
|
||||||
var params map[string]string
|
|
||||||
err = json.Unmarshal(req.Params, ¶ms)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, "value", params["key"])
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestToolStructs(t *testing.T) {
|
|
||||||
// Test Tool struct
|
|
||||||
tool := Tool{
|
|
||||||
Name: "test.tool",
|
|
||||||
Description: "A test tool",
|
|
||||||
InputSchema: InputSchema{
|
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{
|
|
||||||
"input": map[string]any{
|
|
||||||
"type": "string",
|
|
||||||
"description": "Input parameter",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Required: []string{"input"},
|
|
||||||
},
|
|
||||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
|
||||||
return "result", nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify fields are correct
|
|
||||||
assert.Equal(t, "test.tool", tool.Name)
|
|
||||||
assert.Equal(t, "A test tool", tool.Description)
|
|
||||||
assert.Equal(t, "object", tool.InputSchema.Type)
|
|
||||||
assert.Contains(t, tool.InputSchema.Properties, "input")
|
|
||||||
propMap, ok := tool.InputSchema.Properties["input"].(map[string]any)
|
|
||||||
assert.True(t, ok, "Property should be a map")
|
|
||||||
assert.Equal(t, "string", propMap["type"])
|
|
||||||
assert.NotNil(t, tool.Handler)
|
|
||||||
|
|
||||||
// Verify JSON marshalling (which should exclude Handler function)
|
|
||||||
data, err := json.Marshal(tool)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Contains(t, string(data), `"name":"test.tool"`)
|
|
||||||
assert.Contains(t, string(data), `"description":"A test tool"`)
|
|
||||||
assert.Contains(t, string(data), `"inputSchema":`)
|
|
||||||
assert.NotContains(t, string(data), `"Handler":`)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPromptStructs(t *testing.T) {
|
|
||||||
// Test Prompt struct
|
|
||||||
prompt := Prompt{
|
|
||||||
Name: "test.prompt",
|
|
||||||
Description: "A test prompt description",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify fields are correct
|
|
||||||
assert.Equal(t, "test.prompt", prompt.Name)
|
|
||||||
assert.Equal(t, "A test prompt description", prompt.Description)
|
|
||||||
|
|
||||||
// Verify JSON marshalling
|
|
||||||
data, err := json.Marshal(prompt)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Contains(t, string(data), `"name":"test.prompt"`)
|
|
||||||
assert.Contains(t, string(data), `"description":"A test prompt description"`)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResourceStructs(t *testing.T) {
|
|
||||||
// Test Resource struct
|
|
||||||
resource := Resource{
|
|
||||||
Name: "test.resource",
|
|
||||||
URI: "http://example.com/resource",
|
|
||||||
Description: "A test resource",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify fields are correct
|
|
||||||
assert.Equal(t, "test.resource", resource.Name)
|
|
||||||
assert.Equal(t, "http://example.com/resource", resource.URI)
|
|
||||||
assert.Equal(t, "A test resource", resource.Description)
|
|
||||||
|
|
||||||
// Verify JSON marshalling
|
|
||||||
data, err := json.Marshal(resource)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Contains(t, string(data), `"name":"test.resource"`)
|
|
||||||
assert.Contains(t, string(data), `"uri":"http://example.com/resource"`)
|
|
||||||
assert.Contains(t, string(data), `"description":"A test resource"`)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestContentTypes(t *testing.T) {
|
|
||||||
// Test TextContent
|
|
||||||
textContent := TextContent{
|
|
||||||
Text: "Sample text",
|
|
||||||
Annotations: &Annotations{
|
|
||||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
|
||||||
Priority: ptr(1.0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := json.Marshal(textContent)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Contains(t, string(data), `"text":"Sample text"`)
|
|
||||||
assert.Contains(t, string(data), `"audience":["user","assistant"]`)
|
|
||||||
assert.Contains(t, string(data), `"priority":1`)
|
|
||||||
|
|
||||||
// Test ImageContent
|
|
||||||
imageContent := ImageContent{
|
|
||||||
Data: "base64data",
|
|
||||||
MimeType: "image/png",
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err = json.Marshal(imageContent)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Contains(t, string(data), `"data":"base64data"`)
|
|
||||||
assert.Contains(t, string(data), `"mimeType":"image/png"`)
|
|
||||||
|
|
||||||
// Test AudioContent
|
|
||||||
audioContent := AudioContent{
|
|
||||||
Data: "base64audio",
|
|
||||||
MimeType: "audio/mp3",
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err = json.Marshal(audioContent)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Contains(t, string(data), `"data":"base64audio"`)
|
|
||||||
assert.Contains(t, string(data), `"mimeType":"audio/mp3"`)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCallToolResult(t *testing.T) {
|
|
||||||
// Test CallToolResult
|
|
||||||
result := CallToolResult{
|
|
||||||
Result: Result{
|
|
||||||
Meta: map[string]any{
|
|
||||||
"progressToken": "token123",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Content: []interface{}{
|
|
||||||
TextContent{
|
|
||||||
Text: "Sample result",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IsError: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := json.Marshal(result)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Contains(t, string(data), `"_meta":{"progressToken":"token123"}`)
|
|
||||||
assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`)
|
|
||||||
assert.NotContains(t, string(data), `"isError":`)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRequest_isNotification(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
id any
|
|
||||||
want bool
|
|
||||||
wantErr error
|
|
||||||
}{
|
|
||||||
// integer test cases
|
|
||||||
{name: "int zero", id: 0, want: true, wantErr: nil},
|
|
||||||
{name: "int non-zero", id: 1, want: false, wantErr: nil},
|
|
||||||
{name: "int64 zero", id: int64(0), want: true, wantErr: nil},
|
|
||||||
{name: "int64 max", id: int64(9223372036854775807), want: false, wantErr: nil},
|
|
||||||
|
|
||||||
// floating point number test cases
|
|
||||||
{name: "float64 zero", id: float64(0.0), want: true, wantErr: nil},
|
|
||||||
{name: "float64 positive", id: float64(0.000001), want: false, wantErr: nil},
|
|
||||||
{name: "float64 negative", id: float64(-0.000001), want: false, wantErr: nil},
|
|
||||||
{name: "float64 epsilon", id: float64(1e-300), want: false, wantErr: nil},
|
|
||||||
|
|
||||||
// string test cases
|
|
||||||
{name: "empty string", id: "", want: true, wantErr: nil},
|
|
||||||
{name: "non-empty string", id: "abc", want: false, wantErr: nil},
|
|
||||||
{name: "space string", id: " ", want: false, wantErr: nil},
|
|
||||||
{name: "unicode string", id: "こんにちは", want: false, wantErr: nil},
|
|
||||||
|
|
||||||
// special cases
|
|
||||||
{name: "nil", id: nil, want: true, wantErr: nil},
|
|
||||||
|
|
||||||
// logical type test cases
|
|
||||||
{name: "bool true", id: true, want: false, wantErr: errors.New("invalid type bool")},
|
|
||||||
{name: "bool false", id: false, want: false, wantErr: errors.New("invalid type bool")},
|
|
||||||
{name: "struct type", id: struct{}{}, want: false, wantErr: errors.New("invalid type struct {}")},
|
|
||||||
{name: "slice type", id: []int{1, 2, 3}, want: false, wantErr: errors.New("invalid type []int")},
|
|
||||||
{name: "map type", id: map[string]int{"a": 1}, want: false, wantErr: errors.New("invalid type map[string]int")},
|
|
||||||
{name: "pointer type", id: new(int), want: false, wantErr: errors.New("invalid type *int")},
|
|
||||||
{name: "func type", id: func() {}, want: false, wantErr: errors.New("invalid type func()")},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
req := Request{
|
|
||||||
SessionId: "test-session",
|
|
||||||
JsonRpc: "2.0",
|
|
||||||
ID: tt.id,
|
|
||||||
Method: "testMethod",
|
|
||||||
Params: json.RawMessage(`{}`),
|
|
||||||
}
|
|
||||||
|
|
||||||
got, err := req.isNotification()
|
|
||||||
|
|
||||||
if (err != nil) != (tt.wantErr != nil) {
|
|
||||||
t.Fatalf("error presence mismatch: got error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
if err != nil && tt.wantErr != nil && err.Error() != tt.wantErr.Error() {
|
|
||||||
t.Fatalf("error message mismatch:\ngot %q\nwant %q", err.Error(), tt.wantErr.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("isNotification() = %v, want %v for ID %v (%T)", got, tt.want, tt.id, tt.id)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
107
mcp/util.go
107
mcp/util.go
@@ -1,107 +0,0 @@
|
|||||||
package mcp
|
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings
|
|
||||||
func formatSSEMessage(event string, data []byte) string {
|
|
||||||
return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data))
|
|
||||||
}
|
|
||||||
|
|
||||||
// ptr is a helper function to get a pointer to a value
|
|
||||||
func ptr[T any](v T) *T {
|
|
||||||
return &v
|
|
||||||
}
|
|
||||||
|
|
||||||
func toTypedContents(contents []any) []any {
|
|
||||||
typedContents := make([]any, len(contents))
|
|
||||||
|
|
||||||
for i, content := range contents {
|
|
||||||
switch v := content.(type) {
|
|
||||||
case TextContent:
|
|
||||||
typedContents[i] = typedTextContent{
|
|
||||||
Type: ContentTypeText,
|
|
||||||
TextContent: v,
|
|
||||||
}
|
|
||||||
case ImageContent:
|
|
||||||
typedContents[i] = typedImageContent{
|
|
||||||
Type: ContentTypeImage,
|
|
||||||
ImageContent: v,
|
|
||||||
}
|
|
||||||
case AudioContent:
|
|
||||||
typedContents[i] = typedAudioContent{
|
|
||||||
Type: ContentTypeAudio,
|
|
||||||
AudioContent: v,
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
typedContents[i] = typedTextContent{
|
|
||||||
Type: ContentTypeText,
|
|
||||||
TextContent: TextContent{
|
|
||||||
Text: fmt.Sprintf("Unknown content type: %T", v),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return typedContents
|
|
||||||
}
|
|
||||||
|
|
||||||
func toTypedPromptMessages(messages []PromptMessage) []PromptMessage {
|
|
||||||
typedMessages := make([]PromptMessage, len(messages))
|
|
||||||
|
|
||||||
for i, msg := range messages {
|
|
||||||
switch v := msg.Content.(type) {
|
|
||||||
case TextContent:
|
|
||||||
typedMessages[i] = PromptMessage{
|
|
||||||
Role: msg.Role,
|
|
||||||
Content: typedTextContent{
|
|
||||||
Type: ContentTypeText,
|
|
||||||
TextContent: v,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
case ImageContent:
|
|
||||||
typedMessages[i] = PromptMessage{
|
|
||||||
Role: msg.Role,
|
|
||||||
Content: typedImageContent{
|
|
||||||
Type: ContentTypeImage,
|
|
||||||
ImageContent: v,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
case AudioContent:
|
|
||||||
typedMessages[i] = PromptMessage{
|
|
||||||
Role: msg.Role,
|
|
||||||
Content: typedAudioContent{
|
|
||||||
Type: ContentTypeAudio,
|
|
||||||
AudioContent: v,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
typedMessages[i] = PromptMessage{
|
|
||||||
Role: msg.Role,
|
|
||||||
Content: typedTextContent{
|
|
||||||
Type: ContentTypeText,
|
|
||||||
TextContent: TextContent{
|
|
||||||
Text: fmt.Sprintf("Unknown content type: %T", v),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return typedMessages
|
|
||||||
}
|
|
||||||
|
|
||||||
// validatePromptArguments checks if all required arguments are provided
|
|
||||||
// Returns a list of missing required arguments
|
|
||||||
func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string {
|
|
||||||
var missingArgs []string
|
|
||||||
|
|
||||||
for _, arg := range prompt.Arguments {
|
|
||||||
if arg.Required {
|
|
||||||
if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 {
|
|
||||||
missingArgs = append(missingArgs, arg.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return missingArgs
|
|
||||||
}
|
|
||||||
274
mcp/util_test.go
274
mcp/util_test.go
@@ -1,274 +0,0 @@
|
|||||||
package mcp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Event struct {
|
|
||||||
Type string
|
|
||||||
Data map[string]any
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseEvent(input string) (*Event, error) {
|
|
||||||
var evt Event
|
|
||||||
var dataStr string
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(strings.NewReader(input))
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Text()
|
|
||||||
if strings.HasPrefix(line, "event:") {
|
|
||||||
evt.Type = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
|
|
||||||
} else if strings.HasPrefix(line, "data:") {
|
|
||||||
dataStr = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := scanner.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(dataStr) > 0 {
|
|
||||||
if err := json.Unmarshal([]byte(dataStr), &evt.Data); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse data: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &evt, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestToTypedPromptMessages tests the toTypedPromptMessages function
|
|
||||||
func TestToTypedPromptMessages(t *testing.T) {
|
|
||||||
// Test with multiple message types in one test
|
|
||||||
t.Run("MixedContentTypes", func(t *testing.T) {
|
|
||||||
// Create test data with different content types
|
|
||||||
messages := []PromptMessage{
|
|
||||||
{
|
|
||||||
Role: RoleUser,
|
|
||||||
Content: TextContent{
|
|
||||||
Text: "Hello, this is a text message",
|
|
||||||
Annotations: &Annotations{
|
|
||||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
|
||||||
Priority: ptr(0.8),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: RoleAssistant,
|
|
||||||
Content: ImageContent{
|
|
||||||
Data: "base64ImageData",
|
|
||||||
MimeType: "image/jpeg",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: RoleUser,
|
|
||||||
Content: AudioContent{
|
|
||||||
Data: "base64AudioData",
|
|
||||||
MimeType: "audio/mp3",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: "system",
|
|
||||||
Content: "This is a simple string that should be handled as unknown type",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Call the function
|
|
||||||
result := toTypedPromptMessages(messages)
|
|
||||||
|
|
||||||
// Validate results
|
|
||||||
require.Len(t, result, 4, "Should return the same number of messages")
|
|
||||||
|
|
||||||
// Validate first message (TextContent)
|
|
||||||
msg := result[0]
|
|
||||||
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
|
|
||||||
|
|
||||||
// Type assertion using reflection since Content is an interface
|
|
||||||
typed, ok := msg.Content.(typedTextContent)
|
|
||||||
require.True(t, ok, "Should be typedTextContent")
|
|
||||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
|
||||||
assert.Equal(t, "Hello, this is a text message", typed.Text, "Text content should be preserved")
|
|
||||||
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
|
|
||||||
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
|
|
||||||
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
|
|
||||||
assert.Equal(t, 0.8, *typed.Annotations.Priority, "Priority value should be preserved")
|
|
||||||
|
|
||||||
// Validate second message (ImageContent)
|
|
||||||
msg = result[1]
|
|
||||||
assert.Equal(t, RoleAssistant, msg.Role, "Role should be preserved")
|
|
||||||
|
|
||||||
// Type assertion for image content
|
|
||||||
typedImg, ok := msg.Content.(typedImageContent)
|
|
||||||
require.True(t, ok, "Should be typedImageContent")
|
|
||||||
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
|
|
||||||
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
|
|
||||||
assert.Equal(t, "image/jpeg", typedImg.MimeType, "MimeType should be preserved")
|
|
||||||
|
|
||||||
// Validate third message (AudioContent)
|
|
||||||
msg = result[2]
|
|
||||||
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
|
|
||||||
|
|
||||||
// Type assertion for audio content
|
|
||||||
typedAudio, ok := msg.Content.(typedAudioContent)
|
|
||||||
require.True(t, ok, "Should be typedAudioContent")
|
|
||||||
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
|
|
||||||
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
|
|
||||||
assert.Equal(t, "audio/mp3", typedAudio.MimeType, "MimeType should be preserved")
|
|
||||||
|
|
||||||
// Validate fourth message (unknown type converted to TextContent)
|
|
||||||
msg = result[3]
|
|
||||||
assert.Equal(t, RoleType("system"), msg.Role, "Role should be preserved")
|
|
||||||
|
|
||||||
// Should be converted to a typedTextContent with error message
|
|
||||||
typedUnknown, ok := msg.Content.(typedTextContent)
|
|
||||||
require.True(t, ok, "Unknown content should be converted to typedTextContent")
|
|
||||||
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
|
|
||||||
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
|
|
||||||
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
|
|
||||||
})
|
|
||||||
|
|
||||||
// Test empty input
|
|
||||||
t.Run("EmptyInput", func(t *testing.T) {
|
|
||||||
messages := []PromptMessage{}
|
|
||||||
result := toTypedPromptMessages(messages)
|
|
||||||
assert.Empty(t, result, "Should return empty slice for empty input")
|
|
||||||
})
|
|
||||||
|
|
||||||
// Test with nil annotations
|
|
||||||
t.Run("NilAnnotations", func(t *testing.T) {
|
|
||||||
messages := []PromptMessage{
|
|
||||||
{
|
|
||||||
Role: RoleUser,
|
|
||||||
Content: TextContent{
|
|
||||||
Text: "Text with nil annotations",
|
|
||||||
Annotations: nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
result := toTypedPromptMessages(messages)
|
|
||||||
require.Len(t, result, 1, "Should return one message")
|
|
||||||
|
|
||||||
typed, ok := result[0].Content.(typedTextContent)
|
|
||||||
require.True(t, ok, "Should be typedTextContent")
|
|
||||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
|
||||||
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
|
|
||||||
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestToTypedContents tests the toTypedContents function
|
|
||||||
func TestToTypedContents(t *testing.T) {
|
|
||||||
// Test with multiple content types in one test
|
|
||||||
t.Run("MixedContentTypes", func(t *testing.T) {
|
|
||||||
// Create test data with different content types
|
|
||||||
contents := []any{
|
|
||||||
TextContent{
|
|
||||||
Text: "Hello, this is a text content",
|
|
||||||
Annotations: &Annotations{
|
|
||||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
|
||||||
Priority: ptr(0.7),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
ImageContent{
|
|
||||||
Data: "base64ImageData",
|
|
||||||
MimeType: "image/png",
|
|
||||||
},
|
|
||||||
AudioContent{
|
|
||||||
Data: "base64AudioData",
|
|
||||||
MimeType: "audio/wav",
|
|
||||||
},
|
|
||||||
"This is a simple string that should be handled as unknown type",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Call the function
|
|
||||||
result := toTypedContents(contents)
|
|
||||||
|
|
||||||
// Validate results
|
|
||||||
require.Len(t, result, 4, "Should return the same number of contents")
|
|
||||||
|
|
||||||
// Validate first content (TextContent)
|
|
||||||
typed, ok := result[0].(typedTextContent)
|
|
||||||
require.True(t, ok, "Should be typedTextContent")
|
|
||||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
|
||||||
assert.Equal(t, "Hello, this is a text content", typed.Text, "Text content should be preserved")
|
|
||||||
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
|
|
||||||
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
|
|
||||||
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
|
|
||||||
assert.Equal(t, 0.7, *typed.Annotations.Priority, "Priority value should be preserved")
|
|
||||||
|
|
||||||
// Validate second content (ImageContent)
|
|
||||||
typedImg, ok := result[1].(typedImageContent)
|
|
||||||
require.True(t, ok, "Should be typedImageContent")
|
|
||||||
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
|
|
||||||
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
|
|
||||||
assert.Equal(t, "image/png", typedImg.MimeType, "MimeType should be preserved")
|
|
||||||
|
|
||||||
// Validate third content (AudioContent)
|
|
||||||
typedAudio, ok := result[2].(typedAudioContent)
|
|
||||||
require.True(t, ok, "Should be typedAudioContent")
|
|
||||||
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
|
|
||||||
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
|
|
||||||
assert.Equal(t, "audio/wav", typedAudio.MimeType, "MimeType should be preserved")
|
|
||||||
|
|
||||||
// Validate fourth content (unknown type converted to TextContent)
|
|
||||||
typedUnknown, ok := result[3].(typedTextContent)
|
|
||||||
require.True(t, ok, "Unknown content should be converted to typedTextContent")
|
|
||||||
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
|
|
||||||
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
|
|
||||||
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
|
|
||||||
})
|
|
||||||
|
|
||||||
// Test empty input
|
|
||||||
t.Run("EmptyInput", func(t *testing.T) {
|
|
||||||
contents := []any{}
|
|
||||||
result := toTypedContents(contents)
|
|
||||||
assert.Empty(t, result, "Should return empty slice for empty input")
|
|
||||||
})
|
|
||||||
|
|
||||||
// Test with nil annotations
|
|
||||||
t.Run("NilAnnotations", func(t *testing.T) {
|
|
||||||
contents := []any{
|
|
||||||
TextContent{
|
|
||||||
Text: "Text with nil annotations",
|
|
||||||
Annotations: nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
result := toTypedContents(contents)
|
|
||||||
require.Len(t, result, 1, "Should return one content")
|
|
||||||
|
|
||||||
typed, ok := result[0].(typedTextContent)
|
|
||||||
require.True(t, ok, "Should be typedTextContent")
|
|
||||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
|
||||||
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
|
|
||||||
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
|
|
||||||
})
|
|
||||||
|
|
||||||
// Test with custom struct (should be handled as unknown type)
|
|
||||||
t.Run("CustomStruct", func(t *testing.T) {
|
|
||||||
type CustomContent struct {
|
|
||||||
Data string
|
|
||||||
}
|
|
||||||
|
|
||||||
contents := []any{
|
|
||||||
CustomContent{
|
|
||||||
Data: "custom data",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
result := toTypedContents(contents)
|
|
||||||
require.Len(t, result, 1, "Should return one content")
|
|
||||||
|
|
||||||
typed, ok := result[0].(typedTextContent)
|
|
||||||
require.True(t, ok, "Custom struct should be converted to typedTextContent")
|
|
||||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
|
||||||
assert.Contains(t, typed.Text, "Unknown content type:", "Should contain error about unknown type")
|
|
||||||
assert.Contains(t, typed.Text, "CustomContent", "Should mention the actual type")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
149
mcp/vars.go
149
mcp/vars.go
@@ -1,149 +0,0 @@
|
|||||||
package mcp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/syncx"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Protocol constants
|
|
||||||
const (
|
|
||||||
// JSON-RPC version as defined in the specification
|
|
||||||
jsonRpcVersion = "2.0"
|
|
||||||
|
|
||||||
// Session identifier key used in request URLs
|
|
||||||
sessionIdKey = "session_id"
|
|
||||||
|
|
||||||
// progressTokenKey is used to track progress of long-running tasks
|
|
||||||
progressTokenKey = "progressToken"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Server-Sent Events (SSE) event types
|
|
||||||
const (
|
|
||||||
// Standard message event for JSON-RPC responses
|
|
||||||
eventMessage = "message"
|
|
||||||
|
|
||||||
// Endpoint event for sending endpoint URL to clients
|
|
||||||
eventEndpoint = "endpoint"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Content type identifiers
|
|
||||||
const (
|
|
||||||
// ContentTypeObject is object content type
|
|
||||||
ContentTypeObject = "object"
|
|
||||||
|
|
||||||
// ContentTypeText is text content type
|
|
||||||
ContentTypeText = "text"
|
|
||||||
|
|
||||||
// ContentTypeImage is image content type
|
|
||||||
ContentTypeImage = "image"
|
|
||||||
|
|
||||||
// ContentTypeAudio is audio content type
|
|
||||||
ContentTypeAudio = "audio"
|
|
||||||
|
|
||||||
// ContentTypeResource is resource content type
|
|
||||||
ContentTypeResource = "resource"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Collection keys for broadcast events
|
|
||||||
const (
|
|
||||||
// Key for prompts collection
|
|
||||||
keyPrompts = "prompts"
|
|
||||||
|
|
||||||
// Key for resources collection
|
|
||||||
keyResources = "resources"
|
|
||||||
|
|
||||||
// Key for tools collection
|
|
||||||
keyTools = "tools"
|
|
||||||
)
|
|
||||||
|
|
||||||
// JSON-RPC error codes
|
|
||||||
// Standard error codes from JSON-RPC 2.0 spec
|
|
||||||
const (
|
|
||||||
// Invalid JSON was received by the server
|
|
||||||
errCodeInvalidRequest = -32600
|
|
||||||
|
|
||||||
// The method does not exist / is not available
|
|
||||||
errCodeMethodNotFound = -32601
|
|
||||||
|
|
||||||
// Invalid method parameter(s)
|
|
||||||
errCodeInvalidParams = -32602
|
|
||||||
|
|
||||||
// Internal JSON-RPC error
|
|
||||||
errCodeInternalError = -32603
|
|
||||||
|
|
||||||
// Tool execution timed out
|
|
||||||
errCodeTimeout = -32001
|
|
||||||
|
|
||||||
// Resource not found error
|
|
||||||
errCodeResourceNotFound = -32002
|
|
||||||
|
|
||||||
// Client hasn't completed initialization
|
|
||||||
errCodeClientNotInitialized = -32800
|
|
||||||
)
|
|
||||||
|
|
||||||
// User and assistant role definitions
|
|
||||||
const (
|
|
||||||
// RoleUser is the "user" role - the entity asking questions
|
|
||||||
RoleUser RoleType = "user"
|
|
||||||
|
|
||||||
// RoleAssistant is the "assistant" role - the entity providing responses
|
|
||||||
RoleAssistant RoleType = "assistant"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Method names as defined in the MCP specification
|
|
||||||
const (
|
|
||||||
// Initialize the connection between client and server
|
|
||||||
methodInitialize = "initialize"
|
|
||||||
|
|
||||||
// List available tools
|
|
||||||
methodToolsList = "tools/list"
|
|
||||||
|
|
||||||
// Call a specific tool
|
|
||||||
methodToolsCall = "tools/call"
|
|
||||||
|
|
||||||
// List available prompts
|
|
||||||
methodPromptsList = "prompts/list"
|
|
||||||
|
|
||||||
// Get a specific prompt
|
|
||||||
methodPromptsGet = "prompts/get"
|
|
||||||
|
|
||||||
// List available resources
|
|
||||||
methodResourcesList = "resources/list"
|
|
||||||
|
|
||||||
// Read a specific resource
|
|
||||||
methodResourcesRead = "resources/read"
|
|
||||||
|
|
||||||
// Subscribe to resource updates
|
|
||||||
methodResourcesSubscribe = "resources/subscribe"
|
|
||||||
|
|
||||||
// Simple ping to check server availability
|
|
||||||
methodPing = "ping"
|
|
||||||
|
|
||||||
// Notification that client is fully initialized
|
|
||||||
methodNotificationsInitialized = "notifications/initialized"
|
|
||||||
|
|
||||||
// Notification that a request was canceled
|
|
||||||
methodNotificationsCancelled = "notifications/cancelled"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Event names for Server-Sent Events (SSE)
|
|
||||||
const (
|
|
||||||
// Notification of tool list changes
|
|
||||||
eventToolsListChanged = "tools/list_changed"
|
|
||||||
|
|
||||||
// Notification of prompt list changes
|
|
||||||
eventPromptsListChanged = "prompts/list_changed"
|
|
||||||
|
|
||||||
// Notification of resource list changes
|
|
||||||
eventResourcesListChanged = "resources/list_changed"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// Default channel size for events
|
|
||||||
eventChanSize = 10
|
|
||||||
|
|
||||||
// Default ping interval for checking connection availability
|
|
||||||
// use syncx.ForAtomicDuration to ensure atomicity in test race
|
|
||||||
pingInterval = syncx.ForAtomicDuration(30 * time.Second)
|
|
||||||
)
|
|
||||||
210
mcp/vars_test.go
210
mcp/vars_test.go
@@ -1,210 +0,0 @@
|
|||||||
// filepath: /Users/kevin/Develop/go/opensource/go-zero/mcp/vars_test.go
|
|
||||||
package mcp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestErrorCodes ensures error codes are applied correctly in error responses
|
|
||||||
func TestErrorCodes(t *testing.T) {
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
code int
|
|
||||||
message string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "invalid request error",
|
|
||||||
code: errCodeInvalidRequest,
|
|
||||||
message: "Invalid request",
|
|
||||||
expected: `"code":-32600`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "method not found error",
|
|
||||||
code: errCodeMethodNotFound,
|
|
||||||
message: "Method not found",
|
|
||||||
expected: `"code":-32601`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid params error",
|
|
||||||
code: errCodeInvalidParams,
|
|
||||||
message: "Invalid parameters",
|
|
||||||
expected: `"code":-32602`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "internal error",
|
|
||||||
code: errCodeInternalError,
|
|
||||||
message: "Internal server error",
|
|
||||||
expected: `"code":-32603`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "timeout error",
|
|
||||||
code: errCodeTimeout,
|
|
||||||
message: "Operation timed out",
|
|
||||||
expected: `"code":-32001`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "resource not found error",
|
|
||||||
code: errCodeResourceNotFound,
|
|
||||||
message: "Resource not found",
|
|
||||||
expected: `"code":-32002`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "client not initialized error",
|
|
||||||
code: errCodeClientNotInitialized,
|
|
||||||
message: "Client not initialized",
|
|
||||||
expected: `"code":-32800`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
resp := Response{
|
|
||||||
JsonRpc: jsonRpcVersion,
|
|
||||||
ID: int64(1),
|
|
||||||
Error: &errorObj{
|
|
||||||
Code: tc.code,
|
|
||||||
Message: tc.message,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
data, err := json.Marshal(resp)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Contains(t, string(data), tc.expected, "Error code should match expected value")
|
|
||||||
assert.Contains(t, string(data), tc.message, "Error message should be included")
|
|
||||||
assert.Contains(t, string(data), jsonRpcVersion, "JSON-RPC version should be included")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestJsonRpcVersion ensures the correct JSON-RPC version is used
|
|
||||||
func TestJsonRpcVersion(t *testing.T) {
|
|
||||||
assert.Equal(t, "2.0", jsonRpcVersion, "JSON-RPC version should be 2.0")
|
|
||||||
|
|
||||||
// Test that it's used in responses
|
|
||||||
resp := Response{
|
|
||||||
JsonRpc: jsonRpcVersion,
|
|
||||||
ID: int64(1),
|
|
||||||
Result: "test",
|
|
||||||
}
|
|
||||||
data, err := json.Marshal(resp)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Contains(t, string(data), `"jsonrpc":"2.0"`, "Response should use correct JSON-RPC version")
|
|
||||||
|
|
||||||
// Test that it's expected in requests
|
|
||||||
reqStr := `{"jsonrpc":"2.0","id":1,"method":"test"}`
|
|
||||||
var req Request
|
|
||||||
err = json.Unmarshal([]byte(reqStr), &req)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, jsonRpcVersion, req.JsonRpc, "Request should parse correct JSON-RPC version")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSessionIdKey ensures session ID extraction works correctly
|
|
||||||
func TestSessionIdKey(t *testing.T) {
|
|
||||||
// Create a mock server implementation
|
|
||||||
mock := newMockMcpServer(t)
|
|
||||||
defer mock.shutdown()
|
|
||||||
|
|
||||||
// Verify the key constant
|
|
||||||
assert.Equal(t, "session_id", sessionIdKey, "Session ID key should be 'session_id'")
|
|
||||||
|
|
||||||
// Test that session ID is extracted correctly
|
|
||||||
mockR := httptest.NewRequest("GET", "/?"+sessionIdKey+"=test-session", nil)
|
|
||||||
|
|
||||||
// Since the mock server is using the same session key logic,
|
|
||||||
// we can test this by accessing the request query parameters directly
|
|
||||||
sessionID := mockR.URL.Query().Get(sessionIdKey)
|
|
||||||
assert.Equal(t, "test-session", sessionID, "Session ID should be extracted correctly")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestEventTypes ensures event types are set correctly in SSE responses
|
|
||||||
func TestEventTypes(t *testing.T) {
|
|
||||||
// Test message event
|
|
||||||
assert.Equal(t, "message", eventMessage, "Message event should be 'message'")
|
|
||||||
|
|
||||||
// Test endpoint event
|
|
||||||
assert.Equal(t, "endpoint", eventEndpoint, "Endpoint event should be 'endpoint'")
|
|
||||||
|
|
||||||
// Verify them in an actual SSE format string
|
|
||||||
messageEvent := "event: " + eventMessage + "\ndata: test\n\n"
|
|
||||||
assert.Contains(t, messageEvent, "event: message", "Message event should format correctly")
|
|
||||||
|
|
||||||
endpointEvent := "event: " + eventEndpoint + "\ndata: test\n\n"
|
|
||||||
assert.Contains(t, endpointEvent, "event: endpoint", "Endpoint event should format correctly")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestCollectionKeys checks that collection keys are used correctly
|
|
||||||
func TestCollectionKeys(t *testing.T) {
|
|
||||||
// Verify collection key constants
|
|
||||||
assert.Equal(t, "prompts", keyPrompts, "Prompts key should be 'prompts'")
|
|
||||||
assert.Equal(t, "resources", keyResources, "Resources key should be 'resources'")
|
|
||||||
assert.Equal(t, "tools", keyTools, "Tools key should be 'tools'")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRoleTypes checks that role types are used correctly
|
|
||||||
func TestRoleTypes(t *testing.T) {
|
|
||||||
// Test in annotations
|
|
||||||
annotations := Annotations{
|
|
||||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
|
||||||
}
|
|
||||||
data, err := json.Marshal(annotations)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Contains(t, string(data), `"audience":["user","assistant"]`, "Role types should marshal correctly")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestMethodNames checks that method names are used correctly
|
|
||||||
func TestMethodNames(t *testing.T) {
|
|
||||||
// Verify method name constants
|
|
||||||
methods := map[string]string{
|
|
||||||
"initialize": methodInitialize,
|
|
||||||
"tools/list": methodToolsList,
|
|
||||||
"tools/call": methodToolsCall,
|
|
||||||
"prompts/list": methodPromptsList,
|
|
||||||
"prompts/get": methodPromptsGet,
|
|
||||||
"resources/list": methodResourcesList,
|
|
||||||
"resources/read": methodResourcesRead,
|
|
||||||
"resources/subscribe": methodResourcesSubscribe,
|
|
||||||
"ping": methodPing,
|
|
||||||
"notifications/initialized": methodNotificationsInitialized,
|
|
||||||
"notifications/cancelled": methodNotificationsCancelled,
|
|
||||||
}
|
|
||||||
|
|
||||||
for expected, actual := range methods {
|
|
||||||
assert.Equal(t, expected, actual, "Method name should be "+expected)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test in a request
|
|
||||||
for methodName := range methods {
|
|
||||||
req := Request{
|
|
||||||
JsonRpc: jsonRpcVersion,
|
|
||||||
ID: int64(1),
|
|
||||||
Method: methodName,
|
|
||||||
}
|
|
||||||
data, err := json.Marshal(req)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Contains(t, string(data), `"method":"`+methodName+`"`, "Method name should be used in requests")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestEventNames checks that event names are used correctly
|
|
||||||
func TestEventNames(t *testing.T) {
|
|
||||||
// Verify event name constants
|
|
||||||
events := map[string]string{
|
|
||||||
"tools/list_changed": eventToolsListChanged,
|
|
||||||
"prompts/list_changed": eventPromptsListChanged,
|
|
||||||
"resources/list_changed": eventResourcesListChanged,
|
|
||||||
}
|
|
||||||
|
|
||||||
for expected, actual := range events {
|
|
||||||
assert.Equal(t, expected, actual, "Event name should be "+expected)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test event names in SSE format
|
|
||||||
for _, eventName := range events {
|
|
||||||
sseEvent := "event: " + eventName + "\ndata: test\n\n"
|
|
||||||
assert.Contains(t, sseEvent, "event: "+eventName, "Event name should format correctly in SSE")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user