[kernel] move all symlinks of kernel to `colossalai._C` (#1971)

pull/1972/head
ver217 2022-11-17 13:42:33 +08:00 committed by GitHub
parent 7e24b9b9ee
commit f8a7148dec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 463 additions and 322 deletions

View File

@ -38,7 +38,6 @@ jobs:
pip install -r requirements/requirements.txt pip install -r requirements/requirements.txt
pip install -v -e . pip install -v -e .
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
cp /__w/ColossalAI/ColossalAI/*.so /github/home/cuda_ext_cache/
pip install -r requirements/requirements-test.txt pip install -r requirements/requirements-test.txt
- name: Unit Testing - name: Unit Testing
run: | run: |

View File

@ -36,7 +36,6 @@ jobs:
pip install -r requirements/requirements.txt pip install -r requirements/requirements.txt
pip install -v -e . pip install -v -e .
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
cp /__w/ColossalAI/ColossalAI/*.so /github/home/cuda_ext_cache/
pip install -r requirements/requirements-test.txt pip install -r requirements/requirements-test.txt
- name: Unit Testing - name: Unit Testing
run: | run: |

View File

@ -1,3 +1,3 @@
include *.txt README.md include *.txt README.md
recursive-include requirements *.txt recursive-include requirements *.txt
recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi

View File

@ -0,0 +1,9 @@
from . import (
cpu_optim,
fused_optim,
layer_norm,
moe,
multihead_attention,
scaled_masked_softmax,
scaled_upper_triang_masked_softmax,
)

View File

@ -0,0 +1,8 @@
from torch import Tensor
class CPUAdamOptimizer:
def __init__(self, lr: float, beta1: float, beta2: float, eps: float,
weight_decay: float, adamw_mode: float) -> None: ...
def step(self, step: int, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, bias_correction: bool,
param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor, loss_scale: float) -> None: ...

View File

@ -0,0 +1,23 @@
from typing import List
from torch import Tensor
def multi_tensor_scale(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], scale: float) -> None:
...
def multi_tensor_sgd(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], weight_decay: float,
momentum: float, dampening: float, lr: float, nesterov: bool, first_run: bool, weight_decay_after_momentum: bool, scale: float) -> None:
...
def multi_tensor_adam(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, mode: int, bias_correction: int, weight_decay: float) -> None:
...
def multi_tensor_lamb(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, bias_correction: int, weight_decay: float, grad_averaging: int, mode: int, global_grad_norm: Tensor, max_grad_norm: float, use_nvlamb_python: bool) -> None:
...
def multi_tensor_l2norm(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], per_tensor_python: bool) -> None:
...

View File

@ -0,0 +1,11 @@
from typing import List
from torch import Tensor
def forward_affine(input: Tensor, normalized_shape: List[int], gamma: Tensor, beta: Tensor, epsilon: float) -> List[Tensor]:
...
def backward_affine(dout: Tensor, mean: Tensor, invvar: Tensor, input: Tensor,
normalized_shape: List[int], gamma: Tensor, beta: Tensor, epsilon: float) -> List[Tensor]:
...

20
colossalai/_C/moe.pyi Normal file
View File

@ -0,0 +1,20 @@
from torch import Tensor
def cumsum_sub_one(mask: Tensor) -> Tensor:
...
def dispatch_forward(s: int, ec: int, h: int, batch_tokens: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
...
def dispatch_backward(s: int, ec: int, h: int, expert_grad: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
...
def combine_forward(s: int, e: int, c: int, h: int, expert_tokens: Tensor, logits: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
...
def combine_backward(s: int, e: int, c: int, h: int, tokens_grad: Tensor, expert_tokens: Tensor, logits: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
...

View File

@ -0,0 +1,55 @@
from typing import List
from torch import Tensor
from torch.distributed import ProcessGroup
def multihead_attention_fw_fp32(layer_id: int, input: Tensor, input_mask: Tensor,
in_proj_weight: Tensor, in_proj_bias: Tensor,
out_proj_weight: Tensor, out_proj_bias: Tensor,
norm_weight: Tensor, norm_bias: Tensor,
training_mode: bool, prelayernorm: bool) -> List[Tensor]:
...
def multihead_attention_fw_fp16(layer_id: int, input: Tensor, input_mask: Tensor,
in_proj_weight: Tensor, in_proj_bias: Tensor,
out_proj_weight: Tensor, out_proj_bias: Tensor,
norm_weight: Tensor, norm_bias: Tensor,
training_mode: bool, prelayernorm: bool) -> List[Tensor]:
...
def multihead_attention_bw_fp32(layer_id: int, grad_dec_output: Tensor,
output: Tensor, input: Tensor,
input_mask: Tensor, in_proj_weight: Tensor,
in_proj_bias: Tensor, out_proj_weight: Tensor,
out_proj_bias: Tensor, norm_weight: Tensor,
norm_bias: Tensor) -> List[Tensor]:
...
def multihead_attention_bw_fp16(layer_id: int, grad_dec_output: Tensor,
output: Tensor, input: Tensor,
input_mask: Tensor, in_proj_weight: Tensor,
in_proj_bias: Tensor, out_proj_weight: Tensor,
out_proj_bias: Tensor, norm_weight: Tensor,
norm_bias: Tensor) -> List[Tensor]:
...
def create_multihead_attention_fp32(layer_id: int, max_batch_tokens: int,
max_seq_len: int, hidden_dim: int, num_heads: int,
attn_prob_dropout_ratio: float,
hidden_dropout_ratio: float,
pre_or_postLayerNorm: bool,
pg: ProcessGroup) -> int:
...
def create_multihead_attention_fp16(layer_id: int, max_batch_tokens: int,
max_seq_len: int, hidden_dim: int, num_heads: int,
attn_prob_dropout_ratio: float,
hidden_dropout_ratio: float,
pre_or_postLayerNorm: bool,
pg: ProcessGroup) -> int:
...

View File

@ -0,0 +1,12 @@
from torch import Tensor
def forward(input: Tensor, mask: Tensor, scale: float) -> Tensor:
...
def backward(output_grads: Tensor, softmax_results: Tensor, scale: float) -> Tensor:
...
def get_batch_per_block(query_seq_len: int, key_seq_len: int, batches: int, attn_heads: int) -> int:
...

View File

@ -0,0 +1,8 @@
from torch import Tensor
def forward(input: Tensor, scale: float) -> Tensor:
...
def backward(output_grads: Tensor, softmax_results: Tensor, scale: float) -> Tensor:
...

View File

@ -5,7 +5,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
try: try:
import colossal_C import colossalai._C.fused_optim
except: except:
print('Colossalai should be built with cuda extension to use the FP16 optimizer') print('Colossalai should be built with cuda extension to use the FP16 optimizer')
@ -35,7 +35,7 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
if overflow_buf: if overflow_buf:
overflow_buf.fill_(0) overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy. # Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(colossal_C.multi_tensor_scale, overflow_buf, [this, that], 1.0) multi_tensor_applier(colossalai._C.fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0)
else: else:
for this_, that_ in zip(this, that): for this_, that_ in zip(this, that):
that_.copy_(this_) that_.copy_(this_)

View File

@ -1,5 +1,6 @@
import click
import subprocess import subprocess
import click
import torch import torch
from torch.utils.cpp_extension import CUDA_HOME from torch.utils.cpp_extension import CUDA_HOME
@ -17,7 +18,7 @@ def check_installation():
def _check_cuda_extension_installed(): def _check_cuda_extension_installed():
try: try:
import colossal_C import colossalai._C.fused_optim
is_cuda_extension_installed = u'\u2713' is_cuda_extension_installed = u'\u2713'
except ImportError: except ImportError:
is_cuda_extension_installed = 'x' is_cuda_extension_installed = 'x'

View File

@ -3,14 +3,11 @@
with some changes. """ with some changes. """
import numbers import numbers
import torch
from torch.nn.parameter import Parameter
from torch.nn import init
from torch.cuda.amp import custom_fwd, custom_bwd
import importlib
global colossal_layer_norm_cuda import torch
colossal_layer_norm_cuda = None from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn import init
from torch.nn.parameter import Parameter
class FusedLayerNormAffineFunction(torch.autograd.Function): class FusedLayerNormAffineFunction(torch.autograd.Function):
@ -18,13 +15,17 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input, weight, bias, normalized_shape, eps): def forward(ctx, input, weight, bias, normalized_shape, eps):
try:
import colossalai._C.layer_norm
except ImportError:
raise RuntimeError('FusedLayerNormAffineFunction requires cuda extensions')
ctx.normalized_shape = normalized_shape ctx.normalized_shape = normalized_shape
ctx.eps = eps ctx.eps = eps
input_ = input.contiguous() input_ = input.contiguous()
weight_ = weight.contiguous() weight_ = weight.contiguous()
bias_ = bias.contiguous() bias_ = bias.contiguous()
output, mean, invvar = colossal_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, bias_, output, mean, invvar = colossalai._C.layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_,
ctx.eps) ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar) ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
@ -33,11 +34,15 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
try:
import colossalai._C.layer_norm
except ImportError:
raise RuntimeError('FusedLayerNormAffineFunction requires cuda extensions')
input_, weight_, bias_, mean, invvar = ctx.saved_tensors input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \ grad_input, grad_weight, grad_bias \
= colossal_layer_norm_cuda.backward_affine( = colossalai._C.layer_norm.backward_affine(
grad_output.contiguous(), mean, invvar, grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape, input_, ctx.normalized_shape,
weight_, bias_, ctx.eps) weight_, bias_, ctx.eps)
@ -50,13 +55,6 @@ class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None): def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None):
super(MixedFusedLayerNorm, self).__init__() super(MixedFusedLayerNorm, self).__init__()
global colossal_layer_norm_cuda
if colossal_layer_norm_cuda is None:
try:
colossal_layer_norm_cuda = importlib.import_module("colossal_layer_norm_cuda")
except ImportError:
raise RuntimeError('MixedFusedLayerNorm requires cuda extensions')
if isinstance(normalized_shape, numbers.Integral): if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,) normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape) self.normalized_shape = torch.Size(normalized_shape)

View File

@ -1,5 +1,4 @@
import math import math
import importlib
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
@ -37,21 +36,21 @@ colossal_multihead_attention = None
@dataclass @dataclass
class Config: class Config:
max_batch_tokens: int # max batch token numbers max_batch_tokens: int # max batch token numbers
max_seq_len: int # max sequence length max_seq_len: int # max sequence length
hidden_size: int # size of transformer hidden layers hidden_size: int # size of transformer hidden layers
nhead: int # number of heads in attention nhead: int # number of heads in attention
attn_prob_dropout_ratio: float # attention score dropout ratio attn_prob_dropout_ratio: float # attention score dropout ratio
hidden_dropout_ratio: float # dropout ration before residual hidden_dropout_ratio: float # dropout ration before residual
norm_first: bool # norm_first norm_first: bool # norm_first
fp16: bool # fp16 presion fp16: bool # fp16 presion
class MultiHeadAttention1DFunc(Function): class MultiHeadAttention1DFunc(Function):
@staticmethod @staticmethod
def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, out_proj_bias, norm_weight,
out_proj_bias, norm_weight, norm_bias, config): norm_bias, config):
cuda_module = colossal_multihead_attention cuda_module = colossal_multihead_attention
forward_func = (cuda_module.multihead_attention_fw_fp16 forward_func = (cuda_module.multihead_attention_fw_fp16
if config.fp16 else cuda_module.multihead_attention_fw_fp32) if config.fp16 else cuda_module.multihead_attention_fw_fp32)
@ -59,13 +58,12 @@ class MultiHeadAttention1DFunc(Function):
input = input.to(torch.half) input = input.to(torch.half)
input_mask = input_mask.to(torch.half) input_mask = input_mask.to(torch.half)
(output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias, (output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight,
out_proj_weight, out_proj_bias, norm_weight, norm_bias, out_proj_bias, norm_weight, norm_bias, config.training, config.norm_first)
config.training, config.norm_first)
if config.is_grad_enabled and config.training: if config.is_grad_enabled and config.training:
ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias, ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight,
out_proj_weight, out_proj_bias, norm_weight, norm_bias) out_proj_bias, norm_weight, norm_bias)
ctx.config = config ctx.config = config
return output return output
@ -98,8 +96,8 @@ class MultiHeadAttention1DFunc(Function):
ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight, ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight,
in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias) in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias)
return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias,
grad_out_proj_bias, grad_norm_weight, grad_norm_bias, None) grad_norm_weight, grad_norm_bias, None)
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
@ -121,19 +119,11 @@ class MultiHeadAttention(nn.Module):
layer_id = 0 layer_id = 0
def __init__(self, def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None):
hidden_size,
nhead,
batch_size,
max_seq_len,
dropout=0.0,
norm_first=False,
fp16=True,
pg=None):
super(MultiHeadAttention, self).__init__() super(MultiHeadAttention, self).__init__()
self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first,
dropout, norm_first, fp16) fp16)
check_config(self.config) check_config(self.config)
self.pg = pg self.pg = pg
self.pg_size = 1 self.pg_size = 1
@ -146,7 +136,8 @@ class MultiHeadAttention(nn.Module):
global colossal_multihead_attention global colossal_multihead_attention
if colossal_multihead_attention is None: if colossal_multihead_attention is None:
try: try:
colossal_multihead_attention = importlib.import_module("colossal_multihead_attention") import colossalai._C.multihead_attention
colossal_multihead_attention = colossalai._C.multihead_attention
except ImportError: except ImportError:
raise RuntimeError('MultiHeadAttention requires cuda extensions') raise RuntimeError('MultiHeadAttention requires cuda extensions')
@ -215,14 +206,13 @@ class MultiHeadAttention(nn.Module):
with torch.no_grad(): with torch.no_grad():
self.in_proj_weight.copy_( self.in_proj_weight.copy_(
attn_qkvw_global.view(3, hs, hs)[ attn_qkvw_global.view(3, hs, hs)[:,
:, int(hs * rank_in_pg / self.pg_size): int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) /
int(hs * (rank_in_pg + 1) / self.pg_size), self.pg_size), :])
:])
self.in_proj_bias.copy_( self.in_proj_bias.copy_(
attn_qkvb_global.view(3, hs)[ attn_qkvb_global.view(3, hs)[:,
:, int(hs * rank_in_pg / self.pg_size): int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) /
int(hs * (rank_in_pg + 1) / self.pg_size)]) self.pg_size)])
attn_ow_global = torch.empty(hs, hs) attn_ow_global = torch.empty(hs, hs)
nn.init.xavier_uniform_(attn_ow_global, 1.0) nn.init.xavier_uniform_(attn_ow_global, 1.0)
@ -230,9 +220,9 @@ class MultiHeadAttention(nn.Module):
torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg) torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg)
attn_ow_global = attn_ow_global.cpu() attn_ow_global = attn_ow_global.cpu()
with torch.no_grad(): with torch.no_grad():
self.out_proj_weight.copy_(attn_ow_global[ self.out_proj_weight.copy_(attn_ow_global[:,
:, int(hs * rank_in_pg / self.pg_size): int(hs * rank_in_pg /
int(hs * (rank_in_pg + 1) / self.pg_size)]) self.pg_size):int(hs * (rank_in_pg + 1) / self.pg_size)])
else: else:
attn_qkvw = self.in_proj_weight.view(-1, hs) attn_qkvw = self.in_proj_weight.view(-1, hs)
@ -243,10 +233,7 @@ class MultiHeadAttention(nn.Module):
nn.init.xavier_uniform_(self.out_proj_weight, 1.0) nn.init.xavier_uniform_(self.out_proj_weight, 1.0)
def state_dict(self, destination=None, prefix="", keep_vars=False): def state_dict(self, destination=None, prefix="", keep_vars=False):
destination = torch.nn.Module.state_dict(self, destination = torch.nn.Module.state_dict(self, destination=destination, prefix=prefix, keep_vars=keep_vars)
destination=destination,
prefix=prefix,
keep_vars=keep_vars)
return destination return destination
def forward(self, hidden_states, encoder_padding_mask): def forward(self, hidden_states, encoder_padding_mask):
@ -257,8 +244,7 @@ class MultiHeadAttention(nn.Module):
bs, sl, dim = hidden_states.size() bs, sl, dim = hidden_states.size()
if bs * sl > self.config.max_batch_tokens: if bs * sl > self.config.max_batch_tokens:
raise ValueError( raise ValueError(f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.")
f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.")
if sl > self.config.max_seq_len: if sl > self.config.max_seq_len:
raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.") raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.")
if len(encoder_padding_mask.size()) == 1: if len(encoder_padding_mask.size()) == 1:
@ -266,9 +252,8 @@ class MultiHeadAttention(nn.Module):
else: else:
assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1) assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1)
output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask, output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask, self.in_proj_weight,
self.in_proj_weight, self.in_proj_bias, self.in_proj_bias, self.out_proj_weight, self.out_proj_bias,
self.out_proj_weight, self.out_proj_bias,
self.norm_weight, self.norm_bias, self.config) self.norm_weight, self.norm_bias, self.config)
return output.to(self.precision) return output.to(self.precision)

View File

@ -1,9 +1,10 @@
"""This code from NVIDIA Megatron """This code from NVIDIA Megatron
with some changes. """ with some changes. """
import enum
import torch import torch
import torch.nn as nn import torch.nn as nn
import enum
class AttnMaskType(enum.Enum): class AttnMaskType(enum.Enum):
@ -23,12 +24,12 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, inputs, scale): def forward(ctx, inputs, scale):
try: try:
import colossal_scaled_upper_triang_masked_softmax import colossalai._C.scaled_upper_triang_masked_softmax
except ImportError: except ImportError:
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions') raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) softmax_results = colossalai._C.scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
return softmax_results return softmax_results
@ -36,12 +37,13 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, output_grads): def backward(ctx, output_grads):
try: try:
import colossal_scaled_upper_triang_masked_softmax import colossalai._C.scaled_upper_triang_masked_softmax
except ImportError: except ImportError:
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions') raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
softmax_results, scale_t = ctx.saved_tensors softmax_results, scale_t = ctx.saved_tensors
input_grads = colossal_scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) input_grads = colossalai._C.scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results,
scale_t[0])
return input_grads, None return input_grads, None
@ -58,26 +60,26 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, inputs, mask, scale): def forward(ctx, inputs, mask, scale):
try: try:
import colossal_scaled_masked_softmax import colossalai._C.scaled_masked_softmax
except ImportError: except ImportError:
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions') raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = colossal_scaled_masked_softmax.forward(inputs, mask, scale_t[0]) softmax_results = colossalai._C.scaled_masked_softmax.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
return softmax_results return softmax_results
@staticmethod @staticmethod
def backward(ctx, output_grads): def backward(ctx, output_grads):
try: try:
import colossal_scaled_masked_softmax import colossalai._C.scaled_masked_softmax
except ImportError: except ImportError:
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions') raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
softmax_results, scale_t = ctx.saved_tensors softmax_results, scale_t = ctx.saved_tensors
input_grads = colossal_scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) input_grads = colossalai._C.scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None return input_grads, None, None
@ -184,8 +186,8 @@ class FusedScaleMaskSoftmax(nn.Module):
@staticmethod @staticmethod
def get_batch_per_block(sq, sk, b, np): def get_batch_per_block(sq, sk, b, np):
try: try:
import colossal_scaled_masked_softmax import colossalai._C.scaled_masked_softmax
except ImportError: except ImportError:
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions') raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
return colossal_scaled_masked_softmax.get_batch_per_block(sq, sk, b, np) return colossalai._C.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)

View File

@ -1,153 +1,154 @@
import torch from typing import Any, Optional, Tuple
import torch.distributed as dist
from torch import Tensor import torch
from typing import Any, Tuple, Optional import torch.distributed as dist
from torch.distributed import ProcessGroup from torch import Tensor
from torch.distributed import ProcessGroup
COL_MOE_KERNEL_FLAG = False
try: COL_MOE_KERNEL_FLAG = False
import colossal_moe_cuda try:
import colossalai._C.moe
COL_MOE_KERNEL_FLAG = True
except ImportError: COL_MOE_KERNEL_FLAG = True
print("If you want to activate cuda mode for MoE, please install with cuda_ext!") except ImportError:
print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
class AllGather(torch.autograd.Function):
class AllGather(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: @staticmethod
if ctx is not None: def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
ctx.comm_grp = group if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1: comm_size = dist.get_world_size(group)
return inputs.unsqueeze(0) if comm_size == 1:
return inputs.unsqueeze(0)
buffer_shape = (comm_size,) + inputs.shape
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) buffer_shape = (comm_size,) + inputs.shape
buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
dist.all_gather(buffer_list, inputs, group=group) buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
return outputs dist.all_gather(buffer_list, inputs, group=group)
return outputs
@staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: @staticmethod
return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None
class ReduceScatter(torch.autograd.Function):
class ReduceScatter(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: @staticmethod
if ctx is not None: def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
ctx.comm_grp = group if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1: comm_size = dist.get_world_size(group)
return inputs.squeeze(0) if comm_size == 1:
return inputs.squeeze(0)
if not inputs.is_contiguous():
inputs = inputs.contiguous() if not inputs.is_contiguous():
inputs = inputs.contiguous()
output_shape = inputs.shape[1:]
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) output_shape = inputs.shape[1:]
buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
dist.reduce_scatter(outputs, buffer_list, group=group) buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
return outputs dist.reduce_scatter(outputs, buffer_list, group=group)
return outputs
@staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: @staticmethod
return AllGather.forward(None, grad_outputs, ctx.comm_grp), None def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllGather.forward(None, grad_outputs, ctx.comm_grp), None
class AllToAll(torch.autograd.Function):
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single class AllToAll(torch.autograd.Function):
operation in torch.distributed. """Dispatches input tensor [e, c, h] to all experts by all_to_all_single
""" operation in torch.distributed.
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: @staticmethod
if ctx is not None: def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
ctx.comm_grp = group if ctx is not None:
if not inputs.is_contiguous(): ctx.comm_grp = group
inputs = inputs.contiguous() if not inputs.is_contiguous():
if dist.get_world_size(group) == 1: inputs = inputs.contiguous()
return inputs if dist.get_world_size(group) == 1:
output = torch.empty_like(inputs) return inputs
dist.all_to_all_single(output, inputs, group=group) output = torch.empty_like(inputs)
return output dist.all_to_all_single(output, inputs, group=group)
return output
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: @staticmethod
return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None
class MoeDispatch(torch.autograd.Function):
class MoeDispatch(torch.autograd.Function):
@staticmethod
def forward(ctx, tokens, mask, dest_idx, ec): @staticmethod
s = tokens.size(0) def forward(ctx, tokens, mask, dest_idx, ec):
h = tokens.size(1) s = tokens.size(0)
h = tokens.size(1)
expert_input = colossal_moe_cuda.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
expert_input = colossalai._C.moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
ctx.save_for_backward(mask, dest_idx)
ctx.s = s ctx.save_for_backward(mask, dest_idx)
ctx.h = h ctx.s = s
ctx.ec = ec ctx.h = h
ctx.ec = ec
return expert_input
return expert_input
@staticmethod
def backward(ctx, output_grad): @staticmethod
mask, dest_idx = ctx.saved_tensors def backward(ctx, output_grad):
d_tokens = colossal_moe_cuda.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx) mask, dest_idx = ctx.saved_tensors
return d_tokens, None, None, None d_tokens = colossalai._C.moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
return d_tokens, None, None, None
class MoeCombine(torch.autograd.Function):
class MoeCombine(torch.autograd.Function):
@staticmethod
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): @staticmethod
assert logits.dtype == torch.float32 def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
assert logits.dtype == torch.float32
s = logits.size(0)
e = logits.size(1) s = logits.size(0)
c = ec // e e = logits.size(1)
h = expert_tokens.size(-1) c = ec // e
h = expert_tokens.size(-1)
fp16_flag = (expert_tokens.dtype == torch.float16)
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens fp16_flag = (expert_tokens.dtype == torch.float16)
ctokens = colossal_moe_cuda.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
output = ctokens.to(torch.float16) if fp16_flag else ctokens ctokens = colossalai._C.moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
output = ctokens.to(torch.float16) if fp16_flag else ctokens
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
ctx.s = s ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
ctx.e = e ctx.s = s
ctx.c = c ctx.e = e
ctx.h = h ctx.c = c
ctx.fp16_flag = fp16_flag ctx.h = h
ctx.fp16_flag = fp16_flag
return output
return output
@staticmethod
def backward(ctx, tokens_grad): @staticmethod
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors def backward(ctx, tokens_grad):
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
else tokens_grad cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens else tokens_grad
d_expert, d_logits = colossal_moe_cuda.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
mask, dest_idx) d_expert, d_logits = colossalai._C.moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits,
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert mask, dest_idx)
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
return d_expert, d_logits, None, None, None
return d_expert, d_logits, None, None, None
def moe_cumsum(inputs: Tensor):
dim0 = inputs.size(0) def moe_cumsum(inputs: Tensor):
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) dim0 = inputs.size(0)
if flag and COL_MOE_KERNEL_FLAG: flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
return colossal_moe_cuda.cumsum_sub_one(inputs) if flag and COL_MOE_KERNEL_FLAG:
else: return colossalai._C.moe.cumsum_sub_one(inputs)
return torch.cumsum(inputs, dim=0) - 1 else:
return torch.cumsum(inputs, dim=0) - 1

View File

@ -1,9 +1,11 @@
import math import math
from typing import Optional
import torch import torch
from colossalai.registry import OPTIMIZERS from colossalai.registry import OPTIMIZERS
from .nvme_optimizer import NVMeOptimizer from .nvme_optimizer import NVMeOptimizer
from typing import Optional
@OPTIMIZERS.register_module @OPTIMIZERS.register_module
@ -11,7 +13,7 @@ class CPUAdam(NVMeOptimizer):
"""Implements Adam algorithm. """Implements Adam algorithm.
Supports parameters updating on both GPU and CPU, depanding on the device of paramters. Supports parameters updating on both GPU and CPU, depanding on the device of paramters.
But the parameters and gradients should on the same device: But the parameters and gradients should on the same device:
* Parameters on CPU and gradients on CPU is allowed. * Parameters on CPU and gradients on CPU is allowed.
* Parameters on GPU and gradients on GPU is allowed. * Parameters on GPU and gradients on GPU is allowed.
* Parameters on GPU and gradients on CPU is **not** allowed. * Parameters on GPU and gradients on CPU is **not** allowed.
@ -44,7 +46,7 @@ class CPUAdam(NVMeOptimizer):
(default: False) NOT SUPPORTED yet in CPUAdam! (default: False) NOT SUPPORTED yet in CPUAdam!
adamw_mode (boolean, optional): Apply L2 regularization or weight decay adamw_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True) True for decoupled weight decay(also known as AdamW) (default: True)
simd_log (boolean, optional): whether to show if you are using SIMD to simd_log (boolean, optional): whether to show if you are using SIMD to
accelerate. (default: False) accelerate. (default: False)
nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0. nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0.
nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files. nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files.
@ -75,10 +77,11 @@ class CPUAdam(NVMeOptimizer):
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode self.adamw_mode = adamw_mode
try: try:
import cpu_adam import colossalai._C.cpu_optim
except ImportError: except ImportError:
raise ImportError('Please install colossalai from source code to use CPUAdam') raise ImportError('Please install colossalai from source code to use CPUAdam')
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) self.cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay,
adamw_mode)
def torch_adam_update(self, def torch_adam_update(self,
data, data,

View File

@ -20,7 +20,7 @@ class FusedAdam(torch.optim.Optimizer):
:class:`colossalai.nn.optimizer.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``, :class:`colossalai.nn.optimizer.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
or ``torch.optim.Adam`` with ``adamw_mode=False`` or ``torch.optim.Adam`` with ``adamw_mode=False``
:class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp. :class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp.
Adam was been proposed in `Adam: A Method for Stochastic Optimization`_. Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
@ -65,10 +65,11 @@ class FusedAdam(torch.optim.Optimizer):
self.adamw_mode = 1 if adamw_mode else 0 self.adamw_mode = 1 if adamw_mode else 0
self.set_grad_none = set_grad_none self.set_grad_none = set_grad_none
if multi_tensor_applier.available: if multi_tensor_applier.available:
import colossal_C import colossalai._C.fused_optim
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_adam = colossal_C.multi_tensor_adam self.multi_tensor_adam = colossalai._C.fused_optim.multi_tensor_adam
else: else:
raise RuntimeError('FusedAdam requires cuda extensions') raise RuntimeError('FusedAdam requires cuda extensions')

View File

@ -76,13 +76,13 @@ class FusedLAMB(torch.optim.Optimizer):
max_grad_norm=max_grad_norm) max_grad_norm=max_grad_norm)
super(FusedLAMB, self).__init__(params, defaults) super(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available: if multi_tensor_applier.available:
import colossal_C import colossalai._C.fused_optim
self.multi_tensor_l2norm = colossal_C.multi_tensor_l2norm self.multi_tensor_l2norm = colossalai._C.fused_optim.multi_tensor_l2norm
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.tensor([0], self._dummy_overflow_buf = torch.tensor([0],
dtype=torch.int, dtype=torch.int,
device=self.param_groups[0]["params"][0].device) device=self.param_groups[0]["params"][0].device)
self.multi_tensor_lamb = colossal_C.multi_tensor_lamb self.multi_tensor_lamb = colossalai._C.fused_optim.multi_tensor_lamb
else: else:
raise RuntimeError('FusedLAMB requires cuda extensions') raise RuntimeError('FusedLAMB requires cuda extensions')

View File

@ -20,7 +20,7 @@ class FusedSGD(Optimizer):
:class:`colossalai.nn.optimizer.FusedSGD` may be used as a drop-in replacement for ``torch.optim.SGD`` :class:`colossalai.nn.optimizer.FusedSGD` may be used as a drop-in replacement for ``torch.optim.SGD``
:class:`colossalai.nn.optimizer.FusedSGD` may be used with or without Amp. :class:`colossalai.nn.optimizer.FusedSGD` may be used with or without Amp.
Nesterov momentum is based on the formula from Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__. `On the importance of initialization and momentum in deep learning`__.
@ -80,12 +80,13 @@ class FusedSGD(Optimizer):
self.wd_after_momentum = wd_after_momentum self.wd_after_momentum = wd_after_momentum
if multi_tensor_applier.available: if multi_tensor_applier.available:
import colossal_C import colossalai._C.fused_optim
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.tensor([0], self._dummy_overflow_buf = torch.tensor([0],
dtype=torch.int, dtype=torch.int,
device=self.param_groups[0]["params"][0].device) device=self.param_groups[0]["params"][0].device)
self.multi_tensor_sgd = colossal_C.multi_tensor_sgd self.multi_tensor_sgd = colossalai._C.fused_optim.multi_tensor_sgd
else: else:
raise RuntimeError('FusedSGD requires cuda extensions') raise RuntimeError('FusedSGD requires cuda extensions')

View File

@ -77,14 +77,15 @@ class HybridAdam(NVMeOptimizer):
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode self.adamw_mode = adamw_mode
try: try:
import colossal_C import colossalai._C.cpu_optim
import cpu_adam import colossalai._C.fused_optim
except ImportError: except ImportError:
raise ImportError('Please install colossalai from source code to use HybridAdam') raise ImportError('Please install colossalai from source code to use HybridAdam')
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) self.cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay,
adamw_mode)
self.gpu_adam_op = colossal_C.multi_tensor_adam self.gpu_adam_op = colossalai._C.fused_optim.multi_tensor_adam
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.cuda.IntTensor([0])
@torch.no_grad() @torch.no_grad()

View File

@ -1,32 +1,33 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import functools
import os import os
import random import random
import socket import socket
from pathlib import Path from pathlib import Path
from typing import Callable, List, Union, Dict, Optional from typing import Callable, Dict, List, Optional, Union
import functools
import torch import torch
from torch._six import inf from torch._six import inf
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
try: try:
import colossal_C import colossalai._C.fused_optim
except: except:
pass pass
from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES)
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from .multi_tensor_apply import multi_tensor_applier
from colossalai.tensor import ColoParameter, ProcessGroup from colossalai.tensor import ColoParameter, ProcessGroup
from collections import defaultdict
from .multi_tensor_apply import multi_tensor_applier
def print_rank_0(msg: str, logger=None): def print_rank_0(msg: str, logger=None):
@ -132,7 +133,7 @@ def _calc_l2_norm(grads):
if len(grads) > 0: if len(grads) > 0:
dummy_overflow_buf = torch.cuda.IntTensor([0]) dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier( norm, _ = multi_tensor_applier(
colossal_C.multi_tensor_l2norm, colossalai._C.fused_optim.multi_tensor_l2norm,
dummy_overflow_buf, dummy_overflow_buf,
[grads], [grads],
False # no per-parameter norm False # no per-parameter norm
@ -269,7 +270,8 @@ def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None:
cpu_grads.append(p.grad.detach()) cpu_grads.append(p.grad.detach())
if len(cuda_grads) > 0: if len(cuda_grads) > 0:
dummy_overflow_buf = torch.cuda.IntTensor([0]) dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], clip_coef) multi_tensor_applier(colossalai._C.fused_optim.multi_tensor_scale, dummy_overflow_buf,
[cuda_grads, cuda_grads], clip_coef)
for g in cpu_grads: for g in cpu_grads:
g.mul_(clip_coef) g.mul_(clip_coef)
@ -395,7 +397,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
if enable_cuda_kernels: if enable_cuda_kernels:
grads = [p.grad.detach() for p in params] grads = [p.grad.detach() for p in params]
dummy_overflow_buf = torch.cuda.IntTensor([0]) dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff) multi_tensor_applier(colossalai._C.fused_optim.multi_tensor_scale, dummy_overflow_buf, [grads, grads],
clip_coeff)
else: else:
for p in params: for p in params:
p.grad.detach().mul_(clip_coeff) p.grad.detach().mul_(clip_coeff)

View File

@ -14,7 +14,7 @@ class MultiTensorApply(object):
def __init__(self, chunk_size): def __init__(self, chunk_size):
try: try:
import colossal_C import colossalai._C.fused_optim
MultiTensorApply.available = True MultiTensorApply.available = True
self.chunk_size = chunk_size self.chunk_size = chunk_size
except ImportError as err: except ImportError as err:

View File

@ -1,7 +1,8 @@
import os import os
import subprocess
import re import re
from setuptools import find_packages, setup, Extension import subprocess
from setuptools import Extension, find_packages, setup
# ninja build does not work unless include_dirs are abs path # ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
@ -104,7 +105,7 @@ def get_version():
if build_cuda_ext: if build_cuda_ext:
try: try:
import torch import torch
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CUDAExtension) from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split('.')[1])
@ -148,7 +149,7 @@ if build_cuda_ext:
extra_cuda_flags = ['-lineinfo'] extra_cuda_flags = ['-lineinfo']
ext_modules.append( ext_modules.append(
cuda_ext_helper('colossal_C', [ cuda_ext_helper('colossalai._C.fused_optim', [
'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu', 'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu',
'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu' 'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu'
], extra_cuda_flags + cc_flag)) ], extra_cuda_flags + cc_flag))
@ -159,21 +160,21 @@ if build_cuda_ext:
] ]
ext_modules.append( ext_modules.append(
cuda_ext_helper('colossal_scaled_upper_triang_masked_softmax', cuda_ext_helper('colossalai._C.scaled_upper_triang_masked_softmax',
['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu'], ['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu'],
extra_cuda_flags + cc_flag)) extra_cuda_flags + cc_flag))
ext_modules.append( ext_modules.append(
cuda_ext_helper('colossal_scaled_masked_softmax', cuda_ext_helper('colossalai._C.scaled_masked_softmax',
['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'], extra_cuda_flags + cc_flag)) ['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'], extra_cuda_flags + cc_flag))
ext_modules.append( ext_modules.append(
cuda_ext_helper('colossal_moe_cuda', ['moe_cuda.cpp', 'moe_cuda_kernel.cu'], extra_cuda_flags + cc_flag)) cuda_ext_helper('colossalai._C.moe', ['moe_cuda.cpp', 'moe_cuda_kernel.cu'], extra_cuda_flags + cc_flag))
extra_cuda_flags = ['-maxrregcount=50'] extra_cuda_flags = ['-maxrregcount=50']
ext_modules.append( ext_modules.append(
cuda_ext_helper('colossal_layer_norm_cuda', ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu'], cuda_ext_helper('colossalai._C.layer_norm', ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu'],
extra_cuda_flags + cc_flag)) extra_cuda_flags + cc_flag))
extra_cuda_flags = [ extra_cuda_flags = [
@ -182,54 +183,53 @@ if build_cuda_ext:
] ]
ext_modules.append( ext_modules.append(
cuda_ext_helper('colossal_multihead_attention', [ cuda_ext_helper('colossalai._C.multihead_attention', [
'multihead_attention_1d.cpp', 'kernels/cublas_wrappers.cu', 'kernels/transform_kernels.cu', 'multihead_attention_1d.cpp', 'kernels/cublas_wrappers.cu', 'kernels/transform_kernels.cu',
'kernels/dropout_kernels.cu', 'kernels/normalize_kernels.cu', 'kernels/softmax_kernels.cu', 'kernels/dropout_kernels.cu', 'kernels/normalize_kernels.cu', 'kernels/softmax_kernels.cu',
'kernels/general_kernels.cu', 'kernels/cuda_util.cu' 'kernels/general_kernels.cu', 'kernels/cuda_util.cu'
], extra_cuda_flags + cc_flag)) ], extra_cuda_flags + cc_flag))
extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native'] extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native']
ext_modules.append(cuda_ext_helper('cpu_adam', ['cpu_adam.cpp'], extra_cuda_flags, extra_cxx_flags)) ext_modules.append(cuda_ext_helper('colossalai._C.cpu_optim', ['cpu_adam.cpp'], extra_cuda_flags, extra_cxx_flags))
setup( setup(name='colossalai',
name='colossalai', version=get_version(),
version=get_version(), packages=find_packages(exclude=(
packages=find_packages(exclude=( 'benchmark',
'benchmark', 'docker',
'docker', 'tests',
'tests', 'docs',
'docs', 'examples',
'examples', 'tests',
'tests', 'scripts',
'scripts', 'requirements',
'requirements', '*.egg-info',
'*.egg-info', )),
)), description='An integrated large-scale model training system with efficient parallelization techniques',
description='An integrated large-scale model training system with efficient parallelization techniques', long_description=fetch_readme(),
long_description=fetch_readme(), long_description_content_type='text/markdown',
long_description_content_type='text/markdown', license='Apache Software License 2.0',
license='Apache Software License 2.0', url='https://www.colossalai.org',
url='https://www.colossalai.org', project_urls={
project_urls={ 'Forum': 'https://github.com/hpcaitech/ColossalAI/discussions',
'Forum': 'https://github.com/hpcaitech/ColossalAI/discussions', 'Bug Tracker': 'https://github.com/hpcaitech/ColossalAI/issues',
'Bug Tracker': 'https://github.com/hpcaitech/ColossalAI/issues', 'Examples': 'https://github.com/hpcaitech/ColossalAI-Examples',
'Examples': 'https://github.com/hpcaitech/ColossalAI-Examples', 'Documentation': 'http://colossalai.readthedocs.io',
'Documentation': 'http://colossalai.readthedocs.io', 'Github': 'https://github.com/hpcaitech/ColossalAI',
'Github': 'https://github.com/hpcaitech/ColossalAI', },
}, ext_modules=ext_modules,
ext_modules=ext_modules, cmdclass={'build_ext': BuildExtension} if ext_modules else {},
cmdclass={'build_ext': BuildExtension} if ext_modules else {}, install_requires=fetch_requirements('requirements/requirements.txt'),
install_requires=fetch_requirements('requirements/requirements.txt'), entry_points='''
entry_points='''
[console_scripts] [console_scripts]
colossalai=colossalai.cli:cli colossalai=colossalai.cli:cli
''', ''',
python_requires='>=3.6', python_requires='>=3.6',
classifiers=[ classifiers=[
'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3',
'License :: OSI Approved :: Apache Software License', 'License :: OSI Approved :: Apache Software License',
'Environment :: GPU :: NVIDIA CUDA', 'Environment :: GPU :: NVIDIA CUDA',
'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: System :: Distributed Computing', 'Topic :: System :: Distributed Computing',
], ],
) package_data={'colossalai': ['_C/*.pyi']})

View File

@ -1,4 +1,5 @@
import math import math
import torch import torch
from colossalai.testing import parameterize from colossalai.testing import parameterize
@ -66,8 +67,8 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype):
exp_avg_sq_copy = exp_avg_sq.clone() exp_avg_sq_copy = exp_avg_sq.clone()
try: try:
import cpu_adam import colossalai._C.cpu_optim
cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw) cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
except: except:
raise ImportError("Import cpu adam error, please install colossal from source code") raise ImportError("Import cpu adam error, please install colossal from source code")

View File

@ -1,8 +1,8 @@
from numpy import dtype import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from numpy import dtype
import math
from colossalai.testing import parameterize from colossalai.testing import parameterize
from colossalai.utils import multi_tensor_applier from colossalai.utils import multi_tensor_applier
@ -47,11 +47,11 @@ def torch_adam_update(
@parameterize('g_dtype', [torch.float, torch.half]) @parameterize('g_dtype', [torch.float, torch.half])
def test_adam(adamw, step, p_dtype, g_dtype): def test_adam(adamw, step, p_dtype, g_dtype):
try: try:
import colossal_C import colossalai._C.fused_optim
fused_adam = colossal_C.multi_tensor_adam fused_adam = colossalai._C.fused_optim.multi_tensor_adam
dummy_overflow_buf = torch.cuda.IntTensor([0]) dummy_overflow_buf = torch.cuda.IntTensor([0])
except: except:
raise ImportError("No colossal_C kernel installed.") raise ImportError("No colossalai._C.fused_optim kernel installed.")
count = 0 count = 0