[NFC] Hotfix/format (#984)

* [NFC] Polish colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu code style. (#937)

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h code style (#939)

* [NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.cpp code style (#936)

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h code style (#938)

* [NFC] polish moe_cuda_kernel.cu code style (#940)

Co-authored-by: Xiao Ye <xiaoye2@illinois.edu>

* [NFC] polish pre-commit run --files colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu code style (#943)

* [NFC] polish colossalai/kernel/cuda_native/csrc/moe_cuda.cpp code style (#942)

* [NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.h code style (#945)

* [NFC] polish colossalai/kernel/jit/bias_gelu.py code style (#946)

Co-authored-by: jnbai <897086360@qq.com>

* [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu code style (#949)

Co-authored-by: Jiatong <jiatong.han@u.nus.edu>

* [NFC] polish colossalai/builder/pipeline.py code style (#951)

* [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp code style (#952)

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu code style (#953)

Co-authored-by: 何晓昕 <cautious@hexiaoxins-MacBook-Pro.local>

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu code style (#954)

* [NFC] polish colossalai/kernel/cuda_native/scaled_softmax.py  code style (#955)

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/context.h code style (#956)

Co-authored-by: RichardoLuo <14049555596@qq.com>

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h code style (#957)

* [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu code style (#958)

* [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h code style (#962)

* [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp code style (#959)

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu code style (#963)

Co-authored-by: “Arsmart123 <202476410arsmart@gmail.com>

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h code style (#964)

* [NFC] polish __init__.py code style (#965)

* [NFC] polish colossalai/nn/layer/parallel_3d/layers.py code style (#966)

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h (#968)

code style

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h code style (#970)

* [NFC] polish colossalai/nn/layer/parallel_2p5d/layers.py code style (#972)

* [NFC] polish colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp code style (#973)

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu code style (#974)

* [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu code style (#977)

* [NFC] polish colossalai/nn/layer/parallel_2d/layers.py code style (#976)

* [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu code style (#978)

* [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu code style (#979)

* [NFC] polish colossalai/kernel/cuda_native/layer_norm.py code style (#980)

* [NFC] polish colossalai/nn/layer/utils/common.py code style (#983)

Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
Co-authored-by: yuxuan-lou <83441848+yuxuan-lou@users.noreply.github.com>
Co-authored-by: Geng Zhang <34452939+zxgx@users.noreply.github.com>
Co-authored-by: Maruyama_Aya <38985202+MaruyamaAya@users.noreply.github.com>
Co-authored-by: XYE <92607131+Itok2000u@users.noreply.github.com>
Co-authored-by: Xiao Ye <xiaoye2@illinois.edu>
Co-authored-by: HaoyuQin <79465534+coder-chin@users.noreply.github.com>
Co-authored-by: wky <64853922+wangkuangyi@users.noreply.github.com>
Co-authored-by: bajiaoyu517 <59548007+bajiaoyu517@users.noreply.github.com>
Co-authored-by: luoling-LC <105470086+luoling-LC@users.noreply.github.com>
Co-authored-by: jnbai <897086360@qq.com>
Co-authored-by: JT.Han <59948448+JThh@users.noreply.github.com>
Co-authored-by: Jiatong <jiatong.han@u.nus.edu>
Co-authored-by: xyupeng <99191637+xyupeng@users.noreply.github.com>
Co-authored-by: Sze-qq <68757353+Sze-qq@users.noreply.github.com>
Co-authored-by: Cautiousss <48676630+Cautiousss@users.noreply.github.com>
Co-authored-by: 何晓昕 <cautious@hexiaoxins-MacBook-Pro.local>
Co-authored-by: Luxios22 <67457897+Luxios22@users.noreply.github.com>
Co-authored-by: Wangbo Zhao(黑色枷锁) <56866854+wangbo-zhao@users.noreply.github.com>
Co-authored-by: RichardoLuo <50363844+RichardoLuo@users.noreply.github.com>
Co-authored-by: RichardoLuo <14049555596@qq.com>
Co-authored-by: doubleHU <98150031+huxin711@users.noreply.github.com>
Co-authored-by: runluo <68489000+run-qiao@users.noreply.github.com>
Co-authored-by: MaxT <854721132@qq.com>
Co-authored-by: superhao1995 <804673818@qq.com>
Co-authored-by: ziyu huang <huang0ziyu@gmail.com>
Co-authored-by: “Arsmart123 <202476410arsmart@gmail.com>
Co-authored-by: Yuer867 <62204893+Yuer867@users.noreply.github.com>
Co-authored-by: lucasliunju <lucasliunju@gmail.com>
Co-authored-by: LuGY <74758262+Gy-Lu@users.noreply.github.com>
Co-authored-by: ExtremeViscent <zhangyiqi55732@sina.com>
Co-authored-by: Xu Kai <xukai16@foxmail.com>
Co-authored-by: Zirui Zhu <zhuzr21@gmail.com>
Co-authored-by: Ofey Chan <ofey206@gmail.com>
Co-authored-by: DouJS <dujiangsu@163.com>
Co-authored-by: Jie Zhu <chore.08-protist@icloud.com>
Co-authored-by: shenggan <csg19971016@gmail.com>
Co-authored-by: Kai Wang (Victor Kai) <37533040+kaiwang960112@users.noreply.github.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: Ziheng Qin <37519855+henryqin1997@users.noreply.github.com>
pull/986/head
binmakeswell 2022-05-17 09:54:49 +08:00 committed by GitHub
parent 5898ccf38b
commit 0772828fba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 666 additions and 784 deletions

View File

@ -2,3 +2,4 @@ from .initialize import (initialize, launch, launch_from_openmpi,
launch_from_slurm, launch_from_torch, get_default_parser) launch_from_slurm, launch_from_torch, get_default_parser)
__version__ = '0.0.1' __version__ = '0.0.1'

View File

@ -251,9 +251,9 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks) partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
module_list = [] module_list = []
for start, end in partitions[pipeline_rank]: for start, end in partitions[pipeline_rank]:
module_list.append(nn.Sequential(*[nn.Identity() for _ in range(start)], module_list.append(
*layers[start:end], nn.Sequential(*[nn.Identity() for _ in range(start)], *layers[start:end],
*[nn.Identity() for _ in range(len(layers) - end)])) *[nn.Identity() for _ in range(len(layers) - end)]))
if verbose: if verbose:
logger = get_dist_logger() logger = get_dist_logger()
logger.info(f'Total {len(layers)} layers', ranks=[0]) logger.info(f'Total {len(layers)} layers', ranks=[0])
@ -264,4 +264,3 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n' log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n'
logger.info(log_str, ranks=[0]) logger.info(log_str, ranks=[0])
return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0] return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0]

View File

@ -20,12 +20,14 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE SOFTWARE
*/ */
#include "cpu_adam.h" #include "cpu_adam.h"
#include <iostream>
#include <math.h> #include <math.h>
#include <memory>
#include <omp.h> #include <omp.h>
#include <string.h> #include <string.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <iostream>
#include <memory>
#include <type_traits> #include <type_traits>
#include <unordered_map> #include <unordered_map>
@ -82,8 +84,7 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
for (size_t t = 0; t < rounded_size; t += TILE) { for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > rounded_size) if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
copy_size = rounded_size - t;
size_t offset = copy_size + t; size_t offset = copy_size + t;
#pragma omp parallel for #pragma omp parallel for
@ -145,8 +146,7 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
if (_param_size > rounded_size) { if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) { for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > _param_size) if ((t + TILE) > _param_size) copy_size = _param_size - t;
copy_size = _param_size - t;
size_t offset = copy_size + t; size_t offset = copy_size + t;
#pragma omp parallel for #pragma omp parallel for
@ -235,8 +235,7 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
for (size_t t = 0; t < rounded_size; t += TILE) { for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > rounded_size) if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
copy_size = rounded_size - t;
size_t offset = copy_size + t; size_t offset = copy_size + t;
#pragma omp parallel for #pragma omp parallel for
@ -321,7 +320,6 @@ int create_adam_optimizer(int optimizer_id, float alpha = 1e-3,
s_optimizers[optimizer_id] = opt; s_optimizers[optimizer_id] = opt;
if (should_log) { if (should_log) {
std::string avx_type = ""; std::string avx_type = "";
#if defined(__AVX512__) #if defined(__AVX512__)
avx_type = "AVX512"; avx_type = "AVX512";
@ -386,8 +384,7 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
for (size_t t = 0; t < rounded_size; t += TILE) { for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > rounded_size) if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
copy_size = rounded_size - t;
size_t offset = copy_size + t; size_t offset = copy_size + t;
#pragma omp parallel for #pragma omp parallel for
@ -463,43 +460,29 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
grad_half_precision, loss_scale); grad_half_precision, loss_scale);
} }
int adam_step(int optimizer_id, int adam_step(int optimizer_id, size_t step, float lr, float beta1, float beta2,
size_t step, float epsilon, float weight_decay, bool bias_correction,
float lr, torch::Tensor &params, torch::Tensor &grads,
float beta1, torch::Tensor &exp_avg, torch::Tensor &exp_avg_sq,
float beta2, float loss_scale) {
float epsilon, auto params_c = params.contiguous();
float weight_decay, auto grads_c = grads.contiguous();
bool bias_correction, auto exp_avg_c = exp_avg.contiguous();
torch::Tensor& params, auto exp_avg_sq_c = exp_avg_sq.contiguous();
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq,
float loss_scale)
{
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
float* params_ptr = (float*)params_c.data_ptr(); float *params_ptr = (float *)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr(); float *grads_ptr = (float *)grads_c.data_ptr();
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); float *exp_avg_ptr = (float *)exp_avg_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr();
std::shared_ptr<Adam_Optimizer> opt = std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]); std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step, beta1, beta2); opt->IncrementStep(step, beta1, beta2);
opt->update_state(lr, epsilon, weight_decay, bias_correction); opt->update_state(lr, epsilon, weight_decay, bias_correction);
opt->Step_8(params_ptr, opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
grads_ptr, params_c.numel(), (params.options().dtype() == at::kHalf),
exp_avg_ptr, (grads.options().dtype() == at::kHalf), loss_scale);
exp_avg_sq_ptr,
params_c.numel(),
(params.options().dtype() == at::kHalf),
(grads.options().dtype() == at::kHalf),
loss_scale);
return 0; return 0;
} }
int destroy_adam_optimizer(int optimizer_id) { int destroy_adam_optimizer(int optimizer_id) {

View File

@ -48,10 +48,10 @@ SOFTWARE
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) #define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm512_sqrt_ps(x) #define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y) #define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_LOAD_HALF(x) \ #define SIMD_LOAD_HALF(x) \
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
#define SIMD_STORE_HALF(x, d) \ #define SIMD_STORE_HALF(x, d) \
_mm256_store_ps( \ _mm256_store_ps( \
x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#elif defined(__AVX256__) or defined(__AVX2__) #elif defined(__AVX256__) or defined(__AVX2__)
@ -66,8 +66,8 @@ SOFTWARE
#define SIMD_SQRT(x) _mm256_sqrt_ps(x) #define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y) #define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) #define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define SIMD_STORE_HALF(x, d) \ #define SIMD_STORE_HALF(x, d) \
_mm_store_ps( \ _mm_store_ps( \
x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#endif #endif
@ -83,19 +83,25 @@ union AVX_Data {
#endif #endif
#define STEP(SPAN) \ #define STEP(SPAN) \
void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \ void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \
float *_exp_avg_sq, size_t _param_size, \ float *_exp_avg_sq, size_t _param_size, \
bool param_half_precision = false, \ bool param_half_precision = false, \
bool grad_half_precision = false, float loss_scale = -1); bool grad_half_precision = false, float loss_scale = -1);
class Adam_Optimizer { class Adam_Optimizer {
public: public:
Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999, Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
float eps = 1e-8, float weight_decay = 0, float eps = 1e-8, float weight_decay = 0,
bool adamw_mode = true) bool adamw_mode = true)
: _alpha(alpha), _betta1(betta1), _betta2(betta2), _eps(eps), : _alpha(alpha),
_weight_decay(weight_decay), _betta1_t(1.0), _betta2_t(1.0), _step(0), _betta1(betta1),
_betta2(betta2),
_eps(eps),
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_step(0),
_adamw_mode(adamw_mode) {} _adamw_mode(adamw_mode) {}
~Adam_Optimizer() {} ~Adam_Optimizer() {}
@ -135,7 +141,7 @@ public:
} }
} }
private: private:
float _alpha; float _alpha;
float _betta1; float _betta1;
float _betta2; float _betta2;

View File

@ -16,7 +16,7 @@ __global__ void ls_cross_entropy_fw_kernel(
const int left_idx = block_start + threadIdx.x; const int left_idx = block_start + threadIdx.x;
const int right_idx = (blockIdx.x + 1) * vocab_size; const int right_idx = (blockIdx.x + 1) * vocab_size;
float max_input[1] = {REDUCE_FLOAT_INF_NEG}; float max_input[1] = {REDUCE_FLOAT_INF_NEG};
float sum_logits[2] = {0.f, 0.f}; // logit and logit exp float sum_logits[2] = {0.f, 0.f}; // logit and logit exp
int target_tid = targets[blockIdx.x]; int target_tid = targets[blockIdx.x];
if (target_tid == padding_idx) { if (target_tid == padding_idx) {

View File

@ -1,10 +1,10 @@
#include <cooperative_groups.h>
#include <chrono> #include <chrono>
#include <ctime> #include <ctime>
#include "kernels.h" #include "kernels.h"
#include <cooperative_groups.h>
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
curandStatePhilox4_32_10_t *curandstate; curandStatePhilox4_32_10_t *curandstate;
@ -165,8 +165,7 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) if (i * 4 >= total_count) return;
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
@ -202,8 +201,7 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) if (i * 8 >= total_count) return;
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
@ -261,8 +259,7 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) if (i * 4 >= total_count) return;
return;
uint8_t m[4]; uint8_t m[4];
@ -289,8 +286,7 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) if (i * 8 >= total_count) return;
return;
float4 *out4 = reinterpret_cast<float4 *>(out); float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in); const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
@ -380,8 +376,7 @@ __global__ void ls_dropout_res_bias_kernel(
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) if (i * 4 >= total_count) return;
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
@ -424,8 +419,7 @@ __global__ void ls_dropout_res_bias_kernel(
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) if (i * 8 >= total_count) return;
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
@ -565,11 +559,9 @@ __global__ void ls_dropout_bias_bwd_kernel(
} }
__syncthreads(); __syncthreads();
for (int i = 1; i < 32; i <<= 1) for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i);
sum += g.shfl_down(sum, i);
if (y == 0) if (y == 0) tile[0][x] = sum;
tile[0][x] = sum;
__syncthreads(); __syncthreads();
if (threadIdx.x < 8) { if (threadIdx.x < 8) {
@ -621,11 +613,9 @@ __global__ void ls_dropout_bias_bwd_kernel(
} }
__syncthreads(); __syncthreads();
for (int i = 1; i < WARP_SIZE; i <<= 1) for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
sum += g.shfl_down(sum, i);
if (y == 0) if (y == 0) tile[0][x] = sum;
tile[0][x] = sum;
__syncthreads(); __syncthreads();
if (threadIdx.x < 8) { if (threadIdx.x < 8) {
@ -689,8 +679,7 @@ __global__ void ls_dropout_act_bias_kernel(
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) if (i * 4 >= total_count) return;
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
@ -735,8 +724,7 @@ __global__ void ls_dropout_act_bias_kernel(
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) if (i * 8 >= total_count) return;
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
@ -897,11 +885,9 @@ __global__ void ls_dropout_act_bias_bwd_kernel(
float sum = tile[threadIdx.y][threadIdx.x]; float sum = tile[threadIdx.y][threadIdx.x];
__syncthreads(); __syncthreads();
for (int i = 1; i < WARP_SIZE; i <<= 1) for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
tile[0][threadIdx.y] = sum;
__syncthreads(); __syncthreads();
if (threadIdx.y == 0) { if (threadIdx.y == 0) {

View File

@ -1,7 +1,7 @@
#include "kernels.h"
#include <cooperative_groups.h> #include <cooperative_groups.h>
#include "kernels.h"
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
/** /**

View File

@ -13,22 +13,23 @@ const float REDUCE_FLOAT_INF_NEG = -100000000.f;
const float REDUCE_FLOAT_INF_POS = 100000000.f; const float REDUCE_FLOAT_INF_POS = 100000000.f;
const unsigned int WARP_REDUCE_SIZE = 32; const unsigned int WARP_REDUCE_SIZE = 32;
template <typename T> __forceinline__ __device__ T warpReduceSum(T val) { template <typename T>
__forceinline__ __device__ T warpReduceSum(T val) {
for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1) for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1)
val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE); val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE);
return val; return val;
} }
/* Calculate the sum of all elements in a block */ /* Calculate the sum of all elements in a block */
template <typename T> __forceinline__ __device__ T blockReduceSum(T val) { template <typename T>
__forceinline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32]; static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f; int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5; int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val); val = warpReduceSum<T>(val);
if (lane == 0) if (lane == 0) shared[wid] = val;
shared[wid] = val;
__syncthreads(); __syncthreads();
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f; val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f;
@ -56,10 +57,10 @@ __inline__ __device__ void warpReduce<ReduceType::kMax, 1>(float *pval) {
template <> template <>
__inline__ __device__ void warpReduce<ReduceType::kMax, 2>(float *pval) { __inline__ __device__ void warpReduce<ReduceType::kMax, 2>(float *pval) {
float val0_tmp, val1_tmp; float val0_tmp, val1_tmp;
#define WarpReduceMaxOneStep(a, b) \ #define WarpReduceMaxOneStep(a, b) \
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \ val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
*(pval) = max(val0_tmp, *(pval)); \ *(pval) = max(val0_tmp, *(pval)); \
*(pval + 1) = max(val1_tmp, *(pval + 1)); *(pval + 1) = max(val1_tmp, *(pval + 1));
WarpReduceMaxOneStep(16, 32); WarpReduceMaxOneStep(16, 32);
@ -88,10 +89,10 @@ __inline__ __device__ void warpReduce<ReduceType::kSum, 1>(float *pval) {
template <> template <>
__inline__ __device__ void warpReduce<ReduceType::kSum, 2>(float *pval) { __inline__ __device__ void warpReduce<ReduceType::kSum, 2>(float *pval) {
float val0_tmp, val1_tmp; float val0_tmp, val1_tmp;
#define WarpReduceSumOneStep(a, b) \ #define WarpReduceSumOneStep(a, b) \
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
*(pval + 0) += val0_tmp; \ *(pval + 0) += val0_tmp; \
*(pval + 1) += val1_tmp *(pval + 1) += val1_tmp
WarpReduceSumOneStep(16, 32); WarpReduceSumOneStep(16, 32);
@ -106,14 +107,14 @@ __inline__ __device__ void warpReduce<ReduceType::kSum, 2>(float *pval) {
template <> template <>
__inline__ __device__ void warpReduce<ReduceType::kSum, 4>(float *pval) { __inline__ __device__ void warpReduce<ReduceType::kSum, 4>(float *pval) {
float val0_tmp, val1_tmp, val2_tmp, val3_tmp; float val0_tmp, val1_tmp, val2_tmp, val3_tmp;
#define WarpReduceSumOneStep(a, b) \ #define WarpReduceSumOneStep(a, b) \
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \ val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \
val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \ val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \
*(pval + 0) += val0_tmp; \ *(pval + 0) += val0_tmp; \
*(pval + 1) += val1_tmp; \ *(pval + 1) += val1_tmp; \
*(pval + 2) += val2_tmp; \ *(pval + 2) += val2_tmp; \
*(pval + 3) += val3_tmp *(pval + 3) += val3_tmp
WarpReduceSumOneStep(16, 32); WarpReduceSumOneStep(16, 32);

View File

@ -9,7 +9,7 @@
#include "cuda_util.h" #include "cuda_util.h"
class Context { class Context {
public: public:
Context() : _stream(nullptr) { Context() : _stream(nullptr) {
CHECK_GPU_ERROR(cublasCreate(&_cublasHandle)); CHECK_GPU_ERROR(cublasCreate(&_cublasHandle));
} }
@ -30,7 +30,7 @@ public:
cublasHandle_t get_cublashandle() { return _cublasHandle; } cublasHandle_t get_cublashandle() { return _cublasHandle; }
private: private:
cudaStream_t _stream; cudaStream_t _stream;
cublasHandle_t _cublasHandle; cublasHandle_t _cublasHandle;
}; };

View File

@ -8,8 +8,9 @@
#include "cuda_util.h" #include "cuda_util.h"
template <typename T> class CrossEntropyLayer { template <typename T>
public: class CrossEntropyLayer {
public:
CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens); CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens);
virtual ~CrossEntropyLayer(); virtual ~CrossEntropyLayer();
@ -22,7 +23,7 @@ public:
void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size); void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size);
private: private:
void allocate_mem_buffer() { void allocate_mem_buffer() {
// allocate local gpu memory // allocate local gpu memory
_loss_buffer = cuda_malloc<float>(_max_batch_tokens * 2); _loss_buffer = cuda_malloc<float>(_max_batch_tokens * 2);

View File

@ -20,7 +20,8 @@ void check_gpu_error(T result, char const *const func, const char *const file,
template <typename T> template <typename T>
void print_vec(const T *outv, std::string outn, int num_output_ele); void print_vec(const T *outv, std::string outn, int num_output_ele);
template <typename T> T *cuda_malloc(size_t ele_num); template <typename T>
T *cuda_malloc(size_t ele_num);
void cuda_free(void *pdata); void cuda_free(void *pdata);
@ -28,6 +29,6 @@ template <typename T>
void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf, void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf,
std::string file, int line, cudaStream_t stream); std::string file, int line, cudaStream_t stream);
#define CHECK_NAN_INF(ptr, size, stream) \ #define CHECK_NAN_INF(ptr, size, stream) \
check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \ check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \
check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream)) check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream))

View File

@ -3,12 +3,14 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <stdio.h> #include <stdio.h>
#include <string> #include <string>
#include "kernels.h" #include "kernels.h"
template <typename T> class Dropout { template <typename T>
public: class Dropout {
public:
struct Config { struct Config {
float ratio; float ratio;
bool training; bool training;
@ -88,7 +90,7 @@ public:
void SetTrainingMode(bool training) { _config.training = training; } void SetTrainingMode(bool training) { _config.training = training; }
private: private:
uint8_t *_mask; uint8_t *_mask;
Config _config; Config _config;
}; };

View File

@ -13,14 +13,16 @@
#include "cublas_wrappers.h" #include "cublas_wrappers.h"
#include "kernels.h" #include "kernels.h"
template <typename T> class FeedForward { template <typename T>
public: class FeedForward {
public:
struct Config { struct Config {
int outputSize; int outputSize;
int inputSize; int inputSize;
std::array<int, 3> gemm_algos; std::array<int, 3> gemm_algos;
Config(int outputs, int inputs) Config(int outputs, int inputs)
: outputSize(outputs), inputSize(inputs), : outputSize(outputs),
inputSize(inputs),
gemm_algos(std::array<int, 3>({99, 99, 99})) {} gemm_algos(std::array<int, 3>({99, 99, 99})) {}
}; };
@ -61,6 +63,6 @@ public:
config_.inputSize = inputSize; config_.inputSize = inputSize;
} }
private: private:
Config config_; Config config_;
}; };

View File

@ -10,8 +10,9 @@
using namespace std; using namespace std;
template <typename T> class Softmax { template <typename T>
public: class Softmax {
public:
struct Config { struct Config {
size_t nhead; size_t nhead;
Config(size_t nhead) : nhead(nhead) {} Config(size_t nhead) : nhead(nhead) {}
@ -36,6 +37,6 @@ public:
void reset_size(size_t nhead) { config_.nhead = nhead; } void reset_size(size_t nhead) { config_.nhead = nhead; }
private: private:
Config config_; Config config_;
}; };

View File

@ -1,6 +1,5 @@
#include "block_reduce.h" #include "block_reduce.h"
#include "kernels.h" #include "kernels.h"
#include <cooperative_groups.h> #include <cooperative_groups.h>
namespace cg = cooperative_groups; namespace cg = cooperative_groups;

View File

@ -1,3 +1,4 @@
#include <cooperative_groups.h>
#include <math.h> #include <math.h>
#include <cub/block/block_load.cuh> #include <cub/block/block_load.cuh>
@ -6,8 +7,6 @@
#include "block_reduce.h" #include "block_reduce.h"
#include "kernels.h" #include "kernels.h"
#include <cooperative_groups.h>
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
const float EPSILON = 1e-8f; const float EPSILON = 1e-8f;
@ -120,7 +119,7 @@ __global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len,
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
to_len); to_len);
} }
} // blockIdx.x } // blockIdx.x
} }
template <typename T, int block_dim, int ele_per_thread> template <typename T, int block_dim, int ele_per_thread>
@ -198,7 +197,7 @@ __global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len,
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
to_len); to_len);
} }
} // blockIdx.x } // blockIdx.x
} }
/* /*
@ -304,8 +303,7 @@ __global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) {
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
for (int i = 1; i < WARP_SIZE; i <<= 1) for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
sum += g.shfl_xor(sum, i);
#pragma unroll #pragma unroll
for (int i = 0; i < ITERATIONS; ++i) { for (int i = 0; i < ITERATIONS; ++i) {

View File

@ -2,11 +2,13 @@
* https://github.com/NVIDIA/apex * https://github.com/NVIDIA/apex
* with minor changes. */ * with minor changes. */
#include "compat.h"
#include <cassert>
#include <torch/extension.h> #include <torch/extension.h>
#include <cassert>
#include <vector> #include <vector>
#include "compat.h"
namespace { namespace {
void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1, void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
@ -65,7 +67,7 @@ void check_args(at::Tensor input, at::IntArrayRef normalized_shape,
check_args(input, normalized_shape, n1, n2); check_args(input, normalized_shape, n1, n2);
check_args(normalized_shape, gamma, beta); check_args(normalized_shape, gamma, beta);
} }
} // namespace } // namespace
void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar, void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar,
at::Tensor *input, int n1, int n2, at::Tensor *input, int n1, int n2,
@ -73,17 +75,16 @@ void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar,
at::Tensor *beta, double epsilon); at::Tensor *beta, double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \ #define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \ #define CHECK_INPUT(x) \
CHECK_CUDA(x); \ CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x) CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm_affine(at::Tensor input, std::vector<at::Tensor> layer_norm_affine(at::Tensor input,
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta, at::Tensor gamma, at::Tensor beta,
double epsilon) { double epsilon) {
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(gamma); CHECK_INPUT(gamma);
CHECK_INPUT(beta); CHECK_INPUT(beta);
@ -109,11 +110,10 @@ void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean,
double epsilon, at::Tensor *grad_input, double epsilon, at::Tensor *grad_input,
at::Tensor *grad_gamma, at::Tensor *grad_beta); at::Tensor *grad_gamma, at::Tensor *grad_beta);
std::vector<at::Tensor> std::vector<at::Tensor> layer_norm_gradient_affine(
layer_norm_gradient_affine(at::Tensor dout, at::Tensor mean, at::Tensor invvar, at::Tensor dout, at::Tensor mean, at::Tensor invvar, at::Tensor input,
at::Tensor input, at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta,
at::Tensor gamma, at::Tensor beta, double epsilon) { double epsilon) {
CHECK_INPUT(dout); CHECK_INPUT(dout);
CHECK_INPUT(mean); CHECK_INPUT(mean);
CHECK_INPUT(invvar); CHECK_INPUT(invvar);

View File

@ -15,25 +15,24 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
torch::Tensor logits, torch::Tensor mask, torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx); torch::Tensor dest_idx);
std::vector<torch::Tensor> std::vector<torch::Tensor> moe_combine_cuda_backward(
moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad, int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask,
torch::Tensor mask, torch::Tensor dest_idx); torch::Tensor dest_idx);
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask); torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
#define CHECK_CUDA(x) \ #define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \ #define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \ #define CHECK_INPUT(x) \
CHECK_CUDA(x); \ CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x) CHECK_CONTIGUOUS(x)
torch::Tensor moe_dispatch_forward(int s, int ec, int h, torch::Tensor moe_dispatch_forward(int s, int ec, int h,
torch::Tensor batch_tokens, torch::Tensor batch_tokens,
torch::Tensor mask, torch::Tensor dest_idx) { torch::Tensor mask, torch::Tensor dest_idx) {
CHECK_INPUT(batch_tokens); CHECK_INPUT(batch_tokens);
CHECK_CUDA(mask); CHECK_CUDA(mask);
CHECK_CUDA(dest_idx); CHECK_CUDA(dest_idx);
@ -45,7 +44,6 @@ torch::Tensor moe_dispatch_backward(int s, int ec, int h,
torch::Tensor expert_grad, torch::Tensor expert_grad,
torch::Tensor mask, torch::Tensor mask,
torch::Tensor dest_idx) { torch::Tensor dest_idx) {
CHECK_INPUT(expert_grad); CHECK_INPUT(expert_grad);
CHECK_CUDA(mask); CHECK_CUDA(mask);
CHECK_CUDA(dest_idx); CHECK_CUDA(dest_idx);
@ -57,7 +55,6 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h,
torch::Tensor expert_tokens, torch::Tensor expert_tokens,
torch::Tensor logits, torch::Tensor mask, torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx) { torch::Tensor dest_idx) {
CHECK_INPUT(expert_tokens); CHECK_INPUT(expert_tokens);
CHECK_INPUT(logits); CHECK_INPUT(logits);
CHECK_CUDA(mask); CHECK_CUDA(mask);
@ -67,11 +64,12 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h,
dest_idx); dest_idx);
} }
std::vector<torch::Tensor> std::vector<torch::Tensor> moe_combine_backward(int s, int e, int c, int h,
moe_combine_backward(int s, int e, int c, int h, torch::Tensor tokens_grad, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor expert_tokens,
torch::Tensor mask, torch::Tensor dest_idx) { torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(tokens_grad); CHECK_INPUT(tokens_grad);
CHECK_INPUT(logits); CHECK_INPUT(logits);
CHECK_CUDA(mask); CHECK_CUDA(mask);

View File

@ -1,12 +1,13 @@
#include "block_reduce.h"
#include <cub/cub.cuh>
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <cub/cub.cuh>
#include "block_reduce.h"
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
@ -28,7 +29,6 @@ __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
@ -51,7 +51,6 @@ __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
const int cols) { const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
@ -75,7 +74,6 @@ __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
const int cols) { const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
@ -105,7 +103,6 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight, __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
const int cols) { const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
@ -134,7 +131,6 @@ __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
T *weight_grad, const T weight, const int cols) { T *weight_grad, const T weight, const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
@ -164,15 +160,13 @@ __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
blockReduce<ReduceType::kSum, 1>(&thread_sum); blockReduce<ReduceType::kSum, 1>(&thread_sum);
if (threadIdx.x == 0) if (threadIdx.x == 0) *weight_grad = static_cast<T>(thread_sum);
*weight_grad = static_cast<T>(thread_sum);
} }
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row, __device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row,
const T weight1, const T weight2, const T weight1, const T weight2,
const int cols) { const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
@ -204,7 +198,6 @@ __device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row,
T *tks_row1, T *tks_row2, T *weight_grad1, T *tks_row1, T *tks_row2, T *weight_grad1,
T *weight_grad2, const T weight1, T *weight_grad2, const T weight1,
const T weight2, const int cols) { const T weight2, const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
@ -251,7 +244,6 @@ template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2, __device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2,
const int cols, const int indicator1, const int cols, const int indicator1,
const int indicator2) { const int indicator2) {
if (indicator1 != 0 && indicator2 != 0) if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_fwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2, moe_dpch_two_fwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
cols); cols);
@ -267,7 +259,6 @@ template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2, __device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2,
const int cols, const int indicator1, const int cols, const int indicator1,
const int indicator2) { const int indicator2) {
if (indicator1 != 0 && indicator2 != 0) if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_bwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2, moe_dpch_two_bwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
cols); cols);
@ -283,7 +274,6 @@ template <typename T, int block_size, int pack_size>
__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input, __global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input,
int *mask1, int *mask2, int *dest1, int *mask1, int *mask2, int *dest1,
int *dest2, const int h) { int *dest2, const int h) {
int row = blockIdx.x; int row = blockIdx.x;
int indicator2 = mask2 == nullptr ? 0 : mask2[row]; int indicator2 = mask2 == nullptr ? 0 : mask2[row];
moe_dpch_fwd_selector<T, block_size, pack_size>( moe_dpch_fwd_selector<T, block_size, pack_size>(
@ -295,7 +285,6 @@ template <typename T, int block_size, int pack_size>
__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1, __global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1,
int *mask2, int *dest1, int *dest2, int *mask2, int *dest1, int *dest2,
const int h) { const int h) {
int row = blockIdx.x; int row = blockIdx.x;
int indicator2 = mask2 == nullptr ? 0 : mask2[row]; int indicator2 = mask2 == nullptr ? 0 : mask2[row];
moe_dpch_bwd_selector<T, block_size, pack_size>( moe_dpch_bwd_selector<T, block_size, pack_size>(
@ -310,7 +299,6 @@ __device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row,
const int cols, const T weight1, const int cols, const T weight1,
const T weight2, const int indicator1, const T weight2, const int indicator1,
const int indicator2) { const int indicator2) {
if (indicator1 != 0 && indicator2 != 0) if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_fwd<T, block_size, pack_size>(src_row1, src_row2, dst_row, moe_cb_two_fwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
weight1, weight2, cols); weight1, weight2, cols);
@ -328,7 +316,6 @@ __device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row,
T *wt_grad1, T *wt_grad2, const T weight1, T *wt_grad1, T *wt_grad2, const T weight1,
const T weight2, const int indicator1, const T weight2, const int indicator1,
const int indicator2) { const int indicator2) {
if (indicator1 != 0 && indicator2 != 0) if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_bwd<T, block_size, pack_size>(src_row1, src_row2, dst_row, moe_cb_two_bwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
tks_row1, tks_row2, wt_grad1, tks_row1, tks_row2, wt_grad1,
@ -348,7 +335,6 @@ __global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens,
T *logits, int *mask1, int *mask2, int *dest1, T *logits, int *mask1, int *mask2, int *dest1,
int *dest2, const int e, const int c, int *dest2, const int e, const int c,
const int h) { const int h) {
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row]; int indicator2 = mask2 == nullptr ? 0 : mask2[row];
T *row_log = logits + (row * e); T *row_log = logits + (row * e);
@ -363,7 +349,6 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
T *logits, T *logits_grad, int *mask1, T *logits, T *logits_grad, int *mask1,
int *mask2, int *dest1, int *dest2, int *mask2, int *dest1, int *dest2,
const int e, const int c, const int h) { const int e, const int c, const int h) {
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row]; int indicator2 = mask2 == nullptr ? 0 : mask2[row];
T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e); T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e);
@ -379,7 +364,6 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
template <int block_size, int pack_size> template <int block_size, int pack_size>
__global__ void cumsum_kernel(int *inputs, int *outputs, const int s, __global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
const int e) { const int e) {
assert(s % pack_size == 0); assert(s % pack_size == 0);
constexpr int bpack_size = block_size * pack_size; constexpr int bpack_size = block_size * pack_size;
int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1; int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1;
@ -426,8 +410,7 @@ __global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
} }
__syncthreads(); __syncthreads();
if (tid == 0) if (tid == 0) temp[0] = temp[block_size];
temp[0] = temp[block_size];
__syncthreads(); __syncthreads();
if (idx + tps < s) { if (idx + tps < s) {
@ -453,7 +436,6 @@ template <typename T>
void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1, void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
int *mask2, int *dest1, int *dest2, const int s, int *mask2, int *dest1, int *dest2, const int s,
const int h) { const int h) {
if (h < 256) if (h < 256)
moe_dpch_fwd_kernel<T, 32, 4> moe_dpch_fwd_kernel<T, 32, 4>
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); <<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
@ -474,7 +456,6 @@ void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
template <typename T> template <typename T>
void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2, void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2,
int *dest1, int *dest2, const int s, const int h) { int *dest1, int *dest2, const int s, const int h) {
if (h < 256) if (h < 256)
moe_dpch_bwd_kernel<T, 32, 4> moe_dpch_bwd_kernel<T, 32, 4>
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); <<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
@ -496,7 +477,6 @@ template <typename T>
void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits, void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits,
int *mask1, int *mask2, int *dest1, int *dest2, int *mask1, int *mask2, int *dest1, int *dest2,
const int s, const int e, const int c, const int h) { const int s, const int e, const int c, const int h) {
if (h < 256) if (h < 256)
moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>(expert_tokens, combine_tokens, moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>(expert_tokens, combine_tokens,
logits, mask1, mask2, dest1, dest2, logits, mask1, mask2, dest1, dest2,
@ -524,12 +504,11 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
T *logits_grad, int *mask1, int *mask2, int *dest1, T *logits_grad, int *mask1, int *mask2, int *dest1,
int *dest2, const int s, const int e, const int c, int *dest2, const int s, const int e, const int c,
const int h) { const int h) {
if (h < 256) if (h < 256)
moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, tks, moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, tks,
logits, logits_grad, mask1, mask2, logits, logits_grad, mask1, mask2,
dest1, dest2, e, c, h); dest1, dest2, e, c, h);
else // if (h < 512) else // if (h < 512)
moe_cb_bwd_kernel<T, 64, 4><<<s, 64>>>(tokens_grad, expert_grad, tks, moe_cb_bwd_kernel<T, 64, 4><<<s, 64>>>(tokens_grad, expert_grad, tks,
logits, logits_grad, mask1, mask2, logits, logits_grad, mask1, mask2,
dest1, dest2, e, c, h); dest1, dest2, e, c, h);
@ -544,7 +523,6 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
} }
void cumsum_launch(int *inputs, int *outputs, const int s, const int e) { void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
if (s <= 256) if (s <= 256)
cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e); cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e);
else if (s <= 512) else if (s <= 512)
@ -559,27 +537,26 @@ void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
// API FUNCTIONS -------------------------------- // API FUNCTIONS --------------------------------
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ #define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
switch (TYPE) { \ switch (TYPE) { \
case at::ScalarType::Float: { \ case at::ScalarType::Float: { \
using scalar_t = float; \ using scalar_t = float; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Half: { \ case at::ScalarType::Half: { \
using scalar_t = at::Half; \ using scalar_t = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
default: \ default: \
AT_ERROR(#NAME, " not implemented yet for specific data type."); \ AT_ERROR(#NAME, " not implemented yet for specific data type."); \
} }
torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
torch::Tensor batch_tokens, torch::Tensor batch_tokens,
torch::Tensor mask, torch::Tensor mask,
torch::Tensor dest_idx) { torch::Tensor dest_idx) {
assert(h % 16 == 0); assert(h % 16 == 0);
auto res = torch::zeros( auto res = torch::zeros(
{ec, h}, {ec, h},
@ -601,7 +578,6 @@ torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
torch::Tensor expert_grad, torch::Tensor expert_grad,
torch::Tensor mask, torch::Tensor mask,
torch::Tensor dest_idx) { torch::Tensor dest_idx) {
assert(h % 16 == 0); assert(h % 16 == 0);
auto res = torch::zeros( auto res = torch::zeros(
{s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device())); {s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device()));
@ -622,7 +598,6 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
torch::Tensor expert_tokens, torch::Tensor expert_tokens,
torch::Tensor logits, torch::Tensor mask, torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx) { torch::Tensor dest_idx) {
assert(h % 16 == 0); assert(h % 16 == 0);
assert(expert_tokens.dtype() == logits.dtype()); assert(expert_tokens.dtype() == logits.dtype());
@ -643,11 +618,10 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
return res; return res;
} }
std::vector<torch::Tensor> std::vector<torch::Tensor> moe_combine_cuda_backward(
moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad, int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask,
torch::Tensor mask, torch::Tensor dest_idx) { torch::Tensor dest_idx) {
assert(h % 16 == 0); assert(h % 16 == 0);
assert(tokens_grad.dtype() == expert_tokens.dtype()); assert(tokens_grad.dtype() == expert_tokens.dtype());
assert(expert_tokens.dtype() == logits.dtype()); assert(expert_tokens.dtype() == logits.dtype());
@ -673,7 +647,6 @@ moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
} }
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) { torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) {
assert(mask.dim() == 2); assert(mask.dim() == 2);
assert(mask.dtype() == torch::kInt32); assert(mask.dtype() == torch::kInt32);

View File

@ -16,7 +16,8 @@
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
template <typename T> __device__ __forceinline__ bool is_aligned(T *p) { template <typename T>
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0; return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
} }
@ -28,11 +29,12 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
} }
template <typename x_t> struct L2NormFunctor { template <typename x_t>
__device__ __forceinline__ void struct L2NormFunctor {
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl, __device__ __forceinline__ void operator()(
float *output, float *output_per_tensor, bool per_tensor, int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
int max_chunks_per_tensor) { float *output, float *output_per_tensor, bool per_tensor,
int max_chunks_per_tensor) {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
// return; // return;
@ -48,8 +50,8 @@ template <typename x_t> struct L2NormFunctor {
__shared__ float s_vals[512]; __shared__ float s_vals[512];
float float vals[ILP]; // = {0}; // this probably works too but I want to be
vals[ILP]; // = {0}; // this probably works too but I want to be sure... // sure...
x_t r_x[ILP]; x_t r_x[ILP];
for (int i = 0; i < ILP; i++) { for (int i = 0; i < ILP; i++) {
vals[i] = 0.f; vals[i] = 0.f;
@ -84,15 +86,14 @@ template <typename x_t> struct L2NormFunctor {
} }
float val = 0.f; float val = 0.f;
for (int i = 0; i < ILP; i++) for (int i = 0; i < ILP; i++) val += vals[i];
val += vals[i];
float final = reduce_block_into_lanes(s_vals, val); float final = reduce_block_into_lanes(s_vals, val);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
if (!isfinite(final)) if (!isfinite(final))
*noop_gmem = *noop_gmem =
1; // Blindly fire off a write. These will race but that's ok. 1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] += final; output[blockIdx.x] += final;
if (per_tensor) if (per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) *
@ -104,11 +105,12 @@ template <typename x_t> struct L2NormFunctor {
// Probably better to template, but since we are not likely to support other // Probably better to template, but since we are not likely to support other
// norm // norm
template <typename x_t> struct MaxNormFunctor { template <typename x_t>
__device__ __forceinline__ void struct MaxNormFunctor {
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl, __device__ __forceinline__ void operator()(
float *output, float *output_per_tensor, bool per_tensor, int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
int max_chunks_per_tensor) { float *output, float *output_per_tensor, bool per_tensor,
int max_chunks_per_tensor) {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
// return; // return;
@ -124,8 +126,8 @@ template <typename x_t> struct MaxNormFunctor {
__shared__ float s_vals[512]; __shared__ float s_vals[512];
float float vals[ILP]; // = {0}; // this probably works too but I want to be
vals[ILP]; // = {0}; // this probably works too but I want to be sure... // sure...
x_t r_x[ILP]; x_t r_x[ILP];
for (int i = 0; i < ILP; i++) { for (int i = 0; i < ILP; i++) {
vals[i] = 0.f; vals[i] = 0.f;
@ -160,15 +162,14 @@ template <typename x_t> struct MaxNormFunctor {
} }
float val = 0.f; float val = 0.f;
for (int i = 0; i < ILP; i++) for (int i = 0; i < ILP; i++) val = fmaxf(fabsf(val), fabsf(vals[i]));
val = fmaxf(fabsf(val), fabsf(vals[i]));
float final = reduce_block_into_lanes_max_op(s_vals, val); float final = reduce_block_into_lanes_max_op(s_vals, val);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
if (!isfinite(final)) if (!isfinite(final))
*noop_gmem = *noop_gmem =
1; // Blindly fire off a write. These will race but that's ok. 1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final)); output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final));
if (per_tensor) if (per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) *
@ -185,13 +186,11 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret,
if (blockIdx.x == 0) { if (blockIdx.x == 0) {
float val = 0; float val = 0;
if (threadIdx.x < 320) if (threadIdx.x < 320) val = output[threadIdx.x];
val = output[threadIdx.x];
float final = reduce_block_into_lanes(vals, val); float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) if (threadIdx.x == 0) *ret = sqrt(final);
*ret = sqrt(final);
} }
if (per_tensor) { if (per_tensor) {
@ -204,8 +203,7 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret,
float final = reduce_block_into_lanes(vals, val); float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = sqrt(final);
ret_per_tensor[blockIdx.x] = sqrt(final);
} }
} }
@ -217,17 +215,14 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
if (blockIdx.x == 0) { if (blockIdx.x == 0) {
float val = 0; float val = 0;
if (threadIdx.x < 320) if (threadIdx.x < 320) val = output[threadIdx.x];
val = output[threadIdx.x];
if (norm_type == 0) { if (norm_type == 0) {
float final = reduce_block_into_lanes_max_op(vals, val); float final = reduce_block_into_lanes_max_op(vals, val);
if (threadIdx.x == 0) if (threadIdx.x == 0) *ret = alpha * (*ret) + beta * final;
*ret = alpha * (*ret) + beta * final;
} else { } else {
float final = reduce_block_into_lanes(vals, val); float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) if (threadIdx.x == 0) *ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
*ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
} }
} }
@ -260,10 +255,10 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
} }
} }
std::tuple<at::Tensor, at::Tensor> std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag, int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python) { at::optional<bool> per_tensor_python) {
bool per_tensor = bool per_tensor =
per_tensor_python.has_value() ? per_tensor_python.value() : false; per_tensor_python.has_value() ? per_tensor_python.value() : false;

View File

@ -15,7 +15,8 @@
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
template <typename T> __device__ __forceinline__ bool is_aligned(T *p) { template <typename T>
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0; return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
} }
@ -28,24 +29,25 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
} }
typedef enum { typedef enum {
MOMENT_MODE_0 = 0, // L2 regularization mode MOMENT_MODE_0 = 0, // L2 regularization mode
MOMENT_MODE_1 = 1 // Decoupled weight decay mode MOMENT_MODE_1 = 1 // Decoupled weight decay mode
} adamMode_t; } adamMode_t;
std::tuple<at::Tensor, at::Tensor> std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag, int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python); at::optional<bool> per_tensor_python);
using MATH_T = float; using MATH_T = float;
template <typename T> struct LAMBStage1Functor { template <typename T>
__device__ __forceinline__ void struct LAMBStage1Functor {
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, __device__ __forceinline__ void operator()(
const float beta1, const float beta2, const float beta3, int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
const float beta1_correction, const float beta2_correction, const float beta1, const float beta2, const float beta3,
const float epsilon, adamMode_t mode, const float decay, const float beta1_correction, const float beta2_correction,
const float *global_grad_norm, const float max_global_grad_norm) { const float epsilon, adamMode_t mode, const float decay,
const float *global_grad_norm, const float max_global_grad_norm) {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
// return; // return;
@ -89,8 +91,7 @@ template <typename T> struct LAMBStage1Functor {
i_start += blockDim.x) { i_start += blockDim.x) {
// load // load
load_store(l_g, g, 0, i_start); load_store(l_g, g, 0, i_start);
if (decay != 0) if (decay != 0) load_store(l_p, p, 0, i_start);
load_store(l_p, p, 0, i_start);
load_store(l_m, m, 0, i_start); load_store(l_m, m, 0, i_start);
load_store(l_v, v, 0, i_start); load_store(l_v, v, 0, i_start);
// unpack // unpack
@ -204,12 +205,12 @@ template <typename T> struct LAMBStage1Functor {
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm. // Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value. // It computes new parameter value.
template <typename T> struct LAMBStage2Functor { template <typename T>
__device__ __forceinline__ void struct LAMBStage2Functor {
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl, __device__ __forceinline__ void operator()(
const float *per_tensor_param_norm, int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl,
const float *per_tensor_update_norm, const float learning_rate, const float *per_tensor_param_norm, const float *per_tensor_update_norm,
const float decay, bool use_nvlamb) { const float learning_rate, const float decay, bool use_nvlamb) {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
// return; // return;
@ -310,8 +311,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
// Handle grad averaging mode // Handle grad averaging mode
float beta3 = 1.0f; float beta3 = 1.0f;
if (grad_averaging == 1) if (grad_averaging == 1) beta3 = 1 - beta1;
beta3 = 1 - beta1;
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(),
tensor_lists.begin() + 1); tensor_lists.begin() + 1);
@ -330,7 +330,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
LAMBStage1Functor<scalar_t_0>(), beta1, beta2, LAMBStage1Functor<scalar_t_0>(), beta1, beta2,
beta3, // 1-beta1 or 1 depends on averaging mode beta3, // 1-beta1 or 1 depends on averaging mode
bias_correction1, bias_correction2, epsilon, bias_correction1, bias_correction2, epsilon,
(adamMode_t)mode, weight_decay, (adamMode_t)mode, weight_decay,
global_grad_norm.DATA_PTR<float>(), max_grad_norm);) global_grad_norm.DATA_PTR<float>(), max_grad_norm);)

View File

@ -15,7 +15,8 @@
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
template <typename T> __device__ __forceinline__ bool is_aligned(T *p) { template <typename T>
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0; return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
} }
@ -27,7 +28,8 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
} }
template <typename in_t, typename out_t> struct ScaleFunctor { template <typename in_t, typename out_t>
struct ScaleFunctor {
__device__ __forceinline__ void operator()(int chunk_size, __device__ __forceinline__ void operator()(int chunk_size,
volatile int *noop_gmem, volatile int *noop_gmem,
TensorListMetadata<2> &tl, TensorListMetadata<2> &tl,
@ -76,8 +78,7 @@ template <typename in_t, typename out_t> struct ScaleFunctor {
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++) {
r_in[ii] = 0; r_in[ii] = 0;
int i = i_start + threadIdx.x + ii * blockDim.x; int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) if (i < n && i < chunk_size) r_in[ii] = in[i];
r_in[ii] = in[i];
} }
// note for clarification to future michael: // note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point // From a pure memory dependency perspective, there's likely no point
@ -93,14 +94,13 @@ template <typename in_t, typename out_t> struct ScaleFunctor {
#pragma unroll #pragma unroll
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x; int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) if (i < n && i < chunk_size) out[i] = r_out[ii];
out[i] = r_out[ii];
} }
} }
} }
if (!finite) if (!finite)
*noop_gmem = *noop_gmem =
1; // Blindly fire off a write. These will race but that's ok. 1; // Blindly fire off a write. These will race but that's ok.
} }
}; };

View File

@ -1,14 +1,15 @@
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu // modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include "multi_tensor_apply.cuh"
#include "compat.h"
#include <assert.h> #include <assert.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "compat.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
@ -28,69 +29,53 @@
* wd_after_momentum : apply weight decay _after_ momentum instead of before * wd_after_momentum : apply weight decay _after_ momentum instead of before
**/ **/
template <int N, typename T_grad, typename T_weight> template <int N, typename T_grad, typename T_weight>
struct SGDFunctor struct SGDFunctor {
{ __device__ __forceinline__ void operator()(
__device__ __forceinline__ void operator()( int chunk_size, volatile int *noop_gmem, TensorListMetadata<N> &tl,
int chunk_size, float wd, float momentum, float dampening, float lr, bool nesterov,
volatile int *noop_gmem, bool first_run, bool wd_after_momentum, float scale) {
TensorListMetadata<N> &tl, // Early exit if we don't need to do anything
float wd, if (*noop_gmem) return;
float momentum,
float dampening,
float lr,
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale)
{
// Early exit if we don't need to do anything
if (*noop_gmem)
return;
int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl.sizes[tensor_loc];
T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc]; T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc];
grad_in += chunk_idx * chunk_size; grad_in += chunk_idx * chunk_size;
T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc]; T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc];
weight_in += chunk_idx * chunk_size; weight_in += chunk_idx * chunk_size;
T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc]; T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc];
mom_in += chunk_idx * chunk_size; mom_in += chunk_idx * chunk_size;
at::Half *model_weights_out = nullptr; at::Half *model_weights_out = nullptr;
if (N == 4) if (N == 4) {
{ model_weights_out = (at::Half *)tl.addresses[3][tensor_loc];
model_weights_out = (at::Half *)tl.addresses[3][tensor_loc]; model_weights_out += chunk_idx * chunk_size;
model_weights_out += chunk_idx * chunk_size; }
}
n -= chunk_idx * chunk_size; n -= chunk_idx * chunk_size;
// Non-divergent exit condition for the __syncthreads // Non-divergent exit condition for the __syncthreads
float incoming_grads[ILP]; float incoming_grads[ILP];
float incoming_weights[ILP]; float incoming_weights[ILP];
float incoming_moms[ILP]; float incoming_moms[ILP];
for (int i_start = 0; for (int i_start = 0; i_start < n && i_start < chunk_size;
i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
i_start += blockDim.x * ILP)
{
#pragma unroll #pragma unroll
for (int ii = 0; ii < ILP; ii++) for (int ii = 0; ii < ILP; ii++) {
{ incoming_grads[ii] = 0;
incoming_grads[ii] = 0; incoming_weights[ii] = 0;
incoming_weights[ii] = 0; incoming_moms[ii] = 0;
incoming_moms[ii] = 0; int i = i_start + threadIdx.x + ii * blockDim.x;
int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) {
if (i < n && i < chunk_size) incoming_grads[ii] = static_cast<float>(grad_in[i]) * scale;
{ incoming_weights[ii] = static_cast<float>(weight_in[i]);
incoming_grads[ii] = static_cast<float>(grad_in[i]) * scale; incoming_moms[ii] = static_cast<float>(mom_in[i]);
incoming_weights[ii] = static_cast<float>(weight_in[i]); }
incoming_moms[ii] = static_cast<float>(mom_in[i]); }
}
}
// note for clarification to future michael: // note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling // From a pure memory dependency perspective, there's likely no point unrolling
@ -98,185 +83,128 @@ struct SGDFunctor
// Put another way, the STGs are dependent on the LDGs, but not on each other. // Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though. // There is still compute ILP benefit from unrolling the loop though.
#pragma unroll #pragma unroll
for (int ii = 0; ii < ILP; ii++) for (int ii = 0; ii < ILP; ii++) {
{ int i = i_start + threadIdx.x + ii * blockDim.x;
int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) {
if (i < n && i < chunk_size) // apply weight decay before momentum if necessary
{ if (wd != 0.f && !wd_after_momentum)
// apply weight decay before momentum if necessary incoming_grads[ii] += wd * incoming_weights[ii];
if (wd != 0.f && !wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii];
if (momentum != 0.f) if (momentum != 0.f) {
{ if (!first_run)
if (!first_run) incoming_moms[ii] = incoming_moms[ii] * momentum +
incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii]; (1.f - dampening) * incoming_grads[ii];
else // initialize momentums to current incoming grads else // initialize momentums to current incoming grads
incoming_moms[ii] = incoming_grads[ii]; incoming_moms[ii] = incoming_grads[ii];
if (nesterov) if (nesterov)
incoming_grads[ii] += momentum * incoming_moms[ii]; incoming_grads[ii] += momentum * incoming_moms[ii];
else else
incoming_grads[ii] = incoming_moms[ii]; incoming_grads[ii] = incoming_moms[ii];
} }
// Apply WD after momentum if desired // Apply WD after momentum if desired
if (wd != 0.f && wd_after_momentum) if (wd != 0.f && wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii]; incoming_grads[ii] += wd * incoming_weights[ii];
// adjust the weight and write out // adjust the weight and write out
weight_in[i] += (-lr * incoming_grads[ii]); weight_in[i] += (-lr * incoming_grads[ii]);
// if necessary, write out an fp16 copy of the weights // if necessary, write out an fp16 copy of the weights
if (N == 4) if (N == 4)
model_weights_out[i] = static_cast<at::Half>(weight_in[i]); model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
// also write out the new momentum // also write out the new momentum
if (momentum != 0.f) if (momentum != 0.f) mom_in[i] = incoming_moms[ii];
mom_in[i] = incoming_moms[ii];
}
}
} }
}
} }
}
}; };
void multi_tensor_sgd_cuda( void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
int chunk_size, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor noop_flag, float wd, float momentum, float dampening, float lr,
std::vector<std::vector<at::Tensor>> tensor_lists, bool nesterov, bool first_run,
float wd, bool wd_after_momentum, float scale) {
float momentum, auto num_tensors = tensor_lists.size();
float dampening, auto grad_type = tensor_lists[0][0].scalar_type();
float lr, auto weight_type = tensor_lists[1][0].scalar_type();
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale)
{
auto num_tensors = tensor_lists.size();
auto grad_type = tensor_lists[0][0].scalar_type();
auto weight_type = tensor_lists[1][0].scalar_type();
if (num_tensors == 4) if (num_tensors == 4)
for (int i = 0; i < tensor_lists[3].size(); i++) for (int i = 0; i < tensor_lists[3].size(); i++)
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half, TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
"Additional output tensors should always be fp16."); "Additional output tensors should always be fp16.");
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors"); TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(),
"expected noop flag to be on the same device as tensors");
// We have 3 possibilities to handle here, in terms of // We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy // grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No // 1. fp16, fp16, fp16, No
// 2. fp32, fp32, fp32, No // 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes // 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case // 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// It's easier to hardcode these possibilities than to use // It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where // switches etc. to handle the cross-product of cases where
// we don't want the majority of them. // we don't want the majority of them.
// Case 1. fp16, fp16, fp16, No // Case 1. fp16, fp16, fp16, No
if (grad_type == at::ScalarType::Half && if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Half && weight_type == at::ScalarType::Half && num_tensors == 3) {
num_tensors == 3) multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
{ SGDFunctor<3, at::Half, at::Half>(), wd, momentum,
multi_tensor_apply<3>( dampening, lr, nesterov, first_run, wd_after_momentum,
BLOCK_SIZE, scale);
chunk_size, }
noop_flag, // Case 2. fp16, fp32, fp32, No
tensor_lists, // else if (grad_type == at::ScalarType::Half &&
SGDFunctor<3, at::Half, at::Half>(), // weight_type == at::ScalarType::Float &&
wd, // num_tensors == 3) {
momentum, // multi_tensor_apply<3>(
dampening, // BLOCK_SIZE,
lr, // chunk_size,
nesterov, // noop_flag,
first_run, // tensor_lists,
wd_after_momentum, // SGDFunctor<3, at::Half, float>(),
scale); // wd,
} // momentum,
// Case 2. fp16, fp32, fp32, No // dampening,
// else if (grad_type == at::ScalarType::Half && // lr,
// weight_type == at::ScalarType::Float && // nesterov,
// num_tensors == 3) { // first_run,
// multi_tensor_apply<3>( // wd_after_momentum);
// BLOCK_SIZE, // }
// chunk_size, // Case 2. fp32, fp32, fp32, No
// noop_flag, else if (grad_type == at::ScalarType::Float &&
// tensor_lists, weight_type == at::ScalarType::Float && num_tensors == 3) {
// SGDFunctor<3, at::Half, float>(), multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
// wd, SGDFunctor<3, float, float>(), wd, momentum,
// momentum, dampening, lr, nesterov, first_run, wd_after_momentum,
// dampening, scale);
// lr, }
// nesterov, // Case 3. fp16, fp32, fp32, Yes
// first_run, else if (grad_type == at::ScalarType::Half &&
// wd_after_momentum); weight_type == at::ScalarType::Float && num_tensors == 4) {
// } multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
// Case 2. fp32, fp32, fp32, No SGDFunctor<4, at::Half, float>(), wd, momentum,
else if (grad_type == at::ScalarType::Float && dampening, lr, nesterov, first_run, wd_after_momentum,
weight_type == at::ScalarType::Float && scale);
num_tensors == 3) }
{ // Case 4. fp32, fp32, fp32, Yes
multi_tensor_apply<3>( else if (grad_type == at::ScalarType::Float &&
BLOCK_SIZE, weight_type == at::ScalarType::Float && num_tensors == 4) {
chunk_size, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
noop_flag, SGDFunctor<4, float, float>(), wd, momentum,
tensor_lists, dampening, lr, nesterov, first_run, wd_after_momentum,
SGDFunctor<3, float, float>(), scale);
wd, } else {
momentum, AT_ERROR(
dampening, "multi_tensor_sgd only supports some combinations of gradient & weight "
lr, "types. Given: ",
nesterov, "gradient: ", grad_type, ", weight: ", weight_type,
first_run, ", num_lists: ", num_tensors);
wd_after_momentum, }
scale);
}
// Case 3. fp16, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, at::Half, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
// Case 4. fp32, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, float, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
else
{
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
}
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
} }

