mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-07 15:10:01 +08:00
Fix the problem that mcp request id is not of int type (#4914)
This commit is contained in:
@@ -173,17 +173,20 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
|
||||
// For notification methods (no ID), we don't send a response
|
||||
isNotification := req.ID == 0
|
||||
isNotification, err := req.isNotification()
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid request.ID", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
|
||||
// Special handling for initialization sequence
|
||||
// Always allow initialize and notifications/initialized regardless of client state
|
||||
if req.Method == methodInitialize {
|
||||
logx.Infof("Processing initialize request with ID: %d", req.ID)
|
||||
logx.Infof("Processing initialize request with ID: %v", req.ID)
|
||||
s.processInitialize(r.Context(), client, req)
|
||||
logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID)
|
||||
logx.Infof("Sent initialize response for ID: %v, waiting for notifications/initialized", req.ID)
|
||||
return
|
||||
} else if req.Method == methodNotificationsInitialized {
|
||||
// Handle initialized notification
|
||||
@@ -206,41 +209,41 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
// Process normal requests only after initialization
|
||||
switch req.Method {
|
||||
case methodToolsCall:
|
||||
logx.Infof("Received tools call request with ID: %d", req.ID)
|
||||
logx.Infof("Received tools call request with ID: %v", req.ID)
|
||||
s.processToolCall(r.Context(), client, req)
|
||||
logx.Infof("Sent tools call response for ID: %d", req.ID)
|
||||
logx.Infof("Sent tools call response for ID: %v", req.ID)
|
||||
case methodToolsList:
|
||||
logx.Infof("Processing tools/list request with ID: %d", req.ID)
|
||||
logx.Infof("Processing tools/list request with ID: %v", req.ID)
|
||||
s.processListTools(r.Context(), client, req)
|
||||
logx.Infof("Sent tools/list response for ID: %d", req.ID)
|
||||
logx.Infof("Sent tools/list response for ID: %v", req.ID)
|
||||
case methodPromptsList:
|
||||
logx.Infof("Processing prompts/list request with ID: %d", req.ID)
|
||||
logx.Infof("Processing prompts/list request with ID: %v", req.ID)
|
||||
s.processListPrompts(r.Context(), client, req)
|
||||
logx.Infof("Sent prompts/list response for ID: %d", req.ID)
|
||||
logx.Infof("Sent prompts/list response for ID: %v", req.ID)
|
||||
case methodPromptsGet:
|
||||
logx.Infof("Processing prompts/get request with ID: %d", req.ID)
|
||||
logx.Infof("Processing prompts/get request with ID: %v", req.ID)
|
||||
s.processGetPrompt(r.Context(), client, req)
|
||||
logx.Infof("Sent prompts/get response for ID: %d", req.ID)
|
||||
logx.Infof("Sent prompts/get response for ID: %v", req.ID)
|
||||
case methodResourcesList:
|
||||
logx.Infof("Processing resources/list request with ID: %d", req.ID)
|
||||
logx.Infof("Processing resources/list request with ID: %v", req.ID)
|
||||
s.processListResources(r.Context(), client, req)
|
||||
logx.Infof("Sent resources/list response for ID: %d", req.ID)
|
||||
logx.Infof("Sent resources/list response for ID: %v", req.ID)
|
||||
case methodResourcesRead:
|
||||
logx.Infof("Processing resources/read request with ID: %d", req.ID)
|
||||
logx.Infof("Processing resources/read request with ID: %v", req.ID)
|
||||
s.processResourcesRead(r.Context(), client, req)
|
||||
logx.Infof("Sent resources/read response for ID: %d", req.ID)
|
||||
logx.Infof("Sent resources/read response for ID: %v", req.ID)
|
||||
case methodResourcesSubscribe:
|
||||
logx.Infof("Processing resources/subscribe request with ID: %d", req.ID)
|
||||
logx.Infof("Processing resources/subscribe request with ID: %v", req.ID)
|
||||
s.processResourceSubscribe(r.Context(), client, req)
|
||||
logx.Infof("Sent resources/subscribe response for ID: %d", req.ID)
|
||||
logx.Infof("Sent resources/subscribe response for ID: %v", req.ID)
|
||||
case methodPing:
|
||||
logx.Infof("Processing ping request with ID: %d", req.ID)
|
||||
logx.Infof("Processing ping request with ID: %v", req.ID)
|
||||
s.processPing(r.Context(), client, req)
|
||||
case methodNotificationsCancelled:
|
||||
logx.Infof("Received notifications/cancelled notification: %d", req.ID)
|
||||
logx.Infof("Received notifications/cancelled notification: %v", req.ID)
|
||||
s.processNotificationCancelled(r.Context(), client, req)
|
||||
default:
|
||||
logx.Infof("Unknown method: %s from client: %d", req.Method, req.ID)
|
||||
logx.Infof("Unknown method: %s from client: %v", req.Method, req.ID)
|
||||
s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||
}
|
||||
}
|
||||
@@ -880,10 +883,10 @@ func (s *sseMcpServer) processPing(ctx context.Context, client *mcpClient, req R
|
||||
|
||||
// sendErrorResponse sends an error response via the SSE channel
|
||||
func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
|
||||
id int64, message string, code int) {
|
||||
id any, message string, code int) {
|
||||
errorResponse := struct {
|
||||
JsonRpc string `json:"jsonrpc"`
|
||||
ID int64 `json:"id"`
|
||||
ID any `json:"id"`
|
||||
Error errorMessage `json:"error"`
|
||||
}{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
@@ -898,7 +901,7 @@ func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
|
||||
jsonData, _ := json.Marshal(errorResponse)
|
||||
// Use CRLF line endings as requested
|
||||
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||
logx.Infof("Sending error for ID %d: %s", id, sseMessage)
|
||||
logx.Infof("Sending error for ID %v: %s", id, sseMessage)
|
||||
|
||||
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||
select {
|
||||
@@ -910,7 +913,7 @@ func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
|
||||
}
|
||||
|
||||
// sendResponse sends a success response via the SSE channel
|
||||
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id int64, result any) {
|
||||
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id any, result any) {
|
||||
response := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: id,
|
||||
@@ -925,13 +928,13 @@ func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id i
|
||||
|
||||
// Use CRLF line endings as requested
|
||||
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||
logx.Infof("Sending response for ID %d: %s", id, sseMessage)
|
||||
logx.Infof("Sending response for ID %v: %s", id, sseMessage)
|
||||
|
||||
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||
select {
|
||||
case client.channel <- sseMessage:
|
||||
default:
|
||||
// Channel buffer is full, log warning and continue
|
||||
logx.Infof("Client %s channel is full while sending response with ID %d", client.id, id)
|
||||
logx.Infof("Client %s channel is full while sending response with ID %v", client.id, id)
|
||||
}
|
||||
}
|
||||
|
||||
25
mcp/types.go
25
mcp/types.go
@@ -3,6 +3,7 @@ package mcp
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/zeromicro/go-zero/rest"
|
||||
@@ -15,11 +16,31 @@ type Cursor string
|
||||
type Request struct {
|
||||
SessionId string `form:"session_id"` // Session identifier for client tracking
|
||||
JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec
|
||||
ID int64 `json:"id"` // Request identifier for matching responses
|
||||
ID any `json:"id"` // Request identifier for matching responses
|
||||
Method string `json:"method"` // Method name to invoke
|
||||
Params json.RawMessage `json:"params"` // Parameters for the method
|
||||
}
|
||||
|
||||
func (r Request) isNotification() (bool, error) {
|
||||
var isNotification bool
|
||||
switch val := r.ID.(type) {
|
||||
case int:
|
||||
isNotification = val == 0
|
||||
case int64:
|
||||
isNotification = val == 0
|
||||
case float64:
|
||||
isNotification = val == 0.0
|
||||
case string:
|
||||
isNotification = len(val) == 0
|
||||
case nil:
|
||||
isNotification = true
|
||||
default:
|
||||
return false, fmt.Errorf("invalid type %T", val)
|
||||
}
|
||||
|
||||
return isNotification, nil
|
||||
}
|
||||
|
||||
type PaginatedParams struct {
|
||||
Cursor string `json:"cursor"`
|
||||
Meta struct {
|
||||
@@ -244,7 +265,7 @@ type errorObj struct {
|
||||
// Response represents a JSON-RPC response
|
||||
type Response struct {
|
||||
JsonRpc string `json:"jsonrpc"` // Always "2.0"
|
||||
ID int64 `json:"id"` // Same as request ID
|
||||
ID any `json:"id"` // Same as request ID
|
||||
Result any `json:"result"` // Result object (null if error)
|
||||
Error *errorObj `json:"error,omitempty"` // Error object (null if success)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package mcp
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -55,7 +56,7 @@ func TestRequestUnmarshaling(t *testing.T) {
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "2.0", req.JsonRpc)
|
||||
assert.Equal(t, int64(789), req.ID)
|
||||
assert.Equal(t, float64(789), req.ID)
|
||||
assert.Equal(t, "test_method", req.Method)
|
||||
|
||||
// Check params unmarshaled correctly
|
||||
@@ -204,3 +205,67 @@ func TestCallToolResult(t *testing.T) {
|
||||
assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`)
|
||||
assert.NotContains(t, string(data), `"isError":`)
|
||||
}
|
||||
|
||||
func TestRequest_isNotification(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id any
|
||||
want bool
|
||||
wantErr error
|
||||
}{
|
||||
// integer test cases
|
||||
{name: "int zero", id: 0, want: true, wantErr: nil},
|
||||
{name: "int non-zero", id: 1, want: false, wantErr: nil},
|
||||
{name: "int64 zero", id: int64(0), want: true, wantErr: nil},
|
||||
{name: "int64 max", id: int64(9223372036854775807), want: false, wantErr: nil},
|
||||
|
||||
// floating point number test cases
|
||||
{name: "float64 zero", id: float64(0.0), want: true, wantErr: nil},
|
||||
{name: "float64 positive", id: float64(0.000001), want: false, wantErr: nil},
|
||||
{name: "float64 negative", id: float64(-0.000001), want: false, wantErr: nil},
|
||||
{name: "float64 epsilon", id: float64(1e-300), want: false, wantErr: nil},
|
||||
|
||||
// string test cases
|
||||
{name: "empty string", id: "", want: true, wantErr: nil},
|
||||
{name: "non-empty string", id: "abc", want: false, wantErr: nil},
|
||||
{name: "space string", id: " ", want: false, wantErr: nil},
|
||||
{name: "unicode string", id: "こんにちは", want: false, wantErr: nil},
|
||||
|
||||
// special cases
|
||||
{name: "nil", id: nil, want: true, wantErr: nil},
|
||||
|
||||
// logical type test cases
|
||||
{name: "bool true", id: true, want: false, wantErr: errors.New("invalid type bool")},
|
||||
{name: "bool false", id: false, want: false, wantErr: errors.New("invalid type bool")},
|
||||
{name: "struct type", id: struct{}{}, want: false, wantErr: errors.New("invalid type struct {}")},
|
||||
{name: "slice type", id: []int{1, 2, 3}, want: false, wantErr: errors.New("invalid type []int")},
|
||||
{name: "map type", id: map[string]int{"a": 1}, want: false, wantErr: errors.New("invalid type map[string]int")},
|
||||
{name: "pointer type", id: new(int), want: false, wantErr: errors.New("invalid type *int")},
|
||||
{name: "func type", id: func() {}, want: false, wantErr: errors.New("invalid type func()")},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := Request{
|
||||
SessionId: "test-session",
|
||||
JsonRpc: "2.0",
|
||||
ID: tt.id,
|
||||
Method: "testMethod",
|
||||
Params: json.RawMessage(`{}`),
|
||||
}
|
||||
|
||||
got, err := req.isNotification()
|
||||
|
||||
if (err != nil) != (tt.wantErr != nil) {
|
||||
t.Fatalf("error presence mismatch: got error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if err != nil && tt.wantErr != nil && err.Error() != tt.wantErr.Error() {
|
||||
t.Fatalf("error message mismatch:\ngot %q\nwant %q", err.Error(), tt.wantErr.Error())
|
||||
}
|
||||
|
||||
if got != tt.want {
|
||||
t.Errorf("isNotification() = %v, want %v for ID %v (%T)", got, tt.want, tt.id, tt.id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user