mirror of https://github.com/cloudreve/Cloudreve
Test: balancer / auth / controller in pkg
parent
f0089045d7
commit
416f4c1dd2
|
@ -9,6 +9,7 @@ import (
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/crontab"
|
"github.com/cloudreve/Cloudreve/v3/pkg/crontab"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/email"
|
"github.com/cloudreve/Cloudreve/v3/pkg/email"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
@ -53,7 +54,7 @@ func Init(path string) {
|
||||||
{
|
{
|
||||||
"master",
|
"master",
|
||||||
func() {
|
func() {
|
||||||
aria2.Init(false)
|
aria2.Init(false, cluster.Default, mq.GlobalMQ)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -33,7 +33,7 @@ func GetLoadBalancer() balancer.Balancer {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init 初始化
|
// Init 初始化
|
||||||
func Init(isReload bool) {
|
func Init(isReload bool, pool cluster.Pool, mqClient mq.MQ) {
|
||||||
Lock.Lock()
|
Lock.Lock()
|
||||||
LB = balancer.NewBalancer("RoundRobin")
|
LB = balancer.NewBalancer("RoundRobin")
|
||||||
Lock.Unlock()
|
Lock.Unlock()
|
||||||
|
@ -44,7 +44,7 @@ func Init(isReload bool) {
|
||||||
|
|
||||||
for i := 0; i < len(unfinished); i++ {
|
for i := 0; i < len(unfinished); i++ {
|
||||||
// 创建任务监控
|
// 创建任务监控
|
||||||
monitor.NewMonitor(&unfinished[i], cluster.Default, mq.GlobalMQ)
|
monitor.NewMonitor(&unfinished[i], pool, mqClient)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,14 +2,15 @@ package aria2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/mocks"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
testMock "github.com/stretchr/testify/mock"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/DATA-DOG/go-sqlmock"
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor"
|
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var mock sqlmock.Sqlmock
|
var mock sqlmock.Sqlmock
|
||||||
|
@ -27,66 +28,39 @@ func TestMain(m *testing.M) {
|
||||||
m.Run()
|
m.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDummyAria2(t *testing.T) {
|
|
||||||
asserts := assert.New(t)
|
|
||||||
instance := DummyAria2{}
|
|
||||||
asserts.Error(instance.CreateTask(nil, nil))
|
|
||||||
_, err := instance.Status(nil)
|
|
||||||
asserts.Error(err)
|
|
||||||
asserts.Error(instance.Cancel(nil))
|
|
||||||
asserts.Error(instance.Select(nil, nil))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInit(t *testing.T) {
|
func TestInit(t *testing.T) {
|
||||||
monitor.MAX_RETRY = 0
|
a := assert.New(t)
|
||||||
asserts := assert.New(t)
|
mockPool := &mocks.NodePoolMock{}
|
||||||
cache.Set("setting_aria2_token", "1", 0)
|
mockPool.On("GetNodeByID", testMock.Anything).Return(nil)
|
||||||
cache.Set("setting_aria2_call_timeout", "5", 0)
|
mockQueue := mq.NewMQ()
|
||||||
cache.Set("setting_aria2_options", `[]`, 0)
|
|
||||||
|
|
||||||
// 未指定RPC地址,跳过
|
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||||
|
Init(false, mockPool, mockQueue)
|
||||||
|
a.NoError(mock.ExpectationsWereMet())
|
||||||
|
mockPool.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTestRPCConnection(t *testing.T) {
|
||||||
|
a := assert.New(t)
|
||||||
|
|
||||||
|
// url not legal
|
||||||
{
|
{
|
||||||
cache.Set("setting_aria2_rpcurl", "", 0)
|
res, err := TestRPCConnection(string([]byte{0x7f}), "", 10)
|
||||||
Init(false)
|
a.Error(err)
|
||||||
asserts.IsType(&DummyAria2{}, Instance)
|
a.Empty(res.Version)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 无法解析服务器地址
|
// rpc failed
|
||||||
{
|
{
|
||||||
cache.Set("setting_aria2_rpcurl", string(byte(0x7f)), 0)
|
res, err := TestRPCConnection("ws://0.0.0.0", "", 0)
|
||||||
Init(false)
|
a.Error(err)
|
||||||
asserts.IsType(&DummyAria2{}, Instance)
|
a.Empty(res.Version)
|
||||||
}
|
|
||||||
|
|
||||||
// 无法解析全局配置
|
|
||||||
{
|
|
||||||
Instance = &RPCService{}
|
|
||||||
cache.Set("setting_aria2_options", "?", 0)
|
|
||||||
cache.Set("setting_aria2_rpcurl", "ws://127.0.0.1:1234", 0)
|
|
||||||
Init(false)
|
|
||||||
asserts.IsType(&DummyAria2{}, Instance)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 连接失败
|
|
||||||
{
|
|
||||||
cache.Set("setting_aria2_options", "{}", 0)
|
|
||||||
cache.Set("setting_aria2_rpcurl", "http://127.0.0.1:1234", 0)
|
|
||||||
cache.Set("setting_aria2_call_timeout", "1", 0)
|
|
||||||
cache.Set("setting_aria2_interval", "100", 0)
|
|
||||||
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"g_id"}).AddRow("1"))
|
|
||||||
Init(false)
|
|
||||||
asserts.NoError(mock.ExpectationsWereMet())
|
|
||||||
asserts.IsType(&RPCService{}, Instance)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetStatus(t *testing.T) {
|
func TestGetLoadBalancer(t *testing.T) {
|
||||||
asserts := assert.New(t)
|
a := assert.New(t)
|
||||||
asserts.Equal(4, GetStatus("complete"))
|
a.NotPanics(func() {
|
||||||
asserts.Equal(1, GetStatus("active"))
|
GetLoadBalancer()
|
||||||
asserts.Equal(0, GetStatus("waiting"))
|
})
|
||||||
asserts.Equal(2, GetStatus("paused"))
|
|
||||||
asserts.Equal(3, GetStatus("error"))
|
|
||||||
asserts.Equal(5, GetStatus("removed"))
|
|
||||||
asserts.Equal(6, GetStatus("?"))
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,114 +0,0 @@
|
||||||
package aria2
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"path/filepath"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RPCService 通过RPC服务的Aria2任务管理器
|
|
||||||
type RPCService struct {
|
|
||||||
options *clientOptions
|
|
||||||
Caller rpc.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
type clientOptions struct {
|
|
||||||
Options map[string]interface{} // 创建下载时额外添加的设置
|
|
||||||
}
|
|
||||||
|
|
||||||
// Init 初始化
|
|
||||||
func (client *RPCService) Init(server, secret string, timeout int, options map[string]interface{}) error {
|
|
||||||
// 客户端已存在,则关闭先前连接
|
|
||||||
if client.Caller != nil {
|
|
||||||
client.Caller.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
client.options = &clientOptions{
|
|
||||||
Options: options,
|
|
||||||
}
|
|
||||||
caller, err := rpc.New(context.Background(), server, secret, time.Duration(timeout)*time.Second,
|
|
||||||
mq.GlobalMQ)
|
|
||||||
client.Caller = caller
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Status 查询下载状态
|
|
||||||
func (client *RPCService) Status(task *model.Download) (rpc.StatusInfo, error) {
|
|
||||||
res, err := client.Caller.TellStatus(task.GID)
|
|
||||||
if err != nil {
|
|
||||||
// 失败后重试
|
|
||||||
util.Log().Debug("无法获取离线下载状态,%s,10秒钟后重试", err)
|
|
||||||
time.Sleep(time.Duration(10) * time.Second)
|
|
||||||
res, err = client.Caller.TellStatus(task.GID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cancel 取消下载
|
|
||||||
func (client *RPCService) Cancel(task *model.Download) error {
|
|
||||||
// 取消下载任务
|
|
||||||
_, err := client.Caller.Remove(task.GID)
|
|
||||||
if err != nil {
|
|
||||||
util.Log().Warning("无法取消离线下载任务[%s], %s", task.GID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
//// 删除临时文件
|
|
||||||
//util.Log().Debug("离线下载任务[%s]已取消,1 分钟后删除临时文件", task.GID)
|
|
||||||
//go func(task *model.Download) {
|
|
||||||
// select {
|
|
||||||
// case <-time.After(time.Duration(60) * time.Second):
|
|
||||||
// err := os.RemoveAll(task.Parent)
|
|
||||||
// if err != nil {
|
|
||||||
// util.Log().Warning("无法删除离线下载临时目录[%s], %s", task.Parent, err)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//}(task)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Select 选取要下载的文件
|
|
||||||
func (client *RPCService) Select(task *model.Download, files []int) error {
|
|
||||||
var selected = make([]string, len(files))
|
|
||||||
for i := 0; i < len(files); i++ {
|
|
||||||
selected[i] = strconv.Itoa(files[i])
|
|
||||||
}
|
|
||||||
_, err := client.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateTask 创建新任务
|
|
||||||
func (client *RPCService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) {
|
|
||||||
// 生成存储路径
|
|
||||||
path := filepath.Join(
|
|
||||||
model.GetSettingByName("aria2_temp_path"),
|
|
||||||
"aria2",
|
|
||||||
strconv.FormatInt(time.Now().UnixNano(), 10),
|
|
||||||
)
|
|
||||||
|
|
||||||
// 创建下载任务
|
|
||||||
options := map[string]interface{}{
|
|
||||||
"dir": path,
|
|
||||||
}
|
|
||||||
for k, v := range client.options.Options {
|
|
||||||
options[k] = v
|
|
||||||
}
|
|
||||||
for k, v := range groupOptions {
|
|
||||||
options[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
gid, err := client.Caller.AddURI(task.Source, options)
|
|
||||||
if err != nil || gid == "" {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return gid, nil
|
|
||||||
}
|
|
|
@ -1,52 +0,0 @@
|
||||||
package aria2
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestRPCService_Init(t *testing.T) {
|
|
||||||
asserts := assert.New(t)
|
|
||||||
caller := &RPCService{}
|
|
||||||
asserts.Error(caller.Init("ws://", "", 1, nil))
|
|
||||||
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRPCService_Status(t *testing.T) {
|
|
||||||
asserts := assert.New(t)
|
|
||||||
caller := &RPCService{}
|
|
||||||
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
|
|
||||||
|
|
||||||
_, err := caller.Status(&model.Download{})
|
|
||||||
asserts.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRPCService_Cancel(t *testing.T) {
|
|
||||||
asserts := assert.New(t)
|
|
||||||
caller := &RPCService{}
|
|
||||||
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
|
|
||||||
|
|
||||||
err := caller.Cancel(&model.Download{Parent: "test"})
|
|
||||||
asserts.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRPCService_Select(t *testing.T) {
|
|
||||||
asserts := assert.New(t)
|
|
||||||
caller := &RPCService{}
|
|
||||||
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
|
|
||||||
|
|
||||||
err := caller.Select(&model.Download{Parent: "test"}, []int{1, 2, 3})
|
|
||||||
asserts.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRPCService_CreateTask(t *testing.T) {
|
|
||||||
asserts := assert.New(t)
|
|
||||||
caller := &RPCService{}
|
|
||||||
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
|
|
||||||
cache.Set("setting_aria2_temp_path", "test", 0)
|
|
||||||
err := caller.CreateTask(&model.Download{Parent: "test"}, map[string]interface{}{"1": "1"})
|
|
||||||
asserts.Error(err)
|
|
||||||
}
|
|
|
@ -1,52 +0,0 @@
|
||||||
package aria2
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNotifier_Notify(t *testing.T) {
|
|
||||||
asserts := assert.New(t)
|
|
||||||
notifier2 := &Notifier{}
|
|
||||||
notifyChan := make(chan StatusEvent, 10)
|
|
||||||
notifier2.Subscribe(notifyChan, "1")
|
|
||||||
|
|
||||||
// 未订阅
|
|
||||||
{
|
|
||||||
notifier2.Notify([]rpc.Event{rpc.Event{Gid: ""}}, 1)
|
|
||||||
asserts.Len(notifyChan, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 订阅
|
|
||||||
{
|
|
||||||
notifier2.Notify([]rpc.Event{{Gid: "1"}}, 1)
|
|
||||||
asserts.Len(notifyChan, 1)
|
|
||||||
<-notifyChan
|
|
||||||
|
|
||||||
notifier2.OnBtDownloadComplete([]rpc.Event{{Gid: "1"}})
|
|
||||||
asserts.Len(notifyChan, 1)
|
|
||||||
<-notifyChan
|
|
||||||
|
|
||||||
notifier2.OnDownloadStart([]rpc.Event{{Gid: "1"}})
|
|
||||||
asserts.Len(notifyChan, 1)
|
|
||||||
<-notifyChan
|
|
||||||
|
|
||||||
notifier2.OnDownloadPause([]rpc.Event{{Gid: "1"}})
|
|
||||||
asserts.Len(notifyChan, 1)
|
|
||||||
<-notifyChan
|
|
||||||
|
|
||||||
notifier2.OnDownloadStop([]rpc.Event{{Gid: "1"}})
|
|
||||||
asserts.Len(notifyChan, 1)
|
|
||||||
<-notifyChan
|
|
||||||
|
|
||||||
notifier2.OnDownloadComplete([]rpc.Event{{Gid: "1"}})
|
|
||||||
asserts.Len(notifyChan, 1)
|
|
||||||
<-notifyChan
|
|
||||||
|
|
||||||
notifier2.OnDownloadError([]rpc.Event{{Gid: "1"}})
|
|
||||||
asserts.Len(notifyChan, 1)
|
|
||||||
<-notifyChan
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -17,8 +17,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrAuthFailed = serializer.NewError(serializer.CodeNoPermissionErr, "鉴权失败", nil)
|
ErrAuthFailed = serializer.NewError(serializer.CodeNoPermissionErr, "鉴权失败", nil)
|
||||||
ErrExpired = serializer.NewError(serializer.CodeSignExpired, "签名已过期", nil)
|
ErrAuthHeaderMissing = serializer.NewError(serializer.CodeNoPermissionErr, "authorization header is missing", nil)
|
||||||
|
ErrExpiresMissing = serializer.NewError(serializer.CodeNoPermissionErr, "expire timestamp is missing", nil)
|
||||||
|
ErrExpired = serializer.NewError(serializer.CodeSignExpired, "签名已过期", nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
// General 通用的认证接口
|
// General 通用的认证接口
|
||||||
|
@ -55,7 +57,7 @@ func CheckRequest(instance Auth, r *http.Request) error {
|
||||||
ok bool
|
ok bool
|
||||||
)
|
)
|
||||||
if sign, ok = r.Header["Authorization"]; !ok || len(sign) == 0 {
|
if sign, ok = r.Header["Authorization"]; !ok || len(sign) == 0 {
|
||||||
return ErrAuthFailed
|
return ErrAuthHeaderMissing
|
||||||
}
|
}
|
||||||
sign[0] = strings.TrimPrefix(sign[0], "Bearer ")
|
sign[0] = strings.TrimPrefix(sign[0], "Bearer ")
|
||||||
|
|
||||||
|
|
|
@ -80,6 +80,19 @@ func TestCheckRequest(t *testing.T) {
|
||||||
asserts := assert.New(t)
|
asserts := assert.New(t)
|
||||||
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
||||||
|
|
||||||
|
// 缺少请求头
|
||||||
|
{
|
||||||
|
req, err := http.NewRequest(
|
||||||
|
"POST",
|
||||||
|
"http://127.0.0.1/api/v3/upload",
|
||||||
|
strings.NewReader("I am body."),
|
||||||
|
)
|
||||||
|
asserts.NoError(err)
|
||||||
|
err = CheckRequest(General, req)
|
||||||
|
asserts.Error(err)
|
||||||
|
asserts.Equal(ErrAuthHeaderMissing, err)
|
||||||
|
}
|
||||||
|
|
||||||
// 非上传请求 验证成功
|
// 非上传请求 验证成功
|
||||||
{
|
{
|
||||||
req, err := http.NewRequest(
|
req, err := http.NewRequest(
|
||||||
|
|
|
@ -33,7 +33,7 @@ func (auth HMACAuth) Check(body string, sign string) error {
|
||||||
signSlice := strings.Split(sign, ":")
|
signSlice := strings.Split(sign, ":")
|
||||||
// 如果未携带expires字段
|
// 如果未携带expires字段
|
||||||
if signSlice[len(signSlice)-1] == "" {
|
if signSlice[len(signSlice)-1] == "" {
|
||||||
return ErrAuthFailed
|
return ErrExpiresMissing
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证是否过期
|
// 验证是否过期
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
package balancer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewBalancer(t *testing.T) {
|
||||||
|
a := assert.New(t)
|
||||||
|
a.NotNil(NewBalancer(""))
|
||||||
|
a.IsType(&RoundRobin{}, NewBalancer("RoundRobin"))
|
||||||
|
}
|
|
@ -0,0 +1,42 @@
|
||||||
|
package balancer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRoundRobin_NextIndex(t *testing.T) {
|
||||||
|
a := assert.New(t)
|
||||||
|
r := &RoundRobin{}
|
||||||
|
total := 5
|
||||||
|
for i := 1; i < total; i++ {
|
||||||
|
a.Equal(i, r.NextIndex(total))
|
||||||
|
}
|
||||||
|
for i := 0; i < total; i++ {
|
||||||
|
a.Equal(i, r.NextIndex(total))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundRobin_NextPeer(t *testing.T) {
|
||||||
|
a := assert.New(t)
|
||||||
|
r := &RoundRobin{}
|
||||||
|
|
||||||
|
// not slice
|
||||||
|
{
|
||||||
|
err, _ := r.NextPeer("s")
|
||||||
|
a.Equal(ErrInputNotSlice, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// no nodes
|
||||||
|
{
|
||||||
|
err, _ := r.NextPeer([]string{})
|
||||||
|
a.Equal(ErrNoAvaliableNode, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// pass
|
||||||
|
{
|
||||||
|
err, res := r.NextPeer([]string{"a"})
|
||||||
|
a.NoError(err)
|
||||||
|
a.Equal("a", res.(string))
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,254 @@
|
||||||
|
package cluster
|
||||||
|
|
||||||
|
import (
|
||||||
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
testMock "github.com/stretchr/testify/mock"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInitController(t *testing.T) {
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
InitController()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSlaveController_HandleHeartBeat(t *testing.T) {
|
||||||
|
a := assert.New(t)
|
||||||
|
c := &slaveController{
|
||||||
|
masters: make(map[string]MasterInfo),
|
||||||
|
}
|
||||||
|
|
||||||
|
// first heart beat
|
||||||
|
{
|
||||||
|
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
|
||||||
|
SiteID: "1",
|
||||||
|
Node: &model.Node{},
|
||||||
|
})
|
||||||
|
a.NoError(err)
|
||||||
|
|
||||||
|
_, err = c.HandleHeartBeat(&serializer.NodePingReq{
|
||||||
|
SiteID: "2",
|
||||||
|
Node: &model.Node{},
|
||||||
|
})
|
||||||
|
a.NoError(err)
|
||||||
|
|
||||||
|
a.Len(c.masters, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// second heart beat, no fresh
|
||||||
|
{
|
||||||
|
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
|
||||||
|
SiteID: "1",
|
||||||
|
SiteURL: "http://127.0.0.1",
|
||||||
|
Node: &model.Node{},
|
||||||
|
})
|
||||||
|
a.NoError(err)
|
||||||
|
a.Len(c.masters, 2)
|
||||||
|
a.Empty(c.masters["1"].URL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// second heart beat, fresh
|
||||||
|
{
|
||||||
|
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
|
||||||
|
SiteID: "1",
|
||||||
|
IsUpdate: true,
|
||||||
|
SiteURL: "http://127.0.0.1",
|
||||||
|
Node: &model.Node{},
|
||||||
|
})
|
||||||
|
a.NoError(err)
|
||||||
|
a.Len(c.masters, 2)
|
||||||
|
a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// second heart beat, fresh, url illegal
|
||||||
|
{
|
||||||
|
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
|
||||||
|
SiteID: "1",
|
||||||
|
IsUpdate: true,
|
||||||
|
SiteURL: string([]byte{0x7f}),
|
||||||
|
Node: &model.Node{},
|
||||||
|
})
|
||||||
|
a.Error(err)
|
||||||
|
a.Len(c.masters, 2)
|
||||||
|
a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type nodeMock struct {
|
||||||
|
testMock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n nodeMock) Init(node *model.Node) {
|
||||||
|
n.Called(node)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n nodeMock) IsFeatureEnabled(feature string) bool {
|
||||||
|
args := n.Called(feature)
|
||||||
|
return args.Bool(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n nodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) {
|
||||||
|
n.Called(callback)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n nodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
|
||||||
|
args := n.Called(req)
|
||||||
|
return args.Get(0).(*serializer.NodePingResp), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n nodeMock) IsActive() bool {
|
||||||
|
args := n.Called()
|
||||||
|
return args.Bool(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n nodeMock) GetAria2Instance() common.Aria2 {
|
||||||
|
args := n.Called()
|
||||||
|
return args.Get(0).(common.Aria2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n nodeMock) ID() uint {
|
||||||
|
args := n.Called()
|
||||||
|
return args.Get(0).(uint)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n nodeMock) Kill() {
|
||||||
|
n.Called()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n nodeMock) IsMater() bool {
|
||||||
|
args := n.Called()
|
||||||
|
return args.Bool(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n nodeMock) MasterAuthInstance() auth.Auth {
|
||||||
|
args := n.Called()
|
||||||
|
return args.Get(0).(auth.Auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n nodeMock) SlaveAuthInstance() auth.Auth {
|
||||||
|
args := n.Called()
|
||||||
|
return args.Get(0).(auth.Auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n nodeMock) DBModel() *model.Node {
|
||||||
|
args := n.Called()
|
||||||
|
return args.Get(0).(*model.Node)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSlaveController_GetAria2Instance(t *testing.T) {
|
||||||
|
a := assert.New(t)
|
||||||
|
mockNode := &nodeMock{}
|
||||||
|
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
|
||||||
|
c := &slaveController{
|
||||||
|
masters: map[string]MasterInfo{
|
||||||
|
"1": {Instance: mockNode},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// node node found
|
||||||
|
{
|
||||||
|
res, err := c.GetAria2Instance("2")
|
||||||
|
a.Nil(res)
|
||||||
|
a.Equal(ErrMasterNotFound, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// node found
|
||||||
|
{
|
||||||
|
res, err := c.GetAria2Instance("1")
|
||||||
|
a.NotNil(res)
|
||||||
|
a.NoError(err)
|
||||||
|
mockNode.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestMock struct {
|
||||||
|
testMock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r requestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
|
||||||
|
return r.Called(method, target, body, opts).Get(0).(*request.Response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSlaveController_SendNotification(t *testing.T) {
|
||||||
|
a := assert.New(t)
|
||||||
|
c := &slaveController{
|
||||||
|
masters: map[string]MasterInfo{
|
||||||
|
"1": {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// node not exit
|
||||||
|
{
|
||||||
|
a.Equal(ErrMasterNotFound, c.SendNotification("2", "", mq.Message{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// gob encode error
|
||||||
|
{
|
||||||
|
type randomType struct{}
|
||||||
|
a.Error(c.SendNotification("1", "", mq.Message{
|
||||||
|
Content: randomType{},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// return none 200
|
||||||
|
{
|
||||||
|
mockRequest := &requestMock{}
|
||||||
|
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s1", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||||
|
Response: &http.Response{StatusCode: http.StatusConflict},
|
||||||
|
})
|
||||||
|
c := &slaveController{
|
||||||
|
masters: map[string]MasterInfo{
|
||||||
|
"1": {Client: mockRequest},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
a.Error(c.SendNotification("1", "s1", mq.Message{}))
|
||||||
|
mockRequest.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// master return error
|
||||||
|
{
|
||||||
|
mockRequest := &requestMock{}
|
||||||
|
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s2", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||||
|
Response: &http.Response{
|
||||||
|
StatusCode: 200,
|
||||||
|
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c := &slaveController{
|
||||||
|
masters: map[string]MasterInfo{
|
||||||
|
"1": {Client: mockRequest},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
a.Equal(1, c.SendNotification("1", "s2", mq.Message{}).(serializer.AppError).Code)
|
||||||
|
mockRequest.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// success
|
||||||
|
{
|
||||||
|
mockRequest := &requestMock{}
|
||||||
|
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s3", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||||
|
Response: &http.Response{
|
||||||
|
StatusCode: 200,
|
||||||
|
Body: ioutil.NopCloser(strings.NewReader("{\"code\":0}")),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c := &slaveController{
|
||||||
|
masters: map[string]MasterInfo{
|
||||||
|
"1": {Client: mockRequest},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
a.NoError(c.SendNotification("1", "s3", mq.Message{}))
|
||||||
|
mockRequest.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,9 +8,11 @@ import (
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
|
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
||||||
testMock "github.com/stretchr/testify/mock"
|
testMock "github.com/stretchr/testify/mock"
|
||||||
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SlaveControllerMock struct {
|
type SlaveControllerMock struct {
|
||||||
|
@ -184,3 +186,11 @@ func (t TaskPoolMock) Add(num int) {
|
||||||
func (t TaskPoolMock) Submit(job task.Job) {
|
func (t TaskPoolMock) Submit(job task.Job) {
|
||||||
t.Called(job)
|
t.Called(job)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RequestMock struct {
|
||||||
|
testMock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r RequestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
|
||||||
|
return r.Called(method, target, body, opts).Get(0).(*request.Response)
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package controllers
|
package controllers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
@ -72,7 +74,7 @@ func AdminReloadService(c *gin.Context) {
|
||||||
case "email":
|
case "email":
|
||||||
email.Init()
|
email.Init()
|
||||||
case "aria2":
|
case "aria2":
|
||||||
aria2.Init(true)
|
aria2.Init(true, cluster.Default, mq.GlobalMQ)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(200, serializer.Response{})
|
c.JSON(200, serializer.Response{})
|
||||||
|
|
|
@ -48,9 +48,7 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取 Aria2 负载均衡器
|
// 获取 Aria2 负载均衡器
|
||||||
aria2.Lock.RLock()
|
lb := aria2.GetLoadBalancer()
|
||||||
lb := aria2.LB
|
|
||||||
aria2.Lock.RUnlock()
|
|
||||||
|
|
||||||
// 获取 Aria2 实例
|
// 获取 Aria2 实例
|
||||||
err, node := cluster.Default.BalanceNodeByFeature("aria2", lb)
|
err, node := cluster.Default.BalanceNodeByFeature("aria2", lb)
|
||||||
|
|
Loading…
Reference in New Issue