mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-10 16:30:01 +08:00
239
mcp/server.go
239
mcp/server.go
@@ -42,14 +42,14 @@ func NewMcpServer(c McpConf) McpServer {
|
||||
Method: http.MethodGet,
|
||||
Path: s.conf.Mcp.SseEndpoint,
|
||||
Handler: s.handleSSE,
|
||||
}, rest.WithSSE())
|
||||
}, rest.WithSSE(), rest.WithTimeout(c.Mcp.SseTimeout))
|
||||
|
||||
// JSON-RPC message endpoint for regular requests
|
||||
s.server.AddRoute(rest.Route{
|
||||
Method: http.MethodPost,
|
||||
Path: s.conf.Mcp.MessageEndpoint,
|
||||
Handler: s.handleRequest,
|
||||
})
|
||||
}, rest.WithTimeout(c.Mcp.MessageTimeout))
|
||||
|
||||
return s
|
||||
}
|
||||
@@ -182,21 +182,23 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
// Always allow initialize and notifications/initialized regardless of client state
|
||||
if req.Method == methodInitialize {
|
||||
logx.Infof("Processing initialize request with ID: %d", req.ID)
|
||||
s.processInitialize(client, req)
|
||||
s.processInitialize(r.Context(), client, req)
|
||||
logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID)
|
||||
return
|
||||
} else if req.Method == methodNotificationsInitialized {
|
||||
// Handle initialized notification
|
||||
logx.Info("Received notifications/initialized notification")
|
||||
if !isNotification {
|
||||
s.sendErrorResponse(client, req.ID, "Method should be used as a notification", errCodeInvalidRequest)
|
||||
s.sendErrorResponse(r.Context(), client, req.ID,
|
||||
"Method should be used as a notification", errCodeInvalidRequest)
|
||||
return
|
||||
}
|
||||
s.processNotificationInitialized(client)
|
||||
return
|
||||
} else if !client.initialized && req.Method != methodNotificationsCancelled {
|
||||
// Block most requests until client is initialized (except for cancellations)
|
||||
s.sendErrorResponse(client, req.ID, "Client not fully initialized, waiting for notifications/initialized",
|
||||
s.sendErrorResponse(r.Context(), client, req.ID,
|
||||
"Client not fully initialized, waiting for notifications/initialized",
|
||||
errCodeClientNotInitialized)
|
||||
return
|
||||
}
|
||||
@@ -205,41 +207,41 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
switch req.Method {
|
||||
case methodToolsCall:
|
||||
logx.Infof("Received tools call request with ID: %d", req.ID)
|
||||
s.processToolCall(client, req)
|
||||
s.processToolCall(r.Context(), client, req)
|
||||
logx.Infof("Sent tools call response for ID: %d", req.ID)
|
||||
case methodToolsList:
|
||||
logx.Infof("Processing tools/list request with ID: %d", req.ID)
|
||||
s.processListTools(client, req)
|
||||
s.processListTools(r.Context(), client, req)
|
||||
logx.Infof("Sent tools/list response for ID: %d", req.ID)
|
||||
case methodPromptsList:
|
||||
logx.Infof("Processing prompts/list request with ID: %d", req.ID)
|
||||
s.processListPrompts(client, req)
|
||||
s.processListPrompts(r.Context(), client, req)
|
||||
logx.Infof("Sent prompts/list response for ID: %d", req.ID)
|
||||
case methodPromptsGet:
|
||||
logx.Infof("Processing prompts/get request with ID: %d", req.ID)
|
||||
s.processGetPrompt(client, req)
|
||||
s.processGetPrompt(r.Context(), client, req)
|
||||
logx.Infof("Sent prompts/get response for ID: %d", req.ID)
|
||||
case methodResourcesList:
|
||||
logx.Infof("Processing resources/list request with ID: %d", req.ID)
|
||||
s.processListResources(client, req)
|
||||
s.processListResources(r.Context(), client, req)
|
||||
logx.Infof("Sent resources/list response for ID: %d", req.ID)
|
||||
case methodResourcesRead:
|
||||
logx.Infof("Processing resources/read request with ID: %d", req.ID)
|
||||
s.processResourcesRead(client, req)
|
||||
s.processResourcesRead(r.Context(), client, req)
|
||||
logx.Infof("Sent resources/read response for ID: %d", req.ID)
|
||||
case methodResourcesSubscribe:
|
||||
logx.Infof("Processing resources/subscribe request with ID: %d", req.ID)
|
||||
s.processResourceSubscribe(client, req)
|
||||
s.processResourceSubscribe(r.Context(), client, req)
|
||||
logx.Infof("Sent resources/subscribe response for ID: %d", req.ID)
|
||||
case methodPing:
|
||||
logx.Infof("Processing ping request with ID: %d", req.ID)
|
||||
s.processPing(client, req)
|
||||
s.processPing(r.Context(), client, req)
|
||||
case methodNotificationsCancelled:
|
||||
logx.Infof("Received notifications/cancelled notification: %v", req.Params)
|
||||
s.processNotificationCancelled(client, req)
|
||||
logx.Infof("Received notifications/cancelled notification: %d", req.ID)
|
||||
s.processNotificationCancelled(r.Context(), client, req)
|
||||
default:
|
||||
logx.Infof("Unknown method: %s", req.Method)
|
||||
s.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||
logx.Infof("Unknown method: %s from client: %d", req.Method, req.ID)
|
||||
s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -321,7 +323,7 @@ func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
flusher.Flush()
|
||||
case <-r.Context().Done():
|
||||
// Client disconnected or request was canceled
|
||||
// Client disconnected or request was canceled or timed out
|
||||
logx.Infof("Client %s disconnected: context done", sessionId)
|
||||
return
|
||||
}
|
||||
@@ -329,7 +331,7 @@ func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// processInitialize processes the initialize request
|
||||
func (s *sseMcpServer) processInitialize(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processInitialize(ctx context.Context, client *mcpClient, req Request) {
|
||||
// Create a proper JSON-RPC response that preserves the client's request ID
|
||||
result := initializationResponse{
|
||||
ProtocolVersion: s.conf.Mcp.ProtocolVersion,
|
||||
@@ -362,11 +364,11 @@ func (s *sseMcpServer) processInitialize(client *mcpClient, req Request) {
|
||||
client.initialized = true
|
||||
|
||||
// Send response with client's original request ID
|
||||
s.sendResponse(client, req.ID, result)
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processListTools processes the tools/list request
|
||||
func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processListTools(ctx context.Context, client *mcpClient, req Request) {
|
||||
// Extract pagination params if any
|
||||
var nextCursor string
|
||||
var progressToken any
|
||||
@@ -390,6 +392,9 @@ func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
|
||||
var toolsList []Tool
|
||||
s.toolsLock.Lock()
|
||||
for _, tool := range s.tools {
|
||||
if len(tool.InputSchema.Type) == 0 {
|
||||
tool.InputSchema.Type = ContentTypeObject
|
||||
}
|
||||
toolsList = append(toolsList, tool)
|
||||
}
|
||||
s.toolsLock.Unlock()
|
||||
@@ -405,15 +410,15 @@ func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
|
||||
// Add meta information if progress token was provided
|
||||
if progressToken != nil {
|
||||
result.Result.Meta = map[string]any{
|
||||
"progressToken": progressToken,
|
||||
progressTokenKey: progressToken,
|
||||
}
|
||||
}
|
||||
|
||||
s.sendResponse(client, req.ID, result)
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processListPrompts processes the prompts/list request
|
||||
func (s *sseMcpServer) processListPrompts(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processListPrompts(ctx context.Context, client *mcpClient, req Request) {
|
||||
// Extract pagination params if any
|
||||
var nextCursor string
|
||||
if req.Params != nil {
|
||||
@@ -447,11 +452,11 @@ func (s *sseMcpServer) processListPrompts(client *mcpClient, req Request) {
|
||||
NextCursor: nextCursor,
|
||||
}
|
||||
|
||||
s.sendResponse(client, req.ID, result)
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processListResources processes the resources/list request
|
||||
func (s *sseMcpServer) processListResources(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processListResources(ctx context.Context, client *mcpClient, req Request) {
|
||||
// Extract pagination params if any
|
||||
var nextCursor string
|
||||
var progressToken any
|
||||
@@ -493,15 +498,15 @@ func (s *sseMcpServer) processListResources(client *mcpClient, req Request) {
|
||||
// Add meta information if progress token was provided
|
||||
if progressToken != nil {
|
||||
result.Result.Meta = map[string]any{
|
||||
"progressToken": progressToken,
|
||||
progressTokenKey: progressToken,
|
||||
}
|
||||
}
|
||||
|
||||
s.sendResponse(client, req.ID, result)
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processGetPrompt processes the prompts/get request
|
||||
func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processGetPrompt(ctx context.Context, client *mcpClient, req Request) {
|
||||
type GetPromptParams struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]string `json:"arguments,omitempty"`
|
||||
@@ -509,7 +514,7 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
|
||||
var params GetPromptParams
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -519,7 +524,7 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
s.promptsLock.Unlock()
|
||||
if !exists {
|
||||
message := fmt.Sprintf("Prompt '%s' not found", params.Name)
|
||||
s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams)
|
||||
s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -529,12 +534,15 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
missingArgs := validatePromptArguments(prompt, params.Arguments)
|
||||
if len(missingArgs) > 0 {
|
||||
message := fmt.Sprintf("Missing required arguments: %s", strings.Join(missingArgs, ", "))
|
||||
s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams)
|
||||
s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
// Apply default values for missing optional arguments
|
||||
args := applyDefaultArguments(prompt, params.Arguments)
|
||||
// Ensure arguments are initialized to an empty map if nil
|
||||
if params.Arguments == nil {
|
||||
params.Arguments = make(map[string]string)
|
||||
}
|
||||
args := params.Arguments
|
||||
|
||||
// Generate messages using handler or static content
|
||||
var messages []PromptMessage
|
||||
@@ -542,17 +550,17 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
|
||||
if prompt.Handler != nil {
|
||||
// Use dynamic handler to generate messages
|
||||
logx.Info("Using prompt handler to generate content")
|
||||
messages, err = prompt.Handler(args)
|
||||
messages, err = prompt.Handler(ctx, args)
|
||||
if err != nil {
|
||||
logx.Errorf("Error from prompt handler: %v", err)
|
||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError)
|
||||
s.sendErrorResponse(ctx, client, req.ID,
|
||||
fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// No handler, generate messages from static content
|
||||
var messageText string
|
||||
if prompt.Content != "" {
|
||||
if len(prompt.Content) > 0 {
|
||||
messageText = prompt.Content
|
||||
|
||||
// Apply argument substitutions to static content
|
||||
@@ -560,21 +568,13 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
placeholder := fmt.Sprintf("{{%s}}", key)
|
||||
messageText = strings.Replace(messageText, placeholder, value, -1)
|
||||
}
|
||||
} else {
|
||||
// No content, use a default fallback
|
||||
topic := "this topic"
|
||||
if t, ok := args["topic"]; ok && t != "" {
|
||||
topic = t
|
||||
}
|
||||
messageText = fmt.Sprintf("Tell me about %s", topic)
|
||||
}
|
||||
|
||||
// Create a single user message with the content
|
||||
messages = []PromptMessage{
|
||||
{
|
||||
Role: roleUser,
|
||||
Role: RoleUser,
|
||||
Content: TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: messageText,
|
||||
},
|
||||
},
|
||||
@@ -587,49 +587,14 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
Messages []PromptMessage `json:"messages"`
|
||||
}{
|
||||
Description: prompt.Description,
|
||||
Messages: messages,
|
||||
Messages: toTypedPromptMessages(messages),
|
||||
}
|
||||
|
||||
s.sendResponse(client, req.ID, result)
|
||||
}
|
||||
|
||||
// validatePromptArguments checks if all required arguments are provided
|
||||
// Returns a list of missing required arguments
|
||||
func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string {
|
||||
var missingArgs []string
|
||||
|
||||
for _, arg := range prompt.Arguments {
|
||||
if arg.Required {
|
||||
if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 {
|
||||
missingArgs = append(missingArgs, arg.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return missingArgs
|
||||
}
|
||||
|
||||
// applyDefaultArguments adds default values for missing optional arguments
|
||||
func applyDefaultArguments(prompt Prompt, providedArgs map[string]string) map[string]string {
|
||||
result := make(map[string]string)
|
||||
|
||||
// Copy all provided arguments
|
||||
for k, v := range providedArgs {
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
// Add defaults for missing arguments
|
||||
for _, arg := range prompt.Arguments {
|
||||
if _, exists := result[arg.Name]; !exists && arg.Default != "" {
|
||||
result[arg.Name] = arg.Default
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processToolCall processes the tools/call request
|
||||
func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processToolCall(ctx context.Context, client *mcpClient, req Request) {
|
||||
var toolCallParams struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
@@ -642,7 +607,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
// If it's a RawMessage (JSON), unmarshal it
|
||||
if err := json.Unmarshal(req.Params, &toolCallParams); err != nil {
|
||||
logx.Errorf("Failed to unmarshal tool call params: %v", err)
|
||||
s.sendErrorResponse(client, req.ID, "Invalid tool call parameters", errCodeInvalidParams)
|
||||
s.sendErrorResponse(ctx, client, req.ID, "Invalid tool call parameters", errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -654,15 +619,11 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
tool, exists := s.tools[toolCallParams.Name]
|
||||
s.toolsLock.Unlock()
|
||||
if !exists {
|
||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Tool '%s' not found",
|
||||
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Tool '%s' not found",
|
||||
toolCallParams.Name), errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a context with the configured timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), s.conf.Mcp.ToolTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Log parameters before execution
|
||||
logx.Infof("Executing tool '%s' with arguments: %#v", toolCallParams.Name, toolCallParams.Arguments)
|
||||
|
||||
@@ -671,6 +632,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
var err error
|
||||
|
||||
// Create a channel to receive the result
|
||||
// make sure to have 1 size buffer to avoid channel leak if timeout
|
||||
resultCh := make(chan struct {
|
||||
result any
|
||||
err error
|
||||
@@ -678,7 +640,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
|
||||
// Execute the tool handler in a goroutine
|
||||
go func() {
|
||||
toolResult, toolErr := tool.Handler(toolCallParams.Arguments)
|
||||
toolResult, toolErr := tool.Handler(ctx, toolCallParams.Arguments)
|
||||
resultCh <- struct {
|
||||
result any
|
||||
err error
|
||||
@@ -694,9 +656,9 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
result = res.result
|
||||
err = res.err
|
||||
case <-ctx.Done():
|
||||
// Handle timeout
|
||||
logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.ToolTimeout, toolCallParams.Name)
|
||||
s.sendErrorResponse(client, req.ID, "Tool execution timed out", errCodeTimeout)
|
||||
// Handle request timeout
|
||||
logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.MessageTimeout, toolCallParams.Name)
|
||||
s.sendErrorResponse(ctx, client, req.ID, "Tool execution timed out", errCodeTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -710,7 +672,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
// Add meta information if progress token was provided
|
||||
if progressToken != nil {
|
||||
callToolResult.Result.Meta = map[string]any{
|
||||
"progressToken": progressToken,
|
||||
progressTokenKey: progressToken,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -722,12 +684,11 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
|
||||
callToolResult.Content = []any{
|
||||
TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: fmt.Sprintf("Error: %v", err),
|
||||
},
|
||||
}
|
||||
callToolResult.IsError = true
|
||||
s.sendResponse(client, req.ID, callToolResult)
|
||||
s.sendResponse(ctx, client, req.ID, callToolResult)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -736,10 +697,9 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
case string:
|
||||
// Simple string becomes text content
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: v,
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
},
|
||||
})
|
||||
case map[string]any:
|
||||
@@ -749,69 +709,63 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
jsonStr = []byte(err.Error())
|
||||
}
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: string(jsonStr),
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
},
|
||||
})
|
||||
case TextContent:
|
||||
// Direct TextContent object
|
||||
callToolResult.Content = append(callToolResult.Content, v)
|
||||
case ImageContent:
|
||||
// Direct ImageContent object
|
||||
callToolResult.Content = append(callToolResult.Content, v)
|
||||
case []any:
|
||||
// Array of content items
|
||||
callToolResult.Content = v
|
||||
case ToolResult:
|
||||
// Handle legacy ToolResult type
|
||||
switch v.Type {
|
||||
case contentTypeText:
|
||||
case ContentTypeText:
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: fmt.Sprintf("%v", v.Content),
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
},
|
||||
})
|
||||
case contentTypeImage:
|
||||
case ContentTypeImage:
|
||||
if imgData, ok := v.Content.(map[string]any); ok {
|
||||
callToolResult.Content = append(callToolResult.Content, ImageContent{
|
||||
Type: contentTypeImage,
|
||||
Data: fmt.Sprintf("%v", imgData["data"]),
|
||||
MimeType: fmt.Sprintf("%v", imgData["mimeType"]),
|
||||
})
|
||||
}
|
||||
default:
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: fmt.Sprintf("%v", v.Content),
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
},
|
||||
})
|
||||
}
|
||||
default:
|
||||
// For any other type, convert to string
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: fmt.Sprintf("%v", v),
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
callToolResult.Content = toTypedContents(callToolResult.Content)
|
||||
logx.Infof("Tool call result: %#v", callToolResult)
|
||||
s.sendResponse(client, req.ID, callToolResult)
|
||||
|
||||
s.sendResponse(ctx, client, req.ID, callToolResult)
|
||||
}
|
||||
|
||||
// processResourcesRead processes the resources/read request
|
||||
func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processResourcesRead(ctx context.Context, client *mcpClient, req Request) {
|
||||
var params ResourceReadParams
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -821,7 +775,7 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
||||
s.resourcesLock.Unlock()
|
||||
|
||||
if !exists {
|
||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||
params.URI), errCodeResourceNotFound)
|
||||
return
|
||||
}
|
||||
@@ -837,14 +791,14 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
||||
},
|
||||
},
|
||||
}
|
||||
s.sendResponse(client, req.ID, result)
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
return
|
||||
}
|
||||
|
||||
// Execute the resource handler
|
||||
content, err := resource.Handler()
|
||||
content, err := resource.Handler(ctx)
|
||||
if err != nil {
|
||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error reading resource: %v", err),
|
||||
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Error reading resource: %v", err),
|
||||
errCodeInternalError)
|
||||
return
|
||||
}
|
||||
@@ -865,14 +819,14 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
||||
Contents: []ResourceContent{content},
|
||||
}
|
||||
|
||||
s.sendResponse(client, req.ID, result)
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processResourceSubscribe processes the resources/subscribe request
|
||||
func (s *sseMcpServer) processResourceSubscribe(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processResourceSubscribe(ctx context.Context, client *mcpClient, req Request) {
|
||||
var params ResourceSubscribeParams
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -882,19 +836,17 @@ func (s *sseMcpServer) processResourceSubscribe(client *mcpClient, req Request)
|
||||
s.resourcesLock.Unlock()
|
||||
|
||||
if !exists {
|
||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||
params.URI), errCodeResourceNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Send success response for the subscription
|
||||
s.sendResponse(client, req.ID, struct{}{})
|
||||
|
||||
logx.Infof("Client %s subscribed to resource '%s'", client.id, params.URI)
|
||||
s.sendResponse(ctx, client, req.ID, struct{}{})
|
||||
}
|
||||
|
||||
// processNotificationCancelled processes the notifications/cancelled notification
|
||||
func (s *sseMcpServer) processNotificationCancelled(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processNotificationCancelled(ctx context.Context, client *mcpClient, req Request) {
|
||||
// Extract the requestId that was canceled
|
||||
type CancelParams struct {
|
||||
RequestId int64 `json:"requestId"`
|
||||
@@ -918,18 +870,17 @@ func (s *sseMcpServer) processNotificationInitialized(client *mcpClient) {
|
||||
}
|
||||
|
||||
// processPing processes the ping request and responds immediately
|
||||
func (s *sseMcpServer) processPing(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processPing(ctx context.Context, client *mcpClient, req Request) {
|
||||
// A ping request should simply respond with an empty result to confirm the server is alive
|
||||
logx.Infof("Received ping request with ID: %d", req.ID)
|
||||
|
||||
// Send an empty response with client's original request ID
|
||||
s.sendResponse(client, req.ID, struct{}{})
|
||||
|
||||
logx.Infof("Sent ping response for ID: %d", req.ID)
|
||||
s.sendResponse(ctx, client, req.ID, struct{}{})
|
||||
}
|
||||
|
||||
// sendErrorResponse sends an error response via the SSE channel
|
||||
func (s *sseMcpServer) sendErrorResponse(client *mcpClient, id int64, message string, code int) {
|
||||
func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
|
||||
id int64, message string, code int) {
|
||||
errorResponse := struct {
|
||||
JsonRpc string `json:"jsonrpc"`
|
||||
ID int64 `json:"id"`
|
||||
@@ -949,11 +900,17 @@ func (s *sseMcpServer) sendErrorResponse(client *mcpClient, id int64, message st
|
||||
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||
logx.Infof("Sending error for ID %d: %s", id, sseMessage)
|
||||
|
||||
client.channel <- sseMessage
|
||||
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||
select {
|
||||
case client.channel <- sseMessage:
|
||||
default:
|
||||
// Channel buffer is full, log warning and continue
|
||||
logx.Infof("Client %s channel is full while sending error response with ID %d", client.id, id)
|
||||
}
|
||||
}
|
||||
|
||||
// sendResponse sends a success response via the SSE channel
|
||||
func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) {
|
||||
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id int64, result any) {
|
||||
response := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: id,
|
||||
@@ -962,7 +919,7 @@ func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) {
|
||||
|
||||
jsonData, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
s.sendErrorResponse(client, id, "Failed to marshal response", errCodeInternalError)
|
||||
s.sendErrorResponse(ctx, client, id, "Failed to marshal response", errCodeInternalError)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -970,5 +927,11 @@ func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) {
|
||||
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||
logx.Infof("Sending response for ID %d: %s", id, sseMessage)
|
||||
|
||||
client.channel <- sseMessage
|
||||
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||
select {
|
||||
case client.channel <- sseMessage:
|
||||
default:
|
||||
// Channel buffer is full, log warning and continue
|
||||
logx.Infof("Client %s channel is full while sending response with ID %d", client.id, id)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user