[Feature] Zigzag Ring attention (#5905)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/6015/head
Edenzzzz 2024-08-16 13:56:38 +08:00 committed by GitHub
parent 887d2d579b
commit f5c84af0b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
50 changed files with 1870 additions and 326 deletions

View File

@ -12,6 +12,7 @@ repos:
hooks: hooks:
- id: isort - id: isort
name: sort all imports (python) name: sort all imports (python)
args: ["--profile", "black"] # avoid conflict with black
- repo: https://github.com/psf/black-pre-commit-mirror - repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.8.0 rev: 24.8.0

View File

@ -32,7 +32,7 @@ from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackw
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp
from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.d_tensor.api import is_distributed_tensor from colossalai.tensor.d_tensor.api import is_distributed_tensor
@ -42,7 +42,7 @@ from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_hand
from .pp_plugin_base import PipelinePluginBase from .pp_plugin_base import PipelinePluginBase
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"]
PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
@ -72,7 +72,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
self.dp_group = dp_group self.dp_group = dp_group
self.tp_group = tp_group self.tp_group = tp_group
self.sp_group = sp_group self.sp_group = sp_group
self.use_dpp = use_ddp self.use_ddp = use_ddp
self.require_grad_sync = True self.require_grad_sync = True
self.overlap_allgather = overlap_allgather self.overlap_allgather = overlap_allgather
@ -139,8 +139,8 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
# Disable automatic gradient synchronization. # Disable automatic gradient synchronization.
self.require_grad_sync = False self.require_grad_sync = False
try: try:
if self.use_dpp: if self.use_ddp:
# If using data parallel processing (use_dpp), disable synchronization too. # If using data parallel processing (use_ddp), disable synchronization too.
with self.module.no_sync(): with self.module.no_sync():
yield yield
else: else:
@ -188,7 +188,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
""" """
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:
if self.shard_config.sequence_parallelism_mode == "all_to_all": if self.shard_config.sequence_parallelism_mode in ["all_to_all", "ring_attn"]:
return return
if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
@ -970,6 +970,9 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn".
It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default.
""" """
def __init__( def __init__(
@ -1017,6 +1020,7 @@ class HybridParallelPlugin(PipelinePluginBase):
dp_outside: bool = True, dp_outside: bool = True,
overlap_p2p: bool = True, overlap_p2p: bool = True,
overlap_allgather: bool = False, overlap_allgather: bool = False,
inner_ring_size: int = None,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -1041,9 +1045,11 @@ class HybridParallelPlugin(PipelinePluginBase):
) )
self.sp_size = 1 self.sp_size = 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size) self.dp_size = dist.get_world_size() // (tp_size * pp_size)
elif self.sequence_parallelism_mode in ["all_to_all"]: elif self.sequence_parallelism_mode in ["all_to_all", "ring_attn"]:
self.sp_size = 1 if sp_size is None else sp_size self.sp_size = 1 if sp_size is None else sp_size
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size) self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size)
if self.sequence_parallelism_mode == "ring_attn":
enable_flash_attention = True
else: else:
self.dp_size = dist.get_world_size() // (tp_size * pp_size) self.dp_size = dist.get_world_size() // (tp_size * pp_size)
assert ( assert (
@ -1063,10 +1069,21 @@ class HybridParallelPlugin(PipelinePluginBase):
self.enable_sequence_parallelism = enable_sequence_parallelism self.enable_sequence_parallelism = enable_sequence_parallelism
if dp_outside: if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) if sequence_parallelism_mode == "ring_attn":
# Swap tp and sp since 2D Ring has better inter-node latency
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size)
self.sp_axis = 2
self.tp_axis = 3
else:
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
else: else:
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) if sequence_parallelism_mode == "ring_attn":
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.sp_size, self.tp_size)
self.sp_axis = 2
self.tp_axis = 3
else:
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
self.stage_manager = None self.stage_manager = None
self.schedule = None self.schedule = None
@ -1108,6 +1125,8 @@ class HybridParallelPlugin(PipelinePluginBase):
) )
else: else:
raise NotImplementedError() raise NotImplementedError()
if sequence_parallelism_mode == "ring_attn":
assert parallel_output, "Ring Attention doesn't support gathering output yet."
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
@ -1132,6 +1151,7 @@ class HybridParallelPlugin(PipelinePluginBase):
parallel_output=parallel_output, parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by, make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config, gradient_checkpoint_config=gradient_checkpoint_config,
inner_ring_size=inner_ring_size,
) )
self.amp_config = dict( self.amp_config = dict(
initial_scale=initial_scale, initial_scale=initial_scale,
@ -1216,15 +1236,15 @@ class HybridParallelPlugin(PipelinePluginBase):
zero_stage = 0 zero_stage = 0
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
# Shouldn't use pp (frequent grad accumulation) with torch ddp
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1 self.dp_size == 1 and self.pp_size == 1
and self.pp_size == 1
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
) )
# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": # Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.dp_size = get_world_size(dp_group)
else: else:
dp_group = self.dp_group dp_group = self.dp_group
model = HybridParallelModule( model = HybridParallelModule(

View File

@ -203,7 +203,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
return return
Path(checkpoint).mkdir(parents=True, exist_ok=True) Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Devices along the same dp_group share the same copies of model. # Devices along the same dp_group share the same copies of model.
# So only let the device with dp_rank == 0 save the model. # So only let the device with dp_rank == 0 save the model.
if self.dp_rank != 0: if self.dp_rank != 0:
@ -643,14 +642,12 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
assert isinstance(model, ModelWrapper), "Please boost the model before saving!" assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather() model._force_wait_all_gather()
model = model.unwrap() model = model.unwrap()
if self.dp_rank != 0: if self.dp_rank != 0:
return return
# The logic of collecting parameter shards along tp degree # The logic of collecting parameter shards along tp degree
# has been implemented by _save_to_state_dict method of ParallelModule in Shardformer. # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
state_dict = model.state_dict() state_dict = model.state_dict()
if self.pp_size == 1: if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict. # When pipeline is not used, let master rank directly save the collected state_dict.
if self.tp_rank == 0: if self.tp_rank == 0:
@ -660,7 +657,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_dict_list = [None for _ in range(self.pp_size)] state_dict_list = [None for _ in range(self.pp_size)]
dist.barrier(self.pp_group) dist.barrier(self.pp_group)
dist.all_gather_object(state_dict_list, state_dict, self.pp_group) dist.all_gather_object(state_dict_list, state_dict, self.pp_group)
# Only the master rank do the saving. # Only the master rank do the saving.
if self.coordinator.is_master(): if self.coordinator.is_master():
complete_state_dict = dict() complete_state_dict = dict()

View File

@ -62,7 +62,6 @@ def new_from_pretrained(
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
@ -116,7 +115,6 @@ def new_from_pretrained(
cache_dir=cache_dir, cache_dir=cache_dir,
return_unused_kwargs=True, return_unused_kwargs=True,
force_download=force_download, force_download=force_download,
resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
@ -195,7 +193,6 @@ def new_from_pretrained(
"cache_dir": cache_dir, "cache_dir": cache_dir,
"force_download": force_download, "force_download": force_download,
"proxies": proxies, "proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only, "local_files_only": local_files_only,
"use_auth_token": use_auth_token, "use_auth_token": use_auth_token,
"user_agent": user_agent, "user_agent": user_agent,
@ -312,7 +309,6 @@ def new_from_pretrained(
pretrained_model_name_or_path, pretrained_model_name_or_path,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,

View File

@ -171,7 +171,7 @@ class OpenMoeForCausalLMPolicy(OpenMoePolicy):
policy = super().module_policy() policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm # add a new item for causal lm
# TODO: recursively assign ep group foe all modules # TODO: recursively assign ep group foe all modules
new_item = { new_item = {
OpenMoeForCausalLM: ModulePolicyDescription( OpenMoeForCausalLM: ModulePolicyDescription(

View File

@ -81,6 +81,9 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True) handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have # Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated # all-reduce scheduled first and have GPU resources allocated
# TODO: This seems to only work if you add torch.cuda.Event.wait()
# _ = torch.zeros(1, device=grad_output.device)
grad_weight = grad_output.t().matmul(total_input) grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None grad_bias = grad_output.sum(dim=0) if use_bias else None

View File

@ -64,7 +64,10 @@ class DistributedLogger:
self._logger.propagate = False self._logger.propagate = False
DistributedLogger.__instances[name] = self DistributedLogger.__instances[name] = self
self.rank = dist.get_rank() if dist.is_initialized() else 0
@property
def rank(self):
return dist.get_rank() if dist.is_initialized() else 0
@staticmethod @staticmethod
def __get_call_info(): def __get_call_info():

View File

@ -286,7 +286,6 @@ class InterleavedSchedule(PipelineSchedule):
# for the first stage, input_obj is None # for the first stage, input_obj is None
# for other stages, input_obj is the output of the previous stage containing hidden_states etc. # for other stages, input_obj is the output of the previous stage containing hidden_states etc.
# Only attention_mask from micro_batch is used # Only attention_mask from micro_batch is used
with self.stage_manager.switch_model_chunk_id(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if isinstance(model_chunk, ModuleList): if isinstance(model_chunk, ModuleList):
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)

View File

@ -244,6 +244,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
output_obj = model_forward(model, micro_batch, input_obj) output_obj = model_forward(model, micro_batch, input_obj)
if self.stage_manager.is_last_stage(): if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatches loss = criterion(output_obj, micro_batch) / self.num_microbatches
if accum_loss is not None: if accum_loss is not None:
accum_loss.add_(loss.detach()) accum_loss.add_(loss.detach())
if outputs is not None: if outputs is not None:

View File

@ -1,5 +1,5 @@
from ._operation import all_to_all_comm from ._operation import all_to_all_comm
from .attn import AttnMaskType, ColoAttention from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
@ -31,5 +31,7 @@ __all__ = [
"VocabParallelLMHead1D", "VocabParallelLMHead1D",
"AttnMaskType", "AttnMaskType",
"ColoAttention", "ColoAttention",
"RingAttention",
"get_pad_info",
"all_to_all_comm", "all_to_all_comm",
] ]

View File

@ -2,6 +2,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from .utils import is_share_sp_tp
try: try:
import fused_mix_prec_layer_norm_cuda import fused_mix_prec_layer_norm_cuda
except: except:
@ -93,7 +95,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
if ctx.async_grad_allreduce: if ctx.async_grad_allreduce:
# Asynchronous all-reduce # Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
grad_weight = total_input.t().matmul(grad_output) grad_weight = total_input.t().matmul(grad_output)
@ -143,7 +145,9 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
if ctx.async_grad_allreduce: if ctx.async_grad_allreduce:
# Asynchronous all-reduce # Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have _ = torch.zeros(1, device=grad_input.device)
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
if _grad_accum_fusion_available and weight.grad is not None: if _grad_accum_fusion_available and weight.grad is not None:
@ -331,7 +335,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous() ).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
if _grad_accum_fusion_available and weight.grad is not None: if _grad_accum_fusion_available and weight.grad is not None:
@ -646,8 +650,8 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
).contiguous() ).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py # all-reduce scheduled first and have GPU resources allocated
grad_weight = total_input.t().matmul(grad_output) grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None grad_bias = grad_output.sum(dim=0) if use_bias else None
@ -721,16 +725,20 @@ class _ReduceForward(torch.autograd.Function):
Args: Args:
input_: input matrix. input_: input matrix.
parallel_mode: parallel mode. process_group: communication group.
""" """
@staticmethod @staticmethod
def forward(ctx, input_, process_group): def forward(ctx, input_, process_group, grad_scale=None):
ctx.grad_scale = grad_scale
return _reduce(input_, process_group) return _reduce(input_, process_group)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
return grad_output, None if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale
return grad_output, None, None
class _ReduceBackward(torch.autograd.Function): class _ReduceBackward(torch.autograd.Function):
@ -979,8 +987,8 @@ def split_forward_gather_backward(input_, dim, process_group, grad_scale=None):
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale) return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale)
def reduce_forward(input_, process_group): def reduce_forward(input_, process_group, grad_scale=None):
return _ReduceForward.apply(input_, process_group) return _ReduceForward.apply(input_, process_group, grad_scale)
def reduce_backward(input_, process_group): def reduce_backward(input_, process_group):
@ -989,3 +997,13 @@ def reduce_backward(input_, process_group):
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
def gather_sp_output(hidden_states, sp_group, sp_mode):
"""
Gather the output of the last layer for cross entropy computation
"""
# Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group)
scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group)
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=scale)
return hidden_states

View File

@ -2,7 +2,10 @@ from enum import Enum
from typing import Callable, Dict, Optional, Tuple from typing import Callable, Dict, Optional, Tuple
import torch import torch
import torch.distributed
import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange
from colossalai.kernel.kernel_loader import ( from colossalai.kernel.kernel_loader import (
FlashAttentionForFloatAndCustomMaskLoader, FlashAttentionForFloatAndCustomMaskLoader,
@ -10,12 +13,18 @@ from colossalai.kernel.kernel_loader import (
FlashAttentionWithCustomMaskLoader, FlashAttentionWithCustomMaskLoader,
KernelLoader, KernelLoader,
) )
from colossalai.logging import get_dist_logger
from .utils import RingComm, get_half_index, split_varlen_zigzag
__all__ = [ __all__ = [
"AttnMaskType", "AttnMaskType",
"ColoAttention", "ColoAttention",
] ]
_flash_attn_forward = _flash_attn_backward = None
_unpad_input = _pad_input = None
class AttnMaskType(Enum): class AttnMaskType(Enum):
CUSTOM = 0 CUSTOM = 0
@ -38,20 +47,32 @@ def invert_mask(mask: torch.Tensor) -> torch.Tensor:
# adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py # adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.Tensor]: def get_pad_info(
padding_mask: torch.Tensor, invert: Optional[bool] = False, return_indices: Optional[bool] = True
) -> Tuple[int, torch.Tensor, torch.Tensor]:
"""Get padding information from padding mask. """Get padding information from padding mask.
Args: Args:
padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, S] padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, Skv]
invert (Optional[bool], optional): Whether to reverse the padding mask.
return_indices (Optional[bool], optional): Whether to return the indices of non-masked tokens.
Returns: Returns:
Tuple[int, torch.Tensor, torch.Tensor]: Tuple of (max_seq_len, cu_seqlens, indices) max_seqlen_in_batch (int): Maximum sequence length in the batch.
cu_seqlens (torch.Tensor): Shape [B+1]. Cumulative sequence lengths of the sequences in the batch.
indices (torch.Tensor): Shape [total_nonzero]. The indices of non-masked tokens from the flattened input sequence.
""" """
if invert:
padding_mask = padding_mask.logical_not()
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() if return_indices:
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item() max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return max_seqlen_in_batch, cu_seqlens, indices if return_indices:
return max_seqlen_in_batch, cu_seqlens, indices
return max_seqlen_in_batch, cu_seqlens
class ColoAttention: class ColoAttention:
@ -107,6 +128,7 @@ class ColoAttention:
q_padding_mask: Optional[torch.Tensor] = None, q_padding_mask: Optional[torch.Tensor] = None,
kv_padding_mask: Optional[torch.Tensor] = None, kv_padding_mask: Optional[torch.Tensor] = None,
is_causal: bool = False, is_causal: bool = False,
invert: bool = True,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
"""Return a dictionary of keyword arguments for attention function. It supports 4 mask type. """Return a dictionary of keyword arguments for attention function. It supports 4 mask type.
1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves. 1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves.
@ -124,7 +146,7 @@ class ColoAttention:
The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token. The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token.
If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None. If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None.
is_causal (bool, optional): Whether to use causal attention mask. Defaults to False. is_causal (bool, optional): Whether to use causal attention mask. Defaults to False.
invert_mask (bool, optional): Whether to invert the mask. Defaults to True.
Returns: Returns:
Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function. Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function.
""" """
@ -154,7 +176,7 @@ class ColoAttention:
assert kv_padding_mask.shape == ( assert kv_padding_mask.shape == (
b, b,
s_kv, s_kv,
), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})" ), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})"
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
outputs.update( outputs.update(
{ {
@ -172,7 +194,8 @@ class ColoAttention:
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
else: else:
outputs["attention_mask_type"] = AttnMaskType.PADDED outputs["attention_mask_type"] = AttnMaskType.PADDED
attention_mask = invert_mask(attention_mask).unsqueeze(1) if invert:
attention_mask = invert_mask(attention_mask).unsqueeze(1)
outputs["attention_mask"] = attention_mask outputs["attention_mask"] = attention_mask
return outputs return outputs
@ -191,6 +214,7 @@ class ColoAttention:
kv_indices: Optional[torch.Tensor] = None, kv_indices: Optional[torch.Tensor] = None,
dropout_p: float = 0.0, dropout_p: float = 0.0,
scale: Optional[float] = None, scale: Optional[float] = None,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
"""Flash Attention function. It supports 4 mask type. """Flash Attention function. It supports 4 mask type.
1. custom mask: recv attention_mask 1. custom mask: recv attention_mask
@ -199,9 +223,9 @@ class ColoAttention:
4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices 4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
Args: Args:
q (torch.Tensor): Query tensor. Shape should be [B, N, Sq, D] q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D]
k (torch.Tensor): Key tensor. Shape should be [B, N, Skv, D] k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Skv, D]
v (torch.Tensor): Value tensor. Shape should be [B, N, Skv, D] v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Skv, D]
attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None. attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None.
attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM. attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM.
cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths
@ -218,7 +242,7 @@ class ColoAttention:
scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None. scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None.
Returns: Returns:
torch.Tensor: Output tensor. Shape should be [B, N, Sq, D] torch.Tensor: Output tensor. Shape should be [B, nHeads, Sq, D]
""" """
# known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan # known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan
# this case is usaul when padding mask is used and self attention is performed # this case is usaul when padding mask is used and self attention is performed
@ -252,6 +276,7 @@ class ColoAttention:
else: else:
# if attention_mask is None, attention_mask_type should be the default value # if attention_mask is None, attention_mask_type should be the default value
assert attention_mask_type == AttnMaskType.CUSTOM assert attention_mask_type == AttnMaskType.CUSTOM
# kernel dispatch # kernel dispatch
mask_type = attention_mask_type if attention_mask is not None else None mask_type = attention_mask_type if attention_mask is not None else None
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type) attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type)
@ -274,3 +299,858 @@ class ColoAttention:
q_indices=q_indices, q_indices=q_indices,
kv_indices=kv_indices, kv_indices=kv_indices,
) )
def _load_varlen_helpers():
"""Helper to load functions for padding and unpadding packed sequences.
Use only when flash attn is installed
"""
global _pad_input, _unpad_input
# Flash attn claims this is more efficient than torch's bool indexing due to avoiding
# broadcast
if _pad_input is None or _unpad_input is None:
try:
from flash_attn.bert_padding import index_first_axis, pad_input
def unpad_input(hidden_states: torch.Tensor, indices: torch.Tensor):
return index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices)
_pad_input = pad_input
_unpad_input = unpad_input
except ImportError as e:
raise RuntimeError(
f"Flash Attention is not installed. You can install it via 'pip install flash-attn --no-build-isolation'"
) from e
def _load_flash_attn():
"""A light-weight loader to check whether flash-attn is installed.
Can't use ColoAttention._dispatch_kernel because we mutate the backward pass
"""
global _flash_attn_forward, _flash_attn_backward
if _flash_attn_forward is None or _flash_attn_backward is None:
try:
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward
except ImportError as e:
raise RuntimeError(
f"Flash Attention is not installed. You can install it via 'pip install flash-attn --no-build-isolation'"
) from e
_load_varlen_helpers()
# NOTE: This can cause spawned processes to hang on exit
# with python 3.9
@torch.compile()
def _rescale_out_lse(out, block_out, lse, block_lse):
"""
Compute the new attention denominator:
exp(lse) + exp(block_lse) = exp(max_scale) * (exp(min_scale - max_scale) + 1)
Args:
out: (T, H, D)
block_out: (T, H, D)
lse: (H, T, 1)
block_lse: (H, T, 1)
"""
# min_scale = torch.min(lse, block_lse)
# max_scale = torch.max(lse, block_lse)
# new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))
# NOTE: directly assigning to .data here is buggy
# probably due to casting dtypes/strides
new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
new_block_lse = torch.exp(block_lse - new_lse)
out = (torch.exp(lse - new_lse) * out + new_block_lse * block_out).to(out)
lse = new_lse
# Equivalent to the above
# See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
# out = (out - F.sigmoid(block_lse - lse) * (out - block_out))
# lse = (lse - F.logsigmoid(lse - block_lse))
return out, lse
class RingAttention(torch.autograd.Function):
"""Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context`
(https://arxiv.org/abs/2310.01889).
For load-balancing we adopted the "zigzag" attention scheme from https://github.com/zhuzilin/ring-flash-attention/tree/main
For portable integration with more models, we don't follow the spirit of "block-wise FNN" in the original paper,
which requires fusing FFN with the Flash Attention kernel/function (see https://arxiv.org/pdf/2305.19370;
implemented in Jax and not optimized).
We adopt the double ring topology from LoongTrain (https://arxiv.org/pdf/2406.18485) to fully utilize available
NICs on each node, by computing attention within a inner ring first and then sending all KVs to the next
ring at once.
"""
# Globle cache to avoid recomputation for same-lengthed sequences
CU_SEQLENS: torch.Tensor = None # [B+1]
TOTAL_SEQLEN: int = None
HALF_INDICES: Tuple = None
SUPPORTED_MASK_TYPES = (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL)
ATTN_DONE: torch.cuda.Event = None
SP_STREAM: torch.cuda.Stream = None
SP_GROUP: dist.ProcessGroup = None
# duplicate process group for concurrent NCCL streams
# while both PyTorch and NCCL warns(https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7)
# against this, in practice it seems to work fine.
INNER_RING_GROUP: dist.ProcessGroup = None
INNER_RING_GROUP_COPY: dist.ProcessGroup = None
INTER_RING_GROUP: dist.ProcessGroup = None
INTER_RING_GROUP_COPY: dist.ProcessGroup = None
@staticmethod
def get_double_ring_groups(sp_group, inner_ring_size=None):
"""
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
shouldn't be larger than the number of NICs on each node.
Args:
sp_group (dist.ProcessGroup): Process group for sequence parallelism
inner_ring_size (Optional[int], optional): Inner ring size. Defaults to None.
Returns:
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
"""
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
if inner_ring_size is None:
if torch.cuda.device_count() >= dist.get_world_size():
# single node, no need to consider NICs
return sp_group, sp_group
if sp_size <= 4:
inner_ring_size = min(2, sp_size)
else:
inner_ring_size = min(4, sp_size)
else:
assert (
inner_ring_size <= sp_size and sp_size % inner_ring_size == 0
), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
if inner_ring_size == sp_size:
return sp_group, sp_group
assert (
sp_size % inner_ring_size == 0
), f"sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
logger = get_dist_logger()
logger.info(
f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!",
ranks=[0],
)
num_rings = sp_size // inner_ring_size
inner_ring_group = None
inter_ring_group = None
# Create inner ring groups
for i in range(inner_ring_size):
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
group = dist.new_group(ranks)
if sp_rank in ranks:
inner_ring_group = group
# Create inter ring groups
for i in range(num_rings):
ranks = list(range(i, sp_size, num_rings))
group = dist.new_group(ranks)
if sp_rank in ranks:
inter_ring_group = group
return inner_ring_group, inter_ring_group
@staticmethod
def attention(
q, # (B, H, Sq, D)
k,
v,
sp_group,
attention_mask_type,
cu_seqlens=None,
max_seqlen=None,
valid_indices=None,
dropout_p=0.0,
softmax_scale=None,
deterministic=False,
return_softmax=False,
inner_ring_size=None,
**kwargs,
):
"""
Ring Attention forward pass supporting variable-length sequences. When using varlen mode,
each sequence in the batch should have length divisible by sp_size * 2.
Args:
q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D]
k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D]
v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D]
sp_group (Optional[dist.ProcessGroup]): Process group for sequence parallelism
sp_tream (torch.cuda.Stream): An different stream for output correction.
cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths
of the sequences in the batch, used to index into q.
Shape should be [B+1].
max_seqlen (Optional[int], optional): Maximum query sequence length in the batch.
valid_indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from get_pad_info.
Shape should be [t].
dropout_p (float, optional): Dropout probability. Defaults to 0.0.
softmax_scale (Optional[float], optional): Scaling factor applied prior to softmax.
deterministic (bool, optional): Whether to force deterministic backward pass. See https://github.com/Dao-AILab/flash-attention/issues/349
return_softmax (bool, optional): Whether to return the softmax denominator (logsumexp).
inner_ring_size (Optional[int], optional): Inner ring size of the 2D ring. By default use a heuristic to decide.
Returns:
out: Output tensor of shape [B, nHeads, Sq, D] or [T, nHeads, D] if pad_output is False.
softmax_lse: (if return_softmax is True) Softmax denominator (logsumexp).
Shape should be [total_q_seqlen, nHeads]
"""
# Check input args
_load_flash_attn()
if RingAttention.ATTN_DONE is None:
RingAttention.ATTN_DONE = torch.cuda.Event()
if RingAttention.SP_STREAM is None:
RingAttention.SP_STREAM = torch.cuda.Stream()
assert (
q.shape[2] == k.shape[2]
), "Q, K and V having different sequence lengths (inference or cross-attn)\
is not supported yet in training."
assert (
attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES
), f"Mask type {attention_mask_type} is not supported yet."
clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg))
if RingAttention.SP_GROUP is not sp_group:
RingAttention.SP_GROUP = sp_group
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size)
RingAttention.INNER_RING_GROUP = inner_ring_group
RingAttention.INTER_RING_GROUP = inter_ring_group
else:
inner_ring_group = RingAttention.INNER_RING_GROUP
inter_ring_group = RingAttention.INTER_RING_GROUP
# (B, H, Sq, D) -> (B, Sq, H, D)
q, k, v = [x.transpose(1, 2).contiguous() for x in (q, k, v)]
pad_output = q.dim() == 4
# Get sequence length info for varlen forward
if attention_mask_type == AttnMaskType.CAUSAL:
# All sequences share the same length
b, sq, h, d = q.shape
max_seqlen = sq
# Cache to avoid recreation for a single sequence
if sq * b == RingAttention.TOTAL_SEQLEN:
cu_seqlens = RingAttention.CU_SEQLENS
else:
cu_seqlens = torch.arange(0, b * sq + 1, sq, device=q.device, dtype=torch.int32)
RingAttention.TOTAL_SEQLEN = b * sq
# "Packed" mode where sequences of different lengths are packed into [total_q_seqlen, H, D]
elif attention_mask_type == AttnMaskType.PADDED_CAUSAL:
assert (
cu_seqlens is not None and max_seqlen is not None and valid_indices is not None
), "Packed mode requires pre-computed cu_seqlens and max_seq_len."
if pad_output:
b, sq, h, d = q.shape
q, k, v = [_unpad_input(x, valid_indices) for x in (q, k, v)]
out, softmax_lse = RingAttention.apply(
q,
k,
v,
sp_group,
RingAttention.SP_STREAM,
cu_seqlens,
max_seqlen,
dropout_p,
softmax_scale,
deterministic,
return_softmax,
attention_mask_type == AttnMaskType.PADDED_CAUSAL,
inner_ring_group,
inter_ring_group,
)
if attention_mask_type == AttnMaskType.PADDED_CAUSAL:
if pad_output:
out = _pad_input(out, valid_indices, b, sq) # (T, ...) -> (B, Sq, ...)
out = out.transpose(1, 2) # (B, Sq, H, D) -> (B, H, Sq, D)
else:
out = out.transpose(1, 2)
if return_softmax:
return out, softmax_lse
return out
@staticmethod
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
sp_group: dist.ProcessGroup,
sp_stream: torch.cuda.Stream,
cu_seqlens: torch.Tensor,
max_seqlen: int,
dropout_p: float = 0.0,
softmax_scale: Optional[float] = None,
deterministic: Optional[bool] = False,
return_softmax: Optional[bool] = False,
is_packed: Optional[bool] = False,
inner_ring_group: Optional[dist.ProcessGroup] = None,
inter_ring_group: Optional[dist.ProcessGroup] = None,
):
cu_seqlens_q = cu_seqlens_kv = cu_seqlens
max_seqlen_q = max_seqlen_kv = max_seqlen
cu_seqlens_half = cu_seqlens // 2
max_seqlen_half = max_seqlen // 2
misc_kwargs = {
"window_size": (-1, -1),
"alibi_slopes": None,
"softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale,
"dropout_p": dropout_p,
"block_table": None,
"softcap": 0.0,
"return_softmax": False,
}
if (
RingAttention.HALF_INDICES is not None
and cu_seqlens.shape == RingAttention.CU_SEQLENS.shape
and (cu_seqlens == RingAttention.CU_SEQLENS).all()
):
half_idx_front, half_idx_back = RingAttention.HALF_INDICES
else:
half_idx_front = get_half_index(cu_seqlens, front=True)
half_idx_back = get_half_index(cu_seqlens, front=False)
RingAttention.HALF_INDICES = (half_idx_front, half_idx_back)
RingAttention.CU_SEQLENS = cu_seqlens
if is_packed:
t, h, d = q.shape
else:
b, sq, h, d = q.shape
t = b * sq
# Be careful about GQA/MQA in reshape
q, k, v = [x.view(t, *x.shape[-2:]) for x in (q, k, v)]
if inner_ring_group is None or inter_ring_group is None:
# Use one ring if not specified
inner_ring_group = inter_ring_group = sp_group
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
# Attempt to achieve concurrent comm in the two-stream forward
local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)]
inter_ring_comm = RingComm(inter_ring_group)
local_sp_size = dist.get_world_size(inner_ring_group)
local_sp_rank = dist.get_rank(inner_ring_group)
inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0
num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1
# Non-contiguous indexing copies to a new contiguous tensor,
# so only do it once
if sp_rank != sp_size - 1:
q1 = q[half_idx_back]
# Pre-allocate double buffer for overlapping and receiving next step's inputs
kv_buffers = [torch.stack((k, v))] # (2, B, Sq, H, D)
kv_buffers.append(torch.empty_like(kv_buffers[0]))
# outputs
out = None
block_out = [None, None]
softmax_lse = [None, None]
block_softmax_lse = [None, None] # log sum exp, the denominator of softmax in attention
rng_states = [None for _ in range(sp_size)]
sp_streams = [torch.cuda.current_stream(), sp_stream]
def _forward(q, k, v, causal):
(
_,
_,
_,
_,
out,
softmax_lse,
_,
rng_state,
) = _flash_attn_forward(
q,
k,
v,
cu_seqlens_q if q.shape[0] == t else cu_seqlens_half,
cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half,
max_seqlen_q if q.shape[0] == t else max_seqlen_half,
max_seqlen_kv if k.shape[0] == t else max_seqlen_half,
causal=causal,
**misc_kwargs,
)
return out, softmax_lse, rng_state
def _local_ring_forward():
# (Hopefully) overlap output correction with next flash attn
for i in range(local_sp_size):
with torch.cuda.stream(sp_streams[i % 2]):
# Wait for current kv from prev rank
# NOTE: waiting outside the current stream will NOT correctly synchronize.
if i > 0:
local_kv_comms[(i + 1) % 2].wait()
# Avoid overwriting attn input when it shares mem with buffer
if not RingAttention.ATTN_DONE.query():
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
if i < local_sp_size - 1:
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
if i == 0:
# Compute with local KV; no mask
kv_block = kv_buffers[0]
q_block = q
(block_out[i % 2], block_softmax_lse[i % 2], rng_states[i]) = _forward( # (T, H, D) # (H, T)
q_block, kv_block[0], kv_block[1], causal=True
)
elif i <= local_sp_rank:
# Received the "surrounding" kv chunks
# Drop the second half of received kv
# (2, t // 2, H, D)
kv_block = kv_buffers[i % 2][:, half_idx_front]
q_block = q
(
block_out[i % 2], # (T, H, D)
block_softmax_lse[i % 2], # (H, T)
rng_states[i],
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
else:
# Received the inner kv chunks
# Drop the first half of q
kv_block = kv_buffers[i % 2]
q_block = q1
(
block_out[i % 2], # (T, H, D)
block_softmax_lse[i % 2], # (H, T)
rng_states[i],
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record()
block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
) # (H, T) -> (T, H, 1)
assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
# Output and log sum exp correction. Ideally overlap this with the next flash attn kernel.
# In reality this always finishes before next flash attn; no need for extra sync.
if i == 0:
out = block_out[0]
softmax_lse = block_softmax_lse[0]
elif i <= local_sp_rank:
out, softmax_lse = _rescale_out_lse(
out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2]
)
else:
out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse(
out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2]
)
torch.cuda.current_stream().wait_stream(sp_stream)
return out, softmax_lse
def _other_ring_forward(ring_num_idx, out, softmax_lse):
# Loop through the inner ring after receiving
# all new KVs from the previous inner ring
for i in range(local_sp_size):
with torch.cuda.stream(sp_streams[i % 2]):
if not RingAttention.ATTN_DONE.query():
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
if i < local_sp_size - 1:
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
# Send & recv KV
if i > 0:
local_kv_comms[(i + 1) % 2].wait()
if ring_num_idx > inter_ring_rank:
kv_block = kv_buffers[i % 2]
(
block_out[i % 2],
block_softmax_lse[i % 2],
rng_states[i + local_sp_size * ring_num_idx],
) = _forward(q1, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record()
block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
)
out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse(
out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2]
)
else:
kv_block = kv_buffers[i % 2][:, half_idx_front]
(
block_out[i % 2],
block_softmax_lse[i % 2],
rng_states[i + local_sp_size * ring_num_idx],
) = _forward(q, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record()
block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
)
out, softmax_lse = _rescale_out_lse(
out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2]
)
torch.cuda.current_stream().wait_stream(sp_stream)
return out, softmax_lse
# Send and recv KV between rings at once to maximize NIC util.
inter_ring_kv = None
for ring_num_idx in range(num_rings):
if ring_num_idx > 0:
inter_ring_comm.wait()
# Reset indices
kv_buffers[0] = inter_ring_kv
if ring_num_idx < num_rings - 1:
if ring_num_idx == 0:
to_send = kv_buffers[0]
else:
# The last received KV
to_send = kv_buffers[(local_sp_size - 1) % 2]
inter_ring_kv = inter_ring_comm.send_recv(to_send)
if ring_num_idx == 0:
out, softmax_lse = _local_ring_forward()
else:
out, softmax_lse = _other_ring_forward(ring_num_idx, out, softmax_lse)
out = out.to(q.dtype)
if not is_packed:
out = out.view(b, sq, h, d)
q, k, v = [x.view(b, sq, *x.shape[-2:]) for x in (q, k, v)] # (T, H, D) -> (B, Sq, H, D)
softmax_lse = softmax_lse.squeeze(-1)
ctx.sp_group = sp_group
ctx.max_seqlen_q = ctx.max_seqlen_kv = max_seqlen
misc_kwargs["deterministic"] = deterministic
del misc_kwargs["return_softmax"]
ctx.misc_kwargs = misc_kwargs
ctx.is_packed = is_packed
ctx.kv_group = inner_ring_group
ctx.inter_kv_group = inter_ring_group
ctx.save_for_backward(
q,
k,
v,
out,
softmax_lse.transpose(0, 1).contiguous(), # (T, H) -> (H, T)
cu_seqlens_q,
cu_seqlens_kv,
half_idx_front,
half_idx_back,
*rng_states,
)
if return_softmax:
return out, softmax_lse
return out, None
def backward(ctx, dout, _):
"""
During backward, we accumulate q grads on each rank locally, but iterate kv and their grads
over all ranks for accumulation.
"""
(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9]
rng_states = ctx.saved_tensors[9:]
is_packed = ctx.is_packed
max_seqlen_q = ctx.max_seqlen_q
max_seqlen_kv = ctx.max_seqlen_kv
cu_seqlens_half = cu_seqlens_q // 2
max_seqlen_half = max_seqlen_q // 2
misc_kwargs = ctx.misc_kwargs
del misc_kwargs["block_table"]
assert (
out.shape == dout.shape == q.shape
), f"out {out.shape} and dout {dout.shape} should have the same shape ({q.shape})."
if is_packed:
t, h, d = q.shape
else:
b, sq, h, d = q.shape
t = b * sq
q, k, v, out, dout = [x.view(t, *x.shape[-2:]) for x in (q, k, v, out, dout)]
# Sequence parallel args
sp_group = ctx.sp_group
local_kv_group = ctx.kv_group
inter_kv_group = ctx.inter_kv_group
local_sp_rank = dist.get_rank(sp_group)
sp_size = dist.get_world_size(sp_group)
# Using separate streams (pg) for concurrent kv and dkv comm may
# cause NCCL "software caused connection abort" here...
local_kv_comm = RingComm(local_kv_group)
local_dkv_comm = RingComm(local_kv_group)
inter_kv_comm = RingComm(inter_kv_group)
inter_dkv_comm = RingComm(inter_kv_group)
local_sp_size = dist.get_world_size(local_kv_group)
local_sp_rank = dist.get_rank(local_kv_group)
if dist.get_world_size(inter_kv_group) != sp_size:
num_rings = dist.get_world_size(inter_kv_group)
inter_ring_rank = dist.get_rank(inter_kv_group)
else:
num_rings = 1
inter_ring_rank = 0
if local_sp_rank != sp_size - 1:
softmax_lse1 = softmax_lse[:, half_idx_back]
dout = dout.contiguous()
# Double comm buffers for sending and receiving kv
kv_buffers = [torch.stack((k, v))] # (2, T, H, D)
kv_buffers.append(torch.empty_like(kv_buffers[0]))
dq = None # (T, H, D)
# Intermediate outputs
dq_block = torch.empty_like(q) # (T, H, D)
dk_block = torch.empty_like(k) # (T, H, D)
dv_block = torch.empty_like(v) # (T, H, D)
dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D)
del k, v
def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal):
_flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q if dq.shape[0] == t else cu_seqlens_half,
cu_seqlens_kv if dk.shape[0] == t else cu_seqlens_half,
max_seqlen_q if dq.shape[0] == t else max_seqlen_half,
max_seqlen_kv if dk.shape[0] == t else max_seqlen_half,
causal=causal,
rng_state=rng_state,
**misc_kwargs,
)
# NOTE: We avoid using two streams due to doubled buffers
# and that backward is more communication intensive.
def _local_ring_backward():
for i in range(local_sp_size):
if i > 0:
local_kv_comm.wait()
if i < local_sp_size - 1:
# Send kv to next rank for backward
local_kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
if i == 0:
# Backward with local kv
k_, v_ = kv_buffers[i % 2]
q_, dout_, out_ = q, dout, out
dq_, dk_, dv_ = dq_block, dk_block, dv_block
_backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_states[i], causal=True)
elif i <= local_sp_rank:
# Drop the second half of kv
# (T, H, D) -> (T // 2, H, D)
k_, v_ = [x[half_idx_front] for x in kv_buffers[i % 2]]
dk_, dv_ = [x[: t // 2] for x in (dk_block, dv_block)]
dq_, q_, out_, dout_ = (dq_block, q, out, dout)
_backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_states[i], causal=False)
else:
# Drop the first half of q
k_, v_ = kv_buffers[i % 2]
dk_, dv_ = dk_block, dv_block
q_, out_, dout_ = [x[half_idx_back] for x in (q, out, dout)]
dq_ = dq_block[: t // 2]
_backward(dout_, q_, k_, v_, out_, softmax_lse1, dq_, dk_, dv_, rng_states[i], causal=False)
# Accumulate grads
if i == 0:
dq = dq_block.float()
dkv_buffers[i % 2][0] = dk_block.float()
dkv_buffers[i % 2][1] = dv_block.float()
else:
# Accumulate local dq
if i <= local_sp_rank:
dq += dq_ # (T, H, D)
else:
dq[half_idx_back] += dq_
# Wait for mobile kv grad accumulators
local_dkv_comm.wait()
if i <= local_sp_rank:
# q blocks "surrounded" by kv blocks
dkv_buffers[i % 2][0][half_idx_front] += dk_
dkv_buffers[i % 2][1][half_idx_front] += dv_
else:
# q blocks "surrounding" kv blocks
dkv_buffers[i % 2][0] += dk_
dkv_buffers[i % 2][1] += dv_
local_dkv_comm.send_recv(send_tensor=dkv_buffers[i % 2], recv_tensor=dkv_buffers[(i + 1) % 2])
local_dkv_comm.wait()
dkv_recv = dkv_buffers[local_sp_size % 2]
dkv_send = dkv_buffers[(local_sp_size - 1) % 2]
return dq, dkv_recv, dkv_send
def _other_ring_backward(ring_num_idx, dq):
if ring_num_idx > inter_ring_rank:
# Indexing is expensive
q_, out_, dout_ = [x[half_idx_back] for x in (q, out, dout)]
else:
q_, out_, dout_ = (q, out, dout)
for i in range(local_sp_size):
if i > 0:
local_kv_comm.wait()
if i < local_sp_size - 1:
local_kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
rng_state = rng_states[i + local_sp_size * ring_num_idx]
if ring_num_idx > inter_ring_rank:
k_, v_ = kv_buffers[i % 2]
dk_, dv_ = dk_block, dv_block
dq_ = dq_block[: t // 2]
_backward(dout_, q_, k_, v_, out_, softmax_lse1, dq_, dk_, dv_, rng_state, causal=False)
dq[half_idx_back] += dq_
if i > 0:
local_dkv_comm.wait()
else:
inter_dkv_comm.wait()
dkv_buffers[i % 2][0] += dk_
dkv_buffers[i % 2][1] += dv_
else:
k_, v_ = [x[half_idx_front] for x in kv_buffers[i % 2]]
dk_, dv_ = [x[: t // 2] for x in (dk_block, dv_block)]
dq_ = dq_block
_backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_state, causal=False)
dq += dq_
if i > 0:
local_dkv_comm.wait()
else:
inter_dkv_comm.wait()
dkv_buffers[i % 2][0][half_idx_front] += dk_
dkv_buffers[i % 2][1][half_idx_front] += dv_
local_dkv_comm.send_recv(send_tensor=dkv_buffers[i % 2], recv_tensor=dkv_buffers[(i + 1) % 2])
local_dkv_comm.wait()
dkv_recv = dkv_buffers[local_sp_size % 2]
dkv_send = dkv_buffers[(local_sp_size - 1) % 2]
return dq, dkv_recv, dkv_send
inter_ring_kv = None
for ring_num_idx in range(num_rings):
if ring_num_idx > 0:
inter_kv_comm.wait()
kv_buffers[0] = inter_ring_kv
if ring_num_idx < num_rings - 1:
# Re-allocate a buffer in each inter-ring step
inter_ring_kv = inter_kv_comm.send_recv(kv_buffers[0])
if ring_num_idx == 0:
dq, dkv_recv, dkv_send = _local_ring_backward()
else:
dq, dkv_recv, dkv_send = _other_ring_backward(ring_num_idx, dq)
if num_rings > 1:
# Reuse the local buffers
inter_dkv_comm.send_recv(send_tensor=dkv_recv, recv_tensor=dkv_send)
# Reset indices
dkv_buffers[0] = dkv_send
dkv_buffers[1] = dkv_recv
if ring_num_idx == num_rings - 1:
inter_dkv_comm.wait()
dkv_recv = dkv_buffers[0]
dq, dk, dv = [x.to(q.dtype) for x in (dq, *dkv_recv)]
if not is_packed:
dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)]
return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None)
@staticmethod
def prepare_varlen_batch(
attention_mask: torch.Tensor,
sp_group: dist.ProcessGroup,
inputs_embeds: torch.Tensor = None,
position_ids: Optional[torch.Tensor] = None,
is_label: bool = False,
is_2d: bool = True,
):
"""
Preprocess a batch of padded sequence by splitting input sequence by sp_size
sequence-wise and packing them into one sequence. Updates the mask info accordingly.
Args:
attention_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked.
sp_group (dist.ProcessGroup): Process group for sequence parallelism
inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...]
position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None.
is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first
token of each sequence.
is_2d (bool, optional): Whether to return 2D outputs padded to max_seqlen // sp_size or flatten
the batch dim to a packed 1d sequence. Contingent on model forward shape definitions.
Returns:
inputs_embeds: Packed input embeddings of shape [B, Sq // sp_size, ...].
mask_info: A dictionary of mask info.
position_ids: Packed position ids of shape [..., Sq // sp_size].
"""
_load_varlen_helpers()
sp_size = dist.get_world_size(group=sp_group)
sp_rank = dist.get_rank(group=sp_group)
mask_info = {}
mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(attention_mask, return_indices=False)
# Unpad, split seq-wise, then pad back to (B, max_seqlen // sp_size)
# Split mask to compute local nonzero position indices
# (B, Sq) -> (B, max_seqlen // sp_size)
attention_mask = attention_mask[:, : mask_info["max_seqlen"]]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]]
inputs_embeds = split_varlen_zigzag(
inputs_embeds,
mask_info["cu_seqlens"],
sp_group,
mask_info["max_seqlen"],
is_2d=is_2d,
is_label=is_label,
)
attention_mask = split_varlen_zigzag(
attention_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_2d=is_2d
)
if position_ids is not None:
indices = torch.tensor([sp_rank, 2 * sp_size - sp_rank - 1], device=inputs_embeds.device)
position_ids = (
position_ids[..., : mask_info["max_seqlen"]] # unpad
.view(-1, sp_size * 2, mask_info["max_seqlen"] // (sp_size * 2))
.index_select(-2, indices)
.view(-1, mask_info["max_seqlen"] // sp_size)
)
mask_info["max_seqlen"] //= sp_size
mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
mask_info["cu_seqlens"] //= sp_size
mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
return inputs_embeds, mask_info, position_ids

View File

@ -200,9 +200,7 @@ class Linear1D_Col(ParallelModule):
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode is None: if self.seq_parallel_mode == "split_gather":
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
elif self.seq_parallel_mode == "split_gather":
input_parallel = gather_forward_reducescatter_backward( input_parallel = gather_forward_reducescatter_backward(
input_parallel, self.process_group, self.seq_parallel_dim input_parallel, self.process_group, self.seq_parallel_dim
) )
@ -211,6 +209,8 @@ class Linear1D_Col(ParallelModule):
output_parallel = linear_gather_forward_reducescatter_backward( output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
) )
else:
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
@ -416,10 +416,7 @@ class Linear1D_Row(ParallelModule):
handle.wait() handle.wait()
output = torch.cat(output_parallel_list, dim=-1) output = torch.cat(output_parallel_list, dim=-1)
else: else:
if self.seq_parallel_mode is None: if self.seq_parallel_mode == "split_gather":
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output = reduce_forward(output_parallel, self.process_group)
elif self.seq_parallel_mode == "split_gather":
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output = reducescatter_forward_gather_backward( output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim output_parallel, self.process_group, self.seq_parallel_dim
@ -432,6 +429,9 @@ class Linear1D_Row(ParallelModule):
dim=self.seq_parallel_dim, dim=self.seq_parallel_dim,
ring=True, ring=True,
) )
else:
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add: if not self.skip_bias_add:
if self.bias is not None: if self.bias is not None:

