2023-03-04 12:08:11 +00:00
|
|
|
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
2022-03-01 10:17:01 +00:00
|
|
|
import functools
|
2022-11-16 07:45:57 +00:00
|
|
|
import itertools
|
2022-03-08 10:18:06 +00:00
|
|
|
from collections import OrderedDict
|
2022-05-27 02:25:08 +00:00
|
|
|
from copy import deepcopy
|
2022-11-16 07:45:57 +00:00
|
|
|
from typing import Any, Iterator, Optional, Tuple
|
|
|
|
|
2022-03-01 10:17:01 +00:00
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torch.nn as nn
|
2022-11-16 07:45:57 +00:00
|
|
|
from torch.distributed import ProcessGroup
|
|
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
2022-03-01 10:17:01 +00:00
|
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
from colossalai.logging import get_dist_logger
|
2022-11-16 07:45:57 +00:00
|
|
|
from colossalai.utils import disposable, get_current_device
|
2022-04-11 08:47:57 +00:00
|
|
|
from colossalai.utils.memory import colo_device_memory_capacity
|
2023-04-04 05:48:16 +00:00
|
|
|
from colossalai.zero.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector
|
|
|
|
from colossalai.zero.legacy.gemini.ophooks import register_ophooks_recursively
|
|
|
|
from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr
|
|
|
|
from colossalai.zero.legacy.gemini.stateful_tensor import TensorState
|
|
|
|
from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
|
|
|
from colossalai.zero.legacy.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory
|
|
|
|
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_move_to_cpu
|
|
|
|
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
|
|
|
|
from colossalai.zero.legacy.sharded_model.reduce_scatter import ReduceScatterBucketer
|
2022-03-15 09:07:35 +00:00
|
|
|
|
2022-11-16 07:45:57 +00:00
|
|
|
from ._utils import (
|
|
|
|
cast_float_arguments,
|
2023-06-05 07:58:31 +00:00
|
|
|
cast_tensor_to_bf16,
|
2022-11-16 07:45:57 +00:00
|
|
|
cast_tensor_to_fp16,
|
|
|
|
cast_tensor_to_fp32,
|
|
|
|
chunk_and_pad,
|
|
|
|
free_storage,
|
|
|
|
get_gradient_predivide_factor,
|
|
|
|
)
|
2023-04-04 05:48:16 +00:00
|
|
|
from .zero_hook import ZeroHook
|
2022-03-02 10:28:29 +00:00
|
|
|
|
2022-06-02 05:48:22 +00:00
|
|
|
try:
|
|
|
|
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
|
|
|
except ImportError:
|
|
|
|
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
|
|
|
|
2022-03-01 10:17:01 +00:00
|
|
|
|
|
|
|
class ShardedModelV2(nn.Module):
|
2022-03-25 03:23:35 +00:00
|
|
|
"""
|
|
|
|
A wrapper for the PyTorch module shards the model parameters among multiple GPU memory.
|
2022-04-01 06:50:56 +00:00
|
|
|
Only `1/#nproc` of parameters, gradients are stored in local CUDA memory, so forward and backward
|
2022-03-25 03:23:35 +00:00
|
|
|
passes can be executed with limited CUDA memory budget.
|
2022-03-29 04:48:00 +00:00
|
|
|
|
2022-04-01 06:50:56 +00:00
|
|
|
Note:
|
|
|
|
You must use ``ShardedModelV2`` with ``ShardedOptimizerV2``.
|
|
|
|
Note:
|
|
|
|
Make sure you don't use gradient accumulation and your optimizer can work with fp16 gradient and fp32 parameter,
|
|
|
|
if you enable ``reuse_fp16_shard``.
|
2022-03-18 08:48:20 +00:00
|
|
|
|
2022-03-23 06:59:59 +00:00
|
|
|
Args:
|
|
|
|
module (nn.Module): A sharded module, which must be initialized by `ZeroInitContext`.
|
|
|
|
shard_strategy (BaseShardStrategy): A shard strategy to manage shard behavior.
|
|
|
|
process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None.
|
2022-11-16 07:45:57 +00:00
|
|
|
reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group.
|
2022-03-23 06:59:59 +00:00
|
|
|
Generally, it should be `None`, and it's the same as `process_group`. Defaults to None.
|
|
|
|
reduce_scatter_bucket_size_mb (int, optional): Reduce-scatter bucket size in *MB*. Defaults to 25.
|
|
|
|
fp32_reduce_scatter (bool, optional): If set to `True`, gradients are forced to FP32 before reduce-scatter. Defaults to False.
|
2022-04-13 07:00:48 +00:00
|
|
|
tensor_placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'.
|
|
|
|
If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used.
|
|
|
|
If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used.
|
|
|
|
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
|
2022-04-14 04:04:45 +00:00
|
|
|
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
|
2022-04-13 07:00:48 +00:00
|
|
|
Defaults to 'cuda'.
|
2023-06-07 16:01:29 +00:00
|
|
|
gradient_predivide_factor (Optional[float], optional): Gradient is divided by this value before reduce-scatter. Defaults to 1.0.
|
2022-11-16 07:45:57 +00:00
|
|
|
reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad.
|
|
|
|
Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation.
|
|
|
|
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
|
|
|
|
We find that PyTorch's optimizers don't support mixed precision,
|
2022-03-23 06:59:59 +00:00
|
|
|
so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False.
|
2023-06-05 07:58:31 +00:00
|
|
|
bf16 (bool, optional): Whether to use bfloat16 for param and grad. Defaults to False.
|
2022-03-18 08:48:20 +00:00
|
|
|
"""
|
2022-03-03 04:42:57 +00:00
|
|
|
|
2022-03-08 10:18:06 +00:00
|
|
|
def __init__(self,
|
|
|
|
module: nn.Module,
|
2022-03-18 08:48:20 +00:00
|
|
|
shard_strategy: BaseShardStrategy,
|
2022-03-08 10:18:06 +00:00
|
|
|
process_group: Optional[ProcessGroup] = None,
|
|
|
|
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
|
|
|
reduce_scatter_bucket_size_mb: int = 25,
|
|
|
|
fp32_reduce_scatter: bool = False,
|
2022-04-13 07:00:48 +00:00
|
|
|
tensor_placement_policy: str = 'cuda',
|
2022-03-08 10:18:06 +00:00
|
|
|
gradient_predivide_factor: Optional[float] = 1.0,
|
2022-06-30 07:23:50 +00:00
|
|
|
reuse_fp16_shard: bool = False,
|
2023-06-05 07:58:31 +00:00
|
|
|
bf16: bool = False,
|
2022-06-30 07:23:50 +00:00
|
|
|
*args,
|
|
|
|
**kwargs):
|
2022-06-21 03:33:53 +00:00
|
|
|
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
|
2022-03-01 10:17:01 +00:00
|
|
|
super().__init__()
|
|
|
|
self.logger = get_dist_logger()
|
2023-06-05 07:58:31 +00:00
|
|
|
self.bf16 = bf16
|
2022-03-01 10:17:01 +00:00
|
|
|
|
2022-03-18 07:44:47 +00:00
|
|
|
# We force users to use ZeroInitContext
|
2022-03-31 10:34:11 +00:00
|
|
|
for submodule in module.modules():
|
|
|
|
sharded_cnt = 0
|
|
|
|
unshard_cnt = 0
|
|
|
|
for param in submodule.parameters(recurse=False):
|
|
|
|
assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.'
|
|
|
|
if param.colo_attr.param_is_sharded:
|
|
|
|
sharded_cnt += 1
|
|
|
|
else:
|
|
|
|
unshard_cnt += 1
|
|
|
|
assert (not sharded_cnt) or (not unshard_cnt), 'nn.Module can not both have shard param and unshard param'
|
|
|
|
submodule.param_is_sharded = (sharded_cnt > 0)
|
|
|
|
|
|
|
|
self.sharded_params = []
|
|
|
|
self.unshard_params = []
|
2022-03-18 07:44:47 +00:00
|
|
|
for param in module.parameters():
|
2022-03-31 10:34:11 +00:00
|
|
|
if param.colo_attr.param_is_sharded:
|
|
|
|
self.sharded_params.append(param)
|
|
|
|
else:
|
|
|
|
self.unshard_params.append(param)
|
2022-03-18 07:44:47 +00:00
|
|
|
|
2022-03-31 10:34:11 +00:00
|
|
|
self.module = module
|
2022-03-01 10:17:01 +00:00
|
|
|
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
|
|
|
self.reduce_scatter_process_group = reduce_scatter_process_group or self.process_group
|
|
|
|
self.world_size = dist.get_world_size(self.process_group)
|
|
|
|
self.rank = dist.get_rank(self.process_group)
|
2022-03-08 10:18:06 +00:00
|
|
|
self.shard_strategy = shard_strategy
|
2022-03-01 10:17:01 +00:00
|
|
|
|
2022-04-13 07:00:48 +00:00
|
|
|
self._use_memory_tracer = tensor_placement_policy == 'auto'
|
2022-03-14 14:05:30 +00:00
|
|
|
if self._use_memory_tracer:
|
2022-12-06 08:43:06 +00:00
|
|
|
self._memstats_collector = MemStatsCollector()
|
2022-04-11 02:46:08 +00:00
|
|
|
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
|
|
|
|
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
|
2022-03-14 14:05:30 +00:00
|
|
|
else:
|
|
|
|
self._memstats_collector = None
|
2022-04-14 03:07:29 +00:00
|
|
|
self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create(
|
|
|
|
tensor_placement_policy)(mem_stats_collector=self._memstats_collector)
|
2022-04-26 07:05:03 +00:00
|
|
|
|
2022-06-30 07:23:50 +00:00
|
|
|
if 'warmup_non_model_data_ratio' in kwargs:
|
|
|
|
if tensor_placement_policy != 'auto':
|
|
|
|
self.logger.warning('setting warmup_non_model_data_ratio is useless if not use auto placement')
|
|
|
|
else:
|
|
|
|
ratio = kwargs['warmup_non_model_data_ratio']
|
|
|
|
self._tensor_placement_policy._warmup_non_model_data_ratio = ratio
|
|
|
|
self.logger.info(f'setting warmup_non_model_data_ratio as {ratio} for auto placement')
|
|
|
|
|
2022-04-13 07:00:48 +00:00
|
|
|
self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy)
|
2022-04-26 07:05:03 +00:00
|
|
|
param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, 'colo_attr')]
|
|
|
|
self._stateful_tensor_mgr.register_stateful_tensor_list(param_tensor_list)
|
2022-03-14 14:05:30 +00:00
|
|
|
|
2022-03-01 10:17:01 +00:00
|
|
|
# Register hooks
|
2022-04-08 09:51:34 +00:00
|
|
|
self._ophook_list = [
|
|
|
|
ZeroHook(self.shard_strategy, self._memstats_collector, self._stateful_tensor_mgr, self.process_group)
|
|
|
|
]
|
2022-04-08 12:23:26 +00:00
|
|
|
register_ophooks_recursively(self.module, self._ophook_list)
|
2022-04-11 05:38:51 +00:00
|
|
|
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
|
2022-03-02 10:28:29 +00:00
|
|
|
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
|
|
|
|
|
|
|
self.fp32_reduce_scatter = fp32_reduce_scatter
|
2022-04-13 07:00:48 +00:00
|
|
|
self._cpu_offload: bool = tensor_placement_policy != 'cuda'
|
2022-03-22 06:56:59 +00:00
|
|
|
for param in module.parameters():
|
2022-03-23 06:59:59 +00:00
|
|
|
# Init `offload_grad`
|
2022-03-31 04:25:45 +00:00
|
|
|
param.colo_attr.offload_grad = self._cpu_offload
|
2022-03-22 06:56:59 +00:00
|
|
|
|
2022-03-02 10:28:29 +00:00
|
|
|
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
|
|
|
|
# So we use 1.0 as the default gradient_predivide_factor
|
2022-03-03 07:06:18 +00:00
|
|
|
# However, if you set gradient_predivide_factor to None, we will set
|
|
|
|
# gradient_predivide_factor to a value >= 1.0 automatically
|
|
|
|
self.gradient_predivide_factor: float = gradient_predivide_factor if \
|
|
|
|
gradient_predivide_factor is not None else \
|
2022-03-02 10:28:29 +00:00
|
|
|
get_gradient_predivide_factor(self.world_size)
|
|
|
|
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
|
|
|
|
|
|
|
|
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
|
|
|
self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb)
|
|
|
|
self._require_backward_grad_sync: bool = True
|
2022-03-01 10:17:01 +00:00
|
|
|
|
2022-03-15 04:02:19 +00:00
|
|
|
self._cuda_margin_space = 0
|
2022-03-23 06:59:59 +00:00
|
|
|
self.reuse_fp16_shard = reuse_fp16_shard
|
2022-03-15 04:02:19 +00:00
|
|
|
|
2022-04-11 07:40:13 +00:00
|
|
|
# record whether gradients have inf or nan
|
|
|
|
self.overflow_counter = 0
|
|
|
|
|
2022-04-08 09:51:34 +00:00
|
|
|
def adjust_stateful_tensor_layout(self) -> None:
|
|
|
|
self._stateful_tensor_mgr.adjust_layout()
|
|
|
|
|
2022-03-29 07:45:48 +00:00
|
|
|
@property
|
|
|
|
def use_memory_tracer(self):
|
|
|
|
return self._use_memory_tracer
|
|
|
|
|
2022-03-15 04:02:19 +00:00
|
|
|
@property
|
|
|
|
def cuda_margin_space(self):
|
|
|
|
return self._cuda_margin_space
|
|
|
|
|
2022-03-11 06:40:01 +00:00
|
|
|
@property
|
|
|
|
def cpu_offload(self):
|
|
|
|
return self._cpu_offload
|
|
|
|
|
2022-04-08 09:51:34 +00:00
|
|
|
def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> None:
|
|
|
|
"""
|
2023-04-26 03:38:43 +00:00
|
|
|
dummy memory tracer collected information to a file.
|
2022-04-08 09:51:34 +00:00
|
|
|
try:
|
|
|
|
# forward: model(inputs)
|
|
|
|
# backward: optimizer.backward()
|
|
|
|
except Exception as e:
|
|
|
|
model.dump_memory_stats()
|
|
|
|
exit(0)
|
2022-03-30 01:38:44 +00:00
|
|
|
"""
|
|
|
|
if self._use_memory_tracer:
|
2023-06-07 16:01:29 +00:00
|
|
|
self.logger.error(f'dump memory tracer collected information to a {filename}', ranks=[0])
|
2022-03-30 01:38:44 +00:00
|
|
|
if gpc.get_global_rank() == 0:
|
|
|
|
with open(filename, 'w+') as f:
|
2022-04-11 05:38:51 +00:00
|
|
|
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
|
|
|
|
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
|
2022-04-01 01:22:33 +00:00
|
|
|
f.write('CUDA model data (GB)\n')
|
2022-03-30 01:38:44 +00:00
|
|
|
f.write('\n')
|
2022-04-01 01:22:33 +00:00
|
|
|
f.write('CUDA non model data (GB)\n')
|
2022-12-06 08:43:06 +00:00
|
|
|
f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda')))
|
2022-04-01 01:22:33 +00:00
|
|
|
f.write('CPU non model data (GB)\n')
|
2022-12-06 08:43:06 +00:00
|
|
|
f.write(str(self._memstats_collector._memstats.non_model_data_list('cpu')))
|
2022-03-30 01:38:44 +00:00
|
|
|
f.write('\n')
|
|
|
|
|
2022-11-07 08:49:03 +00:00
|
|
|
def _pre_forward_operations(self, *args):
|
2022-04-11 02:46:08 +00:00
|
|
|
# the operation will affect the memory tracer behavior in ZeroHook
|
|
|
|
if self._memstats_collector:
|
|
|
|
self._start_collect_memstats()
|
2022-03-30 07:57:46 +00:00
|
|
|
|
|
|
|
for p in self.module.parameters():
|
2022-03-31 04:25:45 +00:00
|
|
|
if hasattr(p, 'colo_attr'):
|
|
|
|
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
2022-03-30 07:57:46 +00:00
|
|
|
|
2022-04-26 07:05:03 +00:00
|
|
|
self._stateful_tensor_mgr.start_iter()
|
|
|
|
|
2022-03-30 07:57:46 +00:00
|
|
|
def _post_forward_operations(self):
|
|
|
|
for p in self.module.parameters():
|
2022-03-31 04:25:45 +00:00
|
|
|
if hasattr(p, 'colo_attr'):
|
|
|
|
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
2022-03-30 07:57:46 +00:00
|
|
|
|
|
|
|
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
2022-11-07 08:49:03 +00:00
|
|
|
self._pre_forward_operations(*args)
|
2023-06-05 07:58:31 +00:00
|
|
|
cast_fn = cast_tensor_to_bf16 if self.bf16 else cast_tensor_to_fp16
|
|
|
|
args, kwargs = cast_float_arguments(cast_fn, *args, **kwargs)
|
2022-03-01 10:17:01 +00:00
|
|
|
outputs = self.module(*args, **kwargs)
|
2022-03-30 07:57:46 +00:00
|
|
|
self._post_forward_operations()
|
2022-03-01 10:17:01 +00:00
|
|
|
return outputs
|
|
|
|
|
|
|
|
def backward(self, loss):
|
2022-03-02 10:28:29 +00:00
|
|
|
loss.backward()
|
2022-03-15 04:02:19 +00:00
|
|
|
self._post_backward_operations()
|
2022-03-22 09:33:20 +00:00
|
|
|
for ophook in self._ophook_list:
|
|
|
|
ophook.post_iter()
|
2022-03-02 10:28:29 +00:00
|
|
|
|
2022-03-03 07:06:18 +00:00
|
|
|
def backward_by_grad(self, tensor, grad):
|
|
|
|
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
|
2022-03-15 04:02:19 +00:00
|
|
|
self._post_backward_operations()
|
2022-03-22 09:33:20 +00:00
|
|
|
for ophook in self._ophook_list:
|
|
|
|
ophook.post_iter()
|
2022-03-03 07:06:18 +00:00
|
|
|
|
2022-03-23 06:59:59 +00:00
|
|
|
def _update_memstats(self):
|
2022-03-14 14:05:30 +00:00
|
|
|
if self._memstats_collector:
|
2022-04-11 02:46:08 +00:00
|
|
|
self._finish_collect_memstats()
|
2022-03-15 04:02:19 +00:00
|
|
|
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
|
|
|
|
# the way to calculate margin space is based on the assumption that
|
|
|
|
# model data is fixed in cuda during training.
|
|
|
|
# cuda margin space can be used to store OS.
|
2022-12-13 16:47:06 +00:00
|
|
|
self._cuda_margin_space = colo_device_memory_capacity(
|
|
|
|
get_current_device()) - self._memstats_collector._memstats.max_overall_cuda
|
2022-03-14 14:05:30 +00:00
|
|
|
|
2022-03-23 06:59:59 +00:00
|
|
|
@torch.no_grad()
|
|
|
|
def _post_backward_operations(self) -> None:
|
|
|
|
"""
|
|
|
|
The method includes operations required to be processed after backward
|
2022-03-30 07:57:46 +00:00
|
|
|
1. update memory tracer.
|
|
|
|
2. flush the gradient in buckets. Reducing partial gradients in each process.
|
|
|
|
3. shard tensors not dealed in the zero hook
|
|
|
|
4. move sharded param grad payload to param.grad
|
2022-03-23 06:59:59 +00:00
|
|
|
"""
|
2022-03-30 07:57:46 +00:00
|
|
|
# 1. update memory tracer.
|
2022-03-23 06:59:59 +00:00
|
|
|
self._update_memstats()
|
|
|
|
|
2022-03-30 07:57:46 +00:00
|
|
|
# 2. flush the gradient in buckets. Reducing partial gradients in each process.
|
2022-03-02 10:28:29 +00:00
|
|
|
if self._require_backward_grad_sync:
|
|
|
|
# Flush any unreduced buckets in the post_backward stream.
|
|
|
|
with torch.cuda.stream(self.comm_stream):
|
|
|
|
self.reducer.flush()
|
|
|
|
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
|
|
|
self.reducer.free()
|
2022-03-31 10:34:11 +00:00
|
|
|
|
2022-03-30 07:57:46 +00:00
|
|
|
# 3. shard tensors not dealed in the zero hook
|
2022-03-31 10:34:11 +00:00
|
|
|
tensor_list = []
|
|
|
|
for p in self.sharded_params:
|
|
|
|
if not p.colo_attr.param_is_sharded:
|
|
|
|
tensor_list.append(p.colo_attr.sharded_data_tensor)
|
|
|
|
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
2022-04-13 06:54:26 +00:00
|
|
|
p.colo_attr.set_data_none()
|
2022-03-31 10:34:11 +00:00
|
|
|
self.shard_strategy.shard(tensor_list, self.process_group)
|
|
|
|
|
|
|
|
# 4. set all parameters' grad to None
|
2022-03-02 10:28:29 +00:00
|
|
|
for p in self.module.parameters():
|
|
|
|
if not p.requires_grad:
|
|
|
|
continue
|
2022-03-30 07:57:46 +00:00
|
|
|
# Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass.
|
2023-04-26 03:38:43 +00:00
|
|
|
# NOTE() (no-sync)/sync pass: (not conduct)/conduct gradient all reducing between process group.
|
2022-03-30 07:57:46 +00:00
|
|
|
# If _require_backward_grad_sync is True,
|
|
|
|
# p.grad remains the accumulated unsharded gradient from prior no-sync passes.
|
|
|
|
# We also allows to interleave no-sync pass with sync passes, if desired.
|
2022-03-02 10:28:29 +00:00
|
|
|
if not self._require_backward_grad_sync:
|
|
|
|
continue
|
2022-03-31 10:34:11 +00:00
|
|
|
|
2022-03-30 10:14:50 +00:00
|
|
|
p.grad = None
|
2022-03-02 10:28:29 +00:00
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
|
|
|
"""
|
|
|
|
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
|
2022-03-03 04:42:57 +00:00
|
|
|
full gradient for the local batch. The reduce-scatter op will save
|
2022-03-03 07:06:18 +00:00
|
|
|
a single shard of the summed gradient across all
|
2022-03-31 04:25:45 +00:00
|
|
|
GPUs to param.colo_attr.grad. This shard will align with the current GPU rank. For example::
|
2022-03-02 10:28:29 +00:00
|
|
|
|
|
|
|
before reduce_scatter:
|
|
|
|
param.grad (GPU #0): [1, 2, 3, 4]
|
|
|
|
param.grad (GPU #1): [5, 6, 7, 8]
|
|
|
|
|
|
|
|
after reduce_scatter:
|
|
|
|
param.grad (GPU #0): [6, 8] # 1+5, 2+6
|
|
|
|
param.grad (GPU #1): [10, 12] # 3+7, 4+8
|
|
|
|
|
|
|
|
The local GPU's ``optim.step`` is responsible for updating a single
|
|
|
|
shard of params, also corresponding to the current GPU's rank. This
|
2022-03-31 04:25:45 +00:00
|
|
|
alignment is created by `param.colo_attr.grad`, which ensures that
|
2022-03-02 10:28:29 +00:00
|
|
|
the local optimizer only sees the relevant parameter shard.
|
|
|
|
"""
|
|
|
|
if grad is None:
|
|
|
|
return
|
|
|
|
assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients'
|
|
|
|
if not self._require_backward_grad_sync:
|
|
|
|
return
|
2022-04-13 01:59:05 +00:00
|
|
|
# used to cheat Pytorch, since we can't return None
|
|
|
|
empty_grad = torch.empty_like(grad)
|
|
|
|
free_storage(empty_grad)
|
|
|
|
# As torch didn't allow modifying grad in hook, we make a copy
|
|
|
|
grad = grad.clone()
|
2022-04-11 05:38:51 +00:00
|
|
|
if param.colo_attr.is_replicated:
|
|
|
|
self._reduce_scatter_handler(param, grad)
|
|
|
|
else:
|
|
|
|
self._save_grad(param, grad)
|
|
|
|
return empty_grad
|
|
|
|
|
|
|
|
def _reduce_scatter_handler(self, param: Parameter, grad: torch.Tensor) -> None:
|
2022-03-02 10:28:29 +00:00
|
|
|
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
|
with torch.cuda.stream(self.comm_stream):
|
2022-03-08 10:18:06 +00:00
|
|
|
if self.fp32_reduce_scatter:
|
2022-04-13 01:59:05 +00:00
|
|
|
grad.data = grad.data.to(param.dtype)
|
2022-03-02 10:28:29 +00:00
|
|
|
if self.gradient_predivide_factor > 1.0:
|
|
|
|
# Average grad by world_size for consistency with PyTorch DDP.
|
2022-04-13 01:59:05 +00:00
|
|
|
grad.data.div_(self.gradient_predivide_factor)
|
2022-03-02 10:28:29 +00:00
|
|
|
if self.world_size > 1:
|
2022-04-13 01:59:05 +00:00
|
|
|
grad_chunks = chunk_and_pad(grad, self.reduce_scatter_process_group.size())
|
2022-03-03 04:42:57 +00:00
|
|
|
self.reducer.reduce_scatter_async(grad_chunks,
|
|
|
|
group=self.reduce_scatter_process_group,
|
|
|
|
callback_fn=functools.partial(self._reduce_scatter_callback, param))
|
2022-03-02 10:28:29 +00:00
|
|
|
else:
|
2022-04-13 01:59:05 +00:00
|
|
|
self._reduce_scatter_callback(param, grad)
|
2022-03-16 05:40:19 +00:00
|
|
|
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
2022-03-02 10:28:29 +00:00
|
|
|
|
|
|
|
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
|
2022-03-30 05:51:37 +00:00
|
|
|
assert isinstance(reduced_grad,
|
|
|
|
torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}"
|
2022-04-14 06:56:46 +00:00
|
|
|
reduced_grad.data = reduced_grad.data.contiguous().view(-1)
|
2022-03-02 10:28:29 +00:00
|
|
|
if self.gradient_postdivide_factor > 1:
|
|
|
|
# Average grad by world_size for consistency with PyTorch DDP.
|
|
|
|
reduced_grad.data.div_(self.gradient_postdivide_factor)
|
2022-04-11 05:38:51 +00:00
|
|
|
self._save_grad(param, reduced_grad)
|
|
|
|
|
|
|
|
# FIXME(ver217): refactor the below line when impl eviction policy
|
|
|
|
def _save_grad(self, param: Parameter, grad: torch.Tensor):
|
2022-04-11 07:40:13 +00:00
|
|
|
|
|
|
|
# record whether we have overflow
|
|
|
|
self.overflow_counter += torch.isinf(grad).any().item()
|
|
|
|
self.overflow_counter += torch.isnan(grad).any().item()
|
|
|
|
|
2022-04-11 05:38:51 +00:00
|
|
|
# move gradient to cpu
|
2022-03-31 08:26:54 +00:00
|
|
|
if param.colo_attr.offload_grad:
|
2022-04-11 05:38:51 +00:00
|
|
|
colo_model_data_move_to_cpu(grad)
|
|
|
|
|
2022-03-23 06:59:59 +00:00
|
|
|
if self.reuse_fp16_shard:
|
2022-04-11 05:38:51 +00:00
|
|
|
# make parameters point to gradient
|
|
|
|
|
2022-03-31 08:26:54 +00:00
|
|
|
assert param.colo_attr.saved_grad.is_null(
|
2023-06-07 16:01:29 +00:00
|
|
|
), 'Gradient accumulation is not supported when reuse_fp16_shard=True'
|
2022-04-11 05:38:51 +00:00
|
|
|
|
2022-04-24 05:08:48 +00:00
|
|
|
param.colo_attr.grad_payload_reset(grad.data)
|
|
|
|
# release the memory of param
|
|
|
|
# we set a false None for parameter's payload
|
2023-04-26 03:38:43 +00:00
|
|
|
# so we can get parameter's device and dtype later in optimizer
|
2022-04-24 05:08:48 +00:00
|
|
|
param.colo_attr.data_payload_reset(torch.empty(0, device=grad.device, dtype=grad.dtype))
|
2022-04-11 05:38:51 +00:00
|
|
|
|
|
|
|
if param.colo_attr.is_replicated:
|
|
|
|
param.colo_attr.sharded_data_tensor.is_sharded = True
|
2022-03-23 06:59:59 +00:00
|
|
|
else:
|
2022-04-11 05:38:51 +00:00
|
|
|
|
|
|
|
fp32_grad = cast_tensor_to_fp32(grad)
|
|
|
|
|
2022-03-31 08:26:54 +00:00
|
|
|
if param.colo_attr.saved_grad.is_null():
|
2022-04-24 05:08:48 +00:00
|
|
|
param.colo_attr.grad_payload_reset(fp32_grad)
|
2022-03-31 08:26:54 +00:00
|
|
|
else:
|
2022-04-13 06:54:26 +00:00
|
|
|
param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload))
|
2022-04-11 05:38:51 +00:00
|
|
|
|
|
|
|
# keep saved_grad in HOLD state
|
2022-03-31 08:26:54 +00:00
|
|
|
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
|
2022-03-08 10:18:06 +00:00
|
|
|
|
2022-05-27 02:25:08 +00:00
|
|
|
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
|
|
|
return self.module.parameters(recurse=recurse)
|
|
|
|
|
|
|
|
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
|
|
|
|
return self.module.named_parameters(prefix, recurse)
|
|
|
|
|
2022-03-08 10:18:06 +00:00
|
|
|
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
2022-05-27 02:25:08 +00:00
|
|
|
return self._colo_state_dict(destination,
|
|
|
|
prefix,
|
|
|
|
keep_vars,
|
|
|
|
shard_strategy=self.shard_strategy,
|
|
|
|
state_dict_func=nn.Module.state_dict,
|
|
|
|
module_to_load=self.module,
|
|
|
|
sharded_params=self.sharded_params,
|
|
|
|
process_group=self.process_group)
|
|
|
|
|
|
|
|
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True) -> None:
|
|
|
|
for name, p in self.named_parameters():
|
|
|
|
if name in state_dict:
|
|
|
|
p.colo_attr.data_payload_reset(state_dict[name].to(dtype=p.colo_attr.data_payload.dtype,
|
|
|
|
device=p.colo_attr.data_payload.device))
|
|
|
|
# Force re-shard
|
|
|
|
p.colo_attr.sharded_data_tensor.is_sharded = False
|
|
|
|
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor])
|
|
|
|
elif strict:
|
|
|
|
raise RuntimeError(f'Missing key in state_dict: {name}')
|
|
|
|
|
|
|
|
def _colo_state_dict(self,
|
|
|
|
destination=None,
|
|
|
|
prefix='',
|
|
|
|
keep_vars=False,
|
|
|
|
shard_strategy: Optional[BaseShardStrategy] = None,
|
|
|
|
state_dict_func=None,
|
|
|
|
module_to_load=None,
|
|
|
|
sharded_params=[],
|
|
|
|
process_group=None) -> 'OrderedDict[str, torch.Tensor]':
|
|
|
|
if len(sharded_params) == 0:
|
|
|
|
for param in self.parameters():
|
|
|
|
if param.colo_attr.param_is_sharded:
|
|
|
|
sharded_params.append(param)
|
|
|
|
if shard_strategy is not None:
|
|
|
|
shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group)
|
|
|
|
for p in sharded_params:
|
2022-04-13 06:54:26 +00:00
|
|
|
p.data = p.colo_attr.data_payload
|
2022-05-27 02:25:08 +00:00
|
|
|
module_to_load = module_to_load or self
|
2022-07-15 14:11:37 +00:00
|
|
|
gathered_state_dict = state_dict_func(module_to_load, destination, prefix, keep_vars)
|
|
|
|
gathered_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in gathered_state_dict.items()}
|
2022-05-27 02:25:08 +00:00
|
|
|
if shard_strategy is not None:
|
|
|
|
shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group)
|
|
|
|
for p in sharded_params:
|
2022-04-13 06:54:26 +00:00
|
|
|
p.colo_attr.set_data_none()
|
2022-03-08 10:18:06 +00:00
|
|
|
return gathered_state_dict
|
|
|
|
|
2022-05-27 02:25:08 +00:00
|
|
|
def _colo_load_from_state_dict(self,
|
|
|
|
state_dict,
|
|
|
|
prefix,
|
|
|
|
local_metadata,
|
|
|
|
strict,
|
|
|
|
missing_keys,
|
|
|
|
unexpected_keys,
|
|
|
|
error_msgs,
|
|
|
|
shard_strategy=None):
|
|
|
|
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
|
|
|
this module, but not its descendants. This is called on every submodule
|
|
|
|
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
|
|
|
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
|
|
|
For state dicts without metadata, :attr:`local_metadata` is empty.
|
|
|
|
Subclasses can achieve class-specific backward compatible loading using
|
|
|
|
the version number at `local_metadata.get("version", None)`.
|
|
|
|
|
|
|
|
.. note::
|
|
|
|
:attr:`state_dict` is not the same object as the input
|
|
|
|
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
|
|
|
|
it can be modified.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
state_dict (dict): a dict containing parameters and
|
|
|
|
persistent buffers.
|
|
|
|
prefix (str): the prefix for parameters and buffers used in this
|
|
|
|
module
|
|
|
|
local_metadata (dict): a dict containing the metadata for this module.
|
|
|
|
See
|
|
|
|
strict (bool): whether to strictly enforce that the keys in
|
|
|
|
:attr:`state_dict` with :attr:`prefix` match the names of
|
|
|
|
parameters and buffers in this module
|
|
|
|
missing_keys (list of str): if ``strict=True``, add missing keys to
|
|
|
|
this list
|
|
|
|
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
|
|
|
keys to this list
|
|
|
|
error_msgs (list of str): error messages should be added to this
|
|
|
|
list, and will be reported together in
|
|
|
|
:meth:`~torch.nn.Module.load_state_dict`
|
2023-03-21 03:48:21 +00:00
|
|
|
shard_strategy (Optional[BaseShardStrategy], optional): A shard strategy to manage shard behavior. Defaults to None.
|
2022-05-27 02:25:08 +00:00
|
|
|
"""
|
|
|
|
for hook in self._load_state_dict_pre_hooks.values():
|
|
|
|
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
|
|
|
|
|
|
|
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
|
|
|
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
|
|
|
local_state = {k: v for k, v in local_name_params if v is not None}
|
|
|
|
|
|
|
|
for name, param in local_state.items():
|
|
|
|
key = prefix + name
|
|
|
|
if key in state_dict:
|
|
|
|
input_param = state_dict[key]
|
|
|
|
if hasattr(param, 'colo_attr'):
|
|
|
|
param.colo_attr.data_payload_reset(
|
|
|
|
input_param.to(dtype=param.colo_attr.data_payload.dtype,
|
|
|
|
device=param.colo_attr.data_payload.device))
|
|
|
|
if shard_strategy is not None:
|
|
|
|
# Force re-shard
|
|
|
|
param.colo_attr.sharded_data_tensor.is_sharded = False
|
|
|
|
shard_strategy.shard([param.colo_attr.sharded_data_tensor])
|
|
|
|
else:
|
|
|
|
# This is used to avoid copying uninitialized parameters into
|
|
|
|
# non-lazy modules, since they dont have the hook to do the checks
|
|
|
|
# in such case, it will error when accessing the .shape attribute.
|
|
|
|
is_param_lazy = torch.nn.parameter.is_lazy(param)
|
|
|
|
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
|
|
|
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
|
|
|
|
input_param = input_param[0]
|
|
|
|
|
|
|
|
if not is_param_lazy and input_param.shape != param.shape:
|
|
|
|
# local shape should match the one in checkpoint
|
|
|
|
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
|
|
|
'the shape in current model is {}.'.format(
|
|
|
|
key, input_param.shape, param.shape))
|
|
|
|
continue
|
|
|
|
try:
|
|
|
|
with torch.no_grad():
|
|
|
|
param.copy_(input_param)
|
|
|
|
except Exception as ex:
|
|
|
|
error_msgs.append('While copying the parameter named "{}", '
|
|
|
|
'whose dimensions in the model are {} and '
|
|
|
|
'whose dimensions in the checkpoint are {}, '
|
|
|
|
'an exception occurred : {}.'.format(key, param.size(), input_param.size(),
|
|
|
|
ex.args))
|
|
|
|
elif strict:
|
|
|
|
missing_keys.append(key)
|
|
|
|
|
|
|
|
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
|
|
|
if getattr(self.__class__, "set_extra_state", nn.Module.set_extra_state) is not nn.Module.set_extra_state:
|
|
|
|
if extra_state_key in state_dict:
|
|
|
|
self.set_extra_state(state_dict[extra_state_key])
|
|
|
|
elif strict:
|
|
|
|
missing_keys.append(extra_state_key)
|
|
|
|
elif strict and (extra_state_key in state_dict):
|
|
|
|
unexpected_keys.append(extra_state_key)
|
|
|
|
|
|
|
|
if strict:
|
|
|
|
for key in state_dict.keys():
|
|
|
|
if key.startswith(prefix) and key != extra_state_key:
|
|
|
|
input_name = key[len(prefix):]
|
|
|
|
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
|
|
|
|
if input_name not in self._modules and input_name not in local_state:
|
|
|
|
unexpected_keys.append(key)
|
2022-03-21 08:55:37 +00:00
|
|
|
|
|
|
|
def __getitem__(self, idx: int):
|
|
|
|
assert isinstance(self.module, nn.ModuleList)
|
|
|
|
return self.module[idx]
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
assert isinstance(self.module, nn.ModuleList)
|
|
|
|
return len(self.module)
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
assert isinstance(self.module, nn.ModuleList)
|
|
|
|
return iter(self.module)
|