diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py index 2bb852dfa..659edf548 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py @@ -64,20 +64,14 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha last_physical_output_dims = output_op_data.data.dim() - 1 if last_logical_input_dims in input_sharding_spec.dim_partition_dict: - update_partition_dim( - sharding_spec=input_sharding_spec, - dim_mapping={last_logical_input_dims: last_physical_input_dims}, - physical_shape=input_op_data.data.shape, - inplace=True, - ) + input_last_dim_mapping = {last_logical_input_dims: last_physical_input_dims} + else: + input_last_dim_mapping = {} if last_logical_output_dims in output_sharding_spec.dim_partition_dict: - update_partition_dim( - sharding_spec=output_sharding_spec, - dim_mapping={last_logical_output_dims: last_physical_output_dims}, - physical_shape=output_op_data.data.shape, - inplace=True, - ) + output_last_dim_mapping = {last_logical_output_dims: last_physical_output_dims} + else: + output_last_dim_mapping = {} # get logger for debug message logger = get_dist_logger() @@ -97,12 +91,18 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) try: # replace the 0th dimension in the logical sharding with ith dimension in the physical sharding + input_dim_mapping = {0: i} + input_dim_mapping.update(input_last_dim_mapping) + update_partition_dim(sharding_spec=input_sharding_spec, - dim_mapping={0: i}, + dim_mapping=input_dim_mapping, physical_shape=input_op_data.data.shape, inplace=True) + output_dim_mapping = {0: i} + output_dim_mapping.update(output_last_dim_mapping) + update_partition_dim(sharding_spec=output_sharding_spec, - dim_mapping={0: i}, + dim_mapping=output_dim_mapping, physical_shape=output_op_data.data.shape, inplace=True) strategy_copy.name = f'{strategy.name}_{i}' @@ -120,12 +120,17 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) # after updating, the logical shape will be replaced by the physical shape + input_dim_mapping = {} + input_dim_mapping.update(input_last_dim_mapping) update_partition_dim(sharding_spec=input_sharding_spec, - dim_mapping={}, + dim_mapping=input_dim_mapping, physical_shape=input_op_data.data.shape, inplace=True) + + output_dim_mapping = {} + output_dim_mapping.update(output_last_dim_mapping) update_partition_dim(sharding_spec=output_sharding_spec, - dim_mapping={}, + dim_mapping=output_dim_mapping, physical_shape=output_op_data.data.shape, inplace=True) sharding_strategies.append(strategy_copy) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py index e0130936d..fb8821fae 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -26,18 +26,21 @@ from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy -def check_linear_module_handler(rank, bias, world_size, port): +def check_linear_module_handler(rank, bias, input_shape, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - input = torch.rand(4, 4, 4, 16).cuda() + input = torch.rand(input_shape).cuda() # the index of linear node in computation graph node_index = 1 # strategy number of linear node - strategy_number = 24 + if input_shape == (1, 4, 4, 16): + strategy_number = 19 + else: + strategy_number = 24 # construct input args input_args = [input] # construct meta arg names @@ -50,7 +53,7 @@ def check_linear_module_handler(rank, bias, world_size, port): meta_arg_names=meta_arg_names) tracer = ColoTracer() - graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 4, 16).to('meta')}) + graph = tracer.trace(model, meta_args={"input": torch.rand(input_shape).to('meta')}) gm = ColoGraphModule(model, graph) linear_mod_node = list(graph.nodes)[1] @@ -69,9 +72,10 @@ def check_linear_module_handler(rank, bias, world_size, port): assert op_data.data is not None assert mapping['input'].name == "input_1" - assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16]) + assert mapping['input'].data.shape == torch.Size(input_shape) assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([64, 16]) + input_logical_shape = mapping['input'].data.view(-1, 16).shape + assert mapping['input'].logical_shape == input_logical_shape assert mapping['other'].name == "weight" assert mapping['other'].data.shape == torch.Size([32, 16]) @@ -85,28 +89,32 @@ def check_linear_module_handler(rank, bias, world_size, port): assert mapping['bias'].logical_shape == torch.Size([32]) assert mapping['output'].name == "_0" - assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32]) + output_shape = input_shape[:-1] + (32,) + assert mapping['output'].data.shape == torch.Size(output_shape) assert mapping['output'].type == OperationDataType.OUTPUT - assert mapping['output'].logical_shape == torch.Size([64, 32]) + output_logical_shape = mapping['output'].data.view(-1, 32).shape + assert mapping['output'].logical_shape == torch.Size(output_logical_shape) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] - # one strategy will be converted to different physical sharding spec - assert len(strategy_name_list) > 8 + + # First dimension cannot be shard if input shape is (1, 4, 4, 16) + if input_shape != (1, 4, 4, 16): + assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S01R = S01R x RR_0' in strategy_name_list # SS = SR x RS - assert 'S0S1 = S0R x RS1_0' in strategy_name_list assert 'S0S1 = S0R x RS1_1' in strategy_name_list assert 'S0S1 = S0R x RS1_2' in strategy_name_list - assert 'S1S0 = S1R x RS0_0' in strategy_name_list assert 'S1S0 = S1R x RS0_1' in strategy_name_list assert 'S1S0 = S1R x RS0_2' in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R_0' in strategy_name_list assert 'S0R = S0S1 x S1R_1' in strategy_name_list assert 'S0R = S0S1 x S1R_2' in strategy_name_list - assert 'S1R = S1S0 x S0R_0' in strategy_name_list assert 'S1R = S1S0 x S0R_1' in strategy_name_list assert 'S1R = S1S0 x S0R_2' in strategy_name_list @@ -123,7 +131,6 @@ def check_linear_module_handler(rank, bias, world_size, port): assert 'RS1 = RR x RS1' in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR_0' in strategy_name_list assert 'S01R = S01R x RR_1' in strategy_name_list assert 'S01R = S01R x RR_2' in strategy_name_list @@ -164,7 +171,7 @@ class LinearModel(nn.Module): return x -def check_linear_function_handler(rank, bias, world_size, port): +def check_linear_function_handler(rank, bias, input_shape, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = LinearModel().cuda() @@ -172,12 +179,15 @@ def check_linear_function_handler(rank, bias, world_size, port): mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - input = torch.rand(4, 4, 4, 16).cuda() + input = torch.rand(input_shape).cuda() other = torch.rand(32, 16).cuda() # the index of linear node in computation graph node_index = 2 # strategy number of linear node - strategy_number = 24 + if input_shape == (1, 4, 4, 16): + strategy_number = 19 + else: + strategy_number = 24 # construct input args input_args = [input, other] # construct meta arg names @@ -192,7 +202,7 @@ def check_linear_function_handler(rank, bias, world_size, port): tracer = ColoTracer() graph = tracer.trace(model, meta_args={ - "input": torch.rand(4, 4, 4, 16).to('meta'), + "input": torch.rand(input_shape).to('meta'), 'others': torch.rand(32, 16).to('meta') }) gm = ColoGraphModule(model, graph) @@ -209,9 +219,10 @@ def check_linear_function_handler(rank, bias, world_size, port): mapping = handler.get_operation_data_mapping() assert mapping['input'].name == "input_1" - assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16]) + assert mapping['input'].data.shape == torch.Size(input_shape) assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([64, 16]) + input_logical_shape = mapping['input'].data.view(-1, 16).shape + assert mapping['input'].logical_shape == torch.Size(input_logical_shape) assert mapping['other'].name == "others" assert mapping['other'].data.shape == torch.Size([32, 16]) @@ -225,27 +236,32 @@ def check_linear_function_handler(rank, bias, world_size, port): assert mapping['other'].logical_shape == torch.Size([16, 32]) assert mapping['output'].name == "linear" - assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32]) + output_shape = input_shape[:-1] + (32,) + assert mapping['output'].data.shape == torch.Size(output_shape) assert mapping['output'].type == OperationDataType.OUTPUT + output_logical_shape = mapping['output'].data.view(-1, 32).shape + assert mapping['output'].logical_shape == torch.Size(output_logical_shape) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] - # one strategy will be converted to different physical sharding spec - assert len(strategy_name_list) > 8 + + # First dimension cannot be shard if input shape is (1, 4, 4, 16) + if input_shape != (1, 4, 4, 16): + assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert 'S0S1 = S0R x RS1_0' in strategy_name_list + assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert 'S0R = S0S1 x S1R_0' in strategy_name_list + assert 'S01R = S01R x RR_0' in strategy_name_list # SS = SR x RS - assert 'S0S1 = S0R x RS1_0' in strategy_name_list assert 'S0S1 = S0R x RS1_1' in strategy_name_list assert 'S0S1 = S0R x RS1_2' in strategy_name_list - assert 'S1S0 = S1R x RS0_0' in strategy_name_list assert 'S1S0 = S1R x RS0_1' in strategy_name_list assert 'S1S0 = S1R x RS0_2' in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R_0' in strategy_name_list assert 'S0R = S0S1 x S1R_1' in strategy_name_list assert 'S0R = S0S1 x S1R_2' in strategy_name_list - assert 'S1R = S1S0 x S0R_0' in strategy_name_list assert 'S1R = S1S0 x S0R_1' in strategy_name_list assert 'S1R = S1S0 x S0R_2' in strategy_name_list @@ -262,7 +278,6 @@ def check_linear_function_handler(rank, bias, world_size, port): assert 'RS1 = RR x RS1' in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR_0' in strategy_name_list assert 'S01R = S01R x RR_1' in strategy_name_list assert 'S01R = S01R x RR_2' in strategy_name_list @@ -293,15 +308,23 @@ def check_linear_function_handler(rank, bias, world_size, port): assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] -# @parameterize('bias', [True, False]) +@parameterize('input_shape', [(1, 4, 4, 16), (4, 4, 4, 16)]) @run_on_environment_flag(name='AUTO_PARALLEL') @pytest.mark.dist @rerun_if_address_is_in_use() -def test_linear_handler(bias=False): +def test_linear_handler(input_shape, bias=False): world_size = 4 - run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port()) + run_func_module = partial(check_linear_module_handler, + bias=bias, + input_shape=input_shape, + world_size=world_size, + port=free_port()) mp.spawn(run_func_module, nprocs=world_size) - run_func_function = partial(check_linear_function_handler, bias=bias, world_size=world_size, port=free_port()) + run_func_function = partial(check_linear_function_handler, + bias=bias, + input_shape=input_shape, + world_size=world_size, + port=free_port()) mp.spawn(run_func_function, nprocs=world_size) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_block.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_related_module.py similarity index 71% rename from tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_block.py rename to tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_related_module.py index f88d907c6..82accebdb 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_block.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_related_module.py @@ -1,11 +1,14 @@ from typing import Optional, Tuple, Union import torch -# from transformers.models.gpt2.modeling_gpt2 import GPT2Attention import torch.nn as nn import transformers from torch.fx import GraphModule -from transformers.models.gpt2.modeling_gpt2 import GPT2MLP +from transformers.models.gpt2.modeling_gpt2 import ( + GPT2MLP, + BaseModelOutputWithPastAndCrossAttentions, + GPT2PreTrainedModel, +) from transformers.pytorch_utils import Conv1D from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP @@ -173,8 +176,91 @@ class GPT2Block(nn.Module): return outputs # hidden_states, present, (attentions, cross_attentions) +class GPT2Model(GPT2PreTrainedModel): + _keys_to_ignore_on_load_missing = ["attn.masked_bias"] + + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + + device = input_ids.device + + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + # add_2 + hidden_states = inputs_embeds + position_embeds + + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + # transformer_drop + hidden_states = self.drop(hidden_states) + # comment to run pipeline + # add_3 + output_shape = input_shape + (hidden_states.size(-1),) + + presents = None + all_self_attentions = None + all_cross_attentions = None + all_hidden_states = None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i]) + hidden_states = outputs[0] + + hidden_states = self.ln_f(hidden_states) + # comment to run pipeline + hidden_states = hidden_states.view(output_shape) + + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None) + + @run_on_environment_flag(name='AUTO_PARALLEL') -@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP]) +@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) def test_self_attention_block(model_cls): config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM) if model_cls == GPT2MLP: @@ -193,11 +279,17 @@ def test_self_attention_block(model_cls): input_sample = { 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), } - else: + elif model_cls in (GPT2Attention, GPT2Block): input_sample = { 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), 'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'), } + else: + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + input_sample = {k: v.to('meta') for k, v in kwargs.items()} graph = tracer.trace(root=model, meta_args=input_sample)