Merge pull request #6012 from hpcaitech/feature/fp8_comm

[fp8]  support fp8 communication and fp8 training for Colossalai
pull/6036/head
Hongxin Liu 2024-08-27 10:09:43 +08:00 committed by GitHub
commit 17904cb5bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
70 changed files with 2856 additions and 267 deletions

View File

@ -9,6 +9,7 @@ on:
paths:
- "examples/**"
- "!examples/**.md"
- ".github/workflows/example_check_on_pr.yml"
jobs:
# This is for changed example files detect and output a matrix containing all the corresponding directory name.
@ -107,7 +108,7 @@ jobs:
- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install -v .
BUILD_EXT=1 pip install -v -e .
- name: Store Colossal-AI Cache
run: |

View File

@ -366,7 +366,9 @@ class GeminiPlugin(DPPluginBase):
enable_jit_fused: bool = False,
enable_sequence_overlap: bool = False,
enable_async_reduce: bool = True,
use_fp8: bool = False,
verbose: bool = False,
fp8_communication: bool = False,
) -> None:
super().__init__()
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
@ -401,6 +403,8 @@ class GeminiPlugin(DPPluginBase):
master_weights=master_weights,
max_prefetch=max_prefetch,
enable_async_reduce=enable_async_reduce,
fp8_communication=fp8_communication,
use_fp8=use_fp8,
)
self.zero_optim_config = dict(
gpu_margin_mem_ratio=gpu_margin_mem_ratio,

View File

@ -31,6 +31,7 @@ from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp
from colossalai.shardformer.policies.base_policy import Policy
@ -66,6 +67,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
ddp_config: dict,
custom_policy: Policy,
overlap_allgather: bool = False,
use_fp8: bool = False,
) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.shard_config = shard_config
@ -75,6 +77,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
self.use_ddp = use_ddp
self.require_grad_sync = True
self.overlap_allgather = overlap_allgather
self.use_fp8 = use_fp8
shardformer = ShardFormer(shard_config)
if custom_policy is not None:
@ -112,6 +115,9 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
module = DDP(module, process_group=dp_group, **ddp_config)
super().__init__(module)
self.op_hooks = []
if use_fp8:
self.op_hooks.append(FP8Hook())
if overlap_allgather:
self.op_hook = ZeroOpHook()
for p in module.parameters():
@ -223,7 +229,11 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
wait_all_gather_handle(p)
def _wait_all_gather(self):
return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
return (
ColoParamOpHookManager.use_hooks(*self.op_hooks)
if (self.overlap_allgather or self.use_fp8)
else nullcontext()
)
def get_param_info(optim: Optimizer):
@ -969,6 +979,7 @@ class HybridParallelPlugin(PipelinePluginBase):
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn".
It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default.
@ -1020,6 +1031,8 @@ class HybridParallelPlugin(PipelinePluginBase):
dp_outside: bool = True,
overlap_p2p: bool = True,
overlap_allgather: bool = False,
fp8_communication: bool = False,
use_fp8: bool = False,
inner_ring_size: int = None,
) -> None:
super().__init__()
@ -1069,8 +1082,10 @@ class HybridParallelPlugin(PipelinePluginBase):
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.use_fp8 = use_fp8
if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
if sequence_parallelism_mode == "ring_attn":
# Swap tp and sp since 2D Ring has better inter-node latency
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size)
@ -1117,6 +1132,7 @@ class HybridParallelPlugin(PipelinePluginBase):
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
overlap_p2p=overlap_p2p,
fp8_communication=fp8_communication,
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
@ -1124,6 +1140,7 @@ class HybridParallelPlugin(PipelinePluginBase):
num_microbatches=num_microbatches,
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
fp8_communication=fp8_communication,
)
else:
raise NotImplementedError()
@ -1158,6 +1175,7 @@ class HybridParallelPlugin(PipelinePluginBase):
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
inner_ring_size=inner_ring_size,
)
self.amp_config = dict(
@ -1250,7 +1268,7 @@ class HybridParallelPlugin(PipelinePluginBase):
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1 and self.pp_size == 1
)
# sync gradients across DP * SP ranks
# Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
@ -1268,6 +1286,7 @@ class HybridParallelPlugin(PipelinePluginBase):
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
use_fp8=self.use_fp8,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if zero_stage == 0:

View File

@ -34,6 +34,7 @@ from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero import LowLevelZeroOptimizer
@ -62,7 +63,12 @@ class OptimizerParamCheckState(enum.Enum):
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(
self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True
self,
module: nn.Module,
precision: str,
overlap_allgather: bool = False,
cast_inputs: bool = True,
use_fp8: bool = False,
) -> None:
super().__init__(module)
self.dtype = None
@ -75,11 +81,16 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
module = module.to(get_accelerator().get_current_device())
self.module = module
self.convert_fn = None
self.use_fp8 = use_fp8
if self.dtype is not None and cast_inputs:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_allgather = overlap_allgather
self.op_hooks = []
if overlap_allgather:
self.op_hook = ZeroOpHook()
self.op_hooks.append(ZeroOpHook())
if use_fp8:
self.op_hooks.append(FP8Hook())
if overlap_allgather or use_fp8:
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter
@ -337,6 +348,8 @@ class LowLevelZeroPlugin(DPPluginBase):
master_weights: bool = True,
verbose: bool = False,
cast_inputs: bool = True,
fp8_communication: bool = False,
use_fp8: bool = False,
) -> None:
super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
@ -360,12 +373,14 @@ class LowLevelZeroPlugin(DPPluginBase):
cpu_offload=cpu_offload,
master_weights=master_weights,
overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
)
self.lora_enabled = False
self.verbose = verbose
self.logger = get_dist_logger()
self.cast_inputs = cast_inputs
self.use_fp8 = use_fp8
# set class name with stage, for better error message
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
@ -484,6 +499,7 @@ class LowLevelZeroPlugin(DPPluginBase):
self.precision,
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
cast_inputs=self.cast_inputs,
use_fp8=self.use_fp8,
)
# TODO: Support Galore + ZeRO

View File

@ -214,6 +214,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
moe_dp_outside: bool = True,
overlap_p2p: bool = True,
overlap_allgather: bool = False,
fp8_communication: bool = False,
use_fp8: bool = False,
) -> None:
self.logger = get_dist_logger()
if overlap_communication or zero_stage == 2:
@ -327,6 +329,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
self.use_fp8 = use_fp8
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
@ -345,6 +348,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
)
self.amp_config = dict(
initial_scale=initial_scale,
@ -431,6 +435,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
use_fp8=self.use_fp8,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.ep_size > 1:

View File

@ -179,6 +179,7 @@ class TorchDDPPlugin(DPPluginBase):
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
fp8_communication: bool = False,
) -> None:
super().__init__()
self.ddp_kwargs = dict(
@ -189,6 +190,7 @@ class TorchDDPPlugin(DPPluginBase):
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
)
self.fp8_communication = fp8_communication
def support_no_sync(self) -> bool:
return True
@ -228,6 +230,11 @@ class TorchDDPPlugin(DPPluginBase):
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)
if self.fp8_communication:
from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async
model.module.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_async)
return model, optimizer, criterion, dataloader, lr_scheduler
def control_checkpoint_io(self) -> bool:

View File

@ -298,6 +298,7 @@ class TorchFSDPPlugin(DPPluginBase):
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
sync_module_states: bool = False,
fp8_communication: bool = False,
):
super().__init__()
self.fsdp_kwargs = dict(
@ -311,6 +312,7 @@ class TorchFSDPPlugin(DPPluginBase):
param_init_fn=param_init_fn,
sync_module_states=sync_module_states,
)
self.fp8_communication = fp8_communication
self.logger = get_dist_logger()
else:
@ -348,6 +350,19 @@ class TorchFSDPPlugin(DPPluginBase):
# wrap the model with PyTorch FSDP
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
if self.fp8_communication:
from colossalai.quantization.utils import patch_fsdp_params_comm_hook
patch_fsdp_params_comm_hook()
from colossalai.quantization.fp8 import fp8_compress_fsdp_params_comm_hook
fsdp_model.module.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook)
from colossalai.quantization.fp8 import fp8_compress_fsdp_grad_comm_hook
fsdp_model.module.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook)
if optimizer is not None:
if len(optimizer.param_groups) > 1:
self.logger.warning(

View File

@ -6,6 +6,8 @@ from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
from colossalai.quantization.fp8 import all_to_all_single_fp8
MOE_KERNEL = None
@ -380,6 +382,7 @@ def _all_to_all(
output_split_sizes: Optional[List[int]] = None,
group=None,
async_op: bool = False,
fp8_communication: bool = False,
):
"""
Returns:
@ -392,9 +395,14 @@ def _all_to_all(
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
inputs = inputs.contiguous()
outputs = outputs.contiguous()
handle = dist.all_to_all_single(
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
)
if fp8_communication:
handle = all_to_all_single_fp8(
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=False
)
else:
handle = dist.all_to_all_single(
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
)
return outputs, handle
@ -407,6 +415,7 @@ class AllToAllUneven(torch.autograd.Function):
output_split_sizes=None,
group=None,
overlap: bool = False,
fp8_communication: bool = False,
):
"""
Returns:
@ -416,7 +425,9 @@ class AllToAllUneven(torch.autograd.Function):
ctx.input_split_sizes = input_split_sizes
ctx.output_split_sizes = output_split_sizes
ctx.group = group
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
return _all_to_all(
inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication=fp8_communication
)
@staticmethod
def backward(ctx: Any, *grad_outputs):
@ -426,6 +437,7 @@ class AllToAllUneven(torch.autograd.Function):
None,
None,
None,
None,
)
@ -435,8 +447,9 @@ def all_to_all_uneven(
output_split_sizes: Optional[List[int]] = None,
group=None,
overlap: bool = False,
fp8_communication: bool = False,
):
assert (
inputs.requires_grad
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication)

View File

@ -11,6 +11,7 @@ from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline
from colossalai.utils import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
@ -32,6 +33,7 @@ class InterleavedSchedule(PipelineSchedule):
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
overlap_p2p: bool = True,
fp8_communication: bool = False,
) -> None:
super().__init__(stage_manager)
assert (
@ -56,6 +58,8 @@ class InterleavedSchedule(PipelineSchedule):
self.tensor_metadata_recv = None
self.grad_metadata_recv = None
self.fp8_communication = fp8_communication
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
@ -191,8 +195,12 @@ class InterleavedSchedule(PipelineSchedule):
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
if self.fp8_communication:
cast_to_fp8_pipeline(output_tensor)
send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor)
return send_handles
return []
@ -210,10 +218,14 @@ class InterleavedSchedule(PipelineSchedule):
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
if self.fp8_communication:
cast_to_fp8_pipeline(input_tensor_grad)
send_handles = self.comm.send_backward(
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
)
self.send_grad_metadata = not self.enable_metadata_cache
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor_grad)
return send_handles
return []
@ -224,6 +236,8 @@ class InterleavedSchedule(PipelineSchedule):
is_send = not self.stage_manager.is_last_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
is_recv = not self.stage_manager.is_first_stage()
if self.fp8_communication:
cast_to_fp8_pipeline(output_tensor)
input_tensor, wait_handles = self.comm.send_forward_recv_forward(
output_tensor,
is_send,
@ -237,6 +251,8 @@ class InterleavedSchedule(PipelineSchedule):
if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor)
return input_tensor, wait_handles
def send_backward_recv_backward(
@ -246,6 +262,8 @@ class InterleavedSchedule(PipelineSchedule):
is_send = not self.stage_manager.is_first_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
is_recv = not self.stage_manager.is_last_stage()
if self.fp8_communication:
cast_to_fp8_pipeline(input_tensor_grad)
output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward(
input_tensor_grad,
is_send,
@ -258,6 +276,8 @@ class InterleavedSchedule(PipelineSchedule):
self.send_grad_metadata = not self.enable_metadata_cache and is_send
if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor_grad)
return output_tensor_grad, wait_handles
def forward_step(
@ -378,6 +398,8 @@ class InterleavedSchedule(PipelineSchedule):
# Wait until current input is received
_wait_p2p(fwd_wait_handles)
if self.fp8_communication and input_obj is not None:
cast_from_fp8_pipeline(input_obj)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not last_batch:
@ -440,6 +462,8 @@ class InterleavedSchedule(PipelineSchedule):
# Wait for input
_wait_p2p(fwd_wait_handles)
if self.fp8_communication and input_obj is not None:
cast_from_fp8_pipeline(input_obj)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
@ -466,6 +490,8 @@ class InterleavedSchedule(PipelineSchedule):
# Wait for input.
_wait_p2p(fwd_wait_handles)
if self.fp8_communication and input_obj is not None:
cast_from_fp8_pipeline(input_obj)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
# Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj)
@ -510,6 +536,8 @@ class InterleavedSchedule(PipelineSchedule):
input_obj, fwd_wait_handles = send_forward_recv_forward()
# Wait for upstream grad
_wait_p2p(bwd_wait_handles)
if self.fp8_communication and output_obj_grad is not None:
cast_from_fp8_pipeline(output_obj_grad)
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
# NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv)
# risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html)
@ -531,6 +559,8 @@ class InterleavedSchedule(PipelineSchedule):
# Wait for upstream grad
_wait_p2p(bwd_wait_handles)
if self.fp8_communication and output_obj_grad is not None:
cast_from_fp8_pipeline(output_obj_grad)
# backward local grads
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
if not last_batch:

View File

@ -10,6 +10,7 @@ from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline
from colossalai.utils import get_current_device
from ._utils import (
@ -32,6 +33,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
fp8_communication: bool = False,
) -> None:
"""1F1B pipeline schedule.
@ -61,6 +63,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.tensor_metadata_recv = None
self.grad_metadata_recv = None
self.fp8_communication = fp8_communication
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
@ -129,6 +133,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor)
return input_tensor
def recv_backward(self, next_rank: int = None) -> Any:
@ -143,6 +149,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
"""
if not self.stage_manager.is_last_stage():
output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor_grad)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
@ -157,9 +165,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
if self.fp8_communication:
cast_to_fp8_pipeline(output_tensor)
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
For 1F1B.
@ -169,8 +182,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.stage_manager.is_first_stage():
if self.fp8_communication:
cast_to_fp8_pipeline(input_tensor_grad)
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
self.send_grad_metadata = not self.enable_metadata_cache
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False)
def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
@ -183,6 +200,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if not self.stage_manager.is_last_stage():
if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
send_first = None
if self.fp8_communication:
cast_to_fp8_pipeline(output_tensor)
output_tensor_grad, _ = self.comm.send_forward_recv_backward(
output_tensor,
send_metadata=self.send_tensor_metadata,
@ -192,6 +211,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
cast_from_fp8_pipeline(output_tensor_grad)
return output_tensor_grad
@ -206,6 +228,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if not self.stage_manager.is_first_stage():
if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
send_first = None # must not fallback
if self.fp8_communication:
cast_to_fp8_pipeline(input_tensor_grad)
input_tensor, _ = self.comm.send_backward_recv_forward(
input_tensor_grad,
send_metadata=self.send_grad_metadata,
@ -215,6 +239,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor)
cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False)
return input_tensor

