|
|
|
@ -19,6 +19,14 @@ def print_rank_0(*args, **kwargs) -> None:
|
|
|
|
|
print(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def divide(x: float, y: float) -> float: |
|
|
|
|
if y == 0: |
|
|
|
|
return float('inf') |
|
|
|
|
elif y == float('inf'): |
|
|
|
|
return float('nan') |
|
|
|
|
return x / y |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
|
def all_reduce_mean(x: float, world_size: int) -> float: |
|
|
|
|
if world_size == 1: |
|
|
|
@ -29,6 +37,24 @@ def all_reduce_mean(x: float, world_size: int) -> float:
|
|
|
|
|
return tensor.item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Timer: |
|
|
|
|
|
|
|
|
|
def __init__(self) -> None: |
|
|
|
|
self.start_time: Optional[float] = None |
|
|
|
|
self.duration: float = 0. |
|
|
|
|
|
|
|
|
|
def start(self) -> None: |
|
|
|
|
self.start_time = time() |
|
|
|
|
|
|
|
|
|
def end(self) -> None: |
|
|
|
|
assert self.start_time is not None |
|
|
|
|
self.duration += time() - self.start_time |
|
|
|
|
self.start_time = None |
|
|
|
|
|
|
|
|
|
def reset(self) -> None: |
|
|
|
|
self.duration = 0. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PerformanceEvaluator(Callback): |
|
|
|
|
""" |
|
|
|
|
Callback for valuate the performance of the model. |
|
|
|
@ -58,27 +84,34 @@ class PerformanceEvaluator(Callback):
|
|
|
|
|
self.ignore_episodes = ignore_episodes |
|
|
|
|
self.disable: bool = False |
|
|
|
|
|
|
|
|
|
self.make_experience_duration: float = 0. |
|
|
|
|
self.make_experience_start_time: Optional[float] = None |
|
|
|
|
self.overall_timer = Timer() |
|
|
|
|
self.make_experience_timer = Timer() |
|
|
|
|
self.learn_timer = Timer() |
|
|
|
|
self.make_experience_num_samples: int = 0 |
|
|
|
|
self.make_experience_flop: int = 0 |
|
|
|
|
self.learn_duration: float = 0. |
|
|
|
|
self.learn_start_time: Optional[float] = None |
|
|
|
|
self.learn_num_samples: int = 0 |
|
|
|
|
self.learn_flop: int = 0 |
|
|
|
|
|
|
|
|
|
def on_episode_start(self, episode: int) -> None: |
|
|
|
|
self.disable = self.ignore_episodes > 0 and episode < self.ignore_episodes |
|
|
|
|
if self.disable: |
|
|
|
|
return |
|
|
|
|
self.overall_timer.start() |
|
|
|
|
|
|
|
|
|
def on_episode_end(self, episode: int) -> None: |
|
|
|
|
if self.disable: |
|
|
|
|
return |
|
|
|
|
self.overall_timer.end() |
|
|
|
|
|
|
|
|
|
def on_make_experience_start(self) -> None: |
|
|
|
|
if self.disable: |
|
|
|
|
return |
|
|
|
|
self.make_experience_start_time = time() |
|
|
|
|
self.make_experience_timer.start() |
|
|
|
|
|
|
|
|
|
def on_make_experience_end(self, experience: Experience) -> None: |
|
|
|
|
if self.disable: |
|
|
|
|
return |
|
|
|
|
self.make_experience_duration += time() - self.make_experience_start_time |
|
|
|
|
self.make_experience_timer.end() |
|
|
|
|
|
|
|
|
|
batch_size, seq_len = experience.sequences.shape |
|
|
|
|
|
|
|
|
@ -101,12 +134,12 @@ class PerformanceEvaluator(Callback):
|
|
|
|
|
def on_learn_batch_start(self) -> None: |
|
|
|
|
if self.disable: |
|
|
|
|
return |
|
|
|
|
self.learn_start_time = time() |
|
|
|
|
self.learn_timer.start() |
|
|
|
|
|
|
|
|
|
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: |
|
|
|
|
if self.disable: |
|
|
|
|
return |
|
|
|
|
self.learn_duration += time() - self.learn_start_time |
|
|
|
|
self.learn_timer.end() |
|
|
|
|
|
|
|
|
|
batch_size, seq_len = experience.sequences.shape |
|
|
|
|
|
|
|
|
@ -118,16 +151,33 @@ class PerformanceEvaluator(Callback):
|
|
|
|
|
self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint)) |
|
|
|
|
|
|
|
|
|
def on_fit_end(self) -> None: |
|
|
|
|
avg_make_experience_duration = all_reduce_mean(self.make_experience_duration, self.world_size) |
|
|
|
|
avg_learn_duration = all_reduce_mean(self.learn_duration, self.world_size) |
|
|
|
|
avg_make_experience_duration = all_reduce_mean(self.make_experience_timer.duration, self.world_size) |
|
|
|
|
avg_learn_duration = all_reduce_mean(self.learn_timer.duration, self.world_size) |
|
|
|
|
avg_overall_duration = all_reduce_mean(self.overall_timer.duration, self.world_size) |
|
|
|
|
|
|
|
|
|
avg_make_experience_throughput = self.make_experience_num_samples / (avg_make_experience_duration + 1e-12) |
|
|
|
|
avg_make_experience_throughput = self.make_experience_num_samples * \ |
|
|
|
|
self.world_size / (avg_make_experience_duration + 1e-12) |
|
|
|
|
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12) |
|
|
|
|
|
|
|
|
|
avg_learn_throughput = self.learn_num_samples / (avg_learn_duration + 1e-12) |
|
|
|
|
avg_learn_throughput = self.learn_num_samples * self.world_size / (avg_learn_duration + 1e-12) |
|
|
|
|
avg_learn_tflops = self.learn_flop / 1e12 / (avg_learn_duration + 1e-12) |
|
|
|
|
|
|
|
|
|
num_effective_samples = min(self.learn_num_samples, self.make_experience_num_samples) * self.world_size |
|
|
|
|
|
|
|
|
|
avg_overall_throughput = num_effective_samples / (avg_overall_duration + 1e-12) |
|
|
|
|
|
|
|
|
|
overall_time_per_sample = divide(1, avg_overall_throughput) |
|
|
|
|
make_experience_time_per_sample = divide(avg_make_experience_duration, num_effective_samples) |
|
|
|
|
learn_time_per_sample = divide(avg_learn_duration, num_effective_samples) |
|
|
|
|
|
|
|
|
|
print_rank_0( |
|
|
|
|
f'Making experience throughput: {avg_make_experience_throughput:.3f} samples/sec, TFLOPS: {avg_make_experience_tflops:.3f}' |
|
|
|
|
f'Performance summary:\n' + |
|
|
|
|
f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n' |
|
|
|
|
+ |
|
|
|
|
f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n' |
|
|
|
|
+ f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n' + |
|
|
|
|
f'Overall time per sample: {overall_time_per_sample:.2f} s\n' + |
|
|
|
|
f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n' |
|
|
|
|
+ |
|
|
|
|
f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%' |
|
|
|
|
) |
|
|
|
|
print_rank_0(f'Learning throughput: {avg_learn_throughput:.3f} samples/sec, TFLOPS: {avg_learn_tflops:.3f}') |
|
|
|
|