View File

@ -4,10 +4,15 @@ from torch.autograd import Function
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from colossalai.shardformer.layer._operation import reduce_forward
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from .utils import is_share_sp_tp
__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"] __all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]
_IGNORE_IDX = -100
class DistCrossEntropy(Function): class DistCrossEntropy(Function):
r""" r"""
@ -26,11 +31,12 @@ class DistCrossEntropy(Function):
process_group: ProcessGroup, process_group: ProcessGroup,
vocab_size: int, vocab_size: int,
dtype=torch.float32, dtype=torch.float32,
mode="mean",
): ):
r""" r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows: Calculate the cross entropy loss before gather, the origin loss function is as follows:
loss = -log(exp(x[class])/sum(exp(x[i])) loss = -log(exp(x[class])/sum(exp(x[i]))
and can be rewrite as: and can be rewriten as:
loss = log(sum(exp(x[i])) - x[class] loss = log(sum(exp(x[i])) - x[class]
To avoid the `nan` of log(sum(exp(x[i]))), we minus the max of x[i] To avoid the `nan` of log(sum(exp(x[i]))), we minus the max of x[i]
@ -44,12 +50,10 @@ class DistCrossEntropy(Function):
Returns: Returns:
:class:`torch.Tensor`: The cross entropy loss :class:`torch.Tensor`: The cross entropy loss
""" """
assert mode in ["mean", "sum"]
# get the max # get the max
logits_max = torch.max(vocab_logits, dim=-1)[0] logits_max = torch.max(vocab_logits, dim=-1)[0]
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group) handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True)
# minus the max to avoid the result of sum of exp is too large and the log is nan
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
# mask the target in the local device # mask the target in the local device
rank = dist.get_rank(group=process_group) rank = dist.get_rank(group=process_group)
@ -70,24 +74,25 @@ class DistCrossEntropy(Function):
mask = (target < down_threshold) | (target >= up_threshold) mask = (target < down_threshold) | (target >= up_threshold)
masked_target = target.clone() - down_threshold masked_target = target.clone() - down_threshold
masked_target[mask] = 0 masked_target[mask] = 0
masked_target_1d = masked_target.view(-1).contiguous()
# minus the max to avoid the result of sum of exp is too large and the log is nan
handle.wait()
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
# reshape the logits and target # reshape the logits and target
# reshape the vocab_logits to [bath_size * seq_len, vocab_size] # reshape the vocab_logits to [bath_size * seq_len, vocab_size]
# reshape the labels to [bath_size * seq_len] # reshape the labels to [bath_size * seq_len]
self_vocab_size = vocab_logits.size()[-1] self_vocab_size = vocab_logits.size()[-1]
logits_2d = vocab_logits.view(-1, self_vocab_size) logits_2d = vocab_logits.view(-1, self_vocab_size)
masked_target_1d = masked_target.view(-1)
# extract the x[class] and set the x[other device] to zero # extract the x[class] and set the x[other device] to zero
pred_logits_1d = logits_2d[ idx = torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device)
torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device), masked_target_1d pred_logits_1d = logits_2d[idx, masked_target_1d].contiguous()
]
pred_logits_1d = pred_logits_1d.clone().contiguous()
pred_logits = pred_logits_1d.view_as(target) pred_logits = pred_logits_1d.view_as(target)
pred_logits[mask] = 0.0 pred_logits[mask] = 0.0
# allreduce the get all x(i,y) # all-reduce to get full x[i, y]
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group) handle = dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group, async_op=True)
exp_logits = vocab_logits exp_logits = vocab_logits
torch.exp(vocab_logits, out=exp_logits) torch.exp(vocab_logits, out=exp_logits)
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32)
@ -95,23 +100,29 @@ class DistCrossEntropy(Function):
# calculate the loss # calculate the loss
# loss = log(sum(exp(x[i]))) - x[class] # loss = log(sum(exp(x[i]))) - x[class]
handle.wait()
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits) loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
num_non_zero = torch.sum(loss != 0.0) if mode == "mean":
ctx.inv_num_non_zero = 1.0 / num_non_zero num_non_zero = torch.sum(loss != 0.0)
loss = torch.sum(loss).div_(num_non_zero) ctx.inv_num_non_zero = 1.0 / num_non_zero
loss = torch.sum(loss).div_(num_non_zero)
else:
loss = torch.sum(loss)
# calculate the softmax # calculate the softmax
exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype) exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype)
exp_logits[target == ignore_index] = 0.0 exp_logits[target == ignore_index] = 0.0
ctx.save_for_backward(exp_logits, mask, masked_target_1d) ctx.save_for_backward(exp_logits, mask, masked_target_1d)
ctx.dtype = dtype ctx.dtype = dtype
ctx.mode = mode
return loss return loss
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
# retrieve the saved tensors # retrieve the saved tensors
grad_output = grad_output * ctx.inv_num_non_zero if ctx.mode == "mean":
grad_output = grad_output * ctx.inv_num_non_zero
exp_logits, mask, masked_target_1d = ctx.saved_tensors exp_logits, mask, masked_target_1d = ctx.saved_tensors
# use exp logits as the input grad # use exp logits as the input grad
@ -123,55 +134,113 @@ class DistCrossEntropy(Function):
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
grad_logits.mul_(grad_output.unsqueeze(dim=-1)) grad_logits.mul_(grad_output.unsqueeze(dim=-1))
return grad_logits, None, None, None, None, None return grad_logits, None, None, None, None, None, None
def cross_entropy_1d( def cross_entropy_1d(
vocab_logits: torch.Tensor, vocab_logits: torch.Tensor,
labels: torch.Tensor, labels: torch.Tensor,
ignore_index: int = -100, ignore_index: int = _IGNORE_IDX,
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
vocab_size: int = None, vocab_size: int = None,
dtype: torch.dtype = None, dtype: torch.dtype = None,
mode: str = "mean",
) -> torch.Tensor: ) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype) return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode)
def dist_cross_entropy( def dist_cross_entropy(
labels: torch.Tensor, labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
logits: torch.Tensor, logits: torch.Tensor, # [B, S, Vocab_size]
shard_config: ShardConfig, shard_config: ShardConfig,
out_features: int, out_features: int,
vocab_size: int, vocab_size: int,
dtype: torch.dtype, dtype: torch.dtype,
seq_dim: int = 1,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Helper to compute cross entropy loss for most shardformer models, Helper to compute cross entropy loss for most shardformer models supporting PP, TP and SP.
compatible with PP, TP and SP.
""" """
if labels is not None: # Split labels if not gather output
# Shift so that tokens < n predict n sp_group = shard_config.sequence_parallel_process_group
shift_logits = logits[..., :-1, :].contiguous() sp_rank = dist.get_rank(sp_group)
shift_labels = labels[..., 1:].contiguous() sp_size = shard_config.sequence_parallel_size
# Flatten the tokens sp_mode = shard_config.sequence_parallelism_mode
loss_fct = CrossEntropyLoss() parallel_output = shard_config.parallel_output
shift_labels = shift_labels.view(-1) is_tp = shard_config.enable_tensor_parallelism
shift_labels = shift_labels.to(shift_logits.device) is_packed = labels.dim() == 2
if shard_config.enable_tensor_parallelism and shard_config.parallel_output: if is_packed:
# Cross entropy with all-reduce for TP bs, seq_len = labels.shape
new_vocab_size = logits.shape[-1] else:
shift_logits = shift_logits.view(-1, new_vocab_size) # padded sequence
loss = cross_entropy_1d( seq_len = labels.shape[-1]
shift_logits, logits = logits.reshape(-1, *logits.shape[2:])
shift_labels, seq_dim = 0
process_group=shard_config.tensor_parallel_process_group,
vocab_size=out_features,
dtype=dtype,
)
else:
# NOTE if use TP and not parallel_output, the output is gathered.
# see VocabParallelLMHead1D
shift_logits = shift_logits.view(-1, vocab_size)
loss = loss_fct(shift_logits, shift_labels)
return loss # Shift labels to predict the next token, and remove the tail logit predicting <EOS>
is_sp = sp_size > 1 and (not is_share_sp_tp(sp_mode))
split_labels_here = seq_len // sp_size == logits.size(seq_dim) # ring attn splits labels before forward
if sp_mode == "ring_attn":
# For Zigzag Ring Attention, labels should've been split and
# shifted by RingAttention.prepare_varlen_batch()
if sp_rank == 0:
logits = logits[..., :-1, :]
logits = torch.cat([logits, torch.full_like(logits[:, :1, :], _IGNORE_IDX)], dim=seq_dim)
elif is_sp:
# Shift only once: either before splitting or in the last rank without splitting
if split_labels_here or (sp_rank == sp_size - 1):
labels = labels[..., 1:]
if split_labels_here:
labels = labels.split(seq_len // sp_size, dim=-1)[sp_rank]
if sp_rank == sp_size - 1:
logits = logits[..., :-1, :]
# Pad logits and labels to the same shape across all ranks for TP all_reduce
if is_tp and parallel_output:
# If is packed sequence (label dim is 1), then each seq already has the end label token padded.
# torch.cat is faster than F.pad...
pad_shape = (logits.shape[0], 1, *logits.shape[2:]) if is_packed else (1, *logits.shape[1:])
padding = torch.full(pad_shape, _IGNORE_IDX, dtype=logits.dtype, device=logits.device)
logits = torch.cat([logits, padding], dim=seq_dim)
pad_shape = (labels.shape[0], 1) if is_packed else (1,)
padding = torch.full(pad_shape, _IGNORE_IDX, dtype=labels.dtype, device=labels.device)
labels = torch.cat([labels, padding], dim=seq_dim)
else:
labels = labels[..., 1:]
logits = logits[..., :-1, :]
labels = labels.contiguous()
logits = logits.contiguous()
num_nonzero = (labels != _IGNORE_IDX).sum()
assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}"
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=_IGNORE_IDX, reduction="sum")
labels = labels.view(-1)
if is_tp and parallel_output:
# Cross entropy with all-reduce for TP
new_vocab_size = logits.shape[-1]
logits = logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
logits,
labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=out_features,
dtype=dtype,
mode="sum",
)
else:
# NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D
logits = logits.view(-1, vocab_size)
loss = loss_fct(logits, labels)
# Reduce loss instead of gathering logits over seq dim for savings
if split_labels_here or sp_mode == "ring_attn":
# Get the global non-zero count
loss = torch.stack((loss, num_nonzero))
# Rescale to offset the grad / (DP * SP) in HybridParallelPlugin
loss = reduce_forward(loss, sp_group, grad_scale=sp_size)
loss, num_nonzero = loss[0], loss[1].detach()
loss = (loss / num_nonzero).squeeze()
return loss

View File

@ -1,5 +1,5 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import List from typing import List, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -289,3 +289,199 @@ def create_randomizer_with_offset(
Randomizer.increment_index() Randomizer.increment_index()
return Randomizer(seed=base_seed) return Randomizer(seed=base_seed)
def split_batch_zigzag(
batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask
in the causal setting will result in the preceding ranks having much less workload.
We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |.
Args:
batch (List[torch.Tensor] or Tensor): The input tensor(s) to split.
sp_group (ProcessGroup): The process group for sequence parallelism.
seq_dim (int): The sequence dimension to split.
is_label (bool): If True, mask and shift the tensor for next token prediction.
"""
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
if isinstance(batch, torch.Tensor):
batch = [batch]
seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1
if sp_size > 1:
for idx, tensor in enumerate(batch):
assert (
tensor.shape[seq_dim] // (sp_size * 2) > 1 and tensor.shape[seq_dim] % (sp_size * 2) == 0
), f"Bro, the seq length {tensor.shape[seq_dim]} for tensor {idx} can't be split by {sp_size * 2}!"
if is_label:
assert tensor.dim() == 2, "Label shape should be (B, Seqlen)"
tensor = torch.cat([tensor[:, 1:], torch.full_like(tensor[:, :1], -100)], dim=1)
tensor = tensor.view(
*tensor.shape[:seq_dim],
2 * sp_size,
tensor.shape[seq_dim] // (2 * sp_size),
*tensor.shape[seq_dim + 1 :],
)
indices = torch.tensor([sp_rank, 2 * sp_size - 1 - sp_rank], device=tensor.device)
tensor = tensor.index_select(seq_dim, indices).contiguous()
# (B, 2, Sq // (2 * sp_size), ...) -> (B, Sq // sp_size, ...)
batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :])
if len(batch) == 1:
return batch[0]
return batch
def split_varlen_zigzag(
batch: Union[List[torch.Tensor], torch.Tensor],
cu_seqlens: torch.Tensor,
sp_group: ProcessGroup,
max_seqlen: int = 0,
is_2d: bool = False,
is_label: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
"""Split each sequence in a batch of packed sequences in a zigzag fashion.
For each tensor in batch, return packed sequences if is_2d is False;
else return a padded batch of sequences.
Args:
batch (List[torch.Tensor]): Packed sequences of shape (B * Sq, ...), or (B, Sq, ...) if is_2d.
cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting.
sp_group (ProcessGroup): The process group for sequence parallelism.
max_seqlen (int): The maximum sequence length in the batch before splitting.
is_2d (bool): If True, then input has batch size and sequence length split into two dimensions.
is_label (bool): If True, mask out the first token in each sequence (<Start of Sentence>).
Returns:
batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size)
or (B, max_seqlen // sp_size, ...) if is_2d
"""
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
if is_2d:
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
if isinstance(batch, torch.Tensor):
batch = [batch]
for i, packed_seq in enumerate(batch):
device = packed_seq.device
dtype = packed_seq.dtype
if is_2d:
assert max_seqlen % (sp_size * 2) == 0
# Recreate a padded tensor with the new max seqlen
shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:])
local_seq = torch.zeros(shape, dtype=dtype, device=device)
else:
total_seqlen = cu_seqlens[-1]
assert (
total_seqlen % (2 * sp_size) == 0
), f"total_seqlen {total_seqlen} must be divisible by 2 * sp_size = {2 * sp_size}"
local_seq = []
for j in range(len(cu_seqlens) - 1):
start, end = cu_seqlens[j], cu_seqlens[j + 1]
seqlen = end - start
assert (
seqlen % (2 * sp_size) == 0
), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting"
if is_2d:
seq = packed_seq[j][:seqlen]
if is_label:
# Shift one position to the right for next token prediction
seq = torch.cat([seq[1:], torch.tensor([-100], dtype=dtype, device=device)])
seq = seq.chunk(2 * sp_size, dim=0)
half = seqlen // sp_size // 2
local_seq[j][:half] = seq[sp_rank]
local_seq[j][half : seqlen // sp_size] = seq[2 * sp_size - 1 - sp_rank]
else:
seq = packed_seq[start:end]
if is_label:
seq = torch.cat(seq[1:], torch.tensor([-100], dtype=dtype, device=device))
seq = seq.chunk(sp_size * 2)
local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]])
if is_2d:
batch[i] = local_seq.contiguous()
else:
batch[i] = torch.cat(local_seq, dim=0)
if len(batch) == 1:
batch = batch[0]
return batch
def is_share_sp_tp(sp_mode: str):
"""sp_mode "ring" and "split_gather" use the TP group as SP group
to split both the vocab and sequence, so we must gather the sequence
to correctly get logits at each positions.
"""
return sp_mode in ["ring", "split_gather"]
class RingComm:
def __init__(self, process_group: dist.ProcessGroup):
self._process_group = process_group
self._ops = []
self.rank = dist.get_rank(self._process_group)
self.world_size = dist.get_world_size(self._process_group)
self._reqs = []
self.send_rank = (self.rank + 1) % self.world_size
self.recv_rank = (self.rank - 1) % self.world_size
self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
def send_recv(
self,
send_tensor: torch.Tensor,
recv_tensor: Optional[torch.Tensor] = None,
commit: bool = True,
) -> torch.Tensor:
if recv_tensor is None:
res = torch.empty_like(send_tensor)
else:
res = recv_tensor
# looks like batch_isend_irecv doesn't deadlock even
# when we don't swap send recv ops based on rank
send_op = dist.P2POp(dist.isend, send_tensor, self.send_rank, group=self._process_group)
recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
self._ops.extend([send_op, recv_op])
if commit:
self._reqs = dist.batch_isend_irecv(self._ops)
return res
def commit(self):
assert len(self._ops) > 0, "No ops to commit"
self._reqs = dist.batch_isend_irecv(self._ops)
def wait(self):
assert len(self._reqs) > 0, "No requests to wait for"
for req in self._reqs:
req.wait()
self._reqs = []
self._ops = []
@torch.jit.script
def get_half_index(cu_seqlens, *, front: bool):
index = torch.zeros(cu_seqlens[-1], dtype=torch.bool, device=cu_seqlens.device)
for i in range(len(cu_seqlens) - 1):
start, end = cu_seqlens[i], cu_seqlens[i + 1]
if front:
end = (start + end) // 2
else:
start = (start + end) // 2
index[start:end] = True
return index