View File

@ -0,0 +1,738 @@
from typing import Any, Optional, Tuple
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from packaging.version import Version
from torch.distributed import ReduceOp
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
class Handle:
def __init__(self, handles=[], remain_ops=None) -> None:
self.handles = handles
self.remain_ops = remain_ops
def wait(self):
for handle in self.handles:
handle.wait()
if self.remain_ops:
self.remain_ops()
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
Args:
inp: input torch Tensor, should be in torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor.
scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling
is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied.
fp8_format: e4m3 or e5m2
Returns:
Tuples: A tuple (fp8_tensor, scale)
"""
if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
raise TypeError("Only float16, bfloat16, and float32 are allowed.")
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
fp8_max = torch.finfo(fp8_type).max
if inp.numel() == 0:
return inp.to(fp8_type), torch.tensor([1.0], device=inp.device)
else:
if per_channel_scale:
per_channel_max = inp.abs().max(dim=-1).values.float()
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
scale = fp8_max / per_channel_max[:, None]
scale_inv = per_channel_max / fp8_max
else:
per_tensor_max = inp.abs().max().float()
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
scale = fp8_max / per_tensor_max
scale_inv = 1.0 / scale
ret = (scale * inp.float()).to(fp8_type)
return ret, scale_inv
def cast_from_fp8(
inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False
) -> torch.Tensor:
r"""
Args:
inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2].
scale: scaling factor returned by cast_to_fp8 function.
ret_type: the datatype of the returned tensor.
Returns:
torch.Tensor
"""
if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.")
if per_channel_scale:
ret = scale_inv[:, None] * inp.float()
else:
ret = scale_inv * inp.float()
return ret.to(ret_type)
def all_reduce_fp8(
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
) -> Optional[Handle]:
r"""
This is an in-place operation for compressed all_reduce using fp8.
It works like dist.all_reduce but during communication the data is cast to fp8 format.
Args:
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
fp8_format: e4m3 or e5m2
op: ReduceOp.SUM or ReduceOp.AVG
Returns:
None
"""
world_size = dist.get_world_size(group=group)
input_type = tensor.dtype
input_shape = tensor.shape
input_device = tensor.device
input_size = tensor.numel()
flat_padded_x = tensor.flatten()
assert op in [ReduceOp.SUM, ReduceOp.AVG], "op can only be ReduceOp.SUM or ReduceOp.AVG"
if flat_padded_x.size(0) % world_size != 0:
pad_size = world_size - flat_padded_x.size(0) % world_size
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
ret, scale = cast_to_fp8(flat_padded_x, fp8_format=fp8_format)
inp = ret.view(torch.uint8)
input_chunks = list(torch.chunk(inp, world_size, dim=0))
output_chunks = list(torch.chunk(torch.empty_like(inp), world_size, dim=0))
dist.all_to_all(output_chunks, input_chunks, group=group)
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
dist.all_gather(scale_list, scale, group=group)
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
for scale, out in zip(scale_list, output_chunks):
out = out.view(fp8_type)
summed_out += cast_from_fp8(out, scale, input_type)
if op == ReduceOp.AVG:
summed_out.div_(world_size)
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
gather_scale_handle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)]
gather_tensor_handle = dist.all_gather(
tensor_list, summed_out_fp8.view(torch.uint8), group=group, async_op=async_op
)
def cat_op():
for i in range(world_size):
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
out = torch.cat(tensor_list, dim=0)
tensor.copy_(out[:input_size].view(input_shape).to(input_type))
if async_op:
return Handle([gather_scale_handle, gather_tensor_handle], cat_op)
else:
cat_op()
def all_to_all_single_fp8(
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
) -> Optional[Handle]:
r"""
This is an in-place operation for compressed all_reduce using fp8.
It works like dist.all_to_all_single but during communication the data is cast to fp8 format.
Args:
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
fp8_format: e4m3 or e5m2
Returns:
None
"""
world_size = dist.get_world_size(group=group)
input_type = input.dtype
input_shape = input.shape
input_device = input.device
input = input.flatten()
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
ret, scale = cast_to_fp8(input, fp8_format=fp8_format)
inp = ret.view(torch.uint8)
if input_split_sizes is not None:
input_split_sizes = [input_split_sizes[i] * np.prod(input_shape[1:]) for i in range(world_size)]
input_chunks = list(torch.split(inp, input_split_sizes))
else:
input_chunks = list(torch.chunk(inp, world_size, dim=0))
if output_split_sizes is not None:
output_chunks = [
torch.empty((output_split_sizes[i] * np.prod(input_shape[1:]),), device=input_device, dtype=inp.dtype)
for i in range(world_size)
]
else:
if dist.get_rank() == world_size - 1:
output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)]
else:
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
chunk_handle = dist.all_to_all(output_chunks, input_chunks, group=group, async_op=async_op)
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
def cast_op():
cast_output_chunk = [
cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks)
]
tensor_out = torch.cat(cast_output_chunk, dim=0)
outputs_shape = list(input_shape)
if output_split_sizes is not None:
outputs_shape[0] = sum(output_split_sizes)
else:
outputs_shape = input_shape
output.data = tensor_out.view(outputs_shape).to(input_type)
if async_op:
return Handle([chunk_handle, scale_hanle], cast_op)
else:
cast_op()
def cast_to_fp8_pipeline(inp: Any) -> None:
"""
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
The activations tensor is indexed by 'hidden_states' in the inp dict.
After FP8 casting, the resulting tensor is saved as float16 or bfloat16 format but the size becomes halved.
Metadata such as fp8_scale is saved into inp dict for communication.
"""
if inp is None:
return
# In pipeline parallelism, when inp is torch.Tensor, it only contains one element, thus can be omitted.
if type(inp) == torch.Tensor:
return
assert "hidden_states" in inp, "required by pipeline parallelism."
assert (
inp["hidden_states"].size(-1) % 2 == 0
), "tensor size(-1) must be divisible by 2 to view Float8_e4m3fn as BFloat16 or Float16"
inp_tensor = inp["hidden_states"]
inp_dtype = inp_tensor.dtype
min_val, max_val = inp_tensor.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs())
finfo = torch.finfo(torch.float8_e4m3fn)
if amax > finfo.max:
fp8_type = torch.float8_e5m2
fp8_view_type = torch.float16
else:
fp8_type = torch.float8_e4m3fn
fp8_view_type = torch.bfloat16
finfo = torch.finfo(fp8_type)
scale = torch.tensor(1.0).to(inp_tensor.device) if amax == 0.0 else finfo.max / amax.float()
q_tensor = inp_tensor.data.float() * scale
# Todo: Currently we use fp8_view_type <float16, bfloat16> to indicate which fp8 format is used. This is a temporary workaround due to 'Only support tensor for fast send'.
# inp_tensor needs to be a float datatype to avoid error during gradient placement.
inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type)
inp["fp8_scale"] = scale.float().reciprocal()
inp["dtype"] = torch.zeros_like(scale).to(inp_dtype)
def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
"""
Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline.
del_metadata = False is useful when this function is called before p2p communication.
"""
if inp is None:
return
if type(inp) == torch.Tensor:
return
assert "hidden_states" in inp, "required by pipeline parallelism."
inp_tensor = inp["hidden_states"]
scale = inp["fp8_scale"]
fp8_view_type = inp_tensor.dtype
if fp8_view_type == torch.float16:
fp8_type = torch.float8_e5m2
elif fp8_view_type == torch.bfloat16:
fp8_type = torch.float8_e4m3fn
else:
raise TypeError("Only float16, bfloat16 are implemented.")
inp_tensor.data = inp_tensor.data.view(fp8_type).to(inp["dtype"]) * scale
if del_metadata:
del inp["fp8_scale"]
del inp["dtype"]
def reduce_scatter_fp8(
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
) -> Optional[Handle]:
r"""
This is an in-place operation for compressed reduce_scatter using fp8.
It works like dist.reduce_scatter but during communication the data is cast to fp8 format.
Args:
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
fp8_format: e4m3 or e5m2
Returns:
None
"""
input_type = output.dtype
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
scale_list = []
cast_input_list = []
output_chunks = []
output_scale_list = []
for input in input_list:
ret, scale = cast_to_fp8(input, fp8_format=fp8_format)
scale_list.append(scale)
ret = ret.view(torch.uint8)
cast_input_list.append(ret)
output_chunks.append(torch.empty_like(ret))
output_scale_list.append(torch.empty_like(scale))
chunk_handle = dist.all_to_all(output_chunks, cast_input_list, group=group, async_op=async_op)
scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op)
def cast_op():
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
for scale, out in zip(output_scale_list, output_chunks):
out = out.view(fp8_type)
summed_out += cast_from_fp8(out, scale, input_type)
output.data = summed_out
if async_op:
return Handle([chunk_handle, scale_handle], cast_op)
else:
cast_op()
def fp8_compress_ddp_grad_comm_hook_async(
process_group: dist.ProcessGroup,
bucket: dist.GradBucket,
fp8_format: str = "e5m2",
) -> torch.futures.Future[torch.Tensor]:
"""
Compress by casting ``GradBucket`` to FP8 floating-point format divided by process group size.
This DDP communication hook implements a simple gradient compression approach that casts ``GradBucket`` tensor
to FP8 floating-point format (``torch.float8_e5m2`` or ``torch.bfloat16_e4m3``), and then divides it
by the process group size.
Once compressed gradient tensors are allreduced, the chained callback ``decompress`` casts it back
to the input data type (such as ``float32``).
Example::
>>> ddp_model.register_comm_hook(process_group, fp8_compress_ddp_grad_comm_hook_async)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
input_tensor = bucket.buffer()
world_size = dist.get_world_size()
input_type = input_tensor.dtype
input_device = input_tensor.device
flat_padded_x = input_tensor.flatten()
if flat_padded_x.size(0) % world_size != 0:
pad_size = world_size - flat_padded_x.size(0) % world_size
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
ret, scale = cast_to_fp8(flat_padded_x, fp8_format=fp8_format)
inp = ret.view(torch.uint8)
output_chunks_single = torch.empty_like(inp)
split_sizes = [inp.numel() // world_size for _ in range(world_size)]
fut0 = dist.all_to_all_single(
output_chunks_single,
inp,
output_split_sizes=split_sizes,
input_split_sizes=split_sizes,
group=group_to_use,
async_op=True,
).get_future()
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
fut1 = dist.all_gather_into_tensor(
torch.cat(scale_list, dim=0), scale, group=group_to_use, async_op=True
).get_future()
all_to_all_fut = torch.futures.collect_all([fut0, fut1])
def sum_and_allgather(fut):
output_chunks_single = fut.value()[0].wait()[0]
scale_list_single = fut.value()[1].wait()[0]
output_chunks = list(torch.chunk(output_chunks_single, world_size, dim=0))
scale_list = scale_list_single.chunk(world_size, dim=0)
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
for scale, out in zip(scale_list, output_chunks):
out = out.view(fp8_type)
summed_out += cast_from_fp8(out, scale, input_type)
summed_out.div_(world_size)
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
tensor_list_single = torch.empty(summed_out_fp8.size(0) * world_size, device=input_device, dtype=torch.uint8)
fut2 = dist.all_gather_into_tensor(
tensor_list_single, summed_out_fp8.view(torch.uint8), group=group_to_use, async_op=True
).get_future()
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
fut3 = dist.all_gather_into_tensor(
torch.cat(scale_list, dim=0), scale, group=group_to_use, async_op=True
).get_future()
fut_combined2 = torch.futures.collect_all([fut2, fut3])
return fut_combined2
def decompress(fut):
tensor_list_single = fut.value().wait()[0].value()[0]
scale_list_single = fut.value().wait()[1].value()[0]
tensor_list = list(torch.chunk(tensor_list_single, world_size, dim=0))
scale_list = scale_list_single.chunk(world_size, dim=0)
for i in range(world_size):
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
out = torch.cat(tensor_list, dim=0)
input_tensor_size = input_tensor.numel()
input_shape = input_tensor.shape
out = out[:input_tensor_size]
input_tensor.copy_(out.view(input_shape).to(input_type))
return input_tensor
return all_to_all_fut.then(sum_and_allgather).then(decompress)
def fp8_compress_ddp_grad_comm_hook_sync(
process_group: dist.ProcessGroup,
bucket: dist.GradBucket,
fp8_format="e5m2",
) -> torch.futures.Future[torch.Tensor]:
"""
Return a future that wraps the input, after the input is allreduced. However, the allreduce commnunication is synchronized.
This breaks the overlapping between allreduce communication and backward compuation.
This hook should **only** be used for debugging purposes, instead of the normal gradient synchronization.
For asynchronized implementation, use fp8_compress_ddp_grad_comm_hook_async instead.
Example::
>>> # xdoctest: +SKIP
>>> ddp_model.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_sync)
"""
buffer = bucket.buffer()
all_reduce_fp8(buffer, fp8_format=fp8_format)
fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
fut.set_result(bucket.buffer())
return fut
def fp8_compress_fsdp_grad_comm_hook(
state: object,
unsharded_gradient_flattened: torch.Tensor,
sharded_gradient: torch.Tensor,
group=None,
fp8_format="e5m2",
) -> None:
"""
This communication hook implements a simple gradient compression approach that casts unsharded_gradient_flattened tensor
to FP8 floating-point format (``torch.float8_e5m2`` or ``torch.bfloat16_e4m3``), and then perform scatter_allreduce logic
by using all_to_all and all_gather among the process group.
Example::
>>> fsdp_model.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook)
"""
grad = unsharded_gradient_flattened
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
input_type = grad.dtype
input_device = grad.device
world_size = dist.get_world_size(group=group)
grad_fp8, scale = cast_to_fp8(grad, fp8_format=fp8_format)
uint8_buffer = torch.empty_like(grad_fp8).view(torch.uint8)
dist.all_to_all_single(uint8_buffer, grad_fp8.view(torch.uint8), group=group)
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
dist.all_gather(scale_list, scale, group=group)
buffer_list = list(torch.chunk(uint8_buffer.view(fp8_type), world_size, dim=0))
sharded_gradient.zero_()
for tensor, scale in zip(buffer_list, scale_list):
sharded_gradient += cast_from_fp8(tensor, scale, input_type)
def fp8_compress_fsdp_params_comm_hook(
state: object,
padded_unsharded_flat_param: torch.Tensor,
sharded_flat_param: torch.Tensor,
group=None,
fp8_format="e5m2",
) -> None:
"""
This hook is pending the official support for parameters communication hook in FSDP, e.g. register_params_comm_hook.
Example::
>>> fsdp_model.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook)
"""
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
fp8_max = torch.finfo(fp8_type).max
inp = sharded_flat_param
out = padded_unsharded_flat_param
per_tensor_max = inp.abs().max().float()
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
dist.all_reduce(per_tensor_max, op=torch.distributed.ReduceOp.MAX, group=group)
scale = fp8_max / per_tensor_max
fp8_sharded_flat_param = (scale * inp.float()).to(fp8_type).view(torch.uint8)
fp8_out = torch.empty(out.shape, dtype=torch.uint8, device=out.device)
dist.all_gather_into_tensor(
fp8_out,
fp8_sharded_flat_param,
group=group,
)
padded_unsharded_flat_param.copy_((fp8_out.view(fp8_type).float() / scale).to(out.dtype))
def split_chunk_by_channel(
chunk: torch.Tensor, channel_size: int, num_channels: int, rank: int = 0, world_size: int = 1
):
offset = chunk.numel() * rank
end = offset + chunk.numel()
break_points = [x for x in range(0, channel_size * num_channels + 1, channel_size) if offset <= x <= end]
if len(break_points) == 0 or break_points[0] > offset:
break_points.insert(0, offset)
if break_points[-1] < end:
break_points.append(end)
sizes = [b - a for a, b in zip(break_points[:-1], break_points[1:])]
return chunk.split(sizes)
def all_gather_into_tensor_flat_fp8(
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
output_shape: torch.Size,
group: dist.ProcessGroup,
fp8_format: str = "e4m3",
async_op: bool = False,
) -> Optional[Handle]:
"""all gather into tensor in fp8 format
Args:
output_tensor (torch.Tensor): output tensor, which is flattened
input_tensor (torch.Tensor): input tensor, which is flattened
group (dist.ProcessGroup): process group
fp8_format (str, optional): fp8 format, e4m3 or e5m2. Defaults to "e4m3".
"""
assert input_tensor.dim() == 1 and output_tensor.dim() == 1, "input/output tensor should be flattened"
world_size = dist.get_world_size(group)
assert (
output_tensor.numel() == input_tensor.numel() * world_size
), "output tensor size should be world_size times of input tensor size"
input_type = output_tensor.dtype
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
fp8_max = torch.finfo(fp8_type).max
if len(output_shape) == 2:
per_channel_max = torch.zeros(output_shape[0], device=output_tensor.device, dtype=torch.float)
num_channels, channel_size = output_shape
rank = dist.get_rank(group)
channel_start_idx = (input_tensor.numel() * rank) // channel_size
per_channel_splits = split_chunk_by_channel(input_tensor, channel_size, num_channels, rank, world_size)
for i, per_channel_split in enumerate(per_channel_splits):
idx = i + channel_start_idx
if idx < num_channels:
per_channel_max[idx] = per_channel_split.abs().max().float()
dist.all_reduce(per_channel_max, op=dist.ReduceOp.MAX, group=group)
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
scale = fp8_max / per_channel_max
fp8_input = input_tensor.float()
fp8_per_channel_splits = split_chunk_by_channel(fp8_input, channel_size, num_channels, rank, world_size)
for i, per_channel_split in enumerate(fp8_per_channel_splits):
idx = i + channel_start_idx
if idx < num_channels:
per_channel_split.mul_(scale[idx])
fp8_input = fp8_input.to(fp8_type)
else:
per_tensor_max = input_tensor.abs().max().float()
dist.all_reduce(per_tensor_max, op=dist.ReduceOp.MAX, group=group)
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
scale = fp8_max / per_tensor_max
fp8_input = (scale * input_tensor.float()).to(fp8_type)
scale_inv = 1.0 / scale
buffer = torch.empty_like(output_tensor, dtype=fp8_type)
tensor_handle = dist.all_gather_into_tensor(
buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group, async_op=async_op
)
def cast_op():
numel = output_shape.numel()
valid_buffer = buffer[:numel].reshape(output_shape)
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2))
output_tensor[:numel].copy_(valid_buffer.view(-1))
if async_op:
return Handle([tensor_handle], cast_op)
else:
cast_op()
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
world_size = dist.get_world_size(group)
input_type = input_list[0].dtype
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
scale_list = []
tensor_list = []
for i in range(world_size):
input_tensor = input_list[i]
ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format)
scale_list.append(scale)
ret = ret.view(torch.uint8)
tensor_list.append(ret)
output_scale_list = [torch.empty_like(x) for x in scale_list]
output_tensor_list = [torch.empty_like(x) for x in tensor_list]
tensor_hanle = dist.all_to_all(output_tensor_list, tensor_list, group=group, async_op=async_op)
scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op)
def cast_op():
for i in range(world_size):
scale = output_scale_list[i]
tensor = output_tensor_list[i]
tensor = tensor.view(fp8_type)
output_list[i].copy_(cast_from_fp8(tensor, scale, input_type))
if async_op:
return Handle([tensor_hanle, scale_handle], cast_op)
else:
cast_op()
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
world_size = dist.get_world_size(group)
input_type = input_.dtype
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format)
fp8_type = ret.dtype
input_ = ret.view(torch.uint8)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)]
chunk_handle = dist.all_gather(tensor_list, input_, group=group, async_op=async_op)
scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
def cast_op():
for i in range(world_size):
output = tensor_list[i].view(fp8_type)
scale = scale_list[i]
output_list[i].copy_(cast_from_fp8(output, scale, input_type))
if async_op:
return Handle([chunk_handle, scale_hanle], cast_op)
else:
cast_op()
class _LinearFp8(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
x: torch.Tensor,
w: torch.Tensor,
bias: Optional[torch.Tensor],
) -> Any:
assert (
x.dtype in (torch.bfloat16, torch.float16) and x.dtype == w.dtype
), "Only float16 and bfloat16 are allowed."
if bias is not None:
assert bias.dtype == x.dtype, "Bias should have the same dtype as input."
# ensure x and w are row-major
x = x.contiguous()
w = w.contiguous()
ctx.x_shape = x.shape
ctx.has_bias = bias is not None
ctx.out_dtype = x.dtype
x = x.reshape(-1, x.shape[-1])
x_fp8, inv_scale_x = cast_to_fp8(x, fp8_format="e4m3")
w_fp8, inv_scale_w = cast_to_fp8(w, fp8_format="e4m3")
ctx.x_fp8 = x_fp8
ctx.w_fp8_t = w_fp8.t()
ctx.inv_scale_x = inv_scale_x
ctx.inv_scale_w = inv_scale_w
out = torch._scaled_mm(
x_fp8,
ctx.w_fp8_t,
bias=bias,
out_dtype=ctx.out_dtype,
scale_a=inv_scale_x,
scale_b=inv_scale_w,
use_fast_accum=True,
)[0]
return out.reshape(*ctx.x_shape[:-1], w.shape[0])
@staticmethod
def backward(ctx: Any, out_grad) -> Any:
out_grad = out_grad.reshape(-1, out_grad.shape[-1])
out_grad_fp8, out_grad_scale = cast_to_fp8(out_grad, fp8_format="e5m2")
x_grad = torch._scaled_mm(
out_grad_fp8,
ctx.w_fp8_t.contiguous().t(),
out_dtype=ctx.out_dtype,
scale_a=out_grad_scale,
scale_b=ctx.inv_scale_w,
use_fast_accum=True,
)[0]
w_grad = torch._scaled_mm(
out_grad_fp8.t().contiguous(),
ctx.x_fp8.t().contiguous().t(),
out_dtype=ctx.out_dtype,
scale_a=out_grad_scale,
scale_b=ctx.inv_scale_x,
use_fast_accum=True,
)[0]
bias_grad = None
if ctx.has_bias:
bias_grad = out_grad.sum(0)
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False)
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
out = _linear_fp8(input, weight, bias)
return out

