2018-05-07 18:01:39 +00:00
|
|
|
package security
|
|
|
|
|
|
|
|
import (
|
2023-05-12 20:55:27 +00:00
|
|
|
"io"
|
2018-05-07 18:01:39 +00:00
|
|
|
"net/http"
|
|
|
|
"net/http/httptest"
|
|
|
|
"testing"
|
|
|
|
"time"
|
|
|
|
)
|
|
|
|
|
|
|
|
func TestLimitAccess(t *testing.T) {
|
|
|
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
})
|
|
|
|
|
|
|
|
t.Run("Request below the limit", func(t *testing.T) {
|
|
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
|
|
rr := httptest.NewRecorder()
|
|
|
|
rateLimiter := NewRateLimiter(10, 1*time.Second, 1*time.Hour)
|
|
|
|
handler := rateLimiter.LimitAccess(testHandler)
|
|
|
|
|
|
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
|
|
|
|
if status := rr.Code; status != http.StatusOK {
|
|
|
|
t.Errorf("handler returned wrong status code: got %v want %v",
|
|
|
|
status, http.StatusOK)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
|
|
|
|
t.Run("Request above the limit", func(t *testing.T) {
|
|
|
|
rateLimiter := NewRateLimiter(1, 1*time.Second, 1*time.Hour)
|
|
|
|
handler := rateLimiter.LimitAccess(testHandler)
|
|
|
|
|
|
|
|
ts := httptest.NewServer(handler)
|
|
|
|
defer ts.Close()
|
|
|
|
http.Get(ts.URL)
|
|
|
|
resp, err := http.Get(ts.URL)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
2023-05-12 20:55:27 +00:00
|
|
|
io.Copy(io.Discard, resp.Body)
|
|
|
|
resp.Body.Close()
|
2018-05-07 18:01:39 +00:00
|
|
|
|
|
|
|
if status := resp.StatusCode; status != http.StatusForbidden {
|
|
|
|
t.Errorf("handler returned wrong status code: got %v want %v",
|
|
|
|
status, http.StatusForbidden)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestStripAddrPort(t *testing.T) {
|
|
|
|
t.Run("IP with port", func(t *testing.T) {
|
|
|
|
result := StripAddrPort("127.0.0.1:1000")
|
|
|
|
if result != "127.0.0.1" {
|
|
|
|
t.Errorf("Expected IP with address to be '127.0.0.1', but it was %s instead", result)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
|
|
|
|
t.Run("IP without port", func(t *testing.T) {
|
|
|
|
result := StripAddrPort("127.0.0.1")
|
|
|
|
if result != "127.0.0.1" {
|
|
|
|
t.Errorf("Expected IP with address to be '127.0.0.1', but it was %s instead", result)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
|
|
|
|
t.Run("Local IP", func(t *testing.T) {
|
|
|
|
result := StripAddrPort("[::1]:1000")
|
|
|
|
if result != "[::1]" {
|
|
|
|
t.Errorf("Expected IP with address to be '[::1]', but it was %s instead", result)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|