mirror of https://github.com/hpcaitech/ColossalAI
MemStatsCollectorStatic (#1765)
parent
327d07c44a
commit
20e255d4e8
|
@ -6,7 +6,7 @@ import torch
|
||||||
|
|
||||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||||
|
|
||||||
from .memory_tracer.memstats_collector import MemStatsCollectorV2
|
from .memory_tracer.memstats_collector import MemStatsCollectorV2, MemStatsCollectorStatic
|
||||||
from .placement_policy import PlacementPolicyFactory
|
from .placement_policy import PlacementPolicyFactory
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,12 +26,26 @@ class GeminiManager:
|
||||||
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
|
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
|
def __init__(self, placement_policy: str,
|
||||||
|
chunk_manager: ChunkManager,
|
||||||
|
module: Optional[torch.nn.Module] = None,
|
||||||
|
use_static_memstats: bool = False) -> None:
|
||||||
|
|
||||||
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
|
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
|
||||||
self.policy_name = placement_policy
|
self.policy_name = placement_policy
|
||||||
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
||||||
self._chunk_manager = chunk_manager
|
self._chunk_manager = chunk_manager
|
||||||
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None
|
# self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None
|
||||||
|
self.use_static_memstats = use_static_memstats
|
||||||
|
if policy_cls.need_mem_stats:
|
||||||
|
if use_static_memstats:
|
||||||
|
assert module is not None
|
||||||
|
self._mem_stats_collector = MemStatsCollectorStatic(module, chunk_manager)
|
||||||
|
else:
|
||||||
|
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager)
|
||||||
|
else:
|
||||||
|
self._mem_stats_collector = None
|
||||||
|
|
||||||
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
|
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
|
||||||
self._compute_list: List[Tuple[Chunk, ...]] = []
|
self._compute_list: List[Tuple[Chunk, ...]] = []
|
||||||
self._compute_idx: int = -1
|
self._compute_idx: int = -1
|
||||||
|
@ -43,9 +57,13 @@ class GeminiManager:
|
||||||
self._warmup = True
|
self._warmup = True
|
||||||
self._comp_cuda_demand_time = 0
|
self._comp_cuda_demand_time = 0
|
||||||
|
|
||||||
def pre_iter(self):
|
def pre_iter(self, *args):
|
||||||
if self._mem_stats_collector and self._warmup:
|
if self._mem_stats_collector and self._warmup:
|
||||||
self._mem_stats_collector.start_collection()
|
if self.use_static_memstats:
|
||||||
|
self._mem_stats_collector.init_mem_stats(*args)
|
||||||
|
self._warmup = False
|
||||||
|
else:
|
||||||
|
self._mem_stats_collector.start_collection()
|
||||||
|
|
||||||
def post_iter(self):
|
def post_iter(self):
|
||||||
"""This function must be called when each iteration finishes
|
"""This function must be called when each iteration finishes
|
||||||
|
|
|
@ -5,8 +5,16 @@ from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||||
from colossalai.gemini.chunk import ChunkManager
|
from colossalai.gemini.chunk import ChunkManager
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import time
|
import time
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
|
from colossalai.fx.profiler import (calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size)
|
||||||
|
from torch.fx import symbolic_trace
|
||||||
|
|
||||||
|
if is_compatible_with_meta():
|
||||||
|
from colossalai.fx.profiler import MetaTensor
|
||||||
|
|
||||||
|
|
||||||
class MemStatsCollector:
|
class MemStatsCollector:
|
||||||
|
@ -150,3 +158,101 @@ class MemStatsCollectorV2(MemStatsCollector):
|
||||||
@property
|
@property
|
||||||
def cuda_margin_mem(self) -> float:
|
def cuda_margin_mem(self) -> float:
|
||||||
return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda'))
|
return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda'))
|
||||||
|
|
||||||
|
|
||||||
|
class MemStatsCollectorStatic(MemStatsCollectorV2):
|
||||||
|
"""
|
||||||
|
A Static Memory statistic collector.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, module: nn.Module, chunk_manager: ChunkManager) -> None:
|
||||||
|
super().__init__(chunk_manager)
|
||||||
|
self.module = module
|
||||||
|
self.module_info_list = []
|
||||||
|
|
||||||
|
|
||||||
|
def init_mem_stats(self, *inputs):
|
||||||
|
|
||||||
|
self.register_opnodes_recursively(self.module)
|
||||||
|
self.refactor_module()
|
||||||
|
|
||||||
|
self.module = self.module.cpu()
|
||||||
|
self.module.train()
|
||||||
|
|
||||||
|
data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs]
|
||||||
|
gm = symbolic_trace(self.module)
|
||||||
|
interp = MetaInfoProp(gm)
|
||||||
|
interp.propagate(*data)
|
||||||
|
|
||||||
|
total_mem = 0
|
||||||
|
for inp in inputs:
|
||||||
|
total_mem += inp.numel() * inp.element_size()
|
||||||
|
last_node = None
|
||||||
|
module_name_list = [mInfo.module_full_name for mInfo in self.module_info_list]
|
||||||
|
for node in gm.graph.nodes:
|
||||||
|
total_mem = total_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node)
|
||||||
|
if node.op == "call_module":
|
||||||
|
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
|
||||||
|
self._non_model_data_cuda_list.append(total_mem)
|
||||||
|
last_node = node
|
||||||
|
self._non_model_data_cuda_list.append(total_mem)
|
||||||
|
self._non_model_data_cuda_list = self._non_model_data_cuda_list[1:]
|
||||||
|
|
||||||
|
cur_module_mem_fwd = 0
|
||||||
|
cur_module_mem_bwd = 0
|
||||||
|
grad_module_out = last_node.meta["fwd_mem_out"]
|
||||||
|
for node in gm.graph.nodes.__reversed__():
|
||||||
|
cur_module_mem_fwd = cur_module_mem_fwd + calculate_fwd_tmp(node) + calculate_fwd_out(node)
|
||||||
|
cur_module_mem_bwd = cur_module_mem_bwd + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
|
||||||
|
if node.op == "call_module":
|
||||||
|
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
|
||||||
|
self._non_model_data_cuda_list.append(total_mem + grad_module_out + cur_module_mem_bwd)
|
||||||
|
total_mem = total_mem - cur_module_mem_fwd
|
||||||
|
cur_module_mem_fwd = 0
|
||||||
|
cur_module_mem_bwd = 0
|
||||||
|
grad_module_out = node.meta["bwd_mem_out"]
|
||||||
|
|
||||||
|
self._step_total = len(self._non_model_data_cuda_list)
|
||||||
|
self.recover_module()
|
||||||
|
|
||||||
|
|
||||||
|
def refactor_module(self):
|
||||||
|
for modInfo in self.module_info_list:
|
||||||
|
temp_node = nn.Sequential(nn.ReLU(), modInfo.module)
|
||||||
|
modInfo.parent_module.__setattr__(modInfo.module_name, temp_node)
|
||||||
|
|
||||||
|
|
||||||
|
def recover_module(self):
|
||||||
|
for modInfo in self.module_info_list:
|
||||||
|
modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module)
|
||||||
|
|
||||||
|
|
||||||
|
def register_opnodes_recursively(self,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
name: str = "",
|
||||||
|
full_name: str = "",
|
||||||
|
parent_module: Optional[torch.nn.Module] = None):
|
||||||
|
|
||||||
|
assert isinstance(module, torch.nn.Module)
|
||||||
|
|
||||||
|
for child_name, child in module.named_children():
|
||||||
|
self.register_opnodes_recursively(child, child_name, full_name + "_" + child_name, module)
|
||||||
|
|
||||||
|
# Early return on modules with no parameters.
|
||||||
|
if len(list(module.parameters(recurse=False))) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.module_info_list.append(ModuleInfos(module, name, full_name[1:], parent_module))
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleInfos:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
module_name: str,
|
||||||
|
module_full_name: str,
|
||||||
|
parent_module: torch.nn.Module):
|
||||||
|
self.module = module
|
||||||
|
self.module_name = module_name
|
||||||
|
self.module_full_name = module_full_name
|
||||||
|
self.parent_module = parent_module
|
|
@ -267,7 +267,7 @@ class ZeroDDP(ColoDDP):
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
||||||
self.module.zero_grad(set_to_none=True)
|
self.module.zero_grad(set_to_none=True)
|
||||||
self.gemini_manager.pre_iter()
|
self.gemini_manager.pre_iter(*args)
|
||||||
with ParamOpHookManager.use_hooks(self.param_op_hook):
|
with ParamOpHookManager.use_hooks(self.param_op_hook):
|
||||||
outputs = self.module(*args, **kwargs)
|
outputs = self.module(*args, **kwargs)
|
||||||
if self.force_outputs_fp32:
|
if self.force_outputs_fp32:
|
||||||
|
|
|
@ -13,7 +13,7 @@ from colossalai.zero.utils import ZeroHook
|
||||||
from colossalai.gemini.paramhooks import BaseParamHookMgr
|
from colossalai.gemini.paramhooks import BaseParamHookMgr
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import get_current_device, disposable
|
from colossalai.utils import get_current_device, disposable
|
||||||
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector
|
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector, MemStatsCollectorStatic
|
||||||
from colossalai.utils.memory import colo_device_memory_capacity
|
from colossalai.utils.memory import colo_device_memory_capacity
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||||
|
@ -77,6 +77,7 @@ class ShardedModelV2(nn.Module):
|
||||||
tensor_placement_policy: str = 'cuda',
|
tensor_placement_policy: str = 'cuda',
|
||||||
gradient_predivide_factor: Optional[float] = 1.0,
|
gradient_predivide_factor: Optional[float] = 1.0,
|
||||||
reuse_fp16_shard: bool = False,
|
reuse_fp16_shard: bool = False,
|
||||||
|
user_static_memstats: bool = False,
|
||||||
*args,
|
*args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
|
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
|
||||||
|
@ -110,10 +111,14 @@ class ShardedModelV2(nn.Module):
|
||||||
self.world_size = dist.get_world_size(self.process_group)
|
self.world_size = dist.get_world_size(self.process_group)
|
||||||
self.rank = dist.get_rank(self.process_group)
|
self.rank = dist.get_rank(self.process_group)
|
||||||
self.shard_strategy = shard_strategy
|
self.shard_strategy = shard_strategy
|
||||||
|
self.user_static_memstats = user_static_memstats
|
||||||
|
|
||||||
self._use_memory_tracer = tensor_placement_policy == 'auto'
|
self._use_memory_tracer = tensor_placement_policy == 'auto'
|
||||||
if self._use_memory_tracer:
|
if self._use_memory_tracer:
|
||||||
self._memstats_collector = MemStatsCollector()
|
if self.user_static_memstats:
|
||||||
|
self._memstats_collector = MemStatsCollectorStatic(self.module)
|
||||||
|
else:
|
||||||
|
self._memstats_collector = MemStatsCollector()
|
||||||
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
|
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
|
||||||
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
|
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
|
||||||
else:
|
else:
|
||||||
|
@ -206,9 +211,11 @@ class ShardedModelV2(nn.Module):
|
||||||
f.write(str(self._memstats_collector.non_model_data_list('cpu', 'GB')))
|
f.write(str(self._memstats_collector.non_model_data_list('cpu', 'GB')))
|
||||||
f.write('\n')
|
f.write('\n')
|
||||||
|
|
||||||
def _pre_forward_operations(self):
|
def _pre_forward_operations(self, *args):
|
||||||
# the operation will affect the memory tracer behavior in ZeroHook
|
# the operation will affect the memory tracer behavior in ZeroHook
|
||||||
if self._memstats_collector:
|
if self._memstats_collector:
|
||||||
|
if self.user_static_memstats:
|
||||||
|
self.init_mem_stats(*args)
|
||||||
self._start_collect_memstats()
|
self._start_collect_memstats()
|
||||||
|
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
|
@ -223,7 +230,7 @@ class ShardedModelV2(nn.Module):
|
||||||
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||||
|
|
||||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||||
self._pre_forward_operations()
|
self._pre_forward_operations(*args)
|
||||||
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
|
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
|
||||||
outputs = self.module(*args, **kwargs)
|
outputs = self.module(*args, **kwargs)
|
||||||
self._post_forward_operations()
|
self._post_forward_operations()
|
||||||
|
|
Loading…
Reference in New Issue