mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
112 lines
5.1 KiB
112 lines
5.1 KiB
import torch |
|
import torch.distributed as dist |
|
from packaging import version |
|
from torch import Tensor |
|
from torch.distributed.fsdp._common_utils import _no_dispatch_record_stream |
|
from torch.distributed.utils import _p_assert |
|
|
|
|
|
def _all_gather_flat_param( |
|
self, |
|
padded_unsharded_flat_param: Tensor, |
|
) -> Tensor: |
|
""" |
|
All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``. |
|
|
|
Then switch to use the all-gathered tensor. |
|
""" |
|
_p_assert( |
|
hasattr(self, "process_group") and hasattr(self, "world_size"), |
|
"Expects a process group and world size to have been set via `shard()`", |
|
) |
|
sharded_flat_param = self.flat_param.data |
|
expected_numel = sharded_flat_param.numel() * self.world_size |
|
_p_assert( |
|
padded_unsharded_flat_param.numel() == expected_numel, |
|
f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}", |
|
) |
|
|
|
pg = self._fake_process_group if self._use_fake_all_gather else self.process_group |
|
|
|
# HACK this should be handled by C10D |
|
if sharded_flat_param.is_cpu: # type: ignore[attr-defined] |
|
tensor_list = list(torch.chunk(padded_unsharded_flat_param, dist.get_world_size(pg))) |
|
work = dist.all_gather(tensor_list, sharded_flat_param, group=pg) |
|
else: |
|
if self._comm_hook is None: |
|
dist.all_gather_into_tensor( |
|
padded_unsharded_flat_param, |
|
sharded_flat_param, |
|
pg, |
|
) |
|
else: |
|
self._comm_hook(None, padded_unsharded_flat_param, sharded_flat_param, pg) |
|
|
|
if self._offload_params: |
|
# In case of offloading, `flat_param.data` (i.e. sharded param) is |
|
# created on the pre-unshard stream. We need to hand it over to the |
|
# unshard stream for all-gather |
|
_no_dispatch_record_stream( |
|
sharded_flat_param, |
|
self._device_handle.current_stream(), # unshard_stream |
|
) |
|
return padded_unsharded_flat_param |
|
|
|
|
|
def register_params_comm_hook(self, state: object, hook: callable): |
|
"""Register a communication hook for FlatParamHandle. |
|
|
|
This is an enhancement that provides a flexible hook to users where they can specify how FSDP unshards |
|
parameters across multiple workers. |
|
|
|
.. warning :: |
|
FSDP communication hook should be registered before running an initial forward pass |
|
and only once. |
|
|
|
Args: |
|
state (object): Passed to the hook to maintain any state information during the training process. |
|
hook (Callable): Callable, which has one of the following signatures: |
|
1) ``hook: Callable[torch.Tensor] -> None``: |
|
This function takes in a Python tensor, which represents |
|
the full, flattened, unsharded gradient with respect to all variables |
|
corresponding to the model this FSDP unit is wrapping |
|
(that are not wrapped by other FSDP sub-units). |
|
It then performs all necessary processing and returns ``None``; |
|
2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``: |
|
This function takes in two Python tensors, the first one represents |
|
the full, flattened, unsharded gradient with respect to all variables |
|
corresponding to the model this FSDP unit is wrapping |
|
(that are not wrapped by other FSDP sub-units). The latter |
|
represents a pre-sized tensor to store a chunk of a sharded gradient after |
|
reduction. |
|
In both cases, callable performs all necessary processing and returns ``None``. |
|
Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case. |
|
Callables with signature 2 are expected to handle gradient communication for sharded cases. |
|
|
|
""" |
|
if not self.check_is_root(): |
|
raise AssertionError("register_comm_hook can only be called on a root instance.") |
|
|
|
# if fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES: |
|
# raise AssertionError( |
|
# f"Communication hook is not supported for hybrid strategies: {fsdp_state.sharding_strategy}" |
|
# ) |
|
if self._handle._comm_hook is not None: |
|
raise AssertionError("A communication hook is already registered") |
|
if not callable(hook): |
|
raise ValueError(f"The communication hook must be callable but got {hook}") |
|
self._handle._comm_hook = hook |
|
self._handle._comm_hook_state = state |
|
|
|
|
|
def patch_fsdp_params_comm_hook(): |
|
if version.parse(torch.__version__) >= version.parse("2.2.0"): |
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
from torch.distributed.fsdp._flat_param import FlatParamHandle |
|
|
|
FlatParamHandle._comm_hook = None |
|
FlatParamHandle._comm_hook_state = None |
|
FlatParamHandle._all_gather_flat_param = _all_gather_flat_param |
|
FSDP.register_params_comm_hook = register_params_comm_hook |
|
else: |
|
raise RuntimeError("This fsdp_params_comm_hook patch is not supported while torch version under 2.2.0.")
|
|
|