[Gemini] polish memstats collector (#1962)

pull/1964/head
Jiarui Fang 2022-11-16 15:45:57 +08:00 committed by GitHub
parent fea3cb661c
commit c4739a725a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 201 additions and 174 deletions

View File

@ -6,7 +6,7 @@ import torch
from colossalai.gemini.chunk import Chunk, ChunkManager from colossalai.gemini.chunk import Chunk, ChunkManager
from .memory_tracer.memstats_collector import MemStatsCollectorV2, MemStatsCollectorStatic from .memory_tracer import ChunkMemStatsCollector, StaticMemStatsCollector
from .placement_policy import PlacementPolicyFactory from .placement_policy import PlacementPolicyFactory
@ -26,7 +26,8 @@ class GeminiManager:
chunk_manager (ChunkManager): A ``ChunkManager`` instance. chunk_manager (ChunkManager): A ``ChunkManager`` instance.
""" """
def __init__(self, placement_policy: str, def __init__(self,
placement_policy: str,
chunk_manager: ChunkManager, chunk_manager: ChunkManager,
module: Optional[torch.nn.Module] = None, module: Optional[torch.nn.Module] = None,
use_static_memstats: bool = False) -> None: use_static_memstats: bool = False) -> None:
@ -35,14 +36,14 @@ class GeminiManager:
self.policy_name = placement_policy self.policy_name = placement_policy
policy_cls = PlacementPolicyFactory.create(placement_policy) policy_cls = PlacementPolicyFactory.create(placement_policy)
self._chunk_manager = chunk_manager self._chunk_manager = chunk_manager
# self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None # self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager) if policy_cls.need_mem_stats else None
self.use_static_memstats = use_static_memstats self.use_static_memstats = use_static_memstats
if policy_cls.need_mem_stats: if policy_cls.need_mem_stats:
if use_static_memstats: if use_static_memstats:
assert module is not None assert module is not None
self._mem_stats_collector = MemStatsCollectorStatic(module, chunk_manager) self._mem_stats_collector = StaticMemStatsCollector(module, chunk_manager)
else: else:
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager)
else: else:
self._mem_stats_collector = None self._mem_stats_collector = None

View File

@ -1,5 +1,10 @@
from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor from .memstats_collector import MemStatsCollector # isort:skip
from .memstats_collector import MemStatsCollector from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER # isort:skip
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
from .static_memstats_collector import StaticMemStatsCollector # isort:skip
__all__ = ['AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER'] __all__ = [
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER'
]

View File

@ -0,0 +1,25 @@
from colossalai.gemini.chunk import ChunkManager
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from .memstats_collector import MemStatsCollector
class ChunkMemStatsCollector(MemStatsCollector):
def __init__(self, chunk_manager: ChunkManager) -> None:
super().__init__()
self._chunk_manager = chunk_manager
def sample_model_data(self) -> None:
"""Sampling model data statistics.
"""
if self._start_flag:
cuda_mem = self._chunk_manager.total_mem['cuda']
cpu_mem = self._chunk_manager.total_mem['cpu']
self._model_data_cuda_list.append(cuda_mem)
self._model_data_cpu_list.append(cpu_mem)
@property
def cuda_margin_mem(self) -> float:
return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda'))

View File

