[setup] support more cuda architectures (#920)

* support more cuda archs

* polish code
pull/922/head
ver217 2022-05-09 10:56:45 +08:00 committed by GitHub
parent 5d8f1262fb
commit 1d625fcd36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 9 deletions

View File

@ -1,7 +1,6 @@
import os
import subprocess
import sys
import re
from setuptools import find_packages, setup
# ninja build does not work unless include_dirs are abs path
@ -138,13 +137,13 @@ if build_cuda_ext:
'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + version_dependent_macros + extra_cuda_flags)
})
cc_flag = ['-gencode', 'arch=compute_70,code=sm_70']
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
cc_flag = []
for arch in torch.cuda.get_arch_list():
res = re.search(r'sm_(\d+)', arch)
if res:
arch_cap = res[1]
if int(arch_cap) >= 60:
cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
extra_cuda_flags = ['-lineinfo']