ColossalAI/colossalai/nn/layer/_common_utils.py

64 lines
1.8 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
from torch import Tensor
from torch import nn
from colossalai.utils import checkpoint
from colossalai.constants import IS_TENSOR_PARALLEL
def divide(numerator, denominator):
""" only allow exact division """
assert numerator % denominator == 0, \
'{} is not divisible by {}'.format(numerator, denominator)
return numerator // denominator
def gelu(x: Tensor) -> Tensor:
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
def set_tensor_parallel_attribute(param):
if not hasattr(param, IS_TENSOR_PARALLEL):
setattr(param, IS_TENSOR_PARALLEL, True)
class CheckpointModule(nn.Module):
def __init__(self, checkpoint: bool = True):
super().__init__()
self.checkpoint = checkpoint
self._use_checkpoint = checkpoint
def _forward(self, *args):
raise NotImplementedError(
'CheckpointModule should implement _forward method instead of origin forward')
def forward(self, *args):
if self._use_checkpoint:
return checkpoint(self._forward, *args)
else:
return self._forward(*args)
def train(self, mode: bool = True):
self._use_checkpoint = self.checkpoint
return super().train(mode=mode)
def eval(self):
self._use_checkpoint = False
return super().eval()