mirror of https://github.com/cloudreve/Cloudreve
				
				
				
			
		
			
				
	
	
		
			272 lines
		
	
	
		
			7.1 KiB
		
	
	
	
		
			Go
		
	
	
			
		
		
	
	
			272 lines
		
	
	
		
			7.1 KiB
		
	
	
	
		
			Go
		
	
	
package onedrive
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"encoding/gob"
 | 
						|
	"encoding/json"
 | 
						|
	"fmt"
 | 
						|
	"io"
 | 
						|
	"net/http"
 | 
						|
	"net/url"
 | 
						|
	"strconv"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/cloudreve/Cloudreve/v4/application/dependency"
 | 
						|
	"github.com/cloudreve/Cloudreve/v4/ent"
 | 
						|
	"github.com/cloudreve/Cloudreve/v4/inventory"
 | 
						|
	"github.com/cloudreve/Cloudreve/v4/inventory/types"
 | 
						|
	"github.com/cloudreve/Cloudreve/v4/pkg/credmanager"
 | 
						|
	"github.com/cloudreve/Cloudreve/v4/pkg/request"
 | 
						|
	"github.com/samber/lo"
 | 
						|
)
 | 
						|
 | 
						|
const (
 | 
						|
	AccessTokenExpiryMargin = 600 // 10 minutes
 | 
						|
)
 | 
						|
 | 
						|
// Error 实现error接口
 | 
						|
func (err OAuthError) Error() string {
 | 
						|
	return err.ErrorDescription
 | 
						|
}
 | 
						|
 | 
						|
// OAuthURL 获取OAuth认证页面URL
 | 
						|
func (client *client) OAuthURL(ctx context.Context, scope []string) string {
 | 
						|
	query := url.Values{
 | 
						|
		"client_id":     {client.policy.BucketName},
 | 
						|
		"scope":         {strings.Join(scope, " ")},
 | 
						|
		"response_type": {"code"},
 | 
						|
		"redirect_uri":  {client.policy.Settings.OauthRedirect},
 | 
						|
		"state":         {strconv.Itoa(client.policy.ID)},
 | 
						|
	}
 | 
						|
	client.endpoints.oAuthEndpoints.authorize.RawQuery = query.Encode()
 | 
						|
	return client.endpoints.oAuthEndpoints.authorize.String()
 | 
						|
}
 | 
						|
 | 
						|
// getOAuthEndpoint gets OAuth endpoints from API endpoint
 | 
						|
