mirror of https://github.com/cloudreve/Cloudreve
Feat: aria2 download and transfer in slave node (#1040)
* Feat: retrieve nodes from data table * Feat: master node ping slave node in REST API * Feat: master send scheduled ping request * Feat: inactive nodes recover loop * Modify: remove database operations from aria2 RPC caller implementation * Feat: init aria2 client in master node * Feat: Round Robin load balancer * Feat: create and monitor aria2 task in master node * Feat: salve receive and handle heartbeat * Fix: Node ID will be 0 in download record generated in older version * Feat: sign request headers with all `X-` prefix * Feat: API call to slave node will carry meta data in headers * Feat: call slave aria2 rpc method from master * Feat: get slave aria2 task status Feat: encode slave response data using gob * Feat: aria2 callback to master node / cancel or select task to slave node * Fix: use dummy aria2 client when caller initialize failed in master node * Feat: slave aria2 status event callback / salve RPC auth * Feat: prototype for slave driven filesystem * Feat: retry for init aria2 client in master node * Feat: init request client with global options * Feat: slave receive async task from master * Fix: competition write in request header * Refactor: dependency initialize order * Feat: generic message queue implementation * Feat: message queue implementation * Feat: master waiting slave transfer result * Feat: slave transfer file in stateless policy * Feat: slave transfer file in slave policy * Feat: slave transfer file in local policy * Feat: slave transfer file in OneDrive policy * Fix: failed to initialize update checker http client * Feat: list slave nodes for dashboard * Feat: test aria2 rpc connection in slave * Feat: add and save node * Feat: add and delete node in node pool * Fix: temp file cannot be removed when aria2 task fails * Fix: delete node in admin panel * Feat: edit node and get node info * Modify: delete unused settingspull/1044/head
parent
a3b4a22dbc
commit
056de22edb
2
assets
2
assets
|
|
@ -1 +1 @@
|
||||||
Subproject commit 59890e6b22d69befa8b742a64967b6bab1bb4a3d
|
Subproject commit 8a61a8e4c238ed60a107ace23717cf8f03f957f6
|
||||||
|
|
@ -34,7 +34,7 @@ type GitHubRelease struct {
|
||||||
|
|
||||||
// CheckUpdate 检查更新
|
// CheckUpdate 检查更新
|
||||||
func CheckUpdate() {
|
func CheckUpdate() {
|
||||||
client := request.HTTPClient{}
|
client := request.NewClient()
|
||||||
res, err := client.Request("GET", "https://api.github.com/repos/cloudreve/cloudreve/releases", nil).GetResponse()
|
res, err := client.Request("GET", "https://api.github.com/repos/cloudreve/cloudreve/releases", nil).GetResponse()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.Log().Warning("更新检查失败, %s", err)
|
util.Log().Warning("更新检查失败, %s", err)
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,11 @@ import (
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/crontab"
|
"github.com/cloudreve/Cloudreve/v3/pkg/crontab"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/email"
|
"github.com/cloudreve/Cloudreve/v3/pkg/email"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
@ -20,14 +22,86 @@ func Init(path string) {
|
||||||
if !conf.SystemConfig.Debug {
|
if !conf.SystemConfig.Debug {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dependencies := []struct {
|
||||||
|
mode string
|
||||||
|
factory func()
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"both",
|
||||||
|
func() {
|
||||||
cache.Init()
|
cache.Init()
|
||||||
if conf.SystemConfig.Mode == "master" {
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"master",
|
||||||
|
func() {
|
||||||
model.Init()
|
model.Init()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"both",
|
||||||
|
func() {
|
||||||
task.Init()
|
task.Init()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"master",
|
||||||
|
func() {
|
||||||
|
cluster.Init()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"master",
|
||||||
|
func() {
|
||||||
aria2.Init(false)
|
aria2.Init(false)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"master",
|
||||||
|
func() {
|
||||||
email.Init()
|
email.Init()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"master",
|
||||||
|
func() {
|
||||||
crontab.Init()
|
crontab.Init()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"master",
|
||||||
|
func() {
|
||||||
InitStatic()
|
InitStatic()
|
||||||
}
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"slave",
|
||||||
|
func() {
|
||||||
|
slave.Init()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"both",
|
||||||
|
func() {
|
||||||
auth.Init()
|
auth.Init()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dependency := range dependencies {
|
||||||
|
switch dependency.mode {
|
||||||
|
case "master":
|
||||||
|
if conf.SystemConfig.Mode == "master" {
|
||||||
|
dependency.factory()
|
||||||
|
}
|
||||||
|
case "slave":
|
||||||
|
if conf.SystemConfig.Mode == "slave" {
|
||||||
|
dependency.factory()
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
dependency.factory()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
1
go.mod
1
go.mod
|
|
@ -16,6 +16,7 @@ require (
|
||||||
github.com/gin-gonic/gin v1.5.0
|
github.com/gin-gonic/gin v1.5.0
|
||||||
github.com/go-ini/ini v1.50.0
|
github.com/go-ini/ini v1.50.0
|
||||||
github.com/go-mail/mail v2.3.1+incompatible
|
github.com/go-mail/mail v2.3.1+incompatible
|
||||||
|
github.com/gofrs/uuid v4.0.0+incompatible
|
||||||
github.com/gomodule/redigo v2.0.0+incompatible
|
github.com/gomodule/redigo v2.0.0+incompatible
|
||||||
github.com/google/go-querystring v1.0.0
|
github.com/google/go-querystring v1.0.0
|
||||||
github.com/gorilla/websocket v1.4.1
|
github.com/gorilla/websocket v1.4.1
|
||||||
|
|
|
||||||
2
go.sum
2
go.sum
|
|
@ -77,6 +77,8 @@ github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG
|
||||||
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
|
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
|
||||||
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||||
|
github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw=
|
||||||
|
github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
|
||||||
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||||
github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
|
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
|
||||||
|
|
|
||||||
|
|
@ -22,16 +22,14 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// SignRequired 验证请求签名
|
// SignRequired 验证请求签名
|
||||||
func SignRequired() gin.HandlerFunc {
|
func SignRequired(authInstance auth.Auth) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
var err error
|
var err error
|
||||||
switch c.Request.Method {
|
switch c.Request.Method {
|
||||||
case "PUT", "POST":
|
case "PUT", "POST", "PATCH":
|
||||||
err = auth.CheckRequest(auth.General, c.Request)
|
err = auth.CheckRequest(authInstance, c.Request)
|
||||||
// TODO 生产环境去掉下一行
|
|
||||||
//err = nil
|
|
||||||
default:
|
default:
|
||||||
err = auth.CheckURI(auth.General, c.Request.URL)
|
err = auth.CheckURI(authInstance, c.Request.URL)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -87,11 +87,10 @@ func TestAuthRequired(t *testing.T) {
|
||||||
|
|
||||||
func TestSignRequired(t *testing.T) {
|
func TestSignRequired(t *testing.T) {
|
||||||
asserts := assert.New(t)
|
asserts := assert.New(t)
|
||||||
auth.General = auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(rec)
|
c, _ := gin.CreateTestContext(rec)
|
||||||
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
||||||
SignRequiredFunc := SignRequired()
|
SignRequiredFunc := SignRequired(auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))})
|
||||||
|
|
||||||
// 鉴权失败
|
// 鉴权失败
|
||||||
SignRequiredFunc(c)
|
SignRequiredFunc(c)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MasterMetadata 解析主机节点发来请求的包含主机节点信息的元数据
|
||||||
|
func MasterMetadata() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
c.Set("MasterSiteID", c.GetHeader("X-Site-Id"))
|
||||||
|
c.Set("MasterSiteURL", c.GetHeader("X-Site-Url"))
|
||||||
|
c.Set("MasterVersion", c.GetHeader("X-Cloudreve-Version"))
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UseSlaveAria2Instance 从机用于获取对应主机节点的Aria2实例
|
||||||
|
func UseSlaveAria2Instance() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if siteID, exist := c.Get("MasterSiteID"); exist {
|
||||||
|
// 获取对应主机节点的从机Aria2实例
|
||||||
|
caller, err := slave.DefaultController.GetAria2Instance(siteID.(string))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(200, serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err))
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set("MasterAria2Instance", caller)
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(200, serializer.ParamErr("未知的主机节点ID", nil))
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func SlaveRPCSignRequired() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
nodeID, err := strconv.ParseUint(c.GetHeader("X-Node-Id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(200, serializer.ParamErr("未知的主机节点ID", err))
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
slaveNode := cluster.Default.GetNodeByID(uint(nodeID))
|
||||||
|
if slaveNode == nil {
|
||||||
|
c.JSON(200, serializer.ParamErr("未知的主机节点ID", err))
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
SignRequired(slaveNode.MasterAuthInstance())(c)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -24,6 +24,7 @@ type Download struct {
|
||||||
Dst string `gorm:"type:text"` // 用户文件系统存储父目录路径
|
Dst string `gorm:"type:text"` // 用户文件系统存储父目录路径
|
||||||
UserID uint // 发起者UID
|
UserID uint // 发起者UID
|
||||||
TaskID uint // 对应的转存任务ID
|
TaskID uint // 对应的转存任务ID
|
||||||
|
NodeID uint // 处理任务的节点ID
|
||||||
|
|
||||||
// 关联模型
|
// 关联模型
|
||||||
User *User `gorm:"PRELOAD:false,association_autoupdate:false"`
|
User *User `gorm:"PRELOAD:false,association_autoupdate:false"`
|
||||||
|
|
@ -114,3 +115,13 @@ func (task *Download) GetOwner() *User {
|
||||||
func (download *Download) Delete() error {
|
func (download *Download) Delete() error {
|
||||||
return DB.Model(download).Delete(download).Error
|
return DB.Model(download).Delete(download).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetNodeID 返回任务所属节点ID
|
||||||
|
func (task *Download) GetNodeID() uint {
|
||||||
|
// 兼容3.4版本之前生成的下载记录
|
||||||
|
if task.NodeID == 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return task.NodeID
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||||
"github.com/fatih/color"
|
"github.com/fatih/color"
|
||||||
|
"github.com/gofrs/uuid"
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -34,8 +35,9 @@ func migration() {
|
||||||
if conf.DatabaseConfig.Type == "mysql" {
|
if conf.DatabaseConfig.Type == "mysql" {
|
||||||
DB = DB.Set("gorm:table_options", "ENGINE=InnoDB")
|
DB = DB.Set("gorm:table_options", "ENGINE=InnoDB")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &Share{},
|
DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &Share{},
|
||||||
&Task{}, &Download{}, &Tag{}, &Webdav{})
|
&Task{}, &Download{}, &Tag{}, &Webdav{}, &Node{})
|
||||||
|
|
||||||
// 创建初始存储策略
|
// 创建初始存储策略
|
||||||
addDefaultPolicy()
|
addDefaultPolicy()
|
||||||
|
|
@ -73,6 +75,8 @@ func addDefaultPolicy() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func addDefaultSettings() {
|
func addDefaultSettings() {
|
||||||
|
siteID, _ := uuid.NewV4()
|
||||||
|
|
||||||
defaultSettings := []Setting{
|
defaultSettings := []Setting{
|
||||||
{Name: "siteURL", Value: `http://localhost`, Type: "basic"},
|
{Name: "siteURL", Value: `http://localhost`, Type: "basic"},
|
||||||
{Name: "siteName", Value: `Cloudreve`, Type: "basic"},
|
{Name: "siteName", Value: `Cloudreve`, Type: "basic"},
|
||||||
|
|
@ -83,6 +87,7 @@ func addDefaultSettings() {
|
||||||
{Name: "siteDes", Value: `Cloudreve`, Type: "basic"},
|
{Name: "siteDes", Value: `Cloudreve`, Type: "basic"},
|
||||||
{Name: "siteTitle", Value: `平步云端`, Type: "basic"},
|
{Name: "siteTitle", Value: `平步云端`, Type: "basic"},
|
||||||
{Name: "siteScript", Value: ``, Type: "basic"},
|
{Name: "siteScript", Value: ``, Type: "basic"},
|
||||||
|
{Name: "siteID", Value: siteID.String(), Type: "basic"},
|
||||||
{Name: "fromName", Value: `Cloudreve`, Type: "mail"},
|
{Name: "fromName", Value: `Cloudreve`, Type: "mail"},
|
||||||
{Name: "mail_keepalive", Value: `30`, Type: "mail"},
|
{Name: "mail_keepalive", Value: `30`, Type: "mail"},
|
||||||
{Name: "fromAdress", Value: `no-reply@acg.blue`, Type: "mail"},
|
{Name: "fromAdress", Value: `no-reply@acg.blue`, Type: "mail"},
|
||||||
|
|
@ -100,10 +105,13 @@ func addDefaultSettings() {
|
||||||
{Name: "upload_credential_timeout", Value: `1800`, Type: "timeout"},
|
{Name: "upload_credential_timeout", Value: `1800`, Type: "timeout"},
|
||||||
{Name: "upload_session_timeout", Value: `86400`, Type: "timeout"},
|
{Name: "upload_session_timeout", Value: `86400`, Type: "timeout"},
|
||||||
{Name: "slave_api_timeout", Value: `60`, Type: "timeout"},
|
{Name: "slave_api_timeout", Value: `60`, Type: "timeout"},
|
||||||
|
{Name: "slave_node_retry", Value: `3`, Type: "slave"},
|
||||||
|
{Name: "slave_ping_interval", Value: `60`, Type: "slave"},
|
||||||
|
{Name: "slave_recover_interval", Value: `120`, Type: "slave"},
|
||||||
|
{Name: "slave_transfer_timeout", Value: `172800`, Type: "timeout"},
|
||||||
{Name: "onedrive_monitor_timeout", Value: `600`, Type: "timeout"},
|
{Name: "onedrive_monitor_timeout", Value: `600`, Type: "timeout"},
|
||||||
{Name: "share_download_session_timeout", Value: `2073600`, Type: "timeout"},
|
{Name: "share_download_session_timeout", Value: `2073600`, Type: "timeout"},
|
||||||
{Name: "onedrive_callback_check", Value: `20`, Type: "timeout"},
|
{Name: "onedrive_callback_check", Value: `20`, Type: "timeout"},
|
||||||
{Name: "aria2_call_timeout", Value: `5`, Type: "timeout"},
|
|
||||||
{Name: "folder_props_timeout", Value: `300`, Type: "timeout"},
|
{Name: "folder_props_timeout", Value: `300`, Type: "timeout"},
|
||||||
{Name: "onedrive_chunk_retries", Value: `1`, Type: "retry"},
|
{Name: "onedrive_chunk_retries", Value: `1`, Type: "retry"},
|
||||||
{Name: "onedrive_source_timeout", Value: `1800`, Type: "timeout"},
|
{Name: "onedrive_source_timeout", Value: `1800`, Type: "timeout"},
|
||||||
|
|
@ -131,11 +139,6 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
|
||||||
{Name: "gravatar_server", Value: `https://www.gravatar.com/`, Type: "avatar"},
|
{Name: "gravatar_server", Value: `https://www.gravatar.com/`, Type: "avatar"},
|
||||||
{Name: "defaultTheme", Value: `#3f51b5`, Type: "basic"},
|
{Name: "defaultTheme", Value: `#3f51b5`, Type: "basic"},
|
||||||
{Name: "themes", Value: `{"#3f51b5":{"palette":{"primary":{"main":"#3f51b5"},"secondary":{"main":"#f50057"}}},"#2196f3":{"palette":{"primary":{"main":"#2196f3"},"secondary":{"main":"#FFC107"}}},"#673AB7":{"palette":{"primary":{"main":"#673AB7"},"secondary":{"main":"#2196F3"}}},"#E91E63":{"palette":{"primary":{"main":"#E91E63"},"secondary":{"main":"#42A5F5","contrastText":"#fff"}}},"#FF5722":{"palette":{"primary":{"main":"#FF5722"},"secondary":{"main":"#3F51B5"}}},"#FFC107":{"palette":{"primary":{"main":"#FFC107"},"secondary":{"main":"#26C6DA"}}},"#8BC34A":{"palette":{"primary":{"main":"#8BC34A","contrastText":"#fff"},"secondary":{"main":"#FF8A65","contrastText":"#fff"}}},"#009688":{"palette":{"primary":{"main":"#009688"},"secondary":{"main":"#4DD0E1","contrastText":"#fff"}}},"#607D8B":{"palette":{"primary":{"main":"#607D8B"},"secondary":{"main":"#F06292"}}},"#795548":{"palette":{"primary":{"main":"#795548"},"secondary":{"main":"#4CAF50","contrastText":"#fff"}}}}`, Type: "basic"},
|
{Name: "themes", Value: `{"#3f51b5":{"palette":{"primary":{"main":"#3f51b5"},"secondary":{"main":"#f50057"}}},"#2196f3":{"palette":{"primary":{"main":"#2196f3"},"secondary":{"main":"#FFC107"}}},"#673AB7":{"palette":{"primary":{"main":"#673AB7"},"secondary":{"main":"#2196F3"}}},"#E91E63":{"palette":{"primary":{"main":"#E91E63"},"secondary":{"main":"#42A5F5","contrastText":"#fff"}}},"#FF5722":{"palette":{"primary":{"main":"#FF5722"},"secondary":{"main":"#3F51B5"}}},"#FFC107":{"palette":{"primary":{"main":"#FFC107"},"secondary":{"main":"#26C6DA"}}},"#8BC34A":{"palette":{"primary":{"main":"#8BC34A","contrastText":"#fff"},"secondary":{"main":"#FF8A65","contrastText":"#fff"}}},"#009688":{"palette":{"primary":{"main":"#009688"},"secondary":{"main":"#4DD0E1","contrastText":"#fff"}}},"#607D8B":{"palette":{"primary":{"main":"#607D8B"},"secondary":{"main":"#F06292"}}},"#795548":{"palette":{"primary":{"main":"#795548"},"secondary":{"main":"#4CAF50","contrastText":"#fff"}}}}`, Type: "basic"},
|
||||||
{Name: "aria2_token", Value: ``, Type: "aria2"},
|
|
||||||
{Name: "aria2_rpcurl", Value: ``, Type: "aria2"},
|
|
||||||
{Name: "aria2_temp_path", Value: ``, Type: "aria2"},
|
|
||||||
{Name: "aria2_options", Value: `{}`, Type: "aria2"},
|
|
||||||
{Name: "aria2_interval", Value: `60`, Type: "aria2"},
|
|
||||||
{Name: "max_worker_num", Value: `10`, Type: "task"},
|
{Name: "max_worker_num", Value: `10`, Type: "task"},
|
||||||
{Name: "max_parallel_transfer", Value: `4`, Type: "task"},
|
{Name: "max_parallel_transfer", Value: `4`, Type: "task"},
|
||||||
{Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"},
|
{Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"},
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,91 @@
|
||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Node 从机节点信息模型
|
||||||
|
type Node struct {
|
||||||
|
gorm.Model
|
||||||
|
Status NodeStatus // 节点状态
|
||||||
|
Name string // 节点别名
|
||||||
|
Type ModelType // 节点状态
|
||||||
|
Server string // 服务器地址
|
||||||
|
SlaveKey string `gorm:"type:text"` // 主->从 通信密钥
|
||||||
|
MasterKey string `gorm:"type:text"` // 从->主 通信密钥
|
||||||
|
Aria2Enabled bool // 是否支持用作离线下载节点
|
||||||
|
Aria2Options string `gorm:"type:text"` // 离线下载配置
|
||||||
|
Rank int // 负载均衡权重
|
||||||
|
|
||||||
|
// 数据库忽略字段
|
||||||
|
Aria2OptionsSerialized Aria2Option `gorm:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aria2Option 非公有的Aria2配置属性
|
||||||
|
type Aria2Option struct {
|
||||||
|
// RPC 服务器地址
|
||||||
|
Server string `json:"server,omitempty"`
|
||||||
|
// RPC 密钥
|
||||||
|
Token string `json:"token,omitempty"`
|
||||||
|
// 临时下载目录
|
||||||
|
TempPath string `json:"temp_path,omitempty"`
|
||||||
|
// 附加下载配置
|
||||||
|
Options string `json:"options,omitempty"`
|
||||||
|
// 下载监控间隔
|
||||||
|
Interval int `json:"interval,omitempty"`
|
||||||
|
// RPC API 请求超时
|
||||||
|
Timeout int `json:"timeout,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type NodeStatus int
|
||||||
|
type ModelType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
NodeActive NodeStatus = iota
|
||||||
|
NodeSuspend
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
SlaveNodeType ModelType = iota
|
||||||
|
MasterNodeType
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetNodeByID 用ID获取节点
|
||||||
|
func GetNodeByID(ID interface{}) (Node, error) {
|
||||||
|
var node Node
|
||||||
|
result := DB.First(&node, ID)
|
||||||
|
return node, result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNodesByStatus 根据给定状态获取节点
|
||||||
|
func GetNodesByStatus(status ...NodeStatus) ([]Node, error) {
|
||||||
|
var nodes []Node
|
||||||
|
result := DB.Where("status in (?)", status).Find(&nodes)
|
||||||
|
return nodes, result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// AfterFind 找到节点后的钩子
|
||||||
|
func (node *Node) AfterFind() (err error) {
|
||||||
|
// 解析离线下载设置到 Aria2OptionsSerialized
|
||||||
|
if node.Aria2Options != "" {
|
||||||
|
err = json.Unmarshal([]byte(node.Aria2Options), &node.Aria2OptionsSerialized)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeSave Save策略前的钩子
|
||||||
|
func (node *Node) BeforeSave() (err error) {
|
||||||
|
optionsValue, err := json.Marshal(&node.Aria2OptionsSerialized)
|
||||||
|
node.Aria2Options = string(optionsValue)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetStatus 设置节点启用状态
|
||||||
|
func (node *Node) SetStatus(status NodeStatus) error {
|
||||||
|
node.Status = status
|
||||||
|
return DB.Model(node).Updates(map[string]interface{}{
|
||||||
|
"status": status,
|
||||||
|
}).Error
|
||||||
|
}
|
||||||
|
|
@ -37,6 +37,7 @@ type Policy struct {
|
||||||
|
|
||||||
// 数据库忽略字段
|
// 数据库忽略字段
|
||||||
OptionsSerialized PolicyOption `gorm:"-"`
|
OptionsSerialized PolicyOption `gorm:"-"`
|
||||||
|
MasterID string `gorm:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PolicyOption 非公有的存储策略属性
|
// PolicyOption 非公有的存储策略属性
|
||||||
|
|
@ -277,6 +278,13 @@ func (policy *Policy) SaveAndClearCache() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveAndClearCache 更新并清理缓存
|
||||||
|
func (policy *Policy) UpdateAccessKeyAndClearCache(s string) error {
|
||||||
|
err := DB.Model(policy).UpdateColumn("access_key", s).Error
|
||||||
|
policy.ClearCache()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// ClearCache 清空policy缓存
|
// ClearCache 清空policy缓存
|
||||||
func (policy *Policy) ClearCache() {
|
func (policy *Policy) ClearCache() {
|
||||||
cache.Deletes([]string{strconv.FormatUint(uint64(policy.ID), 10)}, "policy_")
|
cache.Deletes([]string{strconv.FormatUint(uint64(policy.ID), 10)}, "policy_")
|
||||||
|
|
|
||||||
|
|
@ -30,12 +30,16 @@ func GetSettingByName(name string) string {
|
||||||
if optionValue, ok := cache.Get(cacheKey); ok {
|
if optionValue, ok := cache.Get(cacheKey); ok {
|
||||||
return optionValue.(string)
|
return optionValue.(string)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 尝试数据库中查找
|
// 尝试数据库中查找
|
||||||
|
if DB != nil {
|
||||||
result := DB.Where("name = ?", name).First(&setting)
|
result := DB.Where("name = ?", name).First(&setting)
|
||||||
if result.Error == nil {
|
if result.Error == nil {
|
||||||
_ = cache.Set(cacheKey, setting.Value, -1)
|
_ = cache.Set(cacheKey, setting.Value, -1)
|
||||||
return setting.Value
|
return setting.Value
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,169 +1,65 @@
|
||||||
package aria2
|
package aria2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Instance 默认使用的Aria2处理实例
|
// Instance 默认使用的Aria2处理实例
|
||||||
var Instance Aria2 = &DummyAria2{}
|
var Instance common.Aria2 = &common.DummyAria2{}
|
||||||
|
|
||||||
|
// LB 获取 Aria2 节点的负载均衡器
|
||||||
|
var LB balancer.Balancer
|
||||||
|
|
||||||
// Lock Instance的读写锁
|
// Lock Instance的读写锁
|
||||||
var Lock sync.RWMutex
|
var Lock sync.RWMutex
|
||||||
|
|
||||||
// EventNotifier 任务状态更新通知处理器
|
// GetLoadBalancer 返回供Aria2使用的负载均衡器
|
||||||
var EventNotifier = &Notifier{}
|
func GetLoadBalancer() balancer.Balancer {
|
||||||
|
Lock.RLock()
|
||||||
// Aria2 离线下载处理接口
|
defer Lock.RUnlock()
|
||||||
type Aria2 interface {
|
return LB
|
||||||
// CreateTask 创建新的任务
|
|
||||||
CreateTask(task *model.Download, options map[string]interface{}) error
|
|
||||||
// 返回状态信息
|
|
||||||
Status(task *model.Download) (rpc.StatusInfo, error)
|
|
||||||
// 取消任务
|
|
||||||
Cancel(task *model.Download) error
|
|
||||||
// 选择要下载的文件
|
|
||||||
Select(task *model.Download, files []int) error
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
// URLTask 从URL添加的任务
|
|
||||||
URLTask = iota
|
|
||||||
// TorrentTask 种子任务
|
|
||||||
TorrentTask
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// Ready 准备就绪
|
|
||||||
Ready = iota
|
|
||||||
// Downloading 下载中
|
|
||||||
Downloading
|
|
||||||
// Paused 暂停中
|
|
||||||
Paused
|
|
||||||
// Error 出错
|
|
||||||
Error
|
|
||||||
// Complete 完成
|
|
||||||
Complete
|
|
||||||
// Canceled 取消/停止
|
|
||||||
Canceled
|
|
||||||
// Unknown 未知状态
|
|
||||||
Unknown
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// ErrNotEnabled 功能未开启错误
|
|
||||||
ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil)
|
|
||||||
// ErrUserNotFound 未找到下载任务创建者
|
|
||||||
ErrUserNotFound = serializer.NewError(serializer.CodeNotFound, "无法找到任务创建者", nil)
|
|
||||||
)
|
|
||||||
|
|
||||||
// DummyAria2 未开启Aria2功能时使用的默认处理器
|
|
||||||
type DummyAria2 struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateTask 创建新任务,此处直接返回未开启错误
|
|
||||||
func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) error {
|
|
||||||
return ErrNotEnabled
|
|
||||||
}
|
|
||||||
|
|
||||||
// Status 返回未开启错误
|
|
||||||
func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error) {
|
|
||||||
return rpc.StatusInfo{}, ErrNotEnabled
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cancel 返回未开启错误
|
|
||||||
func (instance *DummyAria2) Cancel(task *model.Download) error {
|
|
||||||
return ErrNotEnabled
|
|
||||||
}
|
|
||||||
|
|
||||||
// Select 返回未开启错误
|
|
||||||
func (instance *DummyAria2) Select(task *model.Download, files []int) error {
|
|
||||||
return ErrNotEnabled
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init 初始化
|
// Init 初始化
|
||||||
func Init(isReload bool) {
|
func Init(isReload bool) {
|
||||||
Lock.Lock()
|
Lock.Lock()
|
||||||
defer Lock.Unlock()
|
LB = balancer.NewBalancer("RoundRobin")
|
||||||
|
Lock.Unlock()
|
||||||
// 关闭上个初始连接
|
|
||||||
if previousClient, ok := Instance.(*RPCService); ok {
|
|
||||||
if previousClient.Caller != nil {
|
|
||||||
util.Log().Debug("关闭上个 aria2 连接")
|
|
||||||
previousClient.Caller.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options")
|
|
||||||
timeout := model.GetIntSetting("aria2_call_timeout", 5)
|
|
||||||
if options["aria2_rpcurl"] == "" {
|
|
||||||
Instance = &DummyAria2{}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
util.Log().Info("初始化 aria2 RPC 服务[%s]", options["aria2_rpcurl"])
|
|
||||||
client := &RPCService{}
|
|
||||||
|
|
||||||
// 解析RPC服务地址
|
|
||||||
server, err := url.Parse(options["aria2_rpcurl"])
|
|
||||||
if err != nil {
|
|
||||||
util.Log().Warning("无法解析 aria2 RPC 服务地址,%s", err)
|
|
||||||
Instance = &DummyAria2{}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
server.Path = "/jsonrpc"
|
|
||||||
|
|
||||||
// 加载自定义下载配置
|
|
||||||
var globalOptions map[string]interface{}
|
|
||||||
err = json.Unmarshal([]byte(options["aria2_options"]), &globalOptions)
|
|
||||||
if err != nil {
|
|
||||||
util.Log().Warning("无法解析 aria2 全局配置,%s", err)
|
|
||||||
Instance = &DummyAria2{}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := client.Init(server.String(), options["aria2_token"], timeout, globalOptions); err != nil {
|
|
||||||
util.Log().Warning("初始化 aria2 RPC 服务失败,%s", err)
|
|
||||||
Instance = &DummyAria2{}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
Instance = client
|
|
||||||
|
|
||||||
if !isReload {
|
if !isReload {
|
||||||
// 从数据库中读取未完成任务,创建监控
|
// 从数据库中读取未完成任务,创建监控
|
||||||
unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading)
|
unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading)
|
||||||
|
|
||||||
for i := 0; i < len(unfinished); i++ {
|
for i := 0; i < len(unfinished); i++ {
|
||||||
// 创建任务监控
|
// 创建任务监控
|
||||||
NewMonitor(&unfinished[i])
|
monitor.NewMonitor(&unfinished[i])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestRPCConnection 发送测试用的 RPC 请求,测试服务连通性
|
||||||
|
func TestRPCConnection(server, secret string, timeout int) (rpc.VersionInfo, error) {
|
||||||
|
// 解析RPC服务地址
|
||||||
|
rpcServer, err := url.Parse(server)
|
||||||
|
if err != nil {
|
||||||
|
return rpc.VersionInfo{}, fmt.Errorf("cannot parse RPC server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getStatus 将给定的状态字符串转换为状态标识数字
|
rpcServer.Path = "/jsonrpc"
|
||||||
func getStatus(status string) int {
|
caller, err := rpc.New(context.Background(), rpcServer.String(), secret, time.Duration(timeout)*time.Second, nil)
|
||||||
switch status {
|
if err != nil {
|
||||||
case "complete":
|
return rpc.VersionInfo{}, fmt.Errorf("cannot initialize rpc connection: %w", err)
|
||||||
return Complete
|
|
||||||
case "active":
|
|
||||||
return Downloading
|
|
||||||
case "waiting":
|
|
||||||
return Ready
|
|
||||||
case "paused":
|
|
||||||
return Paused
|
|
||||||
case "error":
|
|
||||||
return Error
|
|
||||||
case "removed":
|
|
||||||
return Canceled
|
|
||||||
default:
|
|
||||||
return Unknown
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return caller.GetVersion()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import (
|
||||||
|
|
||||||
"github.com/DATA-DOG/go-sqlmock"
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
@ -37,7 +38,7 @@ func TestDummyAria2(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInit(t *testing.T) {
|
func TestInit(t *testing.T) {
|
||||||
MAX_RETRY = 0
|
monitor.MAX_RETRY = 0
|
||||||
asserts := assert.New(t)
|
asserts := assert.New(t)
|
||||||
cache.Set("setting_aria2_token", "1", 0)
|
cache.Set("setting_aria2_token", "1", 0)
|
||||||
cache.Set("setting_aria2_call_timeout", "5", 0)
|
cache.Set("setting_aria2_call_timeout", "5", 0)
|
||||||
|
|
@ -81,11 +82,11 @@ func TestInit(t *testing.T) {
|
||||||
|
|
||||||
func TestGetStatus(t *testing.T) {
|
func TestGetStatus(t *testing.T) {
|
||||||
asserts := assert.New(t)
|
asserts := assert.New(t)
|
||||||
asserts.Equal(4, getStatus("complete"))
|
asserts.Equal(4, GetStatus("complete"))
|
||||||
asserts.Equal(1, getStatus("active"))
|
asserts.Equal(1, GetStatus("active"))
|
||||||
asserts.Equal(0, getStatus("waiting"))
|
asserts.Equal(0, GetStatus("waiting"))
|
||||||
asserts.Equal(2, getStatus("paused"))
|
asserts.Equal(2, GetStatus("paused"))
|
||||||
asserts.Equal(3, getStatus("error"))
|
asserts.Equal(3, GetStatus("error"))
|
||||||
asserts.Equal(5, getStatus("removed"))
|
asserts.Equal(5, GetStatus("removed"))
|
||||||
asserts.Equal(6, getStatus("?"))
|
asserts.Equal(6, GetStatus("?"))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ import (
|
||||||
|
|
||||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -33,7 +34,7 @@ func (client *RPCService) Init(server, secret string, timeout int, options map[s
|
||||||
Options: options,
|
Options: options,
|
||||||
}
|
}
|
||||||
caller, err := rpc.New(context.Background(), server, secret, time.Duration(timeout)*time.Second,
|
caller, err := rpc.New(context.Background(), server, secret, time.Duration(timeout)*time.Second,
|
||||||
EventNotifier)
|
mq.GlobalMQ)
|
||||||
client.Caller = caller
|
client.Caller = caller
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -85,7 +86,7 @@ func (client *RPCService) Select(task *model.Download, files []int) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateTask 创建新任务
|
// CreateTask 创建新任务
|
||||||
func (client *RPCService) CreateTask(task *model.Download, groupOptions map[string]interface{}) error {
|
func (client *RPCService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) {
|
||||||
// 生成存储路径
|
// 生成存储路径
|
||||||
path := filepath.Join(
|
path := filepath.Join(
|
||||||
model.GetSettingByName("aria2_temp_path"),
|
model.GetSettingByName("aria2_temp_path"),
|
||||||
|
|
@ -106,18 +107,8 @@ func (client *RPCService) CreateTask(task *model.Download, groupOptions map[stri
|
||||||
|
|
||||||
gid, err := client.Caller.AddURI(task.Source, options)
|
gid, err := client.Caller.AddURI(task.Source, options)
|
||||||
if err != nil || gid == "" {
|
if err != nil || gid == "" {
|
||||||
return err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保存到数据库
|
return gid, nil
|
||||||
task.GID = gid
|
|
||||||
_, err = task.Create()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建任务监控
|
|
||||||
NewMonitor(task)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,114 @@
|
||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Aria2 离线下载处理接口
|
||||||
|
type Aria2 interface {
|
||||||
|
// Init 初始化客户端连接
|
||||||
|
Init() error
|
||||||
|
// CreateTask 创建新的任务
|
||||||
|
CreateTask(task *model.Download, options map[string]interface{}) (string, error)
|
||||||
|
// 返回状态信息
|
||||||
|
Status(task *model.Download) (rpc.StatusInfo, error)
|
||||||
|
// 取消任务
|
||||||
|
Cancel(task *model.Download) error
|
||||||
|
// 选择要下载的文件
|
||||||
|
Select(task *model.Download, files []int) error
|
||||||
|
// 获取离线下载配置
|
||||||
|
GetConfig() model.Aria2Option
|
||||||
|
// 删除临时下载文件
|
||||||
|
DeleteTempFile(*model.Download) error
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// URLTask 从URL添加的任务
|
||||||
|
URLTask = iota
|
||||||
|
// TorrentTask 种子任务
|
||||||
|
TorrentTask
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Ready 准备就绪
|
||||||
|
Ready = iota
|
||||||
|
// Downloading 下载中
|
||||||
|
Downloading
|
||||||
|
// Paused 暂停中
|
||||||
|
Paused
|
||||||
|
// Error 出错
|
||||||
|
Error
|
||||||
|
// Complete 完成
|
||||||
|
Complete
|
||||||
|
// Canceled 取消/停止
|
||||||
|
Canceled
|
||||||
|
// Unknown 未知状态
|
||||||
|
Unknown
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrNotEnabled 功能未开启错误
|
||||||
|
ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil)
|
||||||
|
// ErrUserNotFound 未找到下载任务创建者
|
||||||
|
ErrUserNotFound = serializer.NewError(serializer.CodeNotFound, "无法找到任务创建者", nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
// DummyAria2 未开启Aria2功能时使用的默认处理器
|
||||||
|
type DummyAria2 struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (instance *DummyAria2) Init() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTask 创建新任务,此处直接返回未开启错误
|
||||||
|
func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) (string, error) {
|
||||||
|
return "", ErrNotEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// Status 返回未开启错误
|
||||||
|
func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||||
|
return rpc.StatusInfo{}, ErrNotEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel 返回未开启错误
|
||||||
|
func (instance *DummyAria2) Cancel(task *model.Download) error {
|
||||||
|
return ErrNotEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select 返回未开启错误
|
||||||
|
func (instance *DummyAria2) Select(task *model.Download, files []int) error {
|
||||||
|
return ErrNotEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfig 返回空的
|
||||||
|
func (instance *DummyAria2) GetConfig() model.Aria2Option {
|
||||||
|
return model.Aria2Option{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfig 返回空的
|
||||||
|
func (instance *DummyAria2) DeleteTempFile(src *model.Download) error {
|
||||||
|
return ErrNotEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStatus 将给定的状态字符串转换为状态标识数字
|
||||||
|
func GetStatus(status string) int {
|
||||||
|
switch status {
|
||||||
|
case "complete":
|
||||||
|
return Complete
|
||||||
|
case "active":
|
||||||
|
return Downloading
|
||||||
|
case "waiting":
|
||||||
|
return Ready
|
||||||
|
case "paused":
|
||||||
|
return Paused
|
||||||
|
case "error":
|
||||||
|
return Error
|
||||||
|
case "removed":
|
||||||
|
return Canceled
|
||||||
|
default:
|
||||||
|
return Unknown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,19 +1,21 @@
|
||||||
package aria2
|
package monitor
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local"
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||||
)
|
)
|
||||||
|
|
@ -23,32 +25,34 @@ type Monitor struct {
|
||||||
Task *model.Download
|
Task *model.Download
|
||||||
Interval time.Duration
|
Interval time.Duration
|
||||||
|
|
||||||
notifier chan StatusEvent
|
notifier <-chan mq.Message
|
||||||
|
node cluster.Node
|
||||||
retried int
|
retried int
|
||||||
}
|
}
|
||||||
|
|
||||||
// StatusEvent 状态改变事件
|
|
||||||
type StatusEvent struct {
|
|
||||||
GID string
|
|
||||||
Status int
|
|
||||||
}
|
|
||||||
|
|
||||||
var MAX_RETRY = 10
|
var MAX_RETRY = 10
|
||||||
|
|
||||||
// NewMonitor 新建上传状态监控
|
// NewMonitor 新建离线下载状态监控
|
||||||
func NewMonitor(task *model.Download) {
|
func NewMonitor(task *model.Download) {
|
||||||
monitor := &Monitor{
|
monitor := &Monitor{
|
||||||
Task: task,
|
Task: task,
|
||||||
Interval: time.Duration(model.GetIntSetting("aria2_interval", 10)) * time.Second,
|
notifier: make(chan mq.Message),
|
||||||
notifier: make(chan StatusEvent),
|
node: cluster.Default.GetNodeByID(task.GetNodeID()),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if monitor.node != nil {
|
||||||
|
monitor.Interval = time.Duration(monitor.node.GetAria2Instance().GetConfig().Interval) * time.Second
|
||||||
go monitor.Loop()
|
go monitor.Loop()
|
||||||
EventNotifier.Subscribe(monitor.notifier, monitor.Task.GID)
|
|
||||||
|
monitor.notifier = mq.GlobalMQ.Subscribe(monitor.Task.GID, 0)
|
||||||
|
} else {
|
||||||
|
monitor.setErrorStatus(errors.New("节点不可用"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Loop 开启监控循环
|
// Loop 开启监控循环
|
||||||
func (monitor *Monitor) Loop() {
|
func (monitor *Monitor) Loop() {
|
||||||
defer EventNotifier.Unsubscribe(monitor.Task.GID)
|
defer mq.GlobalMQ.Unsubscribe(monitor.Task.GID, monitor.notifier)
|
||||||
|
|
||||||
// 首次循环立即更新
|
// 首次循环立即更新
|
||||||
interval := time.Duration(0)
|
interval := time.Duration(0)
|
||||||
|
|
@ -70,9 +74,7 @@ func (monitor *Monitor) Loop() {
|
||||||
|
|
||||||
// Update 更新状态,返回值表示是否退出监控
|
// Update 更新状态,返回值表示是否退出监控
|
||||||
func (monitor *Monitor) Update() bool {
|
func (monitor *Monitor) Update() bool {
|
||||||
Lock.RLock()
|
status, err := monitor.node.GetAria2Instance().Status(monitor.Task)
|
||||||
status, err := Instance.Status(monitor.Task)
|
|
||||||
Lock.RUnlock()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
monitor.retried++
|
monitor.retried++
|
||||||
|
|
@ -102,6 +104,7 @@ func (monitor *Monitor) Update() bool {
|
||||||
if err := monitor.UpdateTaskInfo(status); err != nil {
|
if err := monitor.UpdateTaskInfo(status); err != nil {
|
||||||
util.Log().Warning("无法更新下载任务[%s]的任务信息[%s],", monitor.Task.GID, err)
|
util.Log().Warning("无法更新下载任务[%s]的任务信息[%s],", monitor.Task.GID, err)
|
||||||
monitor.setErrorStatus(err)
|
monitor.setErrorStatus(err)
|
||||||
|
monitor.RemoveTempFolder()
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -115,7 +118,7 @@ func (monitor *Monitor) Update() bool {
|
||||||
case "active", "waiting", "paused":
|
case "active", "waiting", "paused":
|
||||||
return false
|
return false
|
||||||
case "removed":
|
case "removed":
|
||||||
monitor.Task.Status = Canceled
|
monitor.Task.Status = common.Canceled
|
||||||
monitor.Task.Save()
|
monitor.Task.Save()
|
||||||
monitor.RemoveTempFolder()
|
monitor.RemoveTempFolder()
|
||||||
return true
|
return true
|
||||||
|
|
@ -130,7 +133,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
|
||||||
originSize := monitor.Task.TotalSize
|
originSize := monitor.Task.TotalSize
|
||||||
|
|
||||||
monitor.Task.GID = status.Gid
|
monitor.Task.GID = status.Gid
|
||||||
monitor.Task.Status = getStatus(status.Status)
|
monitor.Task.Status = common.GetStatus(status.Status)
|
||||||
|
|
||||||
// 文件大小、已下载大小
|
// 文件大小、已下载大小
|
||||||
total, err := strconv.ParseUint(status.TotalLength, 10, 64)
|
total, err := strconv.ParseUint(status.TotalLength, 10, 64)
|
||||||
|
|
@ -164,9 +167,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
|
||||||
// 文件大小更新后,对文件限制等进行校验
|
// 文件大小更新后,对文件限制等进行校验
|
||||||
if err := monitor.ValidateFile(); err != nil {
|
if err := monitor.ValidateFile(); err != nil {
|
||||||
// 验证失败时取消任务
|
// 验证失败时取消任务
|
||||||
Lock.RLock()
|
monitor.node.GetAria2Instance().Cancel(monitor.Task)
|
||||||
Instance.Cancel(monitor.Task)
|
|
||||||
Lock.RUnlock()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -179,7 +180,7 @@ func (monitor *Monitor) ValidateFile() error {
|
||||||
// 找到任务创建者
|
// 找到任务创建者
|
||||||
user := monitor.Task.GetOwner()
|
user := monitor.Task.GetOwner()
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return ErrUserNotFound
|
return common.ErrUserNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建文件系统
|
// 创建文件系统
|
||||||
|
|
@ -230,28 +231,31 @@ func (monitor *Monitor) Error(status rpc.StatusInfo) bool {
|
||||||
|
|
||||||
// RemoveTempFolder 清理下载临时目录
|
// RemoveTempFolder 清理下载临时目录
|
||||||
func (monitor *Monitor) RemoveTempFolder() {
|
func (monitor *Monitor) RemoveTempFolder() {
|
||||||
err := os.RemoveAll(monitor.Task.Parent)
|
monitor.node.GetAria2Instance().DeleteTempFile(monitor.Task)
|
||||||
if err != nil {
|
|
||||||
util.Log().Warning("无法删除离线下载临时目录[%s], %s", monitor.Task.Parent, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Complete 完成下载,返回是否中断监控
|
// Complete 完成下载,返回是否中断监控
|
||||||
func (monitor *Monitor) Complete(status rpc.StatusInfo) bool {
|
func (monitor *Monitor) Complete(status rpc.StatusInfo) bool {
|
||||||
// 创建中转任务
|
// 创建中转任务
|
||||||
file := make([]string, 0, len(monitor.Task.StatusInfo.Files))
|
file := make([]string, 0, len(monitor.Task.StatusInfo.Files))
|
||||||
|
sizes := make(map[string]uint64, len(monitor.Task.StatusInfo.Files))
|
||||||
for i := 0; i < len(monitor.Task.StatusInfo.Files); i++ {
|
for i := 0; i < len(monitor.Task.StatusInfo.Files); i++ {
|
||||||
if monitor.Task.StatusInfo.Files[i].Selected == "true" {
|
fileInfo := monitor.Task.StatusInfo.Files[i]
|
||||||
file = append(file, monitor.Task.StatusInfo.Files[i].Path)
|
if fileInfo.Selected == "true" {
|
||||||
|
file = append(file, fileInfo.Path)
|
||||||
|
size, _ := strconv.ParseUint(fileInfo.Length, 10, 64)
|
||||||
|
sizes[fileInfo.Path] = size
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
job, err := task.NewTransferTask(
|
job, err := task.NewTransferTask(
|
||||||
monitor.Task.UserID,
|
monitor.Task.UserID,
|
||||||
file,
|
file,
|
||||||
monitor.Task.Dst,
|
monitor.Task.Dst,
|
||||||
monitor.Task.Parent,
|
monitor.Task.Parent,
|
||||||
true,
|
true,
|
||||||
|
monitor.node.ID(),
|
||||||
|
sizes,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
monitor.setErrorStatus(err)
|
monitor.setErrorStatus(err)
|
||||||
|
|
@ -269,7 +273,7 @@ func (monitor *Monitor) Complete(status rpc.StatusInfo) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (monitor *Monitor) setErrorStatus(err error) {
|
func (monitor *Monitor) setErrorStatus(err error) {
|
||||||
monitor.Task.Status = Error
|
monitor.Task.Status = common.Error
|
||||||
monitor.Task.Error = err.Error()
|
monitor.Task.Error = err.Error()
|
||||||
monitor.Task.Save()
|
monitor.Task.Save()
|
||||||
}
|
}
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
package aria2
|
package monitor
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
|
@ -7,6 +7,8 @@ import (
|
||||||
|
|
||||||
"github.com/DATA-DOG/go-sqlmock"
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||||
|
|
@ -44,13 +46,13 @@ func (m InstanceMock) Select(task *model.Download, files []int) error {
|
||||||
func TestNewMonitor(t *testing.T) {
|
func TestNewMonitor(t *testing.T) {
|
||||||
asserts := assert.New(t)
|
asserts := assert.New(t)
|
||||||
NewMonitor(&model.Download{GID: "gid"})
|
NewMonitor(&model.Download{GID: "gid"})
|
||||||
_, ok := EventNotifier.Subscribes.Load("gid")
|
_, ok := common.EventNotifier.Subscribes.Load("gid")
|
||||||
asserts.True(ok)
|
asserts.True(ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMonitor_Loop(t *testing.T) {
|
func TestMonitor_Loop(t *testing.T) {
|
||||||
asserts := assert.New(t)
|
asserts := assert.New(t)
|
||||||
notifier := make(chan StatusEvent)
|
notifier := make(chan common.StatusEvent)
|
||||||
MAX_RETRY = 0
|
MAX_RETRY = 0
|
||||||
monitor := &Monitor{
|
monitor := &Monitor{
|
||||||
Task: &model.Download{GID: "gid"},
|
Task: &model.Download{GID: "gid"},
|
||||||
|
|
@ -76,10 +78,10 @@ func TestMonitor_Update(t *testing.T) {
|
||||||
{
|
{
|
||||||
MAX_RETRY = 1
|
MAX_RETRY = 1
|
||||||
testInstance := new(InstanceMock)
|
testInstance := new(InstanceMock)
|
||||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error"))
|
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error"))
|
||||||
file, _ := util.CreatNestedFile("TestMonitor_Update/1")
|
file, _ := util.CreatNestedFile("TestMonitor_Update/1")
|
||||||
file.Close()
|
file.Close()
|
||||||
Instance = testInstance
|
aria2.Instance = testInstance
|
||||||
asserts.False(monitor.Update())
|
asserts.False(monitor.Update())
|
||||||
asserts.True(monitor.Update())
|
asserts.True(monitor.Update())
|
||||||
testInstance.AssertExpectations(t)
|
testInstance.AssertExpectations(t)
|
||||||
|
|
@ -89,16 +91,16 @@ func TestMonitor_Update(t *testing.T) {
|
||||||
// 磁力链下载重定向
|
// 磁力链下载重定向
|
||||||
{
|
{
|
||||||
testInstance := new(InstanceMock)
|
testInstance := new(InstanceMock)
|
||||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{
|
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{
|
||||||
FollowedBy: []string{"1"},
|
FollowedBy: []string{"1"},
|
||||||
}, nil)
|
}, nil)
|
||||||
monitor.Task.ID = 1
|
monitor.Task.ID = 1
|
||||||
mock.ExpectBegin()
|
aria2.mock.ExpectBegin()
|
||||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
mock.ExpectCommit()
|
aria2.mock.ExpectCommit()
|
||||||
Instance = testInstance
|
aria2.Instance = testInstance
|
||||||
asserts.False(monitor.Update())
|
asserts.False(monitor.Update())
|
||||||
asserts.NoError(mock.ExpectationsWereMet())
|
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||||
testInstance.AssertExpectations(t)
|
testInstance.AssertExpectations(t)
|
||||||
asserts.EqualValues("1", monitor.Task.GID)
|
asserts.EqualValues("1", monitor.Task.GID)
|
||||||
}
|
}
|
||||||
|
|
@ -106,82 +108,82 @@ func TestMonitor_Update(t *testing.T) {
|
||||||
// 无法更新任务信息
|
// 无法更新任务信息
|
||||||
{
|
{
|
||||||
testInstance := new(InstanceMock)
|
testInstance := new(InstanceMock)
|
||||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil)
|
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{}, nil)
|
||||||
monitor.Task.ID = 1
|
monitor.Task.ID = 1
|
||||||
mock.ExpectBegin()
|
aria2.mock.ExpectBegin()
|
||||||
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
|
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
|
||||||
mock.ExpectRollback()
|
aria2.mock.ExpectRollback()
|
||||||
Instance = testInstance
|
aria2.Instance = testInstance
|
||||||
asserts.True(monitor.Update())
|
asserts.True(monitor.Update())
|
||||||
asserts.NoError(mock.ExpectationsWereMet())
|
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||||
testInstance.AssertExpectations(t)
|
testInstance.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 返回未知状态
|
// 返回未知状态
|
||||||
{
|
{
|
||||||
testInstance := new(InstanceMock)
|
testInstance := new(InstanceMock)
|
||||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "?"}, nil)
|
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "?"}, nil)
|
||||||
mock.ExpectBegin()
|
aria2.mock.ExpectBegin()
|
||||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
mock.ExpectCommit()
|
aria2.mock.ExpectCommit()
|
||||||
Instance = testInstance
|
aria2.Instance = testInstance
|
||||||
asserts.True(monitor.Update())
|
asserts.True(monitor.Update())
|
||||||
asserts.NoError(mock.ExpectationsWereMet())
|
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||||
testInstance.AssertExpectations(t)
|
testInstance.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 返回被取消状态
|
// 返回被取消状态
|
||||||
{
|
{
|
||||||
testInstance := new(InstanceMock)
|
testInstance := new(InstanceMock)
|
||||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "removed"}, nil)
|
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "removed"}, nil)
|
||||||
mock.ExpectBegin()
|
aria2.mock.ExpectBegin()
|
||||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
mock.ExpectCommit()
|
aria2.mock.ExpectCommit()
|
||||||
mock.ExpectBegin()
|
aria2.mock.ExpectBegin()
|
||||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
mock.ExpectCommit()
|
aria2.mock.ExpectCommit()
|
||||||
Instance = testInstance
|
aria2.Instance = testInstance
|
||||||
asserts.True(monitor.Update())
|
asserts.True(monitor.Update())
|
||||||
asserts.NoError(mock.ExpectationsWereMet())
|
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||||
testInstance.AssertExpectations(t)
|
testInstance.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 返回活跃状态
|
// 返回活跃状态
|
||||||
{
|
{
|
||||||
testInstance := new(InstanceMock)
|
testInstance := new(InstanceMock)
|
||||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "active"}, nil)
|
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "active"}, nil)
|
||||||
mock.ExpectBegin()
|
aria2.mock.ExpectBegin()
|
||||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
mock.ExpectCommit()
|
aria2.mock.ExpectCommit()
|
||||||
Instance = testInstance
|
aria2.Instance = testInstance
|
||||||
asserts.False(monitor.Update())
|
asserts.False(monitor.Update())
|
||||||
asserts.NoError(mock.ExpectationsWereMet())
|
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||||
testInstance.AssertExpectations(t)
|
testInstance.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 返回错误状态
|
// 返回错误状态
|
||||||
{
|
{
|
||||||
testInstance := new(InstanceMock)
|
testInstance := new(InstanceMock)
|
||||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "error"}, nil)
|
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "error"}, nil)
|
||||||
mock.ExpectBegin()
|
aria2.mock.ExpectBegin()
|
||||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
mock.ExpectCommit()
|
aria2.mock.ExpectCommit()
|
||||||
Instance = testInstance
|
aria2.Instance = testInstance
|
||||||
asserts.True(monitor.Update())
|
asserts.True(monitor.Update())
|
||||||
asserts.NoError(mock.ExpectationsWereMet())
|
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||||
testInstance.AssertExpectations(t)
|
testInstance.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 返回完成
|
// 返回完成
|
||||||
{
|
{
|
||||||
testInstance := new(InstanceMock)
|
testInstance := new(InstanceMock)
|
||||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "complete"}, nil)
|
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "complete"}, nil)
|
||||||
mock.ExpectBegin()
|
aria2.mock.ExpectBegin()
|
||||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
mock.ExpectCommit()
|
aria2.mock.ExpectCommit()
|
||||||
Instance = testInstance
|
aria2.Instance = testInstance
|
||||||
asserts.True(monitor.Update())
|
asserts.True(monitor.Update())
|
||||||
asserts.NoError(mock.ExpectationsWereMet())
|
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||||
testInstance.AssertExpectations(t)
|
testInstance.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -198,34 +200,34 @@ func TestMonitor_UpdateTaskInfo(t *testing.T) {
|
||||||
|
|
||||||
// 失败
|
// 失败
|
||||||
{
|
{
|
||||||
mock.ExpectBegin()
|
aria2.mock.ExpectBegin()
|
||||||
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
|
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
|
||||||
mock.ExpectRollback()
|
aria2.mock.ExpectRollback()
|
||||||
err := monitor.UpdateTaskInfo(rpc.StatusInfo{})
|
err := monitor.UpdateTaskInfo(rpc.StatusInfo{})
|
||||||
asserts.NoError(mock.ExpectationsWereMet())
|
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||||
asserts.Error(err)
|
asserts.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新成功,无需校验
|
// 更新成功,无需校验
|
||||||
{
|
{
|
||||||
mock.ExpectBegin()
|
aria2.mock.ExpectBegin()
|
||||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
mock.ExpectCommit()
|
aria2.mock.ExpectCommit()
|
||||||
err := monitor.UpdateTaskInfo(rpc.StatusInfo{})
|
err := monitor.UpdateTaskInfo(rpc.StatusInfo{})
|
||||||
asserts.NoError(mock.ExpectationsWereMet())
|
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||||
asserts.NoError(err)
|
asserts.NoError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新成功,大小改变,需要校验,校验失败
|
// 更新成功,大小改变,需要校验,校验失败
|
||||||
{
|
{
|
||||||
testInstance := new(InstanceMock)
|
testInstance := new(InstanceMock)
|
||||||
testInstance.On("Cancel", testMock.Anything).Return(nil)
|
testInstance.On("SlaveCancel", testMock.Anything).Return(nil)
|
||||||
Instance = testInstance
|
aria2.Instance = testInstance
|
||||||
mock.ExpectBegin()
|
aria2.mock.ExpectBegin()
|
||||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
mock.ExpectCommit()
|
aria2.mock.ExpectCommit()
|
||||||
err := monitor.UpdateTaskInfo(rpc.StatusInfo{TotalLength: "1"})
|
err := monitor.UpdateTaskInfo(rpc.StatusInfo{TotalLength: "1"})
|
||||||
asserts.NoError(mock.ExpectationsWereMet())
|
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||||
asserts.Error(err)
|
asserts.Error(err)
|
||||||
testInstance.AssertExpectations(t)
|
testInstance.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
|
|
@ -308,17 +310,17 @@ func TestMonitor_Complete(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
cache.Set("setting_max_worker_num", "1", 0)
|
cache.Set("setting_max_worker_num", "1", 0)
|
||||||
mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id"}))
|
aria2.mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id"}))
|
||||||
task.Init()
|
task.Init()
|
||||||
mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
aria2.mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||||
mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
aria2.mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||||
mock.ExpectBegin()
|
aria2.mock.ExpectBegin()
|
||||||
mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1))
|
aria2.mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
mock.ExpectCommit()
|
aria2.mock.ExpectCommit()
|
||||||
|
|
||||||
mock.ExpectBegin()
|
aria2.mock.ExpectBegin()
|
||||||
mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1))
|
aria2.mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
mock.ExpectCommit()
|
aria2.mock.ExpectCommit()
|
||||||
asserts.True(monitor.Complete(rpc.StatusInfo{}))
|
asserts.True(monitor.Complete(rpc.StatusInfo{}))
|
||||||
asserts.NoError(mock.ExpectationsWereMet())
|
asserts.NoError(aria2.mock.ExpectationsWereMet())
|
||||||
}
|
}
|
||||||
|
|
@ -1,64 +0,0 @@
|
||||||
package aria2
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Notifier aria2实践通知处理
|
|
||||||
type Notifier struct {
|
|
||||||
Subscribes sync.Map
|
|
||||||
}
|
|
||||||
|
|
||||||
// Subscribe 订阅事件通知
|
|
||||||
func (notifier *Notifier) Subscribe(target chan StatusEvent, gid string) {
|
|
||||||
notifier.Subscribes.Store(gid, target)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unsubscribe 取消订阅事件通知
|
|
||||||
func (notifier *Notifier) Unsubscribe(gid string) {
|
|
||||||
notifier.Subscribes.Delete(gid)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Notify 发送通知
|
|
||||||
func (notifier *Notifier) Notify(events []rpc.Event, status int) {
|
|
||||||
for _, event := range events {
|
|
||||||
if target, ok := notifier.Subscribes.Load(event.Gid); ok {
|
|
||||||
target.(chan StatusEvent) <- StatusEvent{
|
|
||||||
GID: event.Gid,
|
|
||||||
Status: status,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnDownloadStart 下载开始
|
|
||||||
func (notifier *Notifier) OnDownloadStart(events []rpc.Event) {
|
|
||||||
notifier.Notify(events, Downloading)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnDownloadPause 下载暂停
|
|
||||||
func (notifier *Notifier) OnDownloadPause(events []rpc.Event) {
|
|
||||||
notifier.Notify(events, Paused)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnDownloadStop 下载停止
|
|
||||||
func (notifier *Notifier) OnDownloadStop(events []rpc.Event) {
|
|
||||||
notifier.Notify(events, Canceled)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnDownloadComplete 下载完成
|
|
||||||
func (notifier *Notifier) OnDownloadComplete(events []rpc.Event) {
|
|
||||||
notifier.Notify(events, Complete)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnDownloadError 下载出错
|
|
||||||
func (notifier *Notifier) OnDownloadError(events []rpc.Event) {
|
|
||||||
notifier.Notify(events, Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnBtDownloadComplete BT下载完成
|
|
||||||
func (notifier *Notifier) OnBtDownloadComplete(events []rpc.Event) {
|
|
||||||
notifier.Notify(events, Complete)
|
|
||||||
}
|
|
||||||
|
|
@ -2,9 +2,11 @@ package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -30,9 +32,8 @@ type Auth interface {
|
||||||
Check(body string, sign string) error
|
Check(body string, sign string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignRequest 对PUT\POST等复杂HTTP请求签名,如果请求Header中
|
// SignRequest 对PUT\POST等复杂HTTP请求签名,只会对URI部分、
|
||||||
// 包含 X-Policy, 则此请求会被认定为上传请求,只会对URI部分和
|
// 请求正文、`X-`开头的header进行签名
|
||||||
// Policy部分进行签名。其他请求则会对URI和Body部分进行签名。
|
|
||||||
func SignRequest(instance Auth, r *http.Request, expires int64) *http.Request {
|
func SignRequest(instance Auth, r *http.Request, expires int64) *http.Request {
|
||||||
// 处理有效期
|
// 处理有效期
|
||||||
if expires > 0 {
|
if expires > 0 {
|
||||||
|
|
@ -61,20 +62,31 @@ func CheckRequest(instance Auth, r *http.Request) error {
|
||||||
return instance.Check(getSignContent(r), sign[0])
|
return instance.Check(getSignContent(r), sign[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
// getSignContent 根据请求Header中是否包含X-Policy判断是否为上传请求,
|
// getSignContent 签名请求 path、正文、以`X-`开头的 Header. 如果 Header 中包含 `X-Policy`,
|
||||||
// 返回待签名/验证的字符串
|
// 则不对正文签名。返回待签名/验证的字符串
|
||||||
func getSignContent(r *http.Request) (rawSignString string) {
|
func getSignContent(r *http.Request) (rawSignString string) {
|
||||||
if policy, ok := r.Header["X-Policy"]; ok {
|
// 读取所有body正文
|
||||||
rawSignString = serializer.NewRequestSignString(r.URL.Path, policy[0], "")
|
|
||||||
} else {
|
|
||||||
var body = []byte{}
|
var body = []byte{}
|
||||||
|
if _, ok := r.Header["X-Policy"]; !ok {
|
||||||
if r.Body != nil {
|
if r.Body != nil {
|
||||||
body, _ = ioutil.ReadAll(r.Body)
|
body, _ = ioutil.ReadAll(r.Body)
|
||||||
_ = r.Body.Close()
|
_ = r.Body.Close()
|
||||||
r.Body = ioutil.NopCloser(bytes.NewReader(body))
|
r.Body = ioutil.NopCloser(bytes.NewReader(body))
|
||||||
}
|
}
|
||||||
rawSignString = serializer.NewRequestSignString(r.URL.Path, "", string(body))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 决定要签名的header
|
||||||
|
var signedHeader []string
|
||||||
|
for k, _ := range r.Header {
|
||||||
|
if strings.HasPrefix(k, "X-") && k != "X-Filename" {
|
||||||
|
signedHeader = append(signedHeader, fmt.Sprintf("%s=%s", k, r.Header.Get(k)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(signedHeader)
|
||||||
|
|
||||||
|
// 读取所有待签名Header
|
||||||
|
rawSignString = serializer.NewRequestSignString(r.URL.Path, strings.Join(signedHeader, "&"), string(body))
|
||||||
|
|
||||||
return rawSignString
|
return rawSignString
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
package balancer
|
||||||
|
|
||||||
|
type Balancer interface {
|
||||||
|
NextPeer(nodes interface{}) (error, interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBalancer 根据策略标识返回新的负载均衡器
|
||||||
|
func NewBalancer(strategy string) Balancer {
|
||||||
|
switch strategy {
|
||||||
|
case "RoundRobin":
|
||||||
|
return &RoundRobin{}
|
||||||
|
default:
|
||||||
|
return &RoundRobin{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
package balancer
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInputNotSlice = errors.New("Input value is not silice")
|
||||||
|
ErrNoAvaliableNode = errors.New("No nodes avaliable")
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,30 @@
|
||||||
|
package balancer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RoundRobin struct {
|
||||||
|
current uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NextPeer 返回轮盘的下一节点
|
||||||
|
func (r *RoundRobin) NextPeer(nodes interface{}) (error, interface{}) {
|
||||||
|
v := reflect.ValueOf(nodes)
|
||||||
|
if v.Kind() != reflect.Slice {
|
||||||
|
return ErrInputNotSlice, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if v.Len() == 0 {
|
||||||
|
return ErrNoAvaliableNode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
next := r.NextIndex(v.Len())
|
||||||
|
return nil, v.Index(next).Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NextIndex 返回下一个节点下标
|
||||||
|
func (r *RoundRobin) NextIndex(total int) int {
|
||||||
|
return int(atomic.AddUint64(&r.current, uint64(1)) % uint64(total))
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
package cluster
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrFeatureNotExist = errors.New("No nodes in nodepool match the feature specificed")
|
||||||
|
ErrIlegalPath = errors.New("path out of boundary of setting temp folder")
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,265 @@
|
||||||
|
package cluster
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||||
|
"github.com/gofrs/uuid"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const deleteTempFileDuration = 60 * time.Second
|
||||||
|
|
||||||
|
type MasterNode struct {
|
||||||
|
Model *model.Node
|
||||||
|
aria2RPC rpcService
|
||||||
|
lock sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// RPCService 通过RPC服务的Aria2任务管理器
|
||||||
|
type rpcService struct {
|
||||||
|
Caller rpc.Client
|
||||||
|
Initialized bool
|
||||||
|
|
||||||
|
parent *MasterNode
|
||||||
|
options *clientOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientOptions struct {
|
||||||
|
Options map[string]interface{} // 创建下载时额外添加的设置
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init 初始化节点
|
||||||
|
func (node *MasterNode) Init(nodeModel *model.Node) {
|
||||||
|
node.lock.Lock()
|
||||||
|
node.Model = nodeModel
|
||||||
|
node.aria2RPC.parent = node
|
||||||
|
node.lock.Unlock()
|
||||||
|
|
||||||
|
node.lock.RLock()
|
||||||
|
if node.Model.Aria2Enabled {
|
||||||
|
node.lock.RUnlock()
|
||||||
|
node.aria2RPC.Init()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
node.lock.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *MasterNode) ID() uint {
|
||||||
|
node.lock.RLock()
|
||||||
|
defer node.lock.RUnlock()
|
||||||
|
|
||||||
|
return node.Model.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *MasterNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
|
||||||
|
return &serializer.NodePingResp{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsFeatureEnabled 查询节点的某项功能是否启用
|
||||||
|
func (node *MasterNode) IsFeatureEnabled(feature string) bool {
|
||||||
|
node.lock.RLock()
|
||||||
|
defer node.lock.RUnlock()
|
||||||
|
|
||||||
|
switch feature {
|
||||||
|
case "aria2":
|
||||||
|
return node.Model.Aria2Enabled
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *MasterNode) MasterAuthInstance() auth.Auth {
|
||||||
|
node.lock.RLock()
|
||||||
|
defer node.lock.RUnlock()
|
||||||
|
|
||||||
|
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *MasterNode) SlaveAuthInstance() auth.Auth {
|
||||||
|
node.lock.RLock()
|
||||||
|
defer node.lock.RUnlock()
|
||||||
|
|
||||||
|
return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubscribeStatusChange 订阅节点状态更改
|
||||||
|
func (node *MasterNode) SubscribeStatusChange(callback func(isActive bool, id uint)) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsActive 返回节点是否在线
|
||||||
|
func (node *MasterNode) IsActive() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kill 结束aria2请求
|
||||||
|
func (node *MasterNode) Kill() {
|
||||||
|
if node.aria2RPC.Caller != nil {
|
||||||
|
node.aria2RPC.Caller.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAria2Instance 获取主机Aria2实例
|
||||||
|
func (node *MasterNode) GetAria2Instance() common.Aria2 {
|
||||||
|
node.lock.RLock()
|
||||||
|
|
||||||
|
if !node.Model.Aria2Enabled {
|
||||||
|
node.lock.RUnlock()
|
||||||
|
return &common.DummyAria2{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !node.aria2RPC.Initialized {
|
||||||
|
node.lock.RUnlock()
|
||||||
|
node.aria2RPC.Init()
|
||||||
|
return &common.DummyAria2{}
|
||||||
|
}
|
||||||
|
|
||||||
|
defer node.lock.RUnlock()
|
||||||
|
return &node.aria2RPC
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *MasterNode) IsMater() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *MasterNode) DBModel() *model.Node {
|
||||||
|
node.lock.RLock()
|
||||||
|
defer node.lock.RUnlock()
|
||||||
|
|
||||||
|
return node.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rpcService) Init() error {
|
||||||
|
r.parent.lock.Lock()
|
||||||
|
defer r.parent.lock.Unlock()
|
||||||
|
r.Initialized = false
|
||||||
|
|
||||||
|
// 客户端已存在,则关闭先前连接
|
||||||
|
if r.Caller != nil {
|
||||||
|
r.Caller.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析RPC服务地址
|
||||||
|
server, err := url.Parse(r.parent.Model.Aria2OptionsSerialized.Server)
|
||||||
|
if err != nil {
|
||||||
|
util.Log().Warning("无法解析主机 Aria2 RPC 服务地址,%s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
server.Path = "/jsonrpc"
|
||||||
|
|
||||||
|
// 加载自定义下载配置
|
||||||
|
var globalOptions map[string]interface{}
|
||||||
|
if r.parent.Model.Aria2OptionsSerialized.Options != "" {
|
||||||
|
err = json.Unmarshal([]byte(r.parent.Model.Aria2OptionsSerialized.Options), &globalOptions)
|
||||||
|
if err != nil {
|
||||||
|
util.Log().Warning("无法解析主机 Aria2 配置,%s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.options = &clientOptions{
|
||||||
|
Options: globalOptions,
|
||||||
|
}
|
||||||
|
timeout := r.parent.Model.Aria2OptionsSerialized.Timeout
|
||||||
|
caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, mq.GlobalMQ)
|
||||||
|
|
||||||
|
r.Caller = caller
|
||||||
|
r.Initialized = err == nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rpcService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) {
|
||||||
|
r.parent.lock.RLock()
|
||||||
|
// 生成存储路径
|
||||||
|
guid, _ := uuid.NewV4()
|
||||||
|
path := filepath.Join(
|
||||||
|
r.parent.Model.Aria2OptionsSerialized.TempPath,
|
||||||
|
"aria2",
|
||||||
|
guid.String(),
|
||||||
|
)
|
||||||
|
r.parent.lock.RUnlock()
|
||||||
|
|
||||||
|
// 创建下载任务
|
||||||
|
options := map[string]interface{}{
|
||||||
|
"dir": path,
|
||||||
|
}
|
||||||
|
for k, v := range r.options.Options {
|
||||||
|
options[k] = v
|
||||||
|
}
|
||||||
|
for k, v := range groupOptions {
|
||||||
|
options[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
gid, err := r.Caller.AddURI(task.Source, options)
|
||||||
|
if err != nil || gid == "" {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return gid, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rpcService) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||||
|
res, err := r.Caller.TellStatus(task.GID)
|
||||||
|
if err != nil {
|
||||||
|
// 失败后重试
|
||||||
|
util.Log().Debug("无法获取离线下载状态,%s,10秒钟后重试", err)
|
||||||
|
time.Sleep(time.Duration(10) * time.Second)
|
||||||
|
res, err = r.Caller.TellStatus(task.GID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rpcService) Cancel(task *model.Download) error {
|
||||||
|
// 取消下载任务
|
||||||
|
_, err := r.Caller.Remove(task.GID)
|
||||||
|
if err != nil {
|
||||||
|
util.Log().Warning("无法取消离线下载任务[%s], %s", task.GID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rpcService) Select(task *model.Download, files []int) error {
|
||||||
|
var selected = make([]string, len(files))
|
||||||
|
for i := 0; i < len(files); i++ {
|
||||||
|
selected[i] = strconv.Itoa(files[i])
|
||||||
|
}
|
||||||
|
_, err := r.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rpcService) GetConfig() model.Aria2Option {
|
||||||
|
r.parent.lock.RLock()
|
||||||
|
defer r.parent.lock.RUnlock()
|
||||||
|
|
||||||
|
return r.parent.Model.Aria2OptionsSerialized
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *rpcService) DeleteTempFile(task *model.Download) error {
|
||||||
|
s.parent.lock.RLock()
|
||||||
|
defer s.parent.lock.RUnlock()
|
||||||
|
|
||||||
|
// 避免被aria2占用,异步执行删除
|
||||||
|
go func(src string) {
|
||||||
|
time.Sleep(deleteTempFileDuration)
|
||||||
|
err := os.RemoveAll(src)
|
||||||
|
if err != nil {
|
||||||
|
util.Log().Warning("无法删除离线下载临时目录[%s], %s", src, err)
|
||||||
|
}
|
||||||
|
}(task.Parent)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,60 @@
|
||||||
|
package cluster
|
||||||
|
|
||||||
|
import (
|
||||||
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Node interface {
|
||||||
|
// Init a node from database model
|
||||||
|
Init(node *model.Node)
|
||||||
|
|
||||||
|
// Check if given feature is enabled
|
||||||
|
IsFeatureEnabled(feature string) bool
|
||||||
|
|
||||||
|
// Subscribe node status change to a callback function
|
||||||
|
SubscribeStatusChange(callback func(isActive bool, id uint))
|
||||||
|
|
||||||
|
// Ping the node
|
||||||
|
Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error)
|
||||||
|
|
||||||
|
// Returns if the node is active
|
||||||
|
IsActive() bool
|
||||||
|
|
||||||
|
// Get instances for aria2 calls
|
||||||
|
GetAria2Instance() common.Aria2
|
||||||
|
|
||||||
|
// Returns unique id of this node
|
||||||
|
ID() uint
|
||||||
|
|
||||||
|
// Kill node and recycle resources
|
||||||
|
Kill()
|
||||||
|
|
||||||
|
// Returns if current node is master node
|
||||||
|
IsMater() bool
|
||||||
|
|
||||||
|
// Get auth instance used to check RPC call from slave to master
|
||||||
|
MasterAuthInstance() auth.Auth
|
||||||
|
|
||||||
|
// Get auth instance used to check RPC call from master to slave
|
||||||
|
SlaveAuthInstance() auth.Auth
|
||||||
|
|
||||||
|
// Get node DB model
|
||||||
|
DBModel() *model.Node
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new node from DB model
|
||||||
|
func NewNodeFromDBModel(node *model.Node) Node {
|
||||||
|
switch node.Type {
|
||||||
|
case model.SlaveNodeType:
|
||||||
|
slave := &SlaveNode{}
|
||||||
|
slave.Init(node)
|
||||||
|
return slave
|
||||||
|
default:
|
||||||
|
master := &MasterNode{}
|
||||||
|
master.Init(node)
|
||||||
|
return master
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,176 @@
|
||||||
|
package cluster
|
||||||
|
|
||||||
|
import (
|
||||||
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var Default *NodePool
|
||||||
|
|
||||||
|
// 需要分类的节点组
|
||||||
|
var featureGroup = []string{"aria2"}
|
||||||
|
|
||||||
|
// Pool 节点池
|
||||||
|
type Pool interface {
|
||||||
|
// Returns active node selected by given feature and load balancer
|
||||||
|
BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node)
|
||||||
|
|
||||||
|
// Returns node by ID
|
||||||
|
GetNodeByID(id uint) Node
|
||||||
|
|
||||||
|
// Add given node into pool. If node existed, refresh node.
|
||||||
|
Add(node *model.Node)
|
||||||
|
|
||||||
|
// Delete and kill node from pool by given node id
|
||||||
|
Delete(id uint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NodePool 通用节点池
|
||||||
|
type NodePool struct {
|
||||||
|
active map[uint]Node
|
||||||
|
inactive map[uint]Node
|
||||||
|
|
||||||
|
featureMap map[string][]Node
|
||||||
|
|
||||||
|
lock sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init 初始化从机节点池
|
||||||
|
func Init() {
|
||||||
|
Default = &NodePool{
|
||||||
|
featureMap: make(map[string][]Node),
|
||||||
|
}
|
||||||
|
if err := Default.initFromDB(); err != nil {
|
||||||
|
util.Log().Warning("节点池初始化失败, %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pool *NodePool) buildIndexMap() {
|
||||||
|
pool.lock.Lock()
|
||||||
|
for _, feature := range featureGroup {
|
||||||
|
pool.featureMap[feature] = make([]Node, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range pool.active {
|
||||||
|
for _, feature := range featureGroup {
|
||||||
|
if v.IsFeatureEnabled(feature) {
|
||||||
|
pool.featureMap[feature] = append(pool.featureMap[feature], v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pool.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pool *NodePool) GetNodeByID(id uint) Node {
|
||||||
|
pool.lock.RLock()
|
||||||
|
defer pool.lock.RUnlock()
|
||||||
|
|
||||||
|
if node, ok := pool.active[id]; ok {
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
return pool.inactive[id]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pool *NodePool) nodeStatusChange(isActive bool, id uint) {
|
||||||
|
util.Log().Debug("从机节点 [ID=%d] 状态变更 [Active=%t]", id, isActive)
|
||||||
|
pool.lock.Lock()
|
||||||
|
if isActive {
|
||||||
|
node := pool.inactive[id]
|
||||||
|
delete(pool.inactive, id)
|
||||||
|
pool.active[id] = node
|
||||||
|
} else {
|
||||||
|
node := pool.active[id]
|
||||||
|
delete(pool.active, id)
|
||||||
|
pool.inactive[id] = node
|
||||||
|
}
|
||||||
|
pool.lock.Unlock()
|
||||||
|
|
||||||
|
pool.buildIndexMap()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pool *NodePool) initFromDB() error {
|
||||||
|
nodes, err := model.GetNodesByStatus(model.NodeActive)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.lock.Lock()
|
||||||
|
pool.active = make(map[uint]Node)
|
||||||
|
pool.inactive = make(map[uint]Node)
|
||||||
|
for i := 0; i < len(nodes); i++ {
|
||||||
|
pool.add(&nodes[i])
|
||||||
|
}
|
||||||
|
pool.lock.Unlock()
|
||||||
|
|
||||||
|
pool.buildIndexMap()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pool *NodePool) add(node *model.Node) {
|
||||||
|
newNode := NewNodeFromDBModel(node)
|
||||||
|
if newNode.IsActive() {
|
||||||
|
pool.active[node.ID] = newNode
|
||||||
|
} else {
|
||||||
|
pool.inactive[node.ID] = newNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// 订阅节点状态变更
|
||||||
|
newNode.SubscribeStatusChange(func(isActive bool, id uint) {
|
||||||
|
pool.nodeStatusChange(isActive, id)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pool *NodePool) Add(node *model.Node) {
|
||||||
|
pool.lock.Lock()
|
||||||
|
defer pool.buildIndexMap()
|
||||||
|
defer pool.lock.Unlock()
|
||||||
|
|
||||||
|
if _, ok := pool.active[node.ID]; ok {
|
||||||
|
// TODO: refresh node
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := pool.inactive[node.ID]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.add(node)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pool *NodePool) Delete(id uint) {
|
||||||
|
pool.lock.Lock()
|
||||||
|
defer pool.buildIndexMap()
|
||||||
|
defer pool.lock.Unlock()
|
||||||
|
|
||||||
|
if node, ok := pool.active[id]; ok {
|
||||||
|
node.Kill()
|
||||||
|
delete(pool.active, id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if node, ok := pool.inactive[id]; ok {
|
||||||
|
node.Kill()
|
||||||
|
delete(pool.inactive, id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNodeByFeature 根据 feature 和 LoadBalancer 取出节点
|
||||||
|
func (pool *NodePool) BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node) {
|
||||||
|
pool.lock.RLock()
|
||||||
|
defer pool.lock.RUnlock()
|
||||||
|
if nodes, ok := pool.featureMap[feature]; ok {
|
||||||
|
err, res := lb.NextPeer(nodes)
|
||||||
|
if err == nil {
|
||||||
|
return nil, res.(Node)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ErrFeatureNotExist, nil
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,405 @@
|
||||||
|
package cluster
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||||
|
"io"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SlaveNode struct {
|
||||||
|
Model *model.Node
|
||||||
|
Active bool
|
||||||
|
|
||||||
|
caller slaveCaller
|
||||||
|
callback func(bool, uint)
|
||||||
|
close chan bool
|
||||||
|
lock sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type slaveCaller struct {
|
||||||
|
parent *SlaveNode
|
||||||
|
Client request.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init 初始化节点
|
||||||
|
func (node *SlaveNode) Init(nodeModel *model.Node) {
|
||||||
|
node.lock.Lock()
|
||||||
|
defer node.lock.Unlock()
|
||||||
|
node.Model = nodeModel
|
||||||
|
|
||||||
|
// Init http request client
|
||||||
|
var endpoint *url.URL
|
||||||
|
if serverURL, err := url.Parse(node.Model.Server); err == nil {
|
||||||
|
var controller *url.URL
|
||||||
|
controller, _ = url.Parse("/api/v3/slave")
|
||||||
|
endpoint = serverURL.ResolveReference(controller)
|
||||||
|
}
|
||||||
|
|
||||||
|
signTTL := model.GetIntSetting("slave_api_timeout", 60)
|
||||||
|
node.caller.Client = request.NewClient(
|
||||||
|
request.WithMasterMeta(),
|
||||||
|
request.WithTimeout(time.Duration(signTTL)*time.Second),
|
||||||
|
request.WithCredential(auth.HMACAuth{SecretKey: []byte(nodeModel.SlaveKey)}, int64(signTTL)),
|
||||||
|
request.WithEndpoint(endpoint.String()),
|
||||||
|
)
|
||||||
|
|
||||||
|
node.caller.parent = node
|
||||||
|
node.Active = true
|
||||||
|
if node.close != nil {
|
||||||
|
node.close <- true
|
||||||
|
}
|
||||||
|
|
||||||
|
go node.StartPingLoop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsFeatureEnabled 查询节点的某项功能是否启用
|
||||||
|
func (node *SlaveNode) IsFeatureEnabled(feature string) bool {
|
||||||
|
node.lock.RLock()
|
||||||
|
defer node.lock.RUnlock()
|
||||||
|
|
||||||
|
switch feature {
|
||||||
|
case "aria2":
|
||||||
|
return node.Model.Aria2Enabled
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubscribeStatusChange 订阅节点状态更改
|
||||||
|
func (node *SlaveNode) SubscribeStatusChange(callback func(bool, uint)) {
|
||||||
|
node.lock.Lock()
|
||||||
|
node.callback = callback
|
||||||
|
node.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ping 从机节点,返回从机负载
|
||||||
|
func (node *SlaveNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
|
||||||
|
reqBodyEncoded, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyReader := strings.NewReader(string(reqBodyEncoded))
|
||||||
|
|
||||||
|
resp, err := node.caller.Client.Request(
|
||||||
|
"POST",
|
||||||
|
"heartbeat",
|
||||||
|
bodyReader,
|
||||||
|
).CheckHTTPResponse(200).DecodeResponse()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理列取结果
|
||||||
|
if resp.Code != 0 {
|
||||||
|
return nil, serializer.NewErrorFromResponse(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
var res serializer.NodePingResp
|
||||||
|
|
||||||
|
if resStr, ok := resp.Data.(string); ok {
|
||||||
|
err = json.Unmarshal([]byte(resStr), &res)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsActive 返回节点是否在线
|
||||||
|
func (node *SlaveNode) IsActive() bool {
|
||||||
|
node.lock.RLock()
|
||||||
|
defer node.lock.RUnlock()
|
||||||
|
|
||||||
|
return node.Active
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kill 结束节点内相关循环
|
||||||
|
func (node *SlaveNode) Kill() {
|
||||||
|
node.lock.RLock()
|
||||||
|
defer node.lock.RUnlock()
|
||||||
|
|
||||||
|
if node.close != nil {
|
||||||
|
close(node.close)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAria2Instance 获取从机Aria2实例
|
||||||
|
func (node *SlaveNode) GetAria2Instance() common.Aria2 {
|
||||||
|
node.lock.RLock()
|
||||||
|
defer node.lock.RUnlock()
|
||||||
|
|
||||||
|
if !node.Model.Aria2Enabled {
|
||||||
|
return &common.DummyAria2{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &node.caller
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *SlaveNode) ID() uint {
|
||||||
|
node.lock.RLock()
|
||||||
|
defer node.lock.RUnlock()
|
||||||
|
|
||||||
|
return node.Model.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *SlaveNode) StartPingLoop() {
|
||||||
|
node.lock.Lock()
|
||||||
|
node.close = make(chan bool)
|
||||||
|
node.lock.Unlock()
|
||||||
|
|
||||||
|
tickDuration := time.Duration(model.GetIntSetting("slave_ping_interval", 300)) * time.Second
|
||||||
|
recoverDuration := time.Duration(model.GetIntSetting("slave_recover_interval", 600)) * time.Second
|
||||||
|
pingTicker := time.Duration(0)
|
||||||
|
|
||||||
|
util.Log().Debug("从机节点 [%s] 启动心跳循环", node.Model.Name)
|
||||||
|
retry := 0
|
||||||
|
recoverMode := false
|
||||||
|
isFirstLoop := true
|
||||||
|
|
||||||
|
loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-time.After(pingTicker):
|
||||||
|
if pingTicker == 0 {
|
||||||
|
pingTicker = tickDuration
|
||||||
|
}
|
||||||
|
|
||||||
|
util.Log().Debug("从机节点 [%s] 发送Ping", node.Model.Name)
|
||||||
|
res, err := node.Ping(node.getHeartbeatContent(isFirstLoop))
|
||||||
|
isFirstLoop = false
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
util.Log().Debug("Ping从机节点 [%s] 时发生错误: %s", node.Model.Name, err)
|
||||||
|
retry++
|
||||||
|
if retry >= model.GetIntSetting("slave_node_retry", 3) {
|
||||||
|
util.Log().Debug("从机节点 [%s] Ping 重试已达到最大限制,将从机节点标记为不可用", node.Model.Name)
|
||||||
|
node.changeStatus(false)
|
||||||
|
|
||||||
|
if !recoverMode {
|
||||||
|
// 启动恢复监控循环
|
||||||
|
util.Log().Debug("从机节点 [%s] 进入恢复模式", node.Model.Name)
|
||||||
|
pingTicker = recoverDuration
|
||||||
|
recoverMode = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if recoverMode {
|
||||||
|
util.Log().Debug("从机节点 [%s] 复活", node.Model.Name)
|
||||||
|
pingTicker = tickDuration
|
||||||
|
recoverMode = false
|
||||||
|
isFirstLoop = true
|
||||||
|
}
|
||||||
|
|
||||||
|
util.Log().Debug("从机节点 [%s] 状态: %s", node.Model.Name, res)
|
||||||
|
node.changeStatus(true)
|
||||||
|
retry = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-node.close:
|
||||||
|
util.Log().Debug("从机节点 [%s] 收到关闭信号", node.Model.Name)
|
||||||
|
break loop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *SlaveNode) IsMater() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *SlaveNode) MasterAuthInstance() auth.Auth {
|
||||||
|
node.lock.RLock()
|
||||||
|
defer node.lock.RUnlock()
|
||||||
|
|
||||||
|
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *SlaveNode) SlaveAuthInstance() auth.Auth {
|
||||||
|
node.lock.RLock()
|
||||||
|
defer node.lock.RUnlock()
|
||||||
|
|
||||||
|
return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *SlaveNode) DBModel() *model.Node {
|
||||||
|
node.lock.RLock()
|
||||||
|
defer node.lock.RUnlock()
|
||||||
|
|
||||||
|
return node.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
// getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave
|
||||||
|
func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq {
|
||||||
|
return &serializer.NodePingReq{
|
||||||
|
SiteURL: model.GetSiteURL().String(),
|
||||||
|
IsUpdate: isUpdate,
|
||||||
|
SiteID: model.GetSettingByName("siteID"),
|
||||||
|
Node: node.Model,
|
||||||
|
CredentialTTL: model.GetIntSetting("slave_api_timeout", 60),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *SlaveNode) changeStatus(isActive bool) {
|
||||||
|
node.lock.RLock()
|
||||||
|
id := node.Model.ID
|
||||||
|
if isActive != node.Active {
|
||||||
|
node.lock.RUnlock()
|
||||||
|
node.lock.Lock()
|
||||||
|
node.Active = isActive
|
||||||
|
node.lock.Unlock()
|
||||||
|
node.callback(isActive, id)
|
||||||
|
} else {
|
||||||
|
node.lock.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *slaveCaller) Init() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendAria2Call send remote aria2 call to slave node
|
||||||
|
func (s *slaveCaller) SendAria2Call(body *serializer.SlaveAria2Call, scope string) (*serializer.Response, error) {
|
||||||
|
reqReader, err := getAria2RequestBody(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.Client.Request(
|
||||||
|
"POST",
|
||||||
|
"aria2/"+scope,
|
||||||
|
reqReader,
|
||||||
|
).CheckHTTPResponse(200).DecodeResponse()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *slaveCaller) CreateTask(task *model.Download, options map[string]interface{}) (string, error) {
|
||||||
|
s.parent.lock.RLock()
|
||||||
|
defer s.parent.lock.RUnlock()
|
||||||
|
|
||||||
|
req := &serializer.SlaveAria2Call{
|
||||||
|
Task: task,
|
||||||
|
GroupOptions: options,
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := s.SendAria2Call(req, "task")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != 0 {
|
||||||
|
return "", serializer.NewErrorFromResponse(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res.Data.(string), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *slaveCaller) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||||
|
s.parent.lock.RLock()
|
||||||
|
defer s.parent.lock.RUnlock()
|
||||||
|
|
||||||
|
req := &serializer.SlaveAria2Call{
|
||||||
|
Task: task,
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := s.SendAria2Call(req, "status")
|
||||||
|
if err != nil {
|
||||||
|
return rpc.StatusInfo{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != 0 {
|
||||||
|
return rpc.StatusInfo{}, serializer.NewErrorFromResponse(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
var status rpc.StatusInfo
|
||||||
|
res.GobDecode(&status)
|
||||||
|
|
||||||
|
return status, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *slaveCaller) Cancel(task *model.Download) error {
|
||||||
|
s.parent.lock.RLock()
|
||||||
|
defer s.parent.lock.RUnlock()
|
||||||
|
|
||||||
|
req := &serializer.SlaveAria2Call{
|
||||||
|
Task: task,
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := s.SendAria2Call(req, "cancel")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != 0 {
|
||||||
|
return serializer.NewErrorFromResponse(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *slaveCaller) Select(task *model.Download, files []int) error {
|
||||||
|
s.parent.lock.RLock()
|
||||||
|
defer s.parent.lock.RUnlock()
|
||||||
|
|
||||||
|
req := &serializer.SlaveAria2Call{
|
||||||
|
Task: task,
|
||||||
|
Files: files,
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := s.SendAria2Call(req, "select")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != 0 {
|
||||||
|
return serializer.NewErrorFromResponse(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *slaveCaller) GetConfig() model.Aria2Option {
|
||||||
|
s.parent.lock.RLock()
|
||||||
|
defer s.parent.lock.RUnlock()
|
||||||
|
|
||||||
|
return s.parent.Model.Aria2OptionsSerialized
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *slaveCaller) DeleteTempFile(task *model.Download) error {
|
||||||
|
s.parent.lock.RLock()
|
||||||
|
defer s.parent.lock.RUnlock()
|
||||||
|
|
||||||
|
req := &serializer.SlaveAria2Call{
|
||||||
|
Task: task,
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := s.SendAria2Call(req, "delete")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != 0 {
|
||||||
|
return serializer.NewErrorFromResponse(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) {
|
||||||
|
reqBodyEncoded, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.NewReader(string(reqBodyEncoded)), nil
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
package driver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
|
"io"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Handler 存储策略适配器
|
||||||
|
type Handler interface {
|
||||||
|
// 上传文件, dst为文件存储路径,size 为文件大小。上下文关闭
|
||||||
|
// 时,应取消上传并清理临时文件
|
||||||
|
Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error
|
||||||
|
|
||||||
|
// 删除一个或多个给定路径的文件,返回删除失败的文件路径列表及错误
|
||||||
|
Delete(ctx context.Context, files []string) ([]string, error)
|
||||||
|
|
||||||
|
// 获取文件内容
|
||||||
|
Get(ctx context.Context, path string) (response.RSCloser, error)
|
||||||
|
|
||||||
|
// 获取缩略图,可直接在ContentResponse中返回文件数据流,也可指
|
||||||
|
// 定为重定向
|
||||||
|
Thumb(ctx context.Context, path string) (*response.ContentResponse, error)
|
||||||
|
|
||||||
|
// 获取外链/下载地址,
|
||||||
|
// url - 站点本身地址,
|
||||||
|
// isDownload - 是否直接下载
|
||||||
|
Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error)
|
||||||
|
|
||||||
|
// Token 获取有效期为ttl的上传凭证和签名,同时回调会话有效期为sessionTTL
|
||||||
|
Token(ctx context.Context, ttl int64, callbackKey string) (serializer.UploadCredential, error)
|
||||||
|
|
||||||
|
// List 递归列取远程端path路径下文件、目录,不包含path本身,
|
||||||
|
// 返回的对象路径以path作为起始根目录.
|
||||||
|
// recursive - 是否递归列出
|
||||||
|
List(ctx context.Context, path string, recursive bool) ([]response.Object, error)
|
||||||
|
}
|
||||||
|
|
@ -55,7 +55,7 @@ func NewClient(policy *model.Policy) (*Client, error) {
|
||||||
ClientID: policy.BucketName,
|
ClientID: policy.BucketName,
|
||||||
ClientSecret: policy.SecretKey,
|
ClientSecret: policy.SecretKey,
|
||||||
Redirect: policy.OptionsSerialized.OdRedirect,
|
Redirect: policy.OptionsSerialized.OdRedirect,
|
||||||
Request: request.HTTPClient{},
|
Request: request.NewClient(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if client.Endpoints.DriverResource == "" {
|
if client.Endpoints.DriverResource == "" {
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ import (
|
||||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||||
|
|
@ -27,6 +28,16 @@ type Driver struct {
|
||||||
HTTPClient request.Client
|
HTTPClient request.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewDriver 从存储策略初始化新的Driver实例
|
||||||
|
func NewDriver(policy *model.Policy) (driver.Handler, error) {
|
||||||
|
client, err := NewClient(policy)
|
||||||
|
return Driver{
|
||||||
|
Policy: policy,
|
||||||
|
Client: client,
|
||||||
|
HTTPClient: request.NewClient(),
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
|
||||||
// List 列取项目
|
// List 列取项目
|
||||||
func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
|
func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
|
||||||
base = strings.TrimPrefix(base, "/")
|
base = strings.TrimPrefix(base, "/")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
package onedrive
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
// CredentialLock 针对存储策略凭证的锁
|
||||||
|
type CredentialLock interface {
|
||||||
|
Lock(uint)
|
||||||
|
Unlock(uint)
|
||||||
|
}
|
||||||
|
|
||||||
|
var GlobalMutex = mutexMap{}
|
||||||
|
|
||||||
|
type mutexMap struct {
|
||||||
|
locks sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mutexMap) Lock(id uint) {
|
||||||
|
lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{})
|
||||||
|
lock.(*sync.Mutex).Lock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mutexMap) Unlock(id uint) {
|
||||||
|
lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{})
|
||||||
|
lock.(*sync.Mutex).Unlock()
|
||||||
|
}
|
||||||
|
|
@ -10,7 +10,9 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -124,6 +126,13 @@ func (client *Client) ObtainToken(ctx context.Context, opts ...Option) (*Credent
|
||||||
|
|
||||||
// UpdateCredential 更新凭证,并检查有效期
|
// UpdateCredential 更新凭证,并检查有效期
|
||||||
func (client *Client) UpdateCredential(ctx context.Context) error {
|
func (client *Client) UpdateCredential(ctx context.Context) error {
|
||||||
|
if conf.SystemConfig.Mode == "slave" {
|
||||||
|
return client.fetchCredentialFromMaster(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
GlobalMutex.Lock(client.Policy.ID)
|
||||||
|
defer GlobalMutex.Unlock(client.Policy.ID)
|
||||||
|
|
||||||
// 如果已存在凭证
|
// 如果已存在凭证
|
||||||
if client.Credential != nil && client.Credential.AccessToken != "" {
|
if client.Credential != nil && client.Credential.AccessToken != "" {
|
||||||
// 检查已有凭证是否过期
|
// 检查已有凭证是否过期
|
||||||
|
|
@ -160,11 +169,21 @@ func (client *Client) UpdateCredential(ctx context.Context) error {
|
||||||
client.Credential = credential
|
client.Credential = credential
|
||||||
|
|
||||||
// 更新存储策略的 RefreshToken
|
// 更新存储策略的 RefreshToken
|
||||||
client.Policy.AccessKey = credential.RefreshToken
|
client.Policy.UpdateAccessKeyAndClearCache(credential.RefreshToken)
|
||||||
client.Policy.SaveAndClearCache()
|
|
||||||
|
|
||||||
// 更新缓存
|
// 更新缓存
|
||||||
cache.Set("onedrive_"+client.ClientID, *credential, int(expires))
|
cache.Set("onedrive_"+client.ClientID, *credential, int(expires))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateCredential 更新凭证,并检查有效期
|
||||||
|
func (client *Client) fetchCredentialFromMaster(ctx context.Context) error {
|
||||||
|
res, err := slave.DefaultController.GetOneDriveToken(client.Policy.MasterID, client.Policy.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
client.Credential = &Credential{AccessToken: res}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ func GetPublicKey(r *http.Request) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取公钥
|
// 获取公钥
|
||||||
client := request.HTTPClient{}
|
client := request.NewClient()
|
||||||
body, err := client.Request("GET", string(pubURL), nil).
|
body, err := client.Request("GET", string(pubURL), nil).
|
||||||
CheckHTTPResponse(200).
|
CheckHTTPResponse(200).
|
||||||
GetResponse()
|
GetResponse()
|
||||||
|
|
|
||||||
|
|
@ -292,7 +292,7 @@ func TestDriver_Get(t *testing.T) {
|
||||||
BucketName: "test",
|
BucketName: "test",
|
||||||
Server: "oss-cn-shanghai.aliyuncs.com",
|
Server: "oss-cn-shanghai.aliyuncs.com",
|
||||||
},
|
},
|
||||||
HTTPClient: request.HTTPClient{},
|
HTTPClient: request.NewClient(),
|
||||||
}
|
}
|
||||||
cache.Set("setting_preview_timeout", "3600", 0)
|
cache.Set("setting_preview_timeout", "3600", 0)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ func (handler Driver) List(ctx context.Context, path string, recursive bool) ([]
|
||||||
handler.getAPIUrl("list"),
|
handler.getAPIUrl("list"),
|
||||||
bodyReader,
|
bodyReader,
|
||||||
request.WithCredential(handler.AuthInstance, int64(signTTL)),
|
request.WithCredential(handler.AuthInstance, int64(signTTL)),
|
||||||
|
request.WithMasterMeta(),
|
||||||
).CheckHTTPResponse(200).DecodeResponse()
|
).CheckHTTPResponse(200).DecodeResponse()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return res, err
|
return res, err
|
||||||
|
|
@ -97,7 +98,7 @@ func (handler Driver) getAPIUrl(scope string, routes ...string) string {
|
||||||
|
|
||||||
// Get 获取文件内容
|
// Get 获取文件内容
|
||||||
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
|
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
|
||||||
// 尝试获取速度限制 TODO 是否需要在这里限制?
|
// 尝试获取速度限制
|
||||||
speedLimit := 0
|
speedLimit := 0
|
||||||
if user, ok := ctx.Value(fsctx.UserCtx).(model.User); ok {
|
if user, ok := ctx.Value(fsctx.UserCtx).(model.User); ok {
|
||||||
speedLimit = user.Group.SpeedLimit
|
speedLimit = user.Group.SpeedLimit
|
||||||
|
|
@ -116,6 +117,7 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser,
|
||||||
nil,
|
nil,
|
||||||
request.WithContext(ctx),
|
request.WithContext(ctx),
|
||||||
request.WithTimeout(time.Duration(0)),
|
request.WithTimeout(time.Duration(0)),
|
||||||
|
request.WithMasterMeta(),
|
||||||
).CheckHTTPResponse(200).GetRSCloser()
|
).CheckHTTPResponse(200).GetRSCloser()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -168,13 +170,15 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s
|
||||||
handler.Policy.GetUploadURL(),
|
handler.Policy.GetUploadURL(),
|
||||||
file,
|
file,
|
||||||
request.WithHeader(map[string][]string{
|
request.WithHeader(map[string][]string{
|
||||||
"Authorization": {credential.Token},
|
|
||||||
"X-Policy": {credential.Policy},
|
"X-Policy": {credential.Policy},
|
||||||
"X-FileName": {fileName},
|
"X-FileName": {fileName},
|
||||||
"X-Overwrite": {overwrite},
|
"X-Overwrite": {overwrite},
|
||||||
}),
|
}),
|
||||||
request.WithContentLength(int64(size)),
|
request.WithContentLength(int64(size)),
|
||||||
request.WithTimeout(time.Duration(0)),
|
request.WithTimeout(time.Duration(0)),
|
||||||
|
request.WithMasterMeta(),
|
||||||
|
request.WithSlaveMeta(handler.Policy.AccessKey),
|
||||||
|
request.WithCredential(handler.AuthInstance, int64(credentialTTL)),
|
||||||
).CheckHTTPResponse(200).DecodeResponse()
|
).CheckHTTPResponse(200).DecodeResponse()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -206,6 +210,8 @@ func (handler Driver) Delete(ctx context.Context, files []string) ([]string, err
|
||||||
handler.getAPIUrl("delete"),
|
handler.getAPIUrl("delete"),
|
||||||
bodyReader,
|
bodyReader,
|
||||||
request.WithCredential(handler.AuthInstance, int64(signTTL)),
|
request.WithCredential(handler.AuthInstance, int64(signTTL)),
|
||||||
|
request.WithMasterMeta(),
|
||||||
|
request.WithSlaveMeta(handler.Policy.AccessKey),
|
||||||
).CheckHTTPResponse(200).GetResponse()
|
).CheckHTTPResponse(200).GetResponse()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return files, err
|
return files, err
|
||||||
|
|
|
||||||
|
|
@ -172,7 +172,7 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取文件数据流
|
// 获取文件数据流
|
||||||
client := request.HTTPClient{}
|
client := request.NewClient()
|
||||||
resp, err := client.Request(
|
resp, err := client.Request(
|
||||||
"GET",
|
"GET",
|
||||||
downloadURL,
|
downloadURL,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
package masterinslave
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNotImplemented = errors.New("this method of shadowed policy is not implemented")
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,56 @@
|
||||||
|
package masterinslave
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
|
"io"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Driver 影子存储策略,用于在从机端上传文件
|
||||||
|
type Driver struct {
|
||||||
|
master cluster.Node
|
||||||
|
handler driver.Handler
|
||||||
|
policy *model.Policy
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDriver 返回新的处理器
|
||||||
|
func NewDriver(master cluster.Node, handler driver.Handler, policy *model.Policy) driver.Handler {
|
||||||
|
return &Driver{
|
||||||
|
master: master,
|
||||||
|
handler: handler,
|
||||||
|
policy: policy,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error {
|
||||||
|
return d.handler.Put(ctx, file, dst, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
|
||||||
|
return d.handler.Delete(ctx, files)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
|
||||||
|
return nil, ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Driver) Thumb(ctx context.Context, path string) (*response.ContentResponse, error) {
|
||||||
|
return nil, ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Driver) Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error) {
|
||||||
|
return "", ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Driver) Token(ctx context.Context, ttl int64, callbackKey string) (serializer.UploadCredential, error) {
|
||||||
|
return serializer.UploadCredential{}, ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
|
||||||
|
return nil, ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
package slaveinmaster
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNotImplemented = errors.New("this method of shadowed policy is not implemented")
|
||||||
|
ErrSlaveSrcPathNotExist = errors.New("cannot determine source file path in slave node")
|
||||||
|
ErrWaitResultTimeout = errors.New("timeout waiting for slave transfer result")
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,121 @@
|
||||||
|
package slaveinmaster
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
|
"io"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Driver 影子存储策略,将上传任务指派给从机节点处理,并等待从机通知上传结果
|
||||||
|
type Driver struct {
|
||||||
|
node cluster.Node
|
||||||
|
handler driver.Handler
|
||||||
|
policy *model.Policy
|
||||||
|
client request.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDriver 返回新的从机指派处理器
|
||||||
|
func NewDriver(node cluster.Node, handler driver.Handler, policy *model.Policy) driver.Handler {
|
||||||
|
var endpoint *url.URL
|
||||||
|
if serverURL, err := url.Parse(node.DBModel().Server); err == nil {
|
||||||
|
var controller *url.URL
|
||||||
|
controller, _ = url.Parse("/api/v3/slave")
|
||||||
|
endpoint = serverURL.ResolveReference(controller)
|
||||||
|
}
|
||||||
|
|
||||||
|
signTTL := model.GetIntSetting("slave_api_timeout", 60)
|
||||||
|
return &Driver{
|
||||||
|
node: node,
|
||||||
|
handler: handler,
|
||||||
|
policy: policy,
|
||||||
|
client: request.NewClient(
|
||||||
|
request.WithMasterMeta(),
|
||||||
|
request.WithTimeout(time.Duration(signTTL)*time.Second),
|
||||||
|
request.WithCredential(node.SlaveAuthInstance(), int64(signTTL)),
|
||||||
|
request.WithEndpoint(endpoint.String()),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put 将ctx中指定的从机物理文件由从机上传到目标存储策略
|
||||||
|
func (d *Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error {
|
||||||
|
src, ok := ctx.Value(fsctx.SlaveSrcPath).(string)
|
||||||
|
if !ok {
|
||||||
|
return ErrSlaveSrcPathNotExist
|
||||||
|
}
|
||||||
|
|
||||||
|
req := serializer.SlaveTransferReq{
|
||||||
|
Src: src,
|
||||||
|
Dst: dst,
|
||||||
|
Policy: d.policy,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 订阅转存结果
|
||||||
|
resChan := mq.GlobalMQ.Subscribe(req.Hash(model.GetSettingByName("siteID")), 0)
|
||||||
|
defer mq.GlobalMQ.Unsubscribe(req.Hash(model.GetSettingByName("siteID")), resChan)
|
||||||
|
|
||||||
|
res, err := d.client.Request("PUT", "task/transfer", bytes.NewReader(body)).
|
||||||
|
CheckHTTPResponse(200).
|
||||||
|
DecodeResponse()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != 0 {
|
||||||
|
return serializer.NewErrorFromResponse(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 等待转存结果或者超时
|
||||||
|
waitTimeout := model.GetIntSetting("slave_transfer_timeout", 172800)
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Duration(waitTimeout) * time.Second):
|
||||||
|
return ErrWaitResultTimeout
|
||||||
|
case msg := <-resChan:
|
||||||
|
if msg.Event != serializer.SlaveTransferSuccess {
|
||||||
|
return errors.New(msg.Content.(serializer.SlaveTransferResult).Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
|
||||||
|
return d.handler.Delete(ctx, files)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
|
||||||
|
return nil, ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Driver) Thumb(ctx context.Context, path string) (*response.ContentResponse, error) {
|
||||||
|
return nil, ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Driver) Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error) {
|
||||||
|
return "", ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Driver) Token(ctx context.Context, ttl int64, callbackKey string) (serializer.UploadCredential, error) {
|
||||||
|
return serializer.UploadCredential{}, ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
|
||||||
|
return nil, ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
@ -1,8 +1,12 @@
|
||||||
package filesystem
|
package filesystem
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/masterinslave"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/slaveinmaster"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
@ -19,7 +23,6 @@ import (
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/remote"
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/remote"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/s3"
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/s3"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/upyun"
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/upyun"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
@ -43,36 +46,6 @@ type FileHeader interface {
|
||||||
GetVirtualPath() string
|
GetVirtualPath() string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handler 存储策略适配器
|
|
||||||
type Handler interface {
|
|
||||||
// 上传文件, dst为文件存储路径,size 为文件大小。上下文关闭
|
|
||||||
// 时,应取消上传并清理临时文件
|
|
||||||
Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error
|
|
||||||
|
|
||||||
// 删除一个或多个给定路径的文件,返回删除失败的文件路径列表及错误
|
|
||||||
Delete(ctx context.Context, files []string) ([]string, error)
|
|
||||||
|
|
||||||
// 获取文件内容
|
|
||||||
Get(ctx context.Context, path string) (response.RSCloser, error)
|
|
||||||
|
|
||||||
// 获取缩略图,可直接在ContentResponse中返回文件数据流,也可指
|
|
||||||
// 定为重定向
|
|
||||||
Thumb(ctx context.Context, path string) (*response.ContentResponse, error)
|
|
||||||
|
|
||||||
// 获取外链/下载地址,
|
|
||||||
// url - 站点本身地址,
|
|
||||||
// isDownload - 是否直接下载
|
|
||||||
Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error)
|
|
||||||
|
|
||||||
// Token 获取有效期为ttl的上传凭证和签名,同时回调会话有效期为sessionTTL
|
|
||||||
Token(ctx context.Context, ttl int64, callbackKey string) (serializer.UploadCredential, error)
|
|
||||||
|
|
||||||
// List 递归列取远程端path路径下文件、目录,不包含path本身,
|
|
||||||
// 返回的对象路径以path作为起始根目录.
|
|
||||||
// recursive - 是否递归列出
|
|
||||||
List(ctx context.Context, path string, recursive bool) ([]response.Object, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileSystem 管理文件的文件系统
|
// FileSystem 管理文件的文件系统
|
||||||
type FileSystem struct {
|
type FileSystem struct {
|
||||||
// 文件系统所有者
|
// 文件系统所有者
|
||||||
|
|
@ -96,7 +69,7 @@ type FileSystem struct {
|
||||||
/*
|
/*
|
||||||
文件系统处理适配器
|
文件系统处理适配器
|
||||||
*/
|
*/
|
||||||
Handler Handler
|
Handler driver.Handler
|
||||||
|
|
||||||
// 回收锁
|
// 回收锁
|
||||||
recycleLock sync.Mutex
|
recycleLock sync.Mutex
|
||||||
|
|
@ -134,7 +107,6 @@ func NewFileSystem(user *model.User) (*FileSystem, error) {
|
||||||
// 分配存储策略适配器
|
// 分配存储策略适配器
|
||||||
err := fs.DispatchHandler()
|
err := fs.DispatchHandler()
|
||||||
|
|
||||||
// TODO 分配默认钩子
|
|
||||||
return fs, err
|
return fs, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -159,7 +131,6 @@ func NewAnonymousFileSystem() (*FileSystem, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DispatchHandler 根据存储策略分配文件适配器
|
// DispatchHandler 根据存储策略分配文件适配器
|
||||||
// TODO 完善测试
|
|
||||||
func (fs *FileSystem) DispatchHandler() error {
|
func (fs *FileSystem) DispatchHandler() error {
|
||||||
var policyType string
|
var policyType string
|
||||||
var currentPolicy *model.Policy
|
var currentPolicy *model.Policy
|
||||||
|
|
@ -184,7 +155,7 @@ func (fs *FileSystem) DispatchHandler() error {
|
||||||
case "remote":
|
case "remote":
|
||||||
fs.Handler = remote.Driver{
|
fs.Handler = remote.Driver{
|
||||||
Policy: currentPolicy,
|
Policy: currentPolicy,
|
||||||
Client: request.HTTPClient{},
|
Client: request.NewClient(),
|
||||||
AuthInstance: auth.HMACAuth{[]byte(currentPolicy.SecretKey)},
|
AuthInstance: auth.HMACAuth{[]byte(currentPolicy.SecretKey)},
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -196,7 +167,7 @@ func (fs *FileSystem) DispatchHandler() error {
|
||||||
case "oss":
|
case "oss":
|
||||||
fs.Handler = oss.Driver{
|
fs.Handler = oss.Driver{
|
||||||
Policy: currentPolicy,
|
Policy: currentPolicy,
|
||||||
HTTPClient: request.HTTPClient{},
|
HTTPClient: request.NewClient(),
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
case "upyun":
|
case "upyun":
|
||||||
|
|
@ -205,13 +176,9 @@ func (fs *FileSystem) DispatchHandler() error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
case "onedrive":
|
case "onedrive":
|
||||||
client, err := onedrive.NewClient(currentPolicy)
|
var odErr error
|
||||||
fs.Handler = onedrive.Driver{
|
fs.Handler, odErr = onedrive.NewDriver(currentPolicy)
|
||||||
Policy: currentPolicy,
|
return odErr
|
||||||
Client: client,
|
|
||||||
HTTPClient: request.HTTPClient{},
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
case "cos":
|
case "cos":
|
||||||
u, _ := url.Parse(currentPolicy.Server)
|
u, _ := url.Parse(currentPolicy.Server)
|
||||||
b := &cossdk.BaseURL{BucketURL: u}
|
b := &cossdk.BaseURL{BucketURL: u}
|
||||||
|
|
@ -223,7 +190,7 @@ func (fs *FileSystem) DispatchHandler() error {
|
||||||
SecretKey: currentPolicy.SecretKey,
|
SecretKey: currentPolicy.SecretKey,
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
HTTPClient: request.HTTPClient{},
|
HTTPClient: request.NewClient(),
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
case "s3":
|
case "s3":
|
||||||
|
|
@ -272,6 +239,30 @@ func NewFileSystemFromCallback(c *gin.Context) (*FileSystem, error) {
|
||||||
return fs, err
|
return fs, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SwitchToSlaveHandler 将负责上传的 Handler 切换为从机节点
|
||||||
|
func (fs *FileSystem) SwitchToSlaveHandler(node cluster.Node) {
|
||||||
|
fs.Handler = slaveinmaster.NewDriver(node, fs.Handler, &fs.User.Policy)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SwitchToShadowHandler 将负责上传的 Handler 切换为从机节点转存使用的影子处理器
|
||||||
|
func (fs *FileSystem) SwitchToShadowHandler(master cluster.Node, masterURL, masterID string) {
|
||||||
|
switch fs.Policy.Type {
|
||||||
|
case "remote":
|
||||||
|
fs.Policy.Type = "local"
|
||||||
|
fs.DispatchHandler()
|
||||||
|
case "local":
|
||||||
|
fs.Policy.Type = "remote"
|
||||||
|
fs.Policy.Server = masterURL
|
||||||
|
fs.Policy.AccessKey = fmt.Sprintf("%d", master.ID())
|
||||||
|
fs.Policy.SecretKey = master.DBModel().MasterKey
|
||||||
|
fs.DispatchHandler()
|
||||||
|
case "onedrive":
|
||||||
|
fs.Policy.MasterID = masterID
|
||||||
|
}
|
||||||
|
|
||||||
|
fs.Handler = masterinslave.NewDriver(master, fs.Handler, fs.Policy)
|
||||||
|
}
|
||||||
|
|
||||||
// SetTargetFile 设置当前处理的目标文件
|
// SetTargetFile 设置当前处理的目标文件
|
||||||
func (fs *FileSystem) SetTargetFile(files *[]model.File) {
|
func (fs *FileSystem) SetTargetFile(files *[]model.File) {
|
||||||
if len(fs.FileTarget) == 0 {
|
if len(fs.FileTarget) == 0 {
|
||||||
|
|
|
||||||
|
|
@ -41,4 +41,6 @@ const (
|
||||||
ValidateCapacityOnceCtx
|
ValidateCapacityOnceCtx
|
||||||
// 禁止上传时同名覆盖操作
|
// 禁止上传时同名覆盖操作
|
||||||
DisableOverwrite
|
DisableOverwrite
|
||||||
|
// 文件在从机节点中的路径
|
||||||
|
SlaveSrcPath
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -228,13 +228,15 @@ func (fs *FileSystem) UploadFromStream(ctx context.Context, src io.ReadCloser, d
|
||||||
}
|
}
|
||||||
|
|
||||||
// UploadFromPath 将本机已有文件上传到用户的文件系统
|
// UploadFromPath 将本机已有文件上传到用户的文件系统
|
||||||
func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string) error {
|
func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string, resetPolicy bool) error {
|
||||||
// 重设存储策略
|
// 重设存储策略
|
||||||
|
if resetPolicy {
|
||||||
fs.Policy = &fs.User.Policy
|
fs.Policy = &fs.User.Policy
|
||||||
err := fs.DispatchHandler()
|
err := fs.DispatchHandler()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
file, err := os.Open(util.RelativePath(src))
|
file, err := os.Open(util.RelativePath(src))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -226,13 +226,13 @@ func TestFileSystem_UploadFromPath(t *testing.T) {
|
||||||
|
|
||||||
// 文件不存在
|
// 文件不存在
|
||||||
{
|
{
|
||||||
err := fs.UploadFromPath(ctx, "test/not_exist", "/")
|
err := fs.UploadFromPath(ctx, "test/not_exist", "/", true)
|
||||||
asserts.Error(err)
|
asserts.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 文存在,上传失败
|
// 文存在,上传失败
|
||||||
{
|
{
|
||||||
err := fs.UploadFromPath(ctx, "tests/test.zip", "/")
|
err := fs.UploadFromPath(ctx, "tests/test.zip", "/", true)
|
||||||
asserts.Error(err)
|
asserts.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,160 @@
|
||||||
|
package mq
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/gob"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Message 消息事件正文
|
||||||
|
type Message struct {
|
||||||
|
// 消息触发者
|
||||||
|
TriggeredBy string
|
||||||
|
|
||||||
|
// 事件标识
|
||||||
|
Event string
|
||||||
|
|
||||||
|
// 消息正文
|
||||||
|
Content interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type CallbackFunc func(Message)
|
||||||
|
|
||||||
|
// MQ 消息队列
|
||||||
|
type MQ interface {
|
||||||
|
rpc.Notifier
|
||||||
|
|
||||||
|
// 发布一个消息
|
||||||
|
Publish(string, Message)
|
||||||
|
|
||||||
|
// 订阅一个消息主题
|
||||||
|
Subscribe(string, int) <-chan Message
|
||||||
|
|
||||||
|
// 订阅一个消息主题,注册触发回调函数
|
||||||
|
SubscribeCallback(string, CallbackFunc)
|
||||||
|
|
||||||
|
// 取消订阅一个消息主题
|
||||||
|
Unsubscribe(string, <-chan Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
var GlobalMQ = NewMQ()
|
||||||
|
|
||||||
|
func NewMQ() MQ {
|
||||||
|
return &inMemoryMQ{
|
||||||
|
topics: make(map[string][]chan Message),
|
||||||
|
callbacks: make(map[string][]CallbackFunc),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
gob.Register(Message{})
|
||||||
|
gob.Register([]rpc.Event{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type inMemoryMQ struct {
|
||||||
|
topics map[string][]chan Message
|
||||||
|
callbacks map[string][]CallbackFunc
|
||||||
|
sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *inMemoryMQ) Publish(topic string, message Message) {
|
||||||
|
i.RLock()
|
||||||
|
subscribersChan, okChan := i.topics[topic]
|
||||||
|
subscribersCallback, okCallback := i.callbacks[topic]
|
||||||
|
i.RUnlock()
|
||||||
|
|
||||||
|
if okChan {
|
||||||
|
go func(subscribersChan []chan Message) {
|
||||||
|
for i := 0; i < len(subscribersChan); i++ {
|
||||||
|
select {
|
||||||
|
case subscribersChan[i] <- message:
|
||||||
|
case <-time.After(time.Millisecond * 500):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(subscribersChan)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if okCallback {
|
||||||
|
for i := 0; i < len(subscribersCallback); i++ {
|
||||||
|
go subscribersCallback[i](message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *inMemoryMQ) Subscribe(topic string, buffer int) <-chan Message {
|
||||||
|
ch := make(chan Message, buffer)
|
||||||
|
i.Lock()
|
||||||
|
i.topics[topic] = append(i.topics[topic], ch)
|
||||||
|
i.Unlock()
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *inMemoryMQ) SubscribeCallback(topic string, callbackFunc CallbackFunc) {
|
||||||
|
i.Lock()
|
||||||
|
i.callbacks[topic] = append(i.callbacks[topic], callbackFunc)
|
||||||
|
i.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *inMemoryMQ) Unsubscribe(topic string, sub <-chan Message) {
|
||||||
|
i.Lock()
|
||||||
|
defer i.Unlock()
|
||||||
|
|
||||||
|
subscribers, ok := i.topics[topic]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var newSubs []chan Message
|
||||||
|
for _, subscriber := range subscribers {
|
||||||
|
if subscriber == sub {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newSubs = append(newSubs, subscriber)
|
||||||
|
}
|
||||||
|
|
||||||
|
i.topics[topic] = newSubs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *inMemoryMQ) Aria2Notify(events []rpc.Event, status int) {
|
||||||
|
for _, event := range events {
|
||||||
|
i.Publish(event.Gid, Message{
|
||||||
|
TriggeredBy: event.Gid,
|
||||||
|
Event: strconv.FormatInt(int64(status), 10),
|
||||||
|
Content: events,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnDownloadStart 下载开始
|
||||||
|
func (i *inMemoryMQ) OnDownloadStart(events []rpc.Event) {
|
||||||
|
i.Aria2Notify(events, common.Downloading)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnDownloadPause 下载暂停
|
||||||
|
func (i *inMemoryMQ) OnDownloadPause(events []rpc.Event) {
|
||||||
|
i.Aria2Notify(events, common.Paused)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnDownloadStop 下载停止
|
||||||
|
func (i *inMemoryMQ) OnDownloadStop(events []rpc.Event) {
|
||||||
|
i.Aria2Notify(events, common.Canceled)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnDownloadComplete 下载完成
|
||||||
|
func (i *inMemoryMQ) OnDownloadComplete(events []rpc.Event) {
|
||||||
|
i.Aria2Notify(events, common.Complete)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnDownloadError 下载出错
|
||||||
|
func (i *inMemoryMQ) OnDownloadError(events []rpc.Event) {
|
||||||
|
i.Aria2Notify(events, common.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnBtDownloadComplete BT下载完成
|
||||||
|
func (i *inMemoryMQ) OnBtDownloadComplete(events []rpc.Event) {
|
||||||
|
i.Aria2Notify(events, common.Complete)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,149 @@
|
||||||
|
package mq
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPublishAndSubscribe(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
asserts := assert.New(t)
|
||||||
|
mq := NewMQ()
|
||||||
|
|
||||||
|
// No subscriber
|
||||||
|
{
|
||||||
|
asserts.NotPanics(func() {
|
||||||
|
mq.Publish("No subscriber", Message{})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// One channel subscriber
|
||||||
|
{
|
||||||
|
topic := "One channel subscriber"
|
||||||
|
msg := Message{TriggeredBy: "Tester"}
|
||||||
|
notifier := mq.Subscribe(topic, 0)
|
||||||
|
mq.Publish(topic, msg)
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
wg.Done()
|
||||||
|
msgRecv := <-notifier
|
||||||
|
asserts.Equal(msg, msgRecv)
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// two channel subscriber
|
||||||
|
{
|
||||||
|
topic := "two channel subscriber"
|
||||||
|
msg := Message{TriggeredBy: "Tester"}
|
||||||
|
notifier := mq.Subscribe(topic, 0)
|
||||||
|
notifier2 := mq.Subscribe(topic, 0)
|
||||||
|
mq.Publish(topic, msg)
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
wg.Done()
|
||||||
|
msgRecv := <-notifier
|
||||||
|
asserts.Equal(msg, msgRecv)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
wg.Done()
|
||||||
|
msgRecv := <-notifier2
|
||||||
|
asserts.Equal(msg, msgRecv)
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// two channel subscriber, one timeout
|
||||||
|
{
|
||||||
|
topic := "two channel subscriber, one timeout"
|
||||||
|
msg := Message{TriggeredBy: "Tester"}
|
||||||
|
mq.Subscribe(topic, 0)
|
||||||
|
notifier2 := mq.Subscribe(topic, 0)
|
||||||
|
mq.Publish(topic, msg)
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
wg.Done()
|
||||||
|
msgRecv := <-notifier2
|
||||||
|
asserts.Equal(msg, msgRecv)
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// two channel subscriber, one unsubscribe
|
||||||
|
{
|
||||||
|
topic := "two channel subscriber, one unsubscribe"
|
||||||
|
msg := Message{TriggeredBy: "Tester"}
|
||||||
|
mq.Subscribe(topic, 0)
|
||||||
|
notifier2 := mq.Subscribe(topic, 0)
|
||||||
|
notifier := mq.Subscribe(topic, 0)
|
||||||
|
mq.Unsubscribe(topic, notifier)
|
||||||
|
mq.Publish(topic, msg)
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
wg.Done()
|
||||||
|
msgRecv := <-notifier2
|
||||||
|
asserts.Equal(msg, msgRecv)
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-notifier:
|
||||||
|
t.Error()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAria2Interface(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
asserts := assert.New(t)
|
||||||
|
mq := NewMQ()
|
||||||
|
var (
|
||||||
|
OnDownloadStart int
|
||||||
|
OnDownloadPause int
|
||||||
|
OnDownloadStop int
|
||||||
|
OnDownloadComplete int
|
||||||
|
OnDownloadError int
|
||||||
|
)
|
||||||
|
l := sync.Mutex{}
|
||||||
|
|
||||||
|
mq.SubscribeCallback("TestAria2Interface", func(message Message) {
|
||||||
|
asserts.Equal("TestAria2Interface", message.TriggeredBy)
|
||||||
|
l.Lock()
|
||||||
|
defer l.Unlock()
|
||||||
|
switch message.Event {
|
||||||
|
case "1":
|
||||||
|
OnDownloadStart++
|
||||||
|
case "2":
|
||||||
|
OnDownloadPause++
|
||||||
|
case "5":
|
||||||
|
OnDownloadStop++
|
||||||
|
case "4":
|
||||||
|
OnDownloadComplete++
|
||||||
|
case "3":
|
||||||
|
OnDownloadError++
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
mq.OnDownloadStart([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
|
||||||
|
mq.OnDownloadPause([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
|
||||||
|
mq.OnDownloadStop([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
|
||||||
|
mq.OnDownloadComplete([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
|
||||||
|
mq.OnDownloadError([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
|
||||||
|
mq.OnBtDownloadComplete([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
|
||||||
|
|
||||||
|
time.Sleep(time.Duration(500) * time.Millisecond)
|
||||||
|
|
||||||
|
asserts.Equal(2, OnDownloadStart)
|
||||||
|
asserts.Equal(2, OnDownloadPause)
|
||||||
|
asserts.Equal(2, OnDownloadStop)
|
||||||
|
asserts.Equal(4, OnDownloadComplete)
|
||||||
|
asserts.Equal(2, OnDownloadError)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,110 @@
|
||||||
|
package request
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Option 发送请求的额外设置
|
||||||
|
type Option interface {
|
||||||
|
apply(*options)
|
||||||
|
}
|
||||||
|
|
||||||
|
type options struct {
|
||||||
|
timeout time.Duration
|
||||||
|
header http.Header
|
||||||
|
sign auth.Auth
|
||||||
|
signTTL int64
|
||||||
|
ctx context.Context
|
||||||
|
contentLength int64
|
||||||
|
masterMeta bool
|
||||||
|
endpoint *url.URL
|
||||||
|
slaveNodeID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type optionFunc func(*options)
|
||||||
|
|
||||||
|
func (f optionFunc) apply(o *options) {
|
||||||
|
f(o)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDefaultOption() *options {
|
||||||
|
return &options{
|
||||||
|
header: http.Header{},
|
||||||
|
timeout: time.Duration(30) * time.Second,
|
||||||
|
contentLength: -1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTimeout 设置请求超时
|
||||||
|
func WithTimeout(t time.Duration) Option {
|
||||||
|
return optionFunc(func(o *options) {
|
||||||
|
o.timeout = t
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithContext 设置请求上下文
|
||||||
|
func WithContext(c context.Context) Option {
|
||||||
|
return optionFunc(func(o *options) {
|
||||||
|
o.ctx = c
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCredential 对请求进行签名
|
||||||
|
func WithCredential(instance auth.Auth, ttl int64) Option {
|
||||||
|
return optionFunc(func(o *options) {
|
||||||
|
o.sign = instance
|
||||||
|
o.signTTL = ttl
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHeader 设置请求Header
|
||||||
|
func WithHeader(header http.Header) Option {
|
||||||
|
return optionFunc(func(o *options) {
|
||||||
|
for k, v := range header {
|
||||||
|
o.header[k] = v
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithoutHeader 设置清除请求Header
|
||||||
|
func WithoutHeader(header []string) Option {
|
||||||
|
return optionFunc(func(o *options) {
|
||||||
|
for _, v := range header {
|
||||||
|
delete(o.header, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithContentLength 设置请求大小
|
||||||
|
func WithContentLength(s int64) Option {
|
||||||
|
return optionFunc(func(o *options) {
|
||||||
|
o.contentLength = s
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMasterMeta 请求时携带主机信息
|
||||||
|
func WithMasterMeta() Option {
|
||||||
|
return optionFunc(func(o *options) {
|
||||||
|
o.masterMeta = true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSlaveMeta 请求时携带从机信息
|
||||||
|
func WithSlaveMeta(s string) Option {
|
||||||
|
return optionFunc(func(o *options) {
|
||||||
|
o.slaveNodeID = s
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Endpoint 使用同一的请求Endpoint
|
||||||
|
func WithEndpoint(endpoint string) Option {
|
||||||
|
endpointURL, _ := url.Parse(endpoint)
|
||||||
|
return optionFunc(func(o *options) {
|
||||||
|
o.endpoint = endpointURL
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -1,23 +1,25 @@
|
||||||
package request
|
package request
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"path"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GeneralClient 通用 HTTP Client
|
// GeneralClient 通用 HTTP Client
|
||||||
var GeneralClient Client = HTTPClient{}
|
var GeneralClient Client = NewClient()
|
||||||
|
|
||||||
// Response 请求的响应或错误信息
|
// Response 请求的响应或错误信息
|
||||||
type Response struct {
|
type Response struct {
|
||||||
|
|
@ -32,90 +34,30 @@ type Client interface {
|
||||||
|
|
||||||
// HTTPClient 实现 Client 接口
|
// HTTPClient 实现 Client 接口
|
||||||
type HTTPClient struct {
|
type HTTPClient struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
options *options
|
||||||
}
|
}
|
||||||
|
|
||||||
// Option 发送请求的额外设置
|
func NewClient(opts ...Option) Client {
|
||||||
type Option interface {
|
client := &HTTPClient{
|
||||||
apply(*options)
|
options: newDefaultOption(),
|
||||||
}
|
}
|
||||||
|
|
||||||
type options struct {
|
for _, o := range opts {
|
||||||
timeout time.Duration
|
o.apply(client.options)
|
||||||
header http.Header
|
|
||||||
sign auth.Auth
|
|
||||||
signTTL int64
|
|
||||||
ctx context.Context
|
|
||||||
contentLength int64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type optionFunc func(*options)
|
return client
|
||||||
|
|
||||||
func (f optionFunc) apply(o *options) {
|
|
||||||
f(o)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newDefaultOption() *options {
|
|
||||||
return &options{
|
|
||||||
header: http.Header{},
|
|
||||||
timeout: time.Duration(30) * time.Second,
|
|
||||||
contentLength: -1,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithTimeout 设置请求超时
|
|
||||||
func WithTimeout(t time.Duration) Option {
|
|
||||||
return optionFunc(func(o *options) {
|
|
||||||
o.timeout = t
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithContext 设置请求上下文
|
|
||||||
func WithContext(c context.Context) Option {
|
|
||||||
return optionFunc(func(o *options) {
|
|
||||||
o.ctx = c
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithCredential 对请求进行签名
|
|
||||||
func WithCredential(instance auth.Auth, ttl int64) Option {
|
|
||||||
return optionFunc(func(o *options) {
|
|
||||||
o.sign = instance
|
|
||||||
o.signTTL = ttl
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithHeader 设置请求Header
|
|
||||||
func WithHeader(header http.Header) Option {
|
|
||||||
return optionFunc(func(o *options) {
|
|
||||||
for k, v := range header {
|
|
||||||
o.header[k] = v
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithoutHeader 设置清除请求Header
|
|
||||||
func WithoutHeader(header []string) Option {
|
|
||||||
return optionFunc(func(o *options) {
|
|
||||||
for _, v := range header {
|
|
||||||
delete(o.header, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithContentLength 设置请求大小
|
|
||||||
func WithContentLength(s int64) Option {
|
|
||||||
return optionFunc(func(o *options) {
|
|
||||||
o.contentLength = s
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Request 发送HTTP请求
|
// Request 发送HTTP请求
|
||||||
func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response {
|
func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response {
|
||||||
// 应用额外设置
|
// 应用额外设置
|
||||||
options := newDefaultOption()
|
c.mu.Lock()
|
||||||
|
options := *c.options
|
||||||
|
c.mu.Unlock()
|
||||||
for _, o := range opts {
|
for _, o := range opts {
|
||||||
o.apply(options)
|
o.apply(&options)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建请求客户端
|
// 创建请求客户端
|
||||||
|
|
@ -126,6 +68,13 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio
|
||||||
body = nil
|
body = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 确定请求URL
|
||||||
|
if options.endpoint != nil {
|
||||||
|
targetURL := *options.endpoint
|
||||||
|
targetURL.Path = path.Join(targetURL.Path, target)
|
||||||
|
target = targetURL.String()
|
||||||
|
}
|
||||||
|
|
||||||
// 创建请求
|
// 创建请求
|
||||||
var (
|
var (
|
||||||
req *http.Request
|
req *http.Request
|
||||||
|
|
@ -141,14 +90,36 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio
|
||||||
}
|
}
|
||||||
|
|
||||||
// 添加请求相关设置
|
// 添加请求相关设置
|
||||||
req.Header = options.header
|
if options.header != nil {
|
||||||
|
for k, v := range options.header {
|
||||||
|
req.Header.Add(k, strings.Join(v, " "))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.masterMeta && conf.SystemConfig.Mode == "master" {
|
||||||
|
req.Header.Add("X-Site-Url", model.GetSiteURL().String())
|
||||||
|
req.Header.Add("X-Site-Id", model.GetSettingByName("siteID"))
|
||||||
|
req.Header.Add("X-Cloudreve-Version", conf.BackendVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.slaveNodeID != "" && conf.SystemConfig.Mode == "slave" {
|
||||||
|
req.Header.Add("X-Node-Id", options.slaveNodeID)
|
||||||
|
}
|
||||||
|
|
||||||
if options.contentLength != -1 {
|
if options.contentLength != -1 {
|
||||||
req.ContentLength = options.contentLength
|
req.ContentLength = options.contentLength
|
||||||
}
|
}
|
||||||
|
|
||||||
// 签名请求
|
// 签名请求
|
||||||
if options.sign != nil {
|
if options.sign != nil {
|
||||||
|
switch method {
|
||||||
|
case "PUT", "POST", "PATCH":
|
||||||
auth.SignRequest(options.sign, req, options.signTTL)
|
auth.SignRequest(options.sign, req, options.signTTL)
|
||||||
|
default:
|
||||||
|
if resURL, err := auth.SignURI(options.sign, req.URL.String(), options.signTTL); err == nil {
|
||||||
|
req.URL = resURL
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 发送请求
|
// 发送请求
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO: move to slave pkg
|
||||||
// RemoteCallback 发送远程存储策略上传回调请求
|
// RemoteCallback 发送远程存储策略上传回调请求
|
||||||
func RemoteCallback(url string, body serializer.UploadCallback) error {
|
func RemoteCallback(url string, body serializer.UploadCallback) error {
|
||||||
callbackBody, err := json.Marshal(struct {
|
callbackBody, err := json.Marshal(struct {
|
||||||
|
|
|
||||||
|
|
@ -5,16 +5,15 @@ import "encoding/json"
|
||||||
// RequestRawSign 待签名的HTTP请求
|
// RequestRawSign 待签名的HTTP请求
|
||||||
type RequestRawSign struct {
|
type RequestRawSign struct {
|
||||||
Path string
|
Path string
|
||||||
Policy string
|
Header string
|
||||||
Body string
|
Body string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRequestSignString 返回JSON格式的待签名字符串
|
// NewRequestSignString 返回JSON格式的待签名字符串
|
||||||
// TODO 测试
|
func NewRequestSignString(path, header, body string) string {
|
||||||
func NewRequestSignString(path, policy, body string) string {
|
|
||||||
req := RequestRawSign{
|
req := RequestRawSign{
|
||||||
Path: path,
|
Path: path,
|
||||||
Policy: policy,
|
Header: header,
|
||||||
Body: body,
|
Body: body,
|
||||||
}
|
}
|
||||||
res, _ := json.Marshal(req)
|
res, _ := json.Marshal(req)
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,9 @@
|
||||||
package serializer
|
package serializer
|
||||||
|
|
||||||
import "github.com/gin-gonic/gin"
|
import (
|
||||||
|
"errors"
|
||||||
// Response 基础序列化器
|
"github.com/gin-gonic/gin"
|
||||||
type Response struct {
|
)
|
||||||
Code int `json:"code"`
|
|
||||||
Data interface{} `json:"data,omitempty"`
|
|
||||||
Msg string `json:"msg"`
|
|
||||||
Error string `json:"error,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// AppError 应用错误,实现了error接口
|
// AppError 应用错误,实现了error接口
|
||||||
type AppError struct {
|
type AppError struct {
|
||||||
|
|
@ -17,7 +12,7 @@ type AppError struct {
|
||||||
RawError error
|
RawError error
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewError 返回新的错误对象 todo:测试 还有下面的
|
// NewError 返回新的错误对象
|
||||||
func NewError(code int, msg string, err error) AppError {
|
func NewError(code int, msg string, err error) AppError {
|
||||||
return AppError{
|
return AppError{
|
||||||
Code: code,
|
Code: code,
|
||||||
|
|
@ -26,6 +21,15 @@ func NewError(code int, msg string, err error) AppError {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewErrorFromResponse 从 serializer.Response 构建错误
|
||||||
|
func NewErrorFromResponse(resp *Response) AppError {
|
||||||
|
return AppError{
|
||||||
|
Code: resp.Code,
|
||||||
|
Msg: resp.Msg,
|
||||||
|
RawError: errors.New(resp.Error),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithError 将应用error携带标准库中的error
|
// WithError 将应用error携带标准库中的error
|
||||||
func (err *AppError) WithError(raw error) AppError {
|
func (err *AppError) WithError(raw error) AppError {
|
||||||
err.RawError = raw
|
err.RawError = raw
|
||||||
|
|
@ -66,6 +70,8 @@ const (
|
||||||
CodeGroupNotAllowed = 40007
|
CodeGroupNotAllowed = 40007
|
||||||
// CodeAdminRequired 非管理用户组
|
// CodeAdminRequired 非管理用户组
|
||||||
CodeAdminRequired = 40008
|
CodeAdminRequired = 40008
|
||||||
|
// CodeMasterNotFound 主机节点未注册
|
||||||
|
CodeMasterNotFound = 40009
|
||||||
// CodeDBError 数据库操作失败
|
// CodeDBError 数据库操作失败
|
||||||
CodeDBError = 50001
|
CodeDBError = 50001
|
||||||
// CodeEncryptError 加密失败
|
// CodeEncryptError 加密失败
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,35 @@
|
||||||
|
package serializer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/gob"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Response 基础序列化器
|
||||||
|
type Response struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Data interface{} `json:"data,omitempty"`
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewResponseWithGobData 返回Data字段使用gob编码的Response
|
||||||
|
func NewResponseWithGobData(data interface{}) Response {
|
||||||
|
var w bytes.Buffer
|
||||||
|
encoder := gob.NewEncoder(&w)
|
||||||
|
if err := encoder.Encode(data); err != nil {
|
||||||
|
return Err(CodeInternalSetting, "无法编码返回结果", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return Response{Data: w.Bytes()}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GobDecode 将 Response 正文解码至目标指针
|
||||||
|
func (r *Response) GobDecode(target interface{}) {
|
||||||
|
src := r.Data.(string)
|
||||||
|
raw := make([]byte, len(src)*len(src)/base64.StdEncoding.DecodedLen(len(src)))
|
||||||
|
base64.StdEncoding.Decode(raw, []byte(src))
|
||||||
|
decoder := gob.NewDecoder(bytes.NewBuffer(raw))
|
||||||
|
decoder.Decode(target)
|
||||||
|
}
|
||||||
|
|
@ -1,5 +1,12 @@
|
||||||
package serializer
|
package serializer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/gob"
|
||||||
|
"fmt"
|
||||||
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
)
|
||||||
|
|
||||||
// RemoteDeleteRequest 远程策略删除接口请求正文
|
// RemoteDeleteRequest 远程策略删除接口请求正文
|
||||||
type RemoteDeleteRequest struct {
|
type RemoteDeleteRequest struct {
|
||||||
Files []string `json:"files"`
|
Files []string `json:"files"`
|
||||||
|
|
@ -10,3 +17,51 @@ type ListRequest struct {
|
||||||
Path string `json:"path"`
|
Path string `json:"path"`
|
||||||
Recursive bool `json:"recursive"`
|
Recursive bool `json:"recursive"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NodePingReq 从机节点Ping请求
|
||||||
|
type NodePingReq struct {
|
||||||
|
SiteURL string `json:"site_url"`
|
||||||
|
SiteID string `json:"site_id"`
|
||||||
|
IsUpdate bool `json:"is_update"`
|
||||||
|
CredentialTTL int `json:"credential_ttl"`
|
||||||
|
Node *model.Node `json:"node"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NodePingResp 从机节点Ping响应
|
||||||
|
type NodePingResp struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// SlaveAria2Call 从机有关Aria2的请求正文
|
||||||
|
type SlaveAria2Call struct {
|
||||||
|
Task *model.Download `json:"task"`
|
||||||
|
GroupOptions map[string]interface{} `json:"group_options"`
|
||||||
|
Files []int `json:"files"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SlaveTransferReq 从机中转任务创建请求
|
||||||
|
type SlaveTransferReq struct {
|
||||||
|
Src string `json:"src"`
|
||||||
|
Dst string `json:"dst"`
|
||||||
|
Policy *model.Policy `json:"policy"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash 返回创建请求的唯一标识,保持创建请求幂等
|
||||||
|
func (s *SlaveTransferReq) Hash(id string) string {
|
||||||
|
h := sha1.New()
|
||||||
|
h.Write([]byte(fmt.Sprintf("transfer-%s-%s-%s-%d", id, s.Src, s.Dst, s.Policy.ID)))
|
||||||
|
bs := h.Sum(nil)
|
||||||
|
return fmt.Sprintf("%x", bs)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
SlaveTransferSuccess = "success"
|
||||||
|
SlaveTransferFailed = "failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SlaveTransferResult struct {
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
gob.Register(SlaveTransferResult{})
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
package slave
|
||||||
|
|
||||||
|
import "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "未知的主机节点", nil)
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,209 @@
|
||||||
|
package slave
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/gob"
|
||||||
|
"fmt"
|
||||||
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var DefaultController Controller
|
||||||
|
|
||||||
|
// Controller controls communications between master and slave
|
||||||
|
type Controller interface {
|
||||||
|
// Handle heartbeat sent from master
|
||||||
|
HandleHeartBeat(*serializer.NodePingReq) (serializer.NodePingResp, error)
|
||||||
|
|
||||||
|
// Get Aria2 Instance by master node ID
|
||||||
|
GetAria2Instance(string) (common.Aria2, error)
|
||||||
|
|
||||||
|
// Send event change message to master node
|
||||||
|
SendNotification(string, string, mq.Message) error
|
||||||
|
|
||||||
|
// Submit async task into task pool
|
||||||
|
SubmitTask(string, interface{}, string, func(interface{})) error
|
||||||
|
|
||||||
|
// Get master node info
|
||||||
|
GetMasterInfo(string) (*MasterInfo, error)
|
||||||
|
|
||||||
|
// Get master OneDrive policy credential
|
||||||
|
GetOneDriveToken(string, uint) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type slaveController struct {
|
||||||
|
masters map[string]MasterInfo
|
||||||
|
lock sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// info of master node
|
||||||
|
type MasterInfo struct {
|
||||||
|
ID string
|
||||||
|
TTL int
|
||||||
|
URL *url.URL
|
||||||
|
// used to invoke aria2 rpc calls
|
||||||
|
Instance cluster.Node
|
||||||
|
Client request.Client
|
||||||
|
|
||||||
|
jobTracker map[string]bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func Init() {
|
||||||
|
DefaultController = &slaveController{
|
||||||
|
masters: make(map[string]MasterInfo),
|
||||||
|
}
|
||||||
|
gob.Register(rpc.StatusInfo{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializer.NodePingResp, error) {
|
||||||
|
c.lock.Lock()
|
||||||
|
defer c.lock.Unlock()
|
||||||
|
|
||||||
|
req.Node.AfterFind()
|
||||||
|
|
||||||
|
// close old node if exist
|
||||||
|
origin, ok := c.masters[req.SiteID]
|
||||||
|
|
||||||
|
if (ok && req.IsUpdate) || !ok {
|
||||||
|
if ok {
|
||||||
|
origin.Instance.Kill()
|
||||||
|
}
|
||||||
|
|
||||||
|
masterUrl, err := url.Parse(req.SiteURL)
|
||||||
|
if err != nil {
|
||||||
|
return serializer.NodePingResp{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.masters[req.SiteID] = MasterInfo{
|
||||||
|
ID: req.SiteID,
|
||||||
|
URL: masterUrl,
|
||||||
|
TTL: req.CredentialTTL,
|
||||||
|
Client: request.NewClient(
|
||||||
|
request.WithEndpoint(masterUrl.String()),
|
||||||
|
request.WithSlaveMeta(fmt.Sprintf("%d", req.Node.ID)),
|
||||||
|
request.WithCredential(auth.HMACAuth{
|
||||||
|
SecretKey: []byte(req.Node.MasterKey),
|
||||||
|
}, int64(req.CredentialTTL)),
|
||||||
|
),
|
||||||
|
jobTracker: make(map[string]bool),
|
||||||
|
Instance: cluster.NewNodeFromDBModel(&model.Node{
|
||||||
|
Model: gorm.Model{ID: req.Node.ID},
|
||||||
|
MasterKey: req.Node.MasterKey,
|
||||||
|
Type: model.MasterNodeType,
|
||||||
|
Aria2Enabled: req.Node.Aria2Enabled,
|
||||||
|
Aria2OptionsSerialized: req.Node.Aria2OptionsSerialized,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return serializer.NodePingResp{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *slaveController) GetAria2Instance(id string) (common.Aria2, error) {
|
||||||
|
c.lock.RLock()
|
||||||
|
defer c.lock.RUnlock()
|
||||||
|
|
||||||
|
if node, ok := c.masters[id]; ok {
|
||||||
|
return node.Instance.GetAria2Instance(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, ErrMasterNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *slaveController) SendNotification(id, subject string, msg mq.Message) error {
|
||||||
|
c.lock.RLock()
|
||||||
|
|
||||||
|
if node, ok := c.masters[id]; ok {
|
||||||
|
c.lock.RUnlock()
|
||||||
|
|
||||||
|
body := bytes.Buffer{}
|
||||||
|
enc := gob.NewEncoder(&body)
|
||||||
|
if err := enc.Encode(&msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := node.Client.Request(
|
||||||
|
"PUT",
|
||||||
|
fmt.Sprintf("/api/v3/slave/notification/%s", subject),
|
||||||
|
&body,
|
||||||
|
).CheckHTTPResponse(200).DecodeResponse()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != 0 {
|
||||||
|
return serializer.NewErrorFromResponse(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.lock.RUnlock()
|
||||||
|
return ErrMasterNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubmitTask 提交异步任务
|
||||||
|
func (c *slaveController) SubmitTask(id string, job interface{}, hash string, submitter func(interface{})) error {
|
||||||
|
c.lock.RLock()
|
||||||
|
defer c.lock.RUnlock()
|
||||||
|
|
||||||
|
if node, ok := c.masters[id]; ok {
|
||||||
|
if _, ok := node.jobTracker[hash]; ok {
|
||||||
|
// 任务已存在,直接返回
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
submitter(job)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ErrMasterNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMasterInfo 获取主机节点信息
|
||||||
|
func (c *slaveController) GetMasterInfo(id string) (*MasterInfo, error) {
|
||||||
|
c.lock.RLock()
|
||||||
|
defer c.lock.RUnlock()
|
||||||
|
|
||||||
|
if node, ok := c.masters[id]; ok {
|
||||||
|
return &node, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, ErrMasterNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOneDriveToken 获取主机OneDrive凭证
|
||||||
|
func (c *slaveController) GetOneDriveToken(id string, policyID uint) (string, error) {
|
||||||
|
c.lock.RLock()
|
||||||
|
|
||||||
|
if node, ok := c.masters[id]; ok {
|
||||||
|
c.lock.RUnlock()
|
||||||
|
|
||||||
|
res, err := node.Client.Request(
|
||||||
|
"GET",
|
||||||
|
fmt.Sprintf("/api/v3/slave/credential/onedrive/%d", policyID),
|
||||||
|
nil,
|
||||||
|
).CheckHTTPResponse(200).DecodeResponse()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != 0 {
|
||||||
|
return "", serializer.NewErrorFromResponse(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res.Data.(string), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.lock.RUnlock()
|
||||||
|
return "", ErrMasterNotFound
|
||||||
|
}
|
||||||
|
|
@ -106,7 +106,7 @@ func (job *CompressTask) Do() {
|
||||||
job.TaskModel.SetProgress(TransferringProgress)
|
job.TaskModel.SetProgress(TransferringProgress)
|
||||||
|
|
||||||
// 上传文件
|
// 上传文件
|
||||||
err = fs.UploadFromPath(ctx, zipFile, job.TaskProps.Dst)
|
err = fs.UploadFromPath(ctx, zipFile, job.TaskProps.Dst, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
job.SetErrorMsg(err.Error())
|
job.SetErrorMsg(err.Error())
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -96,9 +96,11 @@ func Resume() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if job != nil {
|
||||||
TaskPoll.Submit(job)
|
TaskPoll.Submit(job)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GetJobFromModel 从数据库给定模型获取任务
|
// GetJobFromModel 从数据库给定模型获取任务
|
||||||
func GetJobFromModel(task *model.Task) (Job, error) {
|
func GetJobFromModel(task *model.Task) (Job, error) {
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ package task
|
||||||
|
|
||||||
import (
|
import (
|
||||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -56,5 +57,7 @@ func Init() {
|
||||||
TaskPoll.Add(maxWorker)
|
TaskPoll.Add(maxWorker)
|
||||||
util.Log().Info("初始化任务队列,WorkerNum = %d", maxWorker)
|
util.Log().Info("初始化任务队列,WorkerNum = %d", maxWorker)
|
||||||
|
|
||||||
|
if conf.SystemConfig.Mode == "master" {
|
||||||
Resume()
|
Resume()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,145 @@
|
||||||
|
package slavetask
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TransferTask 文件中转任务
|
||||||
|
type TransferTask struct {
|
||||||
|
Err *task.JobError
|
||||||
|
Req *serializer.SlaveTransferReq
|
||||||
|
MasterID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Props 获取任务属性
|
||||||
|
func (job *TransferTask) Props() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type 获取任务类型
|
||||||
|
func (job *TransferTask) Type() int {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creator 获取创建者ID
|
||||||
|
func (job *TransferTask) Creator() uint {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model 获取任务的数据库模型
|
||||||
|
func (job *TransferTask) Model() *model.Task {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetStatus 设定状态
|
||||||
|
func (job *TransferTask) SetStatus(status int) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetError 设定任务失败信息
|
||||||
|
func (job *TransferTask) SetError(err *task.JobError) {
|
||||||
|
job.Err = err
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetErrorMsg 设定任务失败信息
|
||||||
|
func (job *TransferTask) SetErrorMsg(msg string, err error) {
|
||||||
|
jobErr := &task.JobError{Msg: msg}
|
||||||
|
if err != nil {
|
||||||
|
jobErr.Error = err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
job.SetError(jobErr)
|
||||||
|
|
||||||
|
notifyMsg := mq.Message{
|
||||||
|
TriggeredBy: job.MasterID,
|
||||||
|
Event: serializer.SlaveTransferFailed,
|
||||||
|
Content: serializer.SlaveTransferResult{
|
||||||
|
Error: err.Error(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := slave.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil {
|
||||||
|
util.Log().Warning("无法发送转存失败通知到从机, ", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetError 返回任务失败信息
|
||||||
|
func (job *TransferTask) GetError() *task.JobError {
|
||||||
|
return job.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do 开始执行任务
|
||||||
|
func (job *TransferTask) Do() {
|
||||||
|
defer job.Recycle()
|
||||||
|
|
||||||
|
fs, err := filesystem.NewAnonymousFileSystem()
|
||||||
|
if err != nil {
|
||||||
|
job.SetErrorMsg("无法初始化匿名文件系统", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fs.Policy = job.Req.Policy
|
||||||
|
if err := fs.DispatchHandler(); err != nil {
|
||||||
|
job.SetErrorMsg("无法分发存储策略", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
master, err := slave.DefaultController.GetMasterInfo(job.MasterID)
|
||||||
|
if err != nil {
|
||||||
|
job.SetErrorMsg("找不到主机节点", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fs.SwitchToShadowHandler(master.Instance, master.URL.String(), master.ID)
|
||||||
|
ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true)
|
||||||
|
file, err := os.Open(util.RelativePath(job.Req.Src))
|
||||||
|
if err != nil {
|
||||||
|
job.SetErrorMsg("无法读取源文件", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
// 获取源文件大小
|
||||||
|
fi, err := file.Stat()
|
||||||
|
if err != nil {
|
||||||
|
job.SetErrorMsg("无法获取源文件大小", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
size := fi.Size()
|
||||||
|
|
||||||
|
err = fs.Handler.Put(ctx, file, job.Req.Dst, uint64(size))
|
||||||
|
if err != nil {
|
||||||
|
job.SetErrorMsg("文件上传失败", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := mq.Message{
|
||||||
|
TriggeredBy: job.MasterID,
|
||||||
|
Event: serializer.SlaveTransferSuccess,
|
||||||
|
Content: serializer.SlaveTransferResult{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := slave.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil {
|
||||||
|
util.Log().Warning("无法发送转存成功通知到从机, ", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recycle 回收临时文件
|
||||||
|
func (job *TransferTask) Recycle() {
|
||||||
|
err := os.RemoveAll(filepath.Dir(job.Req.Src))
|
||||||
|
if err != nil {
|
||||||
|
util.Log().Warning("无法删除中转临时目录[%s], %s", job.Req.Src, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||||
|
|
@ -27,10 +28,13 @@ type TransferTask struct {
|
||||||
// TransferProps 中转任务属性
|
// TransferProps 中转任务属性
|
||||||
type TransferProps struct {
|
type TransferProps struct {
|
||||||
Src []string `json:"src"` // 原始文件
|
Src []string `json:"src"` // 原始文件
|
||||||
|
SrcSizes map[string]uint64 `json:"src_size"` // 原始文件的大小信息,从机转存时使用
|
||||||
Parent string `json:"parent"` // 父目录
|
Parent string `json:"parent"` // 父目录
|
||||||
Dst string `json:"dst"` // 目的目录ID
|
Dst string `json:"dst"` // 目的目录ID
|
||||||
// 将会保留原始文件的目录结构,Src 除去 Parent 开头作为最终路径
|
// 将会保留原始文件的目录结构,Src 除去 Parent 开头作为最终路径
|
||||||
TrimPath bool `json:"trim_path"`
|
TrimPath bool `json:"trim_path"`
|
||||||
|
// 负责处理中专任务的节点ID
|
||||||
|
NodeID uint `json:"node_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Props 获取任务属性
|
// Props 获取任务属性
|
||||||
|
|
@ -104,7 +108,24 @@ func (job *TransferTask) Do() {
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true)
|
ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true)
|
||||||
err = fs.UploadFromPath(ctx, file, dst)
|
ctx = context.WithValue(ctx, fsctx.SlaveSrcPath, file)
|
||||||
|
if job.TaskProps.NodeID > 1 {
|
||||||
|
// 指定为从机中转
|
||||||
|
|
||||||
|
// 获取从机节点
|
||||||
|
node := cluster.Default.GetNodeByID(job.TaskProps.NodeID)
|
||||||
|
if node == nil {
|
||||||
|
job.SetErrorMsg("从机节点不可用", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 切换为从机节点处理上传
|
||||||
|
fs.SwitchToSlaveHandler(node)
|
||||||
|
err = fs.UploadFromStream(ctx, nil, dst, job.TaskProps.SrcSizes[file])
|
||||||
|
} else {
|
||||||
|
// 主机节点中转
|
||||||
|
err = fs.UploadFromPath(ctx, file, dst, true)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
job.SetErrorMsg("文件转存失败", err)
|
job.SetErrorMsg("文件转存失败", err)
|
||||||
}
|
}
|
||||||
|
|
@ -114,15 +135,16 @@ func (job *TransferTask) Do() {
|
||||||
|
|
||||||
// Recycle 回收临时文件
|
// Recycle 回收临时文件
|
||||||
func (job *TransferTask) Recycle() {
|
func (job *TransferTask) Recycle() {
|
||||||
|
if job.TaskProps.NodeID == 1 {
|
||||||
err := os.RemoveAll(job.TaskProps.Parent)
|
err := os.RemoveAll(job.TaskProps.Parent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Parent, err)
|
util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Parent, err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTransferTask 新建中转任务
|
// NewTransferTask 新建中转任务
|
||||||
func NewTransferTask(user uint, src []string, dst, parent string, trim bool) (Job, error) {
|
func NewTransferTask(user uint, src []string, dst, parent string, trim bool, node uint, sizes map[string]uint64) (Job, error) {
|
||||||
creator, err := model.GetActiveUserByID(user)
|
creator, err := model.GetActiveUserByID(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -135,6 +157,8 @@ func NewTransferTask(user uint, src []string, dst, parent string, trim bool) (Jo
|
||||||
Parent: parent,
|
Parent: parent,
|
||||||
Dst: dst,
|
Dst: dst,
|
||||||
TrimPath: trim,
|
TrimPath: trim,
|
||||||
|
NodeID: node,
|
||||||
|
SrcSizes: sizes,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package controllers
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/email"
|
"github.com/cloudreve/Cloudreve/v3/pkg/email"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||||
|
|
@ -24,7 +25,7 @@ func AdminSummary(c *gin.Context) {
|
||||||
|
|
||||||
// AdminNews 获取社区新闻
|
// AdminNews 获取社区新闻
|
||||||
func AdminNews(c *gin.Context) {
|
func AdminNews(c *gin.Context) {
|
||||||
r := request.HTTPClient{}
|
r := request.NewClient()
|
||||||
res := r.Request("GET", "https://forum.cloudreve.org/api/discussions?include=startUser%2ClastUser%2CstartPost%2Ctags&filter%5Bq%5D=%20tag%3Anotice&sort=-startTime&page%5Blimit%5D=10", nil)
|
res := r.Request("GET", "https://forum.cloudreve.org/api/discussions?include=startUser%2ClastUser%2CstartPost%2Ctags&filter%5Bq%5D=%20tag%3Anotice&sort=-startTime&page%5Blimit%5D=10", nil)
|
||||||
if res.Err == nil {
|
if res.Err == nil {
|
||||||
io.Copy(c.Writer, res.Response.Body)
|
io.Copy(c.Writer, res.Response.Body)
|
||||||
|
|
@ -92,7 +93,13 @@ func AdminSendTestMail(c *gin.Context) {
|
||||||
func AdminTestAria2(c *gin.Context) {
|
func AdminTestAria2(c *gin.Context) {
|
||||||
var service admin.Aria2TestService
|
var service admin.Aria2TestService
|
||||||
if err := c.ShouldBindJSON(&service); err == nil {
|
if err := c.ShouldBindJSON(&service); err == nil {
|
||||||
res := service.Test()
|
var res serializer.Response
|
||||||
|
if service.Type == model.MasterNodeType {
|
||||||
|
res = service.TestMaster()
|
||||||
|
} else {
|
||||||
|
res = service.TestSlave()
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(200, res)
|
c.JSON(200, res)
|
||||||
} else {
|
} else {
|
||||||
c.JSON(200, ErrorResponse(err))
|
c.JSON(200, ErrorResponse(err))
|
||||||
|
|
@ -425,3 +432,58 @@ func AdminListFolders(c *gin.Context) {
|
||||||
c.JSON(200, ErrorResponse(err))
|
c.JSON(200, ErrorResponse(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AdminListNodes 列出从机节点
|
||||||
|
func AdminListNodes(c *gin.Context) {
|
||||||
|
var service admin.AdminListService
|
||||||
|
if err := c.ShouldBindJSON(&service); err == nil {
|
||||||
|
res := service.Nodes()
|
||||||
|
c.JSON(200, res)
|
||||||
|
} else {
|
||||||
|
c.JSON(200, ErrorResponse(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminAddNode 新建节点
|
||||||
|
func AdminAddNode(c *gin.Context) {
|
||||||
|
var service admin.AddNodeService
|
||||||
|
if err := c.ShouldBindJSON(&service); err == nil {
|
||||||
|
res := service.Add()
|
||||||
|
c.JSON(200, res)
|
||||||
|
} else {
|
||||||
|
c.JSON(200, ErrorResponse(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminToggleNode 启用/暂停节点
|
||||||
|
func AdminToggleNode(c *gin.Context) {
|
||||||
|
var service admin.ToggleNodeService
|
||||||
|
if err := c.ShouldBindUri(&service); err == nil {
|
||||||
|
res := service.Toggle()
|
||||||
|
c.JSON(200, res)
|
||||||
|
} else {
|
||||||
|
c.JSON(200, ErrorResponse(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminDeleteGroup 删除用户组
|
||||||
|
func AdminDeleteNode(c *gin.Context) {
|
||||||
|
var service admin.NodeService
|
||||||
|
if err := c.ShouldBindUri(&service); err == nil {
|
||||||
|
res := service.Delete()
|
||||||
|
c.JSON(200, res)
|
||||||
|
} else {
|
||||||
|
c.JSON(200, ErrorResponse(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminGetNode 获取节点详情
|
||||||
|
func AdminGetNode(c *gin.Context) {
|
||||||
|
var service admin.NodeService
|
||||||
|
if err := c.ShouldBindUri(&service); err == nil {
|
||||||
|
res := service.Get()
|
||||||
|
c.JSON(200, res)
|
||||||
|
} else {
|
||||||
|
c.JSON(200, ErrorResponse(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ package controllers
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
ariaCall "github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||||
"github.com/cloudreve/Cloudreve/v3/service/aria2"
|
"github.com/cloudreve/Cloudreve/v3/service/aria2"
|
||||||
"github.com/cloudreve/Cloudreve/v3/service/explorer"
|
"github.com/cloudreve/Cloudreve/v3/service/explorer"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
@ -13,7 +13,7 @@ import (
|
||||||
func AddAria2URL(c *gin.Context) {
|
func AddAria2URL(c *gin.Context) {
|
||||||
var addService aria2.AddURLService
|
var addService aria2.AddURLService
|
||||||
if err := c.ShouldBindJSON(&addService); err == nil {
|
if err := c.ShouldBindJSON(&addService); err == nil {
|
||||||
res := addService.Add(c, ariaCall.URLTask)
|
res := addService.Add(c, common.URLTask)
|
||||||
c.JSON(200, res)
|
c.JSON(200, res)
|
||||||
} else {
|
} else {
|
||||||
c.JSON(200, ErrorResponse(err))
|
c.JSON(200, ErrorResponse(err))
|
||||||
|
|
@ -52,7 +52,7 @@ func AddAria2Torrent(c *gin.Context) {
|
||||||
|
|
||||||
if err := c.ShouldBindJSON(&addService); err == nil {
|
if err := c.ShouldBindJSON(&addService); err == nil {
|
||||||
addService.URL = res.Data.(string)
|
addService.URL = res.Data.(string)
|
||||||
res := addService.Add(c, ariaCall.URLTask)
|
res := addService.Add(c, common.URLTask)
|
||||||
c.JSON(200, res)
|
c.JSON(200, res)
|
||||||
} else {
|
} else {
|
||||||
c.JSON(200, ErrorResponse(err))
|
c.JSON(200, ErrorResponse(err))
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,9 @@ import (
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
"github.com/cloudreve/Cloudreve/v3/service/admin"
|
"github.com/cloudreve/Cloudreve/v3/service/admin"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/service/aria2"
|
||||||
"github.com/cloudreve/Cloudreve/v3/service/explorer"
|
"github.com/cloudreve/Cloudreve/v3/service/explorer"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/service/node"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -175,3 +177,102 @@ func SlaveList(c *gin.Context) {
|
||||||
c.JSON(200, ErrorResponse(err))
|
c.JSON(200, ErrorResponse(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SlaveHeartbeat 接受主机心跳包
|
||||||
|
func SlaveHeartbeat(c *gin.Context) {
|
||||||
|
var service serializer.NodePingReq
|
||||||
|
if err := c.ShouldBindJSON(&service); err == nil {
|
||||||
|
res := node.HandleMasterHeartbeat(&service)
|
||||||
|
c.JSON(200, res)
|
||||||
|
} else {
|
||||||
|
c.JSON(200, ErrorResponse(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SlaveAria2Create 创建 Aria2 任务
|
||||||
|
func SlaveAria2Create(c *gin.Context) {
|
||||||
|
var service serializer.SlaveAria2Call
|
||||||
|
if err := c.ShouldBindJSON(&service); err == nil {
|
||||||
|
res := aria2.Add(c, &service)
|
||||||
|
c.JSON(200, res)
|
||||||
|
} else {
|
||||||
|
c.JSON(200, ErrorResponse(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SlaveAria2Status 查询从机 Aria2 任务状态
|
||||||
|
func SlaveAria2Status(c *gin.Context) {
|
||||||
|
var service serializer.SlaveAria2Call
|
||||||
|
if err := c.ShouldBindJSON(&service); err == nil {
|
||||||
|
res := aria2.SlaveStatus(c, &service)
|
||||||
|
c.JSON(200, res)
|
||||||
|
} else {
|
||||||
|
c.JSON(200, ErrorResponse(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SlaveCancelAria2Task 取消从机离线下载任务
|
||||||
|
func SlaveCancelAria2Task(c *gin.Context) {
|
||||||
|
var service serializer.SlaveAria2Call
|
||||||
|
if err := c.ShouldBindJSON(&service); err == nil {
|
||||||
|
res := aria2.SlaveCancel(c, &service)
|
||||||
|
c.JSON(200, res)
|
||||||
|
} else {
|
||||||
|
c.JSON(200, ErrorResponse(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SlaveSelectTask 从机选取离线下载文件
|
||||||
|
func SlaveSelectTask(c *gin.Context) {
|
||||||
|
var service serializer.SlaveAria2Call
|
||||||
|
if err := c.ShouldBindJSON(&service); err == nil {
|
||||||
|
res := aria2.SlaveSelect(c, &service)
|
||||||
|
c.JSON(200, res)
|
||||||
|
} else {
|
||||||
|
c.JSON(200, ErrorResponse(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SlaveCreateTransferTask 从机创建中转任务
|
||||||
|
func SlaveCreateTransferTask(c *gin.Context) {
|
||||||
|
var service serializer.SlaveTransferReq
|
||||||
|
if err := c.ShouldBindJSON(&service); err == nil {
|
||||||
|
res := explorer.CreateTransferTask(c, &service)
|
||||||
|
c.JSON(200, res)
|
||||||
|
} else {
|
||||||
|
c.JSON(200, ErrorResponse(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SlaveNotificationPush 处理从机发送的消息推送
|
||||||
|
func SlaveNotificationPush(c *gin.Context) {
|
||||||
|
var service node.SlaveNotificationService
|
||||||
|
if err := c.ShouldBindUri(&service); err == nil {
|
||||||
|
res := service.HandleSlaveNotificationPush(c)
|
||||||
|
c.JSON(200, res)
|
||||||
|
} else {
|
||||||
|
c.JSON(200, ErrorResponse(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SlaveGetOneDriveCredential 从机获取主机的OneDrive存储策略凭证
|
||||||
|
func SlaveGetOneDriveCredential(c *gin.Context) {
|
||||||
|
var service node.OneDriveCredentialService
|
||||||
|
if err := c.ShouldBindUri(&service); err == nil {
|
||||||
|
res := service.Get(c)
|
||||||
|
c.JSON(200, res)
|
||||||
|
} else {
|
||||||
|
c.JSON(200, ErrorResponse(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SlaveSelectTask 从机删除离线下载临时文件
|
||||||
|
func SlaveDeleteTempFile(c *gin.Context) {
|
||||||
|
var service serializer.SlaveAria2Call
|
||||||
|
if err := c.ShouldBindJSON(&service); err == nil {
|
||||||
|
res := aria2.SlaveDeleteTemp(c, &service)
|
||||||
|
c.JSON(200, res)
|
||||||
|
} else {
|
||||||
|
c.JSON(200, ErrorResponse(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ package routers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/cloudreve/Cloudreve/v3/middleware"
|
"github.com/cloudreve/Cloudreve/v3/middleware"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
|
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||||
|
|
@ -29,7 +30,9 @@ func InitSlaveRouter() *gin.Engine {
|
||||||
InitCORS(r)
|
InitCORS(r)
|
||||||
v3 := r.Group("/api/v3/slave")
|
v3 := r.Group("/api/v3/slave")
|
||||||
// 鉴权中间件
|
// 鉴权中间件
|
||||||
v3.Use(middleware.SignRequired())
|
v3.Use(middleware.SignRequired(auth.General))
|
||||||
|
// 主机信息解析
|
||||||
|
v3.Use(middleware.MasterMetadata())
|
||||||
|
|
||||||
/*
|
/*
|
||||||
路由
|
路由
|
||||||
|
|
@ -37,6 +40,10 @@ func InitSlaveRouter() *gin.Engine {
|
||||||
{
|
{
|
||||||
// Ping
|
// Ping
|
||||||
v3.POST("ping", controllers.SlavePing)
|
v3.POST("ping", controllers.SlavePing)
|
||||||
|
// 测试 Aria2 RPC 连接
|
||||||
|
v3.POST("ping/aria2", controllers.AdminTestAria2)
|
||||||
|
// 接收主机心跳包
|
||||||
|
v3.POST("heartbeat", controllers.SlaveHeartbeat)
|
||||||
// 上传
|
// 上传
|
||||||
v3.POST("upload", controllers.SlaveUpload)
|
v3.POST("upload", controllers.SlaveUpload)
|
||||||
// 下载
|
// 下载
|
||||||
|
|
@ -49,6 +56,28 @@ func InitSlaveRouter() *gin.Engine {
|
||||||
v3.POST("delete", controllers.SlaveDelete)
|
v3.POST("delete", controllers.SlaveDelete)
|
||||||
// 列出文件
|
// 列出文件
|
||||||
v3.POST("list", controllers.SlaveList)
|
v3.POST("list", controllers.SlaveList)
|
||||||
|
|
||||||
|
// 离线下载
|
||||||
|
aria2 := v3.Group("aria2")
|
||||||
|
aria2.Use(middleware.UseSlaveAria2Instance())
|
||||||
|
{
|
||||||
|
// 创建离线下载任务
|
||||||
|
aria2.POST("task", controllers.SlaveAria2Create)
|
||||||
|
// 获取任务状态
|
||||||
|
aria2.POST("status", controllers.SlaveAria2Status)
|
||||||
|
// 取消离线下载任务
|
||||||
|
aria2.POST("cancel", controllers.SlaveCancelAria2Task)
|
||||||
|
// 选取任务文件
|
||||||
|
aria2.POST("select", controllers.SlaveSelectTask)
|
||||||
|
// 删除任务临时文件
|
||||||
|
aria2.POST("delete", controllers.SlaveDeleteTempFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步任务
|
||||||
|
task := v3.Group("task")
|
||||||
|
{
|
||||||
|
task.PUT("transfer", controllers.SlaveCreateTransferTask)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
@ -131,7 +160,7 @@ func InitMasterRouter() *gin.Engine {
|
||||||
user.PATCH("reset", controllers.UserReset)
|
user.PATCH("reset", controllers.UserReset)
|
||||||
// 邮件激活
|
// 邮件激活
|
||||||
user.GET("activate/:id",
|
user.GET("activate/:id",
|
||||||
middleware.SignRequired(),
|
middleware.SignRequired(auth.General),
|
||||||
middleware.HashID(hashid.UserID),
|
middleware.HashID(hashid.UserID),
|
||||||
controllers.UserActivate,
|
controllers.UserActivate,
|
||||||
)
|
)
|
||||||
|
|
@ -159,7 +188,7 @@ func InitMasterRouter() *gin.Engine {
|
||||||
|
|
||||||
// 需要携带签名验证的
|
// 需要携带签名验证的
|
||||||
sign := v3.Group("")
|
sign := v3.Group("")
|
||||||
sign.Use(middleware.SignRequired())
|
sign.Use(middleware.SignRequired(auth.General))
|
||||||
{
|
{
|
||||||
file := sign.Group("file")
|
file := sign.Group("file")
|
||||||
{
|
{
|
||||||
|
|
@ -174,6 +203,18 @@ func InitMasterRouter() *gin.Engine {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 从机的 RPC 通信
|
||||||
|
slave := v3.Group("slave")
|
||||||
|
slave.Use(middleware.SlaveRPCSignRequired())
|
||||||
|
{
|
||||||
|
// 事件通知
|
||||||
|
slave.PUT("notification/:subject", controllers.SlaveNotificationPush)
|
||||||
|
// 上传
|
||||||
|
slave.POST("upload", controllers.SlaveUpload)
|
||||||
|
// OneDrive 存储策略凭证
|
||||||
|
slave.GET("credential/onedrive/:id", controllers.SlaveGetOneDriveCredential)
|
||||||
|
}
|
||||||
|
|
||||||
// 回调接口
|
// 回调接口
|
||||||
callback := v3.Group("callback")
|
callback := v3.Group("callback")
|
||||||
{
|
{
|
||||||
|
|
@ -405,6 +446,22 @@ func InitMasterRouter() *gin.Engine {
|
||||||
task.POST("import", controllers.AdminCreateImportTask)
|
task.POST("import", controllers.AdminCreateImportTask)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
node := admin.Group("node")
|
||||||
|
{
|
||||||
|
// 列出从机节点
|
||||||
|
node.POST("list", controllers.AdminListNodes)
|
||||||
|
// 列出从机节点
|
||||||
|
node.POST("aria2/test", controllers.AdminTestAria2)
|
||||||
|
// 创建/保存节点
|
||||||
|
node.POST("", controllers.AdminAddNode)
|
||||||
|
// 启用/暂停节点
|
||||||
|
node.PATCH("enable/:id/:desired", controllers.AdminToggleNode)
|
||||||
|
// 删除节点
|
||||||
|
node.DELETE(":id", controllers.AdminDeleteNode)
|
||||||
|
// 获取节点
|
||||||
|
node.GET(":id", controllers.AdminGetNode)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 用户
|
// 用户
|
||||||
|
|
|
||||||
|
|
@ -1,43 +1,71 @@
|
||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Aria2TestService aria2连接测试服务
|
// Aria2TestService aria2连接测试服务
|
||||||
type Aria2TestService struct {
|
type Aria2TestService struct {
|
||||||
Server string `json:"server" binding:"required"`
|
Server string `json:"server" binding:"required"`
|
||||||
|
RPC string `json:"rpc" binding:"required"`
|
||||||
|
Secret string `json:"secret" binding:"required"`
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
|
Type model.ModelType `json:"type"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test 测试aria2连接
|
// Test 测试aria2连接
|
||||||
func (service *Aria2TestService) Test() serializer.Response {
|
func (service *Aria2TestService) TestMaster() serializer.Response {
|
||||||
testRPC := aria2.RPCService{}
|
res, err := aria2.TestRPCConnection(service.RPC, service.Token, 5)
|
||||||
|
|
||||||
// 解析RPC服务地址
|
|
||||||
server, err := url.Parse(service.Server)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return serializer.ParamErr("无法解析 aria2 RPC 服务地址, "+err.Error(), nil)
|
return serializer.ParamErr(err.Error(), err)
|
||||||
}
|
|
||||||
server.Path = "/jsonrpc"
|
|
||||||
|
|
||||||
if err := testRPC.Init(server.String(), service.Token, 5, map[string]interface{}{}); err != nil {
|
|
||||||
return serializer.ParamErr("无法初始化连接, "+err.Error(), nil)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
defer testRPC.Caller.Close()
|
if res.Version == "" {
|
||||||
|
|
||||||
info, err := testRPC.Caller.GetVersion()
|
|
||||||
if err != nil {
|
|
||||||
return serializer.ParamErr("无法请求 RPC 服务, "+err.Error(), nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
if info.Version == "" {
|
|
||||||
return serializer.ParamErr("RPC 服务返回非预期响应", nil)
|
return serializer.ParamErr("RPC 服务返回非预期响应", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
return serializer.Response{Data: info.Version}
|
return serializer.Response{Data: res.Version}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *Aria2TestService) TestSlave() serializer.Response {
|
||||||
|
slave, err := url.Parse(service.Server)
|
||||||
|
if err != nil {
|
||||||
|
return serializer.ParamErr("无法解析从机端地址,"+err.Error(), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
controller, _ := url.Parse("/api/v3/slave/ping/aria2")
|
||||||
|
|
||||||
|
// 请求正文
|
||||||
|
service.Type = model.MasterNodeType
|
||||||
|
bodyByte, _ := json.Marshal(service)
|
||||||
|
|
||||||
|
r := request.NewClient()
|
||||||
|
res, err := r.Request(
|
||||||
|
"POST",
|
||||||
|
slave.ResolveReference(controller).String(),
|
||||||
|
bytes.NewReader(bodyByte),
|
||||||
|
request.WithTimeout(time.Duration(10)*time.Second),
|
||||||
|
request.WithCredential(
|
||||||
|
auth.HMACAuth{SecretKey: []byte(service.Secret)},
|
||||||
|
int64(model.GetIntSetting("slave_api_timeout", 60)),
|
||||||
|
),
|
||||||
|
).DecodeResponse()
|
||||||
|
if err != nil {
|
||||||
|
return serializer.ParamErr("无连接到从机,"+err.Error(), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Code != 0 {
|
||||||
|
return serializer.ParamErr("成功接到从机,但是从机返回:"+res.Msg, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return serializer.Response{Data: res.Data.(string)}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,138 @@
|
||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AddNodeService 节点添加服务
|
||||||
|
type AddNodeService struct {
|
||||||
|
Node model.Node `json:"node" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add 添加节点
|
||||||
|
func (service *AddNodeService) Add() serializer.Response {
|
||||||
|
if service.Node.ID > 0 {
|
||||||
|
if err := model.DB.Save(&service.Node).Error; err != nil {
|
||||||
|
return serializer.ParamErr("节点保存失败", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := model.DB.Create(&service.Node).Error; err != nil {
|
||||||
|
return serializer.ParamErr("节点添加失败", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return serializer.Response{Data: service.Node.ID}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Nodes 列出从机节点
|
||||||
|
func (service *AdminListService) Nodes() serializer.Response {
|
||||||
|
var res []model.Node
|
||||||
|
total := 0
|
||||||
|
|
||||||
|
tx := model.DB.Model(&model.Node{})
|
||||||
|
if service.OrderBy != "" {
|
||||||
|
tx = tx.Order(service.OrderBy)
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range service.Conditions {
|
||||||
|
tx = tx.Where(k+" = ?", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(service.Searches) > 0 {
|
||||||
|
search := ""
|
||||||
|
for k, v := range service.Searches {
|
||||||
|
search += k + " like '%" + v + "%' OR "
|
||||||
|
}
|
||||||
|
search = strings.TrimSuffix(search, " OR ")
|
||||||
|
tx = tx.Where(search)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算总数用于分页
|
||||||
|
tx.Count(&total)
|
||||||
|
|
||||||
|
// 查询记录
|
||||||
|
tx.Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res)
|
||||||
|
|
||||||
|
isActive := make(map[uint]bool)
|
||||||
|
for i := 0; i < len(res); i++ {
|
||||||
|
if node := cluster.Default.GetNodeByID(res[i].ID); node != nil {
|
||||||
|
isActive[res[i].ID] = node.IsActive()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return serializer.Response{Data: map[string]interface{}{
|
||||||
|
"total": total,
|
||||||
|
"items": res,
|
||||||
|
"active": isActive,
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToggleNodeService 开关节点服务
|
||||||
|
type ToggleNodeService struct {
|
||||||
|
ID uint `uri:"id"`
|
||||||
|
Desired model.NodeStatus `uri:"desired"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Toggle 开关节点
|
||||||
|
func (service *ToggleNodeService) Toggle() serializer.Response {
|
||||||
|
node, err := model.GetNodeByID(service.ID)
|
||||||
|
if err != nil {
|
||||||
|
return serializer.DBErr("找不到节点", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 是否为系统节点
|
||||||
|
if node.ID <= 1 {
|
||||||
|
return serializer.Err(serializer.CodeNoPermissionErr, "系统节点无法更改", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = node.SetStatus(service.Desired); err != nil {
|
||||||
|
return serializer.DBErr("无法更改节点状态", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if service.Desired == model.NodeActive {
|
||||||
|
cluster.Default.Add(&node)
|
||||||
|
} else {
|
||||||
|
cluster.Default.Delete(node.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return serializer.Response{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NodeService 节点ID服务
|
||||||
|
type NodeService struct {
|
||||||
|
ID uint `uri:"id" json:"id" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除节点
|
||||||
|
func (service *NodeService) Delete() serializer.Response {
|
||||||
|
// 查找用户组
|
||||||
|
node, err := model.GetNodeByID(service.ID)
|
||||||
|
if err != nil {
|
||||||
|
return serializer.Err(serializer.CodeNotFound, "节点不存在", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 是否为系统节点
|
||||||
|
if node.ID <= 1 {
|
||||||
|
return serializer.Err(serializer.CodeNoPermissionErr, "系统节点无法删除", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cluster.Default.Delete(node.ID)
|
||||||
|
if err := model.DB.Delete(&node).Error; err != nil {
|
||||||
|
return serializer.DBErr("无法删除节点", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return serializer.Response{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 获取节点详情
|
||||||
|
func (service *NodeService) Get() serializer.Response {
|
||||||
|
node, err := model.GetNodeByID(service.ID)
|
||||||
|
if err != nil {
|
||||||
|
return serializer.Err(serializer.CodeNotFound, "节点不存在", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return serializer.Response{Data: node}
|
||||||
|
}
|
||||||
|
|
@ -151,7 +151,7 @@ func (service *PolicyService) AddCORS() serializer.Response {
|
||||||
case "oss":
|
case "oss":
|
||||||
handler := oss.Driver{
|
handler := oss.Driver{
|
||||||
Policy: &policy,
|
Policy: &policy,
|
||||||
HTTPClient: request.HTTPClient{},
|
HTTPClient: request.NewClient(),
|
||||||
}
|
}
|
||||||
if err := handler.CORS(); err != nil {
|
if err := handler.CORS(); err != nil {
|
||||||
return serializer.Err(serializer.CodeInternalSetting, "跨域策略添加失败", err)
|
return serializer.Err(serializer.CodeInternalSetting, "跨域策略添加失败", err)
|
||||||
|
|
@ -161,7 +161,7 @@ func (service *PolicyService) AddCORS() serializer.Response {
|
||||||
b := &cossdk.BaseURL{BucketURL: u}
|
b := &cossdk.BaseURL{BucketURL: u}
|
||||||
handler := cos.Driver{
|
handler := cos.Driver{
|
||||||
Policy: &policy,
|
Policy: &policy,
|
||||||
HTTPClient: request.HTTPClient{},
|
HTTPClient: request.NewClient(),
|
||||||
Client: cossdk.NewClient(b, &http.Client{
|
Client: cossdk.NewClient(b, &http.Client{
|
||||||
Transport: &cossdk.AuthorizationTransport{
|
Transport: &cossdk.AuthorizationTransport{
|
||||||
SecretID: policy.AccessKey,
|
SecretID: policy.AccessKey,
|
||||||
|
|
@ -195,7 +195,7 @@ func (service *SlavePingService) Test() serializer.Response {
|
||||||
|
|
||||||
controller, _ := url.Parse("/api/v3/site/ping")
|
controller, _ := url.Parse("/api/v3/site/ping")
|
||||||
|
|
||||||
r := request.HTTPClient{}
|
r := request.NewClient()
|
||||||
res, err := r.Request(
|
res, err := r.Request(
|
||||||
"GET",
|
"GET",
|
||||||
master.ResolveReference(controller).String(),
|
master.ResolveReference(controller).String(),
|
||||||
|
|
@ -229,7 +229,7 @@ func (service *SlaveTestService) Test() serializer.Response {
|
||||||
}
|
}
|
||||||
bodyByte, _ := json.Marshal(body)
|
bodyByte, _ := json.Marshal(body)
|
||||||
|
|
||||||
r := request.HTTPClient{}
|
r := request.NewClient()
|
||||||
res, err := r.Request(
|
res, err := r.Request(
|
||||||
"POST",
|
"POST",
|
||||||
slave.ResolveReference(controller).String(),
|
slave.ResolveReference(controller).String(),
|
||||||
|
|
@ -245,7 +245,7 @@ func (service *SlaveTestService) Test() serializer.Response {
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.Code != 0 {
|
if res.Code != 0 {
|
||||||
return serializer.ParamErr("成功接到从机,但是"+res.Msg, nil)
|
return serializer.ParamErr("成功接到从机,但是从机返回:"+res.Msg, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
return serializer.Response{}
|
return serializer.Response{}
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,14 @@ package aria2
|
||||||
import (
|
import (
|
||||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -14,7 +20,7 @@ type AddURLService struct {
|
||||||
Dst string `json:"dst" binding:"required,min=1"`
|
Dst string `json:"dst" binding:"required,min=1"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add 创建新的链接离线下载任务
|
// Add 主机创建新的链接离线下载任务
|
||||||
func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Response {
|
func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Response {
|
||||||
// 创建文件系统
|
// 创建文件系统
|
||||||
fs, err := filesystem.NewFileSystemFromContext(c)
|
fs, err := filesystem.NewFileSystemFromContext(c)
|
||||||
|
|
@ -35,19 +41,60 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo
|
||||||
|
|
||||||
// 创建任务
|
// 创建任务
|
||||||
task := &model.Download{
|
task := &model.Download{
|
||||||
Status: aria2.Ready,
|
Status: common.Ready,
|
||||||
Type: taskType,
|
Type: taskType,
|
||||||
Dst: service.Dst,
|
Dst: service.Dst,
|
||||||
UserID: fs.User.ID,
|
UserID: fs.User.ID,
|
||||||
Source: service.URL,
|
Source: service.URL,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 获取 Aria2 负载均衡器
|
||||||
aria2.Lock.RLock()
|
aria2.Lock.RLock()
|
||||||
if err := aria2.Instance.CreateTask(task, fs.User.Group.OptionsSerialized.Aria2Options); err != nil {
|
lb := aria2.LB
|
||||||
aria2.Lock.RUnlock()
|
aria2.Lock.RUnlock()
|
||||||
|
|
||||||
|
// 获取 Aria2 实例
|
||||||
|
err, node := cluster.Default.BalanceNodeByFeature("aria2", lb)
|
||||||
|
if err != nil {
|
||||||
|
return serializer.Err(serializer.CodeInternalSetting, "Aria2 实例获取失败", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建任务
|
||||||
|
gid, err := node.GetAria2Instance().CreateTask(task, fs.User.Group.OptionsSerialized.Aria2Options)
|
||||||
|
if err != nil {
|
||||||
return serializer.Err(serializer.CodeNotSet, "任务创建失败", err)
|
return serializer.Err(serializer.CodeNotSet, "任务创建失败", err)
|
||||||
}
|
}
|
||||||
aria2.Lock.RUnlock()
|
|
||||||
|
task.GID = gid
|
||||||
|
task.NodeID = node.ID()
|
||||||
|
_, err = task.Create()
|
||||||
|
if err != nil {
|
||||||
|
return serializer.DBErr("任务创建失败", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建任务监控
|
||||||
|
monitor.NewMonitor(task)
|
||||||
|
|
||||||
return serializer.Response{}
|
return serializer.Response{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add 从机创建新的链接离线下载任务
|
||||||
|
func Add(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response {
|
||||||
|
caller, _ := c.Get("MasterAria2Instance")
|
||||||
|
|
||||||
|
// 创建任务
|
||||||
|
gid, err := caller.(common.Aria2).CreateTask(service.Task, service.GroupOptions)
|
||||||
|
if err != nil {
|
||||||
|
return serializer.Err(serializer.CodeInternalSetting, "无法创建离线下载任务", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建事件通知回调
|
||||||
|
siteID, _ := c.Get("MasterSiteID")
|
||||||
|
mq.GlobalMQ.SubscribeCallback(gid, func(message mq.Message) {
|
||||||
|
if err := slave.DefaultController.SendNotification(siteID.(string), message.TriggeredBy, message); err != nil {
|
||||||
|
util.Log().Warning("无法发送离线下载任务状态变更通知, %s", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return serializer.Response{Data: gid}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,8 @@ package aria2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
@ -25,14 +26,14 @@ type DownloadListService struct {
|
||||||
// Finished 获取已完成的任务
|
// Finished 获取已完成的任务
|
||||||
func (service *DownloadListService) Finished(c *gin.Context, user *model.User) serializer.Response {
|
func (service *DownloadListService) Finished(c *gin.Context, user *model.User) serializer.Response {
|
||||||
// 查找下载记录
|
// 查找下载记录
|
||||||
downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, aria2.Error, aria2.Complete, aria2.Canceled, aria2.Unknown)
|
downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Error, common.Complete, common.Canceled, common.Unknown)
|
||||||
return serializer.BuildFinishedListResponse(downloads)
|
return serializer.BuildFinishedListResponse(downloads)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Downloading 获取正在下载中的任务
|
// Downloading 获取正在下载中的任务
|
||||||
func (service *DownloadListService) Downloading(c *gin.Context, user *model.User) serializer.Response {
|
func (service *DownloadListService) Downloading(c *gin.Context, user *model.User) serializer.Response {
|
||||||
// 查找下载记录
|
// 查找下载记录
|
||||||
downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, aria2.Downloading, aria2.Paused, aria2.Ready)
|
downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Downloading, common.Paused, common.Ready)
|
||||||
return serializer.BuildDownloadingResponse(downloads)
|
return serializer.BuildDownloadingResponse(downloads)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -47,7 +48,7 @@ func (service *DownloadTaskService) Delete(c *gin.Context) serializer.Response {
|
||||||
return serializer.Err(serializer.CodeNotFound, "下载记录不存在", err)
|
return serializer.Err(serializer.CodeNotFound, "下载记录不存在", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if download.Status >= aria2.Error {
|
if download.Status >= common.Error {
|
||||||
// 如果任务已完成,则删除任务记录
|
// 如果任务已完成,则删除任务记录
|
||||||
if err := download.Delete(); err != nil {
|
if err := download.Delete(); err != nil {
|
||||||
return serializer.Err(serializer.CodeDBError, "任务记录删除失败", err)
|
return serializer.Err(serializer.CodeDBError, "任务记录删除失败", err)
|
||||||
|
|
@ -56,9 +57,12 @@ func (service *DownloadTaskService) Delete(c *gin.Context) serializer.Response {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 取消任务
|
// 取消任务
|
||||||
aria2.Lock.RLock()
|
node := cluster.Default.GetNodeByID(download.GetNodeID())
|
||||||
defer aria2.Lock.RUnlock()
|
if node == nil {
|
||||||
if err := aria2.Instance.Cancel(download); err != nil {
|
return serializer.Err(serializer.CodeInternalSetting, "目标节点不可用", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := node.GetAria2Instance().Cancel(download); err != nil {
|
||||||
return serializer.Err(serializer.CodeNotSet, "操作失败", err)
|
return serializer.Err(serializer.CodeNotSet, "操作失败", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -76,17 +80,72 @@ func (service *SelectFileService) Select(c *gin.Context) serializer.Response {
|
||||||
return serializer.Err(serializer.CodeNotFound, "下载记录不存在", err)
|
return serializer.Err(serializer.CodeNotFound, "下载记录不存在", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if download.StatusInfo.BitTorrent.Mode != "multi" || (download.Status != aria2.Downloading && download.Status != aria2.Paused) {
|
if download.StatusInfo.BitTorrent.Mode != "multi" || (download.Status != common.Downloading && download.Status != common.Paused) {
|
||||||
return serializer.Err(serializer.CodeNoPermissionErr, "此下载任务无法选取文件", err)
|
return serializer.Err(serializer.CodeNoPermissionErr, "此下载任务无法选取文件", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 选取下载
|
// 选取下载
|
||||||
aria2.Lock.RLock()
|
node := cluster.Default.GetNodeByID(download.GetNodeID())
|
||||||
defer aria2.Lock.RUnlock()
|
if err := node.GetAria2Instance().Select(download, service.Indexes); err != nil {
|
||||||
if err := aria2.Instance.Select(download, service.Indexes); err != nil {
|
|
||||||
return serializer.Err(serializer.CodeNotSet, "操作失败", err)
|
return serializer.Err(serializer.CodeNotSet, "操作失败", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return serializer.Response{}
|
return serializer.Response{}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SlaveStatus 从机查询离线任务状态
|
||||||
|
func SlaveStatus(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response {
|
||||||
|
caller, _ := c.Get("MasterAria2Instance")
|
||||||
|
|
||||||
|
// 查询任务
|
||||||
|
status, err := caller.(common.Aria2).Status(service.Task)
|
||||||
|
if err != nil {
|
||||||
|
return serializer.Err(serializer.CodeInternalSetting, "离线下载任务查询失败", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return serializer.NewResponseWithGobData(status)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// SlaveCancel 取消从机离线下载任务
|
||||||
|
func SlaveCancel(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response {
|
||||||
|
caller, _ := c.Get("MasterAria2Instance")
|
||||||
|
|
||||||
|
// 查询任务
|
||||||
|
err := caller.(common.Aria2).Cancel(service.Task)
|
||||||
|
if err != nil {
|
||||||
|
return serializer.Err(serializer.CodeInternalSetting, "任务取消失败", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return serializer.Response{}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// SlaveSelect 从机选取离线下载任务文件
|
||||||
|
func SlaveSelect(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response {
|
||||||
|
caller, _ := c.Get("MasterAria2Instance")
|
||||||
|
|
||||||
|
// 查询任务
|
||||||
|
err := caller.(common.Aria2).Select(service.Task, service.Files)
|
||||||
|
if err != nil {
|
||||||
|
return serializer.Err(serializer.CodeInternalSetting, "任务选取失败", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return serializer.Response{}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// SlaveSelect 从机选取离线下载任务文件
|
||||||
|
func SlaveDeleteTemp(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response {
|
||||||
|
caller, _ := c.Get("MasterAria2Instance")
|
||||||
|
|
||||||
|
// 查询任务
|
||||||
|
err := caller.(common.Aria2).DeleteTempFile(service.Task)
|
||||||
|
if err != nil {
|
||||||
|
return serializer.Err(serializer.CodeInternalSetting, "临时文件删除失败", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return serializer.Response{}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ package explorer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
|
@ -20,7 +19,6 @@ import (
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/jinzhu/gorm"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SingleFileService 对单文件进行操作的五福,path为文件完整路径
|
// SingleFileService 对单文件进行操作的五福,path为文件完整路径
|
||||||
|
|
@ -43,29 +41,6 @@ type DownloadService struct {
|
||||||
ID string `uri:"id" binding:"required"`
|
ID string `uri:"id" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SlaveDownloadService 从机文件下載服务
|
|
||||||
type SlaveDownloadService struct {
|
|
||||||
PathEncoded string `uri:"path" binding:"required"`
|
|
||||||
Name string `uri:"name" binding:"required"`
|
|
||||||
Speed int `uri:"speed" binding:"min=0"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SlaveFileService 从机单文件文件相关服务
|
|
||||||
type SlaveFileService struct {
|
|
||||||
PathEncoded string `uri:"path" binding:"required"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SlaveFilesService 从机多文件相关服务
|
|
||||||
type SlaveFilesService struct {
|
|
||||||
Files []string `json:"files" binding:"required,gt=0"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SlaveListService 从机列表服务
|
|
||||||
type SlaveListService struct {
|
|
||||||
Path string `json:"path" binding:"required,min=1,max=65535"`
|
|
||||||
Recursive bool `json:"recursive"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// New 创建新文件
|
// New 创建新文件
|
||||||
func (service *SingleFileService) Create(c *gin.Context) serializer.Response {
|
func (service *SingleFileService) Create(c *gin.Context) serializer.Response {
|
||||||
// 创建文件系统
|
// 创建文件系统
|
||||||
|
|
@ -449,106 +424,3 @@ func (service *FileIDService) PutContent(ctx context.Context, c *gin.Context) se
|
||||||
Code: 0,
|
Code: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeFile 通过签名的URL下载从机文件
|
|
||||||
func (service *SlaveDownloadService) ServeFile(ctx context.Context, c *gin.Context, isDownload bool) serializer.Response {
|
|
||||||
// 创建文件系统
|
|
||||||
fs, err := filesystem.NewAnonymousFileSystem()
|
|
||||||
if err != nil {
|
|
||||||
return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)
|
|
||||||
}
|
|
||||||
defer fs.Recycle()
|
|
||||||
|
|
||||||
// 解码文件路径
|
|
||||||
fileSource, err := base64.RawURLEncoding.DecodeString(service.PathEncoded)
|
|
||||||
if err != nil {
|
|
||||||
return serializer.ParamErr("无法解析的文件地址", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 根据URL里的信息创建一个文件对象和用户对象
|
|
||||||
file := model.File{
|
|
||||||
Name: service.Name,
|
|
||||||
SourceName: string(fileSource),
|
|
||||||
Policy: model.Policy{
|
|
||||||
Model: gorm.Model{ID: 1},
|
|
||||||
Type: "local",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
fs.User = &model.User{
|
|
||||||
Group: model.Group{SpeedLimit: service.Speed},
|
|
||||||
}
|
|
||||||
fs.FileTarget = []model.File{file}
|
|
||||||
|
|
||||||
// 开始处理下载
|
|
||||||
ctx = context.WithValue(ctx, fsctx.GinCtx, c)
|
|
||||||
rs, err := fs.GetDownloadContent(ctx, 0)
|
|
||||||
if err != nil {
|
|
||||||
return serializer.Err(serializer.CodeNotSet, err.Error(), err)
|
|
||||||
}
|
|
||||||
defer rs.Close()
|
|
||||||
|
|
||||||
// 设置下载文件名
|
|
||||||
if isDownload {
|
|
||||||
c.Header("Content-Disposition", "attachment; filename=\""+url.PathEscape(fs.FileTarget[0].Name)+"\"")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 发送文件
|
|
||||||
http.ServeContent(c.Writer, c.Request, fs.FileTarget[0].Name, time.Now(), rs)
|
|
||||||
|
|
||||||
return serializer.Response{
|
|
||||||
Code: 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete 通过签名的URL删除从机文件
|
|
||||||
func (service *SlaveFilesService) Delete(ctx context.Context, c *gin.Context) serializer.Response {
|
|
||||||
// 创建文件系统
|
|
||||||
fs, err := filesystem.NewAnonymousFileSystem()
|
|
||||||
if err != nil {
|
|
||||||
return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)
|
|
||||||
}
|
|
||||||
defer fs.Recycle()
|
|
||||||
|
|
||||||
// 删除文件
|
|
||||||
failed, err := fs.Handler.Delete(ctx, service.Files)
|
|
||||||
if err != nil {
|
|
||||||
// 将Data字段写为字符串方便主控端解析
|
|
||||||
data, _ := json.Marshal(serializer.RemoteDeleteRequest{Files: failed})
|
|
||||||
|
|
||||||
return serializer.Response{
|
|
||||||
Code: serializer.CodeNotFullySuccess,
|
|
||||||
Data: string(data),
|
|
||||||
Msg: fmt.Sprintf("有 %d 个文件未能成功删除", len(failed)),
|
|
||||||
Error: err.Error(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return serializer.Response{Code: 0}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Thumb 通过签名URL获取从机文件缩略图
|
|
||||||
func (service *SlaveFileService) Thumb(ctx context.Context, c *gin.Context) serializer.Response {
|
|
||||||
// 创建文件系统
|
|
||||||
fs, err := filesystem.NewAnonymousFileSystem()
|
|
||||||
if err != nil {
|
|
||||||
return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)
|
|
||||||
}
|
|
||||||
defer fs.Recycle()
|
|
||||||
|
|
||||||
// 解码文件路径
|
|
||||||
fileSource, err := base64.RawURLEncoding.DecodeString(service.PathEncoded)
|
|
||||||
if err != nil {
|
|
||||||
return serializer.ParamErr("无法解析的文件地址", err)
|
|
||||||
}
|
|
||||||
fs.FileTarget = []model.File{{SourceName: string(fileSource), PicInfo: "1,1"}}
|
|
||||||
|
|
||||||
// 获取缩略图
|
|
||||||
resp, err := fs.GetThumb(ctx, 0)
|
|
||||||
if err != nil {
|
|
||||||
return serializer.Err(serializer.CodeNotSet, "无法获取缩略图", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer resp.Content.Close()
|
|
||||||
http.ServeContent(c.Writer, c.Request, "thumb.png", time.Now(), resp.Content)
|
|
||||||
|
|
||||||
return serializer.Response{Code: 0}
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,166 @@
|
||||||
|
package explorer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/task/slavetask"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SlaveDownloadService 从机文件下載服务
|
||||||
|
type SlaveDownloadService struct {
|
||||||
|
PathEncoded string `uri:"path" binding:"required"`
|
||||||
|
Name string `uri:"name" binding:"required"`
|
||||||
|
Speed int `uri:"speed" binding:"min=0"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SlaveFileService 从机单文件文件相关服务
|
||||||
|
type SlaveFileService struct {
|
||||||
|
PathEncoded string `uri:"path" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SlaveFilesService 从机多文件相关服务
|
||||||
|
type SlaveFilesService struct {
|
||||||
|
Files []string `json:"files" binding:"required,gt=0"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SlaveListService 从机列表服务
|
||||||
|
type SlaveListService struct {
|
||||||
|
Path string `json:"path" binding:"required,min=1,max=65535"`
|
||||||
|
Recursive bool `json:"recursive"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeFile 通过签名的URL下载从机文件
|
||||||
|
func (service *SlaveDownloadService) ServeFile(ctx context.Context, c *gin.Context, isDownload bool) serializer.Response {
|
||||||
|
// 创建文件系统
|
||||||
|
fs, err := filesystem.NewAnonymousFileSystem()
|
||||||
|
if err != nil {
|
||||||
|
return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)
|
||||||
|
}
|
||||||
|
defer fs.Recycle()
|
||||||
|
|
||||||
|
// 解码文件路径
|
||||||
|
fileSource, err := base64.RawURLEncoding.DecodeString(service.PathEncoded)
|
||||||
|
if err != nil {
|
||||||
|
return serializer.ParamErr("无法解析的文件地址", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据URL里的信息创建一个文件对象和用户对象
|
||||||
|
file := model.File{
|
||||||
|
Name: service.Name,
|
||||||
|
SourceName: string(fileSource),
|
||||||
|
Policy: model.Policy{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Type: "local",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fs.User = &model.User{
|
||||||
|
Group: model.Group{SpeedLimit: service.Speed},
|
||||||
|
}
|
||||||
|
fs.FileTarget = []model.File{file}
|
||||||
|
|
||||||
|
// 开始处理下载
|
||||||
|
ctx = context.WithValue(ctx, fsctx.GinCtx, c)
|
||||||
|
rs, err := fs.GetDownloadContent(ctx, 0)
|
||||||
|
if err != nil {
|
||||||
|
return serializer.Err(serializer.CodeNotSet, err.Error(), err)
|
||||||
|
}
|
||||||
|
defer rs.Close()
|
||||||
|
|
||||||
|
// 设置下载文件名
|
||||||
|
if isDownload {
|
||||||
|
c.Header("Content-Disposition", "attachment; filename=\""+url.PathEscape(fs.FileTarget[0].Name)+"\"")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送文件
|
||||||
|
http.ServeContent(c.Writer, c.Request, fs.FileTarget[0].Name, time.Now(), rs)
|
||||||
|
|
||||||
|
return serializer.Response{
|
||||||
|
Code: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 通过签名的URL删除从机文件
|
||||||
|
func (service *SlaveFilesService) Delete(ctx context.Context, c *gin.Context) serializer.Response {
|
||||||
|
// 创建文件系统
|
||||||
|
fs, err := filesystem.NewAnonymousFileSystem()
|
||||||
|
if err != nil {
|
||||||
|
return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)
|
||||||
|
}
|
||||||
|
defer fs.Recycle()
|
||||||
|
|
||||||
|
// 删除文件
|
||||||
|
failed, err := fs.Handler.Delete(ctx, service.Files)
|
||||||
|
if err != nil {
|
||||||
|
// 将Data字段写为字符串方便主控端解析
|
||||||
|
data, _ := json.Marshal(serializer.RemoteDeleteRequest{Files: failed})
|
||||||
|
|
||||||
|
return serializer.Response{
|
||||||
|
Code: serializer.CodeNotFullySuccess,
|
||||||
|
Data: string(data),
|
||||||
|
Msg: fmt.Sprintf("有 %d 个文件未能成功删除", len(failed)),
|
||||||
|
Error: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return serializer.Response{Code: 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Thumb 通过签名URL获取从机文件缩略图
|
||||||
|
func (service *SlaveFileService) Thumb(ctx context.Context, c *gin.Context) serializer.Response {
|
||||||
|
// 创建文件系统
|
||||||
|
fs, err := filesystem.NewAnonymousFileSystem()
|
||||||
|
if err != nil {
|
||||||
|
return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)
|
||||||
|
}
|
||||||
|
defer fs.Recycle()
|
||||||
|
|
||||||
|
// 解码文件路径
|
||||||
|
fileSource, err := base64.RawURLEncoding.DecodeString(service.PathEncoded)
|
||||||
|
if err != nil {
|
||||||
|
return serializer.ParamErr("无法解析的文件地址", err)
|
||||||
|
}
|
||||||
|
fs.FileTarget = []model.File{{SourceName: string(fileSource), PicInfo: "1,1"}}
|
||||||
|
|
||||||
|
// 获取缩略图
|
||||||
|
resp, err := fs.GetThumb(ctx, 0)
|
||||||
|
if err != nil {
|
||||||
|
return serializer.Err(serializer.CodeNotSet, "无法获取缩略图", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Content.Close()
|
||||||
|
http.ServeContent(c.Writer, c.Request, "thumb.png", time.Now(), resp.Content)
|
||||||
|
|
||||||
|
return serializer.Response{Code: 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTransferTask 创建从机文件转存任务
|
||||||
|
func CreateTransferTask(c *gin.Context, req *serializer.SlaveTransferReq) serializer.Response {
|
||||||
|
if id, ok := c.Get("MasterSiteID"); ok {
|
||||||
|
job := &slavetask.TransferTask{
|
||||||
|
Req: req,
|
||||||
|
MasterID: id.(string),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := slave.DefaultController.SubmitTask(job.MasterID, job, req.Hash(job.MasterID), func(job interface{}) {
|
||||||
|
task.TaskPoll.Submit(job.(task.Job))
|
||||||
|
}); err != nil {
|
||||||
|
return serializer.Err(serializer.CodeInternalSetting, "任务创建失败", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return serializer.Response{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return serializer.ParamErr("未知的主机节点ID", nil)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
package node
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/gob"
|
||||||
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SlaveNotificationService struct {
|
||||||
|
Subject string `uri:"subject" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OneDriveCredentialService struct {
|
||||||
|
PolicyID uint `uri:"id" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func HandleMasterHeartbeat(req *serializer.NodePingReq) serializer.Response {
|
||||||
|
res, err := slave.DefaultController.HandleHeartBeat(req)
|
||||||
|
if err != nil {
|
||||||
|
return serializer.Err(serializer.CodeInternalSetting, "Cannot initialize slave controller", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return serializer.Response{
|
||||||
|
Code: 0,
|
||||||
|
Data: res,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleSlaveNotificationPush 转发从机的消息通知到本机消息队列
|
||||||
|
func (s *SlaveNotificationService) HandleSlaveNotificationPush(c *gin.Context) serializer.Response {
|
||||||
|
var msg mq.Message
|
||||||
|
dec := gob.NewDecoder(c.Request.Body)
|
||||||
|
if err := dec.Decode(&msg); err != nil {
|
||||||
|
return serializer.ParamErr("Cannot parse notification message", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mq.GlobalMQ.Publish(s.Subject, msg)
|
||||||
|
return serializer.Response{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 获取主机OneDrive策略的AccessToken
|
||||||
|
func (s *OneDriveCredentialService) Get(c *gin.Context) serializer.Response {
|
||||||
|
policy, err := model.GetPolicyByID(s.PolicyID)
|
||||||
|
if err != nil {
|
||||||
|
return serializer.Err(serializer.CodeNotFound, "Cannot found storage policy", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := onedrive.NewClient(&policy)
|
||||||
|
if err != nil {
|
||||||
|
return serializer.Err(serializer.CodeInternalSetting, "Cannot initialize OneDrive client", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := client.UpdateCredential(c); err != nil {
|
||||||
|
return serializer.Err(serializer.CodeInternalSetting, "Cannot refresh OneDrive credential", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return serializer.Response{Data: client.Credential.AccessToken}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue