diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py index 909232e61..774457f7d 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py @@ -30,7 +30,7 @@ def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Lis input_tensor = args[0].data output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data - inplace = kwargs.get("inplace", False) + is_inplace = kwargs.get("inplace", False) # construct input args for forward fwd_in_args = [input_tensor] @@ -51,7 +51,7 @@ def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Lis # NOTE: the inplace ReLU don't have forward memory cost # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward fwd_memory_cost = MemoryCost( - activation=activation_size(input_tensor) if inplace else activation_size([output_tensor, input_tensor]), + activation=activation_size(input_tensor) if is_inplace else activation_size([output_tensor, input_tensor]), parameter=0, temp=0, buffer=0) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py index 3ecabb6dc..79780c92e 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py @@ -14,6 +14,7 @@ __all__ = ["avgpool_meta_info", "maxpool_meta_info"] @meta_register.register(torch.nn.AdaptiveAvgPool1d) @meta_register.register(torch.nn.AdaptiveAvgPool2d) @meta_register.register(torch.nn.AdaptiveAvgPool3d) +@meta_register.register(torch.flatten) def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: """Meta info for AdaptiveAvgPool The aten graph of AdaptiveAvgPool is @@ -32,6 +33,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, input_tensor = args[0].data output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + is_inplace = kwargs.get("inplace", False) # construct forward args for flop mapping fwd_in_args = [input_tensor] @@ -51,8 +53,8 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) # calculate memory cost - fwd_mem_cost = MemoryCost(activation=activation_size(output_tensor)) - bwd_mem_cost = MemoryCost(activation=activation_size(input_tensor)) + fwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(output_tensor)) + bwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(input_tensor)) # total cost total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation)