package limit import ( "context" "testing" "time" "git.wishpal.cn/wishpal_ironfan/xframe/component/logger" "git.wishpal.cn/wishpal_ironfan/xframe/component/storage/redis" "git.wishpal.cn/wishpal_ironfan/xframe/component/storage/redis/redistest" "github.com/alicebob/miniredis/v2" "github.com/stretchr/testify/assert" ) func init() { logger.Disable() } func TestTokenLimit_WithCtx(t *testing.T) { s, err := miniredis.Run() assert.Nil(t, err) rds, err := redis.NewRedis(redis.RedisConf{ Host: s.Addr(), Type: redis.NodeType, }) assert.Nil(t, err) const ( total = 100 rate = 5 burst = 10 ) l := NewTokenLimiter(rate, burst, rds, "tokenlimit") defer s.Close() ctx, cancel := context.WithCancel(context.Background()) ok := l.AllowCtx(ctx) assert.True(t, ok) cancel() for i := 0; i < total; i++ { ok := l.AllowCtx(ctx) assert.False(t, ok) assert.False(t, l.monitorStarted) } } func TestTokenLimit_Rescue(t *testing.T) { s, err := miniredis.Run() assert.Nil(t, err) rds, err := redis.NewRedis(redis.RedisConf{ Host: s.Addr(), Type: redis.NodeType, }) assert.Nil(t, err) const ( total = 100 rate = 5 burst = 10 ) l := NewTokenLimiter(rate, burst, rds, "tokenlimit") s.Close() var allowed int for i := 0; i < total; i++ { time.Sleep(time.Second / time.Duration(total)) if i == total>>1 { assert.Nil(t, s.Restart()) } if l.Allow() { allowed++ } // make sure start monitor more than once doesn't matter l.startMonitor() } assert.True(t, allowed >= burst+rate) } func TestTokenLimit_Take(t *testing.T) { store, err := redistest.CreateRedis(t) assert.Nil(t, err) const ( total = 100 rate = 5 burst = 10 ) l := NewTokenLimiter(rate, burst, store, "tokenlimit") var allowed int for i := 0; i < total; i++ { time.Sleep(time.Second / time.Duration(total)) if l.Allow() { allowed++ } } assert.True(t, allowed >= burst+rate) } func TestTokenLimit_TakeBurst(t *testing.T) { store, err := redistest.CreateRedis(t) assert.Nil(t, err) const ( total = 100 rate = 5 burst = 10 ) l := NewTokenLimiter(rate, burst, store, "tokenlimit") var allowed int for i := 0; i < total; i++ { if l.Allow() { allowed++ } } assert.True(t, allowed >= burst) }