mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
847 lines
32 KiB
847 lines
32 KiB
import os
|
|
from typing import Any, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn.functional as F
|
|
from packaging.version import Version
|
|
from torch.distributed import ReduceOp
|
|
|
|
from .fp8_config import dynamic_kernel
|
|
|
|
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
|
|
SCALE_BYTES = 4
|
|
try:
|
|
cuda_arch = int("".join(str(i) for i in torch.cuda.get_device_capability()))
|
|
except:
|
|
cuda_arch = 0
|
|
|
|
|
|
class Handle:
|
|
def __init__(self, handles=[], remain_ops=None) -> None:
|
|
self.handles = handles
|
|
self.remain_ops = remain_ops
|
|
|
|
def wait(self):
|
|
for handle in self.handles:
|
|
handle.wait()
|
|
if self.remain_ops:
|
|
self.remain_ops()
|
|
|
|
|
|
def process_group_is_intranode(pg):
|
|
if pg is None:
|
|
from torch.distributed.distributed_c10d import _get_default_group
|
|
|
|
pg = _get_default_group()
|
|
|
|
local_world_size = None
|
|
for var in ["LOCAL_WORLD_SIZE", "OMPI_COMM_WORLD_LOCAL_SIZE", "SLURM_TASKS_PER_NODE"]:
|
|
if var in os.environ:
|
|
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
|
|
if local_world_size is None:
|
|
local_world_size = torch.cuda.device_count()
|
|
|
|
group_ranks = dist.get_process_group_ranks(pg)
|
|
group_ranks_node_ids = [rank // local_world_size for rank in group_ranks]
|
|
return min(group_ranks_node_ids) == max(group_ranks_node_ids)
|
|
|
|
|
|
def cast_to_fp8(
|
|
inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
r"""
|
|
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
|
|
Args:
|
|
inp: input torch Tensor, should be in torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor.
|
|
scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling
|
|
is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied.
|
|
fp8_format: e4m3 or e5m2
|
|
|
|
Returns:
|
|
Tuples: A tuple (fp8_tensor, scale)
|
|
"""
|
|
|
|
if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
|
|
raise TypeError("Only float16, bfloat16, and float32 are allowed.")
|
|
|
|
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
|
fp8_max = torch.finfo(fp8_type).max
|
|
|
|
if inp.numel() == 0:
|
|
return inp.to(fp8_type), torch.tensor([1.0], device=inp.device)
|
|
else:
|
|
if per_channel_scale:
|
|
per_channel_max = inp.abs().max(dim=-1).values.float()
|
|
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
|
|
scale = fp8_max / per_channel_max[:, None]
|
|
scale_inv = per_channel_max / fp8_max
|
|
else:
|
|
per_tensor_max = inp.abs().max().float()
|
|
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
|
|
scale = fp8_max / per_tensor_max
|
|
scale_inv = 1.0 / scale
|
|
|
|
if out is not None:
|
|
ret = torch.mul(scale, inp.float(), out=out)
|
|
else:
|
|
ret = (scale * inp.float()).to(fp8_type)
|
|
return ret, torch.unsqueeze(scale_inv, dim=0)
|
|
|
|
|
|
def cast_from_fp8(
|
|
inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False, out=None
|
|
) -> torch.Tensor:
|
|
r"""
|
|
Args:
|
|
inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2].
|
|
scale: scaling factor returned by cast_to_fp8 function.
|
|
ret_type: the datatype of the returned tensor.
|
|
Returns:
|
|
torch.Tensor
|
|
"""
|
|
if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.")
|
|
|
|
if per_channel_scale:
|
|
if out is not None:
|
|
return torch.mul(scale_inv[:, None], inp.float(), out=out)
|
|
else:
|
|
ret = scale_inv[:, None] * inp.float()
|
|
else:
|
|
if out is not None:
|
|
return torch.mul(scale_inv, inp.float(), out=out)
|
|
else:
|
|
ret = scale_inv * inp.float()
|
|
return ret.to(ret_type)
|
|
|
|
|
|
def _all_reduce_fp8(
|
|
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
|
|
) -> Optional[Handle]:
|
|
r"""
|
|
This is an in-place operation for compressed all_reduce using fp8.
|
|
It works like dist.all_reduce but during communication the data is cast to fp8 format.
|
|
|
|
Args:
|
|
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
|
|
fp8_format: e4m3 or e5m2
|
|
op: ReduceOp.SUM or ReduceOp.AVG
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
|
|
world_size = dist.get_world_size(group=group)
|
|
input_type = tensor.dtype
|
|
input_shape = tensor.shape
|
|
input_device = tensor.device
|
|
input_size = tensor.numel()
|
|
flat_padded_x = tensor.flatten()
|
|
|
|
assert op in [ReduceOp.SUM, ReduceOp.AVG], "op can only be ReduceOp.SUM or ReduceOp.AVG"
|
|
|
|
if flat_padded_x.size(0) % world_size != 0:
|
|
pad_size = world_size - flat_padded_x.size(0) % world_size
|
|
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
|
|
|
|
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
|
ret, scale = cast_to_fp8(flat_padded_x, fp8_format=fp8_format)
|
|
|
|
inp = ret.view(torch.uint8)
|
|
input_chunks = list(torch.chunk(inp, world_size, dim=0))
|
|
output_chunks = list(torch.chunk(torch.empty_like(inp), world_size, dim=0))
|
|
dist.all_to_all(output_chunks, input_chunks, group=group)
|
|
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
|
dist.all_gather(scale_list, scale, group=group)
|
|
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
|
|
|
for scale, out in zip(scale_list, output_chunks):
|
|
out = out.view(fp8_type)
|
|
summed_out += cast_from_fp8(out, scale, input_type)
|
|
|
|
if op == ReduceOp.AVG:
|
|
summed_out.div_(world_size)
|
|
|
|
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
|
|
gather_scale_handle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
|
|
|
|
tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)]
|
|
gather_tensor_handle = dist.all_gather(
|
|
tensor_list, summed_out_fp8.view(torch.uint8), group=group, async_op=async_op
|
|
)
|
|
|
|
def cat_op():
|
|
for i in range(world_size):
|
|
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
|
|
out = torch.cat(tensor_list, dim=0)
|
|
tensor.copy_(out[:input_size].view(input_shape).to(input_type))
|
|
|
|
if async_op:
|
|
return Handle([gather_scale_handle, gather_tensor_handle], cat_op)
|
|
else:
|
|
cat_op()
|
|
|
|
|
|
def all_reduce_fp8(
|
|
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
|
|
) -> Optional[Handle]:
|
|
# fall back to default op due to performance issue
|
|
return dist.all_reduce(tensor, op=op, group=group, async_op=async_op)
|
|
|
|
|
|
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89)
|
|
def _all_to_all_single_fp8(
|
|
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
|
|
) -> Optional[Handle]:
|
|
r"""
|
|
This is an in-place operation for compressed all_reduce using fp8.
|
|
It works like dist.all_to_all_single but during communication the data is cast to fp8 format.
|
|
Args:
|
|
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
|
|
fp8_format: e4m3 or e5m2
|
|
Returns:
|
|
None
|
|
"""
|
|
world_size = dist.get_world_size(group=group)
|
|
input_type = input.dtype
|
|
input_shape = input.shape
|
|
input_device = input.device
|
|
input = input.flatten()
|
|
|
|
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
|
|
|
ret, scale = cast_to_fp8(input, fp8_format=fp8_format)
|
|
|
|
inp = ret.view(torch.uint8)
|
|
if input_split_sizes is not None:
|
|
input_split_sizes = [input_split_sizes[i] * np.prod(input_shape[1:]) for i in range(world_size)]
|
|
input_chunks = list(torch.split(inp, input_split_sizes))
|
|
else:
|
|
input_chunks = list(torch.chunk(inp, world_size, dim=0))
|
|
|
|
if output_split_sizes is not None:
|
|
output_chunks = [
|
|
torch.empty((output_split_sizes[i] * np.prod(input_shape[1:]),), device=input_device, dtype=inp.dtype)
|
|
for i in range(world_size)
|
|
]
|
|
else:
|
|
if dist.get_rank() == world_size - 1:
|
|
output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)]
|
|
else:
|
|
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
|
|
|
|
chunk_handle = dist.all_to_all(output_chunks, input_chunks, group=group, async_op=async_op)
|
|
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
|
scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
|
|
|
|
def cast_op():
|
|
cast_output_chunk = [
|
|
cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks)
|
|
]
|
|
|
|
tensor_out = torch.cat(cast_output_chunk, dim=0)
|
|
outputs_shape = list(input_shape)
|
|
if output_split_sizes is not None:
|
|
outputs_shape[0] = sum(output_split_sizes)
|
|
else:
|
|
outputs_shape = input_shape
|
|
output.data = tensor_out.view(outputs_shape).to(input_type)
|
|
|
|
if async_op:
|
|
return Handle([chunk_handle, scale_hanle], cast_op)
|
|
else:
|
|
cast_op()
|
|
|
|
|
|
def all_to_all_single_fp8(
|
|
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
|
|
) -> Optional[Handle]:
|
|
r"""
|
|
This is wrapper for _all_to_all_single_fp8.
|
|
"""
|
|
if process_group_is_intranode(group):
|
|
return dist.all_to_all_single(
|
|
output,
|
|
input,
|
|
output_split_sizes=output_split_sizes,
|
|
input_split_sizes=input_split_sizes,
|
|
group=group,
|
|
async_op=async_op,
|
|
)
|
|
else:
|
|
return _all_to_all_single_fp8(
|
|
output,
|
|
input,
|
|
fp8_format=fp8_format,
|
|
output_split_sizes=output_split_sizes,
|
|
input_split_sizes=input_split_sizes,
|
|
group=group,
|
|
async_op=async_op,
|
|
)
|
|
|
|
|
|
def cast_to_fp8_pipeline(inp: Any) -> None:
|
|
"""
|
|
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
|
|
The activations tensor is indexed by 'hidden_states' in the inp dict.
|
|
After FP8 casting, the resulting tensor is saved as float16 or bfloat16 format but the size becomes halved.
|
|
Metadata such as fp8_scale is saved into inp dict for communication.
|
|
"""
|
|
if inp is None:
|
|
return
|
|
# In pipeline parallelism, when inp is torch.Tensor, it only contains one element, thus can be omitted.
|
|
if type(inp) == torch.Tensor:
|
|
return
|
|
|
|
assert "hidden_states" in inp, "required by pipeline parallelism."
|
|
assert (
|
|
inp["hidden_states"].size(-1) % 2 == 0
|
|
), "tensor size(-1) must be divisible by 2 to view Float8_e4m3fn as BFloat16 or Float16"
|
|
inp_tensor = inp["hidden_states"]
|
|
inp_dtype = inp_tensor.dtype
|
|
|
|
min_val, max_val = inp_tensor.aminmax()
|
|
amax = torch.maximum(min_val.abs(), max_val.abs())
|
|
|
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
|
if amax > finfo.max:
|
|
fp8_type = torch.float8_e5m2
|
|
fp8_view_type = torch.float16
|
|
else:
|
|
fp8_type = torch.float8_e4m3fn
|
|
fp8_view_type = torch.bfloat16
|
|
|
|
finfo = torch.finfo(fp8_type)
|
|
scale = torch.tensor(1.0).to(inp_tensor.device) if amax == 0.0 else finfo.max / amax.float()
|
|
q_tensor = inp_tensor.data.float() * scale
|
|
# Todo: Currently we use fp8_view_type <float16, bfloat16> to indicate which fp8 format is used. This is a temporary workaround due to 'Only support tensor for fast send'.
|
|
# inp_tensor needs to be a float datatype to avoid error during gradient placement.
|
|
inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type)
|
|
|
|
inp["fp8_scale"] = scale.float().reciprocal()
|
|
inp["dtype"] = torch.zeros_like(scale).to(inp_dtype)
|
|
|
|
|
|
def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
|
"""
|
|
Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline.
|
|
del_metadata = False is useful when this function is called before p2p communication.
|
|
"""
|
|
if inp is None:
|
|
return
|
|
if type(inp) == torch.Tensor:
|
|
return
|
|
|
|
assert "hidden_states" in inp, "required by pipeline parallelism."
|
|
inp_tensor = inp["hidden_states"]
|
|
scale = inp["fp8_scale"]
|
|
|
|
fp8_view_type = inp_tensor.dtype
|
|
if fp8_view_type == torch.float16:
|
|
fp8_type = torch.float8_e5m2
|
|
elif fp8_view_type == torch.bfloat16:
|
|
fp8_type = torch.float8_e4m3fn
|
|
else:
|
|
raise TypeError("Only float16, bfloat16 are implemented.")
|
|
|
|
inp_tensor.data = inp_tensor.data.view(fp8_type).to(inp["dtype"]) * scale
|
|
|
|
if del_metadata:
|
|
del inp["fp8_scale"]
|
|
del inp["dtype"]
|
|
|
|
|
|
def _reduce_scatter_fp8(
|
|
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
|
|
) -> Optional[Handle]:
|
|
r"""
|
|
This is an in-place operation for compressed reduce_scatter using fp8.
|
|
It works like dist.reduce_scatter but during communication the data is cast to fp8 format.
|
|
|
|
Args:
|
|
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
|
|
fp8_format: e4m3 or e5m2
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
|
|
input_type = output.dtype
|
|
|
|
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
|
scale_list = []
|
|
cast_input_list = []
|
|
output_chunks = []
|
|
output_scale_list = []
|
|
for input in input_list:
|
|
ret, scale = cast_to_fp8(input, fp8_format=fp8_format)
|
|
scale_list.append(scale)
|
|
ret = ret.view(torch.uint8)
|
|
cast_input_list.append(ret)
|
|
output_chunks.append(torch.empty_like(ret))
|
|
output_scale_list.append(torch.empty_like(scale))
|
|
chunk_handle = dist.all_to_all(output_chunks, cast_input_list, group=group, async_op=async_op)
|
|
scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op)
|
|
|
|
def cast_op():
|
|
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
|
for scale, out in zip(output_scale_list, output_chunks):
|
|
out = out.view(fp8_type)
|
|
summed_out += cast_from_fp8(out, scale, input_type)
|
|
output.data = summed_out
|
|
|
|
if async_op:
|
|
return Handle([chunk_handle, scale_handle], cast_op)
|
|
else:
|
|
cast_op()
|
|
|
|
|
|
def reduce_scatter_fp8(
|
|
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
|
|
) -> Optional[Handle]:
|
|
# fall back to default op due to performance issue
|
|
return dist.reduce_scatter(output, input_list, group=group, async_op=async_op)
|
|
|
|
|
|
def fp8_compress_ddp_grad_comm_hook_async(
|
|
process_group: dist.ProcessGroup,
|
|
bucket: dist.GradBucket,
|
|
fp8_format: str = "e5m2",
|
|
) -> torch.futures.Future[torch.Tensor]:
|
|
"""
|
|
Compress by casting ``GradBucket`` to FP8 floating-point format divided by process group size.
|
|
|
|
This DDP communication hook implements a simple gradient compression approach that casts ``GradBucket`` tensor
|
|
to FP8 floating-point format (``torch.float8_e5m2`` or ``torch.bfloat16_e4m3``), and then divides it
|
|
by the process group size.
|
|
Once compressed gradient tensors are allreduced, the chained callback ``decompress`` casts it back
|
|
to the input data type (such as ``float32``).
|
|
|
|
Example::
|
|
>>> ddp_model.register_comm_hook(process_group, fp8_compress_ddp_grad_comm_hook_async)
|
|
"""
|
|
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
|
|
|
input_tensor = bucket.buffer()
|
|
world_size = dist.get_world_size()
|
|
input_type = input_tensor.dtype
|
|
input_device = input_tensor.device
|
|
flat_padded_x = input_tensor.flatten()
|
|
|
|
if flat_padded_x.size(0) % world_size != 0:
|
|
pad_size = world_size - flat_padded_x.size(0) % world_size
|
|
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
|
|
|
|
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
|
ret, scale = cast_to_fp8(flat_padded_x, fp8_format=fp8_format)
|
|
|
|
inp = ret.view(torch.uint8)
|
|
output_chunks_single = torch.empty_like(inp)
|
|
split_sizes = [inp.numel() // world_size for _ in range(world_size)]
|
|
fut0 = dist.all_to_all_single(
|
|
output_chunks_single,
|
|
inp,
|
|
output_split_sizes=split_sizes,
|
|
input_split_sizes=split_sizes,
|
|
group=group_to_use,
|
|
async_op=True,
|
|
).get_future()
|
|
|
|
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
|
fut1 = dist.all_gather_into_tensor(
|
|
torch.cat(scale_list, dim=0), scale, group=group_to_use, async_op=True
|
|
).get_future()
|
|
all_to_all_fut = torch.futures.collect_all([fut0, fut1])
|
|
|
|
def sum_and_allgather(fut):
|
|
output_chunks_single = fut.value()[0].wait()[0]
|
|
scale_list_single = fut.value()[1].wait()[0]
|
|
|
|
output_chunks = list(torch.chunk(output_chunks_single, world_size, dim=0))
|
|
scale_list = scale_list_single.chunk(world_size, dim=0)
|
|
|
|
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
|
for scale, out in zip(scale_list, output_chunks):
|
|
out = out.view(fp8_type)
|
|
summed_out += cast_from_fp8(out, scale, input_type)
|
|
summed_out.div_(world_size)
|
|
|
|
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
|
|
|
|
tensor_list_single = torch.empty(summed_out_fp8.size(0) * world_size, device=input_device, dtype=torch.uint8)
|
|
fut2 = dist.all_gather_into_tensor(
|
|
tensor_list_single, summed_out_fp8.view(torch.uint8), group=group_to_use, async_op=True
|
|
).get_future()
|
|
|
|
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
|
fut3 = dist.all_gather_into_tensor(
|
|
torch.cat(scale_list, dim=0), scale, group=group_to_use, async_op=True
|
|
).get_future()
|
|
fut_combined2 = torch.futures.collect_all([fut2, fut3])
|
|
return fut_combined2
|
|
|
|
def decompress(fut):
|
|
tensor_list_single = fut.value().wait()[0].value()[0]
|
|
scale_list_single = fut.value().wait()[1].value()[0]
|
|
|
|
tensor_list = list(torch.chunk(tensor_list_single, world_size, dim=0))
|
|
scale_list = scale_list_single.chunk(world_size, dim=0)
|
|
|
|
for i in range(world_size):
|
|
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
|
|
out = torch.cat(tensor_list, dim=0)
|
|
|
|
input_tensor_size = input_tensor.numel()
|
|
input_shape = input_tensor.shape
|
|
out = out[:input_tensor_size]
|
|
|
|
input_tensor.copy_(out.view(input_shape).to(input_type))
|
|
return input_tensor
|
|
|
|
return all_to_all_fut.then(sum_and_allgather).then(decompress)
|
|
|
|
|
|
def fp8_compress_ddp_grad_comm_hook_sync(
|
|
process_group: dist.ProcessGroup,
|
|
bucket: dist.GradBucket,
|
|
fp8_format="e5m2",
|
|
) -> torch.futures.Future[torch.Tensor]:
|
|
"""
|
|
Return a future that wraps the input, after the input is allreduced. However, the allreduce commnunication is synchronized.
|
|
This breaks the overlapping between allreduce communication and backward compuation.
|
|
|
|
This hook should **only** be used for debugging purposes, instead of the normal gradient synchronization.
|
|
For asynchronized implementation, use fp8_compress_ddp_grad_comm_hook_async instead.
|
|
|
|
Example::
|
|
>>> # xdoctest: +SKIP
|
|
>>> ddp_model.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_sync)
|
|
"""
|
|
|
|
buffer = bucket.buffer()
|
|
all_reduce_fp8(buffer, fp8_format=fp8_format)
|
|
|
|
fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
|
|
fut.set_result(bucket.buffer())
|
|
|
|
return fut
|
|
|
|
|
|
def fp8_compress_fsdp_grad_comm_hook(
|
|
state: object,
|
|
unsharded_gradient_flattened: torch.Tensor,
|
|
sharded_gradient: torch.Tensor,
|
|
group=None,
|
|
fp8_format="e5m2",
|
|
) -> None:
|
|
"""
|
|
This communication hook implements a simple gradient compression approach that casts unsharded_gradient_flattened tensor
|
|
to FP8 floating-point format (``torch.float8_e5m2`` or ``torch.bfloat16_e4m3``), and then perform scatter_allreduce logic
|
|
by using all_to_all and all_gather among the process group.
|
|
|
|
Example::
|
|
>>> fsdp_model.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook)
|
|
"""
|
|
grad = unsharded_gradient_flattened
|
|
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
|
input_type = grad.dtype
|
|
input_device = grad.device
|
|
world_size = dist.get_world_size(group=group)
|
|
|
|
grad_fp8, scale = cast_to_fp8(grad, fp8_format=fp8_format)
|
|
uint8_buffer = torch.empty_like(grad_fp8).view(torch.uint8)
|
|
dist.all_to_all_single(uint8_buffer, grad_fp8.view(torch.uint8), group=group)
|
|
|
|
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
|
dist.all_gather(scale_list, scale, group=group)
|
|
|
|
buffer_list = list(torch.chunk(uint8_buffer.view(fp8_type), world_size, dim=0))
|
|
sharded_gradient.zero_()
|
|
for tensor, scale in zip(buffer_list, scale_list):
|
|
sharded_gradient += cast_from_fp8(tensor, scale, input_type)
|
|
|
|
|
|
def fp8_compress_fsdp_params_comm_hook(
|
|
state: object,
|
|
padded_unsharded_flat_param: torch.Tensor,
|
|
sharded_flat_param: torch.Tensor,
|
|
group=None,
|
|
fp8_format="e5m2",
|
|
) -> None:
|
|
"""
|
|
This hook is pending the official support for parameters communication hook in FSDP, e.g. register_params_comm_hook.
|
|
|
|
Example::
|
|
>>> fsdp_model.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook)
|
|
"""
|
|
|
|
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
|
fp8_max = torch.finfo(fp8_type).max
|
|
inp = sharded_flat_param
|
|
out = padded_unsharded_flat_param
|
|
|
|
per_tensor_max = inp.abs().max().float()
|
|
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
|
|
dist.all_reduce(per_tensor_max, op=torch.distributed.ReduceOp.MAX, group=group)
|
|
|
|
scale = fp8_max / per_tensor_max
|
|
fp8_sharded_flat_param = (scale * inp.float()).to(fp8_type).view(torch.uint8)
|
|
|
|
fp8_out = torch.empty(out.shape, dtype=torch.uint8, device=out.device)
|
|
dist.all_gather_into_tensor(
|
|
fp8_out,
|
|
fp8_sharded_flat_param,
|
|
group=group,
|
|
)
|
|
padded_unsharded_flat_param.copy_((fp8_out.view(fp8_type).float() / scale).to(out.dtype))
|
|
|
|
|
|
def split_chunk_by_channel(
|
|
chunk: torch.Tensor, channel_size: int, num_channels: int, rank: int = 0, world_size: int = 1
|
|
):
|
|
offset = chunk.numel() * rank
|
|
end = offset + chunk.numel()
|
|
break_points = [x for x in range(0, channel_size * num_channels + 1, channel_size) if offset <= x <= end]
|
|
if len(break_points) == 0 or break_points[0] > offset:
|
|
break_points.insert(0, offset)
|
|
if break_points[-1] < end:
|
|
break_points.append(end)
|
|
sizes = [b - a for a, b in zip(break_points[:-1], break_points[1:])]
|
|
return chunk.split(sizes)
|
|
|
|
|
|
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89)
|
|
def _all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
|
|
world_size = dist.get_world_size(group)
|
|
input_type = input_list[0].dtype
|
|
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
|
scale_list = []
|
|
tensor_list = []
|
|
|
|
for i in range(world_size):
|
|
input_tensor = input_list[i]
|
|
ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format)
|
|
scale_list.append(scale)
|
|
ret = ret.view(torch.uint8)
|
|
tensor_list.append(ret)
|
|
|
|
output_scale_list = [torch.empty_like(x) for x in scale_list]
|
|
output_tensor_list = [torch.empty_like(x) for x in tensor_list]
|
|
tensor_hanle = dist.all_to_all(output_tensor_list, tensor_list, group=group, async_op=async_op)
|
|
scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op)
|
|
|
|
def cast_op():
|
|
for i in range(world_size):
|
|
scale = output_scale_list[i]
|
|
tensor = output_tensor_list[i]
|
|
tensor = tensor.view(fp8_type)
|
|
output_list[i].copy_(cast_from_fp8(tensor, scale, input_type))
|
|
|
|
if async_op:
|
|
return Handle([tensor_hanle, scale_handle], cast_op)
|
|
else:
|
|
cast_op()
|
|
|
|
|
|
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
|
|
if process_group_is_intranode(group):
|
|
return dist.all_to_all(output_list, input_list, group=group, async_op=async_op)
|
|
else:
|
|
return _all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format, async_op=async_op)
|
|
|
|
|
|
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89)
|
|
def _all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
|
|
world_size = dist.get_world_size(group)
|
|
|
|
input_type = input_.dtype
|
|
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format)
|
|
fp8_type = ret.dtype
|
|
input_ = ret.view(torch.uint8)
|
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
|
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)]
|
|
chunk_handle = dist.all_gather(tensor_list, input_, group=group, async_op=async_op)
|
|
scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
|
|
|
|
def cast_op():
|
|
for i in range(world_size):
|
|
output = tensor_list[i].view(fp8_type)
|
|
scale = scale_list[i]
|
|
output_list[i].copy_(cast_from_fp8(output, scale, input_type))
|
|
|
|
if async_op:
|
|
return Handle([chunk_handle, scale_hanle], cast_op)
|
|
else:
|
|
cast_op()
|
|
|
|
|
|
def all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
|
|
if process_group_is_intranode(group):
|
|
return dist.all_gather(output_list, input_, group=group, async_op=async_op)
|
|
else:
|
|
return _all_gather_fp8(output_list, input_, group=group, fp8_format=fp8_format, async_op=async_op)
|
|
|
|
|
|
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89)
|
|
def all_gather_fp8_lagacy(
|
|
output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False
|
|
) -> Optional[Handle]:
|
|
world_size = dist.get_world_size(group)
|
|
shape = input_.shape
|
|
input_type = input_.dtype
|
|
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
|
|
|
combined_buffer = torch.empty(world_size * (SCALE_BYTES + input_.numel()), dtype=torch.uint8, device=input_.device)
|
|
combined_buffers = list(combined_buffer.chunk(world_size, dim=0))
|
|
cur_buffer = combined_buffers[dist.get_rank(group)]
|
|
ret = cur_buffer[SCALE_BYTES:].view(fp8_type)
|
|
ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret)
|
|
cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale
|
|
# cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8)
|
|
dist.all_gather(combined_buffers, cur_buffer, group=group, async_op=async_op)
|
|
for out, buf in zip(output_list, combined_buffers):
|
|
scale = buf[:SCALE_BYTES].clone().view(scale.dtype)
|
|
output = buf[SCALE_BYTES:].view(fp8_type)
|
|
cast_from_fp8(output.view(shape), scale, input_type, out=out)
|
|
# output = combined_buffer.view(world_size, -1)[:, SCALE_BYTES:].view(fp8_type)
|
|
# scales = combined_buffer.view(world_size, -1)[:, :SCALE_BYTES].view(torch.float)
|
|
# output = output.float() * scales
|
|
# for i, out in enumerate(output_list):
|
|
# out.copy_(output[i].view(shape))
|
|
|
|
|
|
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89)
|
|
def all_gather_fp8_ring(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
|
|
world_size = dist.get_world_size(group)
|
|
rank = dist.get_rank(group)
|
|
|
|
send_rank = (rank + 1) % world_size
|
|
recv_rank = (rank - 1) % world_size
|
|
|
|
shape = input_.shape
|
|
input_type = input_.dtype
|
|
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
|
|
|
combined_buffer = torch.empty(world_size * (SCALE_BYTES + input_.numel()), dtype=torch.uint8, device=input_.device)
|
|
combined_buffers = list(combined_buffer.chunk(world_size, dim=0))
|
|
cur_buffer = combined_buffers[dist.get_rank(group)]
|
|
ret = cur_buffer[SCALE_BYTES:].view(fp8_type)
|
|
ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret)
|
|
# cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8)
|
|
cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale
|
|
|
|
def send_recv(idx):
|
|
send_idx = (rank - idx) % world_size
|
|
recv_idx = (rank - idx - 1) % world_size
|
|
ops = dist.batch_isend_irecv(
|
|
[
|
|
dist.P2POp(dist.isend, combined_buffers[send_idx], send_rank, group=group),
|
|
dist.P2POp(dist.irecv, combined_buffers[recv_idx], recv_rank, group=group),
|
|
]
|
|
)
|
|
return ops
|
|
|
|
def cast(idx):
|
|
cast_idx = (rank - idx - 1) % world_size
|
|
scale = combined_buffers[cast_idx][:SCALE_BYTES].clone().view(torch.float)
|
|
output = combined_buffers[cast_idx][SCALE_BYTES:].view(fp8_type)
|
|
cast_from_fp8(output.view(shape), scale, input_type, out=output_list[cast_idx])
|
|
|
|
# warmup
|
|
ops = send_recv(0)
|
|
output_list[rank].copy_(input_)
|
|
for op in ops:
|
|
op.wait()
|
|
ops = []
|
|
|
|
# 1p-1c
|
|
for i in range(1, world_size - 1):
|
|
new_ops = send_recv(i)
|
|
for op in ops:
|
|
op.wait()
|
|
cast(i - 1)
|
|
ops = new_ops
|
|
|
|
# cooldown
|
|
for op in ops:
|
|
op.wait()
|
|
cast(world_size - 2)
|
|
|
|
|
|
class _LinearFp8(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx: Any,
|
|
x: torch.Tensor,
|
|
w: torch.Tensor,
|
|
bias: Optional[torch.Tensor],
|
|
) -> Any:
|
|
assert (
|
|
x.dtype in (torch.bfloat16, torch.float16) and x.dtype == w.dtype
|
|
), "Only float16 and bfloat16 are allowed."
|
|
if bias is not None:
|
|
assert bias.dtype == x.dtype, "Bias should have the same dtype as input."
|
|
# ensure x and w are row-major
|
|
x = x.contiguous()
|
|
w = w.contiguous()
|
|
ctx.x_shape = x.shape
|
|
ctx.has_bias = bias is not None
|
|
ctx.out_dtype = x.dtype
|
|
x = x.reshape(-1, x.shape[-1])
|
|
|
|
x_fp8, inv_scale_x = cast_to_fp8(x, fp8_format="e4m3")
|
|
w_fp8, inv_scale_w = cast_to_fp8(w, fp8_format="e4m3")
|
|
ctx.x_fp8 = x_fp8
|
|
ctx.w_fp8_t = w_fp8.t()
|
|
ctx.inv_scale_x = inv_scale_x
|
|
ctx.inv_scale_w = inv_scale_w
|
|
out = torch._scaled_mm(
|
|
x_fp8,
|
|
ctx.w_fp8_t,
|
|
bias=bias,
|
|
out_dtype=ctx.out_dtype,
|
|
scale_a=inv_scale_x,
|
|
scale_b=inv_scale_w,
|
|
use_fast_accum=True,
|
|
)[0]
|
|
return out.reshape(*ctx.x_shape[:-1], w.shape[0])
|
|
|
|
@staticmethod
|
|
def backward(ctx: Any, out_grad) -> Any:
|
|
out_grad = out_grad.reshape(-1, out_grad.shape[-1])
|
|
out_grad_fp8, out_grad_scale = cast_to_fp8(out_grad, fp8_format="e5m2")
|
|
x_grad = torch._scaled_mm(
|
|
out_grad_fp8,
|
|
ctx.w_fp8_t.contiguous().t(),
|
|
out_dtype=ctx.out_dtype,
|
|
scale_a=out_grad_scale,
|
|
scale_b=ctx.inv_scale_w,
|
|
use_fast_accum=True,
|
|
)[0]
|
|
w_grad = torch._scaled_mm(
|
|
out_grad_fp8.t().contiguous(),
|
|
ctx.x_fp8.t().contiguous().t(),
|
|
out_dtype=ctx.out_dtype,
|
|
scale_a=out_grad_scale,
|
|
scale_b=ctx.inv_scale_x,
|
|
use_fast_accum=True,
|
|
)[0]
|
|
bias_grad = None
|
|
if ctx.has_bias:
|
|
bias_grad = out_grad.sum(0)
|
|
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
|
|
|
|
|
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel)
|
|
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
return _LinearFp8.apply(input, weight, bias)
|
|
|
|
|
|
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
if input.shape[-1] % 16 != 0 or np.prod(input.shape[:-1]) % 16 != 0:
|
|
return F.linear(input, weight, bias)
|
|
out = _linear_fp8(input, weight, bias)
|
|
return out
|