[Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928)

* support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* implement communication hook for FSDP params all-gather

* added unit test for fp8 operators

* support fp8 communication in GeminiPlugin

* update training scripts to support fsdp and fp8 communication

* fixed some minor bugs observed in unit test

* add all_gather_into_tensor_flat_fp8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add skip the test if torch < 2.2.0

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add skip the test if torch < 2.2.0

* add skip the test if torch < 2.2.0

* add fp8_comm flag

* rebase latest fp8 operators

* rebase latest fp8 operators

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5978/head
Hanks 2024-08-08 15:55:01 +08:00 committed by GitHub
parent 7739629b9d
commit b480eec738
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 602 additions and 14 deletions

View File

@ -364,6 +364,7 @@ class GeminiPlugin(DPPluginBase):
enable_sequence_overlap: bool = False, enable_sequence_overlap: bool = False,
enable_async_reduce: bool = True, enable_async_reduce: bool = True,
verbose: bool = False, verbose: bool = False,
fp8_communication: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
@ -395,6 +396,7 @@ class GeminiPlugin(DPPluginBase):
master_weights=master_weights, master_weights=master_weights,
max_prefetch=max_prefetch, max_prefetch=max_prefetch,
enable_async_reduce=enable_async_reduce, enable_async_reduce=enable_async_reduce,
fp8_communication=fp8_communication,
) )
self.zero_optim_config = dict( self.zero_optim_config = dict(
gpu_margin_mem_ratio=gpu_margin_mem_ratio, gpu_margin_mem_ratio=gpu_margin_mem_ratio,

View File

@ -177,6 +177,7 @@ class TorchDDPPlugin(DPPluginBase):
check_reduction: bool = False, check_reduction: bool = False,
gradient_as_bucket_view: bool = False, gradient_as_bucket_view: bool = False,
static_graph: bool = False, static_graph: bool = False,
fp8_communication: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.ddp_kwargs = dict( self.ddp_kwargs = dict(
@ -187,6 +188,7 @@ class TorchDDPPlugin(DPPluginBase):
gradient_as_bucket_view=gradient_as_bucket_view, gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph, static_graph=static_graph,
) )
self.fp8_communication = fp8_communication
def support_no_sync(self) -> bool: def support_no_sync(self) -> bool:
return True return True
@ -226,6 +228,11 @@ class TorchDDPPlugin(DPPluginBase):
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer) optimizer = OptimizerWrapper(optimizer)
if self.fp8_communication:
from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async
model.module.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_async)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler
def control_checkpoint_io(self) -> bool: def control_checkpoint_io(self) -> bool:

View File