View File

@ -10,8 +10,9 @@
#include "kernels.h" #include "kernels.h"
template <typename T> template <typename T>
MultiHeadAttention<T>::MultiHeadAttention(int layer_id, int max_batch_tokens, int max_seq_len, MultiHeadAttention<T>::MultiHeadAttention(int layer_id, int max_batch_tokens,
int hidden_size, int num_heads, int max_seq_len, int hidden_size,
int num_heads,
float attn_prob_dropout_ratio, float attn_prob_dropout_ratio,
float hidden_output_dropout_ratio, float hidden_output_dropout_ratio,
bool pre_or_postLayerNorm) bool pre_or_postLayerNorm)
@ -22,18 +23,22 @@ MultiHeadAttention<T>::MultiHeadAttention(int layer_id, int max_batch_tokens, in
_heads(num_heads), _heads(num_heads),
_training(true), _training(true),
_pre_or_postLayerNorm(pre_or_postLayerNorm), _pre_or_postLayerNorm(pre_or_postLayerNorm),
_qkv_linear(typename FeedForward<T>::Config(3 * hidden_size, hidden_size)), _qkv_linear(
_attn_out_linear(typename FeedForward<T>::Config(hidden_size, hidden_size)), typename FeedForward<T>::Config(3 * hidden_size, hidden_size)),
_attn_ln(typename Normalize_Layer<T>::Config(hidden_size, false), _max_batch_tokens), _attn_out_linear(
typename FeedForward<T>::Config(hidden_size, hidden_size)),
_attn_ln(typename Normalize_Layer<T>::Config(hidden_size, false),
_max_batch_tokens),
_softmax(typename Softmax<T>::Config(num_heads)), _softmax(typename Softmax<T>::Config(num_heads)),
_attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio), _attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio),
_max_batch_tokens * _heads * _max_seq_len), _max_batch_tokens * _heads * _max_seq_len),
_attn_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio), _attn_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio),
_max_batch_tokens * _hidden_size), _max_batch_tokens * _hidden_size),
_attn_scores(typename StridedBatchGemm<T>::Config((T(1.0) / T(sqrt(_hidden_size / _heads))), _attn_scores(typename StridedBatchGemm<T>::Config(
T(0.0), CUBLAS_OP_T, CUBLAS_OP_N)), (T(1.0) / T(sqrt(_hidden_size / _heads))), T(0.0), CUBLAS_OP_T,
_attn_context( CUBLAS_OP_N)),
typename StridedBatchGemm<T>::Config(T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) { _attn_context(typename StridedBatchGemm<T>::Config(
T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) {
assert(_hidden_size % _heads == 0); assert(_hidden_size % _heads == 0);
} }
@ -43,43 +48,52 @@ MultiHeadAttention<T>::~MultiHeadAttention() {
} }
template <typename T> template <typename T>
void MultiHeadAttention<T>::attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, void MultiHeadAttention<T>::attn_layer_fw(const T *input_ptr,
const T *input_mask_ptr,
T *output_ptr, T *buffer) { T *output_ptr, T *buffer) {
T *q_tf_ptr = _qkv_ptr; T *q_tf_ptr = _qkv_ptr;
T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size; T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size;
T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size; T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size;
if (_pre_or_postLayerNorm) { if (_pre_or_postLayerNorm) {
_attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, _attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr,
_stream); _batch_tokens, _stream);
} }
const T *gemmQKV_inp_ptr = _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; const T *gemmQKV_inp_ptr =
_pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
_qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size); _qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size);
_qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer, _cublasHandle); _qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer,
_cublasHandle);
launch_bias_add_transform_20314<T>(q_tf_ptr, buffer, _attn_qkvb_ptr, _batch_size, _seq_len, 3, launch_bias_add_transform_20314<T>(q_tf_ptr, buffer, _attn_qkvb_ptr,
_heads / pg_size, _hidden_size / _heads, _stream); _batch_size, _seq_len, 3, _heads / pg_size,
_hidden_size / _heads, _stream);
// attention scores, q*k // attention scores, q*k
_attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle); _attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr,
_cublasHandle);
// Softmax + Mask // Softmax + Mask
_softmax.reset_size(_heads / pg_size); _softmax.reset_size(_heads / pg_size);
_softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len, _seq_len, _stream, true); _softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len,
_seq_len, _stream, true);
// attn prob dropout. // attn prob dropout.
_attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr, _batch_heads * _seq_len * _seq_len, _attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr,
_stream); _batch_heads * _seq_len * _seq_len, _stream);
// attention context, score * v // attention context, score * v
_attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle); _attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr,
_cublasHandle);
// [b, nh, s, ad] -> [b, s, nh, ad] // [b, nh, s, ad] -> [b, s, nh, ad]
launch_transform4d_0213<T>(_attn_o_inp_ptr, buffer, _batch_size, _seq_len, _hidden_size / pg_size, launch_transform4d_0213<T>(_attn_o_inp_ptr, buffer, _batch_size, _seq_len,
_heads / pg_size, 1, _stream); _hidden_size / pg_size, _heads / pg_size, 1,
_stream);
_attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size); _attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size);
_attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr, output_ptr, _cublasHandle); _attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr,
output_ptr, _cublasHandle);
// allreduce // allreduce
if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) { if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) {
@ -88,24 +102,27 @@ void MultiHeadAttention<T>::attn_layer_fw(const T *input_ptr, const T *input_mas
if (typeid(T) != typeid(float)) { if (typeid(T) != typeid(float)) {
data_type = torch::kHalf; data_type = torch::kHalf;
} }
auto output_tensor = auto output_tensor = torch::from_blob(
torch::from_blob(output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)}, output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)},
torch::TensorOptions(torch::kCUDA).dtype(data_type)); torch::TensorOptions(torch::kCUDA).dtype(data_type));
std::vector<torch::Tensor> allreduce_tensors = {output_tensor}; std::vector<torch::Tensor> allreduce_tensors = {output_tensor};
auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions()); auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions());
work->wait(); work->wait();
} }
_attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr, _attn_ob_ptr, _attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr,
_batch_tokens, _hidden_size, _stream); _attn_ob_ptr, _batch_tokens, _hidden_size,
_stream);
if (!_pre_or_postLayerNorm) { if (!_pre_or_postLayerNorm) {
// in-place ln since ln-input will not be used in post-ln mode // in-place ln since ln-input will not be used in post-ln mode
_attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, _stream); _attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr,
_batch_tokens, _stream);
} }
} }
template <typename T> template <typename T>
void MultiHeadAttention<T>::Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr) { void MultiHeadAttention<T>::Forward(const T *input_ptr, const T *input_mask_ptr,
T *out_ptr) {
_stream = Context::Instance().get_stream(); _stream = Context::Instance().get_stream();
_cublasHandle = Context::Instance().get_cublashandle(); _cublasHandle = Context::Instance().get_cublashandle();
T *attn_buffer = _shared_mem_ptr; // 3 * _batch_dim T *attn_buffer = _shared_mem_ptr; // 3 * _batch_dim
@ -114,8 +131,11 @@ void MultiHeadAttention<T>::Forward(const T *input_ptr, const T *input_mask_ptr,
} }
template <typename T> template <typename T>
void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, const T *output_ptr, void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr,
const T *grad_output_ptr, T *grad_input_ptr, T *buffer) { const T *input_mask_ptr,
const T *output_ptr,
const T *grad_output_ptr,
T *grad_input_ptr, T *buffer) {
cudaStream_t streams[2] = {_stream, _stream}; cudaStream_t streams[2] = {_stream, _stream};
const T *q_tf_ptr = _qkv_ptr; const T *q_tf_ptr = _qkv_ptr;
@ -137,45 +157,57 @@ void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr, const T *input_mas
// batch_size * head_num * seq_len * seq_len); // batch_size * head_num * seq_len * seq_len);
if (_pre_or_postLayerNorm) { if (_pre_or_postLayerNorm) {
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, grad_output_ptr, _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr,
_batch_tokens, _hidden_size, _stream); grad_output_ptr, _batch_tokens,
_hidden_size, _stream);
} else { } else {
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr, grad_output_ptr, _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr,
nullptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams); grad_output_ptr, nullptr, output_ptr, _attn_nw_ptr,
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, grad_residual_ptr, _attn_nb_ptr, _batch_tokens, streams);
_batch_tokens, _hidden_size, _stream); _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr,
grad_residual_ptr, _batch_tokens,
_hidden_size, _stream);
} }
// bw of output project // bw of output project
_attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size); _attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size);
_attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr, _attn_ow_ptr, _attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr,
_grad_attn_ow_ptr, _grad_attn_ob_ptr, _cublasHandle, _stream, _attn_ow_ptr, _grad_attn_ow_ptr, _grad_attn_ob_ptr,
grad_input_buf_ptr, nullptr, false); _cublasHandle, _stream, grad_input_buf_ptr, nullptr,
launch_transform_0213<T>(grad_input_ptr, grad_input_buf_ptr, _batch_size, _seq_len, false);
_hidden_size / pg_size, _heads / pg_size, _stream); launch_transform_0213<T>(grad_input_ptr, grad_input_buf_ptr, _batch_size,
_seq_len, _hidden_size / pg_size, _heads / pg_size,
_stream);
// bw of score * v // bw of score * v
_attn_context.Backward(_batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle, _attn_context.Backward(
grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr); _batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle,
grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr);
_attn_prob_dropout.d_dropout(grad_softmax_ptr, _batch_heads * _seq_len * _seq_len, _stream); _attn_prob_dropout.d_dropout(grad_softmax_ptr,
_batch_heads * _seq_len * _seq_len, _stream);
_softmax.reset_size(_heads / pg_size); _softmax.reset_size(_heads / pg_size);
_softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len, _seq_len, _stream); _softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len,
_seq_len, _stream);
// bw of q * k // bw of q * k
_attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle, _attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr,
grad_qkv_5d_ptr + _batch_dim / pg_size, grad_qkv_5d_ptr); _cublasHandle, grad_qkv_5d_ptr + _batch_dim / pg_size,
grad_qkv_5d_ptr);
// [3, b, nh, s, ad] -> [b, s, 3, h] // [3, b, nh, s, ad] -> [b, s, 3, h]
launch_transform4d_0213<T>(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size, _seq_len, launch_transform4d_0213<T>(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size,
_hidden_size / pg_size, _heads / pg_size, 3, _stream); _seq_len, _hidden_size / pg_size, _heads / pg_size,
3, _stream);
const T *gemmQKV_inp_ptr = _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; const T *gemmQKV_inp_ptr =
_pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
_qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size); _qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size);
_qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr, _attn_qkvw_ptr, _qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr,
_grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr, _cublasHandle, _stream, _attn_qkvw_ptr, _grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr,
grad_input_buf_ptr, nullptr, true); _cublasHandle, _stream, grad_input_buf_ptr, nullptr,
true);
// allreduce // allreduce
if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) { if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) {
@ -185,7 +217,8 @@ void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr, const T *input_mas
data_type = torch::kHalf; data_type = torch::kHalf;
} }
auto grad_input_tensor = auto grad_input_tensor =
torch::from_blob(grad_input_buf_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)}, torch::from_blob(grad_input_buf_ptr,
{int(_batch_size), int(_seq_len), int(_hidden_size)},
torch::TensorOptions(torch::kCUDA).dtype(data_type)); torch::TensorOptions(torch::kCUDA).dtype(data_type));
std::vector<torch::Tensor> allreduce_tensors = {grad_input_tensor}; std::vector<torch::Tensor> allreduce_tensors = {grad_input_tensor};
auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions()); auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions());
@ -193,19 +226,21 @@ void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr, const T *input_mas
} }
if (_pre_or_postLayerNorm) { if (_pre_or_postLayerNorm) {
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr, grad_input_buf_ptr, _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr,
grad_output_ptr, gemmQKV_inp_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, grad_input_buf_ptr, grad_output_ptr, gemmQKV_inp_ptr,
streams); _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams);
} else { } else {
// FIXME later // FIXME later
launch_fused_add2<T>(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr, _batch_size, launch_fused_add2<T>(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr,
_seq_len, _hidden_size, _stream); _batch_size, _seq_len, _hidden_size, _stream);
} }
} }
template <typename T> template <typename T>
void MultiHeadAttention<T>::Backward(const T *grad_output_ptr, const T *input_ptr, const T *output_ptr, void MultiHeadAttention<T>::Backward(const T *grad_output_ptr,
const T *input_mask_ptr, T *grad_input_ptr) { const T *input_ptr, const T *output_ptr,
const T *input_mask_ptr,
T *grad_input_ptr) {
_stream = Context::Instance().get_stream(); _stream = Context::Instance().get_stream();
_cublasHandle = Context::Instance().get_cublashandle(); _cublasHandle = Context::Instance().get_cublashandle();
T *buffer = _shared_mem_ptr; T *buffer = _shared_mem_ptr;
@ -215,7 +250,8 @@ void MultiHeadAttention<T>::Backward(const T *grad_output_ptr, const T *input_pt
4 * _batch_dim + max(3 * _batch_dim, 4 * _batch_dim + max(3 * _batch_dim,
_batch_size * _head_num * _seq_len * _seq_len); _batch_size * _head_num * _seq_len * _seq_len);
*/ */
attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr, grad_input_ptr, buffer); attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr,
grad_input_ptr, buffer);
} }
template <typename T> template <typename T>
@ -233,7 +269,8 @@ template class MultiHeadAttention<__half>;
// x is torch::Tensor // x is torch::Tensor
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \ #define CHECK_INPUT(x) \
CHECK_CUDA(x); \ CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x) CHECK_CONTIGUOUS(x)
@ -241,15 +278,17 @@ template class MultiHeadAttention<__half>;
static std::unordered_map<int, std::shared_ptr<void>> s_multihead_attention; static std::unordered_map<int, std::shared_ptr<void>> s_multihead_attention;
template <typename T> template <typename T>
int create_multihead_attention(int layer_id, int max_batch_tokens, int max_seq_len, int hidden_dim, int create_multihead_attention(int layer_id, int max_batch_tokens,
int num_heads, float attn_prob_dropout_ratio, int max_seq_len, int hidden_dim, int num_heads,
float hidden_dropout_ratio, bool pre_or_postLayerNorm, float attn_prob_dropout_ratio,
float hidden_dropout_ratio,
bool pre_or_postLayerNorm,
c10::intrusive_ptr<c10d::ProcessGroup> pg_) { c10::intrusive_ptr<c10d::ProcessGroup> pg_) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Context::Instance().set_stream(stream); Context::Instance().set_stream(stream);
auto layer = std::make_shared<MultiHeadAttention<T>>( auto layer = std::make_shared<MultiHeadAttention<T>>(
layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads, attn_prob_dropout_ratio, layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads,
hidden_dropout_ratio, pre_or_postLayerNorm); attn_prob_dropout_ratio, hidden_dropout_ratio, pre_or_postLayerNorm);
layer->SetPG(pg_); layer->SetPG(pg_);
@ -261,15 +300,12 @@ int create_multihead_attention(int layer_id, int max_batch_tokens, int max_seq_l
} }
template <typename T> template <typename T>
std::vector<torch::Tensor> multihead_attention_fw(int layer_id, const torch::Tensor &input, std::vector<torch::Tensor> multihead_attention_fw(
const torch::Tensor &input_mask, int layer_id, const torch::Tensor &input, const torch::Tensor &input_mask,
const torch::Tensor &in_proj_weight, const torch::Tensor &in_proj_weight, const torch::Tensor &in_proj_bias,
const torch::Tensor &in_proj_bias, const torch::Tensor &out_proj_weight, const torch::Tensor &out_proj_bias,
const torch::Tensor &out_proj_weight, const torch::Tensor &norm_weight, const torch::Tensor &norm_bias,
const torch::Tensor &out_proj_bias, bool training_mode, bool prelayernorm) {
const torch::Tensor &norm_weight,
const torch::Tensor &norm_bias,
bool training_mode, bool prelayernorm) {
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(input_mask); CHECK_INPUT(input_mask);
@ -280,7 +316,8 @@ std::vector<torch::Tensor> multihead_attention_fw(int layer_id, const torch::Ten
T *out_ptr = (T *)output.data_ptr(); T *out_ptr = (T *)output.data_ptr();
std::shared_ptr<MultiHeadAttention<T>> layer = std::shared_ptr<MultiHeadAttention<T>> layer =
std::static_pointer_cast<MultiHeadAttention<T>>(s_multihead_attention[layer_id]); std::static_pointer_cast<MultiHeadAttention<T>>(
s_multihead_attention[layer_id]);
layer->set_cur_batch_shape(input.size(0), input.size(1)); layer->set_cur_batch_shape(input.size(0), input.size(1));
layer->SetTrainingMode(training_mode); layer->SetTrainingMode(training_mode);
@ -297,17 +334,13 @@ std::vector<torch::Tensor> multihead_attention_fw(int layer_id, const torch::Ten
} }
template <typename T> template <typename T>
std::vector<torch::Tensor> multihead_attention_bw(int layer_id, std::vector<torch::Tensor> multihead_attention_bw(
const torch::Tensor &grad_dec_output, int layer_id, const torch::Tensor &grad_dec_output,
const torch::Tensor &output, const torch::Tensor &output, const torch::Tensor &input,
const torch::Tensor &input, const torch::Tensor &input_mask, const torch::Tensor &in_proj_weight,
const torch::Tensor &input_mask, const torch::Tensor &in_proj_bias, const torch::Tensor &out_proj_weight,
const torch::Tensor &in_proj_weight, const torch::Tensor &out_proj_bias, const torch::Tensor &norm_weight,
const torch::Tensor &in_proj_bias, const torch::Tensor &norm_bias) {
const torch::Tensor &out_proj_weight,
const torch::Tensor &out_proj_bias,
const torch::Tensor &norm_weight,
const torch::Tensor &norm_bias) {
auto g_output = grad_dec_output.contiguous(); auto g_output = grad_dec_output.contiguous();
CHECK_INPUT(g_output); CHECK_INPUT(g_output);
CHECK_INPUT(output); CHECK_INPUT(output);
@ -332,7 +365,8 @@ std::vector<torch::Tensor> multihead_attention_bw(int layer_id,
T *grad_input_ptr = (T *)grad_input.data_ptr(); T *grad_input_ptr = (T *)grad_input.data_ptr();
std::shared_ptr<MultiHeadAttention<T>> layer = std::shared_ptr<MultiHeadAttention<T>> layer =
std::static_pointer_cast<MultiHeadAttention<T>>(s_multihead_attention[layer_id]); std::static_pointer_cast<MultiHeadAttention<T>>(
s_multihead_attention[layer_id]);
layer->set_cur_batch_shape(g_output.size(0), g_output.size(1)); layer->set_cur_batch_shape(g_output.size(0), g_output.size(1));
layer->_grad_attn_qkvw_ptr = (T *)grad_in_proj_weight.data_ptr(); layer->_grad_attn_qkvw_ptr = (T *)grad_in_proj_weight.data_ptr();
@ -342,10 +376,12 @@ std::vector<torch::Tensor> multihead_attention_bw(int layer_id,
layer->_grad_attn_nw_ptr = (T *)grad_norm_weight.data_ptr(); layer->_grad_attn_nw_ptr = (T *)grad_norm_weight.data_ptr();
layer->_grad_attn_nb_ptr = (T *)grad_norm_bias.data_ptr(); layer->_grad_attn_nb_ptr = (T *)grad_norm_bias.data_ptr();
layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr, grad_input_ptr); layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr,
grad_input_ptr);
return {grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, return {grad_input, grad_in_proj_weight, grad_in_proj_bias,
grad_out_proj_bias, grad_norm_weight, grad_norm_bias}; grad_out_proj_weight, grad_out_proj_bias, grad_norm_weight,
grad_norm_bias};
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

View File

@ -19,21 +19,25 @@
template <typename T> template <typename T>
class MultiHeadAttention { class MultiHeadAttention {
public: public:
MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len, int hidden_size, MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len,
int num_heads, float attn_dropout_ratio, float hidden_output_dropout_ratio, int hidden_size, int num_heads, float attn_dropout_ratio,
float hidden_output_dropout_ratio,
bool pre_or_postLayerNorm); bool pre_or_postLayerNorm);
virtual ~MultiHeadAttention(); virtual ~MultiHeadAttention();
void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr); void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr);
void Backward(const T *grad_output_ptr, const T *input_ptr, const T *output_ptr, void Backward(const T *grad_output_ptr, const T *input_ptr,
const T *input_mask_ptr, T *grad_input_ptr); const T *output_ptr, const T *input_mask_ptr,
T *grad_input_ptr);
void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, T *buffer); void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr,
T *buffer);
void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, const T *output_ptr, void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr,
const T *grad_output_ptr, T *grad_input_attn_layer_bwptr, T *buffer); const T *output_ptr, const T *grad_output_ptr,
T *grad_input_attn_layer_bwptr, T *buffer);
void set_cur_batch_shape(int batch_size, int seq_len) { void set_cur_batch_shape(int batch_size, int seq_len) {
_batch_size = batch_size; _batch_size = batch_size;
@ -83,14 +87,17 @@ class MultiHeadAttention {
} }
_qkv_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size * 3); _qkv_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size * 3);
_soft_out_ptr = cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len); _soft_out_ptr =
_ctx_bufB_ptr = cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len); cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
_ctx_bufB_ptr =
cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
_attn_o_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size); _attn_o_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size);
// buffer size needed by attn bw // buffer size needed by attn bw
size_t smem_size = 4 * _max_batch_tokens * _hidden_size / pg_size + size_t smem_size =
std::max(3 * _max_batch_tokens * _hidden_size / pg_size, 4 * _max_batch_tokens * _hidden_size / pg_size +
_max_batch_tokens * _heads / pg_size * _max_seq_len); std::max(3 * _max_batch_tokens * _hidden_size / pg_size,
_max_batch_tokens * _heads / pg_size * _max_seq_len);
if (!_shared_mem_ptr) { if (!_shared_mem_ptr) {
cuda_free(_shared_mem_ptr); cuda_free(_shared_mem_ptr);

View File

@ -2,12 +2,13 @@
* with minor changes. */ * with minor changes. */
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h> #include <cuda_runtime.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "scaled_masked_softmax.h" #include "scaled_masked_softmax.h"
#include "type_shim.h" #include "type_shim.h"
@ -15,17 +16,15 @@ namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace scaled_masked_softmax { namespace scaled_masked_softmax {
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches,
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); int attn_heads) {
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
} }
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
torch::Tensor fwd_cuda( float scale_factor) {
torch::Tensor const& input, // input is a 4d tensor with dimensions [batches, attn_heads, seq_len,
torch::Tensor const& mask, // seq_len]
float scale_factor)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0); const int batches = input.size(0);
const int pad_batches = mask.size(0); const int pad_batches = mask.size(0);
const int attn_heads = input.size(1); const int attn_heads = input.size(1);
@ -38,10 +37,10 @@ torch::Tensor fwd_cuda(
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output // Output
auto act_options = input.options().requires_grad(false); auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results = torch::Tensor softmax_results = torch::empty(
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); {batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr // Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr()); void* input_ptr = static_cast<void*>(input.data_ptr());
@ -49,31 +48,23 @@ torch::Tensor fwd_cuda(
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT( DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(), input.scalar_type(), "dispatch_scaled_masked_softmax_forward",
"dispatch_scaled_masked_softmax_forward",
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>( dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr), reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr), reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr), reinterpret_cast<const uint8_t*>(mask_ptr), scale_factor,
scale_factor, query_seq_len, key_seq_len, batches, attn_heads, pad_batches););
query_seq_len,
key_seq_len,
batches,
attn_heads,
pad_batches);
);
return softmax_results; return softmax_results;
} }
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
torch::Tensor const& output_grads_, torch::Tensor const& softmax_results_,
torch::Tensor const& softmax_results_, float scale_factor) {
float scale_factor) {
auto output_grads = output_grads_.contiguous(); auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous(); auto softmax_results = softmax_results_.contiguous();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len,
// seq_len]
const int batches = output_grads.size(0); const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1); const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2); const int query_seq_len = output_grads.size(2);
@ -81,24 +72,18 @@ torch::Tensor bwd_cuda(
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr()); void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad // Softmax Grad
DISPATCH_HALF_AND_BFLOAT( DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(), output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward",
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>( dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr), reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr), reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()), reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor, scale_factor, query_seq_len, key_seq_len, batches, attn_heads););
query_seq_len,
key_seq_len, // backward pass is completely in-place
batches,
attn_heads);
);
//backward pass is completely in-place
return output_grads; return output_grads;
} }
} } // namespace scaled_masked_softmax
} } // namespace fused_softmax
} } // namespace multihead_attn

