ColossalAI/colossalai/elixir/ctx/meta_ctx.py

35 lines
962 B
Python

import torch
from colossalai.elixir.ctx import tensor_creation_methods
class MetaContext(object):
"""A context manager that wraps all tensor creation methods in torch.
By default, all tensors will be created in meta.
args:
device_type: The device type of the tensors to be created.
"""
def __init__(self, device_type: str = 'meta') -> None:
super().__init__()
self.device_type = device_type
return None
def __enter__(self):
def meta_wrap(func):
def wrapped_func(*args, **kwargs):
kwargs['device'] = self.device_type
return func(*args, **kwargs)
return wrapped_func
for name, method in tensor_creation_methods.items():
setattr(torch, name, meta_wrap(method))
def __exit__(self, exc_type, exc_val, exc_tb):
for name, method in tensor_creation_methods.items():
setattr(torch, name, method)