mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement communication hook for FSDP params all-gather * added unit test for fp8 operators * support fp8 communication in GeminiPlugin * update training scripts to support fsdp and fp8 communication * fixed some minor bugs observed in unit test * add all_gather_into_tensor_flat_fp8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * add skip the test if torch < 2.2.0 * add fp8_comm flag * rebase latest fp8 operators * rebase latest fp8 operators * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5978/head
Hanks
4 months ago
committed by
GitHub
14 changed files with 602 additions and 14 deletions
@ -0,0 +1,112 @@
|
||||
import torch |
||||
import torch.distributed as dist |
||||
from packaging import version |
||||
from torch import Tensor |
||||
from torch.distributed.fsdp._common_utils import _no_dispatch_record_stream |
||||
from torch.distributed.utils import _p_assert |
||||
|
||||
|
||||
def _all_gather_flat_param( |
||||
self, |
||||
padded_unsharded_flat_param: Tensor, |
||||
) -> Tensor: |
||||
""" |
||||
All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``. |
||||
|
||||
Then switch to use the all-gathered tensor. |
||||
""" |
||||
_p_assert( |
||||
hasattr(self, "process_group") and hasattr(self, "world_size"), |
||||
"Expects a process group and world size to have been set via `shard()`", |
||||
) |
||||
sharded_flat_param = self.flat_param.data |
||||
expected_numel = sharded_flat_param.numel() * self.world_size |
||||
_p_assert( |
||||
padded_unsharded_flat_param.numel() == expected_numel, |
||||
f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}", |
||||
) |
||||
|
||||
pg = self._fake_process_group if self._use_fake_all_gather else self.process_group |
||||
|
||||
# HACK this should be handled by C10D |
||||
if sharded_flat_param.is_cpu: # type: ignore[attr-defined] |
||||
tensor_list = list(torch.chunk(padded_unsharded_flat_param, dist.get_world_size(pg))) |
||||
work = dist.all_gather(tensor_list, sharded_flat_param, group=pg) |
||||
else: |
||||
if self._comm_hook is None: |
||||
dist.all_gather_into_tensor( |
||||
padded_unsharded_flat_param, |
||||
sharded_flat_param, |
||||
pg, |
||||
) |
||||
else: |
||||
self._comm_hook(None, padded_unsharded_flat_param, sharded_flat_param, pg) |
||||
|
||||
if self._offload_params: |
||||
# In case of offloading, `flat_param.data` (i.e. sharded param) is |
||||
# created on the pre-unshard stream. We need to hand it over to the |
||||
# unshard stream for all-gather |
||||
_no_dispatch_record_stream( |
||||
sharded_flat_param, |
||||
self._device_handle.current_stream(), # unshard_stream |
||||
) |
||||
return padded_unsharded_flat_param |
||||
|
||||
|
||||
def register_params_comm_hook(self, state: object, hook: callable): |
||||
"""Register a communication hook for FlatParamHandle. |
||||
|
||||
This is an enhancement that provides a flexible hook to users where they can specify how FSDP unshards |
||||
parameters across multiple workers. |
||||
|
||||
.. warning :: |
||||
FSDP communication hook should be registered before running an initial forward pass |
||||
and only once. |
||||
|
||||
Args: |
||||
state (object): Passed to the hook to maintain any state information during the training process. |
||||
hook (Callable): Callable, which has one of the following signatures: |
||||
1) ``hook: Callable[torch.Tensor] -> None``: |
||||
This function takes in a Python tensor, which represents |
||||
the full, flattened, unsharded gradient with respect to all variables |
||||
corresponding to the model this FSDP unit is wrapping |
||||
(that are not wrapped by other FSDP sub-units). |
||||
It then performs all necessary processing and returns ``None``; |
||||
2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``: |
||||
This function takes in two Python tensors, the first one represents |
||||
the full, flattened, unsharded gradient with respect to all variables |
||||
corresponding to the model this FSDP unit is wrapping |
||||
(that are not wrapped by other FSDP sub-units). The latter |
||||
represents a pre-sized tensor to store a chunk of a sharded gradient after |
||||
reduction. |
||||
In both cases, callable performs all necessary processing and returns ``None``. |
||||
Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case. |
||||
Callables with signature 2 are expected to handle gradient communication for sharded cases. |
||||
|
||||
""" |
||||
if not self.check_is_root(): |
||||
raise AssertionError("register_comm_hook can only be called on a root instance.") |
||||
|
||||
# if fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES: |
||||
# raise AssertionError( |
||||
# f"Communication hook is not supported for hybrid strategies: {fsdp_state.sharding_strategy}" |
||||
# ) |
||||
if self._handle._comm_hook is not None: |
||||
raise AssertionError("A communication hook is already registered") |
||||
if not callable(hook): |
||||
raise ValueError(f"The communication hook must be callable but got {hook}") |
||||
self._handle._comm_hook = hook |
||||
self._handle._comm_hook_state = state |
||||
|
||||
|
||||
def patch_fsdp_params_comm_hook(): |
||||
if version.parse(torch.__version__) >= version.parse("2.2.0"): |
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
||||
from torch.distributed.fsdp._flat_param import FlatParamHandle |
||||
|
||||
FlatParamHandle._comm_hook = None |
||||
FlatParamHandle._comm_hook_state = None |
||||
FlatParamHandle._all_gather_flat_param = _all_gather_flat_param |
||||
FSDP.register_params_comm_hook = register_params_comm_hook |
||||
else: |
||||
raise RuntimeError("This fsdp_params_comm_hook patch is not supported while torch version under 2.2.0.") |
@ -0,0 +1,26 @@
|
||||
import torch |
||||
from torch.testing import assert_close |
||||
|
||||
from colossalai.accelerator import get_accelerator |
||||
from colossalai.quantization.fp8 import cast_from_fp8, cast_from_fp8_pipeline, cast_to_fp8, cast_to_fp8_pipeline |
||||
from colossalai.testing import parameterize |
||||
|
||||
|
||||
@parameterize("shape", [(100, 10), (10, 100), (3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)]) |
||||
@parameterize("dtype", [torch.bfloat16, torch.float16, torch.float32]) |
||||
@parameterize("fp8_format", ["e4m3", "e5m2"]) |
||||
def test_fp8_cast(shape, dtype, fp8_format): |
||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) |
||||
ret, scale_inv = cast_to_fp8(x, fp8_format=fp8_format) |
||||
out = cast_from_fp8(ret, scale_inv, x.dtype) |
||||
assert_close(out, x, rtol=0.1, atol=0.1) |
||||
|
||||
if x.size(-1) % 2 == 0: |
||||
inp_dict = {"hidden_states": x.clone()} |
||||
cast_to_fp8_pipeline(inp_dict) |
||||
cast_from_fp8_pipeline(inp_dict) |
||||
assert_close(inp_dict["hidden_states"], x, rtol=0.1, atol=0.1) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
test_fp8_cast() |
@ -0,0 +1,87 @@
|
||||
import os |
||||
|
||||
import torch |
||||
import torch.distributed as dist |
||||
import torch.multiprocessing as mp |
||||
import torch.nn as nn |
||||
import torch.optim as optim |
||||
from torch.nn.parallel import DistributedDataParallel as DDP |
||||
from torch.testing import assert_close |
||||
|
||||
# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html |
||||
|
||||
|
||||
def setup(rank, world_size): |
||||
os.environ["MASTER_ADDR"] = "localhost" |
||||
os.environ["MASTER_PORT"] = "12355" |
||||
|
||||
# initialize the process group |
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size) |
||||
|
||||
|
||||
def cleanup(): |
||||
dist.destroy_process_group() |
||||
|
||||
|
||||
class ToyModel(nn.Module): |
||||
def __init__(self): |
||||
super(ToyModel, self).__init__() |
||||
self.net1 = nn.Linear(10, 10) |
||||
self.relu = nn.ReLU() |
||||
self.net2 = nn.Linear(10, 5) |
||||
|
||||
def forward(self, x): |
||||
return self.net2(self.relu(self.net1(x))) |
||||
|
||||
|
||||
def demo_basic(rank, world_size): |
||||
print(f"Running basic DDP example on rank {rank}.") |
||||
setup(rank, world_size) |
||||
|
||||
def get_grads_after_one_iteration(hook=None): |
||||
torch.manual_seed(0) |
||||
# create model and move it to GPU with id rank |
||||
model = ToyModel().to(rank) |
||||
|
||||
ddp_model = DDP(model, device_ids=[rank]) |
||||
|
||||
if hook is not None: |
||||
ddp_model.register_comm_hook(None, hook) |
||||
|
||||
loss_fn = nn.MSELoss() |
||||
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) |
||||
|
||||
optimizer.zero_grad() |
||||
outputs = ddp_model(torch.randn(20, 10)) |
||||
labels = torch.randn(20, 5).to(rank) |
||||
loss_fn(outputs, labels).backward() |
||||
optimizer.step() |
||||
|
||||
torch.distributed.barrier() |
||||
|
||||
grad_dict = {} |
||||
for name, params in ddp_model.named_parameters(): |
||||
grad_dict[name] = params.grad |
||||
return grad_dict |
||||
|
||||
from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async, fp8_compress_ddp_grad_comm_hook_sync |
||||
|
||||
grad_dict = get_grads_after_one_iteration() |
||||
for hook in [fp8_compress_ddp_grad_comm_hook_sync, fp8_compress_ddp_grad_comm_hook_async]: |
||||
grad_dict_w_hook = get_grads_after_one_iteration(hook) |
||||
if dist.get_rank() == 0: |
||||
for name in grad_dict: |
||||
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1) |
||||
|
||||
cleanup() |
||||
|
||||
|
||||
def run_demo(demo_fn, world_size): |
||||
mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
n_gpus = torch.cuda.device_count() |
||||
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" |
||||
world_size = n_gpus |
||||
run_demo(demo_basic, world_size) |
@ -0,0 +1,107 @@
|
||||
import pytest |
||||
import torch |
||||
import torch.distributed as dist |
||||
import torch.nn as nn |
||||
import torch.optim as optim |
||||
from packaging import version |
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
||||
from torch.testing import assert_close |
||||
|
||||
from colossalai import launch |
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn |
||||
|
||||
# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html |
||||
|
||||
|
||||
def cleanup(): |
||||
dist.destroy_process_group() |
||||
|
||||
|
||||
class ToyModel(nn.Module): |
||||
def __init__(self): |
||||
super(ToyModel, self).__init__() |
||||
self.net1 = nn.Linear(100, 100) |
||||
self.relu = nn.ReLU() |
||||
self.net2 = nn.Linear(100, 50) |
||||
|
||||
def forward(self, x): |
||||
return self.net2(self.relu(self.net1(x))) |
||||
|
||||
|
||||
@parameterize("mode", ["grad", "params"]) |
||||
def run_model(mode): |
||||
rank = dist.get_rank() |
||||
|
||||
from colossalai.quantization.utils import patch_fsdp_params_comm_hook |
||||
|
||||
patch_fsdp_params_comm_hook() |
||||
|
||||
def get_grads_after_one_iteration(grad_hook=None, params_hook=None): |
||||
torch.manual_seed(0) |
||||
# create model and move it to GPU with id rank |
||||
model = ToyModel().to(rank) |
||||
fsdp_model = FSDP(model) |
||||
|
||||
if grad_hook is not None: |
||||
fsdp_model.register_comm_hook(None, grad_hook) |
||||
|
||||
if params_hook is not None: |
||||
fsdp_model.register_params_comm_hook(None, params_hook) |
||||
|
||||
loss_fn = nn.MSELoss() |
||||
optimizer = optim.SGD(fsdp_model.parameters(), lr=0.001) |
||||
|
||||
optimizer.zero_grad() |
||||
outputs = fsdp_model(torch.randn(20, 100)) |
||||
labels = torch.randn(20, 50).to(rank) |
||||
loss_fn(outputs, labels).backward() |
||||
optimizer.step() |
||||
|
||||
torch.distributed.barrier() |
||||
|
||||
grad_dict = {} |
||||
for name, params in fsdp_model.named_parameters(): |
||||
grad_dict[name] = params.grad |
||||
return grad_dict |
||||
|
||||
from colossalai.quantization.fp8 import fp8_compress_fsdp_grad_comm_hook, fp8_compress_fsdp_params_comm_hook |
||||
|
||||
if mode == "grad": |
||||
grad_dict = get_grads_after_one_iteration() |
||||
for hook in [ |
||||
fp8_compress_fsdp_grad_comm_hook, |
||||
]: |
||||
grad_dict_w_hook = get_grads_after_one_iteration(grad_hook=hook) |
||||
if dist.get_rank() == 0: |
||||
for name in grad_dict: |
||||
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1) |
||||
elif mode == "params": |
||||
grad_dict = get_grads_after_one_iteration() |
||||
for hook in [ |
||||
fp8_compress_fsdp_params_comm_hook, |
||||
]: |
||||
grad_dict_w_hook = get_grads_after_one_iteration(params_hook=hook) |
||||
if dist.get_rank() == 0: |
||||
for name in grad_dict: |
||||
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1) |
||||
else: |
||||
raise NotImplementedError |
||||
|
||||
|
||||
def demo_basic(rank, world_size, port): |
||||
print(f"Running basic FSDP example on rank {rank}.") |
||||
launch(rank=rank, world_size=world_size, port=port, host="localhost") |
||||
run_model() |
||||
cleanup() |
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("2.2.0"), reason="torch version < 2.2.0.") |
||||
@rerun_if_address_is_in_use() |
||||
def test_fsdp(): |
||||
n_gpus = torch.cuda.device_count() |
||||
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" |
||||
spawn(demo_basic, n_gpus) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
test_fsdp() |
Loading…
Reference in new issue