View File

@ -3,57 +3,52 @@
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <vector> #include <vector>
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax { namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda( torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor);
torch::Tensor const& input,
float scale_factor);
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
torch::Tensor const& output_grads, torch::Tensor const& softmax_results,
torch::Tensor const& softmax_results, float scale_factor);
float scale_factor);
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16), (input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported"); "Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor); return fwd_cuda(input, scale_factor);
} }
torch::Tensor bwd( torch::Tensor bwd(torch::Tensor const& output_grads,
torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor) {
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16), (output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported"); "Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16), (softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported"); "Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor); return bwd_cuda(output_grads, softmax_results, scale_factor);
} }
} // end namespace scaled_upper_triang_masked_softmax } // end namespace scaled_upper_triang_masked_softmax
} // end namespace fused_softmax } // end namespace fused_softmax
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", m.def("forward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward."); "Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward", m.def("backward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward."); "Self Multihead Attention scaled, time masked softmax -- Backward.");
} }

View File

@ -2,12 +2,13 @@
* with minor changes. */ * with minor changes. */
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h> #include <cuda_runtime.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h" #include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h" #include "type_shim.h"
@ -15,18 +16,15 @@ namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax { namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda( torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) {
torch::Tensor const& input,
float scale_factor)
{
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = input.size(0); const int attn_batches = input.size(0);
const int seq_len = input.size(1); const int seq_len = input.size(1);
TORCH_INTERNAL_ASSERT(seq_len <= 2048); TORCH_INTERNAL_ASSERT(seq_len <= 2048);
// Output // Output
auto act_options = input.options().requires_grad(false); auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results = torch::Tensor softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options); torch::empty({attn_batches, seq_len, seq_len}, act_options);
// Softmax Intermediate Result Ptr // Softmax Intermediate Result Ptr
@ -36,50 +34,42 @@ torch::Tensor fwd_cuda(
DISPATCH_HALF_AND_BFLOAT( DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(), input.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_forward", "dispatch_scaled_upper_triang_masked_softmax_forward",
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>( dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t,
reinterpret_cast<scalar_t*>(softmax_results_ptr), float>(
reinterpret_cast<const scalar_t*>(input_ptr), reinterpret_cast<scalar_t*>(softmax_results_ptr),
scale_factor, reinterpret_cast<const scalar_t*>(input_ptr), scale_factor, seq_len,
seq_len, seq_len, attn_batches););
seq_len,
attn_batches);
);
return softmax_results; return softmax_results;
} }
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
torch::Tensor const& output_grads_, torch::Tensor const& softmax_results_,
torch::Tensor const& softmax_results_, float scale_factor) {
float scale_factor) {
auto output_grads = output_grads_.contiguous(); auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous(); auto softmax_results = softmax_results_.contiguous();
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] // output grads is a 3d tensor with dimensions [attn_batches, seq_len,
// seq_len]
const int attn_batches = output_grads.size(0); const int attn_batches = output_grads.size(0);
const int seq_len = output_grads.size(1); const int seq_len = output_grads.size(1);
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr()); void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad // Softmax Grad
DISPATCH_HALF_AND_BFLOAT( DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(), output_grads_.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_backward", "dispatch_scaled_upper_triang_masked_softmax_backward",
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>( dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t,
reinterpret_cast<scalar_t*>(output_grads_ptr), float>(
reinterpret_cast<scalar_t*>(output_grads_ptr), reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()), reinterpret_cast<scalar_t*>(output_grads_ptr),
scale_factor, reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
seq_len, scale_factor, seq_len, seq_len, attn_batches););
seq_len,
attn_batches); // backward pass is completely in-place
);
//backward pass is completely in-place
return output_grads; return output_grads;
} }
} } // namespace scaled_upper_triang_masked_softmax
} } // namespace fused_softmax
} } // namespace multihead_attn

