diff --git a/cmd/root.go b/cmd/root.go index 59329c5c..8dda7256 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -3,6 +3,7 @@ package cmd import ( "crypto/tls" "errors" + "github.com/patrickmn/go-cache" "io" "io/fs" "log" @@ -13,6 +14,7 @@ import ( "path/filepath" "strings" "syscall" + "time" homedir "github.com/mitchellh/go-homedir" "github.com/spf13/afero" @@ -176,7 +178,9 @@ user created with the credentials from options "username" and "password".`, panic(err) } - handler, err := fbhttp.NewHandler(imgSvc, fileCache, d.store, server, assetsFs) + downloaderCache := cache.New(cache.NoExpiration, 1*time.Minute) + + handler, err := fbhttp.NewHandler(imgSvc, fileCache, d.store, server, assetsFs, downloaderCache) checkErr(err) defer listener.Close() diff --git a/downloader/downloader.go b/downloader/downloader.go deleted file mode 100644 index 937e2112..00000000 --- a/downloader/downloader.go +++ /dev/null @@ -1,96 +0,0 @@ -package downloader - -import ( - "errors" - "fmt" - "io" - "net/http" - "os" - "path" - "strings" -) - -type Downloader interface { - Download(url string, filename string, pathname string) error - GetRatio() float64 -} - -type DownloadTask struct { - Filename string - Path string - Url string -} - -func NewDownloadTask(filename string, path string, url string) *DownloadTask { - return &DownloadTask{ - Filename: filename, - Path: path, - Url: url, - } -} - -func (t *DownloadTask) Valid() error { - if t.Filename == "" { - return errors.New("filename is empty") - } - if t.Url == "" { - return errors.New("url is empty") - } - if t.Path == "" { - return errors.New("path is empty") - } - if strings.Contains(t.Path, ",,") { - return errors.New("path is invalid") - } - return nil -} - -func isExists(pathname string) bool { - _, err := os.Stat(pathname) - return err == nil || !os.IsNotExist(err) -} - -func isDir(pathname string) bool { - info, err := os.Stat(pathname) - if err != nil { - return false - } - return info.IsDir() -} - -func (t *DownloadTask) Download() error { - if err := t.Valid(); err != nil { - return err - } - if isExists(t.Path) { - if !isDir(t.Path) { - return errors.New("path is not a directory") - } - if isExists(path.Join(t.Path, t.Filename)) { - return errors.New("file already exists") - } - } else { - if err := os.Mkdir(t.Path, 0755); err != nil { - return err - } - } - resp, err := http.Get(t.Url) - if err != nil { - return fmt.Errorf("failed to download %s: %w", t.Url, err) - } - defer resp.Body.Close() - contentLength := resp.ContentLength - if contentLength < 0 { - return errors.New("failed to get content length") - } - file, err := os.Create(path.Join(t.Path, t.Filename)) - if err != nil { - return fmt.Errorf("failed to create file %s: %w", t.Filename, err) - } - defer file.Close() - _, err = io.Copy(file, resp.Body) - if err != nil { - return fmt.Errorf("failed to write file %s: %w", t.Filename, err) - } - return nil -} diff --git a/go.mod b/go.mod index a58d5c24..06c4ff53 100644 --- a/go.mod +++ b/go.mod @@ -44,6 +44,7 @@ require ( github.com/go-ole/go-ole v1.2.6 // indirect github.com/golang/geo v0.0.0-20230421003525-6adc56603217 // indirect github.com/golang/snappy v0.0.4 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/klauspost/compress v1.17.7 // indirect @@ -51,6 +52,7 @@ require ( github.com/magiconair/properties v1.8.7 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/nwaples/rardecode v1.1.3 // indirect + github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect diff --git a/go.sum b/go.sum index 4e451c05..97be363b 100644 --- a/go.sum +++ b/go.sum @@ -75,6 +75,8 @@ github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= @@ -116,6 +118,8 @@ github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RR github.com/nwaples/rardecode v1.1.0/go.mod h1:5DzqNKiOdpKKBH87u8VlvAnPZMXcGRhxWkRpHbbfGS0= github.com/nwaples/rardecode v1.1.3 h1:cWCaZwfM5H7nAD6PyEdcVnczzV8i/JtotnyW/dD9lEc= github.com/nwaples/rardecode v1.1.3/go.mod h1:5DzqNKiOdpKKBH87u8VlvAnPZMXcGRhxWkRpHbbfGS0= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= github.com/pierrec/lz4/v4 v4.1.2/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= diff --git a/http/downloader.go b/http/downloader.go index c384b525..7fe4513e 100644 --- a/http/downloader.go +++ b/http/downloader.go @@ -2,11 +2,60 @@ package http import ( "encoding/json" - "github.com/filebrowser/filebrowser/v2/downloader" + "fmt" + "github.com/filebrowser/filebrowser/v2/files" + "github.com/google/uuid" + "github.com/patrickmn/go-cache" + "github.com/spf13/afero" + "io" "net/http" + "os" + "path" ) -func downloadHandler() handleFunc { +type DownloadTask struct { + TaskID uuid.UUID `json:"taskID"` + URL string `json:"url"` + Filename string `json:"filename"` + Pathname string `json:"pathname"` + totalSize int64 + savedSize int64 + cache *cache.Cache +} + +func (d *DownloadTask) Progress() float64 { + if d.totalSize == 0 { + return 0 + } + return float64(d.savedSize) / float64(d.totalSize) +} + +func NewDownloadTask(url, filename, pathname string, downloaderCache *cache.Cache) *DownloadTask { + taskId := uuid.New() + downloadTask := &DownloadTask{ + TaskID: taskId, + URL: url, + Filename: filename, + Pathname: pathname, + cache: downloaderCache, + } + downloaderCache.Set(taskId.String(), downloadTask, cache.NoExpiration) + return downloadTask +} + +type WriteCounter struct { + task *DownloadTask +} + +func (wc *WriteCounter) Write(p []byte) (int, error) { + n := len(p) + wc.task.savedSize += int64(n) + wc.task.cache.Set(wc.task.TaskID.String(), wc.task, cache.NoExpiration) + fmt.Printf("Downloaded %d of %d bytes, percent: %.2f%%\n", wc.task.savedSize, wc.task.totalSize, wc.task.Progress()) + return n, nil +} + +func downloadHandler(downloaderCache *cache.Cache) handleFunc { return withUser(func(w http.ResponseWriter, r *http.Request, d *data) (int, error) { if !d.user.Perm.Create || !d.Check(r.URL.Path) { return http.StatusForbidden, nil @@ -19,13 +68,54 @@ func downloadHandler() handleFunc { if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil { return http.StatusBadRequest, err } - downloadTask := downloader.NewDownloadTask(params.Filename, params.Pathname, params.URL) - err := downloadTask.Download() - - if err != nil { - return http.StatusInternalServerError, err + _, err := os.Stat(path.Join(params.Pathname, params.Filename)) + if err != nil && !os.IsNotExist(err) { + return errToStatus(err), err } - return http.StatusNoContent, nil + downloadTask := NewDownloadTask(params.URL, params.Filename, params.Pathname, downloaderCache) + asyncDownloadWithTask(d.user.Fs, downloadTask) + + _, err = w.Write([]byte(downloadTask.TaskID.String())) + if err != nil { + return errToStatus(err), err + } + return 0, nil }) } + +func downloadWithTask(fs afero.Fs, task *DownloadTask) error { + err := fs.MkdirAll(task.Pathname, files.PermDir) + if err != nil { + return err + } + + file, err := fs.OpenFile(path.Join(task.Pathname, task.Filename), os.O_RDWR|os.O_CREATE|os.O_TRUNC, files.PermFile) + if err != nil { + return err + } + defer file.Close() + resp, err := http.Get(task.URL) + if err != nil { + return err + } + defer resp.Body.Close() + task.totalSize = resp.ContentLength + + _, err = io.Copy(file, io.TeeReader(resp.Body, &WriteCounter{task: task})) + if err != nil { + return err + } + + return nil +} + +func asyncDownloadWithTask(fs afero.Fs, task *DownloadTask) { + go func() { + err := downloadWithTask(fs, task) + if err != nil { + fmt.Printf("Error downloading file: %v\n", err) + return + } + }() +} diff --git a/http/http.go b/http/http.go index 05a65b28..28677173 100644 --- a/http/http.go +++ b/http/http.go @@ -1,6 +1,7 @@ package http import ( + "github.com/patrickmn/go-cache" "io/fs" "net/http" @@ -21,6 +22,7 @@ func NewHandler( store *storage.Storage, server *settings.Server, assetsFs fs.FS, + downloaderCache *cache.Cache, ) (http.Handler, error) { server.Clean() @@ -92,7 +94,7 @@ func NewHandler( public.PathPrefix("/dl").Handler(monkey(publicDlHandler, "/api/public/dl/")).Methods("GET") public.PathPrefix("/share").Handler(monkey(publicShareHandler, "/api/public/share/")).Methods("GET") - api.PathPrefix("/download").Handler(monkey(downloadHandler(), "/api/download/")).Methods("POST") + api.PathPrefix("/download").Handler(monkey(downloadHandler(downloaderCache), "/api/download/")).Methods("POST") return stripPrefix(server.BaseURL, r), nil }