[fp8] optimize all-gather (#6043)

* [fp8] optimize all-gather

* [fp8] fix all gather fp8 ring

* [fp8] enable compile

* [fp8] fix all gather fp8 ring
pull/6045/head
Hongxin Liu 2024-09-03 15:45:17 +08:00 committed by GitHub
parent c650a906db
commit c3b5caff0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 101 additions and 5 deletions

View File

@ -8,6 +8,7 @@ from packaging.version import Version
from torch.distributed import ReduceOp
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
SCALE_BYTES = 4
class Handle:
@ -22,7 +23,9 @@ class Handle:
self.remain_ops()
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
def cast_to_fp8(
inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
Args:
@ -55,12 +58,15 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -
scale = fp8_max / per_tensor_max
scale_inv = 1.0 / scale
ret = (scale * inp.float()).to(fp8_type)
if out is not None:
ret = torch.mul(scale, inp.float(), out=out)
else:
ret = (scale * inp.float()).to(fp8_type)
return ret, torch.unsqueeze(scale_inv, dim=0)
def cast_from_fp8(
inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False
inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False, out=None
) -> torch.Tensor:
r"""
Args:
@ -74,9 +80,15 @@ def cast_from_fp8(
raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.")
if per_channel_scale:
ret = scale_inv[:, None] * inp.float()
if out is not None:
return torch.mul(scale_inv[:, None], inp.float(), out=out)
else:
ret = scale_inv[:, None] * inp.float()
else:
ret = scale_inv * inp.float()
if out is not None:
return torch.mul(scale_inv, inp.float(), out=out)
else:
ret = scale_inv * inp.float()
return ret.to(ret_type)
@ -664,6 +676,90 @@ def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: boo
cast_op()
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
def all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
world_size = dist.get_world_size(group)
shape = input_.shape
input_type = input_.dtype
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
combined_buffer = torch.empty(world_size * (SCALE_BYTES + input_.numel()), dtype=torch.uint8, device=input_.device)
combined_buffers = list(combined_buffer.chunk(world_size, dim=0))
cur_buffer = combined_buffers[dist.get_rank(group)]
ret = cur_buffer[SCALE_BYTES:].view(fp8_type)
ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret)
cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale
# cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8)
dist.all_gather(combined_buffers, cur_buffer, group=group, async_op=async_op)
for out, buf in zip(output_list, combined_buffers):
scale = buf[:SCALE_BYTES].clone().view(scale.dtype)
output = buf[SCALE_BYTES:].view(fp8_type)
cast_from_fp8(output.view(shape), scale, input_type, out=out)
# output = combined_buffer.view(world_size, -1)[:, SCALE_BYTES:].view(fp8_type)
# scales = combined_buffer.view(world_size, -1)[:, :SCALE_BYTES].view(torch.float)
# output = output.float() * scales
# for i, out in enumerate(output_list):
# out.copy_(output[i].view(shape))
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
def all_gather_fp8_ring(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
world_size = dist.get_world_size(group)
rank = dist.get_rank(group)
send_rank = (rank + 1) % world_size
recv_rank = (rank - 1) % world_size
shape = input_.shape
input_type = input_.dtype
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
combined_buffer = torch.empty(world_size * (SCALE_BYTES + input_.numel()), dtype=torch.uint8, device=input_.device)
combined_buffers = list(combined_buffer.chunk(world_size, dim=0))
cur_buffer = combined_buffers[dist.get_rank(group)]
ret = cur_buffer[SCALE_BYTES:].view(fp8_type)
ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret)
# cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8)
cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale
def send_recv(idx):
send_idx = (rank - idx) % world_size
recv_idx = (rank - idx - 1) % world_size
ops = dist.batch_isend_irecv(
[
dist.P2POp(dist.isend, combined_buffers[send_idx], send_rank, group=group),
dist.P2POp(dist.irecv, combined_buffers[recv_idx], recv_rank, group=group),
]
)
return ops
def cast(idx):
cast_idx = (rank - idx - 1) % world_size
scale = combined_buffers[cast_idx][:SCALE_BYTES].clone().view(torch.float)
output = combined_buffers[cast_idx][SCALE_BYTES:].view(fp8_type)
cast_from_fp8(output.view(shape), scale, input_type, out=output_list[cast_idx])
# warmup
ops = send_recv(0)
output_list[rank].copy_(input_)
for op in ops:
op.wait()
ops = []
# 1p-1c
for i in range(1, world_size - 1):
new_ops = send_recv(i)
for op in ops:
op.wait()
cast(i - 1)
ops = new_ops
# cooldown
for op in ops:
op.wait()
cast(world_size - 2)
class _LinearFp8(torch.autograd.Function):
@staticmethod
def forward(