feat(goctl/rpc): support external proto imports with cross-package ty… (#5472)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
kesonan
2026-03-22 12:01:20 +08:00
committed by GitHub
parent c12c82b2f6
commit 004995f06a
93 changed files with 4871 additions and 270 deletions

View File

@@ -47,6 +47,7 @@ func (g *Generator) GenCall(ctx DirContext, proto parser.Proto, cfg *conf.Config
func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
dir := ctx.GetCall()
head := util.GetHead(proto.Name)
pkgMap := parser.BuildProtoPackageMap(proto.ImportedProtos)
for _, service := range proto.Service {
childPkg, err := dir.GetChildPackage(service.Name)
if err != nil {
@@ -90,12 +91,13 @@ func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.C
serviceName = stringx.From(service.Name + "_zrpc_client").ToCamel()
}
functions, err := g.genFunction(proto.PbPackage, serviceName, service, isCallPkgSameToGrpcPkg)
extraImports := collection.NewSet[string]()
functions, err := g.genFunction(proto.PbPackage, proto.GoPackage, serviceName, service, isCallPkgSameToGrpcPkg, pkgMap, alias, extraImports)
if err != nil {
return err
}
iFunctions, err := g.getInterfaceFuncs(proto.PbPackage, service, isCallPkgSameToGrpcPkg)
iFunctions, err := g.getInterfaceFuncs(proto.PbPackage, proto.GoPackage, service, isCallPkgSameToGrpcPkg, pkgMap, extraImports)
if err != nil {
return err
}
@@ -112,6 +114,7 @@ func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.C
protoGoPackage = ""
}
extraImportLines := buildExtraImportLines(extraImports)
aliasKeys := alias.Keys()
sort.Strings(aliasKeys)
if err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]any{
@@ -121,6 +124,7 @@ func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.C
"filePackage": childDir,
"pbPackage": pbPackage,
"protoGoPackage": protoGoPackage,
"extraImports": extraImportLines,
"serviceName": serviceName,
"functions": strings.Join(functions, pathx.NL),
"interface": strings.Join(iFunctions, pathx.NL),
@@ -162,13 +166,15 @@ func (g *Generator) genCallInCompatibility(ctx DirContext, proto parser.Proto,
serviceName = stringx.From(service.Name + "_zrpc_client").ToCamel()
}
pkgMap := parser.BuildProtoPackageMap(proto.ImportedProtos)
extraImports := collection.NewSet[string]()
filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", callFilename))
functions, err := g.genFunction(proto.PbPackage, serviceName, service, isCallPkgSameToGrpcPkg)
functions, err := g.genFunction(proto.PbPackage, proto.GoPackage, serviceName, service, isCallPkgSameToGrpcPkg, pkgMap, alias, extraImports)
if err != nil {
return err
}
iFunctions, err := g.getInterfaceFuncs(proto.PbPackage, service, isCallPkgSameToGrpcPkg)
iFunctions, err := g.getInterfaceFuncs(proto.PbPackage, proto.GoPackage, service, isCallPkgSameToGrpcPkg, pkgMap, extraImports)
if err != nil {
return err
}
@@ -184,6 +190,7 @@ func (g *Generator) genCallInCompatibility(ctx DirContext, proto parser.Proto,
pbPackage = ""
protoGoPackage = ""
}
extraImportLines := buildExtraImportLines(extraImports)
aliasKeys := alias.Keys()
sort.Strings(aliasKeys)
return util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]any{
@@ -193,6 +200,7 @@ func (g *Generator) genCallInCompatibility(ctx DirContext, proto parser.Proto,
"filePackage": dir.Base,
"pbPackage": pbPackage,
"protoGoPackage": protoGoPackage,
"extraImports": extraImportLines,
"serviceName": serviceName,
"functions": strings.Join(functions, pathx.NL),
"interface": strings.Join(iFunctions, pathx.NL),
@@ -221,8 +229,9 @@ func getMessageName(msg proto.Message) string {
return strings.Join(list, "_")
}
func (g *Generator) genFunction(goPackage string, serviceName string, service parser.Service,
isCallPkgSameToGrpcPkg bool) ([]string, error) {
func (g *Generator) genFunction(goPackage, mainGoPackage, serviceName string, service parser.Service,
isCallPkgSameToGrpcPkg bool, pkgMap map[string]parser.ImportedProto,
alias, extraImports *collection.Set[string]) ([]string, error) {
functions := make([]string, 0)
for _, rpc := range service.RPC {
@@ -238,13 +247,29 @@ func (g *Generator) genFunction(goPackage string, serviceName string, service pa
streamServer = fmt.Sprintf("%s_%s%s", parser.CamelCase(service.Name),
parser.CamelCase(rpc.Name), "Client")
}
reqName, reqAlias, reqImport := resolveCallTypeRef(rpc.RequestType, goPackage, mainGoPackage, pkgMap)
respName, respAlias, respImport := resolveCallTypeRef(rpc.ReturnsType, goPackage, mainGoPackage, pkgMap)
if reqAlias != "" {
alias.Add(reqAlias)
}
if respAlias != "" {
alias.Add(respAlias)
}
if reqImport != "" {
extraImports.Add(reqImport)
}
if respImport != "" {
extraImports.Add(respImport)
}
buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]any{
"serviceName": serviceName,
"rpcServiceName": parser.CamelCase(service.Name),
"method": parser.CamelCase(rpc.Name),
"package": goPackage,
"pbRequest": parser.CamelCase(rpc.RequestType),
"pbResponse": parser.CamelCase(rpc.ReturnsType),
"pbRequest": reqName,
"pbResponse": respName,
"hasComment": len(comment) > 0,
"comment": comment,
"hasReq": !rpc.StreamsRequest,
@@ -262,8 +287,9 @@ func (g *Generator) genFunction(goPackage string, serviceName string, service pa
return functions, nil
}
func (g *Generator) getInterfaceFuncs(goPackage string, service parser.Service,
isCallPkgSameToGrpcPkg bool) ([]string, error) {
func (g *Generator) getInterfaceFuncs(goPackage, mainGoPackage string, service parser.Service,
isCallPkgSameToGrpcPkg bool, pkgMap map[string]parser.ImportedProto,
extraImports *collection.Set[string]) ([]string, error) {
functions := make([]string, 0)
for _, rpc := range service.RPC {
@@ -280,15 +306,25 @@ func (g *Generator) getInterfaceFuncs(goPackage string, service parser.Service,
streamServer = fmt.Sprintf("%s_%s%s", parser.CamelCase(service.Name),
parser.CamelCase(rpc.Name), "Client")
}
reqName, _, reqImport := resolveCallTypeRef(rpc.RequestType, goPackage, mainGoPackage, pkgMap)
respName, _, respImport := resolveCallTypeRef(rpc.ReturnsType, goPackage, mainGoPackage, pkgMap)
if reqImport != "" {
extraImports.Add(reqImport)
}
if respImport != "" {
extraImports.Add(respImport)
}
buffer, err := util.With("interfaceFn").Parse(text).Execute(
map[string]any{
"hasComment": len(comment) > 0,
"comment": comment,
"method": parser.CamelCase(rpc.Name),
"hasReq": !rpc.StreamsRequest,
"pbRequest": parser.CamelCase(rpc.RequestType),
"pbRequest": reqName,
"notStream": !rpc.StreamsRequest && !rpc.StreamsReturns,
"pbResponse": parser.CamelCase(rpc.ReturnsType),
"pbResponse": respName,
"streamBody": streamServer,
})
if err != nil {
@@ -300,3 +336,18 @@ func (g *Generator) getInterfaceFuncs(goPackage string, service parser.Service,
return functions, nil
}
// buildExtraImportLines converts a set of import paths into quoted import lines
// for use in the call.tpl {{.extraImports}} placeholder.
func buildExtraImportLines(extraImports *collection.Set[string]) string {
if extraImports.Count() == 0 {
return ""
}
keys := extraImports.Keys()
sort.Strings(keys)
lines := make([]string, 0, len(keys))
for _, k := range keys {
lines = append(lines, fmt.Sprintf(`"%s"`, k))
}
return strings.Join(lines, "\n\t")
}