|
|
|
@ -5,20 +5,18 @@ package rate
|
|
|
|
|
|
|
|
|
|
import ( |
|
|
|
|
"bytes" |
|
|
|
|
"context" |
|
|
|
|
"github.com/hashicorp/consul/agent/metrics" |
|
|
|
|
"github.com/stretchr/testify/require" |
|
|
|
|
"net" |
|
|
|
|
"net/netip" |
|
|
|
|
"testing" |
|
|
|
|
|
|
|
|
|
"golang.org/x/time/rate" |
|
|
|
|
|
|
|
|
|
"github.com/stretchr/testify/mock" |
|
|
|
|
"github.com/stretchr/testify/require" |
|
|
|
|
|
|
|
|
|
"github.com/hashicorp/go-hclog" |
|
|
|
|
"github.com/stretchr/testify/mock" |
|
|
|
|
|
|
|
|
|
"github.com/hashicorp/consul/agent/consul/multilimiter" |
|
|
|
|
"github.com/hashicorp/consul/agent/metrics" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
//
|
|
|
|
@ -226,10 +224,10 @@ func TestHandler(t *testing.T) {
|
|
|
|
|
for desc, tc := range testCases { |
|
|
|
|
t.Run(desc, func(t *testing.T) { |
|
|
|
|
sink := metrics.TestSetupMetrics(t, "") |
|
|
|
|
limiter := newMockLimiter(t) |
|
|
|
|
limiter := multilimiter.NewMockRateLimiter(t) |
|
|
|
|
limiter.On("UpdateConfig", mock.Anything, mock.Anything).Return() |
|
|
|
|
for _, c := range tc.checks { |
|
|
|
|
limiter.On("Allow", c.limit).Return(c.allow) |
|
|
|
|
limiter.On("Allow", mock.Anything).Return(c.allow) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
leaderStatusProvider := NewMockLeaderStatusProvider(t) |
|
|
|
@ -376,7 +374,7 @@ func TestAllow(t *testing.T) {
|
|
|
|
|
type testCase struct { |
|
|
|
|
description string |
|
|
|
|
cfg *HandlerConfig |
|
|
|
|
expectedAllowCalls int |
|
|
|
|
expectedAllowCalls bool |
|
|
|
|
} |
|
|
|
|
testCases := []testCase{ |
|
|
|
|
{ |
|
|
|
@ -390,7 +388,7 @@ func TestAllow(t *testing.T) {
|
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
expectedAllowCalls: 0, |
|
|
|
|
expectedAllowCalls: false, |
|
|
|
|
}, |
|
|
|
|
{ |
|
|
|
|
description: "RateLimiter gets called when mode is permissive.", |
|
|
|
@ -403,7 +401,7 @@ func TestAllow(t *testing.T) {
|
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
expectedAllowCalls: 1, |
|
|
|
|
expectedAllowCalls: true, |
|
|
|
|
}, |
|
|
|
|
{ |
|
|
|
|
description: "RateLimiter gets called when mode is enforcing.", |
|
|
|
@ -416,14 +414,14 @@ func TestAllow(t *testing.T) {
|
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
expectedAllowCalls: 1, |
|
|
|
|
expectedAllowCalls: true, |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for _, tc := range testCases { |
|
|
|
|
t.Run(tc.description, func(t *testing.T) { |
|
|
|
|
mockRateLimiter := multilimiter.NewMockRateLimiter(t) |
|
|
|
|
if tc.expectedAllowCalls > 0 { |
|
|
|
|
if tc.expectedAllowCalls { |
|
|
|
|
mockRateLimiter.On("Allow", mock.Anything).Return(func(entity multilimiter.LimitedEntity) bool { return true }) |
|
|
|
|
} |
|
|
|
|
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return() |
|
|
|
@ -435,31 +433,7 @@ func TestAllow(t *testing.T) {
|
|
|
|
|
addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234")) |
|
|
|
|
mockRateLimiter.Calls = nil |
|
|
|
|
handler.Allow(Operation{Name: "test", SourceAddr: addr}) |
|
|
|
|
mockRateLimiter.AssertNumberOfCalls(t, "Allow", tc.expectedAllowCalls) |
|
|
|
|
mockRateLimiter.AssertExpectations(t) |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
var _ multilimiter.RateLimiter = (*mockLimiter)(nil) |
|
|
|
|
|
|
|
|
|
func newMockLimiter(t *testing.T) *mockLimiter { |
|
|
|
|
l := &mockLimiter{} |
|
|
|
|
l.Mock.Test(t) |
|
|
|
|
|
|
|
|
|
t.Cleanup(func() { l.AssertExpectations(t) }) |
|
|
|
|
|
|
|
|
|
return l |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type mockLimiter struct { |
|
|
|
|
mock.Mock |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (m *mockLimiter) Allow(v multilimiter.LimitedEntity) bool { return m.Called(v).Bool(0) } |
|
|
|
|
func (m *mockLimiter) Run(ctx context.Context) { m.Called(ctx) } |
|
|
|
|
func (m *mockLimiter) UpdateConfig(cfg multilimiter.LimiterConfig, prefix []byte) { |
|
|
|
|
m.Called(cfg, prefix) |
|
|
|
|
} |
|
|
|
|
func (m *mockLimiter) DeleteConfig(prefix []byte) { |
|
|
|
|
m.Called(prefix) |
|
|
|
|
} |
|
|
|
|