mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] remove static tracer (#2083)
parent
28ef3f29af
commit
1f99205827
|
@ -26,27 +26,13 @@ class GeminiManager:
|
||||||
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
|
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
|
||||||
placement_policy: str,
|
|
||||||
chunk_manager: ChunkManager,
|
|
||||||
module: Optional[torch.nn.Module] = None,
|
|
||||||
use_static_memstats: bool = False) -> None:
|
|
||||||
|
|
||||||
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
|
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
|
||||||
self.policy_name = placement_policy
|
self.policy_name = placement_policy
|
||||||
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
||||||
self._chunk_manager = chunk_manager
|
self._chunk_manager = chunk_manager
|
||||||
# self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager) if policy_cls.need_mem_stats else None
|
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager) if policy_cls.need_mem_stats else None
|
||||||
self.use_static_memstats = use_static_memstats
|
|
||||||
if policy_cls.need_mem_stats:
|
|
||||||
if use_static_memstats:
|
|
||||||
assert module is not None
|
|
||||||
self._mem_stats_collector = StaticMemStatsCollector(module, chunk_manager)
|
|
||||||
else:
|
|
||||||
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager)
|
|
||||||
else:
|
|
||||||
self._mem_stats_collector = None
|
|
||||||
|
|
||||||
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
|
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
|
||||||
self._compute_list: List[Tuple[Chunk, ...]] = []
|
self._compute_list: List[Tuple[Chunk, ...]] = []
|
||||||
self._compute_idx: int = -1
|
self._compute_idx: int = -1
|
||||||
|
@ -60,11 +46,7 @@ class GeminiManager:
|
||||||
|
|
||||||
def pre_iter(self, *args):
|
def pre_iter(self, *args):
|
||||||
if self._mem_stats_collector and self._warmup:
|
if self._mem_stats_collector and self._warmup:
|
||||||
if self.use_static_memstats:
|
self._mem_stats_collector.start_collection()
|
||||||
self._mem_stats_collector.init_mem_stats(*args)
|
|
||||||
self._warmup = False
|
|
||||||
else:
|
|
||||||
self._mem_stats_collector.start_collection()
|
|
||||||
|
|
||||||
def post_iter(self):
|
def post_iter(self):
|
||||||
"""This function must be called when each iteration finishes
|
"""This function must be called when each iteration finishes
|
||||||
|
|
|
@ -9,6 +9,16 @@ __all__ = ['RuntimeMemTracer']
|
||||||
|
|
||||||
|
|
||||||
class RuntimeMemTracer():
|
class RuntimeMemTracer():
|
||||||
|
"""RuntimeMemTracer for the module training using ColoParameter.
|
||||||
|
|
||||||
|
Trace non-model memory usage during fwd+bwd process.
|
||||||
|
It is obtained by using a tensor with the same shape as the training process as the inputs
|
||||||
|
and running an single fwd+bwd to trace the statistics.
|
||||||
|
|
||||||
|
NOTE()
|
||||||
|
1. The premise to use this tracer is that the target DNN execute the same operations at each iterations,
|
||||||
|
2. Module buffers are viewed as non-model data.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
|
def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -50,5 +50,5 @@ class GeminiDDP(ZeroDDP):
|
||||||
hidden_dim=hidden_dim,
|
hidden_dim=hidden_dim,
|
||||||
search_range_mb=search_range_mb,
|
search_range_mb=search_range_mb,
|
||||||
min_chunk_size_mb=min_chunk_size_mb)
|
min_chunk_size_mb=min_chunk_size_mb)
|
||||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, module)
|
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||||
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)
|
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)
|
||||||
|
|
|
@ -117,7 +117,7 @@ def run_1d_hybrid_tp(model_name):
|
||||||
else:
|
else:
|
||||||
output_torch = model_torch(data, label)
|
output_torch = model_torch(data, label)
|
||||||
loss_torch = output_torch
|
loss_torch = output_torch
|
||||||
assert torch.allclose(loss, loss_torch, rtol=1e-2)
|
assert torch.allclose(loss, loss_torch, rtol=1e-2), f"model_name {model_name} failed"
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
Loading…
Reference in New Issue