From a3b631f9e9f33fb9e63ace58e799dca9598424df Mon Sep 17 00:00:00 2001 From: BoYanZh Date: Sun, 30 Oct 2022 15:05:07 +0800 Subject: [PATCH] fix(smb): remount smb before each operation (close #2123 pr #2140) --- drivers/smb/driver.go | 44 ++++++++++++++++++++++++++++++++++++++++++- drivers/smb/util.go | 20 ++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/drivers/smb/driver.go b/drivers/smb/driver.go index 4589d5a4..a11315d2 100644 --- a/drivers/smb/driver.go +++ b/drivers/smb/driver.go @@ -4,6 +4,7 @@ import ( "context" "errors" "path/filepath" + "time" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/model" @@ -15,7 +16,8 @@ import ( type SMB struct { model.Storage Addition - fs *smb2.Share + fs *smb2.Share + lastConnTime time.Time } func (d *SMB) Config() driver.Config { @@ -43,11 +45,16 @@ func (d *SMB) Drop(ctx context.Context) error { } func (d *SMB) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + if err := d.checkConn(); err != nil { + return nil, err + } fullPath := d.getSMBPath(dir) rawFiles, err := d.fs.ReadDir(fullPath) if err != nil { + d.cleanLastConnTime() return nil, err } + d.updateLastConnTime() var files []model.Obj for _, f := range rawFiles { file := model.ObjThumb{ @@ -69,46 +76,69 @@ func (d *SMB) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]m //} func (d *SMB) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + if err := d.checkConn(); err != nil { + return nil, err + } fullPath := d.getSMBPath(file) remoteFile, err := d.fs.Open(fullPath) if err != nil { + d.cleanLastConnTime() return nil, err } + d.updateLastConnTime() return &model.Link{ Data: remoteFile, }, nil } func (d *SMB) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + if err := d.checkConn(); err != nil { + return err + } fullPath := filepath.Join(d.getSMBPath(parentDir), dirName) err := d.fs.MkdirAll(fullPath, 0700) if err != nil { + d.cleanLastConnTime() return err } + d.updateLastConnTime() return nil } func (d *SMB) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + if err := d.checkConn(); err != nil { + return err + } srcPath := d.getSMBPath(srcObj) dstPath := filepath.Join(d.getSMBPath(dstDir), srcObj.GetName()) err := d.fs.Rename(srcPath, dstPath) if err != nil { + d.cleanLastConnTime() return err } + d.updateLastConnTime() return nil } func (d *SMB) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + if err := d.checkConn(); err != nil { + return err + } srcPath := d.getSMBPath(srcObj) dstPath := filepath.Join(filepath.Dir(srcPath), newName) err := d.fs.Rename(srcPath, dstPath) if err != nil { + d.cleanLastConnTime() return err } + d.updateLastConnTime() return nil } func (d *SMB) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + if err := d.checkConn(); err != nil { + return err + } srcPath := d.getSMBPath(srcObj) dstPath := filepath.Join(d.getSMBPath(dstDir), srcObj.GetName()) var err error @@ -118,12 +148,17 @@ func (d *SMB) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { err = d.CopyFile(srcPath, dstPath) } if err != nil { + d.cleanLastConnTime() return err } + d.updateLastConnTime() return nil } func (d *SMB) Remove(ctx context.Context, obj model.Obj) error { + if err := d.checkConn(); err != nil { + return err + } var err error fullPath := d.getSMBPath(obj) if obj.IsDir() { @@ -132,17 +167,24 @@ func (d *SMB) Remove(ctx context.Context, obj model.Obj) error { err = d.fs.Remove(fullPath) } if err != nil { + d.cleanLastConnTime() return err } + d.updateLastConnTime() return nil } func (d *SMB) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + if err := d.checkConn(); err != nil { + return err + } fullPath := filepath.Join(d.getSMBPath(dstDir), stream.GetName()) out, err := d.fs.Create(fullPath) if err != nil { + d.cleanLastConnTime() return err } + d.updateLastConnTime() defer func() { _ = out.Close() if errors.Is(err, context.Canceled) { diff --git a/drivers/smb/util.go b/drivers/smb/util.go index 144ec171..87c04009 100644 --- a/drivers/smb/util.go +++ b/drivers/smb/util.go @@ -6,11 +6,20 @@ import ( "net" "os" "path/filepath" + "time" "github.com/alist-org/alist/v3/internal/model" "github.com/hirochachacha/go-smb2" ) +func (d *SMB) updateLastConnTime() { + d.lastConnTime = time.Now() +} + +func (d *SMB) cleanLastConnTime() { + d.lastConnTime = time.Now().AddDate(0, 0, -1) +} + func (d *SMB) initFS() error { conn, err := net.Dial("tcp", d.Address) if err != nil { @@ -30,9 +39,20 @@ func (d *SMB) initFS() error { if err != nil { return err } + d.updateLastConnTime() return err } +func (d *SMB) checkConn() error { + if time.Since(d.lastConnTime) < 5*time.Minute { + return nil + } + if d.fs != nil { + _ = d.fs.Umount() + } + return d.initFS() +} + func (d *SMB) getSMBPath(dir model.Obj) string { fullPath := dir.GetPath() if fullPath[0:1] != "." {