#!/usr/bin/env python # -*- encoding: utf-8 -*- from typing import Optional import torch import torch.nn.functional as F from flash_attn.ops.fused_dense import ( ColumnParallelLinear, RowParallelLinear, fused_dense_func, ) from torch import nn from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc class ScaleColumnParallelLinear(nn.Linear): """ ScaleColumnParallelLinear. Args: in_features (int): size of each input sample out_features (int): size of each output sample process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False in the config. sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: we do an all_gather of x before doing the matmul. If not, then the input is already gathered. device (Optional[Union[str, torch.device]]): The device will be used. dtype (Optional[torch.dtype]): The type of data. weight_scale (int): For training stability. 1 by default. """ def __init__( self, in_features: int, out_features: int, process_group: Optional[torch.distributed.ProcessGroup], bias: bool = True, sequence_parallel: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, weight_scale: int = 1, ) -> None: world_size = torch.distributed.get_world_size(process_group) if out_features % world_size != 0: raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})") super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype) self.process_group = process_group self.sequence_parallel = sequence_parallel self.weight_scale = weight_scale def forward(self, input): # pylint: disable=W0622 # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # we do an all_gather of x before doing the matmul. # If not, then the input is already gathered. if self.weight_scale != 1: weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() else: weight = self.weight return fused_dense_func( input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel ) class RewardModelLinear(ScaleColumnParallelLinear): """ RewardModelLinear. Args: in_features (int): size of each input sample out_features (int): size of each output sample process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False in the config. sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: we do an all_gather of x before doing the matmul. If not, then the input is already gathered. device (Optional[Union[str, torch.device]]): The device will be used. dtype (Optional[torch.dtype]): The type of data. weight_scale (int): For training stability. 1 by default. """ def __init__( self, in_features: int, out_features: int, process_group: Optional[torch.distributed.ProcessGroup], bias: bool = True, sequence_parallel: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, weight_scale: int = 1, ) -> None: super().__init__(in_features, out_features, process_group, bias, sequence_parallel, device, dtype, weight_scale) torch.distributed.broadcast(self.weight, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group) if bias: torch.distributed.broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group) def forward(self, input): # pylint: disable=W0622 # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # we do an all_gather of x before doing the matmul. # If not, then the input is already gathered. if self.weight_scale != 1: weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() else: weight = self.weight return fused_dense_func( input, weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel ) class FeedForward(nn.Module): """ FeedForward. Args: in_features (int): size of each input sample hidden_features (int): size of hidden state of FFN out_features (int): size of each output sample process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False in the config. device (Optional[Union[str, torch.device]]): The device will be used. dtype (Optional[torch.dtype]): The type of data. multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default. """ def __init__( self, in_features: int, hidden_features: int, out_features: int = None, process_group: Optional[torch.distributed.ProcessGroup] = None, bias: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, multiple_of: int = 256, ): super().__init__() hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) self.w1 = ColumnParallelLinear( in_features, hidden_features, process_group, bias, sequence_parallel=False, device=device, dtype=dtype, ) self.w2 = ColumnParallelLinear( in_features, hidden_features, process_group, bias, sequence_parallel=False, device=device, dtype=dtype ) self.w3 = RowParallelLinear( hidden_features, out_features, process_group, bias=bias, sequence_parallel=False, device=device, dtype=dtype, ) # need to assign tp attribute so that colossalai know it is tensor parallel module if gpc.get_world_size(ParallelMode.TENSOR) > 1: for name in ["w1", "w2", "w3"]: for param in getattr(self, name).parameters(): setattr(param, IS_TENSOR_PARALLEL, True) def forward(self, x): out = self.w3(F.silu(self.w1(x)) * self.w2(x)) return out