package security import ( "net/http" "strings" "time" "github.com/g07cha/defender" httperror "github.com/portainer/libhttp/error" "github.com/portainer/portainer" ) // RateLimiter represents an entity that manages request rate limiting type RateLimiter struct { *defender.Defender } // NewRateLimiter initializes a new RateLimiter func NewRateLimiter(maxRequests int, duration time.Duration, banDuration time.Duration) *RateLimiter { messages := make(chan struct{}) limiter := defender.New(maxRequests, duration, banDuration) go limiter.CleanupTask(messages) return &RateLimiter{ limiter, } } // LimitAccess wraps current request with check if remote address does not goes above the defined limits func (limiter *RateLimiter) LimitAccess(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := StripAddrPort(r.RemoteAddr) if banned := limiter.Inc(ip); banned == true { httperror.WriteError(w, http.StatusForbidden, "Access denied", portainer.ErrResourceAccessDenied) return } next.ServeHTTP(w, r) }) } // StripAddrPort removes port from IP address func StripAddrPort(addr string) string { portIndex := strings.LastIndex(addr, ":") if portIndex != -1 { addr = addr[:portIndex] } return addr }