Cloudreve/middleware/auth.go

274 lines
7.4 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package middleware
import (
"net/http"
"github.com/cloudreve/Cloudreve/v4/application/dependency"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/inventory"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/driver/oss"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/fs"
"github.com/cloudreve/Cloudreve/v4/pkg/filemanager/manager"
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
"github.com/cloudreve/Cloudreve/v4/pkg/request"
"github.com/cloudreve/Cloudreve/v4/pkg/util"
"github.com/cloudreve/Cloudreve/v4/pkg/auth"
"github.com/cloudreve/Cloudreve/v4/pkg/serializer"
"github.com/gin-gonic/gin"
)
const (
CallbackFailedStatusCode = http.StatusUnauthorized
)
// SignRequired 验证请求签名
func SignRequired(authInstance auth.Auth) gin.HandlerFunc {
return func(c *gin.Context) {
var err error
switch c.Request.Method {
case http.MethodPut, http.MethodPost, http.MethodPatch:
err = auth.CheckRequest(c, authInstance, c.Request)
default:
err = auth.CheckURI(c, authInstance, c.Request.URL)
}
if err != nil {
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCredentialInvalid, err.Error(), err))
c.Abort()
return
}
c.Next()
}
}
// CurrentUser 获取登录用户
func CurrentUser() gin.HandlerFunc {
return func(c *gin.Context) {
dep := dependency.FromContext(c)
shouldContinue, err := dep.TokenAuth().VerifyAndRetrieveUser(c)
if err != nil {
c.JSON(200, serializer.Err(c, err))
c.Abort()
return
}
if shouldContinue {
// TODO: Logto handler
}
uid := inventory.UserIDFromContext(c)
if err := SetUserCtx(c, uid); err != nil {
c.JSON(200, serializer.Err(c, err))
c.Abort()
return
}
c.Next()
}
}
// SetUserCtx set the current login user via uid
func SetUserCtx(c *gin.Context, uid int) error {
dep := dependency.FromContext(c)
userClient := dep.UserClient()
loginUser, err := userClient.GetLoginUserByID(c, uid)
if err != nil {
return serializer.NewError(serializer.CodeDBError, "failed to get login user", err)
}
SetUserCtxByUser(c, loginUser)
return nil
}
func SetUserCtxByUser(c *gin.Context, user *ent.User) {
util.WithValue(c, inventory.UserCtx{}, user)
}
// LoginRequired 需要登录
func LoginRequired() gin.HandlerFunc {
return func(c *gin.Context) {
if u := inventory.UserFromContext(c); u != nil && !inventory.IsAnonymousUser(u) {
c.Next()
return
}
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeCheckLogin, "Login required", nil))
c.Abort()
}
}
// WebDAVAuth 验证WebDAV登录及权限
func WebDAVAuth() gin.HandlerFunc {
return func(c *gin.Context) {
username, password, ok := c.Request.BasicAuth()
if !ok {
// OPTIONS 请求不需要鉴权
if c.Request.Method == http.MethodOptions {
c.Next()
return
}
c.Writer.Header()["WWW-Authenticate"] = []string{`Basic realm="cloudreve"`}
c.Status(http.StatusUnauthorized)
c.Abort()
return
}
dep := dependency.FromContext(c)
l := dep.Logger()
userClient := dep.UserClient()
expectedUser, err := userClient.GetActiveByDavAccount(c, username, password)
if err != nil {
if username == "" {
if u, err := userClient.GetByEmail(c, username); err == nil {
// Try login with known user but incorrect password, record audit log
SetUserCtxByUser(c, u)
}
}
l.Debug("WebDAVAuth: failed to get user %q with provided credential: %s", username, err)
c.Status(http.StatusUnauthorized)
c.Abort()
return
}
// Validate dav account
accounts, err := expectedUser.Edges.DavAccountsOrErr()
if err != nil || len(accounts) == 0 {
l.Debug("WebDAVAuth: failed to get user dav accounts %q with provided credential: %s", username, err)
c.Status(http.StatusUnauthorized)
c.Abort()
return
}
// 用户组已启用WebDAV
group, err := expectedUser.Edges.GroupOrErr()
if err != nil {
l.Debug("WebDAVAuth: user group not found: %s", err)
c.Status(http.StatusInternalServerError)
c.Abort()
return
}
if !group.Permissions.Enabled(int(types.GroupPermissionWebDAV)) {
c.Status(http.StatusForbidden)
l.Debug("WebDAVAuth: user %q does not have WebDAV permission.", expectedUser.Email)
c.Abort()
return
}
// 检查是否只读
if expectedUser.Edges.DavAccounts[0].Options.Enabled(int(types.DavAccountReadOnly)) {
switch c.Request.Method {
case http.MethodDelete, http.MethodPut, "MKCOL", "COPY", "MOVE", "LOCK", "UNLOCK":
c.Status(http.StatusForbidden)
c.Abort()
return
}
}
SetUserCtxByUser(c, expectedUser)
c.Next()
}
}
// 对上传会话进行验证
func UseUploadSession(policyType types.PolicyType) gin.HandlerFunc {
return func(c *gin.Context) {
// 验证key并查找用户
err := uploadCallbackCheck(c, policyType)
if err != nil {
c.JSON(CallbackFailedStatusCode, serializer.Err(c, err))
c.Abort()
return
}
c.Next()
}
}
// uploadCallbackCheck 对上传回调请求的 callback key 进行验证,如果成功则返回上传用户
func uploadCallbackCheck(c *gin.Context, policyType types.PolicyType) error {
// 验证 Callback Key
sessionID := c.Param("sessionID")
if sessionID == "" {
return serializer.NewError(serializer.CodeParamErr, "Session ID cannot be empty", nil)
}
dep := dependency.FromContext(c)
callbackSessionRaw, exist := dep.KV().Get(manager.UploadSessionCachePrefix + sessionID)
if !exist {
return serializer.NewError(serializer.CodeUploadSessionExpired, "Upload session does not exist or expired", nil)
}
callbackSession := callbackSessionRaw.(fs.UploadSession)
c.Set(manager.UploadSessionCtx, &callbackSession)
if callbackSession.Policy.Type != string(policyType) {
return serializer.NewError(serializer.CodePolicyNotAllowed, "", nil)
}
if err := SetUserCtx(c, callbackSession.UID); err != nil {
return err
}
return nil
}
// RemoteCallbackAuth 远程回调签名验证
func RemoteCallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// 验证签名
session := c.MustGet(manager.UploadSessionCtx).(*fs.UploadSession)
if session.Policy.Edges.Node == nil {
c.JSON(CallbackFailedStatusCode, serializer.ErrWithDetails(c, serializer.CodeCredentialInvalid, "Node not found", nil))
c.Abort()
return
}
authInstance := auth.HMACAuth{SecretKey: []byte(session.Policy.Edges.Node.SlaveKey)}
if err := auth.CheckRequest(c, authInstance, c.Request); err != nil {
c.JSON(CallbackFailedStatusCode, serializer.ErrWithDetails(c, serializer.CodeCredentialInvalid, err.Error(), err))
c.Abort()
return
}
c.Next()
}
}
// OSSCallbackAuth 阿里云OSS回调签名验证
func OSSCallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) {
dep := dependency.FromContext(c)
err := oss.VerifyCallbackSignature(c.Request, dep.KV(), dep.RequestClient(
request.WithContext(c),
request.WithLogger(logging.FromContext(c)),
))
if err != nil {
dep.Logger().Debug("Failed to verify callback request: %s", err)
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "Failed to verify callback request."})
c.Abort()
return
}
c.Next()
}
}
// IsAdmin 必须为管理员用户组
func IsAdmin() gin.HandlerFunc {
return func(c *gin.Context) {
user := inventory.UserFromContext(c)
if !user.Edges.Group.Permissions.Enabled(int(types.GroupPermissionIsAdmin)) {
c.JSON(200, serializer.ErrWithDetails(c, serializer.CodeNoPermissionErr, "", nil))
c.Abort()
return
}
c.Next()
}
}