From d68cf4920c879005bbc2c422ed03677165ddb4e8 Mon Sep 17 00:00:00 2001 From: guonaihong Date: Thu, 14 Aug 2025 22:32:09 +0800 Subject: [PATCH] fix: 5076 (#5083) Co-authored-by: Kevin Wan Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- gateway/internal/eventhandler.go | 24 ++++- gateway/internal/eventhandler_test.go | 144 ++++++++++++++++++++++++++ 2 files changed, 165 insertions(+), 3 deletions(-) diff --git a/gateway/internal/eventhandler.go b/gateway/internal/eventhandler.go index e3a0f0c2d..84dde7f5f 100644 --- a/gateway/internal/eventhandler.go +++ b/gateway/internal/eventhandler.go @@ -2,6 +2,7 @@ package internal import ( "io" + "net/http" "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/proto" @@ -33,15 +34,32 @@ func (h *EventHandler) OnReceiveResponse(message proto.Message) { } } -func (h *EventHandler) OnReceiveTrailers(status *status.Status, _ metadata.MD) { +func (h *EventHandler) OnReceiveTrailers(status *status.Status, md metadata.MD) { + w, ok := h.writer.(http.ResponseWriter) + if ok { + for k, v := range md { + for _, v2 := range v { + w.Header().Add(http.CanonicalHeaderKey(k), v2) + } + } + } h.Status = status } - func (h *EventHandler) OnResolveMethod(_ *desc.MethodDescriptor) { } func (h *EventHandler) OnSendHeaders(_ metadata.MD) { } -func (h *EventHandler) OnReceiveHeaders(_ metadata.MD) { +func (h *EventHandler) OnReceiveHeaders(md metadata.MD) { + w, ok := h.writer.(http.ResponseWriter) + if ok { + for k, v := range md { + for _, v2 := range v { + canonicalKey := http.CanonicalHeaderKey(k) + for _, v2 := range v { + w.Header().Add(canonicalKey, v2) + } + } + } } diff --git a/gateway/internal/eventhandler_test.go b/gateway/internal/eventhandler_test.go index cf81ace1b..4fd2b56fd 100644 --- a/gateway/internal/eventhandler_test.go +++ b/gateway/internal/eventhandler_test.go @@ -2,10 +2,12 @@ package internal import ( "io" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) @@ -18,3 +20,145 @@ func TestEventHandler(t *testing.T) { assert.Equal(t, codes.OK, h.Status.Code()) h.OnReceiveResponse(nil) } + +func TestEventHandler_OnReceiveTrailers(t *testing.T) { + tests := []struct { + name string + writer io.Writer + status *status.Status + metadata metadata.MD + expectedStatus codes.Code + expectedHeader map[string][]string + }{ + { + name: "with http.ResponseWriter and metadata", + writer: httptest.NewRecorder(), + status: status.New(codes.OK, "success"), + metadata: metadata.MD{ + "x-custom-header": []string{"value1", "value2"}, + "x-another-header": []string{"single-value"}, + }, + expectedStatus: codes.OK, + expectedHeader: map[string][]string{ + "X-Custom-Header": {"value1", "value2"}, + "X-Another-Header": {"single-value"}, + }, + }, + { + name: "with http.ResponseWriter and nil metadata", + writer: httptest.NewRecorder(), + status: status.New(codes.Internal, "error"), + metadata: nil, + expectedStatus: codes.Internal, + expectedHeader: map[string][]string{}, + }, + { + name: "with non-http.ResponseWriter", + writer: io.Discard, + status: status.New(codes.OK, "success"), + metadata: metadata.MD{"x-header": []string{"value"}}, + expectedStatus: codes.OK, + expectedHeader: nil, // headers should not be set + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := NewEventHandler(tt.writer, nil) + + h.OnReceiveTrailers(tt.status, tt.metadata) + + // Check status is set correctly + assert.Equal(t, tt.expectedStatus, h.Status.Code()) + + // Check headers are set correctly if writer is http.ResponseWriter + if recorder, ok := tt.writer.(*httptest.ResponseRecorder); ok { + if tt.expectedHeader != nil { + for key, expectedValues := range tt.expectedHeader { + actualValues := recorder.Header()[key] + assert.Equal(t, expectedValues, actualValues, "Header %s should match", key) + } + } + } + }) + } +} + +func TestEventHandler_OnReceiveHeaders(t *testing.T) { + tests := []struct { + name string + writer io.Writer + metadata metadata.MD + expectedHeader map[string][]string + }{ + { + name: "with http.ResponseWriter and metadata", + writer: httptest.NewRecorder(), + metadata: metadata.MD{ + "content-type": []string{"application/json"}, + "x-custom-header": []string{"value1", "value2"}, + "x-another-header": []string{"single-value"}, + }, + expectedHeader: map[string][]string{ + "Content-Type": {"application/json"}, + "X-Custom-Header": {"value1", "value2"}, + "X-Another-Header": {"single-value"}, + }, + }, + { + name: "with http.ResponseWriter and nil metadata", + writer: httptest.NewRecorder(), + metadata: nil, + expectedHeader: map[string][]string{}, + }, + { + name: "with http.ResponseWriter and empty metadata", + writer: httptest.NewRecorder(), + metadata: metadata.MD{}, + expectedHeader: map[string][]string{}, + }, + { + name: "with non-http.ResponseWriter", + writer: io.Discard, + metadata: metadata.MD{"x-header": []string{"value"}}, + expectedHeader: nil, // headers should not be set + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := NewEventHandler(tt.writer, nil) + + h.OnReceiveHeaders(tt.metadata) + + // Check headers are set correctly if writer is http.ResponseWriter + if recorder, ok := tt.writer.(*httptest.ResponseRecorder); ok { + if tt.expectedHeader != nil { + for key, expectedValues := range tt.expectedHeader { + actualValues := recorder.Header()[key] + assert.Equal(t, expectedValues, actualValues, "Header %s should match", key) + } + } + } + }) + } +} + +func TestEventHandler_OnReceiveHeaders_MultipleValues(t *testing.T) { + recorder := httptest.NewRecorder() + h := NewEventHandler(recorder, nil) + + // Test that multiple calls to OnReceiveHeaders accumulate headers + h.OnReceiveHeaders(metadata.MD{ + "x-header-1": []string{"value1"}, + }) + + h.OnReceiveHeaders(metadata.MD{ + "x-header-1": []string{"value2"}, // Should add to existing header + "x-header-2": []string{"value3"}, + }) + + // Check that headers are accumulated (not overwritten) + assert.Equal(t, []string{"value1", "value2"}, recorder.Header()["X-Header-1"]) + assert.Equal(t, []string{"value3"}, recorder.Header()["X-Header-2"]) +}