mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
138 lines
4.3 KiB
138 lines
4.3 KiB
import json
|
|
import os
|
|
from typing import Any, Dict, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def get_model_numel(model: torch.nn.Module) -> int:
|
|
return sum(p.numel() for p in model.parameters())
|
|
|
|
|
|
def compute_reward(
|
|
r: Union[torch.Tensor, float],
|
|
kl_coef: float,
|
|
log_probs: torch.Tensor,
|
|
log_probs_base: torch.Tensor,
|
|
action_mask: Optional[torch.Tensor] = None,
|
|
reward_eps=5,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
log_probs: [batch_size, response_length]
|
|
log_probs_base: [batch_size, response_length]
|
|
action_mask: [batch_size, response_length]
|
|
r: float
|
|
Returns:
|
|
reward: [batch_size, response_length]
|
|
"""
|
|
log_ratio = log_probs - log_probs_base # address numerical instability issue
|
|
kl = -kl_coef * log_ratio * action_mask
|
|
reward = kl
|
|
r_clip = torch.clamp(r, -reward_eps, reward_eps)
|
|
for i in range(action_mask.size(0)):
|
|
assert action_mask[i].sum() > 0
|
|
reward[i, : action_mask[i].sum()] += r_clip[i]
|
|
reward[i, action_mask[i].sum() :] *= 0
|
|
return reward, ((log_ratio * (log_ratio < 10)).exp() - 1 - log_ratio) * action_mask
|
|
|
|
|
|
def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Compute the log probabilities from logits for the given labels.
|
|
|
|
Args:
|
|
logits (torch.Tensor): The input logits.
|
|
labels (torch.Tensor): The target labels.
|
|
|
|
Returns:
|
|
torch.Tensor: The log probabilities corresponding to the labels.
|
|
"""
|
|
log_probs = F.log_softmax(logits, dim=-1)
|
|
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
|
return log_probs_labels.squeeze(-1)
|
|
|
|
|
|
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
|
|
"""Calculate action log probs.
|
|
|
|
Args:
|
|
output (torch.Tensor): Output tensor of Actor.forward.logits.
|
|
sequences (torch.LongTensor): Input sequences.
|
|
num_actions (int): Number of actions.
|
|
|
|
Returns:
|
|
torch.Tensor: Action log probs.
|
|
"""
|
|
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
|
return log_probs[:, -num_actions:]
|
|
|
|
|
|
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
|
"""
|
|
Compute the masked mean of a tensor along a specified dimension.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The input tensor.
|
|
mask (torch.Tensor): The mask tensor with the same shape as the input tensor.
|
|
dim (int, optional): The dimension along which to compute the mean. Default is 1.
|
|
|
|
Returns:
|
|
torch.Tensor: The masked mean tensor.
|
|
|
|
"""
|
|
tensor = tensor * mask
|
|
tensor = tensor.sum(dim=dim)
|
|
mask_sum = mask.sum(dim=dim)
|
|
mean = tensor / (mask_sum + 1e-8)
|
|
return mean
|
|
|
|
|
|
def calc_masked_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, mask: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Calculate the masked log probabilities for a given sequence of logits.
|
|
|
|
Args:
|
|
logits (torch.Tensor): The input logits tensor of shape (batch_size, sequence_length, vocab_size).
|
|
sequences (torch.LongTensor): The input sequence tensor of shape (batch_size, sequence_length).
|
|
mask (torch.Tensor): The mask tensor of shape (batch_size, sequence_length).
|
|
|
|
Returns:
|
|
torch.Tensor: The masked log probabilities tensor of shape (batch_size, sequence_length - 1).
|
|
"""
|
|
# logits are probabilities of the next token, so we shift them to the left by one
|
|
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
|
return log_probs * mask
|
|
|
|
|
|
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
|
|
"""
|
|
Load file in JSON format
|
|
"""
|
|
with open(file=file_path, mode="r", encoding="utf-8") as fp:
|
|
return json.load(fp)
|
|
|
|
|
|
def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
|
|
"""
|
|
Save as JSON format
|
|
"""
|
|
with open(file=file_path, mode="w", encoding="utf-8") as fp:
|
|
json.dump(data, fp=fp, ensure_ascii=False, indent=4)
|
|
|
|
|
|
def disable_dropout(model: torch.nn.Module):
|
|
"""
|
|
Disables dropout in a PyTorch model. This is used in PPO Training
|
|
|
|
Args:
|
|
model (torch.nn.Module): The PyTorch model.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
for module in model.modules():
|
|
if isinstance(module, torch.nn.Dropout):
|
|
module.p = 0.0
|