portainer/api/http/security/rate_limiter.go

48 lines
1.3 KiB
Go
Raw Normal View History

package security
import (
"net/http"
"strings"
"time"
"github.com/g07cha/defender"
httperror "github.com/portainer/libhttp/error"
"github.com/portainer/portainer"
)
// RateLimiter represents an entity that manages request rate limiting
type RateLimiter struct {
*defender.Defender
}
// NewRateLimiter initializes a new RateLimiter
func NewRateLimiter(maxRequests int, duration time.Duration, banDuration time.Duration) *RateLimiter {
messages := make(chan struct{})
limiter := defender.New(maxRequests, duration, banDuration)
go limiter.CleanupTask(messages)
return &RateLimiter{
limiter,
}
}
// LimitAccess wraps current request with check if remote address does not goes above the defined limits
func (limiter *RateLimiter) LimitAccess(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := StripAddrPort(r.RemoteAddr)
if banned := limiter.Inc(ip); banned == true {
httperror.WriteError(w, http.StatusForbidden, "Access denied", portainer.ErrResourceAccessDenied)
return
}
next.ServeHTTP(w, r)
})
}
// StripAddrPort removes port from IP address
func StripAddrPort(addr string) string {
portIndex := strings.LastIndex(addr, ":")
if portIndex != -1 {
addr = addr[:portIndex]
}
return addr
}