pull/155/head
yingtongxiong 2023-07-31 20:36:04 +08:00
parent 570e30a6bc
commit e551b4dffc
4 changed files with 74 additions and 52 deletions

View File

@ -155,15 +155,21 @@ def args_sanity_check():
elif gpc.config.model.dtype in ("torch.float16", "torch.half"):
gpc.config.model.dtype = torch.float16
elif gpc.config.model.dtype == "torch.float32":
assert gpc.config.model.use_flash_attn == False, "when using float32, the use_flash_attn must be False"
assert gpc.config.model.use_flash_attn is False, "when using float32, the use_flash_attn must be False"
gpc.config.model.dtype = torch.float32
elif gpc.config.model.dtype == "torch.tf32":
assert gpc.config.model.use_flash_attn == False, "when using tf32, the use_flash_attn must be False"
assert gpc.config.model.use_flash_attn is False, "when using tf32, the use_flash_attn must be False"
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
gpc.config.model.dtype = torch.float32
else:
assert gpc.config.model.dtype in ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"]
assert gpc.config.model.dtype in [
"torch.float16",
"torch.half",
"torch.bfloat16",
"torch.float32",
"torch.tf32",
]
if gpc.is_rank_for_log():
logger.info("+" * 15 + " Model Info " + "+" * 15) # pylint: disable=W1201

View File

@ -5,12 +5,13 @@ from typing import Optional
import torch
import torch.nn.functional as F
from torch.distributed import ProcessGroup
from torch import nn
from torch.distributed import ProcessGroup
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import fused_dense_func_torch, reduce_scatter, all_reduce
from internlm.model.utils import all_reduce, fused_dense_func_torch, reduce_scatter
class ScaleColumnParallelLinear(nn.Linear):
"""
@ -109,15 +110,20 @@ class RewardModelLinear(ScaleColumnParallelLinear):
class ColumnParallelLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup,
bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None:
def __init__(
self,
in_features: int,
out_features: int,
process_group: ProcessGroup,
bias: bool = True,
sequence_parallel=True,
device=None,
dtype=None,
) -> None:
world_size = torch.distributed.get_world_size(process_group)
if out_features % world_size != 0:
raise ValueError(f'out_features ({out_features}) must be divisible by '
f'world_size ({world_size})')
super().__init__(in_features, out_features // world_size, bias=bias,
device=device, dtype=dtype)
raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})")
super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype)
self.process_group = process_group
self.sequence_parallel = sequence_parallel
@ -126,22 +132,28 @@ class ColumnParallelLinear(nn.Linear):
# we do an all_gather of x before doing the matmul.
# If not, then the input is already gathered.
return fused_dense_func_torch(x, self.weight, self.bias, process_group=self.process_group,
sequence_parallel=self.sequence_parallel)
return fused_dense_func_torch(
x, self.weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
)
class RowParallelLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup,
bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None:
def __init__(
self,
in_features: int,
out_features: int,
process_group: ProcessGroup,
bias: bool = True,
sequence_parallel=True,
device=None,
dtype=None,
) -> None:
world_size = torch.distributed.get_world_size(process_group)
rank = torch.distributed.get_rank(process_group)
if in_features % world_size != 0:
raise ValueError(f'in_features ({in_features}) must be divisible by '
f'world_size ({world_size})')
raise ValueError(f"in_features ({in_features}) must be divisible by " f"world_size ({world_size})")
# Only rank 0 will have bias
super().__init__(in_features // world_size, out_features, bias=bias and rank == 0,
device=device, dtype=dtype)
super().__init__(in_features // world_size, out_features, bias=bias and rank == 0, device=device, dtype=dtype)
self.process_group = process_group
self.sequence_parallel = sequence_parallel

View File

@ -19,6 +19,7 @@ from internlm.core.context import global_context as gpc
from internlm.model.embedding import RotaryEmbedding
from internlm.model.linear import ColumnParallelLinear, RowParallelLinear
class MHA(nn.Module):
"""
Multi-head self-attention and cross-attention.

