diff --git a/.github/workflows/scripts/build_colossalai_wheel.py b/.github/workflows/scripts/build_colossalai_wheel.py index 2d33238e2..5a2db0c87 100644 --- a/.github/workflows/scripts/build_colossalai_wheel.py +++ b/.github/workflows/scripts/build_colossalai_wheel.py @@ -7,7 +7,6 @@ import subprocess from packaging import version from functools import cmp_to_key - WHEEL_TEXT_ROOT_URL = 'https://github.com/hpcaitech/public_assets/tree/main/colossalai/torch_build/torch_wheels' RAW_TEXT_FILE_PREFIX = 'https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/torch_build/torch_wheels' CUDA_HOME = os.environ['CUDA_HOME'] @@ -16,10 +15,15 @@ CUDA_HOME = os.environ['CUDA_HOME'] def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--torch_version', type=str) - parser.add_argument('--nightly', action='store_true', - help='whether this build is for nightly release, if True, will only build on the latest PyTorch version and Python 3.8') + parser.add_argument( + '--nightly', + action='store_true', + help= + 'whether this build is for nightly release, if True, will only build on the latest PyTorch version and Python 3.8' + ) return parser.parse_args() + def get_cuda_bare_metal_version(): raw_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True) output = raw_output.split() @@ -30,6 +34,7 @@ def get_cuda_bare_metal_version(): return bare_metal_major, bare_metal_minor + def all_wheel_info(): page_text = requests.get(WHEEL_TEXT_ROOT_URL).text soup = BeautifulSoup(page_text) @@ -63,6 +68,7 @@ def all_wheel_info(): wheel_info[torch_version][cuda_version][python_version] = dict(method=method, url=url, flags=flags) return wheel_info + def build_colossalai(wheel_info): cuda_version_major, cuda_version_minor = get_cuda_bare_metal_version() cuda_version_on_host = f'{cuda_version_major}.{cuda_version_minor}' @@ -78,12 +84,14 @@ def build_colossalai(wheel_info): cmd = f'bash ./build_colossalai_wheel.sh {method} {url} {filename} {cuda_version} {python_version} {torch_version} {flags}' os.system(cmd) + def main(): args = parse_args() wheel_info = all_wheel_info() # filter wheels on condition all_torch_versions = list(wheel_info.keys()) + def _compare_version(a, b): if version.parse(a) > version.parse(b): return 1 @@ -105,11 +113,6 @@ def main(): build_colossalai(wheel_info) + if __name__ == '__main__': main() - - - - - -