mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
139 lines
4.7 KiB
139 lines
4.7 KiB
import torch |
|
from torch.fx import Graph, Node |
|
from torch.utils._pytree import tree_map |
|
|
|
|
|
def normalize_tuple(x): |
|
if not isinstance(x, tuple): |
|
return (x,) |
|
return x |
|
|
|
|
|
def is_autogradable(x): |
|
return isinstance(x, torch.Tensor) and x.is_floating_point() |
|
|
|
|
|
def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Graph: |
|
"""Trace forward and backward graph with MetaTensor |
|
|
|
Args: |
|
module (torch.nn.Module): The target module for tracing. |
|
|
|
Returns: |
|
graph (torch.fx.Graph): The computation graph. |
|
|
|
Usage: |
|
>>> import torchvision.models as tm |
|
>>> model = tm.alexnet() |
|
>>> graph = meta_trace(model, torch.rand(1000, 3, 224, 224)) |
|
>>> graph.print_tabular() |
|
""" |
|
graph = Graph() |
|
namespace = graph._graph_namespace |
|
|
|
class MetaProxy(torch.Tensor): |
|
""" |
|
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops. |
|
""" |
|
|
|
_tensor: torch.Tensor |
|
_node: Node |
|
|
|
__slots__ = ["_tensor", "_node"] |
|
|
|
@staticmethod |
|
def __new__(cls, tensor, fake_device=None, placeholder=False, name=None): |
|
r = torch.Tensor._make_wrapper_subclass( |
|
cls, |
|
tensor.size(), |
|
strides=tensor.stride(), |
|
storage_offset=tensor.storage_offset(), |
|
dtype=tensor.dtype, |
|
layout=tensor.layout, |
|
device=fake_device if fake_device is not None else tensor.device, |
|
requires_grad=tensor.requires_grad, |
|
) # deceive the frontend for aten selections |
|
r._tensor = tensor |
|
if placeholder: |
|
if name is None: |
|
name = "input" |
|
r._node = graph.create_node( |
|
"placeholder", "placeholder", (graph._root,), name=namespace.create_name(name, tensor) |
|
) |
|
# ...the real tensor is held as an element on the tensor. |
|
if not r._tensor.is_meta: |
|
r._tensor = r._tensor.to(torch.device("meta")) |
|
return r |
|
|
|
@classmethod |
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): |
|
def unwrap(x): |
|
nonlocal fake_device |
|
if isinstance(x, MetaProxy): |
|
fake_device = x.device |
|
x = x._tensor |
|
# assert not isinstance(x, MetaProxy) |
|
elif isinstance(x, torch.Tensor): |
|
fake_device = x.device |
|
x = x.to(torch.device("meta")) |
|
return x |
|
|
|
def get_node(x): |
|
if isinstance(x, torch.Tensor) and not hasattr(x, "_node"): |
|
x = MetaProxy(x, placeholder=True, name="weight") |
|
return x if not hasattr(x, "_node") else x._node |
|
|
|
args_node = tree_map(get_node, args) |
|
kwargs_node = tree_map(get_node, kwargs) |
|
node = graph.create_node("call_function", func, args_node, kwargs_node) |
|
|
|
if "device" in kwargs: |
|
fake_device = kwargs["device"] |
|
kwargs["device"] = torch.device("meta") |
|
|
|
args = tree_map(unwrap, args) |
|
kwargs = tree_map(unwrap, kwargs) |
|
|
|
# run aten for backend=CPU but actually on backend=Meta |
|
out = func(*args, **kwargs) |
|
|
|
# Now, we want to continue propagating this tensor, so we rewrap Tensors in |
|
# our custom tensor subclass |
|
def wrap(x): |
|
if isinstance(x, torch.Tensor): |
|
nonlocal fake_device |
|
if not x.is_meta: |
|
x = x.to(torch.device("meta")) |
|
return ( |
|
MetaProxy(x, fake_device=fake_device) |
|
if isinstance(x, torch.Tensor) and not hasattr(x, "_tensor") |
|
else x |
|
) |
|
|
|
def set_node(x): |
|
x._node = node |
|
|
|
out = tree_map(wrap, out) |
|
tree_map(set_node, out) |
|
|
|
return out |
|
|
|
def wrap(x): |
|
return MetaProxy(x, fake_device=fake_device, placeholder=True) if isinstance(x, torch.Tensor) else x |
|
|
|
args = tree_map(wrap, args) |
|
kwargs = tree_map(wrap, kwargs) |
|
|
|
out = module(*args, **kwargs) |
|
|
|
for tensor in normalize_tuple(out): |
|
if is_autogradable(tensor) and tensor.requires_grad: |
|
grad = ( |
|
torch.empty_like(tensor._tensor, device=torch.device("meta")) |
|
if isinstance(tensor, MetaProxy) |
|
else torch.empty_like(tensor, device=torch.device("meta")) |
|
) |
|
torch.autograd.backward( |
|
tensor, MetaProxy(grad, fake_device=tensor.device, placeholder=True), retain_graph=True |
|
) |
|
return graph
|
|
|