[fp8] disable all_to_all_fp8 in intranode (#6045)

* enhance all_to_all_fp8 with internode comm control

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* disable some fp8 ops due to performance issue

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/6053/head
Hanks 2024-09-09 13:47:17 +08:00 committed by GitHub
parent 26e553937b
commit 5ce6dd75bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 73 additions and 6 deletions

View File

@ -1,3 +1,4 @@
import os
from typing import Any, Optional, Tuple
import numpy as np
@ -23,6 +24,24 @@ class Handle:
self.remain_ops()
def process_group_is_intranode(pg):
if pg is None:
from torch.distributed.distributed_c10d import _get_default_group
pg = _get_default_group()
local_world_size = None
for var in ["LOCAL_WORLD_SIZE", "OMPI_COMM_WORLD_LOCAL_SIZE", "SLURM_TASKS_PER_NODE"]:
if var in os.environ:
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
if local_world_size is None:
local_world_size = torch.cuda.device_count()
group_ranks = dist.get_process_group_ranks(pg)
group_ranks_node_ids = [rank // local_world_size for rank in group_ranks]
return min(group_ranks_node_ids) == max(group_ranks_node_ids)
def cast_to_fp8(
inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None
) -> Tuple[torch.Tensor, torch.Tensor]:
@ -92,7 +111,7 @@ def cast_from_fp8(
return ret.to(ret_type)
def all_reduce_fp8(
def _all_reduce_fp8(
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
) -> Optional[Handle]:
r"""
@ -159,7 +178,15 @@ def all_reduce_fp8(
cat_op()
def all_to_all_single_fp8(
def all_reduce_fp8(
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
) -> Optional[Handle]:
# fall back to default op due to performance issue
return dist.all_reduce(tensor, op=op, group=group, async_op=async_op)
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
def _all_to_all_single_fp8(
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
) -> Optional[Handle]:
r"""
@ -222,6 +249,33 @@ def all_to_all_single_fp8(
cast_op()
def all_to_all_single_fp8(
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
) -> Optional[Handle]:
r"""
This is wrapper for _all_to_all_single_fp8.
"""
if process_group_is_intranode(group):
return dist.all_to_all_single(
output,
input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=async_op,
)
else:
return _all_to_all_single_fp8(
output,
input,
fp8_format=fp8_format,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=async_op,
)
def cast_to_fp8_pipeline(inp: Any) -> None:
"""
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
@ -293,7 +347,7 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
del inp["dtype"]
def reduce_scatter_fp8(
def _reduce_scatter_fp8(
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
) -> Optional[Handle]:
r"""
@ -338,6 +392,13 @@ def reduce_scatter_fp8(
cast_op()
def reduce_scatter_fp8(
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
) -> Optional[Handle]:
# fall back to default op due to performance issue
return dist.reduce_scatter(output, input_list, group=group, async_op=async_op)
def fp8_compress_ddp_grad_comm_hook_async(
process_group: dist.ProcessGroup,
bucket: dist.GradBucket,
@ -617,10 +678,9 @@ def all_gather_into_tensor_flat_fp8(
cast_op()
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
def _all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
world_size = dist.get_world_size(group)
input_type = input_list[0].dtype
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
scale_list = []
@ -651,6 +711,13 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async
cast_op()
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
if process_group_is_intranode(group):
return dist.all_to_all(output_list, input_list, group=group, async_op=async_op)
else:
return _all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format, async_op=async_op)
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
world_size = dist.get_world_size(group)