Fix the problem that mcp request id is not of int type (#4914)

This commit is contained in:
MarkJoyMa
2025-06-07 10:37:18 +08:00
committed by GitHub
parent 4b2095ed03
commit d4cccca387
3 changed files with 120 additions and 31 deletions

View File

@@ -173,17 +173,20 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
return
}
w.WriteHeader(http.StatusAccepted)
// For notification methods (no ID), we don't send a response
isNotification := req.ID == 0
isNotification, err := req.isNotification()
if err != nil {
http.Error(w, "Invalid request.ID", http.StatusBadRequest)
}
w.WriteHeader(http.StatusAccepted)
// Special handling for initialization sequence
// Always allow initialize and notifications/initialized regardless of client state
if req.Method == methodInitialize {
logx.Infof("Processing initialize request with ID: %d", req.ID)
logx.Infof("Processing initialize request with ID: %v", req.ID)
s.processInitialize(r.Context(), client, req)
logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID)
logx.Infof("Sent initialize response for ID: %v, waiting for notifications/initialized", req.ID)
return
} else if req.Method == methodNotificationsInitialized {
// Handle initialized notification
@@ -206,41 +209,41 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
// Process normal requests only after initialization
switch req.Method {
case methodToolsCall:
logx.Infof("Received tools call request with ID: %d", req.ID)
logx.Infof("Received tools call request with ID: %v", req.ID)
s.processToolCall(r.Context(), client, req)
logx.Infof("Sent tools call response for ID: %d", req.ID)
logx.Infof("Sent tools call response for ID: %v", req.ID)
case methodToolsList:
logx.Infof("Processing tools/list request with ID: %d", req.ID)
logx.Infof("Processing tools/list request with ID: %v", req.ID)
s.processListTools(r.Context(), client, req)
logx.Infof("Sent tools/list response for ID: %d", req.ID)
logx.Infof("Sent tools/list response for ID: %v", req.ID)
case methodPromptsList:
logx.Infof("Processing prompts/list request with ID: %d", req.ID)
logx.Infof("Processing prompts/list request with ID: %v", req.ID)
s.processListPrompts(r.Context(), client, req)
logx.Infof("Sent prompts/list response for ID: %d", req.ID)
logx.Infof("Sent prompts/list response for ID: %v", req.ID)
case methodPromptsGet:
logx.Infof("Processing prompts/get request with ID: %d", req.ID)
logx.Infof("Processing prompts/get request with ID: %v", req.ID)
s.processGetPrompt(r.Context(), client, req)
logx.Infof("Sent prompts/get response for ID: %d", req.ID)
logx.Infof("Sent prompts/get response for ID: %v", req.ID)
case methodResourcesList:
logx.Infof("Processing resources/list request with ID: %d", req.ID)
logx.Infof("Processing resources/list request with ID: %v", req.ID)
s.processListResources(r.Context(), client, req)
logx.Infof("Sent resources/list response for ID: %d", req.ID)
logx.Infof("Sent resources/list response for ID: %v", req.ID)
case methodResourcesRead:
logx.Infof("Processing resources/read request with ID: %d", req.ID)
logx.Infof("Processing resources/read request with ID: %v", req.ID)
s.processResourcesRead(r.Context(), client, req)
logx.Infof("Sent resources/read response for ID: %d", req.ID)
logx.Infof("Sent resources/read response for ID: %v", req.ID)
case methodResourcesSubscribe:
logx.Infof("Processing resources/subscribe request with ID: %d", req.ID)
logx.Infof("Processing resources/subscribe request with ID: %v", req.ID)
s.processResourceSubscribe(r.Context(), client, req)
logx.Infof("Sent resources/subscribe response for ID: %d", req.ID)
logx.Infof("Sent resources/subscribe response for ID: %v", req.ID)
case methodPing:
logx.Infof("Processing ping request with ID: %d", req.ID)
logx.Infof("Processing ping request with ID: %v", req.ID)
s.processPing(r.Context(), client, req)
case methodNotificationsCancelled:
logx.Infof("Received notifications/cancelled notification: %d", req.ID)
logx.Infof("Received notifications/cancelled notification: %v", req.ID)
s.processNotificationCancelled(r.Context(), client, req)
default:
logx.Infof("Unknown method: %s from client: %d", req.Method, req.ID)
logx.Infof("Unknown method: %s from client: %v", req.Method, req.ID)
s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound)
}
}
@@ -880,10 +883,10 @@ func (s *sseMcpServer) processPing(ctx context.Context, client *mcpClient, req R
// sendErrorResponse sends an error response via the SSE channel
func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
id int64, message string, code int) {
id any, message string, code int) {
errorResponse := struct {
JsonRpc string `json:"jsonrpc"`
ID int64 `json:"id"`
ID any `json:"id"`
Error errorMessage `json:"error"`
}{
JsonRpc: jsonRpcVersion,
@@ -898,7 +901,7 @@ func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
jsonData, _ := json.Marshal(errorResponse)
// Use CRLF line endings as requested
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
logx.Infof("Sending error for ID %d: %s", id, sseMessage)
logx.Infof("Sending error for ID %v: %s", id, sseMessage)
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
select {
@@ -910,7 +913,7 @@ func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
}
// sendResponse sends a success response via the SSE channel
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id int64, result any) {
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id any, result any) {
response := Response{
JsonRpc: jsonRpcVersion,
ID: id,
@@ -925,13 +928,13 @@ func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id i
// Use CRLF line endings as requested
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
logx.Infof("Sending response for ID %d: %s", id, sseMessage)
logx.Infof("Sending response for ID %v: %s", id, sseMessage)
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
select {
case client.channel <- sseMessage:
default:
// Channel buffer is full, log warning and continue
logx.Infof("Client %s channel is full while sending response with ID %d", client.id, id)
logx.Infof("Client %s channel is full while sending response with ID %v", client.id, id)
}
}

View File

@@ -3,6 +3,7 @@ package mcp
import (
"context"
"encoding/json"
"fmt"
"sync"
"github.com/zeromicro/go-zero/rest"
@@ -15,11 +16,31 @@ type Cursor string
type Request struct {
SessionId string `form:"session_id"` // Session identifier for client tracking
JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec
ID int64 `json:"id"` // Request identifier for matching responses
ID any `json:"id"` // Request identifier for matching responses
Method string `json:"method"` // Method name to invoke
Params json.RawMessage `json:"params"` // Parameters for the method
}
func (r Request) isNotification() (bool, error) {
var isNotification bool
switch val := r.ID.(type) {
case int:
isNotification = val == 0
case int64:
isNotification = val == 0
case float64:
isNotification = val == 0.0
case string:
isNotification = len(val) == 0
case nil:
isNotification = true
default:
return false, fmt.Errorf("invalid type %T", val)
}
return isNotification, nil
}
type PaginatedParams struct {
Cursor string `json:"cursor"`
Meta struct {
@@ -244,7 +265,7 @@ type errorObj struct {
// Response represents a JSON-RPC response
type Response struct {
JsonRpc string `json:"jsonrpc"` // Always "2.0"
ID int64 `json:"id"` // Same as request ID
ID any `json:"id"` // Same as request ID
Result any `json:"result"` // Result object (null if error)
Error *errorObj `json:"error,omitempty"` // Error object (null if success)
}

View File

@@ -3,6 +3,7 @@ package mcp
import (
"context"
"encoding/json"
"errors"
"testing"
"github.com/stretchr/testify/assert"
@@ -55,7 +56,7 @@ func TestRequestUnmarshaling(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "2.0", req.JsonRpc)
assert.Equal(t, int64(789), req.ID)
assert.Equal(t, float64(789), req.ID)
assert.Equal(t, "test_method", req.Method)
// Check params unmarshaled correctly
@@ -204,3 +205,67 @@ func TestCallToolResult(t *testing.T) {
assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`)
assert.NotContains(t, string(data), `"isError":`)
}
func TestRequest_isNotification(t *testing.T) {
tests := []struct {
name string
id any
want bool
wantErr error
}{
// integer test cases
{name: "int zero", id: 0, want: true, wantErr: nil},
{name: "int non-zero", id: 1, want: false, wantErr: nil},
{name: "int64 zero", id: int64(0), want: true, wantErr: nil},
{name: "int64 max", id: int64(9223372036854775807), want: false, wantErr: nil},
// floating point number test cases
{name: "float64 zero", id: float64(0.0), want: true, wantErr: nil},
{name: "float64 positive", id: float64(0.000001), want: false, wantErr: nil},
{name: "float64 negative", id: float64(-0.000001), want: false, wantErr: nil},
{name: "float64 epsilon", id: float64(1e-300), want: false, wantErr: nil},
// string test cases
{name: "empty string", id: "", want: true, wantErr: nil},
{name: "non-empty string", id: "abc", want: false, wantErr: nil},
{name: "space string", id: " ", want: false, wantErr: nil},
{name: "unicode string", id: "こんにちは", want: false, wantErr: nil},
// special cases
{name: "nil", id: nil, want: true, wantErr: nil},
// logical type test cases
{name: "bool true", id: true, want: false, wantErr: errors.New("invalid type bool")},
{name: "bool false", id: false, want: false, wantErr: errors.New("invalid type bool")},
{name: "struct type", id: struct{}{}, want: false, wantErr: errors.New("invalid type struct {}")},
{name: "slice type", id: []int{1, 2, 3}, want: false, wantErr: errors.New("invalid type []int")},
{name: "map type", id: map[string]int{"a": 1}, want: false, wantErr: errors.New("invalid type map[string]int")},
{name: "pointer type", id: new(int), want: false, wantErr: errors.New("invalid type *int")},
{name: "func type", id: func() {}, want: false, wantErr: errors.New("invalid type func()")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := Request{
SessionId: "test-session",
JsonRpc: "2.0",
ID: tt.id,
Method: "testMethod",
Params: json.RawMessage(`{}`),
}
got, err := req.isNotification()
if (err != nil) != (tt.wantErr != nil) {
t.Fatalf("error presence mismatch: got error = %v, wantErr %v", err, tt.wantErr)
}
if err != nil && tt.wantErr != nil && err.Error() != tt.wantErr.Error() {
t.Fatalf("error message mismatch:\ngot %q\nwant %q", err.Error(), tt.wantErr.Error())
}
if got != tt.want {
t.Errorf("isNotification() = %v, want %v for ID %v (%T)", got, tt.want, tt.id, tt.id)
}
})
}
}