feat(security): add request rate limiter on authentication endpoint (#1866)

pull/1872/head
Konstantin Azizov 2018-05-07 20:01:39 +02:00 committed by Anthony Lapenna
parent 6360e6a20b
commit 55a96767bb
4 changed files with 122 additions and 3 deletions

View File

@ -37,14 +37,14 @@ const (
)
// NewAuthHandler returns a new instance of AuthHandler.
func NewAuthHandler(bouncer *security.RequestBouncer, authDisabled bool) *AuthHandler {
func NewAuthHandler(bouncer *security.RequestBouncer, rateLimiter *security.RateLimiter, authDisabled bool) *AuthHandler {
h := &AuthHandler{
Router: mux.NewRouter(),
Logger: log.New(os.Stderr, "", log.LstdFlags),
authDisabled: authDisabled,
}
h.Handle("/auth",
bouncer.PublicAccess(http.HandlerFunc(h.handlePostAuth))).Methods(http.MethodPost)
rateLimiter.LimitAccess(bouncer.PublicAccess(http.HandlerFunc(h.handlePostAuth)))).Methods(http.MethodPost)
return h
}

View File

@ -0,0 +1,47 @@
package security
import (
"net/http"
"strings"
"time"
"github.com/g07cha/defender"
"github.com/portainer/portainer"
httperror "github.com/portainer/portainer/http/error"
)
// RateLimiter represents an entity that manages request rate limiting
type RateLimiter struct {
*defender.Defender
}
// NewRateLimiter initializes a new RateLimiter
func NewRateLimiter(maxRequests int, duration time.Duration, banDuration time.Duration) *RateLimiter {
messages := make(chan struct{})
limiter := defender.New(maxRequests, duration, banDuration)
go limiter.CleanupTask(messages)
return &RateLimiter{
limiter,
}
}
// LimitAccess wraps current request with check if remote address does not goes above the defined limits
func (limiter *RateLimiter) LimitAccess(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := StripAddrPort(r.RemoteAddr)
if banned := limiter.Inc(ip); banned == true {
httperror.WriteErrorResponse(w, portainer.ErrResourceAccessDenied, http.StatusForbidden, nil)
return
}
next.ServeHTTP(w, r)
})
}
// StripAddrPort removes port from IP address
func StripAddrPort(addr string) string {
portIndex := strings.LastIndex(addr, ":")
if portIndex != -1 {
addr = addr[:portIndex]
}
return addr
}

View File

@ -0,0 +1,69 @@
package security
import (
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestLimitAccess(t *testing.T) {
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
t.Run("Request below the limit", func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
rateLimiter := NewRateLimiter(10, 1*time.Second, 1*time.Hour)
handler := rateLimiter.LimitAccess(testHandler)
handler.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
})
t.Run("Request above the limit", func(t *testing.T) {
rateLimiter := NewRateLimiter(1, 1*time.Second, 1*time.Hour)
handler := rateLimiter.LimitAccess(testHandler)
ts := httptest.NewServer(handler)
defer ts.Close()
http.Get(ts.URL)
resp, err := http.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
if status := resp.StatusCode; status != http.StatusForbidden {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusForbidden)
}
})
}
func TestStripAddrPort(t *testing.T) {
t.Run("IP with port", func(t *testing.T) {
result := StripAddrPort("127.0.0.1:1000")
if result != "127.0.0.1" {
t.Errorf("Expected IP with address to be '127.0.0.1', but it was %s instead", result)
}
})
t.Run("IP without port", func(t *testing.T) {
result := StripAddrPort("127.0.0.1")
if result != "127.0.0.1" {
t.Errorf("Expected IP with address to be '127.0.0.1', but it was %s instead", result)
}
})
t.Run("Local IP", func(t *testing.T) {
result := StripAddrPort("[::1]:1000")
if result != "[::1]" {
t.Errorf("Expected IP with address to be '[::1]', but it was %s instead", result)
}
})
}

View File

@ -1,6 +1,8 @@
package http
import (
"time"
"github.com/portainer/portainer"
"github.com/portainer/portainer/http/handler"
"github.com/portainer/portainer/http/handler/extensions"
@ -53,9 +55,10 @@ func (server *Server) Start() error {
SignatureService: server.SignatureService,
}
proxyManager := proxy.NewManager(proxyManagerParameters)
rateLimiter := security.NewRateLimiter(10, 1*time.Second, 1*time.Hour)
var fileHandler = handler.NewFileHandler(filepath.Join(server.AssetsPath, "public"))
var authHandler = handler.NewAuthHandler(requestBouncer, server.AuthDisabled)
var authHandler = handler.NewAuthHandler(requestBouncer, rateLimiter, server.AuthDisabled)
authHandler.UserService = server.UserService
authHandler.CryptoService = server.CryptoService
authHandler.JWTService = server.JWTService