mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish .github/workflows/scripts/build_colossalai_wheel.py code style (#1721)
parent
730f88f8e1
commit
8860d37846
|
@ -7,7 +7,6 @@ import subprocess
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from functools import cmp_to_key
|
from functools import cmp_to_key
|
||||||
|
|
||||||
|
|
||||||
WHEEL_TEXT_ROOT_URL = 'https://github.com/hpcaitech/public_assets/tree/main/colossalai/torch_build/torch_wheels'
|
WHEEL_TEXT_ROOT_URL = 'https://github.com/hpcaitech/public_assets/tree/main/colossalai/torch_build/torch_wheels'
|
||||||
RAW_TEXT_FILE_PREFIX = 'https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/torch_build/torch_wheels'
|
RAW_TEXT_FILE_PREFIX = 'https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/torch_build/torch_wheels'
|
||||||
CUDA_HOME = os.environ['CUDA_HOME']
|
CUDA_HOME = os.environ['CUDA_HOME']
|
||||||
|
@ -16,10 +15,15 @@ CUDA_HOME = os.environ['CUDA_HOME']
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--torch_version', type=str)
|
parser.add_argument('--torch_version', type=str)
|
||||||
parser.add_argument('--nightly', action='store_true',
|
parser.add_argument(
|
||||||
help='whether this build is for nightly release, if True, will only build on the latest PyTorch version and Python 3.8')
|
'--nightly',
|
||||||
|
action='store_true',
|
||||||
|
help=
|
||||||
|
'whether this build is for nightly release, if True, will only build on the latest PyTorch version and Python 3.8'
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_bare_metal_version():
|
def get_cuda_bare_metal_version():
|
||||||
raw_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True)
|
raw_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||||
output = raw_output.split()
|
output = raw_output.split()
|
||||||
|
@ -30,6 +34,7 @@ def get_cuda_bare_metal_version():
|
||||||
|
|
||||||
return bare_metal_major, bare_metal_minor
|
return bare_metal_major, bare_metal_minor
|
||||||
|
|
||||||
|
|
||||||
def all_wheel_info():
|
def all_wheel_info():
|
||||||
page_text = requests.get(WHEEL_TEXT_ROOT_URL).text
|
page_text = requests.get(WHEEL_TEXT_ROOT_URL).text
|
||||||
soup = BeautifulSoup(page_text)
|
soup = BeautifulSoup(page_text)
|
||||||
|
@ -63,6 +68,7 @@ def all_wheel_info():
|
||||||
wheel_info[torch_version][cuda_version][python_version] = dict(method=method, url=url, flags=flags)
|
wheel_info[torch_version][cuda_version][python_version] = dict(method=method, url=url, flags=flags)
|
||||||
return wheel_info
|
return wheel_info
|
||||||
|
|
||||||
|
|
||||||
def build_colossalai(wheel_info):
|
def build_colossalai(wheel_info):
|
||||||
cuda_version_major, cuda_version_minor = get_cuda_bare_metal_version()
|
cuda_version_major, cuda_version_minor = get_cuda_bare_metal_version()
|
||||||
cuda_version_on_host = f'{cuda_version_major}.{cuda_version_minor}'
|
cuda_version_on_host = f'{cuda_version_major}.{cuda_version_minor}'
|
||||||
|
@ -78,12 +84,14 @@ def build_colossalai(wheel_info):
|
||||||
cmd = f'bash ./build_colossalai_wheel.sh {method} {url} {filename} {cuda_version} {python_version} {torch_version} {flags}'
|
cmd = f'bash ./build_colossalai_wheel.sh {method} {url} {filename} {cuda_version} {python_version} {torch_version} {flags}'
|
||||||
os.system(cmd)
|
os.system(cmd)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
wheel_info = all_wheel_info()
|
wheel_info = all_wheel_info()
|
||||||
|
|
||||||
# filter wheels on condition
|
# filter wheels on condition
|
||||||
all_torch_versions = list(wheel_info.keys())
|
all_torch_versions = list(wheel_info.keys())
|
||||||
|
|
||||||
def _compare_version(a, b):
|
def _compare_version(a, b):
|
||||||
if version.parse(a) > version.parse(b):
|
if version.parse(a) > version.parse(b):
|
||||||
return 1
|
return 1
|
||||||
|
@ -105,11 +113,6 @@ def main():
|
||||||
|
|
||||||
build_colossalai(wheel_info)
|
build_colossalai(wheel_info)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue