diff --git a/gateway/server.go b/gateway/server.go index fe2703953..aea4f2f9e 100644 --- a/gateway/server.go +++ b/gateway/server.go @@ -27,12 +27,16 @@ import ( const defaultHttpScheme = "http" type ( + // MiddlewareFunc defines the function signature for middleware. + MiddlewareFunc func(next http.HandlerFunc) http.HandlerFunc + // Server is a gateway server. Server struct { *rest.Server upstreams []Upstream conns []zrpc.Client processHeader func(http.Header) []string + middlewares []MiddlewareFunc dialer func(conf zrpc.RpcClientConf) zrpc.Client } @@ -105,7 +109,7 @@ func (s *Server) build() error { func (s *Server) buildGrpcHandler(source grpcurl.DescriptorSource, resolver jsonpb.AnyResolver, cli zrpc.Client, rpcPath string) func(http.ResponseWriter, *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { + handler := func(w http.ResponseWriter, r *http.Request) { parser, err := internal.NewRequestParser(r, resolver) if err != nil { httpx.ErrorCtx(r.Context(), w, err) @@ -124,6 +128,8 @@ func (s *Server) buildGrpcHandler(source grpcurl.DescriptorSource, resolver json httpx.ErrorCtx(r.Context(), w, st.Err()) } } + + return s.buildChainHandler(handler) } func (s *Server) buildGrpcRoute(up Upstream, writer mr.Writer[rest.Route], cancel func(error)) { @@ -177,7 +183,7 @@ func (s *Server) buildGrpcRoute(up Upstream, writer mr.Writer[rest.Route], cance } func (s *Server) buildHttpHandler(target *HttpClientConf) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { + handler := func(w http.ResponseWriter, r *http.Request) { w.Header().Set(httpx.ContentType, httpx.JsonContentType) req, err := buildRequestWithNewTarget(r, target) if err != nil { @@ -213,6 +219,8 @@ func (s *Server) buildHttpHandler(target *HttpClientConf) http.HandlerFunc { logc.Error(r.Context(), err) } } + + return s.buildChainHandler(handler) } func (s *Server) buildHttpRoute(up Upstream, writer mr.Writer[rest.Route]) { @@ -263,6 +271,21 @@ func WithHeaderProcessor(processHeader func(http.Header) []string) func(*Server) } } +// WithMiddleware adds one or more middleware functions to process HTTP requests. +// Multiple middlewares will be executed in the order they were passed (like an onion model). +func WithMiddleware(middlewares ...MiddlewareFunc) func(*Server) { + return func(s *Server) { + s.middlewares = append(s.middlewares, middlewares...) + } +} + +func (s *Server) buildChainHandler(handler http.HandlerFunc) http.HandlerFunc { + for i := len(s.middlewares) - 1; i >= 0; i-- { + handler = s.middlewares[i](handler) + } + return handler +} + func buildRequestWithNewTarget(r *http.Request, target *HttpClientConf) (*http.Request, error) { u := *r.URL u.Host = target.Target diff --git a/gateway/server_test.go b/gateway/server_test.go index 80c7693c0..7eff9eb47 100644 --- a/gateway/server_test.go +++ b/gateway/server_test.go @@ -325,3 +325,50 @@ type badResponseWriter struct { func (w *badResponseWriter) Write([]byte) (int, error) { return 0, errors.New("bad writer") } + +func TestWithMiddleware(t *testing.T) { + var callOrder []string + + firstMiddleware := func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + callOrder = append(callOrder, "first-start") + w.Header().Set("X-First-Middleware", "called") + next(w, r) + callOrder = append(callOrder, "first-end") + } + } + + secondMiddleware := func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + callOrder = append(callOrder, "second-start") + w.Header().Set("X-Second-Middleware", "called") + next(w, r) + callOrder = append(callOrder, "second-end") + } + } + + var c GatewayConf + err := conf.FillDefault(&c) + assert.Nil(t, err) + // Test multiple middlewares in one call + server1 := MustNewServer(c, WithMiddleware(firstMiddleware, secondMiddleware)) + assert.Len(t, server1.middlewares, 2, "Should have 2 middlewares from one call") + // Test multiple middleware calls + server2 := MustNewServer(c, WithMiddleware(firstMiddleware), WithMiddleware(secondMiddleware)) + assert.Len(t, server2.middlewares, 2, "Should have 2 middlewares from separate calls") + // Test execution order (onion model) + finalHandler := func(w http.ResponseWriter, r *http.Request) { + callOrder = append(callOrder, "handler") + w.WriteHeader(http.StatusOK) + } + + testHandler := server1.buildChainHandler(finalHandler) + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/test", nil) + testHandler(w, r) + + expectedOrder := []string{"first-start", "second-start", "handler", "second-end", "first-end"} + assert.Equal(t, expectedOrder, callOrder, "Middleware execution should follow onion model") + assert.Equal(t, "called", w.Header().Get("X-First-Middleware")) + assert.Equal(t, "called", w.Header().Get("X-Second-Middleware")) +}