[autoparallel] Add F.conv metainfo (#2069)

* [fx] metainfo class for auto parallel

* [fx] add unit test for linear metainfo

* [fx] fix bwd param for linear

* [fx] modify unit test

* [fx] modify unit test

* [fx] modify import

* [fx] modify import

* [fx] modify import

* [fx] move meta profiler to auto parallel

* [fx] add conv metainfo class

* [fx] restore profiler

* [fx] restore meta profiler

* [autoparallel] modify unit test

* [fx] modify unit test

* [autoparallel] add batchnorm metainfo class

* [autoparallel] fix batchnorm unit test function declaration

* [fx] restore profiler

* [fx] add relu metainfo class

* [fx] restore profiler

* [autoparallel] modify metainfo input

* [autoparallel] add pooling metainfo

* [autoparallel] add F.linear metainfo generator

* [autoparallel] add binary elementwise metainfo

* [fx] recover profiler

* [autoparallel] fix forward memory calculation

* [autoparallel] modify constants.py

* [autoparallel] remove redundant print

* [autoparallel] add F.conv metainfo

* [autoparallel] linear fix
pull/2083/head
Boyuan Yao 2022-12-06 10:17:57 +08:00 committed by GitHub
parent f123476666
commit cf0268da93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 68 additions and 6 deletions

View File

@ -22,6 +22,9 @@ __all__ = ['convnd_meta_info']
@meta_register.register(torch.nn.Conv1d) @meta_register.register(torch.nn.Conv1d)
@meta_register.register(torch.nn.Conv2d) @meta_register.register(torch.nn.Conv2d)
@meta_register.register(torch.nn.Conv3d) @meta_register.register(torch.nn.Conv3d)
@meta_register.register(torch.nn.functional.conv1d)
@meta_register.register(torch.nn.functional.conv2d)
@meta_register.register(torch.nn.functional.conv3d)
def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d meta info generator """torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d meta info generator
The atens graph of torch.nn.Convnd with bias is The atens graph of torch.nn.Convnd with bias is
@ -57,12 +60,19 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
has_bias: bool = False has_bias: bool = False
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
weight_tensor = next(filter(lambda x: x.name == 'weight', args)).data weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM]
# check if conv has bias # check if conv has bias
if len(args) == 4: if len(weight_tensors) > 1:
bias_tensor = next(filter(lambda x: x.name == 'bias', args)).data
has_bias = True has_bias = True
# bias tensor's shape only has one dimension
if len(weight_tensors[0].shape) == 1:
bias_tensor, weight_tensor = weight_tensors
else:
weight_tensor, bias_tensor = weight_tensors
else:
weight_tensor = weight_tensors[0]
# construct input args for forward # construct input args for forward
fwd_args = [None] * 9 fwd_args = [None] * 9

View File

@ -143,7 +143,7 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# NOTE: Linear don't have buffer and temp in forward and backward phase # NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor # the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=activation_size(output_tensor), fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
parameter=activation_size(weight_tensor), parameter=activation_size(weight_tensor),
temp=0, temp=0,
buffer=0) buffer=0)

View File

@ -15,6 +15,16 @@ from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
class ConvFunctionModule(nn.Module):
def __init__(self, in_channels=4, out_channels=64, kernel_size=3):
super().__init__()
self.conv_weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
def forward(self, input):
return nn.functional.conv2d(input, self.conv_weight)
def _conv_module_mem_test(rank, bias, world_size, port): def _conv_module_mem_test(rank, bias, world_size, port):
"""This function is for conv memory test """This function is for conv memory test
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
@ -57,5 +67,47 @@ def test_conv_meta_concrete_info_match(bias=False):
mp.spawn(run_func_module, nprocs=world_size) mp.spawn(run_func_module, nprocs=world_size)
def _conv_function_mem_test(rank, world_size, port):
"""This function is for conv function memory test
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
Args:
rank: device rank
bias: indicate whether conv module need bias
world_size: number of devices
port: port for initializing process group
"""
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = ConvFunctionModule().cuda()
input = torch.rand(4, 4, 64, 64).cuda()
input.requires_grad = True
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# index of target node in computation graph
node_index = 2
# total number of target node strategies
strategy_number = 16
mem_test_for_node_strategy(rank=rank,
model=model,
device_mesh=device_mesh,
node_index=node_index,
strategy_number=strategy_number,
input_args=[input],
meta_arg_names=['input'])
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_conv_function_concrete_info_match():
world_size = 4
run_func_module = partial(_conv_function_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_conv_meta_concrete_info_match() # test_conv_meta_concrete_info_match()
test_conv_function_concrete_info_match()

View File

@ -92,7 +92,7 @@ def _linear_function_mem_test(rank, world_size, port):
model=model, model=model,
device_mesh=device_mesh, device_mesh=device_mesh,
node_index=2, node_index=2,
strategy_number=13, strategy_number=23,
input_args=[input], input_args=[input],
meta_arg_names=["input"]) meta_arg_names=["input"])