mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-10 16:30:01 +08:00
fix(timingwheel): add missing Wait() call and improve code clarity (#5315)
Signed-off-by: kevin <wanjunfeng@gmail.com>
This commit is contained in:
@@ -164,6 +164,7 @@ func (tw *TimingWheel) Stop() {
|
|||||||
|
|
||||||
func (tw *TimingWheel) drainAll(fn func(key, value any)) {
|
func (tw *TimingWheel) drainAll(fn func(key, value any)) {
|
||||||
runner := threading.NewTaskRunner(drainWorkers)
|
runner := threading.NewTaskRunner(drainWorkers)
|
||||||
|
|
||||||
for _, slot := range tw.slots {
|
for _, slot := range tw.slots {
|
||||||
for e := slot.Front(); e != nil; {
|
for e := slot.Front(); e != nil; {
|
||||||
task := e.Value.(*timingEntry)
|
task := e.Value.(*timingEntry)
|
||||||
@@ -177,6 +178,8 @@ func (tw *TimingWheel) drainAll(fn func(key, value any)) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
runner.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos, circle int) {
|
func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos, circle int) {
|
||||||
|
|||||||
@@ -629,6 +629,157 @@ func TestMoveAndRemoveTask(t *testing.T) {
|
|||||||
assert.Equal(t, 0, len(keys))
|
assert.Equal(t, 0, len(keys))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestTimingWheel_DrainClosureBug tests the closure capture bug in drainAll
|
||||||
|
// Issue: https://github.com/zeromicro/go-zero/issues/5314
|
||||||
|
func TestTimingWheel_DrainClosureBug(t *testing.T) {
|
||||||
|
ticker := timex.NewFakeTicker()
|
||||||
|
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker)
|
||||||
|
defer tw.Stop()
|
||||||
|
|
||||||
|
// Set multiple timers with different values
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
tw.SetTimer(i, i*10, testStep*5)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Give time for timers to be set
|
||||||
|
time.Sleep(time.Millisecond * 100)
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
received := make(map[int]int)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(10)
|
||||||
|
|
||||||
|
tw.Drain(func(key, value any) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
k := key.(int)
|
||||||
|
v := value.(int)
|
||||||
|
received[k] = v
|
||||||
|
wg.Done()
|
||||||
|
})
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Check if all values match their keys
|
||||||
|
for k, v := range received {
|
||||||
|
expected := k * 10
|
||||||
|
assert.Equal(t, expected, v, "key %d should have value %d, got %d", k, expected, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTimingWheel_RunTasksClosureBug tests the closure capture bug in runTasks
|
||||||
|
// Issue: https://github.com/zeromicro/go-zero/issues/5314
|
||||||
|
func TestTimingWheel_RunTasksClosureBug(t *testing.T) {
|
||||||
|
ticker := timex.NewFakeTicker()
|
||||||
|
var mu sync.Mutex
|
||||||
|
executed := make(map[int]int)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
key := k.(int)
|
||||||
|
val := v.(int)
|
||||||
|
executed[key] = val
|
||||||
|
wg.Done()
|
||||||
|
}, ticker)
|
||||||
|
defer tw.Stop()
|
||||||
|
|
||||||
|
// Set multiple timers that should fire in the same tick
|
||||||
|
count := 10
|
||||||
|
wg.Add(count)
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
tw.SetTimer(i, i*10, testStep)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Advance ticker to trigger tasks
|
||||||
|
ticker.Tick()
|
||||||
|
|
||||||
|
// Wait for execution with timeout
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// Success
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for tasks to execute")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all tasks executed with correct values
|
||||||
|
assert.Equal(t, count, len(executed), "should have executed all tasks")
|
||||||
|
for k, v := range executed {
|
||||||
|
expected := k * 10
|
||||||
|
assert.Equal(t, expected, v, "key %d should have value %d, got %d", k, expected, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTimingWheel_RunTasksRaceCondition tests for race conditions in runTasks
|
||||||
|
// This test specifically targets the loop variable capture bug
|
||||||
|
func TestTimingWheel_RunTasksRaceCondition(t *testing.T) {
|
||||||
|
// Run multiple times to increase likelihood of catching the bug
|
||||||
|
for attempt := 0; attempt < 10; attempt++ {
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
ticker := timex.NewFakeTicker()
|
||||||
|
var mu sync.Mutex
|
||||||
|
keyValues := make(map[int][]int)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
|
||||||
|
// Add small delay to increase chance of race
|
||||||
|
time.Sleep(time.Microsecond)
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
key := k.(int)
|
||||||
|
val := v.(int)
|
||||||
|
keyValues[key] = append(keyValues[key], val)
|
||||||
|
wg.Done()
|
||||||
|
}, ticker)
|
||||||
|
defer tw.Stop()
|
||||||
|
|
||||||
|
// Set many timers rapidly to increase chance of race
|
||||||
|
count := 50
|
||||||
|
wg.Add(count)
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
tw.SetTimer(i, i*100, testStep)
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker.Tick()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for tasks")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for duplicates or wrong values
|
||||||
|
wrongCount := 0
|
||||||
|
for key, values := range keyValues {
|
||||||
|
assert.Equal(t, 1, len(values), "key %d should only execute once, got %v", key, values)
|
||||||
|
if len(values) > 0 {
|
||||||
|
expected := key * 100
|
||||||
|
if values[0] != expected {
|
||||||
|
wrongCount++
|
||||||
|
t.Logf("BUG DETECTED: key %d should have value %d, got %d", key, expected, values[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if wrongCount > 0 {
|
||||||
|
t.Errorf("Found %d tasks with wrong values due to closure bug", wrongCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkTimingWheel(b *testing.B) {
|
func BenchmarkTimingWheel(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ func TestRoutineGroupRun(t *testing.T) {
|
|||||||
|
|
||||||
func TestRoutingGroupRunSafe(t *testing.T) {
|
func TestRoutingGroupRunSafe(t *testing.T) {
|
||||||
logtest.Discard(t)
|
logtest.Discard(t)
|
||||||
|
|
||||||
var count int32
|
var count int32
|
||||||
group := NewRoutineGroup()
|
group := NewRoutineGroup()
|
||||||
var once sync.Once
|
var once sync.Once
|
||||||
|
|||||||
Reference in New Issue
Block a user