[NFC] polish .github/workflows/scripts/build_colossalai_wheel.py code style (#1721)

pull/1743/head
Arsmart1 2022-10-18 09:37:57 +08:00 committed by Frank Lee
parent 730f88f8e1
commit 8860d37846
1 changed files with 12 additions and 9 deletions

View File

@ -7,7 +7,6 @@ import subprocess
from packaging import version from packaging import version
from functools import cmp_to_key from functools import cmp_to_key
WHEEL_TEXT_ROOT_URL = 'https://github.com/hpcaitech/public_assets/tree/main/colossalai/torch_build/torch_wheels' WHEEL_TEXT_ROOT_URL = 'https://github.com/hpcaitech/public_assets/tree/main/colossalai/torch_build/torch_wheels'
RAW_TEXT_FILE_PREFIX = 'https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/torch_build/torch_wheels' RAW_TEXT_FILE_PREFIX = 'https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/torch_build/torch_wheels'
CUDA_HOME = os.environ['CUDA_HOME'] CUDA_HOME = os.environ['CUDA_HOME']
@ -16,10 +15,15 @@ CUDA_HOME = os.environ['CUDA_HOME']
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--torch_version', type=str) parser.add_argument('--torch_version', type=str)
parser.add_argument('--nightly', action='store_true', parser.add_argument(
help='whether this build is for nightly release, if True, will only build on the latest PyTorch version and Python 3.8') '--nightly',
action='store_true',
help=
'whether this build is for nightly release, if True, will only build on the latest PyTorch version and Python 3.8'
)
return parser.parse_args() return parser.parse_args()
def get_cuda_bare_metal_version(): def get_cuda_bare_metal_version():
raw_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True) raw_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split() output = raw_output.split()
@ -30,6 +34,7 @@ def get_cuda_bare_metal_version():
return bare_metal_major, bare_metal_minor return bare_metal_major, bare_metal_minor
def all_wheel_info(): def all_wheel_info():
page_text = requests.get(WHEEL_TEXT_ROOT_URL).text page_text = requests.get(WHEEL_TEXT_ROOT_URL).text
soup = BeautifulSoup(page_text) soup = BeautifulSoup(page_text)
@ -63,6 +68,7 @@ def all_wheel_info():
wheel_info[torch_version][cuda_version][python_version] = dict(method=method, url=url, flags=flags) wheel_info[torch_version][cuda_version][python_version] = dict(method=method, url=url, flags=flags)
return wheel_info return wheel_info
def build_colossalai(wheel_info): def build_colossalai(wheel_info):
cuda_version_major, cuda_version_minor = get_cuda_bare_metal_version() cuda_version_major, cuda_version_minor = get_cuda_bare_metal_version()
cuda_version_on_host = f'{cuda_version_major}.{cuda_version_minor}' cuda_version_on_host = f'{cuda_version_major}.{cuda_version_minor}'
@ -78,12 +84,14 @@ def build_colossalai(wheel_info):
cmd = f'bash ./build_colossalai_wheel.sh {method} {url} {filename} {cuda_version} {python_version} {torch_version} {flags}' cmd = f'bash ./build_colossalai_wheel.sh {method} {url} {filename} {cuda_version} {python_version} {torch_version} {flags}'
os.system(cmd) os.system(cmd)
def main(): def main():
args = parse_args() args = parse_args()
wheel_info = all_wheel_info() wheel_info = all_wheel_info()
# filter wheels on condition # filter wheels on condition
all_torch_versions = list(wheel_info.keys()) all_torch_versions = list(wheel_info.keys())
def _compare_version(a, b): def _compare_version(a, b):
if version.parse(a) > version.parse(b): if version.parse(a) > version.parse(b):
return 1 return 1
@ -105,11 +113,6 @@ def main():
build_colossalai(wheel_info) build_colossalai(wheel_info)
if __name__ == '__main__': if __name__ == '__main__':
main() main()