mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-06-29 15:31:00 +08:00
Co-authored-by: Deepak kudi <deepakkudi23@adsl-172-10-9-116.dsl.sndg02.sbcglobal.net> Co-authored-by: kevin <wanjunfeng@gmail.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
476 lines
16 KiB
Go
476 lines
16 KiB
Go
package generator
|
|
|
|
import (
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/emicklei/proto"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
conf "github.com/zeromicro/go-zero/tools/goctl/config"
|
|
"github.com/zeromicro/go-zero/tools/goctl/rpc/parser"
|
|
"github.com/zeromicro/go-zero/tools/goctl/util/stringx"
|
|
)
|
|
|
|
// mockDirContext is a minimal DirContext for unit-testing genCallGroup.
|
|
type mockDirContext struct {
|
|
callDir Dir
|
|
pbDir Dir
|
|
protoGo Dir
|
|
}
|
|
|
|
func (m *mockDirContext) GetCall() Dir { return m.callDir }
|
|
func (m *mockDirContext) GetEtc() Dir { return Dir{} }
|
|
func (m *mockDirContext) GetInternal() Dir { return Dir{} }
|
|
func (m *mockDirContext) GetConfig() Dir { return Dir{} }
|
|
func (m *mockDirContext) GetLogic() Dir { return Dir{} }
|
|
func (m *mockDirContext) GetServer() Dir { return Dir{} }
|
|
func (m *mockDirContext) GetSvc() Dir { return Dir{} }
|
|
func (m *mockDirContext) GetPb() Dir { return m.pbDir }
|
|
func (m *mockDirContext) GetProtoGo() Dir { return m.protoGo }
|
|
func (m *mockDirContext) GetMain() Dir { return Dir{} }
|
|
func (m *mockDirContext) GetServiceName() stringx.String { return stringx.From("test") }
|
|
func (m *mockDirContext) SetPbDir(pbDir, grpcDir string) {}
|
|
|
|
// newTestDirContext builds a mockDirContext that writes generated files under
|
|
// callBase, with a pb directory that differs (so alias generation is triggered).
|
|
func newTestDirContext(t *testing.T, callBase, pbBase string, services ...string) *mockDirContext {
|
|
t.Helper()
|
|
for _, svc := range services {
|
|
require.NoError(t, os.MkdirAll(filepath.Join(callBase, strings.ToLower(svc)), 0755))
|
|
}
|
|
require.NoError(t, os.MkdirAll(pbBase, 0755))
|
|
return &mockDirContext{
|
|
callDir: Dir{
|
|
Filename: callBase,
|
|
Package: "example.com/test/call",
|
|
Base: "call",
|
|
GetChildPackage: func(childPath string) (string, error) {
|
|
return filepath.Join(callBase, strings.ToLower(childPath)), nil
|
|
},
|
|
},
|
|
pbDir: Dir{Filename: pbBase, Package: "example.com/test/pb", Base: "pb"},
|
|
protoGo: Dir{
|
|
// Must differ from service dir names so isCallPkgSameToPbPkg stays
|
|
// false and alias generation is triggered.
|
|
Filename: pbBase,
|
|
Package: "example.com/test/pb",
|
|
Base: "pb",
|
|
},
|
|
}
|
|
}
|
|
|
|
// ---- unit tests for collectServiceUsedTypes --------------------------------
|
|
|
|
// TestCollectServiceUsedTypes_DirectOnly verifies that request and response
|
|
// types with no message fields are collected as-is.
|
|
func TestCollectServiceUsedTypes_DirectOnly(t *testing.T) {
|
|
messages := []parser.Message{
|
|
{Message: &proto.Message{Name: "AReq"}},
|
|
{Message: &proto.Message{Name: "AResp"}},
|
|
{Message: &proto.Message{Name: "Unrelated"}},
|
|
}
|
|
service := parser.Service{
|
|
Service: &proto.Service{Name: "ServiceA"},
|
|
RPC: []*parser.RPC{
|
|
{RPC: &proto.RPC{Name: "Do", RequestType: "AReq", ReturnsType: "AResp"}},
|
|
},
|
|
}
|
|
|
|
got := collectServiceUsedTypes(messages, service)
|
|
|
|
assert.True(t, got.Contains("AReq"))
|
|
assert.True(t, got.Contains("AResp"))
|
|
assert.False(t, got.Contains("Unrelated"), "unrelated message must not be collected")
|
|
}
|
|
|
|
// TestCollectServiceUsedTypes_NestedNormalField verifies that a message type
|
|
// referenced via a NormalField inside a response is transitively collected
|
|
// (regression test for issue #5618).
|
|
func TestCollectServiceUsedTypes_NestedNormalField(t *testing.T) {
|
|
messages := []parser.Message{
|
|
{Message: &proto.Message{Name: "AReq"}},
|
|
{Message: &proto.Message{
|
|
Name: "AResp",
|
|
Elements: []proto.Visitee{
|
|
&proto.NormalField{Field: &proto.Field{Name: "items", Type: "AItem"}},
|
|
},
|
|
}},
|
|
{Message: &proto.Message{Name: "AItem"}},
|
|
}
|
|
service := parser.Service{
|
|
Service: &proto.Service{Name: "ServiceA"},
|
|
RPC: []*parser.RPC{
|
|
{RPC: &proto.RPC{Name: "List", RequestType: "AReq", ReturnsType: "AResp"}},
|
|
},
|
|
}
|
|
|
|
got := collectServiceUsedTypes(messages, service)
|
|
|
|
assert.True(t, got.Contains("AReq"))
|
|
assert.True(t, got.Contains("AResp"))
|
|
assert.True(t, got.Contains("AItem"), "field type AItem must be transitively collected")
|
|
}
|
|
|
|
// TestCollectServiceUsedTypes_MapValueField verifies that the value type of a
|
|
// MapField inside a response message is transitively collected.
|
|
func TestCollectServiceUsedTypes_MapValueField(t *testing.T) {
|
|
messages := []parser.Message{
|
|
{Message: &proto.Message{Name: "AReq"}},
|
|
{Message: &proto.Message{
|
|
Name: "AResp",
|
|
Elements: []proto.Visitee{
|
|
&proto.MapField{KeyType: "string", Field: &proto.Field{Name: "index", Type: "AItem"}},
|
|
},
|
|
}},
|
|
{Message: &proto.Message{Name: "AItem"}},
|
|
}
|
|
service := parser.Service{
|
|
Service: &proto.Service{Name: "ServiceA"},
|
|
RPC: []*parser.RPC{
|
|
{RPC: &proto.RPC{Name: "GetMap", RequestType: "AReq", ReturnsType: "AResp"}},
|
|
},
|
|
}
|
|
|
|
got := collectServiceUsedTypes(messages, service)
|
|
|
|
assert.True(t, got.Contains("AResp"))
|
|
assert.True(t, got.Contains("AItem"), "map value type AItem must be transitively collected")
|
|
}
|
|
|
|
// TestCollectServiceUsedTypes_OneofField verifies that message types referenced
|
|
// inside a Oneof element are transitively collected.
|
|
func TestCollectServiceUsedTypes_OneofField(t *testing.T) {
|
|
oneof := &proto.Oneof{Name: "result"}
|
|
oneof.Elements = []proto.Visitee{
|
|
&proto.OneOfField{Field: &proto.Field{Name: "success", Type: "SuccessMsg"}},
|
|
&proto.OneOfField{Field: &proto.Field{Name: "failure", Type: "FailureMsg"}},
|
|
}
|
|
messages := []parser.Message{
|
|
{Message: &proto.Message{Name: "AReq"}},
|
|
{Message: &proto.Message{
|
|
Name: "AResp",
|
|
Elements: []proto.Visitee{oneof},
|
|
}},
|
|
{Message: &proto.Message{Name: "SuccessMsg"}},
|
|
{Message: &proto.Message{Name: "FailureMsg"}},
|
|
}
|
|
service := parser.Service{
|
|
Service: &proto.Service{Name: "ServiceA"},
|
|
RPC: []*parser.RPC{
|
|
{RPC: &proto.RPC{Name: "Do", RequestType: "AReq", ReturnsType: "AResp"}},
|
|
},
|
|
}
|
|
|
|
got := collectServiceUsedTypes(messages, service)
|
|
|
|
assert.True(t, got.Contains("AResp"))
|
|
assert.True(t, got.Contains("SuccessMsg"), "oneof field type SuccessMsg must be collected")
|
|
assert.True(t, got.Contains("FailureMsg"), "oneof field type FailureMsg must be collected")
|
|
}
|
|
|
|
// TestCollectServiceUsedTypes_MultiLevelTransitive verifies that a chain
|
|
// AResp → BMsg → CMsg is fully collected (multi-level transitivity).
|
|
func TestCollectServiceUsedTypes_MultiLevelTransitive(t *testing.T) {
|
|
messages := []parser.Message{
|
|
{Message: &proto.Message{Name: "AReq"}},
|
|
{Message: &proto.Message{
|
|
Name: "AResp",
|
|
Elements: []proto.Visitee{
|
|
&proto.NormalField{Field: &proto.Field{Name: "b", Type: "BMsg"}},
|
|
},
|
|
}},
|
|
{Message: &proto.Message{
|
|
Name: "BMsg",
|
|
Elements: []proto.Visitee{
|
|
&proto.NormalField{Field: &proto.Field{Name: "c", Type: "CMsg"}},
|
|
},
|
|
}},
|
|
{Message: &proto.Message{Name: "CMsg"}},
|
|
}
|
|
service := parser.Service{
|
|
Service: &proto.Service{Name: "ServiceA"},
|
|
RPC: []*parser.RPC{
|
|
{RPC: &proto.RPC{Name: "Do", RequestType: "AReq", ReturnsType: "AResp"}},
|
|
},
|
|
}
|
|
|
|
got := collectServiceUsedTypes(messages, service)
|
|
|
|
assert.True(t, got.Contains("AReq"))
|
|
assert.True(t, got.Contains("AResp"))
|
|
assert.True(t, got.Contains("BMsg"), "BMsg must be transitively collected via AResp")
|
|
assert.True(t, got.Contains("CMsg"), "CMsg must be transitively collected via BMsg")
|
|
}
|
|
|
|
// TestCollectServiceUsedTypes_CycleDetection verifies that circular field
|
|
// references (AResp ↔ BMsg) do not cause infinite recursion.
|
|
func TestCollectServiceUsedTypes_CycleDetection(t *testing.T) {
|
|
messages := []parser.Message{
|
|
{Message: &proto.Message{Name: "AReq"}},
|
|
{Message: &proto.Message{
|
|
Name: "AResp",
|
|
Elements: []proto.Visitee{
|
|
&proto.NormalField{Field: &proto.Field{Name: "b", Type: "BMsg"}},
|
|
},
|
|
}},
|
|
{Message: &proto.Message{
|
|
Name: "BMsg",
|
|
Elements: []proto.Visitee{
|
|
// circular back-reference to AResp
|
|
&proto.NormalField{Field: &proto.Field{Name: "a", Type: "AResp"}},
|
|
},
|
|
}},
|
|
}
|
|
service := parser.Service{
|
|
Service: &proto.Service{Name: "ServiceA"},
|
|
RPC: []*parser.RPC{
|
|
{RPC: &proto.RPC{Name: "Do", RequestType: "AReq", ReturnsType: "AResp"}},
|
|
},
|
|
}
|
|
|
|
// Must not panic or loop; both messages are reachable.
|
|
got := collectServiceUsedTypes(messages, service)
|
|
|
|
assert.True(t, got.Contains("AResp"))
|
|
assert.True(t, got.Contains("BMsg"))
|
|
}
|
|
|
|
// TestCollectServiceUsedTypes_ExcludesUnrelatedService verifies that messages
|
|
// belonging only to another service are not included.
|
|
func TestCollectServiceUsedTypes_ExcludesUnrelatedService(t *testing.T) {
|
|
messages := []parser.Message{
|
|
{Message: &proto.Message{Name: "AReq"}},
|
|
{Message: &proto.Message{Name: "AResp"}},
|
|
{Message: &proto.Message{Name: "BReq"}},
|
|
{Message: &proto.Message{Name: "BResp"}},
|
|
}
|
|
service := parser.Service{
|
|
Service: &proto.Service{Name: "ServiceA"},
|
|
RPC: []*parser.RPC{
|
|
{RPC: &proto.RPC{Name: "DoA", RequestType: "AReq", ReturnsType: "AResp"}},
|
|
},
|
|
}
|
|
|
|
got := collectServiceUsedTypes(messages, service)
|
|
|
|
assert.True(t, got.Contains("AReq"))
|
|
assert.True(t, got.Contains("AResp"))
|
|
assert.False(t, got.Contains("BReq"), "BReq belongs to ServiceB and must be excluded")
|
|
assert.False(t, got.Contains("BResp"), "BResp belongs to ServiceB and must be excluded")
|
|
}
|
|
|
|
// ---- integration tests via genCallGroup ------------------------------------
|
|
|
|
// TestGenCallGroup_OnlyUsedTypesAliased verifies that in multi-service mode
|
|
// each generated client file aliases only its own request/response types and
|
|
// their transitive field dependencies (fix for issues #5481 and #5618).
|
|
func TestGenCallGroup_OnlyUsedTypesAliased(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
callBase := filepath.Join(tmpDir, "call")
|
|
pbBase := filepath.Join(tmpDir, "pb")
|
|
|
|
mctx := newTestDirContext(t, callBase, pbBase, "ServiceA", "ServiceB")
|
|
|
|
// ServiceA: AResp contains a NormalField of type AItem (issue #5618).
|
|
// ServiceB: BResp has no nested message fields.
|
|
// AItem must appear in ServiceA's file but not ServiceB's.
|
|
protoData := parser.Proto{
|
|
Name: "multi.proto",
|
|
PbPackage: "pb",
|
|
Message: []parser.Message{
|
|
{Message: &proto.Message{Name: "AReq"}},
|
|
{Message: &proto.Message{
|
|
Name: "AResp",
|
|
Elements: []proto.Visitee{
|
|
&proto.NormalField{Field: &proto.Field{Name: "items", Type: "AItem"}},
|
|
},
|
|
}},
|
|
{Message: &proto.Message{Name: "AItem"}},
|
|
{Message: &proto.Message{Name: "BReq"}},
|
|
{Message: &proto.Message{Name: "BResp"}},
|
|
},
|
|
Service: parser.Services{
|
|
{
|
|
Service: &proto.Service{Name: "ServiceA"},
|
|
RPC: []*parser.RPC{
|
|
{RPC: &proto.RPC{Name: "DoA", RequestType: "AReq", ReturnsType: "AResp"}},
|
|
},
|
|
},
|
|
{
|
|
Service: &proto.Service{Name: "ServiceB"},
|
|
RPC: []*parser.RPC{
|
|
{RPC: &proto.RPC{Name: "DoB", RequestType: "BReq", ReturnsType: "BResp"}},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
cfg, err := conf.NewConfig("")
|
|
require.NoError(t, err)
|
|
require.NoError(t, NewGenerator("gozero", false).genCallGroup(mctx, protoData, cfg))
|
|
|
|
aFile := normalizeWS(readGenFile(t, callBase, "servicea", "servicea.go"))
|
|
assert.Contains(t, aFile, "AReq = pb.AReq", "ServiceA must alias AReq")
|
|
assert.Contains(t, aFile, "AResp = pb.AResp", "ServiceA must alias AResp")
|
|
assert.Contains(t, aFile, "AItem = pb.AItem", "ServiceA must alias AItem (transitive NormalField)")
|
|
assert.NotContains(t, aFile, "BReq = pb.BReq", "ServiceA must not alias BReq")
|
|
assert.NotContains(t, aFile, "BResp = pb.BResp", "ServiceA must not alias BResp")
|
|
|
|
bFile := normalizeWS(readGenFile(t, callBase, "serviceb", "serviceb.go"))
|
|
assert.Contains(t, bFile, "BReq = pb.BReq", "ServiceB must alias BReq")
|
|
assert.Contains(t, bFile, "BResp = pb.BResp", "ServiceB must alias BResp")
|
|
assert.NotContains(t, bFile, "AReq = pb.AReq", "ServiceB must not alias AReq")
|
|
assert.NotContains(t, bFile, "AResp = pb.AResp", "ServiceB must not alias AResp")
|
|
assert.NotContains(t, bFile, "AItem = pb.AItem", "ServiceB must not alias AItem")
|
|
}
|
|
|
|
// TestGenCallGroup_MapValueAliased verifies that the value type of a MapField
|
|
// inside a service response is included in the generated aliases.
|
|
func TestGenCallGroup_MapValueAliased(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
callBase := filepath.Join(tmpDir, "call")
|
|
pbBase := filepath.Join(tmpDir, "pb")
|
|
|
|
mctx := newTestDirContext(t, callBase, pbBase, "ServiceA")
|
|
|
|
protoData := parser.Proto{
|
|
Name: "map.proto",
|
|
PbPackage: "pb",
|
|
Message: []parser.Message{
|
|
{Message: &proto.Message{Name: "AReq"}},
|
|
{Message: &proto.Message{
|
|
Name: "AResp",
|
|
Elements: []proto.Visitee{
|
|
&proto.MapField{KeyType: "string", Field: &proto.Field{Name: "index", Type: "AItem"}},
|
|
},
|
|
}},
|
|
{Message: &proto.Message{Name: "AItem"}},
|
|
},
|
|
Service: parser.Services{
|
|
{
|
|
Service: &proto.Service{Name: "ServiceA"},
|
|
RPC: []*parser.RPC{
|
|
{RPC: &proto.RPC{Name: "GetMap", RequestType: "AReq", ReturnsType: "AResp"}},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
cfg, err := conf.NewConfig("")
|
|
require.NoError(t, err)
|
|
require.NoError(t, NewGenerator("gozero", false).genCallGroup(mctx, protoData, cfg))
|
|
|
|
aFile := normalizeWS(readGenFile(t, callBase, "servicea", "servicea.go"))
|
|
assert.Contains(t, aFile, "AResp = pb.AResp")
|
|
assert.Contains(t, aFile, "AItem = pb.AItem", "map value type AItem must be aliased")
|
|
}
|
|
|
|
// TestGenCallGroup_OneofAliased verifies that message types referenced inside a
|
|
// Oneof element are included in the generated aliases.
|
|
func TestGenCallGroup_OneofAliased(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
callBase := filepath.Join(tmpDir, "call")
|
|
pbBase := filepath.Join(tmpDir, "pb")
|
|
|
|
mctx := newTestDirContext(t, callBase, pbBase, "ServiceA")
|
|
|
|
oneof := &proto.Oneof{Name: "result"}
|
|
oneof.Elements = []proto.Visitee{
|
|
&proto.OneOfField{Field: &proto.Field{Name: "ok", Type: "SuccessMsg"}},
|
|
&proto.OneOfField{Field: &proto.Field{Name: "err", Type: "FailureMsg"}},
|
|
}
|
|
protoData := parser.Proto{
|
|
Name: "oneof.proto",
|
|
PbPackage: "pb",
|
|
Message: []parser.Message{
|
|
{Message: &proto.Message{Name: "AReq"}},
|
|
{Message: &proto.Message{
|
|
Name: "AResp",
|
|
Elements: []proto.Visitee{oneof},
|
|
}},
|
|
{Message: &proto.Message{Name: "SuccessMsg"}},
|
|
{Message: &proto.Message{Name: "FailureMsg"}},
|
|
},
|
|
Service: parser.Services{
|
|
{
|
|
Service: &proto.Service{Name: "ServiceA"},
|
|
RPC: []*parser.RPC{
|
|
{RPC: &proto.RPC{Name: "Do", RequestType: "AReq", ReturnsType: "AResp"}},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
cfg, err := conf.NewConfig("")
|
|
require.NoError(t, err)
|
|
require.NoError(t, NewGenerator("gozero", false).genCallGroup(mctx, protoData, cfg))
|
|
|
|
aFile := normalizeWS(readGenFile(t, callBase, "servicea", "servicea.go"))
|
|
assert.Contains(t, aFile, "SuccessMsg = pb.SuccessMsg", "oneof type SuccessMsg must be aliased")
|
|
assert.Contains(t, aFile, "FailureMsg = pb.FailureMsg", "oneof type FailureMsg must be aliased")
|
|
}
|
|
|
|
// TestGenCallGroup_MultiLevelTransitiveAliased verifies that a dependency chain
|
|
// AResp → BMsg → CMsg causes all three types to be aliased in the client file.
|
|
func TestGenCallGroup_MultiLevelTransitiveAliased(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
callBase := filepath.Join(tmpDir, "call")
|
|
pbBase := filepath.Join(tmpDir, "pb")
|
|
|
|
mctx := newTestDirContext(t, callBase, pbBase, "ServiceA")
|
|
|
|
protoData := parser.Proto{
|
|
Name: "transitive.proto",
|
|
PbPackage: "pb",
|
|
Message: []parser.Message{
|
|
{Message: &proto.Message{Name: "AReq"}},
|
|
{Message: &proto.Message{
|
|
Name: "AResp",
|
|
Elements: []proto.Visitee{
|
|
&proto.NormalField{Field: &proto.Field{Name: "b", Type: "BMsg"}},
|
|
},
|
|
}},
|
|
{Message: &proto.Message{
|
|
Name: "BMsg",
|
|
Elements: []proto.Visitee{
|
|
&proto.NormalField{Field: &proto.Field{Name: "c", Type: "CMsg"}},
|
|
},
|
|
}},
|
|
{Message: &proto.Message{Name: "CMsg"}},
|
|
},
|
|
Service: parser.Services{
|
|
{
|
|
Service: &proto.Service{Name: "ServiceA"},
|
|
RPC: []*parser.RPC{
|
|
{RPC: &proto.RPC{Name: "Do", RequestType: "AReq", ReturnsType: "AResp"}},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
cfg, err := conf.NewConfig("")
|
|
require.NoError(t, err)
|
|
require.NoError(t, NewGenerator("gozero", false).genCallGroup(mctx, protoData, cfg))
|
|
|
|
aFile := normalizeWS(readGenFile(t, callBase, "servicea", "servicea.go"))
|
|
assert.Contains(t, aFile, "AResp = pb.AResp")
|
|
assert.Contains(t, aFile, "BMsg = pb.BMsg", "BMsg must be transitively aliased via AResp")
|
|
assert.Contains(t, aFile, "CMsg = pb.CMsg", "CMsg must be transitively aliased via BMsg")
|
|
}
|
|
|
|
// readGenFile reads a generated file relative to callBase and returns its content.
|
|
func readGenFile(t *testing.T, callBase string, parts ...string) string {
|
|
t.Helper()
|
|
content, err := os.ReadFile(filepath.Join(append([]string{callBase}, parts...)...))
|
|
require.NoError(t, err)
|
|
return string(content)
|
|
}
|
|
|
|
// normalizeWS replaces runs of whitespace with a single space.
|
|
func normalizeWS(s string) string {
|
|
return strings.Join(strings.Fields(strings.ReplaceAll(s, "\n", " \n ")), " ")
|
|
}
|