mirror of https://github.com/hpcaitech/ColossalAI
35 lines
962 B
Python
35 lines
962 B
Python
import torch
|
|
|
|
from colossalai.elixir.ctx import tensor_creation_methods
|
|
|
|
|
|
class MetaContext(object):
|
|
"""A context manager that wraps all tensor creation methods in torch.
|
|
By default, all tensors will be created in meta.
|
|
|
|
args:
|
|
device_type: The device type of the tensors to be created.
|
|
"""
|
|
|
|
def __init__(self, device_type: str = 'meta') -> None:
|
|
super().__init__()
|
|
self.device_type = device_type
|
|
return None
|
|
|
|
def __enter__(self):
|
|
|
|
def meta_wrap(func):
|
|
|
|
def wrapped_func(*args, **kwargs):
|
|
kwargs['device'] = self.device_type
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapped_func
|
|
|
|
for name, method in tensor_creation_methods.items():
|
|
setattr(torch, name, meta_wrap(method))
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
for name, method in tensor_creation_methods.items():
|
|
setattr(torch, name, method)
|