mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-07 15:10:01 +08:00
chore: add unit test for WithCodeResponseWriter (#5028)
Signed-off-by: kevin <wanjunfeng@gmail.com>
This commit is contained in:
@@ -10,7 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/codec"
|
"github.com/zeromicro/go-zero/core/codec"
|
||||||
"github.com/zeromicro/go-zero/core/load"
|
"github.com/zeromicro/go-zero/core/load"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logc"
|
||||||
"github.com/zeromicro/go-zero/core/stat"
|
"github.com/zeromicro/go-zero/core/stat"
|
||||||
"github.com/zeromicro/go-zero/rest/chain"
|
"github.com/zeromicro/go-zero/rest/chain"
|
||||||
"github.com/zeromicro/go-zero/rest/handler"
|
"github.com/zeromicro/go-zero/rest/handler"
|
||||||
@@ -67,25 +67,6 @@ func (ng *engine) addRoutes(r featuredRoutes) {
|
|||||||
ng.mightUpdateTimeout(r)
|
ng.mightUpdateTimeout(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildSSERoutes(routes []Route) []Route {
|
|
||||||
for i, route := range routes {
|
|
||||||
h := route.Handler
|
|
||||||
routes[i].Handler = func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
rc := http.NewResponseController(w)
|
|
||||||
err := rc.SetWriteDeadline(time.Time{})
|
|
||||||
if err != nil {
|
|
||||||
logx.Errorf("set conn write deadline failed:%v", err)
|
|
||||||
}
|
|
||||||
w.Header().Set(header.ContentType, header.ContentTypeEventStream)
|
|
||||||
w.Header().Set(header.CacheControl, header.CacheControlNoCache)
|
|
||||||
w.Header().Set(header.Connection, header.ConnectionKeepAlive)
|
|
||||||
h(w, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return routes
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
|
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
|
||||||
verifier func(chain.Chain) chain.Chain) chain.Chain {
|
verifier func(chain.Chain) chain.Chain) chain.Chain {
|
||||||
if fr.jwt.enabled {
|
if fr.jwt.enabled {
|
||||||
@@ -400,6 +381,27 @@ func (ng *engine) withNetworkTimeout() internal.StartOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildSSERoutes(routes []Route) []Route {
|
||||||
|
for i, route := range routes {
|
||||||
|
h := route.Handler
|
||||||
|
routes[i].Handler = func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// remove the default write deadline set by http.Server,
|
||||||
|
// because SSE requires the connection to be kept alive indefinitely.
|
||||||
|
rc := http.NewResponseController(w)
|
||||||
|
if err := rc.SetWriteDeadline(time.Time{}); err != nil {
|
||||||
|
logc.Errorf(r.Context(), "set conn write deadline failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set(header.ContentType, header.ContentTypeEventStream)
|
||||||
|
w.Header().Set(header.CacheControl, header.CacheControlNoCache)
|
||||||
|
w.Header().Set(header.Connection, header.ConnectionKeepAlive)
|
||||||
|
h(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes
|
||||||
|
}
|
||||||
|
|
||||||
func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {
|
func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return ware(next.ServeHTTP)
|
return ware(next.ServeHTTP)
|
||||||
|
|||||||
@@ -578,3 +578,7 @@ func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {
|
|||||||
|
|
||||||
func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
|
func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ptrOfDuration(d time.Duration) *time.Duration {
|
||||||
|
return &d
|
||||||
|
}
|
||||||
|
|||||||
@@ -49,6 +49,12 @@ func (w *WithCodeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|||||||
return nil, nil, errors.New("server doesn't support hijacking")
|
return nil, nil, errors.New("server doesn't support hijacking")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unwrap returns the underlying http.ResponseWriter.
|
||||||
|
// This is used by http.ResponseController to unwrap the response writer.
|
||||||
|
func (w *WithCodeResponseWriter) Unwrap() http.ResponseWriter {
|
||||||
|
return w.Writer
|
||||||
|
}
|
||||||
|
|
||||||
// Write writes bytes into w.
|
// Write writes bytes into w.
|
||||||
func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) {
|
func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) {
|
||||||
return w.Writer.Write(bytes)
|
return w.Writer.Write(bytes)
|
||||||
@@ -59,8 +65,3 @@ func (w *WithCodeResponseWriter) WriteHeader(code int) {
|
|||||||
w.Writer.WriteHeader(code)
|
w.Writer.WriteHeader(code)
|
||||||
w.Code = code
|
w.Code = code
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unwrap returns the underlying ResponseWriter.
|
|
||||||
func (w *WithCodeResponseWriter) Unwrap() http.ResponseWriter {
|
|
||||||
return w.Writer
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -46,3 +46,15 @@ func TestWithCodeResponseWriter_Hijack(t *testing.T) {
|
|||||||
writer.Hijack()
|
writer.Hijack()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWithCodeResponseWriter_Unwrap(t *testing.T) {
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
writer := NewWithCodeResponseWriter(resp)
|
||||||
|
unwrapped := writer.Unwrap()
|
||||||
|
assert.Equal(t, resp, unwrapped)
|
||||||
|
|
||||||
|
// Test with a nested WithCodeResponseWriter
|
||||||
|
nestedWriter := NewWithCodeResponseWriter(writer)
|
||||||
|
unwrappedNested := nestedWriter.Unwrap()
|
||||||
|
assert.Equal(t, resp, unwrappedNested)
|
||||||
|
}
|
||||||
|
|||||||
@@ -293,7 +293,6 @@ func WithSignature(signature SignatureConf) RouteOption {
|
|||||||
func WithSSE() RouteOption {
|
func WithSSE() RouteOption {
|
||||||
return func(r *featuredRoutes) {
|
return func(r *featuredRoutes) {
|
||||||
r.sse = true
|
r.sse = true
|
||||||
r.timeout = ptrOfDuration(0)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -335,10 +334,6 @@ func handleError(err error) {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ptrOfDuration(d time.Duration) *time.Duration {
|
|
||||||
return &d
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateSecret(secret string) {
|
func validateSecret(secret string) {
|
||||||
if len(secret) < 8 {
|
if len(secret) < 8 {
|
||||||
panic("secret's length can't be less than 8")
|
panic("secret's length can't be less than 8")
|
||||||
|
|||||||
Reference in New Issue
Block a user