@ -298,6 +298,7 @@ class TorchFSDPPlugin(DPPluginBase):
ignored_modules: Optional[Iterable[torch.nn.Module]] = None, ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None, param_init_fn: Optional[Callable[[nn.Module], None]] = None,
sync_module_states: bool = False, sync_module_states: bool = False,
fp8_communication: bool = False,
): ):
super().__init__() super().__init__()
self.fsdp_kwargs = dict( self.fsdp_kwargs = dict(
@ -311,6 +312,7 @@ class TorchFSDPPlugin(DPPluginBase):
param_init_fn=param_init_fn, param_init_fn=param_init_fn,
sync_module_states=sync_module_states, sync_module_states=sync_module_states,
) )
self.fp8_communication = fp8_communication
else: else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
@ -347,6 +349,19 @@ class TorchFSDPPlugin(DPPluginBase):
# wrap the model with PyTorch FSDP # wrap the model with PyTorch FSDP
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs) fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
if self.fp8_communication:
from colossalai.quantization.utils import patch_fsdp_params_comm_hook
patch_fsdp_params_comm_hook()
from colossalai.quantization.fp8 import fp8_compress_fsdp_params_comm_hook
fsdp_model.module.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook)
from colossalai.quantization.fp8 import fp8_compress_fsdp_grad_comm_hook
fsdp_model.module.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook)
if optimizer is not None: if optimizer is not None:
if len(optimizer.param_groups) > 1: if len(optimizer.param_groups) > 1:
warnings.warn( warnings.warn(

View File

@ -15,6 +15,7 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -
scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling
is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied. is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied.
fp8_format: e4m3 or e5m2 fp8_format: e4m3 or e5m2
Returns: Returns:
Tuples: A tuple (fp8_tensor, scale) Tuples: A tuple (fp8_tensor, scale)
""" """
@ -29,12 +30,13 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -
per_channel_max = inp.abs().max(dim=-1).values.float() per_channel_max = inp.abs().max(dim=-1).values.float()
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0) per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
scale = fp8_max / per_channel_max[:, None] scale = fp8_max / per_channel_max[:, None]
scale_inv = per_channel_max / fp8_max
else: else:
per_tensor_max = inp.abs().max().float() per_tensor_max = inp.abs().max().float()
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0) per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
scale = fp8_max / per_tensor_max scale = fp8_max / per_tensor_max
scale_inv = 1.0 / scale scale_inv = 1.0 / scale
ret = (scale * inp.float()).to(fp8_type) ret = (scale * inp.float()).to(fp8_type)
return ret, scale_inv return ret, scale_inv
@ -185,7 +187,11 @@ def cast_to_fp8_pipeline(inp: Any) -> None:
return return
assert "hidden_states" in inp, "required by pipeline parallelism." assert "hidden_states" in inp, "required by pipeline parallelism."
assert (
inp["hidden_states"].size(-1) % 2 == 0
), "tensor size(-1) must be divisible by 2 to view Float8_e4m3fn as BFloat16 or Float16"
inp_tensor = inp["hidden_states"] inp_tensor = inp["hidden_states"]
inp_dtype = inp_tensor.dtype
min_val, max_val = inp_tensor.aminmax() min_val, max_val = inp_tensor.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()) amax = torch.maximum(min_val.abs(), max_val.abs())
@ -206,6 +212,7 @@ def cast_to_fp8_pipeline(inp: Any) -> None:
inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type) inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type)
inp["fp8_scale"] = scale.float().reciprocal() inp["fp8_scale"] = scale.float().reciprocal()
inp["dtype"] = torch.zeros_like(scale).to(inp_dtype)
def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
@ -230,10 +237,11 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
else: else:
raise TypeError("Only float16, bfloat16 are implemented.") raise TypeError("Only float16, bfloat16 are implemented.")
inp_tensor.data = inp_tensor.data.view(fp8_type).to(torch.float16) * scale inp_tensor.data = inp_tensor.data.view(fp8_type).to(inp["dtype"]) * scale
if del_metadata: if del_metadata:
del inp["fp8_scale"] del inp["fp8_scale"]
del inp["dtype"]
def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None: def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None:
@ -273,6 +281,199 @@ def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2
output.data = summed_out output.data = summed_out
def fp8_compress_ddp_grad_comm_hook_async(
process_group: dist.ProcessGroup,
bucket: dist.GradBucket,
fp8_format: str = "e5m2",
) -> torch.futures.Future[torch.Tensor]:
"""
Compress by casting ``GradBucket`` to FP8 floating-point format divided by process group size.
This DDP communication hook implements a simple gradient compression approach that casts ``GradBucket`` tensor
to FP8 floating-point format (``torch.float8_e5m2`` or ``torch.bfloat16_e4m3``), and then divides it
by the process group size.
Once compressed gradient tensors are allreduced, the chained callback ``decompress`` casts it back
to the input data type (such as ``float32``).
Example::
>>> ddp_model.register_comm_hook(process_group, fp8_compress_ddp_grad_comm_hook_async)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
input_tensor = bucket.buffer()
world_size = dist.get_world_size()
input_type = input_tensor.dtype
input_device = input_tensor.device
flat_padded_x = input_tensor.flatten()
if flat_padded_x.size(0) % world_size != 0:
pad_size = world_size - flat_padded_x.size(0) % world_size
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
ret, scale = cast_to_fp8(flat_padded_x, fp8_format=fp8_format)
inp = ret.view(torch.uint8)
output_chunks_single = torch.empty_like(inp)
split_sizes = [inp.numel() // world_size for _ in range(world_size)]
fut0 = dist.all_to_all_single(
output_chunks_single,
inp,
output_split_sizes=split_sizes,
input_split_sizes=split_sizes,
group=group_to_use,
async_op=True,
).get_future()
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
fut1 = dist.all_gather_into_tensor(
torch.cat(scale_list, dim=0), scale, group=group_to_use, async_op=True
).get_future()
all_to_all_fut = torch.futures.collect_all([fut0, fut1])
def sum_and_allgather(fut):
output_chunks_single = fut.value()[0].wait()[0]
scale_list_single = fut.value()[1].wait()[0]
output_chunks = list(torch.chunk(output_chunks_single, world_size, dim=0))
scale_list = scale_list_single.chunk(world_size, dim=0)
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
for scale, out in zip(scale_list, output_chunks):
out = out.view(fp8_type)
summed_out += cast_from_fp8(out, scale, input_type)
summed_out.div_(world_size)
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
tensor_list_single = torch.empty(summed_out_fp8.size(0) * world_size, device=input_device, dtype=torch.uint8)
fut2 = dist.all_gather_into_tensor(
tensor_list_single, summed_out_fp8.view(torch.uint8), group=group_to_use, async_op=True
).get_future()
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
fut3 = dist.all_gather_into_tensor(
torch.cat(scale_list, dim=0), scale, group=group_to_use, async_op=True
).get_future()
fut_combined2 = torch.futures.collect_all([fut2, fut3])
return fut_combined2
def decompress(fut):
tensor_list_single = fut.value().wait()[0].value()[0]
scale_list_single = fut.value().wait()[1].value()[0]
tensor_list = list(torch.chunk(tensor_list_single, world_size, dim=0))
scale_list = scale_list_single.chunk(world_size, dim=0)
for i in range(world_size):
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
out = torch.cat(tensor_list, dim=0)
input_tensor_size = input_tensor.numel()
input_shape = input_tensor.shape
out = out[:input_tensor_size]
input_tensor.copy_(out.view(input_shape).to(input_type))
return input_tensor
return all_to_all_fut.then(sum_and_allgather).then(decompress)
def fp8_compress_ddp_grad_comm_hook_sync(
process_group: dist.ProcessGroup,
bucket: dist.GradBucket,
fp8_format="e5m2",
) -> torch.futures.Future[torch.Tensor]:
"""
Return a future that wraps the input, after the input is allreduced. However, the allreduce commnunication is synchronized.
This breaks the overlapping between allreduce communication and backward compuation.
This hook should **only** be used for debugging purposes, instead of the normal gradient synchronization.
For asynchronized implementation, use fp8_compress_ddp_grad_comm_hook_async instead.
Example::
>>> # xdoctest: +SKIP
>>> ddp_model.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_sync)
"""
buffer = bucket.buffer()
all_reduce_fp8(buffer, fp8_format=fp8_format)
fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
fut.set_result(bucket.buffer())
return fut
def fp8_compress_fsdp_grad_comm_hook(
state: object,
unsharded_gradient_flattened: torch.Tensor,
sharded_gradient: torch.Tensor,
group=None,
fp8_format="e5m2",
) -> None:
"""
This communication hook implements a simple gradient compression approach that casts unsharded_gradient_flattened tensor
to FP8 floating-point format (``torch.float8_e5m2`` or ``torch.bfloat16_e4m3``), and then perform scatter_allreduce logic
by using all_to_all and all_gather among the process group.
Example::
>>> fsdp_model.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook)
"""
grad = unsharded_gradient_flattened
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
input_type = grad.dtype
input_device = grad.device
world_size = dist.get_world_size(group=group)
grad_fp8, scale = cast_to_fp8(grad, fp8_format=fp8_format)
uint8_buffer = torch.empty_like(grad_fp8).view(torch.uint8)
dist.all_to_all_single(uint8_buffer, grad_fp8.view(torch.uint8), group=group)
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
dist.all_gather(scale_list, scale, group=group)
buffer_list = list(torch.chunk(uint8_buffer.view(fp8_type), world_size, dim=0))
sharded_gradient.zero_()
for tensor, scale in zip(buffer_list, scale_list):
sharded_gradient += cast_from_fp8(tensor, scale, input_type)
def fp8_compress_fsdp_params_comm_hook(
state: object,
padded_unsharded_flat_param: torch.Tensor,
sharded_flat_param: torch.Tensor,
group=None,
fp8_format="e5m2",
) -> None:
"""
This hook is pending the official support for parameters communication hook in FSDP, e.g. register_params_comm_hook.
Example::
>>> fsdp_model.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook)
"""
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
fp8_max = torch.finfo(fp8_type).max
inp = sharded_flat_param
out = padded_unsharded_flat_param
per_tensor_max = inp.abs().max().float()
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
dist.all_reduce(per_tensor_max, op=torch.distributed.ReduceOp.MAX, group=group)
scale = fp8_max / per_tensor_max
fp8_sharded_flat_param = (scale * inp.float()).to(fp8_type).view(torch.uint8)
fp8_out = torch.empty(out.shape, dtype=torch.uint8, device=out.device)
dist.all_gather_into_tensor(
fp8_out,
fp8_sharded_flat_param,
group=group,
)
padded_unsharded_flat_param.copy_((fp8_out.view(fp8_type).float() / scale).to(out.dtype))
def split_chunk_by_channel( def split_chunk_by_channel(
chunk: torch.Tensor, channel_size: int, num_channels: int, rank: int = 0, world_size: int = 1 chunk: torch.Tensor, channel_size: int, num_channels: int, rank: int = 0, world_size: int = 1
): ):
@ -342,7 +543,7 @@ def all_gather_into_tensor_flat_fp8(
scale_inv = 1.0 / scale scale_inv = 1.0 / scale
buffer = torch.empty_like(output_tensor, dtype=fp8_type) buffer = torch.empty_like(output_tensor, dtype=fp8_type)
dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group) dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group)
numel = np.prod(output_shape) numel = output_shape.numel()
valid_buffer = buffer[:numel].reshape(output_shape) valid_buffer = buffer[:numel].reshape(output_shape)
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2)) valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2))
output_tensor[:numel].copy_(valid_buffer.view(-1)) output_tensor[:numel].copy_(valid_buffer.view(-1))

View File

@ -0,0 +1,112 @@
import torch
import torch.distributed as dist
from packaging import version
from torch import Tensor
from torch.distributed.fsdp._common_utils import _no_dispatch_record_stream
from torch.distributed.utils import _p_assert
def _all_gather_flat_param(
self,
padded_unsharded_flat_param: Tensor,
) -> Tensor:
"""
All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``.
Then switch to use the all-gathered tensor.
"""
_p_assert(
hasattr(self, "process_group") and hasattr(self, "world_size"),
"Expects a process group and world size to have been set via `shard()`",
)
sharded_flat_param = self.flat_param.data
expected_numel = sharded_flat_param.numel() * self.world_size
_p_assert(
padded_unsharded_flat_param.numel() == expected_numel,
f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
)
pg = self._fake_process_group if self._use_fake_all_gather else self.process_group
# HACK this should be handled by C10D
if sharded_flat_param.is_cpu: # type: ignore[attr-defined]
tensor_list = list(torch.chunk(padded_unsharded_flat_param, dist.get_world_size(pg)))
work = dist.all_gather(tensor_list, sharded_flat_param, group=pg)
else:
if self._comm_hook is None:
dist.all_gather_into_tensor(
padded_unsharded_flat_param,
sharded_flat_param,
pg,
)
else:
self._comm_hook(None, padded_unsharded_flat_param, sharded_flat_param, pg)
if self._offload_params:
# In case of offloading, `flat_param.data` (i.e. sharded param) is
# created on the pre-unshard stream. We need to hand it over to the
# unshard stream for all-gather
_no_dispatch_record_stream(
sharded_flat_param,
self._device_handle.current_stream(), # unshard_stream
)
return padded_unsharded_flat_param
def register_params_comm_hook(self, state: object, hook: callable):
"""Register a communication hook for FlatParamHandle.
This is an enhancement that provides a flexible hook to users where they can specify how FSDP unshards
parameters across multiple workers.
.. warning ::
FSDP communication hook should be registered before running an initial forward pass
and only once.
Args:
state (object): Passed to the hook to maintain any state information during the training process.
hook (Callable): Callable, which has one of the following signatures:
1) ``hook: Callable[torch.Tensor] -> None``:
This function takes in a Python tensor, which represents
the full, flattened, unsharded gradient with respect to all variables
corresponding to the model this FSDP unit is wrapping
(that are not wrapped by other FSDP sub-units).
It then performs all necessary processing and returns ``None``;
2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``:
This function takes in two Python tensors, the first one represents
the full, flattened, unsharded gradient with respect to all variables
corresponding to the model this FSDP unit is wrapping
(that are not wrapped by other FSDP sub-units). The latter
represents a pre-sized tensor to store a chunk of a sharded gradient after
reduction.
In both cases, callable performs all necessary processing and returns ``None``.
Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case.
Callables with signature 2 are expected to handle gradient communication for sharded cases.
"""
if not self.check_is_root():
raise AssertionError("register_comm_hook can only be called on a root instance.")
# if fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES:
# raise AssertionError(
# f"Communication hook is not supported for hybrid strategies: {fsdp_state.sharding_strategy}"
# )
if self._handle._comm_hook is not None:
raise AssertionError("A communication hook is already registered")
if not callable(hook):
raise ValueError(f"The communication hook must be callable but got {hook}")
self._handle._comm_hook = hook
self._handle._comm_hook_state = state
def patch_fsdp_params_comm_hook():
if version.parse(torch.__version__) >= version.parse("2.2.0"):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._flat_param import FlatParamHandle
FlatParamHandle._comm_hook = None
FlatParamHandle._comm_hook_state = None
FlatParamHandle._all_gather_flat_param = _all_gather_flat_param
FSDP.register_params_comm_hook = register_params_comm_hook
else:
raise RuntimeError("This fsdp_params_comm_hook patch is not supported while torch version under 2.2.0.")

View File

@ -166,6 +166,7 @@ class Chunk:
self.grad_chunk = None self.grad_chunk = None
# the async all-reduce/reduce-scatter work of this grad chunk (None means sync) # the async all-reduce/reduce-scatter work of this grad chunk (None means sync)
self.grad_reduce_work = None self.grad_reduce_work = None
self.fp8_communication = False
@property @property
def memory_usage(self) -> Dict[str, int]: def memory_usage(self) -> Dict[str, int]:
@ -521,6 +522,14 @@ class Chunk:
alloc_storage(self.cuda_global_chunk) alloc_storage(self.cuda_global_chunk)
assert self.cuda_global_chunk.is_contiguous() assert self.cuda_global_chunk.is_contiguous()
if self.fp8_communication:
assert async_op == False, "fp8 all-gather does not support async_op!"
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
work = all_gather_into_tensor_flat_fp8(
self.cuda_global_chunk, self.cuda_shard, self.cuda_global_chunk.shape, self.torch_pg
)
else:
work = dist.all_gather_into_tensor( work = dist.all_gather_into_tensor(
self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op
) )

View File

@ -26,6 +26,7 @@ class ChunkManager:
init_device: Optional[torch.device] = None, init_device: Optional[torch.device] = None,
reuse_fp16_chunk: bool = True, reuse_fp16_chunk: bool = True,
max_prefetch: int = 0, max_prefetch: int = 0,
fp8_communication: bool = False,
) -> None: ) -> None:
self.device = init_device or get_accelerator().get_current_device() self.device = init_device or get_accelerator().get_current_device()
self.dp_degree_chunk_size_dict: Dict[int, int] = dict() self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
@ -44,6 +45,7 @@ class ChunkManager:
self.accumulating_grads = False self.accumulating_grads = False
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device()) self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
self._prefetch_stream = get_accelerator().Stream() if max_prefetch else None self._prefetch_stream = get_accelerator().Stream() if max_prefetch else None
self.fp8_communication = fp8_communication
def register_tensor( def register_tensor(
self, self,
@ -101,6 +103,8 @@ class ChunkManager:
extra_dp_group=extra_dp_group, extra_dp_group=extra_dp_group,
**chunk_kwargs, **chunk_kwargs,
) )
if self.fp8_communication:
chunk.fp8_communication = True
chunk_group.append(chunk) chunk_group.append(chunk)
chunk.append_tensor(tensor) chunk.append_tensor(tensor)

View File

@ -98,6 +98,7 @@ class GeminiDDP(ModelWrapper):
extra_dp_group: Optional[ProcessGroup] = None, extra_dp_group: Optional[ProcessGroup] = None,
verbose: bool = False, verbose: bool = False,
enable_async_reduce: bool = True, enable_async_reduce: bool = True,
fp8_communication: bool = False,
) -> None: ) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16) assert mixed_precision in (torch.float16, torch.bfloat16)
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
@ -122,6 +123,8 @@ class GeminiDDP(ModelWrapper):
verbose=verbose, verbose=verbose,
max_prefetch=max_prefetch, max_prefetch=max_prefetch,
) )
if fp8_communication:
self.chunk_manager.fp8_communication = True
self.gemini_manager = GeminiManager( self.gemini_manager = GeminiManager(
placement_policy, placement_policy,
self.chunk_manager, self.chunk_manager,

View File

@ -179,7 +179,7 @@ def main():
"--plugin", "--plugin",
type=str, type=str,
default="torch_ddp", default="torch_ddp",
choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel"], choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel", "torch_fsdp"],
help="plugin to use", help="plugin to use",
) )
parser.add_argument( parser.add_argument(
@ -215,9 +215,9 @@ def main():
if args.plugin == "torch_ddp_fp16": if args.plugin == "torch_ddp_fp16":
booster_kwargs["mixed_precision"] = "fp16" booster_kwargs["mixed_precision"] = "fp16"
if args.plugin.startswith("torch_ddp"): if args.plugin.startswith("torch_ddp"):
plugin = TorchDDPPlugin() plugin = TorchDDPPlugin(fp8_communication=args.use_fp8_comm)
elif args.plugin == "gemini": elif args.plugin == "gemini":
plugin = GeminiPlugin(initial_scale=2**5) plugin = GeminiPlugin(initial_scale=2**5, fp8_communication=args.use_fp8_comm)
elif args.plugin == "low_level_zero": elif args.plugin == "low_level_zero":
plugin = LowLevelZeroPlugin(initial_scale=2**5) plugin = LowLevelZeroPlugin(initial_scale=2**5)
elif args.plugin == "hybrid_parallel": elif args.plugin == "hybrid_parallel":
@ -235,6 +235,17 @@ def main():
initial_scale=1, initial_scale=1,
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
) )
elif args.plugin == "torch_fsdp":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from colossalai.booster.plugin import TorchFSDPPlugin
plugin = TorchFSDPPlugin(
mixed_precision=MixedPrecision(
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
),
fp8_communication=args.use_fp8_comm,
)
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin, **booster_kwargs)

View File

@ -212,7 +212,7 @@ def main():
if args.plugin == "torch_ddp_fp16": if args.plugin == "torch_ddp_fp16":
booster_kwargs["mixed_precision"] = "fp16" booster_kwargs["mixed_precision"] = "fp16"
if args.plugin.startswith("torch_ddp"): if args.plugin.startswith("torch_ddp"):
plugin = TorchDDPPlugin() plugin = TorchDDPPlugin(fp8_communication=args.use_fp8_comm)
elif args.plugin == "gemini": elif args.plugin == "gemini":
plugin = GeminiPlugin(initial_scale=2**5) plugin = GeminiPlugin(initial_scale=2**5)
elif args.plugin == "low_level_zero": elif args.plugin == "low_level_zero":

View File

@ -98,7 +98,7 @@ def main():
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation") parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true") parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--overlap_allgather", action="store_true") parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
args = parser.parse_args() args = parser.parse_args()
colossalai.launch_from_torch() colossalai.launch_from_torch()
@ -158,6 +158,7 @@ def main():
buffer_dtype=torch.float16, buffer_dtype=torch.float16,
), ),
param_init_fn=empty_init(), param_init_fn=empty_init(),
fp8_communication=args.use_fp8_comm,
) )
else: else:
plugin = TorchFSDPPlugin( plugin = TorchFSDPPlugin(
@ -165,7 +166,8 @@ def main():
param_dtype=torch.float16, param_dtype=torch.float16,
reduce_dtype=torch.float16, reduce_dtype=torch.float16,
buffer_dtype=torch.float16, buffer_dtype=torch.float16,
) ),
fp8_communication=args.use_fp8_comm,
) )
elif args.plugin == "fsdp_cpu": elif args.plugin == "fsdp_cpu":
if use_empty_init: if use_empty_init:
@ -177,6 +179,7 @@ def main():
), ),
cpu_offload=CPUOffload(offload_params=True), cpu_offload=CPUOffload(offload_params=True),
param_init_fn=empty_init(), param_init_fn=empty_init(),
fp8_communication=args.use_fp8_comm,
) )
else: else:
plugin = TorchFSDPPlugin( plugin = TorchFSDPPlugin(
@ -186,6 +189,7 @@ def main():
buffer_dtype=torch.float16, buffer_dtype=torch.float16,
), ),
cpu_offload=CPUOffload(offload_params=True), cpu_offload=CPUOffload(offload_params=True),
fp8_communication=args.use_fp8_comm,
) )
elif args.plugin == "3d": elif args.plugin == "3d":
plugin = HybridParallelPlugin( plugin = HybridParallelPlugin(
@ -200,9 +204,9 @@ def main():
enable_flash_attention=args.xformers, enable_flash_attention=args.xformers,
microbatch_size=args.mbs, microbatch_size=args.mbs,
precision="bf16", precision="bf16",
dp_outside=False,
overlap_p2p=args.overlap, overlap_p2p=args.overlap,
enable_metadata_cache=not args.no_cache, enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
**hybrid_kwargs, **hybrid_kwargs,
) )
elif args.plugin == "3d_cpu": elif args.plugin == "3d_cpu":
@ -293,7 +297,7 @@ def main():
with get_profile_context( with get_profile_context(
args.profile, args.profile,
args.ignore_steps, args.ignore_steps,
1, # avoid creating massive log files len(dataloader) - 1,
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
) as prof: ) as prof:
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:

View File

@ -0,0 +1,26 @@
import torch
from torch.testing import assert_close
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import cast_from_fp8, cast_from_fp8_pipeline, cast_to_fp8, cast_to_fp8_pipeline
from colossalai.testing import parameterize
@parameterize("shape", [(100, 10), (10, 100), (3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
@parameterize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@parameterize("fp8_format", ["e4m3", "e5m2"])
def test_fp8_cast(shape, dtype, fp8_format):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
ret, scale_inv = cast_to_fp8(x, fp8_format=fp8_format)
out = cast_from_fp8(ret, scale_inv, x.dtype)
assert_close(out, x, rtol=0.1, atol=0.1)
if x.size(-1) % 2 == 0:
inp_dict = {"hidden_states": x.clone()}
cast_to_fp8_pipeline(inp_dict)
cast_from_fp8_pipeline(inp_dict)
assert_close(inp_dict["hidden_states"], x, rtol=0.1, atol=0.1)
if __name__ == "__main__":
test_fp8_cast()

View File

@ -0,0 +1,87 @@
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
def get_grads_after_one_iteration(hook=None):
torch.manual_seed(0)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
if hook is not None:
ddp_model.register_comm_hook(None, hook)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
torch.distributed.barrier()
grad_dict = {}
for name, params in ddp_model.named_parameters():
grad_dict[name] = params.grad
return grad_dict
from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async, fp8_compress_ddp_grad_comm_hook_sync
grad_dict = get_grads_after_one_iteration()
for hook in [fp8_compress_ddp_grad_comm_hook_sync, fp8_compress_ddp_grad_comm_hook_async]:
grad_dict_w_hook = get_grads_after_one_iteration(hook)
if dist.get_rank() == 0:
for name in grad_dict:
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1)
cleanup()
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True)
if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
world_size = n_gpus
run_demo(demo_basic, world_size)

View File

@ -0,0 +1,107 @@
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from packaging import version
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing import assert_close
from colossalai import launch
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(100, 100)
self.relu = nn.ReLU()
self.net2 = nn.Linear(100, 50)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
@parameterize("mode", ["grad", "params"])
def run_model(mode):
rank = dist.get_rank()
from colossalai.quantization.utils import patch_fsdp_params_comm_hook
patch_fsdp_params_comm_hook()
def get_grads_after_one_iteration(grad_hook=None, params_hook=None):
torch.manual_seed(0)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
fsdp_model = FSDP(model)
if grad_hook is not None:
fsdp_model.register_comm_hook(None, grad_hook)
if params_hook is not None:
fsdp_model.register_params_comm_hook(None, params_hook)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(fsdp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = fsdp_model(torch.randn(20, 100))
labels = torch.randn(20, 50).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
torch.distributed.barrier()
grad_dict = {}
for name, params in fsdp_model.named_parameters():
grad_dict[name] = params.grad
return grad_dict
from colossalai.quantization.fp8 import fp8_compress_fsdp_grad_comm_hook, fp8_compress_fsdp_params_comm_hook
if mode == "grad":
grad_dict = get_grads_after_one_iteration()
for hook in [
fp8_compress_fsdp_grad_comm_hook,
]:
grad_dict_w_hook = get_grads_after_one_iteration(grad_hook=hook)
if dist.get_rank() == 0:
for name in grad_dict:
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1)
elif mode == "params":
grad_dict = get_grads_after_one_iteration()
for hook in [
fp8_compress_fsdp_params_comm_hook,
]:
grad_dict_w_hook = get_grads_after_one_iteration(params_hook=hook)
if dist.get_rank() == 0:
for name in grad_dict:
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1)
else:
raise NotImplementedError
def demo_basic(rank, world_size, port):
print(f"Running basic FSDP example on rank {rank}.")
launch(rank=rank, world_size=world_size, port=port, host="localhost")
run_model()
cleanup()
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("2.2.0"), reason="torch version < 2.2.0.")
@rerun_if_address_is_in_use()
def test_fsdp():
n_gpus = torch.cuda.device_count()
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
spawn(demo_basic, n_gpus)
if __name__ == "__main__":
test_fsdp()