View File

@ -0,0 +1,23 @@
import torch.nn.functional as F
from colossalai.quantization.fp8 import linear_fp8
from colossalai.tensor.param_op_hook import ColoParamOpHook
class FP8Hook(ColoParamOpHook):
def pre_forward(self, params) -> None:
pass
def post_forward(self, params) -> None:
pass
def pre_backward(self, params) -> None:
pass
def post_backward(self, params) -> None:
pass
def rewrite_op(self, func):
if func is F.linear:
return linear_fp8
return func

View File

@ -0,0 +1,112 @@
import torch
import torch.distributed as dist
from packaging import version
from torch import Tensor
from torch.distributed.fsdp._common_utils import _no_dispatch_record_stream
from torch.distributed.utils import _p_assert
def _all_gather_flat_param(
self,
padded_unsharded_flat_param: Tensor,
) -> Tensor:
"""
All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``.
Then switch to use the all-gathered tensor.
"""
_p_assert(
hasattr(self, "process_group") and hasattr(self, "world_size"),
"Expects a process group and world size to have been set via `shard()`",
)
sharded_flat_param = self.flat_param.data
expected_numel = sharded_flat_param.numel() * self.world_size
_p_assert(
padded_unsharded_flat_param.numel() == expected_numel,
f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
)
pg = self._fake_process_group if self._use_fake_all_gather else self.process_group
# HACK this should be handled by C10D
if sharded_flat_param.is_cpu: # type: ignore[attr-defined]
tensor_list = list(torch.chunk(padded_unsharded_flat_param, dist.get_world_size(pg)))
work = dist.all_gather(tensor_list, sharded_flat_param, group=pg)
else:
if self._comm_hook is None:
dist.all_gather_into_tensor(
padded_unsharded_flat_param,
sharded_flat_param,
pg,
)
else:
self._comm_hook(None, padded_unsharded_flat_param, sharded_flat_param, pg)
if self._offload_params:
# In case of offloading, `flat_param.data` (i.e. sharded param) is
# created on the pre-unshard stream. We need to hand it over to the
# unshard stream for all-gather
_no_dispatch_record_stream(
sharded_flat_param,
self._device_handle.current_stream(), # unshard_stream
)
return padded_unsharded_flat_param
def register_params_comm_hook(self, state: object, hook: callable):
"""Register a communication hook for FlatParamHandle.
This is an enhancement that provides a flexible hook to users where they can specify how FSDP unshards
parameters across multiple workers.
.. warning ::
FSDP communication hook should be registered before running an initial forward pass
and only once.
Args:
state (object): Passed to the hook to maintain any state information during the training process.
hook (Callable): Callable, which has one of the following signatures:
1) ``hook: Callable[torch.Tensor] -> None``:
This function takes in a Python tensor, which represents
the full, flattened, unsharded gradient with respect to all variables
corresponding to the model this FSDP unit is wrapping
(that are not wrapped by other FSDP sub-units).
It then performs all necessary processing and returns ``None``;
2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``:
This function takes in two Python tensors, the first one represents
the full, flattened, unsharded gradient with respect to all variables
corresponding to the model this FSDP unit is wrapping
(that are not wrapped by other FSDP sub-units). The latter
represents a pre-sized tensor to store a chunk of a sharded gradient after
reduction.
In both cases, callable performs all necessary processing and returns ``None``.
Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case.
Callables with signature 2 are expected to handle gradient communication for sharded cases.
"""
if not self.check_is_root():
raise AssertionError("register_comm_hook can only be called on a root instance.")
# if fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES:
# raise AssertionError(
# f"Communication hook is not supported for hybrid strategies: {fsdp_state.sharding_strategy}"
# )
if self._handle._comm_hook is not None:
raise AssertionError("A communication hook is already registered")
if not callable(hook):
raise ValueError(f"The communication hook must be callable but got {hook}")
self._handle._comm_hook = hook
self._handle._comm_hook_state = state
def patch_fsdp_params_comm_hook():
if version.parse(torch.__version__) >= version.parse("2.2.0"):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._flat_param import FlatParamHandle
FlatParamHandle._comm_hook = None
FlatParamHandle._comm_hook_state = None
FlatParamHandle._all_gather_flat_param = _all_gather_flat_param
FSDP.register_params_comm_hook = register_params_comm_hook
else:
raise RuntimeError("This fsdp_params_comm_hook patch is not supported while torch version under 2.2.0.")

