From 40c916b1924097b154d611ef4a7177f8c5ebac76 Mon Sep 17 00:00:00 2001 From: Boyuan Yao <70263930+Cypher30@users.noreply.github.com> Date: Mon, 13 Feb 2023 16:09:22 +0800 Subject: [PATCH] [autoparallel] Patch meta information of `torch.nn.functional.softmax` and `torch.nn.Softmax` (#2674) * [autoparallel] softmax metainfo * [autoparallel] softmax metainfo --- .../meta_profiler/meta_registry/activation.py | 50 ++++++++++++++++++ .../test_metainfo/test_activation_metainfo.py | 51 ++++++++++++++++++- 2 files changed, 99 insertions(+), 2 deletions(-) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py index 774457f7d..c659cd9ac 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py @@ -72,3 +72,53 @@ def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Lis fwd_out = [torch.zeros_like(output_tensor, device='meta')] return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out + + +@meta_register.register(torch.nn.Softmax) +@meta_register.register(torch.nn.functional.softmax) +def softmax_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """torch.nn.Softmax metainfo generator + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + input_tensor = next( + filter( + lambda x: + (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim', + args)).data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + softmax_dim = next(filter(lambda x: x.name == 'softmax_dim', args)).data + + # calculate cost + + # calculate compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten._softmax.default]([input_tensor], [output_tensor]) + bwd_compute_cost = flop_mapping[torch.ops.aten._softmax_backward_data.default]([output_tensor], [input_tensor]) + + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # calculate memory cost + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]), + parameter=0, + temp=0, + buffer=0) + bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor), + parameter=0, + temp=activation_size(input_tensor), + buffer=0) + + # total cost is the sum of forward and backward cost + total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + temp=fwd_memory_cost.temp + bwd_memory_cost.temp, + buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [] + fwd_buffer = [torch.zeros_like(output_tensor, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py index f468b1ab2..b9b42f8c1 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py @@ -5,6 +5,8 @@ import torch import torch.multiprocessing as mp import torch.nn as nn +from colossalai.auto_parallel.meta_profiler import meta_register +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.initialize import launch @@ -12,7 +14,7 @@ from colossalai.logging import disable_existing_loggers from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port -from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results def _ReLU_module_mem_test(rank, world_size, port): @@ -57,5 +59,50 @@ def test_ReLU_meta_concrete_info_match(): mp.spawn(run_func_module, nprocs=world_size) +@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +def test_sofmax_meta_info(): + meta_func = meta_register.get(torch.nn.functional.softmax) + # construct meta tensors + input_tensor = torch.rand(256, 1024, device="meta") + output_tensor = torch.rand(256, 1024, device="meta") + softmax_dim = 0 + + # construct operation data + input_data = OperationData(name='input', type=OperationDataType.ARG, data=input_tensor) + output_data = OperationData(name='output', type=OperationDataType.OUTPUT, data=output_tensor) + softmax_dim_data = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim) + + # construct args and kwargs + args = [input_data, softmax_dim_data, output_data] + kwargs = {'inplace': False} + + # estimated results + compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) + + # actual results + input_real_tensor = torch.rand(256, 1024, device="cuda") + + input_real_tensor.requires_grad = True + + # fwd + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + output_real_tensor = torch.nn.functional.softmax(input_real_tensor, dim=softmax_dim) + fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + # bwd + upstream_grad = torch.rand_like(output_real_tensor) + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + torch.autograd.backward(output_real_tensor, upstream_grad) + bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + print_results([input_real_tensor], [output_real_tensor], compute_cost, memory_cost, fwd_allocated, fwd_peak, + bwd_allocated, bwd_peak) + + if __name__ == '__main__': - test_ReLU_meta_concrete_info_match() + # test_ReLU_meta_concrete_info_match() + test_sofmax_meta_info()