package csrf

import (
	"crypto/rand"
	"fmt"
	"net/http"

	httperror "github.com/portainer/portainer/pkg/libhttp/error"

	gorillacsrf "github.com/gorilla/csrf"
	"github.com/portainer/portainer/api/http/security"
	"github.com/urfave/negroni"
)

func WithProtect(handler http.Handler) (http.Handler, error) {
	handler = withSendCSRFToken(handler)

	token := make([]byte, 32)
	_, err := rand.Read(token)
	if err != nil {
		return nil, fmt.Errorf("failed to generate CSRF token: %w", err)
	}

	handler = gorillacsrf.Protect([]byte(token), gorillacsrf.Path("/"))(handler)

	return withSkipCSRF(handler), nil
}

func withSendCSRFToken(handler http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

		sw := negroni.NewResponseWriter(w)

		sw.Before(func(sw negroni.ResponseWriter) {
			statusCode := sw.Status()
			if statusCode >= 200 && statusCode < 300 {
				csrfToken := gorillacsrf.Token(r)
				sw.Header().Set("X-CSRF-Token", csrfToken)
			}
		})

		handler.ServeHTTP(sw, r)

	})
}

func withSkipCSRF(handler http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

		skip, err := security.ShouldSkipCSRFCheck(r)
		if err != nil {
			httperror.WriteError(w, http.StatusForbidden, err.Error(), err)
			return
		}

		if skip {
			r = gorillacsrf.UnsafeSkipCheck(r)
		}

		handler.ServeHTTP(w, r)
	})
}