View File

@ -16,6 +16,14 @@ try:
except ImportError:
_grad_accum_fusion_available = False
from colossalai.quantization.fp8 import (
all_reduce_fp8,
all_to_all_fp8,
all_to_all_single_fp8,
gather_fp8,
reduce_scatter_fp8,
)
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
r"""Layernorm
@ -61,11 +69,12 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
ctx.fp8_communication = fp8_communication
output = torch.matmul(input_, weight)
@ -78,6 +87,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
fp8_communication = ctx.fp8_communication
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
weight = weight.view(weight.shape)
@ -92,7 +102,9 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_allreduce:
if ctx.async_grad_allreduce and fp8_communication:
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2")
elif ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
@ -101,10 +113,10 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce:
if ctx.async_grad_allreduce and not fp8_communication:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None
return grad_input, grad_weight, grad_bias, None, None, None, None
class LinearWithAsyncCommunication(torch.autograd.Function):
@ -113,11 +125,12 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
ctx.fp8_communication = fp8_communication
if bias is not None:
output = F.linear(input_, weight, bias)
else:
@ -129,6 +142,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
fp8_communication = ctx.fp8_communication
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
@ -144,10 +158,11 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
_ = torch.zeros(1, device=grad_input.device)
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
if fp8_communication:
all_reduce_fp8(grad_input, group=ctx.process_group)
else:
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
if _grad_accum_fusion_available and weight.grad is not None:
@ -165,10 +180,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce:
if ctx.async_grad_allreduce and not fp8_communication:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None
return grad_input, grad_weight, grad_bias, None, None, None, None
def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
@ -236,17 +251,18 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, process_group, dim):
def forward(ctx, input_, process_group, dim, fp8_communication=False):
ctx.process_group = process_group
ctx.dim = dim
ctx.fp8_communication = fp8_communication
return _gather(input_, dim, process_group)
return _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
@staticmethod
def backward(ctx, grad_output):
dim = ctx.dim
process_group = ctx.process_group
fp8_communication = ctx.fp8_communication
# do reduce-scatter
new_shape = list(grad_output.shape)
assert (
@ -257,9 +273,13 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device)
dist.reduce_scatter(output, grad_list, group=process_group)
return output, None, None
if fp8_communication:
reduce_scatter_fp8(output, grad_list, group=process_group, fp8_format="e5m2")
else:
dist.reduce_scatter(output, grad_list, group=process_group)
return output, None, None, None
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
@ -550,9 +570,10 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, process_group, dim):
def forward(ctx, input_, process_group, dim, fp8_communication=False):
ctx.dim = dim
ctx.process_group = process_group
ctx.fp8_communication = fp8_communication
# do reduce-scatter
new_shape = list(input_.shape)
@ -562,7 +583,10 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
dist.reduce_scatter(output, input_list, group=process_group)
if fp8_communication:
reduce_scatter_fp8(output, input_list, group=process_group, fp8_format="e4m3")
else:
dist.reduce_scatter(output, input_list, group=process_group)
return output
@ -570,8 +594,9 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
def backward(ctx, grad_output):
dim = ctx.dim
process_group = ctx.process_group
fp8_communication = ctx.fp8_communication
return _gather(grad_output, dim, process_group), None, None
return _gather(grad_output, dim, process_group, fp8_communication, fp8_format="e5m2"), None, None, None
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
@ -586,13 +611,16 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring):
def forward(
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap
ctx.fp8_communication = fp8_communication
if ring is True:
input_to_gather = {}
@ -609,7 +637,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
)
else:
input_parallel = _gather(input_, dim, process_group)
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
output = torch.matmul(input_parallel, weight)
@ -624,6 +652,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap
fp8_communication = ctx.fp8_communication
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
weight = weight.view(weight.shape)
@ -631,7 +660,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
bias = bias.view(bias.shape)
if not overlap:
input_parallel = _gather(input_, dim, process_group)
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2")
total_input = input_parallel
grad_input = grad_output.matmul(weight.T)
@ -691,7 +720,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
# wait until reduce-scatter finished
reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None, None
return output, grad_weight, grad_bias, None, None, None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function):
@ -706,17 +735,25 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, dim, process_group, grad_scale=None):
def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False):
ctx.process_group = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
ctx.fp8_communication = fp8_communication
return _split(input_, dim, process_group)
@staticmethod
def backward(ctx, grad_output):
if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale
return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None
return (
_gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication, fp8_format="e5m2"),
None,
None,
None,
None,
)
class _ReduceForward(torch.autograd.Function):
@ -730,15 +767,15 @@ class _ReduceForward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, process_group, grad_scale=None):
def forward(ctx, input_, process_group, grad_scale=None, fp8_communication=False):
ctx.grad_scale = grad_scale
return _reduce(input_, process_group)
return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3")
@staticmethod
def backward(ctx, grad_output):
if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale
return grad_output, None, None
return grad_output, None, None, None
class _ReduceBackward(torch.autograd.Function):
@ -751,13 +788,15 @@ class _ReduceBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, process_group):
def forward(ctx, input_, process_group, fp8_communication=False):
ctx.process_group = process_group
ctx.fp8_communication = fp8_communication
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output, ctx.process_group), None
fp8_communication = ctx.fp8_communication
return _reduce(grad_output, ctx.process_group, fp8_communication, fp8_format="e5m2"), None, None
class _GatherForwardSplitBackward(torch.autograd.Function):
@ -770,17 +809,18 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, dim, process_group, grad_scale=None):
def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False):
ctx.process_group = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
return _gather(input_, dim, process_group)
return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3")
@staticmethod
def backward(ctx, grad_output):
if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale
return _split(grad_output, ctx.dim, ctx.process_group), None, None, None
return _split(grad_output, ctx.dim, ctx.process_group), None, None, None, None
class _AllToAll(torch.autograd.Function):
@ -794,26 +834,67 @@ class _AllToAll(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
def forward(ctx, input_, process_group, scatter_dim, gather_dim, fp8_communication=False):
ctx.process_group = process_group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
ctx.fp8_communication = fp8_communication
world_size = dist.get_world_size(process_group)
bsz, _, _ = input_.shape
# using all_to_all_single when batch size is 1
if bsz == 1:
return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim)
return _all_to_all_single(
input_,
world_size,
process_group,
scatter_dim,
gather_dim,
fp8_communication=fp8_communication,
fp8_format="e4m3",
)
else:
return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim)
return _all_to_all(
input_,
world_size,
process_group,
scatter_dim,
gather_dim,
fp8_communication=fp8_communication,
fp8_format="e4m3",
)
@staticmethod
def backward(ctx, *grad_output):
def backward(ctx, grad_output):
process_group = ctx.process_group
scatter_dim = ctx.gather_dim
gather_dim = ctx.scatter_dim
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
return (return_grad, None, None, None)
fp8_communication = ctx.fp8_communication
world_size = dist.get_world_size(process_group)
bsz, _, _ = grad_output.shape
if bsz == 1:
return_grad = _all_to_all_single(
grad_output,
world_size,
process_group,
scatter_dim,
gather_dim,
fp8_communication=fp8_communication,
fp8_format="e5m2",
)
else:
return_grad = _all_to_all(
grad_output,
world_size,
process_group,
scatter_dim,
gather_dim,
fp8_communication=fp8_communication,
fp8_format="e5m2",
)
return (return_grad, None, None, None, None)
class HookParameter(torch.autograd.Function):
@ -839,12 +920,15 @@ def hook_parameter_in_backward(input, weight=None, bias=None):
return HookParameter.apply(input, weight, bias)
def _reduce(input_, process_group):
def _reduce(input_, process_group, fp8_communication=False, fp8_format="e5m2"):
# skip if only one rank involved
if dist.get_world_size(process_group) == 1:
return input_
else:
dist.all_reduce(input_, group=process_group)
if fp8_communication:
all_reduce_fp8(input_, group=process_group, fp8_format=fp8_format)
else:
dist.all_reduce(input_, group=process_group)
return input_
@ -868,18 +952,19 @@ def _split(input_, dim=-1, process_group=None):
return output
def _gather(input_, dim=-1, process_group=None):
def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e5m2"):
# skip if only one rank involved
world_size = dist.get_world_size(process_group)
if world_size == 1:
return input_
# all gather
input_ = input_.contiguous()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
torch.distributed.all_gather(tensor_list, input_, group=process_group)
if fp8_communication:
gather_fp8(tensor_list, input_, fp8_format=fp8_format, group=process_group)
else:
dist.all_gather(tensor_list, input_, group=process_group)
# concat
output = torch.cat(tensor_list, dim=dim).contiguous()
return output
@ -909,14 +994,19 @@ def _reduce_scatter(input_, dim=1, process_group=None):
return output
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim):
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"):
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
dist.all_to_all(output_list, input_list, group=group)
if fp8_communication:
all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format)
else:
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()
def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
def _all_to_all_single(
input_, seq_world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"
):
inp_shape = list(input_.shape)
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
if scatter_dim < 2:
@ -929,7 +1019,11 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
)
output = torch.empty_like(input_t)
dist.all_to_all_single(output, input_t, group=group)
if fp8_communication:
all_to_all_single_fp8(output, input_t, group=group, fp8_format=fp8_format)
else:
dist.all_to_all_single(output, input_t, group=group)
if scatter_dim < 2:
output = output.transpose(0, 1).contiguous()
@ -943,12 +1037,16 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
).contiguous()
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
return MatmulWithAsyncCommunication.apply(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication
)
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
return LinearWithAsyncCommunication.apply(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication
)
def linear_gather_forward_reducescatter_backward(
@ -959,12 +1057,12 @@ def linear_gather_forward_reducescatter_backward(
)
def gather_forward_reducescatter_backward(input_, process_group, dim):
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim)
def gather_forward_reducescatter_backward(input_, process_group, dim, fp8_communication=False):
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim, fp8_communication)
def reducescatter_forward_gather_backward(input_, process_group, dim):
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_communication=False):
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim, fp8_communication)
def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False):
@ -972,38 +1070,40 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc
def matmul_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False, fp8_communication=False
):
return _MatmulWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
)
def gather_forward_split_backward(input_, dim, process_group, grad_scale=None):
return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale)
def gather_forward_split_backward(input_, dim, process_group, grad_scale=None, fp8_communication=False):
return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale, fp8_communication)
def split_forward_gather_backward(input_, dim, process_group, grad_scale=None):
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale)
def split_forward_gather_backward(input_, dim, process_group, grad_scale=None, fp8_communication=False):
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale, fp8_communication)
def reduce_forward(input_, process_group, grad_scale=None):
return _ReduceForward.apply(input_, process_group, grad_scale)
def reduce_forward(input_, process_group, grad_scale=None, fp8_communication=False):
return _ReduceForward.apply(input_, process_group, grad_scale, fp8_communication)
def reduce_backward(input_, process_group):
return _ReduceBackward.apply(input_, process_group)
def reduce_backward(input_, process_group, fp8_communication=False):
return _ReduceBackward.apply(input_, process_group, fp8_communication)
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_communication=False):
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication)
def gather_sp_output(hidden_states, sp_group, sp_mode):
def gather_sp_output(hidden_states, sp_group, sp_mode, fp8_communication=False):
"""
Gather the output of the last layer for cross entropy computation
"""
# Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group)
scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group)
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=scale)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=scale, fp8_communication=fp8_communication
)
return hidden_states

View File

@ -68,6 +68,7 @@ class Embedding1D(ParallelModule):
gather_output: bool = True,
weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(),
fp8_communication: bool = False,
*args,
**kwargs,
):
@ -81,6 +82,7 @@ class Embedding1D(ParallelModule):
self.embed_args = args
self.embed_kwargs = kwargs
self.gather_output = gather_output
self.fp8_communication = fp8_communication
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
@ -155,7 +157,9 @@ class Embedding1D(ParallelModule):
def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
if self.gather_output:
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
output = gather_forward_split_backward(
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
return output
else:
return output_parallel
@ -274,6 +278,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(),
make_vocab_size_divisible_by: int = 64,
fp8_communication: bool = False,
*args,
**kwargs,
):
@ -282,6 +287,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
self.embed_args = args
self.embed_kwargs = kwargs
self.process_group = process_group
self.fp8_communication = fp8_communication
tensor_parallel_size = dist.get_world_size(group=process_group)
tensor_parallel_rank = dist.get_rank(group=process_group)
@ -390,5 +396,5 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
embedding_output = output_parallel.clone()
embedding_output[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_forward(embedding_output, self.process_group)
output = reduce_forward(embedding_output, self.process_group, fp8_communication=self.fp8_communication)
return output

View File

@ -84,6 +84,7 @@ class Linear1D_Col(ParallelModule):
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
**kwargs,
):
super().__init__(weight=weight, bias_=bias_, **kwargs)
@ -98,6 +99,7 @@ class Linear1D_Col(ParallelModule):
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
self.fp8_communication = fp8_communication
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
@ -202,19 +204,25 @@ class Linear1D_Col(ParallelModule):
if self.seq_parallel_mode == "split_gather":
input_parallel = gather_forward_reducescatter_backward(
input_parallel, self.process_group, self.seq_parallel_dim
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
)
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False)
elif self.seq_parallel_mode == "ring":
output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
)
else:
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
output = gather_forward_split_backward(
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
else:
output = output_parallel
@ -264,6 +272,7 @@ class Linear1D_Row(ParallelModule):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1,
fp8_communication: bool = False,
):
super().__init__()
@ -278,6 +287,7 @@ class Linear1D_Row(ParallelModule):
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.num_partitions = dist.get_world_size(self.process_group)
self.fp8_communication = fp8_communication
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
@ -398,7 +408,9 @@ class Linear1D_Row(ParallelModule):
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
)
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
input_ = split_forward_gather_backward(
input_, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
if self.stream_chunk_num > 1:
if self.training:
@ -416,10 +428,13 @@ class Linear1D_Row(ParallelModule):
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
if self.seq_parallel_mode == "split_gather":
if self.seq_parallel_mode is None:
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
elif self.seq_parallel_mode == "split_gather":
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
elif self.seq_parallel_mode == "ring":
output = linear_reducescatter_forward_gather_backward(
@ -562,6 +577,7 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
make_vocab_size_divisible_by: int = 64,
fp8_communication: bool = False,
**kwargs,
):
# create weight and bias
@ -592,6 +608,7 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
**kwargs,
new_num_embeddings=new_out_features,
old_num_embeddings=out_features,
fp8_communication=fp8_communication,
)
# get the length of valid embeddings
tp_rank = dist.get_rank(process_group)

View File

@ -183,6 +183,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
):
super().__init__()
@ -197,6 +198,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.n_fused = n_fused
self.process_group = process_group
self.async_communication = async_communication
self.fp8_communication = fp8_communication
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
@ -314,14 +316,26 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
if self.seq_parallel_mode is None:
# Set up backprop all-reduce.
input_parallel = reduce_backward(input_, self.process_group)
input_parallel = reduce_backward(input_, self.process_group, fp8_communication=self.fp8_communication)
output_parallel = matmul_with_async_comm(
input_parallel, self.weight, bias, self.process_group, self.async_communication
input_parallel,
self.weight,
bias,
self.process_group,
self.async_communication,
fp8_communication=self.fp8_communication,
)
elif self.seq_parallel_mode == "split_gather":
input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap
input_parallel,
self.weight,
bias,
self.process_group,
True,
1,
self.overlap,
fp8_communication=self.fp8_communication,
)
elif self.seq_parallel_mode == "ring":
input_parallel = input_
@ -331,7 +345,9 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
output = gather_forward_split_backward(
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
else:
output = output_parallel
@ -379,6 +395,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1,
fp8_communication: bool = False,
):
super().__init__()
@ -392,6 +409,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
self.process_group = process_group
self.seq_parallel_mode = seq_parallel_mode
self.num_partitions = dist.get_world_size(self.process_group)
self.fp8_communication = fp8_communication
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
@ -514,7 +532,9 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions
)
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
input_ = split_forward_gather_backward(
input_, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
if self.stream_chunk_num > 1:
if self.training:
@ -535,13 +555,20 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
else:
if self.seq_parallel_mode is None:
output_parallel = torch.matmul(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
elif self.seq_parallel_mode == "split_gather":
output_parallel = torch.matmul(input_, self.weight)
output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
output = reducescatter_forward_gather_backward(
output_parallel,
self.process_group,
1,
self.fp8_communication,
)
elif self.seq_parallel_mode == "ring":
output_parallel = torch.matmul(input_, self.weight)
output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, 1, self.fp8_communication
)
if not self.skip_bias_add:
if self.bias is not None:
@ -600,6 +627,7 @@ class FusedLinear1D_Col(ParallelModule):
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
):
super().__init__()
# Keep input parameters
@ -611,6 +639,7 @@ class FusedLinear1D_Col(ParallelModule):
self.n_fused = n_fused
self.process_group = process_group
self.async_communication = async_communication
self.fp8_communication = fp8_communication
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
@ -740,7 +769,9 @@ class FusedLinear1D_Col(ParallelModule):
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
output = gather_forward_split_backward(
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
else:
output = output_parallel

View File

@ -187,11 +187,17 @@ class BertPipelineForwards:
if shard_config is not None and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
if encoder_hidden_states is not None:
encoder_hidden_states = split_forward_gather_backward(
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
encoder_hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
@ -242,7 +248,10 @@ class BertPipelineForwards:
if shard_config is not None and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
if output_hidden_states:
@ -1135,11 +1144,17 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
embedding_output = split_forward_gather_backward(
embedding_output, dim=1, process_group=shard_config.tensor_parallel_process_group
embedding_output,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
if encoder_hidden_states is not None:
encoder_hidden_states = split_forward_gather_backward(
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
encoder_hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
encoder_outputs = self.encoder(
@ -1159,7 +1174,10 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
# When sequence parallelism done, gather the output tensor in forward and split it in backward
sequence_output = gather_forward_split_backward(
sequence_output, dim=1, process_group=shard_config.tensor_parallel_process_group
sequence_output,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

View File

@ -221,7 +221,10 @@ class BloomPipelineForwards:
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
start_idx, end_idx = stage_index[0], stage_index[1]
@ -264,7 +267,10 @@ class BloomPipelineForwards:
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
if stage_manager.is_last_stage():
@ -922,7 +928,10 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@ -960,7 +969,10 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
# When sequence parallelism done, gather the output tensor in forward and split it in backward
hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
# Add last hidden state
hidden_states = self.ln_f(hidden_states)

View File

@ -206,6 +206,15 @@ class ChatGLMPipelineForwards:
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
@ -245,6 +254,15 @@ class ChatGLMPipelineForwards:
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
@ -401,6 +419,12 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
if sp_mode in ["all_to_all"] and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with sp mode `{sp_mode}`. Setting `use_cache=False`..."
)
use_cache = False
if sp_mode in ["all_to_all"] and self.training:
if use_cache:
logger.warning_once(
@ -414,6 +438,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
inputs_embeds,
dim=0,
process_group=sp_group,
fp8_communication=shard_config.fp8_communication,
)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(
@ -421,6 +446,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
dim=0,
process_group=sp_group,
grad_scale=1 / sp_size,
fp8_communication=shard_config.fp8_communication,
)
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds,
@ -436,6 +462,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
@ -443,6 +470,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
dim=0,
process_group=sp_group,
grad_scale=sp_size,
fp8_communication=shard_config.fp8_communication,
)
if not return_dict:
@ -532,9 +560,24 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s
key_layer = key_layer.reshape(sq, bs, -1)
value_layer = value_layer.reshape(sq, bs, -1)
query_layer = all_to_all_comm(query_layer, sp_group, gather_dim=0)
key_layer = all_to_all_comm(key_layer, sp_group, gather_dim=0)
value_layer = all_to_all_comm(value_layer, sp_group, gather_dim=0)
query_layer = all_to_all_comm(
query_layer,
sp_group,
gather_dim=0,
fp8_communication=shard_config.fp8_communication,
)
key_layer = all_to_all_comm(
key_layer,
sp_group,
gather_dim=0,
fp8_communication=shard_config.fp8_communication,
)
value_layer = all_to_all_comm(
value_layer,
sp_group,
gather_dim=0,
fp8_communication=shard_config.fp8_communication,
)
query_layer = query_layer.view(
sq * sp_size,
@ -610,7 +653,13 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
if sp_mode == "all_to_all":
context_layer = all_to_all_comm(context_layer, sp_group, gather_dim=2, scatter_dim=0)
context_layer = all_to_all_comm(
context_layer,
sp_group,
gather_dim=2,
scatter_dim=0,
fp8_communication=shard_config.fp8_communication,
)
# =================
# Output. [sq, b, h]

View File

@ -142,6 +142,7 @@ class CommandPipelineForwards:
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
@ -149,6 +150,7 @@ class CommandPipelineForwards:
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
# decoder layers
@ -213,6 +215,7 @@ class CommandPipelineForwards:
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
@ -220,6 +223,7 @@ class CommandPipelineForwards:
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
# add hidden states from the last decoder layer
@ -384,9 +388,9 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
@ -448,7 +452,9 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
attn_output = all_to_all_comm(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -528,9 +534,13 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
hidden_states = inputs_embeds
# decoder layers
@ -575,9 +585,13 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer
if output_hidden_states:

View File

@ -24,6 +24,7 @@ from colossalai.moe._operation import (
all_to_all_uneven,
)
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization.fp8 import all_reduce_fp8
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
@ -61,7 +62,13 @@ class EPDeepseekMoE(nn.Module):
def __init__(self):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
def setup_process_groups(
self,
tp_group: ProcessGroup,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
fp8_communication: bool = False,
):
assert tp_group is not None
assert moe_dp_group is not None
assert ep_group is not None
@ -70,6 +77,7 @@ class EPDeepseekMoE(nn.Module):
self.ep_rank = dist.get_rank(ep_group)
self.num_experts = self.config.n_routed_experts
assert self.num_experts % self.ep_size == 0
self.fp8_communication = fp8_communication
self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size
@ -86,9 +94,15 @@ class EPDeepseekMoE(nn.Module):
self.tp_group = tp_group
if self.tp_group.size() > 1:
for expert in held_experts:
expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group)
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group)
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group)
expert.gate_proj = Linear1D_Col.from_native_module(
expert.gate_proj, self.tp_group, fp8_communication=self.fp8_communication
)
expert.up_proj = Linear1D_Col.from_native_module(
expert.up_proj, self.tp_group, fp8_communication=self.fp8_communication
)
expert.down_proj = Linear1D_Row.from_native_module(
expert.down_proj, self.tp_group, fp8_communication=self.fp8_communication
)
for p in self.experts.parameters():
set_moe_tensor_ep_group(p, ep_group)
@ -106,7 +120,8 @@ class EPDeepseekMoE(nn.Module):
if module.__class__.__name__ == "DeepseekMLP":
return module
module.__class__ = EPDeepseekMoE
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
fp8_communication = kwargs.get("fp8_communication", False)
module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication=fp8_communication)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -130,18 +145,32 @@ class EPDeepseekMoE(nn.Module):
output_split_sizes = torch.zeros_like(input_split_sizes)
# [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
dist.all_to_all_single(
output_split_sizes,
input_split_sizes,
group=self.ep_group,
)
with torch.no_grad():
activate_experts = output_split_sizes[: self.num_experts_per_ep].clone()
for i in range(1, self.ep_size):
activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
activate_experts = (activate_experts > 0).float()
dist.all_reduce(activate_experts, group=self.moe_dp_group)
if self.fp8_communication:
all_reduce_fp8(activate_experts, group=self.moe_dp_group)
else:
dist.all_reduce(activate_experts, group=self.moe_dp_group)
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
output_states, _ = all_to_all_uneven(
dispatch_states,
input_split_list,
output_split_list,
self.ep_group,
fp8_communication=self.fp8_communication,
)
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
if output_states.size(0) > 0:
@ -167,7 +196,9 @@ class EPDeepseekMoE(nn.Module):
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
dispatch_states, _ = all_to_all_uneven(
output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication
)
recover_token_idx = torch.empty_like(flat_topk_token_idx)
recover_token_idx[flat_topk_token_idx] = torch.arange(
flat_topk_token_idx.size(0), device=flat_topk_token_idx.device
@ -534,9 +565,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
@ -595,7 +626,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
attn_output = all_to_all_comm(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
) # (1, 4, 256)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -669,6 +702,7 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
# TODO: upgrade transformers to 4.44.0 to fix the bug, remove the hard code.
self._use_flash_attention_2 = shard_config.enable_flash_attention
self._use_sdpa = False if shard_config.enable_flash_attention else self._use_sdpa
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
@ -688,9 +722,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
)
if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
# embed positions
hidden_states = inputs_embeds
@ -734,9 +772,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer
if output_hidden_states:

View File

@ -221,6 +221,7 @@ class GPT2PipelineForwards:
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
# Going through held blocks.
@ -276,6 +277,7 @@ class GPT2PipelineForwards:
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
if stage_manager.is_last_stage():
@ -1119,6 +1121,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@ -1186,6 +1189,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
hidden_states = self.ln_f(hidden_states)

View File

@ -185,6 +185,7 @@ class GPTJPipelineForwards:
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
# Going through held blocks.
@ -236,6 +237,7 @@ class GPTJPipelineForwards:
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
if stage_manager.is_last_stage():
@ -915,6 +917,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@ -978,6 +981,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
hidden_states = self.ln_f(hidden_states)

View File

@ -162,9 +162,13 @@ class LlamaPipelineForwards:
hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group)
elif is_share_sp_tp(sp_mode):
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
hidden_states = split_forward_gather_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
hidden_states = split_forward_gather_backward(
hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
if self.gradient_checkpointing and self.training and use_cache:
if use_cache:
@ -227,7 +231,9 @@ class LlamaPipelineForwards:
if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode)
hidden_states = gather_sp_output(
hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer
if output_hidden_states:
@ -532,9 +538,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
@ -605,7 +611,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
attn_output = all_to_all_comm(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -707,9 +715,13 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
attn_kwargs = {"attention_mask_type": attn_kwargs["attention_mask_type"]} # drop redundant tensors
elif is_share_sp_tp(sp_mode):
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
hidden_states = inputs_embeds
# decoder layers
@ -754,7 +766,9 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
hidden_states = self.norm(hidden_states)
# Cases that don't support parallelizing cross entropy computation along sequence
if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather:
hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode)
hidden_states = gather_sp_output(
hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer
if output_hidden_states:

View File

@ -53,7 +53,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, *args, **kwargs):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
def setup_process_groups(
self,
tp_group: ProcessGroup,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
fp8_communication: bool = False,
):
assert tp_group is not None
assert moe_dp_group is not None
assert ep_group is not None
@ -62,6 +68,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
self.ep_size = dist.get_world_size(ep_group)
self.ep_rank = dist.get_rank(ep_group)
self.ep_group = ep_group
self.fp8_communication = fp8_communication
if self.num_experts % self.ep_size != 0:
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
@ -80,9 +87,15 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
self.tp_group = tp_group
if self.tp_group.size() > 1:
for expert in held_experts:
expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group)
expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group)
expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group)
expert.w1 = Linear1D_Col.from_native_module(
expert.w1, self.tp_group, fp8_communication=self.fp8_communication
)
expert.w3 = Linear1D_Col.from_native_module(
expert.w3, self.tp_group, fp8_communication=self.fp8_communication
)
expert.w2 = Linear1D_Row.from_native_module(
expert.w2, self.tp_group, fp8_communication=self.fp8_communication
)
for p in self.experts.parameters():
set_moe_tensor_ep_group(p, ep_group)
@ -99,7 +112,8 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
# TODO: better init
LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
fp8_communication = kwargs.get("fp8_communication", False)
module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -120,6 +134,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
with torch.no_grad():
@ -132,7 +147,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
output_states, _ = all_to_all_uneven(
dispatch_states,
input_split_list,
output_split_list,
self.ep_group,
fp8_communication=self.fp8_communication,
)
# compute expert output
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
if output_states.size(0) > 0:
@ -162,7 +183,9 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
output_states = torch.cat(output_states_list)
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
dispatch_states, _ = all_to_all_uneven(
output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication
)
recover_experts_idx = torch.empty_like(selected_experts_idx)
recover_experts_idx[selected_experts_idx] = torch.arange(
@ -566,9 +589,9 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
@ -673,7 +696,9 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
attn_output = all_to_all_comm(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
) # (1, 4, 256)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -780,9 +805,13 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
)
if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
hidden_states = inputs_embeds
# decoder layers
@ -831,9 +860,13 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer
if output_hidden_states:

View File

@ -175,6 +175,7 @@ class Qwen2PipelineForwards:
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
@ -182,6 +183,7 @@ class Qwen2PipelineForwards:
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
# decoder layers
@ -246,6 +248,7 @@ class Qwen2PipelineForwards:
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
@ -253,6 +256,7 @@ class Qwen2PipelineForwards:
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
# add hidden states from the last decoder layer
if output_hidden_states:
@ -516,9 +520,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
value_states = self.v_proj(hidden_states)
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
@ -604,7 +608,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
attn_output = attn_output.transpose(1, 2).contiguous()
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
attn_output = all_to_all_comm(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -702,9 +708,13 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
next_decoder_cache = None
if sp_mode in ["ring", "split_gather"]:
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
hidden_states = split_forward_gather_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
hidden_states = split_forward_gather_backward(
hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
for decoder_layer in self.layers:
if output_hidden_states:
@ -741,9 +751,13 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer
if output_hidden_states:

View File

@ -98,6 +98,7 @@ class BertPolicy(Policy):
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
@ -106,6 +107,7 @@ class BertPolicy(Policy):
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
@ -114,6 +116,7 @@ class BertPolicy(Policy):
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
@ -123,7 +126,10 @@ class BertPolicy(Policy):
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel_mode": sp_mode},
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
@ -136,12 +142,16 @@ class BertPolicy(Policy):
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel_mode": sp_mode},
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="output.dropout",
@ -180,6 +190,13 @@ class BertPolicy(Policy):
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=embedding_cls,
kwargs=(
{
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {}
),
)
],
policy=policy,
@ -249,6 +266,7 @@ class BertPolicy(Policy):
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
),
policy=base_policy,

View File

@ -72,20 +72,30 @@ class BlipPolicy(Policy):
target_module=col_nn.FusedLinear1D_Col,
kwargs={
"n_fused": 3,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.projection",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.fc1",
target_module=col_nn.Linear1D_Col,
kwargs={"skip_bias_add": self.enable_bias_gelu_fused},
kwargs={
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.fc2",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)
@ -114,14 +124,23 @@ class BlipPolicy(Policy):
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
@ -130,6 +149,9 @@ class BlipPolicy(Policy):
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
@ -138,14 +160,23 @@ class BlipPolicy(Policy):
SubModuleReplacementDescription(
suffix="crossattention.attention.query",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="crossattention.attention.key",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="crossattention.attention.value",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="crossattention.attention.dropout",
@ -154,6 +185,9 @@ class BlipPolicy(Policy):
SubModuleReplacementDescription(
suffix="crossattention.output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="crossattention.output.dropout",
@ -162,10 +196,16 @@ class BlipPolicy(Policy):
SubModuleReplacementDescription(
suffix="intermediate_query.dense",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="output_query.dense",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="output_query.dropout",
@ -185,26 +225,44 @@ class BlipPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="fc1",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)
@ -225,7 +283,14 @@ class BlipPolicy(Policy):
SubModuleReplacementDescription(
suffix="model.decoder.embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
],
policy=policy,
@ -241,6 +306,7 @@ class BlipPolicy(Policy):
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
),
],

View File

@ -76,12 +76,19 @@ class BloomPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap},
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel_mode": sp_mode},
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
@ -90,12 +97,19 @@ class BloomPolicy(Policy):
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap},
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel_mode": sp_mode},
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)
@ -115,7 +129,14 @@ class BloomPolicy(Policy):
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
],
policy=policy,
@ -279,6 +300,7 @@ class BloomForCausalLMPolicy(BloomPolicy):
kwargs=dict(
gather_output=not self.shard_config.parallel_output,
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
fp8_communication=self.shard_config.fp8_communication,
),
),
policy=policy,
@ -337,7 +359,9 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
suffix="score",
target_module=col_nn.Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
),
policy=policy,
target_key=BloomForSequenceClassification,
@ -374,7 +398,9 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
suffix="classifier",
target_module=col_nn.Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="dropout",

View File

@ -128,12 +128,17 @@ class ChatGLMPolicy(Policy):
"seq_parallel_mode": sp_mode,
"seq_parallel_dim": 0,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0},
kwargs={
"seq_parallel_mode": sp_mode,
"seq_parallel_dim": 0,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attention.core_attention.attention_dropout",
@ -148,7 +153,14 @@ class ChatGLMPolicy(Policy):
SubModuleReplacementDescription(
suffix="embedding.word_embeddings",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
],
policy=policy,

View File

@ -128,37 +128,37 @@ class CommandPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
],
)
@ -168,7 +168,14 @@ class CommandPolicy(Policy):
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=CohereModel,
@ -306,6 +313,7 @@ class CommandForCausalLMPolicy(CommandPolicy):
kwargs={
"gather_output": not self.shard_config.parallel_output,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
)
],

View File

@ -118,18 +118,22 @@ class DeepseekPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
),
],
)
@ -138,7 +142,10 @@ class DeepseekPolicy(Policy):
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs={
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
),
policy=policy,
target_key="DeepseekModel",
@ -155,6 +162,7 @@ class DeepseekPolicy(Policy):
"ep_group": self.shard_config.ep_group,
"tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group,
"fp8_communication": self.shard_config.fp8_communication,
},
)
],
@ -298,14 +306,14 @@ class DeepseekForCausalLMPolicy(DeepseekPolicy):
policy = super().module_policy()
# TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism:
# add a new item for causal lm
# add a new item for casual lm
new_item = {
"DeepseekForCausalLM": ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
)
]
)

View File

@ -105,7 +105,14 @@ class FalconPolicy(Policy):
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
],
policy=policy,

View File

@ -110,14 +110,13 @@ class GPT2Policy(Policy):
"n_fused": 3,
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={
"seq_parallel_mode": sp_mode,
},
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
@ -127,14 +126,13 @@ class GPT2Policy(Policy):
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={
"seq_parallel_mode": sp_mode,
},
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
@ -164,7 +162,14 @@ class GPT2Policy(Policy):
description=SubModuleReplacementDescription(
suffix="wte",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=GPT2Model,
@ -334,6 +339,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
kwargs={
"gather_output": False,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
)
],
@ -404,6 +410,7 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
)
]

View File

@ -77,6 +77,7 @@ class GPTJPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
@ -84,6 +85,7 @@ class GPTJPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
@ -91,19 +93,29 @@ class GPTJPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attn.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.fc_in",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.fc_out",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
@ -125,7 +137,14 @@ class GPTJPolicy(Policy):
description=SubModuleReplacementDescription(
suffix="wte",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=GPTJModel,
@ -264,6 +283,7 @@ class GPTJForCausalLMPolicy(GPTJPolicy):
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
)
]

View File

@ -133,37 +133,37 @@ class LlamaPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
],
)
@ -173,7 +173,14 @@ class LlamaPolicy(Policy):
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=LlamaModel,
@ -316,6 +323,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
kwargs={
"gather_output": not self.shard_config.parallel_output,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
)
],
@ -384,7 +392,12 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
LlamaForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
suffix="score",
target_module=Linear1D_Col,
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
),
)
]
)

View File

@ -88,30 +88,51 @@ class MistralPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)
@ -121,7 +142,14 @@ class MistralPolicy(Policy):
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=MistralModel,
@ -281,6 +309,7 @@ class MistralForCausalLMPolicy(MistralPolicy):
kwargs={
"gather_output": not self.shard_config.parallel_output,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
)
]
@ -297,7 +326,9 @@ class MistralForCausalLMPolicy(MistralPolicy):
SubModuleReplacementDescription(
suffix="lm_head",
target_module=PaddingLMHead,
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
kwargs=dict(
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
),
)
]
)
@ -350,7 +381,9 @@ class MistralForSequenceClassificationPolicy(MistralPolicy):
MistralForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
suffix="score",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
)
]
)

View File

@ -114,21 +114,27 @@ class MixtralPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription( # or replicate?
suffix="block_sparse_moe.gate", target_module=Linear1D_Col, kwargs={"gather_output": True}
suffix="block_sparse_moe.gate",
target_module=Linear1D_Col,
kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication},
),
],
)
@ -138,7 +144,14 @@ class MixtralPolicy(Policy):
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=MixtralModel,
@ -282,7 +295,7 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
)
]
)
@ -336,7 +349,9 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
MixtralForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
suffix="score",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
)
]
)

View File

@ -102,18 +102,30 @@ class OPTPolicy(Policy):
SubModuleReplacementDescription(
suffix="q_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="k_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="v_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="out_proj",
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)
@ -123,7 +135,14 @@ class OPTPolicy(Policy):
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=OPTDecoder,
@ -272,6 +291,7 @@ class OPTForCausalLMPolicy(OPTPolicy):
kwargs=dict(
gather_output=not self.shard_config.parallel_output,
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
fp8_communication=self.shard_config.fp8_communication,
),
),
policy=policy,

View File

@ -119,37 +119,37 @@ class Qwen2Policy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode),
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
],
)
@ -159,7 +159,14 @@ class Qwen2Policy(Policy):
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=Qwen2Model,
@ -313,11 +320,15 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
setattr(self.shard_config, "causal_lm", True)
if self.shard_config.enable_tensor_parallelism:
# add a new item for causal lm
# add a new item for casual lm
new_item = {
Qwen2ForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col)
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(fp8_communication=self.shard_config.fp8_communication),
)
],
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
)
@ -366,7 +377,9 @@ class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
Qwen2ForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
suffix="score",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
)
]
)

View File

@ -43,19 +43,29 @@ class SamPolicy(Policy):
target_module=col_nn.FusedLinear1D_Col,
kwargs={
"n_fused": 3,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attn.proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.lin1",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.lin2",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)
@ -68,58 +78,100 @@ class SamPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_token_to_image.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_token_to_image.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_token_to_image.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_token_to_image.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.lin1",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.lin2",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_image_to_token.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_image_to_token.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_image_to_token.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_image_to_token.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)
@ -132,18 +184,30 @@ class SamPolicy(Policy):
SubModuleReplacementDescription(
suffix="final_attn_token_to_image.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="final_attn_token_to_image.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="final_attn_token_to_image.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="final_attn_token_to_image.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)

View File

@ -117,23 +117,38 @@ class T5BasePolicy(Policy):
SubModuleReplacementDescription(
suffix="q",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="k",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="v",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="o",
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="relative_attention_bias",
target_module=Embedding1D,
kwargs=dict(gather_output=False),
kwargs=dict(
gather_output=False,
fp8_communication=self.shard_config.fp8_communication,
),
ignore_if_not_exist=True,
),
],
@ -151,13 +166,24 @@ class T5BasePolicy(Policy):
SubModuleReplacementDescription(
suffix="wi_0 ",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="wi_1",
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
suffix="wo",
target_module=Linear1D_Col,
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
),
),
SubModuleReplacementDescription(
suffix="dropout",
@ -170,10 +196,16 @@ class T5BasePolicy(Policy):
SubModuleReplacementDescription(
suffix="wi",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="wo",
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="dropout",
@ -187,7 +219,14 @@ class T5BasePolicy(Policy):
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=T5Stack,
@ -407,7 +446,14 @@ class T5ModelPolicy(T5BasePolicy):
description=SubModuleReplacementDescription(
suffix="shared",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=T5Model,
@ -451,7 +497,14 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
description=SubModuleReplacementDescription(
suffix="shared",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=T5ForConditionalGeneration,
@ -465,6 +518,7 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
),
policy=policy,
@ -539,7 +593,14 @@ class T5EncoderPolicy(T5BasePolicy):
description=SubModuleReplacementDescription(
suffix="shared",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=T5EncoderModel,

View File

@ -70,14 +70,23 @@ class ViTPolicy(Policy):
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
@ -86,6 +95,9 @@ class ViTPolicy(Policy):
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
@ -96,11 +108,15 @@ class ViTPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="output.dropout",
@ -215,7 +231,9 @@ class ViTForImageClassificationPolicy(ViTPolicy):
ViTForImageClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
suffix="classifier",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
)
]
)

View File

@ -91,26 +91,44 @@ class WhisperPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="fc1",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)
@ -128,42 +146,72 @@ class WhisperPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="encoder_attn.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="encoder_attn.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="encoder_attn.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="encoder_attn.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="fc1",
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)
@ -174,7 +222,14 @@ class WhisperPolicy(Policy):
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
],
policy=policy,
@ -303,6 +358,7 @@ class WhisperPolicy(Policy):
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
),
policy=base_policy,

View File

@ -29,6 +29,7 @@ class ShardConfig:
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism. Defaults to False.
parallel_output (bool): For TP: whether to use parallelize cross entropy computation along the feature dim.
For SP: set to True to NOT gather the output along the seq dim.
"""
@ -54,6 +55,7 @@ class ShardConfig:
# for moe related
moe_dp_group: Optional[ProcessGroup] = None
ep_group: Optional[ProcessGroup] = None
fp8_communication: bool = False
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']

