From d4cccca3872d72fd6240a7979f83799e0f2d78b4 Mon Sep 17 00:00:00 2001 From: MarkJoyMa <64180138+MarkJoyMa@users.noreply.github.com> Date: Sat, 7 Jun 2025 10:37:18 +0800 Subject: [PATCH] Fix the problem that mcp request id is not of int type (#4914) --- mcp/server.go | 59 +++++++++++++++++++++-------------------- mcp/types.go | 25 ++++++++++++++++-- mcp/types_test.go | 67 ++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 120 insertions(+), 31 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index b3143ae9f..1c95095ce 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -173,17 +173,20 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) { return } - w.WriteHeader(http.StatusAccepted) - // For notification methods (no ID), we don't send a response - isNotification := req.ID == 0 + isNotification, err := req.isNotification() + if err != nil { + http.Error(w, "Invalid request.ID", http.StatusBadRequest) + } + + w.WriteHeader(http.StatusAccepted) // Special handling for initialization sequence // Always allow initialize and notifications/initialized regardless of client state if req.Method == methodInitialize { - logx.Infof("Processing initialize request with ID: %d", req.ID) + logx.Infof("Processing initialize request with ID: %v", req.ID) s.processInitialize(r.Context(), client, req) - logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID) + logx.Infof("Sent initialize response for ID: %v, waiting for notifications/initialized", req.ID) return } else if req.Method == methodNotificationsInitialized { // Handle initialized notification @@ -206,41 +209,41 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) { // Process normal requests only after initialization switch req.Method { case methodToolsCall: - logx.Infof("Received tools call request with ID: %d", req.ID) + logx.Infof("Received tools call request with ID: %v", req.ID) s.processToolCall(r.Context(), client, req) - logx.Infof("Sent tools call response for ID: %d", req.ID) + logx.Infof("Sent tools call response for ID: %v", req.ID) case methodToolsList: - logx.Infof("Processing tools/list request with ID: %d", req.ID) + logx.Infof("Processing tools/list request with ID: %v", req.ID) s.processListTools(r.Context(), client, req) - logx.Infof("Sent tools/list response for ID: %d", req.ID) + logx.Infof("Sent tools/list response for ID: %v", req.ID) case methodPromptsList: - logx.Infof("Processing prompts/list request with ID: %d", req.ID) + logx.Infof("Processing prompts/list request with ID: %v", req.ID) s.processListPrompts(r.Context(), client, req) - logx.Infof("Sent prompts/list response for ID: %d", req.ID) + logx.Infof("Sent prompts/list response for ID: %v", req.ID) case methodPromptsGet: - logx.Infof("Processing prompts/get request with ID: %d", req.ID) + logx.Infof("Processing prompts/get request with ID: %v", req.ID) s.processGetPrompt(r.Context(), client, req) - logx.Infof("Sent prompts/get response for ID: %d", req.ID) + logx.Infof("Sent prompts/get response for ID: %v", req.ID) case methodResourcesList: - logx.Infof("Processing resources/list request with ID: %d", req.ID) + logx.Infof("Processing resources/list request with ID: %v", req.ID) s.processListResources(r.Context(), client, req) - logx.Infof("Sent resources/list response for ID: %d", req.ID) + logx.Infof("Sent resources/list response for ID: %v", req.ID) case methodResourcesRead: - logx.Infof("Processing resources/read request with ID: %d", req.ID) + logx.Infof("Processing resources/read request with ID: %v", req.ID) s.processResourcesRead(r.Context(), client, req) - logx.Infof("Sent resources/read response for ID: %d", req.ID) + logx.Infof("Sent resources/read response for ID: %v", req.ID) case methodResourcesSubscribe: - logx.Infof("Processing resources/subscribe request with ID: %d", req.ID) + logx.Infof("Processing resources/subscribe request with ID: %v", req.ID) s.processResourceSubscribe(r.Context(), client, req) - logx.Infof("Sent resources/subscribe response for ID: %d", req.ID) + logx.Infof("Sent resources/subscribe response for ID: %v", req.ID) case methodPing: - logx.Infof("Processing ping request with ID: %d", req.ID) + logx.Infof("Processing ping request with ID: %v", req.ID) s.processPing(r.Context(), client, req) case methodNotificationsCancelled: - logx.Infof("Received notifications/cancelled notification: %d", req.ID) + logx.Infof("Received notifications/cancelled notification: %v", req.ID) s.processNotificationCancelled(r.Context(), client, req) default: - logx.Infof("Unknown method: %s from client: %d", req.Method, req.ID) + logx.Infof("Unknown method: %s from client: %v", req.Method, req.ID) s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound) } } @@ -880,10 +883,10 @@ func (s *sseMcpServer) processPing(ctx context.Context, client *mcpClient, req R // sendErrorResponse sends an error response via the SSE channel func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient, - id int64, message string, code int) { + id any, message string, code int) { errorResponse := struct { JsonRpc string `json:"jsonrpc"` - ID int64 `json:"id"` + ID any `json:"id"` Error errorMessage `json:"error"` }{ JsonRpc: jsonRpcVersion, @@ -898,7 +901,7 @@ func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient, jsonData, _ := json.Marshal(errorResponse) // Use CRLF line endings as requested sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData)) - logx.Infof("Sending error for ID %d: %s", id, sseMessage) + logx.Infof("Sending error for ID %v: %s", id, sseMessage) // cannot receive from ctx.Done() because we're sending to the channel for SSE messages select { @@ -910,7 +913,7 @@ func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient, } // sendResponse sends a success response via the SSE channel -func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id int64, result any) { +func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id any, result any) { response := Response{ JsonRpc: jsonRpcVersion, ID: id, @@ -925,13 +928,13 @@ func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id i // Use CRLF line endings as requested sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData)) - logx.Infof("Sending response for ID %d: %s", id, sseMessage) + logx.Infof("Sending response for ID %v: %s", id, sseMessage) // cannot receive from ctx.Done() because we're sending to the channel for SSE messages select { case client.channel <- sseMessage: default: // Channel buffer is full, log warning and continue - logx.Infof("Client %s channel is full while sending response with ID %d", client.id, id) + logx.Infof("Client %s channel is full while sending response with ID %v", client.id, id) } } diff --git a/mcp/types.go b/mcp/types.go index e711be582..5a056313f 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -3,6 +3,7 @@ package mcp import ( "context" "encoding/json" + "fmt" "sync" "github.com/zeromicro/go-zero/rest" @@ -15,11 +16,31 @@ type Cursor string type Request struct { SessionId string `form:"session_id"` // Session identifier for client tracking JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec - ID int64 `json:"id"` // Request identifier for matching responses + ID any `json:"id"` // Request identifier for matching responses Method string `json:"method"` // Method name to invoke Params json.RawMessage `json:"params"` // Parameters for the method } +func (r Request) isNotification() (bool, error) { + var isNotification bool + switch val := r.ID.(type) { + case int: + isNotification = val == 0 + case int64: + isNotification = val == 0 + case float64: + isNotification = val == 0.0 + case string: + isNotification = len(val) == 0 + case nil: + isNotification = true + default: + return false, fmt.Errorf("invalid type %T", val) + } + + return isNotification, nil +} + type PaginatedParams struct { Cursor string `json:"cursor"` Meta struct { @@ -244,7 +265,7 @@ type errorObj struct { // Response represents a JSON-RPC response type Response struct { JsonRpc string `json:"jsonrpc"` // Always "2.0" - ID int64 `json:"id"` // Same as request ID + ID any `json:"id"` // Same as request ID Result any `json:"result"` // Result object (null if error) Error *errorObj `json:"error,omitempty"` // Error object (null if success) } diff --git a/mcp/types_test.go b/mcp/types_test.go index e0b5c323d..ba27100c9 100644 --- a/mcp/types_test.go +++ b/mcp/types_test.go @@ -3,6 +3,7 @@ package mcp import ( "context" "encoding/json" + "errors" "testing" "github.com/stretchr/testify/assert" @@ -55,7 +56,7 @@ func TestRequestUnmarshaling(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "2.0", req.JsonRpc) - assert.Equal(t, int64(789), req.ID) + assert.Equal(t, float64(789), req.ID) assert.Equal(t, "test_method", req.Method) // Check params unmarshaled correctly @@ -204,3 +205,67 @@ func TestCallToolResult(t *testing.T) { assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`) assert.NotContains(t, string(data), `"isError":`) } + +func TestRequest_isNotification(t *testing.T) { + tests := []struct { + name string + id any + want bool + wantErr error + }{ + // integer test cases + {name: "int zero", id: 0, want: true, wantErr: nil}, + {name: "int non-zero", id: 1, want: false, wantErr: nil}, + {name: "int64 zero", id: int64(0), want: true, wantErr: nil}, + {name: "int64 max", id: int64(9223372036854775807), want: false, wantErr: nil}, + + // floating point number test cases + {name: "float64 zero", id: float64(0.0), want: true, wantErr: nil}, + {name: "float64 positive", id: float64(0.000001), want: false, wantErr: nil}, + {name: "float64 negative", id: float64(-0.000001), want: false, wantErr: nil}, + {name: "float64 epsilon", id: float64(1e-300), want: false, wantErr: nil}, + + // string test cases + {name: "empty string", id: "", want: true, wantErr: nil}, + {name: "non-empty string", id: "abc", want: false, wantErr: nil}, + {name: "space string", id: " ", want: false, wantErr: nil}, + {name: "unicode string", id: "こんにちは", want: false, wantErr: nil}, + + // special cases + {name: "nil", id: nil, want: true, wantErr: nil}, + + // logical type test cases + {name: "bool true", id: true, want: false, wantErr: errors.New("invalid type bool")}, + {name: "bool false", id: false, want: false, wantErr: errors.New("invalid type bool")}, + {name: "struct type", id: struct{}{}, want: false, wantErr: errors.New("invalid type struct {}")}, + {name: "slice type", id: []int{1, 2, 3}, want: false, wantErr: errors.New("invalid type []int")}, + {name: "map type", id: map[string]int{"a": 1}, want: false, wantErr: errors.New("invalid type map[string]int")}, + {name: "pointer type", id: new(int), want: false, wantErr: errors.New("invalid type *int")}, + {name: "func type", id: func() {}, want: false, wantErr: errors.New("invalid type func()")}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := Request{ + SessionId: "test-session", + JsonRpc: "2.0", + ID: tt.id, + Method: "testMethod", + Params: json.RawMessage(`{}`), + } + + got, err := req.isNotification() + + if (err != nil) != (tt.wantErr != nil) { + t.Fatalf("error presence mismatch: got error = %v, wantErr %v", err, tt.wantErr) + } + if err != nil && tt.wantErr != nil && err.Error() != tt.wantErr.Error() { + t.Fatalf("error message mismatch:\ngot %q\nwant %q", err.Error(), tt.wantErr.Error()) + } + + if got != tt.want { + t.Errorf("isNotification() = %v, want %v for ID %v (%T)", got, tt.want, tt.id, tt.id) + } + }) + } +}