mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
111 lines
3.1 KiB
111 lines
3.1 KiB
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
import functools
|
|
import os
|
|
import random
|
|
from contextlib import contextmanager
|
|
from pathlib import Path
|
|
from typing import Callable, Optional, Set
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from colossalai.accelerator import get_accelerator
|
|
|
|
|
|
def get_current_device():
|
|
"""
|
|
A wrapper function for accelerator's API for backward compatibility.
|
|
"""
|
|
return get_accelerator().get_current_device()
|
|
|
|
|
|
def ensure_path_exists(filename: str):
|
|
# ensure the path exists
|
|
dirpath = os.path.dirname(filename)
|
|
if not os.path.exists(dirpath):
|
|
Path(dirpath).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
@contextmanager
|
|
def conditional_context(context_manager, enable=True):
|
|
if enable:
|
|
with context_manager:
|
|
yield
|
|
else:
|
|
yield
|
|
|
|
|
|
def is_ddp_ignored(p):
|
|
return getattr(p, "_ddp_to_ignore", False)
|
|
|
|
|
|
def disposable(func: Callable) -> Callable:
|
|
executed = False
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
nonlocal executed
|
|
if not executed:
|
|
executed = True
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def free_storage(data: torch.Tensor) -> None:
|
|
"""Free underlying storage of a Tensor."""
|
|
if data.storage().size() > 0:
|
|
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
|
|
# is the sole occupant of the Storage.
|
|
assert data.storage_offset() == 0
|
|
data.storage().resize_(0)
|
|
|
|
|
|
def _cast_float(args, dtype: torch.dtype):
|
|
if isinstance(args, torch.Tensor) and torch.is_floating_point(args):
|
|
args = args.to(dtype)
|
|
elif isinstance(args, (list, tuple)):
|
|
args = type(args)(_cast_float(t, dtype) for t in args)
|
|
elif isinstance(args, dict):
|
|
args = {k: _cast_float(v, dtype) for k, v in args.items()}
|
|
return args
|
|
|
|
|
|
def set_seed(seed):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
|
|
|
|
def get_non_persistent_buffers_set(
|
|
module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
|
|
):
|
|
r"""
|
|
Args:
|
|
memo: a memo to store the set of modules already added to the result
|
|
prefix: a prefix that will be added to the name of the module
|
|
remove_duplicate: whether to remove the duplicated module instances in the result
|
|
or not
|
|
"""
|
|
|
|
if memo is None:
|
|
memo = set()
|
|
self_non_persistent_set = set()
|
|
if module not in memo:
|
|
if remove_duplicate:
|
|
memo.add(module)
|
|
self_non_persistent_set = set(
|
|
map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
|
|
)
|
|
for name, sub_module in module._modules.items():
|
|
if sub_module is None:
|
|
continue
|
|
submodule_prefix = prefix + ("." if prefix else "") + name
|
|
child_non_persistent_set = get_non_persistent_buffers_set(
|
|
sub_module, memo, submodule_prefix, remove_duplicate
|
|
)
|
|
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
|
|
return self_non_persistent_set
|