mirror of https://github.com/hpcaitech/ColossalAI
113 lines
5.1 KiB
Python
113 lines
5.1 KiB
Python
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
from packaging import version
|
||
|
from torch import Tensor
|
||
|
from torch.distributed.fsdp._common_utils import _no_dispatch_record_stream
|
||
|
from torch.distributed.utils import _p_assert
|
||
|
|
||
|
|
||
|
def _all_gather_flat_param(
|
||
|
self,
|
||
|
padded_unsharded_flat_param: Tensor,
|
||
|
) -> Tensor:
|
||
|
"""
|
||
|
All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``.
|
||
|
|
||
|
Then switch to use the all-gathered tensor.
|
||
|
"""
|
||
|
_p_assert(
|
||
|
hasattr(self, "process_group") and hasattr(self, "world_size"),
|
||
|
"Expects a process group and world size to have been set via `shard()`",
|
||
|
)
|
||
|
sharded_flat_param = self.flat_param.data
|
||
|
expected_numel = sharded_flat_param.numel() * self.world_size
|
||
|
_p_assert(
|
||
|
padded_unsharded_flat_param.numel() == expected_numel,
|
||
|
f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
|
||
|
)
|
||
|
|
||
|
pg = self._fake_process_group if self._use_fake_all_gather else self.process_group
|
||
|
|
||
|
# HACK this should be handled by C10D
|
||
|
if sharded_flat_param.is_cpu: # type: ignore[attr-defined]
|
||
|
tensor_list = list(torch.chunk(padded_unsharded_flat_param, dist.get_world_size(pg)))
|
||
|
work = dist.all_gather(tensor_list, sharded_flat_param, group=pg)
|
||
|
else:
|
||
|
if self._comm_hook is None:
|
||
|
dist.all_gather_into_tensor(
|
||
|
padded_unsharded_flat_param,
|
||
|
sharded_flat_param,
|
||
|
pg,
|
||
|
)
|
||
|
else:
|
||
|
self._comm_hook(None, padded_unsharded_flat_param, sharded_flat_param, pg)
|
||
|
|
||
|
if self._offload_params:
|
||
|
# In case of offloading, `flat_param.data` (i.e. sharded param) is
|
||
|
# created on the pre-unshard stream. We need to hand it over to the
|
||
|
# unshard stream for all-gather
|
||
|
_no_dispatch_record_stream(
|
||
|
sharded_flat_param,
|
||
|
self._device_handle.current_stream(), # unshard_stream
|
||
|
)
|
||
|
return padded_unsharded_flat_param
|
||
|
|
||
|
|
||
|
def register_params_comm_hook(self, state: object, hook: callable):
|
||
|
"""Register a communication hook for FlatParamHandle.
|
||
|
|
||
|
This is an enhancement that provides a flexible hook to users where they can specify how FSDP unshards
|
||
|
parameters across multiple workers.
|
||
|
|
||
|
.. warning ::
|
||
|
FSDP communication hook should be registered before running an initial forward pass
|
||
|
and only once.
|
||
|
|
||
|
Args:
|
||
|
state (object): Passed to the hook to maintain any state information during the training process.
|
||
|
hook (Callable): Callable, which has one of the following signatures:
|
||
|
1) ``hook: Callable[torch.Tensor] -> None``:
|
||
|
This function takes in a Python tensor, which represents
|
||
|
the full, flattened, unsharded gradient with respect to all variables
|
||
|
corresponding to the model this FSDP unit is wrapping
|
||
|
(that are not wrapped by other FSDP sub-units).
|
||
|
It then performs all necessary processing and returns ``None``;
|
||
|
2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``:
|
||
|
This function takes in two Python tensors, the first one represents
|
||
|
the full, flattened, unsharded gradient with respect to all variables
|
||
|
corresponding to the model this FSDP unit is wrapping
|
||
|
(that are not wrapped by other FSDP sub-units). The latter
|
||
|
represents a pre-sized tensor to store a chunk of a sharded gradient after
|
||
|
reduction.
|
||
|
In both cases, callable performs all necessary processing and returns ``None``.
|
||
|
Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case.
|
||
|
Callables with signature 2 are expected to handle gradient communication for sharded cases.
|
||
|
|
||
|
"""
|
||
|
if not self.check_is_root():
|
||
|
raise AssertionError("register_comm_hook can only be called on a root instance.")
|
||
|
|
||
|
# if fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES:
|
||
|
# raise AssertionError(
|
||
|
# f"Communication hook is not supported for hybrid strategies: {fsdp_state.sharding_strategy}"
|
||
|
# )
|
||
|
if self._handle._comm_hook is not None:
|
||
|
raise AssertionError("A communication hook is already registered")
|
||
|
if not callable(hook):
|
||
|
raise ValueError(f"The communication hook must be callable but got {hook}")
|
||
|
self._handle._comm_hook = hook
|
||
|
self._handle._comm_hook_state = state
|
||
|
|
||
|
|
||
|
def patch_fsdp_params_comm_hook():
|
||
|
if version.parse(torch.__version__) >= version.parse("2.2.0"):
|
||
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||
|
from torch.distributed.fsdp._flat_param import FlatParamHandle
|
||
|
|
||
|
FlatParamHandle._comm_hook = None
|
||
|
FlatParamHandle._comm_hook_state = None
|
||
|
FlatParamHandle._all_gather_flat_param = _all_gather_flat_param
|
||
|
FSDP.register_params_comm_hook = register_params_comm_hook
|
||
|
else:
|
||
|
raise RuntimeError("This fsdp_params_comm_hook patch is not supported while torch version under 2.2.0.")
|