mirror of https://github.com/portainer/portainer
feat(security): add request rate limiter on authentication endpoint (#1866)
parent
6360e6a20b
commit
55a96767bb
|
@ -37,14 +37,14 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewAuthHandler returns a new instance of AuthHandler.
|
// NewAuthHandler returns a new instance of AuthHandler.
|
||||||
func NewAuthHandler(bouncer *security.RequestBouncer, authDisabled bool) *AuthHandler {
|
func NewAuthHandler(bouncer *security.RequestBouncer, rateLimiter *security.RateLimiter, authDisabled bool) *AuthHandler {
|
||||||
h := &AuthHandler{
|
h := &AuthHandler{
|
||||||
Router: mux.NewRouter(),
|
Router: mux.NewRouter(),
|
||||||
Logger: log.New(os.Stderr, "", log.LstdFlags),
|
Logger: log.New(os.Stderr, "", log.LstdFlags),
|
||||||
authDisabled: authDisabled,
|
authDisabled: authDisabled,
|
||||||
}
|
}
|
||||||
h.Handle("/auth",
|
h.Handle("/auth",
|
||||||
bouncer.PublicAccess(http.HandlerFunc(h.handlePostAuth))).Methods(http.MethodPost)
|
rateLimiter.LimitAccess(bouncer.PublicAccess(http.HandlerFunc(h.handlePostAuth)))).Methods(http.MethodPost)
|
||||||
|
|
||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/g07cha/defender"
|
||||||
|
"github.com/portainer/portainer"
|
||||||
|
httperror "github.com/portainer/portainer/http/error"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RateLimiter represents an entity that manages request rate limiting
|
||||||
|
type RateLimiter struct {
|
||||||
|
*defender.Defender
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRateLimiter initializes a new RateLimiter
|
||||||
|
func NewRateLimiter(maxRequests int, duration time.Duration, banDuration time.Duration) *RateLimiter {
|
||||||
|
messages := make(chan struct{})
|
||||||
|
limiter := defender.New(maxRequests, duration, banDuration)
|
||||||
|
go limiter.CleanupTask(messages)
|
||||||
|
return &RateLimiter{
|
||||||
|
limiter,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LimitAccess wraps current request with check if remote address does not goes above the defined limits
|
||||||
|
func (limiter *RateLimiter) LimitAccess(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ip := StripAddrPort(r.RemoteAddr)
|
||||||
|
if banned := limiter.Inc(ip); banned == true {
|
||||||
|
httperror.WriteErrorResponse(w, portainer.ErrResourceAccessDenied, http.StatusForbidden, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// StripAddrPort removes port from IP address
|
||||||
|
func StripAddrPort(addr string) string {
|
||||||
|
portIndex := strings.LastIndex(addr, ":")
|
||||||
|
if portIndex != -1 {
|
||||||
|
addr = addr[:portIndex]
|
||||||
|
}
|
||||||
|
return addr
|
||||||
|
}
|
|
@ -0,0 +1,69 @@
|
||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLimitAccess(t *testing.T) {
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Request below the limit", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
rateLimiter := NewRateLimiter(10, 1*time.Second, 1*time.Hour)
|
||||||
|
handler := rateLimiter.LimitAccess(testHandler)
|
||||||
|
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if status := rr.Code; status != http.StatusOK {
|
||||||
|
t.Errorf("handler returned wrong status code: got %v want %v",
|
||||||
|
status, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Request above the limit", func(t *testing.T) {
|
||||||
|
rateLimiter := NewRateLimiter(1, 1*time.Second, 1*time.Hour)
|
||||||
|
handler := rateLimiter.LimitAccess(testHandler)
|
||||||
|
|
||||||
|
ts := httptest.NewServer(handler)
|
||||||
|
defer ts.Close()
|
||||||
|
http.Get(ts.URL)
|
||||||
|
resp, err := http.Get(ts.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status := resp.StatusCode; status != http.StatusForbidden {
|
||||||
|
t.Errorf("handler returned wrong status code: got %v want %v",
|
||||||
|
status, http.StatusForbidden)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStripAddrPort(t *testing.T) {
|
||||||
|
t.Run("IP with port", func(t *testing.T) {
|
||||||
|
result := StripAddrPort("127.0.0.1:1000")
|
||||||
|
if result != "127.0.0.1" {
|
||||||
|
t.Errorf("Expected IP with address to be '127.0.0.1', but it was %s instead", result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("IP without port", func(t *testing.T) {
|
||||||
|
result := StripAddrPort("127.0.0.1")
|
||||||
|
if result != "127.0.0.1" {
|
||||||
|
t.Errorf("Expected IP with address to be '127.0.0.1', but it was %s instead", result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Local IP", func(t *testing.T) {
|
||||||
|
result := StripAddrPort("[::1]:1000")
|
||||||
|
if result != "[::1]" {
|
||||||
|
t.Errorf("Expected IP with address to be '[::1]', but it was %s instead", result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -1,6 +1,8 @@
|
||||||
package http
|
package http
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/portainer/portainer"
|
"github.com/portainer/portainer"
|
||||||
"github.com/portainer/portainer/http/handler"
|
"github.com/portainer/portainer/http/handler"
|
||||||
"github.com/portainer/portainer/http/handler/extensions"
|
"github.com/portainer/portainer/http/handler/extensions"
|
||||||
|
@ -53,9 +55,10 @@ func (server *Server) Start() error {
|
||||||
SignatureService: server.SignatureService,
|
SignatureService: server.SignatureService,
|
||||||
}
|
}
|
||||||
proxyManager := proxy.NewManager(proxyManagerParameters)
|
proxyManager := proxy.NewManager(proxyManagerParameters)
|
||||||
|
rateLimiter := security.NewRateLimiter(10, 1*time.Second, 1*time.Hour)
|
||||||
|
|
||||||
var fileHandler = handler.NewFileHandler(filepath.Join(server.AssetsPath, "public"))
|
var fileHandler = handler.NewFileHandler(filepath.Join(server.AssetsPath, "public"))
|
||||||
var authHandler = handler.NewAuthHandler(requestBouncer, server.AuthDisabled)
|
var authHandler = handler.NewAuthHandler(requestBouncer, rateLimiter, server.AuthDisabled)
|
||||||
authHandler.UserService = server.UserService
|
authHandler.UserService = server.UserService
|
||||||
authHandler.CryptoService = server.CryptoService
|
authHandler.CryptoService = server.CryptoService
|
||||||
authHandler.JWTService = server.JWTService
|
authHandler.JWTService = server.JWTService
|
||||||
|
|
Loading…
Reference in New Issue