mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] remove static tracer (#2083)
parent
28ef3f29af
commit
1f99205827
|
@ -26,27 +26,13 @@ class GeminiManager:
|
|||
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
placement_policy: str,
|
||||
chunk_manager: ChunkManager,
|
||||
module: Optional[torch.nn.Module] = None,
|
||||
use_static_memstats: bool = False) -> None:
|
||||
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
|
||||
|
||||
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
|
||||
self.policy_name = placement_policy
|
||||
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
||||
self._chunk_manager = chunk_manager
|
||||
# self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager) if policy_cls.need_mem_stats else None
|
||||
self.use_static_memstats = use_static_memstats
|
||||
if policy_cls.need_mem_stats:
|
||||
if use_static_memstats:
|
||||
assert module is not None
|
||||
self._mem_stats_collector = StaticMemStatsCollector(module, chunk_manager)
|
||||
else:
|
||||
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager)
|
||||
else:
|
||||
self._mem_stats_collector = None
|
||||
|
||||
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager) if policy_cls.need_mem_stats else None
|
||||
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
|
||||
self._compute_list: List[Tuple[Chunk, ...]] = []
|
||||
self._compute_idx: int = -1
|
||||
|
@ -60,11 +46,7 @@ class GeminiManager:
|
|||
|
||||
def pre_iter(self, *args):
|
||||
if self._mem_stats_collector and self._warmup:
|
||||
if self.use_static_memstats:
|
||||
self._mem_stats_collector.init_mem_stats(*args)
|
||||
self._warmup = False
|
||||
else:
|
||||
self._mem_stats_collector.start_collection()
|
||||
self._mem_stats_collector.start_collection()
|
||||
|
||||
def post_iter(self):
|
||||
"""This function must be called when each iteration finishes
|
||||
|
|
|
@ -9,6 +9,16 @@ __all__ = ['RuntimeMemTracer']
|
|||
|
||||
|
||||
class RuntimeMemTracer():
|
||||
"""RuntimeMemTracer for the module training using ColoParameter.
|
||||
|
||||
Trace non-model memory usage during fwd+bwd process.
|
||||
It is obtained by using a tensor with the same shape as the training process as the inputs
|
||||
and running an single fwd+bwd to trace the statistics.
|
||||
|
||||
NOTE()
|
||||
1. The premise to use this tracer is that the target DNN execute the same operations at each iterations,
|
||||
2. Module buffers are viewed as non-model data.
|
||||
"""
|
||||
|
||||
def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
|
||||
super().__init__()
|
||||
|
|
|
@ -50,5 +50,5 @@ class GeminiDDP(ZeroDDP):
|
|||
hidden_dim=hidden_dim,
|
||||
search_range_mb=search_range_mb,
|
||||
min_chunk_size_mb=min_chunk_size_mb)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, module)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)
|
||||
|
|
|
@ -117,7 +117,7 @@ def run_1d_hybrid_tp(model_name):
|
|||
else:
|
||||
output_torch = model_torch(data, label)
|
||||
loss_torch = output_torch
|
||||
assert torch.allclose(loss, loss_torch, rtol=1e-2)
|
||||
assert torch.allclose(loss, loss_torch, rtol=1e-2), f"model_name {model_name} failed"
|
||||
torch.distributed.barrier()
|
||||
|
||||
loss.backward()
|
||||
|
|
Loading…
Reference in New Issue