2023-03-28 12:25:36 +00:00
|
|
|
from time import time
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
from coati.experience_maker import Experience
|
|
|
|
|
|
|
|
from .base import Callback
|
|
|
|
|
|
|
|
|
|
|
|
def get_world_size() -> int:
|
|
|
|
if dist.is_initialized():
|
|
|
|
return dist.get_world_size()
|
|
|
|
return 1
|
|
|
|
|
|
|
|
|
|
|
|
def print_rank_0(*args, **kwargs) -> None:
|
|
|
|
if not dist.is_initialized() or dist.get_rank() == 0:
|
|
|
|
print(*args, **kwargs)
|
|
|
|
|
|
|
|
|
2023-04-26 09:34:59 +00:00
|
|
|
def divide(x: float, y: float) -> float:
|
|
|
|
if y == 0:
|
|
|
|
return float('inf')
|
|
|
|
elif y == float('inf'):
|
|
|
|
return float('nan')
|
|
|
|
return x / y
|
|
|
|
|
|
|
|
|
2023-03-28 12:25:36 +00:00
|
|
|
@torch.no_grad()
|
|
|
|
def all_reduce_mean(x: float, world_size: int) -> float:
|
|
|
|
if world_size == 1:
|
|
|
|
return x
|
|
|
|
tensor = torch.tensor([x], device=torch.cuda.current_device())
|
|
|
|
dist.all_reduce(tensor)
|
|
|
|
tensor = tensor / world_size
|
|
|
|
return tensor.item()
|
|
|
|
|
|
|
|
|
2023-04-26 09:34:59 +00:00
|
|
|
class Timer:
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
self.start_time: Optional[float] = None
|
|
|
|
self.duration: float = 0.
|
|
|
|
|
|
|
|
def start(self) -> None:
|
|
|
|
self.start_time = time()
|
|
|
|
|
|
|
|
def end(self) -> None:
|
|
|
|
assert self.start_time is not None
|
|
|
|
self.duration += time() - self.start_time
|
|
|
|
self.start_time = None
|
|
|
|
|
|
|
|
def reset(self) -> None:
|
|
|
|
self.duration = 0.
|
|
|
|
|
|
|
|
|
2023-03-28 12:25:36 +00:00
|
|
|
class PerformanceEvaluator(Callback):
|
|
|
|
"""
|
|
|
|
Callback for valuate the performance of the model.
|
|
|
|
Args:
|
|
|
|
actor_num_params: The number of parameters of the actor model.
|
|
|
|
critic_num_params: The number of parameters of the critic model.
|
|
|
|
initial_model_num_params: The number of parameters of the initial model.
|
|
|
|
reward_model_num_params: The number of parameters of the reward model.
|
|
|
|
enable_grad_checkpoint: Whether to enable gradient checkpointing.
|
|
|
|
ignore_episodes: The number of episodes to ignore when calculating the performance.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
actor_num_params: int,
|
|
|
|
critic_num_params: int,
|
|
|
|
initial_model_num_params: int,
|
|
|
|
reward_model_num_params: int,
|
|
|
|
enable_grad_checkpoint: bool = False,
|
|
|
|
ignore_episodes: int = 0) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self.world_size = get_world_size()
|
|
|
|
self.actor_num_params = actor_num_params
|
|
|
|
self.critic_num_params = critic_num_params
|
|
|
|
self.initial_model_num_params = initial_model_num_params
|
|
|
|
self.reward_model_num_params = reward_model_num_params
|
|
|
|
self.enable_grad_checkpoint = enable_grad_checkpoint
|
|
|
|
self.ignore_episodes = ignore_episodes
|
|
|
|
self.disable: bool = False
|
|
|
|
|
2023-04-26 09:34:59 +00:00
|
|
|
self.overall_timer = Timer()
|
|
|
|
self.make_experience_timer = Timer()
|
|
|
|
self.learn_timer = Timer()
|
2023-03-28 12:25:36 +00:00
|
|
|
self.make_experience_num_samples: int = 0
|
|
|
|
self.make_experience_flop: int = 0
|
|
|
|
self.learn_num_samples: int = 0
|
|
|
|
self.learn_flop: int = 0
|
|
|
|
|
|
|
|
def on_episode_start(self, episode: int) -> None:
|
|
|
|
self.disable = self.ignore_episodes > 0 and episode < self.ignore_episodes
|
2023-04-26 09:34:59 +00:00
|
|
|
if self.disable:
|
|
|
|
return
|
|
|
|
self.overall_timer.start()
|
|
|
|
|
|
|
|
def on_episode_end(self, episode: int) -> None:
|
|
|
|
if self.disable:
|
|
|
|
return
|
|
|
|
self.overall_timer.end()
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
def on_make_experience_start(self) -> None:
|
|
|
|
if self.disable:
|
|
|
|
return
|
2023-04-26 09:34:59 +00:00
|
|
|
self.make_experience_timer.start()
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
def on_make_experience_end(self, experience: Experience) -> None:
|
|
|
|
if self.disable:
|
|
|
|
return
|
2023-04-26 09:34:59 +00:00
|
|
|
self.make_experience_timer.end()
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
batch_size, seq_len = experience.sequences.shape
|
|
|
|
|
|
|
|
self.make_experience_num_samples += batch_size
|
|
|
|
|
|
|
|
# actor generate
|
|
|
|
num_actions = experience.action_mask.size(1)
|
|
|
|
input_len = seq_len - num_actions
|
|
|
|
total_seq_len = (input_len + seq_len - 1) * num_actions / 2
|
|
|
|
self.make_experience_flop += self.actor_num_params * batch_size * total_seq_len * 2
|
|
|
|
# actor forward
|
|
|
|
self.make_experience_flop += self.actor_num_params * batch_size * seq_len * 2
|
|
|
|
# critic forward
|
|
|
|
self.make_experience_flop += self.critic_num_params * batch_size * seq_len * 2
|
|
|
|
# initial model forward
|
|
|
|
self.make_experience_flop += self.initial_model_num_params * batch_size * seq_len * 2
|
|
|
|
# reward model forward
|
|
|
|
self.make_experience_flop += self.reward_model_num_params * batch_size * seq_len * 2
|
|
|
|
|
|
|
|
def on_learn_batch_start(self) -> None:
|
|
|
|
if self.disable:
|
|
|
|
return
|
2023-04-26 09:34:59 +00:00
|
|
|
self.learn_timer.start()
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
|
|
|
if self.disable:
|
|
|
|
return
|
2023-04-26 09:34:59 +00:00
|
|
|
self.learn_timer.end()
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
batch_size, seq_len = experience.sequences.shape
|
|
|
|
|
|
|
|
self.learn_num_samples += batch_size
|
|
|
|
|
|
|
|
# actor forward-backward, 3 means forward(1) + backward(2)
|
|
|
|
self.learn_flop += self.actor_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
|
2023-04-20 09:22:15 +00:00
|
|
|
# critic forward-backward
|
2023-03-28 12:25:36 +00:00
|
|
|
self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
|
|
|
|
|
|
|
|
def on_fit_end(self) -> None:
|
2023-04-26 09:34:59 +00:00
|
|
|
avg_make_experience_duration = all_reduce_mean(self.make_experience_timer.duration, self.world_size)
|
|
|
|
avg_learn_duration = all_reduce_mean(self.learn_timer.duration, self.world_size)
|
|
|
|
avg_overall_duration = all_reduce_mean(self.overall_timer.duration, self.world_size)
|
2023-03-28 12:25:36 +00:00
|
|
|
|
2023-04-26 09:34:59 +00:00
|
|
|
avg_make_experience_throughput = self.make_experience_num_samples * \
|
|
|
|
self.world_size / (avg_make_experience_duration + 1e-12)
|
2023-03-28 12:25:36 +00:00
|
|
|
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
|
|
|
|
|
2023-04-26 09:34:59 +00:00
|
|
|
avg_learn_throughput = self.learn_num_samples * self.world_size / (avg_learn_duration + 1e-12)
|
2023-03-28 12:25:36 +00:00
|
|
|
avg_learn_tflops = self.learn_flop / 1e12 / (avg_learn_duration + 1e-12)
|
|
|
|
|
2023-04-26 09:34:59 +00:00
|
|
|
num_effective_samples = min(self.learn_num_samples, self.make_experience_num_samples) * self.world_size
|
|
|
|
|
|
|
|
avg_overall_throughput = num_effective_samples / (avg_overall_duration + 1e-12)
|
|
|
|
|
|
|
|
overall_time_per_sample = divide(1, avg_overall_throughput)
|
|
|
|
make_experience_time_per_sample = divide(avg_make_experience_duration, num_effective_samples)
|
|
|
|
learn_time_per_sample = divide(avg_learn_duration, num_effective_samples)
|
|
|
|
|
2023-03-28 12:25:36 +00:00
|
|
|
print_rank_0(
|
2023-08-02 02:17:36 +00:00
|
|
|
f'Performance summary:\n'
|
|
|
|
+ f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n'
|
|
|
|
|
|
|
|
+ f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n'
|
|
|
|
+ f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n'
|
|
|
|
+ f'Overall time per sample: {overall_time_per_sample:.2f} s\n'
|
|
|
|
+ f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n'
|
|
|
|
|
|
|
|
+ f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%'
|
2023-03-28 12:25:36 +00:00
|
|
|
)
|