MemStatsCollectorStatic (#1765)

pull/1805/head
Zihao 2022-11-07 16:49:03 +08:00 committed by GitHub
parent 327d07c44a
commit 20e255d4e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 142 additions and 11 deletions

View File

@ -6,7 +6,7 @@ import torch
from colossalai.gemini.chunk import Chunk, ChunkManager from colossalai.gemini.chunk import Chunk, ChunkManager
from .memory_tracer.memstats_collector import MemStatsCollectorV2 from .memory_tracer.memstats_collector import MemStatsCollectorV2, MemStatsCollectorStatic
from .placement_policy import PlacementPolicyFactory from .placement_policy import PlacementPolicyFactory
@ -26,12 +26,26 @@ class GeminiManager:
chunk_manager (ChunkManager): A ``ChunkManager`` instance. chunk_manager (ChunkManager): A ``ChunkManager`` instance.
""" """
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None: def __init__(self, placement_policy: str,
chunk_manager: ChunkManager,
module: Optional[torch.nn.Module] = None,
use_static_memstats: bool = False) -> None:
assert placement_policy in PlacementPolicyFactory.get_polocy_names() assert placement_policy in PlacementPolicyFactory.get_polocy_names()
self.policy_name = placement_policy self.policy_name = placement_policy
policy_cls = PlacementPolicyFactory.create(placement_policy) policy_cls = PlacementPolicyFactory.create(placement_policy)
self._chunk_manager = chunk_manager self._chunk_manager = chunk_manager
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None # self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None
self.use_static_memstats = use_static_memstats
if policy_cls.need_mem_stats:
if use_static_memstats:
assert module is not None
self._mem_stats_collector = MemStatsCollectorStatic(module, chunk_manager)
else:
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager)
else:
self._mem_stats_collector = None
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector) self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1 self._compute_idx: int = -1
@ -43,9 +57,13 @@ class GeminiManager:
self._warmup = True self._warmup = True
self._comp_cuda_demand_time = 0 self._comp_cuda_demand_time = 0
def pre_iter(self): def pre_iter(self, *args):
if self._mem_stats_collector and self._warmup: if self._mem_stats_collector and self._warmup:
self._mem_stats_collector.start_collection() if self.use_static_memstats:
self._mem_stats_collector.init_mem_stats(*args)
self._warmup = False
else:
self._mem_stats_collector.start_collection()
def post_iter(self): def post_iter(self):
"""This function must be called when each iteration finishes """This function must be called when each iteration finishes

View File

@ -5,8 +5,16 @@ from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.gemini.chunk import ChunkManager from colossalai.gemini.chunk import ChunkManager
import torch import torch
import torch.nn as nn
import time import time
from typing import List from typing import List, Optional
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import (calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size)
from torch.fx import symbolic_trace
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
class MemStatsCollector: class MemStatsCollector:
@ -150,3 +158,101 @@ class MemStatsCollectorV2(MemStatsCollector):
@property @property
def cuda_margin_mem(self) -> float: def cuda_margin_mem(self) -> float:
return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda')) return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda'))
class MemStatsCollectorStatic(MemStatsCollectorV2):
"""
A Static Memory statistic collector.
"""
def __init__(self, module: nn.Module, chunk_manager: ChunkManager) -> None:
super().__init__(chunk_manager)
self.module = module
self.module_info_list = []
def init_mem_stats(self, *inputs):
self.register_opnodes_recursively(self.module)
self.refactor_module()
self.module = self.module.cpu()
self.module.train()
data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs]
gm = symbolic_trace(self.module)
interp = MetaInfoProp(gm)
interp.propagate(*data)
total_mem = 0
for inp in inputs:
total_mem += inp.numel() * inp.element_size()
last_node = None
module_name_list = [mInfo.module_full_name for mInfo in self.module_info_list]
for node in gm.graph.nodes:
total_mem = total_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node)
if node.op == "call_module":
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
self._non_model_data_cuda_list.append(total_mem)
last_node = node
self._non_model_data_cuda_list.append(total_mem)
self._non_model_data_cuda_list = self._non_model_data_cuda_list[1:]
cur_module_mem_fwd = 0
cur_module_mem_bwd = 0
grad_module_out = last_node.meta["fwd_mem_out"]
for node in gm.graph.nodes.__reversed__():
cur_module_mem_fwd = cur_module_mem_fwd + calculate_fwd_tmp(node) + calculate_fwd_out(node)
cur_module_mem_bwd = cur_module_mem_bwd + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
if node.op == "call_module":
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
self._non_model_data_cuda_list.append(total_mem + grad_module_out + cur_module_mem_bwd)
total_mem = total_mem - cur_module_mem_fwd
cur_module_mem_fwd = 0
cur_module_mem_bwd = 0
grad_module_out = node.meta["bwd_mem_out"]
self._step_total = len(self._non_model_data_cuda_list)
self.recover_module()
def refactor_module(self):
for modInfo in self.module_info_list:
temp_node = nn.Sequential(nn.ReLU(), modInfo.module)
modInfo.parent_module.__setattr__(modInfo.module_name, temp_node)
def recover_module(self):
for modInfo in self.module_info_list:
modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module)
def register_opnodes_recursively(self,
module: torch.nn.Module,
name: str = "",
full_name: str = "",
parent_module: Optional[torch.nn.Module] = None):
assert isinstance(module, torch.nn.Module)
for child_name, child in module.named_children():
self.register_opnodes_recursively(child, child_name, full_name + "_" + child_name, module)
# Early return on modules with no parameters.
if len(list(module.parameters(recurse=False))) == 0:
return
self.module_info_list.append(ModuleInfos(module, name, full_name[1:], parent_module))
class ModuleInfos:
def __init__(self,
module: torch.nn.Module,
module_name: str,
module_full_name: str,
parent_module: torch.nn.Module):
self.module = module
self.module_name = module_name
self.module_full_name = module_full_name
self.parent_module = parent_module

