diff --git a/zrpc/client.go b/zrpc/client.go index e44abc165..fa12daafb 100644 --- a/zrpc/client.go +++ b/zrpc/client.go @@ -1,12 +1,16 @@ package zrpc import ( + "context" + "fmt" "time" "github.com/zeromicro/go-zero/core/conf" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/zrpc/internal" "github.com/zeromicro/go-zero/zrpc/internal/auth" + "github.com/zeromicro/go-zero/zrpc/internal/balancer/consistenthash" + "github.com/zeromicro/go-zero/zrpc/internal/balancer/p2c" "github.com/zeromicro/go-zero/zrpc/internal/clientinterceptors" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" @@ -67,6 +71,9 @@ func NewClient(c RpcClientConf, options ...ClientOption) (Client, error) { }))) } + svcCfg := makeLBServiceConfig(c.BalancerName) + opts = append(opts, WithDialOption(grpc.WithDefaultServiceConfig(svcCfg))) + opts = append(opts, options...) target, err := c.BuildTarget() @@ -111,7 +118,20 @@ func SetClientSlowThreshold(threshold time.Duration) { clientinterceptors.SetSlowThreshold(threshold) } +// SetHashKey sets the hash key into context. +func SetHashKey(ctx context.Context, key string) context.Context { + return consistenthash.SetHashKey(ctx, key) +} + // WithCallTimeout return a call option with given timeout to make a method call. func WithCallTimeout(timeout time.Duration) grpc.CallOption { return clientinterceptors.WithCallTimeout(timeout) } + +func makeLBServiceConfig(balancerName string) string { + if len(balancerName) == 0 { + balancerName = p2c.Name + } + + return fmt.Sprintf(`{"loadBalancingPolicy":"%s"}`, balancerName) +} diff --git a/zrpc/client_test.go b/zrpc/client_test.go index 09a06cbb1..3526cad1e 100644 --- a/zrpc/client_test.go +++ b/zrpc/client_test.go @@ -12,6 +12,8 @@ import ( "github.com/zeromicro/go-zero/core/discov" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/internal/mock" + "github.com/zeromicro/go-zero/zrpc/internal/balancer/consistenthash" + "github.com/zeromicro/go-zero/zrpc/internal/balancer/p2c" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" @@ -245,3 +247,42 @@ func TestNewClientWithTarget(t *testing.T) { assert.NotNil(t, err) } + +func TestMakeLBServiceConfig(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "empty name uses default p2c", + input: "", + expected: fmt.Sprintf(`{"loadBalancingPolicy":"%s"}`, p2c.Name), + }, + { + name: "custom balancer name", + input: "consistent_hash", + expected: `{"loadBalancingPolicy":"consistent_hash"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := makeLBServiceConfig(tt.input) + if got != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, got) + } + }) + } +} + +func TestSetHashKey(t *testing.T) { + ctx := context.Background() + key := "abc123" + + ctx = SetHashKey(ctx, key) + got := consistenthash.GetHashKey(ctx) + assert.Equal(t, key, got) + + assert.Empty(t, consistenthash.GetHashKey(context.Background())) +} diff --git a/zrpc/config.go b/zrpc/config.go index 5e459a3c4..0ae62af2f 100644 --- a/zrpc/config.go +++ b/zrpc/config.go @@ -31,6 +31,7 @@ type ( Timeout int64 `json:",default=2000"` KeepaliveTime time.Duration `json:",optional"` Middlewares ClientMiddlewaresConf + BalancerName string `json:",default=p2c_ewma"` } // A RpcServerConf is a rpc server config. diff --git a/zrpc/config_test.go b/zrpc/config_test.go index 0f7a637e1..d3aaa12e9 100644 --- a/zrpc/config_test.go +++ b/zrpc/config_test.go @@ -4,9 +4,11 @@ import ( "testing" "github.com/stretchr/testify/assert" + zconf "github.com/zeromicro/go-zero/core/conf" "github.com/zeromicro/go-zero/core/discov" "github.com/zeromicro/go-zero/core/service" "github.com/zeromicro/go-zero/core/stores/redis" + "github.com/zeromicro/go-zero/zrpc/internal/balancer/p2c" ) func TestRpcClientConf(t *testing.T) { @@ -39,6 +41,13 @@ func TestRpcClientConf(t *testing.T) { _, err := conf.BuildTarget() assert.Error(t, err) }) + + t.Run("default balancer name", func(t *testing.T) { + var conf RpcClientConf + err := zconf.FillDefault(&conf) + assert.NoError(t, err) + assert.Equal(t, p2c.Name, conf.BalancerName) + }) } func TestRpcServerConf(t *testing.T) { diff --git a/zrpc/internal/balancer/consistenthash/consistenthash.go b/zrpc/internal/balancer/consistenthash/consistenthash.go new file mode 100644 index 000000000..6543cd893 --- /dev/null +++ b/zrpc/internal/balancer/consistenthash/consistenthash.go @@ -0,0 +1,97 @@ +package consistenthash + +import ( + "context" + + "github.com/zeromicro/go-zero/core/hash" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +const ( + Name = "consistent_hash" + + defaultReplicaCount = 100 +) + +var emptyPickResult balancer.PickResult + +func init() { + balancer.Register(newBuilder()) +} + +type ( + // hashKey is the key type for consistent hash in context. + hashKey struct{} + // pickerBuilder is a builder for picker. + pickerBuilder struct{} + // picker is a picker that uses consistent hash to pick a sub connection. + picker struct { + hashRing *hash.ConsistentHash + conns map[string]balancer.SubConn + } +) + +func (b *pickerBuilder) Build(info base.PickerBuildInfo) balancer.Picker { + readySCs := info.ReadySCs + if len(readySCs) == 0 { + return base.NewErrPicker(balancer.ErrNoSubConnAvailable) + } + + conns := make(map[string]balancer.SubConn, len(readySCs)) + hashRing := hash.NewCustomConsistentHash(defaultReplicaCount, hash.Hash) + for conn, connInfo := range readySCs { + addr := connInfo.Address.Addr + conns[addr] = conn + hashRing.Add(addr) + } + + return &picker{ + hashRing: hashRing, + conns: conns, + } +} + +func newBuilder() balancer.Builder { + return base.NewBalancerBuilder(Name, &pickerBuilder{}, base.Config{HealthCheck: true}) +} + +func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + hashKey := GetHashKey(info.Ctx) + if len(hashKey) == 0 { + return emptyPickResult, status.Error(codes.InvalidArgument, + "[consistent_hash] missing hash key in context") + } + + if addrAny, ok := p.hashRing.Get(hashKey); ok { + addr, ok := addrAny.(string) + if !ok { + return emptyPickResult, status.Error(codes.Internal, + "[consistent_hash] invalid addr type in consistent hash") + } + + subConn, ok := p.conns[addr] + if !ok { + return emptyPickResult, status.Errorf(codes.Internal, + "[consistent_hash] no subConn for addr: %s", addr) + } + + return balancer.PickResult{SubConn: subConn}, nil + } + + return emptyPickResult, status.Errorf(codes.Unavailable, + "[consistent_hash] no matching conn for hashKey: %s", hashKey) +} + +// SetHashKey sets the hash key into context. +func SetHashKey(ctx context.Context, key string) context.Context { + return context.WithValue(ctx, hashKey{}, key) +} + +// GetHashKey gets the hash key from context. +func GetHashKey(ctx context.Context) string { + v, _ := ctx.Value(hashKey{}).(string) + return v +} diff --git a/zrpc/internal/balancer/consistenthash/consistenthash_test.go b/zrpc/internal/balancer/consistenthash/consistenthash_test.go new file mode 100644 index 000000000..2f3524d46 --- /dev/null +++ b/zrpc/internal/balancer/consistenthash/consistenthash_test.go @@ -0,0 +1,175 @@ +package consistenthash + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/core/hash" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/resolver" +) + +type fakeSubConn struct{ id int } + +func (f *fakeSubConn) Connect() {} +func (f *fakeSubConn) UpdateAddresses(_ []resolver.Address) {} +func (f *fakeSubConn) Shutdown() {} +func (f *fakeSubConn) GetOrBuildProducer(b balancer.ProducerBuilder) (balancer.Producer, func()) { + return nil, func() {} +} + +func TestPickerBuilder_EmptyReadySCs(t *testing.T) { + b := &pickerBuilder{} + p := b.Build(base.PickerBuildInfo{ReadySCs: map[balancer.SubConn]base.SubConnInfo{}}) + + _, err := p.Pick(balancer.PickInfo{}) + assert.Equal(t, balancer.ErrNoSubConnAvailable, err) +} + +func TestPickerBuilder_BuildAndRing(t *testing.T) { + subConn1 := &fakeSubConn{id: 1} + subConn2 := &fakeSubConn{id: 2} + addr1 := "127.0.0.1:8080" + addr2 := "127.0.0.1:8081" + + b := &pickerBuilder{} + info := base.PickerBuildInfo{ + ReadySCs: map[balancer.SubConn]base.SubConnInfo{ + subConn1: {Address: resolver.Address{Addr: addr1}}, + subConn2: {Address: resolver.Address{Addr: addr2}}, + }, + } + + p := b.Build(info).(*picker) + assert.NotNil(t, p.hashRing) + assert.Len(t, p.conns, 2) +} + +func TestPicker_HashConsistency(t *testing.T) { + subConn1 := &fakeSubConn{id: 1} + subConn2 := &fakeSubConn{id: 2} + + pb := &pickerBuilder{} + info := base.PickerBuildInfo{ + ReadySCs: map[balancer.SubConn]base.SubConnInfo{ + subConn1: {Address: resolver.Address{Addr: "127.0.0.1:8080"}}, + subConn2: {Address: resolver.Address{Addr: "127.0.0.1:8081"}}, + }, + } + p := pb.Build(info).(*picker) + ctx := SetHashKey(context.Background(), "user_123") + res1, err := p.Pick(balancer.PickInfo{Ctx: ctx}) + assert.NoError(t, err) + assert.NotNil(t, res1.SubConn) + + // Multiple requests with the same key remain consistent + for i := 0; i < 5; i++ { + resN, err := p.Pick(balancer.PickInfo{Ctx: ctx}) + assert.NoError(t, err) + assert.Equal(t, res1.SubConn, resN.SubConn) + } +} + +func TestPicker_MissingKey(t *testing.T) { + subConn := &fakeSubConn{id: 1} + + pb := &pickerBuilder{} + info := base.PickerBuildInfo{ + ReadySCs: map[balancer.SubConn]base.SubConnInfo{ + subConn: {Address: resolver.Address{Addr: "127.0.0.1:8080"}}, + }, + } + p := pb.Build(info).(*picker) + + // No hash key in context + _, err := p.Pick(balancer.PickInfo{Ctx: context.Background()}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "[consistent_hash] missing hash key in context") +} + +func TestPicker_NoMatchingConn(t *testing.T) { + emptyRing := newCustomRingForTest() + p := &picker{ + hashRing: emptyRing, + conns: map[string]balancer.SubConn{}, + } + + _, err := p.Pick(balancer.PickInfo{Ctx: SetHashKey(context.Background(), "someone")}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "[consistent_hash] no matching conn for hashKey: someone") +} + +func TestPicker_InvalidAddrType(t *testing.T) { + ring := newCustomRingForTest() + ring.Add(12345) + + subConn := &fakeSubConn{id: 1} + p := &picker{ + hashRing: ring, + conns: map[string]balancer.SubConn{ + "12345": subConn, + }, + } + + _, err := p.Pick(balancer.PickInfo{Ctx: SetHashKey(context.Background(), "anykey")}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "[consistent_hash] invalid addr type in consistent hash") +} + +func TestPicker_NoSubConnForAddr(t *testing.T) { + ring := newCustomRingForTest() + ring.Add("ghost:9999") + + exist := &fakeSubConn{id: 1} + p := &picker{ + hashRing: ring, + conns: map[string]balancer.SubConn{ + "real:8080": exist, + }, + } + + _, err := p.Pick(balancer.PickInfo{Ctx: SetHashKey(context.Background(), "anykey")}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "[consistent_hash] no subConn for addr: ghost:9999") +} + +func TestSetAndGetHashKey(t *testing.T) { + ctx := context.Background() + key := "abc123" + + ctx = SetHashKey(ctx, key) + got := GetHashKey(ctx) + assert.Equal(t, key, got) + + assert.Empty(t, GetHashKey(context.Background())) +} + +func BenchmarkPicker_HashConsistency(b *testing.B) { + subConn1 := &fakeSubConn{id: 1} + subConn2 := &fakeSubConn{id: 2} + + pb := &pickerBuilder{} + info := base.PickerBuildInfo{ + ReadySCs: map[balancer.SubConn]base.SubConnInfo{ + subConn1: {Address: resolver.Address{Addr: "127.0.0.1:8080"}}, + subConn2: {Address: resolver.Address{Addr: "127.0.0.1:8081"}}, + }, + } + p := pb.Build(info).(*picker) + + ctx := SetHashKey(context.Background(), "hot_user_123") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + res, err := p.Pick(balancer.PickInfo{Ctx: ctx}) + if err != nil || res.SubConn == nil { + b.Fatalf("unexpected result: res=%v err=%v", res.SubConn, err) + } + } +} + +func newCustomRingForTest() *hash.ConsistentHash { + return hash.NewCustomConsistentHash(defaultReplicaCount, hash.Hash) +} diff --git a/zrpc/internal/client.go b/zrpc/internal/client.go index 0d01ae993..8d05f7fed 100644 --- a/zrpc/internal/client.go +++ b/zrpc/internal/client.go @@ -7,7 +7,6 @@ import ( "strings" "time" - "github.com/zeromicro/go-zero/zrpc/internal/balancer/p2c" "github.com/zeromicro/go-zero/zrpc/internal/clientinterceptors" "github.com/zeromicro/go-zero/zrpc/resolver" "google.golang.org/grpc" @@ -53,9 +52,6 @@ func NewClient(target string, middlewares ClientMiddlewaresConf, opts ...ClientO middlewares: middlewares, } - svcCfg := fmt.Sprintf(`{"loadBalancingPolicy":"%s"}`, p2c.Name) - balancerOpt := WithDialOption(grpc.WithDefaultServiceConfig(svcCfg)) - opts = append([]ClientOption{balancerOpt}, opts...) if err := cli.dial(target, opts...); err != nil { return nil, err }