@ -1,20 +1,11 @@
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor import time
from colossalai.utils.memory import colo_device_memory_used, colo_device_memory_capacity from typing import List
from colossalai.utils import get_current_device
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.gemini.chunk import ChunkManager
import torch import torch
import torch.nn as nn
import time
from typing import List, Optional
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.fx.profiler import (calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size) from colossalai.gemini.stateful_tensor import StatefulTensor
from torch.fx import symbolic_trace from colossalai.utils.memory import colo_device_memory_used
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
class MemStatsCollector: class MemStatsCollector:
@ -138,121 +129,3 @@ class MemStatsCollector:
self._start_flag = False self._start_flag = False
self._step_idx = 0 self._step_idx = 0
self._step_total = 0 self._step_total = 0
class MemStatsCollectorV2(MemStatsCollector):
def __init__(self, chunk_manager: ChunkManager) -> None:
super().__init__()
self._chunk_manager = chunk_manager
def sample_model_data(self) -> None:
"""Sampling model data statistics.
"""
if self._start_flag:
cuda_mem = self._chunk_manager.total_mem['cuda']
cpu_mem = self._chunk_manager.total_mem['cpu']
self._model_data_cuda_list.append(cuda_mem)
self._model_data_cpu_list.append(cpu_mem)
@property
def cuda_margin_mem(self) -> float:
return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda'))
class MemStatsCollectorStatic(MemStatsCollectorV2):
"""
A Static Memory statistic collector.
"""
def __init__(self, module: nn.Module, chunk_manager: ChunkManager) -> None:
super().__init__(chunk_manager)
self.module = module
self.module_info_list = []
def init_mem_stats(self, *inputs):
self.register_opnodes_recursively(self.module)
self.refactor_module()
self.module = self.module.cpu()
self.module.train()
data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs]
gm = symbolic_trace(self.module)
interp = MetaInfoProp(gm)
interp.propagate(*data)
total_mem = 0
for inp in inputs:
total_mem += inp.numel() * inp.element_size()
last_node = None
module_name_list = [mInfo.module_full_name for mInfo in self.module_info_list]
for node in gm.graph.nodes:
total_mem = total_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node)
if node.op == "call_module":
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
self._non_model_data_cuda_list.append(total_mem)
last_node = node
self._non_model_data_cuda_list.append(total_mem)
self._non_model_data_cuda_list = self._non_model_data_cuda_list[1:]
cur_module_mem_fwd = 0
cur_module_mem_bwd = 0
grad_module_out = last_node.meta["fwd_mem_out"]
for node in gm.graph.nodes.__reversed__():
cur_module_mem_fwd = cur_module_mem_fwd + calculate_fwd_tmp(node) + calculate_fwd_out(node)
cur_module_mem_bwd = cur_module_mem_bwd + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
if node.op == "call_module":
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
self._non_model_data_cuda_list.append(total_mem + grad_module_out + cur_module_mem_bwd)
total_mem = total_mem - cur_module_mem_fwd
cur_module_mem_fwd = 0
cur_module_mem_bwd = 0
grad_module_out = node.meta["bwd_mem_out"]
self._step_total = len(self._non_model_data_cuda_list)
self.recover_module()
def refactor_module(self):
for modInfo in self.module_info_list:
temp_node = nn.Sequential(nn.ReLU(), modInfo.module)
modInfo.parent_module.__setattr__(modInfo.module_name, temp_node)
def recover_module(self):
for modInfo in self.module_info_list:
modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module)
def register_opnodes_recursively(self,
module: torch.nn.Module,
name: str = "",
full_name: str = "",
parent_module: Optional[torch.nn.Module] = None):
assert isinstance(module, torch.nn.Module)
for child_name, child in module.named_children():
self.register_opnodes_recursively(child, child_name, full_name + "_" + child_name, module)
# Early return on modules with no parameters.
if len(list(module.parameters(recurse=False))) == 0:
return
self.module_info_list.append(ModuleInfos(module, name, full_name[1:], parent_module))
class ModuleInfos:
def __init__(self,
module: torch.nn.Module,
module_name: str,
module_full_name: str,
parent_module: torch.nn.Module):
self.module = module
self.module_name = module_name
self.module_full_name = module_full_name
self.parent_module = parent_module

View File

@ -0,0 +1,105 @@
from typing import Optional
import torch
import torch.nn as nn
from torch.fx import symbolic_trace
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta
from colossalai.gemini.chunk import ChunkManager
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
from .chunk_memstats_collector import ChunkMemStatsCollector
class ModuleInfos:
def __init__(self, module: torch.nn.Module, module_name: str, module_full_name: str,
parent_module: torch.nn.Module):
self.module = module
self.module_name = module_name
self.module_full_name = module_full_name
self.parent_module = parent_module
class StaticMemStatsCollector(ChunkMemStatsCollector):
"""
A Static Memory statistic collector.
"""
def __init__(self, module: nn.Module, chunk_manager: ChunkManager) -> None:
super().__init__(chunk_manager)
self.module = module
self.module_info_list = []
def init_mem_stats(self, *inputs):
self.register_opnodes_recursively(self.module)
self.refactor_module()
self.module = self.module.cpu()
self.module.train()
data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs]
gm = symbolic_trace(self.module)
interp = MetaInfoProp(gm)
interp.propagate(*data)
total_mem = 0
for inp in inputs:
total_mem += inp.numel() * inp.element_size()
last_node = None
module_name_list = [mInfo.module_full_name for mInfo in self.module_info_list]
for node in gm.graph.nodes:
total_mem = total_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node)
if node.op == "call_module":
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
self._non_model_data_cuda_list.append(total_mem)
last_node = node
self._non_model_data_cuda_list.append(total_mem)
self._non_model_data_cuda_list = self._non_model_data_cuda_list[1:]
cur_module_mem_fwd = 0
cur_module_mem_bwd = 0
grad_module_out = last_node.meta["fwd_mem_out"]
for node in gm.graph.nodes.__reversed__():
cur_module_mem_fwd = cur_module_mem_fwd + calculate_fwd_tmp(node) + calculate_fwd_out(node)
cur_module_mem_bwd = cur_module_mem_bwd + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
if node.op == "call_module":
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
self._non_model_data_cuda_list.append(total_mem + grad_module_out + cur_module_mem_bwd)
total_mem = total_mem - cur_module_mem_fwd
cur_module_mem_fwd = 0
cur_module_mem_bwd = 0
grad_module_out = node.meta["bwd_mem_out"]
self._step_total = len(self._non_model_data_cuda_list)
self.recover_module()
def refactor_module(self):
for modInfo in self.module_info_list:
temp_node = nn.Sequential(nn.ReLU(), modInfo.module)
modInfo.parent_module.__setattr__(modInfo.module_name, temp_node)
def recover_module(self):
for modInfo in self.module_info_list:
modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module)
def register_opnodes_recursively(self,
module: torch.nn.Module,
name: str = "",
full_name: str = "",
parent_module: Optional[torch.nn.Module] = None):
assert isinstance(module, torch.nn.Module)
for child_name, child in module.named_children():
self.register_opnodes_recursively(child, child_name, full_name + "_" + child_name, module)
# Early return on modules with no parameters.
if len(list(module.parameters(recurse=False))) == 0:
return
self.module_info_list.append(ModuleInfos(module, name, full_name[1:], parent_module))

