Fix the problem that mcp request id is not of int type (#4914)

This commit is contained in:
MarkJoyMa
2025-06-07 10:37:18 +08:00
committed by GitHub
parent 4b2095ed03
commit d4cccca387
3 changed files with 120 additions and 31 deletions

View File

@@ -173,17 +173,20 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
return return
} }
w.WriteHeader(http.StatusAccepted)
// For notification methods (no ID), we don't send a response // For notification methods (no ID), we don't send a response
isNotification := req.ID == 0 isNotification, err := req.isNotification()
if err != nil {
http.Error(w, "Invalid request.ID", http.StatusBadRequest)
}
w.WriteHeader(http.StatusAccepted)
// Special handling for initialization sequence // Special handling for initialization sequence
// Always allow initialize and notifications/initialized regardless of client state // Always allow initialize and notifications/initialized regardless of client state
if req.Method == methodInitialize { if req.Method == methodInitialize {
logx.Infof("Processing initialize request with ID: %d", req.ID) logx.Infof("Processing initialize request with ID: %v", req.ID)
s.processInitialize(r.Context(), client, req) s.processInitialize(r.Context(), client, req)
logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID) logx.Infof("Sent initialize response for ID: %v, waiting for notifications/initialized", req.ID)
return return
} else if req.Method == methodNotificationsInitialized { } else if req.Method == methodNotificationsInitialized {
// Handle initialized notification // Handle initialized notification
@@ -206,41 +209,41 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
// Process normal requests only after initialization // Process normal requests only after initialization
switch req.Method { switch req.Method {
case methodToolsCall: case methodToolsCall:
logx.Infof("Received tools call request with ID: %d", req.ID) logx.Infof("Received tools call request with ID: %v", req.ID)
s.processToolCall(r.Context(), client, req) s.processToolCall(r.Context(), client, req)
logx.Infof("Sent tools call response for ID: %d", req.ID) logx.Infof("Sent tools call response for ID: %v", req.ID)
case methodToolsList: case methodToolsList:
logx.Infof("Processing tools/list request with ID: %d", req.ID) logx.Infof("Processing tools/list request with ID: %v", req.ID)
s.processListTools(r.Context(), client, req) s.processListTools(r.Context(), client, req)
logx.Infof("Sent tools/list response for ID: %d", req.ID) logx.Infof("Sent tools/list response for ID: %v", req.ID)
case methodPromptsList: case methodPromptsList:
logx.Infof("Processing prompts/list request with ID: %d", req.ID) logx.Infof("Processing prompts/list request with ID: %v", req.ID)
s.processListPrompts(r.Context(), client, req) s.processListPrompts(r.Context(), client, req)
logx.Infof("Sent prompts/list response for ID: %d", req.ID) logx.Infof("Sent prompts/list response for ID: %v", req.ID)
case methodPromptsGet: case methodPromptsGet:
logx.Infof("Processing prompts/get request with ID: %d", req.ID) logx.Infof("Processing prompts/get request with ID: %v", req.ID)
s.processGetPrompt(r.Context(), client, req) s.processGetPrompt(r.Context(), client, req)
logx.Infof("Sent prompts/get response for ID: %d", req.ID) logx.Infof("Sent prompts/get response for ID: %v", req.ID)
case methodResourcesList: case methodResourcesList:
logx.Infof("Processing resources/list request with ID: %d", req.ID) logx.Infof("Processing resources/list request with ID: %v", req.ID)
s.processListResources(r.Context(), client, req) s.processListResources(r.Context(), client, req)
logx.Infof("Sent resources/list response for ID: %d", req.ID) logx.Infof("Sent resources/list response for ID: %v", req.ID)
case methodResourcesRead: case methodResourcesRead:
logx.Infof("Processing resources/read request with ID: %d", req.ID) logx.Infof("Processing resources/read request with ID: %v", req.ID)
s.processResourcesRead(r.Context(), client, req) s.processResourcesRead(r.Context(), client, req)
logx.Infof("Sent resources/read response for ID: %d", req.ID) logx.Infof("Sent resources/read response for ID: %v", req.ID)
case methodResourcesSubscribe: case methodResourcesSubscribe:
logx.Infof("Processing resources/subscribe request with ID: %d", req.ID) logx.Infof("Processing resources/subscribe request with ID: %v", req.ID)
s.processResourceSubscribe(r.Context(), client, req) s.processResourceSubscribe(r.Context(), client, req)
logx.Infof("Sent resources/subscribe response for ID: %d", req.ID) logx.Infof("Sent resources/subscribe response for ID: %v", req.ID)
case methodPing: case methodPing:
logx.Infof("Processing ping request with ID: %d", req.ID) logx.Infof("Processing ping request with ID: %v", req.ID)
s.processPing(r.Context(), client, req) s.processPing(r.Context(), client, req)
case methodNotificationsCancelled: case methodNotificationsCancelled:
logx.Infof("Received notifications/cancelled notification: %d", req.ID) logx.Infof("Received notifications/cancelled notification: %v", req.ID)
s.processNotificationCancelled(r.Context(), client, req) s.processNotificationCancelled(r.Context(), client, req)
default: default:
logx.Infof("Unknown method: %s from client: %d", req.Method, req.ID) logx.Infof("Unknown method: %s from client: %v", req.Method, req.ID)
s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound) s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound)
} }
} }
@@ -880,10 +883,10 @@ func (s *sseMcpServer) processPing(ctx context.Context, client *mcpClient, req R
// sendErrorResponse sends an error response via the SSE channel // sendErrorResponse sends an error response via the SSE channel
func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient, func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
id int64, message string, code int) { id any, message string, code int) {
errorResponse := struct { errorResponse := struct {
JsonRpc string `json:"jsonrpc"` JsonRpc string `json:"jsonrpc"`
ID int64 `json:"id"` ID any `json:"id"`
Error errorMessage `json:"error"` Error errorMessage `json:"error"`
}{ }{
JsonRpc: jsonRpcVersion, JsonRpc: jsonRpcVersion,
@@ -898,7 +901,7 @@ func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
jsonData, _ := json.Marshal(errorResponse) jsonData, _ := json.Marshal(errorResponse)
// Use CRLF line endings as requested // Use CRLF line endings as requested
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData)) sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
logx.Infof("Sending error for ID %d: %s", id, sseMessage) logx.Infof("Sending error for ID %v: %s", id, sseMessage)
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages // cannot receive from ctx.Done() because we're sending to the channel for SSE messages
select { select {
@@ -910,7 +913,7 @@ func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
} }
// sendResponse sends a success response via the SSE channel // sendResponse sends a success response via the SSE channel
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id int64, result any) { func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id any, result any) {
response := Response{ response := Response{
JsonRpc: jsonRpcVersion, JsonRpc: jsonRpcVersion,
ID: id, ID: id,
@@ -925,13 +928,13 @@ func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id i
// Use CRLF line endings as requested // Use CRLF line endings as requested
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData)) sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
logx.Infof("Sending response for ID %d: %s", id, sseMessage) logx.Infof("Sending response for ID %v: %s", id, sseMessage)
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages // cannot receive from ctx.Done() because we're sending to the channel for SSE messages
select { select {
case client.channel <- sseMessage: case client.channel <- sseMessage:
default: default:
// Channel buffer is full, log warning and continue // Channel buffer is full, log warning and continue
logx.Infof("Client %s channel is full while sending response with ID %d", client.id, id) logx.Infof("Client %s channel is full while sending response with ID %v", client.id, id)
} }
} }

