From 61c20b44bca90bd4d37bc68fd507bab3a7afcb70 Mon Sep 17 00:00:00 2001
From: Jiarui Fang <fangjiarui123@gmail.com>
Date: Wed, 20 Apr 2022 10:05:39 +0800
Subject: [PATCH] [log] local throughput metrics (#811)

* Revert "[zero] add ZeroTensorShardStrategy (#793)"

This reverts commit 88759e289efd0a7b5e0d7bf8e01dbe29db85cf71.

* [gemini] set cpu memory capacity

* [log] local throughput collecting

* polish

* polish

* polish

* polish code

* polish
---
 colossalai/trainer/hooks/_metric_hook.py      | 25 +++++++++++++------
 colossalai/zero/__init__.py                   |  4 +--
 .../zero/sharded_optim/sharded_optim_v2.py    |  3 ++-
 3 files changed, 21 insertions(+), 11 deletions(-)

diff --git a/colossalai/trainer/hooks/_metric_hook.py b/colossalai/trainer/hooks/_metric_hook.py
index 001557a35..dbca20169 100644
--- a/colossalai/trainer/hooks/_metric_hook.py
+++ b/colossalai/trainer/hooks/_metric_hook.py
@@ -124,7 +124,7 @@ class LossMetric(Metric):
     def get_last_step_value(self) -> str:
         """Returns :attr:`last_step_loss`.
         """
-        return str(self.last_step_loss)
+        return str(self.last_step_loss.cpu().item())
 
     @staticmethod
     def is_better(a, b):
@@ -207,7 +207,7 @@ class AccuracyMetric(Metric):
     def get_last_step_value(self) -> str:
         self.last_step_sum = all_reduce(self.last_step_sum, ParallelMode.DATA)
         self.last_step_correct = all_reduce(self.last_step_correct, ParallelMode.DATA)
-        return str(_format_number((self.last_step_correct / self.last_step_sum).item()))
+        return str(_format_number((self.last_step_correct / self.last_step_sum).cpu().item()))
 
     def get_accumulated_value(self):
         self.accumulated_sum = all_reduce(self.accumulated_sum, ParallelMode.DATA)
@@ -324,7 +324,7 @@ class ThroughputMetric(Metric):
         epoch_only (bool): Whether the metric only read for the full epoch.
     """
 
-    def __init__(self, epoch_only: bool, ignored_steps: int = 0, tflop_per_step: int = 0):
+    def __init__(self, epoch_only: bool, ignored_steps: int = 0, tflop_per_step: int = 0, use_local: bool = False):
         super().__init__(epoch_only=epoch_only)
         self.ignored_steps = ignored_steps
         self.cur_steps = 0
@@ -333,6 +333,7 @@ class ThroughputMetric(Metric):
         self.last_step_num_samples = torch.zeros(1, device=get_current_device())
         self.last_step_used_time = torch.zeros(1, device=get_current_device())
         self._tflop_per_step = tflop_per_step
+        self._use_local = use_local
 
     def reset(self) -> None:
         # self.cur_steps = 0
@@ -350,9 +351,13 @@ class ThroughputMetric(Metric):
             self.accumulated_used_time += self.last_step_used_time
 
     def get_last_step_value(self) -> str:
-        self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
-            gpc.get_world_size(ParallelMode.DATA)
-        self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
+        if self._use_local:
+            self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
+        else:
+            self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
+                 gpc.get_world_size(ParallelMode.DATA)
+            self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
+
         sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
         if self._tflop_per_step > 0:
             tflops = _format_number(self._tflop_per_step / (self.last_step_used_time.item() + 1e-12))
@@ -380,19 +385,23 @@ class ThroughputHook(MetricHook):
         priority (int, optional): Priority in the printing, hooks with small priority will be printed in front
             defaults to 10. If different hooks share same priority, the order of printing would
             depend on the hooks order in the hook list.
+        tflop_per_step(int, optional): tera floating point operations per step.
+        use_local (bool, optional): Whether to use local time for throughput calculation.
     """
 
-    def __init__(self, ignored_steps: int = 0, priority: int = 10, tflop_per_step: int = 0):
+    def __init__(self, ignored_steps: int = 0, priority: int = 10, tflop_per_step: int = 0, use_local=False):
         super().__init__(priority)
         self.ignored_steps = ignored_steps
         self._tflop_per_step = tflop_per_step
+        self._use_local = use_local
 
     def after_hook_is_attached(self, trainer):
         self._check_metric_states_initialization(trainer)
         if self._is_stage_to_compute:
             self.metric = ThroughputMetric(epoch_only=True,
                                            ignored_steps=self.ignored_steps,
-                                           tflop_per_step=self._tflop_per_step)
+                                           tflop_per_step=self._tflop_per_step,
+                                           use_local=self._use_local)
 
             # register the metric
             trainer.states['metrics']['train']['Throughput'] = self.metric
diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py
index 714474ea5..1ea7c73e3 100644
--- a/colossalai/zero/__init__.py
+++ b/colossalai/zero/__init__.py
@@ -23,10 +23,10 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
 
     logger = get_dist_logger('convert_to_zero_v2')
 
-    logger.info(f'optimizer_config is {optimizer_config}')
+    logger.info(f'optimizer_config is {optimizer_config}', ranks=[0])
     if optimizer_config is None:
         optimizer_config = dict()
-    logger.info(f'model_config is {model_config}')
+    logger.info(f'model_config is {model_config}', ranks=[0])
     if model_config is None:
         model_config = dict()
 
diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py
index b5d980e04..9f6ee7e03 100644
--- a/colossalai/zero/sharded_optim/sharded_optim_v2.py
+++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py
@@ -122,7 +122,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
         self._register_master_weight()
         if self.gpu_margin_mem_ratio != 0.0 and not isinstance(sharded_model._tensor_placement_policy,
                                                                AutoTensorPlacementPolicy):
-            self._logger.warning(f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"')
+            self._logger.warning(f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"',
+                                 ranks=[0])
 
         if self._verbose:
             self._logger.debug(