// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 package middleware import ( "context" "errors" "net" "testing" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" pbacl "github.com/hashicorp/consul/proto-public/pbacl" "github.com/hashicorp/go-hclog" "github.com/hashicorp/consul/agent/consul/rate" ) func TestServerRateLimiterMiddleware_Integration(t *testing.T) { limiter := rate.NewMockRequestLimitsHandler(t) logger := hclog.NewNullLogger() server := grpc.NewServer( grpc.InTapHandle(ServerRateLimiterMiddleware(limiter, NewPanicHandler(logger), logger)), ) pbacl.RegisterACLServiceServer(server, mockACLServer{}) lis, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) t.Cleanup(func() { if err := lis.Close(); err != nil { t.Logf("failed to close listener: %v", err) } }) go server.Serve(lis) t.Cleanup(server.Stop) conn, err := grpc.Dial( lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), ) require.NoError(t, err) t.Cleanup(func() { if err := conn.Close(); err != nil { t.Logf("failed to close client connection: %v", err) } }) client := pbacl.NewACLServiceClient(conn) ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) t.Run("ErrRetryElsewhere = ResourceExhausted", func(t *testing.T) { limiter.On("Allow", mock.Anything). Run(func(args mock.Arguments) { op := args.Get(0).(rate.Operation) require.Equal(t, "/hashicorp.consul.acl.ACLService/Login", op.Name) addr := op.SourceAddr.(*net.TCPAddr) require.True(t, addr.IP.IsLoopback()) }). Return(rate.ErrRetryElsewhere). Once() _, err = client.Login(ctx, &pbacl.LoginRequest{}) require.Error(t, err) require.Equal(t, codes.ResourceExhausted.String(), status.Code(err).String()) }) t.Run("ErrRetryLater = Unavailable", func(t *testing.T) { limiter.On("Allow", mock.Anything). Return(rate.ErrRetryLater). Once() _, err = client.Login(ctx, &pbacl.LoginRequest{}) require.Error(t, err) require.Equal(t, codes.Unavailable.String(), status.Code(err).String()) }) t.Run("unexpected error", func(t *testing.T) { limiter.On("Allow", mock.Anything). Return(errors.New("uh oh")). Once() _, err = client.Login(ctx, &pbacl.LoginRequest{}) require.Error(t, err) require.Equal(t, codes.Internal.String(), status.Code(err).String()) }) t.Run("operation allowed", func(t *testing.T) { limiter.On("Allow", mock.Anything). Return(nil). Once() _, err = client.Login(ctx, &pbacl.LoginRequest{}) require.NoError(t, err) }) t.Run("Allow panics", func(t *testing.T) { limiter.On("Allow", mock.Anything). Panic("uh oh"). Once() _, err = client.Login(ctx, &pbacl.LoginRequest{}) require.Error(t, err) require.Equal(t, codes.Internal.String(), status.Code(err).String()) }) } type mockACLServer struct { pbacl.ACLServiceServer } func (mockACLServer) Login(context.Context, *pbacl.LoginRequest) (*pbacl.LoginResponse, error) { return &pbacl.LoginResponse{}, nil }