From 3c9b6335fb18a7c7cfe8125c66e215656de47a57 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Fri, 18 Jul 2025 21:25:40 +0800 Subject: [PATCH] chore: refactor set in collection package (#5016) --- core/collection/set.go | 338 ++---------------- core/collection/set_test.go | 154 +++----- rest/handler/tracehandler.go | 4 +- .../serverinterceptors/statinterceptor.go | 8 +- .../statinterceptor_test.go | 11 +- 5 files changed, 74 insertions(+), 441 deletions(-) diff --git a/core/collection/set.go b/core/collection/set.go index 396a7eae2..41fbb3cb5 100644 --- a/core/collection/set.go +++ b/core/collection/set.go @@ -1,85 +1,45 @@ package collection -import ( - "github.com/zeromicro/go-zero/core/lang" - "github.com/zeromicro/go-zero/core/logx" -) +import "github.com/zeromicro/go-zero/core/lang" -const ( - unmanaged = iota - untyped - intType - int64Type - uintType - uint64Type - stringType -) - -// TypedSet is a type-safe generic set collection. It's not thread-safe, -// use with synchronization for concurrent access. -// -// Advantages over the legacy Set: -// - Compile-time type safety (no runtime type validation needed) -// - Better performance (no type assertions or reflection overhead) -// - Cleaner API (single Add method instead of multiple type-specific methods) -// - No need for type-specific Keys methods (KeysInt, KeysStr, etc.) -// - Zero-allocation for empty checks and direct type access -type TypedSet[T comparable] struct { +// Set is a type-safe generic set collection. +// It's not thread-safe, use with synchronization for concurrent access. +type Set[T comparable] struct { data map[T]lang.PlaceholderType } -// NewTypedSet returns a new type-safe set. -func NewTypedSet[T comparable]() *TypedSet[T] { - return &TypedSet[T]{ +// NewSet returns a new type-safe set. +func NewSet[T comparable]() *Set[T] { + return &Set[T]{ data: make(map[T]lang.PlaceholderType), } } -// NewIntSet returns a new int-typed set. -func NewIntSet() *TypedSet[int] { - return NewTypedSet[int]() -} - -// NewInt64Set returns a new int64-typed set. -func NewInt64Set() *TypedSet[int64] { - return NewTypedSet[int64]() -} - -// NewUintSet returns a new uint-typed set. -func NewUintSet() *TypedSet[uint] { - return NewTypedSet[uint]() -} - -// NewUint64Set returns a new uint64-typed set. -func NewUint64Set() *TypedSet[uint64] { - return NewTypedSet[uint64]() -} - -// NewStringSet returns a new string-typed set. -func NewStringSet() *TypedSet[string] { - return NewTypedSet[string]() -} - // Add adds items to the set. Duplicates are automatically ignored. -func (s *TypedSet[T]) Add(items ...T) { +func (s *Set[T]) Add(items ...T) { for _, item := range items { s.data[item] = lang.Placeholder } } +// Clear removes all items from the set. +func (s *Set[T]) Clear() { + clear(s.data) +} + // Contains checks if an item exists in the set. -func (s *TypedSet[T]) Contains(item T) bool { +func (s *Set[T]) Contains(item T) bool { _, ok := s.data[item] return ok } -// Remove removes an item from the set. -func (s *TypedSet[T]) Remove(item T) { - delete(s.data, item) +// Count returns the number of items in the set. +func (s *Set[T]) Count() int { + return len(s.data) } // Keys returns all elements in the set as a slice. -func (s *TypedSet[T]) Keys() []T { +func (s *Set[T]) Keys() []T { keys := make([]T, 0, len(s.data)) for key := range s.data { keys = append(keys, key) @@ -87,263 +47,7 @@ func (s *TypedSet[T]) Keys() []T { return keys } -// Count returns the number of items in the set. -func (s *TypedSet[T]) Count() int { - return len(s.data) -} - -// Clear removes all items from the set. -func (s *TypedSet[T]) Clear() { - s.data = make(map[T]lang.PlaceholderType) -} - -// Set is not thread-safe, for concurrent use, make sure to use it with synchronization. -// Deprecated: Use TypedSet[T] instead for better type safety and performance. -// TypedSet provides compile-time type checking and eliminates the need for type-specific methods. -type Set struct { - data map[any]lang.PlaceholderType - tp int -} - -// NewSet returns a managed Set, can only put the values with the same type. -// Deprecated: Use NewTypedSet[T]() instead for better type safety and performance. -// Example: NewIntSet() instead of NewSet() with AddInt() -func NewSet() *Set { - return &Set{ - data: make(map[any]lang.PlaceholderType), - tp: untyped, - } -} - -// NewUnmanagedSet returns an unmanaged Set, which can put values with different types. -// Deprecated: Use TypedSet[any] or multiple TypedSet instances for different types instead. -// If you really need mixed types, consider using map[any]struct{} directly. -func NewUnmanagedSet() *Set { - return &Set{ - data: make(map[any]lang.PlaceholderType), - tp: unmanaged, - } -} - -// Add adds i into s. -// Deprecated: Use TypedSet[T].Add() instead for better type safety and performance. -func (s *Set) Add(i ...any) { - for _, each := range i { - s.add(each) - } -} - -// AddInt adds int values ii into s. -// Deprecated: Use NewIntSet().Add() instead for better type safety and performance. -// Example: intSet := NewIntSet(); intSet.Add(1, 2, 3) -func (s *Set) AddInt(ii ...int) { - for _, each := range ii { - s.add(each) - } -} - -// AddInt64 adds int64 values ii into s. -// Deprecated: Use NewInt64Set().Add() instead for better type safety and performance. -// Example: int64Set := NewInt64Set(); int64Set.Add(1, 2, 3) -func (s *Set) AddInt64(ii ...int64) { - for _, each := range ii { - s.add(each) - } -} - -// AddUint adds uint values ii into s. -// Deprecated: Use NewUintSet().Add() instead for better type safety and performance. -// Example: uintSet := NewUintSet(); uintSet.Add(1, 2, 3) -func (s *Set) AddUint(ii ...uint) { - for _, each := range ii { - s.add(each) - } -} - -// AddUint64 adds uint64 values ii into s. -// Deprecated: Use NewUint64Set().Add() instead for better type safety and performance. -// Example: uint64Set := NewUint64Set(); uint64Set.Add(1, 2, 3) -func (s *Set) AddUint64(ii ...uint64) { - for _, each := range ii { - s.add(each) - } -} - -// AddStr adds string values ss into s. -// Deprecated: Use NewStringSet().Add() instead for better type safety and performance. -// Example: stringSet := NewStringSet(); stringSet.Add("a", "b", "c") -func (s *Set) AddStr(ss ...string) { - for _, each := range ss { - s.add(each) - } -} - -// Contains checks if i is in s. -// Deprecated: Use TypedSet[T].Contains() instead for better type safety and performance. -func (s *Set) Contains(i any) bool { - if len(s.data) == 0 { - return false - } - - s.validate(i) - _, ok := s.data[i] - return ok -} - -// Keys returns the keys in s. -// Deprecated: Use TypedSet[T].Keys() instead for better type safety and performance. -func (s *Set) Keys() []any { - var keys []any - - for key := range s.data { - keys = append(keys, key) - } - - return keys -} - -// KeysInt returns the int keys in s. -// Deprecated: Use NewIntSet().Keys() instead for better type safety and performance. -// The TypedSet version returns []int directly without type casting. -func (s *Set) KeysInt() []int { - var keys []int - - for key := range s.data { - if intKey, ok := key.(int); ok { - keys = append(keys, intKey) - } - } - - return keys -} - -// KeysInt64 returns int64 keys in s. -// Deprecated: Use NewInt64Set().Keys() instead for better type safety and performance. -// The TypedSet version returns []int64 directly without type casting. -func (s *Set) KeysInt64() []int64 { - var keys []int64 - - for key := range s.data { - if intKey, ok := key.(int64); ok { - keys = append(keys, intKey) - } - } - - return keys -} - -// KeysUint returns uint keys in s. -// Deprecated: Use NewUintSet().Keys() instead for better type safety and performance. -// The TypedSet version returns []uint directly without type casting. -func (s *Set) KeysUint() []uint { - var keys []uint - - for key := range s.data { - if intKey, ok := key.(uint); ok { - keys = append(keys, intKey) - } - } - - return keys -} - -// KeysUint64 returns uint64 keys in s. -// -// Deprecated: Use NewUint64Set().Keys() instead for better type safety and performance. -// The TypedSet version returns []uint64 directly without type casting. -func (s *Set) KeysUint64() []uint64 { - var keys []uint64 - - for key := range s.data { - if intKey, ok := key.(uint64); ok { - keys = append(keys, intKey) - } - } - - return keys -} - -// KeysStr returns string keys in s. -// Deprecated: Use NewStringSet().Keys() instead for better type safety and performance. -// The TypedSet version returns []string directly without type casting. -func (s *Set) KeysStr() []string { - var keys []string - - for key := range s.data { - if strKey, ok := key.(string); ok { - keys = append(keys, strKey) - } - } - - return keys -} - -// Remove removes i from s. -// Deprecated: Use TypedSet[T].Remove() instead for better type safety and performance. -func (s *Set) Remove(i any) { - s.validate(i) - delete(s.data, i) -} - -// Count returns the number of items in s. -// Deprecated: Use TypedSet[T].Count() instead for better type safety and performance. -func (s *Set) Count() int { - return len(s.data) -} - -func (s *Set) add(i any) { - switch s.tp { - case unmanaged: - // do nothing - case untyped: - s.setType(i) - default: - s.validate(i) - } - s.data[i] = lang.Placeholder -} - -func (s *Set) setType(i any) { - // s.tp can only be untyped here - switch i.(type) { - case int: - s.tp = intType - case int64: - s.tp = int64Type - case uint: - s.tp = uintType - case uint64: - s.tp = uint64Type - case string: - s.tp = stringType - } -} - -func (s *Set) validate(i any) { - if s.tp == unmanaged { - return - } - - switch i.(type) { - case int: - if s.tp != intType { - logx.Errorf("element is int, but set contains elements with type %d", s.tp) - } - case int64: - if s.tp != int64Type { - logx.Errorf("element is int64, but set contains elements with type %d", s.tp) - } - case uint: - if s.tp != uintType { - logx.Errorf("element is uint, but set contains elements with type %d", s.tp) - } - case uint64: - if s.tp != uint64Type { - logx.Errorf("element is uint64, but set contains elements with type %d", s.tp) - } - case string: - if s.tp != stringType { - logx.Errorf("element is string, but set contains elements with type %d", s.tp) - } - } +// Remove removes an item from the set. +func (s *Set[T]) Remove(item T) { + delete(s.data, item) } diff --git a/core/collection/set_test.go b/core/collection/set_test.go index 6cb7e4692..11397f961 100644 --- a/core/collection/set_test.go +++ b/core/collection/set_test.go @@ -12,9 +12,9 @@ func init() { logx.Disable() } -// TypedSet functionality tests +// Set functionality tests func TestTypedSetInt(t *testing.T) { - set := NewIntSet() + set := NewSet[int]() values := []int{1, 2, 3, 2, 1} // Contains duplicates // Test adding @@ -39,7 +39,7 @@ func TestTypedSetInt(t *testing.T) { } func TestTypedSetStringOps(t *testing.T) { - set := NewStringSet() + set := NewSet[string]() values := []string{"a", "b", "c", "b", "a"} set.Add(values...) @@ -56,7 +56,7 @@ func TestTypedSetStringOps(t *testing.T) { } func TestTypedSetClear(t *testing.T) { - set := NewIntSet() + set := NewSet[int]() set.Add(1, 2, 3) assert.Equal(t, 3, set.Count()) @@ -66,7 +66,7 @@ func TestTypedSetClear(t *testing.T) { } func TestTypedSetEmpty(t *testing.T) { - set := NewIntSet() + set := NewSet[int]() assert.Equal(t, 0, set.Count()) assert.False(t, set.Contains(1)) assert.Empty(t, set.Keys()) @@ -74,16 +74,16 @@ func TestTypedSetEmpty(t *testing.T) { func TestTypedSetMultipleTypes(t *testing.T) { // Test different typed generic sets - intSet := NewIntSet() - int64Set := NewInt64Set() - uintSet := NewUintSet() - uint64Set := NewUint64Set() - stringSet := NewStringSet() + intSet := NewSet[int]() + int64Set := NewSet[int64]() + uintSet := NewSet[uint]() + uint64Set := NewSet[uint64]() + stringSet := NewSet[string]() intSet.Add(1, 2, 3) - int64Set.Add(int64(1), int64(2), int64(3)) - uintSet.Add(uint(1), uint(2), uint(3)) - uint64Set.Add(uint64(1), uint64(2), uint64(3)) + int64Set.Add(1, 2, 3) + uintSet.Add(1, 2, 3) + uint64Set.Add(1, 2, 3) stringSet.Add("1", "2", "3") assert.Equal(t, 3, intSet.Count()) @@ -93,9 +93,9 @@ func TestTypedSetMultipleTypes(t *testing.T) { assert.Equal(t, 3, stringSet.Count()) } -// TypedSet benchmarks +// Set benchmarks func BenchmarkTypedIntSet(b *testing.B) { - s := NewIntSet() + s := NewSet[int]() for i := 0; i < b.N; i++ { s.Add(i) _ = s.Contains(i) @@ -103,7 +103,7 @@ func BenchmarkTypedIntSet(b *testing.B) { } func BenchmarkTypedStringSet(b *testing.B) { - s := NewStringSet() + s := NewSet[string]() for i := 0; i < b.N; i++ { s.Add(string(rune(i))) _ = s.Contains(string(rune(i))) @@ -119,26 +119,10 @@ func BenchmarkRawSet(b *testing.B) { } } -func BenchmarkUnmanagedSet(b *testing.B) { - s := NewUnmanagedSet() - for i := 0; i < b.N; i++ { - s.Add(i) - _ = s.Contains(i) - } -} - -func BenchmarkSet(b *testing.B) { - s := NewSet() - for i := 0; i < b.N; i++ { - s.AddInt(i) - _ = s.Contains(i) - } -} - func TestAdd(t *testing.T) { // given - set := NewUnmanagedSet() - values := []any{1, 2, 3} + set := NewSet[int]() + values := []int{1, 2, 3} // when set.Add(values...) @@ -150,82 +134,74 @@ func TestAdd(t *testing.T) { func TestAddInt(t *testing.T) { // given - set := NewSet() + set := NewSet[int]() values := []int{1, 2, 3} // when - set.AddInt(values...) + set.Add(values...) // then assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3)) - keys := set.KeysInt() + keys := set.Keys() sort.Ints(keys) assert.EqualValues(t, values, keys) } func TestAddInt64(t *testing.T) { // given - set := NewSet() + set := NewSet[int64]() values := []int64{1, 2, 3} // when - set.AddInt64(values...) + set.Add(values...) // then - assert.True(t, set.Contains(int64(1)) && set.Contains(int64(2)) && set.Contains(int64(3))) - assert.Equal(t, len(values), len(set.KeysInt64())) + assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3)) + assert.Equal(t, len(values), len(set.Keys())) } func TestAddUint(t *testing.T) { // given - set := NewSet() + set := NewSet[uint]() values := []uint{1, 2, 3} // when - set.AddUint(values...) + set.Add(values...) // then - assert.True(t, set.Contains(uint(1)) && set.Contains(uint(2)) && set.Contains(uint(3))) - assert.Equal(t, len(values), len(set.KeysUint())) + assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3)) + assert.Equal(t, len(values), len(set.Keys())) } func TestAddUint64(t *testing.T) { // given - set := NewSet() + set := NewSet[uint64]() values := []uint64{1, 2, 3} // when - set.AddUint64(values...) + set.Add(values...) // then - assert.True(t, set.Contains(uint64(1)) && set.Contains(uint64(2)) && set.Contains(uint64(3))) - assert.Equal(t, len(values), len(set.KeysUint64())) + assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3)) + assert.Equal(t, len(values), len(set.Keys())) } func TestAddStr(t *testing.T) { // given - set := NewSet() + set := NewSet[string]() values := []string{"1", "2", "3"} // when - set.AddStr(values...) + set.Add(values...) // then assert.True(t, set.Contains("1") && set.Contains("2") && set.Contains("3")) - assert.Equal(t, len(values), len(set.KeysStr())) + assert.Equal(t, len(values), len(set.Keys())) } func TestContainsWithoutElements(t *testing.T) { // given - set := NewSet() - - // then - assert.False(t, set.Contains(1)) -} - -func TestContainsUnmanagedWithoutElements(t *testing.T) { - // given - set := NewUnmanagedSet() + set := NewSet[int]() // then assert.False(t, set.Contains(1)) @@ -233,8 +209,8 @@ func TestContainsUnmanagedWithoutElements(t *testing.T) { func TestRemove(t *testing.T) { // given - set := NewSet() - set.Add([]any{1, 2, 3}...) + set := NewSet[int]() + set.Add([]int{1, 2, 3}...) // when set.Remove(2) @@ -245,57 +221,9 @@ func TestRemove(t *testing.T) { func TestCount(t *testing.T) { // given - set := NewSet() - set.Add([]any{1, 2, 3}...) + set := NewSet[int]() + set.Add([]int{1, 2, 3}...) // then assert.Equal(t, set.Count(), 3) } - -func TestKeysIntMismatch(t *testing.T) { - set := NewSet() - set.add(int64(1)) - set.add(2) - vals := set.KeysInt() - assert.EqualValues(t, []int{2}, vals) -} - -func TestKeysInt64Mismatch(t *testing.T) { - set := NewSet() - set.add(1) - set.add(int64(2)) - vals := set.KeysInt64() - assert.EqualValues(t, []int64{2}, vals) -} - -func TestKeysUintMismatch(t *testing.T) { - set := NewSet() - set.add(1) - set.add(uint(2)) - vals := set.KeysUint() - assert.EqualValues(t, []uint{2}, vals) -} - -func TestKeysUint64Mismatch(t *testing.T) { - set := NewSet() - set.add(1) - set.add(uint64(2)) - vals := set.KeysUint64() - assert.EqualValues(t, []uint64{2}, vals) -} - -func TestKeysStrMismatch(t *testing.T) { - set := NewSet() - set.add(1) - set.add("2") - vals := set.KeysStr() - assert.EqualValues(t, []string{"2"}, vals) -} - -func TestSetType(t *testing.T) { - set := NewUnmanagedSet() - set.add(1) - set.add("2") - vals := set.Keys() - assert.ElementsMatch(t, []any{1, "2"}, vals) -} diff --git a/rest/handler/tracehandler.go b/rest/handler/tracehandler.go index 7a11da462..5df49db2e 100644 --- a/rest/handler/tracehandler.go +++ b/rest/handler/tracehandler.go @@ -29,8 +29,8 @@ func TraceHandler(serviceName, path string, opts ...TraceOption) func(http.Handl opt(&options) } - ignorePaths := collection.NewSet() - ignorePaths.AddStr(options.traceIgnorePaths...) + ignorePaths := collection.NewSet[string]() + ignorePaths.Add(options.traceIgnorePaths...) return func(next http.Handler) http.Handler { tracer := otel.Tracer(trace.TraceName) diff --git a/zrpc/internal/serverinterceptors/statinterceptor.go b/zrpc/internal/serverinterceptors/statinterceptor.go index b98721f55..32279ba5a 100644 --- a/zrpc/internal/serverinterceptors/statinterceptor.go +++ b/zrpc/internal/serverinterceptors/statinterceptor.go @@ -43,8 +43,8 @@ func SetSlowThreshold(threshold time.Duration) { // UnaryStatInterceptor returns a func that uses given metrics to report stats. func UnaryStatInterceptor(metrics *stat.Metrics, conf StatConf) grpc.UnaryServerInterceptor { - staticNotLoggingContentMethods := collection.NewSet() - staticNotLoggingContentMethods.AddStr(conf.IgnoreContentMethods...) + staticNotLoggingContentMethods := collection.NewSet[string]() + staticNotLoggingContentMethods.Add(conf.IgnoreContentMethods...) return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { @@ -68,7 +68,7 @@ func isSlow(duration, durationThreshold time.Duration) bool { } func logDuration(ctx context.Context, method string, req any, duration time.Duration, - ignoreMethods *collection.Set, durationThreshold time.Duration) { + ignoreMethods *collection.Set[string], durationThreshold time.Duration) { var addr string client, ok := peer.FromContext(ctx) if ok { @@ -92,7 +92,7 @@ func logDuration(ctx context.Context, method string, req any, duration time.Dura } } -func shouldLogContent(method string, ignoreMethods *collection.Set) bool { +func shouldLogContent(method string, ignoreMethods *collection.Set[string]) bool { _, ok := ignoreContentMethods.Load(method) return !ok && !ignoreMethods.Contains(method) } diff --git a/zrpc/internal/serverinterceptors/statinterceptor_test.go b/zrpc/internal/serverinterceptors/statinterceptor_test.go index 322446e26..1fa4f2325 100644 --- a/zrpc/internal/serverinterceptors/statinterceptor_test.go +++ b/zrpc/internal/serverinterceptors/statinterceptor_test.go @@ -88,7 +88,7 @@ func TestLogDuration(t *testing.T) { assert.NotPanics(t, func() { logDuration(test.ctx, "foo", test.req, test.duration, - collection.NewSet(), test.durationThreshold) + collection.NewSet[string](), test.durationThreshold) }) }) } @@ -150,7 +150,7 @@ func TestLogDurationWithoutContent(t *testing.T) { assert.NotPanics(t, func() { logDuration(test.ctx, "foo", test.req, test.duration, - collection.NewSet(), test.durationThreshold) + collection.NewSet[string](), test.durationThreshold) }) }) } @@ -206,9 +206,10 @@ func Test_shouldLogContent(t *testing.T) { t.Cleanup(func() { ignoreContentMethods = sync.Map{} }) - set := collection.NewSet() - set.AddStr(tt.args.staticNotLoggingContentMethods...) - assert.Equalf(t, tt.want, shouldLogContent(tt.args.method, set), "shouldLogContent(%v, %v)", tt.args.method, tt.args.staticNotLoggingContentMethods) + set := collection.NewSet[string]() + set.Add(tt.args.staticNotLoggingContentMethods...) + assert.Equalf(t, tt.want, shouldLogContent(tt.args.method, set), + "shouldLogContent(%v, %v)", tt.args.method, tt.args.staticNotLoggingContentMethods) }) } }