mirror of https://github.com/hpcaitech/ColossalAI
[fp8] disable all_to_all_fp8 in intranode (#6045)
* enhance all_to_all_fp8 with internode comm control * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * disable some fp8 ops due to performance issue * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/6053/head
parent
26e553937b
commit
5ce6dd75bf
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
@ -23,6 +24,24 @@ class Handle:
|
|||
self.remain_ops()
|
||||
|
||||
|
||||
def process_group_is_intranode(pg):
|
||||
if pg is None:
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
pg = _get_default_group()
|
||||
|
||||
local_world_size = None
|
||||
for var in ["LOCAL_WORLD_SIZE", "OMPI_COMM_WORLD_LOCAL_SIZE", "SLURM_TASKS_PER_NODE"]:
|
||||
if var in os.environ:
|
||||
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
|
||||
if local_world_size is None:
|
||||
local_world_size = torch.cuda.device_count()
|
||||
|
||||
group_ranks = dist.get_process_group_ranks(pg)
|
||||
group_ranks_node_ids = [rank // local_world_size for rank in group_ranks]
|
||||
return min(group_ranks_node_ids) == max(group_ranks_node_ids)
|
||||
|
||||
|
||||
def cast_to_fp8(
|
||||
inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
@ -92,7 +111,7 @@ def cast_from_fp8(
|
|||
return ret.to(ret_type)
|
||||
|
||||
|
||||
def all_reduce_fp8(
|
||||
def _all_reduce_fp8(
|
||||
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
|
||||
) -> Optional[Handle]:
|
||||
r"""
|
||||
|
@ -159,7 +178,15 @@ def all_reduce_fp8(
|
|||
cat_op()
|
||||
|
||||
|
||||
def all_to_all_single_fp8(
|
||||
def all_reduce_fp8(
|
||||
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
|
||||
) -> Optional[Handle]:
|
||||
# fall back to default op due to performance issue
|
||||
return dist.all_reduce(tensor, op=op, group=group, async_op=async_op)
|
||||
|
||||
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
|
||||
def _all_to_all_single_fp8(
|
||||
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
|
||||
) -> Optional[Handle]:
|
||||
r"""
|
||||
|
@ -222,6 +249,33 @@ def all_to_all_single_fp8(
|
|||
cast_op()
|
||||
|
||||
|
||||
def all_to_all_single_fp8(
|
||||
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
|
||||
) -> Optional[Handle]:
|
||||
r"""
|
||||
This is wrapper for _all_to_all_single_fp8.
|
||||
"""
|
||||
if process_group_is_intranode(group):
|
||||
return dist.all_to_all_single(
|
||||
output,
|
||||
input,
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
async_op=async_op,
|
||||
)
|
||||
else:
|
||||
return _all_to_all_single_fp8(
|
||||
output,
|
||||
input,
|
||||
fp8_format=fp8_format,
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
async_op=async_op,
|
||||
)
|
||||
|
||||
|
||||
def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||
"""
|
||||
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
|
||||
|
@ -293,7 +347,7 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
|||
del inp["dtype"]
|
||||
|
||||
|
||||
def reduce_scatter_fp8(
|
||||
def _reduce_scatter_fp8(
|
||||
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
|
||||
) -> Optional[Handle]:
|
||||
r"""
|
||||
|
@ -338,6 +392,13 @@ def reduce_scatter_fp8(
|
|||
cast_op()
|
||||
|
||||
|
||||
def reduce_scatter_fp8(
|
||||
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
|
||||
) -> Optional[Handle]:
|
||||
# fall back to default op due to performance issue
|
||||
return dist.reduce_scatter(output, input_list, group=group, async_op=async_op)
|
||||
|
||||
|
||||
def fp8_compress_ddp_grad_comm_hook_async(
|
||||
process_group: dist.ProcessGroup,
|
||||
bucket: dist.GradBucket,
|
||||
|
@ -617,10 +678,9 @@ def all_gather_into_tensor_flat_fp8(
|
|||
cast_op()
|
||||
|
||||
|
||||
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
|
||||
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
|
||||
def _all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
|
||||
world_size = dist.get_world_size(group)
|
||||
|
||||
input_type = input_list[0].dtype
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
scale_list = []
|
||||
|
@ -651,6 +711,13 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async
|
|||
cast_op()
|
||||
|
||||
|
||||
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
|
||||
if process_group_is_intranode(group):
|
||||
return dist.all_to_all(output_list, input_list, group=group, async_op=async_op)
|
||||
else:
|
||||
return _all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format, async_op=async_op)
|
||||
|
||||
|
||||
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
|
||||
|
||||
world_size = dist.get_world_size(group)
|
||||
|
|
Loading…
Reference in New Issue