Files
go-zero/mcp/request_metadata_test.go

186 lines
5.3 KiB
Go

package mcp
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/rest/pathvar"
)
func TestDefaultRequestMetadataExtractor(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/sse?tenant=t1&trace=abc", nil)
req.Header.Add("X-Tenant-Id", "tenant-from-header")
req = pathvar.WithVars(req, map[string]string{"tool": "sum"})
metadata := DefaultRequestMetadataExtractor(req)
header, ok := metadata.Headers["X-Tenant-Id"]
assert.True(t, ok)
assert.Equal(t, []string{"tenant-from-header"}, header)
assert.Equal(t, []string{"t1"}, metadata.Query["tenant"])
assert.Equal(t, "sum", metadata.Path["tool"])
}
func TestRequestMetadataContextHelpers(t *testing.T) {
ctx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{
Headers: map[string][]string{"X-Trace-Id": {"trace-1"}},
Query: map[string][]string{"tenant": {"foo"}},
Path: map[string]string{"scope": "prod"},
})
metadata, ok := RequestMetadataFromContext(ctx)
assert.True(t, ok)
assert.Equal(t, []string{"trace-1"}, metadata.Headers["X-Trace-Id"])
header, ok := HeaderFromContext(ctx, "x-trace-id")
assert.True(t, ok)
assert.Equal(t, "trace-1", header)
query, ok := QueryFromContext(ctx, "tenant")
assert.True(t, ok)
assert.Equal(t, "foo", query)
path, ok := PathFromContext(ctx, "scope")
assert.True(t, ok)
assert.Equal(t, "prod", path)
}
func TestRequestMetadataContextHelpersMissingKeys(t *testing.T) {
ctx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{
Headers: map[string][]string{"X-Trace-Id": {"trace-1"}},
Query: map[string][]string{"tenant": {"foo"}},
Path: map[string]string{"scope": "prod"},
})
_, ok := HeaderFromContext(ctx, "x-missing")
assert.False(t, ok)
_, ok = QueryFromContext(ctx, "missing")
assert.False(t, ok)
_, ok = PathFromContext(ctx, "missing")
assert.False(t, ok)
}
func TestRequestMetadataFromContextNotFound(t *testing.T) {
_, ok := RequestMetadataFromContext(context.Background())
assert.False(t, ok)
_, ok = HeaderFromContext(context.Background(), "x-test")
assert.False(t, ok)
_, ok = QueryFromContext(context.Background(), "tenant")
assert.False(t, ok)
_, ok = PathFromContext(context.Background(), "tenant")
assert.False(t, ok)
}
func TestWrapRequestMetadata(t *testing.T) {
s := &mcpServerImpl{
options: serverOptions{
requestMetadataExtractor: DefaultRequestMetadataExtractor,
},
}
called := false
handler := s.wrapRequestMetadata(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
called = true
header, ok := HeaderFromContext(r.Context(), "x-tenant-id")
assert.True(t, ok)
assert.Equal(t, "tenant-1", header)
query, ok := QueryFromContext(r.Context(), "tenant")
assert.True(t, ok)
assert.Equal(t, "q-tenant", query)
}))
req := httptest.NewRequest(http.MethodGet, "/sse?tenant=q-tenant", nil)
req.Header.Set("X-Tenant-Id", "tenant-1")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.True(t, called)
}
func TestWrapRequestMetadataNoExtractor(t *testing.T) {
s := &mcpServerImpl{}
called := false
handler := s.wrapRequestMetadata(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
called = true
_, ok := RequestMetadataFromContext(r.Context())
assert.False(t, ok)
}))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/sse", nil))
assert.True(t, called)
}
func TestWrapRequestMetadataCanonicalizesCustomHeaders(t *testing.T) {
s := &mcpServerImpl{
options: serverOptions{
requestMetadataExtractor: func(*http.Request) RequestMetadata {
return RequestMetadata{
Headers: map[string][]string{
"x-tenant-id": {"tenant-lower"},
},
}
},
},
}
called := false
handler := s.wrapRequestMetadata(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
called = true
header, ok := HeaderFromContext(r.Context(), "X-Tenant-Id")
assert.True(t, ok)
assert.Equal(t, "tenant-lower", header)
}))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/sse", nil))
assert.True(t, called)
}
func TestRequestMetadataFromContextReturnsCopy(t *testing.T) {
ctx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{
Headers: map[string][]string{"X-Trace-Id": {"trace-1"}},
})
metadata, ok := RequestMetadataFromContext(ctx)
assert.True(t, ok)
metadata.Headers["X-Trace-Id"][0] = "mutated"
metadata.Headers["X-New"] = []string{"new"}
fresh, ok := RequestMetadataFromContext(ctx)
assert.True(t, ok)
assert.Equal(t, []string{"trace-1"}, fresh.Headers["X-Trace-Id"])
assert.Nil(t, fresh.Headers["X-New"])
}
func TestRequestMetadataFromContextWithEmptyAndCanonicalizedHeaders(t *testing.T) {
emptyCtx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{})
empty, ok := RequestMetadataFromContext(emptyCtx)
assert.True(t, ok)
assert.Nil(t, empty.Headers)
assert.Nil(t, empty.Query)
assert.Nil(t, empty.Path)
ctx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{
Headers: map[string][]string{
"x-tenant-id": {"a"},
"X-Tenant-Id": {"b"},
},
})
metadata, ok := RequestMetadataFromContext(ctx)
assert.True(t, ok)
assert.Equal(t, []string{"a", "b"}, metadata.Headers["X-Tenant-Id"])
}