mirror of https://github.com/hpcaitech/ColossalAI
[fx] add more meta_registry for MetaTensor execution. (#2000)
* [sc] add examples for auto checkpoint. * merge upstream * [fx] add more meta_registry for MetaTensor execution.pull/2005/head
parent
d00d905b86
commit
2edbef13cc
|
@ -3,7 +3,7 @@
|
||||||
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||||
# for more meta_registrations
|
# for more meta_registrations
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
@ -179,6 +179,42 @@ def meta_adaptive_avg_pool2d_backward(
|
||||||
return grad_input
|
return grad_input
|
||||||
|
|
||||||
|
|
||||||
|
# ================================ RNN =============================================
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
||||||
|
@register_meta(aten._cudnn_rnn.default)
|
||||||
|
def meta_cuda_rnn(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_stride0: int,
|
||||||
|
weight_buf: torch.Tensor,
|
||||||
|
hx: torch.Tensor,
|
||||||
|
cx: Optional[torch.Tensor] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if cx is not None:
|
||||||
|
return torch.empty_like(input), torch.empty_like(hx), torch.empty_like(cx)
|
||||||
|
else:
|
||||||
|
return torch.empty_like(input), torch.empty_like(hx), torch.empty((), device='meta')
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
||||||
|
@register_meta(aten._cudnn_rnn_backward.default)
|
||||||
|
def meta_cudnn_rnn_backward(input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_stride0: int,
|
||||||
|
hx: torch.Tensor,
|
||||||
|
cx: Optional[torch.Tensor] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs):
|
||||||
|
print(input, weight, hx, cx)
|
||||||
|
grad_input = torch.empty_like(input)
|
||||||
|
grad_weight = torch.empty_like(weight)
|
||||||
|
grad_hx = torch.empty_like(hx)
|
||||||
|
grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device='meta')
|
||||||
|
return grad_input, grad_weight, grad_hx, grad_cx
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
|
||||||
# ============================== Activations =======================================
|
# ============================== Activations =======================================
|
||||||
@register_meta(aten.relu.default)
|
@register_meta(aten.relu.default)
|
||||||
|
@ -186,6 +222,11 @@ def meta_relu(input: torch.Tensor):
|
||||||
return torch.empty_like(input)
|
return torch.empty_like(input)
|
||||||
|
|
||||||
|
|
||||||
|
@register_meta(aten.prelu.default)
|
||||||
|
def meta_prelu(input: torch.Tensor, weight: torch.Tensor):
|
||||||
|
return torch.empty_like(input)
|
||||||
|
|
||||||
|
|
||||||
@register_meta(aten.hardswish.default)
|
@register_meta(aten.hardswish.default)
|
||||||
def meta_hardswish(input: torch.Tensor):
|
def meta_hardswish(input: torch.Tensor):
|
||||||
return torch.empty_like(input)
|
return torch.empty_like(input)
|
||||||
|
@ -278,12 +319,18 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me
|
||||||
|
|
||||||
|
|
||||||
# ================================== Misc ==========================================
|
# ================================== Misc ==========================================
|
||||||
#https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||||
@register_meta(aten.roll.default)
|
@register_meta(aten.roll.default)
|
||||||
def meta_roll(input: torch.Tensor, shifts, dims):
|
def meta_roll(input: torch.Tensor, shifts, dims):
|
||||||
return input
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp
|
||||||
|
@register_meta(aten._local_scalar_dense.default)
|
||||||
|
def meta_local_scalar_dense(self: torch.Tensor):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
|
||||||
@register_meta(aten.where.self)
|
@register_meta(aten.where.self)
|
||||||
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
|
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
|
||||||
|
@ -317,7 +364,7 @@ def meta_index_Tensor(self, indices):
|
||||||
indices = result
|
indices = result
|
||||||
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
|
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
|
||||||
# expand_outplace
|
# expand_outplace
|
||||||
import torch._refs as refs # avoid import cycle in mypy
|
import torch._refs as refs
|
||||||
|
|
||||||
indices = list(refs._maybe_broadcast(*indices))
|
indices = list(refs._maybe_broadcast(*indices))
|
||||||
# add missing null tensors
|
# add missing null tensors
|
||||||
|
|
|
@ -128,3 +128,13 @@ class MetaTensor(torch.Tensor):
|
||||||
if device is not None:
|
if device is not None:
|
||||||
result = MetaTensor(result, fake_device=device)
|
result = MetaTensor(result, fake_device=device)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def cpu(self, *args, **kwargs):
|
||||||
|
if self.device.type == 'cpu':
|
||||||
|
return self.to(*args, **kwargs)
|
||||||
|
return self.to(*args, device='cpu', **kwargs)
|
||||||
|
|
||||||
|
def cuda(self, *args, **kwargs):
|
||||||
|
if self.device.type == 'cuda':
|
||||||
|
return self.to(*args, **kwargs)
|
||||||
|
return self.to(*args, device='cuda', **kwargs)
|
||||||
|
|
|
@ -20,28 +20,25 @@ def symbolic_trace(
|
||||||
Given an ``nn.Module`` or function instance ``root``, this function will return a ``ColoGraphModule``
|
Given an ``nn.Module`` or function instance ``root``, this function will return a ``ColoGraphModule``
|
||||||
constructed by recording operations seen while tracing through ``root``.
|
constructed by recording operations seen while tracing through ``root``.
|
||||||
|
|
||||||
With ``meta_args`` and ``concrete_args``, we can trace the model that are untraceable subject to control flow.
|
With ``meta_args``, we can trace the model that are untraceable subject to control flow. If specified using
|
||||||
If specified using ``meta_args`` only, the tracing can be done ahead of time.
|
``meta_args`` only, the tracing can be done ahead of time.
|
||||||
|
|
||||||
Note that both ``meta_args`` and ``concrete_args`` are kwargs, which contains the key of the argument's names
|
Note that ``meta_args`` are kwargs, which contains the key of the argument's names and the value of the
|
||||||
and the value of the argument's values.
|
argument's values.
|
||||||
|
|
||||||
Uses:
|
Uses:
|
||||||
>>> model = ...
|
>>> model = ...
|
||||||
|
|
||||||
# if this works
|
# if this works
|
||||||
>>> gm = symbolic_trace(model)
|
>>> gm = symbolic_trace(model, concrete_args=concrete_args)
|
||||||
|
|
||||||
# else try this
|
# else try this
|
||||||
>>> gm = symbolic_trace(model, meta_args={'x': torch.rand(1, 3, 224, 224, device='meta')})
|
>>> gm = symbolic_trace(model, concrete_args=concrete_args, meta_args={'x': torch.rand(1, 3, 224, 224, device='meta')})
|
||||||
|
|
||||||
# else try this
|
|
||||||
>>> gm = symbolic_trace(model, concrete_args={'x': torch.rand(1, 3, 224, 224)})
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
root (Union[torch.nn.Module, Callable[..., Any]]): Module or function to be traced and converted
|
root (Union[torch.nn.Module, Callable[..., Any]]): Module or function to be traced and converted
|
||||||
into a Graph representation.
|
into a Graph representation.
|
||||||
concrete_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized. Defaults to None.
|
concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be used for tracing.
|
||||||
meta_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized, special for ``ColoTracer``.
|
meta_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized, special for ``ColoTracer``.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
|
||||||
|
@ -52,7 +49,6 @@ def symbolic_trace(
|
||||||
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
|
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
tracer = ColoTracer()
|
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args)
|
||||||
graph = tracer.trace(root, concrete_args, meta_args)
|
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||||
name = (root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__)
|
return ColoGraphModule(root, graph, name)
|
||||||
return ColoGraphModule(tracer.root, graph, name)
|
|
||||||
|
|
|
@ -18,13 +18,11 @@ def bench(gm: torch.fx.GraphModule,
|
||||||
data_gen: Callable,
|
data_gen: Callable,
|
||||||
num_steps: int = 5) -> Tuple[int, int]:
|
num_steps: int = 5) -> Tuple[int, int]:
|
||||||
"""Benchmarking a given graph module
|
"""Benchmarking a given graph module
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
gm (torch.fx.GraphModule): The graph module to benchmark.
|
gm (torch.fx.GraphModule): The graph module to benchmark.
|
||||||
criterion (torch.nn.Module): Loss function.
|
criterion (torch.nn.Module): Loss function.
|
||||||
data_gen (Callable): Data generator.
|
data_gen (Callable): Data generator.
|
||||||
num_steps (int, optional): Number of test steps. Defaults to 5.
|
num_steps (int, optional): Number of test steps. Defaults to 5.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[int, int]: peak memory in MB and step time in MS.
|
Tuple[int, int]: peak memory in MB and step time in MS.
|
||||||
"""
|
"""
|
||||||
|
@ -69,7 +67,6 @@ def bench_rotor(gm: torch.fx.GraphModule,
|
||||||
start_factor: int = 4) -> Tuple[np.array, list, list]:
|
start_factor: int = 4) -> Tuple[np.array, list, list]:
|
||||||
"""Auto Checkpoint Rotor Algorithm benchmarking
|
"""Auto Checkpoint Rotor Algorithm benchmarking
|
||||||
Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data.
|
Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
gm (torch.fx.GraphModule): The graph module to benchmark.
|
gm (torch.fx.GraphModule): The graph module to benchmark.
|
||||||
criterion (torch.nn.Module): Loss function.
|
criterion (torch.nn.Module): Loss function.
|
||||||
|
@ -79,7 +76,6 @@ def bench_rotor(gm: torch.fx.GraphModule,
|
||||||
free_memory (int, optional): Max memory budget in Byte. Defaults to torch.cuda.mem_get_info()[0].
|
free_memory (int, optional): Max memory budget in Byte. Defaults to torch.cuda.mem_get_info()[0].
|
||||||
start_factor (int, optional): Start memory budget factor for benchmark, the start memory budget
|
start_factor (int, optional): Start memory budget factor for benchmark, the start memory budget
|
||||||
will be free_memory / start_factor. Defaults to 4.
|
will be free_memory / start_factor. Defaults to 4.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[np.array, list, list]: return budgets vector (MB), peak memory vector (MB), step time vector (MS).
|
Tuple[np.array, list, list]: return budgets vector (MB), peak memory vector (MB), step time vector (MS).
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue