mirror of https://github.com/cloudreve/Cloudreve
239 lines
5.7 KiB
Go
239 lines
5.7 KiB
Go
package credmanager
|
|
|
|
import (
|
|
"context"
|
|
"encoding/gob"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/cloudreve/Cloudreve/v4/pkg/auth"
|
|
"github.com/cloudreve/Cloudreve/v4/pkg/cache"
|
|
"github.com/cloudreve/Cloudreve/v4/pkg/cluster"
|
|
"github.com/cloudreve/Cloudreve/v4/pkg/cluster/routes"
|
|
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
|
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
|
"github.com/cloudreve/Cloudreve/v4/pkg/request"
|
|
)
|
|
|
|
type (
|
|
// CredManager is a centralized for all Oauth tokens that requires periodic refresh
|
|
// It is primarily used by OneDrive storage policy.
|
|
CredManager interface {
|
|
// Obtain gets a credential from the manager, refresh it if it's expired
|
|
Obtain(ctx context.Context, key string) (Credential, error)
|
|
// Upsert inserts or updates a credential in the manager
|
|
Upsert(ctx context.Context, cred ...Credential) error
|
|
RefreshAll(ctx context.Context)
|
|
}
|
|
|
|
Credential interface {
|
|
String() string
|
|
Refresh(ctx context.Context) (Credential, error)
|
|
Key() string
|
|
Expiry() time.Time
|
|
RefreshedAt() *time.Time
|
|
}
|
|
)
|
|
|
|
func init() {
|
|
gob.Register(CredentialResponse{})
|
|
}
|
|
|
|
func New(kv cache.Driver) CredManager {
|
|
return &credManager{
|
|
kv: kv,
|
|
locks: make(map[string]*sync.Mutex),
|
|
}
|
|
}
|
|
|
|
type (
|
|
credManager struct {
|
|
kv cache.Driver
|
|
mu sync.RWMutex
|
|
|
|
locks map[string]*sync.Mutex
|
|
}
|
|
)
|
|
|
|
var (
|
|
ErrNotFound = errors.New("credential not found")
|
|
)
|
|
|
|
func (m *credManager) Upsert(ctx context.Context, cred ...Credential) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
l := logging.FromContext(ctx)
|
|
for _, c := range cred {
|
|
l.Info("CredManager: Upsert credential for key %q...", c.Key())
|
|
if err := m.kv.Set(c.Key(), c, 0); err != nil {
|
|
return fmt.Errorf("failed to update credential in KV for key %q: %w", c.Key(), err)
|
|
}
|
|
|
|
if _, ok := m.locks[c.Key()]; !ok {
|
|
m.locks[c.Key()] = &sync.Mutex{}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *credManager) Obtain(ctx context.Context, key string) (Credential, error) {
|
|
m.mu.RLock()
|
|
itemRaw, ok := m.kv.Get(key)
|
|
if !ok {
|
|
m.mu.RUnlock()
|
|
return nil, fmt.Errorf("credential not found for key %q: %w", key, ErrNotFound)
|
|
}
|
|
|
|
l := logging.FromContext(ctx)
|
|
|
|
item := itemRaw.(Credential)
|
|
if _, ok := m.locks[key]; !ok {
|
|
m.locks[key] = &sync.Mutex{}
|
|
}
|
|
m.locks[key].Lock()
|
|
defer m.locks[key].Unlock()
|
|
m.mu.RUnlock()
|
|
|
|
if item.Expiry().After(time.Now()) {
|
|
// Credential is still valid
|
|
return item, nil
|
|
}
|
|
|
|
// Credential is expired, refresh it
|
|
l.Info("Refreshing credential for key %q...", key)
|
|
newCred, err := item.Refresh(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to refresh credential for key %q: %w", key, err)
|
|
}
|
|
|
|
l.Info("New credential for key %q is obtained, expire at %s", key, newCred.Expiry().String())
|
|
if err := m.kv.Set(key, newCred, 0); err != nil {
|
|
return nil, fmt.Errorf("failed to update credential in KV for key %q: %w", key, err)
|
|
}
|
|
|
|
return newCred, nil
|
|
}
|
|
|
|
func (m *credManager) RefreshAll(ctx context.Context) {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
l := logging.FromContext(ctx)
|
|
for key := range m.locks {
|
|
l.Info("Refreshing credential for key %q...", key)
|
|
m.locks[key].Lock()
|
|
defer m.locks[key].Unlock()
|
|
|
|
itemRaw, ok := m.kv.Get(key)
|
|
if !ok {
|
|
l.Warning("Credential not found for key %q", key)
|
|
continue
|
|
}
|
|
|
|
item := itemRaw.(Credential)
|
|
newCred, err := item.Refresh(ctx)
|
|
if err != nil {
|
|
l.Warning("Failed to refresh credential for key %q: %s", key, err)
|
|
continue
|
|
}
|
|
|
|
l.Info("New credential for key %q is obtained, expire at %s", key, newCred.Expiry().String())
|
|
if err := m.kv.Set(key, newCred, 0); err != nil {
|
|
l.Warning("Failed to update credential in KV for key %q: %s", key, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
type (
|
|
slaveCredManager struct {
|
|
kv cache.Driver
|
|
client request.Client
|
|
}
|
|
|
|
CredentialResponse struct {
|
|
Token string `json:"token"`
|
|
ExpireAt time.Time `json:"expire_at"`
|
|
}
|
|
)
|
|
|
|
func NewSlaveManager(kv cache.Driver, config conf.ConfigProvider) CredManager {
|
|
return &slaveCredManager{
|
|
kv: kv,
|
|
client: request.NewClient(
|
|
config,
|
|
request.WithCredential(auth.HMACAuth{
|
|
[]byte(config.Slave().Secret),
|
|
}, int64(config.Slave().SignatureTTL)),
|
|
),
|
|
}
|
|
}
|
|
|
|
func (c CredentialResponse) String() string {
|
|
return c.Token
|
|
}
|
|
|
|
func (c CredentialResponse) Refresh(ctx context.Context) (Credential, error) {
|
|
return c, nil
|
|
}
|
|
|
|
func (c CredentialResponse) Key() string {
|
|
return ""
|
|
}
|
|
|
|
func (c CredentialResponse) Expiry() time.Time {
|
|
return c.ExpireAt
|
|
}
|
|
|
|
func (c CredentialResponse) RefreshedAt() *time.Time {
|
|
return nil
|
|
}
|
|
|
|
func (m *slaveCredManager) Upsert(ctx context.Context, cred ...Credential) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *slaveCredManager) Obtain(ctx context.Context, key string) (Credential, error) {
|
|
itemRaw, ok := m.kv.Get(key)
|
|
if !ok {
|
|
return m.requestCredFromMaster(ctx, key)
|
|
}
|
|
|
|
return itemRaw.(Credential), nil
|
|
}
|
|
|
|
// No op on slave node
|
|
func (m *slaveCredManager) RefreshAll(ctx context.Context) {}
|
|
|
|
func (m *slaveCredManager) requestCredFromMaster(ctx context.Context, key string) (Credential, error) {
|
|
l := logging.FromContext(ctx)
|
|
l.Info("SlaveCredManager: Requesting credential for key %q from master...", key)
|
|
|
|
requestDst := routes.MasterGetCredentialUrl(cluster.MasterSiteUrlFromContext(ctx), key)
|
|
resp, err := m.client.Request(
|
|
http.MethodGet,
|
|
requestDst.String(),
|
|
nil,
|
|
request.WithContext(ctx),
|
|
request.WithLogger(l),
|
|
request.WithSlaveMeta(cluster.NodeIdFromContext(ctx)),
|
|
request.WithCorrelationID(),
|
|
).CheckHTTPResponse(http.StatusOK).DecodeResponse()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to request credential from master: %w", err)
|
|
}
|
|
|
|
cred := &CredentialResponse{}
|
|
resp.GobDecode(&cred)
|
|
|
|
if err := m.kv.Set(key, *cred, max(int(time.Until(cred.Expiry()).Seconds()), 1)); err != nil {
|
|
return nil, fmt.Errorf("failed to update credential in KV for key %q: %w", key, err)
|
|
}
|
|
|
|
return cred, nil
|
|
}
|