diff --git a/tools/goctl/api/gogen/genroutes.go b/tools/goctl/api/gogen/genroutes.go index 9770a57e1..58d4c243e 100644 --- a/tools/goctl/api/gogen/genroutes.go +++ b/tools/goctl/api/gogen/genroutes.go @@ -40,7 +40,7 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) { ` routesAdditionTemplate = ` server.AddRoutes( - {{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.timeout}} {{.maxBytes}} + {{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.timeout}} {{.maxBytes}} {{.sse}} ) ` timeoutThreshold = time.Millisecond @@ -63,6 +63,7 @@ type ( routes []route jwtEnabled bool signatureEnabled bool + sseEnabled bool authName string timeout string middlewares []string @@ -123,10 +124,17 @@ func genRoutes(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error if len(g.jwtTrans) > 0 { jwt = jwt + fmt.Sprintf("\n rest.WithJwtTransition(serverCtx.Config.%s.PrevSecret,serverCtx.Config.%s.Secret),", g.jwtTrans, g.jwtTrans) } + var signature, prefix string if g.signatureEnabled { signature = "\n rest.WithSignature(serverCtx.Config.Signature)," } + + var sse string + if g.sseEnabled { + sse = "\n rest.WithSSE()," + } + if len(g.prefix) > 0 { prefix = fmt.Sprintf(` rest.WithPrefix("%s"),`, g.prefix) @@ -172,6 +180,7 @@ rest.WithPrefix("%s"),`, g.prefix) "routes": routes, "jwt": jwt, "signature": signature, + "sse": sse, "prefix": prefix, "timeout": timeout, "maxBytes": maxBytes, @@ -281,6 +290,10 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) { if signature == "true" { groupedRoutes.signatureEnabled = true } + sse := g.GetAnnotation("sse") + if sse == "true" { + groupedRoutes.sseEnabled = true + } middleware := g.GetAnnotation("middleware") if len(middleware) > 0 { groupedRoutes.middlewares = append(groupedRoutes.middlewares, diff --git a/tools/goctl/api/gogen/genroutes_test.go b/tools/goctl/api/gogen/genroutes_test.go index 9aac41822..6555951af 100644 --- a/tools/goctl/api/gogen/genroutes_test.go +++ b/tools/goctl/api/gogen/genroutes_test.go @@ -3,6 +3,8 @@ package gogen import ( "testing" "time" + + "github.com/zeromicro/go-zero/tools/goctl/api/spec" ) func Test_formatDuration(t *testing.T) { @@ -25,3 +27,173 @@ func Test_formatDuration(t *testing.T) { } } } + +func TestSSESupport(t *testing.T) { + // Test API spec with SSE enabled + apiSpec := &spec.ApiSpec{ + Service: spec.Service{ + Groups: []spec.Group{ + { + Annotation: spec.Annotation{ + Properties: map[string]string{ + "sse": "true", + "prefix": "/api/v1", + }, + }, + Routes: []spec.Route{ + { + Method: "get", + Path: "/events", + Handler: "StreamEvents", + }, + }, + }, + }, + }, + } + + groups, err := getRoutes(apiSpec) + if err != nil { + t.Fatalf("getRoutes failed: %v", err) + } + + if len(groups) != 1 { + t.Fatalf("Expected 1 group, got %d", len(groups)) + } + + group := groups[0] + if !group.sseEnabled { + t.Error("Expected SSE to be enabled") + } + + if group.prefix != "/api/v1" { + t.Errorf("Expected prefix '/api/v1', got '%s'", group.prefix) + } + + if len(group.routes) != 1 { + t.Fatalf("Expected 1 route, got %d", len(group.routes)) + } + + route := group.routes[0] + if route.method != "http.MethodGet" { + t.Errorf("Expected method 'http.MethodGet', got '%s'", route.method) + } + + if route.path != "/events" { + t.Errorf("Expected path '/events', got '%s'", route.path) + } +} + +func TestSSEWithOtherFeatures(t *testing.T) { + // Test API spec with SSE and other features + apiSpec := &spec.ApiSpec{ + Service: spec.Service{ + Groups: []spec.Group{ + { + Annotation: spec.Annotation{ + Properties: map[string]string{ + "sse": "true", + "jwt": "Auth", + "signature": "true", + "prefix": "/api/v1", + "timeout": "30s", + "middleware": "AuthMiddleware,LogMiddleware", + }, + }, + Routes: []spec.Route{ + { + Method: "get", + Path: "/events", + Handler: "StreamEvents", + }, + }, + }, + }, + }, + } + + groups, err := getRoutes(apiSpec) + if err != nil { + t.Fatalf("getRoutes failed: %v", err) + } + + if len(groups) != 1 { + t.Fatalf("Expected 1 group, got %d", len(groups)) + } + + group := groups[0] + + // Verify all features are enabled + if !group.sseEnabled { + t.Error("Expected SSE to be enabled") + } + + if !group.jwtEnabled { + t.Error("Expected JWT to be enabled") + } + + if !group.signatureEnabled { + t.Error("Expected signature to be enabled") + } + + if group.authName != "Auth" { + t.Errorf("Expected authName 'Auth', got '%s'", group.authName) + } + + if group.prefix != "/api/v1" { + t.Errorf("Expected prefix '/api/v1', got '%s'", group.prefix) + } + + if group.timeout != "30s" { + t.Errorf("Expected timeout '30s', got '%s'", group.timeout) + } + + expectedMiddlewares := []string{"AuthMiddleware", "LogMiddleware"} + if len(group.middlewares) != len(expectedMiddlewares) { + t.Errorf("Expected %d middlewares, got %d", len(expectedMiddlewares), len(group.middlewares)) + } + + for i, expected := range expectedMiddlewares { + if group.middlewares[i] != expected { + t.Errorf("Expected middleware[%d] '%s', got '%s'", i, expected, group.middlewares[i]) + } + } +} + +func TestSSEDisabled(t *testing.T) { + // Test API spec without SSE + apiSpec := &spec.ApiSpec{ + Service: spec.Service{ + Groups: []spec.Group{ + { + Annotation: spec.Annotation{ + Properties: map[string]string{ + "prefix": "/api/v1", + }, + }, + Routes: []spec.Route{ + { + Method: "get", + Path: "/status", + Handler: "GetStatus", + }, + }, + }, + }, + }, + } + + groups, err := getRoutes(apiSpec) + if err != nil { + t.Fatalf("getRoutes failed: %v", err) + } + + if len(groups) != 1 { + t.Fatalf("Expected 1 group, got %d", len(groups)) + } + + group := groups[0] + if group.sseEnabled { + t.Error("Expected SSE to be disabled") + } +}