mirror of https://github.com/InternLM/InternLM
support fstp and refactor code
parent
bd4af3a31f
commit
189a313da6
|
@ -5,7 +5,7 @@ SEQ_LEN = 2048
|
||||||
HIDDEN_SIZE = 4096
|
HIDDEN_SIZE = 4096
|
||||||
NUM_ATTENTION_HEAD = 32
|
NUM_ATTENTION_HEAD = 32
|
||||||
MLP_RATIO = 8 / 3
|
MLP_RATIO = 8 / 3
|
||||||
NUM_LAYER = 32
|
NUM_LAYER = 4
|
||||||
VOCAB_SIZE = 103168
|
VOCAB_SIZE = 103168
|
||||||
|
|
||||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
||||||
|
@ -55,7 +55,7 @@ data = dict(
|
||||||
# defaults to the value of micro_num
|
# defaults to the value of micro_num
|
||||||
valid_micro_num=4,
|
valid_micro_num=4,
|
||||||
# defaults to 0, means disable evaluate
|
# defaults to 0, means disable evaluate
|
||||||
valid_every=1000,
|
valid_every=10,
|
||||||
pack_sample_into_one=False,
|
pack_sample_into_one=False,
|
||||||
total_steps=50000,
|
total_steps=50000,
|
||||||
skip_batches="",
|
skip_batches="",
|
||||||
|
@ -64,7 +64,7 @@ data = dict(
|
||||||
min_length=50,
|
min_length=50,
|
||||||
# train_folder=TRAIN_FOLDER,
|
# train_folder=TRAIN_FOLDER,
|
||||||
# valid_folder=VALID_FOLDER,
|
# valid_folder=VALID_FOLDER,
|
||||||
empty_cache_and_diag_interval=10,
|
empty_cache_and_diag_interval=100,
|
||||||
diag_outlier_ratio=1.1,
|
diag_outlier_ratio=1.1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -135,7 +135,7 @@ model = dict(
|
||||||
num_layers=NUM_LAYER,
|
num_layers=NUM_LAYER,
|
||||||
mlp_ratio=MLP_RATIO,
|
mlp_ratio=MLP_RATIO,
|
||||||
apply_post_layer_norm=False,
|
apply_post_layer_norm=False,
|
||||||
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
|
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
|
||||||
norm_type="rmsnorm",
|
norm_type="rmsnorm",
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
use_flash_attn=True,
|
use_flash_attn=True,
|
||||||
|
@ -155,7 +155,7 @@ tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||||
"""
|
"""
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=-1,
|
zero1=-1,
|
||||||
tensor=2,
|
tensor=dict(size=2, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'
|
||||||
pipeline=dict(size=1, interleaved_overlap=True),
|
pipeline=dict(size=1, interleaved_overlap=True),
|
||||||
sequence_parallel=True,
|
sequence_parallel=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -568,7 +568,8 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
# during model construction), this is because the random state will be different in different tensor parallel
|
# during model construction), this is because the random state will be different in different tensor parallel
|
||||||
# device of the same data parallel group. The underlying reason is that the device of tp_rank = 0 will perform
|
# device of the same data parallel group. The underlying reason is that the device of tp_rank = 0 will perform
|
||||||
# additional random operations during the RowParallelLinear module building process.
|
# additional random operations during the RowParallelLinear module building process.
|
||||||
set_mode(ParallelMode.DUMMY)
|
# set_mode(ParallelMode.DUMMY)
|
||||||
|
set_mode(ParallelMode.TENSOR)
|
||||||
|
|
||||||
seeds = get_seeds()
|
seeds = get_seeds()
|
||||||
seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()])
|
seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()])
|
||||||
|
|
|
@ -279,6 +279,12 @@ def args_sanity_check():
|
||||||
assert not (
|
assert not (
|
||||||
gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
|
gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
|
||||||
), "sequence parallel does not support use_flash_attn=False"
|
), "sequence parallel does not support use_flash_attn=False"
|
||||||
|
|
||||||
|
if gpc.config.parallel["tensor"].get("mode", None) is None:
|
||||||
|
gpc.config.parallel["tensor"]["mode"] = "origin_tp"
|
||||||
|
|
||||||
|
if gpc.config.parallel["tensor"].get("mode", None) is 'fstp':
|
||||||
|
assert gpc.config.parallel.sequence_parallel is True, "when the tp_mode is fstp, the sequence_parallel should be True."
|
||||||
|
|
||||||
# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
|
# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
|
||||||
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:
|
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:
|
||||||
|
|
|
@ -4,44 +4,20 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
||||||
from flash_attn.utils.distributed import all_reduce, reduce_scatter, all_gather_raw, reduce_scatter_raw
|
from flash_attn.utils.distributed import all_reduce, reduce_scatter, all_gather_raw, reduce_scatter_raw
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
|
|
||||||
|
# import fused_dense_cuda # from apex
|
||||||
|
import fused_dense_lib as fused_dense_cuda
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.model.utils import Silu, fused_dense_func_torch
|
from internlm.model.utils import Silu, fused_dense_func_torch
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import Tensor
|
|
||||||
from torch.distributed import ProcessGroup
|
|
||||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
|
||||||
|
|
||||||
# import fused_dense_cuda # from apex
|
|
||||||
import fused_dense_lib as fused_dense_cuda
|
|
||||||
|
|
||||||
from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_fwd, sqrelu_bwd
|
|
||||||
from flash_attn.utils.distributed import all_gather_raw, all_reduce_raw
|
|
||||||
# reduce_scatter_raw
|
|
||||||
from flash_attn.utils.distributed import reduce_scatter, all_reduce
|
|
||||||
|
|
||||||
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False, op=torch.distributed.ReduceOp.SUM):
|
|
||||||
world_size = torch.distributed.get_world_size(process_group)
|
|
||||||
assert input_.shape[0] % world_size == 0
|
|
||||||
output = torch.empty(
|
|
||||||
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
|
||||||
)
|
|
||||||
handle = torch.distributed.reduce_scatter_tensor(
|
|
||||||
output, input_.contiguous(), op=op, group=process_group, async_op=async_op
|
|
||||||
)
|
|
||||||
return output, handle
|
|
||||||
|
|
||||||
class ScaleColumnParallelLinear(nn.Linear):
|
class ScaleColumnParallelLinear(nn.Linear):
|
||||||
"""
|
"""
|
||||||
|
@ -231,7 +207,7 @@ class FeedForward(nn.Module):
|
||||||
out = self.w3(Silu(w1_o, w2_o))
|
out = self.w3(Silu(w1_o, w2_o))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class FusedDenseFunc_fsdp(torch.autograd.Function):
|
class FSDPFusedDenseFunc(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd
|
@custom_fwd
|
||||||
|
@ -243,21 +219,26 @@ class FusedDenseFunc_fsdp(torch.autograd.Function):
|
||||||
|
|
||||||
if torch.is_autocast_enabled():
|
if torch.is_autocast_enabled():
|
||||||
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
||||||
x = x.contiguous()
|
total_x = x.contiguous()
|
||||||
total_x = x
|
|
||||||
|
|
||||||
# do all_gather for weight and bias before actual computation
|
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
if world_size > 1:
|
||||||
if bias is not None:
|
# do all_gather for weight and bias before actual computation
|
||||||
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
|
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||||
handle_bias.wait()
|
if bias is not None:
|
||||||
|
total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True)
|
||||||
|
handle_bias.wait()
|
||||||
|
else:
|
||||||
|
total_bias = bias
|
||||||
|
handle_weight.wait()
|
||||||
else:
|
else:
|
||||||
|
total_weight = weight
|
||||||
total_bias = bias
|
total_bias = bias
|
||||||
|
|
||||||
if torch.is_autocast_enabled():
|
if torch.is_autocast_enabled():
|
||||||
total_weight = total_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
total_weight = total_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
||||||
total_bias = total_bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
|
total_bias = total_bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
|
||||||
handle_weight.wait()
|
|
||||||
total_weight = total_weight.contiguous()
|
total_weight = total_weight.contiguous()
|
||||||
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
||||||
batch_dim = batch_shape.numel()
|
batch_dim = batch_shape.numel()
|
||||||
|
@ -289,9 +270,13 @@ class FusedDenseFunc_fsdp(torch.autograd.Function):
|
||||||
batch_dim = batch_shape.numel()
|
batch_dim = batch_shape.numel()
|
||||||
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||||
|
|
||||||
# do all-gather for weight before backward
|
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
if world_size > 1:
|
||||||
handle_weight.wait()
|
# do all-gather for weight before backward
|
||||||
|
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||||
|
handle_weight.wait()
|
||||||
|
else:
|
||||||
|
total_weight = weight
|
||||||
|
|
||||||
if ctx.needs_input_grad[0]:
|
if ctx.needs_input_grad[0]:
|
||||||
if not ctx.return_residual:
|
if not ctx.return_residual:
|
||||||
|
@ -300,32 +285,24 @@ class FusedDenseFunc_fsdp(torch.autograd.Function):
|
||||||
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
|
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
|
||||||
grad_output, total_weight)
|
grad_output, total_weight)
|
||||||
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||||
# if process_group is not None:
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
# grad_input, handle_grad_input = reduce_scatter_raw(grad_input, process_group, async_op=True)
|
|
||||||
# grad_input, handle_grad_input = all_reduce_raw(grad_input, process_group, async_op=True)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
grad_input = None
|
grad_input = None
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
if ctx.needs_input_grad[1]:
|
if ctx.needs_input_grad[1]:
|
||||||
assert ctx.compute_weight_gradient
|
assert ctx.compute_weight_gradient
|
||||||
|
|
||||||
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
|
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
|
||||||
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
||||||
)
|
)
|
||||||
grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
|
if world_size > 1:
|
||||||
if grad_bias is not None:
|
grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
|
||||||
grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
|
if grad_bias is not None:
|
||||||
handle_grad_bias.wait()
|
grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
|
||||||
handle_grad_weight.wait()
|
handle_grad_bias.wait()
|
||||||
|
handle_grad_weight.wait()
|
||||||
else:
|
else:
|
||||||
grad_weight = None
|
grad_weight = None
|
||||||
grad_bias = grad_output if ctx.needs_input_grad[2] else None
|
grad_bias = grad_output if ctx.needs_input_grad[2] else None
|
||||||
# if process_group is not None and ctx.needs_input_grad[0]:
|
|
||||||
# handle_grad_input.wait()
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
return grad_input, grad_weight, grad_bias, None, None, None
|
return grad_input, grad_weight, grad_bias, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
@ -334,7 +311,7 @@ def fsdp_fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = No
|
||||||
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
|
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
|
||||||
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
|
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
|
||||||
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
||||||
return FusedDenseFunc_fsdp.apply(x, weight, bias, return_residual, process_group)
|
return FSDPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group)
|
||||||
else:
|
else:
|
||||||
assert process_group is None
|
assert process_group is None
|
||||||
out = F.linear(x, weight, bias)
|
out = F.linear(x, weight, bias)
|
||||||
|
@ -426,5 +403,5 @@ class FSDPFeedForward(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
w1_o = self.w1(x)
|
w1_o = self.w1(x)
|
||||||
w2_o = self.w2(x)
|
w2_o = self.w2(x)
|
||||||
out = self.w3(Silu(w1_o, w2_o))
|
out = self.w3(F.silu(w1_o) * w2_o)
|
||||||
return out
|
return out
|
||||||
|
|
|
@ -74,6 +74,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
use_scaled_init: bool = True,
|
use_scaled_init: bool = True,
|
||||||
use_swiglu: bool = True,
|
use_swiglu: bool = True,
|
||||||
use_flash_attn: bool = True,
|
use_flash_attn: bool = True,
|
||||||
|
tp_mode: str = 'origin_tp',
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
|
@ -98,6 +99,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
use_flash_attn=use_flash_attn,
|
use_flash_attn=use_flash_attn,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
tp_mode=tp_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dropout1 = nn.Dropout(drop_rate)
|
self.dropout1 = nn.Dropout(drop_rate)
|
||||||
|
@ -109,16 +111,8 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
|
|
||||||
if use_swiglu:
|
if use_swiglu:
|
||||||
# self.mlp = FeedForward(
|
mlp_cls = FeedForward if tp_mode == 'origin_tp' else FSDPFeedForward
|
||||||
# hidden_size,
|
self.mlp = mlp_cls(
|
||||||
# int(hidden_size * mlp_ratio),
|
|
||||||
# out_features=hidden_size,
|
|
||||||
# process_group=gpc.get_group(ParallelMode.TENSOR),
|
|
||||||
# bias=False,
|
|
||||||
# device=device,
|
|
||||||
# dtype=dtype,
|
|
||||||
# )
|
|
||||||
self.mlp = FSDPFeedForward(
|
|
||||||
hidden_size,
|
hidden_size,
|
||||||
int(hidden_size * mlp_ratio),
|
int(hidden_size * mlp_ratio),
|
||||||
out_features=hidden_size,
|
out_features=hidden_size,
|
||||||
|
@ -179,6 +173,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
else:
|
else:
|
||||||
normal_(std=0.006 if "fc1" in name else 0.0015)(param.data)
|
normal_(std=0.006 if "fc1" in name else 0.0015)(param.data)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
|
def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
|
||||||
if self.checkpoint and self.training:
|
if self.checkpoint and self.training:
|
||||||
return activation_checkpoint(
|
return activation_checkpoint(
|
||||||
|
@ -300,12 +295,12 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
checkpoint_layer_num = int(num_layers * checkpoint)
|
checkpoint_layer_num = int(num_layers * checkpoint)
|
||||||
|
self.tp_mode = gpc.config.parallel["tensor"]["mode"]
|
||||||
|
|
||||||
if is_reward:
|
if is_reward:
|
||||||
head_cls = RewardModelLinear
|
head_cls = RewardModelLinear
|
||||||
else:
|
else:
|
||||||
# head_cls = ScaleColumnParallelLinear
|
head_cls = ScaleColumnParallelLinear
|
||||||
head_cls = FSDPScaleLinear
|
|
||||||
if first:
|
if first:
|
||||||
if embed_split_hidden:
|
if embed_split_hidden:
|
||||||
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
|
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
|
||||||
|
@ -346,6 +341,7 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
use_scaled_init=use_scaled_init,
|
use_scaled_init=use_scaled_init,
|
||||||
use_swiglu=use_swiglu,
|
use_swiglu=use_swiglu,
|
||||||
use_flash_attn=use_flash_attn,
|
use_flash_attn=use_flash_attn,
|
||||||
|
tp_mode = self.tp_mode,
|
||||||
)
|
)
|
||||||
for lid in range(num_layers)
|
for lid in range(num_layers)
|
||||||
]
|
]
|
||||||
|
@ -391,7 +387,8 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
assert len(indexes) == 1
|
assert len(indexes) == 1
|
||||||
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
||||||
indexes = indexes[0]
|
indexes = indexes[0]
|
||||||
if gpc.config.parallel.sequence_parallel:
|
# if the tensor parallel mode is 'fstp', the indexes should also be split in sequence dimension.
|
||||||
|
if gpc.config.parallel.sequence_parallel and self.tp_mode == 'fstp':
|
||||||
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
|
indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0)
|
||||||
|
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
|
||||||
|
@ -408,8 +405,12 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
if hasattr(self, "norm"):
|
if hasattr(self, "norm"):
|
||||||
hidden_states = self.norm(hidden_states.float())
|
hidden_states = self.norm(hidden_states.float())
|
||||||
if hasattr(self, "head"):
|
if hasattr(self, "head"):
|
||||||
|
# if hidden_states.ndim == 3:
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
# hidden_states = self.head(hidden_states, dim=1)
|
||||||
|
# else:
|
||||||
|
# hidden_states = self.head(hidden_states)
|
||||||
hidden_states = self.head(hidden_states)
|
hidden_states = self.head(hidden_states)
|
||||||
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=0)
|
|
||||||
|
|
||||||
if not self.parallel_output:
|
if not self.parallel_output:
|
||||||
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
|
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
|
||||||
|
|
|
@ -57,49 +57,29 @@ class DistributedAttention(torch.nn.Module):
|
||||||
Arguments:
|
Arguments:
|
||||||
local_attention (Module): local attention with q,k,v
|
local_attention (Module): local attention with q,k,v
|
||||||
sequence_process_group (ProcessGroup): sequence parallel process group
|
sequence_process_group (ProcessGroup): sequence parallel process group
|
||||||
scatter_idx (int): scatter_idx for all2all comm
|
first_scatter_idx (int): scatter_idx for the first all2all comm
|
||||||
gather_idx (int): gather_idx for all2all comm
|
first_gather_idx (int): gather_idx for the first all2all comm
|
||||||
|
second_scatter_idx (int): scatter_idx for the second all2all comm
|
||||||
|
second_gather_idx (int): gather_idx for the second all2all comm
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
local_attention: Module,
|
local_attention: Module,
|
||||||
sequence_process_group: dist.ProcessGroup,
|
sequence_process_group: dist.ProcessGroup,
|
||||||
scatter_idx: int = 2,
|
first_scatter_idx: int = 2,
|
||||||
gather_idx: int = 0,
|
first_gather_idx: int = 0,
|
||||||
|
second_scatter_idx: int = 0,
|
||||||
|
second_gather_idx: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
super(DistributedAttention, self).__init__()
|
super(DistributedAttention, self).__init__()
|
||||||
self.local_attn = local_attention
|
self.local_attn = local_attention
|
||||||
self.spg = sequence_process_group
|
self.spg = sequence_process_group
|
||||||
self.scatter_idx = scatter_idx
|
self.first_scatter_idx = first_scatter_idx
|
||||||
self.gather_idx = gather_idx
|
self.first_gather_idx = first_gather_idx
|
||||||
|
self.second_scatter_idx = second_scatter_idx
|
||||||
# def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tensor:
|
self.second_gather_idx = second_gather_idx
|
||||||
# """ forward
|
|
||||||
|
|
||||||
# Arguments:
|
|
||||||
# query (Tensor): query input to the layer
|
|
||||||
# key (Tensor): key input to the layer
|
|
||||||
# value (Tensor): value input to the layer
|
|
||||||
# args: other args
|
|
||||||
|
|
||||||
# Returns:
|
|
||||||
# * output (Tensor): context output
|
|
||||||
# """
|
|
||||||
# # TODO Merge three alltoall calls into one
|
|
||||||
# #in shape : e.g., [s/p:h:]
|
|
||||||
# query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx)
|
|
||||||
# key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
|
|
||||||
# value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)
|
|
||||||
|
|
||||||
# #out shape : e.g., [s:h/p:]
|
|
||||||
# context_layer = self.local_attn(query_layer, key_layer, value_layer, *args)
|
|
||||||
|
|
||||||
# output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
|
|
||||||
|
|
||||||
# #out e.g., [s/p::h]
|
|
||||||
# return output
|
|
||||||
|
|
||||||
def forward(self, qkv: Tensor, **kwargs: Any) -> Tensor:
|
def forward(self, qkv: Tensor, **kwargs: Any) -> Tensor:
|
||||||
""" forward
|
""" forward
|
||||||
|
@ -114,15 +94,21 @@ class DistributedAttention(torch.nn.Module):
|
||||||
* output (Tensor): context output
|
* output (Tensor): context output
|
||||||
"""
|
"""
|
||||||
# TODO Merge three alltoall calls into one
|
# TODO Merge three alltoall calls into one
|
||||||
#in shape : e.g., [s/p:h:]
|
if qkv.ndim == 5:
|
||||||
qkv = _SeqAllToAll.apply(self.spg, qkv, 2, 0)
|
# in shape: [seq/tp_size, 3, head, head_dim]
|
||||||
# key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
|
qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx + 1, self.first_gather_idx + 1)
|
||||||
# value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)
|
#out shape : [seq, head/tp_size, head_dim]
|
||||||
|
context_layer = self.local_attn(qkv, **kwargs)
|
||||||
#out shape : e.g., [s:h/p:]
|
# in shape: [seq, head/tp_size, head_dim]
|
||||||
context_layer = self.local_attn(qkv, **kwargs)
|
output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx + 1, self.second_gather_idx + 1)
|
||||||
|
else:
|
||||||
output = _SeqAllToAll.apply(self.spg, context_layer, 0, 1)
|
|
||||||
|
# in shape: [seq/tp_size, 3, head, head_dim]
|
||||||
|
qkv = _SeqAllToAll.apply(self.spg, qkv, self.first_scatter_idx, self.first_gather_idx)
|
||||||
|
#out shape : [seq, head/tp_size, head_dim]
|
||||||
|
context_layer = self.local_attn(qkv, **kwargs)
|
||||||
|
# in shape: [seq, head/tp_size, head_dim]
|
||||||
|
output = _SeqAllToAll.apply(self.spg, context_layer, self.second_scatter_idx, self.second_gather_idx)
|
||||||
|
|
||||||
#out e.g., [s/p::h]
|
#out e.g., [s/p::h]
|
||||||
return output
|
return output
|
||||||
|
@ -171,6 +157,7 @@ class MHA(nn.Module):
|
||||||
use_flash_attn: bool = True,
|
use_flash_attn: bool = True,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
tp_mode: str = 'origin_tp',
|
||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -198,16 +185,8 @@ class MHA(nn.Module):
|
||||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
|
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
|
||||||
|
|
||||||
# notice here should change bias=True
|
# notice here should change bias=True
|
||||||
# self.Wqkv = ColumnParallelLinearTorch(
|
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == 'origin_tp' else FSDPLinear
|
||||||
# embed_dim,
|
self.Wqkv = Wqkv_cls(
|
||||||
# 3 * embed_dim,
|
|
||||||
# process_group,
|
|
||||||
# bias=True,
|
|
||||||
# sequence_parallel=gpc.config.parallel.sequence_parallel,
|
|
||||||
# **factory_kwargs,
|
|
||||||
# ) # according to https://spaces.ac.cn/archives/9577
|
|
||||||
|
|
||||||
self.Wqkv = FSDPLinear(
|
|
||||||
embed_dim,
|
embed_dim,
|
||||||
3 * embed_dim,
|
3 * embed_dim,
|
||||||
process_group,
|
process_group,
|
||||||
|
@ -222,25 +201,20 @@ class MHA(nn.Module):
|
||||||
self.inner_cross_attn = inner_cross_attn_cls(
|
self.inner_cross_attn = inner_cross_attn_cls(
|
||||||
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
||||||
)
|
)
|
||||||
|
if tp_mode == 'fstp':
|
||||||
self.inner_attn_sp = DistributedAttention(self.inner_attn, sequence_process_group=process_group, scatter_idx=3, gather_idx=0)
|
self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=process_group)
|
||||||
self.inner_cross_attn_sp = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group, scatter_idx=3, gather_idx=0)
|
self.inner_cross_attn = DistributedAttention(self.inner_cross_attn, sequence_process_group=process_group)
|
||||||
|
|
||||||
# output projection always have the bias (for now)
|
# output projection always have the bias (for now)
|
||||||
# self.out_proj = RowParallelLinearTorch(
|
out_proj_cls = RowParallelLinearTorch if tp_mode == 'origin_tp' else FSDPLinear
|
||||||
# embed_dim,
|
self.out_proj = out_proj_cls(
|
||||||
# embed_dim,
|
|
||||||
# process_group,
|
|
||||||
# sequence_parallel=gpc.config.parallel.sequence_parallel,
|
|
||||||
# **factory_kwargs,
|
|
||||||
# )
|
|
||||||
self.out_proj = FSDPLinear(
|
|
||||||
embed_dim,
|
embed_dim,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
process_group,
|
process_group,
|
||||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# need to assign tp attribute so that internlm know it is tensor parallel module
|
# need to assign tp attribute so that internlm know it is tensor parallel module
|
||||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||||
for name in ["out_proj", "Wqkv"]:
|
for name in ["out_proj", "Wqkv"]:
|
||||||
|
@ -343,11 +317,9 @@ class MHA(nn.Module):
|
||||||
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
||||||
if qkv.dtype not in [torch.float16, torch.bfloat16]:
|
if qkv.dtype not in [torch.float16, torch.bfloat16]:
|
||||||
qkv = qkv.to(torch.bfloat16)
|
qkv = qkv.to(torch.bfloat16)
|
||||||
# context = self.inner_attn(qkv, **kwargs).to(x.dtype)
|
context = self.inner_attn(qkv, **kwargs).to(x.dtype)
|
||||||
context = self.inner_attn_sp(qkv, **kwargs).to(x.dtype)
|
|
||||||
else:
|
else:
|
||||||
# context = self.inner_attn(qkv, **kwargs)
|
context = self.inner_attn(qkv, **kwargs)
|
||||||
context = self.inner_attn_sp(qkv, **kwargs)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Not support this right now")
|
raise RuntimeError("Not support this right now")
|
||||||
|
|
|
@ -54,7 +54,10 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
|
||||||
def switch_sequence_parallel_mode():
|
def switch_sequence_parallel_mode():
|
||||||
prev_mode = gpc.config.parallel.sequence_parallel
|
prev_mode = gpc.config.parallel.sequence_parallel
|
||||||
try:
|
try:
|
||||||
gpc.config.parallel.sequence_parallel = False
|
if gpc.config.parallel["tensor"]["mode"] == 'fstp':
|
||||||
|
gpc.config.parallel.sequence_parallel = True
|
||||||
|
else:
|
||||||
|
gpc.config.parallel.sequence_parallel = False
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
gpc.config.parallel.sequence_parallel = prev_mode
|
gpc.config.parallel.sequence_parallel = prev_mode
|
||||||
|
|
Loading…
Reference in New Issue