[Gemini] remove static tracer (#2083)

pull/2086/head
Jiarui Fang 2022-12-06 12:53:58 +08:00 committed by GitHub
parent 28ef3f29af
commit 1f99205827
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 15 additions and 23 deletions

View File

@ -26,27 +26,13 @@ class GeminiManager:
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
"""
def __init__(self,
placement_policy: str,
chunk_manager: ChunkManager,
module: Optional[torch.nn.Module] = None,
use_static_memstats: bool = False) -> None:
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
self.policy_name = placement_policy
policy_cls = PlacementPolicyFactory.create(placement_policy)
self._chunk_manager = chunk_manager
# self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager) if policy_cls.need_mem_stats else None
self.use_static_memstats = use_static_memstats
if policy_cls.need_mem_stats:
if use_static_memstats:
assert module is not None
self._mem_stats_collector = StaticMemStatsCollector(module, chunk_manager)
else:
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager)
else:
self._mem_stats_collector = None
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager) if policy_cls.need_mem_stats else None
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1
@ -60,11 +46,7 @@ class GeminiManager:
def pre_iter(self, *args):
if self._mem_stats_collector and self._warmup:
if self.use_static_memstats:
self._mem_stats_collector.init_mem_stats(*args)
self._warmup = False
else:
self._mem_stats_collector.start_collection()
self._mem_stats_collector.start_collection()
def post_iter(self):
"""This function must be called when each iteration finishes

View File

@ -9,6 +9,16 @@ __all__ = ['RuntimeMemTracer']
class RuntimeMemTracer():
"""RuntimeMemTracer for the module training using ColoParameter.
Trace non-model memory usage during fwd+bwd process.
It is obtained by using a tensor with the same shape as the training process as the inputs
and running an single fwd+bwd to trace the statistics.
NOTE()
1. The premise to use this tracer is that the target DNN execute the same operations at each iterations,
2. Module buffers are viewed as non-model data.
"""
def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
super().__init__()

View File

@ -50,5 +50,5 @@ class GeminiDDP(ZeroDDP):
hidden_dim=hidden_dim,
search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb)
gemini_manager = GeminiManager(placement_policy, chunk_manager, module)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)

View File

@ -117,7 +117,7 @@ def run_1d_hybrid_tp(model_name):
else:
output_torch = model_torch(data, label)
loss_torch = output_torch
assert torch.allclose(loss, loss_torch, rtol=1e-2)
assert torch.allclose(loss, loss_torch, rtol=1e-2), f"model_name {model_name} failed"
torch.distributed.barrier()
loss.backward()