diff --git a/drivers/github_release/driver.go b/drivers/github_release/driver.go index a6f6d69f..18dd57b5 100644 --- a/drivers/github_release/driver.go +++ b/drivers/github_release/driver.go @@ -20,7 +20,7 @@ type GithubRelease struct { model.Storage Addition - api *ApiContext + api *APIContext repo repository } @@ -56,7 +56,7 @@ func (d *GithubRelease) Init(ctx context.Context) error { return err } - d.api = NewApiContext(d.Addition.Token, nil) + d.api = NewAPIContext(d.Addition.Token, nil) repo, err := newRepository(d.Addition.Repo) if err != nil { diff --git a/drivers/github_release/github.go b/drivers/github_release/github.go index 5505122e..0edbbd8b 100644 --- a/drivers/github_release/github.go +++ b/drivers/github_release/github.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net/http" + "strconv" "time" "github.com/alist-org/alist/v3/internal/model" @@ -12,49 +13,85 @@ import ( "github.com/pkg/errors" ) -const GITHUB_API_VERSION = "2022-11-28" +const ( + GITHUB_API_VERSION = "2022-11-28" + DEFAULT_TIMEOUT = 10 * time.Second +) -type ApiContext struct { - token string - version string - client *http.Client +var ErrRateLimitExceeded = errors.New("rate limit exceeded") + +// RateLimit 表示 GitHub API 的速率限制信息 +type RateLimit struct { + Limit uint + Remaining uint + Reset time.Time } -func NewApiContext(token string, client *http.Client) *ApiContext { - ret := ApiContext{ - token: token, - version: GITHUB_API_VERSION, - client: client, +// GitHubError 表示 GitHub API 返回的错误信息 +type GitHubError struct { + Message string `json:"message"` + DocumentationURL string `json:"documentation_url"` + StatusCode int +} + +func (e *GitHubError) Error() string { + return fmt.Sprintf("github api error: %s (status: %d)", e.Message, e.StatusCode) +} + +// parseHTTPError 解析 GitHub API 的错误响应 +func parseHTTPError(statusCode int, body []byte) error { + var v GitHubError + err := utils.Json.Unmarshal(body, &v) + if err != nil { + return &GitHubError{ + Message: string(body), + StatusCode: statusCode, + } + } + v.StatusCode = statusCode + return &v +} + +// parseRateLimit 从响应头中解析速率限制信息 +func parseRateLimit(header http.Header) *RateLimit { + limit, _ := strconv.Atoi(header.Get("X-RateLimit-Limit")) + remaining, _ := strconv.Atoi(header.Get("X-RateLimit-Remaining")) + reset, _ := strconv.ParseInt(header.Get("X-RateLimit-Reset"), 10, 64) + + return &RateLimit{ + Limit: uint(limit), + Remaining: uint(remaining), + Reset: time.Unix(reset, 0), + } +} + +// APIContext 表示 GitHub API 的上下文信息 +type APIContext struct { + token string + version string + client *http.Client + defaultTimeout time.Duration + rateLimit *RateLimit +} + +// NewAPIContext 创建一个新的 GitHub API 上下文 +func NewAPIContext(token string, client *http.Client) *APIContext { + ret := APIContext{ + token: token, + version: GITHUB_API_VERSION, + client: client, + defaultTimeout: DEFAULT_TIMEOUT, } if ret.client == nil { - ret.client = http.DefaultClient + ret.client = &http.Client{ + Timeout: ret.defaultTimeout, + } } return &ret } -// parseHTTPError 解析 HTTP 错误. -func parseHTTPError(body []byte) error { - var v map[string]interface{} - err := utils.Json.Unmarshal(body, &v) - if err != nil { - return errors.New(string(body)) - } - - iface, ok := v["message"] - if !ok { - return errors.New(string(body)) - } - - message, ok := iface.(string) - if !ok { - return errors.New(string(body)) - } - - return errors.New(message) -} - // sleepWithContext 在指定的时间内等待, 如果 context 被取消则提前返回. func sleepWithContext(ctx context.Context, d time.Duration) error { timer := time.NewTimer(d) @@ -69,7 +106,7 @@ func sleepWithContext(ctx context.Context, d time.Duration) error { } // getWithRetry 获取 GitHub API 并重试. -func (a *ApiContext) getWithRetry(ctx context.Context, url string) (*http.Response, error) { +func (a *APIContext) getWithRetry(ctx context.Context, url string) (*http.Response, error) { backoff := Backoff{} for { @@ -81,6 +118,11 @@ func (a *ApiContext) getWithRetry(ctx context.Context, url string) (*http.Respon // non-2xx code does not cause error if err != nil { + // 如果错误是速率限制错误, 则直接返回 + if errors.Is(err, ErrRateLimitExceeded) { + return nil, err + } + // retry when error is not nil p, retryAgain := backoff.Pause() if !retryAgain { @@ -115,7 +157,7 @@ func (a *ApiContext) getWithRetry(ctx context.Context, url string) (*http.Respon // retry when server error p, retryAgain := backoff.Pause() if !retryAgain { - return nil, parseHTTPError(body) + return nil, parseHTTPError(response.StatusCode, body) } utils.Log.Debugf("query github api error: status code %d, retry after %s", response.StatusCode, p) @@ -125,18 +167,18 @@ func (a *ApiContext) getWithRetry(ctx context.Context, url string) (*http.Respon continue } - return nil, parseHTTPError(body) + return nil, parseHTTPError(response.StatusCode, body) } } // SetAuthHeader 为请求头添加 GitHub API 所需的认证头. // 这是一个副作用函数, 会直接修改传入的 header. -func (a *ApiContext) SetAuthHeader(header http.Header) { +func (a *APIContext) SetAuthHeader(header http.Header) { header.Set("Authorization", fmt.Sprintf("Bearer %s", a.token)) } // get 获取 GitHub API. -func (a *ApiContext) get(ctx context.Context, url string) (*http.Response, error) { +func (a *APIContext) get(ctx context.Context, url string) (*http.Response, error) { request, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, err @@ -150,11 +192,21 @@ func (a *ApiContext) get(ctx context.Context, url string) (*http.Response, error return nil, err } + // 更新速率限制信息 + a.rateLimit = parseRateLimit(response.Header) + + // 如果剩余请求数为 0, 等待到重置时间 + if a.rateLimit.Remaining == 0 { + waitTime := time.Until(a.rateLimit.Reset) + utils.Log.Warnf("rate limit exceeded, will wait for %s", waitTime) + return nil, ErrRateLimitExceeded + } + return response, nil } // GetReleases 获取仓库信息. -func (a *ApiContext) GetReleases(ctx context.Context, repo repository, perPage int) ([]model.Obj, error) { +func (a *APIContext) GetReleases(ctx context.Context, repo repository, perPage int) ([]model.Obj, error) { if perPage < 1 { perPage = 30 } @@ -170,10 +222,6 @@ func (a *ApiContext) GetReleases(ctx context.Context, repo repository, perPage i return nil, errors.Wrap(err, "failed to read response body") } - if response.StatusCode != http.StatusOK { - return nil, parseHTTPError(body) - } - releases := []Release{} err = utils.Json.Unmarshal(body, &releases) if err != nil { @@ -187,8 +235,48 @@ func (a *ApiContext) GetReleases(ctx context.Context, repo repository, perPage i return tree, nil } +// GetLatestRelease 获取最新 release. +func (a *APIContext) GetLatestRelease(ctx context.Context, repo repository) (model.Obj, error) { + url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo.UrlEncode()) + response, err := a.getWithRetry(ctx, url) + if err != nil { + var githubErr *GitHubError + if errors.As(err, &githubErr) && githubErr.StatusCode == http.StatusNotFound { + return nil, ErrNoRelease + } + return nil, errors.Wrap(err, "get latest release") + } + defer response.Body.Close() + + body, err := io.ReadAll(response.Body) + if err != nil { + return nil, errors.Wrap(err, "read response body") + } + + if response.StatusCode == http.StatusNotFound { + return nil, ErrNoRelease + } + + if response.StatusCode != http.StatusOK { + err := parseHTTPError(response.StatusCode, body) + var githubErr *GitHubError + if errors.As(err, &githubErr) && githubErr.StatusCode == http.StatusNotFound { + return nil, ErrNoRelease + } + return nil, err + } + + var release Release + if err := utils.Json.Unmarshal(body, &release); err != nil { + return nil, errors.Wrap(err, "unmarshal release data") + } + + release.SetLatestFlag(true) + return &release, nil +} + // GetRelease 获取指定 tag 的 release. -func (a *ApiContext) GetRelease(ctx context.Context, repo repository, id int64) (*Release, error) { +func (a *APIContext) GetRelease(ctx context.Context, repo repository, id int64) (*Release, error) { url := fmt.Sprintf("https://api.github.com/repos/%s/releases/%d", repo.UrlEncode(), id) response, err := a.getWithRetry(ctx, url) if err != nil { @@ -201,10 +289,6 @@ func (a *ApiContext) GetRelease(ctx context.Context, repo repository, id int64) return nil, errors.Wrap(err, "failed to read response body") } - if response.StatusCode != http.StatusOK { - return nil, parseHTTPError(body) - } - release := Release{} err = utils.Json.Unmarshal(body, &release) if err != nil { @@ -215,7 +299,7 @@ func (a *ApiContext) GetRelease(ctx context.Context, repo repository, id int64) } // GetReleaseAsset 获取指定 tag 的 release 的 assets. -func (a *ApiContext) GetReleaseAsset(ctx context.Context, repo repository, ID int64) (*Asset, error) { +func (a *APIContext) GetReleaseAsset(ctx context.Context, repo repository, ID int64) (*Asset, error) { url := fmt.Sprintf("https://api.github.com/repos/%s/releases/assets/%d", repo.UrlEncode(), ID) response, err := a.getWithRetry(ctx, url) if err != nil { @@ -228,10 +312,6 @@ func (a *ApiContext) GetReleaseAsset(ctx context.Context, repo repository, ID in return nil, errors.Wrap(err, "failed to read response body") } - if response.StatusCode != http.StatusOK { - return nil, parseHTTPError(body) - } - asset := Asset{} err = utils.Json.Unmarshal(body, &asset) if err != nil { @@ -244,36 +324,3 @@ func (a *ApiContext) GetReleaseAsset(ctx context.Context, repo repository, ID in var ( ErrNoRelease = errors.New("no release found") ) - -// GetLatestRelease 获取最新 release. -func (a *ApiContext) GetLatestRelease(ctx context.Context, repo repository) (model.Obj, error) { - url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo.UrlEncode()) - response, err := a.getWithRetry(ctx, url) - if err != nil { - return nil, err - } - defer response.Body.Close() - - body, err := io.ReadAll(response.Body) - if err != nil { - return nil, errors.Wrap(err, "get latest release failed") - } - - if response.StatusCode != http.StatusOK { - if response.StatusCode == http.StatusNotFound { - // identify no release - return nil, ErrNoRelease - } - return nil, parseHTTPError(body) - } - - release := Release{} - err = utils.Json.Unmarshal(body, &release) - if err != nil { - return nil, errors.Wrap(err, "get latest release failed") - } - - release.SetLatestFlag(true) - - return &release, nil -} diff --git a/drivers/github_release/github_test.go b/drivers/github_release/github_test.go new file mode 100644 index 00000000..cc4b03db --- /dev/null +++ b/drivers/github_release/github_test.go @@ -0,0 +1,155 @@ +package github_release + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestParseRateLimit(t *testing.T) { + header := http.Header{} + header.Set("X-RateLimit-Limit", "60") + header.Set("X-RateLimit-Remaining", "59") + header.Set("X-RateLimit-Reset", "1735689600") // 2025-01-01 00:00:00 UTC + + rateLimit := parseRateLimit(header) + + assert.Equal(t, uint(60), rateLimit.Limit) + assert.Equal(t, uint(59), rateLimit.Remaining) + assert.Equal(t, time.Unix(1735689600, 0), rateLimit.Reset) +} + +func TestGitHubError(t *testing.T) { + err := &GitHubError{ + Message: "API rate limit exceeded", + StatusCode: 403, + } + + assert.Equal(t, "github api error: API rate limit exceeded (status: 403)", err.Error()) +} + +func TestNewAPIContext(t *testing.T) { + token := "test-token" + client := &http.Client{} + ctx := NewAPIContext(token, client) + + assert.Equal(t, token, ctx.token) + assert.Equal(t, GITHUB_API_VERSION, ctx.version) + assert.Equal(t, client, ctx.client) + assert.Equal(t, DEFAULT_TIMEOUT, ctx.defaultTimeout) +} + +func TestAPIContext_SetAuthHeader(t *testing.T) { + token := "test-token" + ctx := NewAPIContext(token, nil) + header := http.Header{} + + ctx.SetAuthHeader(header) + assert.Equal(t, "Bearer "+token, header.Get("Authorization")) +} + +func TestAPIContext_GetWithRetry_RateLimit(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-RateLimit-Limit", "60") + w.Header().Set("X-RateLimit-Remaining", "0") + w.Header().Set("X-RateLimit-Reset", "1735689600") + w.WriteHeader(http.StatusForbidden) + w.Write([]byte(`{"message": "API rate limit exceeded"}`)) + })) + defer server.Close() + + ctx := NewAPIContext("test-token", server.Client()) + _, err := ctx.getWithRetry(context.Background(), server.URL) + + assert.ErrorIs(t, err, ErrRateLimitExceeded) +} + +type testRoundTripper struct { + handler http.HandlerFunc +} + +func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // 创建一个响应记录器 + w := httptest.NewRecorder() + // 调用处理函数 + t.handler.ServeHTTP(w, req) + // 将响应记录器转换为响应 + return w.Result(), nil +} + +func TestAPIContext_GetLatestRelease(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 验证请求路径 + assert.Equal(t, "/repos/test-owner/test-repo/releases/latest", r.URL.Path) + + // 验证请求头 + assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization")) + assert.Equal(t, "application/vnd.github+json", r.Header.Get("Accept")) + + // 设置速率限制头部 + w.Header().Set("X-RateLimit-Limit", "60") + w.Header().Set("X-RateLimit-Remaining", "59") + w.Header().Set("X-RateLimit-Reset", "1735689600") + + // 设置响应头和内容 + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "id": 1, + "tag_name": "v1.0.0", + "name": "Release 1.0.0", + "published_at": "2025-01-01T00:00:00Z", + "created_at": "2025-01-01T00:00:00Z", + "assets": [] + }`)) + }) + + // 创建一个自定义的 HTTP 客户端 + client := &http.Client{ + Transport: &testRoundTripper{handler: handler}, + } + + ctx := NewAPIContext("test-token", client) + repo := repository{owner: "test-owner", name: "test-repo"} + release, err := ctx.GetLatestRelease(context.Background(), repo) + + if assert.NoError(t, err) { + assert.NotNil(t, release) + assert.Equal(t, "latest(v1.0.0)", release.GetName()) + } +} + +func TestAPIContext_GetLatestRelease_NoRelease(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 验证请求路径 + assert.Equal(t, "/repos/test-owner/test-repo/releases/latest", r.URL.Path) + + // 验证请求头 + assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization")) + assert.Equal(t, "application/vnd.github+json", r.Header.Get("Accept")) + + // 设置速率限制头部 + w.Header().Set("X-RateLimit-Limit", "60") + w.Header().Set("X-RateLimit-Remaining", "59") + w.Header().Set("X-RateLimit-Reset", "1735689600") + + // 返回 404 状态码 + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(`{"message": "Not Found"}`)) + }) + + // 创建一个自定义的 HTTP 客户端 + client := &http.Client{ + Transport: &testRoundTripper{handler: handler}, + } + + ctx := NewAPIContext("test-token", client) + repo := repository{owner: "test-owner", name: "test-repo"} + _, err := ctx.GetLatestRelease(context.Background(), repo) + + assert.ErrorIs(t, err, ErrNoRelease) +}