2021-12-21 04:19:52 +00:00
|
|
|
/* Copyright 2021 The LightSeq Team
|
|
|
|
Copyright Microsoft DeepSpeed
|
|
|
|
This file is adapted from Microsoft DeepSpeed
|
2023-03-04 12:08:11 +00:00
|
|
|
Licensed under the MIT License.
|
2021-12-21 04:19:52 +00:00
|
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include <assert.h>
|
|
|
|
#include <cublas_v2.h>
|
|
|
|
#include <cuda.h>
|
|
|
|
#include <cuda_fp16.h>
|
|
|
|
#include <cuda_runtime.h>
|
|
|
|
#include <mma.h>
|
|
|
|
#include <stdio.h>
|
|
|
|
|
|
|
|
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
|
|
|
|
cublasOperation_t transb, int m, int n, int k,
|
|
|
|
const float *alpha, const float *beta, const float *A,
|
|
|
|
const float *B, float *C,
|
|
|
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
|
|
|
|
|
|
|
|
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
|
|
|
|
cublasOperation_t transb, int m, int n, int k,
|
|
|
|
const float *alpha, const float *beta, const __half *A,
|
|
|
|
const __half *B, __half *C,
|
|
|
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
|
|
|
|
|
|
|
int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k,
|
|
|
|
const float *alpha, const float *beta,
|
|
|
|
const float *A, const float *B, float *C,
|
|
|
|
cublasOperation_t op_A, cublasOperation_t op_B,
|
|
|
|
int stride_A, int stride_B, int stride_C,
|
|
|
|
int batch,
|
|
|
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
|
|
|
|
|
|
|
|
int cublas_strided_batched_gemm(
|
|
|
|
cublasHandle_t handle, int m, int n, int k, const float *alpha,
|
|
|
|
const float *beta, const __half *A, const __half *B, __half *C,
|
|
|
|
cublasOperation_t op_A, cublasOperation_t op_B, int stride_A, int stride_B,
|
|
|
|
int stride_C, int batch,
|
|
|
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|