MemStatsCollectorStatic (#1765)

pull/1805/head
Zihao 2022-11-07 16:49:03 +08:00 committed by GitHub
parent 327d07c44a
commit 20e255d4e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 142 additions and 11 deletions

View File

@ -6,7 +6,7 @@ import torch
from colossalai.gemini.chunk import Chunk, ChunkManager
from .memory_tracer.memstats_collector import MemStatsCollectorV2
from .memory_tracer.memstats_collector import MemStatsCollectorV2, MemStatsCollectorStatic
from .placement_policy import PlacementPolicyFactory
@ -26,12 +26,26 @@ class GeminiManager:
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
"""
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
def __init__(self, placement_policy: str,
chunk_manager: ChunkManager,
module: Optional[torch.nn.Module] = None,
use_static_memstats: bool = False) -> None:
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
self.policy_name = placement_policy
policy_cls = PlacementPolicyFactory.create(placement_policy)
self._chunk_manager = chunk_manager
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None
# self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None
self.use_static_memstats = use_static_memstats
if policy_cls.need_mem_stats:
if use_static_memstats:
assert module is not None
self._mem_stats_collector = MemStatsCollectorStatic(module, chunk_manager)
else:
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager)
else:
self._mem_stats_collector = None
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1
@ -43,9 +57,13 @@ class GeminiManager:
self._warmup = True
self._comp_cuda_demand_time = 0
def pre_iter(self):
def pre_iter(self, *args):
if self._mem_stats_collector and self._warmup:
self._mem_stats_collector.start_collection()
if self.use_static_memstats:
self._mem_stats_collector.init_mem_stats(*args)
self._warmup = False
else:
self._mem_stats_collector.start_collection()
def post_iter(self):
"""This function must be called when each iteration finishes

View File

@ -5,8 +5,16 @@ from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.gemini.chunk import ChunkManager
import torch
import torch.nn as nn
import time
from typing import List
from typing import List, Optional
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import (calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size)
from torch.fx import symbolic_trace
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
class MemStatsCollector:
@ -150,3 +158,101 @@ class MemStatsCollectorV2(MemStatsCollector):
@property
def cuda_margin_mem(self) -> float:
return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda'))
class MemStatsCollectorStatic(MemStatsCollectorV2):
"""
A Static Memory statistic collector.
"""
def __init__(self, module: nn.Module, chunk_manager: ChunkManager) -> None:
super().__init__(chunk_manager)
self.module = module
self.module_info_list = []
def init_mem_stats(self, *inputs):
self.register_opnodes_recursively(self.module)
self.refactor_module()
self.module = self.module.cpu()
self.module.train()
data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs]
gm = symbolic_trace(self.module)
interp = MetaInfoProp(gm)
interp.propagate(*data)
total_mem = 0
for inp in inputs:
total_mem += inp.numel() * inp.element_size()
last_node = None
module_name_list = [mInfo.module_full_name for mInfo in self.module_info_list]
for node in gm.graph.nodes:
total_mem = total_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node)
if node.op == "call_module":
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
self._non_model_data_cuda_list.append(total_mem)
last_node = node
self._non_model_data_cuda_list.append(total_mem)
self._non_model_data_cuda_list = self._non_model_data_cuda_list[1:]
cur_module_mem_fwd = 0
cur_module_mem_bwd = 0
grad_module_out = last_node.meta["fwd_mem_out"]
for node in gm.graph.nodes.__reversed__():
cur_module_mem_fwd = cur_module_mem_fwd + calculate_fwd_tmp(node) + calculate_fwd_out(node)
cur_module_mem_bwd = cur_module_mem_bwd + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
if node.op == "call_module":
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
self._non_model_data_cuda_list.append(total_mem + grad_module_out + cur_module_mem_bwd)
total_mem = total_mem - cur_module_mem_fwd
cur_module_mem_fwd = 0
cur_module_mem_bwd = 0
grad_module_out = node.meta["bwd_mem_out"]
self._step_total = len(self._non_model_data_cuda_list)
self.recover_module()
def refactor_module(self):
for modInfo in self.module_info_list:
temp_node = nn.Sequential(nn.ReLU(), modInfo.module)
modInfo.parent_module.__setattr__(modInfo.module_name, temp_node)
def recover_module(self):
for modInfo in self.module_info_list:
modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module)
def register_opnodes_recursively(self,
module: torch.nn.Module,
name: str = "",
full_name: str = "",
parent_module: Optional[torch.nn.Module] = None):
assert isinstance(module, torch.nn.Module)
for child_name, child in module.named_children():
self.register_opnodes_recursively(child, child_name, full_name + "_" + child_name, module)
# Early return on modules with no parameters.
if len(list(module.parameters(recurse=False))) == 0:
return
self.module_info_list.append(ModuleInfos(module, name, full_name[1:], parent_module))
class ModuleInfos:
def __init__(self,
module: torch.nn.Module,
module_name: str,
module_full_name: str,
parent_module: torch.nn.Module):
self.module = module
self.module_name = module_name
self.module_full_name = module_full_name
self.parent_module = parent_module

View File

@ -267,7 +267,7 @@ class ZeroDDP(ColoDDP):
def forward(self, *args, **kwargs):
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter()
self.gemini_manager.pre_iter(*args)
with ParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
if self.force_outputs_fp32:

View File

@ -13,7 +13,7 @@ from colossalai.zero.utils import ZeroHook
from colossalai.gemini.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device, disposable
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector, MemStatsCollectorStatic
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
@ -77,6 +77,7 @@ class ShardedModelV2(nn.Module):
tensor_placement_policy: str = 'cuda',
gradient_predivide_factor: Optional[float] = 1.0,
reuse_fp16_shard: bool = False,
user_static_memstats: bool = False,
*args,
**kwargs):
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
@ -110,10 +111,14 @@ class ShardedModelV2(nn.Module):
self.world_size = dist.get_world_size(self.process_group)
self.rank = dist.get_rank(self.process_group)
self.shard_strategy = shard_strategy
self.user_static_memstats = user_static_memstats
self._use_memory_tracer = tensor_placement_policy == 'auto'
if self._use_memory_tracer:
self._memstats_collector = MemStatsCollector()
if self.user_static_memstats:
self._memstats_collector = MemStatsCollectorStatic(self.module)
else:
self._memstats_collector = MemStatsCollector()
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
else:
@ -206,9 +211,11 @@ class ShardedModelV2(nn.Module):
f.write(str(self._memstats_collector.non_model_data_list('cpu', 'GB')))
f.write('\n')
def _pre_forward_operations(self):
def _pre_forward_operations(self, *args):
# the operation will affect the memory tracer behavior in ZeroHook
if self._memstats_collector:
if self.user_static_memstats:
self.init_mem_stats(*args)
self._start_collect_memstats()
for p in self.module.parameters():
@ -223,7 +230,7 @@ class ShardedModelV2(nn.Module):
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._pre_forward_operations()
self._pre_forward_operations(*args)
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
outputs = self.module(*args, **kwargs)
self._post_forward_operations()