mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-07 06:59:59 +08:00
186 lines
5.3 KiB
Go
186 lines
5.3 KiB
Go
|
|
package mcp
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"testing"
|
||
|
|
|
||
|
|
"github.com/stretchr/testify/assert"
|
||
|
|
"github.com/zeromicro/go-zero/rest/pathvar"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestDefaultRequestMetadataExtractor(t *testing.T) {
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/sse?tenant=t1&trace=abc", nil)
|
||
|
|
req.Header.Add("X-Tenant-Id", "tenant-from-header")
|
||
|
|
req = pathvar.WithVars(req, map[string]string{"tool": "sum"})
|
||
|
|
|
||
|
|
metadata := DefaultRequestMetadataExtractor(req)
|
||
|
|
header, ok := metadata.Headers["X-Tenant-Id"]
|
||
|
|
assert.True(t, ok)
|
||
|
|
assert.Equal(t, []string{"tenant-from-header"}, header)
|
||
|
|
assert.Equal(t, []string{"t1"}, metadata.Query["tenant"])
|
||
|
|
assert.Equal(t, "sum", metadata.Path["tool"])
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRequestMetadataContextHelpers(t *testing.T) {
|
||
|
|
ctx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{
|
||
|
|
Headers: map[string][]string{"X-Trace-Id": {"trace-1"}},
|
||
|
|
Query: map[string][]string{"tenant": {"foo"}},
|
||
|
|
Path: map[string]string{"scope": "prod"},
|
||
|
|
})
|
||
|
|
|
||
|
|
metadata, ok := RequestMetadataFromContext(ctx)
|
||
|
|
assert.True(t, ok)
|
||
|
|
assert.Equal(t, []string{"trace-1"}, metadata.Headers["X-Trace-Id"])
|
||
|
|
|
||
|
|
header, ok := HeaderFromContext(ctx, "x-trace-id")
|
||
|
|
assert.True(t, ok)
|
||
|
|
assert.Equal(t, "trace-1", header)
|
||
|
|
|
||
|
|
query, ok := QueryFromContext(ctx, "tenant")
|
||
|
|
assert.True(t, ok)
|
||
|
|
assert.Equal(t, "foo", query)
|
||
|
|
|
||
|
|
path, ok := PathFromContext(ctx, "scope")
|
||
|
|
assert.True(t, ok)
|
||
|
|
assert.Equal(t, "prod", path)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRequestMetadataContextHelpersMissingKeys(t *testing.T) {
|
||
|
|
ctx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{
|
||
|
|
Headers: map[string][]string{"X-Trace-Id": {"trace-1"}},
|
||
|
|
Query: map[string][]string{"tenant": {"foo"}},
|
||
|
|
Path: map[string]string{"scope": "prod"},
|
||
|
|
})
|
||
|
|
|
||
|
|
_, ok := HeaderFromContext(ctx, "x-missing")
|
||
|
|
assert.False(t, ok)
|
||
|
|
|
||
|
|
_, ok = QueryFromContext(ctx, "missing")
|
||
|
|
assert.False(t, ok)
|
||
|
|
|
||
|
|
_, ok = PathFromContext(ctx, "missing")
|
||
|
|
assert.False(t, ok)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRequestMetadataFromContextNotFound(t *testing.T) {
|
||
|
|
_, ok := RequestMetadataFromContext(context.Background())
|
||
|
|
assert.False(t, ok)
|
||
|
|
|
||
|
|
_, ok = HeaderFromContext(context.Background(), "x-test")
|
||
|
|
assert.False(t, ok)
|
||
|
|
|
||
|
|
_, ok = QueryFromContext(context.Background(), "tenant")
|
||
|
|
assert.False(t, ok)
|
||
|
|
|
||
|
|
_, ok = PathFromContext(context.Background(), "tenant")
|
||
|
|
assert.False(t, ok)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestWrapRequestMetadata(t *testing.T) {
|
||
|
|
s := &mcpServerImpl{
|
||
|
|
options: serverOptions{
|
||
|
|
requestMetadataExtractor: DefaultRequestMetadataExtractor,
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
called := false
|
||
|
|
handler := s.wrapRequestMetadata(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||
|
|
called = true
|
||
|
|
header, ok := HeaderFromContext(r.Context(), "x-tenant-id")
|
||
|
|
assert.True(t, ok)
|
||
|
|
assert.Equal(t, "tenant-1", header)
|
||
|
|
|
||
|
|
query, ok := QueryFromContext(r.Context(), "tenant")
|
||
|
|
assert.True(t, ok)
|
||
|
|
assert.Equal(t, "q-tenant", query)
|
||
|
|
}))
|
||
|
|
|
||
|
|
req := httptest.NewRequest(http.MethodGet, "/sse?tenant=q-tenant", nil)
|
||
|
|
req.Header.Set("X-Tenant-Id", "tenant-1")
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, req)
|
||
|
|
|
||
|
|
assert.True(t, called)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestWrapRequestMetadataNoExtractor(t *testing.T) {
|
||
|
|
s := &mcpServerImpl{}
|
||
|
|
|
||
|
|
called := false
|
||
|
|
handler := s.wrapRequestMetadata(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||
|
|
called = true
|
||
|
|
_, ok := RequestMetadataFromContext(r.Context())
|
||
|
|
assert.False(t, ok)
|
||
|
|
}))
|
||
|
|
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/sse", nil))
|
||
|
|
|
||
|
|
assert.True(t, called)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestWrapRequestMetadataCanonicalizesCustomHeaders(t *testing.T) {
|
||
|
|
s := &mcpServerImpl{
|
||
|
|
options: serverOptions{
|
||
|
|
requestMetadataExtractor: func(*http.Request) RequestMetadata {
|
||
|
|
return RequestMetadata{
|
||
|
|
Headers: map[string][]string{
|
||
|
|
"x-tenant-id": {"tenant-lower"},
|
||
|
|
},
|
||
|
|
}
|
||
|
|
},
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
called := false
|
||
|
|
handler := s.wrapRequestMetadata(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||
|
|
called = true
|
||
|
|
header, ok := HeaderFromContext(r.Context(), "X-Tenant-Id")
|
||
|
|
assert.True(t, ok)
|
||
|
|
assert.Equal(t, "tenant-lower", header)
|
||
|
|
}))
|
||
|
|
|
||
|
|
rr := httptest.NewRecorder()
|
||
|
|
handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/sse", nil))
|
||
|
|
|
||
|
|
assert.True(t, called)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRequestMetadataFromContextReturnsCopy(t *testing.T) {
|
||
|
|
ctx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{
|
||
|
|
Headers: map[string][]string{"X-Trace-Id": {"trace-1"}},
|
||
|
|
})
|
||
|
|
|
||
|
|
metadata, ok := RequestMetadataFromContext(ctx)
|
||
|
|
assert.True(t, ok)
|
||
|
|
metadata.Headers["X-Trace-Id"][0] = "mutated"
|
||
|
|
metadata.Headers["X-New"] = []string{"new"}
|
||
|
|
|
||
|
|
fresh, ok := RequestMetadataFromContext(ctx)
|
||
|
|
assert.True(t, ok)
|
||
|
|
assert.Equal(t, []string{"trace-1"}, fresh.Headers["X-Trace-Id"])
|
||
|
|
assert.Nil(t, fresh.Headers["X-New"])
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRequestMetadataFromContextWithEmptyAndCanonicalizedHeaders(t *testing.T) {
|
||
|
|
emptyCtx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{})
|
||
|
|
empty, ok := RequestMetadataFromContext(emptyCtx)
|
||
|
|
assert.True(t, ok)
|
||
|
|
assert.Nil(t, empty.Headers)
|
||
|
|
assert.Nil(t, empty.Query)
|
||
|
|
assert.Nil(t, empty.Path)
|
||
|
|
|
||
|
|
ctx := context.WithValue(context.Background(), requestMetadataCtxKey{}, RequestMetadata{
|
||
|
|
Headers: map[string][]string{
|
||
|
|
"x-tenant-id": {"a"},
|
||
|
|
"X-Tenant-Id": {"b"},
|
||
|
|
},
|
||
|
|
})
|
||
|
|
|
||
|
|
metadata, ok := RequestMetadataFromContext(ctx)
|
||
|
|
assert.True(t, ok)
|
||
|
|
assert.Equal(t, []string{"a", "b"}, metadata.Headers["X-Tenant-Id"])
|
||
|
|
}
|