View File

@ -26,6 +26,8 @@ from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, dist_cross_entropy from ..layer import ColoAttention, dist_cross_entropy
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
class CommandPipelineForwards: class CommandPipelineForwards:
""" """
@ -349,7 +351,7 @@ class CommandPipelineForwards:
return {"hidden_states": hidden_states} return {"hidden_states": hidden_states}
def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -362,7 +364,7 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if sp_mode is not None: if sp_mode is not None:
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet"
assert (sp_size is not None) and ( assert (sp_size is not None) and (
sp_group is not None sp_group is not None
), "Must specify sp_size and sp_group for sequence parallel" ), "Must specify sp_size and sp_group for sequence parallel"
@ -459,7 +461,7 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
return forward return forward
def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def forward( def forward(

View File

@ -1,8 +1,9 @@
import math import math
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.distributed
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
@ -24,14 +25,14 @@ from transformers.models.llama.modeling_llama import (
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import ( from colossalai.shardformer.layer import AttnMaskType
all_to_all_comm, from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward
gather_forward_split_backward, from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
split_forward_gather_backward,
)
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, dist_cross_entropy from ..layer import ColoAttention, RingAttention, dist_cross_entropy
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
class LlamaPipelineForwards: class LlamaPipelineForwards:
@ -57,6 +58,10 @@ class LlamaPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None, shard_config: ShardConfig = None,
# Split output only when computing cross entropy using llama_for_causal_lm_forward
# or get_lm_forward_with_dist_cross_entropy
# Default to True to avoid bug when calling classification forward from huggingface
force_sp_output_gather: bool = True,
): ):
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -97,7 +102,7 @@ class LlamaPipelineForwards:
sp_group = shard_config.sequence_parallel_process_group sp_group = shard_config.sequence_parallel_process_group
sp_size = shard_config.sequence_parallel_size sp_size = shard_config.sequence_parallel_size
if sp_mode == "all_to_all" and not stage_manager.is_first_stage(): if sp_mode == "all_to_all" and not stage_manager.is_first_stage():
# For correct positions ids. The states will be gather along the seq dim in the attention layer later. # For generating full positions ids, as the states will be gather along the seq dim in the attention layer later.
seq_length *= sp_size seq_length *= sp_size
past_seen_tokens = 0 past_seen_tokens = 0
@ -127,22 +132,36 @@ class LlamaPipelineForwards:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
# embed positions, for the first stage, hidden_states is the input embeddings, # embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage # for the other stages, hidden_states is the output of the previous stage
if shard_config.enable_flash_attention: if not stage_manager.is_first_stage() and sp_mode == "ring_attn":
_, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group)
elif shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor # in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs( attn_kwargs = ColoAttention.prepare_attn_kwargs(
mask_shape, mask_shape,
hidden_states.dtype, hidden_states.dtype,
hidden_states.device, hidden_states.device,
q_padding_mask=attention_mask, q_padding_mask=attention_mask,
is_causal=True, is_causal=True,
invert=(sp_mode != "ring_attn"),
) )
else: else:
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) attn_kwargs = self._update_causal_mask(attention_mask, hidden_states, cache_position)
# Support SP + PP # Support SP + PP
# TODO: support padded casual cu_seqlens across stages
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
if sp_mode in ["ring", "split_gather"]: # Ring Attention zigzag batch processing
if sp_mode == "ring_attn":
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
attention_mask, sp_group, hidden_states, position_ids
)
else:
hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group)
elif is_share_sp_tp(sp_mode):
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
elif sp_mode == "all_to_all": elif sp_mode == "all_to_all":
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size) hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
@ -177,12 +196,11 @@ class LlamaPipelineForwards:
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if idx - start_idx < num_ckpt_layers: if idx - start_idx < num_ckpt_layers:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attn_kwargs,
position_ids, position_ids,
past_key_values, past_key_values,
output_attentions, output_attentions,
@ -192,14 +210,13 @@ class LlamaPipelineForwards:
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attn_kwargs,
position_ids=position_ids, position_ids=position_ids,
past_key_value=past_key_values, past_key_value=past_key_values,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:
@ -209,10 +226,8 @@ class LlamaPipelineForwards:
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather": if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
if output_hidden_states: if output_hidden_states:
@ -298,6 +313,15 @@ class LlamaPipelineForwards:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False output_hidden_states = False
if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output:
# Split labels in a zigzag fashion too
sp_group = shard_config.sequence_parallel_process_group
if attention_mask.bool().all():
labels = split_batch_zigzag(labels, sp_group, seq_dim=1)
else:
# [B, max_seqlen // sp_size]
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = LlamaPipelineForwards.llama_model_forward( outputs = LlamaPipelineForwards.llama_model_forward(
self.model, self.model,
@ -315,6 +339,7 @@ class LlamaPipelineForwards:
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config, shard_config=shard_config,
force_sp_output_gather=False,
) )
past_key_values = None past_key_values = None
@ -457,11 +482,11 @@ class LlamaPipelineForwards:
return {"hidden_states": hidden_states} return {"hidden_states": hidden_states}
def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[Union[torch.Tensor, Dict]] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
output_attentions: bool = False, output_attentions: bool = False,
@ -470,7 +495,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if sp_mode is not None: if sp_mode is not None:
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet"
assert (sp_size is not None) and ( assert (sp_size is not None) and (
sp_group is not None sp_group is not None
), "Must specify sp_size and sp_group for sequence parallel" ), "Must specify sp_size and sp_group for sequence parallel"
@ -481,7 +506,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
# sp: modify sp_len when sequence parallel mode is ring # sp: modify sp_len when sequence parallel mode is ring
if sp_mode in ["split_gather", "ring"]: if is_share_sp_tp(sp_mode):
q_len *= sp_size q_len *= sp_size
if self.config.pretraining_tp > 1: if self.config.pretraining_tp > 1:
@ -526,6 +551,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
) )
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
@ -537,12 +563,21 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
if shard_config.enable_flash_attention: if sp_mode == "ring_attn":
attn_output = RingAttention.attention(
query_states,
key_states,
value_states,
sp_group,
**attention_mask,
inner_ring_size=shard_config.inner_ring_size,
)
elif shard_config.enable_flash_attention:
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
else: else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError( raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
@ -588,7 +623,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
return forward return forward
def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def forward( def forward(
@ -603,6 +638,10 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
# Split output only when computing cross entropy using llama_for_causal_lm_forward
# or get_lm_forward_with_dist_cross_entropy
# Default to True to avoid bug when calling classification forward from huggingface
force_sp_output_gather: bool = True,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@ -629,32 +668,45 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=
past_seen_tokens = 0 past_seen_tokens = 0
seq_len = inputs_embeds.shape[1] seq_len = inputs_embeds.shape[1]
batch_size = inputs_embeds.shape[0]
if use_cache: # kept for BC (cache positions) if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache): if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length() past_seen_tokens = past_key_values.get_seq_length()
if cache_position is None: if cache_position is None:
if isinstance(past_key_values, StaticCache): if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.") raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device) cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device)
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
# in this case, attention_mask is a dict rather than a tensor
if shard_config.enable_flash_attention: if shard_config.enable_flash_attention:
mask_shape = (inputs_embeds.shape[0], 1, seq_len, past_seen_tokens + seq_len) mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len)
attention_mask = ColoAttention.prepare_attn_kwargs( attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
mask_shape, mask_shape,
inputs_embeds.dtype, inputs_embeds.dtype,
inputs_embeds.device, inputs_embeds.device,
q_padding_mask=attention_mask, q_padding_mask=attention_mask,
is_causal=True, is_causal=True,
invert=(sp_mode != "ring_attn"),
) )
else:
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
if sp_mode in ["ring", "split_gather"]: else:
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
# Ring Attention zigzag batch processing
if sp_mode == "ring_attn":
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
inputs_embeds, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
attention_mask, sp_group, inputs_embeds, position_ids
)
else:
inputs_embeds, position_ids = split_batch_zigzag([inputs_embeds, position_ids], sp_group)
attn_kwargs = {"attention_mask_type": attn_kwargs["attention_mask_type"]} # drop redundant tensors
elif is_share_sp_tp(sp_mode):
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
elif sp_mode == "all_to_all": elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
@ -672,7 +724,7 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attn_kwargs,
position_ids, position_ids,
past_key_values, past_key_values,
output_attentions, output_attentions,
@ -683,7 +735,7 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attn_kwargs,
position_ids=position_ids, position_ids=position_ids,
past_key_value=past_key_values, past_key_value=past_key_values,
output_attentions=output_attentions, output_attentions=output_attentions,
@ -700,11 +752,9 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
# Cases that don't support parallelizing cross entropy computation along sequence
if sp_mode == "ring" or sp_mode == "split_gather": if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather:
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
if output_hidden_states: if output_hidden_states:
@ -777,6 +827,15 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output:
# Special processing: Split labels in a zigzag fashion too
sp_group = shard_config.sequence_parallel_process_group
if attention_mask.bool().all():
labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
else:
# [B, max_seq_len // sp_size]
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model( outputs = self.model(
input_ids=input_ids, input_ids=input_ids,
@ -789,6 +848,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position, cache_position=cache_position,
force_sp_output_gather=False,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
@ -799,7 +859,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
else: else:
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = logits.float() logits = logits.float()
loss = dist_cross_entropy( loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
) )

View File

@ -75,6 +75,7 @@ class Policy(ABC):
def __init__(self) -> None: def __init__(self) -> None:
self.shard_config: Optional[ShardConfig] = None self.shard_config: Optional[ShardConfig] = None
self.model: Optional[Module] = None self.model: Optional[Module] = None
self.is_causal = None # Whether we're doing causal lm, i.e. using cross entropy
def set_model(self, model: nn.Module) -> None: def set_model(self, model: nn.Module) -> None:
r""" r"""

View File

@ -69,13 +69,18 @@ class CommandPolicy(Policy):
sp_size = self.shard_config.sequence_parallel_size or None sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"] sp_partial_derived = sp_mode in ["split_gather", "ring"]
if sp_mode == "ring_attn" and not self.is_causal:
raise ValueError("Ring attention is only meant for causal language modeling.")
tp_size = self.shard_config.tensor_parallel_size or None
num_q_heads = self.model.config.num_attention_heads
num_kv_heads = getattr(self.model.config, "num_key_value_heads", None)
if sp_mode == "all_to_all": if sp_mode == "all_to_all":
decoder_attribute_replacement = { num_q_heads //= sp_size
"num_heads": self.model.config.num_attention_heads // sp_size, decoder_attribute_replacement = {"num_heads": num_q_heads}
} if num_kv_heads:
if getattr(self.model.config, "num_key_value_heads", False): num_kv_heads //= sp_size
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
policy[attn_cls] = ModulePolicyDescription( policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement, attribute_replacement=decoder_attribute_replacement,
@ -104,21 +109,18 @@ class CommandPolicy(Policy):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
assert ( assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 num_q_heads % tp_size == 0
), f"The number of attention heads must be divisible by tensor parallel size." ), f"The number of attention heads must be divisible by tensor parallel size."
if hasattr(self.model.config, "num_key_value_heads"): if hasattr(self.model.config, "num_key_value_heads"):
assert ( assert (
self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size num_kv_heads >= tp_size and num_kv_heads % tp_size == 0
and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size." ), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
decoder_attribute_replacement = { decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attn.hidden_size": self.model.config.hidden_size // tp_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, "self_attn.num_heads": num_q_heads // tp_size,
} }
if getattr(self.model.config, "num_key_value_heads", False): if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads // tp_size
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
)
policy[CohereDecoderLayer] = ModulePolicyDescription( policy[CohereDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement, attribute_replacement=decoder_attribute_replacement,
@ -290,10 +292,11 @@ class CommandForCausalLMPolicy(CommandPolicy):
def module_policy(self): def module_policy(self):
from transformers import CohereForCausalLM from transformers import CohereForCausalLM
self.is_causal = True
policy = super().module_policy() policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm # add a new item for causal lm
new_item = { new_item = {
CohereForCausalLM: ModulePolicyDescription( CohereForCausalLM: ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[

View File

@ -298,7 +298,7 @@ class DeepseekForCausalLMPolicy(DeepseekPolicy):
policy = super().module_policy() policy = super().module_policy()
# TODO: assign pg mesh from plugin to all modules # TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm # add a new item for causal lm
new_item = { new_item = {
"DeepseekForCausalLM": ModulePolicyDescription( "DeepseekForCausalLM": ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[

View File

@ -69,13 +69,20 @@ class LlamaPolicy(Policy):
sp_size = self.shard_config.sequence_parallel_size or None sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"] sp_partial_derived = sp_mode in ["split_gather", "ring"]
if sp_mode == "ring_attn" and not self.is_causal:
raise ValueError("Ring attention is only meant for causal language modeling.")
tp_size = self.shard_config.tensor_parallel_size
# Modified by SP and TP
num_q_heads = self.model.config.num_attention_heads
num_kv_heads = getattr(self.model.config, "num_key_value_heads", None)
if sp_mode == "all_to_all": if sp_mode == "all_to_all":
decoder_attribute_replacement = { num_q_heads //= sp_size
"num_heads": self.model.config.num_attention_heads // sp_size, decoder_attribute_replacement = {"num_heads": num_q_heads}
} if num_kv_heads:
if getattr(self.model.config, "num_key_value_heads", False): num_kv_heads //= sp_size
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
policy[attn_cls] = ModulePolicyDescription( policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement, attribute_replacement=decoder_attribute_replacement,
@ -104,21 +111,20 @@ class LlamaPolicy(Policy):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
assert ( assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 num_q_heads % tp_size == 0
), f"The number of attention heads must be divisible by tensor parallel size." ), f"The number of attention heads must be divisible by tensor parallel size."
if hasattr(self.model.config, "num_key_value_heads"): if hasattr(self.model.config, "num_key_value_heads"):
assert ( assert (
self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size num_kv_heads >= tp_size and num_kv_heads % tp_size == 0
and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size." ), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
num_q_heads //= tp_size
decoder_attribute_replacement = { decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attn.hidden_size": self.model.config.hidden_size // tp_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, "self_attn.num_heads": num_q_heads,
} }
if getattr(self.model.config, "num_key_value_heads", False): if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( num_kv_heads //= tp_size
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads
)
policy[LlamaDecoderLayer] = ModulePolicyDescription( policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement, attribute_replacement=decoder_attribute_replacement,
@ -295,10 +301,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
def module_policy(self): def module_policy(self):
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM
self.is_causal = True
policy = super().module_policy() policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm # add a new item for causal lm
new_item = { new_item = {
LlamaForCausalLM: ModulePolicyDescription( LlamaForCausalLM: ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[
@ -313,10 +320,6 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
], ],
) )
} }
if self.shard_config.parallel_output:
new_item[LlamaForCausalLM].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
else: else:
new_item = { new_item = {
LlamaForCausalLM: ModulePolicyDescription( LlamaForCausalLM: ModulePolicyDescription(
@ -336,7 +339,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy
) )
elif self.shard_config.enable_tensor_parallelism or self.shard_config.enable_sequence_parallelism:
# Compute loss distributedly along the sequence dimension
new_item[LlamaForCausalLM].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
return policy return policy
def get_held_layers(self) -> List[Module]: def get_held_layers(self) -> List[Module]:

View File

@ -271,7 +271,7 @@ class MistralForCausalLMPolicy(MistralPolicy):
policy = super().module_policy() policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm # add a new item for causal lm
new_item = { new_item = {
MistralForCausalLM: ModulePolicyDescription( MistralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[

View File

@ -275,7 +275,7 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
policy = super().module_policy() policy = super().module_policy()
# TODO: assign pg mesh from plugin to all modules # TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm # add a new item for causal lm
new_item = { new_item = {
MixtralForCausalLM: ModulePolicyDescription( MixtralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[

View File

@ -313,7 +313,7 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
setattr(self.shard_config, "causal_lm", True) setattr(self.shard_config, "causal_lm", True)
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm # add a new item for causal lm
new_item = { new_item = {
Qwen2ForCausalLM: ModulePolicyDescription( Qwen2ForCausalLM: ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[

View File

@ -10,7 +10,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from .grad_ckpt_config import GradientCheckpointConfig from .grad_ckpt_config import GradientCheckpointConfig
__all__ = ["ShardConfig"] __all__ = ["ShardConfig"]
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"]
@dataclass @dataclass
@ -29,6 +29,8 @@ class ShardConfig:
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None. gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
parallel_output (bool): For TP: whether to use parallelize cross entropy computation along the feature dim.
For SP: set to True to NOT gather the output along the seq dim.
""" """
tensor_parallel_process_group: Optional[ProcessGroup] = None tensor_parallel_process_group: Optional[ProcessGroup] = None
@ -47,10 +49,11 @@ class ShardConfig:
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict) extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# For ring attention
inner_ring_size: Optional[int] = None
# for moe related # for moe related
moe_dp_group: Optional[ProcessGroup] = None moe_dp_group: Optional[ProcessGroup] = None
ep_group: Optional[ProcessGroup] = None ep_group: Optional[ProcessGroup] = None
# pipeline_parallel_size: int # pipeline_parallel_size: int
# data_parallel_size: int # data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
@ -80,9 +83,9 @@ class ShardConfig:
self.enable_tensor_parallelism self.enable_tensor_parallelism
), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True" ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True"
elif self.sequence_parallelism_mode in ["all_to_all"]: elif self.sequence_parallelism_mode in ["all_to_all"]:
assert ( # assert (
not self.enable_tensor_parallelism # not self.enable_tensor_parallelism
), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False" # ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False"
if self.enable_sequence_overlap: if self.enable_sequence_overlap:
self.enable_sequence_overlap = False self.enable_sequence_overlap = False
warnings.warn( warnings.warn(

View File

@ -28,6 +28,7 @@ warnings.filterwarnings("ignore")
# Constants # Constants
# ============================== # ==============================
# We have lots of llamas for your choice!
MODEL_CONFIGS = { MODEL_CONFIGS = {
"100m": LlamaConfig( "100m": LlamaConfig(
max_position_embeddings=4096, max_position_embeddings=4096,
@ -36,6 +37,7 @@ MODEL_CONFIGS = {
intermediate_size=2048, intermediate_size=2048,
hidden_size=1024, hidden_size=1024,
), ),
"5b": LlamaConfig(max_position_embeddings=4096, num_key_value_heads=8),
"7b": LlamaConfig(max_position_embeddings=4096), "7b": LlamaConfig(max_position_embeddings=4096),
"13b": LlamaConfig( "13b": LlamaConfig(
hidden_size=5120, hidden_size=5120,
@ -68,9 +70,6 @@ def main():
default="gemini", default="gemini",
help="Choose which plugin to use", help="Choose which plugin to use",
) )
parser.add_argument(
"--overlap", action="store_true", help="Overlap communication with computation in Pipeline Parallel."
)
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size") parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size")
parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run") parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore") parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
@ -94,11 +93,24 @@ def main():
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
parser.add_argument("--profile", action="store_true", help="Profile the code", default=False) parser.add_argument("--profile", action="store_true", help="Profile the code")
parser.add_argument(
"--nsys",
action="store_true",
help="Use nsys for profiling. \
You should put something like this before colossalai launch: \
nsys profile -w true -t cuda,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --capture-range-end=stop --cudabacktrace=true -x true --python-backtrace=cuda -o prof_out",
)
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation") parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true") parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--overlap_allgather", action="store_true") parser.add_argument("--overlap_allgather", action="store_true")
parser.add_argument(
"--sp_mode",
default="all_to_all",
choices=["all_to_all", "ring_attn", "ring", "split_gather"],
help="Sequence parallelism mode",
)
args = parser.parse_args() args = parser.parse_args()
colossalai.launch_from_torch() colossalai.launch_from_torch()
@ -195,12 +207,12 @@ def main():
num_model_chunks=args.n_chunks, num_model_chunks=args.n_chunks,
zero_stage=args.zero, zero_stage=args.zero,
sp_size=args.sp, sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
enable_sequence_parallelism=args.sp > 1, enable_sequence_parallelism=args.sp > 1,
enable_fused_normalization=torch.cuda.is_available(), enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.xformers, enable_flash_attention=args.xformers,
microbatch_size=args.mbs, microbatch_size=args.mbs,
precision="bf16", precision="bf16",
overlap_p2p=args.overlap,
enable_metadata_cache=not args.no_cache, enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather, overlap_allgather=args.overlap_allgather,
**hybrid_kwargs, **hybrid_kwargs,
@ -218,7 +230,6 @@ def main():
microbatch_size=args.mbs, microbatch_size=args.mbs,
initial_scale=2**8, initial_scale=2**8,
precision="bf16", precision="bf16",
overlap_p2p=args.overlap,
) )
else: else:
raise ValueError(f"Unknown plugin {args.plugin}") raise ValueError(f"Unknown plugin {args.plugin}")
@ -295,6 +306,7 @@ def main():
args.ignore_steps, args.ignore_steps,
1, # avoid creating massive log files 1, # avoid creating massive log files
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
nsys=args.nsys,
) as prof: ) as prof:
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
data_iter = iter(dataloader) data_iter = iter(dataloader)
@ -320,13 +332,16 @@ def main():
performance_evaluator.on_step_start(step) performance_evaluator.on_step_start(step)
outputs = model(**batch) outputs = model(**batch)
loss = outputs[0] loss = outputs[0]
del outputs # free memory
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
booster.backward(loss, optimizer) booster.backward(loss, optimizer)
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
performance_evaluator.on_step_end(**batch) performance_evaluator.on_step_end(**batch)
prof.step() prof.step()
performance_evaluator.on_fit_end() performance_evaluator.on_fit_end()
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")

View File

@ -17,7 +17,7 @@ limitations under the License.
## OPT ## OPT
Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments. Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments.
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Causal Language Modelling at low cost.
## Our Modifications ## Our Modifications

View File

@ -28,7 +28,7 @@ def all_reduce_mean(x: float, world_size: int) -> float:
return tensor.item() return tensor.item()
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir, nsys=False):
class DummyProfiler: class DummyProfiler:
def __init__(self): def __init__(self):
self.step_number = 0 self.step_number = 0
@ -42,7 +42,29 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
pass pass
class NsysProfiler:
def __init__(self, warmup_steps, active_steps):
self.step_number = 0
self.warmup_steps = warmup_steps
self.active_steps = active_steps
def step(self):
if self.step_number == self.warmup_steps:
torch.cuda.cudart().cudaProfilerStart()
elif self.step_number == self.warmup_steps + self.active_steps:
torch.cuda.cudart().cudaProfilerStop()
self.step_number += 1
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
if enable_flag: if enable_flag:
if nsys:
return NsysProfiler(warmup_steps, active_steps)
return profile( return profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),

View File

@ -19,7 +19,7 @@ limitations under the License.
## OPT ## OPT
Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments. Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments.
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning causal Language Modelling at low cost.
We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling). the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).

View File

@ -57,14 +57,14 @@ class FlashAttentionDaoCudaExtension(_Extension):
q_indices: Optional[torch.Tensor] = None, q_indices: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None, kv_indices: Optional[torch.Tensor] = None,
): ):
# [B, N, S, D] -> [B, S, N, D] # [B, H, S, D] -> [B, S, H, D]
q = q.transpose(1, 2) q = q.transpose(1, 2)
k = k.transpose(1, 2) k = k.transpose(1, 2)
v = v.transpose(1, 2) v = v.transpose(1, 2)
b, s_q = q.shape[:2] b, s_q = q.shape[:2]
if cu_seqlens_q is not None: if cu_seqlens_q is not None:
# padded / padded causal # padded / padded causal
# unpad input: [B, S, N, D] -> [T, N, D] # unpad input: [B, S, H, D] -> [T, H, D]
q = _unpad_input(q, q_indices) q = _unpad_input(q, q_indices)
kv = _unpad_input(torch.stack(tensors=(k, v), dim=2), kv_indices) kv = _unpad_input(torch.stack(tensors=(k, v), dim=2), kv_indices)
attn_output = flash_attn_varlen_kvpacked_func( attn_output = flash_attn_varlen_kvpacked_func(
@ -78,7 +78,7 @@ class FlashAttentionDaoCudaExtension(_Extension):
softmax_scale=scale, softmax_scale=scale,
causal=is_causal, causal=is_causal,
) )
# pad output: [T, N, D] -> [B, S, N, D] # pad output: [T, H, D] -> [B, S, H, D]
attn_output = pad_input(attn_output, q_indices, b, s_q) attn_output = pad_input(attn_output, q_indices, b, s_q)
else: else:
# causal / no attn mask # causal / no attn mask
@ -90,7 +90,7 @@ class FlashAttentionDaoCudaExtension(_Extension):
softmax_scale=scale, softmax_scale=scale,
causal=is_causal, causal=is_causal,
) )
# [B, S, N, D] -> [B, N, S, D] # [B, S, H, D] -> [B, H, S, D]
return attn_output.transpose(1, 2) return attn_output.transpose(1, 2)
return flash_attention return flash_attention

View File

@ -22,9 +22,9 @@ COMMON_MODELS = [
"transformers_bloom_for_causal_lm", "transformers_bloom_for_causal_lm",
"transformers_falcon_for_causal_lm", "transformers_falcon_for_causal_lm",
"transformers_chatglm_for_conditional_generation", "transformers_chatglm_for_conditional_generation",
"transformers_llama_for_casual_lm", "transformers_llama_for_causal_lm",
"transformers_vit_for_masked_image_modeling", "transformers_vit_for_masked_image_modeling",
"transformers_mistral_for_casual_lm", "transformers_mistral_for_causal_lm",
] ]
IS_FAST_TEST = os.environ.get("FAST_TEST", "0") == "1" IS_FAST_TEST = os.environ.get("FAST_TEST", "0") == "1"

View File

@ -32,8 +32,8 @@ if HAS_COMMAND:
return dict(input_ids=input_ids, attention_mask=attention_mask) return dict(input_ids=input_ids, attention_mask=attention_mask)
# label is needed for casual lm # label is needed for causal lm
def data_gen_for_casual_lm(): def data_gen_for_causal_lm():
data = data_gen() data = data_gen()
labels = data["input_ids"].clone() labels = data["input_ids"].clone()
data["labels"] = labels data["labels"] = labels
@ -44,7 +44,7 @@ if HAS_COMMAND:
# function to get the loss # function to get the loss
loss_fn = lambda output: output["last_hidden_state"].mean() loss_fn = lambda output: output["last_hidden_state"].mean()
loss_fn_for_casual_lm = lambda output: output["loss"] loss_fn_for_causal_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output["logits"].mean() loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = CohereConfig( config = CohereConfig(
@ -70,10 +70,10 @@ if HAS_COMMAND:
model_attribute=ModelAttribute(has_control_flow=True), model_attribute=ModelAttribute(has_control_flow=True),
) )
model_zoo.register( model_zoo.register(
name="transformers_command_for_casual_lm", name="transformers_command_for_causal_lm",
model_fn=lambda: transformers.CohereForCausalLM(config), model_fn=lambda: transformers.CohereForCausalLM(config),
data_gen_fn=data_gen_for_casual_lm, data_gen_fn=data_gen_for_causal_lm,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_casual_lm, loss_fn=loss_fn_for_causal_lm,
model_attribute=ModelAttribute(has_control_flow=True), model_attribute=ModelAttribute(has_control_flow=True),
) )

View File

@ -33,20 +33,21 @@ if HAS_LLAMA:
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
] ]
).long() ).long()
attention_mask = torch.ones_like(input_ids)
attention_mask = torch.Tensor(
[
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
]
).long()
return dict(input_ids=input_ids, attention_mask=attention_mask) return dict(input_ids=input_ids, attention_mask=attention_mask)
# label is needed for casual lm # label is needed for causal lm
def data_gen_for_casual_lm(): def data_gen_for_causal_lm():
data = data_gen() data = data_gen()
# Test padded sequence
padding = torch.zeros(2, data["input_ids"].shape[1] // 2, dtype=torch.long)
data["input_ids"] = torch.cat([data["input_ids"], padding], dim=1)
data["attention_mask"] = torch.cat([data["attention_mask"], padding], dim=1)
ignore_idx = -100
labels = data["input_ids"].clone() labels = data["input_ids"].clone()
labels[~data["attention_mask"].bool()] = ignore_idx
data["labels"] = labels data["labels"] = labels
return data return data
@ -55,7 +56,7 @@ if HAS_LLAMA:
# function to get the loss # function to get the loss
loss_fn = lambda output: output["last_hidden_state"].mean() loss_fn = lambda output: output["last_hidden_state"].mean()
loss_fn_for_casual_lm = lambda output: output["loss"] loss_fn_for_causal_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output["logits"].mean() loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = LlamaConfig( config = LlamaConfig(
@ -70,9 +71,17 @@ if HAS_LLAMA:
config.pad_token_id = config.eos_token_id config.pad_token_id = config.eos_token_id
# register the following models # register the following models
# transformers.LlamaModel,
# transformers.LlamaForCausalLM, # transformers.LlamaForCausalLM,
# transformers.LlamaModel,
# transformers.LlamaForSequenceClassification, # transformers.LlamaForSequenceClassification,
model_zoo.register(
name="transformers_llama_for_causal_lm",
model_fn=lambda: transformers.LlamaForCausalLM(config),
data_gen_fn=data_gen_for_causal_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_causal_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register( model_zoo.register(
name="transformers_llama", name="transformers_llama",
model_fn=lambda: transformers.LlamaModel(config), model_fn=lambda: transformers.LlamaModel(config),
@ -81,14 +90,6 @@ if HAS_LLAMA:
loss_fn=loss_fn, loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True), model_attribute=ModelAttribute(has_control_flow=True),
) )
model_zoo.register(
name="transformers_llama_for_casual_lm",
model_fn=lambda: transformers.LlamaForCausalLM(config),
data_gen_fn=data_gen_for_casual_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_casual_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register( model_zoo.register(
name="transformers_llama_for_sequence_classification", name="transformers_llama_for_sequence_classification",
model_fn=lambda: transformers.LlamaForSequenceClassification(config), model_fn=lambda: transformers.LlamaForSequenceClassification(config),

View File

@ -64,7 +64,7 @@ model_zoo.register(
model_attribute=ModelAttribute(has_control_flow=True), model_attribute=ModelAttribute(has_control_flow=True),
) )
model_zoo.register( model_zoo.register(
name="transformers_mistral_for_casual_lm", name="transformers_mistral_for_causal_lm",
model_fn=lambda: transformers.MistralForCausalLM(config), model_fn=lambda: transformers.MistralForCausalLM(config),
data_gen_fn=data_gen_for_lm, data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,

View File

@ -33,8 +33,8 @@ if HAS_QWEN2:
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long() attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long()
return dict(input_ids=input_ids, attention_mask=attention_mask) return dict(input_ids=input_ids, attention_mask=attention_mask)
# label is needed for casual lm # label is needed for causal lm
def data_gen_for_casual_lm(): def data_gen_for_causal_lm():
data = data_gen() data = data_gen()
labels = data["input_ids"].clone() labels = data["input_ids"].clone()
data["labels"] = labels data["labels"] = labels
@ -45,7 +45,7 @@ if HAS_QWEN2:
# function to get the loss # function to get the loss
loss_fn = lambda output: output["last_hidden_state"].mean() loss_fn = lambda output: output["last_hidden_state"].mean()
loss_fn_for_casual_lm = lambda output: output["loss"] loss_fn_for_causal_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output["logits"].mean() loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = Qwen2Config( config = Qwen2Config(
@ -72,11 +72,11 @@ if HAS_QWEN2:
model_attribute=ModelAttribute(has_control_flow=True), model_attribute=ModelAttribute(has_control_flow=True),
) )
model_zoo.register( model_zoo.register(
name="transformers_qwen2_for_casual_lm", name="transformers_qwen2_for_causal_lm",
model_fn=lambda: transformers.Qwen2ForCausalLM(config), model_fn=lambda: transformers.Qwen2ForCausalLM(config),
data_gen_fn=data_gen_for_casual_lm, data_gen_fn=data_gen_for_causal_lm,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_casual_lm, loss_fn=loss_fn_for_causal_lm,
model_attribute=ModelAttribute(has_control_flow=True), model_attribute=ModelAttribute(has_control_flow=True),
) )
model_zoo.register( model_zoo.register(

View File

@ -97,7 +97,7 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
# TODO(ver217): add more models # TODO(ver217): add more models
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry( for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(
"transformers_llama_for_casual_lm" "transformers_llama_for_causal_lm"
).items(): ).items():
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)

View File

@ -105,7 +105,7 @@ def check_low_level_zero_lora(stage, model_name, early_stop: bool = True):
sub_model_zoo = model_zoo.get_sub_registry(model_name) sub_model_zoo = model_zoo.get_sub_registry(model_name)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
task_type = None task_type = None
if name == "transformers_llama_for_casual_lm": if name == "transformers_llama_for_causal_lm":
task_type = "CAUSAL_LM" task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification": if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS" task_type = "SEQ_CLS"

View File

@ -74,7 +74,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
@clear_cache_before_run() @clear_cache_before_run()
@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS) @parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS)
@parameterize("shard", [True, False]) @parameterize("shard", [True, False])
@parameterize("model_name", ["transformers_llama_for_casual_lm"]) @parameterize("model_name", ["transformers_llama_for_causal_lm"])
@parameterize("size_per_shard", [32]) @parameterize("size_per_shard", [32])
@parameterize("tp_size", [1, 2]) @parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2]) @parameterize("zero_size", [2])

View File

@ -20,7 +20,7 @@ from tests.kit.model_zoo import model_zoo
@clear_cache_before_run() @clear_cache_before_run()
@parameterize("shard", [False, True]) @parameterize("shard", [False, True])
@parameterize("model_name", ["transformers_llama_for_casual_lm"]) @parameterize("model_name", ["transformers_llama_for_causal_lm"])
def exam_torch_load_from_gemini(shard: bool, model_name: str): def exam_torch_load_from_gemini(shard: bool, model_name: str):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean() criterion = lambda x: x.mean()

View File

@ -39,7 +39,7 @@ else:
@parameterize("shard", [True, False]) @parameterize("shard", [True, False])
@parameterize("model_name", ["transformers_llama_for_casual_lm"]) @parameterize("model_name", ["transformers_llama_for_causal_lm"])
@parameterize("size_per_shard", [32]) @parameterize("size_per_shard", [32])
@parameterize("test_config", TEST_CONFIGS) @parameterize("test_config", TEST_CONFIGS)
@clear_cache_before_run() @clear_cache_before_run()

View File

@ -149,7 +149,7 @@ def check_low_level_zero_lora_checkpointIO(
if name != "transformers_llama": if name != "transformers_llama":
continue continue
task_type = None task_type = None
if name == "transformers_llama_for_casual_lm": if name == "transformers_llama_for_causal_lm":
task_type = "CAUSAL_LM" task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification": if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS" task_type = "SEQ_CLS"

View File

@ -18,7 +18,7 @@ from tests.kit.model_zoo import model_zoo
@clear_cache_before_run() @clear_cache_before_run()
@parameterize("model_name", ["transformers_llama_for_casual_lm"]) @parameterize("model_name", ["transformers_llama_for_causal_lm"])
@parameterize("plugin_type", ["ddp", "zero", "gemini"]) @parameterize("plugin_type", ["ddp", "zero", "gemini"])
def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32): def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32):
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(

View File

@ -91,7 +91,7 @@ def run_lora_test():
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
task_type = None task_type = None
if name == "transformers_llama_for_casual_lm": if name == "transformers_llama_for_causal_lm":
task_type = "CAUSAL_LM" task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification": if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS" task_type = "SEQ_CLS"

View File

@ -6,6 +6,7 @@ import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.testing import assert_close
import colossalai import colossalai
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
@ -107,13 +108,13 @@ def run_pp(
# check loss # check loss
if stage_manager.is_last_stage(ignore_chunk=True): if stage_manager.is_last_stage(ignore_chunk=True):
assert torch.allclose(torch_loss, pp_ret["loss"]) assert_close(torch_loss, pp_ret["loss"])
# check gradients # check gradients
for i in range(num_model_chunk): for i in range(num_model_chunk):
idx = world_size * i + rank idx = world_size * i + rank
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad) assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad) assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
# step # step
torch_optimizer.step() torch_optimizer.step()
@ -123,8 +124,8 @@ def run_pp(
# check updated param # check updated param
for i in range(num_model_chunk): for i in range(num_model_chunk):
idx = world_size * i + rank idx = world_size * i + rank
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight) assert_close(torch_model.layers[idx].weight, sharded_model[i].weight)
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias) assert_close(torch_model.layers[idx].bias, sharded_model[i].bias)
# forward only # forward only
with torch.no_grad(): with torch.no_grad():
@ -135,14 +136,14 @@ def run_pp(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
) )
if stage_manager.is_last_stage(ignore_chunk=True): if stage_manager.is_last_stage(ignore_chunk=True):
assert torch.allclose(torch_loss, pp_ret["loss"]) assert_close(torch_loss, pp_ret["loss"])
for layer in sharded_model: for layer in sharded_model:
if layer.weight.grad is None: if layer.weight.grad is None:
assert layer.weight.grad is None and layer.bias.grad is None assert layer.weight.grad is None and layer.bias.grad is None
else: else:
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad)) assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad)) assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad))
@pytest.mark.dist @pytest.mark.dist

View File

@ -6,6 +6,7 @@ import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.testing import assert_close
import colossalai import colossalai
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
@ -103,13 +104,13 @@ def examine_pp(num_microbatch: int, batch_size: int):
# check loss # check loss
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
assert torch.allclose(torch_loss, pp_ret["loss"]) assert_close(torch_loss, pp_ret["loss"])
# check gradients # check gradients
for i in range(len(sharded_model)): for i in range(len(sharded_model)):
idx = rank * num_local_layer + i idx = rank * num_local_layer + i
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad) assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad) assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
# step # step
torch_optimizer.step() torch_optimizer.step()
@ -119,8 +120,8 @@ def examine_pp(num_microbatch: int, batch_size: int):
# check updated param # check updated param
for i in range(len(sharded_model)): for i in range(len(sharded_model)):
idx = rank * num_local_layer + i idx = rank * num_local_layer + i
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight) assert_close(torch_model.layers[idx].weight, sharded_model[i].weight)
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias) assert_close(torch_model.layers[idx].bias, sharded_model[i].bias)
# forward only # forward only
with torch.no_grad(): with torch.no_grad():
@ -131,14 +132,14 @@ def examine_pp(num_microbatch: int, batch_size: int):
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
) )
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
assert torch.allclose(torch_loss, pp_ret["loss"]) assert_close(torch_loss, pp_ret["loss"])
for layer in sharded_model: for layer in sharded_model:
if layer.weight.grad is None: if layer.weight.grad is None:
assert layer.weight.grad is None and layer.bias.grad is None assert layer.weight.grad is None and layer.bias.grad is None
else: else:
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad)) assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad)) assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad))
def run_dist( def run_dist(

View File

@ -88,6 +88,7 @@ def check_attn_func(dtype: torch.dtype, attn_func, attn_kwargs: dict, padding_ma
padding_mask = padding_mask[:, None, :, None].logical_not() padding_mask = padding_mask[:, None, :, None].logical_not()
ref_output = ref_output.masked_fill(padding_mask, 0) ref_output = ref_output.masked_fill(padding_mask, 0)
output = output.masked_fill(padding_mask, 0) output = output.masked_fill(padding_mask, 0)
assert_close(output, ref_output, **tols) assert_close(output, ref_output, **tols)
output.mean().backward() output.mean().backward()
ref_output.mean().backward() ref_output.mean().backward()
@ -128,6 +129,8 @@ def test_flash_attn_func(dtype: torch.dtype):
attn_kwargs, padding_mask = gen_kwargs_func(dtype) attn_kwargs, padding_mask = gen_kwargs_func(dtype)
for attn_func, name, need_postprocess in attn_funcs: for attn_func, name, need_postprocess in attn_funcs:
print(f"{dtype}, {name}, {mask_type}") print(f"{dtype}, {name}, {mask_type}")
if mask_type == "padded":
pass
if need_postprocess: if need_postprocess:
check_attn_func(dtype, attn_func, post_process_kwargs_for_raw_attn(attn_kwargs), padding_mask) check_attn_func(dtype, attn_func, post_process_kwargs_for_raw_attn(attn_kwargs), padding_mask)
else: else:

View File

@ -0,0 +1,186 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func
from torch.testing import assert_close
import colossalai
from colossalai.shardformer.layer import AttnMaskType
from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention
from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
@parameterize("seq_len", [4096])
@parameterize("bs", [2])
@parameterize("nheads", [5])
@parameterize("d", [128])
@parameterize("dtype", [torch.bfloat16, torch.float16])
def check_ring_attn(seq_len, bs, nheads, d, dtype):
torch.cuda.manual_seed(2)
device = get_current_device()
sp_group = dist.group.WORLD
sp_size = dist.get_world_size()
# Some outliers may seem large, but our errors are still lower than
# than Megatron-LM context parallel's
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
# and the original zigzag implementation's (https://github.com/zhuzilin/ring-flash-attention/tree/main)
atol = rtol = 7e-3
# Setup inputs
qkv = torch.randn(bs, seq_len, 3, nheads, d, device=device, dtype=dtype, requires_grad=True)
local_qkv = split_batch_zigzag(qkv, sp_group)
q, k, v = local_qkv.unbind(dim=-3)
q, k, v = [x.squeeze(2).detach().clone().transpose(1, 2) for x in (q, k, v)] # (B, nHeads, Sq, D)
q.requires_grad = k.requires_grad = v.requires_grad = True
# Ring attention vs single GPU
ring_out, ring_lse = RingAttention.attention(
q,
k,
v,
sp_group,
AttnMaskType.CAUSAL,
return_softmax=True,
inner_ring_size=max(2, sp_size // 2),
# inner_ring_size=4
)
ring_out = ring_out.transpose(1, 2)
out, lse, _ = flash_attn_qkvpacked_func(
qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True
)
# Checkout out and softmax denominator
local_out = split_batch_zigzag(out, sp_group)
local_lse = split_batch_zigzag(lse, sp_group, seq_dim=-1)
local_lse = local_lse.transpose(1, 2).contiguous().view(-1, ring_lse.shape[-1]) # (B, nHeads, Sq) -> (T, nHeads)
assert_close(ring_lse, local_lse, atol=atol, rtol=rtol)
assert_close(ring_out, local_out, atol=atol, rtol=rtol)
# Check grads
ring_out.sum().backward()
out.sum().backward()
ring_dq, ring_dk, ring_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)]
dqkv = qkv.grad
local_dqkv = split_batch_zigzag(dqkv, sp_group)
assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol)
assert_close(ring_dk, local_dqkv[:, :, 1], atol=atol, rtol=rtol)
assert_close(ring_dv, local_dqkv[:, :, 2], atol=atol, rtol=rtol)
if dist.get_rank() == 0:
print(
f"sp_size {dist.get_world_size()}, inner ring size {dist.get_world_size(RingAttention.INNER_RING_GROUP)} passed."
)
@parameterize("seqlen", [4096])
@parameterize("bs", [2])
@parameterize("nheads", [5])
@parameterize("d", [128])
@parameterize("dtype", [torch.bfloat16, torch.float16])
def check_packed_seq(seqlen, bs, nheads, d, dtype):
device = get_current_device()
sp_group = dist.group.WORLD
sp_size = dist.get_world_size()
atol = rtol = 7e-3
torch.cuda.manual_seed(2)
# Prepare varlen attention mask
padding_mask = torch.ones((bs, seqlen), dtype=torch.int, device=device)
padding_mask[: bs // 2, (seqlen // 4) * 3 :] = 0
padding_mask[:, seqlen // 2 :] = 0
input_embeds = torch.randn(bs, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True)
# Forward
# out = ColoAttention.attention(q, k, v, **mask_info)
flat_input = input_embeds.view(-1, nheads, d)[padding_mask.flatten().nonzero().squeeze()]
qkv = torch.stack([flat_input] * 3, dim=1)
qkv.retain_grad()
input_embeds, mask_info, _ = RingAttention.prepare_varlen_batch(padding_mask, sp_group, input_embeds)
out, lse, _ = flash_attn_varlen_qkvpacked_func(
qkv,
mask_info["cu_seqlens"] * sp_size,
mask_info["max_seqlen"] * sp_size,
return_attn_probs=True,
causal=True,
# deterministic=True
)
# Test the splitting function
local_input = split_varlen_zigzag(
flat_input, mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size
)
assert (local_input == input_embeds.view(-1, nheads, d)[mask_info["valid_indices"]]).all()
del local_input, flat_input
q_ring, k_ring, v_ring = [input_embeds.clone().transpose(1, 2) for _ in range(3)]
q_ring.retain_grad()
k_ring.retain_grad()
v_ring.retain_grad()
ring_out, ring_lse = RingAttention.attention(
q_ring,
k_ring,
v_ring,
sp_group,
**mask_info,
pad_output=False,
return_softmax=True,
# deterministic=True
)
ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d)
# Check output
lse = lse.transpose(0, 1)
out, lse = split_varlen_zigzag(
[out, lse], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size
)
assert_close(lse, ring_lse, atol=atol, rtol=rtol)
assert_close(out, ring_out, atol=atol, rtol=rtol)
# Check grads
labels = torch.ones(out.shape[0], dtype=dtype, device=device)
F.mse_loss(out.sum((-2, -1)), labels).backward()
F.mse_loss(ring_out.sum((-2, -1)), labels[: ring_out.shape[0]]).backward()
dq, dk, dv = [
split_varlen_zigzag(
qkv.grad[:, i], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size
)
for i in range(3)
]
dq_ring, dk_ring, dv_ring = [
x.transpose(1, 2).reshape(-1, nheads, d)[mask_info["valid_indices"]]
for x in (q_ring.grad, k_ring.grad, v_ring.grad)
]
assert_close(dq, dq_ring, atol=atol, rtol=rtol)
assert_close(dk, dk_ring, atol=atol, rtol=rtol)
assert_close(dv, dv_ring, atol=atol, rtol=rtol)
def launch_single_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port)
check_packed_seq()
check_ring_attn()
def launch_double_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port)
check_ring_attn()
@rerun_if_address_is_in_use()
@parameterize("world_size", [2])
def test_ring_attn(world_size):
spawn(launch_single_ring, nprocs=world_size)
@rerun_if_address_is_in_use()
@parameterize("world_size", [4])
def test_double_ring(world_size):
spawn(launch_double_ring, nprocs=world_size)
if __name__ == "__main__":
test_ring_attn()
test_double_ring()

