fix(goctl): include nested client aliases (#5627)

Co-authored-by: Deepak kudi <deepakkudi23@adsl-172-10-9-116.dsl.sndg02.sbcglobal.net>
Co-authored-by: kevin <wanjunfeng@gmail.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Puneet Dixit
2026-06-27 21:39:03 +05:30
committed by GitHub
parent d318de1212
commit f910257ec9
2 changed files with 475 additions and 65 deletions

View File

@@ -67,12 +67,9 @@ func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.C
serviceName := stringx.From(service.Name).ToCamel()
// Collect only the message types actually used by this service's RPCs,
// so that each client file only aliases its own request/response types.
usedTypes := collection.NewSet[string]()
for _, rpc := range service.RPC {
usedTypes.Add(parser.CamelCase(rpc.RequestType))
usedTypes.Add(parser.CamelCase(rpc.ReturnsType))
}
// so that each client file only aliases its own request/response types
// and their same-file message dependencies.
usedTypes := collectServiceUsedTypes(proto.Message, service)
alias := collection.NewSet[string]()
var hasSameNameBetweenMessageAndService bool
@@ -337,17 +334,85 @@ func (g *Generator) getInterfaceFuncs(goPackage, mainGoPackage string, service p
return functions, nil
}
// collectServiceUsedTypes returns the set of CamelCase message names that are
// reachable from any of the service's RPC request or response types via field
// references within the same proto file. This ensures per-service client files
// alias their own request/response types and all transitively-referenced message
// types, but never unrelated messages from other services.
func collectServiceUsedTypes(messages []parser.Message, service parser.Service) *collection.Set[string] {
messageByName := make(map[string]*proto.Message, len(messages))
for _, item := range messages {
msgName := parser.CamelCase(getMessageName(*item.Message))
messageByName[msgName] = item.Message
}
usedTypes := collection.NewSet[string]()
for _, rpc := range service.RPC {
collectMessageDependencies(rpc.RequestType, messageByName, usedTypes)
collectMessageDependencies(rpc.ReturnsType, messageByName, usedTypes)
}
return usedTypes
}
// collectMessageDependencies recursively adds protoType and all message types
// referenced by its fields into usedTypes, looking up messages by CamelCase
// name in messageByName. The cycle guard (usedTypes.Contains) prevents
// infinite recursion on circular field references.
func collectMessageDependencies(protoType string, messageByName map[string]*proto.Message,
usedTypes *collection.Set[string]) {
for _, candidate := range messageTypeCandidates(protoType) {
msg, ok := messageByName[candidate]
if !ok {
continue
}
if usedTypes.Contains(candidate) {
return
}
usedTypes.Add(candidate)
for _, elem := range msg.Elements {
switch field := elem.(type) {
case *proto.NormalField:
collectMessageDependencies(field.Type, messageByName, usedTypes)
case *proto.MapField:
// Map key types are always scalars in proto3; only the value type
// can be a message.
collectMessageDependencies(field.Type, messageByName, usedTypes)
case *proto.Oneof:
for _, oneofElem := range field.Elements {
if oneofField, ok := oneofElem.(*proto.OneOfField); ok {
collectMessageDependencies(oneofField.Type, messageByName, usedTypes)
}
}
}
}
return
}
}
// messageTypeCandidates returns the CamelCase lookup keys to try for a proto
// field type. Two candidates are produced to handle both simple names
// ("MyMsg") and dotted/qualified names ("pkg.MyMsg" → "PkgMyMsg").
func messageTypeCandidates(protoType string) []string {
protoType = strings.TrimPrefix(protoType, ".")
return []string{
parser.CamelCase(protoType),
parser.CamelCase(strings.ReplaceAll(protoType, ".", "_")),
}
}
// buildExtraImportLines converts a set of import paths into quoted import lines
// for use in the call.tpl {{.extraImports}} placeholder.
func buildExtraImportLines(extraImports *collection.Set[string]) string {
if extraImports.Count() == 0 {
return ""
}
keys := extraImports.Keys()
sort.Strings(keys)
lines := make([]string, 0, len(keys))
for _, k := range keys {
lines = append(lines, fmt.Sprintf(`"%s"`, k))
}
return strings.Join(lines, "\n\t")
if extraImports.Count() == 0 {
return ""
}
keys := extraImports.Keys()
sort.Strings(keys)
lines := make([]string, 0, len(keys))
for _, k := range keys {
lines = append(lines, fmt.Sprintf(`"%s"`, k))
}
return strings.Join(lines, "\n\t")
}