mirror of https://github.com/hpcaitech/ColossalAI
[fp8] disable all_to_all_fp8 in intranode (#6045)
* enhance all_to_all_fp8 with internode comm control * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * disable some fp8 ops due to performance issue * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/6053/head
parent
26e553937b
commit
5ce6dd75bf
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -23,6 +24,24 @@ class Handle:
|
||||||
self.remain_ops()
|
self.remain_ops()
|
||||||
|
|
||||||
|
|
||||||
|
def process_group_is_intranode(pg):
|
||||||
|
if pg is None:
|
||||||
|
from torch.distributed.distributed_c10d import _get_default_group
|
||||||
|
|
||||||
|
pg = _get_default_group()
|
||||||
|
|
||||||
|
local_world_size = None
|
||||||
|
for var in ["LOCAL_WORLD_SIZE", "OMPI_COMM_WORLD_LOCAL_SIZE", "SLURM_TASKS_PER_NODE"]:
|
||||||
|
if var in os.environ:
|
||||||
|
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
|
||||||
|
if local_world_size is None:
|
||||||
|
local_world_size = torch.cuda.device_count()
|
||||||
|
|
||||||
|
group_ranks = dist.get_process_group_ranks(pg)
|
||||||
|
group_ranks_node_ids = [rank // local_world_size for rank in group_ranks]
|
||||||
|
return min(group_ranks_node_ids) == max(group_ranks_node_ids)
|
||||||
|
|
||||||
|
|
||||||
def cast_to_fp8(
|
def cast_to_fp8(
|
||||||
inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None
|
inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False, out=None
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
@ -92,7 +111,7 @@ def cast_from_fp8(
|
||||||
return ret.to(ret_type)
|
return ret.to(ret_type)
|
||||||
|
|
||||||
|
|
||||||
def all_reduce_fp8(
|
def _all_reduce_fp8(
|
||||||
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
|
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
|
||||||
) -> Optional[Handle]:
|
) -> Optional[Handle]:
|
||||||
r"""
|
r"""
|
||||||
|
@ -159,7 +178,15 @@ def all_reduce_fp8(
|
||||||
cat_op()
|
cat_op()
|
||||||
|
|
||||||
|
|
||||||
def all_to_all_single_fp8(
|
def all_reduce_fp8(
|
||||||
|
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
|
||||||
|
) -> Optional[Handle]:
|
||||||
|
# fall back to default op due to performance issue
|
||||||
|
return dist.all_reduce(tensor, op=op, group=group, async_op=async_op)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
|
||||||
|
def _all_to_all_single_fp8(
|
||||||
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
|
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
|
||||||
) -> Optional[Handle]:
|
) -> Optional[Handle]:
|
||||||
r"""
|
r"""
|
||||||
|
@ -222,6 +249,33 @@ def all_to_all_single_fp8(
|
||||||
cast_op()
|
cast_op()
|
||||||
|
|
||||||
|
|
||||||
|
def all_to_all_single_fp8(
|
||||||
|
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
|
||||||
|
) -> Optional[Handle]:
|
||||||
|
r"""
|
||||||
|
This is wrapper for _all_to_all_single_fp8.
|
||||||
|
"""
|
||||||
|
if process_group_is_intranode(group):
|
||||||
|
return dist.all_to_all_single(
|
||||||
|
output,
|
||||||
|
input,
|
||||||
|
output_split_sizes=output_split_sizes,
|
||||||
|
input_split_sizes=input_split_sizes,
|
||||||
|
group=group,
|
||||||
|
async_op=async_op,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return _all_to_all_single_fp8(
|
||||||
|
output,
|
||||||
|
input,
|
||||||
|
fp8_format=fp8_format,
|
||||||
|
output_split_sizes=output_split_sizes,
|
||||||
|
input_split_sizes=input_split_sizes,
|
||||||
|
group=group,
|
||||||
|
async_op=async_op,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def cast_to_fp8_pipeline(inp: Any) -> None:
|
def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||||
"""
|
"""
|
||||||
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
|
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
|
||||||
|
@ -293,7 +347,7 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
||||||
del inp["dtype"]
|
del inp["dtype"]
|
||||||
|
|
||||||
|
|
||||||
def reduce_scatter_fp8(
|
def _reduce_scatter_fp8(
|
||||||
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
|
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
|
||||||
) -> Optional[Handle]:
|
) -> Optional[Handle]:
|
||||||
r"""
|
r"""
|
||||||
|
@ -338,6 +392,13 @@ def reduce_scatter_fp8(
|
||||||
cast_op()
|
cast_op()
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_scatter_fp8(
|
||||||
|
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
|
||||||
|
) -> Optional[Handle]:
|
||||||
|
# fall back to default op due to performance issue
|
||||||
|
return dist.reduce_scatter(output, input_list, group=group, async_op=async_op)
|
||||||
|
|
||||||
|
|
||||||
def fp8_compress_ddp_grad_comm_hook_async(
|
def fp8_compress_ddp_grad_comm_hook_async(
|
||||||
process_group: dist.ProcessGroup,
|
process_group: dist.ProcessGroup,
|
||||||
bucket: dist.GradBucket,
|
bucket: dist.GradBucket,
|
||||||
|
@ -617,10 +678,9 @@ def all_gather_into_tensor_flat_fp8(
|
||||||
cast_op()
|
cast_op()
|
||||||
|
|
||||||
|
|
||||||
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
|
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
|
||||||
|
def _all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
|
||||||
world_size = dist.get_world_size(group)
|
world_size = dist.get_world_size(group)
|
||||||
|
|
||||||
input_type = input_list[0].dtype
|
input_type = input_list[0].dtype
|
||||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||||
scale_list = []
|
scale_list = []
|
||||||
|
@ -651,6 +711,13 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async
|
||||||
cast_op()
|
cast_op()
|
||||||
|
|
||||||
|
|
||||||
|
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
|
||||||
|
if process_group_is_intranode(group):
|
||||||
|
return dist.all_to_all(output_list, input_list, group=group, async_op=async_op)
|
||||||
|
else:
|
||||||
|
return _all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format, async_op=async_op)
|
||||||
|
|
||||||
|
|
||||||
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
|
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
|
||||||
|
|
||||||
world_size = dist.get_world_size(group)
|
world_size = dist.get_world_size(group)
|
||||||
|
|
Loading…
Reference in New Issue