From 208ac8f2ba67d8f43ec6f9024c3a4d112f9b4586 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 13 Jul 2023 15:34:06 +0800 Subject: [PATCH] [pipeline] Add Pipeline Forward for GPT2Model Shardformer (#4224) * * fix typehint & docstring in sharder.py * * update pipeline forward for GPT2Model * * add test for pipeline forward of GPT2Model * * add cache cleaning in gpt2 test * * change assert to raise command --- colossalai/shardformer/layer/linear.py | 2 +- colossalai/shardformer/policies/gpt2.py | 270 +++++++++++++++++- colossalai/shardformer/shard/sharder.py | 15 +- .../test_model/test_shard_gpt2.py | 2 + .../test_model/test_shard_gpt2_pipeline.py | 77 +++++ 5 files changed, 357 insertions(+), 9 deletions(-) create mode 100644 tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index a8439f303..383d9b3f5 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -129,7 +129,7 @@ class Linear1D_Col(ParallelModule): **kwargs) with torch.no_grad(): - # the weigh to the linear layer is a transpose + # the weight to the linear layer is a transpose # thus shard on row is equal to shard on column sharded_weight = shard_rowwise(module.weight.data, process_group) linear_1d.weight.data.copy_(sharded_weight) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 8f9d90e67..ffba27a50 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,6 +1,14 @@ -import torch.nn as nn +import logging +from functools import partial +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.nn import Module import colossalai.shardformer.layer as col_nn +from colossalai.pipeline.stage_manager import PipelineStageManager from .._utils import getattr_, setattr_ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -119,6 +127,46 @@ class GPT2ModelPolicy(GPT2Policy): def __init__(self) -> None: super().__init__() + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2Model + + policy = super().module_policy() + if self.pipeline_stage_manager: + # set None as default + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + 'forward': + partial(GPT2PipelineForwards.gpt2_model_forward, + stage_manager=stage_manager, + stage_index=stage_index) + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=GPT2Model) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + module = self.model + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.wte) + held_layers.append(module.wpe) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + # TODO: check whether there is shared param in gpt2model + """No shared params in gpt2 model.""" + return [] + # GPT2LMHeadModel class GPT2LMHeadModelPolicy(GPT2Policy): @@ -194,3 +242,223 @@ class GPT2ForSequenceClassificationPolicy(GPT2Policy): def __init__(self) -> None: super().__init__() + + +class GPT2PipelineForwards: + ''' + This class serves as a micro library for forward function substitution of GPT2 models + under pipeline setting. + ''' + + @staticmethod + def gpt2_model_forward( + self: 'GPT2Model', + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None) -> Union[Tuple, 'BaseModelOutputWithPastAndCrossAttentions']: + + # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. + # Please refer to original code of transformers for more details. + + from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions + + # Preprocess passed in arguments + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + input_shape = input_ids.size() + input_ids = input_ids.view(-1, seq_length) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, seq_length) + else: + if hidden_states is None: + raise ValueError("hidden_states shouln't be None for stages other than the first stage.") + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape[0], input_shape[1] + device = hidden_states.device + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if stage_manager.is_first_stage(): + if position_ids is not None: + position_ids = position_ids.view(-1, seq_length) + else: + position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logging.warning('Non-empty past_key_values is not supported for pipeline models at the moment.') + past_key_values = None + if output_attentions: + logging.warning('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logging.warning('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logging.warning('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + if self.gradient_checkpointing and self.training: + if use_cache: + logging.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + + # Going through held blocks. + start_idx, end_idx = stage_index[0], stage_index[1] + for i in range(start_idx, end_idx): + block = self.h[i] + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=None, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + if stage_manager.is_last_stage(): + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if stage_manager.is_last_stage(): + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + else: + # always return dict for intermediate stage + return {'hidden_states': hidden_states} diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 882f93c7a..5e0b572e2 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -72,17 +72,18 @@ class ModelSharder(object): attr_replacement: Dict[str, Any], param_replacement: List[Callable], method_replacement: Dict[str, Callable], - sub_module_replacement: List[Callable], + sub_module_replacement: List[SubModuleReplacementDescription], ) -> None: r""" Reverse the replace layer operation Args: - layer (torch.nn.Module): The object of layer to shard - origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name. - attr_replacement (Dict): The attribute dict to modify + module (torch.nn.Module): The object of layer to shard + origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name + attr_replacement (Dict[str, Any]): The attribute dict to modify param_replacement (List[Callable]): The function list to get parameter shard information in policy - sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy + method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement + sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy """ if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \ (module.__class__ == origin_cls): @@ -111,7 +112,7 @@ class ModelSharder(object): Replace the attribute of the layer Args: - layer (:class:`torch.nn.Module`): The object of layer to shard + module (:class:`torch.nn.Module`): The object of layer to shard attr_replacement (Dict): The attribute dict to modify """ for k, v in attr_replacement.items(): @@ -126,7 +127,7 @@ class ModelSharder(object): Replace the parameter of the layer Args: - layer (:class:`torch.nn.Module`): The object of layer to shard + module (:class:`torch.nn.Module`): The object of layer to shard param_replacement (List[Callable]): The function list to get parameter shard information in policy """ for param_func in param_replacement: diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 9e5608e7f..552c6e2f4 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -65,6 +65,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo assert torch.allclose( org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" + torch.cuda.empty_cache() @parameterize('enable_fused_normalization', [True, False]) @@ -77,6 +78,7 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py new file mode 100644 index 000000000..5f92f638f --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py @@ -0,0 +1,77 @@ +import pytest +import torch + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # TODO: add tests for forward/backward later + pass + + +@parameterize('enable_fused_normalization', [False]) +@parameterize('enable_tensor_parallelism', [False]) +@parameterize('use_lazy_init', [False]) +#TODO: merge this into test_shard_gpt2 +def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + + sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name != "transformers_gpt": + continue + + inputs = data_gen_fn() + inputs = {k: v.cuda() for k, v in inputs.items()} + + org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + org_model.train() + org_output = org_model(**inputs) + hidden_state_shape = org_output['last_hidden_state'].shape + + if stage_manager.is_first_stage(): + output = sharded_model(**inputs) + assert output['hidden_states'].shape == hidden_state_shape + else: + attention_mask = inputs['attention_mask'] + hidden_states = torch.zeros(*hidden_state_shape).cuda() + output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask) + if stage_manager.is_last_stage(): + assert output['last_hidden_state'].shape == hidden_state_shape + else: + assert output['hidden_states'].shape == hidden_state_shape + + torch.cuda.empty_cache() + + +def check_gpt2(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_gpt2_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_gpt2(): + spawn(check_gpt2, 4) + + +if __name__ == "__main__": + test_gpt2()