mirror of https://github.com/hpcaitech/ColossalAI
[setup] remove torch dependency (#2333)
parent
89f26331e9
commit
8711310cda
|
@ -3,14 +3,18 @@ import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_cc_flag() -> List:
|
def get_cuda_cc_flag() -> List:
|
||||||
"""get_cuda_cc_flag
|
"""get_cuda_cc_flag
|
||||||
|
|
||||||
cc flag for your GPU arch
|
cc flag for your GPU arch
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# only import torch when needed
|
||||||
|
# this is to avoid importing torch when building on a machine without torch pre-installed
|
||||||
|
# one case is to build wheel for pypi release
|
||||||
|
import torch
|
||||||
|
|
||||||
cc_flag = []
|
cc_flag = []
|
||||||
for arch in torch.cuda.get_arch_list():
|
for arch in torch.cuda.get_arch_list():
|
||||||
res = re.search(r'sm_(\d+)', arch)
|
res = re.search(r'sm_(\d+)', arch)
|
||||||
|
|
5
setup.py
5
setup.py
|
@ -15,8 +15,9 @@ try:
|
||||||
if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 10):
|
if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 10):
|
||||||
raise RuntimeError("Colossal-AI requires Pytorch 1.10 or newer.\n"
|
raise RuntimeError("Colossal-AI requires Pytorch 1.10 or newer.\n"
|
||||||
"The latest stable release can be obtained from https://pytorch.org/")
|
"The latest stable release can be obtained from https://pytorch.org/")
|
||||||
|
TORCH_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ModuleNotFoundError('torch is not found. You need to install PyTorch before installing Colossal-AI.')
|
TORCH_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
# ninja build does not work unless include_dirs are abs path
|
# ninja build does not work unless include_dirs are abs path
|
||||||
|
@ -24,7 +25,7 @@ this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
build_cuda_ext = True
|
build_cuda_ext = True
|
||||||
ext_modules = []
|
ext_modules = []
|
||||||
|
|
||||||
if int(os.environ.get('NO_CUDA_EXT', '0')) == 1:
|
if int(os.environ.get('NO_CUDA_EXT', '0')) == 1 or not TORCH_AVAILABLE:
|
||||||
build_cuda_ext = False
|
build_cuda_ext = False
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue