mirror of https://github.com/hpcaitech/ColossalAI
[fp8] optimize all-gather (#6043)
* [fp8] optimize all-gather * [fp8] fix all gather fp8 ring * [fp8] enable compile * [fp8] fix all gather fp8 ringpull/6045/head
parent
c650a906db
commit
c3b5caff0e
|
@ -8,6 +8,7 @@ from packaging.version import Version
|
|||
from torch.distributed import ReduceOp
|
||||
|
||||
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
|
||||
SCALE_BYTES = 4
|
||||
|
||||
|
||||
class Handle:
|
||||
|
@ -22,7 +23,9 @@ class Handle:
|
|||
self.remain_ops()
|
||||
|
||||
|
||||
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def cast_to_fp8(
|
||||
inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
|
||||
Args:
|
||||
|
@ -55,12 +58,15 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -
|
|||
scale = fp8_max / per_tensor_max
|
||||
scale_inv = 1.0 / scale
|
||||
|
||||
ret = (scale * inp.float()).to(fp8_type)
|
||||
if out is not None:
|
||||
ret = torch.mul(scale, inp.float(), out=out)
|
||||
else:
|
||||
ret = (scale * inp.float()).to(fp8_type)
|
||||
return ret, torch.unsqueeze(scale_inv, dim=0)
|
||||
|
||||
|
||||
def cast_from_fp8(
|
||||
inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False
|
||||
inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False, out=None
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
|
@ -74,9 +80,15 @@ def cast_from_fp8(
|
|||
raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.")
|
||||
|
||||
if per_channel_scale:
|
||||
ret = scale_inv[:, None] * inp.float()
|
||||
if out is not None:
|
||||
return torch.mul(scale_inv[:, None], inp.float(), out=out)
|
||||
else:
|
||||
ret = scale_inv[:, None] * inp.float()
|
||||
else:
|
||||
ret = scale_inv * inp.float()
|
||||
if out is not None:
|
||||
return torch.mul(scale_inv, inp.float(), out=out)
|
||||
else:
|
||||
ret = scale_inv * inp.float()
|
||||
return ret.to(ret_type)
|
||||
|
||||
|
||||
|
@ -664,6 +676,90 @@ def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: boo
|
|||
cast_op()
|
||||
|
||||
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
|
||||
def all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
|
||||
world_size = dist.get_world_size(group)
|
||||
shape = input_.shape
|
||||
input_type = input_.dtype
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
|
||||
combined_buffer = torch.empty(world_size * (SCALE_BYTES + input_.numel()), dtype=torch.uint8, device=input_.device)
|
||||
combined_buffers = list(combined_buffer.chunk(world_size, dim=0))
|
||||
cur_buffer = combined_buffers[dist.get_rank(group)]
|
||||
ret = cur_buffer[SCALE_BYTES:].view(fp8_type)
|
||||
ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret)
|
||||
cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale
|
||||
# cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8)
|
||||
dist.all_gather(combined_buffers, cur_buffer, group=group, async_op=async_op)
|
||||
for out, buf in zip(output_list, combined_buffers):
|
||||
scale = buf[:SCALE_BYTES].clone().view(scale.dtype)
|
||||
output = buf[SCALE_BYTES:].view(fp8_type)
|
||||
cast_from_fp8(output.view(shape), scale, input_type, out=out)
|
||||
# output = combined_buffer.view(world_size, -1)[:, SCALE_BYTES:].view(fp8_type)
|
||||
# scales = combined_buffer.view(world_size, -1)[:, :SCALE_BYTES].view(torch.float)
|
||||
# output = output.float() * scales
|
||||
# for i, out in enumerate(output_list):
|
||||
# out.copy_(output[i].view(shape))
|
||||
|
||||
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
|
||||
def all_gather_fp8_ring(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
|
||||
world_size = dist.get_world_size(group)
|
||||
rank = dist.get_rank(group)
|
||||
|
||||
send_rank = (rank + 1) % world_size
|
||||
recv_rank = (rank - 1) % world_size
|
||||
|
||||
shape = input_.shape
|
||||
input_type = input_.dtype
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
|
||||
combined_buffer = torch.empty(world_size * (SCALE_BYTES + input_.numel()), dtype=torch.uint8, device=input_.device)
|
||||
combined_buffers = list(combined_buffer.chunk(world_size, dim=0))
|
||||
cur_buffer = combined_buffers[dist.get_rank(group)]
|
||||
ret = cur_buffer[SCALE_BYTES:].view(fp8_type)
|
||||
ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret)
|
||||
# cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8)
|
||||
cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale
|
||||
|
||||
def send_recv(idx):
|
||||
send_idx = (rank - idx) % world_size
|
||||
recv_idx = (rank - idx - 1) % world_size
|
||||
ops = dist.batch_isend_irecv(
|
||||
[
|
||||
dist.P2POp(dist.isend, combined_buffers[send_idx], send_rank, group=group),
|
||||
dist.P2POp(dist.irecv, combined_buffers[recv_idx], recv_rank, group=group),
|
||||
]
|
||||
)
|
||||
return ops
|
||||
|
||||
def cast(idx):
|
||||
cast_idx = (rank - idx - 1) % world_size
|
||||
scale = combined_buffers[cast_idx][:SCALE_BYTES].clone().view(torch.float)
|
||||
output = combined_buffers[cast_idx][SCALE_BYTES:].view(fp8_type)
|
||||
cast_from_fp8(output.view(shape), scale, input_type, out=output_list[cast_idx])
|
||||
|
||||
# warmup
|
||||
ops = send_recv(0)
|
||||
output_list[rank].copy_(input_)
|
||||
for op in ops:
|
||||
op.wait()
|
||||
ops = []
|
||||
|
||||
# 1p-1c
|
||||
for i in range(1, world_size - 1):
|
||||
new_ops = send_recv(i)
|
||||
for op in ops:
|
||||
op.wait()
|
||||
cast(i - 1)
|
||||
ops = new_ops
|
||||
|
||||
# cooldown
|
||||
for op in ops:
|
||||
op.wait()
|
||||
cast(world_size - 2)
|
||||
|
||||
|
||||
class _LinearFp8(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
|
|
Loading…
Reference in New Issue