diff --git a/drivers/pikpak/driver.go b/drivers/pikpak/driver.go index e27263dd..2dab2a9b 100644 --- a/drivers/pikpak/driver.go +++ b/drivers/pikpak/driver.go @@ -41,10 +41,6 @@ func (d *PikPak) Init(ctx context.Context) (err error) { d.ClientSecret = "dbw2OtmVEeuUvIptb1Coyg" } - withClient := func(ctx context.Context) context.Context { - return context.WithValue(ctx, oauth2.HTTPClient, base.HttpClient) - } - oauth2Config := &oauth2.Config{ ClientID: d.ClientID, ClientSecret: d.ClientSecret, @@ -55,11 +51,13 @@ func (d *PikPak) Init(ctx context.Context) (err error) { }, } - oauth2Token, err := oauth2Config.PasswordCredentialsToken(withClient(ctx), d.Username, d.Password) - if err != nil { - return err - } - d.oauth2Token = oauth2Config.TokenSource(withClient(context.Background()), oauth2Token) + d.oauth2Token = oauth2.ReuseTokenSource(nil, utils.TokenSource(func() (*oauth2.Token, error) { + return oauth2Config.PasswordCredentialsToken( + context.WithValue(context.Background(), oauth2.HTTPClient, base.HttpClient), + d.Username, + d.Password, + ) + })) return nil } diff --git a/drivers/pikpak_share/driver.go b/drivers/pikpak_share/driver.go index 58c2c8c4..1862db06 100644 --- a/drivers/pikpak_share/driver.go +++ b/drivers/pikpak_share/driver.go @@ -33,10 +33,6 @@ func (d *PikPakShare) Init(ctx context.Context) error { d.ClientSecret = "dbw2OtmVEeuUvIptb1Coyg" } - withClient := func(ctx context.Context) context.Context { - return context.WithValue(ctx, oauth2.HTTPClient, base.HttpClient) - } - oauth2Config := &oauth2.Config{ ClientID: d.ClientID, ClientSecret: d.ClientSecret, @@ -47,17 +43,16 @@ func (d *PikPakShare) Init(ctx context.Context) error { }, } - oauth2Token, err := oauth2Config.PasswordCredentialsToken(withClient(ctx), d.Username, d.Password) - if err != nil { - return err - } - d.oauth2Token = oauth2Config.TokenSource(withClient(context.Background()), oauth2Token) + d.oauth2Token = oauth2.ReuseTokenSource(nil, utils.TokenSource(func() (*oauth2.Token, error) { + return oauth2Config.PasswordCredentialsToken( + context.WithValue(context.Background(), oauth2.HTTPClient, base.HttpClient), + d.Username, + d.Password, + ) + })) if d.SharePwd != "" { - err = d.getSharePassToken() - if err != nil { - return err - } + return d.getSharePassToken() } return nil } diff --git a/pkg/utils/oauth2.go b/pkg/utils/oauth2.go new file mode 100644 index 00000000..c1ad1612 --- /dev/null +++ b/pkg/utils/oauth2.go @@ -0,0 +1,15 @@ +package utils + +import "golang.org/x/oauth2" + +type tokenSource struct { + fn func() (*oauth2.Token, error) +} + +func (t *tokenSource) Token() (*oauth2.Token, error) { + return t.fn() +} + +func TokenSource(fn func() (*oauth2.Token, error)) oauth2.TokenSource { + return &tokenSource{fn} +}