[NFC] polish .github/workflows/scripts/build_colossalai_wheel.py code style (#1721)

pull/1743/head
Arsmart1 2022-10-18 09:37:57 +08:00 committed by Frank Lee
parent 730f88f8e1
commit 8860d37846
1 changed files with 12 additions and 9 deletions

View File

@ -7,7 +7,6 @@ import subprocess
from packaging import version
from functools import cmp_to_key
WHEEL_TEXT_ROOT_URL = 'https://github.com/hpcaitech/public_assets/tree/main/colossalai/torch_build/torch_wheels'
RAW_TEXT_FILE_PREFIX = 'https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/torch_build/torch_wheels'
CUDA_HOME = os.environ['CUDA_HOME']
@ -16,10 +15,15 @@ CUDA_HOME = os.environ['CUDA_HOME']
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--torch_version', type=str)
parser.add_argument('--nightly', action='store_true',
help='whether this build is for nightly release, if True, will only build on the latest PyTorch version and Python 3.8')
parser.add_argument(
'--nightly',
action='store_true',
help=
'whether this build is for nightly release, if True, will only build on the latest PyTorch version and Python 3.8'
)
return parser.parse_args()
def get_cuda_bare_metal_version():
raw_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
@ -30,6 +34,7 @@ def get_cuda_bare_metal_version():
return bare_metal_major, bare_metal_minor
def all_wheel_info():
page_text = requests.get(WHEEL_TEXT_ROOT_URL).text
soup = BeautifulSoup(page_text)
@ -63,6 +68,7 @@ def all_wheel_info():
wheel_info[torch_version][cuda_version][python_version] = dict(method=method, url=url, flags=flags)
return wheel_info
def build_colossalai(wheel_info):
cuda_version_major, cuda_version_minor = get_cuda_bare_metal_version()
cuda_version_on_host = f'{cuda_version_major}.{cuda_version_minor}'
@ -78,12 +84,14 @@ def build_colossalai(wheel_info):
cmd = f'bash ./build_colossalai_wheel.sh {method} {url} {filename} {cuda_version} {python_version} {torch_version} {flags}'
os.system(cmd)
def main():
args = parse_args()
wheel_info = all_wheel_info()
# filter wheels on condition
all_torch_versions = list(wheel_info.keys())
def _compare_version(a, b):
if version.parse(a) > version.parse(b):
return 1
@ -105,11 +113,6 @@ def main():
build_colossalai(wheel_info)
if __name__ == '__main__':
main()