diff --git a/internal/model/user.go b/internal/model/user.go index b0a75867..0d0461a3 100644 --- a/internal/model/user.go +++ b/internal/model/user.go @@ -2,6 +2,7 @@ package model import ( "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/pkg/utils" "github.com/pkg/errors" ) @@ -89,3 +90,7 @@ func (u User) CanWebdavRead() bool { func (u User) CanWebdavManage() bool { return u.IsAdmin() || (u.Permission>>9)&1 == 1 } + +func (u User) JoinPath(reqPath string) (string, error) { + return utils.JoinBasePath(u.BasePath, reqPath) +} diff --git a/pkg/utils/path.go b/pkg/utils/path.go index be09a86b..0be9bb4a 100644 --- a/pkg/utils/path.go +++ b/pkg/utils/path.go @@ -6,6 +6,8 @@ import ( "path/filepath" "runtime" "strings" + + "github.com/alist-org/alist/v3/internal/errs" ) // StandardizePath convert path like '/' '/root' '/a/b' @@ -60,3 +62,10 @@ func EncodePath(path string, all ...bool) string { } return strings.Join(seg, "/") } + +func JoinBasePath(basePath, reqPath string) (string, error) { + if strings.HasSuffix(reqPath, "..") || strings.Contains(reqPath, "../") { + return "", errs.RelativePath + } + return stdpath.Join(basePath, reqPath), nil +} diff --git a/server/handles/aria2.go b/server/handles/aria2.go index 7096a83c..9948a90a 100644 --- a/server/handles/aria2.go +++ b/server/handles/aria2.go @@ -1,8 +1,6 @@ package handles import ( - stdpath "path" - "github.com/alist-org/alist/v3/internal/aria2" "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/db" @@ -58,9 +56,13 @@ func AddAria2(c *gin.Context) { common.ErrorResp(c, err, 400) return } - req.Path = stdpath.Join(user.BasePath, req.Path) + reqPath, err := user.JoinPath(req.Path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } for _, url := range req.Urls { - err := aria2.AddURI(c, url, req.Path) + err := aria2.AddURI(c, url, reqPath) if err != nil { common.ErrorResp(c, err, 500) return diff --git a/server/handles/fsmanage.go b/server/handles/fsmanage.go index dc2f9787..1cfa44e9 100644 --- a/server/handles/fsmanage.go +++ b/server/handles/fsmanage.go @@ -26,25 +26,29 @@ func FsMkdir(c *gin.Context) { return } user := c.MustGet("user").(*model.User) - req.Path = stdpath.Join(user.BasePath, req.Path) + reqPath, err := user.JoinPath(req.Path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } if !user.CanWrite() { - meta, err := db.GetNearestMeta(stdpath.Dir(req.Path)) + meta, err := db.GetNearestMeta(stdpath.Dir(reqPath)) if err != nil { if !errors.Is(errors.Cause(err), errs.MetaNotFound) { common.ErrorResp(c, err, 500, true) return } } - if !common.CanWrite(meta, req.Path) { + if !common.CanWrite(meta, reqPath) { common.ErrorResp(c, errs.PermissionDenied, 403) return } } - if err := fs.MakeDir(c, req.Path); err != nil { + if err := fs.MakeDir(c, reqPath); err != nil { common.ErrorResp(c, err, 500) return } - fs.ClearCache(stdpath.Dir(req.Path)) + fs.ClearCache(stdpath.Dir(reqPath)) common.SuccessResp(c) } @@ -69,17 +73,25 @@ func FsMove(c *gin.Context) { common.ErrorResp(c, errs.PermissionDenied, 403) return } - req.SrcDir = stdpath.Join(user.BasePath, req.SrcDir) - req.DstDir = stdpath.Join(user.BasePath, req.DstDir) + srcDir, err := user.JoinPath(req.SrcDir) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + dstDir, err := user.JoinPath(req.DstDir) + if err != nil { + common.ErrorResp(c, err, 403) + return + } for _, name := range req.Names { - err := fs.Move(c, stdpath.Join(req.SrcDir, name), req.DstDir) + err := fs.Move(c, stdpath.Join(srcDir, name), dstDir) if err != nil { common.ErrorResp(c, err, 500) return } } - fs.ClearCache(req.SrcDir) - fs.ClearCache(req.DstDir) + fs.ClearCache(srcDir) + fs.ClearCache(dstDir) common.SuccessResp(c) } @@ -98,11 +110,19 @@ func FsCopy(c *gin.Context) { common.ErrorResp(c, errs.PermissionDenied, 403) return } - req.SrcDir = stdpath.Join(user.BasePath, req.SrcDir) - req.DstDir = stdpath.Join(user.BasePath, req.DstDir) + srcDir, err := user.JoinPath(req.SrcDir) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + dstDir, err := user.JoinPath(req.DstDir) + if err != nil { + common.ErrorResp(c, err, 403) + return + } var addedTask []string for _, name := range req.Names { - ok, err := fs.Copy(c, stdpath.Join(req.SrcDir, name), req.DstDir) + ok, err := fs.Copy(c, stdpath.Join(srcDir, name), dstDir) if ok { addedTask = append(addedTask, name) } @@ -112,7 +132,7 @@ func FsCopy(c *gin.Context) { } } if len(req.Names) != len(addedTask) { - fs.ClearCache(req.DstDir) + fs.ClearCache(dstDir) } if len(addedTask) > 0 { common.SuccessResp(c, fmt.Sprintf("Added %d tasks", len(addedTask))) @@ -137,12 +157,16 @@ func FsRename(c *gin.Context) { common.ErrorResp(c, errs.PermissionDenied, 403) return } - req.Path = stdpath.Join(user.BasePath, req.Path) - if err := fs.Rename(c, req.Path, req.Name); err != nil { + reqPath, err := user.JoinPath(req.Path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + if err := fs.Rename(c, reqPath, req.Name); err != nil { common.ErrorResp(c, err, 500) return } - fs.ClearCache(stdpath.Dir(req.Path)) + fs.ClearCache(stdpath.Dir(reqPath)) common.SuccessResp(c) } @@ -166,9 +190,13 @@ func FsRemove(c *gin.Context) { common.ErrorResp(c, errs.PermissionDenied, 403) return } - req.Dir = stdpath.Join(user.BasePath, req.Dir) + reqDir, err := user.JoinPath(req.Dir) + if err != nil { + common.ErrorResp(c, err, 403) + return + } for _, name := range req.Names { - err := fs.Remove(c, stdpath.Join(req.Dir, name)) + err := fs.Remove(c, stdpath.Join(reqDir, name)) if err != nil { common.ErrorResp(c, err, 500) return @@ -185,8 +213,10 @@ func Link(c *gin.Context) { common.ErrorResp(c, err, 400) return } - user := c.MustGet("user").(*model.User) - rawPath := stdpath.Join(user.BasePath, req.Path) + //user := c.MustGet("user").(*model.User) + //rawPath := stdpath.Join(user.BasePath, req.Path) + // why need not join base_path? because it's always the full path + rawPath := req.Path storage, err := fs.GetStorage(rawPath) if err != nil { common.ErrorResp(c, err, 500) diff --git a/server/handles/fsread.go b/server/handles/fsread.go index 40ef2927..a27f4026 100644 --- a/server/handles/fsread.go +++ b/server/handles/fsread.go @@ -56,8 +56,12 @@ func FsList(c *gin.Context) { } req.Validate() user := c.MustGet("user").(*model.User) - req.Path = stdpath.Join(user.BasePath, req.Path) - meta, err := db.GetNearestMeta(req.Path) + reqPath, err := user.JoinPath(req.Path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + meta, err := db.GetNearestMeta(reqPath) if err != nil { if !errors.Is(errors.Cause(err), errs.MetaNotFound) { common.ErrorResp(c, err, 500, true) @@ -65,30 +69,30 @@ func FsList(c *gin.Context) { } } c.Set("meta", meta) - if !common.CanAccess(user, meta, req.Path, req.Password) { + if !common.CanAccess(user, meta, reqPath, req.Password) { common.ErrorStrResp(c, "password is incorrect", 403) return } - if !user.CanWrite() && !common.CanWrite(meta, req.Path) && req.Refresh { + if !user.CanWrite() && !common.CanWrite(meta, reqPath) && req.Refresh { common.ErrorStrResp(c, "Refresh without permission", 403) return } - objs, err := fs.List(c, req.Path, req.Refresh) + objs, err := fs.List(c, reqPath, req.Refresh) if err != nil { common.ErrorResp(c, err, 500) return } total, objs := pagination(objs, &req.PageReq) provider := "unknown" - storage, err := fs.GetStorage(req.Path) + storage, err := fs.GetStorage(reqPath) if err == nil { provider = storage.GetStorage().Driver } common.SuccessResp(c, FsListResp{ - Content: toObjsResp(objs, req.Path, isEncrypt(meta, req.Path)), + Content: toObjsResp(objs, reqPath, isEncrypt(meta, reqPath)), Total: int64(total), - Readme: getReadme(meta, req.Path), - Write: user.CanWrite() || common.CanWrite(meta, req.Path), + Readme: getReadme(meta, reqPath), + Write: user.CanWrite() || common.CanWrite(meta, reqPath), Provider: provider, }) } @@ -100,15 +104,21 @@ func FsDirs(c *gin.Context) { return } user := c.MustGet("user").(*model.User) + var reqPath string if req.ForceRoot { if !user.IsAdmin() { common.ErrorStrResp(c, "Permission denied", 403) return } } else { - req.Path = stdpath.Join(user.BasePath, req.Path) + tmp, err := user.JoinPath(req.Path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + reqPath = tmp } - meta, err := db.GetNearestMeta(req.Path) + meta, err := db.GetNearestMeta(reqPath) if err != nil { if !errors.Is(errors.Cause(err), errs.MetaNotFound) { common.ErrorResp(c, err, 500, true) @@ -116,11 +126,11 @@ func FsDirs(c *gin.Context) { } } c.Set("meta", meta) - if !common.CanAccess(user, meta, req.Path, req.Password) { + if !common.CanAccess(user, meta, reqPath, req.Password) { common.ErrorStrResp(c, "password is incorrect", 403) return } - objs, err := fs.List(c, req.Path) + objs, err := fs.List(c, reqPath) if err != nil { common.ErrorResp(c, err, 500) return @@ -218,8 +228,12 @@ func FsGet(c *gin.Context) { return } user := c.MustGet("user").(*model.User) - req.Path = stdpath.Join(user.BasePath, req.Path) - meta, err := db.GetNearestMeta(req.Path) + reqPath, err := user.JoinPath(req.Path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + meta, err := db.GetNearestMeta(reqPath) if err != nil { if !errors.Is(errors.Cause(err), errs.MetaNotFound) { common.ErrorResp(c, err, 500) @@ -227,18 +241,18 @@ func FsGet(c *gin.Context) { } } c.Set("meta", meta) - if !common.CanAccess(user, meta, req.Path, req.Password) { + if !common.CanAccess(user, meta, reqPath, req.Password) { common.ErrorStrResp(c, "password is incorrect", 403) return } - obj, err := fs.Get(c, req.Path) + obj, err := fs.Get(c, reqPath) if err != nil { common.ErrorResp(c, err, 500) return } var rawURL string - storage, err := fs.GetStorage(req.Path) + storage, err := fs.GetStorage(reqPath) provider := "unknown" if err == nil { provider = storage.Config().Name @@ -252,13 +266,13 @@ func FsGet(c *gin.Context) { if storage.GetStorage().DownProxyUrl != "" { rawURL = fmt.Sprintf("%s%s?sign=%s", strings.Split(storage.GetStorage().DownProxyUrl, "\n")[0], - utils.EncodePath(req.Path, true), - sign.Sign(req.Path)) + utils.EncodePath(reqPath, true), + sign.Sign(reqPath)) } else { rawURL = fmt.Sprintf("%s/p%s?sign=%s", common.GetApiUrl(c.Request), - utils.EncodePath(req.Path, true), - sign.Sign(req.Path)) + utils.EncodePath(reqPath, true), + sign.Sign(reqPath)) } } else { // file have raw url @@ -266,7 +280,7 @@ func FsGet(c *gin.Context) { rawURL = u.URL() } else { // if storage is not proxy, use raw url by fs.Link - link, _, err := fs.Link(c, req.Path, model.LinkArgs{IP: c.ClientIP(), Header: c.Request.Header}) + link, _, err := fs.Link(c, reqPath, model.LinkArgs{IP: c.ClientIP(), Header: c.Request.Header}) if err != nil { common.ErrorResp(c, err, 500) return @@ -276,7 +290,7 @@ func FsGet(c *gin.Context) { } } var related []model.Obj - parentPath := stdpath.Dir(req.Path) + parentPath := stdpath.Dir(reqPath) sameLevelFiles, err := fs.List(c, parentPath) if err == nil { related = filterRelated(sameLevelFiles, obj) @@ -288,11 +302,11 @@ func FsGet(c *gin.Context) { Size: obj.GetSize(), IsDir: obj.IsDir(), Modified: obj.ModTime(), - Sign: common.Sign(obj, parentPath, isEncrypt(meta, req.Path)), + Sign: common.Sign(obj, parentPath, isEncrypt(meta, reqPath)), Type: utils.GetFileType(obj.GetName()), }, RawURL: rawURL, - Readme: getReadme(meta, req.Path), + Readme: getReadme(meta, reqPath), Provider: provider, Related: toObjsResp(related, parentPath, isEncrypt(parentMeta, parentPath)), }) @@ -324,7 +338,12 @@ func FsOther(c *gin.Context) { return } user := c.MustGet("user").(*model.User) - req.Path = stdpath.Join(user.BasePath, req.Path) + var err error + req.Path, err = user.JoinPath(req.Path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } meta, err := db.GetNearestMeta(req.Path) if err != nil { if !errors.Is(errors.Cause(err), errs.MetaNotFound) { diff --git a/server/handles/fsup.go b/server/handles/fsup.go index a5f466f4..b32fca50 100644 --- a/server/handles/fsup.go +++ b/server/handles/fsup.go @@ -21,8 +21,11 @@ func FsStream(c *gin.Context) { } asTask := c.GetHeader("As-Task") == "true" user := c.MustGet("user").(*model.User) - path = stdpath.Join(user.BasePath, path) - + path, err = user.JoinPath(path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } dir, name := stdpath.Split(path) sizeStr := c.GetHeader("Content-Length") size, err := strconv.ParseInt(sizeStr, 10, 64) @@ -61,8 +64,11 @@ func FsForm(c *gin.Context) { } asTask := c.GetHeader("As-Task") == "true" user := c.MustGet("user").(*model.User) - path = stdpath.Join(user.BasePath, path) - + path, err = user.JoinPath(path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } storage, err := fs.GetStorage(path) if err != nil { common.ErrorResp(c, err, 400) diff --git a/server/middlewares/fsup.go b/server/middlewares/fsup.go index 91b9f470..3e503cc6 100644 --- a/server/middlewares/fsup.go +++ b/server/middlewares/fsup.go @@ -22,7 +22,11 @@ func FsUp(c *gin.Context) { return } user := c.MustGet("user").(*model.User) - path = stdpath.Join(user.BasePath, path) + path, err = user.JoinPath(path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } meta, err := db.GetNearestMeta(stdpath.Dir(path)) if err != nil { if !errors.Is(errors.Cause(err), errs.MetaNotFound) { diff --git a/server/webdav/webdav.go b/server/webdav/webdav.go index f8cb8fcf..ac9d975a 100644 --- a/server/webdav/webdav.go +++ b/server/webdav/webdav.go @@ -178,7 +178,10 @@ func (h *Handler) handleOptions(w http.ResponseWriter, r *http.Request) (status } ctx := r.Context() user := ctx.Value("user").(*model.User) - reqPath = path.Join(user.BasePath, reqPath) + reqPath, err = user.JoinPath(reqPath) + if err != nil { + return 403, err + } allow := "OPTIONS, LOCK, PUT, MKCOL" if fi, err := fs.Get(ctx, reqPath); err == nil { if fi.IsDir() { @@ -203,7 +206,10 @@ func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request) (sta // TODO: check locks for read-only access?? ctx := r.Context() user := ctx.Value("user").(*model.User) - reqPath = path.Join(user.BasePath, reqPath) + reqPath, err = user.JoinPath(reqPath) + if err != nil { + return 403, err + } fi, err := fs.Get(ctx, reqPath) if err != nil { return http.StatusNotFound, err @@ -258,7 +264,10 @@ func (h *Handler) handleDelete(w http.ResponseWriter, r *http.Request) (status i ctx := r.Context() user := ctx.Value("user").(*model.User) - reqPath = path.Join(user.BasePath, reqPath) + reqPath, err = user.JoinPath(reqPath) + if err != nil { + return 403, err + } // TODO: return MultiStatus where appropriate. // "godoc os RemoveAll" says that "If the path does not exist, RemoveAll @@ -291,7 +300,10 @@ func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request) (status int, // comments in http.checkEtag. ctx := r.Context() user := ctx.Value("user").(*model.User) - reqPath = path.Join(user.BasePath, reqPath) + reqPath, err = user.JoinPath(reqPath) + if err != nil { + return 403, err + } obj := model.Object{ Name: path.Base(reqPath), Size: r.ContentLength, @@ -337,7 +349,10 @@ func (h *Handler) handleMkcol(w http.ResponseWriter, r *http.Request) (status in ctx := r.Context() user := ctx.Value("user").(*model.User) - reqPath = path.Join(user.BasePath, reqPath) + reqPath, err = user.JoinPath(reqPath) + if err != nil { + return 403, err + } if r.ContentLength > 0 { return http.StatusUnsupportedMediaType, nil @@ -384,8 +399,14 @@ func (h *Handler) handleCopyMove(w http.ResponseWriter, r *http.Request) (status ctx := r.Context() user := ctx.Value("user").(*model.User) - src = path.Join(user.BasePath, src) - dst = path.Join(user.BasePath, dst) + src, err = user.JoinPath(src) + if err != nil { + return 403, err + } + dst, err = user.JoinPath(dst) + if err != nil { + return 403, err + } if r.Method == "COPY" { // Section 7.5.1 says that a COPY only needs to lock the destination, @@ -476,7 +497,10 @@ func (h *Handler) handleLock(w http.ResponseWriter, r *http.Request) (retStatus } } reqPath, status, err := h.stripPrefix(r.URL.Path) - reqPath = path.Join(user.BasePath, reqPath) + reqPath, err = user.JoinPath(reqPath) + if err != nil { + return 403, err + } if err != nil { return status, err } @@ -557,7 +581,10 @@ func (h *Handler) handlePropfind(w http.ResponseWriter, r *http.Request) (status } ctx := r.Context() user := ctx.Value("user").(*model.User) - reqPath = path.Join(user.BasePath, reqPath) + reqPath, err = user.JoinPath(reqPath) + if err != nil { + return 403, err + } fi, err := fs.Get(ctx, reqPath) if err != nil { if errs.IsObjectNotFound(err) { @@ -633,8 +660,10 @@ func (h *Handler) handleProppatch(w http.ResponseWriter, r *http.Request) (statu ctx := r.Context() user := ctx.Value("user").(*model.User) - reqPath = path.Join(user.BasePath, reqPath) - + reqPath, err = user.JoinPath(reqPath) + if err != nil { + return 403, err + } if _, err := fs.Get(ctx, reqPath); err != nil { if errs.IsObjectNotFound(err) { return http.StatusNotFound, err