mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] polish memstats collector (#1962)
parent
fea3cb661c
commit
c4739a725a
|
@ -6,7 +6,7 @@ import torch
|
|||
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||
|
||||
from .memory_tracer.memstats_collector import MemStatsCollectorV2, MemStatsCollectorStatic
|
||||
from .memory_tracer import ChunkMemStatsCollector, StaticMemStatsCollector
|
||||
from .placement_policy import PlacementPolicyFactory
|
||||
|
||||
|
||||
|
@ -26,7 +26,8 @@ class GeminiManager:
|
|||
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
|
||||
"""
|
||||
|
||||
def __init__(self, placement_policy: str,
|
||||
def __init__(self,
|
||||
placement_policy: str,
|
||||
chunk_manager: ChunkManager,
|
||||
module: Optional[torch.nn.Module] = None,
|
||||
use_static_memstats: bool = False) -> None:
|
||||
|
@ -35,14 +36,14 @@ class GeminiManager:
|
|||
self.policy_name = placement_policy
|
||||
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
||||
self._chunk_manager = chunk_manager
|
||||
# self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None
|
||||
# self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager) if policy_cls.need_mem_stats else None
|
||||
self.use_static_memstats = use_static_memstats
|
||||
if policy_cls.need_mem_stats:
|
||||
if use_static_memstats:
|
||||
assert module is not None
|
||||
self._mem_stats_collector = MemStatsCollectorStatic(module, chunk_manager)
|
||||
self._mem_stats_collector = StaticMemStatsCollector(module, chunk_manager)
|
||||
else:
|
||||
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager)
|
||||
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager)
|
||||
else:
|
||||
self._mem_stats_collector = None
|
||||
|
||||
|
|
|
@ -1,5 +1,10 @@
|
|||
from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor
|
||||
from .memstats_collector import MemStatsCollector
|
||||
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
|
||||
from .memstats_collector import MemStatsCollector # isort:skip
|
||||
from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER # isort:skip
|
||||
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
|
||||
from .static_memstats_collector import StaticMemStatsCollector # isort:skip
|
||||
|
||||
__all__ = ['AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER']
|
||||
__all__ = [
|
||||
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
|
||||
'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
from colossalai.gemini.chunk import ChunkManager
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
|
||||
from .memstats_collector import MemStatsCollector
|
||||
|
||||
|
||||
class ChunkMemStatsCollector(MemStatsCollector):
|
||||
|
||||
def __init__(self, chunk_manager: ChunkManager) -> None:
|
||||
super().__init__()
|
||||
self._chunk_manager = chunk_manager
|
||||
|
||||
def sample_model_data(self) -> None:
|
||||
"""Sampling model data statistics.
|
||||
"""
|
||||
if self._start_flag:
|
||||
cuda_mem = self._chunk_manager.total_mem['cuda']
|
||||
cpu_mem = self._chunk_manager.total_mem['cpu']
|
||||
self._model_data_cuda_list.append(cuda_mem)
|
||||
self._model_data_cpu_list.append(cpu_mem)
|
||||
|
||||
@property
|
||||
def cuda_margin_mem(self) -> float:
|
||||
return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda'))
|
|
@ -1,26 +1,17 @@
|
|||
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
||||
from colossalai.utils.memory import colo_device_memory_used, colo_device_memory_capacity
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
from colossalai.gemini.chunk import ChunkManager
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.profiler import (calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size)
|
||||
from torch.fx import symbolic_trace
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
from colossalai.utils.memory import colo_device_memory_used
|
||||
|
||||
|
||||
class MemStatsCollector:
|
||||
"""
|
||||
A Memory statistic collector.
|
||||
It works in two phases.
|
||||
It works in two phases.
|
||||
Phase 1. Collection Phase: collect memory usage statistics of CPU and GPU.
|
||||
The first iteration of DNN training.
|
||||
Phase 2. Runtime Phase: use the read-only collected stats
|
||||
|
@ -138,121 +129,3 @@ class MemStatsCollector:
|
|||
self._start_flag = False
|
||||
self._step_idx = 0
|
||||
self._step_total = 0
|
||||
|
||||
|
||||
class MemStatsCollectorV2(MemStatsCollector):
|
||||
|
||||
def __init__(self, chunk_manager: ChunkManager) -> None:
|
||||
super().__init__()
|
||||
self._chunk_manager = chunk_manager
|
||||
|
||||
def sample_model_data(self) -> None:
|
||||
"""Sampling model data statistics.
|
||||
"""
|
||||
if self._start_flag:
|
||||
cuda_mem = self._chunk_manager.total_mem['cuda']
|
||||
cpu_mem = self._chunk_manager.total_mem['cpu']
|
||||
self._model_data_cuda_list.append(cuda_mem)
|
||||
self._model_data_cpu_list.append(cpu_mem)
|
||||
|
||||
@property
|
||||
def cuda_margin_mem(self) -> float:
|
||||
return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda'))
|
||||
|
||||
|
||||
class MemStatsCollectorStatic(MemStatsCollectorV2):
|
||||
"""
|
||||
A Static Memory statistic collector.
|
||||
"""
|
||||
|
||||
def __init__(self, module: nn.Module, chunk_manager: ChunkManager) -> None:
|
||||
super().__init__(chunk_manager)
|
||||
self.module = module
|
||||
self.module_info_list = []
|
||||
|
||||
|
||||
def init_mem_stats(self, *inputs):
|
||||
|
||||
self.register_opnodes_recursively(self.module)
|
||||
self.refactor_module()
|
||||
|
||||
self.module = self.module.cpu()
|
||||
self.module.train()
|
||||
|
||||
data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs]
|
||||
gm = symbolic_trace(self.module)
|
||||
interp = MetaInfoProp(gm)
|
||||
interp.propagate(*data)
|
||||
|
||||
total_mem = 0
|
||||
for inp in inputs:
|
||||
total_mem += inp.numel() * inp.element_size()
|
||||
last_node = None
|
||||
module_name_list = [mInfo.module_full_name for mInfo in self.module_info_list]
|
||||
for node in gm.graph.nodes:
|
||||
total_mem = total_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node)
|
||||
if node.op == "call_module":
|
||||
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
|
||||
self._non_model_data_cuda_list.append(total_mem)
|
||||
last_node = node
|
||||
self._non_model_data_cuda_list.append(total_mem)
|
||||
self._non_model_data_cuda_list = self._non_model_data_cuda_list[1:]
|
||||
|
||||
cur_module_mem_fwd = 0
|
||||
cur_module_mem_bwd = 0
|
||||
grad_module_out = last_node.meta["fwd_mem_out"]
|
||||
for node in gm.graph.nodes.__reversed__():
|
||||
cur_module_mem_fwd = cur_module_mem_fwd + calculate_fwd_tmp(node) + calculate_fwd_out(node)
|
||||
cur_module_mem_bwd = cur_module_mem_bwd + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
|
||||
if node.op == "call_module":
|
||||
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
|
||||
self._non_model_data_cuda_list.append(total_mem + grad_module_out + cur_module_mem_bwd)
|
||||
total_mem = total_mem - cur_module_mem_fwd
|
||||
cur_module_mem_fwd = 0
|
||||
cur_module_mem_bwd = 0
|
||||
grad_module_out = node.meta["bwd_mem_out"]
|
||||
|
||||
self._step_total = len(self._non_model_data_cuda_list)
|
||||
self.recover_module()
|
||||
|
||||
|
||||
def refactor_module(self):
|
||||
for modInfo in self.module_info_list:
|
||||
temp_node = nn.Sequential(nn.ReLU(), modInfo.module)
|
||||
modInfo.parent_module.__setattr__(modInfo.module_name, temp_node)
|
||||
|
||||
|
||||
def recover_module(self):
|
||||
for modInfo in self.module_info_list:
|
||||
modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module)
|
||||
|
||||
|
||||
def register_opnodes_recursively(self,
|
||||
module: torch.nn.Module,
|
||||
name: str = "",
|
||||
full_name: str = "",
|
||||
parent_module: Optional[torch.nn.Module] = None):
|
||||
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
|
||||
for child_name, child in module.named_children():
|
||||
self.register_opnodes_recursively(child, child_name, full_name + "_" + child_name, module)
|
||||
|
||||
# Early return on modules with no parameters.
|
||||
if len(list(module.parameters(recurse=False))) == 0:
|
||||
return
|
||||
|
||||
self.module_info_list.append(ModuleInfos(module, name, full_name[1:], parent_module))
|
||||
|
||||
|
||||
class ModuleInfos:
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
module_name: str,
|
||||
module_full_name: str,
|
||||
parent_module: torch.nn.Module):
|
||||
self.module = module
|
||||
self.module_name = module_name
|
||||
self.module_full_name = module_full_name
|
||||
self.parent_module = parent_module
|
|
@ -0,0 +1,105 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import symbolic_trace
|
||||
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta
|
||||
from colossalai.gemini.chunk import ChunkManager
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
|
||||
from .chunk_memstats_collector import ChunkMemStatsCollector
|
||||
|
||||
|
||||
class ModuleInfos:
|
||||
|
||||
def __init__(self, module: torch.nn.Module, module_name: str, module_full_name: str,
|
||||
parent_module: torch.nn.Module):
|
||||
self.module = module
|
||||
self.module_name = module_name
|
||||
self.module_full_name = module_full_name
|
||||
self.parent_module = parent_module
|
||||
|
||||
|
||||
class StaticMemStatsCollector(ChunkMemStatsCollector):
|
||||
"""
|
||||
A Static Memory statistic collector.
|
||||
"""
|
||||
|
||||
def __init__(self, module: nn.Module, chunk_manager: ChunkManager) -> None:
|
||||
super().__init__(chunk_manager)
|
||||
self.module = module
|
||||
self.module_info_list = []
|
||||
|
||||
def init_mem_stats(self, *inputs):
|
||||
|
||||
self.register_opnodes_recursively(self.module)
|
||||
self.refactor_module()
|
||||
|
||||
self.module = self.module.cpu()
|
||||
self.module.train()
|
||||
|
||||
data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs]
|
||||
gm = symbolic_trace(self.module)
|
||||
interp = MetaInfoProp(gm)
|
||||
interp.propagate(*data)
|
||||
|
||||
total_mem = 0
|
||||
for inp in inputs:
|
||||
total_mem += inp.numel() * inp.element_size()
|
||||
last_node = None
|
||||
module_name_list = [mInfo.module_full_name for mInfo in self.module_info_list]
|
||||
for node in gm.graph.nodes:
|
||||
total_mem = total_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node)
|
||||
if node.op == "call_module":
|
||||
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
|
||||
self._non_model_data_cuda_list.append(total_mem)
|
||||
last_node = node
|
||||
self._non_model_data_cuda_list.append(total_mem)
|
||||
self._non_model_data_cuda_list = self._non_model_data_cuda_list[1:]
|
||||
|
||||
cur_module_mem_fwd = 0
|
||||
cur_module_mem_bwd = 0
|
||||
grad_module_out = last_node.meta["fwd_mem_out"]
|
||||
for node in gm.graph.nodes.__reversed__():
|
||||
cur_module_mem_fwd = cur_module_mem_fwd + calculate_fwd_tmp(node) + calculate_fwd_out(node)
|
||||
cur_module_mem_bwd = cur_module_mem_bwd + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
|
||||
if node.op == "call_module":
|
||||
if node.name.endswith("_0") and node.name[:-2] in module_name_list:
|
||||
self._non_model_data_cuda_list.append(total_mem + grad_module_out + cur_module_mem_bwd)
|
||||
total_mem = total_mem - cur_module_mem_fwd
|
||||
cur_module_mem_fwd = 0
|
||||
cur_module_mem_bwd = 0
|
||||
grad_module_out = node.meta["bwd_mem_out"]
|
||||
|
||||
self._step_total = len(self._non_model_data_cuda_list)
|
||||
self.recover_module()
|
||||
|
||||
def refactor_module(self):
|
||||
for modInfo in self.module_info_list:
|
||||
temp_node = nn.Sequential(nn.ReLU(), modInfo.module)
|
||||
modInfo.parent_module.__setattr__(modInfo.module_name, temp_node)
|
||||
|
||||
def recover_module(self):
|
||||
for modInfo in self.module_info_list:
|
||||
modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module)
|
||||
|
||||
def register_opnodes_recursively(self,
|
||||
module: torch.nn.Module,
|
||||
name: str = "",
|
||||
full_name: str = "",
|
||||
parent_module: Optional[torch.nn.Module] = None):
|
||||
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
|
||||
for child_name, child in module.named_children():
|
||||
self.register_opnodes_recursively(child, child_name, full_name + "_" + child_name, module)
|
||||
|
||||
# Early return on modules with no parameters.
|
||||
if len(list(module.parameters(recurse=False))) == 0:
|
||||
return
|
||||
|
||||
self.module_info_list.append(ModuleInfos(module, name, full_name[1:], parent_module))
|
|
@ -1,22 +1,24 @@
|
|||
import functools
|
||||
from abc import ABC, abstractmethod
|
||||
from time import time
|
||||
from typing import List, Optional, Tuple, Dict
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||
from colossalai.gemini.memory_tracer import ChunkMemStatsCollector
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
|
||||
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollectorV2
|
||||
from typing import Type
|
||||
import functools
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||
|
||||
|
||||
class PlacementPolicy(ABC):
|
||||
need_mem_stats: bool = False
|
||||
|
||||
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
|
||||
self.chunk_manager = chunk_manager
|
||||
self.mem_stats_collector: Optional[MemStatsCollectorV2] = mem_stats_collector
|
||||
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
|
||||
|
||||
@abstractmethod
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
|
||||
|
@ -29,7 +31,9 @@ class PlacementPolicy(ABC):
|
|||
|
||||
class CPUPlacementPolicy(PlacementPolicy):
|
||||
|
||||
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
|
||||
|
@ -44,7 +48,9 @@ class CPUPlacementPolicy(PlacementPolicy):
|
|||
|
||||
class CUDAPlacementPolicy(PlacementPolicy):
|
||||
|
||||
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
|
||||
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
|
||||
|
@ -65,7 +71,9 @@ class AutoPlacementPolicy(PlacementPolicy):
|
|||
_warmup_non_model_data_ratio: float = 0.8
|
||||
_steady_cuda_cap_ratio: float = 0.9
|
||||
|
||||
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self,
|
||||
|
@ -154,7 +162,9 @@ class ConstPlacementPolicy(PlacementPolicy):
|
|||
need_mem_stats: bool = False
|
||||
_accessed_memory_boundary = 512 * 1024**2
|
||||
|
||||
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
|
||||
def __init__(self,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self,
|
||||
|
|
|
@ -1,31 +1,39 @@
|
|||
import functools
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Optional, Iterator, Tuple
|
||||
from copy import deepcopy
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Any, Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector
|
||||
from colossalai.gemini.ophooks import register_ophooks_recursively
|
||||
from colossalai.zero.utils import ZeroHook
|
||||
from colossalai.gemini.paramhooks import BaseParamHookMgr
|
||||
from colossalai.gemini.stateful_tensor import TensorState
|
||||
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory
|
||||
from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device, disposable
|
||||
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector, MemStatsCollectorStatic
|
||||
from colossalai.utils import disposable, get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu
|
||||
from colossalai.gemini.stateful_tensor import TensorState
|
||||
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy
|
||||
from colossalai.zero.utils import ZeroHook
|
||||
|
||||
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
||||
get_gradient_predivide_factor)
|
||||
from ._utils import (
|
||||
cast_float_arguments,
|
||||
cast_tensor_to_fp16,
|
||||
cast_tensor_to_fp32,
|
||||
chunk_and_pad,
|
||||
free_storage,
|
||||
get_gradient_predivide_factor,
|
||||
)
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
|
@ -49,7 +57,7 @@ class ShardedModelV2(nn.Module):
|
|||
module (nn.Module): A sharded module, which must be initialized by `ZeroInitContext`.
|
||||
shard_strategy (BaseShardStrategy): A shard strategy to manage shard behavior.
|
||||
process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None.
|
||||
reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group.
|
||||
reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group.
|
||||
Generally, it should be `None`, and it's the same as `process_group`. Defaults to None.
|
||||
reduce_scatter_bucket_size_mb (int, optional): Reduce-scatter bucket size in *MB*. Defaults to 25.
|
||||
fp32_reduce_scatter (bool, optional): If set to `True`, gradients are forced to FP32 before reduce-scatter. Defaults to False.
|
||||
|
@ -60,10 +68,10 @@ class ShardedModelV2(nn.Module):
|
|||
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
|
||||
Defaults to 'cuda'.
|
||||
gradient_predivide_factor (Optional[float], optional): Gradient is divived by this value before reduce-scatter. Defaults to 1.0.
|
||||
reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad.
|
||||
Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation.
|
||||
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
|
||||
We find that PyTorch's optimizers don't support mixed precision,
|
||||
reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad.
|
||||
Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation.
|
||||
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
|
||||
We find that PyTorch's optimizers don't support mixed precision,
|
||||
so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False.
|
||||
"""
|
||||
|
||||
|
@ -116,7 +124,7 @@ class ShardedModelV2(nn.Module):
|
|||
self._use_memory_tracer = tensor_placement_policy == 'auto'
|
||||
if self._use_memory_tracer:
|
||||
if self.user_static_memstats:
|
||||
self._memstats_collector = MemStatsCollectorStatic(self.module)
|
||||
self._memstats_collector = StaticMemStatsCollector(self.module)
|
||||
else:
|
||||
self._memstats_collector = MemStatsCollector()
|
||||
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
|
||||
|
|
Loading…
Reference in New Issue