View File

@ -61,6 +61,8 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
with torch._C.DisableTorchFunction():
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
args, kwargs = replace_args(args, kwargs, new_args)
with torch._C.DisableTorchFunction():
func = ColoParamOpHookManager.rewrite_op(func)
ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction():
ret = ColoParamOpHookManager.post_op(params, ret)

View File

@ -30,6 +30,9 @@ class ColoParamOpHook(ABC):
def post_backward(self, params: List[torch.Tensor]) -> None:
pass
def rewrite_op(self, func) -> Any:
return func
class ColoParamOpHookManager:
"""
@ -101,6 +104,12 @@ class ColoParamOpHookManager:
def has_hook() -> bool:
return len(ColoParamOpHookManager.hooks) > 0
@staticmethod
def rewrite_op(func) -> Any:
for hook in ColoParamOpHookManager.hooks:
func = hook.rewrite_op(func)
return func
class PreFwdPostBwd(torch.autograd.Function):
@staticmethod

View File

@ -166,6 +166,7 @@ class Chunk:
self.grad_chunk = None
# the async all-reduce/reduce-scatter work of this grad chunk (None means sync)
self.grad_reduce_work = None
self.fp8_communication = False
@property
def memory_usage(self) -> Dict[str, int]:
@ -521,9 +522,17 @@ class Chunk:
alloc_storage(self.cuda_global_chunk)
assert self.cuda_global_chunk.is_contiguous()
work = dist.all_gather_into_tensor(
self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op
)
if self.fp8_communication:
assert async_op == False, "fp8 all-gather does not support async_op!"
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
work = all_gather_into_tensor_flat_fp8(
self.cuda_global_chunk, self.cuda_shard, self.cuda_global_chunk.shape, self.torch_pg
)
else:
work = dist.all_gather_into_tensor(
self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op
)
self.cuda_shard = None
self.is_gathered = True

View File

@ -26,6 +26,7 @@ class ChunkManager:
init_device: Optional[torch.device] = None,
reuse_fp16_chunk: bool = True,
max_prefetch: int = 0,
fp8_communication: bool = False,
) -> None:
self.device = init_device or get_accelerator().get_current_device()
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
@ -44,6 +45,7 @@ class ChunkManager:
self.accumulating_grads = False
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
self._prefetch_stream = get_accelerator().Stream() if max_prefetch else None
self.fp8_communication = fp8_communication
def register_tensor(
self,
@ -101,6 +103,8 @@ class ChunkManager:
extra_dp_group=extra_dp_group,
**chunk_kwargs,
)
if self.fp8_communication:
chunk.fp8_communication = True
chunk_group.append(chunk)
chunk.append_tensor(tensor)

View File

@ -15,6 +15,7 @@ from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.d_tensor import (
distribute_tensor,
@ -98,6 +99,8 @@ class GeminiDDP(ModelWrapper):
extra_dp_group: Optional[ProcessGroup] = None,
verbose: bool = False,
enable_async_reduce: bool = True,
fp8_communication: bool = False,
use_fp8: bool = False,
) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
@ -122,6 +125,8 @@ class GeminiDDP(ModelWrapper):
verbose=verbose,
max_prefetch=max_prefetch,
)
if fp8_communication:
self.chunk_manager.fp8_communication = True
self.gemini_manager = GeminiManager(
placement_policy,
self.chunk_manager,
@ -135,6 +140,9 @@ class GeminiDDP(ModelWrapper):
)
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
self.hooks = [self.param_op_hook]
if use_fp8:
self.hooks.append(FP8Hook())
self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
@ -307,7 +315,7 @@ class GeminiDDP(ModelWrapper):
outputs = self._inference_forward(*args, **kwargs)
else:
self.gemini_manager.pre_iter(*args)
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
with ColoParamOpHookManager.use_hooks(*self.hooks):
outputs = self.module(*args, **kwargs)
if self.force_outputs_fp32:
@ -316,7 +324,7 @@ class GeminiDDP(ModelWrapper):
def _inference_forward(self, *args, **kwargs):
"""This function is only triggered for inference."""
fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook)
fwd_ctx = ColoParamOpHookManager.use_hooks(*self.hooks)
if not self.scatter_after_inference:
# gather all chunks
for chunk in self.chunk_manager.get_chunks(self.fp16_params):
@ -369,7 +377,7 @@ class GeminiDDP(ModelWrapper):
def backward(self, loss: torch.Tensor):
self._pre_backward()
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(*self.hooks):
loss.backward()
self._post_backward()

View File

@ -4,6 +4,8 @@ import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
class TensorBucket:
def __init__(self, size):
@ -61,11 +63,14 @@ class TensorBucket:
for old, new in zip(self._bucket, unflattened_tensor_list):
old.copy_(new)
def all_gather(self, group=None):
def all_gather(self, group=None, fp8_communication: bool = False):
flat = self.flatten()
buffers = [torch.empty_like(flat) for _ in range(dist.get_world_size(group))]
dist.all_gather(buffers, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffers]
buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype)
if fp8_communication:
all_gather_into_tensor_flat_fp8(buffer, flat, output_shape=buffer.shape, group=group)
else:
dist.all_gather_into_tensor(buffer, flat, group=group)
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))]
# transpose the list of list
unflat_buffers = list(map(list, zip(*unflat_buffers)))
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):

View File

@ -20,6 +20,7 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
)
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8, all_reduce_fp8, reduce_scatter_fp8
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
@ -86,6 +87,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
forced_dtype: Optional[torch.dtype] = None,
master_weights: bool = True, # master weights
overlap_allgather: bool = False,
fp8_communication: bool = False,
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
@ -127,6 +129,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._overlap_allgather = overlap_allgather
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype
self._fp8_communication = fp8_communication
# gradient clipping
self._clip_grad_norm = clip_grad_norm
@ -330,7 +333,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
flat_grads = flat_grads.to(self._communication_dtype)
if not self._partition_grads:
dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
if self._fp8_communication:
all_reduce_fp8(flat_grads, group=bucket_store.torch_pg)
else:
dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype)
@ -340,7 +346,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size))
received_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)
if self._fp8_communication:
reduce_scatter_fp8(
received_grad,
flat_grads_list,
group=bucket_store.torch_pg,
)
else:
dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)
if received_grad.dtype != grad_dtype:
received_grad = received_grad.to(grad_dtype)
@ -562,18 +575,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
set_all_gather_handle(working_param, handle)
else:
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
if self._fp8_communication:
all_gather_into_tensor_flat_fp8(
padded_working_param, param_to_gather, pg, fp8_format="e4m3"
)
else:
dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
continue
try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
except RuntimeError:
self.pg_to_tensor_bucket[pg].all_gather(pg)
self.pg_to_tensor_bucket[pg].all_gather(pg, fp8_communication=self._fp8_communication)
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
if not self._overlap_allgather:
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg)
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""

View File

@ -179,7 +179,7 @@ def main():
"--plugin",
type=str,
default="torch_ddp",
choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel"],
choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel", "torch_fsdp"],
help="plugin to use",
)
parser.add_argument(
@ -190,6 +190,7 @@ def main():
)
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication")
args = parser.parse_args()
if args.model_type == "bert":
@ -214,9 +215,9 @@ def main():
if args.plugin == "torch_ddp_fp16":
booster_kwargs["mixed_precision"] = "fp16"
if args.plugin.startswith("torch_ddp"):
plugin = TorchDDPPlugin()
plugin = TorchDDPPlugin(fp8_communication=args.use_fp8_comm)
elif args.plugin == "gemini":
plugin = GeminiPlugin(initial_scale=2**5)
plugin = GeminiPlugin(initial_scale=2**5, fp8_communication=args.use_fp8_comm)
elif args.plugin == "low_level_zero":
plugin = LowLevelZeroPlugin(initial_scale=2**5)
elif args.plugin == "hybrid_parallel":
@ -232,6 +233,18 @@ def main():
zero_stage=1,
precision="fp16",
initial_scale=1,
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "torch_fsdp":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from colossalai.booster.plugin import TorchFSDPPlugin
plugin = TorchFSDPPlugin(
mixed_precision=MixedPrecision(
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
),
fp8_communication=args.use_fp8_comm,
)
booster = Booster(plugin=plugin, **booster_kwargs)

View File

@ -188,6 +188,8 @@ def main():
help="only gpt2 now",
)
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication")
args = parser.parse_args()
if args.model_type == "gpt2":
@ -210,7 +212,7 @@ def main():
if args.plugin == "torch_ddp_fp16":
booster_kwargs["mixed_precision"] = "fp16"
if args.plugin.startswith("torch_ddp"):
plugin = TorchDDPPlugin()
plugin = TorchDDPPlugin(fp8_communication=args.use_fp8_comm)
elif args.plugin == "gemini":
plugin = GeminiPlugin(initial_scale=2**5)
elif args.plugin == "low_level_zero":
@ -226,6 +228,7 @@ def main():
zero_stage=1,
precision="fp16",
initial_scale=1,
fp8_communication=args.use_fp8_comm,
)
booster = Booster(plugin=plugin, **booster_kwargs)

View File

@ -104,6 +104,8 @@ def main():
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
parser.add_argument("--use_fp8", action="store_true")
parser.add_argument("--overlap_allgather", action="store_true")
parser.add_argument(
"--sp_mode",
@ -148,6 +150,7 @@ def main():
enable_flash_attention=args.xformers,
max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce,
use_fp8=args.use_fp8,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@ -160,6 +163,7 @@ def main():
max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce,
enable_flash_attention=args.xformers,
use_fp8=args.use_fp8,
)
elif args.plugin == "fsdp":
if use_empty_init:
@ -170,6 +174,7 @@ def main():
buffer_dtype=torch.float16,
),
param_init_fn=empty_init(),
fp8_communication=args.use_fp8_comm,
)
else:
plugin = TorchFSDPPlugin(
@ -177,7 +182,8 @@ def main():
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16,
)
),
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "fsdp_cpu":
if use_empty_init:
@ -189,6 +195,7 @@ def main():
),
cpu_offload=CPUOffload(offload_params=True),
param_init_fn=empty_init(),
fp8_communication=args.use_fp8_comm,
)
else:
plugin = TorchFSDPPlugin(
@ -198,6 +205,7 @@ def main():
buffer_dtype=torch.float16,
),
cpu_offload=CPUOffload(offload_params=True),
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
@ -215,6 +223,7 @@ def main():
precision="bf16",
enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":
@ -230,6 +239,8 @@ def main():
microbatch_size=args.mbs,
initial_scale=2**8,
precision="bf16",
overlap_p2p=args.overlap,
use_fp8=args.use_fp8,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
@ -259,7 +270,6 @@ def main():
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
else nullcontext()
)
init_kwargs = {}
if config.model_type == "chatglm":
init_kwargs["empty_init"] = False

View File

@ -9,7 +9,7 @@ torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package
torchrec==0.2.0
contexttimer
einops
triton==2.1.0
triton
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
SentencePiece
ninja

View File

@ -0,0 +1,75 @@
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group
from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_to_all_single_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)])
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("async_op", [True, False])
def check_all2all(shape, dtype, async_op):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
output = torch.empty_like(x)
output_fp8 = torch.empty_like(x)
origin_hanle = dist.all_to_all_single(output, x, group=_get_default_group(), async_op=async_op)
fp8_handle = all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=async_op)
if async_op:
origin_hanle.wait()
fp8_handle.wait()
assert_close(output, output_fp8, rtol=0.1, atol=0.1)
@parameterize("shape", [(8, 8, 16)])
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("async_op", [True, False])
def check_all2all_uneven(shape, dtype, async_op):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
input_split_sizes = [3, 3, 1, 1]
if dist.get_rank() in [0, 1]:
output_split_sizes = [3, 3, 3, 3]
else:
output_split_sizes = [1, 1, 1, 1]
output_shape = list(shape)
output_shape[0] = sum(output_split_sizes)
output = torch.empty(output_shape, device=x.device, dtype=x.dtype)
output_fp8 = torch.empty(output_shape, device=x.device, dtype=x.dtype)
origin_hanle = dist.all_to_all_single(
output,
x,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=_get_default_group(),
async_op=async_op,
)
fp8_handle = all_to_all_single_fp8(
output_fp8,
x,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=_get_default_group(),
async_op=async_op,
)
if async_op:
origin_hanle.wait()
fp8_handle.wait()
assert_close(output, output_fp8, rtol=0.1, atol=0.1)
def run_dist(rank, world_size, port):
launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_all2all()
check_all2all_uneven()
@rerun_if_address_is_in_use()
def test_all_to_all_single():
spawn(run_dist, 4)
if __name__ == "__main__":
test_all_to_all_single()

View File

@ -0,0 +1,39 @@
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group
from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_to_all_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize("shape", [(16, 8, 4)])
@parameterize("scatter_dim", [0, 1, 2])
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("fp8_format", ["e4m3", "e5m2"])
def check_4gpu(shape, scatter_dim, dtype, fp8_format):
world_size = dist.get_world_size()
input_tensor = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
input_tensor_list = list(torch.chunk(input_tensor, world_size, scatter_dim))
input_tensor_list = [x.contiguous() for x in input_tensor_list]
output_tensor_list_fp8 = [torch.empty_like(x) for x in input_tensor_list]
output_tensor_list = [torch.empty_like(x) for x in input_tensor_list]
all_to_all_fp8(output_tensor_list_fp8, input_tensor_list, group=_get_default_group(), fp8_format=fp8_format)
dist.all_to_all(output_tensor_list, input_tensor_list, group=_get_default_group())
assert_close(output_tensor_list_fp8, output_tensor_list, rtol=0.1, atol=0.1)
def run_dist(rank, world_size, port):
launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_4gpu()
@rerun_if_address_is_in_use()
def test_all_to_all():
spawn(run_dist, 4)
if __name__ == "__main__":
test_all_to_all()

View File

@ -0,0 +1,37 @@
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group
from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_to_all_single_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
dist.all_to_all_single
@parameterize("shape", [(4), (8, 7), (4, 8, 16)])
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("fp8_format", ["e4m3", "e5m2"])
def check_4gpu(shape, dtype, fp8_format):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
output = torch.empty_like(x)
output_fp8 = torch.empty_like(x)
all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), fp8_format=fp8_format)
dist.all_to_all_single(output, x, group=_get_default_group())
assert_close(output, output_fp8, rtol=0.1, atol=0.1)
def run_dist(rank, world_size, port):
launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_4gpu()
@rerun_if_address_is_in_use()
def test_all_to_all_single():
spawn(run_dist, 4)
if __name__ == "__main__":
test_all_to_all_single()

View File

@ -0,0 +1,43 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed.distributed_c10d import _get_default_group
from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize("shape", [(3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("async_op", [True, False])
def check_4gpu(shape, dtype, async_op):
world_size = dist.get_world_size()
rank = dist.get_rank()
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
flat_padded_x = x.view(-1)
if flat_padded_x.size(0) % world_size != 0:
pad_size = world_size - flat_padded_x.size(0) % world_size
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
output = torch.empty_like(flat_padded_x)
chunk = flat_padded_x.chunk(world_size)[rank].clone()
handle = all_gather_into_tensor_flat_fp8(output, chunk, x.shape, group=_get_default_group(), async_op=async_op)
if async_op:
handle.wait()
assert_close(output[: x.numel()], x.view(-1), rtol=0.1, atol=0.1)
def run_dist(rank, world_size, port):
launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_4gpu()
@rerun_if_address_is_in_use()
def test_all_gather_flat():
spawn(run_dist, 4)
if __name__ == "__main__":
test_all_gather_flat()

View File

@ -0,0 +1,55 @@
import torch
import torch.distributed as dist
from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_reduce_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize(
"shape",
[
(3, 7),
(4, 7),
(7, 4),
(8, 9),
(3),
(7,),
(8,),
],
)
@parameterize("dtype", [torch.float16, torch.bfloat16])
@parameterize("fp8_format", ["e4m3", "e5m2"])
@parameterize("async_op", [True, False])
def check_4gpu(shape, dtype, fp8_format, async_op):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
x_fp8 = x.clone()
origin_handle = dist.all_reduce(x, async_op=async_op)
fp8_handle = all_reduce_fp8(x_fp8, fp8_format=fp8_format, async_op=async_op)
if async_op:
origin_handle.wait()
fp8_handle.wait()
assert_close(x, x_fp8, rtol=0.1, atol=0.1)
origin_handle = dist.all_reduce(x, op=dist.ReduceOp.AVG, async_op=async_op)
fp8_handle = all_reduce_fp8(x_fp8, op=dist.ReduceOp.AVG, fp8_format=fp8_format, async_op=async_op)
if async_op:
origin_handle.wait()
fp8_handle.wait()
assert_close(x, x_fp8, rtol=0.1, atol=0.1)
def run_dist(rank, world_size, port):
launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_4gpu()
@rerun_if_address_is_in_use()
def test_all_reduce():
spawn(run_dist, 4)
if __name__ == "__main__":
test_all_reduce()

View File

@ -0,0 +1,26 @@
import torch
from torch.testing import assert_close
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import cast_from_fp8, cast_from_fp8_pipeline, cast_to_fp8, cast_to_fp8_pipeline
from colossalai.testing import parameterize
@parameterize("shape", [(100, 10), (10, 100), (3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
@parameterize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@parameterize("fp8_format", ["e4m3", "e5m2"])
def test_fp8_cast(shape, dtype, fp8_format):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
ret, scale_inv = cast_to_fp8(x, fp8_format=fp8_format)
out = cast_from_fp8(ret, scale_inv, x.dtype)
assert_close(out, x, rtol=0.1, atol=0.1)
if x.size(-1) % 2 == 0:
inp_dict = {"hidden_states": x.clone()}
cast_to_fp8_pipeline(inp_dict)
cast_from_fp8_pipeline(inp_dict)
assert_close(inp_dict["hidden_states"], x, rtol=0.1, atol=0.1)
if __name__ == "__main__":
test_fp8_cast()

View File

@ -0,0 +1,87 @@
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
def get_grads_after_one_iteration(hook=None):
torch.manual_seed(0)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
if hook is not None:
ddp_model.register_comm_hook(None, hook)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
torch.distributed.barrier()
grad_dict = {}
for name, params in ddp_model.named_parameters():
grad_dict[name] = params.grad
return grad_dict
from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async, fp8_compress_ddp_grad_comm_hook_sync
grad_dict = get_grads_after_one_iteration()
for hook in [fp8_compress_ddp_grad_comm_hook_sync, fp8_compress_ddp_grad_comm_hook_async]:
grad_dict_w_hook = get_grads_after_one_iteration(hook)
if dist.get_rank() == 0:
for name in grad_dict:
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1)
cleanup()
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True)
if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
world_size = n_gpus
run_demo(demo_basic, world_size)

View File

@ -0,0 +1,107 @@
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from packaging import version
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing import assert_close
from colossalai import launch
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(100, 100)
self.relu = nn.ReLU()
self.net2 = nn.Linear(100, 50)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
@parameterize("mode", ["grad", "params"])
def run_model(mode):
rank = dist.get_rank()
from colossalai.quantization.utils import patch_fsdp_params_comm_hook
patch_fsdp_params_comm_hook()
def get_grads_after_one_iteration(grad_hook=None, params_hook=None):
torch.manual_seed(0)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
fsdp_model = FSDP(model)
if grad_hook is not None:
fsdp_model.register_comm_hook(None, grad_hook)
if params_hook is not None:
fsdp_model.register_params_comm_hook(None, params_hook)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(fsdp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = fsdp_model(torch.randn(20, 100))
labels = torch.randn(20, 50).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
torch.distributed.barrier()
grad_dict = {}
for name, params in fsdp_model.named_parameters():
grad_dict[name] = params.grad
return grad_dict
from colossalai.quantization.fp8 import fp8_compress_fsdp_grad_comm_hook, fp8_compress_fsdp_params_comm_hook
if mode == "grad":
grad_dict = get_grads_after_one_iteration()
for hook in [
fp8_compress_fsdp_grad_comm_hook,
]:
grad_dict_w_hook = get_grads_after_one_iteration(grad_hook=hook)
if dist.get_rank() == 0:
for name in grad_dict:
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1)
elif mode == "params":
grad_dict = get_grads_after_one_iteration()
for hook in [
fp8_compress_fsdp_params_comm_hook,
]:
grad_dict_w_hook = get_grads_after_one_iteration(params_hook=hook)
if dist.get_rank() == 0:
for name in grad_dict:
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1)
else:
raise NotImplementedError
def demo_basic(rank, world_size, port):
print(f"Running basic FSDP example on rank {rank}.")
launch(rank=rank, world_size=world_size, port=port, host="localhost")
run_model()
cleanup()
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("2.2.0"), reason="torch version < 2.2.0.")
@rerun_if_address_is_in_use()
def test_fsdp():
n_gpus = torch.cuda.device_count()
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
spawn(demo_basic, n_gpus)
if __name__ == "__main__":
test_fsdp()

View File

@ -0,0 +1,52 @@
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group
from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import gather_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize(
"shape",
[
(3, 7),
(2, 1),
(1, 2),
(2, 2),
(4, 2),
(5,),
(4,),
(2,),
],
)
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("fp8_format", ["e4m3", "e5m2"])
@parameterize("async_op", [True, False])
def check_4gpu(shape, dtype, fp8_format, async_op):
world_size = dist.get_world_size()
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
output_list = [torch.empty_like(x) for _ in range(world_size)]
output_list_fp8 = [torch.empty_like(x) for _ in range(world_size)]
fp8_handle = gather_fp8(output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op)
origin_hanle = dist.all_gather(output_list, x, group=_get_default_group(), async_op=async_op)
if async_op:
fp8_handle.wait()
origin_hanle.wait()
assert_close(output_list, output_list_fp8, rtol=0.1, atol=0.1)
def run_dist(rank, world_size, port):
launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_4gpu()
@rerun_if_address_is_in_use()
def test_all_gather():
spawn(run_dist, 4)
if __name__ == "__main__":
test_all_gather()

View File

@ -0,0 +1,50 @@
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import linear_fp8
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device
REPLACED = False
TRIGGERED = False
def new_linear_fp8(x, w, bias=None):
global TRIGGERED
TRIGGERED = True
return linear_fp8(x, w, bias)
class FP8TestHook(FP8Hook):
def rewrite_op(self, func):
func = super().rewrite_op(func)
if func is linear_fp8:
global REPLACED
REPLACED = True
return new_linear_fp8
return func
D_IN, D_OUT = 16, 32
B, S = 2, 64
DTYPE = torch.bfloat16
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
def test_fp8_hook():
# create tensors
w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE))
x = torch.rand(B, S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
w.__class__ = ColoParameter
w.__init__(w, requires_grad=True)
hook = FP8TestHook()
with ColoParamOpHookManager.use_hooks(hook):
o = F.linear(x, w)
assert o.shape == (B, S, D_OUT)
assert REPLACED
assert TRIGGERED

View File

@ -0,0 +1,45 @@
import pytest
import torch
import torch.nn.functional as F
from torch.testing import assert_close
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import linear_fp8
from colossalai.utils import get_current_device
D_IN, D_OUT = 16, 32
B, S = 2, 64
DTYPE = torch.bfloat16
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("use_batch", [True, False])
def test_fp8_linear(use_bias: bool, use_batch: bool):
# create tensors
w = torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
ref_w = w.clone().detach().requires_grad_()
if use_batch:
x_shape = (B, S, D_IN)
else:
x_shape = (S, D_IN)
x = torch.rand(x_shape, device=get_current_device(), dtype=DTYPE, requires_grad=True)
ref_x = x.clone().detach().requires_grad_()
if use_bias:
bias = torch.rand(D_OUT, device=get_current_device(), dtype=DTYPE, requires_grad=True)
ref_bias = bias.clone().detach().requires_grad_()
else:
bias = None
ref_bias = None
out = linear_fp8(x, w, bias)
assert out.shape == x_shape[:-1] + (D_OUT,)
out.sum().backward()
ref_out = F.linear(ref_x, ref_w, ref_bias)
ref_out.sum().backward()
assert_close(out, ref_out, rtol=0.2, atol=0.1)
assert_close(x.grad, ref_x.grad, rtol=0.2, atol=0.1)
assert_close(w.grad, ref_w.grad, rtol=0.2, atol=0.1)
if use_bias:
assert_close(bias.grad, ref_bias.grad, rtol=0.2, atol=0.1)

View File

@ -0,0 +1,44 @@
import torch
from torch.distributed import reduce_scatter
from torch.distributed.distributed_c10d import _get_default_group
from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import reduce_scatter_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize("shape", [(16, 8, 4)])
@parameterize("scatter_dim", [0, 1, 2])
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("fp8_format", ["e4m3", "e5m2"])
@parameterize("async_op", [True, False])
def check_4gpu(shape, scatter_dim, dtype, fp8_format, async_op):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
input_list = list(torch.chunk(x, dim=scatter_dim, chunks=4))
input_list = [t.contiguous() for t in input_list]
output_origin = torch.empty_like(input_list[0])
output_fp8 = torch.empty_like(input_list[0])
origin_handle = reduce_scatter(output_origin, input_list, group=_get_default_group(), async_op=async_op)
fp8_handle = reduce_scatter_fp8(
output_fp8, input_list, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op
)
if async_op:
origin_handle.wait()
fp8_handle.wait()
assert_close(output_origin, output_fp8, rtol=0.1, atol=0.1)
def run_dist(rank, world_size, port):
launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_4gpu()
@rerun_if_address_is_in_use()
def test_reduce_scatter():
spawn(run_dist, 4)
if __name__ == "__main__":
test_reduce_scatter()

View File

@ -51,7 +51,8 @@ def split_ddp_grad(grad, world_size):
return splited_grad
def exam_zero_1_2():
@parameterize("fp8_communication", [True, False])
def exam_zero_1_2(fp8_communication: bool):
"""
In this test, we want to test whether zero stage 1 and 2
deliver the same numerical results despite different communication
@ -73,10 +74,18 @@ def exam_zero_1_2():
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
zero1_optimizer = LowLevelZeroOptimizer(
zero1_optimizer, overlap_communication=True, initial_scale=128, verbose=True
zero1_optimizer,
overlap_communication=True,
initial_scale=128,
verbose=True,
fp8_communication=fp8_communication,
)
zero2_optimizer = LowLevelZeroOptimizer(
zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=128
zero2_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=128,
fp8_communication=fp8_communication,
)
# create data
seed_all(2001 + local_rank)
@ -97,7 +106,10 @@ def exam_zero_1_2():
if g1 is None or g2 is None:
assert g1 is None and g2 is None
continue
assert torch.allclose(g1, g2)
if fp8_communication:
loose_close(g1, g2, dtype=torch.float16)
else:
assert torch.allclose(g1, g2)
# step
zero1_optimizer.step()
@ -105,7 +117,8 @@ def exam_zero_1_2():
# check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
assert torch.allclose(z1p, z2p)
if not fp8_communication:
assert torch.allclose(z1p, z2p)
@parameterize("dtype", [torch.float16, torch.bfloat16])