mirror of https://github.com/InternLM/InternLM
fix lint
parent
29df765f65
commit
f191853bf4
|
@ -283,8 +283,10 @@ def args_sanity_check():
|
|||
if gpc.config.parallel["tensor"].get("mode", None) is None:
|
||||
gpc.config.parallel["tensor"]["mode"] = "origin_tp"
|
||||
|
||||
if gpc.config.parallel["tensor"].get("mode", None) == 'fstp':
|
||||
assert gpc.config.parallel.sequence_parallel is True, "when the tp_mode is fstp, the sequence_parallel should be True."
|
||||
if gpc.config.parallel["tensor"].get("mode", None) == "fstp":
|
||||
assert (
|
||||
gpc.config.parallel.sequence_parallel is True
|
||||
), "when the tp_mode is fstp, the sequence_parallel should be True."
|
||||
|
||||
# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
|
||||
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:
|
||||
|
|
|
@ -9,10 +9,9 @@ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
|||
from flash_attn.utils.distributed import all_reduce, reduce_scatter
|
||||
from torch import nn
|
||||
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.utils import Silu, fused_dense_func_torch, fsdp_fused_dense_func
|
||||
from internlm.model.utils import Silu, fstp_fused_dense_func, fused_dense_func_torch
|
||||
|
||||
|
||||
class ScaleColumnParallelLinear(nn.Linear):
|
||||
|
@ -124,7 +123,12 @@ class ColumnParallelLinearTorch(ColumnParallelLinear):
|
|||
# If not, then the input is already gathered.
|
||||
|
||||
return fused_dense_func_torch(
|
||||
x, self.weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel, gather_dim=gather_dim,
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
process_group=self.process_group,
|
||||
sequence_parallel=self.sequence_parallel,
|
||||
gather_dim=gather_dim,
|
||||
)
|
||||
|
||||
|
||||
|
@ -204,31 +208,13 @@ class FeedForward(nn.Module):
|
|||
out = self.w3(Silu(w1_o, w2_o))
|
||||
return out
|
||||
|
||||
class FSDPLinear(ColumnParallelLinear):
|
||||
|
||||
class FSTPLinear(ColumnParallelLinear):
|
||||
def forward(self, x):
|
||||
return fsdp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group)
|
||||
return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group)
|
||||
|
||||
|
||||
class FSDPScaleLinear(ScaleColumnParallelLinear):
|
||||
|
||||
def forward(self, input): # pylint: disable=W0622
|
||||
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||
# we do an all_gather of x before doing the matmul.
|
||||
# If not, then the input is already gathered.
|
||||
if self.weight_scale != 1:
|
||||
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
|
||||
else:
|
||||
weight = self.weight
|
||||
return fsdp_fused_dense_func(
|
||||
input,
|
||||
weight,
|
||||
self.bias,
|
||||
process_group=self.process_group,
|
||||
)
|
||||
|
||||
|
||||
class FSDPFeedForward(nn.Module):
|
||||
class FSTPFeedForward(nn.Module):
|
||||
"""
|
||||
FeedForward.
|
||||
|
||||
|
@ -259,7 +245,7 @@ class FSDPFeedForward(nn.Module):
|
|||
|
||||
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = FSDPLinear(
|
||||
self.w1 = FSTPLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
process_group,
|
||||
|
@ -268,7 +254,7 @@ class FSDPFeedForward(nn.Module):
|
|||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w2 = FSDPLinear(
|
||||
self.w2 = FSTPLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
process_group,
|
||||
|
@ -277,7 +263,7 @@ class FSDPFeedForward(nn.Module):
|
|||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w3 = FSDPLinear(
|
||||
self.w3 = FSTPLinear(
|
||||
hidden_features,
|
||||
out_features,
|
||||
process_group,
|
||||
|
|
|
@ -15,13 +15,16 @@ from internlm.initialize.initialize_tensor import normal_, scaled_init_method_no
|
|||
from internlm.model.embedding import Embedding1D
|
||||
from internlm.model.linear import (
|
||||
FeedForward,
|
||||
FSTPFeedForward,
|
||||
RewardModelLinear,
|
||||
ScaleColumnParallelLinear,
|
||||
FSDPScaleLinear,
|
||||
FSDPFeedForward,
|
||||
)
|
||||
from internlm.model.multi_head_attention import MHA
|
||||
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm, split_forward_gather_backward
|
||||
from internlm.model.utils import (
|
||||
gather_forward_split_backward,
|
||||
split_forward_gather_backward,
|
||||
try_import_RMSNorm,
|
||||
)
|
||||
from internlm.solver.pipeline_utils import partition_uniform
|
||||
from internlm.utils.checkpoint import activation_checkpoint
|
||||
from internlm.utils.common import filter_kwargs
|
||||
|
@ -74,7 +77,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
tp_mode: str = 'origin_tp',
|
||||
tp_mode: str = "origin_tp",
|
||||
):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
|
@ -111,7 +114,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
|
||||
if use_swiglu:
|
||||
mlp_cls = FeedForward if tp_mode == 'origin_tp' else FSDPFeedForward
|
||||
mlp_cls = FeedForward if tp_mode == "origin_tp" else FSTPFeedForward
|
||||
self.mlp = mlp_cls(
|
||||
hidden_size,
|
||||
int(hidden_size * mlp_ratio),
|
||||
|
@ -173,7 +176,6 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
else:
|
||||
normal_(std=0.006 if "fc1" in name else 0.0015)(param.data)
|
||||
|
||||
|
||||
def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
|
||||
if self.checkpoint and self.training:
|
||||
return activation_checkpoint(
|
||||
|
@ -341,7 +343,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
use_scaled_init=use_scaled_init,
|
||||
use_swiglu=use_swiglu,
|
||||
use_flash_attn=use_flash_attn,
|
||||
tp_mode = self.tp_mode,
|
||||
tp_mode=self.tp_mode,
|
||||
)
|
||||
for lid in range(num_layers)
|
||||
]
|
||||
|
@ -388,7 +390,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
||||
indexes = indexes[0]
|
||||
# if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension.
|
||||
if gpc.config.parallel.sequence_parallel and self.tp_mode == 'fstp':
|
||||
if gpc.config.parallel.sequence_parallel and self.tp_mode == "fstp":
|
||||
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
|
||||
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
|
||||
|
|
|
@ -2,9 +2,10 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from einops import rearrange
|
||||
from flash_attn.modules.mha import (
|
||||
CrossAttention,
|
||||
|
@ -13,26 +14,25 @@ from flash_attn.modules.mha import (
|
|||
SelfAttention,
|
||||
_update_kv_cache,
|
||||
)
|
||||
from torch import nn
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import Module
|
||||
|
||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.embedding import DynamicNTKScalingRotaryEmbedding, RotaryEmbedding
|
||||
from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch, FSDPLinear
|
||||
|
||||
import torch
|
||||
|
||||
from typing import Any, Tuple
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
import torch.distributed as dist
|
||||
from internlm.model.linear import (
|
||||
ColumnParallelLinearTorch,
|
||||
FSTPLinear,
|
||||
RowParallelLinearTorch,
|
||||
)
|
||||
|
||||
|
||||
# adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py
|
||||
class _SeqAllToAll(torch.autograd.Function):
|
||||
"sequence alltoall"
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
|
||||
def forward(ctx: Any, group: dist.ProcessGroup, input_: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
|
||||
|
||||
ctx.group = group
|
||||
ctx.scatter_idx = scatter_idx
|
||||
|
@ -40,7 +40,7 @@ class _SeqAllToAll(torch.autograd.Function):
|
|||
|
||||
seq_world_size = dist.get_world_size(group)
|
||||
|
||||
input_list = [t.contiguous() for t in torch.tensor_split(input, seq_world_size, scatter_idx)]
|
||||
input_list = [t.contiguous() for t in torch.tensor_split(input_, seq_world_size, scatter_idx)]
|
||||
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
|
||||
# TODO Use all_to_all_single instead
|
||||
dist.all_to_all(output_list, input_list, group=group)
|
||||
|
@ -51,6 +51,7 @@ class _SeqAllToAll(torch.autograd.Function):
|
|||
return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)
|
||||
|
||||
|
||||
# adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py
|
||||
class DistributedAttention(torch.nn.Module):
|
||||
"""Initialization.
|
||||
|
||||
|
@ -73,7 +74,7 @@ class DistributedAttention(torch.nn.Module):
|
|||
second_gather_idx: int = 1,
|
||||
) -> None:
|
||||
|
||||
super(DistributedAttention, self).__init__()
|
||||
super().__init__()
|
||||
self.local_attn = local_attention
|
||||
self.spg = sequence_process_group
|
||||
self.first_scatter_idx = first_scatter_idx
|
||||
|
@ -82,7 +83,7 @@ class DistributedAttention(torch.nn.Module):
|
|||
self.second_gather_idx = second_gather_idx
|
||||
|
||||
def forward(self, qkv: Tensor, **kwargs: Any) -> Tensor:
|
||||
""" forward
|
||||
"""forward
|
||||
|
||||
Arguments:
|
||||
query (Tensor): query input to the layer
|
||||
|
@ -93,24 +94,25 @@ class DistributedAttention(torch.nn.Module):
|
|||
Returns:
|
||||
* output (Tensor): context output
|
||||
"""
|
||||
# TODO Merge three alltoall calls into one
|
||||
# Evaluation
|
||||
if qkv.ndim == 5:
|
||||
# in shape: [seq/tp_size, 3, head, head_dim]
|
||||
# in shape: [batch, seq/tp_size, 3, head, head_dim]
|
||||
qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx + 1, self.first_gather_idx + 1)
|
||||
#out shape : [seq, head/tp_size, head_dim]
|
||||
# out shape : [batch, seq, head/tp_size, head_dim]
|
||||
context_layer = self.local_attn(qkv, **kwargs)
|
||||
# in shape: [seq, head/tp_size, head_dim]
|
||||
output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx + 1, self.second_gather_idx + 1)
|
||||
else:
|
||||
|
||||
# in shape: [batch, seq, head/tp_size, head_dim]
|
||||
output = _SeqAllToAll.apply(
|
||||
self.spg, context_layer, self.second_scatter_idx + 1, self.second_gather_idx + 1
|
||||
)
|
||||
else: # training
|
||||
# in shape: [seq/tp_size, 3, head, head_dim]
|
||||
qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx, self.first_gather_idx)
|
||||
#out shape : [seq, head/tp_size, head_dim]
|
||||
# out shape : [seq, head/tp_size, head_dim]
|
||||
context_layer = self.local_attn(qkv, **kwargs)
|
||||
# in shape: [seq, head/tp_size, head_dim]
|
||||
output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx, self.second_gather_idx)
|
||||
|
||||
#out e.g., [s/p::h]
|
||||
# out e.g., [s/p::h]
|
||||
return output
|
||||
|
||||
|
||||
|
@ -157,7 +159,7 @@ class MHA(nn.Module):
|
|||
use_flash_attn: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
tp_mode: str = 'origin_tp',
|
||||
tp_mode: str = "origin_tp",
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
|
@ -185,7 +187,7 @@ class MHA(nn.Module):
|
|||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
|
||||
|
||||
# notice here should change bias=True
|
||||
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == 'origin_tp' else FSDPLinear
|
||||
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
||||
self.Wqkv = Wqkv_cls(
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
|
@ -201,12 +203,12 @@ class MHA(nn.Module):
|
|||
self.inner_cross_attn = inner_cross_attn_cls(
|
||||
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
||||
)
|
||||
if tp_mode == 'fstp':
|
||||
if tp_mode == "fstp":
|
||||
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group)
|
||||
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)
|
||||
|
||||
# output projection always have the bias (for now)
|
||||
out_proj_cls = RowParallelLinearTorch if tp_mode == 'origin_tp' else FSDPLinear
|
||||
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
||||
self.out_proj = out_proj_cls(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
|
@ -214,7 +216,6 @@ class MHA(nn.Module):
|
|||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
# need to assign tp attribute so that internlm know it is tensor parallel module
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
for name in ["out_proj", "Wqkv"]:
|
||||
|
@ -311,7 +312,6 @@ class MHA(nn.Module):
|
|||
qkv = rearrange(qkv, "t (three h d) -> t three h d", three=3, d=self.head_dim) # total x 3 x n_head x d
|
||||
qkv = self.rotary_emb(qkv, **kwargs)
|
||||
kwargs.pop("indexes")
|
||||
|
||||
if inference_params is None:
|
||||
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
|
||||
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
||||
|
|
|
@ -3,18 +3,14 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
import fused_dense_lib as fused_dense_cuda
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from flash_attn.utils.distributed import (
|
||||
all_reduce_raw,
|
||||
reduce_scatter_raw,
|
||||
)
|
||||
from flash_attn.utils.distributed import all_reduce_raw, reduce_scatter_raw
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import fused_dense_lib as fused_dense_cuda
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.logger import get_logger
|
||||
|
@ -123,8 +119,9 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool =
|
|||
shape = list(input_.shape)
|
||||
shape[gather_dim] = shape[gather_dim] * world_size
|
||||
output = torch.empty(shape, dtype=input_.dtype, device=input_.device)
|
||||
handle = torch.distributed.all_gather_into_tensor(output, input_.contiguous(),
|
||||
group=process_group, async_op=async_op)
|
||||
handle = torch.distributed.all_gather_into_tensor(
|
||||
output, input_.contiguous(), group=process_group, async_op=async_op
|
||||
)
|
||||
return output, handle
|
||||
|
||||
|
||||
|
@ -137,11 +134,11 @@ def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias):
|
|||
|
||||
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
|
||||
class FusedDenseFunc(torch.autograd.Function):
|
||||
"tp fused dense function"
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x, weight, bias, return_residual=False, process_group=None,
|
||||
sequence_parallel=True, gather_dim=0):
|
||||
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True, gather_dim=0):
|
||||
"""
|
||||
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
|
||||
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
|
||||
|
@ -171,7 +168,7 @@ class FusedDenseFunc(torch.autograd.Function):
|
|||
batch_dim = batch_shape.numel()
|
||||
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
|
||||
if min(batch_dim, n, *weight.shape) > 65535 * 32:
|
||||
raise RuntimeError('fused_dense only supports matrix dims <= 2M')
|
||||
raise RuntimeError("fused_dense only supports matrix dims <= 2M")
|
||||
output = F.linear(total_x, weight, bias)
|
||||
if ctx.compute_weight_gradient:
|
||||
ctx.save_for_backward(x, weight)
|
||||
|
@ -184,7 +181,7 @@ class FusedDenseFunc(torch.autograd.Function):
|
|||
def backward(ctx, grad_output, *args):
|
||||
grad_output = grad_output.contiguous()
|
||||
if ctx.return_residual:
|
||||
grad_input, = args
|
||||
(grad_input,) = args
|
||||
grad_input = grad_input.contiguous()
|
||||
process_group = ctx.process_group
|
||||
sequence_parallel = ctx.sequence_parallel
|
||||
|
@ -197,7 +194,7 @@ class FusedDenseFunc(torch.autograd.Function):
|
|||
else:
|
||||
total_x = x
|
||||
else:
|
||||
weight, = ctx.saved_tensors
|
||||
(weight,) = ctx.saved_tensors
|
||||
total_x = None
|
||||
batch_shape = grad_output.shape[:-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
|
@ -206,8 +203,7 @@ class FusedDenseFunc(torch.autograd.Function):
|
|||
if not ctx.return_residual:
|
||||
grad_input = F.linear(grad_output, weight.t())
|
||||
else:
|
||||
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
|
||||
grad_output, weight)
|
||||
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight)
|
||||
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||
if process_group is not None:
|
||||
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
||||
|
@ -282,7 +278,8 @@ class FusedDenseFuncTorch(FusedDenseFunc):
|
|||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
class FSDPFusedDenseFunc(torch.autograd.Function):
|
||||
class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||
"FSTP fused dense function"
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
|
@ -319,7 +316,7 @@ class FSDPFusedDenseFunc(torch.autograd.Function):
|
|||
batch_dim = batch_shape.numel()
|
||||
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
|
||||
if min(batch_dim, n, *total_weight.shape) > 65535 * 32:
|
||||
raise RuntimeError('fused_dense only supports matrix dims <= 2M')
|
||||
raise RuntimeError("fused_dense only supports matrix dims <= 2M")
|
||||
output = F.linear(total_x, total_weight, total_bias)
|
||||
if ctx.compute_weight_gradient:
|
||||
ctx.save_for_backward(x, weight)
|
||||
|
@ -332,14 +329,14 @@ class FSDPFusedDenseFunc(torch.autograd.Function):
|
|||
def backward(ctx, grad_output, *args):
|
||||
grad_output = grad_output.contiguous()
|
||||
if ctx.return_residual:
|
||||
grad_input, = args
|
||||
(grad_input,) = args
|
||||
grad_input = grad_input.contiguous()
|
||||
process_group = ctx.process_group
|
||||
if ctx.compute_weight_gradient:
|
||||
x, weight = ctx.saved_tensors
|
||||
total_x = x
|
||||
else:
|
||||
weight, = ctx.saved_tensors
|
||||
(weight,) = ctx.saved_tensors
|
||||
total_x = None
|
||||
batch_shape = grad_output.shape[:-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
|
@ -357,8 +354,7 @@ class FSDPFusedDenseFunc(torch.autograd.Function):
|
|||
if not ctx.return_residual:
|
||||
grad_input = F.linear(grad_output, total_weight.t())
|
||||
else:
|
||||
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
|
||||
grad_output, total_weight)
|
||||
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, total_weight)
|
||||
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||
else:
|
||||
grad_input = None
|
||||
|
@ -399,12 +395,14 @@ def fused_dense_func_torch(
|
|||
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel, gather_dim)
|
||||
|
||||
|
||||
def fsdp_fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
|
||||
return_residual: bool = False, process_group = None):
|
||||
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
|
||||
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
|
||||
def fstp_fused_dense_func(
|
||||
x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, return_residual: bool = False, process_group=None
|
||||
):
|
||||
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
|
||||
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
||||
)
|
||||
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
||||
return FSDPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group)
|
||||
return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group)
|
||||
else:
|
||||
assert process_group is None
|
||||
out = F.linear(x, weight, bias)
|
||||
|
|
|
@ -54,7 +54,7 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
|
|||
def switch_sequence_parallel_mode():
|
||||
prev_mode = gpc.config.parallel.sequence_parallel
|
||||
try:
|
||||
if gpc.config.parallel["tensor"]["mode"] == 'fstp':
|
||||
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
||||
gpc.config.parallel.sequence_parallel = True
|
||||
else:
|
||||
gpc.config.parallel.sequence_parallel = False
|
||||
|
@ -106,10 +106,14 @@ def evaluate_on_val_dls(
|
|||
total_val_bsz = len(batch[1])
|
||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
||||
if gpc.config.parallel['tensor']['mode'] == 'fstp':
|
||||
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
||||
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
tensor_shape = torch.Size(
|
||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1] // sequence_world_size, gpc.config.HIDDEN_SIZE]
|
||||
[
|
||||
data_cfg.micro_bsz,
|
||||
batch[0]["input_ids"].shape[1] // sequence_world_size,
|
||||
gpc.config.HIDDEN_SIZE,
|
||||
]
|
||||
)
|
||||
else:
|
||||
tensor_shape = torch.Size(
|
||||
|
|
Loading…
Reference in New Issue