View File

@ -1,16 +1,15 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional
import torch
import torch.nn.functional as F
from typing import Optional
from flash_attn.ops.fused_dense import FusedDenseFunc
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
from flash_attn.ops.fused_dense import FusedDenseFunc
from internlm.core.context import global_context as gpc
@ -80,29 +79,32 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
def gather_forward_split_backward(input_, parallel_mode, dim):
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group)
output = torch.empty(world_size * input_.shape[0], *input_.shape[1:],
dtype=input_.dtype, device=input_.device)
handle = torch.distributed.all_gather_into_tensor(output, input_.contiguous(),
group=process_group, async_op=async_op)
output = torch.empty(world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device)
handle = torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=process_group, async_op=async_op
)
return output, handle
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group)
assert input_.shape[0] % world_size == 0
output = torch.empty(input_.shape[0] // world_size, *input_.shape[1:],
dtype=input_.dtype, device=input_.device)
handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(),
group=process_group,
async_op=async_op)
output = torch.empty(input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device)
handle = torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), group=process_group, async_op=async_op
)
return output, handle
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
input_ = input_.contiguous()
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
return input_, handle
class ReduceScatterFunc(torch.autograd.Function):
"""Reduce scatter the input from the sequence parallel region and concatenate."""
@ -137,6 +139,7 @@ reduce_scatter = ReduceScatterFunc.apply
# Supports autograd, but does not support async
all_reduce = AllReduceFunc.apply
def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
assert input.dtype == grad_output.dtype
grad_weight = torch.matmul(grad_output.t(), input)
@ -145,11 +148,9 @@ def linear_bias_wgrad_torch(input, grad_output, has_d_bias):
class FusedDenseFuncTorch(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight, bias, return_residual=False, process_group=None,
sequence_parallel=True):
def forward(ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True):
"""
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
@ -178,7 +179,7 @@ class FusedDenseFuncTorch(torch.autograd.Function):
batch_dim = batch_shape.numel()
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if min(batch_dim, n, *weight.shape) > 65535 * 32:
raise RuntimeError('fused_dense only supports matrix dims <= 2M')
raise RuntimeError("fused_dense only supports matrix dims <= 2M")
output = F.linear(total_x, weight, bias)
if ctx.compute_weight_gradient:
ctx.save_for_backward(x, weight)
@ -191,7 +192,7 @@ class FusedDenseFuncTorch(torch.autograd.Function):
def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous()
if ctx.return_residual:
grad_input, = args
(grad_input,) = args
grad_input = grad_input.contiguous()
process_group = ctx.process_group
sequence_parallel = ctx.sequence_parallel
@ -202,7 +203,7 @@ class FusedDenseFuncTorch(torch.autograd.Function):
else:
total_x = x
else:
weight, = ctx.saved_tensors
(weight,) = ctx.saved_tensors
total_x = None
batch_shape = grad_output.shape[:-1]
batch_dim = batch_shape.numel()
@ -211,8 +212,7 @@ class FusedDenseFuncTorch(torch.autograd.Function):
if not ctx.return_residual:
grad_input = F.linear(grad_output, weight.t())
else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
grad_output, weight)
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
if process_group is not None:
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
@ -234,15 +234,18 @@ class FusedDenseFuncTorch(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None
def fused_dense_func_torch(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
return_residual: bool = False, process_group: Optional[ProcessGroup] = None,
sequence_parallel: bool = True):
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
def fused_dense_func_torch(
x: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
return_residual: bool = False,
process_group: Optional[ProcessGroup] = None,
sequence_parallel: bool = True,
):
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
x.dtype == torch.float32 and torch.is_autocast_enabled()
)
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group,
sequence_parallel)
return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
else:
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group,
sequence_parallel)
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)