View File

@ -24,8 +24,8 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
input_ = input.contiguous() input_ = input.contiguous()
weight_ = weight.contiguous() weight_ = weight.contiguous()
bias_ = bias.contiguous() bias_ = bias.contiguous()
output, mean, invvar = colossal_layer_norm_cuda.forward_affine( output, mean, invvar = colossal_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, bias_,
input_, ctx.normalized_shape, weight_, bias_, ctx.eps) ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar) ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output return output
@ -72,8 +72,7 @@ class MixedFusedLayerNorm(torch.nn.Module):
def forward(self, input): def forward(self, input):
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)
self.normalized_shape, self.eps)
def __repr__(self): def __repr__(self):
return f'MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})' return f'MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})'

View File

@ -28,9 +28,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions') raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = colossal_scaled_upper_triang_masked_softmax.forward( softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
return softmax_results return softmax_results
@ -43,9 +41,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions') raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
softmax_results, scale_t = ctx.saved_tensors softmax_results, scale_t = ctx.saved_tensors
input_grads = colossal_scaled_upper_triang_masked_softmax.backward( input_grads = colossal_scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
output_grads, softmax_results, scale_t[0]
)
return input_grads, None return input_grads, None
@ -81,9 +77,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
softmax_results, scale_t = ctx.saved_tensors softmax_results, scale_t = ctx.saved_tensors
input_grads = colossal_scaled_masked_softmax.backward( input_grads = colossal_scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
output_grads, softmax_results, scale_t[0]
)
return input_grads, None, None return input_grads, None, None
@ -114,9 +108,8 @@ class FusedScaleMaskSoftmax(nn.Module):
super(FusedScaleMaskSoftmax, self).__init__() super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16 self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16 self.input_in_bf16 = input_in_bf16
assert not ( assert not (self.input_in_fp16
self.input_in_fp16 and self.input_in_bf16 and self.input_in_bf16), "both fp16 and bf16 flags cannot be active at the same time."
), "both fp16 and bf16 flags cannot be active at the same time."
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
@ -124,9 +117,7 @@ class FusedScaleMaskSoftmax(nn.Module):
self.softmax_in_fp32 = softmax_in_fp32 self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale self.scale = scale
assert ( assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled"
self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled"
def forward(self, input, mask): def forward(self, input, mask):
# [b, np, sq, sk] # [b, np, sq, sk]
@ -140,14 +131,13 @@ class FusedScaleMaskSoftmax(nn.Module):
def is_kernel_available(self, mask, b, np, sq, sk): def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np attn_batches = b * np
if ( if (self.scaled_masked_softmax_fusion # user want to fuse
self.scaled_masked_softmax_fusion # user want to fuse and self.input_in_float16 # input must be fp16
and self.input_in_float16 # input must be fp16 and mask is not None # mask tensor must not be None
and mask is not None # mask tensor must not be None and 16 < sk <= 2048 # sk must be 16 ~ 2048
and 16 < sk <= 2048 # sk must be 16 ~ 2048 and sq % 4 == 0 # sq must be divisor of 4
and sq % 4 == 0 # sq must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4 ):
):
if 0 <= sk <= 2048: if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np) batch_per_block = self.get_batch_per_block(sq, sk, b, np)

View File

@ -1,6 +1,5 @@
import torch import torch
###### BIAS GELU FUSION/ NO AUTOGRAD ################ ###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423 # 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678 # 1/sqrt(2) -> 0.70710678
@ -9,10 +8,12 @@ import torch
# actual gelu is: # actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) # x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script @torch.jit.script
def bias_gelu(bias, y): def bias_gelu(bias, y):
x = bias + y x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu # gradient of tanh approximation of gelu
# gradient of actual gelu is: # gradient of actual gelu is:
@ -23,9 +24,11 @@ def bias_gelu_back(g, bias, y):
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g return ff * g
class GeLUFunction(torch.autograd.Function): class GeLUFunction(torch.autograd.Function):
@staticmethod @staticmethod
# bias is an optional argument # bias is an optional argument
def forward(ctx, input, bias): def forward(ctx, input, bias):
@ -38,4 +41,5 @@ class GeLUFunction(torch.autograd.Function):
tmp = bias_gelu_back(grad_output, bias, input) tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
bias_gelu_impl = GeLUFunction.apply

View File

@ -182,7 +182,7 @@ class Linear2D(ParallelLayer):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# input: [m/q, n/q, k/q] # input: [m/q, n/q, k/q]
# output: [m/q, n/q, h/q] # output: [m/q, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition, ) out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
output = Matmul_AB_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank, output = Matmul_AB_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank,
@ -337,16 +337,16 @@ class LayerNorm2D(ParallelLayer):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
with torch.no_grad(): with torch.no_grad():
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
E_x /= self.normalized_shape E_x /= self.normalized_shape
# Var_x in the block below is the sum of input^2 # Var_x in the block below is the sum of input^2
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
Var_x /= self.normalized_shape Var_x /= self.normalized_shape
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
# this time 1/sqrt(Var_x + epsilon) # this time 1/sqrt(Var_x + epsilon)
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
@ -569,7 +569,7 @@ class PatchEmbedding2D(ParallelLayer):
output = F.conv2d(input_, weight, bias, stride=self.patch_size) output = F.conv2d(input_, weight, bias, stride=self.patch_size)
if self.flatten: if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = all_gather_tensor_2d(self.cls_token, -1, ParallelMode.PARALLEL_2D_COL) cls_token = all_gather_tensor_2d(self.cls_token, -1, ParallelMode.PARALLEL_2D_COL)
pos_embed = all_gather_tensor_2d(self.pos_embed, -1, ParallelMode.PARALLEL_2D_COL) pos_embed = all_gather_tensor_2d(self.pos_embed, -1, ParallelMode.PARALLEL_2D_COL)
@ -1012,7 +1012,7 @@ class Classifier2D(ParallelLayer):
destination.update(local_state) destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
out_shape = input_.shape[:-1] + (self.num_classes, ) out_shape = input_.shape[:-1] + (self.num_classes,)
return classifier_2d(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank, self.col_rank, return classifier_2d(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank,
@ -1186,7 +1186,7 @@ class VocabParallelClassifier2D(ParallelLayer):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# input: [m/q, n/q, k/q] # input: [m/q, n/q, k/q]
# output: [m/q, n/q, h/q] # output: [m/q, n/q, h/q]
out_shape = x.shape[:-1] + (self.output_size_per_partition, ) out_shape = x.shape[:-1] + (self.output_size_per_partition,)
output = Matmul_ABT_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank, output = Matmul_ABT_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,

View File

@ -189,7 +189,7 @@ class Linear2p5D(ParallelLayer):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# input: [m/dq, n/q, k/q] # input: [m/dq, n/q, k/q]
# output: [m/dq, n/q, h/q] # output: [m/dq, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition, ) out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
output = Matmul_AB_2p5D.apply( output = Matmul_AB_2p5D.apply(
x, x,
@ -254,7 +254,7 @@ class LayerNorm2p5D(ParallelLayer):
self.tesseract_dim, _ = get_tesseract_dim_dep_from_env() self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
# partitioning dimension # partitioning dimension
self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # * self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # *
# create parameters # create parameters
factory_kwargs = {'device': get_current_device(), 'dtype': dtype} factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
@ -357,16 +357,16 @@ class LayerNorm2p5D(ParallelLayer):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
with torch.no_grad(): with torch.no_grad():
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
E_x /= self.normalized_shape E_x /= self.normalized_shape
# Var_x in the block below is the sum of input^2 # Var_x in the block below is the sum of input^2
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
Var_x /= self.normalized_shape Var_x /= self.normalized_shape
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
# this time 1/sqrt(Var_x + epsilon) # this time 1/sqrt(Var_x + epsilon)
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
@ -589,7 +589,7 @@ class PatchEmbedding2p5D(ParallelLayer):
output = F.conv2d(input_, weight, bias, stride=self.patch_size) output = F.conv2d(input_, weight, bias, stride=self.patch_size)
if self.flatten: if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = all_gather_tensor_2p5d(self.cls_token, -1, ParallelMode.PARALLEL_2P5D_COL) cls_token = all_gather_tensor_2p5d(self.cls_token, -1, ParallelMode.PARALLEL_2P5D_COL)
pos_embed = all_gather_tensor_2p5d(self.pos_embed, -1, ParallelMode.PARALLEL_2P5D_COL) pos_embed = all_gather_tensor_2p5d(self.pos_embed, -1, ParallelMode.PARALLEL_2P5D_COL)
@ -1038,7 +1038,7 @@ class Classifier2p5D(ParallelLayer):
destination.update(local_state) destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
out_shape = input_.shape[:-1] + (self.num_classes, ) out_shape = input_.shape[:-1] + (self.num_classes,)
return classifier_2p5d(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank, return classifier_2p5d(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank,
self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL,
@ -1172,7 +1172,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# input: [m/dq, n/q, k/q] # input: [m/dq, n/q, k/q]
# output: [m/dq, n/q, h/q] # output: [m/dq, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition, ) out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
output = Matmul_ABT_2p5D.apply( output = Matmul_ABT_2p5D.apply(
x, x,

View File

@ -53,8 +53,8 @@ class LayerNorm3D(ParallelLayer):
self.weight = Parameter( self.weight = Parameter(
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
if bias: if bias:
self.bias = Parameter(torch.zeros(self.normalized_shape_per_partition, self.bias = Parameter(
device=get_current_device(), dtype=dtype)) torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
else: else:
self.bias = None self.bias = None
self.variance_epsilon = eps self.variance_epsilon = eps
@ -854,7 +854,7 @@ class PatchEmbedding3D(ParallelLayer):
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode) input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten: if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = self.cls_token.expand(output.shape[0], -1, -1) cls_token = self.cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1) output = torch.cat((cls_token, output), dim=1)

View File

@ -13,7 +13,8 @@ from torch import Tensor, nn
class CheckpointModule(nn.Module): class CheckpointModule(nn.Module):
def __init__(self, checkpoint: bool = True, offload : bool = False):
def __init__(self, checkpoint: bool = True, offload: bool = False):
super().__init__() super().__init__()
self.checkpoint = checkpoint self.checkpoint = checkpoint
self._use_checkpoint = checkpoint self._use_checkpoint = checkpoint
@ -78,6 +79,7 @@ def get_tensor_parallel_mode():
def _ntuple(n): def _ntuple(n):
def parse(x): def parse(x):
if isinstance(x, collections.abc.Iterable): if isinstance(x, collections.abc.Iterable):
return x return x