mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-10 16:30:01 +08:00
@@ -34,7 +34,7 @@ host: localhost
|
||||
port: 8080
|
||||
mcp:
|
||||
name: mcp-test-server
|
||||
toolTimeout: 5s
|
||||
messageTimeout: 5s
|
||||
`
|
||||
|
||||
var c McpConf
|
||||
@@ -82,7 +82,6 @@ func (m *mockMcpServer) registerExampleTool() {
|
||||
Name: "test.tool",
|
||||
Description: "A test tool",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"input": map[string]any{
|
||||
"type": "string",
|
||||
@@ -91,7 +90,7 @@ func (m *mockMcpServer) registerExampleTool() {
|
||||
},
|
||||
Required: []string{"input"},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
input, ok := params["input"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid input parameter")
|
||||
@@ -135,7 +134,7 @@ port: 8080
|
||||
mcp:
|
||||
cors:
|
||||
- http://localhost:3000
|
||||
toolTimeout: 5s
|
||||
messageTimeout: 5s
|
||||
`
|
||||
|
||||
var c McpConf
|
||||
@@ -186,7 +185,6 @@ func TestRegisterTool(t *testing.T) {
|
||||
Name: "example.tool",
|
||||
Description: "An example tool",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"input": map[string]any{
|
||||
"type": "string",
|
||||
@@ -194,7 +192,7 @@ func TestRegisterTool(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return "result", nil
|
||||
},
|
||||
}
|
||||
@@ -280,7 +278,7 @@ func TestToolsList(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processListTools(client, req)
|
||||
mock.server.processListTools(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -328,7 +326,7 @@ func TestToolCallBasic(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -355,8 +353,7 @@ func TestToolCallBasic(t *testing.T) {
|
||||
|
||||
// Verify the response content
|
||||
assert.Len(t, parsed.Result.Content, 1, "Response should contain one content item")
|
||||
assert.Equal(t, "text", parsed.Result.Content[0]["type"], "Content type should be text")
|
||||
assert.Equal(t, "Processed: test-input", parsed.Result.Content[0]["text"], "Tool result incorrect")
|
||||
assert.Equal(t, "Processed: test-input", parsed.Result.Content[0][ContentTypeText], "Tool result incorrect")
|
||||
assert.False(t, parsed.Result.IsError, "Response should not be an error")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for tool call response")
|
||||
@@ -373,10 +370,9 @@ func TestToolCallMapResult(t *testing.T) {
|
||||
Name: "map.tool",
|
||||
Description: "A tool that returns a map result",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
// Return a complex nested map structure
|
||||
return map[string]any{
|
||||
"string": "value",
|
||||
@@ -417,7 +413,7 @@ func TestToolCallMapResult(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -445,13 +441,8 @@ func TestToolCallMapResult(t *testing.T) {
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's a text content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "text", contentType, "Content type should be text")
|
||||
|
||||
// Get the text content which should be our JSON
|
||||
text, ok := firstItem["text"].(string)
|
||||
text, ok := firstItem[ContentTypeText].(string)
|
||||
require.True(t, ok, "Content should have text")
|
||||
|
||||
// Verify the text is valid JSON and contains our data
|
||||
@@ -496,10 +487,9 @@ func TestToolCallArrayResult(t *testing.T) {
|
||||
Name: "array.tool",
|
||||
Description: "A tool that returns an array result",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
// Return an array of mixed content types
|
||||
return []any{
|
||||
"string item",
|
||||
@@ -536,7 +526,7 @@ func TestToolCallArrayResult(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -574,16 +564,14 @@ func TestToolCallTextContentResult(t *testing.T) {
|
||||
Name: "text.content.tool",
|
||||
Description: "A tool that returns a TextContent result",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
// Return a TextContent object directly
|
||||
return TextContent{
|
||||
Type: "text",
|
||||
Text: "This is a direct TextContent result",
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
Priority: func() *float64 { p := 0.9; return &p }(),
|
||||
},
|
||||
}, nil
|
||||
@@ -614,7 +602,7 @@ func TestToolCallTextContentResult(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -642,16 +630,6 @@ func TestToolCallTextContentResult(t *testing.T) {
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's a text content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "text", contentType, "Content type should be text")
|
||||
|
||||
// Check text content
|
||||
text, ok := firstItem["text"].(string)
|
||||
require.True(t, ok, "Content should have text")
|
||||
assert.Equal(t, "This is a direct TextContent result", text, "Text content should match")
|
||||
|
||||
// Check annotations
|
||||
annotations, ok := firstItem["annotations"].(map[string]any)
|
||||
require.True(t, ok, "Should have annotations")
|
||||
@@ -679,13 +657,11 @@ func TestToolCallImageContentResult(t *testing.T) {
|
||||
Name: "image.content.tool",
|
||||
Description: "A tool that returns an ImageContent result",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
// Return an ImageContent object directly
|
||||
return ImageContent{
|
||||
Type: "image",
|
||||
Data: "dGVzdCBpbWFnZSBkYXRhIChiYXNlNjQgZW5jb2RlZCk=", // "test image data (base64 encoded)" in base64
|
||||
MimeType: "image/png",
|
||||
}, nil
|
||||
@@ -716,7 +692,7 @@ func TestToolCallImageContentResult(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -744,11 +720,6 @@ func TestToolCallImageContentResult(t *testing.T) {
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's an image content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "image", contentType, "Content type should be image")
|
||||
|
||||
// Check image data
|
||||
data, ok := firstItem["data"].(string)
|
||||
require.True(t, ok, "Content should have data")
|
||||
@@ -773,12 +744,12 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
Name: "toolresult.tool",
|
||||
Description: "A tool that returns a ToolResult object",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Type: ContentTypeObject,
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return ToolResult{
|
||||
Type: "text",
|
||||
Type: ContentTypeText,
|
||||
Content: "This is a ToolResult with text content type",
|
||||
}, nil
|
||||
},
|
||||
@@ -790,10 +761,10 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
Name: "toolresult.image.tool",
|
||||
Description: "A tool that returns a ToolResult with image content",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Type: ContentTypeObject,
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return ToolResult{
|
||||
Type: "image",
|
||||
Content: map[string]any{
|
||||
@@ -810,10 +781,9 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
Name: "toolresult.audio.tool",
|
||||
Description: "A tool that returns a ToolResult with audio content",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
// Test with image type
|
||||
return ToolResult{
|
||||
Type: "audio",
|
||||
@@ -831,10 +801,9 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
Name: "toolresult.int.tool",
|
||||
Description: "A tool that returns a ToolResult with int content",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return 2, nil
|
||||
},
|
||||
}
|
||||
@@ -845,10 +814,9 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
Name: "toolresult.bad.tool",
|
||||
Description: "A tool that returns a ToolResult with bad content",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return map[string]any{
|
||||
"type": "custom",
|
||||
"data": make(chan int),
|
||||
@@ -881,7 +849,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -909,13 +877,8 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's a text content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "text", contentType, "Content type should be text")
|
||||
|
||||
// Check text content
|
||||
text, ok := firstItem["text"].(string)
|
||||
text, ok := firstItem[ContentTypeText].(string)
|
||||
require.True(t, ok, "Content should have text")
|
||||
assert.Equal(t, "This is a ToolResult with text content type", text, "Text content should match")
|
||||
|
||||
@@ -947,7 +910,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -975,11 +938,6 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's an image content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "image", contentType, "Content type should be image")
|
||||
|
||||
// Check image data and mime type
|
||||
data, ok := firstItem["data"].(string)
|
||||
require.True(t, ok, "Content should have data")
|
||||
@@ -1017,7 +975,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -1040,15 +998,6 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
content, ok := result["content"].([]any)
|
||||
require.True(t, ok, "Result should have a content array")
|
||||
require.NotEmpty(t, content, "Content should not be empty")
|
||||
|
||||
// The first content item should be converted from ToolResult to ImageContent
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's an image content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "text", contentType, "Content type should be image")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for tool call response")
|
||||
}
|
||||
@@ -1077,7 +1026,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -1100,15 +1049,6 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
content, ok := result["content"].([]any)
|
||||
require.True(t, ok, "Result should have a content array")
|
||||
require.NotEmpty(t, content, "Content should not be empty")
|
||||
|
||||
// The first content item should be converted from ToolResult to ImageContent
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's an image content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "text", contentType, "Content type should be image")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for tool call response")
|
||||
}
|
||||
@@ -1137,7 +1077,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -1159,10 +1099,9 @@ func TestToolCallError(t *testing.T) {
|
||||
Name: "error.tool",
|
||||
Description: "A tool that returns an error",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return nil, fmt.Errorf("tool execution failed")
|
||||
},
|
||||
})
|
||||
@@ -1189,7 +1128,7 @@ func TestToolCallError(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Check the response
|
||||
select {
|
||||
@@ -1207,20 +1146,16 @@ func TestToolCallTimeout(t *testing.T) {
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
|
||||
// Set a very short timeout for testing
|
||||
mock.server.conf.Mcp.ToolTimeout = 10 * time.Millisecond
|
||||
|
||||
// Register a tool that times out
|
||||
err := mock.server.RegisterTool(Tool{
|
||||
Name: "timeout.tool",
|
||||
Description: "A tool that times out",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
time.Sleep(50 * time.Millisecond) // Sleep longer than timeout
|
||||
return "this should never be returned", nil
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
<-ctx.Done()
|
||||
return nil, fmt.Errorf("tool execution timed out")
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -1244,16 +1179,24 @@ func TestToolCallTimeout(t *testing.T) {
|
||||
Method: methodToolsCall,
|
||||
Params: paramBytes,
|
||||
}
|
||||
jsonBody, _ := json.Marshal(req)
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
// Create HTTP request
|
||||
r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", bytes.NewReader(jsonBody))
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
r = r.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Process through handleRequest
|
||||
go mock.server.handleRequest(w, r)
|
||||
|
||||
// Check the response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, "event: message", "Response should have message event")
|
||||
assert.Contains(t, response, `-32001`, "Response should contain a timeout error code")
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for tool call response")
|
||||
}
|
||||
}
|
||||
@@ -1274,7 +1217,7 @@ func TestInitializeAndNotifications(t *testing.T) {
|
||||
Params: json.RawMessage(`{}`),
|
||||
}
|
||||
|
||||
mock.server.processInitialize(client, initReq)
|
||||
mock.server.processInitialize(context.Background(), client, initReq)
|
||||
|
||||
// Check that client is initialized after initialize request
|
||||
assert.True(t, client.initialized, "Client should be marked as initialized after initialize request")
|
||||
@@ -1418,7 +1361,7 @@ func TestNotificationCancelled_badParams(t *testing.T) {
|
||||
}
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
mock.server.processNotificationCancelled(client, cancelReq)
|
||||
mock.server.processNotificationCancelled(context.Background(), client, cancelReq)
|
||||
|
||||
select {
|
||||
case <-client.channel:
|
||||
@@ -1593,7 +1536,7 @@ func TestGetPrompt(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the request
|
||||
mock.server.processGetPrompt(client, promptReq)
|
||||
mock.server.processGetPrompt(context.Background(), client, promptReq)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
@@ -1622,7 +1565,7 @@ func TestGetPrompt(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the request
|
||||
mock.server.processGetPrompt(client, promptReq)
|
||||
mock.server.processGetPrompt(context.Background(), client, promptReq)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
@@ -1636,6 +1579,44 @@ func TestGetPrompt(t *testing.T) {
|
||||
t.Fatal("Timed out waiting for prompt response")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test prompt with nil params", func(t *testing.T) {
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
|
||||
// Create a test client
|
||||
client := addTestClient(mock.server, "test-client", true)
|
||||
|
||||
// Register a test prompt
|
||||
testPrompt := Prompt{
|
||||
Name: "test.prompt",
|
||||
Description: "A test prompt",
|
||||
}
|
||||
mock.server.RegisterPrompt(testPrompt)
|
||||
|
||||
// Create a get prompt request
|
||||
paramBytes, _ := json.Marshal(map[string]any{
|
||||
"name": "test.prompt",
|
||||
})
|
||||
promptReq := Request{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: 1,
|
||||
Method: "prompts/get",
|
||||
Params: paramBytes,
|
||||
}
|
||||
|
||||
// Process the request
|
||||
mock.server.processGetPrompt(context.Background(), client, promptReq)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
_, err := parseEvent(response)
|
||||
assert.NoError(t, err, "Should be able to parse event")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for prompt response")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestBroadcast tests the broadcast functionality
|
||||
@@ -1903,34 +1884,79 @@ func TestNotificationInitialized(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSendBadResponse(t *testing.T) {
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
func TestSendResponse(t *testing.T) {
|
||||
t.Run("bad response", func(t *testing.T) {
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
|
||||
// Create a test client
|
||||
client := addTestClient(mock.server, "test-client", true)
|
||||
// Create a test client
|
||||
client := addTestClient(mock.server, "test-client", true)
|
||||
|
||||
// Create a response
|
||||
response := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: 1,
|
||||
Result: make(chan int),
|
||||
}
|
||||
// Create a response
|
||||
response := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: 1,
|
||||
Result: make(chan int),
|
||||
}
|
||||
|
||||
// Send the response
|
||||
mock.server.sendResponse(client, 1, response)
|
||||
// Send the response
|
||||
mock.server.sendResponse(context.Background(), client, 1, response)
|
||||
|
||||
// Check the response in the client's channel
|
||||
select {
|
||||
case res := <-client.channel:
|
||||
evt, err := parseEvent(res)
|
||||
require.NoError(t, err, "Should parse event without error")
|
||||
errMsg, ok := evt.Data["error"].(map[string]any)
|
||||
require.True(t, ok, "Should have error in response")
|
||||
assert.Equal(t, float64(errCodeInternalError), errMsg["code"], "Error code should match")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for response")
|
||||
}
|
||||
// Check the response in the client's channel
|
||||
select {
|
||||
case res := <-client.channel:
|
||||
evt, err := parseEvent(res)
|
||||
require.NoError(t, err, "Should parse event without error")
|
||||
errMsg, ok := evt.Data["error"].(map[string]any)
|
||||
require.True(t, ok, "Should have error in response")
|
||||
assert.Equal(t, float64(errCodeInternalError), errMsg["code"], "Error code should match")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for response")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("channel full", func(t *testing.T) {
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
|
||||
// Create a test client
|
||||
client := addTestClient(mock.server, "test-client", true)
|
||||
for i := 0; i < eventChanSize; i++ {
|
||||
client.channel <- "test"
|
||||
}
|
||||
|
||||
// Create a response
|
||||
response := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: 1,
|
||||
Result: "foo",
|
||||
}
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
// Send the response
|
||||
mock.server.sendResponse(context.Background(), client, 1, response)
|
||||
// Check the response in the client's channel
|
||||
assert.Contains(t, buf.String(), "channel is full")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSendErrorResponse(t *testing.T) {
|
||||
t.Run("channel full", func(t *testing.T) {
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
|
||||
// Create a test client
|
||||
client := addTestClient(mock.server, "test-client", true)
|
||||
for i := 0; i < eventChanSize; i++ {
|
||||
client.channel <- "test"
|
||||
}
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
// Send the response
|
||||
mock.server.sendErrorResponse(context.Background(), client, 1, "foo", errCodeInternalError)
|
||||
// Check the response in the client's channel
|
||||
assert.Contains(t, buf.String(), "channel is full")
|
||||
})
|
||||
}
|
||||
|
||||
// TestMethodToolsCall tests the handling of tools/call method through handleRequest
|
||||
@@ -2028,8 +2054,7 @@ func TestMethodToolsCall(t *testing.T) {
|
||||
if len(content) > 0 {
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
if ok {
|
||||
assert.Equal(t, "text", firstItem["type"], "Content type should be text")
|
||||
assert.Contains(t, firstItem["text"], "Processed: test-input", "Content should include processed input")
|
||||
assert.Contains(t, firstItem[ContentTypeText], "Processed: test-input", "Content should include processed input")
|
||||
}
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
@@ -2145,7 +2170,6 @@ func TestMethodPromptsGet(t *testing.T) {
|
||||
{
|
||||
Name: "topic",
|
||||
Description: "Topic to discuss",
|
||||
Default: "artificial intelligence",
|
||||
},
|
||||
},
|
||||
Content: "Hello {{name}}! Let's talk about {{topic}}.",
|
||||
@@ -2227,14 +2251,12 @@ func TestMethodPromptsGet(t *testing.T) {
|
||||
if len(messages) > 0 {
|
||||
message, ok := messages[0].(map[string]any)
|
||||
require.True(t, ok, "Message should be an object")
|
||||
assert.Equal(t, "user", message["role"], "Role should be 'user'")
|
||||
assert.Equal(t, string(RoleUser), message["role"], "Role should be 'user'")
|
||||
|
||||
content, ok := message["content"].(map[string]any)
|
||||
require.True(t, ok, "Should have content object")
|
||||
assert.Equal(t, "text", content["type"], "Content type should be text")
|
||||
assert.Contains(t, content["text"], "Hello Test User", "Content should include the name argument")
|
||||
assert.Contains(t, content["text"], "about artificial intelligence",
|
||||
"Content should include the default topic argument")
|
||||
assert.Equal(t, ContentTypeText, content["type"], "Content type should be text")
|
||||
assert.Contains(t, content[ContentTypeText], "Hello Test User", "Content should include the name argument")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for prompt get response")
|
||||
@@ -2255,27 +2277,24 @@ func TestMethodPromptsGet(t *testing.T) {
|
||||
{
|
||||
Name: "question",
|
||||
Description: "User's question",
|
||||
Default: "How does this work?",
|
||||
},
|
||||
},
|
||||
Handler: func(args map[string]string) ([]PromptMessage, error) {
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]PromptMessage, error) {
|
||||
username := args["username"]
|
||||
question := args["question"]
|
||||
|
||||
// Create a system message
|
||||
systemMessage := PromptMessage{
|
||||
Role: "system",
|
||||
Role: RoleAssistant,
|
||||
Content: TextContent{
|
||||
Type: "text",
|
||||
Text: "You are a helpful assistant.",
|
||||
},
|
||||
}
|
||||
|
||||
// Create a user message
|
||||
userMessage := PromptMessage{
|
||||
Role: "user",
|
||||
Role: RoleUser,
|
||||
Content: TextContent{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("Hi, I'm %s and I'm wondering: %s", username, question),
|
||||
},
|
||||
}
|
||||
@@ -2340,20 +2359,20 @@ func TestMethodPromptsGet(t *testing.T) {
|
||||
|
||||
// Check message content
|
||||
if len(messages) >= 2 {
|
||||
// First message should be system
|
||||
// First message should be assistant
|
||||
message1, _ := messages[0].(map[string]any)
|
||||
assert.Equal(t, "system", message1["role"], "First role should be 'system'")
|
||||
assert.Equal(t, string(RoleAssistant), message1["role"], "First role should be 'system'")
|
||||
|
||||
content1, _ := message1["content"].(map[string]any)
|
||||
assert.Contains(t, content1["text"], "helpful assistant", "System message should be correct")
|
||||
assert.Contains(t, content1[ContentTypeText], "helpful assistant", "System message should be correct")
|
||||
|
||||
// Second message should be user
|
||||
message2, _ := messages[1].(map[string]any)
|
||||
assert.Equal(t, "user", message2["role"], "Second role should be 'user'")
|
||||
assert.Equal(t, string(RoleUser), message2["role"], "Second role should be 'user'")
|
||||
|
||||
content2, _ := message2["content"].(map[string]any)
|
||||
assert.Contains(t, content2["text"], "Dynamic User", "User message should contain username")
|
||||
assert.Contains(t, content2["text"], "How to test this?", "User message should contain question")
|
||||
assert.Contains(t, content2[ContentTypeText], "Dynamic User", "User message should contain username")
|
||||
assert.Contains(t, content2[ContentTypeText], "How to test this?", "User message should contain question")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for prompt get response")
|
||||
@@ -2459,7 +2478,7 @@ func TestMethodPromptsGet(t *testing.T) {
|
||||
Name: "error-handler-prompt",
|
||||
Description: "A prompt with a handler that returns an error",
|
||||
Arguments: []PromptArgument{},
|
||||
Handler: func(args map[string]string) ([]PromptMessage, error) {
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]PromptMessage, error) {
|
||||
return nil, fmt.Errorf("test handler error")
|
||||
},
|
||||
}
|
||||
@@ -2583,7 +2602,7 @@ func TestMethodResourcesList(t *testing.T) {
|
||||
URI: "file:///test/resource.txt",
|
||||
Description: "A test resource with handler",
|
||||
MimeType: "text/plain",
|
||||
Handler: func() (ResourceContent, error) {
|
||||
Handler: func(ctx context.Context) (ResourceContent, error) {
|
||||
return ResourceContent{
|
||||
URI: "file:///test/resource.txt",
|
||||
MimeType: "text/plain",
|
||||
@@ -2654,7 +2673,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
URI: "file:///test/resource.txt",
|
||||
Description: "A test resource with handler",
|
||||
MimeType: "text/plain",
|
||||
Handler: func() (ResourceContent, error) {
|
||||
Handler: func(ctx context.Context) (ResourceContent, error) {
|
||||
return ResourceContent{
|
||||
URI: "file:///test/resource.txt",
|
||||
MimeType: "text/plain",
|
||||
@@ -2729,7 +2748,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
require.True(t, ok, "Content should be an object")
|
||||
assert.Equal(t, "file:///test/resource.txt", content["uri"], "URI should match")
|
||||
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match")
|
||||
assert.Equal(t, "This is test resource content", content["text"], "Text content should match")
|
||||
assert.Equal(t, "This is test resource content", content[ContentTypeText], "Text content should match")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for resource read response")
|
||||
@@ -2799,7 +2818,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
require.True(t, ok, "Content should be an object")
|
||||
assert.Equal(t, "file:///test/no-handler.txt", content["uri"], "URI should match")
|
||||
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match")
|
||||
_, ok = content["text"]
|
||||
_, ok = content[ContentTypeText]
|
||||
assert.False(t, ok, "Text content should be empty string")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
@@ -2880,7 +2899,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
client := addTestClient(mock.server, "test-client-resources", true)
|
||||
|
||||
// Process through handleRequest
|
||||
mock.server.processResourcesRead(client, req)
|
||||
mock.server.processResourcesRead(context.Background(), client, req)
|
||||
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
@@ -2898,7 +2917,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
URI: "file:///test/error.txt",
|
||||
Description: "A test resource with handler that returns error",
|
||||
MimeType: "text/plain",
|
||||
Handler: func() (ResourceContent, error) {
|
||||
Handler: func(ctx context.Context) (ResourceContent, error) {
|
||||
return ResourceContent{}, fmt.Errorf("test handler error")
|
||||
},
|
||||
}
|
||||
@@ -2946,7 +2965,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
URI: "file:///test/missing-fields.txt",
|
||||
Description: "A test resource with handler that returns content missing fields",
|
||||
MimeType: "text/plain",
|
||||
Handler: func() (ResourceContent, error) {
|
||||
Handler: func(ctx context.Context) (ResourceContent, error) {
|
||||
// Return ResourceContent without URI and MimeType
|
||||
return ResourceContent{
|
||||
Text: "Content with missing fields",
|
||||
@@ -3006,7 +3025,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
require.True(t, ok, "Content should be an object")
|
||||
assert.Equal(t, "file:///test/missing-fields.txt", content["uri"], "URI should be filled from request")
|
||||
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should be filled from resource")
|
||||
assert.Equal(t, "Content with missing fields", content["text"], "Text content should match")
|
||||
assert.Equal(t, "Content with missing fields", content[ContentTypeText], "Text content should match")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for resource read response")
|
||||
@@ -3159,7 +3178,7 @@ func TestMethodResourcesSubscribe(t *testing.T) {
|
||||
}
|
||||
|
||||
client := addTestClient(mock.server, "test-client-sub-not-found", true)
|
||||
mock.server.processResourceSubscribe(client, req)
|
||||
mock.server.processResourceSubscribe(context.Background(), client, req)
|
||||
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
@@ -3268,7 +3287,7 @@ func TestToolCallUnmarshalError(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call directly
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Check for error response about invalid JSON
|
||||
select {
|
||||
@@ -3316,7 +3335,7 @@ func TestToolCallWithInvalidParams(t *testing.T) {
|
||||
jsonBody, _ := json.Marshal(req)
|
||||
|
||||
// Create HTTP request
|
||||
r := httptest.NewRequest("POST", "/?session_id=test-client-invalid-json", bytes.NewReader(jsonBody))
|
||||
r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-invalid-json", bytes.NewReader(jsonBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Process through handleRequest
|
||||
|
||||
Reference in New Issue
Block a user