diff --git a/core/collection/timingwheel.go b/core/collection/timingwheel.go index 5bd8dd281..3de88739a 100644 --- a/core/collection/timingwheel.go +++ b/core/collection/timingwheel.go @@ -164,6 +164,7 @@ func (tw *TimingWheel) Stop() { func (tw *TimingWheel) drainAll(fn func(key, value any)) { runner := threading.NewTaskRunner(drainWorkers) + for _, slot := range tw.slots { for e := slot.Front(); e != nil; { task := e.Value.(*timingEntry) @@ -177,6 +178,8 @@ func (tw *TimingWheel) drainAll(fn func(key, value any)) { } } } + + runner.Wait() } func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos, circle int) { diff --git a/core/collection/timingwheel_test.go b/core/collection/timingwheel_test.go index c997725ff..4d992eef3 100644 --- a/core/collection/timingwheel_test.go +++ b/core/collection/timingwheel_test.go @@ -629,6 +629,157 @@ func TestMoveAndRemoveTask(t *testing.T) { assert.Equal(t, 0, len(keys)) } +// TestTimingWheel_DrainClosureBug tests the closure capture bug in drainAll +// Issue: https://github.com/zeromicro/go-zero/issues/5314 +func TestTimingWheel_DrainClosureBug(t *testing.T) { + ticker := timex.NewFakeTicker() + tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker) + defer tw.Stop() + + // Set multiple timers with different values + for i := 0; i < 10; i++ { + tw.SetTimer(i, i*10, testStep*5) + } + + // Give time for timers to be set + time.Sleep(time.Millisecond * 100) + + var mu sync.Mutex + received := make(map[int]int) + var wg sync.WaitGroup + wg.Add(10) + + tw.Drain(func(key, value any) { + mu.Lock() + defer mu.Unlock() + k := key.(int) + v := value.(int) + received[k] = v + wg.Done() + }) + + wg.Wait() + + // Check if all values match their keys + for k, v := range received { + expected := k * 10 + assert.Equal(t, expected, v, "key %d should have value %d, got %d", k, expected, v) + } +} + +// TestTimingWheel_RunTasksClosureBug tests the closure capture bug in runTasks +// Issue: https://github.com/zeromicro/go-zero/issues/5314 +func TestTimingWheel_RunTasksClosureBug(t *testing.T) { + ticker := timex.NewFakeTicker() + var mu sync.Mutex + executed := make(map[int]int) + var wg sync.WaitGroup + + tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) { + mu.Lock() + defer mu.Unlock() + key := k.(int) + val := v.(int) + executed[key] = val + wg.Done() + }, ticker) + defer tw.Stop() + + // Set multiple timers that should fire in the same tick + count := 10 + wg.Add(count) + for i := 0; i < count; i++ { + tw.SetTimer(i, i*10, testStep) + } + + // Advance ticker to trigger tasks + ticker.Tick() + + // Wait for execution with timeout + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Success + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for tasks to execute") + } + + // Verify all tasks executed with correct values + assert.Equal(t, count, len(executed), "should have executed all tasks") + for k, v := range executed { + expected := k * 10 + assert.Equal(t, expected, v, "key %d should have value %d, got %d", k, expected, v) + } +} + +// TestTimingWheel_RunTasksRaceCondition tests for race conditions in runTasks +// This test specifically targets the loop variable capture bug +func TestTimingWheel_RunTasksRaceCondition(t *testing.T) { + // Run multiple times to increase likelihood of catching the bug + for attempt := 0; attempt < 10; attempt++ { + t.Run("", func(t *testing.T) { + ticker := timex.NewFakeTicker() + var mu sync.Mutex + keyValues := make(map[int][]int) + var wg sync.WaitGroup + + tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) { + // Add small delay to increase chance of race + time.Sleep(time.Microsecond) + mu.Lock() + defer mu.Unlock() + key := k.(int) + val := v.(int) + keyValues[key] = append(keyValues[key], val) + wg.Done() + }, ticker) + defer tw.Stop() + + // Set many timers rapidly to increase chance of race + count := 50 + wg.Add(count) + for i := 0; i < count; i++ { + tw.SetTimer(i, i*100, testStep) + } + + ticker.Tick() + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for tasks") + } + + // Check for duplicates or wrong values + wrongCount := 0 + for key, values := range keyValues { + assert.Equal(t, 1, len(values), "key %d should only execute once, got %v", key, values) + if len(values) > 0 { + expected := key * 100 + if values[0] != expected { + wrongCount++ + t.Logf("BUG DETECTED: key %d should have value %d, got %d", key, expected, values[0]) + } + } + } + if wrongCount > 0 { + t.Errorf("Found %d tasks with wrong values due to closure bug", wrongCount) + } + }) + } +} + func BenchmarkTimingWheel(b *testing.B) { b.ReportAllocs() diff --git a/core/threading/routinegroup_test.go b/core/threading/routinegroup_test.go index fce7ffff3..a5dda3924 100644 --- a/core/threading/routinegroup_test.go +++ b/core/threading/routinegroup_test.go @@ -25,6 +25,7 @@ func TestRoutineGroupRun(t *testing.T) { func TestRoutingGroupRunSafe(t *testing.T) { logtest.Discard(t) + var count int32 group := NewRoutineGroup() var once sync.Once