View File

@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup
from torch.nn import Module from torch.nn import Module
from torch.optim import Adam, Optimizer from torch.optim import Adam, Optimizer
from torch.testing import assert_close from torch.testing import assert_close
from transformers.modeling_outputs import BaseModelOutputWithPast
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
@ -259,7 +260,6 @@ def run_forward_backward_with_hybrid_plugin(
org_output = org_model(**unshard_test_data) org_output = org_model(**unshard_test_data)
org_loss = criterion(org_output) org_loss = criterion(org_output)
org_loss.backward() org_loss.backward()
return org_loss, org_output, sharded_loss, sharded_output return org_loss, org_output, sharded_loss, sharded_output
@ -302,11 +302,12 @@ def run_forward_backward_with_low_level_zero_plugin(
def check_output_hidden_state( def check_output_hidden_state(
org_output: Tensor, org_output: BaseModelOutputWithPast,
sharded_output: Tensor, sharded_output: BaseModelOutputWithPast,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
atol: float = 1e-5, atol: float = 1e-5,
rtol: float = 1e-3, rtol: float = 1e-3,
shard_config: Optional[ShardConfig] = None,
): ):
org_hidden_state = org_output.last_hidden_state org_hidden_state = org_output.last_hidden_state
@ -315,6 +316,14 @@ def check_output_hidden_state(
else: else:
sharded_hidden_state = sharded_output.last_hidden_state sharded_hidden_state = sharded_output.last_hidden_state
# Check if the output sequence is gathered before cross entropy
if shard_config is not None:
seq_dim = 1
sp_group = shard_config.sequence_parallel_process_group
sp_size = shard_config.sequence_parallel_size
if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size:
org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)]
assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol) assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol)
@ -374,8 +383,11 @@ def get_grad_tensors_for_check(
shard_grad = torch.cat(shard_grad_list, dim=dim) shard_grad = torch.cat(shard_grad_list, dim=dim)
# embedding may be resized when using tensor parallel # embedding may be resized when using tensor parallel
if shard_grad.shape[0] > org_grad.shape[0]: try:
shard_grad = shard_grad[: org_grad.shape[0], :] if shard_grad.shape[0] > org_grad.shape[0]:
shard_grad = shard_grad[: org_grad.shape[0], :]
except:
pass
if verbose and dist.get_rank() == 0: if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}") print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
@ -404,9 +416,6 @@ def check_grad(
org_grad = getattr_(org_model, suffix).weight.grad org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight shard_weight = getattr_(sharded_model, suffix).weight
# if verbose and dist.get_rank() == 0:
# print("shard_weight", shard_weight)
# print("org_grad", org_grad)
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))] shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))]
dist.all_gather(shard_grad_list, shard_grad, tp_group) dist.all_gather(shard_grad_list, shard_grad, tp_group)
@ -440,7 +449,7 @@ def check_all_grad_tensors(check_tensors):
"org_grad": tensor to be compared from the original model "org_grad": tensor to be compared from the original model
"shard_grad": tensor to be compared from the sharded model "shard_grad": tensor to be compared from the sharded model
""" """
for suffix, check_info in check_tensors.items(): for idx, (suffix, check_info) in enumerate(check_tensors.items()):
org_grad = check_info["org_grad"] org_grad = check_info["org_grad"]
shard_grad = check_info["shard_grad"] shard_grad = check_info["shard_grad"]
rtol = check_info["rtol"] rtol = check_info["rtol"]

