mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #6012 from hpcaitech/feature/fp8_comm
[fp8] support fp8 communication and fp8 training for Colossalaipull/6036/head
commit
17904cb5bf
|
@ -9,6 +9,7 @@ on:
|
|||
paths:
|
||||
- "examples/**"
|
||||
- "!examples/**.md"
|
||||
- ".github/workflows/example_check_on_pr.yml"
|
||||
|
||||
jobs:
|
||||
# This is for changed example files detect and output a matrix containing all the corresponding directory name.
|
||||
|
@ -107,7 +108,7 @@ jobs:
|
|||
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
BUILD_EXT=1 pip install -v .
|
||||
BUILD_EXT=1 pip install -v -e .
|
||||
|
||||
- name: Store Colossal-AI Cache
|
||||
run: |
|
||||
|
|
|
@ -366,7 +366,9 @@ class GeminiPlugin(DPPluginBase):
|
|||
enable_jit_fused: bool = False,
|
||||
enable_sequence_overlap: bool = False,
|
||||
enable_async_reduce: bool = True,
|
||||
use_fp8: bool = False,
|
||||
verbose: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
|
||||
|
@ -401,6 +403,8 @@ class GeminiPlugin(DPPluginBase):
|
|||
master_weights=master_weights,
|
||||
max_prefetch=max_prefetch,
|
||||
enable_async_reduce=enable_async_reduce,
|
||||
fp8_communication=fp8_communication,
|
||||
use_fp8=use_fp8,
|
||||
)
|
||||
self.zero_optim_config = dict(
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||
|
|
|
@ -31,6 +31,7 @@ from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
|||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.quantization.fp8_hook import FP8Hook
|
||||
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
@ -66,6 +67,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
ddp_config: dict,
|
||||
custom_policy: Policy,
|
||||
overlap_allgather: bool = False,
|
||||
use_fp8: bool = False,
|
||||
) -> None:
|
||||
self.stage_manager = shard_config.pipeline_stage_manager
|
||||
self.shard_config = shard_config
|
||||
|
@ -75,6 +77,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
self.use_ddp = use_ddp
|
||||
self.require_grad_sync = True
|
||||
self.overlap_allgather = overlap_allgather
|
||||
self.use_fp8 = use_fp8
|
||||
|
||||
shardformer = ShardFormer(shard_config)
|
||||
if custom_policy is not None:
|
||||
|
@ -112,6 +115,9 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
module = DDP(module, process_group=dp_group, **ddp_config)
|
||||
|
||||
super().__init__(module)
|
||||
self.op_hooks = []
|
||||
if use_fp8:
|
||||
self.op_hooks.append(FP8Hook())
|
||||
if overlap_allgather:
|
||||
self.op_hook = ZeroOpHook()
|
||||
for p in module.parameters():
|
||||
|
@ -223,7 +229,11 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
wait_all_gather_handle(p)
|
||||
|
||||
def _wait_all_gather(self):
|
||||
return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
|
||||
return (
|
||||
ColoParamOpHookManager.use_hooks(*self.op_hooks)
|
||||
if (self.overlap_allgather or self.use_fp8)
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
|
||||
def get_param_info(optim: Optimizer):
|
||||
|
@ -969,6 +979,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
|
||||
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
||||
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
||||
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism
|
||||
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
|
||||
inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn".
|
||||
It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default.
|
||||
|
@ -1020,6 +1031,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
dp_outside: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
overlap_allgather: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
inner_ring_size: int = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
@ -1069,8 +1082,10 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
self.use_fp8 = use_fp8
|
||||
if dp_outside:
|
||||
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
||||
if sequence_parallelism_mode == "ring_attn":
|
||||
# Swap tp and sp since 2D Ring has better inter-node latency
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size)
|
||||
|
@ -1117,6 +1132,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
overlap_p2p=overlap_p2p,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
elif pp_style == "1f1b":
|
||||
self.schedule = OneForwardOneBackwardSchedule(
|
||||
|
@ -1124,6 +1140,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
num_microbatches=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
@ -1158,6 +1175,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
parallel_output=parallel_output,
|
||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||
fp8_communication=fp8_communication,
|
||||
inner_ring_size=inner_ring_size,
|
||||
)
|
||||
self.amp_config = dict(
|
||||
|
@ -1250,7 +1268,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
||||
self.dp_size == 1 and self.pp_size == 1
|
||||
)
|
||||
|
||||
# sync gradients across DP * SP ranks
|
||||
# Apply Hybrid ZeRO across DP * SP ranks
|
||||
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
|
||||
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
|
@ -1268,6 +1286,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
ddp_config=self.ddp_config,
|
||||
custom_policy=self.custom_policy,
|
||||
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
|
||||
use_fp8=self.use_fp8,
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if zero_stage == 0:
|
||||
|
|
|
@ -34,6 +34,7 @@ from colossalai.interface.optimizer import DistributedOptim
|
|||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.quantization.fp8_hook import FP8Hook
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
@ -62,7 +63,12 @@ class OptimizerParamCheckState(enum.Enum):
|
|||
|
||||
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
def __init__(
|
||||
self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True
|
||||
self,
|
||||
module: nn.Module,
|
||||
precision: str,
|
||||
overlap_allgather: bool = False,
|
||||
cast_inputs: bool = True,
|
||||
use_fp8: bool = False,
|
||||
) -> None:
|
||||
super().__init__(module)
|
||||
self.dtype = None
|
||||
|
@ -75,11 +81,16 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
|||
module = module.to(get_accelerator().get_current_device())
|
||||
self.module = module
|
||||
self.convert_fn = None
|
||||
self.use_fp8 = use_fp8
|
||||
if self.dtype is not None and cast_inputs:
|
||||
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||
self.overlap_allgather = overlap_allgather
|
||||
self.op_hooks = []
|
||||
if overlap_allgather:
|
||||
self.op_hook = ZeroOpHook()
|
||||
self.op_hooks.append(ZeroOpHook())
|
||||
if use_fp8:
|
||||
self.op_hooks.append(FP8Hook())
|
||||
if overlap_allgather or use_fp8:
|
||||
for p in module.parameters():
|
||||
if p.requires_grad and type(p) is not ColoParameter:
|
||||
p.__class__ = ColoParameter
|
||||
|
@ -337,6 +348,8 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
master_weights: bool = True,
|
||||
verbose: bool = False,
|
||||
cast_inputs: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
||||
|
@ -360,12 +373,14 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
cpu_offload=cpu_offload,
|
||||
master_weights=master_weights,
|
||||
overlap_allgather=overlap_allgather,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
self.lora_enabled = False
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger()
|
||||
self.cast_inputs = cast_inputs
|
||||
|
||||
self.use_fp8 = use_fp8
|
||||
# set class name with stage, for better error message
|
||||
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
|
||||
|
||||
|
@ -484,6 +499,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
self.precision,
|
||||
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
|
||||
cast_inputs=self.cast_inputs,
|
||||
use_fp8=self.use_fp8,
|
||||
)
|
||||
|
||||
# TODO: Support Galore + ZeRO
|
||||
|
|
|
@ -214,6 +214,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
moe_dp_outside: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
overlap_allgather: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
) -> None:
|
||||
self.logger = get_dist_logger()
|
||||
if overlap_communication or zero_stage == 2:
|
||||
|
@ -327,6 +329,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||
else:
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
||||
self.use_fp8 = use_fp8
|
||||
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
|
@ -345,6 +348,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
parallel_output=parallel_output,
|
||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
|
@ -431,6 +435,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
use_ddp=use_ddp,
|
||||
ddp_config=self.ddp_config,
|
||||
custom_policy=self.custom_policy,
|
||||
use_fp8=self.use_fp8,
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if self.ep_size > 1:
|
||||
|
|
|
@ -179,6 +179,7 @@ class TorchDDPPlugin(DPPluginBase):
|
|||
check_reduction: bool = False,
|
||||
gradient_as_bucket_view: bool = False,
|
||||
static_graph: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.ddp_kwargs = dict(
|
||||
|
@ -189,6 +190,7 @@ class TorchDDPPlugin(DPPluginBase):
|
|||
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||
static_graph=static_graph,
|
||||
)
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
def support_no_sync(self) -> bool:
|
||||
return True
|
||||
|
@ -228,6 +230,11 @@ class TorchDDPPlugin(DPPluginBase):
|
|||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = OptimizerWrapper(optimizer)
|
||||
|
||||
if self.fp8_communication:
|
||||
from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async
|
||||
|
||||
model.module.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_async)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
def control_checkpoint_io(self) -> bool:
|
||||
|
|
|
@ -298,6 +298,7 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
|
||||
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
|
||||
sync_module_states: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.fsdp_kwargs = dict(
|
||||
|
@ -311,6 +312,7 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||
param_init_fn=param_init_fn,
|
||||
sync_module_states=sync_module_states,
|
||||
)
|
||||
self.fp8_communication = fp8_communication
|
||||
self.logger = get_dist_logger()
|
||||
|
||||
else:
|
||||
|
@ -348,6 +350,19 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||
# wrap the model with PyTorch FSDP
|
||||
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
|
||||
|
||||
if self.fp8_communication:
|
||||
from colossalai.quantization.utils import patch_fsdp_params_comm_hook
|
||||
|
||||
patch_fsdp_params_comm_hook()
|
||||
|
||||
from colossalai.quantization.fp8 import fp8_compress_fsdp_params_comm_hook
|
||||
|
||||
fsdp_model.module.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook)
|
||||
|
||||
from colossalai.quantization.fp8 import fp8_compress_fsdp_grad_comm_hook
|
||||
|
||||
fsdp_model.module.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook)
|
||||
|
||||
if optimizer is not None:
|
||||
if len(optimizer.param_groups) > 1:
|
||||
self.logger.warning(
|
||||
|
|
|
@ -6,6 +6,8 @@ from torch import Tensor
|
|||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.quantization.fp8 import all_to_all_single_fp8
|
||||
|
||||
MOE_KERNEL = None
|
||||
|
||||
|
||||
|
@ -380,6 +382,7 @@ def _all_to_all(
|
|||
output_split_sizes: Optional[List[int]] = None,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
"""
|
||||
Returns:
|
||||
|
@ -392,9 +395,14 @@ def _all_to_all(
|
|||
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
inputs = inputs.contiguous()
|
||||
outputs = outputs.contiguous()
|
||||
handle = dist.all_to_all_single(
|
||||
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
|
||||
)
|
||||
if fp8_communication:
|
||||
handle = all_to_all_single_fp8(
|
||||
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=False
|
||||
)
|
||||
else:
|
||||
handle = dist.all_to_all_single(
|
||||
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
|
||||
)
|
||||
return outputs, handle
|
||||
|
||||
|
||||
|
@ -407,6 +415,7 @@ class AllToAllUneven(torch.autograd.Function):
|
|||
output_split_sizes=None,
|
||||
group=None,
|
||||
overlap: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
"""
|
||||
Returns:
|
||||
|
@ -416,7 +425,9 @@ class AllToAllUneven(torch.autograd.Function):
|
|||
ctx.input_split_sizes = input_split_sizes
|
||||
ctx.output_split_sizes = output_split_sizes
|
||||
ctx.group = group
|
||||
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
||||
return _all_to_all(
|
||||
inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication=fp8_communication
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs):
|
||||
|
@ -426,6 +437,7 @@ class AllToAllUneven(torch.autograd.Function):
|
|||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
|
@ -435,8 +447,9 @@ def all_to_all_uneven(
|
|||
output_split_sizes: Optional[List[int]] = None,
|
||||
group=None,
|
||||
overlap: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
assert (
|
||||
inputs.requires_grad
|
||||
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
|
||||
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
||||
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication)
|
||||
|
|
|
@ -11,6 +11,7 @@ from colossalai.accelerator import get_accelerator
|
|||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
|
||||
|
@ -32,6 +33,7 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
microbatch_size: Optional[int] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
super().__init__(stage_manager)
|
||||
assert (
|
||||
|
@ -56,6 +58,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
|
||||
|
@ -191,8 +195,12 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(output_tensor)
|
||||
send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor)
|
||||
return send_handles
|
||||
return []
|
||||
|
||||
|
@ -210,10 +218,14 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(input_tensor_grad)
|
||||
send_handles = self.comm.send_backward(
|
||||
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(input_tensor_grad)
|
||||
return send_handles
|
||||
return []
|
||||
|
||||
|
@ -224,6 +236,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
is_send = not self.stage_manager.is_last_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
is_recv = not self.stage_manager.is_first_stage()
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(output_tensor)
|
||||
input_tensor, wait_handles = self.comm.send_forward_recv_forward(
|
||||
output_tensor,
|
||||
is_send,
|
||||
|
@ -237,6 +251,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor)
|
||||
return input_tensor, wait_handles
|
||||
|
||||
def send_backward_recv_backward(
|
||||
|
@ -246,6 +262,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
is_send = not self.stage_manager.is_first_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
is_recv = not self.stage_manager.is_last_stage()
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(input_tensor_grad)
|
||||
output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward(
|
||||
input_tensor_grad,
|
||||
is_send,
|
||||
|
@ -258,6 +276,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
self.send_grad_metadata = not self.enable_metadata_cache and is_send
|
||||
if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(input_tensor_grad)
|
||||
return output_tensor_grad, wait_handles
|
||||
|
||||
def forward_step(
|
||||
|
@ -378,6 +398,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
# Wait until current input is received
|
||||
_wait_p2p(fwd_wait_handles)
|
||||
if self.fp8_communication and input_obj is not None:
|
||||
cast_from_fp8_pipeline(input_obj)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
|
||||
if not last_batch:
|
||||
|
@ -440,6 +462,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
# Wait for input
|
||||
_wait_p2p(fwd_wait_handles)
|
||||
if self.fp8_communication and input_obj is not None:
|
||||
cast_from_fp8_pipeline(input_obj)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
|
@ -466,6 +490,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
# Wait for input.
|
||||
_wait_p2p(fwd_wait_handles)
|
||||
if self.fp8_communication and input_obj is not None:
|
||||
cast_from_fp8_pipeline(input_obj)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
|
@ -510,6 +536,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
input_obj, fwd_wait_handles = send_forward_recv_forward()
|
||||
# Wait for upstream grad
|
||||
_wait_p2p(bwd_wait_handles)
|
||||
if self.fp8_communication and output_obj_grad is not None:
|
||||
cast_from_fp8_pipeline(output_obj_grad)
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
# NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv)
|
||||
# risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html)
|
||||
|
@ -531,6 +559,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
# Wait for upstream grad
|
||||
_wait_p2p(bwd_wait_handles)
|
||||
if self.fp8_communication and output_obj_grad is not None:
|
||||
cast_from_fp8_pipeline(output_obj_grad)
|
||||
# backward local grads
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
if not last_batch:
|
||||
|
|
|
@ -10,6 +10,7 @@ from colossalai.accelerator import get_accelerator
|
|||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from ._utils import (
|
||||
|
@ -32,6 +33,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
"""1F1B pipeline schedule.
|
||||
|
||||
|
@ -61,6 +63,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
|
||||
|
@ -129,6 +133,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(input_tensor)
|
||||
return input_tensor
|
||||
|
||||
def recv_backward(self, next_rank: int = None) -> Any:
|
||||
|
@ -143,6 +149,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor_grad)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
|
||||
|
@ -157,9 +165,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(output_tensor)
|
||||
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
|
||||
|
||||
def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
@ -169,8 +182,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(input_tensor_grad)
|
||||
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False)
|
||||
|
||||
def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
|
||||
|
@ -183,6 +200,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
if not self.stage_manager.is_last_stage():
|
||||
if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
|
||||
send_first = None
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(output_tensor)
|
||||
output_tensor_grad, _ = self.comm.send_forward_recv_backward(
|
||||
output_tensor,
|
||||
send_metadata=self.send_tensor_metadata,
|
||||
|
@ -192,6 +211,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
|
||||
cast_from_fp8_pipeline(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
|
@ -206,6 +228,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
if not self.stage_manager.is_first_stage():
|
||||
if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
|
||||
send_first = None # must not fallback
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(input_tensor_grad)
|
||||
input_tensor, _ = self.comm.send_backward_recv_forward(
|
||||
input_tensor_grad,
|
||||
send_metadata=self.send_grad_metadata,
|
||||
|
@ -215,6 +239,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(input_tensor)
|
||||
cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
|
|
@ -0,0 +1,738 @@
|
|||
from typing import Any, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from packaging.version import Version
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
|
||||
|
||||
|
||||
class Handle:
|
||||
def __init__(self, handles=[], remain_ops=None) -> None:
|
||||
self.handles = handles
|
||||
self.remain_ops = remain_ops
|
||||
|
||||
def wait(self):
|
||||
for handle in self.handles:
|
||||
handle.wait()
|
||||
if self.remain_ops:
|
||||
self.remain_ops()
|
||||
|
||||
|
||||
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
|
||||
Args:
|
||||
inp: input torch Tensor, should be in torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor.
|
||||
scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling
|
||||
is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied.
|
||||
fp8_format: e4m3 or e5m2
|
||||
|
||||
Returns:
|
||||
Tuples: A tuple (fp8_tensor, scale)
|
||||
"""
|
||||
|
||||
if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
raise TypeError("Only float16, bfloat16, and float32 are allowed.")
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
fp8_max = torch.finfo(fp8_type).max
|
||||
|
||||
if inp.numel() == 0:
|
||||
return inp.to(fp8_type), torch.tensor([1.0], device=inp.device)
|
||||
else:
|
||||
if per_channel_scale:
|
||||
per_channel_max = inp.abs().max(dim=-1).values.float()
|
||||
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
|
||||
scale = fp8_max / per_channel_max[:, None]
|
||||
scale_inv = per_channel_max / fp8_max
|
||||
else:
|
||||
per_tensor_max = inp.abs().max().float()
|
||||
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
|
||||
scale = fp8_max / per_tensor_max
|
||||
scale_inv = 1.0 / scale
|
||||
|
||||
ret = (scale * inp.float()).to(fp8_type)
|
||||
return ret, scale_inv
|
||||
|
||||
|
||||
def cast_from_fp8(
|
||||
inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2].
|
||||
scale: scaling factor returned by cast_to_fp8 function.
|
||||
ret_type: the datatype of the returned tensor.
|
||||
Returns:
|
||||
torch.Tensor
|
||||
"""
|
||||
if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.")
|
||||
|
||||
if per_channel_scale:
|
||||
ret = scale_inv[:, None] * inp.float()
|
||||
else:
|
||||
ret = scale_inv * inp.float()
|
||||
return ret.to(ret_type)
|
||||
|
||||
|
||||
def all_reduce_fp8(
|
||||
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
|
||||
) -> Optional[Handle]:
|
||||
r"""
|
||||
This is an in-place operation for compressed all_reduce using fp8.
|
||||
It works like dist.all_reduce but during communication the data is cast to fp8 format.
|
||||
|
||||
Args:
|
||||
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
|
||||
fp8_format: e4m3 or e5m2
|
||||
op: ReduceOp.SUM or ReduceOp.AVG
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
world_size = dist.get_world_size(group=group)
|
||||
input_type = tensor.dtype
|
||||
input_shape = tensor.shape
|
||||
input_device = tensor.device
|
||||
input_size = tensor.numel()
|
||||
flat_padded_x = tensor.flatten()
|
||||
|
||||
assert op in [ReduceOp.SUM, ReduceOp.AVG], "op can only be ReduceOp.SUM or ReduceOp.AVG"
|
||||
|
||||
if flat_padded_x.size(0) % world_size != 0:
|
||||
pad_size = world_size - flat_padded_x.size(0) % world_size
|
||||
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
ret, scale = cast_to_fp8(flat_padded_x, fp8_format=fp8_format)
|
||||
|
||||
inp = ret.view(torch.uint8)
|
||||
input_chunks = list(torch.chunk(inp, world_size, dim=0))
|
||||
output_chunks = list(torch.chunk(torch.empty_like(inp), world_size, dim=0))
|
||||
dist.all_to_all(output_chunks, input_chunks, group=group)
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
||||
|
||||
for scale, out in zip(scale_list, output_chunks):
|
||||
out = out.view(fp8_type)
|
||||
summed_out += cast_from_fp8(out, scale, input_type)
|
||||
|
||||
if op == ReduceOp.AVG:
|
||||
summed_out.div_(world_size)
|
||||
|
||||
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
|
||||
gather_scale_handle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
|
||||
|
||||
tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)]
|
||||
gather_tensor_handle = dist.all_gather(
|
||||
tensor_list, summed_out_fp8.view(torch.uint8), group=group, async_op=async_op
|
||||
)
|
||||
|
||||
def cat_op():
|
||||
for i in range(world_size):
|
||||
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
|
||||
out = torch.cat(tensor_list, dim=0)
|
||||
tensor.copy_(out[:input_size].view(input_shape).to(input_type))
|
||||
|
||||
if async_op:
|
||||
return Handle([gather_scale_handle, gather_tensor_handle], cat_op)
|
||||
else:
|
||||
cat_op()
|
||||
|
||||
|
||||
def all_to_all_single_fp8(
|
||||
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
|
||||
) -> Optional[Handle]:
|
||||
r"""
|
||||
This is an in-place operation for compressed all_reduce using fp8.
|
||||
It works like dist.all_to_all_single but during communication the data is cast to fp8 format.
|
||||
Args:
|
||||
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
|
||||
fp8_format: e4m3 or e5m2
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
world_size = dist.get_world_size(group=group)
|
||||
input_type = input.dtype
|
||||
input_shape = input.shape
|
||||
input_device = input.device
|
||||
input = input.flatten()
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
|
||||
ret, scale = cast_to_fp8(input, fp8_format=fp8_format)
|
||||
|
||||
inp = ret.view(torch.uint8)
|
||||
if input_split_sizes is not None:
|
||||
input_split_sizes = [input_split_sizes[i] * np.prod(input_shape[1:]) for i in range(world_size)]
|
||||
input_chunks = list(torch.split(inp, input_split_sizes))
|
||||
else:
|
||||
input_chunks = list(torch.chunk(inp, world_size, dim=0))
|
||||
|
||||
if output_split_sizes is not None:
|
||||
output_chunks = [
|
||||
torch.empty((output_split_sizes[i] * np.prod(input_shape[1:]),), device=input_device, dtype=inp.dtype)
|
||||
for i in range(world_size)
|
||||
]
|
||||
else:
|
||||
if dist.get_rank() == world_size - 1:
|
||||
output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)]
|
||||
else:
|
||||
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
|
||||
|
||||
chunk_handle = dist.all_to_all(output_chunks, input_chunks, group=group, async_op=async_op)
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||
scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
|
||||
|
||||
def cast_op():
|
||||
cast_output_chunk = [
|
||||
cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks)
|
||||
]
|
||||
|
||||
tensor_out = torch.cat(cast_output_chunk, dim=0)
|
||||
outputs_shape = list(input_shape)
|
||||
if output_split_sizes is not None:
|
||||
outputs_shape[0] = sum(output_split_sizes)
|
||||
else:
|
||||
outputs_shape = input_shape
|
||||
output.data = tensor_out.view(outputs_shape).to(input_type)
|
||||
|
||||
if async_op:
|
||||
return Handle([chunk_handle, scale_hanle], cast_op)
|
||||
else:
|
||||
cast_op()
|
||||
|
||||
|
||||
def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||
"""
|
||||
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
|
||||
The activations tensor is indexed by 'hidden_states' in the inp dict.
|
||||
After FP8 casting, the resulting tensor is saved as float16 or bfloat16 format but the size becomes halved.
|
||||
Metadata such as fp8_scale is saved into inp dict for communication.
|
||||
"""
|
||||
if inp is None:
|
||||
return
|
||||
# In pipeline parallelism, when inp is torch.Tensor, it only contains one element, thus can be omitted.
|
||||
if type(inp) == torch.Tensor:
|
||||
return
|
||||
|
||||
assert "hidden_states" in inp, "required by pipeline parallelism."
|
||||
assert (
|
||||
inp["hidden_states"].size(-1) % 2 == 0
|
||||
), "tensor size(-1) must be divisible by 2 to view Float8_e4m3fn as BFloat16 or Float16"
|
||||
inp_tensor = inp["hidden_states"]
|
||||
inp_dtype = inp_tensor.dtype
|
||||
|
||||
min_val, max_val = inp_tensor.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs())
|
||||
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
if amax > finfo.max:
|
||||
fp8_type = torch.float8_e5m2
|
||||
fp8_view_type = torch.float16
|
||||
else:
|
||||
fp8_type = torch.float8_e4m3fn
|
||||
fp8_view_type = torch.bfloat16
|
||||
|
||||
finfo = torch.finfo(fp8_type)
|
||||
scale = torch.tensor(1.0).to(inp_tensor.device) if amax == 0.0 else finfo.max / amax.float()
|
||||
q_tensor = inp_tensor.data.float() * scale
|
||||
# Todo: Currently we use fp8_view_type <float16, bfloat16> to indicate which fp8 format is used. This is a temporary workaround due to 'Only support tensor for fast send'.
|
||||
# inp_tensor needs to be a float datatype to avoid error during gradient placement.
|
||||
inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type)
|
||||
|
||||
inp["fp8_scale"] = scale.float().reciprocal()
|
||||
inp["dtype"] = torch.zeros_like(scale).to(inp_dtype)
|
||||
|
||||
|
||||
def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
||||
"""
|
||||
Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline.
|
||||
del_metadata = False is useful when this function is called before p2p communication.
|
||||
"""
|
||||
if inp is None:
|
||||
return
|
||||
if type(inp) == torch.Tensor:
|
||||
return
|
||||
|
||||
assert "hidden_states" in inp, "required by pipeline parallelism."
|
||||
inp_tensor = inp["hidden_states"]
|
||||
scale = inp["fp8_scale"]
|
||||
|
||||
fp8_view_type = inp_tensor.dtype
|
||||
if fp8_view_type == torch.float16:
|
||||
fp8_type = torch.float8_e5m2
|
||||
elif fp8_view_type == torch.bfloat16:
|
||||
fp8_type = torch.float8_e4m3fn
|
||||
else:
|
||||
raise TypeError("Only float16, bfloat16 are implemented.")
|
||||
|
||||
inp_tensor.data = inp_tensor.data.view(fp8_type).to(inp["dtype"]) * scale
|
||||
|
||||
if del_metadata:
|
||||
del inp["fp8_scale"]
|
||||
del inp["dtype"]
|
||||
|
||||
|
||||
def reduce_scatter_fp8(
|
||||
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
|
||||
) -> Optional[Handle]:
|
||||
r"""
|
||||
This is an in-place operation for compressed reduce_scatter using fp8.
|
||||
It works like dist.reduce_scatter but during communication the data is cast to fp8 format.
|
||||
|
||||
Args:
|
||||
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
|
||||
fp8_format: e4m3 or e5m2
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
input_type = output.dtype
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
scale_list = []
|
||||
cast_input_list = []
|
||||
output_chunks = []
|
||||
output_scale_list = []
|
||||
for input in input_list:
|
||||
ret, scale = cast_to_fp8(input, fp8_format=fp8_format)
|
||||
scale_list.append(scale)
|
||||
ret = ret.view(torch.uint8)
|
||||
cast_input_list.append(ret)
|
||||
output_chunks.append(torch.empty_like(ret))
|
||||
output_scale_list.append(torch.empty_like(scale))
|
||||
chunk_handle = dist.all_to_all(output_chunks, cast_input_list, group=group, async_op=async_op)
|
||||
scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op)
|
||||
|
||||
def cast_op():
|
||||
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
||||
for scale, out in zip(output_scale_list, output_chunks):
|
||||
out = out.view(fp8_type)
|
||||
summed_out += cast_from_fp8(out, scale, input_type)
|
||||
output.data = summed_out
|
||||
|
||||
if async_op:
|
||||
return Handle([chunk_handle, scale_handle], cast_op)
|
||||
else:
|
||||
cast_op()
|
||||
|
||||
|
||||
def fp8_compress_ddp_grad_comm_hook_async(
|
||||
process_group: dist.ProcessGroup,
|
||||
bucket: dist.GradBucket,
|
||||
fp8_format: str = "e5m2",
|
||||
) -> torch.futures.Future[torch.Tensor]:
|
||||
"""
|
||||
Compress by casting ``GradBucket`` to FP8 floating-point format divided by process group size.
|
||||
|
||||
This DDP communication hook implements a simple gradient compression approach that casts ``GradBucket`` tensor
|
||||
to FP8 floating-point format (``torch.float8_e5m2`` or ``torch.bfloat16_e4m3``), and then divides it
|
||||
by the process group size.
|
||||
Once compressed gradient tensors are allreduced, the chained callback ``decompress`` casts it back
|
||||
to the input data type (such as ``float32``).
|
||||
|
||||
Example::
|
||||
>>> ddp_model.register_comm_hook(process_group, fp8_compress_ddp_grad_comm_hook_async)
|
||||
"""
|
||||
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
||||
|
||||
input_tensor = bucket.buffer()
|
||||
world_size = dist.get_world_size()
|
||||
input_type = input_tensor.dtype
|
||||
input_device = input_tensor.device
|
||||
flat_padded_x = input_tensor.flatten()
|
||||
|
||||
if flat_padded_x.size(0) % world_size != 0:
|
||||
pad_size = world_size - flat_padded_x.size(0) % world_size
|
||||
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
ret, scale = cast_to_fp8(flat_padded_x, fp8_format=fp8_format)
|
||||
|
||||
inp = ret.view(torch.uint8)
|
||||
output_chunks_single = torch.empty_like(inp)
|
||||
split_sizes = [inp.numel() // world_size for _ in range(world_size)]
|
||||
fut0 = dist.all_to_all_single(
|
||||
output_chunks_single,
|
||||
inp,
|
||||
output_split_sizes=split_sizes,
|
||||
input_split_sizes=split_sizes,
|
||||
group=group_to_use,
|
||||
async_op=True,
|
||||
).get_future()
|
||||
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||
fut1 = dist.all_gather_into_tensor(
|
||||
torch.cat(scale_list, dim=0), scale, group=group_to_use, async_op=True
|
||||
).get_future()
|
||||
all_to_all_fut = torch.futures.collect_all([fut0, fut1])
|
||||
|
||||
def sum_and_allgather(fut):
|
||||
output_chunks_single = fut.value()[0].wait()[0]
|
||||
scale_list_single = fut.value()[1].wait()[0]
|
||||
|
||||
output_chunks = list(torch.chunk(output_chunks_single, world_size, dim=0))
|
||||
scale_list = scale_list_single.chunk(world_size, dim=0)
|
||||
|
||||
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
||||
for scale, out in zip(scale_list, output_chunks):
|
||||
out = out.view(fp8_type)
|
||||
summed_out += cast_from_fp8(out, scale, input_type)
|
||||
summed_out.div_(world_size)
|
||||
|
||||
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
|
||||
|
||||
tensor_list_single = torch.empty(summed_out_fp8.size(0) * world_size, device=input_device, dtype=torch.uint8)
|
||||
fut2 = dist.all_gather_into_tensor(
|
||||
tensor_list_single, summed_out_fp8.view(torch.uint8), group=group_to_use, async_op=True
|
||||
).get_future()
|
||||
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||
fut3 = dist.all_gather_into_tensor(
|
||||
torch.cat(scale_list, dim=0), scale, group=group_to_use, async_op=True
|
||||
).get_future()
|
||||
fut_combined2 = torch.futures.collect_all([fut2, fut3])
|
||||
return fut_combined2
|
||||
|
||||
def decompress(fut):
|
||||
tensor_list_single = fut.value().wait()[0].value()[0]
|
||||
scale_list_single = fut.value().wait()[1].value()[0]
|
||||
|
||||
tensor_list = list(torch.chunk(tensor_list_single, world_size, dim=0))
|
||||
scale_list = scale_list_single.chunk(world_size, dim=0)
|
||||
|
||||
for i in range(world_size):
|
||||
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
|
||||
out = torch.cat(tensor_list, dim=0)
|
||||
|
||||
input_tensor_size = input_tensor.numel()
|
||||
input_shape = input_tensor.shape
|
||||
out = out[:input_tensor_size]
|
||||
|
||||
input_tensor.copy_(out.view(input_shape).to(input_type))
|
||||
return input_tensor
|
||||
|
||||
return all_to_all_fut.then(sum_and_allgather).then(decompress)
|
||||
|
||||
|
||||
def fp8_compress_ddp_grad_comm_hook_sync(
|
||||
process_group: dist.ProcessGroup,
|
||||
bucket: dist.GradBucket,
|
||||
fp8_format="e5m2",
|
||||
) -> torch.futures.Future[torch.Tensor]:
|
||||
"""
|
||||
Return a future that wraps the input, after the input is allreduced. However, the allreduce commnunication is synchronized.
|
||||
This breaks the overlapping between allreduce communication and backward compuation.
|
||||
|
||||
This hook should **only** be used for debugging purposes, instead of the normal gradient synchronization.
|
||||
For asynchronized implementation, use fp8_compress_ddp_grad_comm_hook_async instead.
|
||||
|
||||
Example::
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> ddp_model.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_sync)
|
||||
"""
|
||||
|
||||
buffer = bucket.buffer()
|
||||
all_reduce_fp8(buffer, fp8_format=fp8_format)
|
||||
|
||||
fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
|
||||
fut.set_result(bucket.buffer())
|
||||
|
||||
return fut
|
||||
|
||||
|
||||
def fp8_compress_fsdp_grad_comm_hook(
|
||||
state: object,
|
||||
unsharded_gradient_flattened: torch.Tensor,
|
||||
sharded_gradient: torch.Tensor,
|
||||
group=None,
|
||||
fp8_format="e5m2",
|
||||
) -> None:
|
||||
"""
|
||||
This communication hook implements a simple gradient compression approach that casts unsharded_gradient_flattened tensor
|
||||
to FP8 floating-point format (``torch.float8_e5m2`` or ``torch.bfloat16_e4m3``), and then perform scatter_allreduce logic
|
||||
by using all_to_all and all_gather among the process group.
|
||||
|
||||
Example::
|
||||
>>> fsdp_model.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook)
|
||||
"""
|
||||
grad = unsharded_gradient_flattened
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
input_type = grad.dtype
|
||||
input_device = grad.device
|
||||
world_size = dist.get_world_size(group=group)
|
||||
|
||||
grad_fp8, scale = cast_to_fp8(grad, fp8_format=fp8_format)
|
||||
uint8_buffer = torch.empty_like(grad_fp8).view(torch.uint8)
|
||||
dist.all_to_all_single(uint8_buffer, grad_fp8.view(torch.uint8), group=group)
|
||||
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
|
||||
buffer_list = list(torch.chunk(uint8_buffer.view(fp8_type), world_size, dim=0))
|
||||
sharded_gradient.zero_()
|
||||
for tensor, scale in zip(buffer_list, scale_list):
|
||||
sharded_gradient += cast_from_fp8(tensor, scale, input_type)
|
||||
|
||||
|
||||
def fp8_compress_fsdp_params_comm_hook(
|
||||
state: object,
|
||||
padded_unsharded_flat_param: torch.Tensor,
|
||||
sharded_flat_param: torch.Tensor,
|
||||
group=None,
|
||||
fp8_format="e5m2",
|
||||
) -> None:
|
||||
"""
|
||||
This hook is pending the official support for parameters communication hook in FSDP, e.g. register_params_comm_hook.
|
||||
|
||||
Example::
|
||||
>>> fsdp_model.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook)
|
||||
"""
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
fp8_max = torch.finfo(fp8_type).max
|
||||
inp = sharded_flat_param
|
||||
out = padded_unsharded_flat_param
|
||||
|
||||
per_tensor_max = inp.abs().max().float()
|
||||
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
|
||||
dist.all_reduce(per_tensor_max, op=torch.distributed.ReduceOp.MAX, group=group)
|
||||
|
||||
scale = fp8_max / per_tensor_max
|
||||
fp8_sharded_flat_param = (scale * inp.float()).to(fp8_type).view(torch.uint8)
|
||||
|
||||
fp8_out = torch.empty(out.shape, dtype=torch.uint8, device=out.device)
|
||||
dist.all_gather_into_tensor(
|
||||
fp8_out,
|
||||
fp8_sharded_flat_param,
|
||||
group=group,
|
||||
)
|
||||
padded_unsharded_flat_param.copy_((fp8_out.view(fp8_type).float() / scale).to(out.dtype))
|
||||
|
||||
|
||||
def split_chunk_by_channel(
|
||||
chunk: torch.Tensor, channel_size: int, num_channels: int, rank: int = 0, world_size: int = 1
|
||||
):
|
||||
offset = chunk.numel() * rank
|
||||
end = offset + chunk.numel()
|
||||
break_points = [x for x in range(0, channel_size * num_channels + 1, channel_size) if offset <= x <= end]
|
||||
if len(break_points) == 0 or break_points[0] > offset:
|
||||
break_points.insert(0, offset)
|
||||
if break_points[-1] < end:
|
||||
break_points.append(end)
|
||||
sizes = [b - a for a, b in zip(break_points[:-1], break_points[1:])]
|
||||
return chunk.split(sizes)
|
||||
|
||||
|
||||
def all_gather_into_tensor_flat_fp8(
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
output_shape: torch.Size,
|
||||
group: dist.ProcessGroup,
|
||||
fp8_format: str = "e4m3",
|
||||
async_op: bool = False,
|
||||
) -> Optional[Handle]:
|
||||
"""all gather into tensor in fp8 format
|
||||
|
||||
Args:
|
||||
output_tensor (torch.Tensor): output tensor, which is flattened
|
||||
input_tensor (torch.Tensor): input tensor, which is flattened
|
||||
group (dist.ProcessGroup): process group
|
||||
fp8_format (str, optional): fp8 format, e4m3 or e5m2. Defaults to "e4m3".
|
||||
"""
|
||||
assert input_tensor.dim() == 1 and output_tensor.dim() == 1, "input/output tensor should be flattened"
|
||||
world_size = dist.get_world_size(group)
|
||||
assert (
|
||||
output_tensor.numel() == input_tensor.numel() * world_size
|
||||
), "output tensor size should be world_size times of input tensor size"
|
||||
|
||||
input_type = output_tensor.dtype
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
fp8_max = torch.finfo(fp8_type).max
|
||||
|
||||
if len(output_shape) == 2:
|
||||
per_channel_max = torch.zeros(output_shape[0], device=output_tensor.device, dtype=torch.float)
|
||||
num_channels, channel_size = output_shape
|
||||
rank = dist.get_rank(group)
|
||||
channel_start_idx = (input_tensor.numel() * rank) // channel_size
|
||||
per_channel_splits = split_chunk_by_channel(input_tensor, channel_size, num_channels, rank, world_size)
|
||||
for i, per_channel_split in enumerate(per_channel_splits):
|
||||
idx = i + channel_start_idx
|
||||
if idx < num_channels:
|
||||
per_channel_max[idx] = per_channel_split.abs().max().float()
|
||||
dist.all_reduce(per_channel_max, op=dist.ReduceOp.MAX, group=group)
|
||||
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
|
||||
scale = fp8_max / per_channel_max
|
||||
fp8_input = input_tensor.float()
|
||||
fp8_per_channel_splits = split_chunk_by_channel(fp8_input, channel_size, num_channels, rank, world_size)
|
||||
for i, per_channel_split in enumerate(fp8_per_channel_splits):
|
||||
idx = i + channel_start_idx
|
||||
if idx < num_channels:
|
||||
per_channel_split.mul_(scale[idx])
|
||||
fp8_input = fp8_input.to(fp8_type)
|
||||
else:
|
||||
per_tensor_max = input_tensor.abs().max().float()
|
||||
dist.all_reduce(per_tensor_max, op=dist.ReduceOp.MAX, group=group)
|
||||
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
|
||||
scale = fp8_max / per_tensor_max
|
||||
fp8_input = (scale * input_tensor.float()).to(fp8_type)
|
||||
scale_inv = 1.0 / scale
|
||||
|
||||
buffer = torch.empty_like(output_tensor, dtype=fp8_type)
|
||||
tensor_handle = dist.all_gather_into_tensor(
|
||||
buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group, async_op=async_op
|
||||
)
|
||||
|
||||
def cast_op():
|
||||
numel = output_shape.numel()
|
||||
valid_buffer = buffer[:numel].reshape(output_shape)
|
||||
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2))
|
||||
output_tensor[:numel].copy_(valid_buffer.view(-1))
|
||||
|
||||
if async_op:
|
||||
return Handle([tensor_handle], cast_op)
|
||||
else:
|
||||
cast_op()
|
||||
|
||||
|
||||
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
|
||||
|
||||
world_size = dist.get_world_size(group)
|
||||
|
||||
input_type = input_list[0].dtype
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
scale_list = []
|
||||
tensor_list = []
|
||||
|
||||
for i in range(world_size):
|
||||
input_tensor = input_list[i]
|
||||
ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format)
|
||||
scale_list.append(scale)
|
||||
ret = ret.view(torch.uint8)
|
||||
tensor_list.append(ret)
|
||||
|
||||
output_scale_list = [torch.empty_like(x) for x in scale_list]
|
||||
output_tensor_list = [torch.empty_like(x) for x in tensor_list]
|
||||
tensor_hanle = dist.all_to_all(output_tensor_list, tensor_list, group=group, async_op=async_op)
|
||||
scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op)
|
||||
|
||||
def cast_op():
|
||||
for i in range(world_size):
|
||||
scale = output_scale_list[i]
|
||||
tensor = output_tensor_list[i]
|
||||
tensor = tensor.view(fp8_type)
|
||||
output_list[i].copy_(cast_from_fp8(tensor, scale, input_type))
|
||||
|
||||
if async_op:
|
||||
return Handle([tensor_hanle, scale_handle], cast_op)
|
||||
else:
|
||||
cast_op()
|
||||
|
||||
|
||||
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
|
||||
|
||||
world_size = dist.get_world_size(group)
|
||||
|
||||
input_type = input_.dtype
|
||||
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format)
|
||||
fp8_type = ret.dtype
|
||||
input_ = ret.view(torch.uint8)
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)]
|
||||
chunk_handle = dist.all_gather(tensor_list, input_, group=group, async_op=async_op)
|
||||
scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
|
||||
|
||||
def cast_op():
|
||||
for i in range(world_size):
|
||||
output = tensor_list[i].view(fp8_type)
|
||||
scale = scale_list[i]
|
||||
output_list[i].copy_(cast_from_fp8(output, scale, input_type))
|
||||
|
||||
if async_op:
|
||||
return Handle([chunk_handle, scale_hanle], cast_op)
|
||||
else:
|
||||
cast_op()
|
||||
|
||||
|
||||
class _LinearFp8(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
x: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> Any:
|
||||
assert (
|
||||
x.dtype in (torch.bfloat16, torch.float16) and x.dtype == w.dtype
|
||||
), "Only float16 and bfloat16 are allowed."
|
||||
if bias is not None:
|
||||
assert bias.dtype == x.dtype, "Bias should have the same dtype as input."
|
||||
# ensure x and w are row-major
|
||||
x = x.contiguous()
|
||||
w = w.contiguous()
|
||||
ctx.x_shape = x.shape
|
||||
ctx.has_bias = bias is not None
|
||||
ctx.out_dtype = x.dtype
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
x_fp8, inv_scale_x = cast_to_fp8(x, fp8_format="e4m3")
|
||||
w_fp8, inv_scale_w = cast_to_fp8(w, fp8_format="e4m3")
|
||||
ctx.x_fp8 = x_fp8
|
||||
ctx.w_fp8_t = w_fp8.t()
|
||||
ctx.inv_scale_x = inv_scale_x
|
||||
ctx.inv_scale_w = inv_scale_w
|
||||
out = torch._scaled_mm(
|
||||
x_fp8,
|
||||
ctx.w_fp8_t,
|
||||
bias=bias,
|
||||
out_dtype=ctx.out_dtype,
|
||||
scale_a=inv_scale_x,
|
||||
scale_b=inv_scale_w,
|
||||
use_fast_accum=True,
|
||||
)[0]
|
||||
return out.reshape(*ctx.x_shape[:-1], w.shape[0])
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, out_grad) -> Any:
|
||||
out_grad = out_grad.reshape(-1, out_grad.shape[-1])
|
||||
out_grad_fp8, out_grad_scale = cast_to_fp8(out_grad, fp8_format="e5m2")
|
||||
x_grad = torch._scaled_mm(
|
||||
out_grad_fp8,
|
||||
ctx.w_fp8_t.contiguous().t(),
|
||||
out_dtype=ctx.out_dtype,
|
||||
scale_a=out_grad_scale,
|
||||
scale_b=ctx.inv_scale_w,
|
||||
use_fast_accum=True,
|
||||
)[0]
|
||||
w_grad = torch._scaled_mm(
|
||||
out_grad_fp8.t().contiguous(),
|
||||
ctx.x_fp8.t().contiguous().t(),
|
||||
out_dtype=ctx.out_dtype,
|
||||
scale_a=out_grad_scale,
|
||||
scale_b=ctx.inv_scale_x,
|
||||
use_fast_accum=True,
|
||||
)[0]
|
||||
bias_grad = None
|
||||
if ctx.has_bias:
|
||||
bias_grad = out_grad.sum(0)
|
||||
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
||||
|
||||
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False)
|
||||
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return _LinearFp8.apply(input, weight, bias)
|
||||
|
||||
|
||||
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
out = _linear_fp8(input, weight, bias)
|
||||
return out
|
|
@ -0,0 +1,23 @@
|
|||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.quantization.fp8 import linear_fp8
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
|
||||
|
||||
class FP8Hook(ColoParamOpHook):
|
||||
def pre_forward(self, params) -> None:
|
||||
pass
|
||||
|
||||
def post_forward(self, params) -> None:
|
||||
pass
|
||||
|
||||
def pre_backward(self, params) -> None:
|
||||
pass
|
||||
|
||||
def post_backward(self, params) -> None:
|
||||
pass
|
||||
|
||||
def rewrite_op(self, func):
|
||||
if func is F.linear:
|
||||
return linear_fp8
|
||||
return func
|
|
@ -0,0 +1,112 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from packaging import version
|
||||
from torch import Tensor
|
||||
from torch.distributed.fsdp._common_utils import _no_dispatch_record_stream
|
||||
from torch.distributed.utils import _p_assert
|
||||
|
||||
|
||||
def _all_gather_flat_param(
|
||||
self,
|
||||
padded_unsharded_flat_param: Tensor,
|
||||
) -> Tensor:
|
||||
"""
|
||||
All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``.
|
||||
|
||||
Then switch to use the all-gathered tensor.
|
||||
"""
|
||||
_p_assert(
|
||||
hasattr(self, "process_group") and hasattr(self, "world_size"),
|
||||
"Expects a process group and world size to have been set via `shard()`",
|
||||
)
|
||||
sharded_flat_param = self.flat_param.data
|
||||
expected_numel = sharded_flat_param.numel() * self.world_size
|
||||
_p_assert(
|
||||
padded_unsharded_flat_param.numel() == expected_numel,
|
||||
f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
|
||||
)
|
||||
|
||||
pg = self._fake_process_group if self._use_fake_all_gather else self.process_group
|
||||
|
||||
# HACK this should be handled by C10D
|
||||
if sharded_flat_param.is_cpu: # type: ignore[attr-defined]
|
||||
tensor_list = list(torch.chunk(padded_unsharded_flat_param, dist.get_world_size(pg)))
|
||||
work = dist.all_gather(tensor_list, sharded_flat_param, group=pg)
|
||||
else:
|
||||
if self._comm_hook is None:
|
||||
dist.all_gather_into_tensor(
|
||||
padded_unsharded_flat_param,
|
||||
sharded_flat_param,
|
||||
pg,
|
||||
)
|
||||
else:
|
||||
self._comm_hook(None, padded_unsharded_flat_param, sharded_flat_param, pg)
|
||||
|
||||
if self._offload_params:
|
||||
# In case of offloading, `flat_param.data` (i.e. sharded param) is
|
||||
# created on the pre-unshard stream. We need to hand it over to the
|
||||
# unshard stream for all-gather
|
||||
_no_dispatch_record_stream(
|
||||
sharded_flat_param,
|
||||
self._device_handle.current_stream(), # unshard_stream
|
||||
)
|
||||
return padded_unsharded_flat_param
|
||||
|
||||
|
||||
def register_params_comm_hook(self, state: object, hook: callable):
|
||||
"""Register a communication hook for FlatParamHandle.
|
||||
|
||||
This is an enhancement that provides a flexible hook to users where they can specify how FSDP unshards
|
||||
parameters across multiple workers.
|
||||
|
||||
.. warning ::
|
||||
FSDP communication hook should be registered before running an initial forward pass
|
||||
and only once.
|
||||
|
||||
Args:
|
||||
state (object): Passed to the hook to maintain any state information during the training process.
|
||||
hook (Callable): Callable, which has one of the following signatures:
|
||||
1) ``hook: Callable[torch.Tensor] -> None``:
|
||||
This function takes in a Python tensor, which represents
|
||||
the full, flattened, unsharded gradient with respect to all variables
|
||||
corresponding to the model this FSDP unit is wrapping
|
||||
(that are not wrapped by other FSDP sub-units).
|
||||
It then performs all necessary processing and returns ``None``;
|
||||
2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``:
|
||||
This function takes in two Python tensors, the first one represents
|
||||
the full, flattened, unsharded gradient with respect to all variables
|
||||
corresponding to the model this FSDP unit is wrapping
|
||||
(that are not wrapped by other FSDP sub-units). The latter
|
||||
represents a pre-sized tensor to store a chunk of a sharded gradient after
|
||||
reduction.
|
||||
In both cases, callable performs all necessary processing and returns ``None``.
|
||||
Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case.
|
||||
Callables with signature 2 are expected to handle gradient communication for sharded cases.
|
||||
|
||||
"""
|
||||
if not self.check_is_root():
|
||||
raise AssertionError("register_comm_hook can only be called on a root instance.")
|
||||
|
||||
# if fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES:
|
||||
# raise AssertionError(
|
||||
# f"Communication hook is not supported for hybrid strategies: {fsdp_state.sharding_strategy}"
|
||||
# )
|
||||
if self._handle._comm_hook is not None:
|
||||
raise AssertionError("A communication hook is already registered")
|
||||
if not callable(hook):
|
||||
raise ValueError(f"The communication hook must be callable but got {hook}")
|
||||
self._handle._comm_hook = hook
|
||||
self._handle._comm_hook_state = state
|
||||
|
||||
|
||||
def patch_fsdp_params_comm_hook():
|
||||
if version.parse(torch.__version__) >= version.parse("2.2.0"):
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp._flat_param import FlatParamHandle
|
||||
|
||||
FlatParamHandle._comm_hook = None
|
||||
FlatParamHandle._comm_hook_state = None
|
||||
FlatParamHandle._all_gather_flat_param = _all_gather_flat_param
|
||||
FSDP.register_params_comm_hook = register_params_comm_hook
|
||||
else:
|
||||
raise RuntimeError("This fsdp_params_comm_hook patch is not supported while torch version under 2.2.0.")
|
|
@ -16,6 +16,14 @@ try:
|
|||
except ImportError:
|
||||
_grad_accum_fusion_available = False
|
||||
|
||||
from colossalai.quantization.fp8 import (
|
||||
all_reduce_fp8,
|
||||
all_to_all_fp8,
|
||||
all_to_all_single_fp8,
|
||||
gather_fp8,
|
||||
reduce_scatter_fp8,
|
||||
)
|
||||
|
||||
|
||||
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
||||
r"""Layernorm
|
||||
|
@ -61,11 +69,12 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
ctx.fp8_communication = fp8_communication
|
||||
|
||||
output = torch.matmul(input_, weight)
|
||||
|
||||
|
@ -78,6 +87,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|||
def backward(ctx, grad_output):
|
||||
input, weight, bias = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
fp8_communication = ctx.fp8_communication
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
|
||||
weight = weight.view(weight.shape)
|
||||
|
@ -92,7 +102,9 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
if ctx.async_grad_allreduce and fp8_communication:
|
||||
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2")
|
||||
elif ctx.async_grad_allreduce:
|
||||
# Asynchronous all-reduce
|
||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||
|
@ -101,10 +113,10 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|||
grad_weight = total_input.t().matmul(grad_output)
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
if ctx.async_grad_allreduce and not fp8_communication:
|
||||
handle.wait()
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
|
@ -113,11 +125,12 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
ctx.fp8_communication = fp8_communication
|
||||
if bias is not None:
|
||||
output = F.linear(input_, weight, bias)
|
||||
else:
|
||||
|
@ -129,6 +142,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
def backward(ctx, grad_output):
|
||||
input, weight, bias = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
fp8_communication = ctx.fp8_communication
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
|
||||
if use_bias:
|
||||
|
@ -144,10 +158,11 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
|
||||
if ctx.async_grad_allreduce:
|
||||
# Asynchronous all-reduce
|
||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||
_ = torch.zeros(1, device=grad_input.device)
|
||||
|
||||
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||
if fp8_communication:
|
||||
all_reduce_fp8(grad_input, group=ctx.process_group)
|
||||
else:
|
||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
||||
|
||||
if _grad_accum_fusion_available and weight.grad is not None:
|
||||
|
@ -165,10 +180,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
if ctx.async_grad_allreduce and not fp8_communication:
|
||||
handle.wait()
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
|
||||
|
@ -236,17 +251,18 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, dim):
|
||||
def forward(ctx, input_, process_group, dim, fp8_communication=False):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
ctx.fp8_communication = fp8_communication
|
||||
|
||||
return _gather(input_, dim, process_group)
|
||||
return _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
|
||||
fp8_communication = ctx.fp8_communication
|
||||
# do reduce-scatter
|
||||
new_shape = list(grad_output.shape)
|
||||
assert (
|
||||
|
@ -257,9 +273,13 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)
|
||||
]
|
||||
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device)
|
||||
dist.reduce_scatter(output, grad_list, group=process_group)
|
||||
|
||||
return output, None, None
|
||||
if fp8_communication:
|
||||
reduce_scatter_fp8(output, grad_list, group=process_group, fp8_format="e5m2")
|
||||
else:
|
||||
dist.reduce_scatter(output, grad_list, group=process_group)
|
||||
|
||||
return output, None, None, None
|
||||
|
||||
|
||||
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
|
@ -550,9 +570,10 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, dim):
|
||||
def forward(ctx, input_, process_group, dim, fp8_communication=False):
|
||||
ctx.dim = dim
|
||||
ctx.process_group = process_group
|
||||
ctx.fp8_communication = fp8_communication
|
||||
|
||||
# do reduce-scatter
|
||||
new_shape = list(input_.shape)
|
||||
|
@ -562,7 +583,10 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
|||
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
|
||||
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
|
||||
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
|
||||
dist.reduce_scatter(output, input_list, group=process_group)
|
||||
if fp8_communication:
|
||||
reduce_scatter_fp8(output, input_list, group=process_group, fp8_format="e4m3")
|
||||
else:
|
||||
dist.reduce_scatter(output, input_list, group=process_group)
|
||||
|
||||
return output
|
||||
|
||||
|
@ -570,8 +594,9 @@ class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
|||
def backward(ctx, grad_output):
|
||||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
fp8_communication = ctx.fp8_communication
|
||||
|
||||
return _gather(grad_output, dim, process_group), None, None
|
||||
return _gather(grad_output, dim, process_group, fp8_communication, fp8_format="e5m2"), None, None, None
|
||||
|
||||
|
||||
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
|
@ -586,13 +611,16 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring):
|
||||
def forward(
|
||||
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
|
||||
):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
|
||||
ctx.dim = dim
|
||||
ctx.overlap = overlap
|
||||
ctx.fp8_communication = fp8_communication
|
||||
|
||||
if ring is True:
|
||||
input_to_gather = {}
|
||||
|
@ -609,7 +637,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
)
|
||||
|
||||
else:
|
||||
input_parallel = _gather(input_, dim, process_group)
|
||||
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
|
||||
|
||||
output = torch.matmul(input_parallel, weight)
|
||||
|
||||
|
@ -624,6 +652,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
overlap = ctx.overlap
|
||||
fp8_communication = ctx.fp8_communication
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
|
||||
weight = weight.view(weight.shape)
|
||||
|
@ -631,7 +660,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
bias = bias.view(bias.shape)
|
||||
|
||||
if not overlap:
|
||||
input_parallel = _gather(input_, dim, process_group)
|
||||
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2")
|
||||
|
||||
total_input = input_parallel
|
||||
grad_input = grad_output.matmul(weight.T)
|
||||
|
@ -691,7 +720,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
# wait until reduce-scatter finished
|
||||
reducescatter_handle.wait()
|
||||
|
||||
return output, grad_weight, grad_bias, None, None, None, None, None
|
||||
return output, grad_weight, grad_bias, None, None, None, None, None, None
|
||||
|
||||
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
|
@ -706,17 +735,25 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, process_group, grad_scale=None):
|
||||
def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
ctx.grad_scale = grad_scale
|
||||
ctx.fp8_communication = fp8_communication
|
||||
return _split(input_, dim, process_group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.grad_scale is not None:
|
||||
grad_output = grad_output * ctx.grad_scale
|
||||
return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None
|
||||
|
||||
return (
|
||||
_gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication, fp8_format="e5m2"),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class _ReduceForward(torch.autograd.Function):
|
||||
|
@ -730,15 +767,15 @@ class _ReduceForward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, grad_scale=None):
|
||||
def forward(ctx, input_, process_group, grad_scale=None, fp8_communication=False):
|
||||
ctx.grad_scale = grad_scale
|
||||
return _reduce(input_, process_group)
|
||||
return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3")
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.grad_scale is not None:
|
||||
grad_output = grad_output * ctx.grad_scale
|
||||
return grad_output, None, None
|
||||
return grad_output, None, None, None
|
||||
|
||||
|
||||
class _ReduceBackward(torch.autograd.Function):
|
||||
|
@ -751,13 +788,15 @@ class _ReduceBackward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group):
|
||||
def forward(ctx, input_, process_group, fp8_communication=False):
|
||||
ctx.process_group = process_group
|
||||
ctx.fp8_communication = fp8_communication
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _reduce(grad_output, ctx.process_group), None
|
||||
fp8_communication = ctx.fp8_communication
|
||||
return _reduce(grad_output, ctx.process_group, fp8_communication, fp8_format="e5m2"), None, None
|
||||
|
||||
|
||||
class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||
|
@ -770,17 +809,18 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, process_group, grad_scale=None):
|
||||
def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
ctx.grad_scale = grad_scale
|
||||
return _gather(input_, dim, process_group)
|
||||
|
||||
return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3")
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.grad_scale is not None:
|
||||
grad_output = grad_output * ctx.grad_scale
|
||||
return _split(grad_output, ctx.dim, ctx.process_group), None, None, None
|
||||
return _split(grad_output, ctx.dim, ctx.process_group), None, None, None, None
|
||||
|
||||
|
||||
class _AllToAll(torch.autograd.Function):
|
||||
|
@ -794,26 +834,67 @@ class _AllToAll(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
|
||||
def forward(ctx, input_, process_group, scatter_dim, gather_dim, fp8_communication=False):
|
||||
ctx.process_group = process_group
|
||||
ctx.scatter_dim = scatter_dim
|
||||
ctx.gather_dim = gather_dim
|
||||
ctx.fp8_communication = fp8_communication
|
||||
world_size = dist.get_world_size(process_group)
|
||||
bsz, _, _ = input_.shape
|
||||
|
||||
# using all_to_all_single when batch size is 1
|
||||
if bsz == 1:
|
||||
return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim)
|
||||
return _all_to_all_single(
|
||||
input_,
|
||||
world_size,
|
||||
process_group,
|
||||
scatter_dim,
|
||||
gather_dim,
|
||||
fp8_communication=fp8_communication,
|
||||
fp8_format="e4m3",
|
||||
)
|
||||
else:
|
||||
return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim)
|
||||
return _all_to_all(
|
||||
input_,
|
||||
world_size,
|
||||
process_group,
|
||||
scatter_dim,
|
||||
gather_dim,
|
||||
fp8_communication=fp8_communication,
|
||||
fp8_format="e4m3",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_output):
|
||||
def backward(ctx, grad_output):
|
||||
process_group = ctx.process_group
|
||||
scatter_dim = ctx.gather_dim
|
||||
gather_dim = ctx.scatter_dim
|
||||
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
|
||||
return (return_grad, None, None, None)
|
||||
fp8_communication = ctx.fp8_communication
|
||||
world_size = dist.get_world_size(process_group)
|
||||
bsz, _, _ = grad_output.shape
|
||||
|
||||
if bsz == 1:
|
||||
return_grad = _all_to_all_single(
|
||||
grad_output,
|
||||
world_size,
|
||||
process_group,
|
||||
scatter_dim,
|
||||
gather_dim,
|
||||
fp8_communication=fp8_communication,
|
||||
fp8_format="e5m2",
|
||||
)
|
||||
else:
|
||||
return_grad = _all_to_all(
|
||||
grad_output,
|
||||
world_size,
|
||||
process_group,
|
||||
scatter_dim,
|
||||
gather_dim,
|
||||
fp8_communication=fp8_communication,
|
||||
fp8_format="e5m2",
|
||||
)
|
||||
|
||||
return (return_grad, None, None, None, None)
|
||||
|
||||
|
||||
class HookParameter(torch.autograd.Function):
|
||||
|
@ -839,12 +920,15 @@ def hook_parameter_in_backward(input, weight=None, bias=None):
|
|||
return HookParameter.apply(input, weight, bias)
|
||||
|
||||
|
||||
def _reduce(input_, process_group):
|
||||
def _reduce(input_, process_group, fp8_communication=False, fp8_format="e5m2"):
|
||||
# skip if only one rank involved
|
||||
if dist.get_world_size(process_group) == 1:
|
||||
return input_
|
||||
else:
|
||||
dist.all_reduce(input_, group=process_group)
|
||||
if fp8_communication:
|
||||
all_reduce_fp8(input_, group=process_group, fp8_format=fp8_format)
|
||||
else:
|
||||
dist.all_reduce(input_, group=process_group)
|
||||
return input_
|
||||
|
||||
|
||||
|
@ -868,18 +952,19 @@ def _split(input_, dim=-1, process_group=None):
|
|||
return output
|
||||
|
||||
|
||||
def _gather(input_, dim=-1, process_group=None):
|
||||
def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e5m2"):
|
||||
# skip if only one rank involved
|
||||
world_size = dist.get_world_size(process_group)
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# all gather
|
||||
input_ = input_.contiguous()
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(tensor_list, input_, group=process_group)
|
||||
if fp8_communication:
|
||||
gather_fp8(tensor_list, input_, fp8_format=fp8_format, group=process_group)
|
||||
else:
|
||||
dist.all_gather(tensor_list, input_, group=process_group)
|
||||
|
||||
# concat
|
||||
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
|
||||
return output
|
||||
|
@ -909,14 +994,19 @@ def _reduce_scatter(input_, dim=1, process_group=None):
|
|||
return output
|
||||
|
||||
|
||||
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim):
|
||||
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"):
|
||||
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
|
||||
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
|
||||
dist.all_to_all(output_list, input_list, group=group)
|
||||
if fp8_communication:
|
||||
all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format)
|
||||
else:
|
||||
dist.all_to_all(output_list, input_list, group=group)
|
||||
return torch.cat(output_list, dim=gather_dim).contiguous()
|
||||
|
||||
|
||||
def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
|
||||
def _all_to_all_single(
|
||||
input_, seq_world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"
|
||||
):
|
||||
inp_shape = list(input_.shape)
|
||||
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
|
||||
if scatter_dim < 2:
|
||||
|
@ -929,7 +1019,11 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
|
|||
)
|
||||
|
||||
output = torch.empty_like(input_t)
|
||||
dist.all_to_all_single(output, input_t, group=group)
|
||||
if fp8_communication:
|
||||
all_to_all_single_fp8(output, input_t, group=group, fp8_format=fp8_format)
|
||||
else:
|
||||
|
||||
dist.all_to_all_single(output, input_t, group=group)
|
||||
|
||||
if scatter_dim < 2:
|
||||
output = output.transpose(0, 1).contiguous()
|
||||
|
@ -943,12 +1037,16 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
|
|||
).contiguous()
|
||||
|
||||
|
||||
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
|
||||
return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
|
||||
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
||||
return MatmulWithAsyncCommunication.apply(
|
||||
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication
|
||||
)
|
||||
|
||||
|
||||
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
|
||||
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
|
||||
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
||||
return LinearWithAsyncCommunication.apply(
|
||||
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication
|
||||
)
|
||||
|
||||
|
||||
def linear_gather_forward_reducescatter_backward(
|
||||
|
@ -959,12 +1057,12 @@ def linear_gather_forward_reducescatter_backward(
|
|||
)
|
||||
|
||||
|
||||
def gather_forward_reducescatter_backward(input_, process_group, dim):
|
||||
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim)
|
||||
def gather_forward_reducescatter_backward(input_, process_group, dim, fp8_communication=False):
|
||||
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim, fp8_communication)
|
||||
|
||||
|
||||
def reducescatter_forward_gather_backward(input_, process_group, dim):
|
||||
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
|
||||
def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_communication=False):
|
||||
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim, fp8_communication)
|
||||
|
||||
|
||||
def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False):
|
||||
|
@ -972,38 +1070,40 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc
|
|||
|
||||
|
||||
def matmul_gather_forward_reducescatter_backward(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False, fp8_communication=False
|
||||
):
|
||||
return _MatmulWithGatherForwardReduceScatterBackward.apply(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
|
||||
)
|
||||
|
||||
|
||||
def gather_forward_split_backward(input_, dim, process_group, grad_scale=None):
|
||||
return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale)
|
||||
def gather_forward_split_backward(input_, dim, process_group, grad_scale=None, fp8_communication=False):
|
||||
return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale, fp8_communication)
|
||||
|
||||
|
||||
def split_forward_gather_backward(input_, dim, process_group, grad_scale=None):
|
||||
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale)
|
||||
def split_forward_gather_backward(input_, dim, process_group, grad_scale=None, fp8_communication=False):
|
||||
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale, fp8_communication)
|
||||
|
||||
|
||||
def reduce_forward(input_, process_group, grad_scale=None):
|
||||
return _ReduceForward.apply(input_, process_group, grad_scale)
|
||||
def reduce_forward(input_, process_group, grad_scale=None, fp8_communication=False):
|
||||
return _ReduceForward.apply(input_, process_group, grad_scale, fp8_communication)
|
||||
|
||||
|
||||
def reduce_backward(input_, process_group):
|
||||
return _ReduceBackward.apply(input_, process_group)
|
||||
def reduce_backward(input_, process_group, fp8_communication=False):
|
||||
return _ReduceBackward.apply(input_, process_group, fp8_communication)
|
||||
|
||||
|
||||
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
|
||||
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
|
||||
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_communication=False):
|
||||
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication)
|
||||
|
||||
|
||||
def gather_sp_output(hidden_states, sp_group, sp_mode):
|
||||
def gather_sp_output(hidden_states, sp_group, sp_mode, fp8_communication=False):
|
||||
"""
|
||||
Gather the output of the last layer for cross entropy computation
|
||||
"""
|
||||
# Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group)
|
||||
scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group)
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=scale)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=scale, fp8_communication=fp8_communication
|
||||
)
|
||||
return hidden_states
|
||||
|
|
|
@ -68,6 +68,7 @@ class Embedding1D(ParallelModule):
|
|||
gather_output: bool = True,
|
||||
weight: Optional[nn.Parameter] = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
fp8_communication: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -81,6 +82,7 @@ class Embedding1D(ParallelModule):
|
|||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
self.gather_output = gather_output
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
|
@ -155,7 +157,9 @@ class Embedding1D(ParallelModule):
|
|||
def forward(self, input_: Tensor) -> Tensor:
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
if self.gather_output:
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
output = gather_forward_split_backward(
|
||||
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
return output
|
||||
else:
|
||||
return output_parallel
|
||||
|
@ -274,6 +278,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
|
|||
weight: Optional[nn.Parameter] = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
fp8_communication: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -282,6 +287,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
|
|||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
self.process_group = process_group
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
tensor_parallel_size = dist.get_world_size(group=process_group)
|
||||
tensor_parallel_rank = dist.get_rank(group=process_group)
|
||||
|
@ -390,5 +396,5 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
|
|||
embedding_output = output_parallel.clone()
|
||||
embedding_output[input_mask, :] = 0.0
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = reduce_forward(embedding_output, self.process_group)
|
||||
output = reduce_forward(embedding_output, self.process_group, fp8_communication=self.fp8_communication)
|
||||
return output
|
||||
|
|
|
@ -84,6 +84,7 @@ class Linear1D_Col(ParallelModule):
|
|||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
fp8_communication: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(weight=weight, bias_=bias_, **kwargs)
|
||||
|
@ -98,6 +99,7 @@ class Linear1D_Col(ParallelModule):
|
|||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.process_group = process_group
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
@ -202,19 +204,25 @@ class Linear1D_Col(ParallelModule):
|
|||
|
||||
if self.seq_parallel_mode == "split_gather":
|
||||
input_parallel = gather_forward_reducescatter_backward(
|
||||
input_parallel, self.process_group, self.seq_parallel_dim
|
||||
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||
)
|
||||
output_parallel = linear_with_async_comm(
|
||||
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
|
||||
)
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
output_parallel = linear_gather_forward_reducescatter_backward(
|
||||
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
|
||||
)
|
||||
else:
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||
output_parallel = linear_with_async_comm(
|
||||
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
output = gather_forward_split_backward(
|
||||
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
|
@ -264,6 +272,7 @@ class Linear1D_Row(ParallelModule):
|
|||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
stream_chunk_num: int = 1,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -278,6 +287,7 @@ class Linear1D_Row(ParallelModule):
|
|||
self.seq_parallel_mode = seq_parallel_mode
|
||||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
@ -398,7 +408,9 @@ class Linear1D_Row(ParallelModule):
|
|||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
|
||||
)
|
||||
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
|
||||
input_ = split_forward_gather_backward(
|
||||
input_, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
if self.training:
|
||||
|
@ -416,10 +428,13 @@ class Linear1D_Row(ParallelModule):
|
|||
handle.wait()
|
||||
output = torch.cat(output_parallel_list, dim=-1)
|
||||
else:
|
||||
if self.seq_parallel_mode == "split_gather":
|
||||
if self.seq_parallel_mode is None:
|
||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
output_parallel, self.process_group, self.seq_parallel_dim
|
||||
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
output = linear_reducescatter_forward_gather_backward(
|
||||
|
@ -562,6 +577,7 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
|
|||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
fp8_communication: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# create weight and bias
|
||||
|
@ -592,6 +608,7 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
|
|||
**kwargs,
|
||||
new_num_embeddings=new_out_features,
|
||||
old_num_embeddings=out_features,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
# get the length of valid embeddings
|
||||
tp_rank = dist.get_rank(process_group)
|
||||
|
|
|
@ -183,6 +183,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -197,6 +198,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
self.n_fused = n_fused
|
||||
self.process_group = process_group
|
||||
self.async_communication = async_communication
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
@ -314,14 +316,26 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
|
||||
if self.seq_parallel_mode is None:
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = reduce_backward(input_, self.process_group)
|
||||
input_parallel = reduce_backward(input_, self.process_group, fp8_communication=self.fp8_communication)
|
||||
output_parallel = matmul_with_async_comm(
|
||||
input_parallel, self.weight, bias, self.process_group, self.async_communication
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
self.process_group,
|
||||
self.async_communication,
|
||||
fp8_communication=self.fp8_communication,
|
||||
)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
input_parallel = input_
|
||||
output_parallel = matmul_gather_forward_reducescatter_backward(
|
||||
input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
self.process_group,
|
||||
True,
|
||||
1,
|
||||
self.overlap,
|
||||
fp8_communication=self.fp8_communication,
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
input_parallel = input_
|
||||
|
@ -331,7 +345,9 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
output = gather_forward_split_backward(
|
||||
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
|
@ -379,6 +395,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
stream_chunk_num: int = 1,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -392,6 +409,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||
self.process_group = process_group
|
||||
self.seq_parallel_mode = seq_parallel_mode
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
@ -514,7 +532,9 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions
|
||||
)
|
||||
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
|
||||
input_ = split_forward_gather_backward(
|
||||
input_, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
if self.training:
|
||||
|
@ -535,13 +555,20 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||
else:
|
||||
if self.seq_parallel_mode is None:
|
||||
output_parallel = torch.matmul(input_, self.weight)
|
||||
output = reduce_forward(output_parallel, self.process_group)
|
||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
output_parallel = torch.matmul(input_, self.weight)
|
||||
output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
output_parallel,
|
||||
self.process_group,
|
||||
1,
|
||||
self.fp8_communication,
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
output_parallel = torch.matmul(input_, self.weight)
|
||||
output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
output_parallel, self.process_group, 1, self.fp8_communication
|
||||
)
|
||||
|
||||
if not self.skip_bias_add:
|
||||
if self.bias is not None:
|
||||
|
@ -600,6 +627,7 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# Keep input parameters
|
||||
|
@ -611,6 +639,7 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
self.n_fused = n_fused
|
||||
self.process_group = process_group
|
||||
self.async_communication = async_communication
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
@ -740,7 +769,9 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
output = gather_forward_split_backward(
|
||||
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
|
|
|
@ -187,11 +187,17 @@ class BertPipelineForwards:
|
|||
if shard_config is not None and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = split_forward_gather_backward(
|
||||
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
encoder_hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
|
||||
|
@ -242,7 +248,10 @@ class BertPipelineForwards:
|
|||
if shard_config is not None and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
if output_hidden_states:
|
||||
|
@ -1135,11 +1144,17 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
embedding_output = split_forward_gather_backward(
|
||||
embedding_output, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
embedding_output,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = split_forward_gather_backward(
|
||||
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
encoder_hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
|
@ -1159,7 +1174,10 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
sequence_output = gather_forward_split_backward(
|
||||
sequence_output, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
sequence_output,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
|
|
@ -221,7 +221,10 @@ class BloomPipelineForwards:
|
|||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
|
@ -264,7 +267,10 @@ class BloomPipelineForwards:
|
|||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
|
@ -922,7 +928,10 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
@ -960,7 +969,10 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
# Add last hidden state
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
|
|
@ -206,6 +206,15 @@ class ChatGLMPipelineForwards:
|
|||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
|
@ -245,6 +254,15 @@ class ChatGLMPipelineForwards:
|
|||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
|
@ -401,6 +419,12 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
|||
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
||||
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
||||
|
||||
if sp_mode in ["all_to_all"] and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with sp mode `{sp_mode}`. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
if sp_mode in ["all_to_all"] and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
|
@ -414,6 +438,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
|||
inputs_embeds,
|
||||
dim=0,
|
||||
process_group=sp_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
|
@ -421,6 +446,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
|||
dim=0,
|
||||
process_group=sp_group,
|
||||
grad_scale=1 / sp_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
||||
inputs_embeds,
|
||||
|
@ -436,6 +462,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
|||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
|
@ -443,6 +470,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
|||
dim=0,
|
||||
process_group=sp_group,
|
||||
grad_scale=sp_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
|
@ -532,9 +560,24 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s
|
|||
key_layer = key_layer.reshape(sq, bs, -1)
|
||||
value_layer = value_layer.reshape(sq, bs, -1)
|
||||
|
||||
query_layer = all_to_all_comm(query_layer, sp_group, gather_dim=0)
|
||||
key_layer = all_to_all_comm(key_layer, sp_group, gather_dim=0)
|
||||
value_layer = all_to_all_comm(value_layer, sp_group, gather_dim=0)
|
||||
query_layer = all_to_all_comm(
|
||||
query_layer,
|
||||
sp_group,
|
||||
gather_dim=0,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
key_layer = all_to_all_comm(
|
||||
key_layer,
|
||||
sp_group,
|
||||
gather_dim=0,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
value_layer = all_to_all_comm(
|
||||
value_layer,
|
||||
sp_group,
|
||||
gather_dim=0,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
query_layer = query_layer.view(
|
||||
sq * sp_size,
|
||||
|
@ -610,7 +653,13 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s
|
|||
|
||||
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
|
||||
if sp_mode == "all_to_all":
|
||||
context_layer = all_to_all_comm(context_layer, sp_group, gather_dim=2, scatter_dim=0)
|
||||
context_layer = all_to_all_comm(
|
||||
context_layer,
|
||||
sp_group,
|
||||
gather_dim=2,
|
||||
scatter_dim=0,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
# =================
|
||||
# Output. [sq, b, h]
|
||||
|
|
|
@ -142,6 +142,7 @@ class CommandPipelineForwards:
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
|
@ -149,6 +150,7 @@ class CommandPipelineForwards:
|
|||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
# decoder layers
|
||||
|
@ -213,6 +215,7 @@ class CommandPipelineForwards:
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
|
@ -220,6 +223,7 @@ class CommandPipelineForwards:
|
|||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
|
@ -384,9 +388,9 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
|||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
@ -448,7 +452,9 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
|
|||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||
attn_output = all_to_all_comm(
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
|
@ -528,9 +534,13 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
|||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
|
@ -575,9 +585,13 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
|
|||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
|
|
@ -24,6 +24,7 @@ from colossalai.moe._operation import (
|
|||
all_to_all_uneven,
|
||||
)
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization.fp8 import all_reduce_fp8
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_forward_split_backward,
|
||||
|
@ -61,7 +62,13 @@ class EPDeepseekMoE(nn.Module):
|
|||
def __init__(self):
|
||||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||
|
||||
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
|
||||
def setup_process_groups(
|
||||
self,
|
||||
tp_group: ProcessGroup,
|
||||
moe_dp_group: ProcessGroup,
|
||||
ep_group: ProcessGroup,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
assert tp_group is not None
|
||||
assert moe_dp_group is not None
|
||||
assert ep_group is not None
|
||||
|
@ -70,6 +77,7 @@ class EPDeepseekMoE(nn.Module):
|
|||
self.ep_rank = dist.get_rank(ep_group)
|
||||
self.num_experts = self.config.n_routed_experts
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
self.ep_group = ep_group
|
||||
self.num_experts_per_ep = self.num_experts // self.ep_size
|
||||
|
@ -86,9 +94,15 @@ class EPDeepseekMoE(nn.Module):
|
|||
self.tp_group = tp_group
|
||||
if self.tp_group.size() > 1:
|
||||
for expert in held_experts:
|
||||
expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group)
|
||||
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group)
|
||||
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group)
|
||||
expert.gate_proj = Linear1D_Col.from_native_module(
|
||||
expert.gate_proj, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
expert.up_proj = Linear1D_Col.from_native_module(
|
||||
expert.up_proj, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
expert.down_proj = Linear1D_Row.from_native_module(
|
||||
expert.down_proj, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
for p in self.experts.parameters():
|
||||
set_moe_tensor_ep_group(p, ep_group)
|
||||
|
@ -106,7 +120,8 @@ class EPDeepseekMoE(nn.Module):
|
|||
if module.__class__.__name__ == "DeepseekMLP":
|
||||
return module
|
||||
module.__class__ = EPDeepseekMoE
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
|
||||
fp8_communication = kwargs.get("fp8_communication", False)
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication=fp8_communication)
|
||||
return module
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -130,18 +145,32 @@ class EPDeepseekMoE(nn.Module):
|
|||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||
|
||||
# [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
|
||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||
dist.all_to_all_single(
|
||||
output_split_sizes,
|
||||
input_split_sizes,
|
||||
group=self.ep_group,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
activate_experts = output_split_sizes[: self.num_experts_per_ep].clone()
|
||||
for i in range(1, self.ep_size):
|
||||
activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
|
||||
activate_experts = (activate_experts > 0).float()
|
||||
dist.all_reduce(activate_experts, group=self.moe_dp_group)
|
||||
|
||||
if self.fp8_communication:
|
||||
all_reduce_fp8(activate_experts, group=self.moe_dp_group)
|
||||
else:
|
||||
dist.all_reduce(activate_experts, group=self.moe_dp_group)
|
||||
|
||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||
output_states, _ = all_to_all_uneven(
|
||||
dispatch_states,
|
||||
input_split_list,
|
||||
output_split_list,
|
||||
self.ep_group,
|
||||
fp8_communication=self.fp8_communication,
|
||||
)
|
||||
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
|
||||
|
||||
if output_states.size(0) > 0:
|
||||
|
@ -167,7 +196,9 @@ class EPDeepseekMoE(nn.Module):
|
|||
output_states_list.append(split_states)
|
||||
output_states = torch.cat(output_states_list)
|
||||
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
|
||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||
dispatch_states, _ = all_to_all_uneven(
|
||||
output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
recover_token_idx = torch.empty_like(flat_topk_token_idx)
|
||||
recover_token_idx[flat_topk_token_idx] = torch.arange(
|
||||
flat_topk_token_idx.size(0), device=flat_topk_token_idx.device
|
||||
|
@ -534,9 +565,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
|
|||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
|
@ -595,7 +626,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
|
|||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
|
||||
attn_output = all_to_all_comm(
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
) # (1, 4, 256)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
|
@ -669,6 +702,7 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
|
|||
# TODO: upgrade transformers to 4.44.0 to fix the bug, remove the hard code.
|
||||
self._use_flash_attention_2 = shard_config.enable_flash_attention
|
||||
self._use_sdpa = False if shard_config.enable_flash_attention else self._use_sdpa
|
||||
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
|
@ -688,9 +722,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
|
|||
)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
|
@ -734,9 +772,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
|
|||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
|
|
@ -221,6 +221,7 @@ class GPT2PipelineForwards:
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
# Going through held blocks.
|
||||
|
@ -276,6 +277,7 @@ class GPT2PipelineForwards:
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
|
@ -1119,6 +1121,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
@ -1186,6 +1189,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
|
|
@ -185,6 +185,7 @@ class GPTJPipelineForwards:
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
# Going through held blocks.
|
||||
|
@ -236,6 +237,7 @@ class GPTJPipelineForwards:
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
|
@ -915,6 +917,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
@ -978,6 +981,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
|
|
@ -162,9 +162,13 @@ class LlamaPipelineForwards:
|
|||
hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group)
|
||||
|
||||
elif is_share_sp_tp(sp_mode):
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
if use_cache:
|
||||
|
@ -227,7 +231,9 @@ class LlamaPipelineForwards:
|
|||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||
hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode)
|
||||
hidden_states = gather_sp_output(
|
||||
hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
@ -532,9 +538,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
@ -605,7 +611,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||
attn_output = all_to_all_comm(
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
|
@ -707,9 +715,13 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
|
|||
attn_kwargs = {"attention_mask_type": attn_kwargs["attention_mask_type"]} # drop redundant tensors
|
||||
|
||||
elif is_share_sp_tp(sp_mode):
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
|
@ -754,7 +766,9 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
|
|||
hidden_states = self.norm(hidden_states)
|
||||
# Cases that don't support parallelizing cross entropy computation along sequence
|
||||
if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather:
|
||||
hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode)
|
||||
hidden_states = gather_sp_output(
|
||||
hidden_states, sp_group, sp_mode, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
|
|
@ -53,7 +53,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
def __init__(self, *args, **kwargs):
|
||||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||
|
||||
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
|
||||
def setup_process_groups(
|
||||
self,
|
||||
tp_group: ProcessGroup,
|
||||
moe_dp_group: ProcessGroup,
|
||||
ep_group: ProcessGroup,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
assert tp_group is not None
|
||||
assert moe_dp_group is not None
|
||||
assert ep_group is not None
|
||||
|
@ -62,6 +68,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
self.ep_size = dist.get_world_size(ep_group)
|
||||
self.ep_rank = dist.get_rank(ep_group)
|
||||
self.ep_group = ep_group
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
if self.num_experts % self.ep_size != 0:
|
||||
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
|
||||
|
@ -80,9 +87,15 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
self.tp_group = tp_group
|
||||
if self.tp_group.size() > 1:
|
||||
for expert in held_experts:
|
||||
expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group)
|
||||
expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group)
|
||||
expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group)
|
||||
expert.w1 = Linear1D_Col.from_native_module(
|
||||
expert.w1, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
expert.w3 = Linear1D_Col.from_native_module(
|
||||
expert.w3, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
expert.w2 = Linear1D_Row.from_native_module(
|
||||
expert.w2, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
for p in self.experts.parameters():
|
||||
set_moe_tensor_ep_group(p, ep_group)
|
||||
|
@ -99,7 +112,8 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
# TODO: better init
|
||||
LazyInitContext.materialize(module)
|
||||
module.__class__ = EPMixtralSparseMoeBlock
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
|
||||
fp8_communication = kwargs.get("fp8_communication", False)
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication)
|
||||
return module
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -120,6 +134,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
|
||||
|
||||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||
|
||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||
|
||||
with torch.no_grad():
|
||||
|
@ -132,7 +147,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
|
||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||
output_states, _ = all_to_all_uneven(
|
||||
dispatch_states,
|
||||
input_split_list,
|
||||
output_split_list,
|
||||
self.ep_group,
|
||||
fp8_communication=self.fp8_communication,
|
||||
)
|
||||
# compute expert output
|
||||
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
|
||||
if output_states.size(0) > 0:
|
||||
|
@ -162,7 +183,9 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
|||
output_states = torch.cat(output_states_list)
|
||||
|
||||
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
|
||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||
dispatch_states, _ = all_to_all_uneven(
|
||||
output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
recover_experts_idx = torch.empty_like(selected_experts_idx)
|
||||
recover_experts_idx[selected_experts_idx] = torch.arange(
|
||||
|
@ -566,9 +589,9 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
|
|||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
@ -673,7 +696,9 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
|
|||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
|
||||
attn_output = all_to_all_comm(
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
) # (1, 4, 256)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
|
@ -780,9 +805,13 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
|
|||
)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
|
@ -831,9 +860,13 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
|
|||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
|
|
@ -175,6 +175,7 @@ class Qwen2PipelineForwards:
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
|
@ -182,6 +183,7 @@ class Qwen2PipelineForwards:
|
|||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
# decoder layers
|
||||
|
@ -246,6 +248,7 @@ class Qwen2PipelineForwards:
|
|||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
|
@ -253,6 +256,7 @@ class Qwen2PipelineForwards:
|
|||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
@ -516,9 +520,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||
value_states = self.v_proj(hidden_states)
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
@ -604,7 +608,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||
attn_output = all_to_all_comm(
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
|
@ -702,9 +708,13 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
|||
next_decoder_cache = None
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
|
@ -741,9 +751,13 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
|||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
|
|
@ -98,6 +98,7 @@ class BertPolicy(Policy):
|
|||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -106,6 +107,7 @@ class BertPolicy(Policy):
|
|||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -114,6 +116,7 @@ class BertPolicy(Policy):
|
|||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -123,7 +126,10 @@ class BertPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel_mode": sp_mode},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
|
@ -136,12 +142,16 @@ class BertPolicy(Policy):
|
|||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel_mode": sp_mode},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dropout",
|
||||
|
@ -180,6 +190,13 @@ class BertPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=embedding_cls,
|
||||
kwargs=(
|
||||
{
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {}
|
||||
),
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
|
@ -249,6 +266,7 @@ class BertPolicy(Policy):
|
|||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
policy=base_policy,
|
||||
|
|
|
@ -72,20 +72,30 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.FusedLinear1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 3,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.projection",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"skip_bias_add": self.enable_bias_gelu_fused},
|
||||
kwargs={
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -114,14 +124,23 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.query",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.key",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.dropout",
|
||||
|
@ -130,6 +149,9 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
|
@ -138,14 +160,23 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.query",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.key",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.dropout",
|
||||
|
@ -154,6 +185,9 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.output.dropout",
|
||||
|
@ -162,10 +196,16 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="intermediate_query.dense",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output_query.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output_query.dropout",
|
||||
|
@ -185,26 +225,44 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -225,7 +283,14 @@ class BlipPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="model.decoder.embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
|
@ -241,6 +306,7 @@ class BlipPolicy(Policy):
|
|||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
|
|
|
@ -76,12 +76,19 @@ class BloomPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.query_key_value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel_mode": sp_mode},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.attention_dropout",
|
||||
|
@ -90,12 +97,19 @@ class BloomPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_h_to_4h",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_4h_to_h",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel_mode": sp_mode},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -115,7 +129,14 @@ class BloomPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
|
@ -279,6 +300,7 @@ class BloomForCausalLMPolicy(BloomPolicy):
|
|||
kwargs=dict(
|
||||
gather_output=not self.shard_config.parallel_output,
|
||||
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
|
@ -337,7 +359,9 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
|
|||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="score",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=BloomForSequenceClassification,
|
||||
|
@ -374,7 +398,9 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
|
|||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="classifier",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
|
|
|
@ -128,12 +128,17 @@ class ChatGLMPolicy(Policy):
|
|||
"seq_parallel_mode": sp_mode,
|
||||
"seq_parallel_dim": 0,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"seq_parallel_dim": 0,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.core_attention.attention_dropout",
|
||||
|
@ -148,7 +153,14 @@ class ChatGLMPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="embedding.word_embeddings",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
|
|
|
@ -128,37 +128,37 @@ class CommandPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -168,7 +168,14 @@ class CommandPolicy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=CohereModel,
|
||||
|
@ -306,6 +313,7 @@ class CommandForCausalLMPolicy(CommandPolicy):
|
|||
kwargs={
|
||||
"gather_output": not self.shard_config.parallel_output,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
)
|
||||
],
|
||||
|
|
|
@ -118,18 +118,22 @@ class DeepseekPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -138,7 +142,10 @@ class DeepseekPolicy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs={
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
policy=policy,
|
||||
target_key="DeepseekModel",
|
||||
|
@ -155,6 +162,7 @@ class DeepseekPolicy(Policy):
|
|||
"ep_group": self.shard_config.ep_group,
|
||||
"tp_group": self.shard_config.tensor_parallel_process_group,
|
||||
"moe_dp_group": self.shard_config.moe_dp_group,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
)
|
||||
],
|
||||
|
@ -298,14 +306,14 @@ class DeepseekForCausalLMPolicy(DeepseekPolicy):
|
|||
policy = super().module_policy()
|
||||
# TODO: assign pg mesh from plugin to all modules
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for causal lm
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
"DeepseekForCausalLM": ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True),
|
||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
@ -105,7 +105,14 @@ class FalconPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
|
|
|
@ -110,14 +110,13 @@ class GPT2Policy(Policy):
|
|||
"n_fused": 3,
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.c_proj",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
},
|
||||
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.c_fc",
|
||||
|
@ -127,14 +126,13 @@ class GPT2Policy(Policy):
|
|||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.c_proj",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
},
|
||||
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.attn_dropout",
|
||||
|
@ -164,7 +162,14 @@ class GPT2Policy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="wte",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=GPT2Model,
|
||||
|
@ -334,6 +339,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||
kwargs={
|
||||
"gather_output": False,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
)
|
||||
],
|
||||
|
@ -404,6 +410,7 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
|||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
|
|
@ -77,6 +77,7 @@ class GPTJPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -84,6 +85,7 @@ class GPTJPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -91,19 +93,29 @@ class GPTJPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc_in",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc_out",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.attn_dropout",
|
||||
|
@ -125,7 +137,14 @@ class GPTJPolicy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="wte",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=GPTJModel,
|
||||
|
@ -264,6 +283,7 @@ class GPTJForCausalLMPolicy(GPTJPolicy):
|
|||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
|
|
@ -133,37 +133,37 @@ class LlamaPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -173,7 +173,14 @@ class LlamaPolicy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
|
@ -316,6 +323,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
|||
kwargs={
|
||||
"gather_output": not self.shard_config.parallel_output,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
)
|
||||
],
|
||||
|
@ -384,7 +392,12 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
|||
LlamaForSequenceClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="score",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
@ -88,30 +88,51 @@ class MistralPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -121,7 +142,14 @@ class MistralPolicy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=MistralModel,
|
||||
|
@ -281,6 +309,7 @@ class MistralForCausalLMPolicy(MistralPolicy):
|
|||
kwargs={
|
||||
"gather_output": not self.shard_config.parallel_output,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
@ -297,7 +326,9 @@ class MistralForCausalLMPolicy(MistralPolicy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=PaddingLMHead,
|
||||
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
|
||||
kwargs=dict(
|
||||
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
@ -350,7 +381,9 @@ class MistralForSequenceClassificationPolicy(MistralPolicy):
|
|||
MistralForSequenceClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="score",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
@ -114,21 +114,27 @@ class MixtralPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
),
|
||||
SubModuleReplacementDescription( # or replicate?
|
||||
suffix="block_sparse_moe.gate", target_module=Linear1D_Col, kwargs={"gather_output": True}
|
||||
suffix="block_sparse_moe.gate",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -138,7 +144,14 @@ class MixtralPolicy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=MixtralModel,
|
||||
|
@ -282,7 +295,7 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True),
|
||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
@ -336,7 +349,9 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
|
|||
MixtralForSequenceClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="score",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
@ -102,18 +102,30 @@ class OPTPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="out_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -123,7 +135,14 @@ class OPTPolicy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=OPTDecoder,
|
||||
|
@ -272,6 +291,7 @@ class OPTForCausalLMPolicy(OPTPolicy):
|
|||
kwargs=dict(
|
||||
gather_output=not self.shard_config.parallel_output,
|
||||
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
|
|
|
@ -119,37 +119,37 @@ class Qwen2Policy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -159,7 +159,14 @@ class Qwen2Policy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=Qwen2Model,
|
||||
|
@ -313,11 +320,15 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
|||
setattr(self.shard_config, "causal_lm", True)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for causal lm
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
Qwen2ForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col)
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(fp8_communication=self.shard_config.fp8_communication),
|
||||
)
|
||||
],
|
||||
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
|
||||
)
|
||||
|
@ -366,7 +377,9 @@ class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
|
|||
Qwen2ForSequenceClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="score",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
@ -43,19 +43,29 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.FusedLinear1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 3,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.lin1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.lin2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -68,58 +78,100 @@ class SamPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_token_to_image.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_token_to_image.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_token_to_image.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_token_to_image.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.lin1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.lin2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_image_to_token.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_image_to_token.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_image_to_token.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_image_to_token.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -132,18 +184,30 @@ class SamPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="final_attn_token_to_image.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_attn_token_to_image.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_attn_token_to_image.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_attn_token_to_image.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
|
@ -117,23 +117,38 @@ class T5BasePolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="q",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="k",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="v",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="o",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="relative_attention_bias",
|
||||
target_module=Embedding1D,
|
||||
kwargs=dict(gather_output=False),
|
||||
kwargs=dict(
|
||||
gather_output=False,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
),
|
||||
ignore_if_not_exist=True,
|
||||
),
|
||||
],
|
||||
|
@ -151,13 +166,24 @@ class T5BasePolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="wi_0 ",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_1",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="wo",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
|
@ -170,10 +196,16 @@ class T5BasePolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="wi",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
|
@ -187,7 +219,14 @@ class T5BasePolicy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5Stack,
|
||||
|
@ -407,7 +446,14 @@ class T5ModelPolicy(T5BasePolicy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5Model,
|
||||
|
@ -451,7 +497,14 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5ForConditionalGeneration,
|
||||
|
@ -465,6 +518,7 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
|||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
policy=policy,
|
||||
|
@ -539,7 +593,14 @@ class T5EncoderPolicy(T5BasePolicy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5EncoderModel,
|
||||
|
|
|
@ -70,14 +70,23 @@ class ViTPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.query",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.key",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.dropout",
|
||||
|
@ -86,6 +95,9 @@ class ViTPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
|
@ -96,11 +108,15 @@ class ViTPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dropout",
|
||||
|
@ -215,7 +231,9 @@ class ViTForImageClassificationPolicy(ViTPolicy):
|
|||
ViTForImageClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="classifier",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
@ -91,26 +91,44 @@ class WhisperPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -128,42 +146,72 @@ class WhisperPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder_attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder_attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder_attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder_attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -174,7 +222,14 @@ class WhisperPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
|
@ -303,6 +358,7 @@ class WhisperPolicy(Policy):
|
|||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
policy=base_policy,
|
||||
|
|
|
@ -29,6 +29,7 @@ class ShardConfig:
|
|||
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
|
||||
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
|
||||
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
|
||||
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism. Defaults to False.
|
||||
parallel_output (bool): For TP: whether to use parallelize cross entropy computation along the feature dim.
|
||||
For SP: set to True to NOT gather the output along the seq dim.
|
||||
"""
|
||||
|
@ -54,6 +55,7 @@ class ShardConfig:
|
|||
# for moe related
|
||||
moe_dp_group: Optional[ProcessGroup] = None
|
||||
ep_group: Optional[ProcessGroup] = None
|
||||
fp8_communication: bool = False
|
||||
# pipeline_parallel_size: int
|
||||
# data_parallel_size: int
|
||||
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
||||
|
|
|
@ -61,6 +61,8 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
|||
with torch._C.DisableTorchFunction():
|
||||
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
|
||||
args, kwargs = replace_args(args, kwargs, new_args)
|
||||
with torch._C.DisableTorchFunction():
|
||||
func = ColoParamOpHookManager.rewrite_op(func)
|
||||
ret = super().__torch_function__(func, types, args, kwargs)
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = ColoParamOpHookManager.post_op(params, ret)
|
||||
|
|
|
@ -30,6 +30,9 @@ class ColoParamOpHook(ABC):
|
|||
def post_backward(self, params: List[torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def rewrite_op(self, func) -> Any:
|
||||
return func
|
||||
|
||||
|
||||
class ColoParamOpHookManager:
|
||||
"""
|
||||
|
@ -101,6 +104,12 @@ class ColoParamOpHookManager:
|
|||
def has_hook() -> bool:
|
||||
return len(ColoParamOpHookManager.hooks) > 0
|
||||
|
||||
@staticmethod
|
||||
def rewrite_op(func) -> Any:
|
||||
for hook in ColoParamOpHookManager.hooks:
|
||||
func = hook.rewrite_op(func)
|
||||
return func
|
||||
|
||||
|
||||
class PreFwdPostBwd(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
|
|
@ -166,6 +166,7 @@ class Chunk:
|
|||
self.grad_chunk = None
|
||||
# the async all-reduce/reduce-scatter work of this grad chunk (None means sync)
|
||||
self.grad_reduce_work = None
|
||||
self.fp8_communication = False
|
||||
|
||||
@property
|
||||
def memory_usage(self) -> Dict[str, int]:
|
||||
|
@ -521,9 +522,17 @@ class Chunk:
|
|||
|
||||
alloc_storage(self.cuda_global_chunk)
|
||||
assert self.cuda_global_chunk.is_contiguous()
|
||||
work = dist.all_gather_into_tensor(
|
||||
self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op
|
||||
)
|
||||
if self.fp8_communication:
|
||||
assert async_op == False, "fp8 all-gather does not support async_op!"
|
||||
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
|
||||
|
||||
work = all_gather_into_tensor_flat_fp8(
|
||||
self.cuda_global_chunk, self.cuda_shard, self.cuda_global_chunk.shape, self.torch_pg
|
||||
)
|
||||
else:
|
||||
work = dist.all_gather_into_tensor(
|
||||
self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op
|
||||
)
|
||||
|
||||
self.cuda_shard = None
|
||||
self.is_gathered = True
|
||||
|
|
|
@ -26,6 +26,7 @@ class ChunkManager:
|
|||
init_device: Optional[torch.device] = None,
|
||||
reuse_fp16_chunk: bool = True,
|
||||
max_prefetch: int = 0,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
self.device = init_device or get_accelerator().get_current_device()
|
||||
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
|
||||
|
@ -44,6 +45,7 @@ class ChunkManager:
|
|||
self.accumulating_grads = False
|
||||
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
|
||||
self._prefetch_stream = get_accelerator().Stream() if max_prefetch else None
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
def register_tensor(
|
||||
self,
|
||||
|
@ -101,6 +103,8 @@ class ChunkManager:
|
|||
extra_dp_group=extra_dp_group,
|
||||
**chunk_kwargs,
|
||||
)
|
||||
if self.fp8_communication:
|
||||
chunk.fp8_communication = True
|
||||
|
||||
chunk_group.append(chunk)
|
||||
chunk.append_tensor(tensor)
|
||||
|
|
|
@ -15,6 +15,7 @@ from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_
|
|||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.lazy import LazyTensor
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.quantization.fp8_hook import FP8Hook
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.d_tensor import (
|
||||
distribute_tensor,
|
||||
|
@ -98,6 +99,8 @@ class GeminiDDP(ModelWrapper):
|
|||
extra_dp_group: Optional[ProcessGroup] = None,
|
||||
verbose: bool = False,
|
||||
enable_async_reduce: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
) -> None:
|
||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
|
||||
|
@ -122,6 +125,8 @@ class GeminiDDP(ModelWrapper):
|
|||
verbose=verbose,
|
||||
max_prefetch=max_prefetch,
|
||||
)
|
||||
if fp8_communication:
|
||||
self.chunk_manager.fp8_communication = True
|
||||
self.gemini_manager = GeminiManager(
|
||||
placement_policy,
|
||||
self.chunk_manager,
|
||||
|
@ -135,6 +140,9 @@ class GeminiDDP(ModelWrapper):
|
|||
)
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
||||
self.hooks = [self.param_op_hook]
|
||||
if use_fp8:
|
||||
self.hooks.append(FP8Hook())
|
||||
self.fp32_params: List[torch.Tensor] = list()
|
||||
self.fp16_params: List[ColoParameter] = list()
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
||||
|
@ -307,7 +315,7 @@ class GeminiDDP(ModelWrapper):
|
|||
outputs = self._inference_forward(*args, **kwargs)
|
||||
else:
|
||||
self.gemini_manager.pre_iter(*args)
|
||||
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
with ColoParamOpHookManager.use_hooks(*self.hooks):
|
||||
outputs = self.module(*args, **kwargs)
|
||||
|
||||
if self.force_outputs_fp32:
|
||||
|
@ -316,7 +324,7 @@ class GeminiDDP(ModelWrapper):
|
|||
|
||||
def _inference_forward(self, *args, **kwargs):
|
||||
"""This function is only triggered for inference."""
|
||||
fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook)
|
||||
fwd_ctx = ColoParamOpHookManager.use_hooks(*self.hooks)
|
||||
if not self.scatter_after_inference:
|
||||
# gather all chunks
|
||||
for chunk in self.chunk_manager.get_chunks(self.fp16_params):
|
||||
|
@ -369,7 +377,7 @@ class GeminiDDP(ModelWrapper):
|
|||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
self._pre_backward()
|
||||
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(*self.hooks):
|
||||
loss.backward()
|
||||
self._post_backward()
|
||||
|
||||
|
|
|
@ -4,6 +4,8 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
|
||||
|
||||
|
||||
class TensorBucket:
|
||||
def __init__(self, size):
|
||||
|
@ -61,11 +63,14 @@ class TensorBucket:
|
|||
for old, new in zip(self._bucket, unflattened_tensor_list):
|
||||
old.copy_(new)
|
||||
|
||||
def all_gather(self, group=None):
|
||||
def all_gather(self, group=None, fp8_communication: bool = False):
|
||||
flat = self.flatten()
|
||||
buffers = [torch.empty_like(flat) for _ in range(dist.get_world_size(group))]
|
||||
dist.all_gather(buffers, flat, group=group)
|
||||
unflat_buffers = [self.unflatten(buffer) for buffer in buffers]
|
||||
buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype)
|
||||
if fp8_communication:
|
||||
all_gather_into_tensor_flat_fp8(buffer, flat, output_shape=buffer.shape, group=group)
|
||||
else:
|
||||
dist.all_gather_into_tensor(buffer, flat, group=group)
|
||||
unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))]
|
||||
# transpose the list of list
|
||||
unflat_buffers = list(map(list, zip(*unflat_buffers)))
|
||||
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
|
||||
|
|
|
@ -20,6 +20,7 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
|||
)
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8, all_reduce_fp8, reduce_scatter_fp8
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
|
||||
|
@ -86,6 +87,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
forced_dtype: Optional[torch.dtype] = None,
|
||||
master_weights: bool = True, # master weights
|
||||
overlap_allgather: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||
|
||||
|
@ -127,6 +129,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
self._overlap_allgather = overlap_allgather
|
||||
self._reduce_bucket_size = reduce_bucket_size
|
||||
self._communication_dtype = communication_dtype
|
||||
self._fp8_communication = fp8_communication
|
||||
|
||||
# gradient clipping
|
||||
self._clip_grad_norm = clip_grad_norm
|
||||
|
@ -330,7 +333,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
flat_grads = flat_grads.to(self._communication_dtype)
|
||||
|
||||
if not self._partition_grads:
|
||||
dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
|
||||
if self._fp8_communication:
|
||||
all_reduce_fp8(flat_grads, group=bucket_store.torch_pg)
|
||||
else:
|
||||
dist.all_reduce(flat_grads, group=bucket_store.torch_pg)
|
||||
if flat_grads.dtype != grad_dtype:
|
||||
flat_grads = flat_grads.to(grad_dtype)
|
||||
|
||||
|
@ -340,7 +346,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
else:
|
||||
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size))
|
||||
received_grad = torch.zeros_like(flat_grads_list[0])
|
||||
dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)
|
||||
if self._fp8_communication:
|
||||
reduce_scatter_fp8(
|
||||
received_grad,
|
||||
flat_grads_list,
|
||||
group=bucket_store.torch_pg,
|
||||
)
|
||||
else:
|
||||
dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)
|
||||
|
||||
if received_grad.dtype != grad_dtype:
|
||||
received_grad = received_grad.to(grad_dtype)
|
||||
|
@ -562,18 +575,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
set_all_gather_handle(working_param, handle)
|
||||
else:
|
||||
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
|
||||
dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
|
||||
if self._fp8_communication:
|
||||
all_gather_into_tensor_flat_fp8(
|
||||
padded_working_param, param_to_gather, pg, fp8_format="e4m3"
|
||||
)
|
||||
else:
|
||||
dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
|
||||
continue
|
||||
try:
|
||||
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||
except RuntimeError:
|
||||
self.pg_to_tensor_bucket[pg].all_gather(pg)
|
||||
self.pg_to_tensor_bucket[pg].all_gather(pg, fp8_communication=self._fp8_communication)
|
||||
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||
if not self._overlap_allgather:
|
||||
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
|
||||
if not tensor_bucket.is_empty():
|
||||
tensor_bucket.all_gather(pg)
|
||||
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
|
||||
|
||||
def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
r"""
|
||||
|
|
|
@ -179,7 +179,7 @@ def main():
|
|||
"--plugin",
|
||||
type=str,
|
||||
default="torch_ddp",
|
||||
choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel"],
|
||||
choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel", "torch_fsdp"],
|
||||
help="plugin to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -190,6 +190,7 @@ def main():
|
|||
)
|
||||
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
|
||||
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
|
||||
parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_type == "bert":
|
||||
|
@ -214,9 +215,9 @@ def main():
|
|||
if args.plugin == "torch_ddp_fp16":
|
||||
booster_kwargs["mixed_precision"] = "fp16"
|
||||
if args.plugin.startswith("torch_ddp"):
|
||||
plugin = TorchDDPPlugin()
|
||||
plugin = TorchDDPPlugin(fp8_communication=args.use_fp8_comm)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(initial_scale=2**5)
|
||||
plugin = GeminiPlugin(initial_scale=2**5, fp8_communication=args.use_fp8_comm)
|
||||
elif args.plugin == "low_level_zero":
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
elif args.plugin == "hybrid_parallel":
|
||||
|
@ -232,6 +233,18 @@ def main():
|
|||
zero_stage=1,
|
||||
precision="fp16",
|
||||
initial_scale=1,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
elif args.plugin == "torch_fsdp":
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
|
||||
|
||||
from colossalai.booster.plugin import TorchFSDPPlugin
|
||||
|
||||
plugin = TorchFSDPPlugin(
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
|
||||
),
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
|
|
@ -188,6 +188,8 @@ def main():
|
|||
help="only gpt2 now",
|
||||
)
|
||||
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
|
||||
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
|
||||
parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_type == "gpt2":
|
||||
|
@ -210,7 +212,7 @@ def main():
|
|||
if args.plugin == "torch_ddp_fp16":
|
||||
booster_kwargs["mixed_precision"] = "fp16"
|
||||
if args.plugin.startswith("torch_ddp"):
|
||||
plugin = TorchDDPPlugin()
|
||||
plugin = TorchDDPPlugin(fp8_communication=args.use_fp8_comm)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(initial_scale=2**5)
|
||||
elif args.plugin == "low_level_zero":
|
||||
|
@ -226,6 +228,7 @@ def main():
|
|||
zero_stage=1,
|
||||
precision="fp16",
|
||||
initial_scale=1,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
|
|
@ -104,6 +104,8 @@ def main():
|
|||
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
|
||||
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
||||
parser.add_argument("--no_cache", action="store_true")
|
||||
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
|
||||
parser.add_argument("--use_fp8", action="store_true")
|
||||
parser.add_argument("--overlap_allgather", action="store_true")
|
||||
parser.add_argument(
|
||||
"--sp_mode",
|
||||
|
@ -148,6 +150,7 @@ def main():
|
|||
enable_flash_attention=args.xformers,
|
||||
max_prefetch=args.prefetch_num,
|
||||
enable_async_reduce=not args.disable_async_reduce,
|
||||
use_fp8=args.use_fp8,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
|
@ -160,6 +163,7 @@ def main():
|
|||
max_prefetch=args.prefetch_num,
|
||||
enable_async_reduce=not args.disable_async_reduce,
|
||||
enable_flash_attention=args.xformers,
|
||||
use_fp8=args.use_fp8,
|
||||
)
|
||||
elif args.plugin == "fsdp":
|
||||
if use_empty_init:
|
||||
|
@ -170,6 +174,7 @@ def main():
|
|||
buffer_dtype=torch.float16,
|
||||
),
|
||||
param_init_fn=empty_init(),
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
else:
|
||||
plugin = TorchFSDPPlugin(
|
||||
|
@ -177,7 +182,8 @@ def main():
|
|||
param_dtype=torch.float16,
|
||||
reduce_dtype=torch.float16,
|
||||
buffer_dtype=torch.float16,
|
||||
)
|
||||
),
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
elif args.plugin == "fsdp_cpu":
|
||||
if use_empty_init:
|
||||
|
@ -189,6 +195,7 @@ def main():
|
|||
),
|
||||
cpu_offload=CPUOffload(offload_params=True),
|
||||
param_init_fn=empty_init(),
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
else:
|
||||
plugin = TorchFSDPPlugin(
|
||||
|
@ -198,6 +205,7 @@ def main():
|
|||
buffer_dtype=torch.float16,
|
||||
),
|
||||
cpu_offload=CPUOffload(offload_params=True),
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
|
@ -215,6 +223,7 @@ def main():
|
|||
precision="bf16",
|
||||
enable_metadata_cache=not args.no_cache,
|
||||
overlap_allgather=args.overlap_allgather,
|
||||
use_fp8=args.use_fp8,
|
||||
**hybrid_kwargs,
|
||||
)
|
||||
elif args.plugin == "3d_cpu":
|
||||
|
@ -230,6 +239,8 @@ def main():
|
|||
microbatch_size=args.mbs,
|
||||
initial_scale=2**8,
|
||||
precision="bf16",
|
||||
overlap_p2p=args.overlap,
|
||||
use_fp8=args.use_fp8,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
@ -259,7 +270,6 @@ def main():
|
|||
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
init_kwargs = {}
|
||||
if config.model_type == "chatglm":
|
||||
init_kwargs["empty_init"] = False
|
||||
|
|
|
@ -9,7 +9,7 @@ torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package
|
|||
torchrec==0.2.0
|
||||
contexttimer
|
||||
einops
|
||||
triton==2.1.0
|
||||
triton
|
||||
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
|
||||
SentencePiece
|
||||
ninja
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai import launch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import all_to_all_single_fp8
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||
@parameterize("async_op", [True, False])
|
||||
def check_all2all(shape, dtype, async_op):
|
||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
output = torch.empty_like(x)
|
||||
output_fp8 = torch.empty_like(x)
|
||||
origin_hanle = dist.all_to_all_single(output, x, group=_get_default_group(), async_op=async_op)
|
||||
fp8_handle = all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=async_op)
|
||||
if async_op:
|
||||
origin_hanle.wait()
|
||||
fp8_handle.wait()
|
||||
assert_close(output, output_fp8, rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
@parameterize("shape", [(8, 8, 16)])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||
@parameterize("async_op", [True, False])
|
||||
def check_all2all_uneven(shape, dtype, async_op):
|
||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
input_split_sizes = [3, 3, 1, 1]
|
||||
if dist.get_rank() in [0, 1]:
|
||||
output_split_sizes = [3, 3, 3, 3]
|
||||
else:
|
||||
output_split_sizes = [1, 1, 1, 1]
|
||||
output_shape = list(shape)
|
||||
output_shape[0] = sum(output_split_sizes)
|
||||
output = torch.empty(output_shape, device=x.device, dtype=x.dtype)
|
||||
output_fp8 = torch.empty(output_shape, device=x.device, dtype=x.dtype)
|
||||
origin_hanle = dist.all_to_all_single(
|
||||
output,
|
||||
x,
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=_get_default_group(),
|
||||
async_op=async_op,
|
||||
)
|
||||
fp8_handle = all_to_all_single_fp8(
|
||||
output_fp8,
|
||||
x,
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=_get_default_group(),
|
||||
async_op=async_op,
|
||||
)
|
||||
if async_op:
|
||||
origin_hanle.wait()
|
||||
fp8_handle.wait()
|
||||
assert_close(output, output_fp8, rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_all2all()
|
||||
check_all2all_uneven()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_all_to_all_single():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_all_to_all_single()
|
|
@ -0,0 +1,39 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai import launch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import all_to_all_fp8
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@parameterize("shape", [(16, 8, 4)])
|
||||
@parameterize("scatter_dim", [0, 1, 2])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
||||
def check_4gpu(shape, scatter_dim, dtype, fp8_format):
|
||||
world_size = dist.get_world_size()
|
||||
input_tensor = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
input_tensor_list = list(torch.chunk(input_tensor, world_size, scatter_dim))
|
||||
input_tensor_list = [x.contiguous() for x in input_tensor_list]
|
||||
output_tensor_list_fp8 = [torch.empty_like(x) for x in input_tensor_list]
|
||||
output_tensor_list = [torch.empty_like(x) for x in input_tensor_list]
|
||||
all_to_all_fp8(output_tensor_list_fp8, input_tensor_list, group=_get_default_group(), fp8_format=fp8_format)
|
||||
dist.all_to_all(output_tensor_list, input_tensor_list, group=_get_default_group())
|
||||
assert_close(output_tensor_list_fp8, output_tensor_list, rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_4gpu()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_all_to_all():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_all_to_all()
|
|
@ -0,0 +1,37 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai import launch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import all_to_all_single_fp8
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
dist.all_to_all_single
|
||||
|
||||
|
||||
@parameterize("shape", [(4), (8, 7), (4, 8, 16)])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
||||
def check_4gpu(shape, dtype, fp8_format):
|
||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
output = torch.empty_like(x)
|
||||
output_fp8 = torch.empty_like(x)
|
||||
all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), fp8_format=fp8_format)
|
||||
dist.all_to_all_single(output, x, group=_get_default_group())
|
||||
assert_close(output, output_fp8, rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_4gpu()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_all_to_all_single():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_all_to_all_single()
|
|
@ -0,0 +1,43 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai import launch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@parameterize("shape", [(3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||
@parameterize("async_op", [True, False])
|
||||
def check_4gpu(shape, dtype, async_op):
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
flat_padded_x = x.view(-1)
|
||||
if flat_padded_x.size(0) % world_size != 0:
|
||||
pad_size = world_size - flat_padded_x.size(0) % world_size
|
||||
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
|
||||
output = torch.empty_like(flat_padded_x)
|
||||
chunk = flat_padded_x.chunk(world_size)[rank].clone()
|
||||
handle = all_gather_into_tensor_flat_fp8(output, chunk, x.shape, group=_get_default_group(), async_op=async_op)
|
||||
if async_op:
|
||||
handle.wait()
|
||||
assert_close(output[: x.numel()], x.view(-1), rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_4gpu()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_all_gather_flat():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_all_gather_flat()
|
|
@ -0,0 +1,55 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai import launch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import all_reduce_fp8
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@parameterize(
|
||||
"shape",
|
||||
[
|
||||
(3, 7),
|
||||
(4, 7),
|
||||
(7, 4),
|
||||
(8, 9),
|
||||
(3),
|
||||
(7,),
|
||||
(8,),
|
||||
],
|
||||
)
|
||||
@parameterize("dtype", [torch.float16, torch.bfloat16])
|
||||
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
||||
@parameterize("async_op", [True, False])
|
||||
def check_4gpu(shape, dtype, fp8_format, async_op):
|
||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
x_fp8 = x.clone()
|
||||
origin_handle = dist.all_reduce(x, async_op=async_op)
|
||||
fp8_handle = all_reduce_fp8(x_fp8, fp8_format=fp8_format, async_op=async_op)
|
||||
if async_op:
|
||||
origin_handle.wait()
|
||||
fp8_handle.wait()
|
||||
assert_close(x, x_fp8, rtol=0.1, atol=0.1)
|
||||
|
||||
origin_handle = dist.all_reduce(x, op=dist.ReduceOp.AVG, async_op=async_op)
|
||||
fp8_handle = all_reduce_fp8(x_fp8, op=dist.ReduceOp.AVG, fp8_format=fp8_format, async_op=async_op)
|
||||
if async_op:
|
||||
origin_handle.wait()
|
||||
fp8_handle.wait()
|
||||
assert_close(x, x_fp8, rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_4gpu()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_all_reduce():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_all_reduce()
|
|
@ -0,0 +1,26 @@
|
|||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import cast_from_fp8, cast_from_fp8_pipeline, cast_to_fp8, cast_to_fp8_pipeline
|
||||
from colossalai.testing import parameterize
|
||||
|
||||
|
||||
@parameterize("shape", [(100, 10), (10, 100), (3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16, torch.float32])
|
||||
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
||||
def test_fp8_cast(shape, dtype, fp8_format):
|
||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
ret, scale_inv = cast_to_fp8(x, fp8_format=fp8_format)
|
||||
out = cast_from_fp8(ret, scale_inv, x.dtype)
|
||||
assert_close(out, x, rtol=0.1, atol=0.1)
|
||||
|
||||
if x.size(-1) % 2 == 0:
|
||||
inp_dict = {"hidden_states": x.clone()}
|
||||
cast_to_fp8_pipeline(inp_dict)
|
||||
cast_from_fp8_pipeline(inp_dict)
|
||||
assert_close(inp_dict["hidden_states"], x, rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_fp8_cast()
|
|
@ -0,0 +1,87 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
|
||||
|
||||
|
||||
def setup(rank, world_size):
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12355"
|
||||
|
||||
# initialize the process group
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
|
||||
|
||||
def cleanup():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
class ToyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(ToyModel, self).__init__()
|
||||
self.net1 = nn.Linear(10, 10)
|
||||
self.relu = nn.ReLU()
|
||||
self.net2 = nn.Linear(10, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net2(self.relu(self.net1(x)))
|
||||
|
||||
|
||||
def demo_basic(rank, world_size):
|
||||
print(f"Running basic DDP example on rank {rank}.")
|
||||
setup(rank, world_size)
|
||||
|
||||
def get_grads_after_one_iteration(hook=None):
|
||||
torch.manual_seed(0)
|
||||
# create model and move it to GPU with id rank
|
||||
model = ToyModel().to(rank)
|
||||
|
||||
ddp_model = DDP(model, device_ids=[rank])
|
||||
|
||||
if hook is not None:
|
||||
ddp_model.register_comm_hook(None, hook)
|
||||
|
||||
loss_fn = nn.MSELoss()
|
||||
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = ddp_model(torch.randn(20, 10))
|
||||
labels = torch.randn(20, 5).to(rank)
|
||||
loss_fn(outputs, labels).backward()
|
||||
optimizer.step()
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
grad_dict = {}
|
||||
for name, params in ddp_model.named_parameters():
|
||||
grad_dict[name] = params.grad
|
||||
return grad_dict
|
||||
|
||||
from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async, fp8_compress_ddp_grad_comm_hook_sync
|
||||
|
||||
grad_dict = get_grads_after_one_iteration()
|
||||
for hook in [fp8_compress_ddp_grad_comm_hook_sync, fp8_compress_ddp_grad_comm_hook_async]:
|
||||
grad_dict_w_hook = get_grads_after_one_iteration(hook)
|
||||
if dist.get_rank() == 0:
|
||||
for name in grad_dict:
|
||||
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1)
|
||||
|
||||
cleanup()
|
||||
|
||||
|
||||
def run_demo(demo_fn, world_size):
|
||||
mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
n_gpus = torch.cuda.device_count()
|
||||
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
|
||||
world_size = n_gpus
|
||||
run_demo(demo_basic, world_size)
|
|
@ -0,0 +1,107 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from packaging import version
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai import launch
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
|
||||
|
||||
|
||||
def cleanup():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
class ToyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(ToyModel, self).__init__()
|
||||
self.net1 = nn.Linear(100, 100)
|
||||
self.relu = nn.ReLU()
|
||||
self.net2 = nn.Linear(100, 50)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net2(self.relu(self.net1(x)))
|
||||
|
||||
|
||||
@parameterize("mode", ["grad", "params"])
|
||||
def run_model(mode):
|
||||
rank = dist.get_rank()
|
||||
|
||||
from colossalai.quantization.utils import patch_fsdp_params_comm_hook
|
||||
|
||||
patch_fsdp_params_comm_hook()
|
||||
|
||||
def get_grads_after_one_iteration(grad_hook=None, params_hook=None):
|
||||
torch.manual_seed(0)
|
||||
# create model and move it to GPU with id rank
|
||||
model = ToyModel().to(rank)
|
||||
fsdp_model = FSDP(model)
|
||||
|
||||
if grad_hook is not None:
|
||||
fsdp_model.register_comm_hook(None, grad_hook)
|
||||
|
||||
if params_hook is not None:
|
||||
fsdp_model.register_params_comm_hook(None, params_hook)
|
||||
|
||||
loss_fn = nn.MSELoss()
|
||||
optimizer = optim.SGD(fsdp_model.parameters(), lr=0.001)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = fsdp_model(torch.randn(20, 100))
|
||||
labels = torch.randn(20, 50).to(rank)
|
||||
loss_fn(outputs, labels).backward()
|
||||
optimizer.step()
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
grad_dict = {}
|
||||
for name, params in fsdp_model.named_parameters():
|
||||
grad_dict[name] = params.grad
|
||||
return grad_dict
|
||||
|
||||
from colossalai.quantization.fp8 import fp8_compress_fsdp_grad_comm_hook, fp8_compress_fsdp_params_comm_hook
|
||||
|
||||
if mode == "grad":
|
||||
grad_dict = get_grads_after_one_iteration()
|
||||
for hook in [
|
||||
fp8_compress_fsdp_grad_comm_hook,
|
||||
]:
|
||||
grad_dict_w_hook = get_grads_after_one_iteration(grad_hook=hook)
|
||||
if dist.get_rank() == 0:
|
||||
for name in grad_dict:
|
||||
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1)
|
||||
elif mode == "params":
|
||||
grad_dict = get_grads_after_one_iteration()
|
||||
for hook in [
|
||||
fp8_compress_fsdp_params_comm_hook,
|
||||
]:
|
||||
grad_dict_w_hook = get_grads_after_one_iteration(params_hook=hook)
|
||||
if dist.get_rank() == 0:
|
||||
for name in grad_dict:
|
||||
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def demo_basic(rank, world_size, port):
|
||||
print(f"Running basic FSDP example on rank {rank}.")
|
||||
launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
run_model()
|
||||
cleanup()
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("2.2.0"), reason="torch version < 2.2.0.")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_fsdp():
|
||||
n_gpus = torch.cuda.device_count()
|
||||
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
|
||||
spawn(demo_basic, n_gpus)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_fsdp()
|
|
@ -0,0 +1,52 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai import launch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import gather_fp8
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@parameterize(
|
||||
"shape",
|
||||
[
|
||||
(3, 7),
|
||||
(2, 1),
|
||||
(1, 2),
|
||||
(2, 2),
|
||||
(4, 2),
|
||||
(5,),
|
||||
(4,),
|
||||
(2,),
|
||||
],
|
||||
)
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
||||
@parameterize("async_op", [True, False])
|
||||
def check_4gpu(shape, dtype, fp8_format, async_op):
|
||||
world_size = dist.get_world_size()
|
||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
output_list = [torch.empty_like(x) for _ in range(world_size)]
|
||||
output_list_fp8 = [torch.empty_like(x) for _ in range(world_size)]
|
||||
fp8_handle = gather_fp8(output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op)
|
||||
origin_hanle = dist.all_gather(output_list, x, group=_get_default_group(), async_op=async_op)
|
||||
if async_op:
|
||||
fp8_handle.wait()
|
||||
origin_hanle.wait()
|
||||
assert_close(output_list, output_list_fp8, rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_4gpu()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_all_gather():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_all_gather()
|
|
@ -0,0 +1,50 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import linear_fp8
|
||||
from colossalai.quantization.fp8_hook import FP8Hook
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
REPLACED = False
|
||||
TRIGGERED = False
|
||||
|
||||
|
||||
def new_linear_fp8(x, w, bias=None):
|
||||
global TRIGGERED
|
||||
TRIGGERED = True
|
||||
return linear_fp8(x, w, bias)
|
||||
|
||||
|
||||
class FP8TestHook(FP8Hook):
|
||||
def rewrite_op(self, func):
|
||||
func = super().rewrite_op(func)
|
||||
if func is linear_fp8:
|
||||
global REPLACED
|
||||
REPLACED = True
|
||||
return new_linear_fp8
|
||||
return func
|
||||
|
||||
|
||||
D_IN, D_OUT = 16, 32
|
||||
B, S = 2, 64
|
||||
DTYPE = torch.bfloat16
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
|
||||
def test_fp8_hook():
|
||||
# create tensors
|
||||
w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE))
|
||||
x = torch.rand(B, S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||
w.__class__ = ColoParameter
|
||||
w.__init__(w, requires_grad=True)
|
||||
hook = FP8TestHook()
|
||||
with ColoParamOpHookManager.use_hooks(hook):
|
||||
o = F.linear(x, w)
|
||||
assert o.shape == (B, S, D_OUT)
|
||||
assert REPLACED
|
||||
assert TRIGGERED
|
|
@ -0,0 +1,45 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import linear_fp8
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
D_IN, D_OUT = 16, 32
|
||||
B, S = 2, 64
|
||||
DTYPE = torch.bfloat16
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("use_batch", [True, False])
|
||||
def test_fp8_linear(use_bias: bool, use_batch: bool):
|
||||
# create tensors
|
||||
w = torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||
ref_w = w.clone().detach().requires_grad_()
|
||||
if use_batch:
|
||||
x_shape = (B, S, D_IN)
|
||||
else:
|
||||
x_shape = (S, D_IN)
|
||||
x = torch.rand(x_shape, device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||
ref_x = x.clone().detach().requires_grad_()
|
||||
if use_bias:
|
||||
bias = torch.rand(D_OUT, device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||
ref_bias = bias.clone().detach().requires_grad_()
|
||||
else:
|
||||
bias = None
|
||||
ref_bias = None
|
||||
|
||||
out = linear_fp8(x, w, bias)
|
||||
assert out.shape == x_shape[:-1] + (D_OUT,)
|
||||
out.sum().backward()
|
||||
ref_out = F.linear(ref_x, ref_w, ref_bias)
|
||||
ref_out.sum().backward()
|
||||
|
||||
assert_close(out, ref_out, rtol=0.2, atol=0.1)
|
||||
assert_close(x.grad, ref_x.grad, rtol=0.2, atol=0.1)
|
||||
assert_close(w.grad, ref_w.grad, rtol=0.2, atol=0.1)
|
||||
if use_bias:
|
||||
assert_close(bias.grad, ref_bias.grad, rtol=0.2, atol=0.1)
|
|
@ -0,0 +1,44 @@
|
|||
import torch
|
||||
from torch.distributed import reduce_scatter
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai import launch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import reduce_scatter_fp8
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@parameterize("shape", [(16, 8, 4)])
|
||||
@parameterize("scatter_dim", [0, 1, 2])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
||||
@parameterize("async_op", [True, False])
|
||||
def check_4gpu(shape, scatter_dim, dtype, fp8_format, async_op):
|
||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
input_list = list(torch.chunk(x, dim=scatter_dim, chunks=4))
|
||||
input_list = [t.contiguous() for t in input_list]
|
||||
output_origin = torch.empty_like(input_list[0])
|
||||
output_fp8 = torch.empty_like(input_list[0])
|
||||
origin_handle = reduce_scatter(output_origin, input_list, group=_get_default_group(), async_op=async_op)
|
||||
fp8_handle = reduce_scatter_fp8(
|
||||
output_fp8, input_list, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op
|
||||
)
|
||||
if async_op:
|
||||
origin_handle.wait()
|
||||
fp8_handle.wait()
|
||||
assert_close(output_origin, output_fp8, rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_4gpu()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_reduce_scatter():
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_reduce_scatter()
|
|
@ -51,7 +51,8 @@ def split_ddp_grad(grad, world_size):
|
|||
return splited_grad
|
||||
|
||||
|
||||
def exam_zero_1_2():
|
||||
@parameterize("fp8_communication", [True, False])
|
||||
def exam_zero_1_2(fp8_communication: bool):
|
||||
"""
|
||||
In this test, we want to test whether zero stage 1 and 2
|
||||
deliver the same numerical results despite different communication
|
||||
|
@ -73,10 +74,18 @@ def exam_zero_1_2():
|
|||
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
|
||||
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
|
||||
zero1_optimizer = LowLevelZeroOptimizer(
|
||||
zero1_optimizer, overlap_communication=True, initial_scale=128, verbose=True
|
||||
zero1_optimizer,
|
||||
overlap_communication=True,
|
||||
initial_scale=128,
|
||||
verbose=True,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
zero2_optimizer = LowLevelZeroOptimizer(
|
||||
zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=128
|
||||
zero2_optimizer,
|
||||
overlap_communication=True,
|
||||
partition_grad=True,
|
||||
initial_scale=128,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
# create data
|
||||
seed_all(2001 + local_rank)
|
||||
|
@ -97,7 +106,10 @@ def exam_zero_1_2():
|
|||
if g1 is None or g2 is None:
|
||||
assert g1 is None and g2 is None
|
||||
continue
|
||||
assert torch.allclose(g1, g2)
|
||||
if fp8_communication:
|
||||
loose_close(g1, g2, dtype=torch.float16)
|
||||
else:
|
||||
assert torch.allclose(g1, g2)
|
||||
|
||||
# step
|
||||
zero1_optimizer.step()
|
||||
|
@ -105,7 +117,8 @@ def exam_zero_1_2():
|
|||
|
||||
# check updated param
|
||||
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
|
||||
assert torch.allclose(z1p, z2p)
|
||||
if not fp8_communication:
|
||||
assert torch.allclose(z1p, z2p)
|
||||
|
||||
|
||||
@parameterize("dtype", [torch.float16, torch.bfloat16])
|
||||
|
|
Loading…
Reference in New Issue