diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py index ff67d0083..ee42807af 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py @@ -19,10 +19,13 @@ from ..registry import meta_register __all__ = ['linear_meta_info'] +@meta_register.register(torch.nn.functional.linear) @meta_register.register(torch.nn.Linear) def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: - """torch.nn.Linear meta info generator - The atens graph of torch.nn.Linear with bias is + """torch.nn.Linear & torch.nn.functional.linear meta info generator + NOTE: currently we separate the bias part from the biased linear ops, we will consider the memory consumption in add metainfo generator, + but we will hold the bias mechanism in the linear metainfo generator for future use. + graph(): %input_2 : [#users=2] = placeholder[target=placeholder](default=) %addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (None, %input_2, None), kwargs = {}) @@ -65,7 +68,7 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L has_bias: bool = False input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data - weight_tensor = next(filter(lambda x: x.name == 'weight', args)).data + weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM] # process the dimension of input and output if len(input_tensor.shape) > 2: @@ -76,9 +79,14 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L output_tensor: torch.Tensor output_tensor = output_tensor.view(-1, output_tensor.shape[-1]) - if len(args) == 4: - bias_tensor = next(filter(lambda x: x.name == 'bias', args)).data + if len(weight_tensors) > 1: has_bias = True + if len(weight_tensors[0].shape) == 2: + weight_tensor, bias_tensor = weight_tensors + else: + bias_tensor, weight_tensor = weight_tensors + else: + weight_tensor = weight_tensors[0] if has_bias: # calculate cost with bias diff --git a/colossalai/auto_parallel/meta_profiler/metainfo.py b/colossalai/auto_parallel/meta_profiler/metainfo.py index 4ea427f49..bec21818f 100644 --- a/colossalai/auto_parallel/meta_profiler/metainfo.py +++ b/colossalai/auto_parallel/meta_profiler/metainfo.py @@ -92,8 +92,12 @@ class MetaInfo: Compute meta info based on sharding strategy and the given target function. """ - assert meta_register.has(self._target.__class__), f'{self._target.__class__} not found in the meta registry' - meta_func = meta_register.get(self._target.__class__) + try: + # module + meta_func = meta_register.get(self._target.__class__) + except: + # function + meta_func = meta_register.get(self._target) # construct args for meta_func args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()] diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py index bdd622c5f..f7fc88884 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py @@ -20,7 +20,17 @@ if torch.__version__ >= '1.12.0': from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register -def _linear_module_mem_test(rank, bias, world_size, port): +class MyModule(nn.Module): + + def __init__(self, in_features=64, out_features=128): + super().__init__() + self.fc_weight = nn.Parameter(torch.randn(out_features, in_features)) + + def forward(self, input): + return nn.functional.linear(input, self.fc_weight) + + +def _linear_module_mem_test(rank, world_size, port): """This function is for linear memory test Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL @@ -32,7 +42,7 @@ def _linear_module_mem_test(rank, bias, world_size, port): """ disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - model = nn.Sequential(nn.Linear(64, 128, bias=bias)).cuda() + model = nn.Sequential(nn.Linear(64, 128, bias=False)).cuda() input = torch.rand(8, 8, 16, 64).cuda() input.requires_grad = True physical_mesh_id = torch.arange(0, 4) @@ -52,11 +62,50 @@ def _linear_module_mem_test(rank, bias, world_size, port): @run_on_environment_flag(name='AUTO_PARALLEL') @pytest.mark.dist @rerun_if_address_is_in_use() -def test_linear_meta_concrete_info_match(bias=False): +def test_linear_module_meta_concrete_info_match(): + world_size = 4 + run_func_module = partial(_linear_module_mem_test, world_size=world_size, port=free_port()) + mp.spawn(run_func_module, nprocs=world_size) + + +def _linear_function_mem_test(rank, world_size, port): + """This function is for linear memory test + Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL + + Args: + rank: device rank + bias: indicate whether linear module need bias + world_size: number of devices + port: port for initializing process group + """ + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + model = MyModule().cuda() + input = torch.rand(8, 8, 16, 64).cuda() + input.requires_grad = True + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # memory test + mem_test_for_node_strategy(rank=rank, + model=model, + device_mesh=device_mesh, + node_index=2, + strategy_number=13, + input_args=[input], + meta_arg_names=["input"]) + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_linear_function_meta_concrete_info_match(): world_size = 4 - run_func_module = partial(_linear_module_mem_test, bias=bias, world_size=world_size, port=free_port()) + run_func_module = partial(_linear_function_mem_test, world_size=world_size, port=free_port()) mp.spawn(run_func_module, nprocs=world_size) if __name__ == '__main__': - test_linear_meta_concrete_info_match() + # test_linear_module_meta_concrete_info_match() + test_linear_function_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py index 04d589ab3..7c06f2ee9 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py @@ -7,6 +7,7 @@ from torch.fx import GraphModule from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.tracer.tracer import ColoTracer @@ -49,8 +50,9 @@ def mem_test_for_node_strategy(rank: int, # construct the strategy for the output node placeholder_strategy = list(graph.nodes)[-1].strategies_vector[0] + output_key = next(key for key in target_node.strategies_vector[strategy_index].sharding_specs.keys() - if key in placeholder_strategy.sharding_specs) + if key.type == OperationDataType.OUTPUT) placeholder_strategy.sharding_specs[output_key] = target_node.strategies_vector[strategy_index].sharding_specs[ output_key] @@ -104,8 +106,12 @@ def mem_test_for_node_strategy(rank: int, ) # estimated memory - metainfo = MetaInfo(target_node.strategies_vector[strategy_index], - target_node.graph.owning_module.get_submodule(target_node.target)) + if target_node.op == "call_module": + metainfo = MetaInfo(target_node.strategies_vector[strategy_index], + target_node.graph.owning_module.get_submodule(target_node.target)) + else: + metainfo = MetaInfo(target_node.strategies_vector[strategy_index], target_node.target) + print("estimated memory:") print( f"forward activation: {metainfo.memory_cost.fwd.activation / 1024} kb, forward param: {metainfo.memory_cost.fwd.parameter / 1024} kb"