View File

@@ -3,6 +3,7 @@ package mcp
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"sync" "sync"
"github.com/zeromicro/go-zero/rest" "github.com/zeromicro/go-zero/rest"
@@ -15,11 +16,31 @@ type Cursor string
type Request struct { type Request struct {
SessionId string `form:"session_id"` // Session identifier for client tracking SessionId string `form:"session_id"` // Session identifier for client tracking
JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec
ID int64 `json:"id"` // Request identifier for matching responses ID any `json:"id"` // Request identifier for matching responses
Method string `json:"method"` // Method name to invoke Method string `json:"method"` // Method name to invoke
Params json.RawMessage `json:"params"` // Parameters for the method Params json.RawMessage `json:"params"` // Parameters for the method
} }
func (r Request) isNotification() (bool, error) {
var isNotification bool
switch val := r.ID.(type) {
case int:
isNotification = val == 0
case int64:
isNotification = val == 0
case float64:
isNotification = val == 0.0
case string:
isNotification = len(val) == 0
case nil:
isNotification = true
default:
return false, fmt.Errorf("invalid type %T", val)
}
return isNotification, nil
}
type PaginatedParams struct { type PaginatedParams struct {
Cursor string `json:"cursor"` Cursor string `json:"cursor"`
Meta struct { Meta struct {
@@ -244,7 +265,7 @@ type errorObj struct {
// Response represents a JSON-RPC response // Response represents a JSON-RPC response
type Response struct { type Response struct {
JsonRpc string `json:"jsonrpc"` // Always "2.0" JsonRpc string `json:"jsonrpc"` // Always "2.0"
ID int64 `json:"id"` // Same as request ID ID any `json:"id"` // Same as request ID
Result any `json:"result"` // Result object (null if error) Result any `json:"result"` // Result object (null if error)
Error *errorObj `json:"error,omitempty"` // Error object (null if success) Error *errorObj `json:"error,omitempty"` // Error object (null if success)
} }

View File

@@ -3,6 +3,7 @@ package mcp
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -55,7 +56,7 @@ func TestRequestUnmarshaling(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "2.0", req.JsonRpc) assert.Equal(t, "2.0", req.JsonRpc)
assert.Equal(t, int64(789), req.ID) assert.Equal(t, float64(789), req.ID)
assert.Equal(t, "test_method", req.Method) assert.Equal(t, "test_method", req.Method)
// Check params unmarshaled correctly // Check params unmarshaled correctly
@@ -204,3 +205,67 @@ func TestCallToolResult(t *testing.T) {
assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`) assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`)
assert.NotContains(t, string(data), `"isError":`) assert.NotContains(t, string(data), `"isError":`)
} }
func TestRequest_isNotification(t *testing.T) {
tests := []struct {
name string
id any
want bool
wantErr error
}{
// integer test cases
{name: "int zero", id: 0, want: true, wantErr: nil},
{name: "int non-zero", id: 1, want: false, wantErr: nil},
{name: "int64 zero", id: int64(0), want: true, wantErr: nil},
{name: "int64 max", id: int64(9223372036854775807), want: false, wantErr: nil},
// floating point number test cases
{name: "float64 zero", id: float64(0.0), want: true, wantErr: nil},
{name: "float64 positive", id: float64(0.000001), want: false, wantErr: nil},
{name: "float64 negative", id: float64(-0.000001), want: false, wantErr: nil},
{name: "float64 epsilon", id: float64(1e-300), want: false, wantErr: nil},
// string test cases
{name: "empty string", id: "", want: true, wantErr: nil},
{name: "non-empty string", id: "abc", want: false, wantErr: nil},
{name: "space string", id: " ", want: false, wantErr: nil},
{name: "unicode string", id: "こんにちは", want: false, wantErr: nil},
// special cases
{name: "nil", id: nil, want: true, wantErr: nil},
// logical type test cases
{name: "bool true", id: true, want: false, wantErr: errors.New("invalid type bool")},
{name: "bool false", id: false, want: false, wantErr: errors.New("invalid type bool")},
{name: "struct type", id: struct{}{}, want: false, wantErr: errors.New("invalid type struct {}")},
{name: "slice type", id: []int{1, 2, 3}, want: false, wantErr: errors.New("invalid type []int")},
{name: "map type", id: map[string]int{"a": 1}, want: false, wantErr: errors.New("invalid type map[string]int")},
{name: "pointer type", id: new(int), want: false, wantErr: errors.New("invalid type *int")},
{name: "func type", id: func() {}, want: false, wantErr: errors.New("invalid type func()")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := Request{
SessionId: "test-session",
JsonRpc: "2.0",
ID: tt.id,
Method: "testMethod",
Params: json.RawMessage(`{}`),
}
got, err := req.isNotification()
if (err != nil) != (tt.wantErr != nil) {
t.Fatalf("error presence mismatch: got error = %v, wantErr %v", err, tt.wantErr)
}
if err != nil && tt.wantErr != nil && err.Error() != tt.wantErr.Error() {
t.Fatalf("error message mismatch:\ngot %q\nwant %q", err.Error(), tt.wantErr.Error())
}
if got != tt.want {
t.Errorf("isNotification() = %v, want %v for ID %v (%T)", got, tt.want, tt.id, tt.id)
}
})
}
}