diff --git a/rpcx/internal/balancer/p2c/p2c_test.go b/rpcx/internal/balancer/p2c/p2c_test.go index 2ca1e6c63..5b8a2d3dc 100644 --- a/rpcx/internal/balancer/p2c/p2c_test.go +++ b/rpcx/internal/balancer/p2c/p2c_test.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "strconv" + "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/tal-tech/go-zero/core/logx" @@ -38,6 +40,10 @@ func TestP2cPicker_Pick(t *testing.T) { name: "single", candidates: 1, }, + { + name: "two", + candidates: 2, + }, { name: "multiple", candidates: 100, @@ -46,6 +52,7 @@ func TestP2cPicker_Pick(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + const total = 10000 builder := new(p2cPickerBuilder) ready := make(map[resolver.Address]balancer.SubConn) for i := 0; i < test.candidates; i++ { @@ -55,7 +62,9 @@ func TestP2cPicker_Pick(t *testing.T) { } picker := builder.Build(ready) - for i := 0; i < 10000; i++ { + var wg sync.WaitGroup + wg.Add(total) + for i := 0; i < total; i++ { _, done, err := picker.Pick(context.Background(), balancer.PickInfo{ FullMethodName: "/", Ctx: context.Background(), @@ -64,11 +73,16 @@ func TestP2cPicker_Pick(t *testing.T) { if i%100 == 0 { err = status.Error(codes.DeadlineExceeded, "deadline") } - done(balancer.DoneInfo{ - Err: err, - }) + go func() { + time.Sleep(time.Millisecond) + done(balancer.DoneInfo{ + Err: err, + }) + wg.Done() + }() } + wg.Wait() dist := make(map[interface{}]int) conns := picker.(*p2cPicker).conns for _, conn := range conns { diff --git a/rpcx/internal/resolver/directbuilder.go b/rpcx/internal/resolver/directbuilder.go index dd30a5e97..2db89d203 100644 --- a/rpcx/internal/resolver/directbuilder.go +++ b/rpcx/internal/resolver/directbuilder.go @@ -11,7 +11,9 @@ type directBuilder struct{} func (d *directBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) ( resolver.Resolver, error) { var addrs []resolver.Address - endpoints := strings.Split(target.Endpoint, EndpointSep) + endpoints := strings.FieldsFunc(target.Endpoint, func(r rune) bool { + return r == EndpointSep + }) for _, val := range subset(endpoints, subsetSize) { addrs = append(addrs, resolver.Address{ diff --git a/rpcx/internal/resolver/directbuilder_test.go b/rpcx/internal/resolver/directbuilder_test.go new file mode 100644 index 000000000..2aebc8e45 --- /dev/null +++ b/rpcx/internal/resolver/directbuilder_test.go @@ -0,0 +1,52 @@ +package resolver + +import ( + "fmt" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/core/lang" + "github.com/tal-tech/go-zero/core/mathx" + "google.golang.org/grpc/resolver" +) + +func TestDirectBuilder_Build(t *testing.T) { + tests := []int{ + 0, + 1, + 2, + subsetSize / 2, + subsetSize, + subsetSize * 2, + } + + for _, test := range tests { + t.Run(strconv.Itoa(test), func(t *testing.T) { + var servers []string + for i := 0; i < test; i++ { + servers = append(servers, fmt.Sprintf("localhost:%d", i)) + } + var b directBuilder + cc := new(mockedClientConn) + _, err := b.Build(resolver.Target{ + Scheme: DirectScheme, + Endpoint: strings.Join(servers, ","), + }, cc, resolver.BuildOptions{}) + assert.Nil(t, err) + size := mathx.MinInt(test, subsetSize) + assert.Equal(t, size, len(cc.state.Addresses)) + m := make(map[string]lang.PlaceholderType) + for _, each := range cc.state.Addresses { + m[each.Addr] = lang.Placeholder + } + assert.Equal(t, size, len(m)) + }) + } +} + +func TestDirectBuilder_Scheme(t *testing.T) { + var b directBuilder + assert.Equal(t, DirectScheme, b.Scheme()) +} diff --git a/rpcx/internal/resolver/discovbuilder.go b/rpcx/internal/resolver/discovbuilder.go index e770a78b4..103e94195 100644 --- a/rpcx/internal/resolver/discovbuilder.go +++ b/rpcx/internal/resolver/discovbuilder.go @@ -11,7 +11,9 @@ type discovBuilder struct{} func (d *discovBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) ( resolver.Resolver, error) { - hosts := strings.Split(target.Authority, EndpointSep) + hosts := strings.FieldsFunc(target.Authority, func(r rune) bool { + return r == EndpointSep + }) sub, err := discov.NewSubscriber(hosts, target.Endpoint) if err != nil { return nil, err diff --git a/rpcx/internal/resolver/resolver.go b/rpcx/internal/resolver/resolver.go index 635a8615a..d1622358e 100644 --- a/rpcx/internal/resolver/resolver.go +++ b/rpcx/internal/resolver/resolver.go @@ -5,7 +5,7 @@ import "google.golang.org/grpc/resolver" const ( DirectScheme = "direct" DiscovScheme = "discov" - EndpointSep = "," + EndpointSep = ',' subsetSize = 32 ) diff --git a/rpcx/internal/resolver/resolver_test.go b/rpcx/internal/resolver/resolver_test.go new file mode 100644 index 000000000..eac31b806 --- /dev/null +++ b/rpcx/internal/resolver/resolver_test.go @@ -0,0 +1,36 @@ +package resolver + +import ( + "testing" + + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" +) + +func TestNopResolver(t *testing.T) { + // make sure ResolveNow & Close don't panic + var r nopResolver + r.ResolveNow(resolver.ResolveNowOptions{}) + r.Close() +} + +type mockedClientConn struct { + state resolver.State +} + +func (m *mockedClientConn) UpdateState(state resolver.State) { + m.state = state +} + +func (m *mockedClientConn) ReportError(err error) { +} + +func (m *mockedClientConn) NewAddress(addresses []resolver.Address) { +} + +func (m *mockedClientConn) NewServiceConfig(serviceConfig string) { +} + +func (m *mockedClientConn) ParseServiceConfig(serviceConfigJSON string) *serviceconfig.ParseResult { + return nil +} diff --git a/rpcx/internal/target.go b/rpcx/internal/target.go index d401472c1..52c7b2a72 100644 --- a/rpcx/internal/target.go +++ b/rpcx/internal/target.go @@ -8,10 +8,11 @@ import ( ) func BuildDirectTarget(endpoints []string) string { - return fmt.Sprintf("%s:///%s", resolver.DirectScheme, strings.Join(endpoints, resolver.EndpointSep)) + return fmt.Sprintf("%s:///%s", resolver.DirectScheme, strings.Join( + endpoints, fmt.Sprint(resolver.EndpointSep))) } func BuildDiscovTarget(endpoints []string, key string) string { - return fmt.Sprintf("%s://%s/%s", resolver.DiscovScheme, - strings.Join(endpoints, resolver.EndpointSep), key) + return fmt.Sprintf("%s://%s/%s", resolver.DiscovScheme, strings.Join( + endpoints, fmt.Sprint(resolver.EndpointSep)), key) }