Files
go-zero/core/collection/timingwheel_test.go

794 lines
16 KiB
Go

package collection
import (
"sort"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/stringx"
"github.com/zeromicro/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/timex"
)
const (
testStep = time.Minute
waitTime = time.Second
)
func TestNewTimingWheel(t *testing.T) {
_, err := NewTimingWheel(0, 10, func(key, value any) {})
assert.NotNil(t, err)
}
func TestTimingWheel_Drain(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
}, ticker)
tw.SetTimer("first", 3, testStep*4)
tw.SetTimer("second", 5, testStep*7)
tw.SetTimer("third", 7, testStep*7)
var keys []string
var vals []int
var lock sync.Mutex
var wg sync.WaitGroup
wg.Add(3)
tw.Drain(func(key, value any) {
lock.Lock()
defer lock.Unlock()
keys = append(keys, key.(string))
vals = append(vals, value.(int))
wg.Done()
})
wg.Wait()
sort.Strings(keys)
sort.Ints(vals)
assert.Equal(t, 3, len(keys))
assert.EqualValues(t, []string{"first", "second", "third"}, keys)
assert.EqualValues(t, []int{3, 5, 7}, vals)
var count int
tw.Drain(func(key, value any) {
count++
})
time.Sleep(time.Millisecond * 100)
assert.Equal(t, 0, count)
tw.Stop()
assert.Equal(t, ErrClosed, tw.Drain(func(key, value any) {}))
}
func TestTimingWheel_SetTimerSoon(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
ticker.Done()
}, ticker)
defer tw.Stop()
tw.SetTimer("any", 3, testStep>>1)
ticker.Tick()
assert.Nil(t, ticker.Wait(waitTime))
assert.True(t, run.True())
}
func TestTimingWheel_SetTimerTwice(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 5, v.(int))
ticker.Done()
}, ticker)
defer tw.Stop()
tw.SetTimer("any", 3, testStep*4)
tw.SetTimer("any", 5, testStep*7)
for i := 0; i < 8; i++ {
ticker.Tick()
}
assert.Nil(t, ticker.Wait(waitTime))
assert.True(t, run.True())
}
func TestTimingWheel_SetTimerWrongDelay(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker)
defer tw.Stop()
assert.NotPanics(t, func() {
tw.SetTimer("any", 3, -testStep)
})
}
func TestTimingWheel_SetTimerAfterClose(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker)
tw.Stop()
assert.Equal(t, ErrClosed, tw.SetTimer("any", 3, testStep))
}
func TestTimingWheel_MoveTimer(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v any) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
ticker.Done()
}, ticker)
tw.SetTimer("any", 3, testStep*4)
tw.MoveTimer("any", testStep*7)
tw.MoveTimer("any", -testStep)
tw.MoveTimer("none", testStep)
for i := 0; i < 5; i++ {
ticker.Tick()
}
assert.False(t, run.True())
for i := 0; i < 3; i++ {
ticker.Tick()
}
assert.Nil(t, ticker.Wait(waitTime))
assert.True(t, run.True())
tw.Stop()
assert.Equal(t, ErrClosed, tw.MoveTimer("any", time.Millisecond))
}
func TestTimingWheel_MoveTimerSoon(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v any) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
ticker.Done()
}, ticker)
defer tw.Stop()
tw.SetTimer("any", 3, testStep*4)
tw.MoveTimer("any", testStep>>1)
assert.Nil(t, ticker.Wait(waitTime))
assert.True(t, run.True())
}
func TestTimingWheel_MoveTimerEarlier(t *testing.T) {
run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
ticker.Done()
}, ticker)
defer tw.Stop()
tw.SetTimer("any", 3, testStep*4)
tw.MoveTimer("any", testStep*2)
for i := 0; i < 3; i++ {
ticker.Tick()
}
assert.Nil(t, ticker.Wait(waitTime))
assert.True(t, run.True())
}
func TestTimingWheel_RemoveTimer(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker)
tw.SetTimer("any", 3, testStep)
assert.NotPanics(t, func() {
tw.RemoveTimer("any")
tw.RemoveTimer("none")
tw.RemoveTimer(nil)
})
for i := 0; i < 5; i++ {
ticker.Tick()
}
tw.Stop()
assert.Equal(t, ErrClosed, tw.RemoveTimer("any"))
}
func TestTimingWheel_SetTimer(t *testing.T) {
tests := []struct {
slots int
setAt time.Duration
}{
{
slots: 5,
setAt: 5,
},
{
slots: 5,
setAt: 7,
},
{
slots: 5,
setAt: 10,
},
{
slots: 5,
setAt: 12,
},
{
slots: 5,
setAt: 7,
},
{
slots: 5,
setAt: 10,
},
{
slots: 5,
setAt: 12,
},
}
for _, test := range tests {
test := test
t.Run(stringx.RandId(), func(t *testing.T) {
t.Parallel()
var count int32
ticker := timex.NewFakeTicker()
tick := func() {
atomic.AddInt32(&count, 1)
ticker.Tick()
time.Sleep(time.Millisecond)
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) {
assert.Equal(t, 1, key.(int))
assert.Equal(t, 2, value.(int))
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
assert.Nil(t, err)
defer tw.Stop()
tw.SetTimer(1, 2, testStep*test.setAt)
for {
select {
case <-done:
assert.Equal(t, int32(test.setAt), actual)
return
default:
tick()
}
}
})
}
}
func TestTimingWheel_SetAndMoveThenStart(t *testing.T) {
tests := []struct {
slots int
setAt time.Duration
moveAt time.Duration
}{
{
slots: 5,
setAt: 3,
moveAt: 5,
},
{
slots: 5,
setAt: 3,
moveAt: 7,
},
{
slots: 5,
setAt: 3,
moveAt: 10,
},
{
slots: 5,
setAt: 3,
moveAt: 12,
},
{
slots: 5,
setAt: 5,
moveAt: 7,
},
{
slots: 5,
setAt: 5,
moveAt: 10,
},
{
slots: 5,
setAt: 5,
moveAt: 12,
},
}
for _, test := range tests {
test := test
t.Run(stringx.RandId(), func(t *testing.T) {
t.Parallel()
var count int32
ticker := timex.NewFakeTicker()
tick := func() {
atomic.AddInt32(&count, 1)
ticker.Tick()
time.Sleep(time.Millisecond * 10)
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) {
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
assert.Nil(t, err)
defer tw.Stop()
tw.SetTimer(1, 2, testStep*test.setAt)
tw.MoveTimer(1, testStep*test.moveAt)
for {
select {
case <-done:
assert.Equal(t, int32(test.moveAt), actual)
return
default:
tick()
}
}
})
}
}
func TestTimingWheel_SetAndMoveTwice(t *testing.T) {
tests := []struct {
slots int
setAt time.Duration
moveAt time.Duration
moveAgainAt time.Duration
}{
{
slots: 5,
setAt: 3,
moveAt: 5,
moveAgainAt: 10,
},
{
slots: 5,
setAt: 3,
moveAt: 7,
moveAgainAt: 12,
},
{
slots: 5,
setAt: 3,
moveAt: 10,
moveAgainAt: 15,
},
{
slots: 5,
setAt: 3,
moveAt: 12,
moveAgainAt: 17,
},
{
slots: 5,
setAt: 5,
moveAt: 7,
moveAgainAt: 12,
},
{
slots: 5,
setAt: 5,
moveAt: 10,
moveAgainAt: 17,
},
{
slots: 5,
setAt: 5,
moveAt: 12,
moveAgainAt: 17,
},
}
for _, test := range tests {
test := test
t.Run(stringx.RandId(), func(t *testing.T) {
t.Parallel()
var count int32
ticker := timex.NewFakeTicker()
tick := func() {
atomic.AddInt32(&count, 1)
ticker.Tick()
time.Sleep(time.Millisecond * 10)
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) {
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
assert.Nil(t, err)
defer tw.Stop()
tw.SetTimer(1, 2, testStep*test.setAt)
tw.MoveTimer(1, testStep*test.moveAt)
tw.MoveTimer(1, testStep*test.moveAgainAt)
for {
select {
case <-done:
assert.Equal(t, int32(test.moveAgainAt), actual)
return
default:
tick()
}
}
})
}
}
func TestTimingWheel_ElapsedAndSet(t *testing.T) {
tests := []struct {
slots int
elapsed time.Duration
setAt time.Duration
}{
{
slots: 5,
elapsed: 3,
setAt: 5,
},
{
slots: 5,
elapsed: 3,
setAt: 7,
},
{
slots: 5,
elapsed: 3,
setAt: 10,
},
{
slots: 5,
elapsed: 3,
setAt: 12,
},
{
slots: 5,
elapsed: 5,
setAt: 7,
},
{
slots: 5,
elapsed: 5,
setAt: 10,
},
{
slots: 5,
elapsed: 5,
setAt: 12,
},
}
for _, test := range tests {
test := test
t.Run(stringx.RandId(), func(t *testing.T) {
t.Parallel()
var count int32
ticker := timex.NewFakeTicker()
tick := func() {
atomic.AddInt32(&count, 1)
ticker.Tick()
time.Sleep(time.Millisecond * 10)
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) {
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
assert.Nil(t, err)
defer tw.Stop()
for i := 0; i < int(test.elapsed); i++ {
tick()
}
tw.SetTimer(1, 2, testStep*test.setAt)
for {
select {
case <-done:
assert.Equal(t, int32(test.elapsed+test.setAt), actual)
return
default:
tick()
}
}
})
}
}
func TestTimingWheel_ElapsedAndSetThenMove(t *testing.T) {
tests := []struct {
slots int
elapsed time.Duration
setAt time.Duration
moveAt time.Duration
}{
{
slots: 5,
elapsed: 3,
setAt: 5,
moveAt: 10,
},
{
slots: 5,
elapsed: 3,
setAt: 7,
moveAt: 12,
},
{
slots: 5,
elapsed: 3,
setAt: 10,
moveAt: 15,
},
{
slots: 5,
elapsed: 3,
setAt: 12,
moveAt: 16,
},
{
slots: 5,
elapsed: 5,
setAt: 7,
moveAt: 12,
},
{
slots: 5,
elapsed: 5,
setAt: 10,
moveAt: 15,
},
{
slots: 5,
elapsed: 5,
setAt: 12,
moveAt: 17,
},
}
for _, test := range tests {
test := test
t.Run(stringx.RandId(), func(t *testing.T) {
t.Parallel()
var count int32
ticker := timex.NewFakeTicker()
tick := func() {
atomic.AddInt32(&count, 1)
ticker.Tick()
time.Sleep(time.Millisecond * 10)
}
var actual int32
done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) {
actual = atomic.LoadInt32(&count)
close(done)
}, ticker)
assert.Nil(t, err)
defer tw.Stop()
for i := 0; i < int(test.elapsed); i++ {
tick()
}
tw.SetTimer(1, 2, testStep*test.setAt)
tw.MoveTimer(1, testStep*test.moveAt)
for {
select {
case <-done:
assert.Equal(t, int32(test.elapsed+test.moveAt), actual)
return
default:
tick()
}
}
})
}
}
func TestMoveAndRemoveTask(t *testing.T) {
ticker := timex.NewFakeTicker()
tick := func(v int) {
for i := 0; i < v; i++ {
ticker.Tick()
}
}
var keys []int
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int))
keys = append(keys, v.(int))
ticker.Done()
}, ticker)
defer tw.Stop()
tw.SetTimer("any", 3, testStep*8)
tick(6)
tw.MoveTimer("any", testStep*7)
tick(3)
tw.RemoveTimer("any")
tick(30)
time.Sleep(time.Millisecond)
assert.Equal(t, 0, len(keys))
}
// TestTimingWheel_DrainClosureBug tests the closure capture bug in drainAll
// Issue: https://github.com/zeromicro/go-zero/issues/5314
func TestTimingWheel_DrainClosureBug(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker)
defer tw.Stop()
// Set multiple timers with different values
for i := 0; i < 10; i++ {
tw.SetTimer(i, i*10, testStep*5)
}
// Give time for timers to be set
time.Sleep(time.Millisecond * 100)
var mu sync.Mutex
received := make(map[int]int)
var wg sync.WaitGroup
wg.Add(10)
tw.Drain(func(key, value any) {
mu.Lock()
defer mu.Unlock()
k := key.(int)
v := value.(int)
received[k] = v
wg.Done()
})
wg.Wait()
// Check if all values match their keys
for k, v := range received {
expected := k * 10
assert.Equal(t, expected, v, "key %d should have value %d, got %d", k, expected, v)
}
}
// TestTimingWheel_RunTasksClosureBug tests the closure capture bug in runTasks
// Issue: https://github.com/zeromicro/go-zero/issues/5314
func TestTimingWheel_RunTasksClosureBug(t *testing.T) {
ticker := timex.NewFakeTicker()
var mu sync.Mutex
executed := make(map[int]int)
var wg sync.WaitGroup
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
mu.Lock()
defer mu.Unlock()
key := k.(int)
val := v.(int)
executed[key] = val
wg.Done()
}, ticker)
defer tw.Stop()
// Set multiple timers that should fire in the same tick
count := 10
wg.Add(count)
for i := 0; i < count; i++ {
tw.SetTimer(i, i*10, testStep)
}
// Advance ticker to trigger tasks
ticker.Tick()
// Wait for execution with timeout
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
// Success
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for tasks to execute")
}
// Verify all tasks executed with correct values
assert.Equal(t, count, len(executed), "should have executed all tasks")
for k, v := range executed {
expected := k * 10
assert.Equal(t, expected, v, "key %d should have value %d, got %d", k, expected, v)
}
}
// TestTimingWheel_RunTasksRaceCondition tests for race conditions in runTasks
// This test specifically targets the loop variable capture bug
func TestTimingWheel_RunTasksRaceCondition(t *testing.T) {
// Run multiple times to increase likelihood of catching the bug
for attempt := 0; attempt < 10; attempt++ {
t.Run("", func(t *testing.T) {
ticker := timex.NewFakeTicker()
var mu sync.Mutex
keyValues := make(map[int][]int)
var wg sync.WaitGroup
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
// Add small delay to increase chance of race
time.Sleep(time.Microsecond)
mu.Lock()
defer mu.Unlock()
key := k.(int)
val := v.(int)
keyValues[key] = append(keyValues[key], val)
wg.Done()
}, ticker)
defer tw.Stop()
// Set many timers rapidly to increase chance of race
count := 50
wg.Add(count)
for i := 0; i < count; i++ {
tw.SetTimer(i, i*100, testStep)
}
ticker.Tick()
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for tasks")
}
// Check for duplicates or wrong values
wrongCount := 0
for key, values := range keyValues {
assert.Equal(t, 1, len(values), "key %d should only execute once, got %v", key, values)
if len(values) > 0 {
expected := key * 100
if values[0] != expected {
wrongCount++
t.Logf("BUG DETECTED: key %d should have value %d, got %d", key, expected, values[0])
}
}
}
if wrongCount > 0 {
t.Errorf("Found %d tasks with wrong values due to closure bug", wrongCount)
}
})
}
}
func BenchmarkTimingWheel(b *testing.B) {
b.ReportAllocs()
tw, _ := NewTimingWheel(time.Second, 100, func(k, v any) {})
for i := 0; i < b.N; i++ {
tw.SetTimer(i, i, time.Second)
tw.SetTimer(b.N+i, b.N+i, time.Second)
tw.MoveTimer(i, time.Second*time.Duration(i))
tw.RemoveTimer(i)
}
}