mirror of https://github.com/Xhofe/alist
fix(github_release): corect the error logic when the latest release is not found
* rename APIContext * add github_test.go * add github api rate limit logpull/7859/head
parent
39a924d4c3
commit
0bce51dd53
|
@ -20,7 +20,7 @@ type GithubRelease struct {
|
|||
model.Storage
|
||||
Addition
|
||||
|
||||
api *ApiContext
|
||||
api *APIContext
|
||||
repo repository
|
||||
}
|
||||
|
||||
|
@ -56,7 +56,7 @@ func (d *GithubRelease) Init(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
d.api = NewApiContext(d.Addition.Token, nil)
|
||||
d.api = NewAPIContext(d.Addition.Token, nil)
|
||||
|
||||
repo, err := newRepository(d.Addition.Repo)
|
||||
if err != nil {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/alist-org/alist/v3/internal/model"
|
||||
|
@ -12,49 +13,85 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const GITHUB_API_VERSION = "2022-11-28"
|
||||
const (
|
||||
GITHUB_API_VERSION = "2022-11-28"
|
||||
DEFAULT_TIMEOUT = 10 * time.Second
|
||||
)
|
||||
|
||||
type ApiContext struct {
|
||||
token string
|
||||
version string
|
||||
client *http.Client
|
||||
var ErrRateLimitExceeded = errors.New("rate limit exceeded")
|
||||
|
||||
// RateLimit 表示 GitHub API 的速率限制信息
|
||||
type RateLimit struct {
|
||||
Limit uint
|
||||
Remaining uint
|
||||
Reset time.Time
|
||||
}
|
||||
|
||||
func NewApiContext(token string, client *http.Client) *ApiContext {
|
||||
ret := ApiContext{
|
||||
token: token,
|
||||
version: GITHUB_API_VERSION,
|
||||
client: client,
|
||||
// GitHubError 表示 GitHub API 返回的错误信息
|
||||
type GitHubError struct {
|
||||
Message string `json:"message"`
|
||||
DocumentationURL string `json:"documentation_url"`
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
func (e *GitHubError) Error() string {
|
||||
return fmt.Sprintf("github api error: %s (status: %d)", e.Message, e.StatusCode)
|
||||
}
|
||||
|
||||
// parseHTTPError 解析 GitHub API 的错误响应
|
||||
func parseHTTPError(statusCode int, body []byte) error {
|
||||
var v GitHubError
|
||||
err := utils.Json.Unmarshal(body, &v)
|
||||
if err != nil {
|
||||
return &GitHubError{
|
||||
Message: string(body),
|
||||
StatusCode: statusCode,
|
||||
}
|
||||
}
|
||||
v.StatusCode = statusCode
|
||||
return &v
|
||||
}
|
||||
|
||||
// parseRateLimit 从响应头中解析速率限制信息
|
||||
func parseRateLimit(header http.Header) *RateLimit {
|
||||
limit, _ := strconv.Atoi(header.Get("X-RateLimit-Limit"))
|
||||
remaining, _ := strconv.Atoi(header.Get("X-RateLimit-Remaining"))
|
||||
reset, _ := strconv.ParseInt(header.Get("X-RateLimit-Reset"), 10, 64)
|
||||
|
||||
return &RateLimit{
|
||||
Limit: uint(limit),
|
||||
Remaining: uint(remaining),
|
||||
Reset: time.Unix(reset, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// APIContext 表示 GitHub API 的上下文信息
|
||||
type APIContext struct {
|
||||
token string
|
||||
version string
|
||||
client *http.Client
|
||||
defaultTimeout time.Duration
|
||||
rateLimit *RateLimit
|
||||
}
|
||||
|
||||
// NewAPIContext 创建一个新的 GitHub API 上下文
|
||||
func NewAPIContext(token string, client *http.Client) *APIContext {
|
||||
ret := APIContext{
|
||||
token: token,
|
||||
version: GITHUB_API_VERSION,
|
||||
client: client,
|
||||
defaultTimeout: DEFAULT_TIMEOUT,
|
||||
}
|
||||
|
||||
if ret.client == nil {
|
||||
ret.client = http.DefaultClient
|
||||
ret.client = &http.Client{
|
||||
Timeout: ret.defaultTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
return &ret
|
||||
}
|
||||
|
||||
// parseHTTPError 解析 HTTP 错误.
|
||||
func parseHTTPError(body []byte) error {
|
||||
var v map[string]interface{}
|
||||
err := utils.Json.Unmarshal(body, &v)
|
||||
if err != nil {
|
||||
return errors.New(string(body))
|
||||
}
|
||||
|
||||
iface, ok := v["message"]
|
||||
if !ok {
|
||||
return errors.New(string(body))
|
||||
}
|
||||
|
||||
message, ok := iface.(string)
|
||||
if !ok {
|
||||
return errors.New(string(body))
|
||||
}
|
||||
|
||||
return errors.New(message)
|
||||
}
|
||||
|
||||
// sleepWithContext 在指定的时间内等待, 如果 context 被取消则提前返回.
|
||||
func sleepWithContext(ctx context.Context, d time.Duration) error {
|
||||
timer := time.NewTimer(d)
|
||||
|
@ -69,7 +106,7 @@ func sleepWithContext(ctx context.Context, d time.Duration) error {
|
|||
}
|
||||
|
||||
// getWithRetry 获取 GitHub API 并重试.
|
||||
func (a *ApiContext) getWithRetry(ctx context.Context, url string) (*http.Response, error) {
|
||||
func (a *APIContext) getWithRetry(ctx context.Context, url string) (*http.Response, error) {
|
||||
backoff := Backoff{}
|
||||
|
||||
for {
|
||||
|
@ -81,6 +118,11 @@ func (a *ApiContext) getWithRetry(ctx context.Context, url string) (*http.Respon
|
|||
|
||||
// non-2xx code does not cause error
|
||||
if err != nil {
|
||||
// 如果错误是速率限制错误, 则直接返回
|
||||
if errors.Is(err, ErrRateLimitExceeded) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// retry when error is not nil
|
||||
p, retryAgain := backoff.Pause()
|
||||
if !retryAgain {
|
||||
|
@ -115,7 +157,7 @@ func (a *ApiContext) getWithRetry(ctx context.Context, url string) (*http.Respon
|
|||
// retry when server error
|
||||
p, retryAgain := backoff.Pause()
|
||||
if !retryAgain {
|
||||
return nil, parseHTTPError(body)
|
||||
return nil, parseHTTPError(response.StatusCode, body)
|
||||
}
|
||||
utils.Log.Debugf("query github api error: status code %d, retry after %s", response.StatusCode, p)
|
||||
|
||||
|
@ -125,18 +167,18 @@ func (a *ApiContext) getWithRetry(ctx context.Context, url string) (*http.Respon
|
|||
continue
|
||||
}
|
||||
|
||||
return nil, parseHTTPError(body)
|
||||
return nil, parseHTTPError(response.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
// SetAuthHeader 为请求头添加 GitHub API 所需的认证头.
|
||||
// 这是一个副作用函数, 会直接修改传入的 header.
|
||||
func (a *ApiContext) SetAuthHeader(header http.Header) {
|
||||
func (a *APIContext) SetAuthHeader(header http.Header) {
|
||||
header.Set("Authorization", fmt.Sprintf("Bearer %s", a.token))
|
||||
}
|
||||
|
||||
// get 获取 GitHub API.
|
||||
func (a *ApiContext) get(ctx context.Context, url string) (*http.Response, error) {
|
||||
func (a *APIContext) get(ctx context.Context, url string) (*http.Response, error) {
|
||||
request, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -150,11 +192,21 @@ func (a *ApiContext) get(ctx context.Context, url string) (*http.Response, error
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// 更新速率限制信息
|
||||
a.rateLimit = parseRateLimit(response.Header)
|
||||
|
||||
// 如果剩余请求数为 0, 等待到重置时间
|
||||
if a.rateLimit.Remaining == 0 {
|
||||
waitTime := time.Until(a.rateLimit.Reset)
|
||||
utils.Log.Warnf("rate limit exceeded, will wait for %s", waitTime)
|
||||
return nil, ErrRateLimitExceeded
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// GetReleases 获取仓库信息.
|
||||
func (a *ApiContext) GetReleases(ctx context.Context, repo repository, perPage int) ([]model.Obj, error) {
|
||||
func (a *APIContext) GetReleases(ctx context.Context, repo repository, perPage int) ([]model.Obj, error) {
|
||||
if perPage < 1 {
|
||||
perPage = 30
|
||||
}
|
||||
|
@ -170,10 +222,6 @@ func (a *ApiContext) GetReleases(ctx context.Context, repo repository, perPage i
|
|||
return nil, errors.Wrap(err, "failed to read response body")
|
||||
}
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return nil, parseHTTPError(body)
|
||||
}
|
||||
|
||||
releases := []Release{}
|
||||
err = utils.Json.Unmarshal(body, &releases)
|
||||
if err != nil {
|
||||
|
@ -187,8 +235,48 @@ func (a *ApiContext) GetReleases(ctx context.Context, repo repository, perPage i
|
|||
return tree, nil
|
||||
}
|
||||
|
||||
// GetLatestRelease 获取最新 release.
|
||||
func (a *APIContext) GetLatestRelease(ctx context.Context, repo repository) (model.Obj, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo.UrlEncode())
|
||||
response, err := a.getWithRetry(ctx, url)
|
||||
if err != nil {
|
||||
var githubErr *GitHubError
|
||||
if errors.As(err, &githubErr) && githubErr.StatusCode == http.StatusNotFound {
|
||||
return nil, ErrNoRelease
|
||||
}
|
||||
return nil, errors.Wrap(err, "get latest release")
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "read response body")
|
||||
}
|
||||
|
||||
if response.StatusCode == http.StatusNotFound {
|
||||
return nil, ErrNoRelease
|
||||
}
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
err := parseHTTPError(response.StatusCode, body)
|
||||
var githubErr *GitHubError
|
||||
if errors.As(err, &githubErr) && githubErr.StatusCode == http.StatusNotFound {
|
||||
return nil, ErrNoRelease
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var release Release
|
||||
if err := utils.Json.Unmarshal(body, &release); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal release data")
|
||||
}
|
||||
|
||||
release.SetLatestFlag(true)
|
||||
return &release, nil
|
||||
}
|
||||
|
||||
// GetRelease 获取指定 tag 的 release.
|
||||
func (a *ApiContext) GetRelease(ctx context.Context, repo repository, id int64) (*Release, error) {
|
||||
func (a *APIContext) GetRelease(ctx context.Context, repo repository, id int64) (*Release, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/%d", repo.UrlEncode(), id)
|
||||
response, err := a.getWithRetry(ctx, url)
|
||||
if err != nil {
|
||||
|
@ -201,10 +289,6 @@ func (a *ApiContext) GetRelease(ctx context.Context, repo repository, id int64)
|
|||
return nil, errors.Wrap(err, "failed to read response body")
|
||||
}
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return nil, parseHTTPError(body)
|
||||
}
|
||||
|
||||
release := Release{}
|
||||
err = utils.Json.Unmarshal(body, &release)
|
||||
if err != nil {
|
||||
|
@ -215,7 +299,7 @@ func (a *ApiContext) GetRelease(ctx context.Context, repo repository, id int64)
|
|||
}
|
||||
|
||||
// GetReleaseAsset 获取指定 tag 的 release 的 assets.
|
||||
func (a *ApiContext) GetReleaseAsset(ctx context.Context, repo repository, ID int64) (*Asset, error) {
|
||||
func (a *APIContext) GetReleaseAsset(ctx context.Context, repo repository, ID int64) (*Asset, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/assets/%d", repo.UrlEncode(), ID)
|
||||
response, err := a.getWithRetry(ctx, url)
|
||||
if err != nil {
|
||||
|
@ -228,10 +312,6 @@ func (a *ApiContext) GetReleaseAsset(ctx context.Context, repo repository, ID in
|
|||
return nil, errors.Wrap(err, "failed to read response body")
|
||||
}
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return nil, parseHTTPError(body)
|
||||
}
|
||||
|
||||
asset := Asset{}
|
||||
err = utils.Json.Unmarshal(body, &asset)
|
||||
if err != nil {
|
||||
|
@ -244,36 +324,3 @@ func (a *ApiContext) GetReleaseAsset(ctx context.Context, repo repository, ID in
|
|||
var (
|
||||
ErrNoRelease = errors.New("no release found")
|
||||
)
|
||||
|
||||
// GetLatestRelease 获取最新 release.
|
||||
func (a *ApiContext) GetLatestRelease(ctx context.Context, repo repository) (model.Obj, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo.UrlEncode())
|
||||
response, err := a.getWithRetry(ctx, url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get latest release failed")
|
||||
}
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
if response.StatusCode == http.StatusNotFound {
|
||||
// identify no release
|
||||
return nil, ErrNoRelease
|
||||
}
|
||||
return nil, parseHTTPError(body)
|
||||
}
|
||||
|
||||
release := Release{}
|
||||
err = utils.Json.Unmarshal(body, &release)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get latest release failed")
|
||||
}
|
||||
|
||||
release.SetLatestFlag(true)
|
||||
|
||||
return &release, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,155 @@
|
|||
package github_release
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseRateLimit(t *testing.T) {
|
||||
header := http.Header{}
|
||||
header.Set("X-RateLimit-Limit", "60")
|
||||
header.Set("X-RateLimit-Remaining", "59")
|
||||
header.Set("X-RateLimit-Reset", "1735689600") // 2025-01-01 00:00:00 UTC
|
||||
|
||||
rateLimit := parseRateLimit(header)
|
||||
|
||||
assert.Equal(t, uint(60), rateLimit.Limit)
|
||||
assert.Equal(t, uint(59), rateLimit.Remaining)
|
||||
assert.Equal(t, time.Unix(1735689600, 0), rateLimit.Reset)
|
||||
}
|
||||
|
||||
func TestGitHubError(t *testing.T) {
|
||||
err := &GitHubError{
|
||||
Message: "API rate limit exceeded",
|
||||
StatusCode: 403,
|
||||
}
|
||||
|
||||
assert.Equal(t, "github api error: API rate limit exceeded (status: 403)", err.Error())
|
||||
}
|
||||
|
||||
func TestNewAPIContext(t *testing.T) {
|
||||
token := "test-token"
|
||||
client := &http.Client{}
|
||||
ctx := NewAPIContext(token, client)
|
||||
|
||||
assert.Equal(t, token, ctx.token)
|
||||
assert.Equal(t, GITHUB_API_VERSION, ctx.version)
|
||||
assert.Equal(t, client, ctx.client)
|
||||
assert.Equal(t, DEFAULT_TIMEOUT, ctx.defaultTimeout)
|
||||
}
|
||||
|
||||
func TestAPIContext_SetAuthHeader(t *testing.T) {
|
||||
token := "test-token"
|
||||
ctx := NewAPIContext(token, nil)
|
||||
header := http.Header{}
|
||||
|
||||
ctx.SetAuthHeader(header)
|
||||
assert.Equal(t, "Bearer "+token, header.Get("Authorization"))
|
||||
}
|
||||
|
||||
func TestAPIContext_GetWithRetry_RateLimit(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-RateLimit-Limit", "60")
|
||||
w.Header().Set("X-RateLimit-Remaining", "0")
|
||||
w.Header().Set("X-RateLimit-Reset", "1735689600")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte(`{"message": "API rate limit exceeded"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx := NewAPIContext("test-token", server.Client())
|
||||
_, err := ctx.getWithRetry(context.Background(), server.URL)
|
||||
|
||||
assert.ErrorIs(t, err, ErrRateLimitExceeded)
|
||||
}
|
||||
|
||||
type testRoundTripper struct {
|
||||
handler http.HandlerFunc
|
||||
}
|
||||
|
||||
func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// 创建一个响应记录器
|
||||
w := httptest.NewRecorder()
|
||||
// 调用处理函数
|
||||
t.handler.ServeHTTP(w, req)
|
||||
// 将响应记录器转换为响应
|
||||
return w.Result(), nil
|
||||
}
|
||||
|
||||
func TestAPIContext_GetLatestRelease(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 验证请求路径
|
||||
assert.Equal(t, "/repos/test-owner/test-repo/releases/latest", r.URL.Path)
|
||||
|
||||
// 验证请求头
|
||||
assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization"))
|
||||
assert.Equal(t, "application/vnd.github+json", r.Header.Get("Accept"))
|
||||
|
||||
// 设置速率限制头部
|
||||
w.Header().Set("X-RateLimit-Limit", "60")
|
||||
w.Header().Set("X-RateLimit-Remaining", "59")
|
||||
w.Header().Set("X-RateLimit-Reset", "1735689600")
|
||||
|
||||
// 设置响应头和内容
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{
|
||||
"id": 1,
|
||||
"tag_name": "v1.0.0",
|
||||
"name": "Release 1.0.0",
|
||||
"published_at": "2025-01-01T00:00:00Z",
|
||||
"created_at": "2025-01-01T00:00:00Z",
|
||||
"assets": []
|
||||
}`))
|
||||
})
|
||||
|
||||
// 创建一个自定义的 HTTP 客户端
|
||||
client := &http.Client{
|
||||
Transport: &testRoundTripper{handler: handler},
|
||||
}
|
||||
|
||||
ctx := NewAPIContext("test-token", client)
|
||||
repo := repository{owner: "test-owner", name: "test-repo"}
|
||||
release, err := ctx.GetLatestRelease(context.Background(), repo)
|
||||
|
||||
if assert.NoError(t, err) {
|
||||
assert.NotNil(t, release)
|
||||
assert.Equal(t, "latest(v1.0.0)", release.GetName())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIContext_GetLatestRelease_NoRelease(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 验证请求路径
|
||||
assert.Equal(t, "/repos/test-owner/test-repo/releases/latest", r.URL.Path)
|
||||
|
||||
// 验证请求头
|
||||
assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization"))
|
||||
assert.Equal(t, "application/vnd.github+json", r.Header.Get("Accept"))
|
||||
|
||||
// 设置速率限制头部
|
||||
w.Header().Set("X-RateLimit-Limit", "60")
|
||||
w.Header().Set("X-RateLimit-Remaining", "59")
|
||||
w.Header().Set("X-RateLimit-Reset", "1735689600")
|
||||
|
||||
// 返回 404 状态码
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Write([]byte(`{"message": "Not Found"}`))
|
||||
})
|
||||
|
||||
// 创建一个自定义的 HTTP 客户端
|
||||
client := &http.Client{
|
||||
Transport: &testRoundTripper{handler: handler},
|
||||
}
|
||||
|
||||
ctx := NewAPIContext("test-token", client)
|
||||
repo := repository{owner: "test-owner", name: "test-repo"}
|
||||
_, err := ctx.GetLatestRelease(context.Background(), repo)
|
||||
|
||||
assert.ErrorIs(t, err, ErrNoRelease)
|
||||
}
|
Loading…
Reference in New Issue