feat: add rest.WithSSE to build SSE route easier (#4729)

This commit is contained in:
Kevin Wan
2025-03-22 13:38:13 +08:00
committed by GitHub
parent cdb0098b18
commit 6edfce63e3
13 changed files with 106 additions and 34 deletions

View File

@@ -15,6 +15,7 @@ import (
"github.com/zeromicro/go-zero/rest/handler"
"github.com/zeromicro/go-zero/rest/httpx"
"github.com/zeromicro/go-zero/rest/internal"
"github.com/zeromicro/go-zero/rest/internal/header"
"github.com/zeromicro/go-zero/rest/internal/response"
)
@@ -54,6 +55,9 @@ func newEngine(c RestConf) *engine {
}
func (ng *engine) addRoutes(r featuredRoutes) {
if r.sse {
r.routes = buildSSERoutes(r.routes)
}
ng.routes = append(ng.routes, r)
// need to guarantee the timeout is the max of all routes
@@ -63,6 +67,20 @@ func (ng *engine) addRoutes(r featuredRoutes) {
}
}
func buildSSERoutes(routes []Route) []Route {
for i, route := range routes {
h := route.Handler
routes[i].Handler = func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(header.ContentType, header.ContentTypeEventStream)
w.Header().Set(header.CacheControl, header.CacheControlNoCache)
w.Header().Set(header.Connection, header.ConnectionKeepAlive)
h(w, r)
}
}
return routes
}
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
verifier func(chain.Chain) chain.Chain) chain.Chain {
if fr.jwt.enabled {

View File

@@ -105,7 +105,7 @@ func buildRequest(ctx context.Context, method, url string, data any) (*http.Requ
req.URL.RawQuery = buildFormQuery(u, val[formKey])
fillHeader(req, val[headerKey])
if hasJsonBody {
req.Header.Set(header.ContentType, header.JsonContentType)
req.Header.Set(header.ContentType, header.ContentTypeJson)
}
return req, nil

View File

@@ -45,7 +45,7 @@ func TestDoRequest_NotFound(t *testing.T) {
defer svr.Close()
req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
assert.Nil(t, err)
req.Header.Set(header.ContentType, header.JsonContentType)
req.Header.Set(header.ContentType, header.ContentTypeJson)
resp, err := DoRequest(req)
assert.Nil(t, err)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)

View File

@@ -18,7 +18,7 @@ func TestParse(t *testing.T) {
}
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("foo", "bar")
w.Header().Set(header.ContentType, header.JsonContentType)
w.Header().Set(header.ContentType, header.ContentTypeJson)
w.Write([]byte(`{"name":"kevin","value":100}`))
}))
defer svr.Close()
@@ -38,7 +38,7 @@ func TestParseHeaderError(t *testing.T) {
}
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("foo", "bar")
w.Header().Set(header.ContentType, header.JsonContentType)
w.Header().Set(header.ContentType, header.ContentTypeJson)
}))
defer svr.Close()
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
@@ -54,7 +54,7 @@ func TestParseNoBody(t *testing.T) {
}
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("foo", "bar")
w.Header().Set(header.ContentType, header.JsonContentType)
w.Header().Set(header.ContentType, header.ContentTypeJson)
}))
defer svr.Close()
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
@@ -72,7 +72,7 @@ func TestParseWithZeroValue(t *testing.T) {
}
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("foo", "0")
w.Header().Set(header.ContentType, header.JsonContentType)
w.Header().Set(header.ContentType, header.ContentTypeJson)
w.Write([]byte(`{"bar":0}`))
}))
defer svr.Close()
@@ -90,7 +90,7 @@ func TestParseWithNegativeContentLength(t *testing.T) {
Bar int `json:"bar"`
}
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(header.ContentType, header.JsonContentType)
w.Header().Set(header.ContentType, header.ContentTypeJson)
w.Write([]byte(`{"bar":0}`))
}))
defer svr.Close()
@@ -124,7 +124,7 @@ func TestParseWithNegativeContentLength(t *testing.T) {
func TestParseWithNegativeContentLengthNoBody(t *testing.T) {
var val struct{}
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(header.ContentType, header.JsonContentType)
w.Header().Set(header.ContentType, header.ContentTypeJson)
}))
defer svr.Close()
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
@@ -156,7 +156,7 @@ func TestParseWithNegativeContentLengthNoBody(t *testing.T) {
func TestParseJsonBody_BodyError(t *testing.T) {
var val struct{}
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(header.ContentType, header.JsonContentType)
w.Header().Set(header.ContentType, header.ContentTypeJson)
}))
defer svr.Close()
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)

View File

@@ -44,7 +44,7 @@ func TestNamedService_DoRequestPost(t *testing.T) {
service := NewService("foo")
req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
assert.Nil(t, err)
req.Header.Set(header.ContentType, header.JsonContentType)
req.Header.Set(header.ContentType, header.ContentTypeJson)
resp, err := service.DoRequest(req)
assert.Nil(t, err)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)

