mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
79 lines
1.9 KiB
79 lines
1.9 KiB
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
import functools
|
|
import os
|
|
import random
|
|
from contextlib import contextmanager
|
|
from pathlib import Path
|
|
from typing import Callable
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from colossalai.accelerator import get_accelerator
|
|
|
|
|
|
def get_current_device():
|
|
"""
|
|
A wrapper function for accelerator's API for backward compatibility.
|
|
"""
|
|
return get_accelerator().get_current_device()
|
|
|
|
|
|
def ensure_path_exists(filename: str):
|
|
# ensure the path exists
|
|
dirpath = os.path.dirname(filename)
|
|
if not os.path.exists(dirpath):
|
|
Path(dirpath).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
@contextmanager
|
|
def conditional_context(context_manager, enable=True):
|
|
if enable:
|
|
with context_manager:
|
|
yield
|
|
else:
|
|
yield
|
|
|
|
|
|
def is_ddp_ignored(p):
|
|
return getattr(p, "_ddp_to_ignore", False)
|
|
|
|
|
|
def disposable(func: Callable) -> Callable:
|
|
executed = False
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
nonlocal executed
|
|
if not executed:
|
|
executed = True
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def free_storage(data: torch.Tensor) -> None:
|
|
"""Free underlying storage of a Tensor."""
|
|
if data.storage().size() > 0:
|
|
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
|
|
# is the sole occupant of the Storage.
|
|
assert data.storage_offset() == 0
|
|
data.storage().resize_(0)
|
|
|
|
|
|
def _cast_float(args, dtype: torch.dtype):
|
|
if isinstance(args, torch.Tensor) and torch.is_floating_point(args):
|
|
args = args.to(dtype)
|
|
elif isinstance(args, (list, tuple)):
|
|
args = type(args)(_cast_float(t, dtype) for t in args)
|
|
elif isinstance(args, dict):
|
|
args = {k: _cast_float(v, dtype) for k, v in args.items()}
|
|
return args
|
|
|
|
|
|
def set_seed(seed):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|