mirror of https://github.com/hpcaitech/ColossalAI
[Feature] Zigzag Ring attention (#5905)
* halfway * fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unified cross entropy func for all shardformer models * remove redundant lines * add basic ring attn; debug cross entropy * fwd bwd logic complete * fwd bwd logic complete; add experimental triton rescale * precision tests passed * precision tests passed * fix typos and remove misc files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add sp_mode to benchmark; fix varlen interface * update softmax_lse shape by new interface * change tester name * remove buffer clone; support packed seq layout * add varlen tests * fix typo * all tests passed * add dkv_group; fix mask * remove debug statements --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/6015/head
parent
887d2d579b
commit
f5c84af0b0
|
@ -12,6 +12,7 @@ repos:
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
name: sort all imports (python)
|
name: sort all imports (python)
|
||||||
|
args: ["--profile", "black"] # avoid conflict with black
|
||||||
|
|
||||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||||
rev: 24.8.0
|
rev: 24.8.0
|
||||||
|
|
|
@ -32,7 +32,7 @@ from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackw
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||||
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
||||||
from colossalai.shardformer.layer.utils import SeqParallelUtils
|
from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp
|
||||||
from colossalai.shardformer.policies.base_policy import Policy
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
from colossalai.tensor.colo_parameter import ColoParameter
|
from colossalai.tensor.colo_parameter import ColoParameter
|
||||||
from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
||||||
|
@ -42,7 +42,7 @@ from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_hand
|
||||||
|
|
||||||
from .pp_plugin_base import PipelinePluginBase
|
from .pp_plugin_base import PipelinePluginBase
|
||||||
|
|
||||||
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
|
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"]
|
||||||
|
|
||||||
PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
|
PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
|
||||||
|
|
||||||
|
@ -72,7 +72,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||||
self.dp_group = dp_group
|
self.dp_group = dp_group
|
||||||
self.tp_group = tp_group
|
self.tp_group = tp_group
|
||||||
self.sp_group = sp_group
|
self.sp_group = sp_group
|
||||||
self.use_dpp = use_ddp
|
self.use_ddp = use_ddp
|
||||||
self.require_grad_sync = True
|
self.require_grad_sync = True
|
||||||
self.overlap_allgather = overlap_allgather
|
self.overlap_allgather = overlap_allgather
|
||||||
|
|
||||||
|
@ -139,8 +139,8 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||||
# Disable automatic gradient synchronization.
|
# Disable automatic gradient synchronization.
|
||||||
self.require_grad_sync = False
|
self.require_grad_sync = False
|
||||||
try:
|
try:
|
||||||
if self.use_dpp:
|
if self.use_ddp:
|
||||||
# If using data parallel processing (use_dpp), disable synchronization too.
|
# If using data parallel processing (use_ddp), disable synchronization too.
|
||||||
with self.module.no_sync():
|
with self.module.no_sync():
|
||||||
yield
|
yield
|
||||||
else:
|
else:
|
||||||
|
@ -188,7 +188,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
if self.shard_config.sequence_parallelism_mode == "all_to_all":
|
if self.shard_config.sequence_parallelism_mode in ["all_to_all", "ring_attn"]:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||||
|
@ -970,6 +970,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
||||||
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
||||||
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
|
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
|
||||||
|
inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn".
|
||||||
|
It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -1017,6 +1020,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
dp_outside: bool = True,
|
dp_outside: bool = True,
|
||||||
overlap_p2p: bool = True,
|
overlap_p2p: bool = True,
|
||||||
overlap_allgather: bool = False,
|
overlap_allgather: bool = False,
|
||||||
|
inner_ring_size: int = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -1041,9 +1045,11 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
)
|
)
|
||||||
self.sp_size = 1
|
self.sp_size = 1
|
||||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||||
elif self.sequence_parallelism_mode in ["all_to_all"]:
|
elif self.sequence_parallelism_mode in ["all_to_all", "ring_attn"]:
|
||||||
self.sp_size = 1 if sp_size is None else sp_size
|
self.sp_size = 1 if sp_size is None else sp_size
|
||||||
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size)
|
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size)
|
||||||
|
if self.sequence_parallelism_mode == "ring_attn":
|
||||||
|
enable_flash_attention = True
|
||||||
else:
|
else:
|
||||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||||
assert (
|
assert (
|
||||||
|
@ -1063,10 +1069,21 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||||
if dp_outside:
|
if dp_outside:
|
||||||
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
if sequence_parallelism_mode == "ring_attn":
|
||||||
|
# Swap tp and sp since 2D Ring has better inter-node latency
|
||||||
|
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size)
|
||||||
|
self.sp_axis = 2
|
||||||
|
self.tp_axis = 3
|
||||||
|
else:
|
||||||
|
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
||||||
else:
|
else:
|
||||||
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
if sequence_parallelism_mode == "ring_attn":
|
||||||
|
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.sp_size, self.tp_size)
|
||||||
|
self.sp_axis = 2
|
||||||
|
self.tp_axis = 3
|
||||||
|
else:
|
||||||
|
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
||||||
|
|
||||||
self.stage_manager = None
|
self.stage_manager = None
|
||||||
self.schedule = None
|
self.schedule = None
|
||||||
|
@ -1108,6 +1125,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
if sequence_parallelism_mode == "ring_attn":
|
||||||
|
assert parallel_output, "Ring Attention doesn't support gathering output yet."
|
||||||
|
|
||||||
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||||
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
|
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
|
||||||
|
@ -1132,6 +1151,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
parallel_output=parallel_output,
|
parallel_output=parallel_output,
|
||||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||||
|
inner_ring_size=inner_ring_size,
|
||||||
)
|
)
|
||||||
self.amp_config = dict(
|
self.amp_config = dict(
|
||||||
initial_scale=initial_scale,
|
initial_scale=initial_scale,
|
||||||
|
@ -1216,15 +1236,15 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
zero_stage = 0
|
zero_stage = 0
|
||||||
|
|
||||||
if not isinstance(model, ModelWrapper):
|
if not isinstance(model, ModelWrapper):
|
||||||
|
# Shouldn't use pp (frequent grad accumulation) with torch ddp
|
||||||
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
||||||
self.dp_size == 1
|
self.dp_size == 1 and self.pp_size == 1
|
||||||
and self.pp_size == 1
|
|
||||||
and self.enable_sequence_parallelism
|
|
||||||
and self.sequence_parallelism_mode == "all_to_all"
|
|
||||||
)
|
)
|
||||||
# sync gradients across DP * SP ranks
|
|
||||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
# Apply Hybrid ZeRO across DP * SP ranks
|
||||||
|
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
|
||||||
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||||
|
self.dp_size = get_world_size(dp_group)
|
||||||
else:
|
else:
|
||||||
dp_group = self.dp_group
|
dp_group = self.dp_group
|
||||||
model = HybridParallelModule(
|
model = HybridParallelModule(
|
||||||
|
|
|
@ -203,7 +203,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
return
|
return
|
||||||
|
|
||||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Devices along the same dp_group share the same copies of model.
|
# Devices along the same dp_group share the same copies of model.
|
||||||
# So only let the device with dp_rank == 0 save the model.
|
# So only let the device with dp_rank == 0 save the model.
|
||||||
if self.dp_rank != 0:
|
if self.dp_rank != 0:
|
||||||
|
@ -643,14 +642,12 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||||
model._force_wait_all_gather()
|
model._force_wait_all_gather()
|
||||||
model = model.unwrap()
|
model = model.unwrap()
|
||||||
|
|
||||||
if self.dp_rank != 0:
|
if self.dp_rank != 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# The logic of collecting parameter shards along tp degree
|
# The logic of collecting parameter shards along tp degree
|
||||||
# has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
|
# has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
if self.pp_size == 1:
|
if self.pp_size == 1:
|
||||||
# When pipeline is not used, let master rank directly save the collected state_dict.
|
# When pipeline is not used, let master rank directly save the collected state_dict.
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
|
@ -660,7 +657,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
state_dict_list = [None for _ in range(self.pp_size)]
|
state_dict_list = [None for _ in range(self.pp_size)]
|
||||||
dist.barrier(self.pp_group)
|
dist.barrier(self.pp_group)
|
||||||
dist.all_gather_object(state_dict_list, state_dict, self.pp_group)
|
dist.all_gather_object(state_dict_list, state_dict, self.pp_group)
|
||||||
|
|
||||||
# Only the master rank do the saving.
|
# Only the master rank do the saving.
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
complete_state_dict = dict()
|
complete_state_dict = dict()
|
||||||
|
|
|
@ -62,7 +62,6 @@ def new_from_pretrained(
|
||||||
config = kwargs.pop("config", None)
|
config = kwargs.pop("config", None)
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
|
||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
local_files_only = kwargs.pop("local_files_only", False)
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
|
@ -116,7 +115,6 @@ def new_from_pretrained(
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
return_unused_kwargs=True,
|
return_unused_kwargs=True,
|
||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
resume_download=resume_download,
|
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
|
@ -195,7 +193,6 @@ def new_from_pretrained(
|
||||||
"cache_dir": cache_dir,
|
"cache_dir": cache_dir,
|
||||||
"force_download": force_download,
|
"force_download": force_download,
|
||||||
"proxies": proxies,
|
"proxies": proxies,
|
||||||
"resume_download": resume_download,
|
|
||||||
"local_files_only": local_files_only,
|
"local_files_only": local_files_only,
|
||||||
"use_auth_token": use_auth_token,
|
"use_auth_token": use_auth_token,
|
||||||
"user_agent": user_agent,
|
"user_agent": user_agent,
|
||||||
|
@ -312,7 +309,6 @@ def new_from_pretrained(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
resume_download=resume_download,
|
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
|
|
|
@ -171,7 +171,7 @@ class OpenMoeForCausalLMPolicy(OpenMoePolicy):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# add a new item for casual lm
|
# add a new item for causal lm
|
||||||
# TODO: recursively assign ep group foe all modules
|
# TODO: recursively assign ep group foe all modules
|
||||||
new_item = {
|
new_item = {
|
||||||
OpenMoeForCausalLM: ModulePolicyDescription(
|
OpenMoeForCausalLM: ModulePolicyDescription(
|
||||||
|
|
|
@ -81,6 +81,9 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||||
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
|
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
|
||||||
# Delay the start of weight gradient computation shortly (3us) to have
|
# Delay the start of weight gradient computation shortly (3us) to have
|
||||||
# all-reduce scheduled first and have GPU resources allocated
|
# all-reduce scheduled first and have GPU resources allocated
|
||||||
|
# TODO: This seems to only work if you add torch.cuda.Event.wait()
|
||||||
|
|
||||||
|
# _ = torch.zeros(1, device=grad_output.device)
|
||||||
|
|
||||||
grad_weight = grad_output.t().matmul(total_input)
|
grad_weight = grad_output.t().matmul(total_input)
|
||||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||||
|
|
|
@ -64,7 +64,10 @@ class DistributedLogger:
|
||||||
self._logger.propagate = False
|
self._logger.propagate = False
|
||||||
|
|
||||||
DistributedLogger.__instances[name] = self
|
DistributedLogger.__instances[name] = self
|
||||||
self.rank = dist.get_rank() if dist.is_initialized() else 0
|
|
||||||
|
@property
|
||||||
|
def rank(self):
|
||||||
|
return dist.get_rank() if dist.is_initialized() else 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __get_call_info():
|
def __get_call_info():
|
||||||
|
|
|
@ -286,7 +286,6 @@ class InterleavedSchedule(PipelineSchedule):
|
||||||
# for the first stage, input_obj is None
|
# for the first stage, input_obj is None
|
||||||
# for other stages, input_obj is the output of the previous stage containing hidden_states etc.
|
# for other stages, input_obj is the output of the previous stage containing hidden_states etc.
|
||||||
# Only attention_mask from micro_batch is used
|
# Only attention_mask from micro_batch is used
|
||||||
|
|
||||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||||
if isinstance(model_chunk, ModuleList):
|
if isinstance(model_chunk, ModuleList):
|
||||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
|
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
|
||||||
|
|
|
@ -244,6 +244,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
output_obj = model_forward(model, micro_batch, input_obj)
|
output_obj = model_forward(model, micro_batch, input_obj)
|
||||||
if self.stage_manager.is_last_stage():
|
if self.stage_manager.is_last_stage():
|
||||||
loss = criterion(output_obj, micro_batch) / self.num_microbatches
|
loss = criterion(output_obj, micro_batch) / self.num_microbatches
|
||||||
|
|
||||||
if accum_loss is not None:
|
if accum_loss is not None:
|
||||||
accum_loss.add_(loss.detach())
|
accum_loss.add_(loss.detach())
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from ._operation import all_to_all_comm
|
from ._operation import all_to_all_comm
|
||||||
from .attn import AttnMaskType, ColoAttention
|
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
|
||||||
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
||||||
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
||||||
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
|
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
|
||||||
|
@ -31,5 +31,7 @@ __all__ = [
|
||||||
"VocabParallelLMHead1D",
|
"VocabParallelLMHead1D",
|
||||||
"AttnMaskType",
|
"AttnMaskType",
|
||||||
"ColoAttention",
|
"ColoAttention",
|
||||||
|
"RingAttention",
|
||||||
|
"get_pad_info",
|
||||||
"all_to_all_comm",
|
"all_to_all_comm",
|
||||||
]
|
]
|
||||||
|
|
|
@ -2,6 +2,8 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .utils import is_share_sp_tp
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import fused_mix_prec_layer_norm_cuda
|
import fused_mix_prec_layer_norm_cuda
|
||||||
except:
|
except:
|
||||||
|
@ -93,7 +95,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
||||||
if ctx.async_grad_allreduce:
|
if ctx.async_grad_allreduce:
|
||||||
# Asynchronous all-reduce
|
# Asynchronous all-reduce
|
||||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||||
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||||
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
||||||
|
|
||||||
grad_weight = total_input.t().matmul(grad_output)
|
grad_weight = total_input.t().matmul(grad_output)
|
||||||
|
@ -143,7 +145,9 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||||
if ctx.async_grad_allreduce:
|
if ctx.async_grad_allreduce:
|
||||||
# Asynchronous all-reduce
|
# Asynchronous all-reduce
|
||||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||||
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
_ = torch.zeros(1, device=grad_input.device)
|
||||||
|
|
||||||
|
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||||
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
||||||
|
|
||||||
if _grad_accum_fusion_available and weight.grad is not None:
|
if _grad_accum_fusion_available and weight.grad is not None:
|
||||||
|
@ -331,7 +335,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
|
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
|
||||||
).contiguous()
|
).contiguous()
|
||||||
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||||
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||||
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
||||||
|
|
||||||
if _grad_accum_fusion_available and weight.grad is not None:
|
if _grad_accum_fusion_available and weight.grad is not None:
|
||||||
|
@ -646,8 +650,8 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
|
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
|
||||||
).contiguous()
|
).contiguous()
|
||||||
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||||
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||||
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
# all-reduce scheduled first and have GPU resources allocated
|
||||||
|
|
||||||
grad_weight = total_input.t().matmul(grad_output)
|
grad_weight = total_input.t().matmul(grad_output)
|
||||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||||
|
@ -721,16 +725,20 @@ class _ReduceForward(torch.autograd.Function):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_: input matrix.
|
input_: input matrix.
|
||||||
parallel_mode: parallel mode.
|
process_group: communication group.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, process_group):
|
def forward(ctx, input_, process_group, grad_scale=None):
|
||||||
|
ctx.grad_scale = grad_scale
|
||||||
return _reduce(input_, process_group)
|
return _reduce(input_, process_group)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
return grad_output, None
|
if ctx.grad_scale is not None:
|
||||||
|
grad_output = grad_output * ctx.grad_scale
|
||||||
|
return grad_output, None, None
|
||||||
|
|
||||||
|
|
||||||
class _ReduceBackward(torch.autograd.Function):
|
class _ReduceBackward(torch.autograd.Function):
|
||||||
|
@ -979,8 +987,8 @@ def split_forward_gather_backward(input_, dim, process_group, grad_scale=None):
|
||||||
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale)
|
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale)
|
||||||
|
|
||||||
|
|
||||||
def reduce_forward(input_, process_group):
|
def reduce_forward(input_, process_group, grad_scale=None):
|
||||||
return _ReduceForward.apply(input_, process_group)
|
return _ReduceForward.apply(input_, process_group, grad_scale)
|
||||||
|
|
||||||
|
|
||||||
def reduce_backward(input_, process_group):
|
def reduce_backward(input_, process_group):
|
||||||
|
@ -989,3 +997,13 @@ def reduce_backward(input_, process_group):
|
||||||
|
|
||||||
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
|
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
|
||||||
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
|
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def gather_sp_output(hidden_states, sp_group, sp_mode):
|
||||||
|
"""
|
||||||
|
Gather the output of the last layer for cross entropy computation
|
||||||
|
"""
|
||||||
|
# Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group)
|
||||||
|
scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group)
|
||||||
|
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=scale)
|
||||||
|
return hidden_states
|
||||||
|
|
|
@ -2,7 +2,10 @@ from enum import Enum
|
||||||
from typing import Callable, Dict, Optional, Tuple
|
from typing import Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
from colossalai.kernel.kernel_loader import (
|
from colossalai.kernel.kernel_loader import (
|
||||||
FlashAttentionForFloatAndCustomMaskLoader,
|
FlashAttentionForFloatAndCustomMaskLoader,
|
||||||
|
@ -10,12 +13,18 @@ from colossalai.kernel.kernel_loader import (
|
||||||
FlashAttentionWithCustomMaskLoader,
|
FlashAttentionWithCustomMaskLoader,
|
||||||
KernelLoader,
|
KernelLoader,
|
||||||
)
|
)
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
from .utils import RingComm, get_half_index, split_varlen_zigzag
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AttnMaskType",
|
"AttnMaskType",
|
||||||
"ColoAttention",
|
"ColoAttention",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
_flash_attn_forward = _flash_attn_backward = None
|
||||||
|
_unpad_input = _pad_input = None
|
||||||
|
|
||||||
|
|
||||||
class AttnMaskType(Enum):
|
class AttnMaskType(Enum):
|
||||||
CUSTOM = 0
|
CUSTOM = 0
|
||||||
|
@ -38,20 +47,32 @@ def invert_mask(mask: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
|
||||||
# adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
|
# adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
|
||||||
def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.Tensor]:
|
def get_pad_info(
|
||||||
|
padding_mask: torch.Tensor, invert: Optional[bool] = False, return_indices: Optional[bool] = True
|
||||||
|
) -> Tuple[int, torch.Tensor, torch.Tensor]:
|
||||||
"""Get padding information from padding mask.
|
"""Get padding information from padding mask.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, S]
|
padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, Skv]
|
||||||
|
invert (Optional[bool], optional): Whether to reverse the padding mask.
|
||||||
|
return_indices (Optional[bool], optional): Whether to return the indices of non-masked tokens.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[int, torch.Tensor, torch.Tensor]: Tuple of (max_seq_len, cu_seqlens, indices)
|
max_seqlen_in_batch (int): Maximum sequence length in the batch.
|
||||||
|
cu_seqlens (torch.Tensor): Shape [B+1]. Cumulative sequence lengths of the sequences in the batch.
|
||||||
|
indices (torch.Tensor): Shape [total_nonzero]. The indices of non-masked tokens from the flattened input sequence.
|
||||||
"""
|
"""
|
||||||
|
if invert:
|
||||||
|
padding_mask = padding_mask.logical_not()
|
||||||
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
|
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
|
if return_indices:
|
||||||
|
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
|
||||||
|
|
||||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||||
return max_seqlen_in_batch, cu_seqlens, indices
|
if return_indices:
|
||||||
|
return max_seqlen_in_batch, cu_seqlens, indices
|
||||||
|
return max_seqlen_in_batch, cu_seqlens
|
||||||
|
|
||||||
|
|
||||||
class ColoAttention:
|
class ColoAttention:
|
||||||
|
@ -107,6 +128,7 @@ class ColoAttention:
|
||||||
q_padding_mask: Optional[torch.Tensor] = None,
|
q_padding_mask: Optional[torch.Tensor] = None,
|
||||||
kv_padding_mask: Optional[torch.Tensor] = None,
|
kv_padding_mask: Optional[torch.Tensor] = None,
|
||||||
is_causal: bool = False,
|
is_causal: bool = False,
|
||||||
|
invert: bool = True,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
"""Return a dictionary of keyword arguments for attention function. It supports 4 mask type.
|
"""Return a dictionary of keyword arguments for attention function. It supports 4 mask type.
|
||||||
1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves.
|
1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves.
|
||||||
|
@ -124,7 +146,7 @@ class ColoAttention:
|
||||||
The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token.
|
The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token.
|
||||||
If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None.
|
If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None.
|
||||||
is_causal (bool, optional): Whether to use causal attention mask. Defaults to False.
|
is_causal (bool, optional): Whether to use causal attention mask. Defaults to False.
|
||||||
|
invert_mask (bool, optional): Whether to invert the mask. Defaults to True.
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function.
|
Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function.
|
||||||
"""
|
"""
|
||||||
|
@ -154,7 +176,7 @@ class ColoAttention:
|
||||||
assert kv_padding_mask.shape == (
|
assert kv_padding_mask.shape == (
|
||||||
b,
|
b,
|
||||||
s_kv,
|
s_kv,
|
||||||
), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
|
), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})"
|
||||||
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
||||||
outputs.update(
|
outputs.update(
|
||||||
{
|
{
|
||||||
|
@ -172,7 +194,8 @@ class ColoAttention:
|
||||||
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
|
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
|
||||||
else:
|
else:
|
||||||
outputs["attention_mask_type"] = AttnMaskType.PADDED
|
outputs["attention_mask_type"] = AttnMaskType.PADDED
|
||||||
attention_mask = invert_mask(attention_mask).unsqueeze(1)
|
if invert:
|
||||||
|
attention_mask = invert_mask(attention_mask).unsqueeze(1)
|
||||||
outputs["attention_mask"] = attention_mask
|
outputs["attention_mask"] = attention_mask
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -191,6 +214,7 @@ class ColoAttention:
|
||||||
kv_indices: Optional[torch.Tensor] = None,
|
kv_indices: Optional[torch.Tensor] = None,
|
||||||
dropout_p: float = 0.0,
|
dropout_p: float = 0.0,
|
||||||
scale: Optional[float] = None,
|
scale: Optional[float] = None,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Flash Attention function. It supports 4 mask type.
|
"""Flash Attention function. It supports 4 mask type.
|
||||||
1. custom mask: recv attention_mask
|
1. custom mask: recv attention_mask
|
||||||
|
@ -199,9 +223,9 @@ class ColoAttention:
|
||||||
4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
|
4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q (torch.Tensor): Query tensor. Shape should be [B, N, Sq, D]
|
q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D]
|
||||||
k (torch.Tensor): Key tensor. Shape should be [B, N, Skv, D]
|
k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Skv, D]
|
||||||
v (torch.Tensor): Value tensor. Shape should be [B, N, Skv, D]
|
v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Skv, D]
|
||||||
attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None.
|
attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None.
|
||||||
attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM.
|
attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM.
|
||||||
cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths
|
cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths
|
||||||
|
@ -218,7 +242,7 @@ class ColoAttention:
|
||||||
scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None.
|
scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Output tensor. Shape should be [B, N, Sq, D]
|
torch.Tensor: Output tensor. Shape should be [B, nHeads, Sq, D]
|
||||||
"""
|
"""
|
||||||
# known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan
|
# known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan
|
||||||
# this case is usaul when padding mask is used and self attention is performed
|
# this case is usaul when padding mask is used and self attention is performed
|
||||||
|
@ -252,6 +276,7 @@ class ColoAttention:
|
||||||
else:
|
else:
|
||||||
# if attention_mask is None, attention_mask_type should be the default value
|
# if attention_mask is None, attention_mask_type should be the default value
|
||||||
assert attention_mask_type == AttnMaskType.CUSTOM
|
assert attention_mask_type == AttnMaskType.CUSTOM
|
||||||
|
|
||||||
# kernel dispatch
|
# kernel dispatch
|
||||||
mask_type = attention_mask_type if attention_mask is not None else None
|
mask_type = attention_mask_type if attention_mask is not None else None
|
||||||
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type)
|
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type)
|
||||||
|
@ -274,3 +299,858 @@ class ColoAttention:
|
||||||
q_indices=q_indices,
|
q_indices=q_indices,
|
||||||
kv_indices=kv_indices,
|
kv_indices=kv_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_varlen_helpers():
|
||||||
|
"""Helper to load functions for padding and unpadding packed sequences.
|
||||||
|
Use only when flash attn is installed
|
||||||
|
"""
|
||||||
|
global _pad_input, _unpad_input
|
||||||
|
# Flash attn claims this is more efficient than torch's bool indexing due to avoiding
|
||||||
|
# broadcast
|
||||||
|
if _pad_input is None or _unpad_input is None:
|
||||||
|
try:
|
||||||
|
from flash_attn.bert_padding import index_first_axis, pad_input
|
||||||
|
|
||||||
|
def unpad_input(hidden_states: torch.Tensor, indices: torch.Tensor):
|
||||||
|
return index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices)
|
||||||
|
|
||||||
|
_pad_input = pad_input
|
||||||
|
_unpad_input = unpad_input
|
||||||
|
except ImportError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Flash Attention is not installed. You can install it via 'pip install flash-attn --no-build-isolation'"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def _load_flash_attn():
|
||||||
|
"""A light-weight loader to check whether flash-attn is installed.
|
||||||
|
Can't use ColoAttention._dispatch_kernel because we mutate the backward pass
|
||||||
|
"""
|
||||||
|
global _flash_attn_forward, _flash_attn_backward
|
||||||
|
if _flash_attn_forward is None or _flash_attn_backward is None:
|
||||||
|
try:
|
||||||
|
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward
|
||||||
|
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward
|
||||||
|
except ImportError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Flash Attention is not installed. You can install it via 'pip install flash-attn --no-build-isolation'"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
_load_varlen_helpers()
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: This can cause spawned processes to hang on exit
|
||||||
|
# with python 3.9
|
||||||
|
@torch.compile()
|
||||||
|
def _rescale_out_lse(out, block_out, lse, block_lse):
|
||||||
|
"""
|
||||||
|
Compute the new attention denominator:
|
||||||
|
exp(lse) + exp(block_lse) = exp(max_scale) * (exp(min_scale - max_scale) + 1)
|
||||||
|
Args:
|
||||||
|
out: (T, H, D)
|
||||||
|
block_out: (T, H, D)
|
||||||
|
lse: (H, T, 1)
|
||||||
|
block_lse: (H, T, 1)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# min_scale = torch.min(lse, block_lse)
|
||||||
|
# max_scale = torch.max(lse, block_lse)
|
||||||
|
# new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))
|
||||||
|
|
||||||
|
# NOTE: directly assigning to .data here is buggy
|
||||||
|
# probably due to casting dtypes/strides
|
||||||
|
new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
|
||||||
|
|
||||||
|
new_block_lse = torch.exp(block_lse - new_lse)
|
||||||
|
out = (torch.exp(lse - new_lse) * out + new_block_lse * block_out).to(out)
|
||||||
|
lse = new_lse
|
||||||
|
|
||||||
|
# Equivalent to the above
|
||||||
|
# See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
|
||||||
|
# out = (out - F.sigmoid(block_lse - lse) * (out - block_out))
|
||||||
|
# lse = (lse - F.logsigmoid(lse - block_lse))
|
||||||
|
return out, lse
|
||||||
|
|
||||||
|
|
||||||
|
class RingAttention(torch.autograd.Function):
|
||||||
|
"""Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context`
|
||||||
|
(https://arxiv.org/abs/2310.01889).
|
||||||
|
For load-balancing we adopted the "zigzag" attention scheme from https://github.com/zhuzilin/ring-flash-attention/tree/main
|
||||||
|
For portable integration with more models, we don't follow the spirit of "block-wise FNN" in the original paper,
|
||||||
|
which requires fusing FFN with the Flash Attention kernel/function (see https://arxiv.org/pdf/2305.19370;
|
||||||
|
implemented in Jax and not optimized).
|
||||||
|
We adopt the double ring topology from LoongTrain (https://arxiv.org/pdf/2406.18485) to fully utilize available
|
||||||
|
NICs on each node, by computing attention within a inner ring first and then sending all KVs to the next
|
||||||
|
ring at once.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Globle cache to avoid recomputation for same-lengthed sequences
|
||||||
|
CU_SEQLENS: torch.Tensor = None # [B+1]
|
||||||
|
TOTAL_SEQLEN: int = None
|
||||||
|
HALF_INDICES: Tuple = None
|
||||||
|
SUPPORTED_MASK_TYPES = (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL)
|
||||||
|
ATTN_DONE: torch.cuda.Event = None
|
||||||
|
SP_STREAM: torch.cuda.Stream = None
|
||||||
|
SP_GROUP: dist.ProcessGroup = None
|
||||||
|
# duplicate process group for concurrent NCCL streams
|
||||||
|
# while both PyTorch and NCCL warns(https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7)
|
||||||
|
# against this, in practice it seems to work fine.
|
||||||
|
INNER_RING_GROUP: dist.ProcessGroup = None
|
||||||
|
INNER_RING_GROUP_COPY: dist.ProcessGroup = None
|
||||||
|
INTER_RING_GROUP: dist.ProcessGroup = None
|
||||||
|
INTER_RING_GROUP_COPY: dist.ProcessGroup = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_double_ring_groups(sp_group, inner_ring_size=None):
|
||||||
|
"""
|
||||||
|
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
|
||||||
|
shouldn't be larger than the number of NICs on each node.
|
||||||
|
Args:
|
||||||
|
sp_group (dist.ProcessGroup): Process group for sequence parallelism
|
||||||
|
inner_ring_size (Optional[int], optional): Inner ring size. Defaults to None.
|
||||||
|
Returns:
|
||||||
|
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
|
||||||
|
"""
|
||||||
|
sp_size = dist.get_world_size(sp_group)
|
||||||
|
sp_rank = dist.get_rank(sp_group)
|
||||||
|
|
||||||
|
if inner_ring_size is None:
|
||||||
|
if torch.cuda.device_count() >= dist.get_world_size():
|
||||||
|
# single node, no need to consider NICs
|
||||||
|
return sp_group, sp_group
|
||||||
|
if sp_size <= 4:
|
||||||
|
inner_ring_size = min(2, sp_size)
|
||||||
|
else:
|
||||||
|
inner_ring_size = min(4, sp_size)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
inner_ring_size <= sp_size and sp_size % inner_ring_size == 0
|
||||||
|
), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
|
||||||
|
|
||||||
|
if inner_ring_size == sp_size:
|
||||||
|
return sp_group, sp_group
|
||||||
|
assert (
|
||||||
|
sp_size % inner_ring_size == 0
|
||||||
|
), f"sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
|
||||||
|
|
||||||
|
logger = get_dist_logger()
|
||||||
|
logger.info(
|
||||||
|
f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!",
|
||||||
|
ranks=[0],
|
||||||
|
)
|
||||||
|
num_rings = sp_size // inner_ring_size
|
||||||
|
inner_ring_group = None
|
||||||
|
inter_ring_group = None
|
||||||
|
|
||||||
|
# Create inner ring groups
|
||||||
|
for i in range(inner_ring_size):
|
||||||
|
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
|
||||||
|
group = dist.new_group(ranks)
|
||||||
|
if sp_rank in ranks:
|
||||||
|
inner_ring_group = group
|
||||||
|
|
||||||
|
# Create inter ring groups
|
||||||
|
for i in range(num_rings):
|
||||||
|
ranks = list(range(i, sp_size, num_rings))
|
||||||
|
group = dist.new_group(ranks)
|
||||||
|
if sp_rank in ranks:
|
||||||
|
inter_ring_group = group
|
||||||
|
|
||||||
|
return inner_ring_group, inter_ring_group
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def attention(
|
||||||
|
q, # (B, H, Sq, D)
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
sp_group,
|
||||||
|
attention_mask_type,
|
||||||
|
cu_seqlens=None,
|
||||||
|
max_seqlen=None,
|
||||||
|
valid_indices=None,
|
||||||
|
dropout_p=0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
deterministic=False,
|
||||||
|
return_softmax=False,
|
||||||
|
inner_ring_size=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Ring Attention forward pass supporting variable-length sequences. When using varlen mode,
|
||||||
|
each sequence in the batch should have length divisible by sp_size * 2.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D]
|
||||||
|
k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D]
|
||||||
|
v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D]
|
||||||
|
sp_group (Optional[dist.ProcessGroup]): Process group for sequence parallelism
|
||||||
|
sp_tream (torch.cuda.Stream): An different stream for output correction.
|
||||||
|
cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths
|
||||||
|
of the sequences in the batch, used to index into q.
|
||||||
|
Shape should be [B+1].
|
||||||
|
max_seqlen (Optional[int], optional): Maximum query sequence length in the batch.
|
||||||
|
valid_indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from get_pad_info.
|
||||||
|
Shape should be [t].
|
||||||
|
dropout_p (float, optional): Dropout probability. Defaults to 0.0.
|
||||||
|
softmax_scale (Optional[float], optional): Scaling factor applied prior to softmax.
|
||||||
|
deterministic (bool, optional): Whether to force deterministic backward pass. See https://github.com/Dao-AILab/flash-attention/issues/349
|
||||||
|
return_softmax (bool, optional): Whether to return the softmax denominator (logsumexp).
|
||||||
|
inner_ring_size (Optional[int], optional): Inner ring size of the 2D ring. By default use a heuristic to decide.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out: Output tensor of shape [B, nHeads, Sq, D] or [T, nHeads, D] if pad_output is False.
|
||||||
|
softmax_lse: (if return_softmax is True) Softmax denominator (logsumexp).
|
||||||
|
Shape should be [total_q_seqlen, nHeads]
|
||||||
|
"""
|
||||||
|
# Check input args
|
||||||
|
_load_flash_attn()
|
||||||
|
if RingAttention.ATTN_DONE is None:
|
||||||
|
RingAttention.ATTN_DONE = torch.cuda.Event()
|
||||||
|
if RingAttention.SP_STREAM is None:
|
||||||
|
RingAttention.SP_STREAM = torch.cuda.Stream()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
q.shape[2] == k.shape[2]
|
||||||
|
), "Q, K and V having different sequence lengths (inference or cross-attn)\
|
||||||
|
is not supported yet in training."
|
||||||
|
assert (
|
||||||
|
attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES
|
||||||
|
), f"Mask type {attention_mask_type} is not supported yet."
|
||||||
|
|
||||||
|
clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg))
|
||||||
|
|
||||||
|
if RingAttention.SP_GROUP is not sp_group:
|
||||||
|
RingAttention.SP_GROUP = sp_group
|
||||||
|
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size)
|
||||||
|
RingAttention.INNER_RING_GROUP = inner_ring_group
|
||||||
|
RingAttention.INTER_RING_GROUP = inter_ring_group
|
||||||
|
else:
|
||||||
|
inner_ring_group = RingAttention.INNER_RING_GROUP
|
||||||
|
inter_ring_group = RingAttention.INTER_RING_GROUP
|
||||||
|
|
||||||
|
# (B, H, Sq, D) -> (B, Sq, H, D)
|
||||||
|
q, k, v = [x.transpose(1, 2).contiguous() for x in (q, k, v)]
|
||||||
|
pad_output = q.dim() == 4
|
||||||
|
|
||||||
|
# Get sequence length info for varlen forward
|
||||||
|
if attention_mask_type == AttnMaskType.CAUSAL:
|
||||||
|
# All sequences share the same length
|
||||||
|
b, sq, h, d = q.shape
|
||||||
|
max_seqlen = sq
|
||||||
|
# Cache to avoid recreation for a single sequence
|
||||||
|
if sq * b == RingAttention.TOTAL_SEQLEN:
|
||||||
|
cu_seqlens = RingAttention.CU_SEQLENS
|
||||||
|
else:
|
||||||
|
cu_seqlens = torch.arange(0, b * sq + 1, sq, device=q.device, dtype=torch.int32)
|
||||||
|
RingAttention.TOTAL_SEQLEN = b * sq
|
||||||
|
|
||||||
|
# "Packed" mode where sequences of different lengths are packed into [total_q_seqlen, H, D]
|
||||||
|
elif attention_mask_type == AttnMaskType.PADDED_CAUSAL:
|
||||||
|
assert (
|
||||||
|
cu_seqlens is not None and max_seqlen is not None and valid_indices is not None
|
||||||
|
), "Packed mode requires pre-computed cu_seqlens and max_seq_len."
|
||||||
|
if pad_output:
|
||||||
|
b, sq, h, d = q.shape
|
||||||
|
q, k, v = [_unpad_input(x, valid_indices) for x in (q, k, v)]
|
||||||
|
|
||||||
|
out, softmax_lse = RingAttention.apply(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
sp_group,
|
||||||
|
RingAttention.SP_STREAM,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
dropout_p,
|
||||||
|
softmax_scale,
|
||||||
|
deterministic,
|
||||||
|
return_softmax,
|
||||||
|
attention_mask_type == AttnMaskType.PADDED_CAUSAL,
|
||||||
|
inner_ring_group,
|
||||||
|
inter_ring_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask_type == AttnMaskType.PADDED_CAUSAL:
|
||||||
|
if pad_output:
|
||||||
|
out = _pad_input(out, valid_indices, b, sq) # (T, ...) -> (B, Sq, ...)
|
||||||
|
out = out.transpose(1, 2) # (B, Sq, H, D) -> (B, H, Sq, D)
|
||||||
|
else:
|
||||||
|
out = out.transpose(1, 2)
|
||||||
|
|
||||||
|
if return_softmax:
|
||||||
|
return out, softmax_lse
|
||||||
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
sp_group: dist.ProcessGroup,
|
||||||
|
sp_stream: torch.cuda.Stream,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
max_seqlen: int,
|
||||||
|
dropout_p: float = 0.0,
|
||||||
|
softmax_scale: Optional[float] = None,
|
||||||
|
deterministic: Optional[bool] = False,
|
||||||
|
return_softmax: Optional[bool] = False,
|
||||||
|
is_packed: Optional[bool] = False,
|
||||||
|
inner_ring_group: Optional[dist.ProcessGroup] = None,
|
||||||
|
inter_ring_group: Optional[dist.ProcessGroup] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
cu_seqlens_q = cu_seqlens_kv = cu_seqlens
|
||||||
|
max_seqlen_q = max_seqlen_kv = max_seqlen
|
||||||
|
cu_seqlens_half = cu_seqlens // 2
|
||||||
|
max_seqlen_half = max_seqlen // 2
|
||||||
|
|
||||||
|
misc_kwargs = {
|
||||||
|
"window_size": (-1, -1),
|
||||||
|
"alibi_slopes": None,
|
||||||
|
"softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale,
|
||||||
|
"dropout_p": dropout_p,
|
||||||
|
"block_table": None,
|
||||||
|
"softcap": 0.0,
|
||||||
|
"return_softmax": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
RingAttention.HALF_INDICES is not None
|
||||||
|
and cu_seqlens.shape == RingAttention.CU_SEQLENS.shape
|
||||||
|
and (cu_seqlens == RingAttention.CU_SEQLENS).all()
|
||||||
|
):
|
||||||
|
half_idx_front, half_idx_back = RingAttention.HALF_INDICES
|
||||||
|
else:
|
||||||
|
half_idx_front = get_half_index(cu_seqlens, front=True)
|
||||||
|
half_idx_back = get_half_index(cu_seqlens, front=False)
|
||||||
|
RingAttention.HALF_INDICES = (half_idx_front, half_idx_back)
|
||||||
|
RingAttention.CU_SEQLENS = cu_seqlens
|
||||||
|
|
||||||
|
if is_packed:
|
||||||
|
t, h, d = q.shape
|
||||||
|
else:
|
||||||
|
b, sq, h, d = q.shape
|
||||||
|
t = b * sq
|
||||||
|
# Be careful about GQA/MQA in reshape
|
||||||
|
q, k, v = [x.view(t, *x.shape[-2:]) for x in (q, k, v)]
|
||||||
|
|
||||||
|
if inner_ring_group is None or inter_ring_group is None:
|
||||||
|
# Use one ring if not specified
|
||||||
|
inner_ring_group = inter_ring_group = sp_group
|
||||||
|
|
||||||
|
sp_size = dist.get_world_size(sp_group)
|
||||||
|
sp_rank = dist.get_rank(sp_group)
|
||||||
|
# Attempt to achieve concurrent comm in the two-stream forward
|
||||||
|
local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)]
|
||||||
|
inter_ring_comm = RingComm(inter_ring_group)
|
||||||
|
local_sp_size = dist.get_world_size(inner_ring_group)
|
||||||
|
local_sp_rank = dist.get_rank(inner_ring_group)
|
||||||
|
inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0
|
||||||
|
num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1
|
||||||
|
|
||||||
|
# Non-contiguous indexing copies to a new contiguous tensor,
|
||||||
|
# so only do it once
|
||||||
|
if sp_rank != sp_size - 1:
|
||||||
|
q1 = q[half_idx_back]
|
||||||
|
|
||||||
|
# Pre-allocate double buffer for overlapping and receiving next step's inputs
|
||||||
|
kv_buffers = [torch.stack((k, v))] # (2, B, Sq, H, D)
|
||||||
|
kv_buffers.append(torch.empty_like(kv_buffers[0]))
|
||||||
|
|
||||||
|
# outputs
|
||||||
|
out = None
|
||||||
|
block_out = [None, None]
|
||||||
|
softmax_lse = [None, None]
|
||||||
|
block_softmax_lse = [None, None] # log sum exp, the denominator of softmax in attention
|
||||||
|
rng_states = [None for _ in range(sp_size)]
|
||||||
|
sp_streams = [torch.cuda.current_stream(), sp_stream]
|
||||||
|
|
||||||
|
def _forward(q, k, v, causal):
|
||||||
|
(
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
out,
|
||||||
|
softmax_lse,
|
||||||
|
_,
|
||||||
|
rng_state,
|
||||||
|
) = _flash_attn_forward(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q if q.shape[0] == t else cu_seqlens_half,
|
||||||
|
cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half,
|
||||||
|
max_seqlen_q if q.shape[0] == t else max_seqlen_half,
|
||||||
|
max_seqlen_kv if k.shape[0] == t else max_seqlen_half,
|
||||||
|
causal=causal,
|
||||||
|
**misc_kwargs,
|
||||||
|
)
|
||||||
|
return out, softmax_lse, rng_state
|
||||||
|
|
||||||
|
def _local_ring_forward():
|
||||||
|
# (Hopefully) overlap output correction with next flash attn
|
||||||
|
for i in range(local_sp_size):
|
||||||
|
with torch.cuda.stream(sp_streams[i % 2]):
|
||||||
|
# Wait for current kv from prev rank
|
||||||
|
# NOTE: waiting outside the current stream will NOT correctly synchronize.
|
||||||
|
if i > 0:
|
||||||
|
local_kv_comms[(i + 1) % 2].wait()
|
||||||
|
|
||||||
|
# Avoid overwriting attn input when it shares mem with buffer
|
||||||
|
if not RingAttention.ATTN_DONE.query():
|
||||||
|
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
|
||||||
|
if i < local_sp_size - 1:
|
||||||
|
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
# Compute with local KV; no mask
|
||||||
|
kv_block = kv_buffers[0]
|
||||||
|
q_block = q
|
||||||
|
(block_out[i % 2], block_softmax_lse[i % 2], rng_states[i]) = _forward( # (T, H, D) # (H, T)
|
||||||
|
q_block, kv_block[0], kv_block[1], causal=True
|
||||||
|
)
|
||||||
|
elif i <= local_sp_rank:
|
||||||
|
# Received the "surrounding" kv chunks
|
||||||
|
# Drop the second half of received kv
|
||||||
|
# (2, t // 2, H, D)
|
||||||
|
kv_block = kv_buffers[i % 2][:, half_idx_front]
|
||||||
|
q_block = q
|
||||||
|
(
|
||||||
|
block_out[i % 2], # (T, H, D)
|
||||||
|
block_softmax_lse[i % 2], # (H, T)
|
||||||
|
rng_states[i],
|
||||||
|
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
|
||||||
|
else:
|
||||||
|
# Received the inner kv chunks
|
||||||
|
# Drop the first half of q
|
||||||
|
kv_block = kv_buffers[i % 2]
|
||||||
|
q_block = q1
|
||||||
|
(
|
||||||
|
block_out[i % 2], # (T, H, D)
|
||||||
|
block_softmax_lse[i % 2], # (H, T)
|
||||||
|
rng_states[i],
|
||||||
|
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
|
||||||
|
RingAttention.ATTN_DONE.record()
|
||||||
|
|
||||||
|
block_softmax_lse[i % 2] = (
|
||||||
|
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
||||||
|
) # (H, T) -> (T, H, 1)
|
||||||
|
assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
|
||||||
|
# Output and log sum exp correction. Ideally overlap this with the next flash attn kernel.
|
||||||
|
# In reality this always finishes before next flash attn; no need for extra sync.
|
||||||
|
if i == 0:
|
||||||
|
out = block_out[0]
|
||||||
|
softmax_lse = block_softmax_lse[0]
|
||||||
|
elif i <= local_sp_rank:
|
||||||
|
out, softmax_lse = _rescale_out_lse(
|
||||||
|
out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse(
|
||||||
|
out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2]
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.cuda.current_stream().wait_stream(sp_stream)
|
||||||
|
return out, softmax_lse
|
||||||
|
|
||||||
|
def _other_ring_forward(ring_num_idx, out, softmax_lse):
|
||||||
|
# Loop through the inner ring after receiving
|
||||||
|
# all new KVs from the previous inner ring
|
||||||
|
for i in range(local_sp_size):
|
||||||
|
with torch.cuda.stream(sp_streams[i % 2]):
|
||||||
|
if not RingAttention.ATTN_DONE.query():
|
||||||
|
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
|
||||||
|
if i < local_sp_size - 1:
|
||||||
|
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
||||||
|
|
||||||
|
# Send & recv KV
|
||||||
|
if i > 0:
|
||||||
|
local_kv_comms[(i + 1) % 2].wait()
|
||||||
|
|
||||||
|
if ring_num_idx > inter_ring_rank:
|
||||||
|
kv_block = kv_buffers[i % 2]
|
||||||
|
(
|
||||||
|
block_out[i % 2],
|
||||||
|
block_softmax_lse[i % 2],
|
||||||
|
rng_states[i + local_sp_size * ring_num_idx],
|
||||||
|
) = _forward(q1, kv_block[0], kv_block[1], causal=False)
|
||||||
|
RingAttention.ATTN_DONE.record()
|
||||||
|
block_softmax_lse[i % 2] = (
|
||||||
|
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
||||||
|
)
|
||||||
|
out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse(
|
||||||
|
out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kv_block = kv_buffers[i % 2][:, half_idx_front]
|
||||||
|
(
|
||||||
|
block_out[i % 2],
|
||||||
|
block_softmax_lse[i % 2],
|
||||||
|
rng_states[i + local_sp_size * ring_num_idx],
|
||||||
|
) = _forward(q, kv_block[0], kv_block[1], causal=False)
|
||||||
|
RingAttention.ATTN_DONE.record()
|
||||||
|
block_softmax_lse[i % 2] = (
|
||||||
|
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
|
||||||
|
)
|
||||||
|
out, softmax_lse = _rescale_out_lse(
|
||||||
|
out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2]
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.cuda.current_stream().wait_stream(sp_stream)
|
||||||
|
return out, softmax_lse
|
||||||
|
|
||||||
|
# Send and recv KV between rings at once to maximize NIC util.
|
||||||
|
inter_ring_kv = None
|
||||||
|
for ring_num_idx in range(num_rings):
|
||||||
|
if ring_num_idx > 0:
|
||||||
|
inter_ring_comm.wait()
|
||||||
|
# Reset indices
|
||||||
|
kv_buffers[0] = inter_ring_kv
|
||||||
|
|
||||||
|
if ring_num_idx < num_rings - 1:
|
||||||
|
if ring_num_idx == 0:
|
||||||
|
to_send = kv_buffers[0]
|
||||||
|
else:
|
||||||
|
# The last received KV
|
||||||
|
to_send = kv_buffers[(local_sp_size - 1) % 2]
|
||||||
|
inter_ring_kv = inter_ring_comm.send_recv(to_send)
|
||||||
|
|
||||||
|
if ring_num_idx == 0:
|
||||||
|
out, softmax_lse = _local_ring_forward()
|
||||||
|
else:
|
||||||
|
out, softmax_lse = _other_ring_forward(ring_num_idx, out, softmax_lse)
|
||||||
|
|
||||||
|
out = out.to(q.dtype)
|
||||||
|
if not is_packed:
|
||||||
|
out = out.view(b, sq, h, d)
|
||||||
|
q, k, v = [x.view(b, sq, *x.shape[-2:]) for x in (q, k, v)] # (T, H, D) -> (B, Sq, H, D)
|
||||||
|
softmax_lse = softmax_lse.squeeze(-1)
|
||||||
|
|
||||||
|
ctx.sp_group = sp_group
|
||||||
|
ctx.max_seqlen_q = ctx.max_seqlen_kv = max_seqlen
|
||||||
|
misc_kwargs["deterministic"] = deterministic
|
||||||
|
del misc_kwargs["return_softmax"]
|
||||||
|
ctx.misc_kwargs = misc_kwargs
|
||||||
|
ctx.is_packed = is_packed
|
||||||
|
|
||||||
|
ctx.kv_group = inner_ring_group
|
||||||
|
ctx.inter_kv_group = inter_ring_group
|
||||||
|
|
||||||
|
ctx.save_for_backward(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
softmax_lse.transpose(0, 1).contiguous(), # (T, H) -> (H, T)
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_kv,
|
||||||
|
half_idx_front,
|
||||||
|
half_idx_back,
|
||||||
|
*rng_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
if return_softmax:
|
||||||
|
return out, softmax_lse
|
||||||
|
return out, None
|
||||||
|
|
||||||
|
def backward(ctx, dout, _):
|
||||||
|
"""
|
||||||
|
During backward, we accumulate q grads on each rank locally, but iterate kv and their grads
|
||||||
|
over all ranks for accumulation.
|
||||||
|
"""
|
||||||
|
(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9]
|
||||||
|
rng_states = ctx.saved_tensors[9:]
|
||||||
|
|
||||||
|
is_packed = ctx.is_packed
|
||||||
|
max_seqlen_q = ctx.max_seqlen_q
|
||||||
|
max_seqlen_kv = ctx.max_seqlen_kv
|
||||||
|
cu_seqlens_half = cu_seqlens_q // 2
|
||||||
|
max_seqlen_half = max_seqlen_q // 2
|
||||||
|
misc_kwargs = ctx.misc_kwargs
|
||||||
|
del misc_kwargs["block_table"]
|
||||||
|
|
||||||
|
assert (
|
||||||
|
out.shape == dout.shape == q.shape
|
||||||
|
), f"out {out.shape} and dout {dout.shape} should have the same shape ({q.shape})."
|
||||||
|
|
||||||
|
if is_packed:
|
||||||
|
t, h, d = q.shape
|
||||||
|
else:
|
||||||
|
b, sq, h, d = q.shape
|
||||||
|
t = b * sq
|
||||||
|
q, k, v, out, dout = [x.view(t, *x.shape[-2:]) for x in (q, k, v, out, dout)]
|
||||||
|
|
||||||
|
# Sequence parallel args
|
||||||
|
sp_group = ctx.sp_group
|
||||||
|
local_kv_group = ctx.kv_group
|
||||||
|
inter_kv_group = ctx.inter_kv_group
|
||||||
|
|
||||||
|
local_sp_rank = dist.get_rank(sp_group)
|
||||||
|
sp_size = dist.get_world_size(sp_group)
|
||||||
|
# Using separate streams (pg) for concurrent kv and dkv comm may
|
||||||
|
# cause NCCL "software caused connection abort" here...
|
||||||
|
local_kv_comm = RingComm(local_kv_group)
|
||||||
|
local_dkv_comm = RingComm(local_kv_group)
|
||||||
|
inter_kv_comm = RingComm(inter_kv_group)
|
||||||
|
inter_dkv_comm = RingComm(inter_kv_group)
|
||||||
|
local_sp_size = dist.get_world_size(local_kv_group)
|
||||||
|
local_sp_rank = dist.get_rank(local_kv_group)
|
||||||
|
|
||||||
|
if dist.get_world_size(inter_kv_group) != sp_size:
|
||||||
|
num_rings = dist.get_world_size(inter_kv_group)
|
||||||
|
inter_ring_rank = dist.get_rank(inter_kv_group)
|
||||||
|
else:
|
||||||
|
num_rings = 1
|
||||||
|
inter_ring_rank = 0
|
||||||
|
|
||||||
|
if local_sp_rank != sp_size - 1:
|
||||||
|
softmax_lse1 = softmax_lse[:, half_idx_back]
|
||||||
|
dout = dout.contiguous()
|
||||||
|
|
||||||
|
# Double comm buffers for sending and receiving kv
|
||||||
|
kv_buffers = [torch.stack((k, v))] # (2, T, H, D)
|
||||||
|
kv_buffers.append(torch.empty_like(kv_buffers[0]))
|
||||||
|
|
||||||
|
dq = None # (T, H, D)
|
||||||
|
# Intermediate outputs
|
||||||
|
dq_block = torch.empty_like(q) # (T, H, D)
|
||||||
|
dk_block = torch.empty_like(k) # (T, H, D)
|
||||||
|
dv_block = torch.empty_like(v) # (T, H, D)
|
||||||
|
dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D)
|
||||||
|
del k, v
|
||||||
|
|
||||||
|
def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal):
|
||||||
|
_flash_attn_backward(
|
||||||
|
dout,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
softmax_lse,
|
||||||
|
dq,
|
||||||
|
dk,
|
||||||
|
dv,
|
||||||
|
cu_seqlens_q if dq.shape[0] == t else cu_seqlens_half,
|
||||||
|
cu_seqlens_kv if dk.shape[0] == t else cu_seqlens_half,
|
||||||
|
max_seqlen_q if dq.shape[0] == t else max_seqlen_half,
|
||||||
|
max_seqlen_kv if dk.shape[0] == t else max_seqlen_half,
|
||||||
|
causal=causal,
|
||||||
|
rng_state=rng_state,
|
||||||
|
**misc_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: We avoid using two streams due to doubled buffers
|
||||||
|
# and that backward is more communication intensive.
|
||||||
|
def _local_ring_backward():
|
||||||
|
for i in range(local_sp_size):
|
||||||
|
if i > 0:
|
||||||
|
local_kv_comm.wait()
|
||||||
|
|
||||||
|
if i < local_sp_size - 1:
|
||||||
|
# Send kv to next rank for backward
|
||||||
|
local_kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
# Backward with local kv
|
||||||
|
k_, v_ = kv_buffers[i % 2]
|
||||||
|
q_, dout_, out_ = q, dout, out
|
||||||
|
dq_, dk_, dv_ = dq_block, dk_block, dv_block
|
||||||
|
_backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_states[i], causal=True)
|
||||||
|
|
||||||
|
elif i <= local_sp_rank:
|
||||||
|
# Drop the second half of kv
|
||||||
|
# (T, H, D) -> (T // 2, H, D)
|
||||||
|
k_, v_ = [x[half_idx_front] for x in kv_buffers[i % 2]]
|
||||||
|
dk_, dv_ = [x[: t // 2] for x in (dk_block, dv_block)]
|
||||||
|
dq_, q_, out_, dout_ = (dq_block, q, out, dout)
|
||||||
|
_backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_states[i], causal=False)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Drop the first half of q
|
||||||
|
k_, v_ = kv_buffers[i % 2]
|
||||||
|
dk_, dv_ = dk_block, dv_block
|
||||||
|
q_, out_, dout_ = [x[half_idx_back] for x in (q, out, dout)]
|
||||||
|
dq_ = dq_block[: t // 2]
|
||||||
|
_backward(dout_, q_, k_, v_, out_, softmax_lse1, dq_, dk_, dv_, rng_states[i], causal=False)
|
||||||
|
|
||||||
|
# Accumulate grads
|
||||||
|
if i == 0:
|
||||||
|
dq = dq_block.float()
|
||||||
|
dkv_buffers[i % 2][0] = dk_block.float()
|
||||||
|
dkv_buffers[i % 2][1] = dv_block.float()
|
||||||
|
else:
|
||||||
|
# Accumulate local dq
|
||||||
|
if i <= local_sp_rank:
|
||||||
|
dq += dq_ # (T, H, D)
|
||||||
|
else:
|
||||||
|
dq[half_idx_back] += dq_
|
||||||
|
|
||||||
|
# Wait for mobile kv grad accumulators
|
||||||
|
local_dkv_comm.wait()
|
||||||
|
|
||||||
|
if i <= local_sp_rank:
|
||||||
|
# q blocks "surrounded" by kv blocks
|
||||||
|
dkv_buffers[i % 2][0][half_idx_front] += dk_
|
||||||
|
dkv_buffers[i % 2][1][half_idx_front] += dv_
|
||||||
|
else:
|
||||||
|
# q blocks "surrounding" kv blocks
|
||||||
|
dkv_buffers[i % 2][0] += dk_
|
||||||
|
dkv_buffers[i % 2][1] += dv_
|
||||||
|
local_dkv_comm.send_recv(send_tensor=dkv_buffers[i % 2], recv_tensor=dkv_buffers[(i + 1) % 2])
|
||||||
|
|
||||||
|
local_dkv_comm.wait()
|
||||||
|
dkv_recv = dkv_buffers[local_sp_size % 2]
|
||||||
|
dkv_send = dkv_buffers[(local_sp_size - 1) % 2]
|
||||||
|
return dq, dkv_recv, dkv_send
|
||||||
|
|
||||||
|
def _other_ring_backward(ring_num_idx, dq):
|
||||||
|
if ring_num_idx > inter_ring_rank:
|
||||||
|
# Indexing is expensive
|
||||||
|
q_, out_, dout_ = [x[half_idx_back] for x in (q, out, dout)]
|
||||||
|
else:
|
||||||
|
q_, out_, dout_ = (q, out, dout)
|
||||||
|
|
||||||
|
for i in range(local_sp_size):
|
||||||
|
if i > 0:
|
||||||
|
local_kv_comm.wait()
|
||||||
|
|
||||||
|
if i < local_sp_size - 1:
|
||||||
|
local_kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
||||||
|
|
||||||
|
rng_state = rng_states[i + local_sp_size * ring_num_idx]
|
||||||
|
if ring_num_idx > inter_ring_rank:
|
||||||
|
k_, v_ = kv_buffers[i % 2]
|
||||||
|
dk_, dv_ = dk_block, dv_block
|
||||||
|
dq_ = dq_block[: t // 2]
|
||||||
|
_backward(dout_, q_, k_, v_, out_, softmax_lse1, dq_, dk_, dv_, rng_state, causal=False)
|
||||||
|
|
||||||
|
dq[half_idx_back] += dq_
|
||||||
|
if i > 0:
|
||||||
|
local_dkv_comm.wait()
|
||||||
|
else:
|
||||||
|
inter_dkv_comm.wait()
|
||||||
|
|
||||||
|
dkv_buffers[i % 2][0] += dk_
|
||||||
|
dkv_buffers[i % 2][1] += dv_
|
||||||
|
else:
|
||||||
|
k_, v_ = [x[half_idx_front] for x in kv_buffers[i % 2]]
|
||||||
|
dk_, dv_ = [x[: t // 2] for x in (dk_block, dv_block)]
|
||||||
|
dq_ = dq_block
|
||||||
|
_backward(dout_, q_, k_, v_, out_, softmax_lse, dq_, dk_, dv_, rng_state, causal=False)
|
||||||
|
|
||||||
|
dq += dq_
|
||||||
|
if i > 0:
|
||||||
|
local_dkv_comm.wait()
|
||||||
|
else:
|
||||||
|
inter_dkv_comm.wait()
|
||||||
|
|
||||||
|
dkv_buffers[i % 2][0][half_idx_front] += dk_
|
||||||
|
dkv_buffers[i % 2][1][half_idx_front] += dv_
|
||||||
|
|
||||||
|
local_dkv_comm.send_recv(send_tensor=dkv_buffers[i % 2], recv_tensor=dkv_buffers[(i + 1) % 2])
|
||||||
|
|
||||||
|
local_dkv_comm.wait()
|
||||||
|
dkv_recv = dkv_buffers[local_sp_size % 2]
|
||||||
|
dkv_send = dkv_buffers[(local_sp_size - 1) % 2]
|
||||||
|
return dq, dkv_recv, dkv_send
|
||||||
|
|
||||||
|
inter_ring_kv = None
|
||||||
|
for ring_num_idx in range(num_rings):
|
||||||
|
if ring_num_idx > 0:
|
||||||
|
inter_kv_comm.wait()
|
||||||
|
kv_buffers[0] = inter_ring_kv
|
||||||
|
|
||||||
|
if ring_num_idx < num_rings - 1:
|
||||||
|
# Re-allocate a buffer in each inter-ring step
|
||||||
|
inter_ring_kv = inter_kv_comm.send_recv(kv_buffers[0])
|
||||||
|
|
||||||
|
if ring_num_idx == 0:
|
||||||
|
dq, dkv_recv, dkv_send = _local_ring_backward()
|
||||||
|
else:
|
||||||
|
dq, dkv_recv, dkv_send = _other_ring_backward(ring_num_idx, dq)
|
||||||
|
|
||||||
|
if num_rings > 1:
|
||||||
|
# Reuse the local buffers
|
||||||
|
inter_dkv_comm.send_recv(send_tensor=dkv_recv, recv_tensor=dkv_send)
|
||||||
|
# Reset indices
|
||||||
|
dkv_buffers[0] = dkv_send
|
||||||
|
dkv_buffers[1] = dkv_recv
|
||||||
|
if ring_num_idx == num_rings - 1:
|
||||||
|
inter_dkv_comm.wait()
|
||||||
|
dkv_recv = dkv_buffers[0]
|
||||||
|
|
||||||
|
dq, dk, dv = [x.to(q.dtype) for x in (dq, *dkv_recv)]
|
||||||
|
if not is_packed:
|
||||||
|
dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)]
|
||||||
|
|
||||||
|
return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prepare_varlen_batch(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sp_group: dist.ProcessGroup,
|
||||||
|
inputs_embeds: torch.Tensor = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
is_label: bool = False,
|
||||||
|
is_2d: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Preprocess a batch of padded sequence by splitting input sequence by sp_size
|
||||||
|
sequence-wise and packing them into one sequence. Updates the mask info accordingly.
|
||||||
|
Args:
|
||||||
|
attention_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked.
|
||||||
|
sp_group (dist.ProcessGroup): Process group for sequence parallelism
|
||||||
|
inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...]
|
||||||
|
position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None.
|
||||||
|
is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first
|
||||||
|
token of each sequence.
|
||||||
|
is_2d (bool, optional): Whether to return 2D outputs padded to max_seqlen // sp_size or flatten
|
||||||
|
the batch dim to a packed 1d sequence. Contingent on model forward shape definitions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
inputs_embeds: Packed input embeddings of shape [B, Sq // sp_size, ...].
|
||||||
|
mask_info: A dictionary of mask info.
|
||||||
|
position_ids: Packed position ids of shape [..., Sq // sp_size].
|
||||||
|
|
||||||
|
"""
|
||||||
|
_load_varlen_helpers()
|
||||||
|
sp_size = dist.get_world_size(group=sp_group)
|
||||||
|
sp_rank = dist.get_rank(group=sp_group)
|
||||||
|
mask_info = {}
|
||||||
|
mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(attention_mask, return_indices=False)
|
||||||
|
|
||||||
|
# Unpad, split seq-wise, then pad back to (B, max_seqlen // sp_size)
|
||||||
|
# Split mask to compute local nonzero position indices
|
||||||
|
# (B, Sq) -> (B, max_seqlen // sp_size)
|
||||||
|
attention_mask = attention_mask[:, : mask_info["max_seqlen"]]
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]]
|
||||||
|
inputs_embeds = split_varlen_zigzag(
|
||||||
|
inputs_embeds,
|
||||||
|
mask_info["cu_seqlens"],
|
||||||
|
sp_group,
|
||||||
|
mask_info["max_seqlen"],
|
||||||
|
is_2d=is_2d,
|
||||||
|
is_label=is_label,
|
||||||
|
)
|
||||||
|
attention_mask = split_varlen_zigzag(
|
||||||
|
attention_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_2d=is_2d
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_ids is not None:
|
||||||
|
indices = torch.tensor([sp_rank, 2 * sp_size - sp_rank - 1], device=inputs_embeds.device)
|
||||||
|
position_ids = (
|
||||||
|
position_ids[..., : mask_info["max_seqlen"]] # unpad
|
||||||
|
.view(-1, sp_size * 2, mask_info["max_seqlen"] // (sp_size * 2))
|
||||||
|
.index_select(-2, indices)
|
||||||
|
.view(-1, mask_info["max_seqlen"] // sp_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
mask_info["max_seqlen"] //= sp_size
|
||||||
|
mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||||
|
mask_info["cu_seqlens"] //= sp_size
|
||||||
|
mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
||||||
|
return inputs_embeds, mask_info, position_ids
|
||||||
|
|
|
@ -200,9 +200,7 @@ class Linear1D_Col(ParallelModule):
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
|
||||||
if self.seq_parallel_mode is None:
|
if self.seq_parallel_mode == "split_gather":
|
||||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
|
||||||
elif self.seq_parallel_mode == "split_gather":
|
|
||||||
input_parallel = gather_forward_reducescatter_backward(
|
input_parallel = gather_forward_reducescatter_backward(
|
||||||
input_parallel, self.process_group, self.seq_parallel_dim
|
input_parallel, self.process_group, self.seq_parallel_dim
|
||||||
)
|
)
|
||||||
|
@ -211,6 +209,8 @@ class Linear1D_Col(ParallelModule):
|
||||||
output_parallel = linear_gather_forward_reducescatter_backward(
|
output_parallel = linear_gather_forward_reducescatter_backward(
|
||||||
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
|
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||||
|
|
||||||
if self.gather_output:
|
if self.gather_output:
|
||||||
# All-gather across the partitions.
|
# All-gather across the partitions.
|
||||||
|
@ -416,10 +416,7 @@ class Linear1D_Row(ParallelModule):
|
||||||
handle.wait()
|
handle.wait()
|
||||||
output = torch.cat(output_parallel_list, dim=-1)
|
output = torch.cat(output_parallel_list, dim=-1)
|
||||||
else:
|
else:
|
||||||
if self.seq_parallel_mode is None:
|
if self.seq_parallel_mode == "split_gather":
|
||||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
|
||||||
output = reduce_forward(output_parallel, self.process_group)
|
|
||||||
elif self.seq_parallel_mode == "split_gather":
|
|
||||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||||
output = reducescatter_forward_gather_backward(
|
output = reducescatter_forward_gather_backward(
|
||||||
output_parallel, self.process_group, self.seq_parallel_dim
|
output_parallel, self.process_group, self.seq_parallel_dim
|
||||||
|
@ -432,6 +429,9 @@ class Linear1D_Row(ParallelModule):
|
||||||
dim=self.seq_parallel_dim,
|
dim=self.seq_parallel_dim,
|
||||||
ring=True,
|
ring=True,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||||
|
output = reduce_forward(output_parallel, self.process_group)
|
||||||
|
|
||||||
if not self.skip_bias_add:
|
if not self.skip_bias_add:
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
|
|
|
@ -4,10 +4,15 @@ from torch.autograd import Function
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
|
from colossalai.shardformer.layer._operation import reduce_forward
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
|
from .utils import is_share_sp_tp
|
||||||
|
|
||||||
__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]
|
__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]
|
||||||
|
|
||||||
|
_IGNORE_IDX = -100
|
||||||
|
|
||||||
|
|
||||||
class DistCrossEntropy(Function):
|
class DistCrossEntropy(Function):
|
||||||
r"""
|
r"""
|
||||||
|
@ -26,11 +31,12 @@ class DistCrossEntropy(Function):
|
||||||
process_group: ProcessGroup,
|
process_group: ProcessGroup,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
|
mode="mean",
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Calculate the cross entropy loss before gather, the origin loss function is as follows:
|
Calculate the cross entropy loss before gather, the origin loss function is as follows:
|
||||||
loss = -log(exp(x[class])/sum(exp(x[i]))
|
loss = -log(exp(x[class])/sum(exp(x[i]))
|
||||||
and can be rewrite as:
|
and can be rewriten as:
|
||||||
loss = log(sum(exp(x[i])) - x[class]
|
loss = log(sum(exp(x[i])) - x[class]
|
||||||
|
|
||||||
To avoid the `nan` of log(sum(exp(x[i]))), we minus the max of x[i]
|
To avoid the `nan` of log(sum(exp(x[i]))), we minus the max of x[i]
|
||||||
|
@ -44,12 +50,10 @@ class DistCrossEntropy(Function):
|
||||||
Returns:
|
Returns:
|
||||||
:class:`torch.Tensor`: The cross entropy loss
|
:class:`torch.Tensor`: The cross entropy loss
|
||||||
"""
|
"""
|
||||||
|
assert mode in ["mean", "sum"]
|
||||||
# get the max
|
# get the max
|
||||||
logits_max = torch.max(vocab_logits, dim=-1)[0]
|
logits_max = torch.max(vocab_logits, dim=-1)[0]
|
||||||
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group)
|
handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True)
|
||||||
|
|
||||||
# minus the max to avoid the result of sum of exp is too large and the log is nan
|
|
||||||
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
|
|
||||||
|
|
||||||
# mask the target in the local device
|
# mask the target in the local device
|
||||||
rank = dist.get_rank(group=process_group)
|
rank = dist.get_rank(group=process_group)
|
||||||
|
@ -70,24 +74,25 @@ class DistCrossEntropy(Function):
|
||||||
mask = (target < down_threshold) | (target >= up_threshold)
|
mask = (target < down_threshold) | (target >= up_threshold)
|
||||||
masked_target = target.clone() - down_threshold
|
masked_target = target.clone() - down_threshold
|
||||||
masked_target[mask] = 0
|
masked_target[mask] = 0
|
||||||
|
masked_target_1d = masked_target.view(-1).contiguous()
|
||||||
|
|
||||||
|
# minus the max to avoid the result of sum of exp is too large and the log is nan
|
||||||
|
handle.wait()
|
||||||
|
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
|
||||||
# reshape the logits and target
|
# reshape the logits and target
|
||||||
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
|
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
|
||||||
# reshape the labels to [bath_size * seq_len]
|
# reshape the labels to [bath_size * seq_len]
|
||||||
self_vocab_size = vocab_logits.size()[-1]
|
self_vocab_size = vocab_logits.size()[-1]
|
||||||
logits_2d = vocab_logits.view(-1, self_vocab_size)
|
logits_2d = vocab_logits.view(-1, self_vocab_size)
|
||||||
masked_target_1d = masked_target.view(-1)
|
|
||||||
|
|
||||||
# extract the x[class] and set the x[other device] to zero
|
# extract the x[class] and set the x[other device] to zero
|
||||||
pred_logits_1d = logits_2d[
|
idx = torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device)
|
||||||
torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device), masked_target_1d
|
pred_logits_1d = logits_2d[idx, masked_target_1d].contiguous()
|
||||||
]
|
|
||||||
pred_logits_1d = pred_logits_1d.clone().contiguous()
|
|
||||||
pred_logits = pred_logits_1d.view_as(target)
|
pred_logits = pred_logits_1d.view_as(target)
|
||||||
pred_logits[mask] = 0.0
|
pred_logits[mask] = 0.0
|
||||||
|
|
||||||
# allreduce the get all x(i,y)
|
# all-reduce to get full x[i, y]
|
||||||
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group)
|
handle = dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group, async_op=True)
|
||||||
exp_logits = vocab_logits
|
exp_logits = vocab_logits
|
||||||
torch.exp(vocab_logits, out=exp_logits)
|
torch.exp(vocab_logits, out=exp_logits)
|
||||||
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32)
|
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32)
|
||||||
|
@ -95,23 +100,29 @@ class DistCrossEntropy(Function):
|
||||||
|
|
||||||
# calculate the loss
|
# calculate the loss
|
||||||
# loss = log(sum(exp(x[i]))) - x[class]
|
# loss = log(sum(exp(x[i]))) - x[class]
|
||||||
|
handle.wait()
|
||||||
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
|
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
|
||||||
num_non_zero = torch.sum(loss != 0.0)
|
if mode == "mean":
|
||||||
ctx.inv_num_non_zero = 1.0 / num_non_zero
|
num_non_zero = torch.sum(loss != 0.0)
|
||||||
loss = torch.sum(loss).div_(num_non_zero)
|
ctx.inv_num_non_zero = 1.0 / num_non_zero
|
||||||
|
loss = torch.sum(loss).div_(num_non_zero)
|
||||||
|
else:
|
||||||
|
loss = torch.sum(loss)
|
||||||
|
|
||||||
# calculate the softmax
|
# calculate the softmax
|
||||||
exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype)
|
exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype)
|
||||||
exp_logits[target == ignore_index] = 0.0
|
exp_logits[target == ignore_index] = 0.0
|
||||||
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
|
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
|
||||||
ctx.dtype = dtype
|
ctx.dtype = dtype
|
||||||
|
ctx.mode = mode
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
# retrieve the saved tensors
|
# retrieve the saved tensors
|
||||||
grad_output = grad_output * ctx.inv_num_non_zero
|
if ctx.mode == "mean":
|
||||||
|
grad_output = grad_output * ctx.inv_num_non_zero
|
||||||
exp_logits, mask, masked_target_1d = ctx.saved_tensors
|
exp_logits, mask, masked_target_1d = ctx.saved_tensors
|
||||||
|
|
||||||
# use exp logits as the input grad
|
# use exp logits as the input grad
|
||||||
|
@ -123,55 +134,113 @@ class DistCrossEntropy(Function):
|
||||||
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
|
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
|
||||||
|
|
||||||
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
|
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
|
||||||
return grad_logits, None, None, None, None, None
|
return grad_logits, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def cross_entropy_1d(
|
def cross_entropy_1d(
|
||||||
vocab_logits: torch.Tensor,
|
vocab_logits: torch.Tensor,
|
||||||
labels: torch.Tensor,
|
labels: torch.Tensor,
|
||||||
ignore_index: int = -100,
|
ignore_index: int = _IGNORE_IDX,
|
||||||
process_group: ProcessGroup = None,
|
process_group: ProcessGroup = None,
|
||||||
vocab_size: int = None,
|
vocab_size: int = None,
|
||||||
dtype: torch.dtype = None,
|
dtype: torch.dtype = None,
|
||||||
|
mode: str = "mean",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype)
|
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode)
|
||||||
|
|
||||||
|
|
||||||
def dist_cross_entropy(
|
def dist_cross_entropy(
|
||||||
labels: torch.Tensor,
|
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor, # [B, S, Vocab_size]
|
||||||
shard_config: ShardConfig,
|
shard_config: ShardConfig,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
seq_dim: int = 1,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Helper to compute cross entropy loss for most shardformer models,
|
Helper to compute cross entropy loss for most shardformer models supporting PP, TP and SP.
|
||||||
compatible with PP, TP and SP.
|
|
||||||
"""
|
"""
|
||||||
if labels is not None:
|
# Split labels if not gather output
|
||||||
# Shift so that tokens < n predict n
|
sp_group = shard_config.sequence_parallel_process_group
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
sp_rank = dist.get_rank(sp_group)
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
sp_size = shard_config.sequence_parallel_size
|
||||||
# Flatten the tokens
|
sp_mode = shard_config.sequence_parallelism_mode
|
||||||
loss_fct = CrossEntropyLoss()
|
parallel_output = shard_config.parallel_output
|
||||||
shift_labels = shift_labels.view(-1)
|
is_tp = shard_config.enable_tensor_parallelism
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
is_packed = labels.dim() == 2
|
||||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
if is_packed:
|
||||||
# Cross entropy with all-reduce for TP
|
bs, seq_len = labels.shape
|
||||||
new_vocab_size = logits.shape[-1]
|
else:
|
||||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
# padded sequence
|
||||||
loss = cross_entropy_1d(
|
seq_len = labels.shape[-1]
|
||||||
shift_logits,
|
logits = logits.reshape(-1, *logits.shape[2:])
|
||||||
shift_labels,
|
seq_dim = 0
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
|
||||||
vocab_size=out_features,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# NOTE if use TP and not parallel_output, the output is gathered.
|
|
||||||
# see VocabParallelLMHead1D
|
|
||||||
shift_logits = shift_logits.view(-1, vocab_size)
|
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
|
||||||
|
|
||||||
return loss
|
# Shift labels to predict the next token, and remove the tail logit predicting <EOS>
|
||||||
|
is_sp = sp_size > 1 and (not is_share_sp_tp(sp_mode))
|
||||||
|
split_labels_here = seq_len // sp_size == logits.size(seq_dim) # ring attn splits labels before forward
|
||||||
|
|
||||||
|
if sp_mode == "ring_attn":
|
||||||
|
# For Zigzag Ring Attention, labels should've been split and
|
||||||
|
# shifted by RingAttention.prepare_varlen_batch()
|
||||||
|
if sp_rank == 0:
|
||||||
|
logits = logits[..., :-1, :]
|
||||||
|
logits = torch.cat([logits, torch.full_like(logits[:, :1, :], _IGNORE_IDX)], dim=seq_dim)
|
||||||
|
elif is_sp:
|
||||||
|
# Shift only once: either before splitting or in the last rank without splitting
|
||||||
|
if split_labels_here or (sp_rank == sp_size - 1):
|
||||||
|
labels = labels[..., 1:]
|
||||||
|
if split_labels_here:
|
||||||
|
labels = labels.split(seq_len // sp_size, dim=-1)[sp_rank]
|
||||||
|
|
||||||
|
if sp_rank == sp_size - 1:
|
||||||
|
logits = logits[..., :-1, :]
|
||||||
|
# Pad logits and labels to the same shape across all ranks for TP all_reduce
|
||||||
|
if is_tp and parallel_output:
|
||||||
|
# If is packed sequence (label dim is 1), then each seq already has the end label token padded.
|
||||||
|
# torch.cat is faster than F.pad...
|
||||||
|
pad_shape = (logits.shape[0], 1, *logits.shape[2:]) if is_packed else (1, *logits.shape[1:])
|
||||||
|
padding = torch.full(pad_shape, _IGNORE_IDX, dtype=logits.dtype, device=logits.device)
|
||||||
|
logits = torch.cat([logits, padding], dim=seq_dim)
|
||||||
|
pad_shape = (labels.shape[0], 1) if is_packed else (1,)
|
||||||
|
padding = torch.full(pad_shape, _IGNORE_IDX, dtype=labels.dtype, device=labels.device)
|
||||||
|
labels = torch.cat([labels, padding], dim=seq_dim)
|
||||||
|
else:
|
||||||
|
labels = labels[..., 1:]
|
||||||
|
logits = logits[..., :-1, :]
|
||||||
|
labels = labels.contiguous()
|
||||||
|
logits = logits.contiguous()
|
||||||
|
num_nonzero = (labels != _IGNORE_IDX).sum()
|
||||||
|
assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}"
|
||||||
|
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss(ignore_index=_IGNORE_IDX, reduction="sum")
|
||||||
|
labels = labels.view(-1)
|
||||||
|
|
||||||
|
if is_tp and parallel_output:
|
||||||
|
# Cross entropy with all-reduce for TP
|
||||||
|
new_vocab_size = logits.shape[-1]
|
||||||
|
logits = logits.view(-1, new_vocab_size)
|
||||||
|
loss = cross_entropy_1d(
|
||||||
|
logits,
|
||||||
|
labels,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
|
vocab_size=out_features,
|
||||||
|
dtype=dtype,
|
||||||
|
mode="sum",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D
|
||||||
|
logits = logits.view(-1, vocab_size)
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
|
||||||
|
# Reduce loss instead of gathering logits over seq dim for savings
|
||||||
|
if split_labels_here or sp_mode == "ring_attn":
|
||||||
|
# Get the global non-zero count
|
||||||
|
loss = torch.stack((loss, num_nonzero))
|
||||||
|
# Rescale to offset the grad / (DP * SP) in HybridParallelPlugin
|
||||||
|
loss = reduce_forward(loss, sp_group, grad_scale=sp_size)
|
||||||
|
loss, num_nonzero = loss[0], loss[1].detach()
|
||||||
|
loss = (loss / num_nonzero).squeeze()
|
||||||
|
return loss
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import List
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -289,3 +289,199 @@ def create_randomizer_with_offset(
|
||||||
Randomizer.increment_index()
|
Randomizer.increment_index()
|
||||||
|
|
||||||
return Randomizer(seed=base_seed)
|
return Randomizer(seed=base_seed)
|
||||||
|
|
||||||
|
|
||||||
|
def split_batch_zigzag(
|
||||||
|
batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False
|
||||||
|
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask
|
||||||
|
in the causal setting will result in the preceding ranks having much less workload.
|
||||||
|
We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
|
||||||
|
For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (List[torch.Tensor] or Tensor): The input tensor(s) to split.
|
||||||
|
sp_group (ProcessGroup): The process group for sequence parallelism.
|
||||||
|
seq_dim (int): The sequence dimension to split.
|
||||||
|
is_label (bool): If True, mask and shift the tensor for next token prediction.
|
||||||
|
|
||||||
|
"""
|
||||||
|
sp_size = dist.get_world_size(sp_group)
|
||||||
|
sp_rank = dist.get_rank(sp_group)
|
||||||
|
if isinstance(batch, torch.Tensor):
|
||||||
|
batch = [batch]
|
||||||
|
seq_dim = seq_dim if seq_dim != -1 else batch[0].dim() - 1
|
||||||
|
|
||||||
|
if sp_size > 1:
|
||||||
|
for idx, tensor in enumerate(batch):
|
||||||
|
assert (
|
||||||
|
tensor.shape[seq_dim] // (sp_size * 2) > 1 and tensor.shape[seq_dim] % (sp_size * 2) == 0
|
||||||
|
), f"Bro, the seq length {tensor.shape[seq_dim]} for tensor {idx} can't be split by {sp_size * 2}!"
|
||||||
|
if is_label:
|
||||||
|
assert tensor.dim() == 2, "Label shape should be (B, Seqlen)"
|
||||||
|
tensor = torch.cat([tensor[:, 1:], torch.full_like(tensor[:, :1], -100)], dim=1)
|
||||||
|
|
||||||
|
tensor = tensor.view(
|
||||||
|
*tensor.shape[:seq_dim],
|
||||||
|
2 * sp_size,
|
||||||
|
tensor.shape[seq_dim] // (2 * sp_size),
|
||||||
|
*tensor.shape[seq_dim + 1 :],
|
||||||
|
)
|
||||||
|
indices = torch.tensor([sp_rank, 2 * sp_size - 1 - sp_rank], device=tensor.device)
|
||||||
|
tensor = tensor.index_select(seq_dim, indices).contiguous()
|
||||||
|
# (B, 2, Sq // (2 * sp_size), ...) -> (B, Sq // sp_size, ...)
|
||||||
|
batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :])
|
||||||
|
|
||||||
|
if len(batch) == 1:
|
||||||
|
return batch[0]
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def split_varlen_zigzag(
|
||||||
|
batch: Union[List[torch.Tensor], torch.Tensor],
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
sp_group: ProcessGroup,
|
||||||
|
max_seqlen: int = 0,
|
||||||
|
is_2d: bool = False,
|
||||||
|
is_label: bool = False,
|
||||||
|
) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||||
|
"""Split each sequence in a batch of packed sequences in a zigzag fashion.
|
||||||
|
For each tensor in batch, return packed sequences if is_2d is False;
|
||||||
|
else return a padded batch of sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (List[torch.Tensor]): Packed sequences of shape (B * Sq, ...), or (B, Sq, ...) if is_2d.
|
||||||
|
cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting.
|
||||||
|
sp_group (ProcessGroup): The process group for sequence parallelism.
|
||||||
|
max_seqlen (int): The maximum sequence length in the batch before splitting.
|
||||||
|
is_2d (bool): If True, then input has batch size and sequence length split into two dimensions.
|
||||||
|
is_label (bool): If True, mask out the first token in each sequence (<Start of Sentence>).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size)
|
||||||
|
or (B, max_seqlen // sp_size, ...) if is_2d
|
||||||
|
"""
|
||||||
|
sp_size = dist.get_world_size(sp_group)
|
||||||
|
sp_rank = dist.get_rank(sp_group)
|
||||||
|
if is_2d:
|
||||||
|
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
|
||||||
|
|
||||||
|
if isinstance(batch, torch.Tensor):
|
||||||
|
batch = [batch]
|
||||||
|
for i, packed_seq in enumerate(batch):
|
||||||
|
device = packed_seq.device
|
||||||
|
dtype = packed_seq.dtype
|
||||||
|
|
||||||
|
if is_2d:
|
||||||
|
assert max_seqlen % (sp_size * 2) == 0
|
||||||
|
# Recreate a padded tensor with the new max seqlen
|
||||||
|
shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:])
|
||||||
|
local_seq = torch.zeros(shape, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
total_seqlen = cu_seqlens[-1]
|
||||||
|
assert (
|
||||||
|
total_seqlen % (2 * sp_size) == 0
|
||||||
|
), f"total_seqlen {total_seqlen} must be divisible by 2 * sp_size = {2 * sp_size}"
|
||||||
|
local_seq = []
|
||||||
|
|
||||||
|
for j in range(len(cu_seqlens) - 1):
|
||||||
|
start, end = cu_seqlens[j], cu_seqlens[j + 1]
|
||||||
|
seqlen = end - start
|
||||||
|
assert (
|
||||||
|
seqlen % (2 * sp_size) == 0
|
||||||
|
), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting"
|
||||||
|
|
||||||
|
if is_2d:
|
||||||
|
seq = packed_seq[j][:seqlen]
|
||||||
|
if is_label:
|
||||||
|
# Shift one position to the right for next token prediction
|
||||||
|
seq = torch.cat([seq[1:], torch.tensor([-100], dtype=dtype, device=device)])
|
||||||
|
|
||||||
|
seq = seq.chunk(2 * sp_size, dim=0)
|
||||||
|
half = seqlen // sp_size // 2
|
||||||
|
local_seq[j][:half] = seq[sp_rank]
|
||||||
|
local_seq[j][half : seqlen // sp_size] = seq[2 * sp_size - 1 - sp_rank]
|
||||||
|
else:
|
||||||
|
seq = packed_seq[start:end]
|
||||||
|
if is_label:
|
||||||
|
seq = torch.cat(seq[1:], torch.tensor([-100], dtype=dtype, device=device))
|
||||||
|
seq = seq.chunk(sp_size * 2)
|
||||||
|
local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]])
|
||||||
|
|
||||||
|
if is_2d:
|
||||||
|
batch[i] = local_seq.contiguous()
|
||||||
|
else:
|
||||||
|
batch[i] = torch.cat(local_seq, dim=0)
|
||||||
|
|
||||||
|
if len(batch) == 1:
|
||||||
|
batch = batch[0]
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def is_share_sp_tp(sp_mode: str):
|
||||||
|
"""sp_mode "ring" and "split_gather" use the TP group as SP group
|
||||||
|
to split both the vocab and sequence, so we must gather the sequence
|
||||||
|
to correctly get logits at each positions.
|
||||||
|
"""
|
||||||
|
return sp_mode in ["ring", "split_gather"]
|
||||||
|
|
||||||
|
|
||||||
|
class RingComm:
|
||||||
|
def __init__(self, process_group: dist.ProcessGroup):
|
||||||
|
self._process_group = process_group
|
||||||
|
self._ops = []
|
||||||
|
self.rank = dist.get_rank(self._process_group)
|
||||||
|
self.world_size = dist.get_world_size(self._process_group)
|
||||||
|
self._reqs = []
|
||||||
|
|
||||||
|
self.send_rank = (self.rank + 1) % self.world_size
|
||||||
|
self.recv_rank = (self.rank - 1) % self.world_size
|
||||||
|
|
||||||
|
self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
|
||||||
|
self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
|
||||||
|
|
||||||
|
def send_recv(
|
||||||
|
self,
|
||||||
|
send_tensor: torch.Tensor,
|
||||||
|
recv_tensor: Optional[torch.Tensor] = None,
|
||||||
|
commit: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if recv_tensor is None:
|
||||||
|
res = torch.empty_like(send_tensor)
|
||||||
|
else:
|
||||||
|
res = recv_tensor
|
||||||
|
|
||||||
|
# looks like batch_isend_irecv doesn't deadlock even
|
||||||
|
# when we don't swap send recv ops based on rank
|
||||||
|
send_op = dist.P2POp(dist.isend, send_tensor, self.send_rank, group=self._process_group)
|
||||||
|
recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
|
||||||
|
self._ops.extend([send_op, recv_op])
|
||||||
|
|
||||||
|
if commit:
|
||||||
|
self._reqs = dist.batch_isend_irecv(self._ops)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def commit(self):
|
||||||
|
assert len(self._ops) > 0, "No ops to commit"
|
||||||
|
self._reqs = dist.batch_isend_irecv(self._ops)
|
||||||
|
|
||||||
|
def wait(self):
|
||||||
|
assert len(self._reqs) > 0, "No requests to wait for"
|
||||||
|
for req in self._reqs:
|
||||||
|
req.wait()
|
||||||
|
self._reqs = []
|
||||||
|
self._ops = []
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def get_half_index(cu_seqlens, *, front: bool):
|
||||||
|
index = torch.zeros(cu_seqlens[-1], dtype=torch.bool, device=cu_seqlens.device)
|
||||||
|
for i in range(len(cu_seqlens) - 1):
|
||||||
|
start, end = cu_seqlens[i], cu_seqlens[i + 1]
|
||||||
|
if front:
|
||||||
|
end = (start + end) // 2
|
||||||
|
else:
|
||||||
|
start = (start + end) // 2
|
||||||
|
index[start:end] = True
|
||||||
|
return index
|
||||||
|
|
|
@ -26,6 +26,8 @@ from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
from ..layer import ColoAttention, dist_cross_entropy
|
from ..layer import ColoAttention, dist_cross_entropy
|
||||||
|
|
||||||
|
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
|
||||||
|
|
||||||
|
|
||||||
class CommandPipelineForwards:
|
class CommandPipelineForwards:
|
||||||
"""
|
"""
|
||||||
|
@ -349,7 +351,7 @@ class CommandPipelineForwards:
|
||||||
return {"hidden_states": hidden_states}
|
return {"hidden_states": hidden_states}
|
||||||
|
|
||||||
|
|
||||||
def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
@ -362,7 +364,7 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||||
if sp_mode is not None:
|
if sp_mode is not None:
|
||||||
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
|
assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet"
|
||||||
assert (sp_size is not None) and (
|
assert (sp_size is not None) and (
|
||||||
sp_group is not None
|
sp_group is not None
|
||||||
), "Must specify sp_size and sp_group for sequence parallel"
|
), "Must specify sp_size and sp_group for sequence parallel"
|
||||||
|
@ -459,7 +461,7 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -24,14 +25,14 @@ from transformers.models.llama.modeling_llama import (
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.layer._operation import (
|
from colossalai.shardformer.layer import AttnMaskType
|
||||||
all_to_all_comm,
|
from colossalai.shardformer.layer._operation import all_to_all_comm, gather_sp_output, split_forward_gather_backward
|
||||||
gather_forward_split_backward,
|
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
|
||||||
split_forward_gather_backward,
|
|
||||||
)
|
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
from ..layer import ColoAttention, dist_cross_entropy
|
from ..layer import ColoAttention, RingAttention, dist_cross_entropy
|
||||||
|
|
||||||
|
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
|
||||||
|
|
||||||
|
|
||||||
class LlamaPipelineForwards:
|
class LlamaPipelineForwards:
|
||||||
|
@ -57,6 +58,10 @@ class LlamaPipelineForwards:
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
shard_config: ShardConfig = None,
|
shard_config: ShardConfig = None,
|
||||||
|
# Split output only when computing cross entropy using llama_for_causal_lm_forward
|
||||||
|
# or get_lm_forward_with_dist_cross_entropy
|
||||||
|
# Default to True to avoid bug when calling classification forward from huggingface
|
||||||
|
force_sp_output_gather: bool = True,
|
||||||
):
|
):
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
@ -97,7 +102,7 @@ class LlamaPipelineForwards:
|
||||||
sp_group = shard_config.sequence_parallel_process_group
|
sp_group = shard_config.sequence_parallel_process_group
|
||||||
sp_size = shard_config.sequence_parallel_size
|
sp_size = shard_config.sequence_parallel_size
|
||||||
if sp_mode == "all_to_all" and not stage_manager.is_first_stage():
|
if sp_mode == "all_to_all" and not stage_manager.is_first_stage():
|
||||||
# For correct positions ids. The states will be gather along the seq dim in the attention layer later.
|
# For generating full positions ids, as the states will be gather along the seq dim in the attention layer later.
|
||||||
seq_length *= sp_size
|
seq_length *= sp_size
|
||||||
|
|
||||||
past_seen_tokens = 0
|
past_seen_tokens = 0
|
||||||
|
@ -127,22 +132,36 @@ class LlamaPipelineForwards:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||||
# for the other stages, hidden_states is the output of the previous stage
|
# for the other stages, hidden_states is the output of the previous stage
|
||||||
if shard_config.enable_flash_attention:
|
if not stage_manager.is_first_stage() and sp_mode == "ring_attn":
|
||||||
|
_, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group)
|
||||||
|
elif shard_config.enable_flash_attention:
|
||||||
# in this case, attention_mask is a dict rather than a tensor
|
# in this case, attention_mask is a dict rather than a tensor
|
||||||
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
|
||||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
attn_kwargs = ColoAttention.prepare_attn_kwargs(
|
||||||
mask_shape,
|
mask_shape,
|
||||||
hidden_states.dtype,
|
hidden_states.dtype,
|
||||||
hidden_states.device,
|
hidden_states.device,
|
||||||
q_padding_mask=attention_mask,
|
q_padding_mask=attention_mask,
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
|
invert=(sp_mode != "ring_attn"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
attn_kwargs = self._update_causal_mask(attention_mask, hidden_states, cache_position)
|
||||||
|
|
||||||
# Support SP + PP
|
# Support SP + PP
|
||||||
|
# TODO: support padded casual cu_seqlens across stages
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
if sp_mode in ["ring", "split_gather"]:
|
# Ring Attention zigzag batch processing
|
||||||
|
if sp_mode == "ring_attn":
|
||||||
|
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
|
||||||
|
if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
|
||||||
|
hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
|
||||||
|
attention_mask, sp_group, hidden_states, position_ids
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group)
|
||||||
|
|
||||||
|
elif is_share_sp_tp(sp_mode):
|
||||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
|
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
|
||||||
elif sp_mode == "all_to_all":
|
elif sp_mode == "all_to_all":
|
||||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
|
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
|
||||||
|
@ -177,12 +196,11 @@ class LlamaPipelineForwards:
|
||||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if idx - start_idx < num_ckpt_layers:
|
if idx - start_idx < num_ckpt_layers:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
decoder_layer.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attn_kwargs,
|
||||||
position_ids,
|
position_ids,
|
||||||
past_key_values,
|
past_key_values,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
|
@ -192,14 +210,13 @@ class LlamaPipelineForwards:
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attn_kwargs,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_value=past_key_values,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
@ -209,10 +226,8 @@ class LlamaPipelineForwards:
|
||||||
|
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
|
||||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode)
|
||||||
elif sp_mode == "all_to_all":
|
|
||||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
# add hidden states from the last decoder layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
|
@ -298,6 +313,15 @@ class LlamaPipelineForwards:
|
||||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||||
output_hidden_states = False
|
output_hidden_states = False
|
||||||
|
|
||||||
|
if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output:
|
||||||
|
# Split labels in a zigzag fashion too
|
||||||
|
sp_group = shard_config.sequence_parallel_process_group
|
||||||
|
if attention_mask.bool().all():
|
||||||
|
labels = split_batch_zigzag(labels, sp_group, seq_dim=1)
|
||||||
|
else:
|
||||||
|
# [B, max_seqlen // sp_size]
|
||||||
|
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
outputs = LlamaPipelineForwards.llama_model_forward(
|
outputs = LlamaPipelineForwards.llama_model_forward(
|
||||||
self.model,
|
self.model,
|
||||||
|
@ -315,6 +339,7 @@ class LlamaPipelineForwards:
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
stage_index=stage_index,
|
stage_index=stage_index,
|
||||||
shard_config=shard_config,
|
shard_config=shard_config,
|
||||||
|
force_sp_output_gather=False,
|
||||||
)
|
)
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
|
|
||||||
|
@ -457,11 +482,11 @@ class LlamaPipelineForwards:
|
||||||
return {"hidden_states": hidden_states}
|
return {"hidden_states": hidden_states}
|
||||||
|
|
||||||
|
|
||||||
def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[Union[torch.Tensor, Dict]] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
@ -470,7 +495,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||||
if sp_mode is not None:
|
if sp_mode is not None:
|
||||||
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
|
assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet"
|
||||||
assert (sp_size is not None) and (
|
assert (sp_size is not None) and (
|
||||||
sp_group is not None
|
sp_group is not None
|
||||||
), "Must specify sp_size and sp_group for sequence parallel"
|
), "Must specify sp_size and sp_group for sequence parallel"
|
||||||
|
@ -481,7 +506,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
# sp: modify sp_len when sequence parallel mode is ring
|
# sp: modify sp_len when sequence parallel mode is ring
|
||||||
if sp_mode in ["split_gather", "ring"]:
|
if is_share_sp_tp(sp_mode):
|
||||||
q_len *= sp_size
|
q_len *= sp_size
|
||||||
|
|
||||||
if self.config.pretraining_tp > 1:
|
if self.config.pretraining_tp > 1:
|
||||||
|
@ -526,6 +551,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
@ -537,12 +563,21 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
if shard_config.enable_flash_attention:
|
if sp_mode == "ring_attn":
|
||||||
|
attn_output = RingAttention.attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
sp_group,
|
||||||
|
**attention_mask,
|
||||||
|
inner_ring_size=shard_config.inner_ring_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif shard_config.enable_flash_attention:
|
||||||
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
|
||||||
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
|
||||||
else:
|
else:
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
@ -588,7 +623,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -603,6 +638,10 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
# Split output only when computing cross entropy using llama_for_causal_lm_forward
|
||||||
|
# or get_lm_forward_with_dist_cross_entropy
|
||||||
|
# Default to True to avoid bug when calling classification forward from huggingface
|
||||||
|
force_sp_output_gather: bool = True,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
|
@ -629,32 +668,45 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=
|
||||||
|
|
||||||
past_seen_tokens = 0
|
past_seen_tokens = 0
|
||||||
seq_len = inputs_embeds.shape[1]
|
seq_len = inputs_embeds.shape[1]
|
||||||
|
batch_size = inputs_embeds.shape[0]
|
||||||
if use_cache: # kept for BC (cache positions)
|
if use_cache: # kept for BC (cache positions)
|
||||||
if not isinstance(past_key_values, StaticCache):
|
if not isinstance(past_key_values, StaticCache):
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
past_seen_tokens = past_key_values.get_seq_length()
|
past_seen_tokens = past_key_values.get_seq_length()
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
if isinstance(past_key_values, StaticCache):
|
if isinstance(past_key_values, StaticCache):
|
||||||
raise ValueError("cache_position is a required argument when using StaticCache.")
|
raise ValueError("cache_position is a required argument when using StaticCache.")
|
||||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device)
|
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=inputs_embeds.device)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
# in this case, attention_mask is a dict rather than a tensor
|
|
||||||
if shard_config.enable_flash_attention:
|
if shard_config.enable_flash_attention:
|
||||||
mask_shape = (inputs_embeds.shape[0], 1, seq_len, past_seen_tokens + seq_len)
|
mask_shape = (batch_size, 1, seq_len, past_seen_tokens + seq_len)
|
||||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
|
||||||
mask_shape,
|
mask_shape,
|
||||||
inputs_embeds.dtype,
|
inputs_embeds.dtype,
|
||||||
inputs_embeds.device,
|
inputs_embeds.device,
|
||||||
q_padding_mask=attention_mask,
|
q_padding_mask=attention_mask,
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
|
invert=(sp_mode != "ring_attn"),
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
|
||||||
|
|
||||||
if sp_mode in ["ring", "split_gather"]:
|
else:
|
||||||
|
attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||||
|
|
||||||
|
# Ring Attention zigzag batch processing
|
||||||
|
if sp_mode == "ring_attn":
|
||||||
|
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
|
||||||
|
if attn_kwargs["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
|
||||||
|
inputs_embeds, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
|
||||||
|
attention_mask, sp_group, inputs_embeds, position_ids
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
inputs_embeds, position_ids = split_batch_zigzag([inputs_embeds, position_ids], sp_group)
|
||||||
|
attn_kwargs = {"attention_mask_type": attn_kwargs["attention_mask_type"]} # drop redundant tensors
|
||||||
|
|
||||||
|
elif is_share_sp_tp(sp_mode):
|
||||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||||
elif sp_mode == "all_to_all":
|
elif sp_mode == "all_to_all":
|
||||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||||
|
@ -672,7 +724,7 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
decoder_layer.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attn_kwargs,
|
||||||
position_ids,
|
position_ids,
|
||||||
past_key_values,
|
past_key_values,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
|
@ -683,7 +735,7 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attn_kwargs,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_value=past_key_values,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
@ -700,11 +752,9 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
# Cases that don't support parallelizing cross entropy computation along sequence
|
||||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
if (not shard_config.parallel_output) or is_share_sp_tp(sp_mode) or force_sp_output_gather:
|
||||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
hidden_states = gather_sp_output(hidden_states, sp_group, sp_mode)
|
||||||
elif sp_mode == "all_to_all":
|
|
||||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
# add hidden states from the last decoder layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
|
@ -777,6 +827,15 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output:
|
||||||
|
# Special processing: Split labels in a zigzag fashion too
|
||||||
|
sp_group = shard_config.sequence_parallel_process_group
|
||||||
|
if attention_mask.bool().all():
|
||||||
|
labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
|
||||||
|
else:
|
||||||
|
# [B, max_seq_len // sp_size]
|
||||||
|
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
@ -789,6 +848,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
force_sp_output_gather=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
|
@ -799,7 +859,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
else:
|
else:
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
logits = logits.float()
|
logits = logits.float()
|
||||||
|
|
||||||
loss = dist_cross_entropy(
|
loss = dist_cross_entropy(
|
||||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||||
)
|
)
|
||||||
|
|
|
@ -75,6 +75,7 @@ class Policy(ABC):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.shard_config: Optional[ShardConfig] = None
|
self.shard_config: Optional[ShardConfig] = None
|
||||||
self.model: Optional[Module] = None
|
self.model: Optional[Module] = None
|
||||||
|
self.is_causal = None # Whether we're doing causal lm, i.e. using cross entropy
|
||||||
|
|
||||||
def set_model(self, model: nn.Module) -> None:
|
def set_model(self, model: nn.Module) -> None:
|
||||||
r"""
|
r"""
|
||||||
|
|
|
@ -69,13 +69,18 @@ class CommandPolicy(Policy):
|
||||||
sp_size = self.shard_config.sequence_parallel_size or None
|
sp_size = self.shard_config.sequence_parallel_size or None
|
||||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||||
|
if sp_mode == "ring_attn" and not self.is_causal:
|
||||||
|
raise ValueError("Ring attention is only meant for causal language modeling.")
|
||||||
|
|
||||||
|
tp_size = self.shard_config.tensor_parallel_size or None
|
||||||
|
num_q_heads = self.model.config.num_attention_heads
|
||||||
|
num_kv_heads = getattr(self.model.config, "num_key_value_heads", None)
|
||||||
if sp_mode == "all_to_all":
|
if sp_mode == "all_to_all":
|
||||||
decoder_attribute_replacement = {
|
num_q_heads //= sp_size
|
||||||
"num_heads": self.model.config.num_attention_heads // sp_size,
|
decoder_attribute_replacement = {"num_heads": num_q_heads}
|
||||||
}
|
if num_kv_heads:
|
||||||
if getattr(self.model.config, "num_key_value_heads", False):
|
num_kv_heads //= sp_size
|
||||||
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
|
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
|
||||||
|
|
||||||
policy[attn_cls] = ModulePolicyDescription(
|
policy[attn_cls] = ModulePolicyDescription(
|
||||||
attribute_replacement=decoder_attribute_replacement,
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
|
@ -104,21 +109,18 @@ class CommandPolicy(Policy):
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
assert (
|
assert (
|
||||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
num_q_heads % tp_size == 0
|
||||||
), f"The number of attention heads must be divisible by tensor parallel size."
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
if hasattr(self.model.config, "num_key_value_heads"):
|
if hasattr(self.model.config, "num_key_value_heads"):
|
||||||
assert (
|
assert (
|
||||||
self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size
|
num_kv_heads >= tp_size and num_kv_heads % tp_size == 0
|
||||||
and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
|
||||||
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
|
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
|
||||||
decoder_attribute_replacement = {
|
decoder_attribute_replacement = {
|
||||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"self_attn.hidden_size": self.model.config.hidden_size // tp_size,
|
||||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
"self_attn.num_heads": num_q_heads // tp_size,
|
||||||
}
|
}
|
||||||
if getattr(self.model.config, "num_key_value_heads", False):
|
if getattr(self.model.config, "num_key_value_heads", False):
|
||||||
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
|
decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads // tp_size
|
||||||
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
|
|
||||||
)
|
|
||||||
|
|
||||||
policy[CohereDecoderLayer] = ModulePolicyDescription(
|
policy[CohereDecoderLayer] = ModulePolicyDescription(
|
||||||
attribute_replacement=decoder_attribute_replacement,
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
|
@ -290,10 +292,11 @@ class CommandForCausalLMPolicy(CommandPolicy):
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers import CohereForCausalLM
|
from transformers import CohereForCausalLM
|
||||||
|
|
||||||
|
self.is_causal = True
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# add a new item for casual lm
|
# add a new item for causal lm
|
||||||
new_item = {
|
new_item = {
|
||||||
CohereForCausalLM: ModulePolicyDescription(
|
CohereForCausalLM: ModulePolicyDescription(
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
|
|
|
@ -298,7 +298,7 @@ class DeepseekForCausalLMPolicy(DeepseekPolicy):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
# TODO: assign pg mesh from plugin to all modules
|
# TODO: assign pg mesh from plugin to all modules
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# add a new item for casual lm
|
# add a new item for causal lm
|
||||||
new_item = {
|
new_item = {
|
||||||
"DeepseekForCausalLM": ModulePolicyDescription(
|
"DeepseekForCausalLM": ModulePolicyDescription(
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
|
|
|
@ -69,13 +69,20 @@ class LlamaPolicy(Policy):
|
||||||
sp_size = self.shard_config.sequence_parallel_size or None
|
sp_size = self.shard_config.sequence_parallel_size or None
|
||||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||||
|
if sp_mode == "ring_attn" and not self.is_causal:
|
||||||
|
raise ValueError("Ring attention is only meant for causal language modeling.")
|
||||||
|
|
||||||
|
tp_size = self.shard_config.tensor_parallel_size
|
||||||
|
# Modified by SP and TP
|
||||||
|
num_q_heads = self.model.config.num_attention_heads
|
||||||
|
num_kv_heads = getattr(self.model.config, "num_key_value_heads", None)
|
||||||
|
|
||||||
if sp_mode == "all_to_all":
|
if sp_mode == "all_to_all":
|
||||||
decoder_attribute_replacement = {
|
num_q_heads //= sp_size
|
||||||
"num_heads": self.model.config.num_attention_heads // sp_size,
|
decoder_attribute_replacement = {"num_heads": num_q_heads}
|
||||||
}
|
if num_kv_heads:
|
||||||
if getattr(self.model.config, "num_key_value_heads", False):
|
num_kv_heads //= sp_size
|
||||||
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
|
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
|
||||||
|
|
||||||
policy[attn_cls] = ModulePolicyDescription(
|
policy[attn_cls] = ModulePolicyDescription(
|
||||||
attribute_replacement=decoder_attribute_replacement,
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
|
@ -104,21 +111,20 @@ class LlamaPolicy(Policy):
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
assert (
|
assert (
|
||||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
num_q_heads % tp_size == 0
|
||||||
), f"The number of attention heads must be divisible by tensor parallel size."
|
), f"The number of attention heads must be divisible by tensor parallel size."
|
||||||
if hasattr(self.model.config, "num_key_value_heads"):
|
if hasattr(self.model.config, "num_key_value_heads"):
|
||||||
assert (
|
assert (
|
||||||
self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size
|
num_kv_heads >= tp_size and num_kv_heads % tp_size == 0
|
||||||
and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
|
|
||||||
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
|
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
|
||||||
|
num_q_heads //= tp_size
|
||||||
decoder_attribute_replacement = {
|
decoder_attribute_replacement = {
|
||||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"self_attn.hidden_size": self.model.config.hidden_size // tp_size,
|
||||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
"self_attn.num_heads": num_q_heads,
|
||||||
}
|
}
|
||||||
if getattr(self.model.config, "num_key_value_heads", False):
|
if getattr(self.model.config, "num_key_value_heads", False):
|
||||||
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
|
num_kv_heads //= tp_size
|
||||||
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
|
decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads
|
||||||
)
|
|
||||||
|
|
||||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||||
attribute_replacement=decoder_attribute_replacement,
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
|
@ -295,10 +301,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers import LlamaForCausalLM
|
from transformers import LlamaForCausalLM
|
||||||
|
|
||||||
|
self.is_causal = True
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# add a new item for casual lm
|
# add a new item for causal lm
|
||||||
new_item = {
|
new_item = {
|
||||||
LlamaForCausalLM: ModulePolicyDescription(
|
LlamaForCausalLM: ModulePolicyDescription(
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
|
@ -313,10 +320,6 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if self.shard_config.parallel_output:
|
|
||||||
new_item[LlamaForCausalLM].method_replacement = {
|
|
||||||
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
new_item = {
|
new_item = {
|
||||||
LlamaForCausalLM: ModulePolicyDescription(
|
LlamaForCausalLM: ModulePolicyDescription(
|
||||||
|
@ -336,7 +339,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||||
self.set_pipeline_forward(
|
self.set_pipeline_forward(
|
||||||
model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy
|
model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy
|
||||||
)
|
)
|
||||||
|
elif self.shard_config.enable_tensor_parallelism or self.shard_config.enable_sequence_parallelism:
|
||||||
|
# Compute loss distributedly along the sequence dimension
|
||||||
|
new_item[LlamaForCausalLM].method_replacement = {
|
||||||
|
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||||
|
}
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
def get_held_layers(self) -> List[Module]:
|
def get_held_layers(self) -> List[Module]:
|
||||||
|
|
|
@ -271,7 +271,7 @@ class MistralForCausalLMPolicy(MistralPolicy):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# add a new item for casual lm
|
# add a new item for causal lm
|
||||||
new_item = {
|
new_item = {
|
||||||
MistralForCausalLM: ModulePolicyDescription(
|
MistralForCausalLM: ModulePolicyDescription(
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
|
|
|
@ -275,7 +275,7 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
# TODO: assign pg mesh from plugin to all modules
|
# TODO: assign pg mesh from plugin to all modules
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# add a new item for casual lm
|
# add a new item for causal lm
|
||||||
new_item = {
|
new_item = {
|
||||||
MixtralForCausalLM: ModulePolicyDescription(
|
MixtralForCausalLM: ModulePolicyDescription(
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
|
|
|
@ -313,7 +313,7 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
||||||
setattr(self.shard_config, "causal_lm", True)
|
setattr(self.shard_config, "causal_lm", True)
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# add a new item for casual lm
|
# add a new item for causal lm
|
||||||
new_item = {
|
new_item = {
|
||||||
Qwen2ForCausalLM: ModulePolicyDescription(
|
Qwen2ForCausalLM: ModulePolicyDescription(
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
|
|
|
@ -10,7 +10,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from .grad_ckpt_config import GradientCheckpointConfig
|
from .grad_ckpt_config import GradientCheckpointConfig
|
||||||
|
|
||||||
__all__ = ["ShardConfig"]
|
__all__ = ["ShardConfig"]
|
||||||
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
|
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -29,6 +29,8 @@ class ShardConfig:
|
||||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
|
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
|
||||||
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
|
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
|
||||||
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
|
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
|
||||||
|
parallel_output (bool): For TP: whether to use parallelize cross entropy computation along the feature dim.
|
||||||
|
For SP: set to True to NOT gather the output along the seq dim.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tensor_parallel_process_group: Optional[ProcessGroup] = None
|
tensor_parallel_process_group: Optional[ProcessGroup] = None
|
||||||
|
@ -47,10 +49,11 @@ class ShardConfig:
|
||||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
|
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
|
||||||
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# For ring attention
|
||||||
|
inner_ring_size: Optional[int] = None
|
||||||
# for moe related
|
# for moe related
|
||||||
moe_dp_group: Optional[ProcessGroup] = None
|
moe_dp_group: Optional[ProcessGroup] = None
|
||||||
ep_group: Optional[ProcessGroup] = None
|
ep_group: Optional[ProcessGroup] = None
|
||||||
|
|
||||||
# pipeline_parallel_size: int
|
# pipeline_parallel_size: int
|
||||||
# data_parallel_size: int
|
# data_parallel_size: int
|
||||||
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
||||||
|
@ -80,9 +83,9 @@ class ShardConfig:
|
||||||
self.enable_tensor_parallelism
|
self.enable_tensor_parallelism
|
||||||
), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True"
|
), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True"
|
||||||
elif self.sequence_parallelism_mode in ["all_to_all"]:
|
elif self.sequence_parallelism_mode in ["all_to_all"]:
|
||||||
assert (
|
# assert (
|
||||||
not self.enable_tensor_parallelism
|
# not self.enable_tensor_parallelism
|
||||||
), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False"
|
# ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False"
|
||||||
if self.enable_sequence_overlap:
|
if self.enable_sequence_overlap:
|
||||||
self.enable_sequence_overlap = False
|
self.enable_sequence_overlap = False
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
@ -28,6 +28,7 @@ warnings.filterwarnings("ignore")
|
||||||
# Constants
|
# Constants
|
||||||
# ==============================
|
# ==============================
|
||||||
|
|
||||||
|
# We have lots of llamas for your choice!
|
||||||
MODEL_CONFIGS = {
|
MODEL_CONFIGS = {
|
||||||
"100m": LlamaConfig(
|
"100m": LlamaConfig(
|
||||||
max_position_embeddings=4096,
|
max_position_embeddings=4096,
|
||||||
|
@ -36,6 +37,7 @@ MODEL_CONFIGS = {
|
||||||
intermediate_size=2048,
|
intermediate_size=2048,
|
||||||
hidden_size=1024,
|
hidden_size=1024,
|
||||||
),
|
),
|
||||||
|
"5b": LlamaConfig(max_position_embeddings=4096, num_key_value_heads=8),
|
||||||
"7b": LlamaConfig(max_position_embeddings=4096),
|
"7b": LlamaConfig(max_position_embeddings=4096),
|
||||||
"13b": LlamaConfig(
|
"13b": LlamaConfig(
|
||||||
hidden_size=5120,
|
hidden_size=5120,
|
||||||
|
@ -68,9 +70,6 @@ def main():
|
||||||
default="gemini",
|
default="gemini",
|
||||||
help="Choose which plugin to use",
|
help="Choose which plugin to use",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--overlap", action="store_true", help="Overlap communication with computation in Pipeline Parallel."
|
|
||||||
)
|
|
||||||
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size")
|
parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size")
|
||||||
parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
|
parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
|
||||||
parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
|
parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
|
||||||
|
@ -94,11 +93,24 @@ def main():
|
||||||
|
|
||||||
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
|
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
|
||||||
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
||||||
parser.add_argument("--profile", action="store_true", help="Profile the code", default=False)
|
parser.add_argument("--profile", action="store_true", help="Profile the code")
|
||||||
|
parser.add_argument(
|
||||||
|
"--nsys",
|
||||||
|
action="store_true",
|
||||||
|
help="Use nsys for profiling. \
|
||||||
|
You should put something like this before colossalai launch: \
|
||||||
|
nsys profile -w true -t cuda,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --capture-range-end=stop --cudabacktrace=true -x true --python-backtrace=cuda -o prof_out",
|
||||||
|
)
|
||||||
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
|
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
|
||||||
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
||||||
parser.add_argument("--no_cache", action="store_true")
|
parser.add_argument("--no_cache", action="store_true")
|
||||||
parser.add_argument("--overlap_allgather", action="store_true")
|
parser.add_argument("--overlap_allgather", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--sp_mode",
|
||||||
|
default="all_to_all",
|
||||||
|
choices=["all_to_all", "ring_attn", "ring", "split_gather"],
|
||||||
|
help="Sequence parallelism mode",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
colossalai.launch_from_torch()
|
colossalai.launch_from_torch()
|
||||||
|
@ -195,12 +207,12 @@ def main():
|
||||||
num_model_chunks=args.n_chunks,
|
num_model_chunks=args.n_chunks,
|
||||||
zero_stage=args.zero,
|
zero_stage=args.zero,
|
||||||
sp_size=args.sp,
|
sp_size=args.sp,
|
||||||
|
sequence_parallelism_mode=args.sp_mode,
|
||||||
enable_sequence_parallelism=args.sp > 1,
|
enable_sequence_parallelism=args.sp > 1,
|
||||||
enable_fused_normalization=torch.cuda.is_available(),
|
enable_fused_normalization=torch.cuda.is_available(),
|
||||||
enable_flash_attention=args.xformers,
|
enable_flash_attention=args.xformers,
|
||||||
microbatch_size=args.mbs,
|
microbatch_size=args.mbs,
|
||||||
precision="bf16",
|
precision="bf16",
|
||||||
overlap_p2p=args.overlap,
|
|
||||||
enable_metadata_cache=not args.no_cache,
|
enable_metadata_cache=not args.no_cache,
|
||||||
overlap_allgather=args.overlap_allgather,
|
overlap_allgather=args.overlap_allgather,
|
||||||
**hybrid_kwargs,
|
**hybrid_kwargs,
|
||||||
|
@ -218,7 +230,6 @@ def main():
|
||||||
microbatch_size=args.mbs,
|
microbatch_size=args.mbs,
|
||||||
initial_scale=2**8,
|
initial_scale=2**8,
|
||||||
precision="bf16",
|
precision="bf16",
|
||||||
overlap_p2p=args.overlap,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||||
|
@ -295,6 +306,7 @@ def main():
|
||||||
args.ignore_steps,
|
args.ignore_steps,
|
||||||
1, # avoid creating massive log files
|
1, # avoid creating massive log files
|
||||||
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
|
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
|
||||||
|
nsys=args.nsys,
|
||||||
) as prof:
|
) as prof:
|
||||||
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
|
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
|
||||||
data_iter = iter(dataloader)
|
data_iter = iter(dataloader)
|
||||||
|
@ -320,13 +332,16 @@ def main():
|
||||||
performance_evaluator.on_step_start(step)
|
performance_evaluator.on_step_start(step)
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
loss = outputs[0]
|
loss = outputs[0]
|
||||||
|
del outputs # free memory
|
||||||
|
|
||||||
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
|
print(f"Step {step} loss: {loss}")
|
||||||
booster.backward(loss, optimizer)
|
booster.backward(loss, optimizer)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
performance_evaluator.on_step_end(**batch)
|
performance_evaluator.on_step_end(**batch)
|
||||||
prof.step()
|
prof.step()
|
||||||
|
|
||||||
performance_evaluator.on_fit_end()
|
performance_evaluator.on_fit_end()
|
||||||
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
|
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||||
## OPT
|
## OPT
|
||||||
Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments.
|
Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments.
|
||||||
|
|
||||||
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost.
|
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Causal Language Modelling at low cost.
|
||||||
|
|
||||||
|
|
||||||
## Our Modifications
|
## Our Modifications
|
||||||
|
|
|
@ -28,7 +28,7 @@ def all_reduce_mean(x: float, world_size: int) -> float:
|
||||||
return tensor.item()
|
return tensor.item()
|
||||||
|
|
||||||
|
|
||||||
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
|
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir, nsys=False):
|
||||||
class DummyProfiler:
|
class DummyProfiler:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.step_number = 0
|
self.step_number = 0
|
||||||
|
@ -42,7 +42,29 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class NsysProfiler:
|
||||||
|
def __init__(self, warmup_steps, active_steps):
|
||||||
|
self.step_number = 0
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
|
self.active_steps = active_steps
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
if self.step_number == self.warmup_steps:
|
||||||
|
torch.cuda.cudart().cudaProfilerStart()
|
||||||
|
elif self.step_number == self.warmup_steps + self.active_steps:
|
||||||
|
torch.cuda.cudart().cudaProfilerStop()
|
||||||
|
self.step_number += 1
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
pass
|
||||||
|
|
||||||
if enable_flag:
|
if enable_flag:
|
||||||
|
if nsys:
|
||||||
|
return NsysProfiler(warmup_steps, active_steps)
|
||||||
|
|
||||||
return profile(
|
return profile(
|
||||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||||
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
|
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
|
||||||
|
|
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||||
## OPT
|
## OPT
|
||||||
Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments.
|
Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments.
|
||||||
|
|
||||||
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost.
|
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning causal Language Modelling at low cost.
|
||||||
|
|
||||||
We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
|
We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
|
||||||
the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).
|
the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).
|
||||||
|
|
|
@ -57,14 +57,14 @@ class FlashAttentionDaoCudaExtension(_Extension):
|
||||||
q_indices: Optional[torch.Tensor] = None,
|
q_indices: Optional[torch.Tensor] = None,
|
||||||
kv_indices: Optional[torch.Tensor] = None,
|
kv_indices: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
# [B, N, S, D] -> [B, S, N, D]
|
# [B, H, S, D] -> [B, S, H, D]
|
||||||
q = q.transpose(1, 2)
|
q = q.transpose(1, 2)
|
||||||
k = k.transpose(1, 2)
|
k = k.transpose(1, 2)
|
||||||
v = v.transpose(1, 2)
|
v = v.transpose(1, 2)
|
||||||
b, s_q = q.shape[:2]
|
b, s_q = q.shape[:2]
|
||||||
if cu_seqlens_q is not None:
|
if cu_seqlens_q is not None:
|
||||||
# padded / padded causal
|
# padded / padded causal
|
||||||
# unpad input: [B, S, N, D] -> [T, N, D]
|
# unpad input: [B, S, H, D] -> [T, H, D]
|
||||||
q = _unpad_input(q, q_indices)
|
q = _unpad_input(q, q_indices)
|
||||||
kv = _unpad_input(torch.stack(tensors=(k, v), dim=2), kv_indices)
|
kv = _unpad_input(torch.stack(tensors=(k, v), dim=2), kv_indices)
|
||||||
attn_output = flash_attn_varlen_kvpacked_func(
|
attn_output = flash_attn_varlen_kvpacked_func(
|
||||||
|
@ -78,7 +78,7 @@ class FlashAttentionDaoCudaExtension(_Extension):
|
||||||
softmax_scale=scale,
|
softmax_scale=scale,
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
)
|
)
|
||||||
# pad output: [T, N, D] -> [B, S, N, D]
|
# pad output: [T, H, D] -> [B, S, H, D]
|
||||||
attn_output = pad_input(attn_output, q_indices, b, s_q)
|
attn_output = pad_input(attn_output, q_indices, b, s_q)
|
||||||
else:
|
else:
|
||||||
# causal / no attn mask
|
# causal / no attn mask
|
||||||
|
@ -90,7 +90,7 @@ class FlashAttentionDaoCudaExtension(_Extension):
|
||||||
softmax_scale=scale,
|
softmax_scale=scale,
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
)
|
)
|
||||||
# [B, S, N, D] -> [B, N, S, D]
|
# [B, S, H, D] -> [B, H, S, D]
|
||||||
return attn_output.transpose(1, 2)
|
return attn_output.transpose(1, 2)
|
||||||
|
|
||||||
return flash_attention
|
return flash_attention
|
||||||
|
|
|
@ -22,9 +22,9 @@ COMMON_MODELS = [
|
||||||
"transformers_bloom_for_causal_lm",
|
"transformers_bloom_for_causal_lm",
|
||||||
"transformers_falcon_for_causal_lm",
|
"transformers_falcon_for_causal_lm",
|
||||||
"transformers_chatglm_for_conditional_generation",
|
"transformers_chatglm_for_conditional_generation",
|
||||||
"transformers_llama_for_casual_lm",
|
"transformers_llama_for_causal_lm",
|
||||||
"transformers_vit_for_masked_image_modeling",
|
"transformers_vit_for_masked_image_modeling",
|
||||||
"transformers_mistral_for_casual_lm",
|
"transformers_mistral_for_causal_lm",
|
||||||
]
|
]
|
||||||
|
|
||||||
IS_FAST_TEST = os.environ.get("FAST_TEST", "0") == "1"
|
IS_FAST_TEST = os.environ.get("FAST_TEST", "0") == "1"
|
||||||
|
|
|
@ -32,8 +32,8 @@ if HAS_COMMAND:
|
||||||
|
|
||||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
# label is needed for casual lm
|
# label is needed for causal lm
|
||||||
def data_gen_for_casual_lm():
|
def data_gen_for_causal_lm():
|
||||||
data = data_gen()
|
data = data_gen()
|
||||||
labels = data["input_ids"].clone()
|
labels = data["input_ids"].clone()
|
||||||
data["labels"] = labels
|
data["labels"] = labels
|
||||||
|
@ -44,7 +44,7 @@ if HAS_COMMAND:
|
||||||
|
|
||||||
# function to get the loss
|
# function to get the loss
|
||||||
loss_fn = lambda output: output["last_hidden_state"].mean()
|
loss_fn = lambda output: output["last_hidden_state"].mean()
|
||||||
loss_fn_for_casual_lm = lambda output: output["loss"]
|
loss_fn_for_causal_lm = lambda output: output["loss"]
|
||||||
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
|
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
|
||||||
|
|
||||||
config = CohereConfig(
|
config = CohereConfig(
|
||||||
|
@ -70,10 +70,10 @@ if HAS_COMMAND:
|
||||||
model_attribute=ModelAttribute(has_control_flow=True),
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
)
|
)
|
||||||
model_zoo.register(
|
model_zoo.register(
|
||||||
name="transformers_command_for_casual_lm",
|
name="transformers_command_for_causal_lm",
|
||||||
model_fn=lambda: transformers.CohereForCausalLM(config),
|
model_fn=lambda: transformers.CohereForCausalLM(config),
|
||||||
data_gen_fn=data_gen_for_casual_lm,
|
data_gen_fn=data_gen_for_causal_lm,
|
||||||
output_transform_fn=output_transform_fn,
|
output_transform_fn=output_transform_fn,
|
||||||
loss_fn=loss_fn_for_casual_lm,
|
loss_fn=loss_fn_for_causal_lm,
|
||||||
model_attribute=ModelAttribute(has_control_flow=True),
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
)
|
)
|
||||||
|
|
|
@ -33,20 +33,21 @@ if HAS_LLAMA:
|
||||||
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
|
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
|
||||||
]
|
]
|
||||||
).long()
|
).long()
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
attention_mask = torch.Tensor(
|
|
||||||
[
|
|
||||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
|
||||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
|
||||||
]
|
|
||||||
).long()
|
|
||||||
|
|
||||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
# label is needed for casual lm
|
# label is needed for causal lm
|
||||||
def data_gen_for_casual_lm():
|
def data_gen_for_causal_lm():
|
||||||
data = data_gen()
|
data = data_gen()
|
||||||
|
|
||||||
|
# Test padded sequence
|
||||||
|
padding = torch.zeros(2, data["input_ids"].shape[1] // 2, dtype=torch.long)
|
||||||
|
data["input_ids"] = torch.cat([data["input_ids"], padding], dim=1)
|
||||||
|
data["attention_mask"] = torch.cat([data["attention_mask"], padding], dim=1)
|
||||||
|
|
||||||
|
ignore_idx = -100
|
||||||
labels = data["input_ids"].clone()
|
labels = data["input_ids"].clone()
|
||||||
|
labels[~data["attention_mask"].bool()] = ignore_idx
|
||||||
data["labels"] = labels
|
data["labels"] = labels
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -55,7 +56,7 @@ if HAS_LLAMA:
|
||||||
|
|
||||||
# function to get the loss
|
# function to get the loss
|
||||||
loss_fn = lambda output: output["last_hidden_state"].mean()
|
loss_fn = lambda output: output["last_hidden_state"].mean()
|
||||||
loss_fn_for_casual_lm = lambda output: output["loss"]
|
loss_fn_for_causal_lm = lambda output: output["loss"]
|
||||||
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
|
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
|
||||||
|
|
||||||
config = LlamaConfig(
|
config = LlamaConfig(
|
||||||
|
@ -70,9 +71,17 @@ if HAS_LLAMA:
|
||||||
config.pad_token_id = config.eos_token_id
|
config.pad_token_id = config.eos_token_id
|
||||||
|
|
||||||
# register the following models
|
# register the following models
|
||||||
# transformers.LlamaModel,
|
|
||||||
# transformers.LlamaForCausalLM,
|
# transformers.LlamaForCausalLM,
|
||||||
|
# transformers.LlamaModel,
|
||||||
# transformers.LlamaForSequenceClassification,
|
# transformers.LlamaForSequenceClassification,
|
||||||
|
model_zoo.register(
|
||||||
|
name="transformers_llama_for_causal_lm",
|
||||||
|
model_fn=lambda: transformers.LlamaForCausalLM(config),
|
||||||
|
data_gen_fn=data_gen_for_causal_lm,
|
||||||
|
output_transform_fn=output_transform_fn,
|
||||||
|
loss_fn=loss_fn_for_causal_lm,
|
||||||
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
|
)
|
||||||
model_zoo.register(
|
model_zoo.register(
|
||||||
name="transformers_llama",
|
name="transformers_llama",
|
||||||
model_fn=lambda: transformers.LlamaModel(config),
|
model_fn=lambda: transformers.LlamaModel(config),
|
||||||
|
@ -81,14 +90,6 @@ if HAS_LLAMA:
|
||||||
loss_fn=loss_fn,
|
loss_fn=loss_fn,
|
||||||
model_attribute=ModelAttribute(has_control_flow=True),
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
)
|
)
|
||||||
model_zoo.register(
|
|
||||||
name="transformers_llama_for_casual_lm",
|
|
||||||
model_fn=lambda: transformers.LlamaForCausalLM(config),
|
|
||||||
data_gen_fn=data_gen_for_casual_lm,
|
|
||||||
output_transform_fn=output_transform_fn,
|
|
||||||
loss_fn=loss_fn_for_casual_lm,
|
|
||||||
model_attribute=ModelAttribute(has_control_flow=True),
|
|
||||||
)
|
|
||||||
model_zoo.register(
|
model_zoo.register(
|
||||||
name="transformers_llama_for_sequence_classification",
|
name="transformers_llama_for_sequence_classification",
|
||||||
model_fn=lambda: transformers.LlamaForSequenceClassification(config),
|
model_fn=lambda: transformers.LlamaForSequenceClassification(config),
|
||||||
|
|
|
@ -64,7 +64,7 @@ model_zoo.register(
|
||||||
model_attribute=ModelAttribute(has_control_flow=True),
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
)
|
)
|
||||||
model_zoo.register(
|
model_zoo.register(
|
||||||
name="transformers_mistral_for_casual_lm",
|
name="transformers_mistral_for_causal_lm",
|
||||||
model_fn=lambda: transformers.MistralForCausalLM(config),
|
model_fn=lambda: transformers.MistralForCausalLM(config),
|
||||||
data_gen_fn=data_gen_for_lm,
|
data_gen_fn=data_gen_for_lm,
|
||||||
output_transform_fn=output_transform_fn,
|
output_transform_fn=output_transform_fn,
|
||||||
|
|
|
@ -33,8 +33,8 @@ if HAS_QWEN2:
|
||||||
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long()
|
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long()
|
||||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
# label is needed for casual lm
|
# label is needed for causal lm
|
||||||
def data_gen_for_casual_lm():
|
def data_gen_for_causal_lm():
|
||||||
data = data_gen()
|
data = data_gen()
|
||||||
labels = data["input_ids"].clone()
|
labels = data["input_ids"].clone()
|
||||||
data["labels"] = labels
|
data["labels"] = labels
|
||||||
|
@ -45,7 +45,7 @@ if HAS_QWEN2:
|
||||||
|
|
||||||
# function to get the loss
|
# function to get the loss
|
||||||
loss_fn = lambda output: output["last_hidden_state"].mean()
|
loss_fn = lambda output: output["last_hidden_state"].mean()
|
||||||
loss_fn_for_casual_lm = lambda output: output["loss"]
|
loss_fn_for_causal_lm = lambda output: output["loss"]
|
||||||
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
|
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
|
||||||
|
|
||||||
config = Qwen2Config(
|
config = Qwen2Config(
|
||||||
|
@ -72,11 +72,11 @@ if HAS_QWEN2:
|
||||||
model_attribute=ModelAttribute(has_control_flow=True),
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
)
|
)
|
||||||
model_zoo.register(
|
model_zoo.register(
|
||||||
name="transformers_qwen2_for_casual_lm",
|
name="transformers_qwen2_for_causal_lm",
|
||||||
model_fn=lambda: transformers.Qwen2ForCausalLM(config),
|
model_fn=lambda: transformers.Qwen2ForCausalLM(config),
|
||||||
data_gen_fn=data_gen_for_casual_lm,
|
data_gen_fn=data_gen_for_causal_lm,
|
||||||
output_transform_fn=output_transform_fn,
|
output_transform_fn=output_transform_fn,
|
||||||
loss_fn=loss_fn_for_casual_lm,
|
loss_fn=loss_fn_for_causal_lm,
|
||||||
model_attribute=ModelAttribute(has_control_flow=True),
|
model_attribute=ModelAttribute(has_control_flow=True),
|
||||||
)
|
)
|
||||||
model_zoo.register(
|
model_zoo.register(
|
||||||
|
|
|
@ -97,7 +97,7 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
|
||||||
|
|
||||||
# TODO(ver217): add more models
|
# TODO(ver217): add more models
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(
|
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(
|
||||||
"transformers_llama_for_casual_lm"
|
"transformers_llama_for_causal_lm"
|
||||||
).items():
|
).items():
|
||||||
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)
|
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)
|
||||||
|
|
||||||
|
|
|
@ -105,7 +105,7 @@ def check_low_level_zero_lora(stage, model_name, early_stop: bool = True):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry(model_name)
|
sub_model_zoo = model_zoo.get_sub_registry(model_name)
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
task_type = None
|
task_type = None
|
||||||
if name == "transformers_llama_for_casual_lm":
|
if name == "transformers_llama_for_causal_lm":
|
||||||
task_type = "CAUSAL_LM"
|
task_type = "CAUSAL_LM"
|
||||||
if name == "transformers_llama_for_sequence_classification":
|
if name == "transformers_llama_for_sequence_classification":
|
||||||
task_type = "SEQ_CLS"
|
task_type = "SEQ_CLS"
|
||||||
|
|
|
@ -74,7 +74,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS)
|
@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS)
|
||||||
@parameterize("shard", [True, False])
|
@parameterize("shard", [True, False])
|
||||||
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
|
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
|
||||||
@parameterize("size_per_shard", [32])
|
@parameterize("size_per_shard", [32])
|
||||||
@parameterize("tp_size", [1, 2])
|
@parameterize("tp_size", [1, 2])
|
||||||
@parameterize("zero_size", [2])
|
@parameterize("zero_size", [2])
|
||||||
|
|
|
@ -20,7 +20,7 @@ from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
@parameterize("shard", [False, True])
|
@parameterize("shard", [False, True])
|
||||||
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
|
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
|
||||||
def exam_torch_load_from_gemini(shard: bool, model_name: str):
|
def exam_torch_load_from_gemini(shard: bool, model_name: str):
|
||||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||||
criterion = lambda x: x.mean()
|
criterion = lambda x: x.mean()
|
||||||
|
|
|
@ -39,7 +39,7 @@ else:
|
||||||
|
|
||||||
|
|
||||||
@parameterize("shard", [True, False])
|
@parameterize("shard", [True, False])
|
||||||
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
|
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
|
||||||
@parameterize("size_per_shard", [32])
|
@parameterize("size_per_shard", [32])
|
||||||
@parameterize("test_config", TEST_CONFIGS)
|
@parameterize("test_config", TEST_CONFIGS)
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
|
|
|
@ -149,7 +149,7 @@ def check_low_level_zero_lora_checkpointIO(
|
||||||
if name != "transformers_llama":
|
if name != "transformers_llama":
|
||||||
continue
|
continue
|
||||||
task_type = None
|
task_type = None
|
||||||
if name == "transformers_llama_for_casual_lm":
|
if name == "transformers_llama_for_causal_lm":
|
||||||
task_type = "CAUSAL_LM"
|
task_type = "CAUSAL_LM"
|
||||||
if name == "transformers_llama_for_sequence_classification":
|
if name == "transformers_llama_for_sequence_classification":
|
||||||
task_type = "SEQ_CLS"
|
task_type = "SEQ_CLS"
|
||||||
|
|
|
@ -18,7 +18,7 @@ from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
|
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
|
||||||
@parameterize("plugin_type", ["ddp", "zero", "gemini"])
|
@parameterize("plugin_type", ["ddp", "zero", "gemini"])
|
||||||
def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32):
|
def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32):
|
||||||
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
|
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
|
||||||
|
|
|
@ -91,7 +91,7 @@ def run_lora_test():
|
||||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
task_type = None
|
task_type = None
|
||||||
if name == "transformers_llama_for_casual_lm":
|
if name == "transformers_llama_for_causal_lm":
|
||||||
task_type = "CAUSAL_LM"
|
task_type = "CAUSAL_LM"
|
||||||
if name == "transformers_llama_for_sequence_classification":
|
if name == "transformers_llama_for_sequence_classification":
|
||||||
task_type = "SEQ_CLS"
|
task_type = "SEQ_CLS"
|
||||||
|
|
|
@ -6,6 +6,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
|
@ -107,13 +108,13 @@ def run_pp(
|
||||||
|
|
||||||
# check loss
|
# check loss
|
||||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
assert_close(torch_loss, pp_ret["loss"])
|
||||||
|
|
||||||
# check gradients
|
# check gradients
|
||||||
for i in range(num_model_chunk):
|
for i in range(num_model_chunk):
|
||||||
idx = world_size * i + rank
|
idx = world_size * i + rank
|
||||||
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
|
assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
|
||||||
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
|
assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
|
||||||
|
|
||||||
# step
|
# step
|
||||||
torch_optimizer.step()
|
torch_optimizer.step()
|
||||||
|
@ -123,8 +124,8 @@ def run_pp(
|
||||||
# check updated param
|
# check updated param
|
||||||
for i in range(num_model_chunk):
|
for i in range(num_model_chunk):
|
||||||
idx = world_size * i + rank
|
idx = world_size * i + rank
|
||||||
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
|
assert_close(torch_model.layers[idx].weight, sharded_model[i].weight)
|
||||||
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
|
assert_close(torch_model.layers[idx].bias, sharded_model[i].bias)
|
||||||
|
|
||||||
# forward only
|
# forward only
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -135,14 +136,14 @@ def run_pp(
|
||||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
|
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
|
||||||
)
|
)
|
||||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
assert_close(torch_loss, pp_ret["loss"])
|
||||||
|
|
||||||
for layer in sharded_model:
|
for layer in sharded_model:
|
||||||
if layer.weight.grad is None:
|
if layer.weight.grad is None:
|
||||||
assert layer.weight.grad is None and layer.bias.grad is None
|
assert layer.weight.grad is None and layer.bias.grad is None
|
||||||
else:
|
else:
|
||||||
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
|
assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad))
|
||||||
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
|
assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
|
|
@ -6,6 +6,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
|
@ -103,13 +104,13 @@ def examine_pp(num_microbatch: int, batch_size: int):
|
||||||
|
|
||||||
# check loss
|
# check loss
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
assert_close(torch_loss, pp_ret["loss"])
|
||||||
|
|
||||||
# check gradients
|
# check gradients
|
||||||
for i in range(len(sharded_model)):
|
for i in range(len(sharded_model)):
|
||||||
idx = rank * num_local_layer + i
|
idx = rank * num_local_layer + i
|
||||||
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
|
assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
|
||||||
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
|
assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
|
||||||
|
|
||||||
# step
|
# step
|
||||||
torch_optimizer.step()
|
torch_optimizer.step()
|
||||||
|
@ -119,8 +120,8 @@ def examine_pp(num_microbatch: int, batch_size: int):
|
||||||
# check updated param
|
# check updated param
|
||||||
for i in range(len(sharded_model)):
|
for i in range(len(sharded_model)):
|
||||||
idx = rank * num_local_layer + i
|
idx = rank * num_local_layer + i
|
||||||
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
|
assert_close(torch_model.layers[idx].weight, sharded_model[i].weight)
|
||||||
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
|
assert_close(torch_model.layers[idx].bias, sharded_model[i].bias)
|
||||||
|
|
||||||
# forward only
|
# forward only
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -131,14 +132,14 @@ def examine_pp(num_microbatch: int, batch_size: int):
|
||||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
|
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
|
||||||
)
|
)
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
assert_close(torch_loss, pp_ret["loss"])
|
||||||
|
|
||||||
for layer in sharded_model:
|
for layer in sharded_model:
|
||||||
if layer.weight.grad is None:
|
if layer.weight.grad is None:
|
||||||
assert layer.weight.grad is None and layer.bias.grad is None
|
assert layer.weight.grad is None and layer.bias.grad is None
|
||||||
else:
|
else:
|
||||||
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
|
assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad))
|
||||||
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
|
assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad))
|
||||||
|
|
||||||
|
|
||||||
def run_dist(
|
def run_dist(
|
||||||
|
|
|
@ -88,6 +88,7 @@ def check_attn_func(dtype: torch.dtype, attn_func, attn_kwargs: dict, padding_ma
|
||||||
padding_mask = padding_mask[:, None, :, None].logical_not()
|
padding_mask = padding_mask[:, None, :, None].logical_not()
|
||||||
ref_output = ref_output.masked_fill(padding_mask, 0)
|
ref_output = ref_output.masked_fill(padding_mask, 0)
|
||||||
output = output.masked_fill(padding_mask, 0)
|
output = output.masked_fill(padding_mask, 0)
|
||||||
|
|
||||||
assert_close(output, ref_output, **tols)
|
assert_close(output, ref_output, **tols)
|
||||||
output.mean().backward()
|
output.mean().backward()
|
||||||
ref_output.mean().backward()
|
ref_output.mean().backward()
|
||||||
|
@ -128,6 +129,8 @@ def test_flash_attn_func(dtype: torch.dtype):
|
||||||
attn_kwargs, padding_mask = gen_kwargs_func(dtype)
|
attn_kwargs, padding_mask = gen_kwargs_func(dtype)
|
||||||
for attn_func, name, need_postprocess in attn_funcs:
|
for attn_func, name, need_postprocess in attn_funcs:
|
||||||
print(f"{dtype}, {name}, {mask_type}")
|
print(f"{dtype}, {name}, {mask_type}")
|
||||||
|
if mask_type == "padded":
|
||||||
|
pass
|
||||||
if need_postprocess:
|
if need_postprocess:
|
||||||
check_attn_func(dtype, attn_func, post_process_kwargs_for_raw_attn(attn_kwargs), padding_mask)
|
check_attn_func(dtype, attn_func, post_process_kwargs_for_raw_attn(attn_kwargs), padding_mask)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -0,0 +1,186 @@
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.shardformer.layer import AttnMaskType
|
||||||
|
from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention
|
||||||
|
from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize("seq_len", [4096])
|
||||||
|
@parameterize("bs", [2])
|
||||||
|
@parameterize("nheads", [5])
|
||||||
|
@parameterize("d", [128])
|
||||||
|
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||||
|
def check_ring_attn(seq_len, bs, nheads, d, dtype):
|
||||||
|
torch.cuda.manual_seed(2)
|
||||||
|
device = get_current_device()
|
||||||
|
sp_group = dist.group.WORLD
|
||||||
|
sp_size = dist.get_world_size()
|
||||||
|
# Some outliers may seem large, but our errors are still lower than
|
||||||
|
# than Megatron-LM context parallel's
|
||||||
|
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
|
||||||
|
# and the original zigzag implementation's (https://github.com/zhuzilin/ring-flash-attention/tree/main)
|
||||||
|
atol = rtol = 7e-3
|
||||||
|
|
||||||
|
# Setup inputs
|
||||||
|
qkv = torch.randn(bs, seq_len, 3, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||||
|
local_qkv = split_batch_zigzag(qkv, sp_group)
|
||||||
|
q, k, v = local_qkv.unbind(dim=-3)
|
||||||
|
q, k, v = [x.squeeze(2).detach().clone().transpose(1, 2) for x in (q, k, v)] # (B, nHeads, Sq, D)
|
||||||
|
q.requires_grad = k.requires_grad = v.requires_grad = True
|
||||||
|
|
||||||
|
# Ring attention vs single GPU
|
||||||
|
ring_out, ring_lse = RingAttention.attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
sp_group,
|
||||||
|
AttnMaskType.CAUSAL,
|
||||||
|
return_softmax=True,
|
||||||
|
inner_ring_size=max(2, sp_size // 2),
|
||||||
|
# inner_ring_size=4
|
||||||
|
)
|
||||||
|
ring_out = ring_out.transpose(1, 2)
|
||||||
|
out, lse, _ = flash_attn_qkvpacked_func(
|
||||||
|
qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Checkout out and softmax denominator
|
||||||
|
local_out = split_batch_zigzag(out, sp_group)
|
||||||
|
local_lse = split_batch_zigzag(lse, sp_group, seq_dim=-1)
|
||||||
|
local_lse = local_lse.transpose(1, 2).contiguous().view(-1, ring_lse.shape[-1]) # (B, nHeads, Sq) -> (T, nHeads)
|
||||||
|
assert_close(ring_lse, local_lse, atol=atol, rtol=rtol)
|
||||||
|
assert_close(ring_out, local_out, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
# Check grads
|
||||||
|
ring_out.sum().backward()
|
||||||
|
out.sum().backward()
|
||||||
|
ring_dq, ring_dk, ring_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)]
|
||||||
|
dqkv = qkv.grad
|
||||||
|
local_dqkv = split_batch_zigzag(dqkv, sp_group)
|
||||||
|
|
||||||
|
assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol)
|
||||||
|
assert_close(ring_dk, local_dqkv[:, :, 1], atol=atol, rtol=rtol)
|
||||||
|
assert_close(ring_dv, local_dqkv[:, :, 2], atol=atol, rtol=rtol)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
print(
|
||||||
|
f"sp_size {dist.get_world_size()}, inner ring size {dist.get_world_size(RingAttention.INNER_RING_GROUP)} passed."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize("seqlen", [4096])
|
||||||
|
@parameterize("bs", [2])
|
||||||
|
@parameterize("nheads", [5])
|
||||||
|
@parameterize("d", [128])
|
||||||
|
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||||
|
def check_packed_seq(seqlen, bs, nheads, d, dtype):
|
||||||
|
device = get_current_device()
|
||||||
|
sp_group = dist.group.WORLD
|
||||||
|
sp_size = dist.get_world_size()
|
||||||
|
atol = rtol = 7e-3
|
||||||
|
torch.cuda.manual_seed(2)
|
||||||
|
# Prepare varlen attention mask
|
||||||
|
padding_mask = torch.ones((bs, seqlen), dtype=torch.int, device=device)
|
||||||
|
padding_mask[: bs // 2, (seqlen // 4) * 3 :] = 0
|
||||||
|
padding_mask[:, seqlen // 2 :] = 0
|
||||||
|
|
||||||
|
input_embeds = torch.randn(bs, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||||
|
|
||||||
|
# Forward
|
||||||
|
# out = ColoAttention.attention(q, k, v, **mask_info)
|
||||||
|
flat_input = input_embeds.view(-1, nheads, d)[padding_mask.flatten().nonzero().squeeze()]
|
||||||
|
qkv = torch.stack([flat_input] * 3, dim=1)
|
||||||
|
qkv.retain_grad()
|
||||||
|
|
||||||
|
input_embeds, mask_info, _ = RingAttention.prepare_varlen_batch(padding_mask, sp_group, input_embeds)
|
||||||
|
out, lse, _ = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv,
|
||||||
|
mask_info["cu_seqlens"] * sp_size,
|
||||||
|
mask_info["max_seqlen"] * sp_size,
|
||||||
|
return_attn_probs=True,
|
||||||
|
causal=True,
|
||||||
|
# deterministic=True
|
||||||
|
)
|
||||||
|
# Test the splitting function
|
||||||
|
local_input = split_varlen_zigzag(
|
||||||
|
flat_input, mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size
|
||||||
|
)
|
||||||
|
assert (local_input == input_embeds.view(-1, nheads, d)[mask_info["valid_indices"]]).all()
|
||||||
|
del local_input, flat_input
|
||||||
|
|
||||||
|
q_ring, k_ring, v_ring = [input_embeds.clone().transpose(1, 2) for _ in range(3)]
|
||||||
|
q_ring.retain_grad()
|
||||||
|
k_ring.retain_grad()
|
||||||
|
v_ring.retain_grad()
|
||||||
|
|
||||||
|
ring_out, ring_lse = RingAttention.attention(
|
||||||
|
q_ring,
|
||||||
|
k_ring,
|
||||||
|
v_ring,
|
||||||
|
sp_group,
|
||||||
|
**mask_info,
|
||||||
|
pad_output=False,
|
||||||
|
return_softmax=True,
|
||||||
|
# deterministic=True
|
||||||
|
)
|
||||||
|
ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d)
|
||||||
|
# Check output
|
||||||
|
lse = lse.transpose(0, 1)
|
||||||
|
out, lse = split_varlen_zigzag(
|
||||||
|
[out, lse], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size
|
||||||
|
)
|
||||||
|
assert_close(lse, ring_lse, atol=atol, rtol=rtol)
|
||||||
|
assert_close(out, ring_out, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
# Check grads
|
||||||
|
labels = torch.ones(out.shape[0], dtype=dtype, device=device)
|
||||||
|
F.mse_loss(out.sum((-2, -1)), labels).backward()
|
||||||
|
F.mse_loss(ring_out.sum((-2, -1)), labels[: ring_out.shape[0]]).backward()
|
||||||
|
dq, dk, dv = [
|
||||||
|
split_varlen_zigzag(
|
||||||
|
qkv.grad[:, i], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size
|
||||||
|
)
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
dq_ring, dk_ring, dv_ring = [
|
||||||
|
x.transpose(1, 2).reshape(-1, nheads, d)[mask_info["valid_indices"]]
|
||||||
|
for x in (q_ring.grad, k_ring.grad, v_ring.grad)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert_close(dq, dq_ring, atol=atol, rtol=rtol)
|
||||||
|
assert_close(dk, dk_ring, atol=atol, rtol=rtol)
|
||||||
|
assert_close(dv, dv_ring, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
|
||||||
|
def launch_single_ring(rank, world_size, port):
|
||||||
|
colossalai.launch(rank, world_size, "localhost", port)
|
||||||
|
check_packed_seq()
|
||||||
|
check_ring_attn()
|
||||||
|
|
||||||
|
|
||||||
|
def launch_double_ring(rank, world_size, port):
|
||||||
|
colossalai.launch(rank, world_size, "localhost", port)
|
||||||
|
check_ring_attn()
|
||||||
|
|
||||||
|
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@parameterize("world_size", [2])
|
||||||
|
def test_ring_attn(world_size):
|
||||||
|
spawn(launch_single_ring, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@parameterize("world_size", [4])
|
||||||
|
def test_double_ring(world_size):
|
||||||
|
spawn(launch_double_ring, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_ring_attn()
|
||||||
|
test_double_ring()
|
|
@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.optim import Adam, Optimizer
|
from torch.optim import Adam, Optimizer
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
|
@ -259,7 +260,6 @@ def run_forward_backward_with_hybrid_plugin(
|
||||||
org_output = org_model(**unshard_test_data)
|
org_output = org_model(**unshard_test_data)
|
||||||
org_loss = criterion(org_output)
|
org_loss = criterion(org_output)
|
||||||
org_loss.backward()
|
org_loss.backward()
|
||||||
|
|
||||||
return org_loss, org_output, sharded_loss, sharded_output
|
return org_loss, org_output, sharded_loss, sharded_output
|
||||||
|
|
||||||
|
|
||||||
|
@ -302,11 +302,12 @@ def run_forward_backward_with_low_level_zero_plugin(
|
||||||
|
|
||||||
|
|
||||||
def check_output_hidden_state(
|
def check_output_hidden_state(
|
||||||
org_output: Tensor,
|
org_output: BaseModelOutputWithPast,
|
||||||
sharded_output: Tensor,
|
sharded_output: BaseModelOutputWithPast,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
atol: float = 1e-5,
|
atol: float = 1e-5,
|
||||||
rtol: float = 1e-3,
|
rtol: float = 1e-3,
|
||||||
|
shard_config: Optional[ShardConfig] = None,
|
||||||
):
|
):
|
||||||
org_hidden_state = org_output.last_hidden_state
|
org_hidden_state = org_output.last_hidden_state
|
||||||
|
|
||||||
|
@ -315,6 +316,14 @@ def check_output_hidden_state(
|
||||||
else:
|
else:
|
||||||
sharded_hidden_state = sharded_output.last_hidden_state
|
sharded_hidden_state = sharded_output.last_hidden_state
|
||||||
|
|
||||||
|
# Check if the output sequence is gathered before cross entropy
|
||||||
|
if shard_config is not None:
|
||||||
|
seq_dim = 1
|
||||||
|
sp_group = shard_config.sequence_parallel_process_group
|
||||||
|
sp_size = shard_config.sequence_parallel_size
|
||||||
|
if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size:
|
||||||
|
org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)]
|
||||||
|
|
||||||
assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol)
|
assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
|
||||||
|
@ -374,8 +383,11 @@ def get_grad_tensors_for_check(
|
||||||
shard_grad = torch.cat(shard_grad_list, dim=dim)
|
shard_grad = torch.cat(shard_grad_list, dim=dim)
|
||||||
|
|
||||||
# embedding may be resized when using tensor parallel
|
# embedding may be resized when using tensor parallel
|
||||||
if shard_grad.shape[0] > org_grad.shape[0]:
|
try:
|
||||||
shard_grad = shard_grad[: org_grad.shape[0], :]
|
if shard_grad.shape[0] > org_grad.shape[0]:
|
||||||
|
shard_grad = shard_grad[: org_grad.shape[0], :]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
if verbose and dist.get_rank() == 0:
|
if verbose and dist.get_rank() == 0:
|
||||||
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
|
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
|
||||||
|
|
||||||
|
@ -404,9 +416,6 @@ def check_grad(
|
||||||
org_grad = getattr_(org_model, suffix).weight.grad
|
org_grad = getattr_(org_model, suffix).weight.grad
|
||||||
shard_grad = getattr_(sharded_model, suffix).weight.grad
|
shard_grad = getattr_(sharded_model, suffix).weight.grad
|
||||||
shard_weight = getattr_(sharded_model, suffix).weight
|
shard_weight = getattr_(sharded_model, suffix).weight
|
||||||
# if verbose and dist.get_rank() == 0:
|
|
||||||
# print("shard_weight", shard_weight)
|
|
||||||
# print("org_grad", org_grad)
|
|
||||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||||
shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))]
|
shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))]
|
||||||
dist.all_gather(shard_grad_list, shard_grad, tp_group)
|
dist.all_gather(shard_grad_list, shard_grad, tp_group)
|
||||||
|
@ -440,7 +449,7 @@ def check_all_grad_tensors(check_tensors):
|
||||||
"org_grad": tensor to be compared from the original model
|
"org_grad": tensor to be compared from the original model
|
||||||
"shard_grad": tensor to be compared from the sharded model
|
"shard_grad": tensor to be compared from the sharded model
|
||||||
"""
|
"""
|
||||||
for suffix, check_info in check_tensors.items():
|
for idx, (suffix, check_info) in enumerate(check_tensors.items()):
|
||||||
org_grad = check_info["org_grad"]
|
org_grad = check_info["org_grad"]
|
||||||
shard_grad = check_info["shard_grad"]
|
shard_grad = check_info["shard_grad"]
|
||||||
rtol = check_info["rtol"]
|
rtol = check_info["rtol"]
|
||||||
|
|
|
@ -271,7 +271,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_command_test(test_config):
|
def run_command_test(test_config):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm")
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
@ -321,7 +321,7 @@ def run_command_test(test_config):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_command_3d_test(test_config):
|
def run_command_3d_test(test_config):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm")
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
|
|
@ -63,7 +63,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
|
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
|
||||||
):
|
):
|
||||||
master2working = sharded_optimizer.get_master_to_working_map()
|
master2working = sharded_optimizer.get_master_to_working_map()
|
||||||
for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
|
for (name, p1), p2 in zip(
|
||||||
|
llama_model.named_parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]
|
||||||
|
):
|
||||||
working_p = master2working[id(p2)]
|
working_p = master2working[id(p2)]
|
||||||
grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p))
|
grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p))
|
||||||
grad_index = (
|
grad_index = (
|
||||||
|
@ -73,7 +75,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
)
|
)
|
||||||
grad = grads[grad_index]
|
grad = grads[grad_index]
|
||||||
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
|
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
|
||||||
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)
|
try:
|
||||||
|
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to check grad for {name}") from e
|
||||||
|
|
||||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||||
grads_to_check = {}
|
grads_to_check = {}
|
||||||
|
@ -114,89 +119,130 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
atol, rtol = 5e-3, 5e-3
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
|
||||||
if org_model.__class__.__name__ == "LlamaModel":
|
if org_model.__class__.__name__ == "LlamaModel":
|
||||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
check_output_hidden_state(
|
||||||
|
org_output,
|
||||||
|
sharded_output,
|
||||||
|
stage_manager,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
shard_config=booster.plugin.shard_config,
|
||||||
|
)
|
||||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
# check weights
|
# check weights
|
||||||
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
|
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
if test_config["precision"] == "fp32":
|
if test_config["precision"] == "fp32":
|
||||||
atol, rtol = 1e-4, 1e-3
|
atol, rtol = 1e-4, 1e-3
|
||||||
else:
|
else:
|
||||||
atol, rtol = 5e-3, 5e-3
|
atol, rtol = 5e-3, 5e-3
|
||||||
try:
|
check_weight(
|
||||||
check_weight(
|
llama_model,
|
||||||
llama_model,
|
shard_llama_model,
|
||||||
shard_llama_model,
|
col_layer_for_check,
|
||||||
col_layer_for_check,
|
tp_group,
|
||||||
tp_group,
|
atol=atol,
|
||||||
atol=atol,
|
rtol=rtol,
|
||||||
rtol=rtol,
|
dim=1,
|
||||||
dim=1,
|
verbose=False,
|
||||||
verbose=False,
|
)
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed config: {test_config}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# check grads
|
# check grads
|
||||||
check_all_grad_tensors(grads_to_check)
|
check_all_grad_tensors(grads_to_check)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
{ # Ulysess + Flash attention
|
# Double Ring Attention
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 1,
|
||||||
|
"sp_size": 4,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "ring_attn",
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 0,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
"inner_ring_size": 2,
|
||||||
|
},
|
||||||
|
# Ring Attention + PP
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 2,
|
||||||
|
"sp_size": 2,
|
||||||
|
"num_microbatches": 2,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "ring_attn",
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
# Ring Attention + TP
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 1,
|
||||||
|
"sp_size": 2,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "ring_attn",
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 2,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{ # Ulysess + TP
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 1,
|
||||||
|
"sp_size": 2,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "all_to_all",
|
||||||
|
"enable_all_optimization": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 0,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{ # Ulysess + PP
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"sp_size": 2,
|
"sp_size": 2,
|
||||||
"num_microbatches": 2,
|
"num_microbatches": 2,
|
||||||
"enable_sequence_parallelism": True,
|
"enable_sequence_parallelism": True,
|
||||||
"sequence_parallelism_mode": "all_to_all",
|
"sequence_parallelism_mode": "all_to_all",
|
||||||
"enable_flash_attention": True,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"zero_stage": 0,
|
"zero_stage": 0,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
{ # Test ring + Flash attention
|
|
||||||
"tp_size": 2,
|
|
||||||
"pp_size": 1,
|
|
||||||
"sp_size": 2,
|
|
||||||
"num_microbatches": 1,
|
|
||||||
"enable_sequence_parallelism": True,
|
|
||||||
"sequence_parallelism_mode": "ring",
|
|
||||||
"enable_flash_attention": True,
|
|
||||||
"use_lazy_init": True,
|
|
||||||
"zero_stage": 2,
|
|
||||||
"precision": "fp16",
|
|
||||||
"initial_scale": 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tp_size": 1,
|
|
||||||
"pp_size": 1,
|
|
||||||
"sp_size": 2,
|
|
||||||
"num_microbatches": 1,
|
|
||||||
"enable_sequence_parallelism": True,
|
|
||||||
"sequence_parallelism_mode": "all_to_all",
|
|
||||||
"use_lazy_init": True,
|
|
||||||
"zero_stage": 1,
|
|
||||||
"precision": "fp16",
|
|
||||||
"initial_scale": 1,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"tp_size": 4,
|
"tp_size": 4,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"num_microbatches": 1,
|
"num_microbatches": 1,
|
||||||
"enable_sequence_parallelism": True,
|
"enable_sequence_parallelism": True,
|
||||||
"sequence_parallelism_mode": "split_gather",
|
"sequence_parallelism_mode": "split_gather",
|
||||||
"enable_flash_attention": False,
|
"enable_flash_attention": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 1,
|
||||||
|
"sp_size": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "ring",
|
||||||
|
"enable_flash_attention": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 2,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
|
@ -240,12 +286,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
)
|
)
|
||||||
def run_llama_test(test_config):
|
def run_llama_test(test_config):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name:
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed config: {test_config}")
|
print(f"Failed config: {test_config}, model name: {name}")
|
||||||
raise e
|
raise e
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
Randomizer.reset_index()
|
Randomizer.reset_index()
|
||||||
|
|
Loading…
Reference in New Issue