fix(github_release): corect the error logic when the latest release is not found

* rename APIContext
* add github_test.go
* add github api rate limit log
pull/7859/head
Zhang JL 2025-01-25 00:14:39 +08:00
parent 39a924d4c3
commit 0bce51dd53
3 changed files with 289 additions and 87 deletions

View File

@ -20,7 +20,7 @@ type GithubRelease struct {
model.Storage model.Storage
Addition Addition
api *ApiContext api *APIContext
repo repository repo repository
} }
@ -56,7 +56,7 @@ func (d *GithubRelease) Init(ctx context.Context) error {
return err return err
} }
d.api = NewApiContext(d.Addition.Token, nil) d.api = NewAPIContext(d.Addition.Token, nil)
repo, err := newRepository(d.Addition.Repo) repo, err := newRepository(d.Addition.Repo)
if err != nil { if err != nil {

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"strconv"
"time" "time"
"github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/model"
@ -12,49 +13,85 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
const GITHUB_API_VERSION = "2022-11-28" const (
GITHUB_API_VERSION = "2022-11-28"
DEFAULT_TIMEOUT = 10 * time.Second
)
type ApiContext struct { var ErrRateLimitExceeded = errors.New("rate limit exceeded")
token string
version string // RateLimit 表示 GitHub API 的速率限制信息
client *http.Client type RateLimit struct {
Limit uint
Remaining uint
Reset time.Time
} }
func NewApiContext(token string, client *http.Client) *ApiContext { // GitHubError 表示 GitHub API 返回的错误信息
ret := ApiContext{ type GitHubError struct {
token: token, Message string `json:"message"`
version: GITHUB_API_VERSION, DocumentationURL string `json:"documentation_url"`
client: client, StatusCode int
}
func (e *GitHubError) Error() string {
return fmt.Sprintf("github api error: %s (status: %d)", e.Message, e.StatusCode)
}
// parseHTTPError 解析 GitHub API 的错误响应
func parseHTTPError(statusCode int, body []byte) error {
var v GitHubError
err := utils.Json.Unmarshal(body, &v)
if err != nil {
return &GitHubError{
Message: string(body),
StatusCode: statusCode,
}
}
v.StatusCode = statusCode
return &v
}
// parseRateLimit 从响应头中解析速率限制信息
func parseRateLimit(header http.Header) *RateLimit {
limit, _ := strconv.Atoi(header.Get("X-RateLimit-Limit"))
remaining, _ := strconv.Atoi(header.Get("X-RateLimit-Remaining"))
reset, _ := strconv.ParseInt(header.Get("X-RateLimit-Reset"), 10, 64)
return &RateLimit{
Limit: uint(limit),
Remaining: uint(remaining),
Reset: time.Unix(reset, 0),
}
}
// APIContext 表示 GitHub API 的上下文信息
type APIContext struct {
token string
version string
client *http.Client
defaultTimeout time.Duration
rateLimit *RateLimit
}
// NewAPIContext 创建一个新的 GitHub API 上下文
func NewAPIContext(token string, client *http.Client) *APIContext {
ret := APIContext{
token: token,
version: GITHUB_API_VERSION,
client: client,
defaultTimeout: DEFAULT_TIMEOUT,
} }
if ret.client == nil { if ret.client == nil {
ret.client = http.DefaultClient ret.client = &http.Client{
Timeout: ret.defaultTimeout,
}
} }
return &ret return &ret
} }
// parseHTTPError 解析 HTTP 错误.
func parseHTTPError(body []byte) error {
var v map[string]interface{}
err := utils.Json.Unmarshal(body, &v)
if err != nil {
return errors.New(string(body))
}
iface, ok := v["message"]
if !ok {
return errors.New(string(body))
}
message, ok := iface.(string)
if !ok {
return errors.New(string(body))
}
return errors.New(message)
}
// sleepWithContext 在指定的时间内等待, 如果 context 被取消则提前返回. // sleepWithContext 在指定的时间内等待, 如果 context 被取消则提前返回.
func sleepWithContext(ctx context.Context, d time.Duration) error { func sleepWithContext(ctx context.Context, d time.Duration) error {
timer := time.NewTimer(d) timer := time.NewTimer(d)
@ -69,7 +106,7 @@ func sleepWithContext(ctx context.Context, d time.Duration) error {
} }
// getWithRetry 获取 GitHub API 并重试. // getWithRetry 获取 GitHub API 并重试.
func (a *ApiContext) getWithRetry(ctx context.Context, url string) (*http.Response, error) { func (a *APIContext) getWithRetry(ctx context.Context, url string) (*http.Response, error) {
backoff := Backoff{} backoff := Backoff{}
for { for {
@ -81,6 +118,11 @@ func (a *ApiContext) getWithRetry(ctx context.Context, url string) (*http.Respon
// non-2xx code does not cause error // non-2xx code does not cause error
if err != nil { if err != nil {
// 如果错误是速率限制错误, 则直接返回
if errors.Is(err, ErrRateLimitExceeded) {
return nil, err
}
// retry when error is not nil // retry when error is not nil
p, retryAgain := backoff.Pause() p, retryAgain := backoff.Pause()
if !retryAgain { if !retryAgain {
@ -115,7 +157,7 @@ func (a *ApiContext) getWithRetry(ctx context.Context, url string) (*http.Respon
// retry when server error // retry when server error
p, retryAgain := backoff.Pause() p, retryAgain := backoff.Pause()
if !retryAgain { if !retryAgain {
return nil, parseHTTPError(body) return nil, parseHTTPError(response.StatusCode, body)
} }
utils.Log.Debugf("query github api error: status code %d, retry after %s", response.StatusCode, p) utils.Log.Debugf("query github api error: status code %d, retry after %s", response.StatusCode, p)
@ -125,18 +167,18 @@ func (a *ApiContext) getWithRetry(ctx context.Context, url string) (*http.Respon
continue continue
} }
return nil, parseHTTPError(body) return nil, parseHTTPError(response.StatusCode, body)
} }
} }
// SetAuthHeader 为请求头添加 GitHub API 所需的认证头. // SetAuthHeader 为请求头添加 GitHub API 所需的认证头.
// 这是一个副作用函数, 会直接修改传入的 header. // 这是一个副作用函数, 会直接修改传入的 header.
func (a *ApiContext) SetAuthHeader(header http.Header) { func (a *APIContext) SetAuthHeader(header http.Header) {
header.Set("Authorization", fmt.Sprintf("Bearer %s", a.token)) header.Set("Authorization", fmt.Sprintf("Bearer %s", a.token))
} }
// get 获取 GitHub API. // get 获取 GitHub API.
func (a *ApiContext) get(ctx context.Context, url string) (*http.Response, error) { func (a *APIContext) get(ctx context.Context, url string) (*http.Response, error) {
request, err := http.NewRequestWithContext(ctx, "GET", url, nil) request, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
@ -150,11 +192,21 @@ func (a *ApiContext) get(ctx context.Context, url string) (*http.Response, error
return nil, err return nil, err
} }
// 更新速率限制信息
a.rateLimit = parseRateLimit(response.Header)
// 如果剩余请求数为 0, 等待到重置时间
if a.rateLimit.Remaining == 0 {
waitTime := time.Until(a.rateLimit.Reset)
utils.Log.Warnf("rate limit exceeded, will wait for %s", waitTime)
return nil, ErrRateLimitExceeded
}
return response, nil return response, nil
} }
// GetReleases 获取仓库信息. // GetReleases 获取仓库信息.
func (a *ApiContext) GetReleases(ctx context.Context, repo repository, perPage int) ([]model.Obj, error) { func (a *APIContext) GetReleases(ctx context.Context, repo repository, perPage int) ([]model.Obj, error) {
if perPage < 1 { if perPage < 1 {
perPage = 30 perPage = 30
} }
@ -170,10 +222,6 @@ func (a *ApiContext) GetReleases(ctx context.Context, repo repository, perPage i
return nil, errors.Wrap(err, "failed to read response body") return nil, errors.Wrap(err, "failed to read response body")
} }
if response.StatusCode != http.StatusOK {
return nil, parseHTTPError(body)
}
releases := []Release{} releases := []Release{}
err = utils.Json.Unmarshal(body, &releases) err = utils.Json.Unmarshal(body, &releases)
if err != nil { if err != nil {
@ -187,8 +235,48 @@ func (a *ApiContext) GetReleases(ctx context.Context, repo repository, perPage i
return tree, nil return tree, nil
} }
// GetLatestRelease 获取最新 release.
func (a *APIContext) GetLatestRelease(ctx context.Context, repo repository) (model.Obj, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo.UrlEncode())
response, err := a.getWithRetry(ctx, url)
if err != nil {
var githubErr *GitHubError
if errors.As(err, &githubErr) && githubErr.StatusCode == http.StatusNotFound {
return nil, ErrNoRelease
}
return nil, errors.Wrap(err, "get latest release")
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
if err != nil {
return nil, errors.Wrap(err, "read response body")
}
if response.StatusCode == http.StatusNotFound {
return nil, ErrNoRelease
}
if response.StatusCode != http.StatusOK {
err := parseHTTPError(response.StatusCode, body)
var githubErr *GitHubError
if errors.As(err, &githubErr) && githubErr.StatusCode == http.StatusNotFound {
return nil, ErrNoRelease
}
return nil, err
}
var release Release
if err := utils.Json.Unmarshal(body, &release); err != nil {
return nil, errors.Wrap(err, "unmarshal release data")
}
release.SetLatestFlag(true)
return &release, nil
}
// GetRelease 获取指定 tag 的 release. // GetRelease 获取指定 tag 的 release.
func (a *ApiContext) GetRelease(ctx context.Context, repo repository, id int64) (*Release, error) { func (a *APIContext) GetRelease(ctx context.Context, repo repository, id int64) (*Release, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/%d", repo.UrlEncode(), id) url := fmt.Sprintf("https://api.github.com/repos/%s/releases/%d", repo.UrlEncode(), id)
response, err := a.getWithRetry(ctx, url) response, err := a.getWithRetry(ctx, url)
if err != nil { if err != nil {
@ -201,10 +289,6 @@ func (a *ApiContext) GetRelease(ctx context.Context, repo repository, id int64)
return nil, errors.Wrap(err, "failed to read response body") return nil, errors.Wrap(err, "failed to read response body")
} }
if response.StatusCode != http.StatusOK {
return nil, parseHTTPError(body)
}
release := Release{} release := Release{}
err = utils.Json.Unmarshal(body, &release) err = utils.Json.Unmarshal(body, &release)
if err != nil { if err != nil {
@ -215,7 +299,7 @@ func (a *ApiContext) GetRelease(ctx context.Context, repo repository, id int64)
} }
// GetReleaseAsset 获取指定 tag 的 release 的 assets. // GetReleaseAsset 获取指定 tag 的 release 的 assets.
func (a *ApiContext) GetReleaseAsset(ctx context.Context, repo repository, ID int64) (*Asset, error) { func (a *APIContext) GetReleaseAsset(ctx context.Context, repo repository, ID int64) (*Asset, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/assets/%d", repo.UrlEncode(), ID) url := fmt.Sprintf("https://api.github.com/repos/%s/releases/assets/%d", repo.UrlEncode(), ID)
response, err := a.getWithRetry(ctx, url) response, err := a.getWithRetry(ctx, url)
if err != nil { if err != nil {
@ -228,10 +312,6 @@ func (a *ApiContext) GetReleaseAsset(ctx context.Context, repo repository, ID in
return nil, errors.Wrap(err, "failed to read response body") return nil, errors.Wrap(err, "failed to read response body")
} }
if response.StatusCode != http.StatusOK {
return nil, parseHTTPError(body)
}
asset := Asset{} asset := Asset{}
err = utils.Json.Unmarshal(body, &asset) err = utils.Json.Unmarshal(body, &asset)
if err != nil { if err != nil {
@ -244,36 +324,3 @@ func (a *ApiContext) GetReleaseAsset(ctx context.Context, repo repository, ID in
var ( var (
ErrNoRelease = errors.New("no release found") ErrNoRelease = errors.New("no release found")
) )
// GetLatestRelease 获取最新 release.
func (a *ApiContext) GetLatestRelease(ctx context.Context, repo repository) (model.Obj, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo.UrlEncode())
response, err := a.getWithRetry(ctx, url)
if err != nil {
return nil, err
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
if err != nil {
return nil, errors.Wrap(err, "get latest release failed")
}
if response.StatusCode != http.StatusOK {
if response.StatusCode == http.StatusNotFound {
// identify no release
return nil, ErrNoRelease
}
return nil, parseHTTPError(body)
}
release := Release{}
err = utils.Json.Unmarshal(body, &release)
if err != nil {
return nil, errors.Wrap(err, "get latest release failed")
}
release.SetLatestFlag(true)
return &release, nil
}

View File

@ -0,0 +1,155 @@
package github_release
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestParseRateLimit(t *testing.T) {
header := http.Header{}
header.Set("X-RateLimit-Limit", "60")
header.Set("X-RateLimit-Remaining", "59")
header.Set("X-RateLimit-Reset", "1735689600") // 2025-01-01 00:00:00 UTC
rateLimit := parseRateLimit(header)
assert.Equal(t, uint(60), rateLimit.Limit)
assert.Equal(t, uint(59), rateLimit.Remaining)
assert.Equal(t, time.Unix(1735689600, 0), rateLimit.Reset)
}
func TestGitHubError(t *testing.T) {
err := &GitHubError{
Message: "API rate limit exceeded",
StatusCode: 403,
}
assert.Equal(t, "github api error: API rate limit exceeded (status: 403)", err.Error())
}
func TestNewAPIContext(t *testing.T) {
token := "test-token"
client := &http.Client{}
ctx := NewAPIContext(token, client)
assert.Equal(t, token, ctx.token)
assert.Equal(t, GITHUB_API_VERSION, ctx.version)
assert.Equal(t, client, ctx.client)
assert.Equal(t, DEFAULT_TIMEOUT, ctx.defaultTimeout)
}
func TestAPIContext_SetAuthHeader(t *testing.T) {
token := "test-token"
ctx := NewAPIContext(token, nil)
header := http.Header{}
ctx.SetAuthHeader(header)
assert.Equal(t, "Bearer "+token, header.Get("Authorization"))
}
func TestAPIContext_GetWithRetry_RateLimit(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-RateLimit-Limit", "60")
w.Header().Set("X-RateLimit-Remaining", "0")
w.Header().Set("X-RateLimit-Reset", "1735689600")
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`{"message": "API rate limit exceeded"}`))
}))
defer server.Close()
ctx := NewAPIContext("test-token", server.Client())
_, err := ctx.getWithRetry(context.Background(), server.URL)
assert.ErrorIs(t, err, ErrRateLimitExceeded)
}
type testRoundTripper struct {
handler http.HandlerFunc
}
func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// 创建一个响应记录器
w := httptest.NewRecorder()
// 调用处理函数
t.handler.ServeHTTP(w, req)
// 将响应记录器转换为响应
return w.Result(), nil
}
func TestAPIContext_GetLatestRelease(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 验证请求路径
assert.Equal(t, "/repos/test-owner/test-repo/releases/latest", r.URL.Path)
// 验证请求头
assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization"))
assert.Equal(t, "application/vnd.github+json", r.Header.Get("Accept"))
// 设置速率限制头部
w.Header().Set("X-RateLimit-Limit", "60")
w.Header().Set("X-RateLimit-Remaining", "59")
w.Header().Set("X-RateLimit-Reset", "1735689600")
// 设置响应头和内容
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{
"id": 1,
"tag_name": "v1.0.0",
"name": "Release 1.0.0",
"published_at": "2025-01-01T00:00:00Z",
"created_at": "2025-01-01T00:00:00Z",
"assets": []
}`))
})
// 创建一个自定义的 HTTP 客户端
client := &http.Client{
Transport: &testRoundTripper{handler: handler},
}
ctx := NewAPIContext("test-token", client)
repo := repository{owner: "test-owner", name: "test-repo"}
release, err := ctx.GetLatestRelease(context.Background(), repo)
if assert.NoError(t, err) {
assert.NotNil(t, release)
assert.Equal(t, "latest(v1.0.0)", release.GetName())
}
}
func TestAPIContext_GetLatestRelease_NoRelease(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 验证请求路径
assert.Equal(t, "/repos/test-owner/test-repo/releases/latest", r.URL.Path)
// 验证请求头
assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization"))
assert.Equal(t, "application/vnd.github+json", r.Header.Get("Accept"))
// 设置速率限制头部
w.Header().Set("X-RateLimit-Limit", "60")
w.Header().Set("X-RateLimit-Remaining", "59")
w.Header().Set("X-RateLimit-Reset", "1735689600")
// 返回 404 状态码
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(`{"message": "Not Found"}`))
})
// 创建一个自定义的 HTTP 客户端
client := &http.Client{
Transport: &testRoundTripper{handler: handler},
}
ctx := NewAPIContext("test-token", client)
repo := repository{owner: "test-owner", name: "test-repo"}
_, err := ctx.GetLatestRelease(context.Background(), repo)
assert.ErrorIs(t, err, ErrNoRelease)
}