chore: add unit test for WithCodeResponseWriter (#5028)

Signed-off-by: kevin <wanjunfeng@gmail.com>
This commit is contained in:
Kevin Wan
2025-07-25 21:45:47 +08:00
committed by GitHub
parent 0be63c3625
commit 25f37ca750
5 changed files with 44 additions and 30 deletions

View File

@@ -10,7 +10,7 @@ import (
"github.com/zeromicro/go-zero/core/codec"
"github.com/zeromicro/go-zero/core/load"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/logc"
"github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/rest/chain"
"github.com/zeromicro/go-zero/rest/handler"
@@ -67,25 +67,6 @@ func (ng *engine) addRoutes(r featuredRoutes) {
ng.mightUpdateTimeout(r)
}
func buildSSERoutes(routes []Route) []Route {
for i, route := range routes {
h := route.Handler
routes[i].Handler = func(w http.ResponseWriter, r *http.Request) {
rc := http.NewResponseController(w)
err := rc.SetWriteDeadline(time.Time{})
if err != nil {
logx.Errorf("set conn write deadline failed:%v", err)
}
w.Header().Set(header.ContentType, header.ContentTypeEventStream)
w.Header().Set(header.CacheControl, header.CacheControlNoCache)
w.Header().Set(header.Connection, header.ConnectionKeepAlive)
h(w, r)
}
}
return routes
}
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
verifier func(chain.Chain) chain.Chain) chain.Chain {
if fr.jwt.enabled {
@@ -400,6 +381,27 @@ func (ng *engine) withNetworkTimeout() internal.StartOption {
}
}
func buildSSERoutes(routes []Route) []Route {
for i, route := range routes {
h := route.Handler
routes[i].Handler = func(w http.ResponseWriter, r *http.Request) {
// remove the default write deadline set by http.Server,
// because SSE requires the connection to be kept alive indefinitely.
rc := http.NewResponseController(w)
if err := rc.SetWriteDeadline(time.Time{}); err != nil {
logc.Errorf(r.Context(), "set conn write deadline failed: %v", err)
}
w.Header().Set(header.ContentType, header.ContentTypeEventStream)
w.Header().Set(header.CacheControl, header.CacheControlNoCache)
w.Header().Set(header.Connection, header.ConnectionKeepAlive)
h(w, r)
}
}
return routes
}
func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return ware(next.ServeHTTP)

View File

@@ -578,3 +578,7 @@ func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {
func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
}
func ptrOfDuration(d time.Duration) *time.Duration {
return &d
}

View File

@@ -49,6 +49,12 @@ func (w *WithCodeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, errors.New("server doesn't support hijacking")
}
// Unwrap returns the underlying http.ResponseWriter.
// This is used by http.ResponseController to unwrap the response writer.
func (w *WithCodeResponseWriter) Unwrap() http.ResponseWriter {
return w.Writer
}
// Write writes bytes into w.
func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) {
return w.Writer.Write(bytes)
@@ -59,8 +65,3 @@ func (w *WithCodeResponseWriter) WriteHeader(code int) {
w.Writer.WriteHeader(code)
w.Code = code
}
// Unwrap returns the underlying ResponseWriter.
func (w *WithCodeResponseWriter) Unwrap() http.ResponseWriter {
return w.Writer
}

View File

@@ -46,3 +46,15 @@ func TestWithCodeResponseWriter_Hijack(t *testing.T) {
writer.Hijack()
})
}
func TestWithCodeResponseWriter_Unwrap(t *testing.T) {
resp := httptest.NewRecorder()
writer := NewWithCodeResponseWriter(resp)
unwrapped := writer.Unwrap()
assert.Equal(t, resp, unwrapped)
// Test with a nested WithCodeResponseWriter
nestedWriter := NewWithCodeResponseWriter(writer)
unwrappedNested := nestedWriter.Unwrap()
assert.Equal(t, resp, unwrappedNested)
}

View File

@@ -293,7 +293,6 @@ func WithSignature(signature SignatureConf) RouteOption {
func WithSSE() RouteOption {
return func(r *featuredRoutes) {
r.sse = true
r.timeout = ptrOfDuration(0)
}
}
@@ -335,10 +334,6 @@ func handleError(err error) {
panic(err)
}
func ptrOfDuration(d time.Duration) *time.Duration {
return &d
}
func validateSecret(secret string) {
if len(secret) < 8 {
panic("secret's length can't be less than 8")