mirror of https://github.com/Xhofe/alist
				
				
				
			
							parent
							
								
									18165eb50d
								
							
						
					
					
						commit
						a3b631f9e9
					
				|  | @ -4,6 +4,7 @@ import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
|  | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/alist-org/alist/v3/internal/driver" | 	"github.com/alist-org/alist/v3/internal/driver" | ||||||
| 	"github.com/alist-org/alist/v3/internal/model" | 	"github.com/alist-org/alist/v3/internal/model" | ||||||
|  | @ -15,7 +16,8 @@ import ( | ||||||
| type SMB struct { | type SMB struct { | ||||||
| 	model.Storage | 	model.Storage | ||||||
| 	Addition | 	Addition | ||||||
| 	fs *smb2.Share | 	fs           *smb2.Share | ||||||
|  | 	lastConnTime time.Time | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (d *SMB) Config() driver.Config { | func (d *SMB) Config() driver.Config { | ||||||
|  | @ -43,11 +45,16 @@ func (d *SMB) Drop(ctx context.Context) error { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (d *SMB) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { | func (d *SMB) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { | ||||||
|  | 	if err := d.checkConn(); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
| 	fullPath := d.getSMBPath(dir) | 	fullPath := d.getSMBPath(dir) | ||||||
| 	rawFiles, err := d.fs.ReadDir(fullPath) | 	rawFiles, err := d.fs.ReadDir(fullPath) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | 		d.cleanLastConnTime() | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | 	d.updateLastConnTime() | ||||||
| 	var files []model.Obj | 	var files []model.Obj | ||||||
| 	for _, f := range rawFiles { | 	for _, f := range rawFiles { | ||||||
| 		file := model.ObjThumb{ | 		file := model.ObjThumb{ | ||||||
|  | @ -69,46 +76,69 @@ func (d *SMB) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]m | ||||||
| //}
 | //}
 | ||||||
| 
 | 
 | ||||||
| func (d *SMB) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { | func (d *SMB) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { | ||||||
|  | 	if err := d.checkConn(); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
| 	fullPath := d.getSMBPath(file) | 	fullPath := d.getSMBPath(file) | ||||||
| 	remoteFile, err := d.fs.Open(fullPath) | 	remoteFile, err := d.fs.Open(fullPath) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | 		d.cleanLastConnTime() | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | 	d.updateLastConnTime() | ||||||
| 	return &model.Link{ | 	return &model.Link{ | ||||||
| 		Data: remoteFile, | 		Data: remoteFile, | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (d *SMB) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { | func (d *SMB) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { | ||||||
|  | 	if err := d.checkConn(); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
| 	fullPath := filepath.Join(d.getSMBPath(parentDir), dirName) | 	fullPath := filepath.Join(d.getSMBPath(parentDir), dirName) | ||||||
| 	err := d.fs.MkdirAll(fullPath, 0700) | 	err := d.fs.MkdirAll(fullPath, 0700) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | 		d.cleanLastConnTime() | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | 	d.updateLastConnTime() | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (d *SMB) Move(ctx context.Context, srcObj, dstDir model.Obj) error { | func (d *SMB) Move(ctx context.Context, srcObj, dstDir model.Obj) error { | ||||||
|  | 	if err := d.checkConn(); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
| 	srcPath := d.getSMBPath(srcObj) | 	srcPath := d.getSMBPath(srcObj) | ||||||
| 	dstPath := filepath.Join(d.getSMBPath(dstDir), srcObj.GetName()) | 	dstPath := filepath.Join(d.getSMBPath(dstDir), srcObj.GetName()) | ||||||
| 	err := d.fs.Rename(srcPath, dstPath) | 	err := d.fs.Rename(srcPath, dstPath) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | 		d.cleanLastConnTime() | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | 	d.updateLastConnTime() | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (d *SMB) Rename(ctx context.Context, srcObj model.Obj, newName string) error { | func (d *SMB) Rename(ctx context.Context, srcObj model.Obj, newName string) error { | ||||||
|  | 	if err := d.checkConn(); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
| 	srcPath := d.getSMBPath(srcObj) | 	srcPath := d.getSMBPath(srcObj) | ||||||
| 	dstPath := filepath.Join(filepath.Dir(srcPath), newName) | 	dstPath := filepath.Join(filepath.Dir(srcPath), newName) | ||||||
| 	err := d.fs.Rename(srcPath, dstPath) | 	err := d.fs.Rename(srcPath, dstPath) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | 		d.cleanLastConnTime() | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | 	d.updateLastConnTime() | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (d *SMB) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { | func (d *SMB) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { | ||||||
|  | 	if err := d.checkConn(); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
| 	srcPath := d.getSMBPath(srcObj) | 	srcPath := d.getSMBPath(srcObj) | ||||||
| 	dstPath := filepath.Join(d.getSMBPath(dstDir), srcObj.GetName()) | 	dstPath := filepath.Join(d.getSMBPath(dstDir), srcObj.GetName()) | ||||||
| 	var err error | 	var err error | ||||||
|  | @ -118,12 +148,17 @@ func (d *SMB) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { | ||||||
| 		err = d.CopyFile(srcPath, dstPath) | 		err = d.CopyFile(srcPath, dstPath) | ||||||
| 	} | 	} | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | 		d.cleanLastConnTime() | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | 	d.updateLastConnTime() | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (d *SMB) Remove(ctx context.Context, obj model.Obj) error { | func (d *SMB) Remove(ctx context.Context, obj model.Obj) error { | ||||||
|  | 	if err := d.checkConn(); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
| 	var err error | 	var err error | ||||||
| 	fullPath := d.getSMBPath(obj) | 	fullPath := d.getSMBPath(obj) | ||||||
| 	if obj.IsDir() { | 	if obj.IsDir() { | ||||||
|  | @ -132,17 +167,24 @@ func (d *SMB) Remove(ctx context.Context, obj model.Obj) error { | ||||||
| 		err = d.fs.Remove(fullPath) | 		err = d.fs.Remove(fullPath) | ||||||
| 	} | 	} | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | 		d.cleanLastConnTime() | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | 	d.updateLastConnTime() | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (d *SMB) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { | func (d *SMB) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { | ||||||
|  | 	if err := d.checkConn(); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
| 	fullPath := filepath.Join(d.getSMBPath(dstDir), stream.GetName()) | 	fullPath := filepath.Join(d.getSMBPath(dstDir), stream.GetName()) | ||||||
| 	out, err := d.fs.Create(fullPath) | 	out, err := d.fs.Create(fullPath) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | 		d.cleanLastConnTime() | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | 	d.updateLastConnTime() | ||||||
| 	defer func() { | 	defer func() { | ||||||
| 		_ = out.Close() | 		_ = out.Close() | ||||||
| 		if errors.Is(err, context.Canceled) { | 		if errors.Is(err, context.Canceled) { | ||||||
|  |  | ||||||
|  | @ -6,11 +6,20 @@ import ( | ||||||
| 	"net" | 	"net" | ||||||
| 	"os" | 	"os" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
|  | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/alist-org/alist/v3/internal/model" | 	"github.com/alist-org/alist/v3/internal/model" | ||||||
| 	"github.com/hirochachacha/go-smb2" | 	"github.com/hirochachacha/go-smb2" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | func (d *SMB) updateLastConnTime() { | ||||||
|  | 	d.lastConnTime = time.Now() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (d *SMB) cleanLastConnTime() { | ||||||
|  | 	d.lastConnTime = time.Now().AddDate(0, 0, -1) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (d *SMB) initFS() error { | func (d *SMB) initFS() error { | ||||||
| 	conn, err := net.Dial("tcp", d.Address) | 	conn, err := net.Dial("tcp", d.Address) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -30,9 +39,20 @@ func (d *SMB) initFS() error { | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | 	d.updateLastConnTime() | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (d *SMB) checkConn() error { | ||||||
|  | 	if time.Since(d.lastConnTime) < 5*time.Minute { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	if d.fs != nil { | ||||||
|  | 		_ = d.fs.Umount() | ||||||
|  | 	} | ||||||
|  | 	return d.initFS() | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (d *SMB) getSMBPath(dir model.Obj) string { | func (d *SMB) getSMBPath(dir model.Obj) string { | ||||||
| 	fullPath := dir.GetPath() | 	fullPath := dir.GetPath() | ||||||
| 	if fullPath[0:1] != "." { | 	if fullPath[0:1] != "." { | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 BoYanZh
						BoYanZh