Refactor routes and harden AddTool (#5375)

This commit is contained in:
mk0walsk
2026-01-24 14:13:35 +02:00
committed by GitHub
parent 173f76acf9
commit 94e2f5bd12
3 changed files with 30 additions and 17 deletions

View File

@@ -85,20 +85,7 @@ func (s *mcpServerImpl) setupSSETransport() {
return s.mcpServer
}, nil)
// Register the SSE endpoint
s.httpServer.AddRoute(rest.Route{
Method: http.MethodGet,
Path: s.conf.Mcp.SseEndpoint,
Handler: handler.ServeHTTP,
}, rest.WithSSE(), rest.WithTimeout(s.conf.Mcp.SseTimeout))
// The SSE handler also handles POST requests to message endpoints
// We need to route those as well
s.httpServer.AddRoute(rest.Route{
Method: http.MethodPost,
Path: s.conf.Mcp.SseEndpoint,
Handler: handler.ServeHTTP,
}, rest.WithTimeout(s.conf.Mcp.MessageTimeout))
s.registerRoutes(handler, s.conf.Mcp.SseEndpoint)
}
// setupStreamableTransport configures the server to use Streamable HTTP transport (2025-03-26 spec)
@@ -109,16 +96,20 @@ func (s *mcpServerImpl) setupStreamableTransport() {
return s.mcpServer
}, nil)
// Register the message endpoint (handles both GET for SSE and POST for messages)
s.registerRoutes(handler, s.conf.Mcp.MessageEndpoint)
}
func (s *mcpServerImpl) registerRoutes(handler http.Handler, endpoint string) {
// Register the endpoint (handles both GET for SSE and POST for messages)
s.httpServer.AddRoute(rest.Route{
Method: http.MethodGet,
Path: s.conf.Mcp.MessageEndpoint,
Path: endpoint,
Handler: handler.ServeHTTP,
}, rest.WithSSE(), rest.WithTimeout(s.conf.Mcp.SseTimeout))
s.httpServer.AddRoute(rest.Route{
Method: http.MethodPost,
Path: s.conf.Mcp.MessageEndpoint,
Path: endpoint,
Handler: handler.ServeHTTP,
}, rest.WithTimeout(s.conf.Mcp.MessageTimeout))
}

View File

@@ -372,3 +372,22 @@ func TestConfig(t *testing.T) {
assert.Equal(t, "/sse", c.Mcp.SseEndpoint)
assert.Equal(t, "/message", c.Mcp.MessageEndpoint)
}
type mockMcpServer struct{}
func (m *mockMcpServer) Start() {}
func (m *mockMcpServer) Stop() {}
func TestAddToolWithCustomServer(t *testing.T) {
server := &mockMcpServer{}
// Should not panic, but log error
defer func() {
if r := recover(); r != nil {
t.Errorf("AddTool panicked with custom server: %v", r)
}
}()
AddTool(server, &Tool{Name: "test"}, func(ctx context.Context, req *CallToolRequest, args struct{}) (*CallToolResult, any, error) {
return nil, nil, nil
})
}

View File

@@ -4,6 +4,7 @@ import (
"context"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/zeromicro/go-zero/core/logx"
)
// Re-export commonly used SDK types for convenience
@@ -93,5 +94,7 @@ func AddTool[In, Out any](server McpServer, tool *Tool, handler func(context.Con
// Access internal server - only works with mcpServerImpl
if impl, ok := server.(*mcpServerImpl); ok {
sdkmcp.AddTool(impl.mcpServer, tool, handler)
} else {
logx.Error("AddTool: server must be of type *mcpServerImpl to use this helper")
}
}