mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-14 02:10:00 +08:00
feat(goctl/rpc): support external proto imports with cross-package ty… (#5472)
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -47,6 +47,7 @@ func (g *Generator) GenCall(ctx DirContext, proto parser.Proto, cfg *conf.Config
|
||||
func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
|
||||
dir := ctx.GetCall()
|
||||
head := util.GetHead(proto.Name)
|
||||
pkgMap := parser.BuildProtoPackageMap(proto.ImportedProtos)
|
||||
for _, service := range proto.Service {
|
||||
childPkg, err := dir.GetChildPackage(service.Name)
|
||||
if err != nil {
|
||||
@@ -90,12 +91,13 @@ func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.C
|
||||
serviceName = stringx.From(service.Name + "_zrpc_client").ToCamel()
|
||||
}
|
||||
|
||||
functions, err := g.genFunction(proto.PbPackage, serviceName, service, isCallPkgSameToGrpcPkg)
|
||||
extraImports := collection.NewSet[string]()
|
||||
functions, err := g.genFunction(proto.PbPackage, proto.GoPackage, serviceName, service, isCallPkgSameToGrpcPkg, pkgMap, alias, extraImports)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
iFunctions, err := g.getInterfaceFuncs(proto.PbPackage, service, isCallPkgSameToGrpcPkg)
|
||||
iFunctions, err := g.getInterfaceFuncs(proto.PbPackage, proto.GoPackage, service, isCallPkgSameToGrpcPkg, pkgMap, extraImports)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -112,6 +114,7 @@ func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.C
|
||||
protoGoPackage = ""
|
||||
}
|
||||
|
||||
extraImportLines := buildExtraImportLines(extraImports)
|
||||
aliasKeys := alias.Keys()
|
||||
sort.Strings(aliasKeys)
|
||||
if err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]any{
|
||||
@@ -121,6 +124,7 @@ func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.C
|
||||
"filePackage": childDir,
|
||||
"pbPackage": pbPackage,
|
||||
"protoGoPackage": protoGoPackage,
|
||||
"extraImports": extraImportLines,
|
||||
"serviceName": serviceName,
|
||||
"functions": strings.Join(functions, pathx.NL),
|
||||
"interface": strings.Join(iFunctions, pathx.NL),
|
||||
@@ -162,13 +166,15 @@ func (g *Generator) genCallInCompatibility(ctx DirContext, proto parser.Proto,
|
||||
serviceName = stringx.From(service.Name + "_zrpc_client").ToCamel()
|
||||
}
|
||||
|
||||
pkgMap := parser.BuildProtoPackageMap(proto.ImportedProtos)
|
||||
extraImports := collection.NewSet[string]()
|
||||
filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", callFilename))
|
||||
functions, err := g.genFunction(proto.PbPackage, serviceName, service, isCallPkgSameToGrpcPkg)
|
||||
functions, err := g.genFunction(proto.PbPackage, proto.GoPackage, serviceName, service, isCallPkgSameToGrpcPkg, pkgMap, alias, extraImports)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
iFunctions, err := g.getInterfaceFuncs(proto.PbPackage, service, isCallPkgSameToGrpcPkg)
|
||||
iFunctions, err := g.getInterfaceFuncs(proto.PbPackage, proto.GoPackage, service, isCallPkgSameToGrpcPkg, pkgMap, extraImports)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -184,6 +190,7 @@ func (g *Generator) genCallInCompatibility(ctx DirContext, proto parser.Proto,
|
||||
pbPackage = ""
|
||||
protoGoPackage = ""
|
||||
}
|
||||
extraImportLines := buildExtraImportLines(extraImports)
|
||||
aliasKeys := alias.Keys()
|
||||
sort.Strings(aliasKeys)
|
||||
return util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]any{
|
||||
@@ -193,6 +200,7 @@ func (g *Generator) genCallInCompatibility(ctx DirContext, proto parser.Proto,
|
||||
"filePackage": dir.Base,
|
||||
"pbPackage": pbPackage,
|
||||
"protoGoPackage": protoGoPackage,
|
||||
"extraImports": extraImportLines,
|
||||
"serviceName": serviceName,
|
||||
"functions": strings.Join(functions, pathx.NL),
|
||||
"interface": strings.Join(iFunctions, pathx.NL),
|
||||
@@ -221,8 +229,9 @@ func getMessageName(msg proto.Message) string {
|
||||
return strings.Join(list, "_")
|
||||
}
|
||||
|
||||
func (g *Generator) genFunction(goPackage string, serviceName string, service parser.Service,
|
||||
isCallPkgSameToGrpcPkg bool) ([]string, error) {
|
||||
func (g *Generator) genFunction(goPackage, mainGoPackage, serviceName string, service parser.Service,
|
||||
isCallPkgSameToGrpcPkg bool, pkgMap map[string]parser.ImportedProto,
|
||||
alias, extraImports *collection.Set[string]) ([]string, error) {
|
||||
functions := make([]string, 0)
|
||||
|
||||
for _, rpc := range service.RPC {
|
||||
@@ -238,13 +247,29 @@ func (g *Generator) genFunction(goPackage string, serviceName string, service pa
|
||||
streamServer = fmt.Sprintf("%s_%s%s", parser.CamelCase(service.Name),
|
||||
parser.CamelCase(rpc.Name), "Client")
|
||||
}
|
||||
|
||||
reqName, reqAlias, reqImport := resolveCallTypeRef(rpc.RequestType, goPackage, mainGoPackage, pkgMap)
|
||||
respName, respAlias, respImport := resolveCallTypeRef(rpc.ReturnsType, goPackage, mainGoPackage, pkgMap)
|
||||
if reqAlias != "" {
|
||||
alias.Add(reqAlias)
|
||||
}
|
||||
if respAlias != "" {
|
||||
alias.Add(respAlias)
|
||||
}
|
||||
if reqImport != "" {
|
||||
extraImports.Add(reqImport)
|
||||
}
|
||||
if respImport != "" {
|
||||
extraImports.Add(respImport)
|
||||
}
|
||||
|
||||
buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]any{
|
||||
"serviceName": serviceName,
|
||||
"rpcServiceName": parser.CamelCase(service.Name),
|
||||
"method": parser.CamelCase(rpc.Name),
|
||||
"package": goPackage,
|
||||
"pbRequest": parser.CamelCase(rpc.RequestType),
|
||||
"pbResponse": parser.CamelCase(rpc.ReturnsType),
|
||||
"pbRequest": reqName,
|
||||
"pbResponse": respName,
|
||||
"hasComment": len(comment) > 0,
|
||||
"comment": comment,
|
||||
"hasReq": !rpc.StreamsRequest,
|
||||
@@ -262,8 +287,9 @@ func (g *Generator) genFunction(goPackage string, serviceName string, service pa
|
||||
return functions, nil
|
||||
}
|
||||
|
||||
func (g *Generator) getInterfaceFuncs(goPackage string, service parser.Service,
|
||||
isCallPkgSameToGrpcPkg bool) ([]string, error) {
|
||||
func (g *Generator) getInterfaceFuncs(goPackage, mainGoPackage string, service parser.Service,
|
||||
isCallPkgSameToGrpcPkg bool, pkgMap map[string]parser.ImportedProto,
|
||||
extraImports *collection.Set[string]) ([]string, error) {
|
||||
functions := make([]string, 0)
|
||||
|
||||
for _, rpc := range service.RPC {
|
||||
@@ -280,15 +306,25 @@ func (g *Generator) getInterfaceFuncs(goPackage string, service parser.Service,
|
||||
streamServer = fmt.Sprintf("%s_%s%s", parser.CamelCase(service.Name),
|
||||
parser.CamelCase(rpc.Name), "Client")
|
||||
}
|
||||
|
||||
reqName, _, reqImport := resolveCallTypeRef(rpc.RequestType, goPackage, mainGoPackage, pkgMap)
|
||||
respName, _, respImport := resolveCallTypeRef(rpc.ReturnsType, goPackage, mainGoPackage, pkgMap)
|
||||
if reqImport != "" {
|
||||
extraImports.Add(reqImport)
|
||||
}
|
||||
if respImport != "" {
|
||||
extraImports.Add(respImport)
|
||||
}
|
||||
|
||||
buffer, err := util.With("interfaceFn").Parse(text).Execute(
|
||||
map[string]any{
|
||||
"hasComment": len(comment) > 0,
|
||||
"comment": comment,
|
||||
"method": parser.CamelCase(rpc.Name),
|
||||
"hasReq": !rpc.StreamsRequest,
|
||||
"pbRequest": parser.CamelCase(rpc.RequestType),
|
||||
"pbRequest": reqName,
|
||||
"notStream": !rpc.StreamsRequest && !rpc.StreamsReturns,
|
||||
"pbResponse": parser.CamelCase(rpc.ReturnsType),
|
||||
"pbResponse": respName,
|
||||
"streamBody": streamServer,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -300,3 +336,18 @@ func (g *Generator) getInterfaceFuncs(goPackage string, service parser.Service,
|
||||
|
||||
return functions, nil
|
||||
}
|
||||
|
||||
// buildExtraImportLines converts a set of import paths into quoted import lines
|
||||
// for use in the call.tpl {{.extraImports}} placeholder.
|
||||
func buildExtraImportLines(extraImports *collection.Set[string]) string {
|
||||
if extraImports.Count() == 0 {
|
||||
return ""
|
||||
}
|
||||
keys := extraImports.Keys()
|
||||
sort.Strings(keys)
|
||||
lines := make([]string, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
lines = append(lines, fmt.Sprintf(`"%s"`, k))
|
||||
}
|
||||
return strings.Join(lines, "\n\t")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user