mirror of https://github.com/hpcaitech/ColossalAI
[setup] support more cuda architectures (#920)
* support more cuda archs * polish codepull/922/head
parent
5d8f1262fb
commit
1d625fcd36
17
setup.py
17
setup.py
|
@ -1,7 +1,6 @@
|
|||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import re
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
# ninja build does not work unless include_dirs are abs path
|
||||
|
@ -138,13 +137,13 @@ if build_cuda_ext:
|
|||
'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + version_dependent_macros + extra_cuda_flags)
|
||||
})
|
||||
|
||||
|
||||
|
||||
cc_flag = ['-gencode', 'arch=compute_70,code=sm_70']
|
||||
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11:
|
||||
cc_flag.append('-gencode')
|
||||
cc_flag.append('arch=compute_80,code=sm_80')
|
||||
cc_flag = []
|
||||
for arch in torch.cuda.get_arch_list():
|
||||
res = re.search(r'sm_(\d+)', arch)
|
||||
if res:
|
||||
arch_cap = res[1]
|
||||
if int(arch_cap) >= 60:
|
||||
cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
|
||||
|
||||
extra_cuda_flags = ['-lineinfo']
|
||||
|
||||
|
|
Loading…
Reference in New Issue