mirror of https://github.com/portainer/portainer
feat(security): add request rate limiter on authentication endpoint (#1866)
parent
6360e6a20b
commit
55a96767bb
|
@ -37,14 +37,14 @@ const (
|
|||
)
|
||||
|
||||
// NewAuthHandler returns a new instance of AuthHandler.
|
||||
func NewAuthHandler(bouncer *security.RequestBouncer, authDisabled bool) *AuthHandler {
|
||||
func NewAuthHandler(bouncer *security.RequestBouncer, rateLimiter *security.RateLimiter, authDisabled bool) *AuthHandler {
|
||||
h := &AuthHandler{
|
||||
Router: mux.NewRouter(),
|
||||
Logger: log.New(os.Stderr, "", log.LstdFlags),
|
||||
authDisabled: authDisabled,
|
||||
}
|
||||
h.Handle("/auth",
|
||||
bouncer.PublicAccess(http.HandlerFunc(h.handlePostAuth))).Methods(http.MethodPost)
|
||||
rateLimiter.LimitAccess(bouncer.PublicAccess(http.HandlerFunc(h.handlePostAuth)))).Methods(http.MethodPost)
|
||||
|
||||
return h
|
||||
}
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
package security
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/g07cha/defender"
|
||||
"github.com/portainer/portainer"
|
||||
httperror "github.com/portainer/portainer/http/error"
|
||||
)
|
||||
|
||||
// RateLimiter represents an entity that manages request rate limiting
|
||||
type RateLimiter struct {
|
||||
*defender.Defender
|
||||
}
|
||||
|
||||
// NewRateLimiter initializes a new RateLimiter
|
||||
func NewRateLimiter(maxRequests int, duration time.Duration, banDuration time.Duration) *RateLimiter {
|
||||
messages := make(chan struct{})
|
||||
limiter := defender.New(maxRequests, duration, banDuration)
|
||||
go limiter.CleanupTask(messages)
|
||||
return &RateLimiter{
|
||||
limiter,
|
||||
}
|
||||
}
|
||||
|
||||
// LimitAccess wraps current request with check if remote address does not goes above the defined limits
|
||||
func (limiter *RateLimiter) LimitAccess(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ip := StripAddrPort(r.RemoteAddr)
|
||||
if banned := limiter.Inc(ip); banned == true {
|
||||
httperror.WriteErrorResponse(w, portainer.ErrResourceAccessDenied, http.StatusForbidden, nil)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// StripAddrPort removes port from IP address
|
||||
func StripAddrPort(addr string) string {
|
||||
portIndex := strings.LastIndex(addr, ":")
|
||||
if portIndex != -1 {
|
||||
addr = addr[:portIndex]
|
||||
}
|
||||
return addr
|
||||
}
|
|
@ -0,0 +1,69 @@
|
|||
package security
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLimitAccess(t *testing.T) {
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
t.Run("Request below the limit", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
rateLimiter := NewRateLimiter(10, 1*time.Second, 1*time.Hour)
|
||||
handler := rateLimiter.LimitAccess(testHandler)
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if status := rr.Code; status != http.StatusOK {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v",
|
||||
status, http.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Request above the limit", func(t *testing.T) {
|
||||
rateLimiter := NewRateLimiter(1, 1*time.Second, 1*time.Hour)
|
||||
handler := rateLimiter.LimitAccess(testHandler)
|
||||
|
||||
ts := httptest.NewServer(handler)
|
||||
defer ts.Close()
|
||||
http.Get(ts.URL)
|
||||
resp, err := http.Get(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if status := resp.StatusCode; status != http.StatusForbidden {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v",
|
||||
status, http.StatusForbidden)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStripAddrPort(t *testing.T) {
|
||||
t.Run("IP with port", func(t *testing.T) {
|
||||
result := StripAddrPort("127.0.0.1:1000")
|
||||
if result != "127.0.0.1" {
|
||||
t.Errorf("Expected IP with address to be '127.0.0.1', but it was %s instead", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IP without port", func(t *testing.T) {
|
||||
result := StripAddrPort("127.0.0.1")
|
||||
if result != "127.0.0.1" {
|
||||
t.Errorf("Expected IP with address to be '127.0.0.1', but it was %s instead", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Local IP", func(t *testing.T) {
|
||||
result := StripAddrPort("[::1]:1000")
|
||||
if result != "[::1]" {
|
||||
t.Errorf("Expected IP with address to be '[::1]', but it was %s instead", result)
|
||||
}
|
||||
})
|
||||
}
|
|
@ -1,6 +1,8 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/portainer/portainer"
|
||||
"github.com/portainer/portainer/http/handler"
|
||||
"github.com/portainer/portainer/http/handler/extensions"
|
||||
|
@ -53,9 +55,10 @@ func (server *Server) Start() error {
|
|||
SignatureService: server.SignatureService,
|
||||
}
|
||||
proxyManager := proxy.NewManager(proxyManagerParameters)
|
||||
rateLimiter := security.NewRateLimiter(10, 1*time.Second, 1*time.Hour)
|
||||
|
||||
var fileHandler = handler.NewFileHandler(filepath.Join(server.AssetsPath, "public"))
|
||||
var authHandler = handler.NewAuthHandler(requestBouncer, server.AuthDisabled)
|
||||
var authHandler = handler.NewAuthHandler(requestBouncer, rateLimiter, server.AuthDisabled)
|
||||
authHandler.UserService = server.UserService
|
||||
authHandler.CryptoService = server.CryptoService
|
||||
authHandler.JWTService = server.JWTService
|
||||
|
|
Loading…
Reference in New Issue