mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] add torch.nn.ReLU metainfo (#1868)
* [fx] metainfo class for auto parallel * [fx] add unit test for linear metainfo * [fx] fix bwd param for linear * [fx] modify unit test * [fx] modify unit test * [fx] modify import * [fx] modify import * [fx] modify import * [fx] move meta profiler to auto parallel * [fx] add conv metainfo class * [fx] restore profiler * [fx] restore meta profiler * [autoparallel] modify unit test * [fx] modify unit test * [autoparallel] add batchnorm metainfo class * [autoparallel] fix batchnorm unit test function declaration * [fx] restore profiler * [fx] add relu metainfo class * [fx] restore profiler * [autoparallel] modify metainfo inputpull/1968/head^2
parent
8c66a1d0aa
commit
7c7921f71b
|
@ -0,0 +1,5 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# list of inplace operations
|
||||
INPLACE_MODULE = [nn.ReLU]
|
|
@ -1,3 +1,4 @@
|
|||
from .activation import *
|
||||
from .conv import *
|
||||
from .linear import *
|
||||
from .norm import *
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
from colossalai.fx.profiler.memory_utils import activation_size
|
||||
from colossalai.fx.profiler.opcount import flop_mapping
|
||||
|
||||
from ..registry import meta_register
|
||||
|
||||
__all__ = ["relu_meta_info"]
|
||||
|
||||
|
||||
@meta_register.register(torch.nn.ReLU)
|
||||
def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
"""torch.nn.ReLU metainfo generator
|
||||
The aten graph of torch.nn.ReLU is
|
||||
graph():
|
||||
%input_2 : [#users=1] = placeholder[target=placeholder](default=)
|
||||
%relu_default : [#users=2] = call_function[target=torch.ops.aten.relu.default](args = (%input_2,), kwargs = {})
|
||||
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%relu_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
|
||||
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%relu_default,), kwargs = {})
|
||||
%threshold_backward_default : [#users=1] = call_function[target=torch.ops.aten.threshold_backward.default](args = (%zeros_like_default, %detach_default, None), kwargs = {})
|
||||
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%threshold_backward_default,), kwargs = {})
|
||||
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
|
||||
|
||||
Returns:
|
||||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||
"""
|
||||
|
||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
inplace = kwargs.get("inplace", False)
|
||||
|
||||
# construct input args for forward
|
||||
fwd_in_args = [input_tensor]
|
||||
|
||||
# construct input args for backward
|
||||
bwd_in_args = [output_tensor]
|
||||
|
||||
# calculate cost
|
||||
# the fwd op with compute cost is relu.default
|
||||
# the bwd op with compute cost is threshold_backward
|
||||
|
||||
# calculate compute cost
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.relu.default](fwd_in_args, (output_tensor,))
|
||||
bwd_compute_cost = flop_mapping[torch.ops.aten.threshold_backward.default](bwd_in_args, (input_tensor,))
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# calculate memory cost
|
||||
# NOTE: the inplace ReLU don't have forward memory cost
|
||||
fwd_memory_cost = MemoryCost(activation=0 if inplace else activation_size(output_tensor),
|
||||
parameter=0,
|
||||
temp=0,
|
||||
buffer=0)
|
||||
|
||||
bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor), parameter=0, temp=0, buffer=0)
|
||||
|
||||
# total cost is the sum of forward and backward cost
|
||||
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
|
||||
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
|
||||
|
||||
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
|
||||
|
||||
# store fwd_in
|
||||
fwd_in = [input_tensor]
|
||||
|
||||
return compute_cost, memory_cost, fwd_in
|
|
@ -22,7 +22,7 @@ __all__ = ['convnd_meta_info']
|
|||
@meta_register.register(torch.nn.Conv1d)
|
||||
@meta_register.register(torch.nn.Conv2d)
|
||||
@meta_register.register(torch.nn.Conv3d)
|
||||
def convnd_meta_info(*args) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
"""torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d meta info generator
|
||||
The atens graph of torch.nn.Convnd with bias is
|
||||
graph():
|
||||
|
|
|
@ -20,7 +20,7 @@ __all__ = ['linear_meta_info']
|
|||
|
||||
|
||||
@meta_register.register(torch.nn.Linear)
|
||||
def linear_meta_info(*args) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
"""torch.nn.Linear meta info generator
|
||||
The atens graph of torch.nn.Linear with bias is
|
||||
graph():
|
||||
|
|
|
@ -22,7 +22,7 @@ __all__ = ['batchnormnd_meta_info']
|
|||
@meta_register.register(torch.nn.BatchNorm1d)
|
||||
@meta_register.register(torch.nn.BatchNorm2d)
|
||||
@meta_register.register(torch.nn.BatchNorm3d)
|
||||
def batchnormnd_meta_info(*args) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||
"""BatchNorm1d, BatchNorm2d, BatchNorm3d, meta info generator
|
||||
The aten graph of BatchNorm2d is like
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|||
)
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .constants import INPLACE_MODULE
|
||||
from .registry import meta_register
|
||||
|
||||
__all__ = ['MetaInfo']
|
||||
|
@ -91,11 +92,17 @@ class MetaInfo:
|
|||
Compute meta info based on sharding strategy and the given target function.
|
||||
"""
|
||||
|
||||
assert meta_register.has(self._target), f'{self._target} not found in the meta registry'
|
||||
meta_func = meta_register.get(self._target)
|
||||
assert meta_register.has(self._target.__class__), f'{self._target.__class__} not found in the meta registry'
|
||||
meta_func = meta_register.get(self._target.__class__)
|
||||
|
||||
# construct args for meta_func
|
||||
args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()]
|
||||
|
||||
# construct kwargs
|
||||
if self.target in INPLACE_MODULE:
|
||||
kwargs = {'inplace': self.target.inplace}
|
||||
else:
|
||||
kwargs = {'inplace': False}
|
||||
|
||||
# compute metainfo with meta_func
|
||||
self.compute_cost, self.memory_cost, self.fwd_in = meta_func(*args)
|
||||
self.compute_cost, self.memory_cost, self.fwd_in = meta_func(*args, **kwargs)
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
|
||||
|
||||
|
||||
def _ReLU_module_mem_test(rank, world_size, port):
|
||||
"""This function is for conv memory test
|
||||
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
|
||||
|
||||
Args:
|
||||
Args:
|
||||
rank: device rank
|
||||
bias: indicate whether conv module need bias
|
||||
world_size: number of devices
|
||||
port: port for initializing process group
|
||||
"""
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = nn.Sequential(nn.ReLU()).cuda()
|
||||
input = torch.rand(4, 128, 64, 64).cuda()
|
||||
input.requires_grad = True
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
# index of conv node in computation graph
|
||||
node_index = 1
|
||||
# total number of conv strategies
|
||||
strategy_number = 1
|
||||
mem_test_for_node_strategy(rank=rank,
|
||||
model=model,
|
||||
device_mesh=device_mesh,
|
||||
node_index=node_index,
|
||||
strategy_number=strategy_number,
|
||||
input_args=[input],
|
||||
meta_arg_names=['input'])
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_ReLU_meta_concrete_info_match():
|
||||
world_size = 4
|
||||
run_func_module = partial(_ReLU_module_mem_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_ReLU_meta_concrete_info_match()
|
|
@ -60,9 +60,10 @@ def mem_test_for_node_strategy(rank: int,
|
|||
gm.recompile()
|
||||
gm: GraphModule
|
||||
|
||||
num_of_strategies = len(target_node.strategies_vector)
|
||||
if rank == 0:
|
||||
print("=======================")
|
||||
print(f"#strategy_index: {strategy_index}")
|
||||
print(f"#strategy_index: {strategy_index + 1}/{num_of_strategies}")
|
||||
pprint(target_node.strategies_vector[strategy_index])
|
||||
|
||||
# warmup
|
||||
|
@ -104,7 +105,7 @@ def mem_test_for_node_strategy(rank: int,
|
|||
|
||||
# estimated memory
|
||||
metainfo = MetaInfo(target_node.strategies_vector[strategy_index],
|
||||
target_node.graph.owning_module.get_submodule(target_node.target).__class__)
|
||||
target_node.graph.owning_module.get_submodule(target_node.target))
|
||||
print("estimated memory:")
|
||||
print(
|
||||
f"forward activation: {metainfo.memory_cost.fwd.activation / 1024} kb, forward param: {metainfo.memory_cost.fwd.parameter / 1024} kb"
|
||||
|
|
Loading…
Reference in New Issue