ColossalAI/colossalai/quantization/utils.py

113 lines
5.1 KiB
Python
Raw Normal View History

import torch
import torch.distributed as dist
from packaging import version
from torch import Tensor
from torch.distributed.fsdp._common_utils import _no_dispatch_record_stream
from torch.distributed.utils import _p_assert
def _all_gather_flat_param(
self,
padded_unsharded_flat_param: Tensor,
) -> Tensor:
"""
All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``.
Then switch to use the all-gathered tensor.
"""
_p_assert(
hasattr(self, "process_group") and hasattr(self, "world_size"),
"Expects a process group and world size to have been set via `shard()`",
)
sharded_flat_param = self.flat_param.data
expected_numel = sharded_flat_param.numel() * self.world_size
_p_assert(
padded_unsharded_flat_param.numel() == expected_numel,
f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
)
pg = self._fake_process_group if self._use_fake_all_gather else self.process_group
# HACK this should be handled by C10D
if sharded_flat_param.is_cpu: # type: ignore[attr-defined]
tensor_list = list(torch.chunk(padded_unsharded_flat_param, dist.get_world_size(pg)))
work = dist.all_gather(tensor_list, sharded_flat_param, group=pg)
else:
if self._comm_hook is None:
dist.all_gather_into_tensor(
padded_unsharded_flat_param,
sharded_flat_param,
pg,
)
else:
self._comm_hook(None, padded_unsharded_flat_param, sharded_flat_param, pg)
if self._offload_params:
# In case of offloading, `flat_param.data` (i.e. sharded param) is
# created on the pre-unshard stream. We need to hand it over to the
# unshard stream for all-gather
_no_dispatch_record_stream(
sharded_flat_param,
self._device_handle.current_stream(), # unshard_stream
)
return padded_unsharded_flat_param
def register_params_comm_hook(self, state: object, hook: callable):
"""Register a communication hook for FlatParamHandle.
This is an enhancement that provides a flexible hook to users where they can specify how FSDP unshards
parameters across multiple workers.
.. warning ::
FSDP communication hook should be registered before running an initial forward pass
and only once.
Args:
state (object): Passed to the hook to maintain any state information during the training process.
hook (Callable): Callable, which has one of the following signatures:
1) ``hook: Callable[torch.Tensor] -> None``:
This function takes in a Python tensor, which represents
the full, flattened, unsharded gradient with respect to all variables
corresponding to the model this FSDP unit is wrapping
(that are not wrapped by other FSDP sub-units).
It then performs all necessary processing and returns ``None``;
2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``:
This function takes in two Python tensors, the first one represents
the full, flattened, unsharded gradient with respect to all variables
corresponding to the model this FSDP unit is wrapping
(that are not wrapped by other FSDP sub-units). The latter
represents a pre-sized tensor to store a chunk of a sharded gradient after
reduction.
In both cases, callable performs all necessary processing and returns ``None``.
Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case.
Callables with signature 2 are expected to handle gradient communication for sharded cases.
"""
if not self.check_is_root():
raise AssertionError("register_comm_hook can only be called on a root instance.")
# if fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES:
# raise AssertionError(
# f"Communication hook is not supported for hybrid strategies: {fsdp_state.sharding_strategy}"
# )
if self._handle._comm_hook is not None:
raise AssertionError("A communication hook is already registered")
if not callable(hook):
raise ValueError(f"The communication hook must be callable but got {hook}")
self._handle._comm_hook = hook
self._handle._comm_hook_state = state
def patch_fsdp_params_comm_hook():
if version.parse(torch.__version__) >= version.parse("2.2.0"):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._flat_param import FlatParamHandle
FlatParamHandle._comm_hook = None
FlatParamHandle._comm_hook_state = None
FlatParamHandle._all_gather_flat_param = _all_gather_flat_param
FSDP.register_params_comm_hook = register_params_comm_hook
else:
raise RuntimeError("This fsdp_params_comm_hook patch is not supported while torch version under 2.2.0.")