[autoparallel] add conv metainfo class for auto parallel (#1796)

* [fx] metainfo class for auto parallel

* [fx] add unit test for linear metainfo

* [fx] fix bwd param for linear

* [fx] modify unit test

* [fx] modify unit test

* [fx] modify import

* [fx] modify import

* [fx] modify import

* [fx] move meta profiler to auto parallel

* [fx] add conv metainfo class

* [fx] restore profiler

* [fx] restore meta profiler

* [autoparallel] modify unit test

* [fx] modify unit test
pull/1765/head
Boyuan Yao 2022-11-07 16:15:35 +08:00 committed by GitHub
parent 501a9e9cd2
commit 327d07c44a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 192 additions and 43 deletions

View File

@ -1 +1,2 @@
from .conv import *
from .linear import *

View File

@ -0,0 +1,122 @@
from typing import Callable, Dict, List, Tuple, Union
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from colossalai.tensor.sharding_spec import ShardingSpec
from ..registry import meta_register
__all__ = ['convnd_meta_info']
@meta_register.register(torch.nn.Conv1d)
@meta_register.register(torch.nn.Conv2d)
@meta_register.register(torch.nn.Conv3d)
def convnd_meta_info(*args) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d meta info generator
The atens graph of torch.nn.Convnd with bias is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%convolution_default : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%input_2, None, None, [None, None, None], [None, None, None], [None, None, None], None, [None, None, None], None), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%convolution_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%convolution_backward_default : [#users=3] = call_function[target=torch.ops.aten.convolution_backward.default](args = (%zeros_like_default, %detach_default, None, [None], [None, None, None], [None, None, None], [None, None, None], None, [None, None, None], None, [None, None, None]), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
%detach_default_5 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_6 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_5,), kwargs = {})
The atens graph of torch.nn.Convnd without bias is
graph():
%input_2 : [#users=2] = placeholder[target=placeholder](default=)
%convolution_default : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%input_2, None, None, [None, None], [None, None], [None, None], None, [None, None], None), kwargs = {})
%zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%convolution_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
%detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%input_2,), kwargs = {})
%convolution_backward_default : [#users=2] = call_function[target=torch.ops.aten.convolution_backward.default](args = (%zeros_like_default, %detach_default, None, [None], [None, None], [None, None], [None, None], None, [None, None], None, [None, None, None]), kwargs = {})
%detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
%detach_default_3 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%convolution_backward_default,), kwargs = {})
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
has_bias: bool = False
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
weight_tensor = next(filter(lambda x: x.name == 'weight', args)).data
# check if conv has bias
if len(args) == 4:
bias_tensor = next(filter(lambda x: x.name == 'bias', args)).data
has_bias = True
# construct input args for forward
fwd_args = [None] * 9
# weight and input
fwd_args[0] = input_tensor
fwd_args[1] = weight_tensor
fwd_args[2] = bias_tensor if has_bias else None
# transpose indicator should be set to False
fwd_args[6] = False
# construct input args for backward
bwd_args = [None] * 11
# weight and input
bwd_args[0] = output_tensor
bwd_args[1] = input_tensor
bwd_args[2] = weight_tensor
bwd_args[-1] = [True, True, True] if has_bias else [True, True, False]
# calculate cost
# the fwd op with compute cost is convolution.default
# the bwd op with compute cost is convolution_backward.default
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.convolution.default](fwd_args, (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) if has_bias else \
flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor))
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# TODO: use profiler to check conv temp memory
fwd_memory_cost = MemoryCost(activation=activation_size(output_tensor),
parameter=activation_size(weight_tensor) +
activation_size(bias_tensor) if has_bias else activation_size(weight_tensor),
temp=0,
buffer=0)
bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) + activation_size(weight_tensor) +
activation_size(bias_tensor) if has_bias else activation_size(input_tensor) +
activation_size(weight_tensor),
parameter=activation_size(weight_tensor) +
activation_size(bias_tensor) if has_bias else activation_size(weight_tensor),
temp=0,
buffer=0)
# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in
fwd_in = [input_tensor]
return compute_cost, memory_cost, fwd_in

View File

@ -59,7 +59,7 @@ def linear_meta_info(*args) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.
%detach_default_4 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_3,), kwargs = {})
Returns:
Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and save input flag
Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and forward inputs
"""
has_bias: bool = False

View File

@ -0,0 +1,61 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
def _conv_module_mem_test(rank, bias, world_size, port):
"""This function is for conv memory test
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
Args:
Args:
rank: device rank
bias: indicate whether conv module need bias
world_size: number of devices
port: port for initializing process group
"""
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = nn.Sequential(nn.Conv2d(4, 64, 3, padding=1, bias=bias)).cuda()
input = torch.rand(4, 4, 64, 64).cuda()
input.requires_grad = True
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# index of conv node in computation graph
node_index = 1
# total number of conv strategies
strategy_number = 16
mem_test_for_node_strategy(rank=rank,
model=model,
device_mesh=device_mesh,
node_index=node_index,
strategy_number=strategy_number,
input_args=[input],
meta_arg_names=['input'])
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_conv_meta_concrete_info_match(bias=False):
world_size = 4
run_func_module = partial(_conv_module_mem_test, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == '__main__':
test_conv_meta_concrete_info_match()

View File

@ -20,48 +20,15 @@ if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='PyTorch version is too low')
@parameterize('bias', [True, False])
def test_linear_metainfo(bias):
model = nn.Sequential(nn.Linear(16, 32, bias=bias).to('meta'))
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
linear_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(linear_mod_node)
# build handler
handler = LinearModuleHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# build strategy
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
# assert module is registered
assert meta_register.has(linear_mod_node.graph.owning_module.get_submodule(linear_mod_node.target).__class__)
# check metainfo
for strategy in strategies_vector:
strategy: ShardingStrategy
try:
metainfo = MetaInfo(strategy,
linear_mod_node.graph.owning_module.get_submodule(linear_mod_node.target).__class__)
except:
raise RuntimeError(f"Failed to compute metainfo for {strategy}")
def _linear_mem_test(rank, bias, world_size, port):
def _linear_module_mem_test(rank, bias, world_size, port):
"""This function is for linear memory test
Test and print real memory cost and estimated, this test will not be executed
in unit test.
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
Args:
bias (bool, optional): Indicate whether we need bias for Linear. Defaults to True.
rank: device rank
bias: indicate whether linear module need bias
world_size: number of devices
port: port for initializing process group
"""
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
@ -87,11 +54,9 @@ def _linear_mem_test(rank, bias, world_size, port):
@rerun_if_address_is_in_use()
def test_linear_meta_concrete_info_match(bias=False):
world_size = 4
run_func_module = partial(_linear_mem_test, bias=bias, world_size=world_size, port=free_port())
run_func_module = partial(_linear_module_mem_test, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == '__main__':
# test_linear_metainfo()
# _linear_mem_test(bias=True)
test_linear_meta_concrete_info_match()