func getOAuthEndpoint(apiEndpoint string) *oauthEndpoint {
 | 
						|
	base, err := url.Parse(apiEndpoint)
 | 
						|
	if err != nil {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	var (
 | 
						|
		token     *url.URL
 | 
						|
		authorize *url.URL
 | 
						|
	)
 | 
						|
	switch base.Host {
 | 
						|
	//case "login.live.com":
 | 
						|
	//	token, _ = url.Parse("https://login.live.com/oauth20_token.srf")
 | 
						|
	//	authorize, _ = url.Parse("https://login.live.com/oauth20_authorize.srf")
 | 
						|
	case "microsoftgraph.chinacloudapi.cn":
 | 
						|
		token, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/token")
 | 
						|
		authorize, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/authorize")
 | 
						|
	default:
 | 
						|
		token, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/token")
 | 
						|
		authorize, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/authorize")
 | 
						|
	}
 | 
						|
 | 
						|
	return &oauthEndpoint{
 | 
						|
		token:     *token,
 | 
						|
		authorize: *authorize,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// Credential 获取token时返回的凭证
 | 
						|
type Credential struct {
 | 
						|
	ExpiresIn       int64  `json:"expires_in"`
 | 
						|
	AccessToken     string `json:"access_token"`
 | 
						|
	RefreshToken    string `json:"refresh_token"`
 | 
						|
	RefreshedAtUnix int64  `json:"refreshed_at"`
 | 
						|
 | 
						|
	PolicyID int `json:"policy_id"`
 | 
						|
}
 | 
						|
 | 
						|
func init() {
 | 
						|
	gob.Register(Credential{})
 | 
						|
}
 | 
						|
 | 
						|
func (c Credential) Refresh(ctx context.Context) (credmanager.Credential, error) {
 | 
						|
	if c.RefreshToken == "" {
 | 
						|
		return nil, ErrInvalidRefreshToken
 | 
						|
	}
 | 
						|
 | 
						|
	dep := dependency.FromContext(ctx)
 | 
						|
	storagePolicyClient := dep.StoragePolicyClient()
 | 
						|
	policy, err := storagePolicyClient.GetPolicyByID(ctx, c.PolicyID)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("failed to get storage policy: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	oauthBase := getOAuthEndpoint(policy.Server)
 | 
						|
 | 
						|
	newCredential, err := obtainToken(ctx, &obtainTokenArgs{
 | 
						|
		clientId:      policy.BucketName,
 | 
						|
		redirect:      policy.Settings.OauthRedirect,
 | 
						|
		secret:        policy.SecretKey,
 | 
						|
		refreshToken:  c.RefreshToken,
 | 
						|
		client:        dep.RequestClient(request.WithLogger(dep.Logger())),
 | 
						|
		tokenEndpoint: oauthBase.token.String(),
 | 
						|
		policyID:      c.PolicyID,
 | 
						|
	})
 | 
						|
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	c.RefreshToken = newCredential.RefreshToken
 | 
						|
	c.AccessToken = newCredential.AccessToken
 | 
						|
	c.ExpiresIn = newCredential.ExpiresIn
 | 
						|
	c.RefreshedAtUnix = time.Now().Unix()
 | 
						|
 | 
						|
	// Write refresh token to db
 | 
						|
	if err := storagePolicyClient.UpdateAccessKey(ctx, policy, newCredential.RefreshToken); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return c, nil
 | 
						|
}
 | 
						|
 | 
						|
func (c Credential) Key() string {
 | 
						|
	return CredentialKey(c.PolicyID)
 | 
						|
}
 | 
						|
 | 
						|
func (c Credential) Expiry() time.Time {
 | 
						|
	return time.Unix(c.ExpiresIn-AccessTokenExpiryMargin, 0)
 | 
						|
}
 | 
						|
 | 
						|
func (c Credential) String() string {
 | 
						|
	return c.AccessToken
 | 
						|
}
 | 
						|
 | 
						|
func (c Credential) RefreshedAt() *time.Time {
 | 
						|
	if c.RefreshedAtUnix == 0 {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	refreshedAt := time.Unix(c.RefreshedAtUnix, 0)
 | 
						|
	return &refreshedAt
 | 
						|
}
 | 
						|
 | 
						|
// ObtainToken 通过code或refresh_token兑换token
 | 
						|
func (client *client) ObtainToken(ctx context.Context, opts ...Option) (*Credential, error) {
 | 
						|
	options := newDefaultOption()
 | 
						|
	for _, o := range opts {
 | 
						|
		o.apply(options)
 | 
						|
	}
 | 
						|
 | 
						|
	return obtainToken(ctx, &obtainTokenArgs{
 | 
						|
		clientId:      client.policy.BucketName,
 | 
						|
		redirect:      client.policy.Settings.OauthRedirect,
 | 
						|
		secret:        client.policy.SecretKey,
 | 
						|
		code:          options.code,
 | 
						|
		refreshToken:  options.refreshToken,
 | 
						|
		client:        client.httpClient,
 | 
						|
		tokenEndpoint: client.endpoints.oAuthEndpoints.token.String(),
 | 
						|
		policyID:      client.policy.ID,
 | 
						|
	})
 | 
						|
 | 
						|
}
 | 
						|
 | 
						|
type obtainTokenArgs struct {
 | 
						|
	clientId      string
 | 
						|
	redirect      string
 | 
						|
	secret        string
 | 
						|
	code          string
 | 
						|
	refreshToken  string
 | 
						|
	client        request.Client
 | 
						|
	tokenEndpoint string
 | 
						|
	policyID      int
 | 
						|
}
 | 
						|
 | 
						|
// obtainToken fetch new access token from Microsoft Graph API
 | 
						|
func obtainToken(ctx context.Context, args *obtainTokenArgs) (*Credential, error) {
 | 
						|
	body := url.Values{
 | 
						|
		"client_id":     {args.clientId},
 | 
						|
		"redirect_uri":  {args.redirect},
 | 
						|
		"client_secret": {args.secret},
 | 
						|
	}
 | 
						|
	if args.code != "" {
 | 
						|
		body.Add("grant_type", "authorization_code")
 | 
						|
		body.Add("code", args.code)
 | 
						|
	} else {
 | 
						|
		body.Add("grant_type", "refresh_token")
 | 
						|
		body.Add("refresh_token", args.refreshToken)
 | 
						|
	}
 | 
						|
	strBody := body.Encode()
 | 
						|
 | 
						|
	res := args.client.Request(
 | 
						|
		"POST",
 | 
						|
		args.tokenEndpoint,
 | 
						|
		io.NopCloser(strings.NewReader(strBody)),
 | 
						|
		request.WithHeader(http.Header{
 | 
						|
			"Content-Type": {"application/x-www-form-urlencoded"}},
 | 
						|
		),
 | 
						|
		request.WithContentLength(int64(len(strBody))),
 | 
						|
		request.WithContext(ctx),
 | 
						|
	)
 | 
						|
	if res.Err != nil {
 | 
						|
		return nil, res.Err
 | 
						|
	}
 | 
						|
 | 
						|
	respBody, err := res.GetResponse()
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	var (
 | 
						|
		errResp    OAuthError
 | 
						|
		credential Credential
 | 
						|
		decodeErr  error
 | 
						|
	)
 | 
						|
 | 
						|
	if res.Response.StatusCode != 200 {
 | 
						|
		decodeErr = json.Unmarshal([]byte(respBody), &errResp)
 | 
						|
	} else {
 | 
						|
		decodeErr = json.Unmarshal([]byte(respBody), &credential)
 | 
						|
	}
 | 
						|
	if decodeErr != nil {
 | 
						|
		return nil, decodeErr
 | 
						|
	}
 | 
						|
 | 
						|
	if errResp.ErrorType != "" {
 | 
						|
		return nil, errResp
 | 
						|
	}
 | 
						|
 | 
						|
	credential.PolicyID = args.policyID
 | 
						|
	credential.ExpiresIn = time.Now().Unix() + credential.ExpiresIn
 | 
						|
	if args.code != "" {
 | 
						|
		credential.ExpiresIn = time.Now().Unix() - 10
 | 
						|
	}
 | 
						|
	return &credential, nil
 | 
						|
}
 | 
						|
 | 
						|
// UpdateCredential 更新凭证,并检查有效期
 | 
						|
func (client *client) UpdateCredential(ctx context.Context) error {
 | 
						|
	newCred, err := client.cred.Obtain(ctx, CredentialKey(client.policy.ID))
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("failed to obtain token from CredManager: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	client.credential = newCred
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// RetrieveOneDriveCredentials retrieves OneDrive credentials from DB inventory
 | 
						|
func RetrieveOneDriveCredentials(ctx context.Context, storagePolicyClient inventory.StoragePolicyClient) ([]credmanager.Credential, error) {
 | 
						|
	odPolicies, err := storagePolicyClient.ListPolicyByType(ctx, types.PolicyTypeOd)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("failed to list OneDrive policies: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return lo.Map(odPolicies, func(item *ent.StoragePolicy, index int) credmanager.Credential {
 | 
						|
		return &Credential{
 | 
						|
			PolicyID:     item.ID,
 | 
						|
			ExpiresIn:    0,
 | 
						|
			RefreshToken: item.AccessKey,
 | 
						|
		}
 | 
						|
	}), nil
 | 
						|
}
 | 
						|
 | 
						|
func CredentialKey(policyId int) string {
 | 
						|
	return fmt.Sprintf("cred_od_%d", policyId)
 | 
						|
}
 |