From b72c948a9f018b3f4e53e6c639b6ace817275568 Mon Sep 17 00:00:00 2001 From: banbxio Date: Tue, 25 Mar 2025 00:43:05 +0800 Subject: [PATCH] refactor(downloader): update wget to http --- downloader/downloader.go | 90 ++++++++++++++++++++++++++++++++++++++++ downloader/wget.go | 49 ---------------------- http/downloader.go | 16 ++++--- 3 files changed, 100 insertions(+), 55 deletions(-) delete mode 100644 downloader/wget.go diff --git a/downloader/downloader.go b/downloader/downloader.go index d80955d1..937e2112 100644 --- a/downloader/downloader.go +++ b/downloader/downloader.go @@ -1,6 +1,96 @@ package downloader +import ( + "errors" + "fmt" + "io" + "net/http" + "os" + "path" + "strings" +) + type Downloader interface { Download(url string, filename string, pathname string) error GetRatio() float64 } + +type DownloadTask struct { + Filename string + Path string + Url string +} + +func NewDownloadTask(filename string, path string, url string) *DownloadTask { + return &DownloadTask{ + Filename: filename, + Path: path, + Url: url, + } +} + +func (t *DownloadTask) Valid() error { + if t.Filename == "" { + return errors.New("filename is empty") + } + if t.Url == "" { + return errors.New("url is empty") + } + if t.Path == "" { + return errors.New("path is empty") + } + if strings.Contains(t.Path, ",,") { + return errors.New("path is invalid") + } + return nil +} + +func isExists(pathname string) bool { + _, err := os.Stat(pathname) + return err == nil || !os.IsNotExist(err) +} + +func isDir(pathname string) bool { + info, err := os.Stat(pathname) + if err != nil { + return false + } + return info.IsDir() +} + +func (t *DownloadTask) Download() error { + if err := t.Valid(); err != nil { + return err + } + if isExists(t.Path) { + if !isDir(t.Path) { + return errors.New("path is not a directory") + } + if isExists(path.Join(t.Path, t.Filename)) { + return errors.New("file already exists") + } + } else { + if err := os.Mkdir(t.Path, 0755); err != nil { + return err + } + } + resp, err := http.Get(t.Url) + if err != nil { + return fmt.Errorf("failed to download %s: %w", t.Url, err) + } + defer resp.Body.Close() + contentLength := resp.ContentLength + if contentLength < 0 { + return errors.New("failed to get content length") + } + file, err := os.Create(path.Join(t.Path, t.Filename)) + if err != nil { + return fmt.Errorf("failed to create file %s: %w", t.Filename, err) + } + defer file.Close() + _, err = io.Copy(file, resp.Body) + if err != nil { + return fmt.Errorf("failed to write file %s: %w", t.Filename, err) + } + return nil +} diff --git a/downloader/wget.go b/downloader/wget.go deleted file mode 100644 index dcbfec98..00000000 --- a/downloader/wget.go +++ /dev/null @@ -1,49 +0,0 @@ -package downloader - -import ( - "fmt" - "os" - "os/exec" - "path/filepath" -) - -type Wget struct { - URL string `json:"url,omitempty"` - Filename string `json:"filename,omitempty"` - Pathname string `json:"pathname,omitempty"` - total int64 - received int64 -} - -func newWget(url string, filename string, pathname string) *Wget { - return &Wget{ - URL: url, - Filename: filename, - Pathname: pathname, - } -} - -func (w *Wget) Download(url string, filename string, pathname string) error { - _, err := os.Stat(pathname) - if err != nil && os.IsNotExist(err) { - err := os.Mkdir(pathname, 0755) - if err != nil { - return err - } - } - downloadFilepath := filepath.Join(pathname, filename) - _, err = os.Stat(downloadFilepath) - if err != nil && os.IsExist(err) { - return err - } - output, err := exec.Command("wget", "-O", downloadFilepath, url).CombinedOutput() - if err != nil { - return err - } - fmt.Printf("%s\n", output) - return nil -} - -func (w *Wget) GetRatio() float64 { - return float64(w.received) / float64(w.total) -} diff --git a/http/downloader.go b/http/downloader.go index a41dabc6..c384b525 100644 --- a/http/downloader.go +++ b/http/downloader.go @@ -2,23 +2,27 @@ package http import ( "encoding/json" - "fmt" "github.com/filebrowser/filebrowser/v2/downloader" "net/http" ) func downloadHandler() handleFunc { return withUser(func(w http.ResponseWriter, r *http.Request, d *data) (int, error) { - fmt.Printf("wget: %v\n", d.user.Perm.Create) - if !d.user.Perm.Create { + if !d.user.Perm.Create || !d.Check(r.URL.Path) { return http.StatusForbidden, nil } - var wget downloader.Wget - if err := json.NewDecoder(r.Body).Decode(&wget); err != nil { + var params struct { + URL string `json:"url"` + Filename string `json:"filename"` + Pathname string `json:"pathname"` + } + if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil { return http.StatusBadRequest, err } + downloadTask := downloader.NewDownloadTask(params.Filename, params.Pathname, params.URL) + + err := downloadTask.Download() - err := wget.Download(wget.URL, wget.Filename, wget.Pathname) if err != nil { return http.StatusInternalServerError, err }