diff --git a/tools/goctl/api/gogen/genhandlers.go b/tools/goctl/api/gogen/genhandlers.go index 4611be4d7..0aa335aa6 100644 --- a/tools/goctl/api/gogen/genhandlers.go +++ b/tools/goctl/api/gogen/genhandlers.go @@ -15,8 +15,12 @@ import ( const defaultLogicPackage = "logic" -//go:embed handler.tpl -var handlerTemplate string +var ( + //go:embed handler.tpl + handlerTemplate string + //go:embed sse_handler.tpl + sseHandlerTemplate string +) func genHandler(dir, rootPkg string, cfg *config.Config, group spec.Group, route spec.Route) error { handler := getHandlerName(route) @@ -32,6 +36,12 @@ func genHandler(dir, rootPkg string, cfg *config.Config, group spec.Group, route return err } + var builtinTemplate = handlerTemplate + sse := group.GetAnnotation("sse") + if sse == "true" { + builtinTemplate = sseHandlerTemplate + } + return genFile(fileGenConfig{ dir: dir, subdir: getHandlerFolderPath(group, route), @@ -39,12 +49,13 @@ func genHandler(dir, rootPkg string, cfg *config.Config, group spec.Group, route templateName: "handlerTemplate", category: category, templateFile: handlerTemplateFile, - builtinTemplate: handlerTemplate, + builtinTemplate: builtinTemplate, data: map[string]any{ "PkgName": pkgName, "ImportPackages": genHandlerImports(group, route, rootPkg), "HandlerName": handler, "RequestType": util.Title(route.RequestTypeName()), + "ResponseType": responseGoTypeName(route, typesPacket), "LogicName": logicName, "LogicType": strings.Title(getLogicName(route)), "Call": strings.Title(strings.TrimSuffix(handler, "Handler")), @@ -73,7 +84,8 @@ func genHandlerImports(group spec.Group, route spec.Route, parentPkg string) str fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, getLogicFolderPath(group, route))), fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, contextDir)), } - if len(route.RequestTypeName()) > 0 { + sse := group.GetAnnotation("sse") + if len(route.RequestTypeName()) > 0 || sse == "true" { imports = append(imports, fmt.Sprintf("\"%s\"\n", pathx.JoinPackages(parentPkg, typesDir))) } diff --git a/tools/goctl/api/gogen/genlogic.go b/tools/goctl/api/gogen/genlogic.go index e162eeb64..fcf39ff8a 100644 --- a/tools/goctl/api/gogen/genlogic.go +++ b/tools/goctl/api/gogen/genlogic.go @@ -15,8 +15,13 @@ import ( "github.com/zeromicro/go-zero/tools/goctl/vars" ) -//go:embed logic.tpl -var logicTemplate string +var ( + //go:embed logic.tpl + logicTemplate string + + //go:embed sse_logic.tpl + sseLogicTemplate string +) func genLogic(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error { for _, g := range api.Service.Groups { @@ -54,6 +59,20 @@ func genLogicByRoute(dir, rootPkg string, cfg *config.Config, group spec.Group, } subDir := getLogicFolderPath(group, route) + builtinTemplate := logicTemplate + sse := group.GetAnnotation("sse") + if sse == "true" { + builtinTemplate = sseLogicTemplate + responseString = "error" + returnString = "return nil" + resp := responseGoTypeName(route, typesPacket) + if len(requestString) == 0 { + requestString = "client chan<- " + resp + } else { + requestString += ", client chan<- " + resp + } + } + return genFile(fileGenConfig{ dir: dir, subdir: subDir, @@ -61,7 +80,7 @@ func genLogicByRoute(dir, rootPkg string, cfg *config.Config, group spec.Group, templateName: "logicTemplate", category: category, templateFile: logicTemplateFile, - builtinTemplate: logicTemplate, + builtinTemplate: builtinTemplate, data: map[string]any{ "pkgName": subDir[strings.LastIndex(subDir, "/")+1:], "imports": imports, diff --git a/tools/goctl/api/gogen/sse_handler.tpl b/tools/goctl/api/gogen/sse_handler.tpl new file mode 100644 index 000000000..2ed4cf322 --- /dev/null +++ b/tools/goctl/api/gogen/sse_handler.tpl @@ -0,0 +1,63 @@ +package {{.PkgName}} + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/zeromicro/go-zero/core/logc" + "github.com/zeromicro/go-zero/core/threading" + {{if .HasRequest}}"github.com/zeromicro/go-zero/rest/httpx"{{end}} + {{.ImportPackages}} +) + +{{if .HasDoc}}{{.Doc}}{{end}} +func {{.HandlerName}}(svcCtx *svc.ServiceContext) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + {{if .HasRequest}}var req types.{{.RequestType}} + if err := httpx.Parse(r, &req); err != nil { + httpx.ErrorCtx(r.Context(), w, err) + return + } + + {{end}}// Buffer size of 16 is chosen as a reasonable default to balance throughput and memory usage. + // You can change this based on your application's needs. + // if your go-zero version less than 1.8.1, you need to add 3 lines below. + // w.Header().Set("Content-Type", "text/event-stream") + // w.Header().Set("Cache-Control", "no-cache") + // w.Header().Set("Connection", "keep-alive") + client := make(chan {{.ResponseType}}, 16) + defer func() { + close(client) + }() + l := {{.LogicName}}.New{{.LogicType}}(r.Context(), svcCtx) + threading.GoSafeCtx(r.Context(), func() { + err := l.{{.Call}}({{if .HasRequest}}&req, {{end}}client) + if err != nil { + logc.Errorw(r.Context(), "{{.HandlerName}}", logc.Field("error", err)) + return + } + }) + + for { + select { + case data := <-client: + output, err := json.Marshal(data) + if err != nil { + logc.Errorw(r.Context(), "{{.HandlerName}}", logc.Field("error", err)) + continue + } + + if _, err := fmt.Fprintf(w, "data: %s\n\n", string(output)); err != nil { + logc.Errorw(r.Context(), "{{.HandlerName}}", logc.Field("error", err)) + return + } + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + case <-r.Context().Done(): + return + } + } + } +} diff --git a/tools/goctl/api/gogen/sse_logic.tpl b/tools/goctl/api/gogen/sse_logic.tpl new file mode 100644 index 000000000..b29fbd303 --- /dev/null +++ b/tools/goctl/api/gogen/sse_logic.tpl @@ -0,0 +1,26 @@ +package {{.pkgName}} + +import ( + {{.imports}} +) + +type {{.logic}} struct { + logx.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +{{if .hasDoc}}{{.doc}}{{end}} +func New{{.logic}}(ctx context.Context, svcCtx *svc.ServiceContext) *{{.logic}} { + return &{{.logic}}{ + Logger: logx.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *{{.logic}}) {{.function}}({{.request}}) {{.responseType}} { + // todo: add your logic here and delete this line + + {{.returnString}} +} diff --git a/tools/goctl/api/gogen/template.go b/tools/goctl/api/gogen/template.go index 1bb05dc5a..a00f3741c 100644 --- a/tools/goctl/api/gogen/template.go +++ b/tools/goctl/api/gogen/template.go @@ -12,8 +12,10 @@ const ( contextTemplateFile = "context.tpl" etcTemplateFile = "etc.tpl" handlerTemplateFile = "handler.tpl" + sseHandlerTemplateFile = "sse_handler.tpl" handlerTestTemplateFile = "handler_test.tpl" logicTemplateFile = "logic.tpl" + sseLogicTemplateFile = "sse_logic.tpl" logicTestTemplateFile = "logic_test.tpl" mainTemplateFile = "main.tpl" middlewareImplementCodeFile = "middleware.tpl" @@ -27,8 +29,10 @@ var templates = map[string]string{ contextTemplateFile: contextTemplate, etcTemplateFile: etcTemplate, handlerTemplateFile: handlerTemplate, + sseHandlerTemplateFile: sseHandlerTemplate, handlerTestTemplateFile: handlerTestTemplate, logicTemplateFile: logicTemplate, + sseLogicTemplateFile: sseLogicTemplate, logicTestTemplateFile: logicTestTemplate, mainTemplateFile: mainTemplate, middlewareImplementCodeFile: middlewareImplementCode, diff --git a/tools/goctl/internal/version/version.go b/tools/goctl/internal/version/version.go index 6c6795f36..ae9cabaf7 100644 --- a/tools/goctl/internal/version/version.go +++ b/tools/goctl/internal/version/version.go @@ -6,9 +6,9 @@ import ( ) // BuildVersion is the version of goctl. -const BuildVersion = "1.8.5" +const BuildVersion = "1.8.6-alpha" -var tag = map[string]int{"pre-alpha": 0, "alpha": 1, "pre-bata": 2, "beta": 3, "released": 4, "": 5} +var tag = map[string]int{"pre-alpha": 0, "alpha": 1, "pre-beta": 2, "beta": 3, "released": 4, "": 5} // GetGoctlVersion returns BuildVersion func GetGoctlVersion() string { diff --git a/tools/goctl/pkg/parser/api/parser/analyzer.go b/tools/goctl/pkg/parser/api/parser/analyzer.go index 34722ef64..dd4e211a5 100644 --- a/tools/goctl/pkg/parser/api/parser/analyzer.go +++ b/tools/goctl/pkg/parser/api/parser/analyzer.go @@ -244,6 +244,7 @@ func (a *Analyzer) fillService() error { group.Annotation.Properties = a.convertKV(item.AtServerStmt.Values) } + sse := group.GetAnnotation("sse") == "true" for _, astRoute := range item.Routes { head, leading := astRoute.CommentGroup() route := spec.Route{ @@ -277,6 +278,13 @@ func (a *Analyzer) fillService() error { } route.ResponseType = responseType } + if route.ResponseType == nil && sse { + if route.RequestType != nil { + return ast.SyntaxError(astRoute.Route.Request.Pos(), "missing response type") + } else { + return ast.SyntaxError(astRoute.Route.Path.Pos(), "missing response type") + } + } if err := a.fillRouteType(&route); err != nil { return err