diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py index 15c3063b7..281a92c0d 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py @@ -24,26 +24,25 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs """ - input_op_data, other_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT] + input_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT] output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args)) # construct forward args for flop mapping - fwd_in_args = [input_op_data.data, other_op_data.data] + fwd_in_args = [opdata.data for opdata in input_op_data] fwd_out_args = [output_op_data.data] # calculate cost # calculate compute cost # NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case - fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args) + fwd_compute_cost = flop_mapping[torch.ops.aten.add.Tensor](fwd_in_args, fwd_out_args) bwd_compute_cost = fwd_compute_cost * 2 compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) # calculate memory cost - param_mem_cost = activation_size( - [arg.data for arg in [input_op_data, other_op_data] if arg.type == OperationDataType.PARAM]) + param_mem_cost = activation_size([arg.data for arg in input_op_data if arg.type == OperationDataType.PARAM]) fwd_mem_cost = MemoryCost( - activation=activation_size([input_op_data.data, output_op_data.data]), + activation=activation_size(output_op_data.data), parameter=param_mem_cost, ) bwd_mem_cost = MemoryCost( diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index af3cb5810..78dc58c90 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -4,7 +4,7 @@ from typing import Dict, List, Tuple, Union import torch from torch.fx.node import Node -from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo +from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, OperationDataType, @@ -234,15 +234,19 @@ class MetaInfoNodeHandler(NodeHandler): """ super().register_strategy(compute_resharding_cost=compute_resharding_cost) target = self.get_target_function() - metainfo_vector = [] - for strategy in self.strategies_vector: - metainfo = MetaInfo(strategy, target) - strategy.compute_cost = metainfo.compute_cost - strategy.memory_cost = metainfo.memory_cost - metainfo_vector.append(metainfo) - - # attach metainfos to the handler - setattr(self, "metainfo_vector", metainfo_vector) + # Currently we haven't patched all the torch functions and modules, so if the target + # is not patched, we will use the default cost model to compute the cost. + # TODO: patch all torch functions and modules to make it clean + if meta_register.has(target.__class__) or meta_register.has(target): + metainfo_vector = [] + for strategy in self.strategies_vector: + metainfo = MetaInfo(strategy, target) + strategy.compute_cost = metainfo.compute_cost + strategy.memory_cost = metainfo.memory_cost + metainfo_vector.append(metainfo) + + # attach metainfos to the handler + setattr(self, "metainfo_vector", metainfo_vector) return self.strategies_vector @@ -281,14 +285,18 @@ class MetaInfoModuleHandler(ModuleHandler): """ super().register_strategy(compute_resharding_cost=compute_resharding_cost) target = self.get_target_function() - metainfo_vector = [] - for strategy in self.strategies_vector: - metainfo = MetaInfo(strategy, target) - strategy.compute_cost = metainfo.compute_cost - strategy.memory_cost = metainfo.memory_cost - metainfo_vector.append(metainfo) - - # attach metainfos to the handler - setattr(self, "metainfo_vector", metainfo_vector) + # Currently we haven't patched all the torch functions and modules, so if the target + # is not patched, we will use the default cost model to compute the cost. + # TODO: patch all torch functions and modules to make it clean + if meta_register.has(target.__class__) or meta_register.has(target): + metainfo_vector = [] + for strategy in self.strategies_vector: + metainfo = MetaInfo(strategy, target) + strategy.compute_cost = metainfo.compute_cost + strategy.memory_cost = metainfo.memory_cost + metainfo_vector.append(metainfo) + + # attach metainfos to the handler + setattr(self, "metainfo_vector", metainfo_vector) return self.strategies_vector