mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-07 06:59:59 +08:00
Refactor routes and harden AddTool (#5375)
This commit is contained in:
@@ -85,20 +85,7 @@ func (s *mcpServerImpl) setupSSETransport() {
|
||||
return s.mcpServer
|
||||
}, nil)
|
||||
|
||||
// Register the SSE endpoint
|
||||
s.httpServer.AddRoute(rest.Route{
|
||||
Method: http.MethodGet,
|
||||
Path: s.conf.Mcp.SseEndpoint,
|
||||
Handler: handler.ServeHTTP,
|
||||
}, rest.WithSSE(), rest.WithTimeout(s.conf.Mcp.SseTimeout))
|
||||
|
||||
// The SSE handler also handles POST requests to message endpoints
|
||||
// We need to route those as well
|
||||
s.httpServer.AddRoute(rest.Route{
|
||||
Method: http.MethodPost,
|
||||
Path: s.conf.Mcp.SseEndpoint,
|
||||
Handler: handler.ServeHTTP,
|
||||
}, rest.WithTimeout(s.conf.Mcp.MessageTimeout))
|
||||
s.registerRoutes(handler, s.conf.Mcp.SseEndpoint)
|
||||
}
|
||||
|
||||
// setupStreamableTransport configures the server to use Streamable HTTP transport (2025-03-26 spec)
|
||||
@@ -109,16 +96,20 @@ func (s *mcpServerImpl) setupStreamableTransport() {
|
||||
return s.mcpServer
|
||||
}, nil)
|
||||
|
||||
// Register the message endpoint (handles both GET for SSE and POST for messages)
|
||||
s.registerRoutes(handler, s.conf.Mcp.MessageEndpoint)
|
||||
}
|
||||
|
||||
func (s *mcpServerImpl) registerRoutes(handler http.Handler, endpoint string) {
|
||||
// Register the endpoint (handles both GET for SSE and POST for messages)
|
||||
s.httpServer.AddRoute(rest.Route{
|
||||
Method: http.MethodGet,
|
||||
Path: s.conf.Mcp.MessageEndpoint,
|
||||
Path: endpoint,
|
||||
Handler: handler.ServeHTTP,
|
||||
}, rest.WithSSE(), rest.WithTimeout(s.conf.Mcp.SseTimeout))
|
||||
|
||||
s.httpServer.AddRoute(rest.Route{
|
||||
Method: http.MethodPost,
|
||||
Path: s.conf.Mcp.MessageEndpoint,
|
||||
Path: endpoint,
|
||||
Handler: handler.ServeHTTP,
|
||||
}, rest.WithTimeout(s.conf.Mcp.MessageTimeout))
|
||||
}
|
||||
|
||||
@@ -372,3 +372,22 @@ func TestConfig(t *testing.T) {
|
||||
assert.Equal(t, "/sse", c.Mcp.SseEndpoint)
|
||||
assert.Equal(t, "/message", c.Mcp.MessageEndpoint)
|
||||
}
|
||||
|
||||
type mockMcpServer struct{}
|
||||
|
||||
func (m *mockMcpServer) Start() {}
|
||||
func (m *mockMcpServer) Stop() {}
|
||||
|
||||
func TestAddToolWithCustomServer(t *testing.T) {
|
||||
server := &mockMcpServer{}
|
||||
// Should not panic, but log error
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("AddTool panicked with custom server: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
AddTool(server, &Tool{Name: "test"}, func(ctx context.Context, req *CallToolRequest, args struct{}) (*CallToolResult, any, error) {
|
||||
return nil, nil, nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
// Re-export commonly used SDK types for convenience
|
||||
@@ -93,5 +94,7 @@ func AddTool[In, Out any](server McpServer, tool *Tool, handler func(context.Con
|
||||
// Access internal server - only works with mcpServerImpl
|
||||
if impl, ok := server.(*mcpServerImpl); ok {
|
||||
sdkmcp.AddTool(impl.mcpServer, tool, handler)
|
||||
} else {
|
||||
logx.Error("AddTool: server must be of type *mcpServerImpl to use this helper")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user