mirror of https://github.com/hpcaitech/ColossalAI
[fp8] optimize all-gather (#6043)
* [fp8] optimize all-gather * [fp8] fix all gather fp8 ring * [fp8] enable compile * [fp8] fix all gather fp8 ringpull/6045/head
parent
c650a906db
commit
c3b5caff0e
|
@ -8,6 +8,7 @@ from packaging.version import Version
|
||||||
from torch.distributed import ReduceOp
|
from torch.distributed import ReduceOp
|
||||||
|
|
||||||
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
|
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
|
||||||
|
SCALE_BYTES = 4
|
||||||
|
|
||||||
|
|
||||||
class Handle:
|
class Handle:
|
||||||
|
@ -22,7 +23,9 @@ class Handle:
|
||||||
self.remain_ops()
|
self.remain_ops()
|
||||||
|
|
||||||
|
|
||||||
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
def cast_to_fp8(
|
||||||
|
inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
|
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
|
||||||
Args:
|
Args:
|
||||||
|
@ -55,12 +58,15 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -
|
||||||
scale = fp8_max / per_tensor_max
|
scale = fp8_max / per_tensor_max
|
||||||
scale_inv = 1.0 / scale
|
scale_inv = 1.0 / scale
|
||||||
|
|
||||||
|
if out is not None:
|
||||||
|
ret = torch.mul(scale, inp.float(), out=out)
|
||||||
|
else:
|
||||||
ret = (scale * inp.float()).to(fp8_type)
|
ret = (scale * inp.float()).to(fp8_type)
|
||||||
return ret, torch.unsqueeze(scale_inv, dim=0)
|
return ret, torch.unsqueeze(scale_inv, dim=0)
|
||||||
|
|
||||||
|
|
||||||
def cast_from_fp8(
|
def cast_from_fp8(
|
||||||
inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False
|
inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False, out=None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -74,7 +80,13 @@ def cast_from_fp8(
|
||||||
raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.")
|
raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.")
|
||||||
|
|
||||||
if per_channel_scale:
|
if per_channel_scale:
|
||||||
|
if out is not None:
|
||||||
|
return torch.mul(scale_inv[:, None], inp.float(), out=out)
|
||||||
|
else:
|
||||||
ret = scale_inv[:, None] * inp.float()
|
ret = scale_inv[:, None] * inp.float()
|
||||||
|
else:
|
||||||
|
if out is not None:
|
||||||
|
return torch.mul(scale_inv, inp.float(), out=out)
|
||||||
else:
|
else:
|
||||||
ret = scale_inv * inp.float()
|
ret = scale_inv * inp.float()
|
||||||
return ret.to(ret_type)
|
return ret.to(ret_type)
|
||||||
|
@ -664,6 +676,90 @@ def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: boo
|
||||||
cast_op()
|
cast_op()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
|
||||||
|
def all_gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
|
||||||
|
world_size = dist.get_world_size(group)
|
||||||
|
shape = input_.shape
|
||||||
|
input_type = input_.dtype
|
||||||
|
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||||
|
|
||||||
|
combined_buffer = torch.empty(world_size * (SCALE_BYTES + input_.numel()), dtype=torch.uint8, device=input_.device)
|
||||||
|
combined_buffers = list(combined_buffer.chunk(world_size, dim=0))
|
||||||
|
cur_buffer = combined_buffers[dist.get_rank(group)]
|
||||||
|
ret = cur_buffer[SCALE_BYTES:].view(fp8_type)
|
||||||
|
ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret)
|
||||||
|
cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale
|
||||||
|
# cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8)
|
||||||
|
dist.all_gather(combined_buffers, cur_buffer, group=group, async_op=async_op)
|
||||||
|
for out, buf in zip(output_list, combined_buffers):
|
||||||
|
scale = buf[:SCALE_BYTES].clone().view(scale.dtype)
|
||||||
|
output = buf[SCALE_BYTES:].view(fp8_type)
|
||||||
|
cast_from_fp8(output.view(shape), scale, input_type, out=out)
|
||||||
|
# output = combined_buffer.view(world_size, -1)[:, SCALE_BYTES:].view(fp8_type)
|
||||||
|
# scales = combined_buffer.view(world_size, -1)[:, :SCALE_BYTES].view(torch.float)
|
||||||
|
# output = output.float() * scales
|
||||||
|
# for i, out in enumerate(output_list):
|
||||||
|
# out.copy_(output[i].view(shape))
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
|
||||||
|
def all_gather_fp8_ring(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
|
||||||
|
world_size = dist.get_world_size(group)
|
||||||
|
rank = dist.get_rank(group)
|
||||||
|
|
||||||
|
send_rank = (rank + 1) % world_size
|
||||||
|
recv_rank = (rank - 1) % world_size
|
||||||
|
|
||||||
|
shape = input_.shape
|
||||||
|
input_type = input_.dtype
|
||||||
|
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||||
|
|
||||||
|
combined_buffer = torch.empty(world_size * (SCALE_BYTES + input_.numel()), dtype=torch.uint8, device=input_.device)
|
||||||
|
combined_buffers = list(combined_buffer.chunk(world_size, dim=0))
|
||||||
|
cur_buffer = combined_buffers[dist.get_rank(group)]
|
||||||
|
ret = cur_buffer[SCALE_BYTES:].view(fp8_type)
|
||||||
|
ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret)
|
||||||
|
# cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8)
|
||||||
|
cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale
|
||||||
|
|
||||||
|
def send_recv(idx):
|
||||||
|
send_idx = (rank - idx) % world_size
|
||||||
|
recv_idx = (rank - idx - 1) % world_size
|
||||||
|
ops = dist.batch_isend_irecv(
|
||||||
|
[
|
||||||
|
dist.P2POp(dist.isend, combined_buffers[send_idx], send_rank, group=group),
|
||||||
|
dist.P2POp(dist.irecv, combined_buffers[recv_idx], recv_rank, group=group),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return ops
|
||||||
|
|
||||||
|
def cast(idx):
|
||||||
|
cast_idx = (rank - idx - 1) % world_size
|
||||||
|
scale = combined_buffers[cast_idx][:SCALE_BYTES].clone().view(torch.float)
|
||||||
|
output = combined_buffers[cast_idx][SCALE_BYTES:].view(fp8_type)
|
||||||
|
cast_from_fp8(output.view(shape), scale, input_type, out=output_list[cast_idx])
|
||||||
|
|
||||||
|
# warmup
|
||||||
|
ops = send_recv(0)
|
||||||
|
output_list[rank].copy_(input_)
|
||||||
|
for op in ops:
|
||||||
|
op.wait()
|
||||||
|
ops = []
|
||||||
|
|
||||||
|
# 1p-1c
|
||||||
|
for i in range(1, world_size - 1):
|
||||||
|
new_ops = send_recv(i)
|
||||||
|
for op in ops:
|
||||||
|
op.wait()
|
||||||
|
cast(i - 1)
|
||||||
|
ops = new_ops
|
||||||
|
|
||||||
|
# cooldown
|
||||||
|
for op in ops:
|
||||||
|
op.wait()
|
||||||
|
cast(world_size - 2)
|
||||||
|
|
||||||
|
|
||||||
class _LinearFp8(torch.autograd.Function):
|
class _LinearFp8(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(
|
def forward(
|
||||||
|
|
Loading…
Reference in New Issue