diff --git a/rest/handler/loghandler.go b/rest/handler/loghandler.go index e6552fc02..f28e0d262 100644 --- a/rest/handler/loghandler.go +++ b/rest/handler/loghandler.go @@ -24,12 +24,16 @@ import ( ) const ( - limitBodyBytes = 1024 - limitDetailedBodyBytes = 4096 - defaultSlowThreshold = time.Millisecond * 500 + limitBodyBytes = 1024 + limitDetailedBodyBytes = 4096 + defaultSlowThreshold = time.Millisecond * 500 + defaultSSESlowThreshold = time.Minute * 3 ) -var slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold) +var ( + slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold) + sseSlowThreshold = syncx.ForAtomicDuration(defaultSSESlowThreshold) +) // LogHandler returns a middleware that logs http request and response. func LogHandler(next http.Handler) http.Handler { @@ -109,6 +113,11 @@ func SetSlowThreshold(threshold time.Duration) { slowThreshold.Set(threshold) } +// SetSSESlowThreshold sets the slow threshold for SSE requests. +func SetSSESlowThreshold(threshold time.Duration) { + sseSlowThreshold.Set(threshold) +} + func dumpRequest(r *http.Request) string { reqContent, err := httputil.DumpRequest(r, true) if err != nil { @@ -129,7 +138,8 @@ func logBrief(r *http.Request, code int, timer *utils.ElapsedTimer, logs *intern logger := logx.WithContext(r.Context()).WithDuration(duration) buf.WriteString(fmt.Sprintf("[HTTP] %s - %s %s - %s - %s", wrapStatusCode(code), wrapMethod(r.Method), r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent())) - if duration > slowThreshold.Load() { + + if duration > getSlowThreshold(r) { logger.Slowf("[HTTP] %s - %s %s - %s - %s - slowcall(%s)", wrapStatusCode(code), wrapMethod(r.Method), r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent(), timex.ReprOfDuration(duration)) @@ -160,7 +170,8 @@ func logDetails(r *http.Request, response *detailLoggedResponseWriter, timer *ut logger := logx.WithContext(r.Context()) buf.WriteString(fmt.Sprintf("[HTTP] %s - %d - %s - %s\n=> %s\n", r.Method, code, r.RemoteAddr, timex.ReprOfDuration(duration), dumpRequest(r))) - if duration > slowThreshold.Load() { + + if duration > getSlowThreshold(r) { logger.Slowf("[HTTP] %s - %d - %s - slowcall(%s)\n=> %s\n", r.Method, code, r.RemoteAddr, timex.ReprOfDuration(duration), dumpRequest(r)) } @@ -223,3 +234,12 @@ func wrapStatusCode(code int) string { return logx.WithColorPadding(strconv.Itoa(code), colour) } + +func getSlowThreshold(r *http.Request) time.Duration { + threshold := slowThreshold.Load() + if r.Header.Get(headerAccept) == valueSSE { + threshold = sseSlowThreshold.Load() + } + + return threshold +} diff --git a/rest/handler/loghandler_test.go b/rest/handler/loghandler_test.go index 85afcfe84..0629e221f 100644 --- a/rest/handler/loghandler_test.go +++ b/rest/handler/loghandler_test.go @@ -88,6 +88,96 @@ func TestLogHandlerSlow(t *testing.T) { } } +func TestLogHandlerSSE(t *testing.T) { + handlers := []func(handler http.Handler) http.Handler{ + LogHandler, + DetailedLogHandler, + } + + for _, logHandler := range handlers { + t.Run("SSE request with normal duration", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody) + req.Header.Set(headerAccept, valueSSE) + + handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(defaultSlowThreshold + time.Second) + w.WriteHeader(http.StatusOK) + })) + + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) + }) + + t.Run("SSE request exceeding SSE threshold", func(t *testing.T) { + originalThreshold := sseSlowThreshold.Load() + SetSSESlowThreshold(time.Millisecond * 100) + defer SetSSESlowThreshold(originalThreshold) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody) + req.Header.Set(headerAccept, valueSSE) + + handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Millisecond * 150) + w.WriteHeader(http.StatusOK) + })) + + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) + }) + } +} + +func TestLogHandlerThresholdSelection(t *testing.T) { + tests := []struct { + name string + acceptHeader string + expectedIsSSE bool + }{ + { + name: "Regular HTTP request", + acceptHeader: "text/html", + expectedIsSSE: false, + }, + { + name: "SSE request", + acceptHeader: valueSSE, + expectedIsSSE: true, + }, + { + name: "No Accept header", + acceptHeader: "", + expectedIsSSE: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody) + if tt.acceptHeader != "" { + req.Header.Set(headerAccept, tt.acceptHeader) + } + + SetSlowThreshold(time.Millisecond * 100) + SetSSESlowThreshold(time.Millisecond * 200) + defer func() { + SetSlowThreshold(defaultSlowThreshold) + SetSSESlowThreshold(defaultSSESlowThreshold) + }() + + handler := LogHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Millisecond * 150) + w.WriteHeader(http.StatusOK) + })) + + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) + }) + } +} + func TestDetailedLogHandler_LargeBody(t *testing.T) { lbuf := logtest.NewCollector(t) @@ -139,6 +229,12 @@ func TestSetSlowThreshold(t *testing.T) { assert.Equal(t, time.Second, slowThreshold.Load()) } +func TestSetSSESlowThreshold(t *testing.T) { + assert.Equal(t, defaultSSESlowThreshold, sseSlowThreshold.Load()) + SetSSESlowThreshold(time.Minute * 10) + assert.Equal(t, time.Minute*10, sseSlowThreshold.Load()) +} + func TestWrapMethodWithColor(t *testing.T) { // no tty assert.Equal(t, http.MethodGet, wrapMethod(http.MethodGet))