mirror of https://github.com/hpcaitech/ColossalAI
parent
0772828fba
commit
58580b50fe
|
@ -2,4 +2,3 @@ from .initialize import (initialize, launch, launch_from_openmpi,
|
|||
launch_from_slurm, launch_from_torch, get_default_parser)
|
||||
|
||||
__version__ = '0.0.1'
|
||||
|
||||
|
|
|
@ -251,9 +251,9 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
|
|||
partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
|
||||
module_list = []
|
||||
for start, end in partitions[pipeline_rank]:
|
||||
module_list.append(
|
||||
nn.Sequential(*[nn.Identity() for _ in range(start)], *layers[start:end],
|
||||
*[nn.Identity() for _ in range(len(layers) - end)]))
|
||||
module_list.append(nn.Sequential(*[nn.Identity() for _ in range(start)],
|
||||
*layers[start:end],
|
||||
*[nn.Identity() for _ in range(len(layers) - end)]))
|
||||
if verbose:
|
||||
logger = get_dist_logger()
|
||||
logger.info(f'Total {len(layers)} layers', ranks=[0])
|
||||
|
@ -264,3 +264,4 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
|
|||
log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n'
|
||||
logger.info(log_str, ranks=[0])
|
||||
return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0]
|
||||
|
|
@ -20,14 +20,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|||
SOFTWARE
|
||||
*/
|
||||
#include "cpu_adam.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <math.h>
|
||||
#include <memory>
|
||||
#include <omp.h>
|
||||
#include <string.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
|
||||
|
@ -84,7 +82,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
|
|||
|
||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||
if ((t + TILE) > rounded_size)
|
||||
copy_size = rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
|
@ -146,7 +145,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
|
|||
if (_param_size > rounded_size) {
|
||||
for (size_t t = rounded_size; t < _param_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > _param_size) copy_size = _param_size - t;
|
||||
if ((t + TILE) > _param_size)
|
||||
copy_size = _param_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
|
@ -235,7 +235,8 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
|
|||
|
||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||
if ((t + TILE) > rounded_size)
|
||||
copy_size = rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
|
@ -320,6 +321,7 @@ int create_adam_optimizer(int optimizer_id, float alpha = 1e-3,
|
|||
s_optimizers[optimizer_id] = opt;
|
||||
|
||||
if (should_log) {
|
||||
|
||||
std::string avx_type = "";
|
||||
#if defined(__AVX512__)
|
||||
avx_type = "AVX512";
|
||||
|
@ -384,7 +386,8 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
|
|||
|
||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||
if ((t + TILE) > rounded_size)
|
||||
copy_size = rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
|
@ -460,29 +463,43 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
|
|||
grad_half_precision, loss_scale);
|
||||
}
|
||||
|
||||
int adam_step(int optimizer_id, size_t step, float lr, float beta1, float beta2,
|
||||
float epsilon, float weight_decay, bool bias_correction,
|
||||
torch::Tensor ¶ms, torch::Tensor &grads,
|
||||
torch::Tensor &exp_avg, torch::Tensor &exp_avg_sq,
|
||||
float loss_scale) {
|
||||
auto params_c = params.contiguous();
|
||||
auto grads_c = grads.contiguous();
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||
int adam_step(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float epsilon,
|
||||
float weight_decay,
|
||||
bool bias_correction,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& exp_avg_sq,
|
||||
float loss_scale)
|
||||
{
|
||||
auto params_c = params.contiguous();
|
||||
auto grads_c = grads.contiguous();
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||
|
||||
float *params_ptr = (float *)params_c.data_ptr();
|
||||
float *grads_ptr = (float *)grads_c.data_ptr();
|
||||
float *exp_avg_ptr = (float *)exp_avg_c.data_ptr();
|
||||
float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr();
|
||||
std::shared_ptr<Adam_Optimizer> opt =
|
||||
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step, beta1, beta2);
|
||||
opt->update_state(lr, epsilon, weight_decay, bias_correction);
|
||||
opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
|
||||
params_c.numel(), (params.options().dtype() == at::kHalf),
|
||||
(grads.options().dtype() == at::kHalf), loss_scale);
|
||||
float* params_ptr = (float*)params_c.data_ptr();
|
||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
||||
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
|
||||
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
|
||||
std::shared_ptr<Adam_Optimizer> opt =
|
||||
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step, beta1, beta2);
|
||||
opt->update_state(lr, epsilon, weight_decay, bias_correction);
|
||||
opt->Step_8(params_ptr,
|
||||
grads_ptr,
|
||||
exp_avg_ptr,
|
||||
exp_avg_sq_ptr,
|
||||
params_c.numel(),
|
||||
(params.options().dtype() == at::kHalf),
|
||||
(grads.options().dtype() == at::kHalf),
|
||||
loss_scale);
|
||||
|
||||
return 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int destroy_adam_optimizer(int optimizer_id) {
|
||||
|
|
|
@ -48,10 +48,10 @@ SOFTWARE
|
|||
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
|
||||
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
|
||||
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
|
||||
#define SIMD_LOAD_HALF(x) \
|
||||
#define SIMD_LOAD_HALF(x) \
|
||||
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
|
||||
#define SIMD_STORE_HALF(x, d) \
|
||||
_mm256_store_ps( \
|
||||
#define SIMD_STORE_HALF(x, d) \
|
||||
_mm256_store_ps( \
|
||||
x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
|
||||
|
||||
#elif defined(__AVX256__) or defined(__AVX2__)
|
||||
|
@ -66,8 +66,8 @@ SOFTWARE
|
|||
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
|
||||
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
|
||||
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
|
||||
#define SIMD_STORE_HALF(x, d) \
|
||||
_mm_store_ps( \
|
||||
#define SIMD_STORE_HALF(x, d) \
|
||||
_mm_store_ps( \
|
||||
x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
|
||||
|
||||
#endif
|
||||
|
@ -83,25 +83,19 @@ union AVX_Data {
|
|||
|
||||
#endif
|
||||
|
||||
#define STEP(SPAN) \
|
||||
void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \
|
||||
float *_exp_avg_sq, size_t _param_size, \
|
||||
bool param_half_precision = false, \
|
||||
#define STEP(SPAN) \
|
||||
void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \
|
||||
float *_exp_avg_sq, size_t _param_size, \
|
||||
bool param_half_precision = false, \
|
||||
bool grad_half_precision = false, float loss_scale = -1);
|
||||
|
||||
class Adam_Optimizer {
|
||||
public:
|
||||
public:
|
||||
Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
|
||||
float eps = 1e-8, float weight_decay = 0,
|
||||
bool adamw_mode = true)
|
||||
: _alpha(alpha),
|
||||
_betta1(betta1),
|
||||
_betta2(betta2),
|
||||
_eps(eps),
|
||||
_weight_decay(weight_decay),
|
||||
_betta1_t(1.0),
|
||||
_betta2_t(1.0),
|
||||
_step(0),
|
||||
: _alpha(alpha), _betta1(betta1), _betta2(betta2), _eps(eps),
|
||||
_weight_decay(weight_decay), _betta1_t(1.0), _betta2_t(1.0), _step(0),
|
||||
_adamw_mode(adamw_mode) {}
|
||||
~Adam_Optimizer() {}
|
||||
|
||||
|
@ -141,7 +135,7 @@ class Adam_Optimizer {
|
|||
}
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
float _alpha;
|
||||
float _betta1;
|
||||
float _betta2;
|
||||
|
|
|
@ -16,7 +16,7 @@ __global__ void ls_cross_entropy_fw_kernel(
|
|||
const int left_idx = block_start + threadIdx.x;
|
||||
const int right_idx = (blockIdx.x + 1) * vocab_size;
|
||||
float max_input[1] = {REDUCE_FLOAT_INF_NEG};
|
||||
float sum_logits[2] = {0.f, 0.f}; // logit and logit exp
|
||||
float sum_logits[2] = {0.f, 0.f}; // logit and logit exp
|
||||
int target_tid = targets[blockIdx.x];
|
||||
|
||||
if (target_tid == padding_idx) {
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
#include <cooperative_groups.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <ctime>
|
||||
|
||||
#include "kernels.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
curandStatePhilox4_32_10_t *curandstate;
|
||||
|
@ -165,7 +165,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
|
|||
const float scale = 1.f / (1.f - ratio);
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (i * 4 >= total_count) return;
|
||||
if (i * 4 >= total_count)
|
||||
return;
|
||||
|
||||
curandStatePhilox4_32_10_t state;
|
||||
curand_init(seed, i, 0, &state);
|
||||
|
@ -201,7 +202,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
|
|||
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (i * 8 >= total_count) return;
|
||||
if (i * 8 >= total_count)
|
||||
return;
|
||||
|
||||
curandStatePhilox4_32_10_t state;
|
||||
curand_init(seed, i, 0, &state);
|
||||
|
@ -259,7 +261,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
|
|||
const float scale = 1.f / (1.f - ratio);
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (i * 4 >= total_count) return;
|
||||
if (i * 4 >= total_count)
|
||||
return;
|
||||
|
||||
uint8_t m[4];
|
||||
|
||||
|
@ -286,7 +289,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
|
|||
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (i * 8 >= total_count) return;
|
||||
if (i * 8 >= total_count)
|
||||
return;
|
||||
|
||||
float4 *out4 = reinterpret_cast<float4 *>(out);
|
||||
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
|
||||
|
@ -376,7 +380,8 @@ __global__ void ls_dropout_res_bias_kernel(
|
|||
const float scale = 1.f / (1.f - ratio);
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (i * 4 >= total_count) return;
|
||||
if (i * 4 >= total_count)
|
||||
return;
|
||||
|
||||
curandStatePhilox4_32_10_t state;
|
||||
curand_init(seed, i, 0, &state);
|
||||
|
@ -419,7 +424,8 @@ __global__ void ls_dropout_res_bias_kernel(
|
|||
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (i * 8 >= total_count) return;
|
||||
if (i * 8 >= total_count)
|
||||
return;
|
||||
|
||||
curandStatePhilox4_32_10_t state;
|
||||
curand_init(seed, i, 0, &state);
|
||||
|
@ -559,9 +565,11 @@ __global__ void ls_dropout_bias_bwd_kernel(
|
|||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i);
|
||||
for (int i = 1; i < 32; i <<= 1)
|
||||
sum += g.shfl_down(sum, i);
|
||||
|
||||
if (y == 0) tile[0][x] = sum;
|
||||
if (y == 0)
|
||||
tile[0][x] = sum;
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < 8) {
|
||||
|
@ -613,9 +621,11 @@ __global__ void ls_dropout_bias_bwd_kernel(
|
|||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
|
||||
for (int i = 1; i < WARP_SIZE; i <<= 1)
|
||||
sum += g.shfl_down(sum, i);
|
||||
|
||||
if (y == 0) tile[0][x] = sum;
|
||||
if (y == 0)
|
||||
tile[0][x] = sum;
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < 8) {
|
||||
|
@ -679,7 +689,8 @@ __global__ void ls_dropout_act_bias_kernel(
|
|||
const float scale = 1.f / (1.f - ratio);
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (i * 4 >= total_count) return;
|
||||
if (i * 4 >= total_count)
|
||||
return;
|
||||
|
||||
curandStatePhilox4_32_10_t state;
|
||||
curand_init(seed, i, 0, &state);
|
||||
|
@ -724,7 +735,8 @@ __global__ void ls_dropout_act_bias_kernel(
|
|||
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (i * 8 >= total_count) return;
|
||||
if (i * 8 >= total_count)
|
||||
return;
|
||||
|
||||
curandStatePhilox4_32_10_t state;
|
||||
curand_init(seed, i, 0, &state);
|
||||
|
@ -885,9 +897,11 @@ __global__ void ls_dropout_act_bias_bwd_kernel(
|
|||
float sum = tile[threadIdx.y][threadIdx.x];
|
||||
__syncthreads();
|
||||
|
||||
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
|
||||
for (int i = 1; i < WARP_SIZE; i <<= 1)
|
||||
sum += g.shfl_down(sum, i);
|
||||
|
||||
if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
|
||||
if (threadIdx.x == 0)
|
||||
tile[0][threadIdx.y] = sum;
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.y == 0) {
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#include <cooperative_groups.h>
|
||||
|
||||
#include "kernels.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
/**
|
||||
|
|
|
@ -13,23 +13,22 @@ const float REDUCE_FLOAT_INF_NEG = -100000000.f;
|
|||
const float REDUCE_FLOAT_INF_POS = 100000000.f;
|
||||
const unsigned int WARP_REDUCE_SIZE = 32;
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T warpReduceSum(T val) {
|
||||
template <typename T> __forceinline__ __device__ T warpReduceSum(T val) {
|
||||
for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1)
|
||||
val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE);
|
||||
return val;
|
||||
}
|
||||
|
||||
/* Calculate the sum of all elements in a block */
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T blockReduceSum(T val) {
|
||||
template <typename T> __forceinline__ __device__ T blockReduceSum(T val) {
|
||||
static __shared__ T shared[32];
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
val = warpReduceSum<T>(val);
|
||||
|
||||
if (lane == 0) shared[wid] = val;
|
||||
if (lane == 0)
|
||||
shared[wid] = val;
|
||||
__syncthreads();
|
||||
|
||||
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f;
|
||||
|
@ -57,10 +56,10 @@ __inline__ __device__ void warpReduce<ReduceType::kMax, 1>(float *pval) {
|
|||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kMax, 2>(float *pval) {
|
||||
float val0_tmp, val1_tmp;
|
||||
#define WarpReduceMaxOneStep(a, b) \
|
||||
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \
|
||||
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
|
||||
*(pval) = max(val0_tmp, *(pval)); \
|
||||
#define WarpReduceMaxOneStep(a, b) \
|
||||
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \
|
||||
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
|
||||
*(pval) = max(val0_tmp, *(pval)); \
|
||||
*(pval + 1) = max(val1_tmp, *(pval + 1));
|
||||
|
||||
WarpReduceMaxOneStep(16, 32);
|
||||
|
@ -89,10 +88,10 @@ __inline__ __device__ void warpReduce<ReduceType::kSum, 1>(float *pval) {
|
|||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kSum, 2>(float *pval) {
|
||||
float val0_tmp, val1_tmp;
|
||||
#define WarpReduceSumOneStep(a, b) \
|
||||
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
|
||||
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
|
||||
*(pval + 0) += val0_tmp; \
|
||||
#define WarpReduceSumOneStep(a, b) \
|
||||
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
|
||||
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
|
||||
*(pval + 0) += val0_tmp; \
|
||||
*(pval + 1) += val1_tmp
|
||||
|
||||
WarpReduceSumOneStep(16, 32);
|
||||
|
@ -107,14 +106,14 @@ __inline__ __device__ void warpReduce<ReduceType::kSum, 2>(float *pval) {
|
|||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kSum, 4>(float *pval) {
|
||||
float val0_tmp, val1_tmp, val2_tmp, val3_tmp;
|
||||
#define WarpReduceSumOneStep(a, b) \
|
||||
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
|
||||
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
|
||||
val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \
|
||||
val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \
|
||||
*(pval + 0) += val0_tmp; \
|
||||
*(pval + 1) += val1_tmp; \
|
||||
*(pval + 2) += val2_tmp; \
|
||||
#define WarpReduceSumOneStep(a, b) \
|
||||
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
|
||||
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
|
||||
val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \
|
||||
val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \
|
||||
*(pval + 0) += val0_tmp; \
|
||||
*(pval + 1) += val1_tmp; \
|
||||
*(pval + 2) += val2_tmp; \
|
||||
*(pval + 3) += val3_tmp
|
||||
|
||||
WarpReduceSumOneStep(16, 32);
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
#include "cuda_util.h"
|
||||
|
||||
class Context {
|
||||
public:
|
||||
public:
|
||||
Context() : _stream(nullptr) {
|
||||
CHECK_GPU_ERROR(cublasCreate(&_cublasHandle));
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ class Context {
|
|||
|
||||
cublasHandle_t get_cublashandle() { return _cublasHandle; }
|
||||
|
||||
private:
|
||||
private:
|
||||
cudaStream_t _stream;
|
||||
cublasHandle_t _cublasHandle;
|
||||
};
|
||||
|
|
|
@ -8,9 +8,8 @@
|
|||
|
||||
#include "cuda_util.h"
|
||||
|
||||
template <typename T>
|
||||
class CrossEntropyLayer {
|
||||
public:
|
||||
template <typename T> class CrossEntropyLayer {
|
||||
public:
|
||||
CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens);
|
||||
|
||||
virtual ~CrossEntropyLayer();
|
||||
|
@ -23,7 +22,7 @@ class CrossEntropyLayer {
|
|||
|
||||
void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size);
|
||||
|
||||
private:
|
||||
private:
|
||||
void allocate_mem_buffer() {
|
||||
// allocate local gpu memory
|
||||
_loss_buffer = cuda_malloc<float>(_max_batch_tokens * 2);
|
||||
|
|
|
@ -20,8 +20,7 @@ void check_gpu_error(T result, char const *const func, const char *const file,
|
|||
template <typename T>
|
||||
void print_vec(const T *outv, std::string outn, int num_output_ele);
|
||||
|
||||
template <typename T>
|
||||
T *cuda_malloc(size_t ele_num);
|
||||
template <typename T> T *cuda_malloc(size_t ele_num);
|
||||
|
||||
void cuda_free(void *pdata);
|
||||
|
||||
|
@ -29,6 +28,6 @@ template <typename T>
|
|||
void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf,
|
||||
std::string file, int line, cudaStream_t stream);
|
||||
|
||||
#define CHECK_NAN_INF(ptr, size, stream) \
|
||||
check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \
|
||||
#define CHECK_NAN_INF(ptr, size, stream) \
|
||||
check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \
|
||||
check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream))
|
||||
|
|
|
@ -3,14 +3,12 @@
|
|||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "kernels.h"
|
||||
|
||||
template <typename T>
|
||||
class Dropout {
|
||||
public:
|
||||
template <typename T> class Dropout {
|
||||
public:
|
||||
struct Config {
|
||||
float ratio;
|
||||
bool training;
|
||||
|
@ -90,7 +88,7 @@ class Dropout {
|
|||
|
||||
void SetTrainingMode(bool training) { _config.training = training; }
|
||||
|
||||
private:
|
||||
private:
|
||||
uint8_t *_mask;
|
||||
Config _config;
|
||||
};
|
||||
|
|
|
@ -13,16 +13,14 @@
|
|||
#include "cublas_wrappers.h"
|
||||
#include "kernels.h"
|
||||
|
||||
template <typename T>
|
||||
class FeedForward {
|
||||
public:
|
||||
template <typename T> class FeedForward {
|
||||
public:
|
||||
struct Config {
|
||||
int outputSize;
|
||||
int inputSize;
|
||||
std::array<int, 3> gemm_algos;
|
||||
Config(int outputs, int inputs)
|
||||
: outputSize(outputs),
|
||||
inputSize(inputs),
|
||||
: outputSize(outputs), inputSize(inputs),
|
||||
gemm_algos(std::array<int, 3>({99, 99, 99})) {}
|
||||
};
|
||||
|
||||
|
@ -63,6 +61,6 @@ class FeedForward {
|
|||
config_.inputSize = inputSize;
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
Config config_;
|
||||
};
|
||||
|
|
|
@ -10,9 +10,8 @@
|
|||
|
||||
using namespace std;
|
||||
|
||||
template <typename T>
|
||||
class Softmax {
|
||||
public:
|
||||
template <typename T> class Softmax {
|
||||
public:
|
||||
struct Config {
|
||||
size_t nhead;
|
||||
Config(size_t nhead) : nhead(nhead) {}
|
||||
|
@ -37,6 +36,6 @@ class Softmax {
|
|||
|
||||
void reset_size(size_t nhead) { config_.nhead = nhead; }
|
||||
|
||||
private:
|
||||
private:
|
||||
Config config_;
|
||||
};
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#include "block_reduce.h"
|
||||
#include "kernels.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
#include <cooperative_groups.h>
|
||||
#include <math.h>
|
||||
|
||||
#include <cub/block/block_load.cuh>
|
||||
|
@ -7,6 +6,8 @@
|
|||
#include "block_reduce.h"
|
||||
#include "kernels.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
const float EPSILON = 1e-8f;
|
||||
|
||||
|
@ -119,7 +120,7 @@ __global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len,
|
|||
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
|
||||
to_len);
|
||||
}
|
||||
} // blockIdx.x
|
||||
} // blockIdx.x
|
||||
}
|
||||
|
||||
template <typename T, int block_dim, int ele_per_thread>
|
||||
|
@ -197,7 +198,7 @@ __global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len,
|
|||
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
|
||||
to_len);
|
||||
}
|
||||
} // blockIdx.x
|
||||
} // blockIdx.x
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -303,7 +304,8 @@ __global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) {
|
|||
cg::thread_block b = cg::this_thread_block();
|
||||
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
|
||||
|
||||
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
|
||||
for (int i = 1; i < WARP_SIZE; i <<= 1)
|
||||
sum += g.shfl_xor(sum, i);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ITERATIONS; ++i) {
|
||||
|
|
|
@ -2,12 +2,10 @@
|
|||
* https://github.com/NVIDIA/apex
|
||||
* with minor changes. */
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
|
||||
#include "compat.h"
|
||||
#include <cassert>
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -67,7 +65,7 @@ void check_args(at::Tensor input, at::IntArrayRef normalized_shape,
|
|||
check_args(input, normalized_shape, n1, n2);
|
||||
check_args(normalized_shape, gamma, beta);
|
||||
}
|
||||
} // namespace
|
||||
} // namespace
|
||||
|
||||
void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar,
|
||||
at::Tensor *input, int n1, int n2,
|
||||
|
@ -75,16 +73,17 @@ void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar,
|
|||
at::Tensor *beta, double epsilon);
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::vector<at::Tensor> layer_norm_affine(at::Tensor input,
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor gamma, at::Tensor beta,
|
||||
double epsilon) {
|
||||
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(gamma);
|
||||
CHECK_INPUT(beta);
|
||||
|
@ -110,10 +109,11 @@ void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean,
|
|||
double epsilon, at::Tensor *grad_input,
|
||||
at::Tensor *grad_gamma, at::Tensor *grad_beta);
|
||||
|
||||
std::vector<at::Tensor> layer_norm_gradient_affine(
|
||||
at::Tensor dout, at::Tensor mean, at::Tensor invvar, at::Tensor input,
|
||||
at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta,
|
||||
double epsilon) {
|
||||
std::vector<at::Tensor>
|
||||
layer_norm_gradient_affine(at::Tensor dout, at::Tensor mean, at::Tensor invvar,
|
||||
at::Tensor input, at::IntArrayRef normalized_shape,
|
||||
at::Tensor gamma, at::Tensor beta, double epsilon) {
|
||||
|
||||
CHECK_INPUT(dout);
|
||||
CHECK_INPUT(mean);
|
||||
CHECK_INPUT(invvar);
|
||||
|
|
|
@ -15,24 +15,25 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
|
|||
torch::Tensor logits, torch::Tensor mask,
|
||||
torch::Tensor dest_idx);
|
||||
|
||||
std::vector<torch::Tensor> moe_combine_cuda_backward(
|
||||
int s, int e, int c, int h, torch::Tensor tokens_grad,
|
||||
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask,
|
||||
torch::Tensor dest_idx);
|
||||
std::vector<torch::Tensor>
|
||||
moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
|
||||
torch::Tensor expert_tokens, torch::Tensor logits,
|
||||
torch::Tensor mask, torch::Tensor dest_idx);
|
||||
|
||||
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
|
||||
|
||||
#define CHECK_CUDA(x) \
|
||||
#define CHECK_CUDA(x) \
|
||||
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
torch::Tensor moe_dispatch_forward(int s, int ec, int h,
|
||||
torch::Tensor batch_tokens,
|
||||
torch::Tensor mask, torch::Tensor dest_idx) {
|
||||
|
||||
CHECK_INPUT(batch_tokens);
|
||||
CHECK_CUDA(mask);
|
||||
CHECK_CUDA(dest_idx);
|
||||
|
@ -44,6 +45,7 @@ torch::Tensor moe_dispatch_backward(int s, int ec, int h,
|
|||
torch::Tensor expert_grad,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
|
||||
CHECK_INPUT(expert_grad);
|
||||
CHECK_CUDA(mask);
|
||||
CHECK_CUDA(dest_idx);
|
||||
|
@ -55,6 +57,7 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h,
|
|||
torch::Tensor expert_tokens,
|
||||
torch::Tensor logits, torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
|
||||
CHECK_INPUT(expert_tokens);
|
||||
CHECK_INPUT(logits);
|
||||
CHECK_CUDA(mask);
|
||||
|
@ -64,12 +67,11 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h,
|
|||
dest_idx);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> moe_combine_backward(int s, int e, int c, int h,
|
||||
torch::Tensor tokens_grad,
|
||||
torch::Tensor expert_tokens,
|
||||
torch::Tensor logits,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
std::vector<torch::Tensor>
|
||||
moe_combine_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
|
||||
torch::Tensor expert_tokens, torch::Tensor logits,
|
||||
torch::Tensor mask, torch::Tensor dest_idx) {
|
||||
|
||||
CHECK_INPUT(tokens_grad);
|
||||
CHECK_INPUT(logits);
|
||||
CHECK_CUDA(mask);
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
#include "block_reduce.h"
|
||||
#include <cub/cub.cuh>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
#include "block_reduce.h"
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
|
@ -29,6 +28,7 @@ __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
|
|||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
|
@ -51,6 +51,7 @@ __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
|
|||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
|
||||
const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
|
@ -74,6 +75,7 @@ __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
|
|||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
|
||||
const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
|
@ -103,6 +105,7 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
|
|||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
|
||||
const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
|
@ -131,6 +134,7 @@ __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
|
|||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
|
||||
T *weight_grad, const T weight, const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
|
@ -160,13 +164,15 @@ __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
|
|||
|
||||
blockReduce<ReduceType::kSum, 1>(&thread_sum);
|
||||
|
||||
if (threadIdx.x == 0) *weight_grad = static_cast<T>(thread_sum);
|
||||
if (threadIdx.x == 0)
|
||||
*weight_grad = static_cast<T>(thread_sum);
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row,
|
||||
const T weight1, const T weight2,
|
||||
const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
|
@ -198,6 +204,7 @@ __device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row,
|
|||
T *tks_row1, T *tks_row2, T *weight_grad1,
|
||||
T *weight_grad2, const T weight1,
|
||||
const T weight2, const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
|
@ -244,6 +251,7 @@ template <typename T, int block_size, int pack_size>
|
|||
__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2,
|
||||
const int cols, const int indicator1,
|
||||
const int indicator2) {
|
||||
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_dpch_two_fwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
|
||||
cols);
|
||||
|
@ -259,6 +267,7 @@ template <typename T, int block_size, int pack_size>
|
|||
__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2,
|
||||
const int cols, const int indicator1,
|
||||
const int indicator2) {
|
||||
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_dpch_two_bwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
|
||||
cols);
|
||||
|
@ -274,6 +283,7 @@ template <typename T, int block_size, int pack_size>
|
|||
__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input,
|
||||
int *mask1, int *mask2, int *dest1,
|
||||
int *dest2, const int h) {
|
||||
|
||||
int row = blockIdx.x;
|
||||
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
||||
moe_dpch_fwd_selector<T, block_size, pack_size>(
|
||||
|
@ -285,6 +295,7 @@ template <typename T, int block_size, int pack_size>
|
|||
__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1,
|
||||
int *mask2, int *dest1, int *dest2,
|
||||
const int h) {
|
||||
|
||||
int row = blockIdx.x;
|
||||
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
||||
moe_dpch_bwd_selector<T, block_size, pack_size>(
|
||||
|
@ -299,6 +310,7 @@ __device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row,
|
|||
const int cols, const T weight1,
|
||||
const T weight2, const int indicator1,
|
||||
const int indicator2) {
|
||||
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_cb_two_fwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
|
||||
weight1, weight2, cols);
|
||||
|
@ -316,6 +328,7 @@ __device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row,
|
|||
T *wt_grad1, T *wt_grad2, const T weight1,
|
||||
const T weight2, const int indicator1,
|
||||
const int indicator2) {
|
||||
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_cb_two_bwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
|
||||
tks_row1, tks_row2, wt_grad1,
|
||||
|
@ -335,6 +348,7 @@ __global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens,
|
|||
T *logits, int *mask1, int *mask2, int *dest1,
|
||||
int *dest2, const int e, const int c,
|
||||
const int h) {
|
||||
|
||||
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
|
||||
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
||||
T *row_log = logits + (row * e);
|
||||
|
@ -349,6 +363,7 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
|
|||
T *logits, T *logits_grad, int *mask1,
|
||||
int *mask2, int *dest1, int *dest2,
|
||||
const int e, const int c, const int h) {
|
||||
|
||||
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
|
||||
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
||||
T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e);
|
||||
|
@ -364,6 +379,7 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
|
|||
template <int block_size, int pack_size>
|
||||
__global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
|
||||
const int e) {
|
||||
|
||||
assert(s % pack_size == 0);
|
||||
constexpr int bpack_size = block_size * pack_size;
|
||||
int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1;
|
||||
|
@ -410,7 +426,8 @@ __global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
|
|||
}
|
||||
__syncthreads();
|
||||
|
||||
if (tid == 0) temp[0] = temp[block_size];
|
||||
if (tid == 0)
|
||||
temp[0] = temp[block_size];
|
||||
__syncthreads();
|
||||
|
||||
if (idx + tps < s) {
|
||||
|
@ -436,6 +453,7 @@ template <typename T>
|
|||
void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
|
||||
int *mask2, int *dest1, int *dest2, const int s,
|
||||
const int h) {
|
||||
|
||||
if (h < 256)
|
||||
moe_dpch_fwd_kernel<T, 32, 4>
|
||||
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
|
@ -456,6 +474,7 @@ void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
|
|||
template <typename T>
|
||||
void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2,
|
||||
int *dest1, int *dest2, const int s, const int h) {
|
||||
|
||||
if (h < 256)
|
||||
moe_dpch_bwd_kernel<T, 32, 4>
|
||||
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
|
@ -477,6 +496,7 @@ template <typename T>
|
|||
void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits,
|
||||
int *mask1, int *mask2, int *dest1, int *dest2,
|
||||
const int s, const int e, const int c, const int h) {
|
||||
|
||||
if (h < 256)
|
||||
moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>(expert_tokens, combine_tokens,
|
||||
logits, mask1, mask2, dest1, dest2,
|
||||
|
@ -504,11 +524,12 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
|
|||
T *logits_grad, int *mask1, int *mask2, int *dest1,
|
||||
int *dest2, const int s, const int e, const int c,
|
||||
const int h) {
|
||||
|
||||
if (h < 256)
|
||||
moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, tks,
|
||||
logits, logits_grad, mask1, mask2,
|
||||
dest1, dest2, e, c, h);
|
||||
else // if (h < 512)
|
||||
else // if (h < 512)
|
||||
moe_cb_bwd_kernel<T, 64, 4><<<s, 64>>>(tokens_grad, expert_grad, tks,
|
||||
logits, logits_grad, mask1, mask2,
|
||||
dest1, dest2, e, c, h);
|
||||
|
@ -523,6 +544,7 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
|
|||
}
|
||||
|
||||
void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
|
||||
|
||||
if (s <= 256)
|
||||
cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e);
|
||||
else if (s <= 512)
|
||||
|
@ -537,26 +559,27 @@ void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
|
|||
|
||||
// API FUNCTIONS --------------------------------
|
||||
|
||||
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented yet for specific data type."); \
|
||||
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented yet for specific data type."); \
|
||||
}
|
||||
|
||||
torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
|
||||
torch::Tensor batch_tokens,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
|
||||
assert(h % 16 == 0);
|
||||
auto res = torch::zeros(
|
||||
{ec, h},
|
||||
|
@ -578,6 +601,7 @@ torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
|
|||
torch::Tensor expert_grad,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
|
||||
assert(h % 16 == 0);
|
||||
auto res = torch::zeros(
|
||||
{s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device()));
|
||||
|
@ -598,6 +622,7 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
|
|||
torch::Tensor expert_tokens,
|
||||
torch::Tensor logits, torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
|
||||
assert(h % 16 == 0);
|
||||
assert(expert_tokens.dtype() == logits.dtype());
|
||||
|
||||
|
@ -618,10 +643,11 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
|
|||
return res;
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> moe_combine_cuda_backward(
|
||||
int s, int e, int c, int h, torch::Tensor tokens_grad,
|
||||
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
std::vector<torch::Tensor>
|
||||
moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
|
||||
torch::Tensor expert_tokens, torch::Tensor logits,
|
||||
torch::Tensor mask, torch::Tensor dest_idx) {
|
||||
|
||||
assert(h % 16 == 0);
|
||||
assert(tokens_grad.dtype() == expert_tokens.dtype());
|
||||
assert(expert_tokens.dtype() == logits.dtype());
|
||||
|
@ -647,6 +673,7 @@ std::vector<torch::Tensor> moe_combine_cuda_backward(
|
|||
}
|
||||
|
||||
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) {
|
||||
|
||||
assert(mask.dim() == 2);
|
||||
assert(mask.dtype() == torch::kInt32);
|
||||
|
||||
|
|
|
@ -16,8 +16,7 @@
|
|||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ bool is_aligned(T *p) {
|
||||
template <typename T> __device__ __forceinline__ bool is_aligned(T *p) {
|
||||
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
|
||||
}
|
||||
|
||||
|
@ -29,12 +28,11 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
|
|||
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
|
||||
}
|
||||
|
||||
template <typename x_t>
|
||||
struct L2NormFunctor {
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
|
||||
float *output, float *output_per_tensor, bool per_tensor,
|
||||
int max_chunks_per_tensor) {
|
||||
template <typename x_t> struct L2NormFunctor {
|
||||
__device__ __forceinline__ void
|
||||
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
|
||||
float *output, float *output_per_tensor, bool per_tensor,
|
||||
int max_chunks_per_tensor) {
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
@ -50,8 +48,8 @@ struct L2NormFunctor {
|
|||
|
||||
__shared__ float s_vals[512];
|
||||
|
||||
float vals[ILP]; // = {0}; // this probably works too but I want to be
|
||||
// sure...
|
||||
float
|
||||
vals[ILP]; // = {0}; // this probably works too but I want to be sure...
|
||||
x_t r_x[ILP];
|
||||
for (int i = 0; i < ILP; i++) {
|
||||
vals[i] = 0.f;
|
||||
|
@ -86,14 +84,15 @@ struct L2NormFunctor {
|
|||
}
|
||||
|
||||
float val = 0.f;
|
||||
for (int i = 0; i < ILP; i++) val += vals[i];
|
||||
for (int i = 0; i < ILP; i++)
|
||||
val += vals[i];
|
||||
|
||||
float final = reduce_block_into_lanes(s_vals, val);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
if (!isfinite(final))
|
||||
*noop_gmem =
|
||||
1; // Blindly fire off a write. These will race but that's ok.
|
||||
1; // Blindly fire off a write. These will race but that's ok.
|
||||
output[blockIdx.x] += final;
|
||||
if (per_tensor)
|
||||
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) *
|
||||
|
@ -105,12 +104,11 @@ struct L2NormFunctor {
|
|||
|
||||
// Probably better to template, but since we are not likely to support other
|
||||
// norm
|
||||
template <typename x_t>
|
||||
struct MaxNormFunctor {
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
|
||||
float *output, float *output_per_tensor, bool per_tensor,
|
||||
int max_chunks_per_tensor) {
|
||||
template <typename x_t> struct MaxNormFunctor {
|
||||
__device__ __forceinline__ void
|
||||
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
|
||||
float *output, float *output_per_tensor, bool per_tensor,
|
||||
int max_chunks_per_tensor) {
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
@ -126,8 +124,8 @@ struct MaxNormFunctor {
|
|||
|
||||
__shared__ float s_vals[512];
|
||||
|
||||
float vals[ILP]; // = {0}; // this probably works too but I want to be
|
||||
// sure...
|
||||
float
|
||||
vals[ILP]; // = {0}; // this probably works too but I want to be sure...
|
||||
x_t r_x[ILP];
|
||||
for (int i = 0; i < ILP; i++) {
|
||||
vals[i] = 0.f;
|
||||
|
@ -162,14 +160,15 @@ struct MaxNormFunctor {
|
|||
}
|
||||
|
||||
float val = 0.f;
|
||||
for (int i = 0; i < ILP; i++) val = fmaxf(fabsf(val), fabsf(vals[i]));
|
||||
for (int i = 0; i < ILP; i++)
|
||||
val = fmaxf(fabsf(val), fabsf(vals[i]));
|
||||
|
||||
float final = reduce_block_into_lanes_max_op(s_vals, val);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
if (!isfinite(final))
|
||||
*noop_gmem =
|
||||
1; // Blindly fire off a write. These will race but that's ok.
|
||||
1; // Blindly fire off a write. These will race but that's ok.
|
||||
output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final));
|
||||
if (per_tensor)
|
||||
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) *
|
||||
|
@ -186,11 +185,13 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret,
|
|||
|
||||
if (blockIdx.x == 0) {
|
||||
float val = 0;
|
||||
if (threadIdx.x < 320) val = output[threadIdx.x];
|
||||
if (threadIdx.x < 320)
|
||||
val = output[threadIdx.x];
|
||||
|
||||
float final = reduce_block_into_lanes(vals, val);
|
||||
|
||||
if (threadIdx.x == 0) *ret = sqrt(final);
|
||||
if (threadIdx.x == 0)
|
||||
*ret = sqrt(final);
|
||||
}
|
||||
|
||||
if (per_tensor) {
|
||||
|
@ -203,7 +204,8 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret,
|
|||
|
||||
float final = reduce_block_into_lanes(vals, val);
|
||||
|
||||
if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = sqrt(final);
|
||||
if (threadIdx.x == 0)
|
||||
ret_per_tensor[blockIdx.x] = sqrt(final);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -215,14 +217,17 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
|
|||
|
||||
if (blockIdx.x == 0) {
|
||||
float val = 0;
|
||||
if (threadIdx.x < 320) val = output[threadIdx.x];
|
||||
if (threadIdx.x < 320)
|
||||
val = output[threadIdx.x];
|
||||
|
||||
if (norm_type == 0) {
|
||||
float final = reduce_block_into_lanes_max_op(vals, val);
|
||||
if (threadIdx.x == 0) *ret = alpha * (*ret) + beta * final;
|
||||
if (threadIdx.x == 0)
|
||||
*ret = alpha * (*ret) + beta * final;
|
||||
} else {
|
||||
float final = reduce_block_into_lanes(vals, val);
|
||||
if (threadIdx.x == 0) *ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
|
||||
if (threadIdx.x == 0)
|
||||
*ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -255,10 +260,10 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
|
|||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
|
||||
int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
at::optional<bool> per_tensor_python) {
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
at::optional<bool> per_tensor_python) {
|
||||
bool per_tensor =
|
||||
per_tensor_python.has_value() ? per_tensor_python.value() : false;
|
||||
|
||||
|
|
|
@ -15,8 +15,7 @@
|
|||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ bool is_aligned(T *p) {
|
||||
template <typename T> __device__ __forceinline__ bool is_aligned(T *p) {
|
||||
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
|
||||
}
|
||||
|
||||
|
@ -29,25 +28,24 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
|
|||
}
|
||||
|
||||
typedef enum {
|
||||
MOMENT_MODE_0 = 0, // L2 regularization mode
|
||||
MOMENT_MODE_1 = 1 // Decoupled weight decay mode
|
||||
MOMENT_MODE_0 = 0, // L2 regularization mode
|
||||
MOMENT_MODE_1 = 1 // Decoupled weight decay mode
|
||||
} adamMode_t;
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
|
||||
int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
at::optional<bool> per_tensor_python);
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
at::optional<bool> per_tensor_python);
|
||||
|
||||
using MATH_T = float;
|
||||
|
||||
template <typename T>
|
||||
struct LAMBStage1Functor {
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
|
||||
const float beta1, const float beta2, const float beta3,
|
||||
const float beta1_correction, const float beta2_correction,
|
||||
const float epsilon, adamMode_t mode, const float decay,
|
||||
const float *global_grad_norm, const float max_global_grad_norm) {
|
||||
template <typename T> struct LAMBStage1Functor {
|
||||
__device__ __forceinline__ void
|
||||
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
|
||||
const float beta1, const float beta2, const float beta3,
|
||||
const float beta1_correction, const float beta2_correction,
|
||||
const float epsilon, adamMode_t mode, const float decay,
|
||||
const float *global_grad_norm, const float max_global_grad_norm) {
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
@ -91,7 +89,8 @@ struct LAMBStage1Functor {
|
|||
i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(l_g, g, 0, i_start);
|
||||
if (decay != 0) load_store(l_p, p, 0, i_start);
|
||||
if (decay != 0)
|
||||
load_store(l_p, p, 0, i_start);
|
||||
load_store(l_m, m, 0, i_start);
|
||||
load_store(l_v, v, 0, i_start);
|
||||
// unpack
|
||||
|
@ -205,12 +204,12 @@ struct LAMBStage1Functor {
|
|||
|
||||
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
|
||||
// It computes new parameter value.
|
||||
template <typename T>
|
||||
struct LAMBStage2Functor {
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl,
|
||||
const float *per_tensor_param_norm, const float *per_tensor_update_norm,
|
||||
const float learning_rate, const float decay, bool use_nvlamb) {
|
||||
template <typename T> struct LAMBStage2Functor {
|
||||
__device__ __forceinline__ void
|
||||
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl,
|
||||
const float *per_tensor_param_norm,
|
||||
const float *per_tensor_update_norm, const float learning_rate,
|
||||
const float decay, bool use_nvlamb) {
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
@ -311,7 +310,8 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
|
|||
|
||||
// Handle grad averaging mode
|
||||
float beta3 = 1.0f;
|
||||
if (grad_averaging == 1) beta3 = 1 - beta1;
|
||||
if (grad_averaging == 1)
|
||||
beta3 = 1 - beta1;
|
||||
|
||||
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(),
|
||||
tensor_lists.begin() + 1);
|
||||
|
@ -330,7 +330,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
|
|||
tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
|
||||
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
LAMBStage1Functor<scalar_t_0>(), beta1, beta2,
|
||||
beta3, // 1-beta1 or 1 depends on averaging mode
|
||||
beta3, // 1-beta1 or 1 depends on averaging mode
|
||||
bias_correction1, bias_correction2, epsilon,
|
||||
(adamMode_t)mode, weight_decay,
|
||||
global_grad_norm.DATA_PTR<float>(), max_grad_norm);)
|
||||
|
|
|
@ -15,8 +15,7 @@
|
|||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ bool is_aligned(T *p) {
|
||||
template <typename T> __device__ __forceinline__ bool is_aligned(T *p) {
|
||||
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
|
||||
}
|
||||
|
||||
|
@ -28,8 +27,7 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
|
|||
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
|
||||
}
|
||||
|
||||
template <typename in_t, typename out_t>
|
||||
struct ScaleFunctor {
|
||||
template <typename in_t, typename out_t> struct ScaleFunctor {
|
||||
__device__ __forceinline__ void operator()(int chunk_size,
|
||||
volatile int *noop_gmem,
|
||||
TensorListMetadata<2> &tl,
|
||||
|
@ -78,7 +76,8 @@ struct ScaleFunctor {
|
|||
for (int ii = 0; ii < ILP; ii++) {
|
||||
r_in[ii] = 0;
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) r_in[ii] = in[i];
|
||||
if (i < n && i < chunk_size)
|
||||
r_in[ii] = in[i];
|
||||
}
|
||||
// note for clarification to future michael:
|
||||
// From a pure memory dependency perspective, there's likely no point
|
||||
|
@ -94,13 +93,14 @@ struct ScaleFunctor {
|
|||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) out[i] = r_out[ii];
|
||||
if (i < n && i < chunk_size)
|
||||
out[i] = r_out[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!finite)
|
||||
*noop_gmem =
|
||||
1; // Blindly fire off a write. These will race but that's ok.
|
||||
1; // Blindly fire off a write. These will race but that's ok.
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
// modified from
|
||||
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu
|
||||
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include "multi_tensor_apply.cuh"
|
||||
#include "compat.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "compat.h"
|
||||
#include "multi_tensor_apply.cuh"
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
|
@ -29,53 +28,69 @@
|
|||
* wd_after_momentum : apply weight decay _after_ momentum instead of before
|
||||
**/
|
||||
template <int N, typename T_grad, typename T_weight>
|
||||
struct SGDFunctor {
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size, volatile int *noop_gmem, TensorListMetadata<N> &tl,
|
||||
float wd, float momentum, float dampening, float lr, bool nesterov,
|
||||
bool first_run, bool wd_after_momentum, float scale) {
|
||||
// Early exit if we don't need to do anything
|
||||
if (*noop_gmem) return;
|
||||
struct SGDFunctor
|
||||
{
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
volatile int *noop_gmem,
|
||||
TensorListMetadata<N> &tl,
|
||||
float wd,
|
||||
float momentum,
|
||||
float dampening,
|
||||
float lr,
|
||||
bool nesterov,
|
||||
bool first_run,
|
||||
bool wd_after_momentum,
|
||||
float scale)
|
||||
{
|
||||
// Early exit if we don't need to do anything
|
||||
if (*noop_gmem)
|
||||
return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc];
|
||||
grad_in += chunk_idx * chunk_size;
|
||||
T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc];
|
||||
grad_in += chunk_idx * chunk_size;
|
||||
|
||||
T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc];
|
||||
weight_in += chunk_idx * chunk_size;
|
||||
T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc];
|
||||
weight_in += chunk_idx * chunk_size;
|
||||
|
||||
T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc];
|
||||
mom_in += chunk_idx * chunk_size;
|
||||
T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc];
|
||||
mom_in += chunk_idx * chunk_size;
|
||||
|
||||
at::Half *model_weights_out = nullptr;
|
||||
if (N == 4) {
|
||||
model_weights_out = (at::Half *)tl.addresses[3][tensor_loc];
|
||||
model_weights_out += chunk_idx * chunk_size;
|
||||
}
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
// Non-divergent exit condition for the __syncthreads
|
||||
float incoming_grads[ILP];
|
||||
float incoming_weights[ILP];
|
||||
float incoming_moms[ILP];
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
incoming_grads[ii] = 0;
|
||||
incoming_weights[ii] = 0;
|
||||
incoming_moms[ii] = 0;
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
incoming_grads[ii] = static_cast<float>(grad_in[i]) * scale;
|
||||
incoming_weights[ii] = static_cast<float>(weight_in[i]);
|
||||
incoming_moms[ii] = static_cast<float>(mom_in[i]);
|
||||
at::Half *model_weights_out = nullptr;
|
||||
if (N == 4)
|
||||
{
|
||||
model_weights_out = (at::Half *)tl.addresses[3][tensor_loc];
|
||||
model_weights_out += chunk_idx * chunk_size;
|
||||
}
|
||||
}
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
// Non-divergent exit condition for the __syncthreads
|
||||
float incoming_grads[ILP];
|
||||
float incoming_weights[ILP];
|
||||
float incoming_moms[ILP];
|
||||
for (int i_start = 0;
|
||||
i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
incoming_grads[ii] = 0;
|
||||
incoming_weights[ii] = 0;
|
||||
incoming_moms[ii] = 0;
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size)
|
||||
{
|
||||
incoming_grads[ii] = static_cast<float>(grad_in[i]) * scale;
|
||||
incoming_weights[ii] = static_cast<float>(weight_in[i]);
|
||||
incoming_moms[ii] = static_cast<float>(mom_in[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// note for clarification to future michael:
|
||||
// From a pure memory dependency perspective, there's likely no point unrolling
|
||||
|
@ -83,128 +98,185 @@ struct SGDFunctor {
|
|||
// Put another way, the STGs are dependent on the LDGs, but not on each other.
|
||||
// There is still compute ILP benefit from unrolling the loop though.
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
// apply weight decay before momentum if necessary
|
||||
if (wd != 0.f && !wd_after_momentum)
|
||||
incoming_grads[ii] += wd * incoming_weights[ii];
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size)
|
||||
{
|
||||
// apply weight decay before momentum if necessary
|
||||
if (wd != 0.f && !wd_after_momentum)
|
||||
incoming_grads[ii] += wd * incoming_weights[ii];
|
||||
|
||||
if (momentum != 0.f) {
|
||||
if (!first_run)
|
||||
incoming_moms[ii] = incoming_moms[ii] * momentum +
|
||||
(1.f - dampening) * incoming_grads[ii];
|
||||
else // initialize momentums to current incoming grads
|
||||
incoming_moms[ii] = incoming_grads[ii];
|
||||
if (momentum != 0.f)
|
||||
{
|
||||
if (!first_run)
|
||||
incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];
|
||||
else // initialize momentums to current incoming grads
|
||||
incoming_moms[ii] = incoming_grads[ii];
|
||||
|
||||
if (nesterov)
|
||||
incoming_grads[ii] += momentum * incoming_moms[ii];
|
||||
else
|
||||
incoming_grads[ii] = incoming_moms[ii];
|
||||
}
|
||||
if (nesterov)
|
||||
incoming_grads[ii] += momentum * incoming_moms[ii];
|
||||
else
|
||||
incoming_grads[ii] = incoming_moms[ii];
|
||||
}
|
||||
|
||||
// Apply WD after momentum if desired
|
||||
if (wd != 0.f && wd_after_momentum)
|
||||
incoming_grads[ii] += wd * incoming_weights[ii];
|
||||
// Apply WD after momentum if desired
|
||||
if (wd != 0.f && wd_after_momentum)
|
||||
incoming_grads[ii] += wd * incoming_weights[ii];
|
||||
|
||||
// adjust the weight and write out
|
||||
weight_in[i] += (-lr * incoming_grads[ii]);
|
||||
// adjust the weight and write out
|
||||
weight_in[i] += (-lr * incoming_grads[ii]);
|
||||
|
||||
// if necessary, write out an fp16 copy of the weights
|
||||
if (N == 4)
|
||||
model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
|
||||
// if necessary, write out an fp16 copy of the weights
|
||||
if (N == 4)
|
||||
model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
|
||||
|
||||
// also write out the new momentum
|
||||
if (momentum != 0.f) mom_in[i] = incoming_moms[ii];
|
||||
// also write out the new momentum
|
||||
if (momentum != 0.f)
|
||||
mom_in[i] = incoming_moms[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
float wd, float momentum, float dampening, float lr,
|
||||
bool nesterov, bool first_run,
|
||||
bool wd_after_momentum, float scale) {
|
||||
auto num_tensors = tensor_lists.size();
|
||||
auto grad_type = tensor_lists[0][0].scalar_type();
|
||||
auto weight_type = tensor_lists[1][0].scalar_type();
|
||||
void multi_tensor_sgd_cuda(
|
||||
int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
float wd,
|
||||
float momentum,
|
||||
float dampening,
|
||||
float lr,
|
||||
bool nesterov,
|
||||
bool first_run,
|
||||
bool wd_after_momentum,
|
||||
float scale)
|
||||
{
|
||||
auto num_tensors = tensor_lists.size();
|
||||
auto grad_type = tensor_lists[0][0].scalar_type();
|
||||
auto weight_type = tensor_lists[1][0].scalar_type();
|
||||
|
||||
if (num_tensors == 4)
|
||||
for (int i = 0; i < tensor_lists[3].size(); i++)
|
||||
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
|
||||
"Additional output tensors should always be fp16.");
|
||||
if (num_tensors == 4)
|
||||
for (int i = 0; i < tensor_lists[3].size(); i++)
|
||||
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
|
||||
"Additional output tensors should always be fp16.");
|
||||
|
||||
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(),
|
||||
"expected noop flag to be on the same device as tensors");
|
||||
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors");
|
||||
|
||||
// We have 3 possibilities to handle here, in terms of
|
||||
// grad_type, param_type, momentum_type, requires_fp16_copy
|
||||
// 1. fp16, fp16, fp16, No
|
||||
// 2. fp32, fp32, fp32, No
|
||||
// 3. fp16, fp32, fp32, Yes
|
||||
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
|
||||
// It's easier to hardcode these possibilities than to use
|
||||
// switches etc. to handle the cross-product of cases where
|
||||
// we don't want the majority of them.
|
||||
// We have 3 possibilities to handle here, in terms of
|
||||
// grad_type, param_type, momentum_type, requires_fp16_copy
|
||||
// 1. fp16, fp16, fp16, No
|
||||
// 2. fp32, fp32, fp32, No
|
||||
// 3. fp16, fp32, fp32, Yes
|
||||
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
|
||||
// It's easier to hardcode these possibilities than to use
|
||||
// switches etc. to handle the cross-product of cases where
|
||||
// we don't want the majority of them.
|
||||
|
||||
// Case 1. fp16, fp16, fp16, No
|
||||
if (grad_type == at::ScalarType::Half &&
|
||||
weight_type == at::ScalarType::Half && num_tensors == 3) {
|
||||
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
SGDFunctor<3, at::Half, at::Half>(), wd, momentum,
|
||||
dampening, lr, nesterov, first_run, wd_after_momentum,
|
||||
scale);
|
||||
}
|
||||
// Case 2. fp16, fp32, fp32, No
|
||||
// else if (grad_type == at::ScalarType::Half &&
|
||||
// weight_type == at::ScalarType::Float &&
|
||||
// num_tensors == 3) {
|
||||
// multi_tensor_apply<3>(
|
||||
// BLOCK_SIZE,
|
||||
// chunk_size,
|
||||
// noop_flag,
|
||||
// tensor_lists,
|
||||
// SGDFunctor<3, at::Half, float>(),
|
||||
// wd,
|
||||
// momentum,
|
||||
// dampening,
|
||||
// lr,
|
||||
// nesterov,
|
||||
// first_run,
|
||||
// wd_after_momentum);
|
||||
// }
|
||||
// Case 2. fp32, fp32, fp32, No
|
||||
else if (grad_type == at::ScalarType::Float &&
|
||||
weight_type == at::ScalarType::Float && num_tensors == 3) {
|
||||
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
SGDFunctor<3, float, float>(), wd, momentum,
|
||||
dampening, lr, nesterov, first_run, wd_after_momentum,
|
||||
scale);
|
||||
}
|
||||
// Case 3. fp16, fp32, fp32, Yes
|
||||
else if (grad_type == at::ScalarType::Half &&
|
||||
weight_type == at::ScalarType::Float && num_tensors == 4) {
|
||||
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
SGDFunctor<4, at::Half, float>(), wd, momentum,
|
||||
dampening, lr, nesterov, first_run, wd_after_momentum,
|
||||
scale);
|
||||
}
|
||||
// Case 4. fp32, fp32, fp32, Yes
|
||||
else if (grad_type == at::ScalarType::Float &&
|
||||
weight_type == at::ScalarType::Float && num_tensors == 4) {
|
||||
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
SGDFunctor<4, float, float>(), wd, momentum,
|
||||
dampening, lr, nesterov, first_run, wd_after_momentum,
|
||||
scale);
|
||||
} else {
|
||||
AT_ERROR(
|
||||
"multi_tensor_sgd only supports some combinations of gradient & weight "
|
||||
"types. Given: ",
|
||||
"gradient: ", grad_type, ", weight: ", weight_type,
|
||||
", num_lists: ", num_tensors);
|
||||
}
|
||||
// Case 1. fp16, fp16, fp16, No
|
||||
if (grad_type == at::ScalarType::Half &&
|
||||
weight_type == at::ScalarType::Half &&
|
||||
num_tensors == 3)
|
||||
{
|
||||
multi_tensor_apply<3>(
|
||||
BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
SGDFunctor<3, at::Half, at::Half>(),
|
||||
wd,
|
||||
momentum,
|
||||
dampening,
|
||||
lr,
|
||||
nesterov,
|
||||
first_run,
|
||||
wd_after_momentum,
|
||||
scale);
|
||||
}
|
||||
// Case 2. fp16, fp32, fp32, No
|
||||
// else if (grad_type == at::ScalarType::Half &&
|
||||
// weight_type == at::ScalarType::Float &&
|
||||
// num_tensors == 3) {
|
||||
// multi_tensor_apply<3>(
|
||||
// BLOCK_SIZE,
|
||||
// chunk_size,
|
||||
// noop_flag,
|
||||
// tensor_lists,
|
||||
// SGDFunctor<3, at::Half, float>(),
|
||||
// wd,
|
||||
// momentum,
|
||||
// dampening,
|
||||
// lr,
|
||||
// nesterov,
|
||||
// first_run,
|
||||
// wd_after_momentum);
|
||||
// }
|
||||
// Case 2. fp32, fp32, fp32, No
|
||||
else if (grad_type == at::ScalarType::Float &&
|
||||
weight_type == at::ScalarType::Float &&
|
||||
num_tensors == 3)
|
||||
{
|
||||
multi_tensor_apply<3>(
|
||||
BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
SGDFunctor<3, float, float>(),
|
||||
wd,
|
||||
momentum,
|
||||
dampening,
|
||||
lr,
|
||||
nesterov,
|
||||
first_run,
|
||||
wd_after_momentum,
|
||||
scale);
|
||||
}
|
||||
// Case 3. fp16, fp32, fp32, Yes
|
||||
else if (grad_type == at::ScalarType::Half &&
|
||||
weight_type == at::ScalarType::Float &&
|
||||
num_tensors == 4)
|
||||
{
|
||||
multi_tensor_apply<4>(
|
||||
BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
SGDFunctor<4, at::Half, float>(),
|
||||
wd,
|
||||
momentum,
|
||||
dampening,
|
||||
lr,
|
||||
nesterov,
|
||||
first_run,
|
||||
wd_after_momentum,
|
||||
scale);
|
||||
}
|
||||
// Case 4. fp32, fp32, fp32, Yes
|
||||
else if (grad_type == at::ScalarType::Float &&
|
||||
weight_type == at::ScalarType::Float &&
|
||||
num_tensors == 4)
|
||||
{
|
||||
multi_tensor_apply<4>(
|
||||
BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
SGDFunctor<4, float, float>(),
|
||||
wd,
|
||||
momentum,
|
||||
dampening,
|
||||
lr,
|
||||
nesterov,
|
||||
first_run,
|
||||
wd_after_momentum,
|
||||
scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
|
||||
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
|
||||
}
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
|
@ -10,9 +10,8 @@
|
|||
#include "kernels.h"
|
||||
|
||||
template <typename T>
|
||||
MultiHeadAttention<T>::MultiHeadAttention(int layer_id, int max_batch_tokens,
|
||||
int max_seq_len, int hidden_size,
|
||||
int num_heads,
|
||||
MultiHeadAttention<T>::MultiHeadAttention(int layer_id, int max_batch_tokens, int max_seq_len,
|
||||
int hidden_size, int num_heads,
|
||||
float attn_prob_dropout_ratio,
|
||||
float hidden_output_dropout_ratio,
|
||||
bool pre_or_postLayerNorm)
|
||||
|
@ -23,22 +22,18 @@ MultiHeadAttention<T>::MultiHeadAttention(int layer_id, int max_batch_tokens,
|
|||
_heads(num_heads),
|
||||
_training(true),
|
||||
_pre_or_postLayerNorm(pre_or_postLayerNorm),
|
||||
_qkv_linear(
|
||||
typename FeedForward<T>::Config(3 * hidden_size, hidden_size)),
|
||||
_attn_out_linear(
|
||||
typename FeedForward<T>::Config(hidden_size, hidden_size)),
|
||||
_attn_ln(typename Normalize_Layer<T>::Config(hidden_size, false),
|
||||
_max_batch_tokens),
|
||||
_qkv_linear(typename FeedForward<T>::Config(3 * hidden_size, hidden_size)),
|
||||
_attn_out_linear(typename FeedForward<T>::Config(hidden_size, hidden_size)),
|
||||
_attn_ln(typename Normalize_Layer<T>::Config(hidden_size, false), _max_batch_tokens),
|
||||
_softmax(typename Softmax<T>::Config(num_heads)),
|
||||
_attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio),
|
||||
_max_batch_tokens * _heads * _max_seq_len),
|
||||
_attn_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio),
|
||||
_max_batch_tokens * _hidden_size),
|
||||
_attn_scores(typename StridedBatchGemm<T>::Config(
|
||||
(T(1.0) / T(sqrt(_hidden_size / _heads))), T(0.0), CUBLAS_OP_T,
|
||||
CUBLAS_OP_N)),
|
||||
_attn_context(typename StridedBatchGemm<T>::Config(
|
||||
T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) {
|
||||
_attn_scores(typename StridedBatchGemm<T>::Config((T(1.0) / T(sqrt(_hidden_size / _heads))),
|
||||
T(0.0), CUBLAS_OP_T, CUBLAS_OP_N)),
|
||||
_attn_context(
|
||||
typename StridedBatchGemm<T>::Config(T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) {
|
||||
assert(_hidden_size % _heads == 0);
|
||||
}
|
||||
|
||||
|
@ -48,52 +43,43 @@ MultiHeadAttention<T>::~MultiHeadAttention() {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiHeadAttention<T>::attn_layer_fw(const T *input_ptr,
|
||||
const T *input_mask_ptr,
|
||||
void MultiHeadAttention<T>::attn_layer_fw(const T *input_ptr, const T *input_mask_ptr,
|
||||
T *output_ptr, T *buffer) {
|
||||
T *q_tf_ptr = _qkv_ptr;
|
||||
T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size;
|
||||
T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size;
|
||||
|
||||
if (_pre_or_postLayerNorm) {
|
||||
_attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr,
|
||||
_batch_tokens, _stream);
|
||||
_attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens,
|
||||
_stream);
|
||||
}
|
||||
const T *gemmQKV_inp_ptr =
|
||||
_pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
|
||||
const T *gemmQKV_inp_ptr = _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
|
||||
_qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size);
|
||||
_qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer,
|
||||
_cublasHandle);
|
||||
_qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer, _cublasHandle);
|
||||
|
||||
launch_bias_add_transform_20314<T>(q_tf_ptr, buffer, _attn_qkvb_ptr,
|
||||
_batch_size, _seq_len, 3, _heads / pg_size,
|
||||
_hidden_size / _heads, _stream);
|
||||
launch_bias_add_transform_20314<T>(q_tf_ptr, buffer, _attn_qkvb_ptr, _batch_size, _seq_len, 3,
|
||||
_heads / pg_size, _hidden_size / _heads, _stream);
|
||||
|
||||
// attention scores, q*k
|
||||
_attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr,
|
||||
_cublasHandle);
|
||||
_attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle);
|
||||
|
||||
// Softmax + Mask
|
||||
_softmax.reset_size(_heads / pg_size);
|
||||
_softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len,
|
||||
_seq_len, _stream, true);
|
||||
_softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len, _seq_len, _stream, true);
|
||||
|
||||
// attn prob dropout.
|
||||
_attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr,
|
||||
_batch_heads * _seq_len * _seq_len, _stream);
|
||||
|
||||
// attention context, score * v
|
||||
_attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr,
|
||||
_cublasHandle);
|
||||
|
||||
// [b, nh, s, ad] -> [b, s, nh, ad]
|
||||
launch_transform4d_0213<T>(_attn_o_inp_ptr, buffer, _batch_size, _seq_len,
|
||||
_hidden_size / pg_size, _heads / pg_size, 1,
|
||||
_attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr, _batch_heads * _seq_len * _seq_len,
|
||||
_stream);
|
||||
|
||||
// attention context, score * v
|
||||
_attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle);
|
||||
|
||||
// [b, nh, s, ad] -> [b, s, nh, ad]
|
||||
launch_transform4d_0213<T>(_attn_o_inp_ptr, buffer, _batch_size, _seq_len, _hidden_size / pg_size,
|
||||
_heads / pg_size, 1, _stream);
|
||||
|
||||
_attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size);
|
||||
_attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr,
|
||||
output_ptr, _cublasHandle);
|
||||
_attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr, output_ptr, _cublasHandle);
|
||||
|
||||
// allreduce
|
||||
if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) {
|
||||
|
@ -102,27 +88,24 @@ void MultiHeadAttention<T>::attn_layer_fw(const T *input_ptr,
|
|||
if (typeid(T) != typeid(float)) {
|
||||
data_type = torch::kHalf;
|
||||
}
|
||||
auto output_tensor = torch::from_blob(
|
||||
output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)},
|
||||
torch::TensorOptions(torch::kCUDA).dtype(data_type));
|
||||
auto output_tensor =
|
||||
torch::from_blob(output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)},
|
||||
torch::TensorOptions(torch::kCUDA).dtype(data_type));
|
||||
std::vector<torch::Tensor> allreduce_tensors = {output_tensor};
|
||||
auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions());
|
||||
work->wait();
|
||||
}
|
||||
|
||||
_attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr,
|
||||
_attn_ob_ptr, _batch_tokens, _hidden_size,
|
||||
_stream);
|
||||
_attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr, _attn_ob_ptr,
|
||||
_batch_tokens, _hidden_size, _stream);
|
||||
if (!_pre_or_postLayerNorm) {
|
||||
// in-place ln since ln-input will not be used in post-ln mode
|
||||
_attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr,
|
||||
_batch_tokens, _stream);
|
||||
_attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, _stream);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiHeadAttention<T>::Forward(const T *input_ptr, const T *input_mask_ptr,
|
||||
T *out_ptr) {
|
||||
void MultiHeadAttention<T>::Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr) {
|
||||
_stream = Context::Instance().get_stream();
|
||||
_cublasHandle = Context::Instance().get_cublashandle();
|
||||
T *attn_buffer = _shared_mem_ptr; // 3 * _batch_dim
|
||||
|
@ -131,11 +114,8 @@ void MultiHeadAttention<T>::Forward(const T *input_ptr, const T *input_mask_ptr,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr,
|
||||
const T *input_mask_ptr,
|
||||
const T *output_ptr,
|
||||
const T *grad_output_ptr,
|
||||
T *grad_input_ptr, T *buffer) {
|
||||
void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, const T *output_ptr,
|
||||
const T *grad_output_ptr, T *grad_input_ptr, T *buffer) {
|
||||
cudaStream_t streams[2] = {_stream, _stream};
|
||||
|
||||
const T *q_tf_ptr = _qkv_ptr;
|
||||
|
@ -157,57 +137,45 @@ void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr,
|
|||
// batch_size * head_num * seq_len * seq_len);
|
||||
|
||||
if (_pre_or_postLayerNorm) {
|
||||
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr,
|
||||
grad_output_ptr, _batch_tokens,
|
||||
_hidden_size, _stream);
|
||||
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, grad_output_ptr,
|
||||
_batch_tokens, _hidden_size, _stream);
|
||||
} else {
|
||||
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr,
|
||||
grad_output_ptr, nullptr, output_ptr, _attn_nw_ptr,
|
||||
_attn_nb_ptr, _batch_tokens, streams);
|
||||
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr,
|
||||
grad_residual_ptr, _batch_tokens,
|
||||
_hidden_size, _stream);
|
||||
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr, grad_output_ptr,
|
||||
nullptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams);
|
||||
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, grad_residual_ptr,
|
||||
_batch_tokens, _hidden_size, _stream);
|
||||
}
|
||||
|
||||
// bw of output project
|
||||
_attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size);
|
||||
_attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr,
|
||||
_attn_ow_ptr, _grad_attn_ow_ptr, _grad_attn_ob_ptr,
|
||||
_cublasHandle, _stream, grad_input_buf_ptr, nullptr,
|
||||
false);
|
||||
launch_transform_0213<T>(grad_input_ptr, grad_input_buf_ptr, _batch_size,
|
||||
_seq_len, _hidden_size / pg_size, _heads / pg_size,
|
||||
_stream);
|
||||
_attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr, _attn_ow_ptr,
|
||||
_grad_attn_ow_ptr, _grad_attn_ob_ptr, _cublasHandle, _stream,
|
||||
grad_input_buf_ptr, nullptr, false);
|
||||
launch_transform_0213<T>(grad_input_ptr, grad_input_buf_ptr, _batch_size, _seq_len,
|
||||
_hidden_size / pg_size, _heads / pg_size, _stream);
|
||||
|
||||
// bw of score * v
|
||||
_attn_context.Backward(
|
||||
_batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle,
|
||||
grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr);
|
||||
_attn_context.Backward(_batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle,
|
||||
grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr);
|
||||
|
||||
_attn_prob_dropout.d_dropout(grad_softmax_ptr,
|
||||
_batch_heads * _seq_len * _seq_len, _stream);
|
||||
_attn_prob_dropout.d_dropout(grad_softmax_ptr, _batch_heads * _seq_len * _seq_len, _stream);
|
||||
|
||||
_softmax.reset_size(_heads / pg_size);
|
||||
_softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len,
|
||||
_seq_len, _stream);
|
||||
_softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len, _seq_len, _stream);
|
||||
|
||||
// bw of q * k
|
||||
_attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr,
|
||||
_cublasHandle, grad_qkv_5d_ptr + _batch_dim / pg_size,
|
||||
grad_qkv_5d_ptr);
|
||||
_attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle,
|
||||
grad_qkv_5d_ptr + _batch_dim / pg_size, grad_qkv_5d_ptr);
|
||||
|
||||
// [3, b, nh, s, ad] -> [b, s, 3, h]
|
||||
launch_transform4d_0213<T>(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size,
|
||||
_seq_len, _hidden_size / pg_size, _heads / pg_size,
|
||||
3, _stream);
|
||||
launch_transform4d_0213<T>(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size, _seq_len,
|
||||
_hidden_size / pg_size, _heads / pg_size, 3, _stream);
|
||||
|
||||
const T *gemmQKV_inp_ptr =
|
||||
_pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
|
||||
const T *gemmQKV_inp_ptr = _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
|
||||
_qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size);
|
||||
_qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr,
|
||||
_attn_qkvw_ptr, _grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr,
|
||||
_cublasHandle, _stream, grad_input_buf_ptr, nullptr,
|
||||
true);
|
||||
_qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr, _attn_qkvw_ptr,
|
||||
_grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr, _cublasHandle, _stream,
|
||||
grad_input_buf_ptr, nullptr, true);
|
||||
|
||||
// allreduce
|
||||
if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) {
|
||||
|
@ -217,8 +185,7 @@ void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr,
|
|||
data_type = torch::kHalf;
|
||||
}
|
||||
auto grad_input_tensor =
|
||||
torch::from_blob(grad_input_buf_ptr,
|
||||
{int(_batch_size), int(_seq_len), int(_hidden_size)},
|
||||
torch::from_blob(grad_input_buf_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)},
|
||||
torch::TensorOptions(torch::kCUDA).dtype(data_type));
|
||||
std::vector<torch::Tensor> allreduce_tensors = {grad_input_tensor};
|
||||
auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions());
|
||||
|
@ -226,21 +193,19 @@ void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr,
|
|||
}
|
||||
|
||||
if (_pre_or_postLayerNorm) {
|
||||
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr,
|
||||
grad_input_buf_ptr, grad_output_ptr, gemmQKV_inp_ptr,
|
||||
_attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams);
|
||||
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr, grad_input_buf_ptr,
|
||||
grad_output_ptr, gemmQKV_inp_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens,
|
||||
streams);
|
||||
} else {
|
||||
// FIXME later
|
||||
launch_fused_add2<T>(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr,
|
||||
_batch_size, _seq_len, _hidden_size, _stream);
|
||||
launch_fused_add2<T>(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr, _batch_size,
|
||||
_seq_len, _hidden_size, _stream);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiHeadAttention<T>::Backward(const T *grad_output_ptr,
|
||||
const T *input_ptr, const T *output_ptr,
|
||||
const T *input_mask_ptr,
|
||||
T *grad_input_ptr) {
|
||||
void MultiHeadAttention<T>::Backward(const T *grad_output_ptr, const T *input_ptr, const T *output_ptr,
|
||||
const T *input_mask_ptr, T *grad_input_ptr) {
|
||||
_stream = Context::Instance().get_stream();
|
||||
_cublasHandle = Context::Instance().get_cublashandle();
|
||||
T *buffer = _shared_mem_ptr;
|
||||
|
@ -250,8 +215,7 @@ void MultiHeadAttention<T>::Backward(const T *grad_output_ptr,
|
|||
4 * _batch_dim + max(3 * _batch_dim,
|
||||
_batch_size * _head_num * _seq_len * _seq_len);
|
||||
*/
|
||||
attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr,
|
||||
grad_input_ptr, buffer);
|
||||
attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr, grad_input_ptr, buffer);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -269,8 +233,7 @@ template class MultiHeadAttention<__half>;
|
|||
|
||||
// x is torch::Tensor
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
@ -278,17 +241,15 @@ template class MultiHeadAttention<__half>;
|
|||
static std::unordered_map<int, std::shared_ptr<void>> s_multihead_attention;
|
||||
|
||||
template <typename T>
|
||||
int create_multihead_attention(int layer_id, int max_batch_tokens,
|
||||
int max_seq_len, int hidden_dim, int num_heads,
|
||||
float attn_prob_dropout_ratio,
|
||||
float hidden_dropout_ratio,
|
||||
bool pre_or_postLayerNorm,
|
||||
int create_multihead_attention(int layer_id, int max_batch_tokens, int max_seq_len, int hidden_dim,
|
||||
int num_heads, float attn_prob_dropout_ratio,
|
||||
float hidden_dropout_ratio, bool pre_or_postLayerNorm,
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> pg_) {
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
Context::Instance().set_stream(stream);
|
||||
auto layer = std::make_shared<MultiHeadAttention<T>>(
|
||||
layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads,
|
||||
attn_prob_dropout_ratio, hidden_dropout_ratio, pre_or_postLayerNorm);
|
||||
layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads, attn_prob_dropout_ratio,
|
||||
hidden_dropout_ratio, pre_or_postLayerNorm);
|
||||
|
||||
layer->SetPG(pg_);
|
||||
|
||||
|
@ -300,12 +261,15 @@ int create_multihead_attention(int layer_id, int max_batch_tokens,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<torch::Tensor> multihead_attention_fw(
|
||||
int layer_id, const torch::Tensor &input, const torch::Tensor &input_mask,
|
||||
const torch::Tensor &in_proj_weight, const torch::Tensor &in_proj_bias,
|
||||
const torch::Tensor &out_proj_weight, const torch::Tensor &out_proj_bias,
|
||||
const torch::Tensor &norm_weight, const torch::Tensor &norm_bias,
|
||||
bool training_mode, bool prelayernorm) {
|
||||
std::vector<torch::Tensor> multihead_attention_fw(int layer_id, const torch::Tensor &input,
|
||||
const torch::Tensor &input_mask,
|
||||
const torch::Tensor &in_proj_weight,
|
||||
const torch::Tensor &in_proj_bias,
|
||||
const torch::Tensor &out_proj_weight,
|
||||
const torch::Tensor &out_proj_bias,
|
||||
const torch::Tensor &norm_weight,
|
||||
const torch::Tensor &norm_bias,
|
||||
bool training_mode, bool prelayernorm) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(input_mask);
|
||||
|
||||
|
@ -316,8 +280,7 @@ std::vector<torch::Tensor> multihead_attention_fw(
|
|||
T *out_ptr = (T *)output.data_ptr();
|
||||
|
||||
std::shared_ptr<MultiHeadAttention<T>> layer =
|
||||
std::static_pointer_cast<MultiHeadAttention<T>>(
|
||||
s_multihead_attention[layer_id]);
|
||||
std::static_pointer_cast<MultiHeadAttention<T>>(s_multihead_attention[layer_id]);
|
||||
layer->set_cur_batch_shape(input.size(0), input.size(1));
|
||||
layer->SetTrainingMode(training_mode);
|
||||
|
||||
|
@ -334,13 +297,17 @@ std::vector<torch::Tensor> multihead_attention_fw(
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<torch::Tensor> multihead_attention_bw(
|
||||
int layer_id, const torch::Tensor &grad_dec_output,
|
||||
const torch::Tensor &output, const torch::Tensor &input,
|
||||
const torch::Tensor &input_mask, const torch::Tensor &in_proj_weight,
|
||||
const torch::Tensor &in_proj_bias, const torch::Tensor &out_proj_weight,
|
||||
const torch::Tensor &out_proj_bias, const torch::Tensor &norm_weight,
|
||||
const torch::Tensor &norm_bias) {
|
||||
std::vector<torch::Tensor> multihead_attention_bw(int layer_id,
|
||||
const torch::Tensor &grad_dec_output,
|
||||
const torch::Tensor &output,
|
||||
const torch::Tensor &input,
|
||||
const torch::Tensor &input_mask,
|
||||
const torch::Tensor &in_proj_weight,
|
||||
const torch::Tensor &in_proj_bias,
|
||||
const torch::Tensor &out_proj_weight,
|
||||
const torch::Tensor &out_proj_bias,
|
||||
const torch::Tensor &norm_weight,
|
||||
const torch::Tensor &norm_bias) {
|
||||
auto g_output = grad_dec_output.contiguous();
|
||||
CHECK_INPUT(g_output);
|
||||
CHECK_INPUT(output);
|
||||
|
@ -365,8 +332,7 @@ std::vector<torch::Tensor> multihead_attention_bw(
|
|||
T *grad_input_ptr = (T *)grad_input.data_ptr();
|
||||
|
||||
std::shared_ptr<MultiHeadAttention<T>> layer =
|
||||
std::static_pointer_cast<MultiHeadAttention<T>>(
|
||||
s_multihead_attention[layer_id]);
|
||||
std::static_pointer_cast<MultiHeadAttention<T>>(s_multihead_attention[layer_id]);
|
||||
layer->set_cur_batch_shape(g_output.size(0), g_output.size(1));
|
||||
|
||||
layer->_grad_attn_qkvw_ptr = (T *)grad_in_proj_weight.data_ptr();
|
||||
|
@ -376,12 +342,10 @@ std::vector<torch::Tensor> multihead_attention_bw(
|
|||
layer->_grad_attn_nw_ptr = (T *)grad_norm_weight.data_ptr();
|
||||
layer->_grad_attn_nb_ptr = (T *)grad_norm_bias.data_ptr();
|
||||
|
||||
layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr,
|
||||
grad_input_ptr);
|
||||
layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr, grad_input_ptr);
|
||||
|
||||
return {grad_input, grad_in_proj_weight, grad_in_proj_bias,
|
||||
grad_out_proj_weight, grad_out_proj_bias, grad_norm_weight,
|
||||
grad_norm_bias};
|
||||
return {grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight,
|
||||
grad_out_proj_bias, grad_norm_weight, grad_norm_bias};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
|
|
|
@ -19,25 +19,21 @@
|
|||
template <typename T>
|
||||
class MultiHeadAttention {
|
||||
public:
|
||||
MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len,
|
||||
int hidden_size, int num_heads, float attn_dropout_ratio,
|
||||
float hidden_output_dropout_ratio,
|
||||
MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len, int hidden_size,
|
||||
int num_heads, float attn_dropout_ratio, float hidden_output_dropout_ratio,
|
||||
bool pre_or_postLayerNorm);
|
||||
|
||||
virtual ~MultiHeadAttention();
|
||||
|
||||
void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr);
|
||||
|
||||
void Backward(const T *grad_output_ptr, const T *input_ptr,
|
||||
const T *output_ptr, const T *input_mask_ptr,
|
||||
T *grad_input_ptr);
|
||||
void Backward(const T *grad_output_ptr, const T *input_ptr, const T *output_ptr,
|
||||
const T *input_mask_ptr, T *grad_input_ptr);
|
||||
|
||||
void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr,
|
||||
T *buffer);
|
||||
void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, T *buffer);
|
||||
|
||||
void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr,
|
||||
const T *output_ptr, const T *grad_output_ptr,
|
||||
T *grad_input_attn_layer_bwptr, T *buffer);
|
||||
void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, const T *output_ptr,
|
||||
const T *grad_output_ptr, T *grad_input_attn_layer_bwptr, T *buffer);
|
||||
|
||||
void set_cur_batch_shape(int batch_size, int seq_len) {
|
||||
_batch_size = batch_size;
|
||||
|
@ -87,17 +83,14 @@ class MultiHeadAttention {
|
|||
}
|
||||
|
||||
_qkv_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size * 3);
|
||||
_soft_out_ptr =
|
||||
cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
|
||||
_ctx_bufB_ptr =
|
||||
cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
|
||||
_soft_out_ptr = cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
|
||||
_ctx_bufB_ptr = cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
|
||||
_attn_o_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size);
|
||||
|
||||
// buffer size needed by attn bw
|
||||
size_t smem_size =
|
||||
4 * _max_batch_tokens * _hidden_size / pg_size +
|
||||
std::max(3 * _max_batch_tokens * _hidden_size / pg_size,
|
||||
_max_batch_tokens * _heads / pg_size * _max_seq_len);
|
||||
size_t smem_size = 4 * _max_batch_tokens * _hidden_size / pg_size +
|
||||
std::max(3 * _max_batch_tokens * _hidden_size / pg_size,
|
||||
_max_batch_tokens * _heads / pg_size * _max_seq_len);
|
||||
|
||||
if (!_shared_mem_ptr) {
|
||||
cuda_free(_shared_mem_ptr);
|
||||
|
|
|
@ -2,13 +2,12 @@
|
|||
* with minor changes. */
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_profiler_api.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "scaled_masked_softmax.h"
|
||||
#include "type_shim.h"
|
||||
|
||||
|
@ -16,15 +15,17 @@ namespace multihead_attn {
|
|||
namespace fused_softmax {
|
||||
namespace scaled_masked_softmax {
|
||||
|
||||
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches,
|
||||
int attn_heads) {
|
||||
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
|
||||
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){
|
||||
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
|
||||
}
|
||||
|
||||
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
|
||||
float scale_factor) {
|
||||
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len,
|
||||
// seq_len]
|
||||
|
||||
torch::Tensor fwd_cuda(
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& mask,
|
||||
float scale_factor)
|
||||
{
|
||||
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
|
||||
const int batches = input.size(0);
|
||||
const int pad_batches = mask.size(0);
|
||||
const int attn_heads = input.size(1);
|
||||
|
@ -37,10 +38,10 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
|
|||
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
|
||||
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
|
||||
|
||||
// Output
|
||||
// Output
|
||||
auto act_options = input.options().requires_grad(false);
|
||||
torch::Tensor softmax_results = torch::empty(
|
||||
{batches, attn_heads, query_seq_len, key_seq_len}, act_options);
|
||||
torch::Tensor softmax_results =
|
||||
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
|
||||
|
||||
// Softmax Intermediate Result Ptr
|
||||
void* input_ptr = static_cast<void*>(input.data_ptr());
|
||||
|
@ -48,23 +49,31 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
|
|||
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
|
||||
|
||||
DISPATCH_HALF_AND_BFLOAT(
|
||||
input.scalar_type(), "dispatch_scaled_masked_softmax_forward",
|
||||
input.scalar_type(),
|
||||
"dispatch_scaled_masked_softmax_forward",
|
||||
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
|
||||
reinterpret_cast<scalar_t*>(softmax_results_ptr),
|
||||
reinterpret_cast<const scalar_t*>(input_ptr),
|
||||
reinterpret_cast<const uint8_t*>(mask_ptr), scale_factor,
|
||||
query_seq_len, key_seq_len, batches, attn_heads, pad_batches););
|
||||
reinterpret_cast<const scalar_t*>(input_ptr),
|
||||
reinterpret_cast<const uint8_t*>(mask_ptr),
|
||||
scale_factor,
|
||||
query_seq_len,
|
||||
key_seq_len,
|
||||
batches,
|
||||
attn_heads,
|
||||
pad_batches);
|
||||
);
|
||||
return softmax_results;
|
||||
}
|
||||
|
||||
torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
|
||||
torch::Tensor const& softmax_results_,
|
||||
float scale_factor) {
|
||||
torch::Tensor bwd_cuda(
|
||||
torch::Tensor const& output_grads_,
|
||||
torch::Tensor const& softmax_results_,
|
||||
float scale_factor) {
|
||||
|
||||
auto output_grads = output_grads_.contiguous();
|
||||
auto softmax_results = softmax_results_.contiguous();
|
||||
|
||||
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len,
|
||||
// seq_len]
|
||||
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
|
||||
const int batches = output_grads.size(0);
|
||||
const int attn_heads = output_grads.size(1);
|
||||
const int query_seq_len = output_grads.size(2);
|
||||
|
@ -72,18 +81,24 @@ torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
|
|||
|
||||
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
|
||||
|
||||
// Softmax Grad
|
||||
//Softmax Grad
|
||||
DISPATCH_HALF_AND_BFLOAT(
|
||||
output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward",
|
||||
output_grads_.scalar_type(),
|
||||
"dispatch_scaled_masked_softmax_backward",
|
||||
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
|
||||
scale_factor, query_seq_len, key_seq_len, batches, attn_heads););
|
||||
|
||||
// backward pass is completely in-place
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
|
||||
scale_factor,
|
||||
query_seq_len,
|
||||
key_seq_len,
|
||||
batches,
|
||||
attn_heads);
|
||||
);
|
||||
|
||||
//backward pass is completely in-place
|
||||
return output_grads;
|
||||
}
|
||||
} // namespace scaled_masked_softmax
|
||||
} // namespace fused_softmax
|
||||
} // namespace multihead_attn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,52 +3,57 @@
|
|||
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace multihead_attn {
|
||||
namespace fused_softmax {
|
||||
namespace scaled_upper_triang_masked_softmax {
|
||||
|
||||
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor);
|
||||
torch::Tensor fwd_cuda(
|
||||
torch::Tensor const& input,
|
||||
float scale_factor);
|
||||
|
||||
torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results,
|
||||
float scale_factor);
|
||||
torch::Tensor bwd_cuda(
|
||||
torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results,
|
||||
float scale_factor);
|
||||
|
||||
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
|
||||
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
|
||||
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
|
||||
(input.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
(input.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
|
||||
return fwd_cuda(input, scale_factor);
|
||||
}
|
||||
|
||||
torch::Tensor bwd(torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results, float scale_factor) {
|
||||
torch::Tensor bwd(
|
||||
torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results,
|
||||
float scale_factor) {
|
||||
|
||||
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
|
||||
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
|
||||
|
||||
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
|
||||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
|
||||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
|
||||
return bwd_cuda(output_grads, softmax_results, scale_factor);
|
||||
}
|
||||
|
||||
} // end namespace scaled_upper_triang_masked_softmax
|
||||
} // end namespace fused_softmax
|
||||
} // end namespace multihead_attn
|
||||
} // end namespace scaled_upper_triang_masked_softmax
|
||||
} // end namespace fused_softmax
|
||||
} // end namespace multihead_attn
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward",
|
||||
m.def("forward",
|
||||
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Forward.");
|
||||
m.def("backward",
|
||||
"Self Multihead Attention scaled, time masked softmax -- Forward.");
|
||||
m.def("backward",
|
||||
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Backward.");
|
||||
"Self Multihead Attention scaled, time masked softmax -- Backward.");
|
||||
}
|
||||
|
|
|
@ -2,13 +2,12 @@
|
|||
* with minor changes. */
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_profiler_api.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "scaled_upper_triang_masked_softmax.h"
|
||||
#include "type_shim.h"
|
||||
|
||||
|
@ -16,15 +15,18 @@ namespace multihead_attn {
|
|||
namespace fused_softmax {
|
||||
namespace scaled_upper_triang_masked_softmax {
|
||||
|
||||
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) {
|
||||
torch::Tensor fwd_cuda(
|
||||
torch::Tensor const& input,
|
||||
float scale_factor)
|
||||
{
|
||||
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
|
||||
const int attn_batches = input.size(0);
|
||||
const int seq_len = input.size(1);
|
||||
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
|
||||
|
||||
// Output
|
||||
// Output
|
||||
auto act_options = input.options().requires_grad(false);
|
||||
torch::Tensor softmax_results =
|
||||
torch::Tensor softmax_results =
|
||||
torch::empty({attn_batches, seq_len, seq_len}, act_options);
|
||||
|
||||
// Softmax Intermediate Result Ptr
|
||||
|
@ -34,42 +36,50 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) {
|
|||
DISPATCH_HALF_AND_BFLOAT(
|
||||
input.scalar_type(),
|
||||
"dispatch_scaled_upper_triang_masked_softmax_forward",
|
||||
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t,
|
||||
float>(
|
||||
reinterpret_cast<scalar_t*>(softmax_results_ptr),
|
||||
reinterpret_cast<const scalar_t*>(input_ptr), scale_factor, seq_len,
|
||||
seq_len, attn_batches););
|
||||
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
|
||||
reinterpret_cast<scalar_t*>(softmax_results_ptr),
|
||||
reinterpret_cast<const scalar_t*>(input_ptr),
|
||||
scale_factor,
|
||||
seq_len,
|
||||
seq_len,
|
||||
attn_batches);
|
||||
);
|
||||
return softmax_results;
|
||||
}
|
||||
|
||||
|
||||
torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
|
||||
torch::Tensor const& softmax_results_,
|
||||
float scale_factor) {
|
||||
torch::Tensor bwd_cuda(
|
||||
torch::Tensor const& output_grads_,
|
||||
torch::Tensor const& softmax_results_,
|
||||
float scale_factor) {
|
||||
|
||||
auto output_grads = output_grads_.contiguous();
|
||||
auto softmax_results = softmax_results_.contiguous();
|
||||
|
||||
// output grads is a 3d tensor with dimensions [attn_batches, seq_len,
|
||||
// seq_len]
|
||||
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
|
||||
const int attn_batches = output_grads.size(0);
|
||||
const int seq_len = output_grads.size(1);
|
||||
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
|
||||
|
||||
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
|
||||
|
||||
// Softmax Grad
|
||||
//Softmax Grad
|
||||
DISPATCH_HALF_AND_BFLOAT(
|
||||
output_grads_.scalar_type(),
|
||||
"dispatch_scaled_upper_triang_masked_softmax_backward",
|
||||
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t,
|
||||
float>(
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
|
||||
scale_factor, seq_len, seq_len, attn_batches););
|
||||
|
||||
// backward pass is completely in-place
|
||||
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
|
||||
scale_factor,
|
||||
seq_len,
|
||||
seq_len,
|
||||
attn_batches);
|
||||
);
|
||||
|
||||
//backward pass is completely in-place
|
||||
return output_grads;
|
||||
}
|
||||
} // namespace scaled_upper_triang_masked_softmax
|
||||
} // namespace fused_softmax
|
||||
} // namespace multihead_attn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,8 +24,8 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
|
|||
input_ = input.contiguous()
|
||||
weight_ = weight.contiguous()
|
||||
bias_ = bias.contiguous()
|
||||
output, mean, invvar = colossal_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, bias_,
|
||||
ctx.eps)
|
||||
output, mean, invvar = colossal_layer_norm_cuda.forward_affine(
|
||||
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
|
||||
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
|
||||
|
||||
return output
|
||||
|
@ -72,7 +72,8 @@ class MixedFusedLayerNorm(torch.nn.Module):
|
|||
|
||||
def forward(self, input):
|
||||
|
||||
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)
|
||||
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias,
|
||||
self.normalized_shape, self.eps)
|
||||
|
||||
def __repr__(self):
|
||||
return f'MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})'
|
||||
|
|
|
@ -28,7 +28,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
|||
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
|
||||
|
||||
scale_t = torch.tensor([scale])
|
||||
softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
|
||||
softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(
|
||||
inputs, scale_t[0]
|
||||
)
|
||||
|
||||
ctx.save_for_backward(softmax_results, scale_t)
|
||||
return softmax_results
|
||||
|
@ -41,7 +43,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
|||
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
|
||||
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
input_grads = colossal_scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
||||
input_grads = colossal_scaled_upper_triang_masked_softmax.backward(
|
||||
output_grads, softmax_results, scale_t[0]
|
||||
)
|
||||
|
||||
return input_grads, None
|
||||
|
||||
|
@ -77,7 +81,9 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
|
|||
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
|
||||
input_grads = colossal_scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
||||
input_grads = colossal_scaled_masked_softmax.backward(
|
||||
output_grads, softmax_results, scale_t[0]
|
||||
)
|
||||
return input_grads, None, None
|
||||
|
||||
|
||||
|
@ -108,8 +114,9 @@ class FusedScaleMaskSoftmax(nn.Module):
|
|||
super(FusedScaleMaskSoftmax, self).__init__()
|
||||
self.input_in_fp16 = input_in_fp16
|
||||
self.input_in_bf16 = input_in_bf16
|
||||
assert not (self.input_in_fp16
|
||||
and self.input_in_bf16), "both fp16 and bf16 flags cannot be active at the same time."
|
||||
assert not (
|
||||
self.input_in_fp16 and self.input_in_bf16
|
||||
), "both fp16 and bf16 flags cannot be active at the same time."
|
||||
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
|
||||
self.attn_mask_type = attn_mask_type
|
||||
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
|
||||
|
@ -117,7 +124,9 @@ class FusedScaleMaskSoftmax(nn.Module):
|
|||
self.softmax_in_fp32 = softmax_in_fp32
|
||||
self.scale = scale
|
||||
|
||||
assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled"
|
||||
assert (
|
||||
self.scale is None or softmax_in_fp32
|
||||
), "softmax should be in fp32 when scaled"
|
||||
|
||||
def forward(self, input, mask):
|
||||
# [b, np, sq, sk]
|
||||
|
@ -131,13 +140,14 @@ class FusedScaleMaskSoftmax(nn.Module):
|
|||
def is_kernel_available(self, mask, b, np, sq, sk):
|
||||
attn_batches = b * np
|
||||
|
||||
if (self.scaled_masked_softmax_fusion # user want to fuse
|
||||
and self.input_in_float16 # input must be fp16
|
||||
and mask is not None # mask tensor must not be None
|
||||
and 16 < sk <= 2048 # sk must be 16 ~ 2048
|
||||
and sq % 4 == 0 # sq must be divisor of 4
|
||||
and attn_batches % 4 == 0 # np * b must be divisor of 4
|
||||
):
|
||||
if (
|
||||
self.scaled_masked_softmax_fusion # user want to fuse
|
||||
and self.input_in_float16 # input must be fp16
|
||||
and mask is not None # mask tensor must not be None
|
||||
and 16 < sk <= 2048 # sk must be 16 ~ 2048
|
||||
and sq % 4 == 0 # sq must be divisor of 4
|
||||
and attn_batches % 4 == 0 # np * b must be divisor of 4
|
||||
):
|
||||
if 0 <= sk <= 2048:
|
||||
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
|
||||
|
||||
###### BIAS GELU FUSION/ NO AUTOGRAD ################
|
||||
# 1/sqrt(2*pi)-> 0.3989423
|
||||
# 1/sqrt(2) -> 0.70710678
|
||||
|
@ -8,12 +9,10 @@ import torch
|
|||
# actual gelu is:
|
||||
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def bias_gelu(bias, y):
|
||||
x = bias + y
|
||||
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
||||
|
||||
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
||||
|
||||
# gradient of tanh approximation of gelu
|
||||
# gradient of actual gelu is:
|
||||
|
@ -24,11 +23,9 @@ def bias_gelu_back(g, bias, y):
|
|||
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
||||
return ff * g
|
||||
|
||||
return ff*g
|
||||
|
||||
class GeLUFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
# bias is an optional argument
|
||||
def forward(ctx, input, bias):
|
||||
|
@ -41,5 +38,4 @@ class GeLUFunction(torch.autograd.Function):
|
|||
tmp = bias_gelu_back(grad_output, bias, input)
|
||||
return tmp, tmp
|
||||
|
||||
|
||||
bias_gelu_impl = GeLUFunction.apply
|
||||
bias_gelu_impl = GeLUFunction.apply
|
|
@ -182,7 +182,7 @@ class Linear2D(ParallelLayer):
|
|||
def forward(self, x: Tensor) -> Tensor:
|
||||
# input: [m/q, n/q, k/q]
|
||||
# output: [m/q, n/q, h/q]
|
||||
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
|
||||
out_shape = x.shape[:-1] + (self.hidden_size_per_partition, )
|
||||
|
||||
output = Matmul_AB_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank,
|
||||
|
@ -337,16 +337,16 @@ class LayerNorm2D(ParallelLayer):
|
|||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
with torch.no_grad():
|
||||
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
|
||||
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
|
||||
torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
|
||||
E_x /= self.normalized_shape
|
||||
|
||||
# Var_x in the block below is the sum of input^2
|
||||
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
|
||||
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
|
||||
torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
|
||||
Var_x /= self.normalized_shape
|
||||
|
||||
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
|
||||
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
|
||||
# this time 1/sqrt(Var_x + epsilon)
|
||||
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
|
||||
|
||||
|
@ -569,7 +569,7 @@ class PatchEmbedding2D(ParallelLayer):
|
|||
|
||||
output = F.conv2d(input_, weight, bias, stride=self.patch_size)
|
||||
if self.flatten:
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
|
||||
cls_token = all_gather_tensor_2d(self.cls_token, -1, ParallelMode.PARALLEL_2D_COL)
|
||||
pos_embed = all_gather_tensor_2d(self.pos_embed, -1, ParallelMode.PARALLEL_2D_COL)
|
||||
|
@ -1012,7 +1012,7 @@ class Classifier2D(ParallelLayer):
|
|||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
out_shape = input_.shape[:-1] + (self.num_classes,)
|
||||
out_shape = input_.shape[:-1] + (self.num_classes, )
|
||||
|
||||
return classifier_2d(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank,
|
||||
|
@ -1186,7 +1186,7 @@ class VocabParallelClassifier2D(ParallelLayer):
|
|||
def forward(self, x: Tensor) -> Tensor:
|
||||
# input: [m/q, n/q, k/q]
|
||||
# output: [m/q, n/q, h/q]
|
||||
out_shape = x.shape[:-1] + (self.output_size_per_partition,)
|
||||
out_shape = x.shape[:-1] + (self.output_size_per_partition, )
|
||||
|
||||
output = Matmul_ABT_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
|
||||
|
|
|
@ -189,7 +189,7 @@ class Linear2p5D(ParallelLayer):
|
|||
def forward(self, x: Tensor) -> Tensor:
|
||||
# input: [m/dq, n/q, k/q]
|
||||
# output: [m/dq, n/q, h/q]
|
||||
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
|
||||
out_shape = x.shape[:-1] + (self.hidden_size_per_partition, )
|
||||
|
||||
output = Matmul_AB_2p5D.apply(
|
||||
x,
|
||||
|
@ -254,7 +254,7 @@ class LayerNorm2p5D(ParallelLayer):
|
|||
self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
|
||||
|
||||
# partitioning dimension
|
||||
self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # *
|
||||
self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # *
|
||||
|
||||
# create parameters
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
|
@ -357,16 +357,16 @@ class LayerNorm2p5D(ParallelLayer):
|
|||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
with torch.no_grad():
|
||||
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
|
||||
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
|
||||
torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
|
||||
E_x /= self.normalized_shape
|
||||
|
||||
# Var_x in the block below is the sum of input^2
|
||||
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
|
||||
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
|
||||
torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
|
||||
Var_x /= self.normalized_shape
|
||||
|
||||
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
|
||||
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
|
||||
# this time 1/sqrt(Var_x + epsilon)
|
||||
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
|
||||
|
||||
|
@ -589,7 +589,7 @@ class PatchEmbedding2p5D(ParallelLayer):
|
|||
|
||||
output = F.conv2d(input_, weight, bias, stride=self.patch_size)
|
||||
if self.flatten:
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
|
||||
cls_token = all_gather_tensor_2p5d(self.cls_token, -1, ParallelMode.PARALLEL_2P5D_COL)
|
||||
pos_embed = all_gather_tensor_2p5d(self.pos_embed, -1, ParallelMode.PARALLEL_2P5D_COL)
|
||||
|
@ -1038,7 +1038,7 @@ class Classifier2p5D(ParallelLayer):
|
|||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
out_shape = input_.shape[:-1] + (self.num_classes,)
|
||||
out_shape = input_.shape[:-1] + (self.num_classes, )
|
||||
|
||||
return classifier_2p5d(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank,
|
||||
self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL,
|
||||
|
@ -1172,7 +1172,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
|
|||
def forward(self, x: Tensor) -> Tensor:
|
||||
# input: [m/dq, n/q, k/q]
|
||||
# output: [m/dq, n/q, h/q]
|
||||
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
|
||||
out_shape = x.shape[:-1] + (self.hidden_size_per_partition, )
|
||||
|
||||
output = Matmul_ABT_2p5D.apply(
|
||||
x,
|
||||
|
|
|
@ -53,8 +53,8 @@ class LayerNorm3D(ParallelLayer):
|
|||
self.weight = Parameter(
|
||||
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
|
||||
self.bias = Parameter(torch.zeros(self.normalized_shape_per_partition,
|
||||
device=get_current_device(), dtype=dtype))
|
||||
else:
|
||||
self.bias = None
|
||||
self.variance_epsilon = eps
|
||||
|
@ -854,7 +854,7 @@ class PatchEmbedding3D(ParallelLayer):
|
|||
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
|
||||
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
|
||||
if self.flatten:
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
|
||||
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
|
||||
output = torch.cat((cls_token, output), dim=1)
|
||||
|
|
|
@ -13,8 +13,7 @@ from torch import Tensor, nn
|
|||
|
||||
|
||||
class CheckpointModule(nn.Module):
|
||||
|
||||
def __init__(self, checkpoint: bool = True, offload: bool = False):
|
||||
def __init__(self, checkpoint: bool = True, offload : bool = False):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self._use_checkpoint = checkpoint
|
||||
|
@ -79,7 +78,6 @@ def get_tensor_parallel_mode():
|
|||
|
||||
|
||||
def _ntuple(n):
|
||||
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
|
|
Loading…
Reference in New Issue