mirror of https://github.com/hpcaitech/ColossalAI
28 lines
811 B
Python
28 lines
811 B
Python
from contextlib import contextmanager
|
|
|
|
import torch
|
|
|
|
|
|
def _noop(*args, **kwargs):
|
|
pass
|
|
|
|
|
|
@contextmanager
|
|
def low_resource_init():
|
|
"""This context manager disables weight initialization and sets the default float dtype to half."""
|
|
old_kaiming_uniform_ = torch.nn.init.kaiming_uniform_
|
|
old_uniform_ = torch.nn.init.uniform_
|
|
old_normal_ = torch.nn.init.normal_
|
|
dtype = torch.get_default_dtype()
|
|
try:
|
|
torch.nn.init.kaiming_uniform_ = _noop
|
|
torch.nn.init.uniform_ = _noop
|
|
torch.nn.init.normal_ = _noop
|
|
torch.set_default_dtype(torch.half)
|
|
yield
|
|
finally:
|
|
torch.nn.init.kaiming_uniform_ = old_kaiming_uniform_
|
|
torch.nn.init.uniform_ = old_uniform_
|
|
torch.nn.init.normal_ = old_normal_
|
|
torch.set_default_dtype(dtype)
|