View File

@ -267,7 +267,7 @@ class ZeroDDP(ColoDDP):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True) self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter() self.gemini_manager.pre_iter(*args)
with ParamOpHookManager.use_hooks(self.param_op_hook): with ParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs) outputs = self.module(*args, **kwargs)
if self.force_outputs_fp32: if self.force_outputs_fp32:

View File

@ -13,7 +13,7 @@ from colossalai.zero.utils import ZeroHook
from colossalai.gemini.paramhooks import BaseParamHookMgr from colossalai.gemini.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device, disposable from colossalai.utils import get_current_device, disposable
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector, MemStatsCollectorStatic
from colossalai.utils.memory import colo_device_memory_capacity from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
@ -77,6 +77,7 @@ class ShardedModelV2(nn.Module):
tensor_placement_policy: str = 'cuda', tensor_placement_policy: str = 'cuda',
gradient_predivide_factor: Optional[float] = 1.0, gradient_predivide_factor: Optional[float] = 1.0,
reuse_fp16_shard: bool = False, reuse_fp16_shard: bool = False,
user_static_memstats: bool = False,
*args, *args,
**kwargs): **kwargs):
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.' assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
@ -110,10 +111,14 @@ class ShardedModelV2(nn.Module):
self.world_size = dist.get_world_size(self.process_group) self.world_size = dist.get_world_size(self.process_group)
self.rank = dist.get_rank(self.process_group) self.rank = dist.get_rank(self.process_group)
self.shard_strategy = shard_strategy self.shard_strategy = shard_strategy
self.user_static_memstats = user_static_memstats
self._use_memory_tracer = tensor_placement_policy == 'auto' self._use_memory_tracer = tensor_placement_policy == 'auto'
if self._use_memory_tracer: if self._use_memory_tracer:
self._memstats_collector = MemStatsCollector() if self.user_static_memstats:
self._memstats_collector = MemStatsCollectorStatic(self.module)
else:
self._memstats_collector = MemStatsCollector()
self._start_collect_memstats = disposable(self._memstats_collector.start_collection) self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection) self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
else: else:
@ -206,9 +211,11 @@ class ShardedModelV2(nn.Module):
f.write(str(self._memstats_collector.non_model_data_list('cpu', 'GB'))) f.write(str(self._memstats_collector.non_model_data_list('cpu', 'GB')))
f.write('\n') f.write('\n')
def _pre_forward_operations(self): def _pre_forward_operations(self, *args):
# the operation will affect the memory tracer behavior in ZeroHook # the operation will affect the memory tracer behavior in ZeroHook
if self._memstats_collector: if self._memstats_collector:
if self.user_static_memstats:
self.init_mem_stats(*args)
self._start_collect_memstats() self._start_collect_memstats()
for p in self.module.parameters(): for p in self.module.parameters():
@ -223,7 +230,7 @@ class ShardedModelV2(nn.Module):
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD) p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._pre_forward_operations() self._pre_forward_operations(*args)
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs) args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
outputs = self.module(*args, **kwargs) outputs = self.module(*args, **kwargs)
self._post_forward_operations() self._post_forward_operations()