diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py index c659cd9ac..faeed9f29 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py @@ -1,124 +1,85 @@ -from typing import List, Tuple +from typing import Callable, List, Tuple import torch from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import flop_mapping +from colossalai.fx.profiler.opcount import elementwise_flop_counter from ..registry import meta_register -__all__ = ["relu_meta_info"] +__all__ = ["elementwise_meta_info"] -@meta_register.register(torch.nn.ReLU) -def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: - """torch.nn.ReLU metainfo generator - The aten graph of torch.nn.ReLU is - graph(): - %input_2 : [#users=1] = placeholder[target=placeholder](default=) - %relu_default : [#users=2] = call_function[target=torch.ops.aten.relu.default](args = (%input_2,), kwargs = {}) - %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%relu_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None}) - %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%relu_default,), kwargs = {}) - %threshold_backward_default : [#users=1] = call_function[target=torch.ops.aten.threshold_backward.default](args = (%zeros_like_default, %detach_default, None), kwargs = {}) - %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%threshold_backward_default,), kwargs = {}) - %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {}) +def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0) -> Callable: + """This is a function to create the meta information generator for elementwise operations + + Args: + temp_mem_scale (float, optional): temp memory scaling factor for backward. Defaults to 0. + buffer_mem_scale (float, optional): buffer memory scaling factor for forward. Defaults to 0. Returns: - Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + Callable: meta information generator """ - input_tensor = args[0].data - output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data - is_inplace = kwargs.get("inplace", False) + def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + input_tensor = next( + filter( + lambda x: + (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim', + args)).data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + is_inplace = 1 if kwargs.get('inplace', False) else 0 - # construct input args for forward - fwd_in_args = [input_tensor] + flop_counter = elementwise_flop_counter(1, 0) + # calculate compute cost + fwd_compute_cost = flop_counter([input_tensor], [output_tensor]) + bwd_compute_cost = flop_counter([output_tensor], [input_tensor]) - # construct input args for backward - bwd_in_args = [output_tensor] + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) - # calculate cost - # the fwd op with compute cost is relu.default - # the bwd op with compute cost is threshold_backward + # calculate memory cost + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + # NOTE: if in_place is True, we will not create a new tensor in forward + fwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) * (2 - is_inplace), + parameter=0, + temp=0, + buffer=activation_size(input_tensor) * buffer_mem_scale) - # calculate compute cost - fwd_compute_cost = flop_mapping[torch.ops.aten.relu.default](fwd_in_args, (output_tensor,)) - bwd_compute_cost = flop_mapping[torch.ops.aten.threshold_backward.default](bwd_in_args, (input_tensor,)) - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + # temp_mem_scale is for situation like softmax backward + # the buffer will be removed during backward phase + bwd_memory_cost = MemoryCost( + activation=activation_size(input_tensor) - activation_size(input_tensor) * buffer_mem_scale, + parameter=0, + temp=activation_size(input_tensor) * temp_mem_scale + activation_size(input_tensor) * buffer_mem_scale, + buffer=0) - # calculate memory cost - # NOTE: the inplace ReLU don't have forward memory cost - # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost( - activation=activation_size(input_tensor) if is_inplace else activation_size([output_tensor, input_tensor]), - parameter=0, - temp=0, - buffer=0) + # total cost is the sum of forward and backward cost + total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + temp=fwd_memory_cost.temp + bwd_memory_cost.temp, + buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer) - bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor), parameter=0, temp=0, buffer=0) + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) - # total cost is the sum of forward and backward cost - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [] + fwd_buffer = [torch.zeros_like(output_tensor, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device='meta')] - memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out - # store fwd_in, fwd_buffer, fwd_out - # NOTE: It might seems a little bit weird here, we just want to align it with the older version - # of MetaInfoProp. In the future we might modify this part to make it clearer. - fwd_in = [] - fwd_buffer = [torch.zeros_like(output_tensor, device='meta')] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] - - return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out + return meta_func -@meta_register.register(torch.nn.Softmax) -@meta_register.register(torch.nn.functional.softmax) -def softmax_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: - """torch.nn.Softmax metainfo generator - Returns: - Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs - """ - input_tensor = next( - filter( - lambda x: - (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim', - args)).data - output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data - softmax_dim = next(filter(lambda x: x.name == 'softmax_dim', args)).data +# register meta information +# (0, 0) +meta_register.register([torch.nn.ReLU, torch.nn.functional.relu, torch.tanh])(elementwise_meta_info(0, 0)) - # calculate cost +# (1, 0) +meta_register.register([torch.nn.Softmax, torch.nn.functional.softmax])(elementwise_meta_info(1, 0)) - # calculate compute cost - fwd_compute_cost = flop_mapping[torch.ops.aten._softmax.default]([input_tensor], [output_tensor]) - bwd_compute_cost = flop_mapping[torch.ops.aten._softmax_backward_data.default]([output_tensor], [input_tensor]) - - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) - - # calculate memory cost - # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]), - parameter=0, - temp=0, - buffer=0) - bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor), - parameter=0, - temp=activation_size(input_tensor), - buffer=0) - - # total cost is the sum of forward and backward cost - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, - temp=fwd_memory_cost.temp + bwd_memory_cost.temp, - buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer) - - memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) - - # store fwd_in, fwd_buffer, fwd_out - fwd_in = [] - fwd_buffer = [torch.zeros_like(output_tensor, device='meta')] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] - - return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out +# (0, 0.25) for dropout, the buffer is in bool type so that the buffer memory cost is 0.25 times of input tensor +meta_register.register([torch.nn.Dropout, torch.nn.functional.dropout])(elementwise_meta_info(0, 0.25)) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py index b9b42f8c1..e41ac4fa6 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py @@ -17,51 +17,15 @@ from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results -def _ReLU_module_mem_test(rank, world_size, port): - """This function is for ReLU memory test - Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL - - Args: - Args: - rank: device rank - bias: indicate whether conv module need bias - world_size: number of devices - port: port for initializing process group - """ - disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - model = nn.Sequential(nn.ReLU()).cuda() - input = torch.rand(4, 128, 64, 64).cuda() - input.requires_grad = True - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - - # index of target node in computation graph - node_index = 1 - # total number of target node strategies - strategy_number = 1 - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) - - -@run_on_environment_flag(name='AUTO_PARALLEL') -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_ReLU_meta_concrete_info_match(): - world_size = 4 - run_func_module = partial(_ReLU_module_mem_test, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) - - @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") -def test_sofmax_meta_info(): - meta_func = meta_register.get(torch.nn.functional.softmax) +@parameterize('func', [ + torch.nn.functional.softmax, + torch.nn.functional.relu, + torch.tanh, + torch.nn.functional.dropout, +]) +def test_activation_meta_info(func): + meta_func = meta_register.get(func) # construct meta tensors input_tensor = torch.rand(256, 1024, device="meta") output_tensor = torch.rand(256, 1024, device="meta") @@ -87,7 +51,7 @@ def test_sofmax_meta_info(): # fwd torch.cuda.reset_peak_memory_stats() mem_stamp0 = torch.cuda.memory_allocated() - output_real_tensor = torch.nn.functional.softmax(input_real_tensor, dim=softmax_dim) + output_real_tensor = func(input_real_tensor) fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 @@ -104,5 +68,4 @@ def test_sofmax_meta_info(): if __name__ == '__main__': - # test_ReLU_meta_concrete_info_match() - test_sofmax_meta_info() + test_activation_meta_info()