View File

@ -271,7 +271,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
], ],
) )
def run_command_test(test_config): def run_command_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
@ -321,7 +321,7 @@ def run_command_test(test_config):
], ],
) )
def run_command_3d_test(test_config): def run_command_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)

View File

@ -63,7 +63,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
): ):
master2working = sharded_optimizer.get_master_to_working_map() master2working = sharded_optimizer.get_master_to_working_map()
for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): for (name, p1), p2 in zip(
llama_model.named_parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]
):
working_p = master2working[id(p2)] working_p = master2working[id(p2)]
grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p)) grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p))
grad_index = ( grad_index = (
@ -73,7 +75,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
) )
grad = grads[grad_index] grad = grads[grad_index]
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) try:
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)
except Exception as e:
raise RuntimeError(f"Failed to check grad for {name}") from e
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step. # Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {} grads_to_check = {}
@ -114,89 +119,130 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == "LlamaModel": if org_model.__class__.__name__ == "LlamaModel":
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_output_hidden_state(
org_output,
sharded_output,
stage_manager,
atol=atol,
rtol=rtol,
shard_config=booster.plugin.shard_config,
)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights # check weights
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
if test_config["precision"] == "fp32": if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3 atol, rtol = 1e-4, 1e-3
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
try: check_weight(
check_weight( llama_model,
llama_model, shard_llama_model,
shard_llama_model, col_layer_for_check,
col_layer_for_check, tp_group,
tp_group, atol=atol,
atol=atol, rtol=rtol,
rtol=rtol, dim=1,
dim=1, verbose=False,
verbose=False, )
)
except Exception as e:
print(f"Failed config: {test_config}")
raise e
# check grads # check grads
check_all_grad_tensors(grads_to_check) check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache() torch.cuda.empty_cache()
@parameterize( @parameterize(
"test_config", "test_config",
[ [
{ # Ulysess + Flash attention # Double Ring Attention
{
"tp_size": 1,
"pp_size": 1,
"sp_size": 4,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring_attn",
"use_lazy_init": True,
"zero_stage": 0,
"precision": "fp16",
"initial_scale": 1,
"inner_ring_size": 2,
},
# Ring Attention + PP
{
"tp_size": 1,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring_attn",
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
# Ring Attention + TP
{
"tp_size": 2,
"pp_size": 1,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring_attn",
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
{ # Ulysess + TP
"tp_size": 2,
"pp_size": 1,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 0,
"precision": "fp16",
"initial_scale": 1,
},
{ # Ulysess + PP
"tp_size": 1, "tp_size": 1,
"pp_size": 2, "pp_size": 2,
"sp_size": 2, "sp_size": 2,
"num_microbatches": 2, "num_microbatches": 2,
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all", "sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True, "enable_all_optimization": True,
"use_lazy_init": True, "use_lazy_init": True,
"zero_stage": 0, "zero_stage": 0,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
{ # Test ring + Flash attention
"tp_size": 2,
"pp_size": 1,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 1,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{ {
"tp_size": 4, "tp_size": 4,
"pp_size": 1, "pp_size": 1,
"num_microbatches": 1, "num_microbatches": 1,
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather", "sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False, "enable_flash_attention": True,
"use_lazy_init": True, "use_lazy_init": True,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
{
"tp_size": 2,
"pp_size": 1,
"sp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
@ -240,12 +286,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
) )
def run_llama_test(test_config): def run_llama_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name:
continue
try: try:
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
except Exception as e: except Exception as e:
print(f"Failed config: {test_config}") print(f"Failed config: {test_config}, model name: {name}")
raise e raise e
clear_layout_converter() clear_layout_converter()
Randomizer.reset_index() Randomizer.reset_index()