Cloudreve/middleware/common.go

152 lines
4.3 KiB
Go

package middleware
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v4/application/constants"
"github.com/cloudreve/Cloudreve/v4/application/dependency"
"github.com/cloudreve/Cloudreve/v4/pkg/auth/requestinfo"
"github.com/cloudreve/Cloudreve/v4/pkg/cluster"
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
"github.com/cloudreve/Cloudreve/v4/pkg/request"
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
"github.com/cloudreve/Cloudreve/v4/pkg/util"
"github.com/gin-gonic/gin"
"github.com/gofrs/uuid"
"net/http"
"time"
)
// HashID 将给定对象的HashID转换为真实ID
func HashID(IDType int) gin.HandlerFunc {
return func(c *gin.Context) {
dep := dependency.FromContext(c)
if c.Param("id") != "" {
id, err := dep.HashIDEncoder().Decode(c.Param("id"), IDType)
if err == nil {
util.WithValue(c, hashid.ObjectIDCtx{}, id)
c.Next()
return
}
c.JSON(200, serializer.ParamErr(c, "Failed to parse object ID", err))
c.Abort()
return
}
c.Next()
}
}
// IsFunctionEnabled 当功能未开启时阻止访问
func IsFunctionEnabled(check func(c *gin.Context) bool) gin.HandlerFunc {
return func(c *gin.Context) {
if !check(c) {
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeFeatureNotEnabled, "This feature is not enabled", nil))
c.Abort()
return
}
c.Next()
}
}
// CacheControl 屏蔽客户端缓存
func CacheControl() gin.HandlerFunc {
return func(c *gin.Context) {
c.Header("Cache-Control", "private, no-cache")
}
}
func Sandbox() gin.HandlerFunc {
return func(c *gin.Context) {
c.Header("Content-Security-Policy", "sandbox")
}
}
// StaticResourceCache 使用静态资源缓存策略
func StaticResourceCache(dep dependency.Dep) gin.HandlerFunc {
settings := dep.SettingProvider()
return func(c *gin.Context) {
c.Header("Cache-Control", fmt.Sprintf("public, max-age=%d", settings.PublicResourceMaxAge(c)))
}
}
// MobileRequestOnly
func MobileRequestOnly() gin.HandlerFunc {
return func(c *gin.Context) {
dep := dependency.FromContext(c)
if c.GetHeader(constants.CrHeaderPrefix+"ios") == "" {
c.Redirect(http.StatusMovedPermanently, dep.SettingProvider().SiteURL(c).String())
c.Abort()
return
}
c.Next()
}
}
// InitializeHandling is added at the beginning of handler chain, it did following setups:
// 1. Inject dependency manager into request context
// 2. Generate and inject correlation ID for diagnostic.
func InitializeHandling(dep dependency.Dep) gin.HandlerFunc {
return func(c *gin.Context) {
reqInfo := &requestinfo.RequestInfo{
IP: c.ClientIP(),
Host: c.Request.Host,
UserAgent: c.Request.UserAgent(),
}
cid := uuid.FromStringOrNil(c.GetHeader(request.CorrelationHeader))
if cid == uuid.Nil {
cid = uuid.Must(uuid.NewV4())
}
l := dep.Logger().CopyWithPrefix(fmt.Sprintf("[Cid: %s]", cid))
ctx := dep.ForkWithLogger(c.Request.Context(), l)
ctx = context.WithValue(ctx, logging.CorrelationIDCtx{}, cid)
ctx = context.WithValue(ctx, requestinfo.RequestInfoCtx{}, reqInfo)
ctx = context.WithValue(ctx, logging.LoggerCtx{}, l)
if id := c.Param("nodeId"); id != "" {
ctx = context.WithValue(ctx, cluster.SlaveNodeIDCtx{}, id)
} else {
ctx = context.WithValue(ctx, cluster.SlaveNodeIDCtx{}, c.GetHeader(request.SlaveNodeIDHeader))
}
c.Request = c.Request.WithContext(ctx)
c.Next()
}
}
// InitializeHandlingSlave retrieves coll correlation ID and other metadata from request header
func InitializeHandlingSlave() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := context.WithValue(c.Request.Context(), cluster.MasterSiteIDCtx{}, c.GetHeader(request.SiteIDHeader))
ctx = context.WithValue(ctx, cluster.MasterSiteUrlCtx{}, c.GetHeader(request.SiteURLHeader))
ctx = context.WithValue(ctx, cluster.MasterSiteVersionCtx{}, c.GetHeader(request.SiteVersionHeader))
c.Request = c.Request.WithContext(ctx)
c.Next()
}
}
// Logging logs incoming request info
func Logging() gin.HandlerFunc {
return func(c *gin.Context) {
// Start timer
start := time.Now()
path := c.Request.URL.Path
raw := c.Request.URL.RawQuery
// Process request
c.Next()
if raw != "" {
path = path + "?" + raw
}
l := logging.FromContext(c)
logging.Request(l, true, c.Writer.Status(), c.Request.Method, c.ClientIP(), path,
c.Errors.ByType(gin.ErrorTypePrivate).String(), start)
}
}