mirror of https://github.com/hpcaitech/ColossalAI
[kernel] move all symlinks of kernel to `colossalai._C` (#1971)
parent
7e24b9b9ee
commit
f8a7148dec
|
@ -38,7 +38,6 @@ jobs:
|
|||
pip install -r requirements/requirements.txt
|
||||
pip install -v -e .
|
||||
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
|
||||
cp /__w/ColossalAI/ColossalAI/*.so /github/home/cuda_ext_cache/
|
||||
pip install -r requirements/requirements-test.txt
|
||||
- name: Unit Testing
|
||||
run: |
|
||||
|
|
|
@ -36,7 +36,6 @@ jobs:
|
|||
pip install -r requirements/requirements.txt
|
||||
pip install -v -e .
|
||||
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
|
||||
cp /__w/ColossalAI/ColossalAI/*.so /github/home/cuda_ext_cache/
|
||||
pip install -r requirements/requirements-test.txt
|
||||
- name: Unit Testing
|
||||
run: |
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
include *.txt README.md
|
||||
recursive-include requirements *.txt
|
||||
recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc
|
||||
recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
from . import (
|
||||
cpu_optim,
|
||||
fused_optim,
|
||||
layer_norm,
|
||||
moe,
|
||||
multihead_attention,
|
||||
scaled_masked_softmax,
|
||||
scaled_upper_triang_masked_softmax,
|
||||
)
|
|
@ -0,0 +1,8 @@
|
|||
from torch import Tensor
|
||||
|
||||
class CPUAdamOptimizer:
|
||||
def __init__(self, lr: float, beta1: float, beta2: float, eps: float,
|
||||
weight_decay: float, adamw_mode: float) -> None: ...
|
||||
|
||||
def step(self, step: int, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, bias_correction: bool,
|
||||
param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor, loss_scale: float) -> None: ...
|
|
@ -0,0 +1,23 @@
|
|||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
def multi_tensor_scale(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], scale: float) -> None:
|
||||
...
|
||||
|
||||
|
||||
def multi_tensor_sgd(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], weight_decay: float,
|
||||
momentum: float, dampening: float, lr: float, nesterov: bool, first_run: bool, weight_decay_after_momentum: bool, scale: float) -> None:
|
||||
...
|
||||
|
||||
|
||||
def multi_tensor_adam(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, mode: int, bias_correction: int, weight_decay: float) -> None:
|
||||
...
|
||||
|
||||
|
||||
def multi_tensor_lamb(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, bias_correction: int, weight_decay: float, grad_averaging: int, mode: int, global_grad_norm: Tensor, max_grad_norm: float, use_nvlamb_python: bool) -> None:
|
||||
...
|
||||
|
||||
|
||||
def multi_tensor_l2norm(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], per_tensor_python: bool) -> None:
|
||||
...
|
|
@ -0,0 +1,11 @@
|
|||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
def forward_affine(input: Tensor, normalized_shape: List[int], gamma: Tensor, beta: Tensor, epsilon: float) -> List[Tensor]:
|
||||
...
|
||||
|
||||
|
||||
def backward_affine(dout: Tensor, mean: Tensor, invvar: Tensor, input: Tensor,
|
||||
normalized_shape: List[int], gamma: Tensor, beta: Tensor, epsilon: float) -> List[Tensor]:
|
||||
...
|
|
@ -0,0 +1,20 @@
|
|||
from torch import Tensor
|
||||
|
||||
def cumsum_sub_one(mask: Tensor) -> Tensor:
|
||||
...
|
||||
|
||||
|
||||
def dispatch_forward(s: int, ec: int, h: int, batch_tokens: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
|
||||
...
|
||||
|
||||
|
||||
def dispatch_backward(s: int, ec: int, h: int, expert_grad: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
|
||||
...
|
||||
|
||||
|
||||
def combine_forward(s: int, e: int, c: int, h: int, expert_tokens: Tensor, logits: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
|
||||
...
|
||||
|
||||
|
||||
def combine_backward(s: int, e: int, c: int, h: int, tokens_grad: Tensor, expert_tokens: Tensor, logits: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
|
||||
...
|
|
@ -0,0 +1,55 @@
|
|||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
def multihead_attention_fw_fp32(layer_id: int, input: Tensor, input_mask: Tensor,
|
||||
in_proj_weight: Tensor, in_proj_bias: Tensor,
|
||||
out_proj_weight: Tensor, out_proj_bias: Tensor,
|
||||
norm_weight: Tensor, norm_bias: Tensor,
|
||||
training_mode: bool, prelayernorm: bool) -> List[Tensor]:
|
||||
...
|
||||
|
||||
|
||||
def multihead_attention_fw_fp16(layer_id: int, input: Tensor, input_mask: Tensor,
|
||||
in_proj_weight: Tensor, in_proj_bias: Tensor,
|
||||
out_proj_weight: Tensor, out_proj_bias: Tensor,
|
||||
norm_weight: Tensor, norm_bias: Tensor,
|
||||
training_mode: bool, prelayernorm: bool) -> List[Tensor]:
|
||||
...
|
||||
|
||||
|
||||
def multihead_attention_bw_fp32(layer_id: int, grad_dec_output: Tensor,
|
||||
output: Tensor, input: Tensor,
|
||||
input_mask: Tensor, in_proj_weight: Tensor,
|
||||
in_proj_bias: Tensor, out_proj_weight: Tensor,
|
||||
out_proj_bias: Tensor, norm_weight: Tensor,
|
||||
norm_bias: Tensor) -> List[Tensor]:
|
||||
...
|
||||
|
||||
|
||||
def multihead_attention_bw_fp16(layer_id: int, grad_dec_output: Tensor,
|
||||
output: Tensor, input: Tensor,
|
||||
input_mask: Tensor, in_proj_weight: Tensor,
|
||||
in_proj_bias: Tensor, out_proj_weight: Tensor,
|
||||
out_proj_bias: Tensor, norm_weight: Tensor,
|
||||
norm_bias: Tensor) -> List[Tensor]:
|
||||
...
|
||||
|
||||
|
||||
def create_multihead_attention_fp32(layer_id: int, max_batch_tokens: int,
|
||||
max_seq_len: int, hidden_dim: int, num_heads: int,
|
||||
attn_prob_dropout_ratio: float,
|
||||
hidden_dropout_ratio: float,
|
||||
pre_or_postLayerNorm: bool,
|
||||
pg: ProcessGroup) -> int:
|
||||
...
|
||||
|
||||
|
||||
def create_multihead_attention_fp16(layer_id: int, max_batch_tokens: int,
|
||||
max_seq_len: int, hidden_dim: int, num_heads: int,
|
||||
attn_prob_dropout_ratio: float,
|
||||
hidden_dropout_ratio: float,
|
||||
pre_or_postLayerNorm: bool,
|
||||
pg: ProcessGroup) -> int:
|
||||
...
|
|
@ -0,0 +1,12 @@
|
|||
from torch import Tensor
|
||||
|
||||
def forward(input: Tensor, mask: Tensor, scale: float) -> Tensor:
|
||||
...
|
||||
|
||||
|
||||
def backward(output_grads: Tensor, softmax_results: Tensor, scale: float) -> Tensor:
|
||||
...
|
||||
|
||||
|
||||
def get_batch_per_block(query_seq_len: int, key_seq_len: int, batches: int, attn_heads: int) -> int:
|
||||
...
|
|
@ -0,0 +1,8 @@
|
|||
from torch import Tensor
|
||||
|
||||
def forward(input: Tensor, scale: float) -> Tensor:
|
||||
...
|
||||
|
||||
|
||||
def backward(output_grads: Tensor, softmax_results: Tensor, scale: float) -> Tensor:
|
||||
...
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
|
||||
try:
|
||||
import colossal_C
|
||||
import colossalai._C.fused_optim
|
||||
except:
|
||||
print('Colossalai should be built with cuda extension to use the FP16 optimizer')
|
||||
|
||||
|
@ -35,7 +35,7 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
|
|||
if overflow_buf:
|
||||
overflow_buf.fill_(0)
|
||||
# Scaling with factor `1.0` is equivalent to copy.
|
||||
multi_tensor_applier(colossal_C.multi_tensor_scale, overflow_buf, [this, that], 1.0)
|
||||
multi_tensor_applier(colossalai._C.fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0)
|
||||
else:
|
||||
for this_, that_ in zip(this, that):
|
||||
that_.copy_(this_)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import click
|
||||
import subprocess
|
||||
|
||||
import click
|
||||
import torch
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
|
@ -17,7 +18,7 @@ def check_installation():
|
|||
|
||||
def _check_cuda_extension_installed():
|
||||
try:
|
||||
import colossal_C
|
||||
import colossalai._C.fused_optim
|
||||
is_cuda_extension_installed = u'\u2713'
|
||||
except ImportError:
|
||||
is_cuda_extension_installed = 'x'
|
||||
|
|
|
@ -3,14 +3,11 @@
|
|||
with some changes. """
|
||||
|
||||
import numbers
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn import init
|
||||
from torch.cuda.amp import custom_fwd, custom_bwd
|
||||
import importlib
|
||||
|
||||
global colossal_layer_norm_cuda
|
||||
colossal_layer_norm_cuda = None
|
||||
import torch
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torch.nn import init
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class FusedLayerNormAffineFunction(torch.autograd.Function):
|
||||
|
@ -18,13 +15,17 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
|
|||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
||||
try:
|
||||
import colossalai._C.layer_norm
|
||||
except ImportError:
|
||||
raise RuntimeError('FusedLayerNormAffineFunction requires cuda extensions')
|
||||
|
||||
ctx.normalized_shape = normalized_shape
|
||||
ctx.eps = eps
|
||||
input_ = input.contiguous()
|
||||
weight_ = weight.contiguous()
|
||||
bias_ = bias.contiguous()
|
||||
output, mean, invvar = colossal_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, bias_,
|
||||
output, mean, invvar = colossalai._C.layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_,
|
||||
ctx.eps)
|
||||
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
|
||||
|
||||
|
@ -33,11 +34,15 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
|
|||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
try:
|
||||
import colossalai._C.layer_norm
|
||||
except ImportError:
|
||||
raise RuntimeError('FusedLayerNormAffineFunction requires cuda extensions')
|
||||
|
||||
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
||||
grad_input = grad_weight = grad_bias = None
|
||||
grad_input, grad_weight, grad_bias \
|
||||
= colossal_layer_norm_cuda.backward_affine(
|
||||
= colossalai._C.layer_norm.backward_affine(
|
||||
grad_output.contiguous(), mean, invvar,
|
||||
input_, ctx.normalized_shape,
|
||||
weight_, bias_, ctx.eps)
|
||||
|
@ -50,13 +55,6 @@ class MixedFusedLayerNorm(torch.nn.Module):
|
|||
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None):
|
||||
super(MixedFusedLayerNorm, self).__init__()
|
||||
|
||||
global colossal_layer_norm_cuda
|
||||
if colossal_layer_norm_cuda is None:
|
||||
try:
|
||||
colossal_layer_norm_cuda = importlib.import_module("colossal_layer_norm_cuda")
|
||||
except ImportError:
|
||||
raise RuntimeError('MixedFusedLayerNorm requires cuda extensions')
|
||||
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
normalized_shape = (normalized_shape,)
|
||||
self.normalized_shape = torch.Size(normalized_shape)
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import math
|
||||
import importlib
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
@ -37,21 +36,21 @@ colossal_multihead_attention = None
|
|||
|
||||
@dataclass
|
||||
class Config:
|
||||
max_batch_tokens: int # max batch token numbers
|
||||
max_seq_len: int # max sequence length
|
||||
hidden_size: int # size of transformer hidden layers
|
||||
nhead: int # number of heads in attention
|
||||
attn_prob_dropout_ratio: float # attention score dropout ratio
|
||||
hidden_dropout_ratio: float # dropout ration before residual
|
||||
norm_first: bool # norm_first
|
||||
fp16: bool # fp16 presion
|
||||
max_batch_tokens: int # max batch token numbers
|
||||
max_seq_len: int # max sequence length
|
||||
hidden_size: int # size of transformer hidden layers
|
||||
nhead: int # number of heads in attention
|
||||
attn_prob_dropout_ratio: float # attention score dropout ratio
|
||||
hidden_dropout_ratio: float # dropout ration before residual
|
||||
norm_first: bool # norm_first
|
||||
fp16: bool # fp16 presion
|
||||
|
||||
|
||||
class MultiHeadAttention1DFunc(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight,
|
||||
out_proj_bias, norm_weight, norm_bias, config):
|
||||
def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, out_proj_bias, norm_weight,
|
||||
norm_bias, config):
|
||||
cuda_module = colossal_multihead_attention
|
||||
forward_func = (cuda_module.multihead_attention_fw_fp16
|
||||
if config.fp16 else cuda_module.multihead_attention_fw_fp32)
|
||||
|
@ -59,13 +58,12 @@ class MultiHeadAttention1DFunc(Function):
|
|||
input = input.to(torch.half)
|
||||
input_mask = input_mask.to(torch.half)
|
||||
|
||||
(output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias,
|
||||
out_proj_weight, out_proj_bias, norm_weight, norm_bias,
|
||||
config.training, config.norm_first)
|
||||
(output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight,
|
||||
out_proj_bias, norm_weight, norm_bias, config.training, config.norm_first)
|
||||
|
||||
if config.is_grad_enabled and config.training:
|
||||
ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias,
|
||||
out_proj_weight, out_proj_bias, norm_weight, norm_bias)
|
||||
ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight,
|
||||
out_proj_bias, norm_weight, norm_bias)
|
||||
ctx.config = config
|
||||
return output
|
||||
|
||||
|
@ -98,8 +96,8 @@ class MultiHeadAttention1DFunc(Function):
|
|||
ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight,
|
||||
in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias)
|
||||
|
||||
return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight,
|
||||
grad_out_proj_bias, grad_norm_weight, grad_norm_bias, None)
|
||||
return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias,
|
||||
grad_norm_weight, grad_norm_bias, None)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
|
@ -121,19 +119,11 @@ class MultiHeadAttention(nn.Module):
|
|||
|
||||
layer_id = 0
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
nhead,
|
||||
batch_size,
|
||||
max_seq_len,
|
||||
dropout=0.0,
|
||||
norm_first=False,
|
||||
fp16=True,
|
||||
pg=None):
|
||||
def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
|
||||
self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout,
|
||||
dropout, norm_first, fp16)
|
||||
self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first,
|
||||
fp16)
|
||||
check_config(self.config)
|
||||
self.pg = pg
|
||||
self.pg_size = 1
|
||||
|
@ -146,7 +136,8 @@ class MultiHeadAttention(nn.Module):
|
|||
global colossal_multihead_attention
|
||||
if colossal_multihead_attention is None:
|
||||
try:
|
||||
colossal_multihead_attention = importlib.import_module("colossal_multihead_attention")
|
||||
import colossalai._C.multihead_attention
|
||||
colossal_multihead_attention = colossalai._C.multihead_attention
|
||||
except ImportError:
|
||||
raise RuntimeError('MultiHeadAttention requires cuda extensions')
|
||||
|
||||
|
@ -215,14 +206,13 @@ class MultiHeadAttention(nn.Module):
|
|||
|
||||
with torch.no_grad():
|
||||
self.in_proj_weight.copy_(
|
||||
attn_qkvw_global.view(3, hs, hs)[
|
||||
:, int(hs * rank_in_pg / self.pg_size):
|
||||
int(hs * (rank_in_pg + 1) / self.pg_size),
|
||||
:])
|
||||
attn_qkvw_global.view(3, hs, hs)[:,
|
||||
int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) /
|
||||
self.pg_size), :])
|
||||
self.in_proj_bias.copy_(
|
||||
attn_qkvb_global.view(3, hs)[
|
||||
:, int(hs * rank_in_pg / self.pg_size):
|
||||
int(hs * (rank_in_pg + 1) / self.pg_size)])
|
||||
attn_qkvb_global.view(3, hs)[:,
|
||||
int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) /
|
||||
self.pg_size)])
|
||||
|
||||
attn_ow_global = torch.empty(hs, hs)
|
||||
nn.init.xavier_uniform_(attn_ow_global, 1.0)
|
||||
|
@ -230,9 +220,9 @@ class MultiHeadAttention(nn.Module):
|
|||
torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg)
|
||||
attn_ow_global = attn_ow_global.cpu()
|
||||
with torch.no_grad():
|
||||
self.out_proj_weight.copy_(attn_ow_global[
|
||||
:, int(hs * rank_in_pg / self.pg_size):
|
||||
int(hs * (rank_in_pg + 1) / self.pg_size)])
|
||||
self.out_proj_weight.copy_(attn_ow_global[:,
|
||||
int(hs * rank_in_pg /
|
||||
self.pg_size):int(hs * (rank_in_pg + 1) / self.pg_size)])
|
||||
|
||||
else:
|
||||
attn_qkvw = self.in_proj_weight.view(-1, hs)
|
||||
|
@ -243,10 +233,7 @@ class MultiHeadAttention(nn.Module):
|
|||
nn.init.xavier_uniform_(self.out_proj_weight, 1.0)
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
destination = torch.nn.Module.state_dict(self,
|
||||
destination=destination,
|
||||
prefix=prefix,
|
||||
keep_vars=keep_vars)
|
||||
destination = torch.nn.Module.state_dict(self, destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
return destination
|
||||
|
||||
def forward(self, hidden_states, encoder_padding_mask):
|
||||
|
@ -257,8 +244,7 @@ class MultiHeadAttention(nn.Module):
|
|||
|
||||
bs, sl, dim = hidden_states.size()
|
||||
if bs * sl > self.config.max_batch_tokens:
|
||||
raise ValueError(
|
||||
f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.")
|
||||
raise ValueError(f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.")
|
||||
if sl > self.config.max_seq_len:
|
||||
raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.")
|
||||
if len(encoder_padding_mask.size()) == 1:
|
||||
|
@ -266,9 +252,8 @@ class MultiHeadAttention(nn.Module):
|
|||
else:
|
||||
assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1)
|
||||
|
||||
output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask,
|
||||
self.in_proj_weight, self.in_proj_bias,
|
||||
self.out_proj_weight, self.out_proj_bias,
|
||||
output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask, self.in_proj_weight,
|
||||
self.in_proj_bias, self.out_proj_weight, self.out_proj_bias,
|
||||
self.norm_weight, self.norm_bias, self.config)
|
||||
|
||||
return output.to(self.precision)
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
"""This code from NVIDIA Megatron
|
||||
with some changes. """
|
||||
|
||||
import enum
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import enum
|
||||
|
||||
|
||||
class AttnMaskType(enum.Enum):
|
||||
|
@ -23,12 +24,12 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
|||
@staticmethod
|
||||
def forward(ctx, inputs, scale):
|
||||
try:
|
||||
import colossal_scaled_upper_triang_masked_softmax
|
||||
import colossalai._C.scaled_upper_triang_masked_softmax
|
||||
except ImportError:
|
||||
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
|
||||
|
||||
scale_t = torch.tensor([scale])
|
||||
softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
|
||||
softmax_results = colossalai._C.scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
|
||||
|
||||
ctx.save_for_backward(softmax_results, scale_t)
|
||||
return softmax_results
|
||||
|
@ -36,12 +37,13 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
|||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
try:
|
||||
import colossal_scaled_upper_triang_masked_softmax
|
||||
import colossalai._C.scaled_upper_triang_masked_softmax
|
||||
except ImportError:
|
||||
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
|
||||
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
input_grads = colossal_scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
||||
input_grads = colossalai._C.scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results,
|
||||
scale_t[0])
|
||||
|
||||
return input_grads, None
|
||||
|
||||
|
@ -58,26 +60,26 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
|
|||
@staticmethod
|
||||
def forward(ctx, inputs, mask, scale):
|
||||
try:
|
||||
import colossal_scaled_masked_softmax
|
||||
import colossalai._C.scaled_masked_softmax
|
||||
except ImportError:
|
||||
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
|
||||
|
||||
scale_t = torch.tensor([scale])
|
||||
|
||||
softmax_results = colossal_scaled_masked_softmax.forward(inputs, mask, scale_t[0])
|
||||
softmax_results = colossalai._C.scaled_masked_softmax.forward(inputs, mask, scale_t[0])
|
||||
ctx.save_for_backward(softmax_results, scale_t)
|
||||
return softmax_results
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
try:
|
||||
import colossal_scaled_masked_softmax
|
||||
import colossalai._C.scaled_masked_softmax
|
||||
except ImportError:
|
||||
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
|
||||
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
|
||||
input_grads = colossal_scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
||||
input_grads = colossalai._C.scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
||||
return input_grads, None, None
|
||||
|
||||
|
||||
|
@ -184,8 +186,8 @@ class FusedScaleMaskSoftmax(nn.Module):
|
|||
@staticmethod
|
||||
def get_batch_per_block(sq, sk, b, np):
|
||||
try:
|
||||
import colossal_scaled_masked_softmax
|
||||
import colossalai._C.scaled_masked_softmax
|
||||
except ImportError:
|
||||
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
|
||||
|
||||
return colossal_scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
|
||||
return colossalai._C.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
|
||||
|
|
|
@ -1,153 +1,154 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from typing import Any, Tuple, Optional
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
COL_MOE_KERNEL_FLAG = False
|
||||
try:
|
||||
import colossal_moe_cuda
|
||||
|
||||
COL_MOE_KERNEL_FLAG = True
|
||||
except ImportError:
|
||||
print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
|
||||
|
||||
|
||||
class AllGather(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
if comm_size == 1:
|
||||
return inputs.unsqueeze(0)
|
||||
|
||||
buffer_shape = (comm_size,) + inputs.shape
|
||||
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
|
||||
dist.all_gather(buffer_list, inputs, group=group)
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||
return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None
|
||||
|
||||
|
||||
class ReduceScatter(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
if comm_size == 1:
|
||||
return inputs.squeeze(0)
|
||||
|
||||
if not inputs.is_contiguous():
|
||||
inputs = inputs.contiguous()
|
||||
|
||||
output_shape = inputs.shape[1:]
|
||||
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
|
||||
dist.reduce_scatter(outputs, buffer_list, group=group)
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||
return AllGather.forward(None, grad_outputs, ctx.comm_grp), None
|
||||
|
||||
|
||||
class AllToAll(torch.autograd.Function):
|
||||
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
|
||||
operation in torch.distributed.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
if not inputs.is_contiguous():
|
||||
inputs = inputs.contiguous()
|
||||
if dist.get_world_size(group) == 1:
|
||||
return inputs
|
||||
output = torch.empty_like(inputs)
|
||||
dist.all_to_all_single(output, inputs, group=group)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||
return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None
|
||||
|
||||
|
||||
class MoeDispatch(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, tokens, mask, dest_idx, ec):
|
||||
s = tokens.size(0)
|
||||
h = tokens.size(1)
|
||||
|
||||
expert_input = colossal_moe_cuda.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
|
||||
|
||||
ctx.save_for_backward(mask, dest_idx)
|
||||
ctx.s = s
|
||||
ctx.h = h
|
||||
ctx.ec = ec
|
||||
|
||||
return expert_input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grad):
|
||||
mask, dest_idx = ctx.saved_tensors
|
||||
d_tokens = colossal_moe_cuda.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
|
||||
return d_tokens, None, None, None
|
||||
|
||||
|
||||
class MoeCombine(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
|
||||
assert logits.dtype == torch.float32
|
||||
|
||||
s = logits.size(0)
|
||||
e = logits.size(1)
|
||||
c = ec // e
|
||||
h = expert_tokens.size(-1)
|
||||
|
||||
fp16_flag = (expert_tokens.dtype == torch.float16)
|
||||
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
|
||||
ctokens = colossal_moe_cuda.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
|
||||
output = ctokens.to(torch.float16) if fp16_flag else ctokens
|
||||
|
||||
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
|
||||
ctx.s = s
|
||||
ctx.e = e
|
||||
ctx.c = c
|
||||
ctx.h = h
|
||||
ctx.fp16_flag = fp16_flag
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, tokens_grad):
|
||||
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
|
||||
|
||||
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
|
||||
else tokens_grad
|
||||
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
|
||||
d_expert, d_logits = colossal_moe_cuda.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits,
|
||||
mask, dest_idx)
|
||||
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
|
||||
|
||||
return d_expert, d_logits, None, None, None
|
||||
|
||||
|
||||
def moe_cumsum(inputs: Tensor):
|
||||
dim0 = inputs.size(0)
|
||||
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
|
||||
if flag and COL_MOE_KERNEL_FLAG:
|
||||
return colossal_moe_cuda.cumsum_sub_one(inputs)
|
||||
else:
|
||||
return torch.cumsum(inputs, dim=0) - 1
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
COL_MOE_KERNEL_FLAG = False
|
||||
try:
|
||||
import colossalai._C.moe
|
||||
|
||||
COL_MOE_KERNEL_FLAG = True
|
||||
except ImportError:
|
||||
print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
|
||||
|
||||
|
||||
class AllGather(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
if comm_size == 1:
|
||||
return inputs.unsqueeze(0)
|
||||
|
||||
buffer_shape = (comm_size,) + inputs.shape
|
||||
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
|
||||
dist.all_gather(buffer_list, inputs, group=group)
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||
return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None
|
||||
|
||||
|
||||
class ReduceScatter(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
if comm_size == 1:
|
||||
return inputs.squeeze(0)
|
||||
|
||||
if not inputs.is_contiguous():
|
||||
inputs = inputs.contiguous()
|
||||
|
||||
output_shape = inputs.shape[1:]
|
||||
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
|
||||
dist.reduce_scatter(outputs, buffer_list, group=group)
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||
return AllGather.forward(None, grad_outputs, ctx.comm_grp), None
|
||||
|
||||
|
||||
class AllToAll(torch.autograd.Function):
|
||||
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
|
||||
operation in torch.distributed.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
if not inputs.is_contiguous():
|
||||
inputs = inputs.contiguous()
|
||||
if dist.get_world_size(group) == 1:
|
||||
return inputs
|
||||
output = torch.empty_like(inputs)
|
||||
dist.all_to_all_single(output, inputs, group=group)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||
return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None
|
||||
|
||||
|
||||
class MoeDispatch(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, tokens, mask, dest_idx, ec):
|
||||
s = tokens.size(0)
|
||||
h = tokens.size(1)
|
||||
|
||||
expert_input = colossalai._C.moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
|
||||
|
||||
ctx.save_for_backward(mask, dest_idx)
|
||||
ctx.s = s
|
||||
ctx.h = h
|
||||
ctx.ec = ec
|
||||
|
||||
return expert_input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grad):
|
||||
mask, dest_idx = ctx.saved_tensors
|
||||
d_tokens = colossalai._C.moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
|
||||
return d_tokens, None, None, None
|
||||
|
||||
|
||||
class MoeCombine(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
|
||||
assert logits.dtype == torch.float32
|
||||
|
||||
s = logits.size(0)
|
||||
e = logits.size(1)
|
||||
c = ec // e
|
||||
h = expert_tokens.size(-1)
|
||||
|
||||
fp16_flag = (expert_tokens.dtype == torch.float16)
|
||||
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
|
||||
ctokens = colossalai._C.moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
|
||||
output = ctokens.to(torch.float16) if fp16_flag else ctokens
|
||||
|
||||
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
|
||||
ctx.s = s
|
||||
ctx.e = e
|
||||
ctx.c = c
|
||||
ctx.h = h
|
||||
ctx.fp16_flag = fp16_flag
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, tokens_grad):
|
||||
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
|
||||
|
||||
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
|
||||
else tokens_grad
|
||||
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
|
||||
d_expert, d_logits = colossalai._C.moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits,
|
||||
mask, dest_idx)
|
||||
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
|
||||
|
||||
return d_expert, d_logits, None, None, None
|
||||
|
||||
|
||||
def moe_cumsum(inputs: Tensor):
|
||||
dim0 = inputs.size(0)
|
||||
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
|
||||
if flag and COL_MOE_KERNEL_FLAG:
|
||||
return colossalai._C.moe.cumsum_sub_one(inputs)
|
||||
else:
|
||||
return torch.cumsum(inputs, dim=0) - 1
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.registry import OPTIMIZERS
|
||||
|
||||
from .nvme_optimizer import NVMeOptimizer
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@OPTIMIZERS.register_module
|
||||
|
@ -11,7 +13,7 @@ class CPUAdam(NVMeOptimizer):
|
|||
"""Implements Adam algorithm.
|
||||
|
||||
Supports parameters updating on both GPU and CPU, depanding on the device of paramters.
|
||||
But the parameters and gradients should on the same device:
|
||||
But the parameters and gradients should on the same device:
|
||||
* Parameters on CPU and gradients on CPU is allowed.
|
||||
* Parameters on GPU and gradients on GPU is allowed.
|
||||
* Parameters on GPU and gradients on CPU is **not** allowed.
|
||||
|
@ -44,7 +46,7 @@ class CPUAdam(NVMeOptimizer):
|
|||
(default: False) NOT SUPPORTED yet in CPUAdam!
|
||||
adamw_mode (boolean, optional): Apply L2 regularization or weight decay
|
||||
True for decoupled weight decay(also known as AdamW) (default: True)
|
||||
simd_log (boolean, optional): whether to show if you are using SIMD to
|
||||
simd_log (boolean, optional): whether to show if you are using SIMD to
|
||||
accelerate. (default: False)
|
||||
nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0.
|
||||
nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files.
|
||||
|
@ -75,10 +77,11 @@ class CPUAdam(NVMeOptimizer):
|
|||
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
||||
self.adamw_mode = adamw_mode
|
||||
try:
|
||||
import cpu_adam
|
||||
import colossalai._C.cpu_optim
|
||||
except ImportError:
|
||||
raise ImportError('Please install colossalai from source code to use CPUAdam')
|
||||
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
||||
self.cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay,
|
||||
adamw_mode)
|
||||
|
||||
def torch_adam_update(self,
|
||||
data,
|
||||
|
|
|
@ -20,7 +20,7 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
:class:`colossalai.nn.optimizer.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
|
||||
or ``torch.optim.Adam`` with ``adamw_mode=False``
|
||||
|
||||
:class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp.
|
||||
:class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp.
|
||||
|
||||
Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
|
||||
|
||||
|
@ -65,10 +65,11 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
self.adamw_mode = 1 if adamw_mode else 0
|
||||
self.set_grad_none = set_grad_none
|
||||
if multi_tensor_applier.available:
|
||||
import colossal_C
|
||||
import colossalai._C.fused_optim
|
||||
|
||||
# Skip buffer
|
||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
self.multi_tensor_adam = colossal_C.multi_tensor_adam
|
||||
self.multi_tensor_adam = colossalai._C.fused_optim.multi_tensor_adam
|
||||
else:
|
||||
raise RuntimeError('FusedAdam requires cuda extensions')
|
||||
|
||||
|
|
|
@ -76,13 +76,13 @@ class FusedLAMB(torch.optim.Optimizer):
|
|||
max_grad_norm=max_grad_norm)
|
||||
super(FusedLAMB, self).__init__(params, defaults)
|
||||
if multi_tensor_applier.available:
|
||||
import colossal_C
|
||||
self.multi_tensor_l2norm = colossal_C.multi_tensor_l2norm
|
||||
import colossalai._C.fused_optim
|
||||
self.multi_tensor_l2norm = colossalai._C.fused_optim.multi_tensor_l2norm
|
||||
# Skip buffer
|
||||
self._dummy_overflow_buf = torch.tensor([0],
|
||||
dtype=torch.int,
|
||||
device=self.param_groups[0]["params"][0].device)
|
||||
self.multi_tensor_lamb = colossal_C.multi_tensor_lamb
|
||||
self.multi_tensor_lamb = colossalai._C.fused_optim.multi_tensor_lamb
|
||||
else:
|
||||
raise RuntimeError('FusedLAMB requires cuda extensions')
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ class FusedSGD(Optimizer):
|
|||
|
||||
:class:`colossalai.nn.optimizer.FusedSGD` may be used as a drop-in replacement for ``torch.optim.SGD``
|
||||
|
||||
:class:`colossalai.nn.optimizer.FusedSGD` may be used with or without Amp.
|
||||
:class:`colossalai.nn.optimizer.FusedSGD` may be used with or without Amp.
|
||||
|
||||
Nesterov momentum is based on the formula from
|
||||
`On the importance of initialization and momentum in deep learning`__.
|
||||
|
@ -80,12 +80,13 @@ class FusedSGD(Optimizer):
|
|||
self.wd_after_momentum = wd_after_momentum
|
||||
|
||||
if multi_tensor_applier.available:
|
||||
import colossal_C
|
||||
import colossalai._C.fused_optim
|
||||
|
||||
# Skip buffer
|
||||
self._dummy_overflow_buf = torch.tensor([0],
|
||||
dtype=torch.int,
|
||||
device=self.param_groups[0]["params"][0].device)
|
||||
self.multi_tensor_sgd = colossal_C.multi_tensor_sgd
|
||||
self.multi_tensor_sgd = colossalai._C.fused_optim.multi_tensor_sgd
|
||||
else:
|
||||
raise RuntimeError('FusedSGD requires cuda extensions')
|
||||
|
||||
|
|
|
@ -77,14 +77,15 @@ class HybridAdam(NVMeOptimizer):
|
|||
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
||||
self.adamw_mode = adamw_mode
|
||||
try:
|
||||
import colossal_C
|
||||
import cpu_adam
|
||||
import colossalai._C.cpu_optim
|
||||
import colossalai._C.fused_optim
|
||||
except ImportError:
|
||||
raise ImportError('Please install colossalai from source code to use HybridAdam')
|
||||
|
||||
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
||||
self.cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay,
|
||||
adamw_mode)
|
||||
|
||||
self.gpu_adam_op = colossal_C.multi_tensor_adam
|
||||
self.gpu_adam_op = colossalai._C.fused_optim.multi_tensor_adam
|
||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -1,32 +1,33 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
import functools
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Union, Dict, Optional
|
||||
import functools
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._six import inf
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
try:
|
||||
import colossal_C
|
||||
import colossalai._C.fused_optim
|
||||
except:
|
||||
pass
|
||||
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch.distributed as dist
|
||||
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES)
|
||||
|
||||
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
from .multi_tensor_apply import multi_tensor_applier
|
||||
|
||||
from colossalai.tensor import ColoParameter, ProcessGroup
|
||||
from collections import defaultdict
|
||||
|
||||
from .multi_tensor_apply import multi_tensor_applier
|
||||
|
||||
|
||||
def print_rank_0(msg: str, logger=None):
|
||||
|
@ -132,7 +133,7 @@ def _calc_l2_norm(grads):
|
|||
if len(grads) > 0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
norm, _ = multi_tensor_applier(
|
||||
colossal_C.multi_tensor_l2norm,
|
||||
colossalai._C.fused_optim.multi_tensor_l2norm,
|
||||
dummy_overflow_buf,
|
||||
[grads],
|
||||
False # no per-parameter norm
|
||||
|
@ -269,7 +270,8 @@ def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None:
|
|||
cpu_grads.append(p.grad.detach())
|
||||
if len(cuda_grads) > 0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], clip_coef)
|
||||
multi_tensor_applier(colossalai._C.fused_optim.multi_tensor_scale, dummy_overflow_buf,
|
||||
[cuda_grads, cuda_grads], clip_coef)
|
||||
for g in cpu_grads:
|
||||
g.mul_(clip_coef)
|
||||
|
||||
|
@ -395,7 +397,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
|||
if enable_cuda_kernels:
|
||||
grads = [p.grad.detach() for p in params]
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
|
||||
multi_tensor_applier(colossalai._C.fused_optim.multi_tensor_scale, dummy_overflow_buf, [grads, grads],
|
||||
clip_coeff)
|
||||
else:
|
||||
for p in params:
|
||||
p.grad.detach().mul_(clip_coeff)
|
||||
|
|
|
@ -14,7 +14,7 @@ class MultiTensorApply(object):
|
|||
|
||||
def __init__(self, chunk_size):
|
||||
try:
|
||||
import colossal_C
|
||||
import colossalai._C.fused_optim
|
||||
MultiTensorApply.available = True
|
||||
self.chunk_size = chunk_size
|
||||
except ImportError as err:
|
||||
|
|
98
setup.py
98
setup.py
|
@ -1,7 +1,8 @@
|
|||
import os
|
||||
import subprocess
|
||||
import re
|
||||
from setuptools import find_packages, setup, Extension
|
||||
import subprocess
|
||||
|
||||
from setuptools import Extension, find_packages, setup
|
||||
|
||||
# ninja build does not work unless include_dirs are abs path
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
@ -104,7 +105,7 @@ def get_version():
|
|||
if build_cuda_ext:
|
||||
try:
|
||||
import torch
|
||||
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CUDAExtension)
|
||||
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
|
||||
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
||||
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
||||
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
||||
|
@ -148,7 +149,7 @@ if build_cuda_ext:
|
|||
extra_cuda_flags = ['-lineinfo']
|
||||
|
||||
ext_modules.append(
|
||||
cuda_ext_helper('colossal_C', [
|
||||
cuda_ext_helper('colossalai._C.fused_optim', [
|
||||
'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu',
|
||||
'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu'
|
||||
], extra_cuda_flags + cc_flag))
|
||||
|
@ -159,21 +160,21 @@ if build_cuda_ext:
|
|||
]
|
||||
|
||||
ext_modules.append(
|
||||
cuda_ext_helper('colossal_scaled_upper_triang_masked_softmax',
|
||||
cuda_ext_helper('colossalai._C.scaled_upper_triang_masked_softmax',
|
||||
['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu'],
|
||||
extra_cuda_flags + cc_flag))
|
||||
|
||||
ext_modules.append(
|
||||
cuda_ext_helper('colossal_scaled_masked_softmax',
|
||||
cuda_ext_helper('colossalai._C.scaled_masked_softmax',
|
||||
['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'], extra_cuda_flags + cc_flag))
|
||||
|
||||
ext_modules.append(
|
||||
cuda_ext_helper('colossal_moe_cuda', ['moe_cuda.cpp', 'moe_cuda_kernel.cu'], extra_cuda_flags + cc_flag))
|
||||
cuda_ext_helper('colossalai._C.moe', ['moe_cuda.cpp', 'moe_cuda_kernel.cu'], extra_cuda_flags + cc_flag))
|
||||
|
||||
extra_cuda_flags = ['-maxrregcount=50']
|
||||
|
||||
ext_modules.append(
|
||||
cuda_ext_helper('colossal_layer_norm_cuda', ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu'],
|
||||
cuda_ext_helper('colossalai._C.layer_norm', ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu'],
|
||||
extra_cuda_flags + cc_flag))
|
||||
|
||||
extra_cuda_flags = [
|
||||
|
@ -182,54 +183,53 @@ if build_cuda_ext:
|
|||
]
|
||||
|
||||
ext_modules.append(
|
||||
cuda_ext_helper('colossal_multihead_attention', [
|
||||
cuda_ext_helper('colossalai._C.multihead_attention', [
|
||||
'multihead_attention_1d.cpp', 'kernels/cublas_wrappers.cu', 'kernels/transform_kernels.cu',
|
||||
'kernels/dropout_kernels.cu', 'kernels/normalize_kernels.cu', 'kernels/softmax_kernels.cu',
|
||||
'kernels/general_kernels.cu', 'kernels/cuda_util.cu'
|
||||
], extra_cuda_flags + cc_flag))
|
||||
|
||||
extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native']
|
||||
ext_modules.append(cuda_ext_helper('cpu_adam', ['cpu_adam.cpp'], extra_cuda_flags, extra_cxx_flags))
|
||||
ext_modules.append(cuda_ext_helper('colossalai._C.cpu_optim', ['cpu_adam.cpp'], extra_cuda_flags, extra_cxx_flags))
|
||||
|
||||
setup(
|
||||
name='colossalai',
|
||||
version=get_version(),
|
||||
packages=find_packages(exclude=(
|
||||
'benchmark',
|
||||
'docker',
|
||||
'tests',
|
||||
'docs',
|
||||
'examples',
|
||||
'tests',
|
||||
'scripts',
|
||||
'requirements',
|
||||
'*.egg-info',
|
||||
)),
|
||||
description='An integrated large-scale model training system with efficient parallelization techniques',
|
||||
long_description=fetch_readme(),
|
||||
long_description_content_type='text/markdown',
|
||||
license='Apache Software License 2.0',
|
||||
url='https://www.colossalai.org',
|
||||
project_urls={
|
||||
'Forum': 'https://github.com/hpcaitech/ColossalAI/discussions',
|
||||
'Bug Tracker': 'https://github.com/hpcaitech/ColossalAI/issues',
|
||||
'Examples': 'https://github.com/hpcaitech/ColossalAI-Examples',
|
||||
'Documentation': 'http://colossalai.readthedocs.io',
|
||||
'Github': 'https://github.com/hpcaitech/ColossalAI',
|
||||
},
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={'build_ext': BuildExtension} if ext_modules else {},
|
||||
install_requires=fetch_requirements('requirements/requirements.txt'),
|
||||
entry_points='''
|
||||
setup(name='colossalai',
|
||||
version=get_version(),
|
||||
packages=find_packages(exclude=(
|
||||
'benchmark',
|
||||
'docker',
|
||||
'tests',
|
||||
'docs',
|
||||
'examples',
|
||||
'tests',
|
||||
'scripts',
|
||||
'requirements',
|
||||
'*.egg-info',
|
||||
)),
|
||||
description='An integrated large-scale model training system with efficient parallelization techniques',
|
||||
long_description=fetch_readme(),
|
||||
long_description_content_type='text/markdown',
|
||||
license='Apache Software License 2.0',
|
||||
url='https://www.colossalai.org',
|
||||
project_urls={
|
||||
'Forum': 'https://github.com/hpcaitech/ColossalAI/discussions',
|
||||
'Bug Tracker': 'https://github.com/hpcaitech/ColossalAI/issues',
|
||||
'Examples': 'https://github.com/hpcaitech/ColossalAI-Examples',
|
||||
'Documentation': 'http://colossalai.readthedocs.io',
|
||||
'Github': 'https://github.com/hpcaitech/ColossalAI',
|
||||
},
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={'build_ext': BuildExtension} if ext_modules else {},
|
||||
install_requires=fetch_requirements('requirements/requirements.txt'),
|
||||
entry_points='''
|
||||
[console_scripts]
|
||||
colossalai=colossalai.cli:cli
|
||||
''',
|
||||
python_requires='>=3.6',
|
||||
classifiers=[
|
||||
'Programming Language :: Python :: 3',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Environment :: GPU :: NVIDIA CUDA',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'Topic :: System :: Distributed Computing',
|
||||
],
|
||||
)
|
||||
python_requires='>=3.6',
|
||||
classifiers=[
|
||||
'Programming Language :: Python :: 3',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Environment :: GPU :: NVIDIA CUDA',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'Topic :: System :: Distributed Computing',
|
||||
],
|
||||
package_data={'colossalai': ['_C/*.pyi']})
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.testing import parameterize
|
||||
|
@ -66,8 +67,8 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype):
|
|||
exp_avg_sq_copy = exp_avg_sq.clone()
|
||||
|
||||
try:
|
||||
import cpu_adam
|
||||
cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
|
||||
import colossalai._C.cpu_optim
|
||||
cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
|
||||
except:
|
||||
raise ImportError("Import cpu adam error, please install colossal from source code")
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from numpy import dtype
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import math
|
||||
from numpy import dtype
|
||||
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.utils import multi_tensor_applier
|
||||
|
@ -47,11 +47,11 @@ def torch_adam_update(
|
|||
@parameterize('g_dtype', [torch.float, torch.half])
|
||||
def test_adam(adamw, step, p_dtype, g_dtype):
|
||||
try:
|
||||
import colossal_C
|
||||
fused_adam = colossal_C.multi_tensor_adam
|
||||
import colossalai._C.fused_optim
|
||||
fused_adam = colossalai._C.fused_optim.multi_tensor_adam
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
except:
|
||||
raise ImportError("No colossal_C kernel installed.")
|
||||
raise ImportError("No colossalai._C.fused_optim kernel installed.")
|
||||
|
||||
count = 0
|
||||
|
||||
|
|
Loading…
Reference in New Issue