diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 4a4b6c685..2eb96cdfd 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -1,7 +1,9 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import os import random import socket +from pathlib import Path from typing import List, Union import torch @@ -39,6 +41,13 @@ def print_rank_0(msg: str, logger=None): logger.info(msg) +def ensure_path_exists(filename: str): + # ensure the path exists + dirpath = os.path.dirname(filename) + if not os.path.exists(dirpath): + Path(dirpath).mkdir(parents=True, exist_ok=True) + + def free_port(): while True: try: @@ -103,7 +112,6 @@ def conditional_context(context_manager, enable=True): class model_branch_context(object): - def __enter__(self): self.env_status = env.save() @@ -123,7 +131,7 @@ def _calc_l2_norm(grads): colossal_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], - False # no per-parameter norm + False # no per-parameter norm ) return norm