diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py index 63d6cdc39..f7d55529f 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py @@ -22,6 +22,9 @@ __all__ = ['convnd_meta_info'] @meta_register.register(torch.nn.Conv1d) @meta_register.register(torch.nn.Conv2d) @meta_register.register(torch.nn.Conv3d) +@meta_register.register(torch.nn.functional.conv1d) +@meta_register.register(torch.nn.functional.conv2d) +@meta_register.register(torch.nn.functional.conv3d) def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: """torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d meta info generator The atens graph of torch.nn.Convnd with bias is @@ -57,12 +60,19 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L has_bias: bool = False input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data - weight_tensor = next(filter(lambda x: x.name == 'weight', args)).data + weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM] # check if conv has bias - if len(args) == 4: - bias_tensor = next(filter(lambda x: x.name == 'bias', args)).data + if len(weight_tensors) > 1: has_bias = True + # bias tensor's shape only has one dimension + if len(weight_tensors[0].shape) == 1: + bias_tensor, weight_tensor = weight_tensors + else: + weight_tensor, bias_tensor = weight_tensors + + else: + weight_tensor = weight_tensors[0] # construct input args for forward fwd_args = [None] * 9 diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py index 76ed48674..b48748fa9 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py @@ -143,7 +143,7 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # NOTE: Linear don't have buffer and temp in forward and backward phase # the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=activation_size(output_tensor), + fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]), parameter=activation_size(weight_tensor), temp=0, buffer=0) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py index 303c40fdf..a973a8182 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py @@ -15,6 +15,16 @@ from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy +class ConvFunctionModule(nn.Module): + + def __init__(self, in_channels=4, out_channels=64, kernel_size=3): + super().__init__() + self.conv_weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) + + def forward(self, input): + return nn.functional.conv2d(input, self.conv_weight) + + def _conv_module_mem_test(rank, bias, world_size, port): """This function is for conv memory test Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL @@ -57,5 +67,47 @@ def test_conv_meta_concrete_info_match(bias=False): mp.spawn(run_func_module, nprocs=world_size) +def _conv_function_mem_test(rank, world_size, port): + """This function is for conv function memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether conv module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = ConvFunctionModule().cuda() + input = torch.rand(4, 4, 64, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # index of target node in computation graph + node_index = 2 + # total number of target node strategies + strategy_number = 16 + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=['input']) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_conv_function_concrete_info_match(): + world_size = 4 + run_func_module = partial(_conv_function_mem_test, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + if __name__ == '__main__': - test_conv_meta_concrete_info_match() + # test_conv_meta_concrete_info_match() + test_conv_function_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py index f7fc88884..62fe11e22 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py @@ -92,7 +92,7 @@ def _linear_function_mem_test(rank, world_size, port): model=model, device_mesh=device_mesh, node_index=2, - strategy_number=13, + strategy_number=23, input_args=[input], meta_arg_names=["input"])