mirror of https://github.com/hpcaitech/ColossalAI
moved ensure_path_exists to utils.common (#591)
parent
e956d93ac2
commit
54e688b623
|
@ -1,7 +1,9 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
|
@ -39,6 +41,13 @@ def print_rank_0(msg: str, logger=None):
|
|||
logger.info(msg)
|
||||
|
||||
|
||||
def ensure_path_exists(filename: str):
|
||||
# ensure the path exists
|
||||
dirpath = os.path.dirname(filename)
|
||||
if not os.path.exists(dirpath):
|
||||
Path(dirpath).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def free_port():
|
||||
while True:
|
||||
try:
|
||||
|
@ -103,7 +112,6 @@ def conditional_context(context_manager, enable=True):
|
|||
|
||||
|
||||
class model_branch_context(object):
|
||||
|
||||
def __enter__(self):
|
||||
self.env_status = env.save()
|
||||
|
||||
|
@ -123,7 +131,7 @@ def _calc_l2_norm(grads):
|
|||
colossal_C.multi_tensor_l2norm,
|
||||
dummy_overflow_buf,
|
||||
[grads],
|
||||
False # no per-parameter norm
|
||||
False # no per-parameter norm
|
||||
)
|
||||
return norm
|
||||
|
||||
|
|
Loading…
Reference in New Issue