[autoparallel] add binary elementwise metainfo for auto parallel (#2058)

* [fx] metainfo class for auto parallel

* [fx] add unit test for linear metainfo

* [fx] fix bwd param for linear

* [fx] modify unit test

* [fx] modify unit test

* [fx] modify import

* [fx] modify import

* [fx] modify import

* [fx] move meta profiler to auto parallel

* [fx] add conv metainfo class

* [fx] restore profiler

* [fx] restore meta profiler

* [autoparallel] modify unit test

* [fx] modify unit test

* [autoparallel] add batchnorm metainfo class

* [autoparallel] fix batchnorm unit test function declaration

* [fx] restore profiler

* [fx] add relu metainfo class

* [fx] restore profiler

* [autoparallel] modify metainfo input

* [autoparallel] add pooling metainfo

* [autoparallel] add F.linear metainfo generator

* [autoparallel] add binary elementwise metainfo

* [fx] recover profiler

* [autoparallel] fix forward memory calculation

* [autoparallel] modify constants.py

* [autoparallel] remove redundant print
pull/2071/head
Boyuan Yao 2022-12-04 15:18:51 +08:00 committed by GitHub
parent 4b40fbd743
commit 616da17fab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 164 additions and 11 deletions

View File

@ -1,5 +1,12 @@
import operator
import torch
import torch.nn as nn
from ..tensor_shard.constants import *
# list of inplace operations
INPLACE_MODULE = [nn.ReLU]
# list of operations that do not save forward activations
NO_SAVE_ACTIVATION = [torch.add, torch.sub, operator.add, operator.sub]

View File

@ -1,4 +1,5 @@
from .activation import *
from .binary_elementwise_ops import *
from .conv import *
from .linear import *
from .norm import *

View File

@ -0,0 +1,65 @@
from typing import List, Tuple
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from ..constants import BCAST_FUNC_OP
from ..registry import meta_register
__all__ = ['binary_elementwise_meta_info']
@meta_register.register(BCAST_FUNC_OP)
def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""Meta information generator for binary elementwise operations
NOTE: Some of the binary elementwise operations will discard the input activation after computation, as they
don't need those tensors for back propagation, for example, if there are two tensors being sent for `torch.add`,
they will be discarded right after add operation is done. We create a simple API in `MetaInfo` class to identify
this behavior, it is critical for better memory estimation.
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_op_data, other_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args))
# construct forward args for flop mapping
fwd_in_args = [input_op_data.data, other_op_data.data]
fwd_out_args = [output_op_data.data]
# calculate cost
# calculate compute cost
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args)
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
param_mem_cost = activation_size(
[arg.data for arg in [input_op_data, other_op_data] if arg.type == OperationDataType.PARAM])
fwd_mem_cost = MemoryCost(
activation=activation_size([input_op_data.data, output_op_data.data]),
parameter=param_mem_cost,
)
bwd_mem_cost = MemoryCost(
activation=activation_size(fwd_in_args),
parameter=param_mem_cost,
)
# total cost
total_mem_cost = MemoryCost(
activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in
fwd_in = fwd_in_args
return compute_cost, memory_cost, fwd_in

View File

@ -13,7 +13,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
)
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import INPLACE_MODULE
from .constants import INPLACE_MODULE, NO_SAVE_ACTIVATION
from .registry import meta_register
__all__ = ['MetaInfo']
@ -35,6 +35,9 @@ class MetaInfo:
# list of input tensors
self.fwd_in: list[OperationData]
# bool type to indicate whether the function will save forward activation
self.save_fwd_in: bool
# sharding strategy
self._strategy = strategy
@ -95,10 +98,16 @@ class MetaInfo:
try:
# module
meta_func = meta_register.get(self._target.__class__)
# check whether the target in the module list that we don't need to save activation
self.save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
except:
# function
meta_func = meta_register.get(self._target)
# check whether the target in the module list that we don't need to save activation
self.save_fwd_in = self._target not in NO_SAVE_ACTIVATION
# construct args for meta_func
args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()]

View File

@ -35,9 +35,9 @@ def _ReLU_module_mem_test(rank, world_size, port):
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# index of conv node in computation graph
# index of target node in computation graph
node_index = 1
# total number of conv strategies
# total number of target node strategies
strategy_number = 1
mem_test_for_node_strategy(rank=rank,
model=model,

View File

@ -34,9 +34,9 @@ def _batchnorm_module_mem_test(rank, world_size, port):
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# index of conv node in computation graph
# index of target node in computation graph
node_index = 1
# total number of conv strategies
# total number of target node strategies
strategy_number = 4
mem_test_for_node_strategy(rank=rank,
model=model,

View File

@ -0,0 +1,71 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
class BinaryElementwiseOpModule(nn.Module):
def __init__(self, token=torch.add, shape=64) -> None:
super().__init__()
self.token = token
self.param = nn.Parameter(torch.rand(shape))
def forward(self, input):
return input + self.param
def _binary_elementwise_mem_test(rank, world_size, port):
"""This function is for binary elementwise ops memory test
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
Args:
rank: device rank
bias: indicate whether conv module need bias
world_size: number of devices
port: port for initializing process group
"""
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = BinaryElementwiseOpModule(token=torch.add, shape=1024).cuda()
input = torch.rand(32, 1024).cuda()
input.requires_grad = True
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# index of target node in computation graph
node_index = 2
# total number of target node strategies
strategy_number = 9
mem_test_for_node_strategy(rank=rank,
model=model,
device_mesh=device_mesh,
node_index=node_index,
strategy_number=strategy_number,
input_args=[input],
meta_arg_names=['input'])
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_binary_elementwise_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_binary_elementwise_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == '__main__':
test_binary_elementwise_meta_concrete_info_match()

View File

@ -35,9 +35,9 @@ def _conv_module_mem_test(rank, bias, world_size, port):
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# index of conv node in computation graph
# index of target node in computation graph
node_index = 1
# total number of conv strategies
# total number of target node strategies
strategy_number = 16
mem_test_for_node_strategy(rank=rank,
model=model,

View File

@ -34,9 +34,9 @@ def _adaptiveavgpool_module_mem_test(rank, world_size, port):
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# index of conv node in computation graph
# index of target node in computation graph
node_index = 1
# total number of conv strategies
# total number of target strategies
strategy_number = 1
mem_test_for_node_strategy(rank=rank,
model=model,
@ -75,9 +75,9 @@ def _maxpool_module_mem_test(rank, world_size, port):
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# index of conv node in computation graph
# index of target node in computation graph
node_index = 1
# total number of conv strategies
# total number of target node strategies
strategy_number = 9
mem_test_for_node_strategy(rank=rank,
model=model,