mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-07 15:10:01 +08:00
chore: add unit test for WithCodeResponseWriter (#5028)
Signed-off-by: kevin <wanjunfeng@gmail.com>
This commit is contained in:
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
"github.com/zeromicro/go-zero/core/codec"
|
||||
"github.com/zeromicro/go-zero/core/load"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/logc"
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/rest/chain"
|
||||
"github.com/zeromicro/go-zero/rest/handler"
|
||||
@@ -67,25 +67,6 @@ func (ng *engine) addRoutes(r featuredRoutes) {
|
||||
ng.mightUpdateTimeout(r)
|
||||
}
|
||||
|
||||
func buildSSERoutes(routes []Route) []Route {
|
||||
for i, route := range routes {
|
||||
h := route.Handler
|
||||
routes[i].Handler = func(w http.ResponseWriter, r *http.Request) {
|
||||
rc := http.NewResponseController(w)
|
||||
err := rc.SetWriteDeadline(time.Time{})
|
||||
if err != nil {
|
||||
logx.Errorf("set conn write deadline failed:%v", err)
|
||||
}
|
||||
w.Header().Set(header.ContentType, header.ContentTypeEventStream)
|
||||
w.Header().Set(header.CacheControl, header.CacheControlNoCache)
|
||||
w.Header().Set(header.Connection, header.ConnectionKeepAlive)
|
||||
h(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
|
||||
verifier func(chain.Chain) chain.Chain) chain.Chain {
|
||||
if fr.jwt.enabled {
|
||||
@@ -400,6 +381,27 @@ func (ng *engine) withNetworkTimeout() internal.StartOption {
|
||||
}
|
||||
}
|
||||
|
||||
func buildSSERoutes(routes []Route) []Route {
|
||||
for i, route := range routes {
|
||||
h := route.Handler
|
||||
routes[i].Handler = func(w http.ResponseWriter, r *http.Request) {
|
||||
// remove the default write deadline set by http.Server,
|
||||
// because SSE requires the connection to be kept alive indefinitely.
|
||||
rc := http.NewResponseController(w)
|
||||
if err := rc.SetWriteDeadline(time.Time{}); err != nil {
|
||||
logc.Errorf(r.Context(), "set conn write deadline failed: %v", err)
|
||||
}
|
||||
|
||||
w.Header().Set(header.ContentType, header.ContentTypeEventStream)
|
||||
w.Header().Set(header.CacheControl, header.CacheControlNoCache)
|
||||
w.Header().Set(header.Connection, header.ConnectionKeepAlive)
|
||||
h(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return ware(next.ServeHTTP)
|
||||
|
||||
@@ -578,3 +578,7 @@ func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {
|
||||
|
||||
func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
|
||||
}
|
||||
|
||||
func ptrOfDuration(d time.Duration) *time.Duration {
|
||||
return &d
|
||||
}
|
||||
|
||||
@@ -49,6 +49,12 @@ func (w *WithCodeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return nil, nil, errors.New("server doesn't support hijacking")
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying http.ResponseWriter.
|
||||
// This is used by http.ResponseController to unwrap the response writer.
|
||||
func (w *WithCodeResponseWriter) Unwrap() http.ResponseWriter {
|
||||
return w.Writer
|
||||
}
|
||||
|
||||
// Write writes bytes into w.
|
||||
func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) {
|
||||
return w.Writer.Write(bytes)
|
||||
@@ -59,8 +65,3 @@ func (w *WithCodeResponseWriter) WriteHeader(code int) {
|
||||
w.Writer.WriteHeader(code)
|
||||
w.Code = code
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying ResponseWriter.
|
||||
func (w *WithCodeResponseWriter) Unwrap() http.ResponseWriter {
|
||||
return w.Writer
|
||||
}
|
||||
|
||||
@@ -46,3 +46,15 @@ func TestWithCodeResponseWriter_Hijack(t *testing.T) {
|
||||
writer.Hijack()
|
||||
})
|
||||
}
|
||||
|
||||
func TestWithCodeResponseWriter_Unwrap(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
writer := NewWithCodeResponseWriter(resp)
|
||||
unwrapped := writer.Unwrap()
|
||||
assert.Equal(t, resp, unwrapped)
|
||||
|
||||
// Test with a nested WithCodeResponseWriter
|
||||
nestedWriter := NewWithCodeResponseWriter(writer)
|
||||
unwrappedNested := nestedWriter.Unwrap()
|
||||
assert.Equal(t, resp, unwrappedNested)
|
||||
}
|
||||
|
||||
@@ -293,7 +293,6 @@ func WithSignature(signature SignatureConf) RouteOption {
|
||||
func WithSSE() RouteOption {
|
||||
return func(r *featuredRoutes) {
|
||||
r.sse = true
|
||||
r.timeout = ptrOfDuration(0)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -335,10 +334,6 @@ func handleError(err error) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
func ptrOfDuration(d time.Duration) *time.Duration {
|
||||
return &d
|
||||
}
|
||||
|
||||
func validateSecret(secret string) {
|
||||
if len(secret) < 8 {
|
||||
panic("secret's length can't be less than 8")
|
||||
|
||||
Reference in New Issue
Block a user