diff --git a/rest/engine.go b/rest/engine.go index e57786caf..436dda937 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -15,6 +15,7 @@ import ( "github.com/zeromicro/go-zero/rest/handler" "github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/internal" + "github.com/zeromicro/go-zero/rest/internal/header" "github.com/zeromicro/go-zero/rest/internal/response" ) @@ -54,6 +55,9 @@ func newEngine(c RestConf) *engine { } func (ng *engine) addRoutes(r featuredRoutes) { + if r.sse { + r.routes = buildSSERoutes(r.routes) + } ng.routes = append(ng.routes, r) // need to guarantee the timeout is the max of all routes @@ -63,6 +67,20 @@ func (ng *engine) addRoutes(r featuredRoutes) { } } +func buildSSERoutes(routes []Route) []Route { + for i, route := range routes { + h := route.Handler + routes[i].Handler = func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(header.ContentType, header.ContentTypeEventStream) + w.Header().Set(header.CacheControl, header.CacheControlNoCache) + w.Header().Set(header.Connection, header.ConnectionKeepAlive) + h(w, r) + } + } + + return routes +} + func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain, verifier func(chain.Chain) chain.Chain) chain.Chain { if fr.jwt.enabled { diff --git a/rest/httpc/requests.go b/rest/httpc/requests.go index bc1251596..3692eaae0 100644 --- a/rest/httpc/requests.go +++ b/rest/httpc/requests.go @@ -105,7 +105,7 @@ func buildRequest(ctx context.Context, method, url string, data any) (*http.Requ req.URL.RawQuery = buildFormQuery(u, val[formKey]) fillHeader(req, val[headerKey]) if hasJsonBody { - req.Header.Set(header.ContentType, header.JsonContentType) + req.Header.Set(header.ContentType, header.ContentTypeJson) } return req, nil diff --git a/rest/httpc/requests_test.go b/rest/httpc/requests_test.go index 440752110..a5e66e494 100644 --- a/rest/httpc/requests_test.go +++ b/rest/httpc/requests_test.go @@ -45,7 +45,7 @@ func TestDoRequest_NotFound(t *testing.T) { defer svr.Close() req, err := http.NewRequest(http.MethodPost, svr.URL, nil) assert.Nil(t, err) - req.Header.Set(header.ContentType, header.JsonContentType) + req.Header.Set(header.ContentType, header.ContentTypeJson) resp, err := DoRequest(req) assert.Nil(t, err) assert.Equal(t, http.StatusNotFound, resp.StatusCode) diff --git a/rest/httpc/responses_test.go b/rest/httpc/responses_test.go index c6c7dd3ef..8e3826a11 100644 --- a/rest/httpc/responses_test.go +++ b/rest/httpc/responses_test.go @@ -18,7 +18,7 @@ func TestParse(t *testing.T) { } svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("foo", "bar") - w.Header().Set(header.ContentType, header.JsonContentType) + w.Header().Set(header.ContentType, header.ContentTypeJson) w.Write([]byte(`{"name":"kevin","value":100}`)) })) defer svr.Close() @@ -38,7 +38,7 @@ func TestParseHeaderError(t *testing.T) { } svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("foo", "bar") - w.Header().Set(header.ContentType, header.JsonContentType) + w.Header().Set(header.ContentType, header.ContentTypeJson) })) defer svr.Close() req, err := http.NewRequest(http.MethodGet, svr.URL, nil) @@ -54,7 +54,7 @@ func TestParseNoBody(t *testing.T) { } svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("foo", "bar") - w.Header().Set(header.ContentType, header.JsonContentType) + w.Header().Set(header.ContentType, header.ContentTypeJson) })) defer svr.Close() req, err := http.NewRequest(http.MethodGet, svr.URL, nil) @@ -72,7 +72,7 @@ func TestParseWithZeroValue(t *testing.T) { } svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("foo", "0") - w.Header().Set(header.ContentType, header.JsonContentType) + w.Header().Set(header.ContentType, header.ContentTypeJson) w.Write([]byte(`{"bar":0}`)) })) defer svr.Close() @@ -90,7 +90,7 @@ func TestParseWithNegativeContentLength(t *testing.T) { Bar int `json:"bar"` } svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set(header.ContentType, header.JsonContentType) + w.Header().Set(header.ContentType, header.ContentTypeJson) w.Write([]byte(`{"bar":0}`)) })) defer svr.Close() @@ -124,7 +124,7 @@ func TestParseWithNegativeContentLength(t *testing.T) { func TestParseWithNegativeContentLengthNoBody(t *testing.T) { var val struct{} svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set(header.ContentType, header.JsonContentType) + w.Header().Set(header.ContentType, header.ContentTypeJson) })) defer svr.Close() req, err := http.NewRequest(http.MethodGet, svr.URL, nil) @@ -156,7 +156,7 @@ func TestParseWithNegativeContentLengthNoBody(t *testing.T) { func TestParseJsonBody_BodyError(t *testing.T) { var val struct{} svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set(header.ContentType, header.JsonContentType) + w.Header().Set(header.ContentType, header.ContentTypeJson) })) defer svr.Close() req, err := http.NewRequest(http.MethodGet, svr.URL, nil) diff --git a/rest/httpc/service_test.go b/rest/httpc/service_test.go index f4a32efd0..8e49ddbc9 100644 --- a/rest/httpc/service_test.go +++ b/rest/httpc/service_test.go @@ -44,7 +44,7 @@ func TestNamedService_DoRequestPost(t *testing.T) { service := NewService("foo") req, err := http.NewRequest(http.MethodPost, svr.URL, nil) assert.Nil(t, err) - req.Header.Set(header.ContentType, header.JsonContentType) + req.Header.Set(header.ContentType, header.ContentTypeJson) resp, err := service.DoRequest(req) assert.Nil(t, err) assert.Equal(t, http.StatusNotFound, resp.StatusCode) diff --git a/rest/httpx/requests_test.go b/rest/httpx/requests_test.go index 8297c48d9..8728f8b58 100644 --- a/rest/httpx/requests_test.go +++ b/rest/httpx/requests_test.go @@ -476,7 +476,7 @@ func TestParseJsonBody(t *testing.T) { body := `{"name":"kevin", "age": 18}` r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) - r.Header.Set(ContentType, header.JsonContentType) + r.Header.Set(ContentType, header.ContentTypeJson) if assert.NoError(t, Parse(r, &v)) { assert.Equal(t, "kevin", v.Name) @@ -492,7 +492,7 @@ func TestParseJsonBody(t *testing.T) { body := `{"name":"kevin", "ag": 18}` r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) - r.Header.Set(ContentType, header.JsonContentType) + r.Header.Set(ContentType, header.ContentTypeJson) assert.Error(t, Parse(r, &v)) }) @@ -517,7 +517,7 @@ func TestParseJsonBody(t *testing.T) { body := `[{"name":"kevin", "age": 18}]` r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) - r.Header.Set(ContentType, header.JsonContentType) + r.Header.Set(ContentType, header.ContentTypeJson) assert.NoError(t, Parse(r, &v)) assert.Equal(t, 1, len(v)) @@ -537,7 +537,7 @@ func TestParseJsonBody(t *testing.T) { body := `[{"name":"apple", "age": 18}]` r := httptest.NewRequest(http.MethodPost, "/a?product=tree", strings.NewReader(body)) - r.Header.Set(ContentType, header.JsonContentType) + r.Header.Set(ContentType, header.ContentTypeJson) assert.NoError(t, Parse(r, &v)) assert.Equal(t, 1, len(v)) @@ -555,7 +555,7 @@ func TestParseJsonBody(t *testing.T) { body, _ := json.Marshal(v1) t.Logf("body:%s", string(body)) r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(body))) - r.Header.Set(ContentType, header.JsonContentType) + r.Header.Set(ContentType, header.ContentTypeJson) var v2 v err := ParseJsonBody(r, &v2) if assert.NoError(t, err) { @@ -609,7 +609,7 @@ func TestParseHeaders(t *testing.T) { request.Header.Add("addrs", "addr2") request.Header.Add("X-Forwarded-For", "10.0.10.11") request.Header.Add("x-real-ip", "10.0.11.10") - request.Header.Add("Accept", header.JsonContentType) + request.Header.Add("Accept", header.ContentTypeJson) err = ParseHeaders(request, &v) if err != nil { t.Fatal(err) @@ -619,7 +619,7 @@ func TestParseHeaders(t *testing.T) { assert.Equal(t, []string{"addr1", "addr2"}, v.Addrs) assert.Equal(t, "10.0.10.11", v.XForwardedFor) assert.Equal(t, "10.0.11.10", v.XRealIP) - assert.Equal(t, header.JsonContentType, v.Accept) + assert.Equal(t, header.ContentTypeJson, v.Accept) } func TestParseHeaders_Error(t *testing.T) { @@ -711,7 +711,7 @@ func TestParseWithFloatPtr(t *testing.T) { } body := `{"weightFloat32": 3.2}` r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) - r.Header.Set(ContentType, header.JsonContentType) + r.Header.Set(ContentType, header.ContentTypeJson) if assert.NoError(t, Parse(r, &v)) { assert.Equal(t, float32(3.2), *v.WeightFloat32) diff --git a/rest/httpx/responses.go b/rest/httpx/responses.go index 40995cc7e..6f2de0953 100644 --- a/rest/httpx/responses.go +++ b/rest/httpx/responses.go @@ -179,7 +179,7 @@ func doWriteJson(w http.ResponseWriter, code int, v any) error { return fmt.Errorf("marshal json failed, error: %w", err) } - w.Header().Set(ContentType, header.JsonContentType) + w.Header().Set(ContentType, header.ContentTypeJson) w.WriteHeader(code) if n, err := w.Write(bs); err != nil { diff --git a/rest/httpx/vars.go b/rest/httpx/vars.go index 750bcd203..de97fdcd7 100644 --- a/rest/httpx/vars.go +++ b/rest/httpx/vars.go @@ -10,7 +10,7 @@ const ( // ContentType means Content-Type. ContentType = header.ContentType // JsonContentType means application/json. - JsonContentType = header.JsonContentType + JsonContentType = header.ContentTypeJson // KeyField means key. KeyField = "key" // SecretField means secret. diff --git a/rest/internal/header/headers.go b/rest/internal/header/headers.go index 29fd5e27f..a07577c42 100644 --- a/rest/internal/header/headers.go +++ b/rest/internal/header/headers.go @@ -3,8 +3,18 @@ package header const ( // ApplicationJson stands for application/json. ApplicationJson = "application/json" + // CacheControl is the header key for Cache-Control. + CacheControl = "Cache-Control" + // CacheControlNoCache is the value for Cache-Control: no-cache. + CacheControlNoCache = "no-cache" + // Connection is the header key for Connection. + Connection = "Connection" + // ConnectionKeepAlive is the value for Connection: keep-alive. + ConnectionKeepAlive = "keep-alive" // ContentType is the header key for Content-Type. ContentType = "Content-Type" - // JsonContentType is the content type for JSON. - JsonContentType = "application/json; charset=utf-8" + // ContentTypeJson is the content type for JSON. + ContentTypeJson = "application/json; charset=utf-8" + // ContentTypeEventStream is the content type for event stream. + ContentTypeEventStream = "text/event-stream" ) diff --git a/rest/router/patrouter_test.go b/rest/router/patrouter_test.go index 02f21ece3..3208e6fda 100644 --- a/rest/router/patrouter_test.go +++ b/rest/router/patrouter_test.go @@ -628,7 +628,7 @@ func TestParseWrappedRequest(t *testing.T) { func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017", bytes.NewReader(nil)) assert.Nil(t, err) - r.Header.Set(httpx.ContentType, header.JsonContentType) + r.Header.Set(httpx.ContentType, header.ContentTypeJson) type ( Request struct { @@ -661,7 +661,7 @@ func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) { func TestParseWrappedHeadRequestWithJsonHeader(t *testing.T) { r, err := http.NewRequest(http.MethodHead, "http://hello.com/kevin/2017", bytes.NewReader(nil)) assert.Nil(t, err) - r.Header.Set(httpx.ContentType, header.JsonContentType) + r.Header.Set(httpx.ContentType, header.ContentTypeJson) type ( Request struct { @@ -758,7 +758,7 @@ func TestParseWithAllUtf8(t *testing.T) { r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`)) assert.Nil(t, err) - r.Header.Set(httpx.ContentType, header.JsonContentType) + r.Header.Set(httpx.ContentType, header.ContentTypeJson) router := NewRouter() err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( @@ -948,7 +948,7 @@ func TestParseWithMissingAllPaths(t *testing.T) { func TestParseGetWithContentLengthHeader(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil) assert.Nil(t, err) - r.Header.Set(httpx.ContentType, header.JsonContentType) + r.Header.Set(httpx.ContentType, header.ContentTypeJson) r.Header.Set(contentLength, "1024") router := NewRouter() @@ -976,7 +976,7 @@ func TestParseJsonPostWithTypeMismatch(t *testing.T) { r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", bytes.NewBufferString(`{"time": "20170912"}`)) assert.Nil(t, err) - r.Header.Set(httpx.ContentType, header.JsonContentType) + r.Header.Set(httpx.ContentType, header.ContentTypeJson) router := NewRouter() err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( @@ -1002,7 +1002,7 @@ func TestParseJsonPostWithInt2String(t *testing.T) { r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017", bytes.NewBufferString(`{"time": 20170912}`)) assert.Nil(t, err) - r.Header.Set(httpx.ContentType, header.JsonContentType) + r.Header.Set(httpx.ContentType, header.ContentTypeJson) router := NewRouter() err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( diff --git a/rest/server.go b/rest/server.go index 9aac50e44..b97ec5aa5 100644 --- a/rest/server.go +++ b/rest/server.go @@ -63,6 +63,11 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) { return server, nil } +// AddRoute adds given route into the Server. +func (s *Server) AddRoute(r Route, opts ...RouteOption) { + s.AddRoutes([]Route{r}, opts...) +} + // AddRoutes add given routes into the Server. func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) { r := featuredRoutes{ @@ -74,11 +79,6 @@ func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) { s.ngin.addRoutes(r) } -// AddRoute adds given route into the Server. -func (s *Server) AddRoute(r Route, opts ...RouteOption) { - s.AddRoutes([]Route{r}, opts...) -} - // PrintRoutes prints the added routes to stdout. func (s *Server) PrintRoutes() { s.ngin.print() @@ -279,6 +279,14 @@ func WithSignature(signature SignatureConf) RouteOption { } } +// WithSSE returns a RouteOption to enable server-sent events. +func WithSSE() RouteOption { + return func(r *featuredRoutes) { + r.sse = true + r.timeout = 0 + } +} + // WithTimeout returns a RouteOption to set timeout with given value. func WithTimeout(timeout time.Duration) RouteOption { return func(r *featuredRoutes) { diff --git a/rest/server_test.go b/rest/server_test.go index f9d3d8813..9a513ddc3 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -20,6 +20,7 @@ import ( "github.com/zeromicro/go-zero/rest/chain" "github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/internal/cors" + "github.com/zeromicro/go-zero/rest/internal/header" "github.com/zeromicro/go-zero/rest/router" ) @@ -754,6 +755,40 @@ Port: 54321 } } +func TestServerEventStream(t *testing.T) { + server := MustNewServer(RestConf{}) + server.AddRoutes([]Route{ + { + Method: http.MethodGet, + Path: "/foo", + Handler: func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("foo")) + }, + }, + { + Method: http.MethodGet, + Path: "/bar", + Handler: func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("bar")) + }, + }, + }, WithSSE()) + + check := func(val string) { + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%s", val), http.NoBody) + assert.Nil(t, err) + rr := httptest.NewRecorder() + serve(server, rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, header.ContentTypeEventStream, rr.Header().Get(header.ContentType)) + assert.Equal(t, header.CacheControlNoCache, rr.Header().Get(header.CacheControl)) + assert.Equal(t, header.ConnectionKeepAlive, rr.Header().Get(header.Connection)) + assert.Equal(t, val, rr.Body.String()) + } + check("foo") + check("bar") +} + //go:embed testdata var content embed.FS @@ -770,7 +805,7 @@ func TestServerEmbedFileSystem(t *testing.T) { } // serve is for test purpose, allow developer to do a unit test with -// all defined router without starting an HTTP Server. +// all defined routes without starting an HTTP Server. // // For example: // diff --git a/rest/types.go b/rest/types.go index f7be79964..6d0bcab54 100644 --- a/rest/types.go +++ b/rest/types.go @@ -35,6 +35,7 @@ type ( priority bool jwt jwtSetting signature signatureSetting + sse bool routes []Route maxBytes int64 }