Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

42 lines
1.3 KiB

import re
import subprocess
from typing import List
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def get_cuda_cc_flag() -> List:
"""get_cuda_cc_flag
cc flag for your GPU arch
"""
# only import torch when needed
# this is to avoid importing torch when building on a machine without torch pre-installed
# one case is to build wheel for pypi release
import torch
cc_flag = []
for arch in torch.cuda.get_arch_list():
res = re.search(r'sm_(\d+)', arch)
if res:
arch_cap = res[1]
if int(arch_cap) >= 60:
cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
return cc_flag
def append_nvcc_threads(nvcc_extra_args):
from torch.utils.cpp_extension import CUDA_HOME
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args