mirror of https://github.com/hpcaitech/ColossalAI
[kernel] move all symlinks of kernel to `colossalai._C` (#1971)
parent
7e24b9b9ee
commit
f8a7148dec
|
@ -38,7 +38,6 @@ jobs:
|
||||||
pip install -r requirements/requirements.txt
|
pip install -r requirements/requirements.txt
|
||||||
pip install -v -e .
|
pip install -v -e .
|
||||||
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
|
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
|
||||||
cp /__w/ColossalAI/ColossalAI/*.so /github/home/cuda_ext_cache/
|
|
||||||
pip install -r requirements/requirements-test.txt
|
pip install -r requirements/requirements-test.txt
|
||||||
- name: Unit Testing
|
- name: Unit Testing
|
||||||
run: |
|
run: |
|
||||||
|
|
|
@ -36,7 +36,6 @@ jobs:
|
||||||
pip install -r requirements/requirements.txt
|
pip install -r requirements/requirements.txt
|
||||||
pip install -v -e .
|
pip install -v -e .
|
||||||
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
|
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
|
||||||
cp /__w/ColossalAI/ColossalAI/*.so /github/home/cuda_ext_cache/
|
|
||||||
pip install -r requirements/requirements-test.txt
|
pip install -r requirements/requirements-test.txt
|
||||||
- name: Unit Testing
|
- name: Unit Testing
|
||||||
run: |
|
run: |
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
include *.txt README.md
|
include *.txt README.md
|
||||||
recursive-include requirements *.txt
|
recursive-include requirements *.txt
|
||||||
recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc
|
recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
from . import (
|
||||||
|
cpu_optim,
|
||||||
|
fused_optim,
|
||||||
|
layer_norm,
|
||||||
|
moe,
|
||||||
|
multihead_attention,
|
||||||
|
scaled_masked_softmax,
|
||||||
|
scaled_upper_triang_masked_softmax,
|
||||||
|
)
|
|
@ -0,0 +1,8 @@
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
class CPUAdamOptimizer:
|
||||||
|
def __init__(self, lr: float, beta1: float, beta2: float, eps: float,
|
||||||
|
weight_decay: float, adamw_mode: float) -> None: ...
|
||||||
|
|
||||||
|
def step(self, step: int, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, bias_correction: bool,
|
||||||
|
param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor, loss_scale: float) -> None: ...
|
|
@ -0,0 +1,23 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
def multi_tensor_scale(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], scale: float) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def multi_tensor_sgd(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], weight_decay: float,
|
||||||
|
momentum: float, dampening: float, lr: float, nesterov: bool, first_run: bool, weight_decay_after_momentum: bool, scale: float) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def multi_tensor_adam(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, mode: int, bias_correction: int, weight_decay: float) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def multi_tensor_lamb(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, bias_correction: int, weight_decay: float, grad_averaging: int, mode: int, global_grad_norm: Tensor, max_grad_norm: float, use_nvlamb_python: bool) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def multi_tensor_l2norm(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], per_tensor_python: bool) -> None:
|
||||||
|
...
|
|
@ -0,0 +1,11 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
def forward_affine(input: Tensor, normalized_shape: List[int], gamma: Tensor, beta: Tensor, epsilon: float) -> List[Tensor]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def backward_affine(dout: Tensor, mean: Tensor, invvar: Tensor, input: Tensor,
|
||||||
|
normalized_shape: List[int], gamma: Tensor, beta: Tensor, epsilon: float) -> List[Tensor]:
|
||||||
|
...
|
|
@ -0,0 +1,20 @@
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
def cumsum_sub_one(mask: Tensor) -> Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_forward(s: int, ec: int, h: int, batch_tokens: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_backward(s: int, ec: int, h: int, expert_grad: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def combine_forward(s: int, e: int, c: int, h: int, expert_tokens: Tensor, logits: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def combine_backward(s: int, e: int, c: int, h: int, tokens_grad: Tensor, expert_tokens: Tensor, logits: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
|
||||||
|
...
|
|
@ -0,0 +1,55 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
def multihead_attention_fw_fp32(layer_id: int, input: Tensor, input_mask: Tensor,
|
||||||
|
in_proj_weight: Tensor, in_proj_bias: Tensor,
|
||||||
|
out_proj_weight: Tensor, out_proj_bias: Tensor,
|
||||||
|
norm_weight: Tensor, norm_bias: Tensor,
|
||||||
|
training_mode: bool, prelayernorm: bool) -> List[Tensor]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def multihead_attention_fw_fp16(layer_id: int, input: Tensor, input_mask: Tensor,
|
||||||
|
in_proj_weight: Tensor, in_proj_bias: Tensor,
|
||||||
|
out_proj_weight: Tensor, out_proj_bias: Tensor,
|
||||||
|
norm_weight: Tensor, norm_bias: Tensor,
|
||||||
|
training_mode: bool, prelayernorm: bool) -> List[Tensor]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def multihead_attention_bw_fp32(layer_id: int, grad_dec_output: Tensor,
|
||||||
|
output: Tensor, input: Tensor,
|
||||||
|
input_mask: Tensor, in_proj_weight: Tensor,
|
||||||
|
in_proj_bias: Tensor, out_proj_weight: Tensor,
|
||||||
|
out_proj_bias: Tensor, norm_weight: Tensor,
|
||||||
|
norm_bias: Tensor) -> List[Tensor]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def multihead_attention_bw_fp16(layer_id: int, grad_dec_output: Tensor,
|
||||||
|
output: Tensor, input: Tensor,
|
||||||
|
input_mask: Tensor, in_proj_weight: Tensor,
|
||||||
|
in_proj_bias: Tensor, out_proj_weight: Tensor,
|
||||||
|
out_proj_bias: Tensor, norm_weight: Tensor,
|
||||||
|
norm_bias: Tensor) -> List[Tensor]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def create_multihead_attention_fp32(layer_id: int, max_batch_tokens: int,
|
||||||
|
max_seq_len: int, hidden_dim: int, num_heads: int,
|
||||||
|
attn_prob_dropout_ratio: float,
|
||||||
|
hidden_dropout_ratio: float,
|
||||||
|
pre_or_postLayerNorm: bool,
|
||||||
|
pg: ProcessGroup) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def create_multihead_attention_fp16(layer_id: int, max_batch_tokens: int,
|
||||||
|
max_seq_len: int, hidden_dim: int, num_heads: int,
|
||||||
|
attn_prob_dropout_ratio: float,
|
||||||
|
hidden_dropout_ratio: float,
|
||||||
|
pre_or_postLayerNorm: bool,
|
||||||
|
pg: ProcessGroup) -> int:
|
||||||
|
...
|
|
@ -0,0 +1,12 @@
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
def forward(input: Tensor, mask: Tensor, scale: float) -> Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def backward(output_grads: Tensor, softmax_results: Tensor, scale: float) -> Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def get_batch_per_block(query_seq_len: int, key_seq_len: int, batches: int, attn_heads: int) -> int:
|
||||||
|
...
|
|
@ -0,0 +1,8 @@
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
def forward(input: Tensor, scale: float) -> Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def backward(output_grads: Tensor, softmax_results: Tensor, scale: float) -> Tensor:
|
||||||
|
...
|
|
@ -5,7 +5,7 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import colossal_C
|
import colossalai._C.fused_optim
|
||||||
except:
|
except:
|
||||||
print('Colossalai should be built with cuda extension to use the FP16 optimizer')
|
print('Colossalai should be built with cuda extension to use the FP16 optimizer')
|
||||||
|
|
||||||
|
@ -35,7 +35,7 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
|
||||||
if overflow_buf:
|
if overflow_buf:
|
||||||
overflow_buf.fill_(0)
|
overflow_buf.fill_(0)
|
||||||
# Scaling with factor `1.0` is equivalent to copy.
|
# Scaling with factor `1.0` is equivalent to copy.
|
||||||
multi_tensor_applier(colossal_C.multi_tensor_scale, overflow_buf, [this, that], 1.0)
|
multi_tensor_applier(colossalai._C.fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0)
|
||||||
else:
|
else:
|
||||||
for this_, that_ in zip(this, that):
|
for this_, that_ in zip(this, that):
|
||||||
that_.copy_(this_)
|
that_.copy_(this_)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import click
|
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
|
import click
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.cpp_extension import CUDA_HOME
|
from torch.utils.cpp_extension import CUDA_HOME
|
||||||
|
|
||||||
|
@ -17,7 +18,7 @@ def check_installation():
|
||||||
|
|
||||||
def _check_cuda_extension_installed():
|
def _check_cuda_extension_installed():
|
||||||
try:
|
try:
|
||||||
import colossal_C
|
import colossalai._C.fused_optim
|
||||||
is_cuda_extension_installed = u'\u2713'
|
is_cuda_extension_installed = u'\u2713'
|
||||||
except ImportError:
|
except ImportError:
|
||||||
is_cuda_extension_installed = 'x'
|
is_cuda_extension_installed = 'x'
|
||||||
|
|
|
@ -3,14 +3,11 @@
|
||||||
with some changes. """
|
with some changes. """
|
||||||
|
|
||||||
import numbers
|
import numbers
|
||||||
import torch
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
from torch.nn import init
|
|
||||||
from torch.cuda.amp import custom_fwd, custom_bwd
|
|
||||||
import importlib
|
|
||||||
|
|
||||||
global colossal_layer_norm_cuda
|
import torch
|
||||||
colossal_layer_norm_cuda = None
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
|
from torch.nn import init
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
|
||||||
class FusedLayerNormAffineFunction(torch.autograd.Function):
|
class FusedLayerNormAffineFunction(torch.autograd.Function):
|
||||||
|
@ -18,13 +15,17 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float32)
|
@custom_fwd(cast_inputs=torch.float32)
|
||||||
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
||||||
|
try:
|
||||||
|
import colossalai._C.layer_norm
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError('FusedLayerNormAffineFunction requires cuda extensions')
|
||||||
|
|
||||||
ctx.normalized_shape = normalized_shape
|
ctx.normalized_shape = normalized_shape
|
||||||
ctx.eps = eps
|
ctx.eps = eps
|
||||||
input_ = input.contiguous()
|
input_ = input.contiguous()
|
||||||
weight_ = weight.contiguous()
|
weight_ = weight.contiguous()
|
||||||
bias_ = bias.contiguous()
|
bias_ = bias.contiguous()
|
||||||
output, mean, invvar = colossal_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, bias_,
|
output, mean, invvar = colossalai._C.layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_,
|
||||||
ctx.eps)
|
ctx.eps)
|
||||||
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
|
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
|
||||||
|
|
||||||
|
@ -33,11 +34,15 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_bwd
|
@custom_bwd
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
|
try:
|
||||||
|
import colossalai._C.layer_norm
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError('FusedLayerNormAffineFunction requires cuda extensions')
|
||||||
|
|
||||||
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
||||||
grad_input = grad_weight = grad_bias = None
|
grad_input = grad_weight = grad_bias = None
|
||||||
grad_input, grad_weight, grad_bias \
|
grad_input, grad_weight, grad_bias \
|
||||||
= colossal_layer_norm_cuda.backward_affine(
|
= colossalai._C.layer_norm.backward_affine(
|
||||||
grad_output.contiguous(), mean, invvar,
|
grad_output.contiguous(), mean, invvar,
|
||||||
input_, ctx.normalized_shape,
|
input_, ctx.normalized_shape,
|
||||||
weight_, bias_, ctx.eps)
|
weight_, bias_, ctx.eps)
|
||||||
|
@ -50,13 +55,6 @@ class MixedFusedLayerNorm(torch.nn.Module):
|
||||||
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None):
|
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None):
|
||||||
super(MixedFusedLayerNorm, self).__init__()
|
super(MixedFusedLayerNorm, self).__init__()
|
||||||
|
|
||||||
global colossal_layer_norm_cuda
|
|
||||||
if colossal_layer_norm_cuda is None:
|
|
||||||
try:
|
|
||||||
colossal_layer_norm_cuda = importlib.import_module("colossal_layer_norm_cuda")
|
|
||||||
except ImportError:
|
|
||||||
raise RuntimeError('MixedFusedLayerNorm requires cuda extensions')
|
|
||||||
|
|
||||||
if isinstance(normalized_shape, numbers.Integral):
|
if isinstance(normalized_shape, numbers.Integral):
|
||||||
normalized_shape = (normalized_shape,)
|
normalized_shape = (normalized_shape,)
|
||||||
self.normalized_shape = torch.Size(normalized_shape)
|
self.normalized_shape = torch.Size(normalized_shape)
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import math
|
import math
|
||||||
import importlib
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -37,21 +36,21 @@ colossal_multihead_attention = None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Config:
|
class Config:
|
||||||
max_batch_tokens: int # max batch token numbers
|
max_batch_tokens: int # max batch token numbers
|
||||||
max_seq_len: int # max sequence length
|
max_seq_len: int # max sequence length
|
||||||
hidden_size: int # size of transformer hidden layers
|
hidden_size: int # size of transformer hidden layers
|
||||||
nhead: int # number of heads in attention
|
nhead: int # number of heads in attention
|
||||||
attn_prob_dropout_ratio: float # attention score dropout ratio
|
attn_prob_dropout_ratio: float # attention score dropout ratio
|
||||||
hidden_dropout_ratio: float # dropout ration before residual
|
hidden_dropout_ratio: float # dropout ration before residual
|
||||||
norm_first: bool # norm_first
|
norm_first: bool # norm_first
|
||||||
fp16: bool # fp16 presion
|
fp16: bool # fp16 presion
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention1DFunc(Function):
|
class MultiHeadAttention1DFunc(Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight,
|
def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, out_proj_bias, norm_weight,
|
||||||
out_proj_bias, norm_weight, norm_bias, config):
|
norm_bias, config):
|
||||||
cuda_module = colossal_multihead_attention
|
cuda_module = colossal_multihead_attention
|
||||||
forward_func = (cuda_module.multihead_attention_fw_fp16
|
forward_func = (cuda_module.multihead_attention_fw_fp16
|
||||||
if config.fp16 else cuda_module.multihead_attention_fw_fp32)
|
if config.fp16 else cuda_module.multihead_attention_fw_fp32)
|
||||||
|
@ -59,13 +58,12 @@ class MultiHeadAttention1DFunc(Function):
|
||||||
input = input.to(torch.half)
|
input = input.to(torch.half)
|
||||||
input_mask = input_mask.to(torch.half)
|
input_mask = input_mask.to(torch.half)
|
||||||
|
|
||||||
(output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias,
|
(output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight,
|
||||||
out_proj_weight, out_proj_bias, norm_weight, norm_bias,
|
out_proj_bias, norm_weight, norm_bias, config.training, config.norm_first)
|
||||||
config.training, config.norm_first)
|
|
||||||
|
|
||||||
if config.is_grad_enabled and config.training:
|
if config.is_grad_enabled and config.training:
|
||||||
ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias,
|
ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight,
|
||||||
out_proj_weight, out_proj_bias, norm_weight, norm_bias)
|
out_proj_bias, norm_weight, norm_bias)
|
||||||
ctx.config = config
|
ctx.config = config
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -98,8 +96,8 @@ class MultiHeadAttention1DFunc(Function):
|
||||||
ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight,
|
ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight,
|
||||||
in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias)
|
in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias)
|
||||||
|
|
||||||
return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight,
|
return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias,
|
||||||
grad_out_proj_bias, grad_norm_weight, grad_norm_bias, None)
|
grad_norm_weight, grad_norm_bias, None)
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
|
@ -121,19 +119,11 @@ class MultiHeadAttention(nn.Module):
|
||||||
|
|
||||||
layer_id = 0
|
layer_id = 0
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None):
|
||||||
hidden_size,
|
|
||||||
nhead,
|
|
||||||
batch_size,
|
|
||||||
max_seq_len,
|
|
||||||
dropout=0.0,
|
|
||||||
norm_first=False,
|
|
||||||
fp16=True,
|
|
||||||
pg=None):
|
|
||||||
super(MultiHeadAttention, self).__init__()
|
super(MultiHeadAttention, self).__init__()
|
||||||
|
|
||||||
self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout,
|
self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first,
|
||||||
dropout, norm_first, fp16)
|
fp16)
|
||||||
check_config(self.config)
|
check_config(self.config)
|
||||||
self.pg = pg
|
self.pg = pg
|
||||||
self.pg_size = 1
|
self.pg_size = 1
|
||||||
|
@ -146,7 +136,8 @@ class MultiHeadAttention(nn.Module):
|
||||||
global colossal_multihead_attention
|
global colossal_multihead_attention
|
||||||
if colossal_multihead_attention is None:
|
if colossal_multihead_attention is None:
|
||||||
try:
|
try:
|
||||||
colossal_multihead_attention = importlib.import_module("colossal_multihead_attention")
|
import colossalai._C.multihead_attention
|
||||||
|
colossal_multihead_attention = colossalai._C.multihead_attention
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise RuntimeError('MultiHeadAttention requires cuda extensions')
|
raise RuntimeError('MultiHeadAttention requires cuda extensions')
|
||||||
|
|
||||||
|
@ -215,14 +206,13 @@ class MultiHeadAttention(nn.Module):
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.in_proj_weight.copy_(
|
self.in_proj_weight.copy_(
|
||||||
attn_qkvw_global.view(3, hs, hs)[
|
attn_qkvw_global.view(3, hs, hs)[:,
|
||||||
:, int(hs * rank_in_pg / self.pg_size):
|
int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) /
|
||||||
int(hs * (rank_in_pg + 1) / self.pg_size),
|
self.pg_size), :])
|
||||||
:])
|
|
||||||
self.in_proj_bias.copy_(
|
self.in_proj_bias.copy_(
|
||||||
attn_qkvb_global.view(3, hs)[
|
attn_qkvb_global.view(3, hs)[:,
|
||||||
:, int(hs * rank_in_pg / self.pg_size):
|
int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) /
|
||||||
int(hs * (rank_in_pg + 1) / self.pg_size)])
|
self.pg_size)])
|
||||||
|
|
||||||
attn_ow_global = torch.empty(hs, hs)
|
attn_ow_global = torch.empty(hs, hs)
|
||||||
nn.init.xavier_uniform_(attn_ow_global, 1.0)
|
nn.init.xavier_uniform_(attn_ow_global, 1.0)
|
||||||
|
@ -230,9 +220,9 @@ class MultiHeadAttention(nn.Module):
|
||||||
torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg)
|
torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg)
|
||||||
attn_ow_global = attn_ow_global.cpu()
|
attn_ow_global = attn_ow_global.cpu()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.out_proj_weight.copy_(attn_ow_global[
|
self.out_proj_weight.copy_(attn_ow_global[:,
|
||||||
:, int(hs * rank_in_pg / self.pg_size):
|
int(hs * rank_in_pg /
|
||||||
int(hs * (rank_in_pg + 1) / self.pg_size)])
|
self.pg_size):int(hs * (rank_in_pg + 1) / self.pg_size)])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
attn_qkvw = self.in_proj_weight.view(-1, hs)
|
attn_qkvw = self.in_proj_weight.view(-1, hs)
|
||||||
|
@ -243,10 +233,7 @@ class MultiHeadAttention(nn.Module):
|
||||||
nn.init.xavier_uniform_(self.out_proj_weight, 1.0)
|
nn.init.xavier_uniform_(self.out_proj_weight, 1.0)
|
||||||
|
|
||||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||||
destination = torch.nn.Module.state_dict(self,
|
destination = torch.nn.Module.state_dict(self, destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||||
destination=destination,
|
|
||||||
prefix=prefix,
|
|
||||||
keep_vars=keep_vars)
|
|
||||||
return destination
|
return destination
|
||||||
|
|
||||||
def forward(self, hidden_states, encoder_padding_mask):
|
def forward(self, hidden_states, encoder_padding_mask):
|
||||||
|
@ -257,8 +244,7 @@ class MultiHeadAttention(nn.Module):
|
||||||
|
|
||||||
bs, sl, dim = hidden_states.size()
|
bs, sl, dim = hidden_states.size()
|
||||||
if bs * sl > self.config.max_batch_tokens:
|
if bs * sl > self.config.max_batch_tokens:
|
||||||
raise ValueError(
|
raise ValueError(f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.")
|
||||||
f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.")
|
|
||||||
if sl > self.config.max_seq_len:
|
if sl > self.config.max_seq_len:
|
||||||
raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.")
|
raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.")
|
||||||
if len(encoder_padding_mask.size()) == 1:
|
if len(encoder_padding_mask.size()) == 1:
|
||||||
|
@ -266,9 +252,8 @@ class MultiHeadAttention(nn.Module):
|
||||||
else:
|
else:
|
||||||
assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1)
|
assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1)
|
||||||
|
|
||||||
output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask,
|
output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask, self.in_proj_weight,
|
||||||
self.in_proj_weight, self.in_proj_bias,
|
self.in_proj_bias, self.out_proj_weight, self.out_proj_bias,
|
||||||
self.out_proj_weight, self.out_proj_bias,
|
|
||||||
self.norm_weight, self.norm_bias, self.config)
|
self.norm_weight, self.norm_bias, self.config)
|
||||||
|
|
||||||
return output.to(self.precision)
|
return output.to(self.precision)
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
"""This code from NVIDIA Megatron
|
"""This code from NVIDIA Megatron
|
||||||
with some changes. """
|
with some changes. """
|
||||||
|
|
||||||
|
import enum
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import enum
|
|
||||||
|
|
||||||
|
|
||||||
class AttnMaskType(enum.Enum):
|
class AttnMaskType(enum.Enum):
|
||||||
|
@ -23,12 +24,12 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, inputs, scale):
|
def forward(ctx, inputs, scale):
|
||||||
try:
|
try:
|
||||||
import colossal_scaled_upper_triang_masked_softmax
|
import colossalai._C.scaled_upper_triang_masked_softmax
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
|
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
|
||||||
|
|
||||||
scale_t = torch.tensor([scale])
|
scale_t = torch.tensor([scale])
|
||||||
softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
|
softmax_results = colossalai._C.scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
|
||||||
|
|
||||||
ctx.save_for_backward(softmax_results, scale_t)
|
ctx.save_for_backward(softmax_results, scale_t)
|
||||||
return softmax_results
|
return softmax_results
|
||||||
|
@ -36,12 +37,13 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, output_grads):
|
def backward(ctx, output_grads):
|
||||||
try:
|
try:
|
||||||
import colossal_scaled_upper_triang_masked_softmax
|
import colossalai._C.scaled_upper_triang_masked_softmax
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
|
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
|
||||||
|
|
||||||
softmax_results, scale_t = ctx.saved_tensors
|
softmax_results, scale_t = ctx.saved_tensors
|
||||||
input_grads = colossal_scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
input_grads = colossalai._C.scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results,
|
||||||
|
scale_t[0])
|
||||||
|
|
||||||
return input_grads, None
|
return input_grads, None
|
||||||
|
|
||||||
|
@ -58,26 +60,26 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, inputs, mask, scale):
|
def forward(ctx, inputs, mask, scale):
|
||||||
try:
|
try:
|
||||||
import colossal_scaled_masked_softmax
|
import colossalai._C.scaled_masked_softmax
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
|
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
|
||||||
|
|
||||||
scale_t = torch.tensor([scale])
|
scale_t = torch.tensor([scale])
|
||||||
|
|
||||||
softmax_results = colossal_scaled_masked_softmax.forward(inputs, mask, scale_t[0])
|
softmax_results = colossalai._C.scaled_masked_softmax.forward(inputs, mask, scale_t[0])
|
||||||
ctx.save_for_backward(softmax_results, scale_t)
|
ctx.save_for_backward(softmax_results, scale_t)
|
||||||
return softmax_results
|
return softmax_results
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, output_grads):
|
def backward(ctx, output_grads):
|
||||||
try:
|
try:
|
||||||
import colossal_scaled_masked_softmax
|
import colossalai._C.scaled_masked_softmax
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
|
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
|
||||||
|
|
||||||
softmax_results, scale_t = ctx.saved_tensors
|
softmax_results, scale_t = ctx.saved_tensors
|
||||||
|
|
||||||
input_grads = colossal_scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
input_grads = colossalai._C.scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
||||||
return input_grads, None, None
|
return input_grads, None, None
|
||||||
|
|
||||||
|
|
||||||
|
@ -184,8 +186,8 @@ class FusedScaleMaskSoftmax(nn.Module):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_batch_per_block(sq, sk, b, np):
|
def get_batch_per_block(sq, sk, b, np):
|
||||||
try:
|
try:
|
||||||
import colossal_scaled_masked_softmax
|
import colossalai._C.scaled_masked_softmax
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
|
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
|
||||||
|
|
||||||
return colossal_scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
|
return colossalai._C.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
|
||||||
|
|
|
@ -1,153 +1,154 @@
|
||||||
import torch
|
from typing import Any, Optional, Tuple
|
||||||
import torch.distributed as dist
|
|
||||||
from torch import Tensor
|
import torch
|
||||||
from typing import Any, Tuple, Optional
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch import Tensor
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
COL_MOE_KERNEL_FLAG = False
|
|
||||||
try:
|
COL_MOE_KERNEL_FLAG = False
|
||||||
import colossal_moe_cuda
|
try:
|
||||||
|
import colossalai._C.moe
|
||||||
COL_MOE_KERNEL_FLAG = True
|
|
||||||
except ImportError:
|
COL_MOE_KERNEL_FLAG = True
|
||||||
print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
|
except ImportError:
|
||||||
|
print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
|
||||||
|
|
||||||
class AllGather(torch.autograd.Function):
|
|
||||||
|
class AllGather(torch.autograd.Function):
|
||||||
@staticmethod
|
|
||||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
@staticmethod
|
||||||
if ctx is not None:
|
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||||
ctx.comm_grp = group
|
if ctx is not None:
|
||||||
|
ctx.comm_grp = group
|
||||||
comm_size = dist.get_world_size(group)
|
|
||||||
if comm_size == 1:
|
comm_size = dist.get_world_size(group)
|
||||||
return inputs.unsqueeze(0)
|
if comm_size == 1:
|
||||||
|
return inputs.unsqueeze(0)
|
||||||
buffer_shape = (comm_size,) + inputs.shape
|
|
||||||
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
|
buffer_shape = (comm_size,) + inputs.shape
|
||||||
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
|
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
|
||||||
dist.all_gather(buffer_list, inputs, group=group)
|
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
|
||||||
return outputs
|
dist.all_gather(buffer_list, inputs, group=group)
|
||||||
|
return outputs
|
||||||
@staticmethod
|
|
||||||
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
@staticmethod
|
||||||
return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None
|
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||||
|
return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None
|
||||||
|
|
||||||
class ReduceScatter(torch.autograd.Function):
|
|
||||||
|
class ReduceScatter(torch.autograd.Function):
|
||||||
@staticmethod
|
|
||||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
@staticmethod
|
||||||
if ctx is not None:
|
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||||
ctx.comm_grp = group
|
if ctx is not None:
|
||||||
|
ctx.comm_grp = group
|
||||||
comm_size = dist.get_world_size(group)
|
|
||||||
if comm_size == 1:
|
comm_size = dist.get_world_size(group)
|
||||||
return inputs.squeeze(0)
|
if comm_size == 1:
|
||||||
|
return inputs.squeeze(0)
|
||||||
if not inputs.is_contiguous():
|
|
||||||
inputs = inputs.contiguous()
|
if not inputs.is_contiguous():
|
||||||
|
inputs = inputs.contiguous()
|
||||||
output_shape = inputs.shape[1:]
|
|
||||||
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
|
output_shape = inputs.shape[1:]
|
||||||
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
|
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
|
||||||
dist.reduce_scatter(outputs, buffer_list, group=group)
|
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
|
||||||
return outputs
|
dist.reduce_scatter(outputs, buffer_list, group=group)
|
||||||
|
return outputs
|
||||||
@staticmethod
|
|
||||||
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
@staticmethod
|
||||||
return AllGather.forward(None, grad_outputs, ctx.comm_grp), None
|
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||||
|
return AllGather.forward(None, grad_outputs, ctx.comm_grp), None
|
||||||
|
|
||||||
class AllToAll(torch.autograd.Function):
|
|
||||||
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
|
class AllToAll(torch.autograd.Function):
|
||||||
operation in torch.distributed.
|
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
|
||||||
"""
|
operation in torch.distributed.
|
||||||
|
"""
|
||||||
@staticmethod
|
|
||||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
@staticmethod
|
||||||
if ctx is not None:
|
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||||
ctx.comm_grp = group
|
if ctx is not None:
|
||||||
if not inputs.is_contiguous():
|
ctx.comm_grp = group
|
||||||
inputs = inputs.contiguous()
|
if not inputs.is_contiguous():
|
||||||
if dist.get_world_size(group) == 1:
|
inputs = inputs.contiguous()
|
||||||
return inputs
|
if dist.get_world_size(group) == 1:
|
||||||
output = torch.empty_like(inputs)
|
return inputs
|
||||||
dist.all_to_all_single(output, inputs, group=group)
|
output = torch.empty_like(inputs)
|
||||||
return output
|
dist.all_to_all_single(output, inputs, group=group)
|
||||||
|
return output
|
||||||
@staticmethod
|
|
||||||
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
@staticmethod
|
||||||
return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None
|
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||||
|
return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None
|
||||||
|
|
||||||
class MoeDispatch(torch.autograd.Function):
|
|
||||||
|
class MoeDispatch(torch.autograd.Function):
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, tokens, mask, dest_idx, ec):
|
@staticmethod
|
||||||
s = tokens.size(0)
|
def forward(ctx, tokens, mask, dest_idx, ec):
|
||||||
h = tokens.size(1)
|
s = tokens.size(0)
|
||||||
|
h = tokens.size(1)
|
||||||
expert_input = colossal_moe_cuda.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
|
|
||||||
|
expert_input = colossalai._C.moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
|
||||||
ctx.save_for_backward(mask, dest_idx)
|
|
||||||
ctx.s = s
|
ctx.save_for_backward(mask, dest_idx)
|
||||||
ctx.h = h
|
ctx.s = s
|
||||||
ctx.ec = ec
|
ctx.h = h
|
||||||
|
ctx.ec = ec
|
||||||
return expert_input
|
|
||||||
|
return expert_input
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, output_grad):
|
@staticmethod
|
||||||
mask, dest_idx = ctx.saved_tensors
|
def backward(ctx, output_grad):
|
||||||
d_tokens = colossal_moe_cuda.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
|
mask, dest_idx = ctx.saved_tensors
|
||||||
return d_tokens, None, None, None
|
d_tokens = colossalai._C.moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
|
||||||
|
return d_tokens, None, None, None
|
||||||
|
|
||||||
class MoeCombine(torch.autograd.Function):
|
|
||||||
|
class MoeCombine(torch.autograd.Function):
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
|
@staticmethod
|
||||||
assert logits.dtype == torch.float32
|
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
|
||||||
|
assert logits.dtype == torch.float32
|
||||||
s = logits.size(0)
|
|
||||||
e = logits.size(1)
|
s = logits.size(0)
|
||||||
c = ec // e
|
e = logits.size(1)
|
||||||
h = expert_tokens.size(-1)
|
c = ec // e
|
||||||
|
h = expert_tokens.size(-1)
|
||||||
fp16_flag = (expert_tokens.dtype == torch.float16)
|
|
||||||
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
|
fp16_flag = (expert_tokens.dtype == torch.float16)
|
||||||
ctokens = colossal_moe_cuda.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
|
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
|
||||||
output = ctokens.to(torch.float16) if fp16_flag else ctokens
|
ctokens = colossalai._C.moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
|
||||||
|
output = ctokens.to(torch.float16) if fp16_flag else ctokens
|
||||||
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
|
|
||||||
ctx.s = s
|
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
|
||||||
ctx.e = e
|
ctx.s = s
|
||||||
ctx.c = c
|
ctx.e = e
|
||||||
ctx.h = h
|
ctx.c = c
|
||||||
ctx.fp16_flag = fp16_flag
|
ctx.h = h
|
||||||
|
ctx.fp16_flag = fp16_flag
|
||||||
return output
|
|
||||||
|
return output
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, tokens_grad):
|
@staticmethod
|
||||||
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
|
def backward(ctx, tokens_grad):
|
||||||
|
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
|
||||||
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
|
|
||||||
else tokens_grad
|
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
|
||||||
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
|
else tokens_grad
|
||||||
d_expert, d_logits = colossal_moe_cuda.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits,
|
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
|
||||||
mask, dest_idx)
|
d_expert, d_logits = colossalai._C.moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits,
|
||||||
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
|
mask, dest_idx)
|
||||||
|
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
|
||||||
return d_expert, d_logits, None, None, None
|
|
||||||
|
return d_expert, d_logits, None, None, None
|
||||||
|
|
||||||
def moe_cumsum(inputs: Tensor):
|
|
||||||
dim0 = inputs.size(0)
|
def moe_cumsum(inputs: Tensor):
|
||||||
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
|
dim0 = inputs.size(0)
|
||||||
if flag and COL_MOE_KERNEL_FLAG:
|
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
|
||||||
return colossal_moe_cuda.cumsum_sub_one(inputs)
|
if flag and COL_MOE_KERNEL_FLAG:
|
||||||
else:
|
return colossalai._C.moe.cumsum_sub_one(inputs)
|
||||||
return torch.cumsum(inputs, dim=0) - 1
|
else:
|
||||||
|
return torch.cumsum(inputs, dim=0) - 1
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
import math
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.registry import OPTIMIZERS
|
from colossalai.registry import OPTIMIZERS
|
||||||
|
|
||||||
from .nvme_optimizer import NVMeOptimizer
|
from .nvme_optimizer import NVMeOptimizer
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
@OPTIMIZERS.register_module
|
@OPTIMIZERS.register_module
|
||||||
|
@ -11,7 +13,7 @@ class CPUAdam(NVMeOptimizer):
|
||||||
"""Implements Adam algorithm.
|
"""Implements Adam algorithm.
|
||||||
|
|
||||||
Supports parameters updating on both GPU and CPU, depanding on the device of paramters.
|
Supports parameters updating on both GPU and CPU, depanding on the device of paramters.
|
||||||
But the parameters and gradients should on the same device:
|
But the parameters and gradients should on the same device:
|
||||||
* Parameters on CPU and gradients on CPU is allowed.
|
* Parameters on CPU and gradients on CPU is allowed.
|
||||||
* Parameters on GPU and gradients on GPU is allowed.
|
* Parameters on GPU and gradients on GPU is allowed.
|
||||||
* Parameters on GPU and gradients on CPU is **not** allowed.
|
* Parameters on GPU and gradients on CPU is **not** allowed.
|
||||||
|
@ -44,7 +46,7 @@ class CPUAdam(NVMeOptimizer):
|
||||||
(default: False) NOT SUPPORTED yet in CPUAdam!
|
(default: False) NOT SUPPORTED yet in CPUAdam!
|
||||||
adamw_mode (boolean, optional): Apply L2 regularization or weight decay
|
adamw_mode (boolean, optional): Apply L2 regularization or weight decay
|
||||||
True for decoupled weight decay(also known as AdamW) (default: True)
|
True for decoupled weight decay(also known as AdamW) (default: True)
|
||||||
simd_log (boolean, optional): whether to show if you are using SIMD to
|
simd_log (boolean, optional): whether to show if you are using SIMD to
|
||||||
accelerate. (default: False)
|
accelerate. (default: False)
|
||||||
nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0.
|
nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0.
|
||||||
nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files.
|
nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files.
|
||||||
|
@ -75,10 +77,11 @@ class CPUAdam(NVMeOptimizer):
|
||||||
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
||||||
self.adamw_mode = adamw_mode
|
self.adamw_mode = adamw_mode
|
||||||
try:
|
try:
|
||||||
import cpu_adam
|
import colossalai._C.cpu_optim
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError('Please install colossalai from source code to use CPUAdam')
|
raise ImportError('Please install colossalai from source code to use CPUAdam')
|
||||||
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
self.cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay,
|
||||||
|
adamw_mode)
|
||||||
|
|
||||||
def torch_adam_update(self,
|
def torch_adam_update(self,
|
||||||
data,
|
data,
|
||||||
|
|
|
@ -20,7 +20,7 @@ class FusedAdam(torch.optim.Optimizer):
|
||||||
:class:`colossalai.nn.optimizer.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
|
:class:`colossalai.nn.optimizer.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
|
||||||
or ``torch.optim.Adam`` with ``adamw_mode=False``
|
or ``torch.optim.Adam`` with ``adamw_mode=False``
|
||||||
|
|
||||||
:class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp.
|
:class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp.
|
||||||
|
|
||||||
Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
|
Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
|
||||||
|
|
||||||
|
@ -65,10 +65,11 @@ class FusedAdam(torch.optim.Optimizer):
|
||||||
self.adamw_mode = 1 if adamw_mode else 0
|
self.adamw_mode = 1 if adamw_mode else 0
|
||||||
self.set_grad_none = set_grad_none
|
self.set_grad_none = set_grad_none
|
||||||
if multi_tensor_applier.available:
|
if multi_tensor_applier.available:
|
||||||
import colossal_C
|
import colossalai._C.fused_optim
|
||||||
|
|
||||||
# Skip buffer
|
# Skip buffer
|
||||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||||
self.multi_tensor_adam = colossal_C.multi_tensor_adam
|
self.multi_tensor_adam = colossalai._C.fused_optim.multi_tensor_adam
|
||||||
else:
|
else:
|
||||||
raise RuntimeError('FusedAdam requires cuda extensions')
|
raise RuntimeError('FusedAdam requires cuda extensions')
|
||||||
|
|
||||||
|
|
|
@ -76,13 +76,13 @@ class FusedLAMB(torch.optim.Optimizer):
|
||||||
max_grad_norm=max_grad_norm)
|
max_grad_norm=max_grad_norm)
|
||||||
super(FusedLAMB, self).__init__(params, defaults)
|
super(FusedLAMB, self).__init__(params, defaults)
|
||||||
if multi_tensor_applier.available:
|
if multi_tensor_applier.available:
|
||||||
import colossal_C
|
import colossalai._C.fused_optim
|
||||||
self.multi_tensor_l2norm = colossal_C.multi_tensor_l2norm
|
self.multi_tensor_l2norm = colossalai._C.fused_optim.multi_tensor_l2norm
|
||||||
# Skip buffer
|
# Skip buffer
|
||||||
self._dummy_overflow_buf = torch.tensor([0],
|
self._dummy_overflow_buf = torch.tensor([0],
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device=self.param_groups[0]["params"][0].device)
|
device=self.param_groups[0]["params"][0].device)
|
||||||
self.multi_tensor_lamb = colossal_C.multi_tensor_lamb
|
self.multi_tensor_lamb = colossalai._C.fused_optim.multi_tensor_lamb
|
||||||
else:
|
else:
|
||||||
raise RuntimeError('FusedLAMB requires cuda extensions')
|
raise RuntimeError('FusedLAMB requires cuda extensions')
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ class FusedSGD(Optimizer):
|
||||||
|
|
||||||
:class:`colossalai.nn.optimizer.FusedSGD` may be used as a drop-in replacement for ``torch.optim.SGD``
|
:class:`colossalai.nn.optimizer.FusedSGD` may be used as a drop-in replacement for ``torch.optim.SGD``
|
||||||
|
|
||||||
:class:`colossalai.nn.optimizer.FusedSGD` may be used with or without Amp.
|
:class:`colossalai.nn.optimizer.FusedSGD` may be used with or without Amp.
|
||||||
|
|
||||||
Nesterov momentum is based on the formula from
|
Nesterov momentum is based on the formula from
|
||||||
`On the importance of initialization and momentum in deep learning`__.
|
`On the importance of initialization and momentum in deep learning`__.
|
||||||
|
@ -80,12 +80,13 @@ class FusedSGD(Optimizer):
|
||||||
self.wd_after_momentum = wd_after_momentum
|
self.wd_after_momentum = wd_after_momentum
|
||||||
|
|
||||||
if multi_tensor_applier.available:
|
if multi_tensor_applier.available:
|
||||||
import colossal_C
|
import colossalai._C.fused_optim
|
||||||
|
|
||||||
# Skip buffer
|
# Skip buffer
|
||||||
self._dummy_overflow_buf = torch.tensor([0],
|
self._dummy_overflow_buf = torch.tensor([0],
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device=self.param_groups[0]["params"][0].device)
|
device=self.param_groups[0]["params"][0].device)
|
||||||
self.multi_tensor_sgd = colossal_C.multi_tensor_sgd
|
self.multi_tensor_sgd = colossalai._C.fused_optim.multi_tensor_sgd
|
||||||
else:
|
else:
|
||||||
raise RuntimeError('FusedSGD requires cuda extensions')
|
raise RuntimeError('FusedSGD requires cuda extensions')
|
||||||
|
|
||||||
|
|
|
@ -77,14 +77,15 @@ class HybridAdam(NVMeOptimizer):
|
||||||
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
||||||
self.adamw_mode = adamw_mode
|
self.adamw_mode = adamw_mode
|
||||||
try:
|
try:
|
||||||
import colossal_C
|
import colossalai._C.cpu_optim
|
||||||
import cpu_adam
|
import colossalai._C.fused_optim
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError('Please install colossalai from source code to use HybridAdam')
|
raise ImportError('Please install colossalai from source code to use HybridAdam')
|
||||||
|
|
||||||
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
self.cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay,
|
||||||
|
adamw_mode)
|
||||||
|
|
||||||
self.gpu_adam_op = colossal_C.multi_tensor_adam
|
self.gpu_adam_op = colossalai._C.fused_optim.multi_tensor_adam
|
||||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
|
@ -1,32 +1,33 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
import functools
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import socket
|
import socket
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, List, Union, Dict, Optional
|
from typing import Callable, Dict, List, Optional, Union
|
||||||
import functools
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._six import inf
|
from torch._six import inf
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import colossal_C
|
import colossalai._C.fused_optim
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES)
|
|
||||||
|
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.global_variables import tensor_parallel_env as env
|
from colossalai.global_variables import tensor_parallel_env as env
|
||||||
from .multi_tensor_apply import multi_tensor_applier
|
|
||||||
|
|
||||||
from colossalai.tensor import ColoParameter, ProcessGroup
|
from colossalai.tensor import ColoParameter, ProcessGroup
|
||||||
from collections import defaultdict
|
|
||||||
|
from .multi_tensor_apply import multi_tensor_applier
|
||||||
|
|
||||||
|
|
||||||
def print_rank_0(msg: str, logger=None):
|
def print_rank_0(msg: str, logger=None):
|
||||||
|
@ -132,7 +133,7 @@ def _calc_l2_norm(grads):
|
||||||
if len(grads) > 0:
|
if len(grads) > 0:
|
||||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||||
norm, _ = multi_tensor_applier(
|
norm, _ = multi_tensor_applier(
|
||||||
colossal_C.multi_tensor_l2norm,
|
colossalai._C.fused_optim.multi_tensor_l2norm,
|
||||||
dummy_overflow_buf,
|
dummy_overflow_buf,
|
||||||
[grads],
|
[grads],
|
||||||
False # no per-parameter norm
|
False # no per-parameter norm
|
||||||
|
@ -269,7 +270,8 @@ def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None:
|
||||||
cpu_grads.append(p.grad.detach())
|
cpu_grads.append(p.grad.detach())
|
||||||
if len(cuda_grads) > 0:
|
if len(cuda_grads) > 0:
|
||||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||||
multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], clip_coef)
|
multi_tensor_applier(colossalai._C.fused_optim.multi_tensor_scale, dummy_overflow_buf,
|
||||||
|
[cuda_grads, cuda_grads], clip_coef)
|
||||||
for g in cpu_grads:
|
for g in cpu_grads:
|
||||||
g.mul_(clip_coef)
|
g.mul_(clip_coef)
|
||||||
|
|
||||||
|
@ -395,7 +397,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||||
if enable_cuda_kernels:
|
if enable_cuda_kernels:
|
||||||
grads = [p.grad.detach() for p in params]
|
grads = [p.grad.detach() for p in params]
|
||||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||||
multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
|
multi_tensor_applier(colossalai._C.fused_optim.multi_tensor_scale, dummy_overflow_buf, [grads, grads],
|
||||||
|
clip_coeff)
|
||||||
else:
|
else:
|
||||||
for p in params:
|
for p in params:
|
||||||
p.grad.detach().mul_(clip_coeff)
|
p.grad.detach().mul_(clip_coeff)
|
||||||
|
|
|
@ -14,7 +14,7 @@ class MultiTensorApply(object):
|
||||||
|
|
||||||
def __init__(self, chunk_size):
|
def __init__(self, chunk_size):
|
||||||
try:
|
try:
|
||||||
import colossal_C
|
import colossalai._C.fused_optim
|
||||||
MultiTensorApply.available = True
|
MultiTensorApply.available = True
|
||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
except ImportError as err:
|
except ImportError as err:
|
||||||
|
|
98
setup.py
98
setup.py
|
@ -1,7 +1,8 @@
|
||||||
import os
|
import os
|
||||||
import subprocess
|
|
||||||
import re
|
import re
|
||||||
from setuptools import find_packages, setup, Extension
|
import subprocess
|
||||||
|
|
||||||
|
from setuptools import Extension, find_packages, setup
|
||||||
|
|
||||||
# ninja build does not work unless include_dirs are abs path
|
# ninja build does not work unless include_dirs are abs path
|
||||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
@ -104,7 +105,7 @@ def get_version():
|
||||||
if build_cuda_ext:
|
if build_cuda_ext:
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CUDAExtension)
|
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
|
||||||
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
||||||
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
||||||
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
||||||
|
@ -148,7 +149,7 @@ if build_cuda_ext:
|
||||||
extra_cuda_flags = ['-lineinfo']
|
extra_cuda_flags = ['-lineinfo']
|
||||||
|
|
||||||
ext_modules.append(
|
ext_modules.append(
|
||||||
cuda_ext_helper('colossal_C', [
|
cuda_ext_helper('colossalai._C.fused_optim', [
|
||||||
'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu',
|
'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu',
|
||||||
'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu'
|
'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu'
|
||||||
], extra_cuda_flags + cc_flag))
|
], extra_cuda_flags + cc_flag))
|
||||||
|
@ -159,21 +160,21 @@ if build_cuda_ext:
|
||||||
]
|
]
|
||||||
|
|
||||||
ext_modules.append(
|
ext_modules.append(
|
||||||
cuda_ext_helper('colossal_scaled_upper_triang_masked_softmax',
|
cuda_ext_helper('colossalai._C.scaled_upper_triang_masked_softmax',
|
||||||
['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu'],
|
['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu'],
|
||||||
extra_cuda_flags + cc_flag))
|
extra_cuda_flags + cc_flag))
|
||||||
|
|
||||||
ext_modules.append(
|
ext_modules.append(
|
||||||
cuda_ext_helper('colossal_scaled_masked_softmax',
|
cuda_ext_helper('colossalai._C.scaled_masked_softmax',
|
||||||
['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'], extra_cuda_flags + cc_flag))
|
['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'], extra_cuda_flags + cc_flag))
|
||||||
|
|
||||||
ext_modules.append(
|
ext_modules.append(
|
||||||
cuda_ext_helper('colossal_moe_cuda', ['moe_cuda.cpp', 'moe_cuda_kernel.cu'], extra_cuda_flags + cc_flag))
|
cuda_ext_helper('colossalai._C.moe', ['moe_cuda.cpp', 'moe_cuda_kernel.cu'], extra_cuda_flags + cc_flag))
|
||||||
|
|
||||||
extra_cuda_flags = ['-maxrregcount=50']
|
extra_cuda_flags = ['-maxrregcount=50']
|
||||||
|
|
||||||
ext_modules.append(
|
ext_modules.append(
|
||||||
cuda_ext_helper('colossal_layer_norm_cuda', ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu'],
|
cuda_ext_helper('colossalai._C.layer_norm', ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu'],
|
||||||
extra_cuda_flags + cc_flag))
|
extra_cuda_flags + cc_flag))
|
||||||
|
|
||||||
extra_cuda_flags = [
|
extra_cuda_flags = [
|
||||||
|
@ -182,54 +183,53 @@ if build_cuda_ext:
|
||||||
]
|
]
|
||||||
|
|
||||||
ext_modules.append(
|
ext_modules.append(
|
||||||
cuda_ext_helper('colossal_multihead_attention', [
|
cuda_ext_helper('colossalai._C.multihead_attention', [
|
||||||
'multihead_attention_1d.cpp', 'kernels/cublas_wrappers.cu', 'kernels/transform_kernels.cu',
|
'multihead_attention_1d.cpp', 'kernels/cublas_wrappers.cu', 'kernels/transform_kernels.cu',
|
||||||
'kernels/dropout_kernels.cu', 'kernels/normalize_kernels.cu', 'kernels/softmax_kernels.cu',
|
'kernels/dropout_kernels.cu', 'kernels/normalize_kernels.cu', 'kernels/softmax_kernels.cu',
|
||||||
'kernels/general_kernels.cu', 'kernels/cuda_util.cu'
|
'kernels/general_kernels.cu', 'kernels/cuda_util.cu'
|
||||||
], extra_cuda_flags + cc_flag))
|
], extra_cuda_flags + cc_flag))
|
||||||
|
|
||||||
extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native']
|
extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native']
|
||||||
ext_modules.append(cuda_ext_helper('cpu_adam', ['cpu_adam.cpp'], extra_cuda_flags, extra_cxx_flags))
|
ext_modules.append(cuda_ext_helper('colossalai._C.cpu_optim', ['cpu_adam.cpp'], extra_cuda_flags, extra_cxx_flags))
|
||||||
|
|
||||||
setup(
|
setup(name='colossalai',
|
||||||
name='colossalai',
|
version=get_version(),
|
||||||
version=get_version(),
|
packages=find_packages(exclude=(
|
||||||
packages=find_packages(exclude=(
|
'benchmark',
|
||||||
'benchmark',
|
'docker',
|
||||||
'docker',
|
'tests',
|
||||||
'tests',
|
'docs',
|
||||||
'docs',
|
'examples',
|
||||||
'examples',
|
'tests',
|
||||||
'tests',
|
'scripts',
|
||||||
'scripts',
|
'requirements',
|
||||||
'requirements',
|
'*.egg-info',
|
||||||
'*.egg-info',
|
)),
|
||||||
)),
|
description='An integrated large-scale model training system with efficient parallelization techniques',
|
||||||
description='An integrated large-scale model training system with efficient parallelization techniques',
|
long_description=fetch_readme(),
|
||||||
long_description=fetch_readme(),
|
long_description_content_type='text/markdown',
|
||||||
long_description_content_type='text/markdown',
|
license='Apache Software License 2.0',
|
||||||
license='Apache Software License 2.0',
|
url='https://www.colossalai.org',
|
||||||
url='https://www.colossalai.org',
|
project_urls={
|
||||||
project_urls={
|
'Forum': 'https://github.com/hpcaitech/ColossalAI/discussions',
|
||||||
'Forum': 'https://github.com/hpcaitech/ColossalAI/discussions',
|
'Bug Tracker': 'https://github.com/hpcaitech/ColossalAI/issues',
|
||||||
'Bug Tracker': 'https://github.com/hpcaitech/ColossalAI/issues',
|
'Examples': 'https://github.com/hpcaitech/ColossalAI-Examples',
|
||||||
'Examples': 'https://github.com/hpcaitech/ColossalAI-Examples',
|
'Documentation': 'http://colossalai.readthedocs.io',
|
||||||
'Documentation': 'http://colossalai.readthedocs.io',
|
'Github': 'https://github.com/hpcaitech/ColossalAI',
|
||||||
'Github': 'https://github.com/hpcaitech/ColossalAI',
|
},
|
||||||
},
|
ext_modules=ext_modules,
|
||||||
ext_modules=ext_modules,
|
cmdclass={'build_ext': BuildExtension} if ext_modules else {},
|
||||||
cmdclass={'build_ext': BuildExtension} if ext_modules else {},
|
install_requires=fetch_requirements('requirements/requirements.txt'),
|
||||||
install_requires=fetch_requirements('requirements/requirements.txt'),
|
entry_points='''
|
||||||
entry_points='''
|
|
||||||
[console_scripts]
|
[console_scripts]
|
||||||
colossalai=colossalai.cli:cli
|
colossalai=colossalai.cli:cli
|
||||||
''',
|
''',
|
||||||
python_requires='>=3.6',
|
python_requires='>=3.6',
|
||||||
classifiers=[
|
classifiers=[
|
||||||
'Programming Language :: Python :: 3',
|
'Programming Language :: Python :: 3',
|
||||||
'License :: OSI Approved :: Apache Software License',
|
'License :: OSI Approved :: Apache Software License',
|
||||||
'Environment :: GPU :: NVIDIA CUDA',
|
'Environment :: GPU :: NVIDIA CUDA',
|
||||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||||
'Topic :: System :: Distributed Computing',
|
'Topic :: System :: Distributed Computing',
|
||||||
],
|
],
|
||||||
)
|
package_data={'colossalai': ['_C/*.pyi']})
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.testing import parameterize
|
from colossalai.testing import parameterize
|
||||||
|
@ -66,8 +67,8 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype):
|
||||||
exp_avg_sq_copy = exp_avg_sq.clone()
|
exp_avg_sq_copy = exp_avg_sq.clone()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import cpu_adam
|
import colossalai._C.cpu_optim
|
||||||
cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
|
cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
|
||||||
except:
|
except:
|
||||||
raise ImportError("Import cpu adam error, please install colossal from source code")
|
raise ImportError("Import cpu adam error, please install colossal from source code")
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from numpy import dtype
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from numpy import dtype
|
||||||
import math
|
|
||||||
|
|
||||||
from colossalai.testing import parameterize
|
from colossalai.testing import parameterize
|
||||||
from colossalai.utils import multi_tensor_applier
|
from colossalai.utils import multi_tensor_applier
|
||||||
|
@ -47,11 +47,11 @@ def torch_adam_update(
|
||||||
@parameterize('g_dtype', [torch.float, torch.half])
|
@parameterize('g_dtype', [torch.float, torch.half])
|
||||||
def test_adam(adamw, step, p_dtype, g_dtype):
|
def test_adam(adamw, step, p_dtype, g_dtype):
|
||||||
try:
|
try:
|
||||||
import colossal_C
|
import colossalai._C.fused_optim
|
||||||
fused_adam = colossal_C.multi_tensor_adam
|
fused_adam = colossalai._C.fused_optim.multi_tensor_adam
|
||||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||||
except:
|
except:
|
||||||
raise ImportError("No colossal_C kernel installed.")
|
raise ImportError("No colossalai._C.fused_optim kernel installed.")
|
||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue