diff --git a/rest/engine.go b/rest/engine.go index 3be437799..87a3d340e 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -10,7 +10,7 @@ import ( "github.com/zeromicro/go-zero/core/codec" "github.com/zeromicro/go-zero/core/load" - "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/core/logc" "github.com/zeromicro/go-zero/core/stat" "github.com/zeromicro/go-zero/rest/chain" "github.com/zeromicro/go-zero/rest/handler" @@ -67,25 +67,6 @@ func (ng *engine) addRoutes(r featuredRoutes) { ng.mightUpdateTimeout(r) } -func buildSSERoutes(routes []Route) []Route { - for i, route := range routes { - h := route.Handler - routes[i].Handler = func(w http.ResponseWriter, r *http.Request) { - rc := http.NewResponseController(w) - err := rc.SetWriteDeadline(time.Time{}) - if err != nil { - logx.Errorf("set conn write deadline failed:%v", err) - } - w.Header().Set(header.ContentType, header.ContentTypeEventStream) - w.Header().Set(header.CacheControl, header.CacheControlNoCache) - w.Header().Set(header.Connection, header.ConnectionKeepAlive) - h(w, r) - } - } - - return routes -} - func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain, verifier func(chain.Chain) chain.Chain) chain.Chain { if fr.jwt.enabled { @@ -400,6 +381,27 @@ func (ng *engine) withNetworkTimeout() internal.StartOption { } } +func buildSSERoutes(routes []Route) []Route { + for i, route := range routes { + h := route.Handler + routes[i].Handler = func(w http.ResponseWriter, r *http.Request) { + // remove the default write deadline set by http.Server, + // because SSE requires the connection to be kept alive indefinitely. + rc := http.NewResponseController(w) + if err := rc.SetWriteDeadline(time.Time{}); err != nil { + logc.Errorf(r.Context(), "set conn write deadline failed: %v", err) + } + + w.Header().Set(header.ContentType, header.ContentTypeEventStream) + w.Header().Set(header.CacheControl, header.CacheControlNoCache) + w.Header().Set(header.Connection, header.ConnectionKeepAlive) + h(w, r) + } + } + + return routes +} + func convertMiddleware(ware Middleware) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return ware(next.ServeHTTP) diff --git a/rest/engine_test.go b/rest/engine_test.go index 6f461294b..0a9aadb8e 100644 --- a/rest/engine_test.go +++ b/rest/engine_test.go @@ -578,3 +578,7 @@ func (m mockedRouter) SetNotFoundHandler(_ http.Handler) { func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) { } + +func ptrOfDuration(d time.Duration) *time.Duration { + return &d +} diff --git a/rest/internal/response/withcoderesponsewriter.go b/rest/internal/response/withcoderesponsewriter.go index 9975aef4c..9f818b7c6 100644 --- a/rest/internal/response/withcoderesponsewriter.go +++ b/rest/internal/response/withcoderesponsewriter.go @@ -49,6 +49,12 @@ func (w *WithCodeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return nil, nil, errors.New("server doesn't support hijacking") } +// Unwrap returns the underlying http.ResponseWriter. +// This is used by http.ResponseController to unwrap the response writer. +func (w *WithCodeResponseWriter) Unwrap() http.ResponseWriter { + return w.Writer +} + // Write writes bytes into w. func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) { return w.Writer.Write(bytes) @@ -59,8 +65,3 @@ func (w *WithCodeResponseWriter) WriteHeader(code int) { w.Writer.WriteHeader(code) w.Code = code } - -// Unwrap returns the underlying ResponseWriter. -func (w *WithCodeResponseWriter) Unwrap() http.ResponseWriter { - return w.Writer -} diff --git a/rest/internal/response/withcoderesponsewriter_test.go b/rest/internal/response/withcoderesponsewriter_test.go index 00485d3d9..c72111766 100644 --- a/rest/internal/response/withcoderesponsewriter_test.go +++ b/rest/internal/response/withcoderesponsewriter_test.go @@ -46,3 +46,15 @@ func TestWithCodeResponseWriter_Hijack(t *testing.T) { writer.Hijack() }) } + +func TestWithCodeResponseWriter_Unwrap(t *testing.T) { + resp := httptest.NewRecorder() + writer := NewWithCodeResponseWriter(resp) + unwrapped := writer.Unwrap() + assert.Equal(t, resp, unwrapped) + + // Test with a nested WithCodeResponseWriter + nestedWriter := NewWithCodeResponseWriter(writer) + unwrappedNested := nestedWriter.Unwrap() + assert.Equal(t, resp, unwrappedNested) +} diff --git a/rest/server.go b/rest/server.go index 6cb7914b7..3809d876f 100644 --- a/rest/server.go +++ b/rest/server.go @@ -293,7 +293,6 @@ func WithSignature(signature SignatureConf) RouteOption { func WithSSE() RouteOption { return func(r *featuredRoutes) { r.sse = true - r.timeout = ptrOfDuration(0) } } @@ -335,10 +334,6 @@ func handleError(err error) { panic(err) } -func ptrOfDuration(d time.Duration) *time.Duration { - return &d -} - func validateSecret(secret string) { if len(secret) < 8 { panic("secret's length can't be less than 8")