[pipeline] Add Pipeline Forward for GPT2Model Shardformer (#4224)

* * fix typehint & docstring in sharder.py

* * update pipeline forward for GPT2Model

* * add test for pipeline forward of GPT2Model

* * add cache cleaning in gpt2 test

* * change assert to raise command
pull/4445/head
Baizhou Zhang 2023-07-13 15:34:06 +08:00 committed by Hongxin Liu
parent 37d22f6878
commit 208ac8f2ba
5 changed files with 357 additions and 9 deletions

View File

@ -129,7 +129,7 @@ class Linear1D_Col(ParallelModule):
**kwargs) **kwargs)
with torch.no_grad(): with torch.no_grad():
# the weigh to the linear layer is a transpose # the weight to the linear layer is a transpose
# thus shard on row is equal to shard on column # thus shard on row is equal to shard on column
sharded_weight = shard_rowwise(module.weight.data, process_group) sharded_weight = shard_rowwise(module.weight.data, process_group)
linear_1d.weight.data.copy_(sharded_weight) linear_1d.weight.data.copy_(sharded_weight)

View File

@ -1,6 +1,14 @@
import torch.nn as nn import logging
from functools import partial
from types import MethodType
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.nn import Module
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from colossalai.pipeline.stage_manager import PipelineStageManager
from .._utils import getattr_, setattr_ from .._utils import getattr_, setattr_
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -119,6 +127,46 @@ class GPT2ModelPolicy(GPT2Policy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
policy = super().module_policy()
if self.pipeline_stage_manager:
# set None as default
stage_manager = self.pipeline_stage_manager
layers_per_stage = Policy.distribute_layers(len(self.model.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
'forward':
partial(GPT2PipelineForwards.gpt2_model_forward,
stage_manager=stage_manager,
stage_index=stage_index)
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=GPT2Model)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
module = self.model
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.wte)
held_layers.append(module.wpe)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
# TODO: check whether there is shared param in gpt2model
"""No shared params in gpt2 model."""
return []
# GPT2LMHeadModel # GPT2LMHeadModel
class GPT2LMHeadModelPolicy(GPT2Policy): class GPT2LMHeadModelPolicy(GPT2Policy):
@ -194,3 +242,223 @@ class GPT2ForSequenceClassificationPolicy(GPT2Policy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
class GPT2PipelineForwards:
'''
This class serves as a micro library for forward function substitution of GPT2 models
under pipeline setting.
'''
@staticmethod
def gpt2_model_forward(
self: 'GPT2Model',
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'BaseModelOutputWithPastAndCrossAttentions']:
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
# Please refer to original code of transformers for more details.
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
# Preprocess passed in arguments
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
input_shape = input_ids.size()
input_ids = input_ids.view(-1, seq_length)
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, seq_length)
else:
if hidden_states is None:
raise ValueError("hidden_states shouln't be None for stages other than the first stage.")
input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape[0], input_shape[1]
device = hidden_states.device
# GPT2Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if stage_manager.is_first_stage():
if position_ids is not None:
position_ids = position_ids.view(-1, seq_length)
else:
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logging.warning('Non-empty past_key_values is not supported for pipeline models at the moment.')
past_key_values = None
if output_attentions:
logging.warning('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logging.warning('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if use_cache:
logging.warning('use_cache=True is not supported for pipeline models at the moment.')
use_cache = False
if self.gradient_checkpointing and self.training:
if use_cache:
logging.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
# Going through held blocks.
start_idx, end_idx = stage_index[0], stage_index[1]
for i in range(start_idx, end_idx):
block = self.h[i]
torch.cuda.set_device(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=None,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
if stage_manager.is_last_stage():
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage():
if not return_dict:
return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
else:
# always return dict for intermediate stage
return {'hidden_states': hidden_states}

View File

@ -72,17 +72,18 @@ class ModelSharder(object):
attr_replacement: Dict[str, Any], attr_replacement: Dict[str, Any],
param_replacement: List[Callable], param_replacement: List[Callable],
method_replacement: Dict[str, Callable], method_replacement: Dict[str, Callable],
sub_module_replacement: List[Callable], sub_module_replacement: List[SubModuleReplacementDescription],
) -> None: ) -> None:
r""" r"""
Reverse the replace layer operation Reverse the replace layer operation
Args: Args:
layer (torch.nn.Module): The object of layer to shard module (torch.nn.Module): The object of layer to shard
origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name. origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name
attr_replacement (Dict): The attribute dict to modify attr_replacement (Dict[str, Any]): The attribute dict to modify
param_replacement (List[Callable]): The function list to get parameter shard information in policy param_replacement (List[Callable]): The function list to get parameter shard information in policy
sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement
sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy
""" """
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \ if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
(module.__class__ == origin_cls): (module.__class__ == origin_cls):
@ -111,7 +112,7 @@ class ModelSharder(object):
Replace the attribute of the layer Replace the attribute of the layer
Args: Args:
layer (:class:`torch.nn.Module`): The object of layer to shard module (:class:`torch.nn.Module`): The object of layer to shard
attr_replacement (Dict): The attribute dict to modify attr_replacement (Dict): The attribute dict to modify
""" """
for k, v in attr_replacement.items(): for k, v in attr_replacement.items():
@ -126,7 +127,7 @@ class ModelSharder(object):
Replace the parameter of the layer Replace the parameter of the layer
Args: Args:
layer (:class:`torch.nn.Module`): The object of layer to shard module (:class:`torch.nn.Module`): The object of layer to shard
param_replacement (List[Callable]): The function list to get parameter shard information in policy param_replacement (List[Callable]): The function list to get parameter shard information in policy
""" """
for param_func in param_replacement: for param_func in param_replacement:

View File

@ -65,6 +65,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
assert torch.allclose( assert torch.allclose(
org_grad, all_shard_grad, org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
torch.cuda.empty_cache()
@parameterize('enable_fused_normalization', [True, False]) @parameterize('enable_fused_normalization', [True, False])
@ -77,6 +78,7 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init) use_lazy_init)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -0,0 +1,77 @@
import pytest
import torch
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# TODO: add tests for forward/backward later
pass
@parameterize('enable_fused_normalization', [False])
@parameterize('enable_tensor_parallelism', [False])
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_gpt2
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name != "transformers_gpt":
continue
inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
org_model.train()
org_output = org_model(**inputs)
hidden_state_shape = org_output['last_hidden_state'].shape
if stage_manager.is_first_stage():
output = sharded_model(**inputs)
assert output['hidden_states'].shape == hidden_state_shape
else:
attention_mask = inputs['attention_mask']
hidden_states = torch.zeros(*hidden_state_shape).cuda()
output = sharded_model(hidden_states=hidden_states, attention_mask=attention_mask)
if stage_manager.is_last_stage():
assert output['last_hidden_state'].shape == hidden_state_shape
else:
assert output['hidden_states'].shape == hidden_state_shape
torch.cuda.empty_cache()
def check_gpt2(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_gpt2_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_gpt2():
spawn(check_gpt2, 4)
if __name__ == "__main__":
test_gpt2()