diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 9eb58df4d..bc99be4cc 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -78,9 +78,9 @@ class GPT2PipelineForwards: if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: - batch_size, seq_length = input_ids.shape input_shape = input_ids.size() - input_ids = input_ids.view(-1, seq_length) + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] batch_size = inputs_embeds.shape[0] @@ -89,13 +89,14 @@ class GPT2PipelineForwards: device = input_ids.device if input_ids is not None else inputs_embeds.device if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, seq_length) + token_type_ids = token_type_ids.view(-1, input_shape[-1]) else: if hidden_states is None: raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") input_shape = hidden_states.size()[:-1] - batch_size, seq_length = input_shape[0], input_shape[1] + batch_size = input_shape[0] device = hidden_states.device + hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:]) # GPT2Attention mask. if attention_mask is not None: @@ -136,9 +137,9 @@ class GPT2PipelineForwards: if stage_manager.is_first_stage(): if position_ids is not None: - position_ids = position_ids.view(-1, seq_length) + position_ids = position_ids.view(-1, input_shape[-1]) else: - position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) + position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) if inputs_embeds is None: @@ -721,7 +722,6 @@ def get_gpt2_flash_attention_forward(): use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - _, tgt_len, _ = hidden_states.size() if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 744ca276e..0198e0468 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -58,9 +58,27 @@ def data_gen_for_sequence_classification(): def date_gen_for_double_heads(): - data = data_gen_for_lm() - data['mc_labels'] = torch.zeros(data['input_ids'].shape[0], dtype=torch.int64) - return data + num_choices = 2 + batch_size = 2 + input_ids = torch.tensor( + [[15496, 11, 616, 3290, 318, 13779, 318, 13779], [15496, 11, 616, 3290, 318, 13779, 318, 13779]], + dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64) + + mc_token_ids = torch.arange(0, num_choices, dtype=torch.int64) + mc_token_ids = mc_token_ids.expand((batch_size, num_choices)) + multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, num_choices, -1).contiguous() + multiple_choice_input_mask = attention_mask.unsqueeze(1).expand(-1, num_choices, -1).contiguous() + + inputs = { + "input_ids": multiple_choice_inputs_ids, + "mc_token_ids": mc_token_ids, + "attention_mask": multiple_choice_input_mask, + "labels": multiple_choice_inputs_ids, + "mc_labels": mc_labels, + } + return inputs # define output transform function @@ -98,14 +116,12 @@ model_zoo.register(name='transformers_gpt_lm', output_transform_fn=output_transform_fn, loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) - -# TODO The model training is failing, there is a bug in GPT2DoubleHeadsModel in transformers. -# model_zoo.register(name='transformers_gpt_double_heads', -# model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), -# data_gen_fn=date_gen_for_double_heads, -# output_transform_fn=lambda x: dict(loss=x.loss + x.mc_loss), -# loss_fn=loss_fn, -# model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_gpt_double_heads', + model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), + data_gen_fn=date_gen_for_double_heads, + output_transform_fn=output_transform_fn, + loss_fn=lambda x: x.loss + x.mc_loss, + model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_gpt_for_question_answering', model_fn=lambda: transformers.GPT2ForQuestionAnswering(config), data_gen_fn=data_gen_for_question_answering, diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 23561f8ae..18be68bf6 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -86,7 +86,8 @@ def check_gemini_plugin(subset: str, init_method: str = 'none', early_stop: bool 'transformers_t5_encoder_model', # does not support apex rmsnorm 'transformers_chatglm', 'transformers_sam', - 'transformers_vit' + 'transformers_vit', + 'transformers_gpt_double_heads', # TODO check why does the model fail to run using Gemini ]: continue diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index f77bf7495..c9c6447a4 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -141,13 +141,13 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo data = data_gen_fn() if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0: - seq_len = data['input_ids'].shape[1] + seq_len = data['input_ids'].shape[-1] lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) times = lcm // seq_len input_shape = data['input_ids'].shape for k, v in data.items(): if v.shape == input_shape: - data[k] = v.repeat(1, times) + data[k] = v.repeat(input_shape[:-1] + (input_shape[-1] * times,)) sharded_model.train() if booster.plugin.stage_manager is not None: diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 115a1bd79..c4cc3812d 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -136,14 +136,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'num_microbatches': 4, 'enable_all_optimization': True, 'use_lazy_init': True, - 'enable_sequence_parallelism': True, - 'precision': 'fp32', -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'enable_sequence_parallelism': True, 'precision': 'fp32', }, { 'tp_size': 2,