From 1d625fcd36783d866b3d1d1b208e0f2291b545d3 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 9 May 2022 10:56:45 +0800 Subject: [PATCH] [setup] support more cuda architectures (#920) * support more cuda archs * polish code --- setup.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index 748361154..10906b61c 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,6 @@ import os import subprocess -import sys - +import re from setuptools import find_packages, setup # ninja build does not work unless include_dirs are abs path @@ -138,13 +137,13 @@ if build_cuda_ext: 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + version_dependent_macros + extra_cuda_flags) }) - - - cc_flag = ['-gencode', 'arch=compute_70,code=sm_70'] - _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') - cc_flag.append('arch=compute_80,code=sm_80') + cc_flag = [] + for arch in torch.cuda.get_arch_list(): + res = re.search(r'sm_(\d+)', arch) + if res: + arch_cap = res[1] + if int(arch_cap) >= 60: + cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}']) extra_cuda_flags = ['-lineinfo']