portainer/api/http/handler/endpoint_group.go

365 lines
12 KiB
Go

package handler
import (
"github.com/portainer/portainer"
httperror "github.com/portainer/portainer/http/error"
"github.com/portainer/portainer/http/security"
"encoding/json"
"log"
"net/http"
"os"
"strconv"
"github.com/asaskevich/govalidator"
"github.com/gorilla/mux"
)
// EndpointGroupHandler represents an HTTP API handler for managing endpoint groups.
type EndpointGroupHandler struct {
*mux.Router
Logger *log.Logger
EndpointService portainer.EndpointService
EndpointGroupService portainer.EndpointGroupService
}
// NewEndpointGroupHandler returns a new instance of EndpointGroupHandler.
func NewEndpointGroupHandler(bouncer *security.RequestBouncer) *EndpointGroupHandler {
h := &EndpointGroupHandler{
Router: mux.NewRouter(),
Logger: log.New(os.Stderr, "", log.LstdFlags),
}
h.Handle("/endpoint_groups",
bouncer.AdministratorAccess(http.HandlerFunc(h.handlePostEndpointGroups))).Methods(http.MethodPost)
h.Handle("/endpoint_groups",
bouncer.RestrictedAccess(http.HandlerFunc(h.handleGetEndpointGroups))).Methods(http.MethodGet)
h.Handle("/endpoint_groups/{id}",
bouncer.AdministratorAccess(http.HandlerFunc(h.handleGetEndpointGroup))).Methods(http.MethodGet)
h.Handle("/endpoint_groups/{id}",
bouncer.AdministratorAccess(http.HandlerFunc(h.handlePutEndpointGroup))).Methods(http.MethodPut)
h.Handle("/endpoint_groups/{id}/access",
bouncer.AdministratorAccess(http.HandlerFunc(h.handlePutEndpointGroupAccess))).Methods(http.MethodPut)
h.Handle("/endpoint_groups/{id}",
bouncer.AdministratorAccess(http.HandlerFunc(h.handleDeleteEndpointGroup))).Methods(http.MethodDelete)
return h
}
type (
postEndpointGroupsResponse struct {
ID int `json:"Id"`
}
postEndpointGroupsRequest struct {
Name string `valid:"required"`
Description string `valid:"-"`
Labels []portainer.Pair `valid:""`
AssociatedEndpoints []portainer.EndpointID `valid:""`
}
putEndpointGroupAccessRequest struct {
AuthorizedUsers []int `valid:"-"`
AuthorizedTeams []int `valid:"-"`
}
putEndpointGroupsRequest struct {
Name string `valid:"-"`
Description string `valid:"-"`
Labels []portainer.Pair `valid:""`
AssociatedEndpoints []portainer.EndpointID `valid:""`
}
)
// handleGetEndpointGroups handles GET requests on /endpoint_groups
func (handler *EndpointGroupHandler) handleGetEndpointGroups(w http.ResponseWriter, r *http.Request) {
securityContext, err := security.RetrieveRestrictedRequestContext(r)
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusInternalServerError, handler.Logger)
return
}
endpointGroups, err := handler.EndpointGroupService.EndpointGroups()
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusInternalServerError, handler.Logger)
return
}
filteredEndpointGroups, err := security.FilterEndpointGroups(endpointGroups, securityContext)
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusInternalServerError, handler.Logger)
return
}
encodeJSON(w, filteredEndpointGroups, handler.Logger)
}
// handlePostEndpointGroups handles POST requests on /endpoint_groups
func (handler *EndpointGroupHandler) handlePostEndpointGroups(w http.ResponseWriter, r *http.Request) {
var req postEndpointGroupsRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httperror.WriteErrorResponse(w, ErrInvalidJSON, http.StatusBadRequest, handler.Logger)
return
}
_, err := govalidator.ValidateStruct(req)
if err != nil {
httperror.WriteErrorResponse(w, ErrInvalidRequestFormat, http.StatusBadRequest, handler.Logger)
return
}
endpointGroup := &portainer.EndpointGroup{
Name: req.Name,
Description: req.Description,
Labels: req.Labels,
AuthorizedUsers: []portainer.UserID{},
AuthorizedTeams: []portainer.TeamID{},
}
err = handler.EndpointGroupService.CreateEndpointGroup(endpointGroup)
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusInternalServerError, handler.Logger)
return
}
endpoints, err := handler.EndpointService.Endpoints()
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusInternalServerError, handler.Logger)
return
}
for _, endpoint := range endpoints {
if endpoint.GroupID == portainer.EndpointGroupID(1) {
err = handler.checkForGroupAssignment(endpoint, endpointGroup.ID, req.AssociatedEndpoints)
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusInternalServerError, handler.Logger)
return
}
}
}
encodeJSON(w, &postEndpointGroupsResponse{ID: int(endpointGroup.ID)}, handler.Logger)
}
// handleGetEndpointGroup handles GET requests on /endpoint_groups/:id
func (handler *EndpointGroupHandler) handleGetEndpointGroup(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
id := vars["id"]
endpointGroupID, err := strconv.Atoi(id)
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusBadRequest, handler.Logger)
return
}
endpointGroup, err := handler.EndpointGroupService.EndpointGroup(portainer.EndpointGroupID(endpointGroupID))
if err == portainer.ErrEndpointGroupNotFound {
httperror.WriteErrorResponse(w, err, http.StatusNotFound, handler.Logger)
return
} else if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusInternalServerError, handler.Logger)
return
}
encodeJSON(w, endpointGroup, handler.Logger)
}
// handlePutEndpointGroupAccess handles PUT requests on /endpoint_groups/:id/access
func (handler *EndpointGroupHandler) handlePutEndpointGroupAccess(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
id := vars["id"]
endpointGroupID, err := strconv.Atoi(id)
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusBadRequest, handler.Logger)
return
}
var req putEndpointGroupAccessRequest
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
httperror.WriteErrorResponse(w, ErrInvalidJSON, http.StatusBadRequest, handler.Logger)
return
}
_, err = govalidator.ValidateStruct(req)
if err != nil {
httperror.WriteErrorResponse(w, ErrInvalidRequestFormat, http.StatusBadRequest, handler.Logger)
return
}
endpointGroup, err := handler.EndpointGroupService.EndpointGroup(portainer.EndpointGroupID(endpointGroupID))
if err == portainer.ErrEndpointGroupNotFound {
httperror.WriteErrorResponse(w, err, http.StatusNotFound, handler.Logger)
return
} else if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusInternalServerError, handler.Logger)
return
}
if req.AuthorizedUsers != nil {
authorizedUserIDs := []portainer.UserID{}
for _, value := range req.AuthorizedUsers {
authorizedUserIDs = append(authorizedUserIDs, portainer.UserID(value))
}
endpointGroup.AuthorizedUsers = authorizedUserIDs
}
if req.AuthorizedTeams != nil {
authorizedTeamIDs := []portainer.TeamID{}
for _, value := range req.AuthorizedTeams {
authorizedTeamIDs = append(authorizedTeamIDs, portainer.TeamID(value))
}
endpointGroup.AuthorizedTeams = authorizedTeamIDs
}
err = handler.EndpointGroupService.UpdateEndpointGroup(endpointGroup.ID, endpointGroup)
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusInternalServerError, handler.Logger)
return
}
}
// handlePutEndpointGroup handles PUT requests on /endpoint_groups/:id
func (handler *EndpointGroupHandler) handlePutEndpointGroup(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
id := vars["id"]
endpointGroupID, err := strconv.Atoi(id)
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusBadRequest, handler.Logger)
return
}
var req putEndpointGroupsRequest
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
httperror.WriteErrorResponse(w, ErrInvalidJSON, http.StatusBadRequest, handler.Logger)
return
}
_, err = govalidator.ValidateStruct(req)
if err != nil {
httperror.WriteErrorResponse(w, ErrInvalidRequestFormat, http.StatusBadRequest, handler.Logger)
return
}
groupID := portainer.EndpointGroupID(endpointGroupID)
endpointGroup, err := handler.EndpointGroupService.EndpointGroup(groupID)
if err == portainer.ErrEndpointGroupNotFound {
httperror.WriteErrorResponse(w, err, http.StatusNotFound, handler.Logger)
return
} else if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusInternalServerError, handler.Logger)
return
}
if req.Name != "" {
endpointGroup.Name = req.Name
}
if req.Description != "" {
endpointGroup.Description = req.Description
}
endpointGroup.Labels = req.Labels
err = handler.EndpointGroupService.UpdateEndpointGroup(endpointGroup.ID, endpointGroup)
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusInternalServerError, handler.Logger)
return
}
endpoints, err := handler.EndpointService.Endpoints()
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusInternalServerError, handler.Logger)
return
}
for _, endpoint := range endpoints {
err = handler.updateEndpointGroup(endpoint, groupID, req.AssociatedEndpoints)
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusInternalServerError, handler.Logger)
return
}
}
}
func (handler *EndpointGroupHandler) updateEndpointGroup(endpoint portainer.Endpoint, groupID portainer.EndpointGroupID, associatedEndpoints []portainer.EndpointID) error {
if endpoint.GroupID == groupID {
return handler.checkForGroupUnassignment(endpoint, associatedEndpoints)
} else if endpoint.GroupID == portainer.EndpointGroupID(1) {
return handler.checkForGroupAssignment(endpoint, groupID, associatedEndpoints)
}
return nil
}
func (handler *EndpointGroupHandler) checkForGroupUnassignment(endpoint portainer.Endpoint, associatedEndpoints []portainer.EndpointID) error {
for _, id := range associatedEndpoints {
if id == endpoint.ID {
return nil
}
}
endpoint.GroupID = portainer.EndpointGroupID(1)
return handler.EndpointService.UpdateEndpoint(endpoint.ID, &endpoint)
}
func (handler *EndpointGroupHandler) checkForGroupAssignment(endpoint portainer.Endpoint, groupID portainer.EndpointGroupID, associatedEndpoints []portainer.EndpointID) error {
for _, id := range associatedEndpoints {
if id == endpoint.ID {
endpoint.GroupID = groupID
return handler.EndpointService.UpdateEndpoint(endpoint.ID, &endpoint)
}
}
return nil
}
// handleDeleteEndpointGroup handles DELETE requests on /endpoint_groups/:id
func (handler *EndpointGroupHandler) handleDeleteEndpointGroup(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
id := vars["id"]
endpointGroupID, err := strconv.Atoi(id)
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusBadRequest, handler.Logger)
return
}
if endpointGroupID == 1 {
httperror.WriteErrorResponse(w, portainer.ErrCannotRemoveDefaultGroup, http.StatusForbidden, handler.Logger)
return
}
groupID := portainer.EndpointGroupID(endpointGroupID)
_, err = handler.EndpointGroupService.EndpointGroup(groupID)
if err == portainer.ErrEndpointGroupNotFound {
httperror.WriteErrorResponse(w, err, http.StatusNotFound, handler.Logger)
return
} else if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusInternalServerError, handler.Logger)
return
}
err = handler.EndpointGroupService.DeleteEndpointGroup(portainer.EndpointGroupID(endpointGroupID))
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusInternalServerError, handler.Logger)
return
}
endpoints, err := handler.EndpointService.Endpoints()
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusInternalServerError, handler.Logger)
return
}
for _, endpoint := range endpoints {
if endpoint.GroupID == groupID {
endpoint.GroupID = portainer.EndpointGroupID(1)
err = handler.EndpointService.UpdateEndpoint(endpoint.ID, &endpoint)
if err != nil {
httperror.WriteErrorResponse(w, err, http.StatusInternalServerError, handler.Logger)
return
}
}
}
}