mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-07 15:10:01 +08:00
Fix the problem that mcp request id is not of int type (#4914)
This commit is contained in:
@@ -173,17 +173,20 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteHeader(http.StatusAccepted)
|
|
||||||
|
|
||||||
// For notification methods (no ID), we don't send a response
|
// For notification methods (no ID), we don't send a response
|
||||||
isNotification := req.ID == 0
|
isNotification, err := req.isNotification()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Invalid request.ID", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusAccepted)
|
||||||
|
|
||||||
// Special handling for initialization sequence
|
// Special handling for initialization sequence
|
||||||
// Always allow initialize and notifications/initialized regardless of client state
|
// Always allow initialize and notifications/initialized regardless of client state
|
||||||
if req.Method == methodInitialize {
|
if req.Method == methodInitialize {
|
||||||
logx.Infof("Processing initialize request with ID: %d", req.ID)
|
logx.Infof("Processing initialize request with ID: %v", req.ID)
|
||||||
s.processInitialize(r.Context(), client, req)
|
s.processInitialize(r.Context(), client, req)
|
||||||
logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID)
|
logx.Infof("Sent initialize response for ID: %v, waiting for notifications/initialized", req.ID)
|
||||||
return
|
return
|
||||||
} else if req.Method == methodNotificationsInitialized {
|
} else if req.Method == methodNotificationsInitialized {
|
||||||
// Handle initialized notification
|
// Handle initialized notification
|
||||||
@@ -206,41 +209,41 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Process normal requests only after initialization
|
// Process normal requests only after initialization
|
||||||
switch req.Method {
|
switch req.Method {
|
||||||
case methodToolsCall:
|
case methodToolsCall:
|
||||||
logx.Infof("Received tools call request with ID: %d", req.ID)
|
logx.Infof("Received tools call request with ID: %v", req.ID)
|
||||||
s.processToolCall(r.Context(), client, req)
|
s.processToolCall(r.Context(), client, req)
|
||||||
logx.Infof("Sent tools call response for ID: %d", req.ID)
|
logx.Infof("Sent tools call response for ID: %v", req.ID)
|
||||||
case methodToolsList:
|
case methodToolsList:
|
||||||
logx.Infof("Processing tools/list request with ID: %d", req.ID)
|
logx.Infof("Processing tools/list request with ID: %v", req.ID)
|
||||||
s.processListTools(r.Context(), client, req)
|
s.processListTools(r.Context(), client, req)
|
||||||
logx.Infof("Sent tools/list response for ID: %d", req.ID)
|
logx.Infof("Sent tools/list response for ID: %v", req.ID)
|
||||||
case methodPromptsList:
|
case methodPromptsList:
|
||||||
logx.Infof("Processing prompts/list request with ID: %d", req.ID)
|
logx.Infof("Processing prompts/list request with ID: %v", req.ID)
|
||||||
s.processListPrompts(r.Context(), client, req)
|
s.processListPrompts(r.Context(), client, req)
|
||||||
logx.Infof("Sent prompts/list response for ID: %d", req.ID)
|
logx.Infof("Sent prompts/list response for ID: %v", req.ID)
|
||||||
case methodPromptsGet:
|
case methodPromptsGet:
|
||||||
logx.Infof("Processing prompts/get request with ID: %d", req.ID)
|
logx.Infof("Processing prompts/get request with ID: %v", req.ID)
|
||||||
s.processGetPrompt(r.Context(), client, req)
|
s.processGetPrompt(r.Context(), client, req)
|
||||||
logx.Infof("Sent prompts/get response for ID: %d", req.ID)
|
logx.Infof("Sent prompts/get response for ID: %v", req.ID)
|
||||||
case methodResourcesList:
|
case methodResourcesList:
|
||||||
logx.Infof("Processing resources/list request with ID: %d", req.ID)
|
logx.Infof("Processing resources/list request with ID: %v", req.ID)
|
||||||
s.processListResources(r.Context(), client, req)
|
s.processListResources(r.Context(), client, req)
|
||||||
logx.Infof("Sent resources/list response for ID: %d", req.ID)
|
logx.Infof("Sent resources/list response for ID: %v", req.ID)
|
||||||
case methodResourcesRead:
|
case methodResourcesRead:
|
||||||
logx.Infof("Processing resources/read request with ID: %d", req.ID)
|
logx.Infof("Processing resources/read request with ID: %v", req.ID)
|
||||||
s.processResourcesRead(r.Context(), client, req)
|
s.processResourcesRead(r.Context(), client, req)
|
||||||
logx.Infof("Sent resources/read response for ID: %d", req.ID)
|
logx.Infof("Sent resources/read response for ID: %v", req.ID)
|
||||||
case methodResourcesSubscribe:
|
case methodResourcesSubscribe:
|
||||||
logx.Infof("Processing resources/subscribe request with ID: %d", req.ID)
|
logx.Infof("Processing resources/subscribe request with ID: %v", req.ID)
|
||||||
s.processResourceSubscribe(r.Context(), client, req)
|
s.processResourceSubscribe(r.Context(), client, req)
|
||||||
logx.Infof("Sent resources/subscribe response for ID: %d", req.ID)
|
logx.Infof("Sent resources/subscribe response for ID: %v", req.ID)
|
||||||
case methodPing:
|
case methodPing:
|
||||||
logx.Infof("Processing ping request with ID: %d", req.ID)
|
logx.Infof("Processing ping request with ID: %v", req.ID)
|
||||||
s.processPing(r.Context(), client, req)
|
s.processPing(r.Context(), client, req)
|
||||||
case methodNotificationsCancelled:
|
case methodNotificationsCancelled:
|
||||||
logx.Infof("Received notifications/cancelled notification: %d", req.ID)
|
logx.Infof("Received notifications/cancelled notification: %v", req.ID)
|
||||||
s.processNotificationCancelled(r.Context(), client, req)
|
s.processNotificationCancelled(r.Context(), client, req)
|
||||||
default:
|
default:
|
||||||
logx.Infof("Unknown method: %s from client: %d", req.Method, req.ID)
|
logx.Infof("Unknown method: %s from client: %v", req.Method, req.ID)
|
||||||
s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -880,10 +883,10 @@ func (s *sseMcpServer) processPing(ctx context.Context, client *mcpClient, req R
|
|||||||
|
|
||||||
// sendErrorResponse sends an error response via the SSE channel
|
// sendErrorResponse sends an error response via the SSE channel
|
||||||
func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
|
func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
|
||||||
id int64, message string, code int) {
|
id any, message string, code int) {
|
||||||
errorResponse := struct {
|
errorResponse := struct {
|
||||||
JsonRpc string `json:"jsonrpc"`
|
JsonRpc string `json:"jsonrpc"`
|
||||||
ID int64 `json:"id"`
|
ID any `json:"id"`
|
||||||
Error errorMessage `json:"error"`
|
Error errorMessage `json:"error"`
|
||||||
}{
|
}{
|
||||||
JsonRpc: jsonRpcVersion,
|
JsonRpc: jsonRpcVersion,
|
||||||
@@ -898,7 +901,7 @@ func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
|
|||||||
jsonData, _ := json.Marshal(errorResponse)
|
jsonData, _ := json.Marshal(errorResponse)
|
||||||
// Use CRLF line endings as requested
|
// Use CRLF line endings as requested
|
||||||
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||||
logx.Infof("Sending error for ID %d: %s", id, sseMessage)
|
logx.Infof("Sending error for ID %v: %s", id, sseMessage)
|
||||||
|
|
||||||
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||||
select {
|
select {
|
||||||
@@ -910,7 +913,7 @@ func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sendResponse sends a success response via the SSE channel
|
// sendResponse sends a success response via the SSE channel
|
||||||
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id int64, result any) {
|
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id any, result any) {
|
||||||
response := Response{
|
response := Response{
|
||||||
JsonRpc: jsonRpcVersion,
|
JsonRpc: jsonRpcVersion,
|
||||||
ID: id,
|
ID: id,
|
||||||
@@ -925,13 +928,13 @@ func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id i
|
|||||||
|
|
||||||
// Use CRLF line endings as requested
|
// Use CRLF line endings as requested
|
||||||
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||||
logx.Infof("Sending response for ID %d: %s", id, sseMessage)
|
logx.Infof("Sending response for ID %v: %s", id, sseMessage)
|
||||||
|
|
||||||
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||||
select {
|
select {
|
||||||
case client.channel <- sseMessage:
|
case client.channel <- sseMessage:
|
||||||
default:
|
default:
|
||||||
// Channel buffer is full, log warning and continue
|
// Channel buffer is full, log warning and continue
|
||||||
logx.Infof("Client %s channel is full while sending response with ID %d", client.id, id)
|
logx.Infof("Client %s channel is full while sending response with ID %v", client.id, id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
25
mcp/types.go
25
mcp/types.go
@@ -3,6 +3,7 @@ package mcp
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/rest"
|
"github.com/zeromicro/go-zero/rest"
|
||||||
@@ -15,11 +16,31 @@ type Cursor string
|
|||||||
type Request struct {
|
type Request struct {
|
||||||
SessionId string `form:"session_id"` // Session identifier for client tracking
|
SessionId string `form:"session_id"` // Session identifier for client tracking
|
||||||
JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec
|
JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec
|
||||||
ID int64 `json:"id"` // Request identifier for matching responses
|
ID any `json:"id"` // Request identifier for matching responses
|
||||||
Method string `json:"method"` // Method name to invoke
|
Method string `json:"method"` // Method name to invoke
|
||||||
Params json.RawMessage `json:"params"` // Parameters for the method
|
Params json.RawMessage `json:"params"` // Parameters for the method
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r Request) isNotification() (bool, error) {
|
||||||
|
var isNotification bool
|
||||||
|
switch val := r.ID.(type) {
|
||||||
|
case int:
|
||||||
|
isNotification = val == 0
|
||||||
|
case int64:
|
||||||
|
isNotification = val == 0
|
||||||
|
case float64:
|
||||||
|
isNotification = val == 0.0
|
||||||
|
case string:
|
||||||
|
isNotification = len(val) == 0
|
||||||
|
case nil:
|
||||||
|
isNotification = true
|
||||||
|
default:
|
||||||
|
return false, fmt.Errorf("invalid type %T", val)
|
||||||
|
}
|
||||||
|
|
||||||
|
return isNotification, nil
|
||||||
|
}
|
||||||
|
|
||||||
type PaginatedParams struct {
|
type PaginatedParams struct {
|
||||||
Cursor string `json:"cursor"`
|
Cursor string `json:"cursor"`
|
||||||
Meta struct {
|
Meta struct {
|
||||||
@@ -244,7 +265,7 @@ type errorObj struct {
|
|||||||
// Response represents a JSON-RPC response
|
// Response represents a JSON-RPC response
|
||||||
type Response struct {
|
type Response struct {
|
||||||
JsonRpc string `json:"jsonrpc"` // Always "2.0"
|
JsonRpc string `json:"jsonrpc"` // Always "2.0"
|
||||||
ID int64 `json:"id"` // Same as request ID
|
ID any `json:"id"` // Same as request ID
|
||||||
Result any `json:"result"` // Result object (null if error)
|
Result any `json:"result"` // Result object (null if error)
|
||||||
Error *errorObj `json:"error,omitempty"` // Error object (null if success)
|
Error *errorObj `json:"error,omitempty"` // Error object (null if success)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package mcp
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -55,7 +56,7 @@ func TestRequestUnmarshaling(t *testing.T) {
|
|||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "2.0", req.JsonRpc)
|
assert.Equal(t, "2.0", req.JsonRpc)
|
||||||
assert.Equal(t, int64(789), req.ID)
|
assert.Equal(t, float64(789), req.ID)
|
||||||
assert.Equal(t, "test_method", req.Method)
|
assert.Equal(t, "test_method", req.Method)
|
||||||
|
|
||||||
// Check params unmarshaled correctly
|
// Check params unmarshaled correctly
|
||||||
@@ -204,3 +205,67 @@ func TestCallToolResult(t *testing.T) {
|
|||||||
assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`)
|
assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`)
|
||||||
assert.NotContains(t, string(data), `"isError":`)
|
assert.NotContains(t, string(data), `"isError":`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRequest_isNotification(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
id any
|
||||||
|
want bool
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
// integer test cases
|
||||||
|
{name: "int zero", id: 0, want: true, wantErr: nil},
|
||||||
|
{name: "int non-zero", id: 1, want: false, wantErr: nil},
|
||||||
|
{name: "int64 zero", id: int64(0), want: true, wantErr: nil},
|
||||||
|
{name: "int64 max", id: int64(9223372036854775807), want: false, wantErr: nil},
|
||||||
|
|
||||||
|
// floating point number test cases
|
||||||
|
{name: "float64 zero", id: float64(0.0), want: true, wantErr: nil},
|
||||||
|
{name: "float64 positive", id: float64(0.000001), want: false, wantErr: nil},
|
||||||
|
{name: "float64 negative", id: float64(-0.000001), want: false, wantErr: nil},
|
||||||
|
{name: "float64 epsilon", id: float64(1e-300), want: false, wantErr: nil},
|
||||||
|
|
||||||
|
// string test cases
|
||||||
|
{name: "empty string", id: "", want: true, wantErr: nil},
|
||||||
|
{name: "non-empty string", id: "abc", want: false, wantErr: nil},
|
||||||
|
{name: "space string", id: " ", want: false, wantErr: nil},
|
||||||
|
{name: "unicode string", id: "こんにちは", want: false, wantErr: nil},
|
||||||
|
|
||||||
|
// special cases
|
||||||
|
{name: "nil", id: nil, want: true, wantErr: nil},
|
||||||
|
|
||||||
|
// logical type test cases
|
||||||
|
{name: "bool true", id: true, want: false, wantErr: errors.New("invalid type bool")},
|
||||||
|
{name: "bool false", id: false, want: false, wantErr: errors.New("invalid type bool")},
|
||||||
|
{name: "struct type", id: struct{}{}, want: false, wantErr: errors.New("invalid type struct {}")},
|
||||||
|
{name: "slice type", id: []int{1, 2, 3}, want: false, wantErr: errors.New("invalid type []int")},
|
||||||
|
{name: "map type", id: map[string]int{"a": 1}, want: false, wantErr: errors.New("invalid type map[string]int")},
|
||||||
|
{name: "pointer type", id: new(int), want: false, wantErr: errors.New("invalid type *int")},
|
||||||
|
{name: "func type", id: func() {}, want: false, wantErr: errors.New("invalid type func()")},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := Request{
|
||||||
|
SessionId: "test-session",
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: tt.id,
|
||||||
|
Method: "testMethod",
|
||||||
|
Params: json.RawMessage(`{}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := req.isNotification()
|
||||||
|
|
||||||
|
if (err != nil) != (tt.wantErr != nil) {
|
||||||
|
t.Fatalf("error presence mismatch: got error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if err != nil && tt.wantErr != nil && err.Error() != tt.wantErr.Error() {
|
||||||
|
t.Fatalf("error message mismatch:\ngot %q\nwant %q", err.Error(), tt.wantErr.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("isNotification() = %v, want %v for ID %v (%T)", got, tt.want, tt.id, tt.id)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user