From 31b9ba19a29430407a82bb45f44bc71aefdb5104 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Sun, 9 Jul 2023 15:04:59 +0800 Subject: [PATCH] chore: refactor httpx.TimeoutHandler (#3400) --- rest/handler/timeouthandler.go | 30 ++++++--- rest/handler/timeouthandler_test.go | 98 +++++++++++++++++++++-------- 2 files changed, 91 insertions(+), 37 deletions(-) diff --git a/rest/handler/timeouthandler.go b/rest/handler/timeouthandler.go index 38a681226..c956b49eb 100644 --- a/rest/handler/timeouthandler.go +++ b/rest/handler/timeouthandler.go @@ -127,18 +127,29 @@ type timeoutWriter struct { var _ http.Pusher = (*timeoutWriter)(nil) +// Flush implements the Flusher interface. func (tw *timeoutWriter) Flush() { - dst := tw.w.Header() - for k, vv := range tw.h { - dst[k] = vv + flusher, ok := tw.w.(http.Flusher) + if !ok { + return } - if flusher, ok := tw.w.(http.Flusher); ok { - tw.w.Write(tw.wbuf.Bytes()) - tw.wbuf.Reset() - flusher.Flush() + + header := tw.w.Header() + for k, v := range tw.h { + header[k] = v } + + tw.w.Write(tw.wbuf.Bytes()) + tw.wbuf.Reset() + flusher.Flush() } +// Header returns the underline temporary http.Header. +func (tw *timeoutWriter) Header() http.Header { + return tw.h +} + +// Hijack implements the Hijacker interface. func (tw *timeoutWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { if hijacked, ok := tw.w.(http.Hijacker); ok { return hijacked.Hijack() @@ -147,14 +158,12 @@ func (tw *timeoutWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return nil, nil, errors.New("server doesn't support hijacking") } -// Header returns the underline temporary http.Header. -func (tw *timeoutWriter) Header() http.Header { return tw.h } - // Push implements the Pusher interface. func (tw *timeoutWriter) Push(target string, opts *http.PushOptions) error { if pusher, ok := tw.w.(http.Pusher); ok { return pusher.Push(target, opts) } + return http.ErrNotSupported } @@ -171,6 +180,7 @@ func (tw *timeoutWriter) Write(p []byte) (int, error) { if !tw.wroteHeader { tw.writeHeaderLocked(http.StatusOK) } + return tw.wbuf.Write(p) } diff --git a/rest/handler/timeouthandler_test.go b/rest/handler/timeouthandler_test.go index 4f496fe00..a234ce097 100644 --- a/rest/handler/timeouthandler_test.go +++ b/rest/handler/timeouthandler_test.go @@ -17,35 +17,63 @@ import ( ) func TestTimeoutWriteFlushOutput(t *testing.T) { - timeoutHandler := TimeoutHandler(1000 * time.Millisecond) - handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream;charset=utf-8") - flusher, ok := w.(http.Flusher) - if !ok { - http.Error(w, "Flushing not supported", http.StatusInternalServerError) - return + t.Run("flusher", func(t *testing.T) { + timeoutHandler := TimeoutHandler(1000 * time.Millisecond) + handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Flushing not supported", http.StatusInternalServerError) + return + } + + for i := 1; i <= 5; i++ { + fmt.Fprint(w, strconv.Itoa(i)+" cats\n\n") + flusher.Flush() + time.Sleep(time.Millisecond) + } + })) + req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + scanner := bufio.NewScanner(resp.Body) + var cats int + for scanner.Scan() { + line := scanner.Text() + if strings.Contains(line, "cats") { + cats++ + } } - for i := 1; i <= 5; i++ { - fmt.Fprint(w, strconv.Itoa(i)+"只猫猫\n\n") - flusher.Flush() - time.Sleep(time.Millisecond) + if err := scanner.Err(); err != nil { + cats = 0 } - })) - req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody) - resp := httptest.NewRecorder() - handler.ServeHTTP(resp, req) - scanner := bufio.NewScanner(resp.Body) - mao := 0 - for scanner.Scan() { - line := scanner.Text() - if strings.Contains(line, "猫猫") { - mao++ - } - } - if err := scanner.Err(); err != nil { - mao = 0 - } - assert.Equal(t, "5只猫猫", strconv.Itoa(mao)+"只猫猫") + assert.Equal(t, 5, cats) + }) + + t.Run("writer", func(t *testing.T) { + recorder := httptest.NewRecorder() + timeoutHandler := TimeoutHandler(1000 * time.Millisecond) + handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Flushing not supported", http.StatusInternalServerError) + return + } + + for i := 1; i <= 5; i++ { + fmt.Fprint(w, strconv.Itoa(i)+" cats\n\n") + flusher.Flush() + time.Sleep(time.Millisecond) + assert.Empty(t, recorder.Body.String()) + } + })) + req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody) + resp := mockedResponseWriter{recorder} + handler.ServeHTTP(resp, req) + assert.Equal(t, "1 cats\n\n2 cats\n\n3 cats\n\n4 cats\n\n5 cats\n\n", + recorder.Body.String()) + }) } func TestTimeout(t *testing.T) { @@ -274,3 +302,19 @@ func (m mockedPusher) WriteHeader(_ int) { func (m mockedPusher) Push(_ string, _ *http.PushOptions) error { panic("implement me") } + +type mockedResponseWriter struct { + http.ResponseWriter +} + +func (m mockedResponseWriter) Header() http.Header { + return m.ResponseWriter.Header() +} + +func (m mockedResponseWriter) Write(bytes []byte) (int, error) { + return m.ResponseWriter.Write(bytes) +} + +func (m mockedResponseWriter) WriteHeader(statusCode int) { + m.ResponseWriter.WriteHeader(statusCode) +}