diff --git a/gateway/internal/eventhandler.go b/gateway/internal/eventhandler.go index 94d203b01..a09e5c64d 100644 --- a/gateway/internal/eventhandler.go +++ b/gateway/internal/eventhandler.go @@ -1,6 +1,7 @@ package internal import ( + "fmt" "io" "net/http" @@ -12,6 +13,22 @@ import ( "google.golang.org/grpc/status" ) +// MetadataHeaderPrefix is the http prefix that represents custom metadata +// parameters to or from a gRPC call. +const MetadataHeaderPrefix = "Grpc-Metadata-" + +// MetadataTrailerPrefix is prepended to gRPC metadata as it is converted to +// HTTP headers in a response handled by go-zero gateway +const MetadataTrailerPrefix = "Grpc-Trailer-" + +func defaultOutgoingHeaderMatcher(key string) (string, bool) { + return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true +} + +func defaultOutgoingTrailerMatcher(key string) (string, bool) { + return fmt.Sprintf("%s%s", MetadataTrailerPrefix, key), true +} + type EventHandler struct { Status *status.Status writer io.Writer @@ -31,9 +48,11 @@ func NewEventHandler(writer io.Writer, resolver jsonpb.AnyResolver) *EventHandle func (h *EventHandler) OnReceiveHeaders(md metadata.MD) { w, ok := h.writer.(http.ResponseWriter) if ok { - for k, v := range md { - for _, val := range v { - w.Header().Add(k, val) + for k, vs := range md { + if h, ok := defaultOutgoingHeaderMatcher(k); ok { + for _, v := range vs { + w.Header().Add(h, v) + } } } } @@ -48,9 +67,11 @@ func (h *EventHandler) OnReceiveResponse(message proto.Message) { func (h *EventHandler) OnReceiveTrailers(status *status.Status, md metadata.MD) { w, ok := h.writer.(http.ResponseWriter) if ok { - for k, v := range md { - for _, val := range v { - w.Header().Add(k, val) + for k, vs := range md { + if h, ok := defaultOutgoingTrailerMatcher(k); ok { + for _, v := range vs { + w.Header().Add(h, v) + } } } } diff --git a/gateway/internal/eventhandler_test.go b/gateway/internal/eventhandler_test.go index 4fd2b56fd..ca7afa6dc 100644 --- a/gateway/internal/eventhandler_test.go +++ b/gateway/internal/eventhandler_test.go @@ -40,8 +40,8 @@ func TestEventHandler_OnReceiveTrailers(t *testing.T) { }, expectedStatus: codes.OK, expectedHeader: map[string][]string{ - "X-Custom-Header": {"value1", "value2"}, - "X-Another-Header": {"single-value"}, + "Grpc-Trailer-X-Custom-Header": {"value1", "value2"}, + "Grpc-Trailer-X-Another-Header": {"single-value"}, }, }, { @@ -100,9 +100,9 @@ func TestEventHandler_OnReceiveHeaders(t *testing.T) { "x-another-header": []string{"single-value"}, }, expectedHeader: map[string][]string{ - "Content-Type": {"application/json"}, - "X-Custom-Header": {"value1", "value2"}, - "X-Another-Header": {"single-value"}, + "Grpc-Metadata-Content-Type": {"application/json"}, + "Grpc-Metadata-X-Custom-Header": {"value1", "value2"}, + "Grpc-Metadata-X-Another-Header": {"single-value"}, }, }, { @@ -158,7 +158,81 @@ func TestEventHandler_OnReceiveHeaders_MultipleValues(t *testing.T) { "x-header-2": []string{"value3"}, }) - // Check that headers are accumulated (not overwritten) - assert.Equal(t, []string{"value1", "value2"}, recorder.Header()["X-Header-1"]) - assert.Equal(t, []string{"value3"}, recorder.Header()["X-Header-2"]) + // Check that headers are accumulated (not overwritten) with proper prefix + assert.Equal(t, []string{"value1", "value2"}, recorder.Header()["Grpc-Metadata-X-Header-1"]) + assert.Equal(t, []string{"value3"}, recorder.Header()["Grpc-Metadata-X-Header-2"]) +} + +func TestEventHandler_OnReceiveHeaders_MetadataPrefix(t *testing.T) { + tests := []struct { + name string + metadata metadata.MD + expectedHeader map[string][]string + }{ + { + name: "all metadata headers should be prefixed with Grpc-Metadata-", + metadata: metadata.MD{ + "content-type": []string{"application/grpc"}, + "x-custom-header": []string{"value1"}, + "authorization": []string{"Bearer token"}, + }, + expectedHeader: map[string][]string{ + "Grpc-Metadata-Content-Type": {"application/grpc"}, + "Grpc-Metadata-X-Custom-Header": {"value1"}, + "Grpc-Metadata-Authorization": {"Bearer token"}, + }, + }, + { + name: "mixed case headers should be prefixed", + metadata: metadata.MD{ + "Content-Type": []string{"APPLICATION/JSON"}, + "X-Custom-Header": []string{"value1"}, + }, + expectedHeader: map[string][]string{ + "Grpc-Metadata-Content-Type": {"APPLICATION/JSON"}, + "Grpc-Metadata-X-Custom-Header": {"value1"}, + }, + }, + { + name: "multiple values for same header", + metadata: metadata.MD{ + "x-multi-header": []string{"value1", "value2", "value3"}, + }, + expectedHeader: map[string][]string{ + "Grpc-Metadata-X-Multi-Header": {"value1", "value2", "value3"}, + }, + }, + { + name: "empty metadata", + metadata: metadata.MD{}, + expectedHeader: map[string][]string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + h := NewEventHandler(recorder, nil) + + h.OnReceiveHeaders(tt.metadata) + + // Check that headers are set correctly + for key, expectedValues := range tt.expectedHeader { + actualValues := recorder.Header()[key] + assert.Equal(t, expectedValues, actualValues, "Header %s should match", key) + } + + // Ensure no unexpected headers are set + for actualKey := range recorder.Header() { + found := false + for expectedKey := range tt.expectedHeader { + if actualKey == expectedKey { + found = true + break + } + } + assert.True(t, found, "Unexpected header found: %s", actualKey) + } + }) + } }