|
|
|
@ -1,7 +1,9 @@
|
|
|
|
|
#!/usr/bin/env python |
|
|
|
|
# -*- encoding: utf-8 -*- |
|
|
|
|
import os |
|
|
|
|
import random |
|
|
|
|
import socket |
|
|
|
|
from pathlib import Path |
|
|
|
|
from typing import List, Union |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
@ -39,6 +41,13 @@ def print_rank_0(msg: str, logger=None):
|
|
|
|
|
logger.info(msg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ensure_path_exists(filename: str): |
|
|
|
|
# ensure the path exists |
|
|
|
|
dirpath = os.path.dirname(filename) |
|
|
|
|
if not os.path.exists(dirpath): |
|
|
|
|
Path(dirpath).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def free_port(): |
|
|
|
|
while True: |
|
|
|
|
try: |
|
|
|
@ -103,7 +112,6 @@ def conditional_context(context_manager, enable=True):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class model_branch_context(object): |
|
|
|
|
|
|
|
|
|
def __enter__(self): |
|
|
|
|
self.env_status = env.save() |
|
|
|
|
|
|
|
|
@ -123,7 +131,7 @@ def _calc_l2_norm(grads):
|
|
|
|
|
colossal_C.multi_tensor_l2norm, |
|
|
|
|
dummy_overflow_buf, |
|
|
|
|
[grads], |
|
|
|
|
False # no per-parameter norm |
|
|
|
|
False # no per-parameter norm |
|
|
|
|
) |
|
|
|
|
return norm |
|
|
|
|
|
|
|
|
|