mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-07-03 01:01:50 +08:00
fix(goctl): include nested client aliases (#5627)
Co-authored-by: Deepak kudi <deepakkudi23@adsl-172-10-9-116.dsl.sndg02.sbcglobal.net> Co-authored-by: kevin <wanjunfeng@gmail.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -67,12 +67,9 @@ func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.C
|
||||
serviceName := stringx.From(service.Name).ToCamel()
|
||||
|
||||
// Collect only the message types actually used by this service's RPCs,
|
||||
// so that each client file only aliases its own request/response types.
|
||||
usedTypes := collection.NewSet[string]()
|
||||
for _, rpc := range service.RPC {
|
||||
usedTypes.Add(parser.CamelCase(rpc.RequestType))
|
||||
usedTypes.Add(parser.CamelCase(rpc.ReturnsType))
|
||||
}
|
||||
// so that each client file only aliases its own request/response types
|
||||
// and their same-file message dependencies.
|
||||
usedTypes := collectServiceUsedTypes(proto.Message, service)
|
||||
|
||||
alias := collection.NewSet[string]()
|
||||
var hasSameNameBetweenMessageAndService bool
|
||||
@@ -337,17 +334,85 @@ func (g *Generator) getInterfaceFuncs(goPackage, mainGoPackage string, service p
|
||||
return functions, nil
|
||||
}
|
||||
|
||||
// collectServiceUsedTypes returns the set of CamelCase message names that are
|
||||
// reachable from any of the service's RPC request or response types via field
|
||||
// references within the same proto file. This ensures per-service client files
|
||||
// alias their own request/response types and all transitively-referenced message
|
||||
// types, but never unrelated messages from other services.
|
||||
func collectServiceUsedTypes(messages []parser.Message, service parser.Service) *collection.Set[string] {
|
||||
messageByName := make(map[string]*proto.Message, len(messages))
|
||||
for _, item := range messages {
|
||||
msgName := parser.CamelCase(getMessageName(*item.Message))
|
||||
messageByName[msgName] = item.Message
|
||||
}
|
||||
|
||||
usedTypes := collection.NewSet[string]()
|
||||
for _, rpc := range service.RPC {
|
||||
collectMessageDependencies(rpc.RequestType, messageByName, usedTypes)
|
||||
collectMessageDependencies(rpc.ReturnsType, messageByName, usedTypes)
|
||||
}
|
||||
|
||||
return usedTypes
|
||||
}
|
||||
|
||||
// collectMessageDependencies recursively adds protoType and all message types
|
||||
// referenced by its fields into usedTypes, looking up messages by CamelCase
|
||||
// name in messageByName. The cycle guard (usedTypes.Contains) prevents
|
||||
// infinite recursion on circular field references.
|
||||
func collectMessageDependencies(protoType string, messageByName map[string]*proto.Message,
|
||||
usedTypes *collection.Set[string]) {
|
||||
for _, candidate := range messageTypeCandidates(protoType) {
|
||||
msg, ok := messageByName[candidate]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if usedTypes.Contains(candidate) {
|
||||
return
|
||||
}
|
||||
|
||||
usedTypes.Add(candidate)
|
||||
for _, elem := range msg.Elements {
|
||||
switch field := elem.(type) {
|
||||
case *proto.NormalField:
|
||||
collectMessageDependencies(field.Type, messageByName, usedTypes)
|
||||
case *proto.MapField:
|
||||
// Map key types are always scalars in proto3; only the value type
|
||||
// can be a message.
|
||||
collectMessageDependencies(field.Type, messageByName, usedTypes)
|
||||
case *proto.Oneof:
|
||||
for _, oneofElem := range field.Elements {
|
||||
if oneofField, ok := oneofElem.(*proto.OneOfField); ok {
|
||||
collectMessageDependencies(oneofField.Type, messageByName, usedTypes)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// messageTypeCandidates returns the CamelCase lookup keys to try for a proto
|
||||
// field type. Two candidates are produced to handle both simple names
|
||||
// ("MyMsg") and dotted/qualified names ("pkg.MyMsg" → "PkgMyMsg").
|
||||
func messageTypeCandidates(protoType string) []string {
|
||||
protoType = strings.TrimPrefix(protoType, ".")
|
||||
return []string{
|
||||
parser.CamelCase(protoType),
|
||||
parser.CamelCase(strings.ReplaceAll(protoType, ".", "_")),
|
||||
}
|
||||
}
|
||||
|
||||
// buildExtraImportLines converts a set of import paths into quoted import lines
|
||||
// for use in the call.tpl {{.extraImports}} placeholder.
|
||||
func buildExtraImportLines(extraImports *collection.Set[string]) string {
|
||||
if extraImports.Count() == 0 {
|
||||
return ""
|
||||
}
|
||||
keys := extraImports.Keys()
|
||||
sort.Strings(keys)
|
||||
lines := make([]string, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
lines = append(lines, fmt.Sprintf(`"%s"`, k))
|
||||
}
|
||||
return strings.Join(lines, "\n\t")
|
||||
if extraImports.Count() == 0 {
|
||||
return ""
|
||||
}
|
||||
keys := extraImports.Keys()
|
||||
sort.Strings(keys)
|
||||
lines := make([]string, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
lines = append(lines, fmt.Sprintf(`"%s"`, k))
|
||||
}
|
||||
return strings.Join(lines, "\n\t")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user