From 94e2f5bd129027cca5d4fc571159d3b0303740d6 Mon Sep 17 00:00:00 2001 From: mk0walsk Date: Sat, 24 Jan 2026 14:13:35 +0200 Subject: [PATCH] Refactor routes and harden AddTool (#5375) --- mcp/server.go | 25 ++++++++----------------- mcp/server_test.go | 19 +++++++++++++++++++ mcp/types.go | 3 +++ 3 files changed, 30 insertions(+), 17 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index 3f97a87a4..fe7e1e8c0 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -85,20 +85,7 @@ func (s *mcpServerImpl) setupSSETransport() { return s.mcpServer }, nil) - // Register the SSE endpoint - s.httpServer.AddRoute(rest.Route{ - Method: http.MethodGet, - Path: s.conf.Mcp.SseEndpoint, - Handler: handler.ServeHTTP, - }, rest.WithSSE(), rest.WithTimeout(s.conf.Mcp.SseTimeout)) - - // The SSE handler also handles POST requests to message endpoints - // We need to route those as well - s.httpServer.AddRoute(rest.Route{ - Method: http.MethodPost, - Path: s.conf.Mcp.SseEndpoint, - Handler: handler.ServeHTTP, - }, rest.WithTimeout(s.conf.Mcp.MessageTimeout)) + s.registerRoutes(handler, s.conf.Mcp.SseEndpoint) } // setupStreamableTransport configures the server to use Streamable HTTP transport (2025-03-26 spec) @@ -109,16 +96,20 @@ func (s *mcpServerImpl) setupStreamableTransport() { return s.mcpServer }, nil) - // Register the message endpoint (handles both GET for SSE and POST for messages) + s.registerRoutes(handler, s.conf.Mcp.MessageEndpoint) +} + +func (s *mcpServerImpl) registerRoutes(handler http.Handler, endpoint string) { + // Register the endpoint (handles both GET for SSE and POST for messages) s.httpServer.AddRoute(rest.Route{ Method: http.MethodGet, - Path: s.conf.Mcp.MessageEndpoint, + Path: endpoint, Handler: handler.ServeHTTP, }, rest.WithSSE(), rest.WithTimeout(s.conf.Mcp.SseTimeout)) s.httpServer.AddRoute(rest.Route{ Method: http.MethodPost, - Path: s.conf.Mcp.MessageEndpoint, + Path: endpoint, Handler: handler.ServeHTTP, }, rest.WithTimeout(s.conf.Mcp.MessageTimeout)) } diff --git a/mcp/server_test.go b/mcp/server_test.go index 833ca87d8..171d07fe6 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -372,3 +372,22 @@ func TestConfig(t *testing.T) { assert.Equal(t, "/sse", c.Mcp.SseEndpoint) assert.Equal(t, "/message", c.Mcp.MessageEndpoint) } + +type mockMcpServer struct{} + +func (m *mockMcpServer) Start() {} +func (m *mockMcpServer) Stop() {} + +func TestAddToolWithCustomServer(t *testing.T) { + server := &mockMcpServer{} + // Should not panic, but log error + defer func() { + if r := recover(); r != nil { + t.Errorf("AddTool panicked with custom server: %v", r) + } + }() + + AddTool(server, &Tool{Name: "test"}, func(ctx context.Context, req *CallToolRequest, args struct{}) (*CallToolResult, any, error) { + return nil, nil, nil + }) +} diff --git a/mcp/types.go b/mcp/types.go index 7f4bff959..ccf515d8d 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -4,6 +4,7 @@ import ( "context" sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/zeromicro/go-zero/core/logx" ) // Re-export commonly used SDK types for convenience @@ -93,5 +94,7 @@ func AddTool[In, Out any](server McpServer, tool *Tool, handler func(context.Con // Access internal server - only works with mcpServerImpl if impl, ok := server.(*mcpServerImpl); ok { sdkmcp.AddTool(impl.mcpServer, tool, handler) + } else { + logx.Error("AddTool: server must be of type *mcpServerImpl to use this helper") } }