mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-07 15:10:01 +08:00
feat(gateway): add custom middleware support with onion model (#5035)
This commit is contained in:
@@ -27,12 +27,16 @@ import (
|
|||||||
const defaultHttpScheme = "http"
|
const defaultHttpScheme = "http"
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
// MiddlewareFunc defines the function signature for middleware.
|
||||||
|
MiddlewareFunc func(next http.HandlerFunc) http.HandlerFunc
|
||||||
|
|
||||||
// Server is a gateway server.
|
// Server is a gateway server.
|
||||||
Server struct {
|
Server struct {
|
||||||
*rest.Server
|
*rest.Server
|
||||||
upstreams []Upstream
|
upstreams []Upstream
|
||||||
conns []zrpc.Client
|
conns []zrpc.Client
|
||||||
processHeader func(http.Header) []string
|
processHeader func(http.Header) []string
|
||||||
|
middlewares []MiddlewareFunc
|
||||||
dialer func(conf zrpc.RpcClientConf) zrpc.Client
|
dialer func(conf zrpc.RpcClientConf) zrpc.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,7 +109,7 @@ func (s *Server) build() error {
|
|||||||
|
|
||||||
func (s *Server) buildGrpcHandler(source grpcurl.DescriptorSource, resolver jsonpb.AnyResolver,
|
func (s *Server) buildGrpcHandler(source grpcurl.DescriptorSource, resolver jsonpb.AnyResolver,
|
||||||
cli zrpc.Client, rpcPath string) func(http.ResponseWriter, *http.Request) {
|
cli zrpc.Client, rpcPath string) func(http.ResponseWriter, *http.Request) {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||||
parser, err := internal.NewRequestParser(r, resolver)
|
parser, err := internal.NewRequestParser(r, resolver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpx.ErrorCtx(r.Context(), w, err)
|
httpx.ErrorCtx(r.Context(), w, err)
|
||||||
@@ -124,6 +128,8 @@ func (s *Server) buildGrpcHandler(source grpcurl.DescriptorSource, resolver json
|
|||||||
httpx.ErrorCtx(r.Context(), w, st.Err())
|
httpx.ErrorCtx(r.Context(), w, st.Err())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return s.buildChainHandler(handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) buildGrpcRoute(up Upstream, writer mr.Writer[rest.Route], cancel func(error)) {
|
func (s *Server) buildGrpcRoute(up Upstream, writer mr.Writer[rest.Route], cancel func(error)) {
|
||||||
@@ -177,7 +183,7 @@ func (s *Server) buildGrpcRoute(up Upstream, writer mr.Writer[rest.Route], cance
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) buildHttpHandler(target *HttpClientConf) http.HandlerFunc {
|
func (s *Server) buildHttpHandler(target *HttpClientConf) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set(httpx.ContentType, httpx.JsonContentType)
|
w.Header().Set(httpx.ContentType, httpx.JsonContentType)
|
||||||
req, err := buildRequestWithNewTarget(r, target)
|
req, err := buildRequestWithNewTarget(r, target)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -213,6 +219,8 @@ func (s *Server) buildHttpHandler(target *HttpClientConf) http.HandlerFunc {
|
|||||||
logc.Error(r.Context(), err)
|
logc.Error(r.Context(), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return s.buildChainHandler(handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) buildHttpRoute(up Upstream, writer mr.Writer[rest.Route]) {
|
func (s *Server) buildHttpRoute(up Upstream, writer mr.Writer[rest.Route]) {
|
||||||
@@ -263,6 +271,21 @@ func WithHeaderProcessor(processHeader func(http.Header) []string) func(*Server)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithMiddleware adds one or more middleware functions to process HTTP requests.
|
||||||
|
// Multiple middlewares will be executed in the order they were passed (like an onion model).
|
||||||
|
func WithMiddleware(middlewares ...MiddlewareFunc) func(*Server) {
|
||||||
|
return func(s *Server) {
|
||||||
|
s.middlewares = append(s.middlewares, middlewares...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) buildChainHandler(handler http.HandlerFunc) http.HandlerFunc {
|
||||||
|
for i := len(s.middlewares) - 1; i >= 0; i-- {
|
||||||
|
handler = s.middlewares[i](handler)
|
||||||
|
}
|
||||||
|
return handler
|
||||||
|
}
|
||||||
|
|
||||||
func buildRequestWithNewTarget(r *http.Request, target *HttpClientConf) (*http.Request, error) {
|
func buildRequestWithNewTarget(r *http.Request, target *HttpClientConf) (*http.Request, error) {
|
||||||
u := *r.URL
|
u := *r.URL
|
||||||
u.Host = target.Target
|
u.Host = target.Target
|
||||||
|
|||||||
@@ -325,3 +325,50 @@ type badResponseWriter struct {
|
|||||||
func (w *badResponseWriter) Write([]byte) (int, error) {
|
func (w *badResponseWriter) Write([]byte) (int, error) {
|
||||||
return 0, errors.New("bad writer")
|
return 0, errors.New("bad writer")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWithMiddleware(t *testing.T) {
|
||||||
|
var callOrder []string
|
||||||
|
|
||||||
|
firstMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
callOrder = append(callOrder, "first-start")
|
||||||
|
w.Header().Set("X-First-Middleware", "called")
|
||||||
|
next(w, r)
|
||||||
|
callOrder = append(callOrder, "first-end")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
secondMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
callOrder = append(callOrder, "second-start")
|
||||||
|
w.Header().Set("X-Second-Middleware", "called")
|
||||||
|
next(w, r)
|
||||||
|
callOrder = append(callOrder, "second-end")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var c GatewayConf
|
||||||
|
err := conf.FillDefault(&c)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
// Test multiple middlewares in one call
|
||||||
|
server1 := MustNewServer(c, WithMiddleware(firstMiddleware, secondMiddleware))
|
||||||
|
assert.Len(t, server1.middlewares, 2, "Should have 2 middlewares from one call")
|
||||||
|
// Test multiple middleware calls
|
||||||
|
server2 := MustNewServer(c, WithMiddleware(firstMiddleware), WithMiddleware(secondMiddleware))
|
||||||
|
assert.Len(t, server2.middlewares, 2, "Should have 2 middlewares from separate calls")
|
||||||
|
// Test execution order (onion model)
|
||||||
|
finalHandler := func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
callOrder = append(callOrder, "handler")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
testHandler := server1.buildChainHandler(finalHandler)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
testHandler(w, r)
|
||||||
|
|
||||||
|
expectedOrder := []string{"first-start", "second-start", "handler", "second-end", "first-end"}
|
||||||
|
assert.Equal(t, expectedOrder, callOrder, "Middleware execution should follow onion model")
|
||||||
|
assert.Equal(t, "called", w.Header().Get("X-First-Middleware"))
|
||||||
|
assert.Equal(t, "called", w.Header().Get("X-Second-Middleware"))
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user