/* Copyright 2021 The LightSeq Team Copyright Microsoft DeepSpeed This file is adapted from Microsoft DeepSpeed */ #pragma once #include #include #include #include #include #include #include int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float *alpha, const float *beta, const float *A, const float *B, float *C, cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float *alpha, const float *beta, const __half *A, const __half *B, __half *C, cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, const float *alpha, const float *beta, const float *A, const float *B, float *C, cublasOperation_t op_A, cublasOperation_t op_B, int stride_A, int stride_B, int stride_C, int batch, cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); int cublas_strided_batched_gemm( cublasHandle_t handle, int m, int n, int k, const float *alpha, const float *beta, const __half *A, const __half *B, __half *C, cublasOperation_t op_A, cublasOperation_t op_B, int stride_A, int stride_B, int stride_C, int batch, cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);