View File

@ -1,22 +1,24 @@
import functools
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from time import time from time import time
from typing import List, Optional, Tuple, Dict from typing import Dict, List, Optional, Tuple, Type
import torch import torch
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.gemini.memory_tracer import ChunkMemStatsCollector
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollectorV2
from typing import Type
import functools
from colossalai.gemini.chunk import Chunk, ChunkManager
class PlacementPolicy(ABC): class PlacementPolicy(ABC):
need_mem_stats: bool = False need_mem_stats: bool = False
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
self.chunk_manager = chunk_manager self.chunk_manager = chunk_manager
self.mem_stats_collector: Optional[MemStatsCollectorV2] = mem_stats_collector self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
@abstractmethod @abstractmethod
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
@ -29,7 +31,9 @@ class PlacementPolicy(ABC):
class CPUPlacementPolicy(PlacementPolicy): class CPUPlacementPolicy(PlacementPolicy):
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
@ -44,7 +48,9 @@ class CPUPlacementPolicy(PlacementPolicy):
class CUDAPlacementPolicy(PlacementPolicy): class CUDAPlacementPolicy(PlacementPolicy):
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available' assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
@ -65,7 +71,9 @@ class AutoPlacementPolicy(PlacementPolicy):
_warmup_non_model_data_ratio: float = 0.8 _warmup_non_model_data_ratio: float = 0.8
_steady_cuda_cap_ratio: float = 0.9 _steady_cuda_cap_ratio: float = 0.9
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self, def evict_tensors(self,
@ -154,7 +162,9 @@ class ConstPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = False need_mem_stats: bool = False
_accessed_memory_boundary = 512 * 1024**2 _accessed_memory_boundary = 512 * 1024**2
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self, def evict_tensors(self,

View File

@ -1,31 +1,39 @@
import functools import functools
from collections import OrderedDict
from typing import Any, Optional, Iterator, Tuple
from copy import deepcopy
import itertools import itertools
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Iterator, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector
from colossalai.gemini.ophooks import register_ophooks_recursively from colossalai.gemini.ophooks import register_ophooks_recursively
from colossalai.zero.utils import ZeroHook
from colossalai.gemini.paramhooks import BaseParamHookMgr from colossalai.gemini.paramhooks import BaseParamHookMgr
from colossalai.gemini.stateful_tensor import TensorState
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory
from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device, disposable from colossalai.utils import disposable, get_current_device
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector, MemStatsCollectorStatic
from colossalai.utils.memory import colo_device_memory_capacity from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from torch.distributed import ProcessGroup from colossalai.zero.utils import ZeroHook
from torch.nn.parameter import Parameter
from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu
from colossalai.gemini.stateful_tensor import TensorState
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, from ._utils import (
get_gradient_predivide_factor) cast_float_arguments,
cast_tensor_to_fp16,
cast_tensor_to_fp32,
chunk_and_pad,
free_storage,
get_gradient_predivide_factor,
)
try: try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
@ -116,7 +124,7 @@ class ShardedModelV2(nn.Module):
self._use_memory_tracer = tensor_placement_policy == 'auto' self._use_memory_tracer = tensor_placement_policy == 'auto'
if self._use_memory_tracer: if self._use_memory_tracer:
if self.user_static_memstats: if self.user_static_memstats:
self._memstats_collector = MemStatsCollectorStatic(self.module) self._memstats_collector = StaticMemStatsCollector(self.module)
else: else:
self._memstats_collector = MemStatsCollector() self._memstats_collector = MemStatsCollector()
self._start_collect_memstats = disposable(self._memstats_collector.start_collection) self._start_collect_memstats = disposable(self._memstats_collector.start_collection)