From f8a7148dec871b5a691044ee026503dbb8232eb9 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 17 Nov 2022 13:42:33 +0800 Subject: [PATCH] [kernel] move all symlinks of kernel to `colossalai._C` (#1971) --- .github/workflows/build.yml | 1 - .github/workflows/build_gpu_8.yml | 1 - MANIFEST.in | 2 +- colossalai/_C/__init__.pyi | 9 + colossalai/_C/cpu_optim.pyi | 8 + colossalai/_C/fused_optim.pyi | 23 ++ colossalai/_C/layer_norm.pyi | 11 + colossalai/_C/moe.pyi | 20 ++ colossalai/_C/multihead_attention.pyi | 55 ++++ colossalai/_C/scaled_masked_softmax.pyi | 12 + .../_C/scaled_upper_triang_masked_softmax.pyi | 8 + colossalai/amp/naive_amp/_fp16_optimizer.py | 4 +- colossalai/cli/check/check_installation.py | 5 +- colossalai/kernel/cuda_native/layer_norm.py | 30 +- .../kernel/cuda_native/multihead_attention.py | 83 ++--- .../kernel/cuda_native/scaled_softmax.py | 24 +- colossalai/nn/layer/moe/_operation.py | 307 +++++++++--------- colossalai/nn/optimizer/cpu_adam.py | 13 +- colossalai/nn/optimizer/fused_adam.py | 7 +- colossalai/nn/optimizer/fused_lamb.py | 6 +- colossalai/nn/optimizer/fused_sgd.py | 7 +- colossalai/nn/optimizer/hybrid_adam.py | 9 +- colossalai/utils/common.py | 23 +- .../multi_tensor_apply/multi_tensor_apply.py | 2 +- setup.py | 98 +++--- tests/test_optimizer/test_cpu_adam.py | 5 +- .../test_optimizer/test_fused_adam_kernel.py | 12 +- 27 files changed, 463 insertions(+), 322 deletions(-) create mode 100644 colossalai/_C/__init__.pyi create mode 100644 colossalai/_C/cpu_optim.pyi create mode 100644 colossalai/_C/fused_optim.pyi create mode 100644 colossalai/_C/layer_norm.pyi create mode 100644 colossalai/_C/moe.pyi create mode 100644 colossalai/_C/multihead_attention.pyi create mode 100644 colossalai/_C/scaled_masked_softmax.pyi create mode 100644 colossalai/_C/scaled_upper_triang_masked_softmax.pyi diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6ccd9a137..36e33b0ab 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -38,7 +38,6 @@ jobs: pip install -r requirements/requirements.txt pip install -v -e . cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ - cp /__w/ColossalAI/ColossalAI/*.so /github/home/cuda_ext_cache/ pip install -r requirements/requirements-test.txt - name: Unit Testing run: | diff --git a/.github/workflows/build_gpu_8.yml b/.github/workflows/build_gpu_8.yml index f90085f5a..2a405d86f 100644 --- a/.github/workflows/build_gpu_8.yml +++ b/.github/workflows/build_gpu_8.yml @@ -36,7 +36,6 @@ jobs: pip install -r requirements/requirements.txt pip install -v -e . cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ - cp /__w/ColossalAI/ColossalAI/*.so /github/home/cuda_ext_cache/ pip install -r requirements/requirements-test.txt - name: Unit Testing run: | diff --git a/MANIFEST.in b/MANIFEST.in index 0991e2737..baf289270 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,3 @@ include *.txt README.md recursive-include requirements *.txt -recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc +recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi diff --git a/colossalai/_C/__init__.pyi b/colossalai/_C/__init__.pyi new file mode 100644 index 000000000..bfd86d0ee --- /dev/null +++ b/colossalai/_C/__init__.pyi @@ -0,0 +1,9 @@ +from . import ( + cpu_optim, + fused_optim, + layer_norm, + moe, + multihead_attention, + scaled_masked_softmax, + scaled_upper_triang_masked_softmax, +) diff --git a/colossalai/_C/cpu_optim.pyi b/colossalai/_C/cpu_optim.pyi new file mode 100644 index 000000000..0f7611790 --- /dev/null +++ b/colossalai/_C/cpu_optim.pyi @@ -0,0 +1,8 @@ +from torch import Tensor + +class CPUAdamOptimizer: + def __init__(self, lr: float, beta1: float, beta2: float, eps: float, + weight_decay: float, adamw_mode: float) -> None: ... + + def step(self, step: int, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, bias_correction: bool, + param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor, loss_scale: float) -> None: ... diff --git a/colossalai/_C/fused_optim.pyi b/colossalai/_C/fused_optim.pyi new file mode 100644 index 000000000..6d8e97dd9 --- /dev/null +++ b/colossalai/_C/fused_optim.pyi @@ -0,0 +1,23 @@ +from typing import List + +from torch import Tensor + +def multi_tensor_scale(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], scale: float) -> None: + ... + + +def multi_tensor_sgd(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], weight_decay: float, + momentum: float, dampening: float, lr: float, nesterov: bool, first_run: bool, weight_decay_after_momentum: bool, scale: float) -> None: + ... + + +def multi_tensor_adam(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, mode: int, bias_correction: int, weight_decay: float) -> None: + ... + + +def multi_tensor_lamb(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, bias_correction: int, weight_decay: float, grad_averaging: int, mode: int, global_grad_norm: Tensor, max_grad_norm: float, use_nvlamb_python: bool) -> None: + ... + + +def multi_tensor_l2norm(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], per_tensor_python: bool) -> None: + ... diff --git a/colossalai/_C/layer_norm.pyi b/colossalai/_C/layer_norm.pyi new file mode 100644 index 000000000..02d4587ff --- /dev/null +++ b/colossalai/_C/layer_norm.pyi @@ -0,0 +1,11 @@ +from typing import List + +from torch import Tensor + +def forward_affine(input: Tensor, normalized_shape: List[int], gamma: Tensor, beta: Tensor, epsilon: float) -> List[Tensor]: + ... + + +def backward_affine(dout: Tensor, mean: Tensor, invvar: Tensor, input: Tensor, + normalized_shape: List[int], gamma: Tensor, beta: Tensor, epsilon: float) -> List[Tensor]: + ... diff --git a/colossalai/_C/moe.pyi b/colossalai/_C/moe.pyi new file mode 100644 index 000000000..121aa7e41 --- /dev/null +++ b/colossalai/_C/moe.pyi @@ -0,0 +1,20 @@ +from torch import Tensor + +def cumsum_sub_one(mask: Tensor) -> Tensor: + ... + + +def dispatch_forward(s: int, ec: int, h: int, batch_tokens: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor: + ... + + +def dispatch_backward(s: int, ec: int, h: int, expert_grad: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor: + ... + + +def combine_forward(s: int, e: int, c: int, h: int, expert_tokens: Tensor, logits: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor: + ... + + +def combine_backward(s: int, e: int, c: int, h: int, tokens_grad: Tensor, expert_tokens: Tensor, logits: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor: + ... diff --git a/colossalai/_C/multihead_attention.pyi b/colossalai/_C/multihead_attention.pyi new file mode 100644 index 000000000..7ad87ea9a --- /dev/null +++ b/colossalai/_C/multihead_attention.pyi @@ -0,0 +1,55 @@ +from typing import List + +from torch import Tensor +from torch.distributed import ProcessGroup + +def multihead_attention_fw_fp32(layer_id: int, input: Tensor, input_mask: Tensor, + in_proj_weight: Tensor, in_proj_bias: Tensor, + out_proj_weight: Tensor, out_proj_bias: Tensor, + norm_weight: Tensor, norm_bias: Tensor, + training_mode: bool, prelayernorm: bool) -> List[Tensor]: + ... + + +def multihead_attention_fw_fp16(layer_id: int, input: Tensor, input_mask: Tensor, + in_proj_weight: Tensor, in_proj_bias: Tensor, + out_proj_weight: Tensor, out_proj_bias: Tensor, + norm_weight: Tensor, norm_bias: Tensor, + training_mode: bool, prelayernorm: bool) -> List[Tensor]: + ... + + +def multihead_attention_bw_fp32(layer_id: int, grad_dec_output: Tensor, + output: Tensor, input: Tensor, + input_mask: Tensor, in_proj_weight: Tensor, + in_proj_bias: Tensor, out_proj_weight: Tensor, + out_proj_bias: Tensor, norm_weight: Tensor, + norm_bias: Tensor) -> List[Tensor]: + ... + + +def multihead_attention_bw_fp16(layer_id: int, grad_dec_output: Tensor, + output: Tensor, input: Tensor, + input_mask: Tensor, in_proj_weight: Tensor, + in_proj_bias: Tensor, out_proj_weight: Tensor, + out_proj_bias: Tensor, norm_weight: Tensor, + norm_bias: Tensor) -> List[Tensor]: + ... + + +def create_multihead_attention_fp32(layer_id: int, max_batch_tokens: int, + max_seq_len: int, hidden_dim: int, num_heads: int, + attn_prob_dropout_ratio: float, + hidden_dropout_ratio: float, + pre_or_postLayerNorm: bool, + pg: ProcessGroup) -> int: + ... + + +def create_multihead_attention_fp16(layer_id: int, max_batch_tokens: int, + max_seq_len: int, hidden_dim: int, num_heads: int, + attn_prob_dropout_ratio: float, + hidden_dropout_ratio: float, + pre_or_postLayerNorm: bool, + pg: ProcessGroup) -> int: + ... diff --git a/colossalai/_C/scaled_masked_softmax.pyi b/colossalai/_C/scaled_masked_softmax.pyi new file mode 100644 index 000000000..fdb88266e --- /dev/null +++ b/colossalai/_C/scaled_masked_softmax.pyi @@ -0,0 +1,12 @@ +from torch import Tensor + +def forward(input: Tensor, mask: Tensor, scale: float) -> Tensor: + ... + + +def backward(output_grads: Tensor, softmax_results: Tensor, scale: float) -> Tensor: + ... + + +def get_batch_per_block(query_seq_len: int, key_seq_len: int, batches: int, attn_heads: int) -> int: + ... diff --git a/colossalai/_C/scaled_upper_triang_masked_softmax.pyi b/colossalai/_C/scaled_upper_triang_masked_softmax.pyi new file mode 100644 index 000000000..39a3d6b22 --- /dev/null +++ b/colossalai/_C/scaled_upper_triang_masked_softmax.pyi @@ -0,0 +1,8 @@ +from torch import Tensor + +def forward(input: Tensor, scale: float) -> Tensor: + ... + + +def backward(output_grads: Tensor, softmax_results: Tensor, scale: float) -> Tensor: + ... diff --git a/colossalai/amp/naive_amp/_fp16_optimizer.py b/colossalai/amp/naive_amp/_fp16_optimizer.py index b01a3cbf0..9a8be009b 100644 --- a/colossalai/amp/naive_amp/_fp16_optimizer.py +++ b/colossalai/amp/naive_amp/_fp16_optimizer.py @@ -5,7 +5,7 @@ import torch import torch.distributed as dist try: - import colossal_C + import colossalai._C.fused_optim except: print('Colossalai should be built with cuda extension to use the FP16 optimizer') @@ -35,7 +35,7 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): if overflow_buf: overflow_buf.fill_(0) # Scaling with factor `1.0` is equivalent to copy. - multi_tensor_applier(colossal_C.multi_tensor_scale, overflow_buf, [this, that], 1.0) + multi_tensor_applier(colossalai._C.fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0) else: for this_, that_ in zip(this, that): that_.copy_(this_) diff --git a/colossalai/cli/check/check_installation.py b/colossalai/cli/check/check_installation.py index eab0bc1ed..a299494fb 100644 --- a/colossalai/cli/check/check_installation.py +++ b/colossalai/cli/check/check_installation.py @@ -1,5 +1,6 @@ -import click import subprocess + +import click import torch from torch.utils.cpp_extension import CUDA_HOME @@ -17,7 +18,7 @@ def check_installation(): def _check_cuda_extension_installed(): try: - import colossal_C + import colossalai._C.fused_optim is_cuda_extension_installed = u'\u2713' except ImportError: is_cuda_extension_installed = 'x' diff --git a/colossalai/kernel/cuda_native/layer_norm.py b/colossalai/kernel/cuda_native/layer_norm.py index 38e95e2f8..f1b5efa4e 100644 --- a/colossalai/kernel/cuda_native/layer_norm.py +++ b/colossalai/kernel/cuda_native/layer_norm.py @@ -3,14 +3,11 @@ with some changes. """ import numbers -import torch -from torch.nn.parameter import Parameter -from torch.nn import init -from torch.cuda.amp import custom_fwd, custom_bwd -import importlib -global colossal_layer_norm_cuda -colossal_layer_norm_cuda = None +import torch +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn import init +from torch.nn.parameter import Parameter class FusedLayerNormAffineFunction(torch.autograd.Function): @@ -18,13 +15,17 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, input, weight, bias, normalized_shape, eps): + try: + import colossalai._C.layer_norm + except ImportError: + raise RuntimeError('FusedLayerNormAffineFunction requires cuda extensions') ctx.normalized_shape = normalized_shape ctx.eps = eps input_ = input.contiguous() weight_ = weight.contiguous() bias_ = bias.contiguous() - output, mean, invvar = colossal_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, bias_, + output, mean, invvar = colossalai._C.layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps) ctx.save_for_backward(input_, weight_, bias_, mean, invvar) @@ -33,11 +34,15 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): @staticmethod @custom_bwd def backward(ctx, grad_output): + try: + import colossalai._C.layer_norm + except ImportError: + raise RuntimeError('FusedLayerNormAffineFunction requires cuda extensions') input_, weight_, bias_, mean, invvar = ctx.saved_tensors grad_input = grad_weight = grad_bias = None grad_input, grad_weight, grad_bias \ - = colossal_layer_norm_cuda.backward_affine( + = colossalai._C.layer_norm.backward_affine( grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps) @@ -50,13 +55,6 @@ class MixedFusedLayerNorm(torch.nn.Module): def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None): super(MixedFusedLayerNorm, self).__init__() - global colossal_layer_norm_cuda - if colossal_layer_norm_cuda is None: - try: - colossal_layer_norm_cuda = importlib.import_module("colossal_layer_norm_cuda") - except ImportError: - raise RuntimeError('MixedFusedLayerNorm requires cuda extensions') - if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) self.normalized_shape = torch.Size(normalized_shape) diff --git a/colossalai/kernel/cuda_native/multihead_attention.py b/colossalai/kernel/cuda_native/multihead_attention.py index c93d1cf60..84cae529a 100644 --- a/colossalai/kernel/cuda_native/multihead_attention.py +++ b/colossalai/kernel/cuda_native/multihead_attention.py @@ -1,5 +1,4 @@ import math -import importlib from dataclasses import dataclass import torch @@ -37,21 +36,21 @@ colossal_multihead_attention = None @dataclass class Config: - max_batch_tokens: int # max batch token numbers - max_seq_len: int # max sequence length - hidden_size: int # size of transformer hidden layers - nhead: int # number of heads in attention - attn_prob_dropout_ratio: float # attention score dropout ratio - hidden_dropout_ratio: float # dropout ration before residual - norm_first: bool # norm_first - fp16: bool # fp16 presion + max_batch_tokens: int # max batch token numbers + max_seq_len: int # max sequence length + hidden_size: int # size of transformer hidden layers + nhead: int # number of heads in attention + attn_prob_dropout_ratio: float # attention score dropout ratio + hidden_dropout_ratio: float # dropout ration before residual + norm_first: bool # norm_first + fp16: bool # fp16 presion class MultiHeadAttention1DFunc(Function): @staticmethod - def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, - out_proj_bias, norm_weight, norm_bias, config): + def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, + norm_bias, config): cuda_module = colossal_multihead_attention forward_func = (cuda_module.multihead_attention_fw_fp16 if config.fp16 else cuda_module.multihead_attention_fw_fp32) @@ -59,13 +58,12 @@ class MultiHeadAttention1DFunc(Function): input = input.to(torch.half) input_mask = input_mask.to(torch.half) - (output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias, - out_proj_weight, out_proj_bias, norm_weight, norm_bias, - config.training, config.norm_first) + (output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, + out_proj_bias, norm_weight, norm_bias, config.training, config.norm_first) if config.is_grad_enabled and config.training: - ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias, - out_proj_weight, out_proj_bias, norm_weight, norm_bias) + ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, + out_proj_bias, norm_weight, norm_bias) ctx.config = config return output @@ -98,8 +96,8 @@ class MultiHeadAttention1DFunc(Function): ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias) - return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, - grad_out_proj_bias, grad_norm_weight, grad_norm_bias, None) + return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, + grad_norm_weight, grad_norm_bias, None) class MultiHeadAttention(nn.Module): @@ -121,19 +119,11 @@ class MultiHeadAttention(nn.Module): layer_id = 0 - def __init__(self, - hidden_size, - nhead, - batch_size, - max_seq_len, - dropout=0.0, - norm_first=False, - fp16=True, - pg=None): + def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None): super(MultiHeadAttention, self).__init__() - self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, - dropout, norm_first, fp16) + self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, + fp16) check_config(self.config) self.pg = pg self.pg_size = 1 @@ -146,7 +136,8 @@ class MultiHeadAttention(nn.Module): global colossal_multihead_attention if colossal_multihead_attention is None: try: - colossal_multihead_attention = importlib.import_module("colossal_multihead_attention") + import colossalai._C.multihead_attention + colossal_multihead_attention = colossalai._C.multihead_attention except ImportError: raise RuntimeError('MultiHeadAttention requires cuda extensions') @@ -215,14 +206,13 @@ class MultiHeadAttention(nn.Module): with torch.no_grad(): self.in_proj_weight.copy_( - attn_qkvw_global.view(3, hs, hs)[ - :, int(hs * rank_in_pg / self.pg_size): - int(hs * (rank_in_pg + 1) / self.pg_size), - :]) + attn_qkvw_global.view(3, hs, hs)[:, + int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) / + self.pg_size), :]) self.in_proj_bias.copy_( - attn_qkvb_global.view(3, hs)[ - :, int(hs * rank_in_pg / self.pg_size): - int(hs * (rank_in_pg + 1) / self.pg_size)]) + attn_qkvb_global.view(3, hs)[:, + int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) / + self.pg_size)]) attn_ow_global = torch.empty(hs, hs) nn.init.xavier_uniform_(attn_ow_global, 1.0) @@ -230,9 +220,9 @@ class MultiHeadAttention(nn.Module): torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg) attn_ow_global = attn_ow_global.cpu() with torch.no_grad(): - self.out_proj_weight.copy_(attn_ow_global[ - :, int(hs * rank_in_pg / self.pg_size): - int(hs * (rank_in_pg + 1) / self.pg_size)]) + self.out_proj_weight.copy_(attn_ow_global[:, + int(hs * rank_in_pg / + self.pg_size):int(hs * (rank_in_pg + 1) / self.pg_size)]) else: attn_qkvw = self.in_proj_weight.view(-1, hs) @@ -243,10 +233,7 @@ class MultiHeadAttention(nn.Module): nn.init.xavier_uniform_(self.out_proj_weight, 1.0) def state_dict(self, destination=None, prefix="", keep_vars=False): - destination = torch.nn.Module.state_dict(self, - destination=destination, - prefix=prefix, - keep_vars=keep_vars) + destination = torch.nn.Module.state_dict(self, destination=destination, prefix=prefix, keep_vars=keep_vars) return destination def forward(self, hidden_states, encoder_padding_mask): @@ -257,8 +244,7 @@ class MultiHeadAttention(nn.Module): bs, sl, dim = hidden_states.size() if bs * sl > self.config.max_batch_tokens: - raise ValueError( - f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.") + raise ValueError(f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.") if sl > self.config.max_seq_len: raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.") if len(encoder_padding_mask.size()) == 1: @@ -266,9 +252,8 @@ class MultiHeadAttention(nn.Module): else: assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1) - output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask, - self.in_proj_weight, self.in_proj_bias, - self.out_proj_weight, self.out_proj_bias, + output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask, self.in_proj_weight, + self.in_proj_bias, self.out_proj_weight, self.out_proj_bias, self.norm_weight, self.norm_bias, self.config) return output.to(self.precision) diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/colossalai/kernel/cuda_native/scaled_softmax.py index cb36da8a1..e02067d05 100644 --- a/colossalai/kernel/cuda_native/scaled_softmax.py +++ b/colossalai/kernel/cuda_native/scaled_softmax.py @@ -1,9 +1,10 @@ """This code from NVIDIA Megatron with some changes. """ +import enum + import torch import torch.nn as nn -import enum class AttnMaskType(enum.Enum): @@ -23,12 +24,12 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): @staticmethod def forward(ctx, inputs, scale): try: - import colossal_scaled_upper_triang_masked_softmax + import colossalai._C.scaled_upper_triang_masked_softmax except ImportError: raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions') scale_t = torch.tensor([scale]) - softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) + softmax_results = colossalai._C.scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) ctx.save_for_backward(softmax_results, scale_t) return softmax_results @@ -36,12 +37,13 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): @staticmethod def backward(ctx, output_grads): try: - import colossal_scaled_upper_triang_masked_softmax + import colossalai._C.scaled_upper_triang_masked_softmax except ImportError: raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions') softmax_results, scale_t = ctx.saved_tensors - input_grads = colossal_scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) + input_grads = colossalai._C.scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, + scale_t[0]) return input_grads, None @@ -58,26 +60,26 @@ class ScaledMaskedSoftmax(torch.autograd.Function): @staticmethod def forward(ctx, inputs, mask, scale): try: - import colossal_scaled_masked_softmax + import colossalai._C.scaled_masked_softmax except ImportError: raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions') scale_t = torch.tensor([scale]) - softmax_results = colossal_scaled_masked_softmax.forward(inputs, mask, scale_t[0]) + softmax_results = colossalai._C.scaled_masked_softmax.forward(inputs, mask, scale_t[0]) ctx.save_for_backward(softmax_results, scale_t) return softmax_results @staticmethod def backward(ctx, output_grads): try: - import colossal_scaled_masked_softmax + import colossalai._C.scaled_masked_softmax except ImportError: raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions') softmax_results, scale_t = ctx.saved_tensors - input_grads = colossal_scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) + input_grads = colossalai._C.scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) return input_grads, None, None @@ -184,8 +186,8 @@ class FusedScaleMaskSoftmax(nn.Module): @staticmethod def get_batch_per_block(sq, sk, b, np): try: - import colossal_scaled_masked_softmax + import colossalai._C.scaled_masked_softmax except ImportError: raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions') - return colossal_scaled_masked_softmax.get_batch_per_block(sq, sk, b, np) + return colossalai._C.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np) diff --git a/colossalai/nn/layer/moe/_operation.py b/colossalai/nn/layer/moe/_operation.py index dbf264297..278cdfbb7 100644 --- a/colossalai/nn/layer/moe/_operation.py +++ b/colossalai/nn/layer/moe/_operation.py @@ -1,153 +1,154 @@ -import torch -import torch.distributed as dist -from torch import Tensor -from typing import Any, Tuple, Optional -from torch.distributed import ProcessGroup - -COL_MOE_KERNEL_FLAG = False -try: - import colossal_moe_cuda - - COL_MOE_KERNEL_FLAG = True -except ImportError: - print("If you want to activate cuda mode for MoE, please install with cuda_ext!") - - -class AllGather(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: - if ctx is not None: - ctx.comm_grp = group - - comm_size = dist.get_world_size(group) - if comm_size == 1: - return inputs.unsqueeze(0) - - buffer_shape = (comm_size,) + inputs.shape - outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) - buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) - dist.all_gather(buffer_list, inputs, group=group) - return outputs - - @staticmethod - def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: - return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None - - -class ReduceScatter(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: - if ctx is not None: - ctx.comm_grp = group - - comm_size = dist.get_world_size(group) - if comm_size == 1: - return inputs.squeeze(0) - - if not inputs.is_contiguous(): - inputs = inputs.contiguous() - - output_shape = inputs.shape[1:] - outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) - buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) - dist.reduce_scatter(outputs, buffer_list, group=group) - return outputs - - @staticmethod - def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: - return AllGather.forward(None, grad_outputs, ctx.comm_grp), None - - -class AllToAll(torch.autograd.Function): - """Dispatches input tensor [e, c, h] to all experts by all_to_all_single - operation in torch.distributed. - """ - - @staticmethod - def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: - if ctx is not None: - ctx.comm_grp = group - if not inputs.is_contiguous(): - inputs = inputs.contiguous() - if dist.get_world_size(group) == 1: - return inputs - output = torch.empty_like(inputs) - dist.all_to_all_single(output, inputs, group=group) - return output - - @staticmethod - def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: - return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None - - -class MoeDispatch(torch.autograd.Function): - - @staticmethod - def forward(ctx, tokens, mask, dest_idx, ec): - s = tokens.size(0) - h = tokens.size(1) - - expert_input = colossal_moe_cuda.dispatch_forward(s, ec, h, tokens, mask, dest_idx) - - ctx.save_for_backward(mask, dest_idx) - ctx.s = s - ctx.h = h - ctx.ec = ec - - return expert_input - - @staticmethod - def backward(ctx, output_grad): - mask, dest_idx = ctx.saved_tensors - d_tokens = colossal_moe_cuda.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx) - return d_tokens, None, None, None - - -class MoeCombine(torch.autograd.Function): - - @staticmethod - def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): - assert logits.dtype == torch.float32 - - s = logits.size(0) - e = logits.size(1) - c = ec // e - h = expert_tokens.size(-1) - - fp16_flag = (expert_tokens.dtype == torch.float16) - cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens - ctokens = colossal_moe_cuda.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) - output = ctokens.to(torch.float16) if fp16_flag else ctokens - - ctx.save_for_backward(expert_tokens, logits, mask, dest_idx) - ctx.s = s - ctx.e = e - ctx.c = c - ctx.h = h - ctx.fp16_flag = fp16_flag - - return output - - @staticmethod - def backward(ctx, tokens_grad): - expert_tokens, logits, mask, dest_idx = ctx.saved_tensors - - cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \ - else tokens_grad - cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens - d_expert, d_logits = colossal_moe_cuda.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, - mask, dest_idx) - d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert - - return d_expert, d_logits, None, None, None - - -def moe_cumsum(inputs: Tensor): - dim0 = inputs.size(0) - flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) - if flag and COL_MOE_KERNEL_FLAG: - return colossal_moe_cuda.cumsum_sub_one(inputs) - else: - return torch.cumsum(inputs, dim=0) - 1 +from typing import Any, Optional, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ProcessGroup + +COL_MOE_KERNEL_FLAG = False +try: + import colossalai._C.moe + + COL_MOE_KERNEL_FLAG = True +except ImportError: + print("If you want to activate cuda mode for MoE, please install with cuda_ext!") + + +class AllGather(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: + if ctx is not None: + ctx.comm_grp = group + + comm_size = dist.get_world_size(group) + if comm_size == 1: + return inputs.unsqueeze(0) + + buffer_shape = (comm_size,) + inputs.shape + outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) + buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) + dist.all_gather(buffer_list, inputs, group=group) + return outputs + + @staticmethod + def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: + return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None + + +class ReduceScatter(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: + if ctx is not None: + ctx.comm_grp = group + + comm_size = dist.get_world_size(group) + if comm_size == 1: + return inputs.squeeze(0) + + if not inputs.is_contiguous(): + inputs = inputs.contiguous() + + output_shape = inputs.shape[1:] + outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) + buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) + dist.reduce_scatter(outputs, buffer_list, group=group) + return outputs + + @staticmethod + def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: + return AllGather.forward(None, grad_outputs, ctx.comm_grp), None + + +class AllToAll(torch.autograd.Function): + """Dispatches input tensor [e, c, h] to all experts by all_to_all_single + operation in torch.distributed. + """ + + @staticmethod + def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: + if ctx is not None: + ctx.comm_grp = group + if not inputs.is_contiguous(): + inputs = inputs.contiguous() + if dist.get_world_size(group) == 1: + return inputs + output = torch.empty_like(inputs) + dist.all_to_all_single(output, inputs, group=group) + return output + + @staticmethod + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: + return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None + + +class MoeDispatch(torch.autograd.Function): + + @staticmethod + def forward(ctx, tokens, mask, dest_idx, ec): + s = tokens.size(0) + h = tokens.size(1) + + expert_input = colossalai._C.moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx) + + ctx.save_for_backward(mask, dest_idx) + ctx.s = s + ctx.h = h + ctx.ec = ec + + return expert_input + + @staticmethod + def backward(ctx, output_grad): + mask, dest_idx = ctx.saved_tensors + d_tokens = colossalai._C.moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx) + return d_tokens, None, None, None + + +class MoeCombine(torch.autograd.Function): + + @staticmethod + def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): + assert logits.dtype == torch.float32 + + s = logits.size(0) + e = logits.size(1) + c = ec // e + h = expert_tokens.size(-1) + + fp16_flag = (expert_tokens.dtype == torch.float16) + cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens + ctokens = colossalai._C.moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) + output = ctokens.to(torch.float16) if fp16_flag else ctokens + + ctx.save_for_backward(expert_tokens, logits, mask, dest_idx) + ctx.s = s + ctx.e = e + ctx.c = c + ctx.h = h + ctx.fp16_flag = fp16_flag + + return output + + @staticmethod + def backward(ctx, tokens_grad): + expert_tokens, logits, mask, dest_idx = ctx.saved_tensors + + cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \ + else tokens_grad + cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens + d_expert, d_logits = colossalai._C.moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, + mask, dest_idx) + d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert + + return d_expert, d_logits, None, None, None + + +def moe_cumsum(inputs: Tensor): + dim0 = inputs.size(0) + flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) + if flag and COL_MOE_KERNEL_FLAG: + return colossalai._C.moe.cumsum_sub_one(inputs) + else: + return torch.cumsum(inputs, dim=0) - 1 diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index ea08ff723..745d8de22 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -1,9 +1,11 @@ import math +from typing import Optional + import torch from colossalai.registry import OPTIMIZERS + from .nvme_optimizer import NVMeOptimizer -from typing import Optional @OPTIMIZERS.register_module @@ -11,7 +13,7 @@ class CPUAdam(NVMeOptimizer): """Implements Adam algorithm. Supports parameters updating on both GPU and CPU, depanding on the device of paramters. - But the parameters and gradients should on the same device: + But the parameters and gradients should on the same device: * Parameters on CPU and gradients on CPU is allowed. * Parameters on GPU and gradients on GPU is allowed. * Parameters on GPU and gradients on CPU is **not** allowed. @@ -44,7 +46,7 @@ class CPUAdam(NVMeOptimizer): (default: False) NOT SUPPORTED yet in CPUAdam! adamw_mode (boolean, optional): Apply L2 regularization or weight decay True for decoupled weight decay(also known as AdamW) (default: True) - simd_log (boolean, optional): whether to show if you are using SIMD to + simd_log (boolean, optional): whether to show if you are using SIMD to accelerate. (default: False) nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0. nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files. @@ -75,10 +77,11 @@ class CPUAdam(NVMeOptimizer): super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) self.adamw_mode = adamw_mode try: - import cpu_adam + import colossalai._C.cpu_optim except ImportError: raise ImportError('Please install colossalai from source code to use CPUAdam') - self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) + self.cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, + adamw_mode) def torch_adam_update(self, data, diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py index 5814c28bd..4687e6f3b 100644 --- a/colossalai/nn/optimizer/fused_adam.py +++ b/colossalai/nn/optimizer/fused_adam.py @@ -20,7 +20,7 @@ class FusedAdam(torch.optim.Optimizer): :class:`colossalai.nn.optimizer.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``, or ``torch.optim.Adam`` with ``adamw_mode=False`` - :class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp. + :class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp. Adam was been proposed in `Adam: A Method for Stochastic Optimization`_. @@ -65,10 +65,11 @@ class FusedAdam(torch.optim.Optimizer): self.adamw_mode = 1 if adamw_mode else 0 self.set_grad_none = set_grad_none if multi_tensor_applier.available: - import colossal_C + import colossalai._C.fused_optim + # Skip buffer self._dummy_overflow_buf = torch.cuda.IntTensor([0]) - self.multi_tensor_adam = colossal_C.multi_tensor_adam + self.multi_tensor_adam = colossalai._C.fused_optim.multi_tensor_adam else: raise RuntimeError('FusedAdam requires cuda extensions') diff --git a/colossalai/nn/optimizer/fused_lamb.py b/colossalai/nn/optimizer/fused_lamb.py index be12e6c62..2e33d7032 100644 --- a/colossalai/nn/optimizer/fused_lamb.py +++ b/colossalai/nn/optimizer/fused_lamb.py @@ -76,13 +76,13 @@ class FusedLAMB(torch.optim.Optimizer): max_grad_norm=max_grad_norm) super(FusedLAMB, self).__init__(params, defaults) if multi_tensor_applier.available: - import colossal_C - self.multi_tensor_l2norm = colossal_C.multi_tensor_l2norm + import colossalai._C.fused_optim + self.multi_tensor_l2norm = colossalai._C.fused_optim.multi_tensor_l2norm # Skip buffer self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device) - self.multi_tensor_lamb = colossal_C.multi_tensor_lamb + self.multi_tensor_lamb = colossalai._C.fused_optim.multi_tensor_lamb else: raise RuntimeError('FusedLAMB requires cuda extensions') diff --git a/colossalai/nn/optimizer/fused_sgd.py b/colossalai/nn/optimizer/fused_sgd.py index 1185eef81..03c3da28d 100644 --- a/colossalai/nn/optimizer/fused_sgd.py +++ b/colossalai/nn/optimizer/fused_sgd.py @@ -20,7 +20,7 @@ class FusedSGD(Optimizer): :class:`colossalai.nn.optimizer.FusedSGD` may be used as a drop-in replacement for ``torch.optim.SGD`` - :class:`colossalai.nn.optimizer.FusedSGD` may be used with or without Amp. + :class:`colossalai.nn.optimizer.FusedSGD` may be used with or without Amp. Nesterov momentum is based on the formula from `On the importance of initialization and momentum in deep learning`__. @@ -80,12 +80,13 @@ class FusedSGD(Optimizer): self.wd_after_momentum = wd_after_momentum if multi_tensor_applier.available: - import colossal_C + import colossalai._C.fused_optim + # Skip buffer self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device) - self.multi_tensor_sgd = colossal_C.multi_tensor_sgd + self.multi_tensor_sgd = colossalai._C.fused_optim.multi_tensor_sgd else: raise RuntimeError('FusedSGD requires cuda extensions') diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index 069b52af5..676dc71e4 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -77,14 +77,15 @@ class HybridAdam(NVMeOptimizer): super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) self.adamw_mode = adamw_mode try: - import colossal_C - import cpu_adam + import colossalai._C.cpu_optim + import colossalai._C.fused_optim except ImportError: raise ImportError('Please install colossalai from source code to use HybridAdam') - self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) + self.cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, + adamw_mode) - self.gpu_adam_op = colossal_C.multi_tensor_adam + self.gpu_adam_op = colossalai._C.fused_optim.multi_tensor_adam self._dummy_overflow_buf = torch.cuda.IntTensor([0]) @torch.no_grad() diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index a52c25530..d8cd709b3 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -1,32 +1,33 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import functools import os import random import socket from pathlib import Path -from typing import Callable, List, Union, Dict, Optional -import functools +from typing import Callable, Dict, List, Optional, Union import torch from torch._six import inf from torch.nn.parameter import Parameter try: - import colossal_C + import colossalai._C.fused_optim except: pass +from collections import defaultdict from contextlib import contextmanager import torch.distributed as dist -from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES) + +from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.global_variables import tensor_parallel_env as env -from .multi_tensor_apply import multi_tensor_applier - from colossalai.tensor import ColoParameter, ProcessGroup -from collections import defaultdict + +from .multi_tensor_apply import multi_tensor_applier def print_rank_0(msg: str, logger=None): @@ -132,7 +133,7 @@ def _calc_l2_norm(grads): if len(grads) > 0: dummy_overflow_buf = torch.cuda.IntTensor([0]) norm, _ = multi_tensor_applier( - colossal_C.multi_tensor_l2norm, + colossalai._C.fused_optim.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm @@ -269,7 +270,8 @@ def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None: cpu_grads.append(p.grad.detach()) if len(cuda_grads) > 0: dummy_overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], clip_coef) + multi_tensor_applier(colossalai._C.fused_optim.multi_tensor_scale, dummy_overflow_buf, + [cuda_grads, cuda_grads], clip_coef) for g in cpu_grads: g.mul_(clip_coef) @@ -395,7 +397,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): if enable_cuda_kernels: grads = [p.grad.detach() for p in params] dummy_overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff) + multi_tensor_applier(colossalai._C.fused_optim.multi_tensor_scale, dummy_overflow_buf, [grads, grads], + clip_coeff) else: for p in params: p.grad.detach().mul_(clip_coeff) diff --git a/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py b/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py index 4e847f17b..6eda9834b 100644 --- a/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py +++ b/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py @@ -14,7 +14,7 @@ class MultiTensorApply(object): def __init__(self, chunk_size): try: - import colossal_C + import colossalai._C.fused_optim MultiTensorApply.available = True self.chunk_size = chunk_size except ImportError as err: diff --git a/setup.py b/setup.py index 8341a97b7..0a83e622e 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,8 @@ import os -import subprocess import re -from setuptools import find_packages, setup, Extension +import subprocess + +from setuptools import Extension, find_packages, setup # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) @@ -104,7 +105,7 @@ def get_version(): if build_cuda_ext: try: import torch - from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CUDAExtension) + from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) @@ -148,7 +149,7 @@ if build_cuda_ext: extra_cuda_flags = ['-lineinfo'] ext_modules.append( - cuda_ext_helper('colossal_C', [ + cuda_ext_helper('colossalai._C.fused_optim', [ 'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu', 'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu' ], extra_cuda_flags + cc_flag)) @@ -159,21 +160,21 @@ if build_cuda_ext: ] ext_modules.append( - cuda_ext_helper('colossal_scaled_upper_triang_masked_softmax', + cuda_ext_helper('colossalai._C.scaled_upper_triang_masked_softmax', ['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu'], extra_cuda_flags + cc_flag)) ext_modules.append( - cuda_ext_helper('colossal_scaled_masked_softmax', + cuda_ext_helper('colossalai._C.scaled_masked_softmax', ['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'], extra_cuda_flags + cc_flag)) ext_modules.append( - cuda_ext_helper('colossal_moe_cuda', ['moe_cuda.cpp', 'moe_cuda_kernel.cu'], extra_cuda_flags + cc_flag)) + cuda_ext_helper('colossalai._C.moe', ['moe_cuda.cpp', 'moe_cuda_kernel.cu'], extra_cuda_flags + cc_flag)) extra_cuda_flags = ['-maxrregcount=50'] ext_modules.append( - cuda_ext_helper('colossal_layer_norm_cuda', ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu'], + cuda_ext_helper('colossalai._C.layer_norm', ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu'], extra_cuda_flags + cc_flag)) extra_cuda_flags = [ @@ -182,54 +183,53 @@ if build_cuda_ext: ] ext_modules.append( - cuda_ext_helper('colossal_multihead_attention', [ + cuda_ext_helper('colossalai._C.multihead_attention', [ 'multihead_attention_1d.cpp', 'kernels/cublas_wrappers.cu', 'kernels/transform_kernels.cu', 'kernels/dropout_kernels.cu', 'kernels/normalize_kernels.cu', 'kernels/softmax_kernels.cu', 'kernels/general_kernels.cu', 'kernels/cuda_util.cu' ], extra_cuda_flags + cc_flag)) extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native'] - ext_modules.append(cuda_ext_helper('cpu_adam', ['cpu_adam.cpp'], extra_cuda_flags, extra_cxx_flags)) + ext_modules.append(cuda_ext_helper('colossalai._C.cpu_optim', ['cpu_adam.cpp'], extra_cuda_flags, extra_cxx_flags)) -setup( - name='colossalai', - version=get_version(), - packages=find_packages(exclude=( - 'benchmark', - 'docker', - 'tests', - 'docs', - 'examples', - 'tests', - 'scripts', - 'requirements', - '*.egg-info', - )), - description='An integrated large-scale model training system with efficient parallelization techniques', - long_description=fetch_readme(), - long_description_content_type='text/markdown', - license='Apache Software License 2.0', - url='https://www.colossalai.org', - project_urls={ - 'Forum': 'https://github.com/hpcaitech/ColossalAI/discussions', - 'Bug Tracker': 'https://github.com/hpcaitech/ColossalAI/issues', - 'Examples': 'https://github.com/hpcaitech/ColossalAI-Examples', - 'Documentation': 'http://colossalai.readthedocs.io', - 'Github': 'https://github.com/hpcaitech/ColossalAI', - }, - ext_modules=ext_modules, - cmdclass={'build_ext': BuildExtension} if ext_modules else {}, - install_requires=fetch_requirements('requirements/requirements.txt'), - entry_points=''' +setup(name='colossalai', + version=get_version(), + packages=find_packages(exclude=( + 'benchmark', + 'docker', + 'tests', + 'docs', + 'examples', + 'tests', + 'scripts', + 'requirements', + '*.egg-info', + )), + description='An integrated large-scale model training system with efficient parallelization techniques', + long_description=fetch_readme(), + long_description_content_type='text/markdown', + license='Apache Software License 2.0', + url='https://www.colossalai.org', + project_urls={ + 'Forum': 'https://github.com/hpcaitech/ColossalAI/discussions', + 'Bug Tracker': 'https://github.com/hpcaitech/ColossalAI/issues', + 'Examples': 'https://github.com/hpcaitech/ColossalAI-Examples', + 'Documentation': 'http://colossalai.readthedocs.io', + 'Github': 'https://github.com/hpcaitech/ColossalAI', + }, + ext_modules=ext_modules, + cmdclass={'build_ext': BuildExtension} if ext_modules else {}, + install_requires=fetch_requirements('requirements/requirements.txt'), + entry_points=''' [console_scripts] colossalai=colossalai.cli:cli ''', - python_requires='>=3.6', - classifiers=[ - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: Apache Software License', - 'Environment :: GPU :: NVIDIA CUDA', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: System :: Distributed Computing', - ], -) + python_requires='>=3.6', + classifiers=[ + 'Programming Language :: Python :: 3', + 'License :: OSI Approved :: Apache Software License', + 'Environment :: GPU :: NVIDIA CUDA', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: System :: Distributed Computing', + ], + package_data={'colossalai': ['_C/*.pyi']}) diff --git a/tests/test_optimizer/test_cpu_adam.py b/tests/test_optimizer/test_cpu_adam.py index 64149b5a4..dff14fbcc 100644 --- a/tests/test_optimizer/test_cpu_adam.py +++ b/tests/test_optimizer/test_cpu_adam.py @@ -1,4 +1,5 @@ import math + import torch from colossalai.testing import parameterize @@ -66,8 +67,8 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype): exp_avg_sq_copy = exp_avg_sq.clone() try: - import cpu_adam - cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw) + import colossalai._C.cpu_optim + cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw) except: raise ImportError("Import cpu adam error, please install colossal from source code") diff --git a/tests/test_optimizer/test_fused_adam_kernel.py b/tests/test_optimizer/test_fused_adam_kernel.py index 6e0aaf45f..2291b0ce6 100644 --- a/tests/test_optimizer/test_fused_adam_kernel.py +++ b/tests/test_optimizer/test_fused_adam_kernel.py @@ -1,8 +1,8 @@ -from numpy import dtype +import math + import torch import torch.nn as nn - -import math +from numpy import dtype from colossalai.testing import parameterize from colossalai.utils import multi_tensor_applier @@ -47,11 +47,11 @@ def torch_adam_update( @parameterize('g_dtype', [torch.float, torch.half]) def test_adam(adamw, step, p_dtype, g_dtype): try: - import colossal_C - fused_adam = colossal_C.multi_tensor_adam + import colossalai._C.fused_optim + fused_adam = colossalai._C.fused_optim.multi_tensor_adam dummy_overflow_buf = torch.cuda.IntTensor([0]) except: - raise ImportError("No colossal_C kernel installed.") + raise ImportError("No colossalai._C.fused_optim kernel installed.") count = 0