mirror of https://github.com/portainer/portainer
67 lines
1.4 KiB
Go
67 lines
1.4 KiB
Go
package csrf
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"fmt"
|
|
"net/http"
|
|
|
|
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
|
|
|
gorillacsrf "github.com/gorilla/csrf"
|
|
"github.com/portainer/portainer/api/http/security"
|
|
"github.com/urfave/negroni"
|
|
)
|
|
|
|
func WithProtect(handler http.Handler) (http.Handler, error) {
|
|
handler = withSendCSRFToken(handler)
|
|
|
|
token := make([]byte, 32)
|
|
_, err := rand.Read(token)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate CSRF token: %w", err)
|
|
}
|
|
|
|
handler = gorillacsrf.Protect(
|
|
[]byte(token),
|
|
gorillacsrf.Path("/"),
|
|
gorillacsrf.Secure(false),
|
|
)(handler)
|
|
|
|
return withSkipCSRF(handler), nil
|
|
}
|
|
|
|
func withSendCSRFToken(handler http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
sw := negroni.NewResponseWriter(w)
|
|
|
|
sw.Before(func(sw negroni.ResponseWriter) {
|
|
statusCode := sw.Status()
|
|
if statusCode >= 200 && statusCode < 300 {
|
|
csrfToken := gorillacsrf.Token(r)
|
|
sw.Header().Set("X-CSRF-Token", csrfToken)
|
|
}
|
|
})
|
|
|
|
handler.ServeHTTP(sw, r)
|
|
|
|
})
|
|
}
|
|
|
|
func withSkipCSRF(handler http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
skip, err := security.ShouldSkipCSRFCheck(r)
|
|
if err != nil {
|
|
httperror.WriteError(w, http.StatusForbidden, err.Error(), err)
|
|
return
|
|
}
|
|
|
|
if skip {
|
|
r = gorillacsrf.UnsafeSkipCheck(r)
|
|
}
|
|
|
|
handler.ServeHTTP(w, r)
|
|
})
|
|
}
|