mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-10 08:29:58 +08:00
fix: 5076 (#5083)
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -2,6 +2,7 @@ package internal
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/golang/protobuf/jsonpb"
|
"github.com/golang/protobuf/jsonpb"
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
@@ -33,15 +34,32 @@ func (h *EventHandler) OnReceiveResponse(message proto.Message) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *EventHandler) OnReceiveTrailers(status *status.Status, _ metadata.MD) {
|
func (h *EventHandler) OnReceiveTrailers(status *status.Status, md metadata.MD) {
|
||||||
|
w, ok := h.writer.(http.ResponseWriter)
|
||||||
|
if ok {
|
||||||
|
for k, v := range md {
|
||||||
|
for _, v2 := range v {
|
||||||
|
w.Header().Add(http.CanonicalHeaderKey(k), v2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
h.Status = status
|
h.Status = status
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *EventHandler) OnResolveMethod(_ *desc.MethodDescriptor) {
|
func (h *EventHandler) OnResolveMethod(_ *desc.MethodDescriptor) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *EventHandler) OnSendHeaders(_ metadata.MD) {
|
func (h *EventHandler) OnSendHeaders(_ metadata.MD) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *EventHandler) OnReceiveHeaders(_ metadata.MD) {
|
func (h *EventHandler) OnReceiveHeaders(md metadata.MD) {
|
||||||
|
w, ok := h.writer.(http.ResponseWriter)
|
||||||
|
if ok {
|
||||||
|
for k, v := range md {
|
||||||
|
for _, v2 := range v {
|
||||||
|
canonicalKey := http.CanonicalHeaderKey(k)
|
||||||
|
for _, v2 := range v {
|
||||||
|
w.Header().Add(canonicalKey, v2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,10 +2,12 @@ package internal
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -18,3 +20,145 @@ func TestEventHandler(t *testing.T) {
|
|||||||
assert.Equal(t, codes.OK, h.Status.Code())
|
assert.Equal(t, codes.OK, h.Status.Code())
|
||||||
h.OnReceiveResponse(nil)
|
h.OnReceiveResponse(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEventHandler_OnReceiveTrailers(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
writer io.Writer
|
||||||
|
status *status.Status
|
||||||
|
metadata metadata.MD
|
||||||
|
expectedStatus codes.Code
|
||||||
|
expectedHeader map[string][]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with http.ResponseWriter and metadata",
|
||||||
|
writer: httptest.NewRecorder(),
|
||||||
|
status: status.New(codes.OK, "success"),
|
||||||
|
metadata: metadata.MD{
|
||||||
|
"x-custom-header": []string{"value1", "value2"},
|
||||||
|
"x-another-header": []string{"single-value"},
|
||||||
|
},
|
||||||
|
expectedStatus: codes.OK,
|
||||||
|
expectedHeader: map[string][]string{
|
||||||
|
"X-Custom-Header": {"value1", "value2"},
|
||||||
|
"X-Another-Header": {"single-value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with http.ResponseWriter and nil metadata",
|
||||||
|
writer: httptest.NewRecorder(),
|
||||||
|
status: status.New(codes.Internal, "error"),
|
||||||
|
metadata: nil,
|
||||||
|
expectedStatus: codes.Internal,
|
||||||
|
expectedHeader: map[string][]string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with non-http.ResponseWriter",
|
||||||
|
writer: io.Discard,
|
||||||
|
status: status.New(codes.OK, "success"),
|
||||||
|
metadata: metadata.MD{"x-header": []string{"value"}},
|
||||||
|
expectedStatus: codes.OK,
|
||||||
|
expectedHeader: nil, // headers should not be set
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
h := NewEventHandler(tt.writer, nil)
|
||||||
|
|
||||||
|
h.OnReceiveTrailers(tt.status, tt.metadata)
|
||||||
|
|
||||||
|
// Check status is set correctly
|
||||||
|
assert.Equal(t, tt.expectedStatus, h.Status.Code())
|
||||||
|
|
||||||
|
// Check headers are set correctly if writer is http.ResponseWriter
|
||||||
|
if recorder, ok := tt.writer.(*httptest.ResponseRecorder); ok {
|
||||||
|
if tt.expectedHeader != nil {
|
||||||
|
for key, expectedValues := range tt.expectedHeader {
|
||||||
|
actualValues := recorder.Header()[key]
|
||||||
|
assert.Equal(t, expectedValues, actualValues, "Header %s should match", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventHandler_OnReceiveHeaders(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
writer io.Writer
|
||||||
|
metadata metadata.MD
|
||||||
|
expectedHeader map[string][]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with http.ResponseWriter and metadata",
|
||||||
|
writer: httptest.NewRecorder(),
|
||||||
|
metadata: metadata.MD{
|
||||||
|
"content-type": []string{"application/json"},
|
||||||
|
"x-custom-header": []string{"value1", "value2"},
|
||||||
|
"x-another-header": []string{"single-value"},
|
||||||
|
},
|
||||||
|
expectedHeader: map[string][]string{
|
||||||
|
"Content-Type": {"application/json"},
|
||||||
|
"X-Custom-Header": {"value1", "value2"},
|
||||||
|
"X-Another-Header": {"single-value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with http.ResponseWriter and nil metadata",
|
||||||
|
writer: httptest.NewRecorder(),
|
||||||
|
metadata: nil,
|
||||||
|
expectedHeader: map[string][]string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with http.ResponseWriter and empty metadata",
|
||||||
|
writer: httptest.NewRecorder(),
|
||||||
|
metadata: metadata.MD{},
|
||||||
|
expectedHeader: map[string][]string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with non-http.ResponseWriter",
|
||||||
|
writer: io.Discard,
|
||||||
|
metadata: metadata.MD{"x-header": []string{"value"}},
|
||||||
|
expectedHeader: nil, // headers should not be set
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
h := NewEventHandler(tt.writer, nil)
|
||||||
|
|
||||||
|
h.OnReceiveHeaders(tt.metadata)
|
||||||
|
|
||||||
|
// Check headers are set correctly if writer is http.ResponseWriter
|
||||||
|
if recorder, ok := tt.writer.(*httptest.ResponseRecorder); ok {
|
||||||
|
if tt.expectedHeader != nil {
|
||||||
|
for key, expectedValues := range tt.expectedHeader {
|
||||||
|
actualValues := recorder.Header()[key]
|
||||||
|
assert.Equal(t, expectedValues, actualValues, "Header %s should match", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventHandler_OnReceiveHeaders_MultipleValues(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
h := NewEventHandler(recorder, nil)
|
||||||
|
|
||||||
|
// Test that multiple calls to OnReceiveHeaders accumulate headers
|
||||||
|
h.OnReceiveHeaders(metadata.MD{
|
||||||
|
"x-header-1": []string{"value1"},
|
||||||
|
})
|
||||||
|
|
||||||
|
h.OnReceiveHeaders(metadata.MD{
|
||||||
|
"x-header-1": []string{"value2"}, // Should add to existing header
|
||||||
|
"x-header-2": []string{"value3"},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Check that headers are accumulated (not overwritten)
|
||||||
|
assert.Equal(t, []string{"value1", "value2"}, recorder.Header()["X-Header-1"])
|
||||||
|
assert.Equal(t, []string{"value3"}, recorder.Header()["X-Header-2"])
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user