View File

@@ -476,7 +476,7 @@ func TestParseJsonBody(t *testing.T) {
body := `{"name":"kevin", "age": 18}`
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
r.Header.Set(ContentType, header.JsonContentType)
r.Header.Set(ContentType, header.ContentTypeJson)
if assert.NoError(t, Parse(r, &v)) {
assert.Equal(t, "kevin", v.Name)
@@ -492,7 +492,7 @@ func TestParseJsonBody(t *testing.T) {
body := `{"name":"kevin", "ag": 18}`
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
r.Header.Set(ContentType, header.JsonContentType)
r.Header.Set(ContentType, header.ContentTypeJson)
assert.Error(t, Parse(r, &v))
})
@@ -517,7 +517,7 @@ func TestParseJsonBody(t *testing.T) {
body := `[{"name":"kevin", "age": 18}]`
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
r.Header.Set(ContentType, header.JsonContentType)
r.Header.Set(ContentType, header.ContentTypeJson)
assert.NoError(t, Parse(r, &v))
assert.Equal(t, 1, len(v))
@@ -537,7 +537,7 @@ func TestParseJsonBody(t *testing.T) {
body := `[{"name":"apple", "age": 18}]`
r := httptest.NewRequest(http.MethodPost, "/a?product=tree", strings.NewReader(body))
r.Header.Set(ContentType, header.JsonContentType)
r.Header.Set(ContentType, header.ContentTypeJson)
assert.NoError(t, Parse(r, &v))
assert.Equal(t, 1, len(v))
@@ -555,7 +555,7 @@ func TestParseJsonBody(t *testing.T) {
body, _ := json.Marshal(v1)
t.Logf("body:%s", string(body))
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(body)))
r.Header.Set(ContentType, header.JsonContentType)
r.Header.Set(ContentType, header.ContentTypeJson)
var v2 v
err := ParseJsonBody(r, &v2)
if assert.NoError(t, err) {
@@ -609,7 +609,7 @@ func TestParseHeaders(t *testing.T) {
request.Header.Add("addrs", "addr2")
request.Header.Add("X-Forwarded-For", "10.0.10.11")
request.Header.Add("x-real-ip", "10.0.11.10")
request.Header.Add("Accept", header.JsonContentType)
request.Header.Add("Accept", header.ContentTypeJson)
err = ParseHeaders(request, &v)
if err != nil {
t.Fatal(err)
@@ -619,7 +619,7 @@ func TestParseHeaders(t *testing.T) {
assert.Equal(t, []string{"addr1", "addr2"}, v.Addrs)
assert.Equal(t, "10.0.10.11", v.XForwardedFor)
assert.Equal(t, "10.0.11.10", v.XRealIP)
assert.Equal(t, header.JsonContentType, v.Accept)
assert.Equal(t, header.ContentTypeJson, v.Accept)
}
func TestParseHeaders_Error(t *testing.T) {
@@ -711,7 +711,7 @@ func TestParseWithFloatPtr(t *testing.T) {
}
body := `{"weightFloat32": 3.2}`
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
r.Header.Set(ContentType, header.JsonContentType)
r.Header.Set(ContentType, header.ContentTypeJson)
if assert.NoError(t, Parse(r, &v)) {
assert.Equal(t, float32(3.2), *v.WeightFloat32)

View File

@@ -179,7 +179,7 @@ func doWriteJson(w http.ResponseWriter, code int, v any) error {
return fmt.Errorf("marshal json failed, error: %w", err)
}
w.Header().Set(ContentType, header.JsonContentType)
w.Header().Set(ContentType, header.ContentTypeJson)
w.WriteHeader(code)
if n, err := w.Write(bs); err != nil {

View File

@@ -10,7 +10,7 @@ const (
// ContentType means Content-Type.
ContentType = header.ContentType
// JsonContentType means application/json.
JsonContentType = header.JsonContentType
JsonContentType = header.ContentTypeJson
// KeyField means key.
KeyField = "key"
// SecretField means secret.

View File

@@ -3,8 +3,18 @@ package header
const (
// ApplicationJson stands for application/json.
ApplicationJson = "application/json"
// CacheControl is the header key for Cache-Control.
CacheControl = "Cache-Control"
// CacheControlNoCache is the value for Cache-Control: no-cache.
CacheControlNoCache = "no-cache"
// Connection is the header key for Connection.
Connection = "Connection"
// ConnectionKeepAlive is the value for Connection: keep-alive.
ConnectionKeepAlive = "keep-alive"
// ContentType is the header key for Content-Type.
ContentType = "Content-Type"
// JsonContentType is the content type for JSON.
JsonContentType = "application/json; charset=utf-8"
// ContentTypeJson is the content type for JSON.
ContentTypeJson = "application/json; charset=utf-8"
// ContentTypeEventStream is the content type for event stream.
ContentTypeEventStream = "text/event-stream"
)

View File

@@ -628,7 +628,7 @@ func TestParseWrappedRequest(t *testing.T) {
func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017", bytes.NewReader(nil))
assert.Nil(t, err)
r.Header.Set(httpx.ContentType, header.JsonContentType)
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
type (
Request struct {
@@ -661,7 +661,7 @@ func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) {
func TestParseWrappedHeadRequestWithJsonHeader(t *testing.T) {
r, err := http.NewRequest(http.MethodHead, "http://hello.com/kevin/2017", bytes.NewReader(nil))
assert.Nil(t, err)
r.Header.Set(httpx.ContentType, header.JsonContentType)
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
type (
Request struct {
@@ -758,7 +758,7 @@ func TestParseWithAllUtf8(t *testing.T) {
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000",
bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`))
assert.Nil(t, err)
r.Header.Set(httpx.ContentType, header.JsonContentType)
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
@@ -948,7 +948,7 @@ func TestParseWithMissingAllPaths(t *testing.T) {
func TestParseGetWithContentLengthHeader(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil)
assert.Nil(t, err)
r.Header.Set(httpx.ContentType, header.JsonContentType)
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
r.Header.Set(contentLength, "1024")
router := NewRouter()
@@ -976,7 +976,7 @@ func TestParseJsonPostWithTypeMismatch(t *testing.T) {
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000",
bytes.NewBufferString(`{"time": "20170912"}`))
assert.Nil(t, err)
r.Header.Set(httpx.ContentType, header.JsonContentType)
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
@@ -1002,7 +1002,7 @@ func TestParseJsonPostWithInt2String(t *testing.T) {
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017",
bytes.NewBufferString(`{"time": 20170912}`))
assert.Nil(t, err)
r.Header.Set(httpx.ContentType, header.JsonContentType)
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(

View File

@@ -63,6 +63,11 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
return server, nil
}
// AddRoute adds given route into the Server.
func (s *Server) AddRoute(r Route, opts ...RouteOption) {
s.AddRoutes([]Route{r}, opts...)
}
// AddRoutes add given routes into the Server.
func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) {
r := featuredRoutes{
@@ -74,11 +79,6 @@ func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) {
s.ngin.addRoutes(r)
}
// AddRoute adds given route into the Server.
func (s *Server) AddRoute(r Route, opts ...RouteOption) {
s.AddRoutes([]Route{r}, opts...)
}
// PrintRoutes prints the added routes to stdout.
func (s *Server) PrintRoutes() {
s.ngin.print()
@@ -279,6 +279,14 @@ func WithSignature(signature SignatureConf) RouteOption {
}
}
// WithSSE returns a RouteOption to enable server-sent events.
func WithSSE() RouteOption {
return func(r *featuredRoutes) {
r.sse = true
r.timeout = 0
}
}
// WithTimeout returns a RouteOption to set timeout with given value.
func WithTimeout(timeout time.Duration) RouteOption {
return func(r *featuredRoutes) {

View File

@@ -20,6 +20,7 @@ import (
"github.com/zeromicro/go-zero/rest/chain"
"github.com/zeromicro/go-zero/rest/httpx"
"github.com/zeromicro/go-zero/rest/internal/cors"
"github.com/zeromicro/go-zero/rest/internal/header"
"github.com/zeromicro/go-zero/rest/router"
)
@@ -754,6 +755,40 @@ Port: 54321
}
}
func TestServerEventStream(t *testing.T) {
server := MustNewServer(RestConf{})
server.AddRoutes([]Route{
{
Method: http.MethodGet,
Path: "/foo",
Handler: func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("foo"))
},
},
{
Method: http.MethodGet,
Path: "/bar",
Handler: func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("bar"))
},
},
}, WithSSE())
check := func(val string) {
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%s", val), http.NoBody)
assert.Nil(t, err)
rr := httptest.NewRecorder()
serve(server, rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, header.ContentTypeEventStream, rr.Header().Get(header.ContentType))
assert.Equal(t, header.CacheControlNoCache, rr.Header().Get(header.CacheControl))
assert.Equal(t, header.ConnectionKeepAlive, rr.Header().Get(header.Connection))
assert.Equal(t, val, rr.Body.String())
}
check("foo")
check("bar")
}
//go:embed testdata
var content embed.FS
@@ -770,7 +805,7 @@ func TestServerEmbedFileSystem(t *testing.T) {
}
// serve is for test purpose, allow developer to do a unit test with
// all defined router without starting an HTTP Server.
// all defined routes without starting an HTTP Server.
//
// For example:
//

View File

@@ -35,6 +35,7 @@ type (
priority bool
jwt jwtSetting
signature signatureSetting
sse bool
routes []Route
maxBytes int64
}