[autoparallel] complete gpt related module search (#2097)

pull/2101/head
YuliangLiu0306 2022-12-08 10:04:09 +08:00 committed by GitHub
parent 85efb7ac2e
commit 3af7e65dea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 173 additions and 53 deletions

View File

@ -64,20 +64,14 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
last_physical_output_dims = output_op_data.data.dim() - 1
if last_logical_input_dims in input_sharding_spec.dim_partition_dict:
update_partition_dim(
sharding_spec=input_sharding_spec,
dim_mapping={last_logical_input_dims: last_physical_input_dims},
physical_shape=input_op_data.data.shape,
inplace=True,
)
input_last_dim_mapping = {last_logical_input_dims: last_physical_input_dims}
else:
input_last_dim_mapping = {}
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
update_partition_dim(
sharding_spec=output_sharding_spec,
dim_mapping={last_logical_output_dims: last_physical_output_dims},
physical_shape=output_op_data.data.shape,
inplace=True,
)
output_last_dim_mapping = {last_logical_output_dims: last_physical_output_dims}
else:
output_last_dim_mapping = {}
# get logger for debug message
logger = get_dist_logger()
@ -97,12 +91,18 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
try:
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
input_dim_mapping = {0: i}
input_dim_mapping.update(input_last_dim_mapping)
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={0: i},
dim_mapping=input_dim_mapping,
physical_shape=input_op_data.data.shape,
inplace=True)
output_dim_mapping = {0: i}
output_dim_mapping.update(output_last_dim_mapping)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={0: i},
dim_mapping=output_dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True)
strategy_copy.name = f'{strategy.name}_{i}'
@ -120,12 +120,17 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
# after updating, the logical shape will be replaced by the physical shape
input_dim_mapping = {}
input_dim_mapping.update(input_last_dim_mapping)
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={},
dim_mapping=input_dim_mapping,
physical_shape=input_op_data.data.shape,
inplace=True)
output_dim_mapping = {}
output_dim_mapping.update(output_last_dim_mapping)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={},
dim_mapping=output_dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True)
sharding_strategies.append(strategy_copy)

View File

@ -26,18 +26,21 @@ from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
def check_linear_module_handler(rank, bias, world_size, port):
def check_linear_module_handler(rank, bias, input_shape, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
input = torch.rand(4, 4, 4, 16).cuda()
input = torch.rand(input_shape).cuda()
# the index of linear node in computation graph
node_index = 1
# strategy number of linear node
strategy_number = 24
if input_shape == (1, 4, 4, 16):
strategy_number = 19
else:
strategy_number = 24
# construct input args
input_args = [input]
# construct meta arg names
@ -50,7 +53,7 @@ def check_linear_module_handler(rank, bias, world_size, port):
meta_arg_names=meta_arg_names)
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 4, 16).to('meta')})
graph = tracer.trace(model, meta_args={"input": torch.rand(input_shape).to('meta')})
gm = ColoGraphModule(model, graph)
linear_mod_node = list(graph.nodes)[1]
@ -69,9 +72,10 @@ def check_linear_module_handler(rank, bias, world_size, port):
assert op_data.data is not None
assert mapping['input'].name == "input_1"
assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16])
assert mapping['input'].data.shape == torch.Size(input_shape)
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([64, 16])
input_logical_shape = mapping['input'].data.view(-1, 16).shape
assert mapping['input'].logical_shape == input_logical_shape
assert mapping['other'].name == "weight"
assert mapping['other'].data.shape == torch.Size([32, 16])
@ -85,28 +89,32 @@ def check_linear_module_handler(rank, bias, world_size, port):
assert mapping['bias'].logical_shape == torch.Size([32])
assert mapping['output'].name == "_0"
assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32])
output_shape = input_shape[:-1] + (32,)
assert mapping['output'].data.shape == torch.Size(output_shape)
assert mapping['output'].type == OperationDataType.OUTPUT
assert mapping['output'].logical_shape == torch.Size([64, 32])
output_logical_shape = mapping['output'].data.view(-1, 32).shape
assert mapping['output'].logical_shape == torch.Size(output_logical_shape)
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
# one strategy will be converted to different physical sharding spec
assert len(strategy_name_list) > 8
# First dimension cannot be shard if input shape is (1, 4, 4, 16)
if input_shape != (1, 4, 4, 16):
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
assert 'S01R = S01R x RR_0' in strategy_name_list
# SS = SR x RS
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
assert 'S0S1 = S0R x RS1_1' in strategy_name_list
assert 'S0S1 = S0R x RS1_2' in strategy_name_list
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
assert 'S1S0 = S1R x RS0_1' in strategy_name_list
assert 'S1S0 = S1R x RS0_2' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
assert 'S0R = S0S1 x S1R_1' in strategy_name_list
assert 'S0R = S0S1 x S1R_2' in strategy_name_list
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
assert 'S1R = S1S0 x S0R_1' in strategy_name_list
assert 'S1R = S1S0 x S0R_2' in strategy_name_list
@ -123,7 +131,6 @@ def check_linear_module_handler(rank, bias, world_size, port):
assert 'RS1 = RR x RS1' in strategy_name_list
# S01R = S01R x RR
assert 'S01R = S01R x RR_0' in strategy_name_list
assert 'S01R = S01R x RR_1' in strategy_name_list
assert 'S01R = S01R x RR_2' in strategy_name_list
@ -164,7 +171,7 @@ class LinearModel(nn.Module):
return x
def check_linear_function_handler(rank, bias, world_size, port):
def check_linear_function_handler(rank, bias, input_shape, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearModel().cuda()
@ -172,12 +179,15 @@ def check_linear_function_handler(rank, bias, world_size, port):
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
input = torch.rand(4, 4, 4, 16).cuda()
input = torch.rand(input_shape).cuda()
other = torch.rand(32, 16).cuda()
# the index of linear node in computation graph
node_index = 2
# strategy number of linear node
strategy_number = 24
if input_shape == (1, 4, 4, 16):
strategy_number = 19
else:
strategy_number = 24
# construct input args
input_args = [input, other]
# construct meta arg names
@ -192,7 +202,7 @@ def check_linear_function_handler(rank, bias, world_size, port):
tracer = ColoTracer()
graph = tracer.trace(model,
meta_args={
"input": torch.rand(4, 4, 4, 16).to('meta'),
"input": torch.rand(input_shape).to('meta'),
'others': torch.rand(32, 16).to('meta')
})
gm = ColoGraphModule(model, graph)
@ -209,9 +219,10 @@ def check_linear_function_handler(rank, bias, world_size, port):
mapping = handler.get_operation_data_mapping()
assert mapping['input'].name == "input_1"
assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16])
assert mapping['input'].data.shape == torch.Size(input_shape)
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([64, 16])
input_logical_shape = mapping['input'].data.view(-1, 16).shape
assert mapping['input'].logical_shape == torch.Size(input_logical_shape)
assert mapping['other'].name == "others"
assert mapping['other'].data.shape == torch.Size([32, 16])
@ -225,27 +236,32 @@ def check_linear_function_handler(rank, bias, world_size, port):
assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['output'].name == "linear"
assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32])
output_shape = input_shape[:-1] + (32,)
assert mapping['output'].data.shape == torch.Size(output_shape)
assert mapping['output'].type == OperationDataType.OUTPUT
output_logical_shape = mapping['output'].data.view(-1, 32).shape
assert mapping['output'].logical_shape == torch.Size(output_logical_shape)
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
# one strategy will be converted to different physical sharding spec
assert len(strategy_name_list) > 8
# First dimension cannot be shard if input shape is (1, 4, 4, 16)
if input_shape != (1, 4, 4, 16):
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
assert 'S01R = S01R x RR_0' in strategy_name_list
# SS = SR x RS
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
assert 'S0S1 = S0R x RS1_1' in strategy_name_list
assert 'S0S1 = S0R x RS1_2' in strategy_name_list
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
assert 'S1S0 = S1R x RS0_1' in strategy_name_list
assert 'S1S0 = S1R x RS0_2' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
assert 'S0R = S0S1 x S1R_1' in strategy_name_list
assert 'S0R = S0S1 x S1R_2' in strategy_name_list
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
assert 'S1R = S1S0 x S0R_1' in strategy_name_list
assert 'S1R = S1S0 x S0R_2' in strategy_name_list
@ -262,7 +278,6 @@ def check_linear_function_handler(rank, bias, world_size, port):
assert 'RS1 = RR x RS1' in strategy_name_list
# S01R = S01R x RR
assert 'S01R = S01R x RR_0' in strategy_name_list
assert 'S01R = S01R x RR_1' in strategy_name_list
assert 'S01R = S01R x RR_2' in strategy_name_list
@ -293,15 +308,23 @@ def check_linear_function_handler(rank, bias, world_size, port):
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
# @parameterize('bias', [True, False])
@parameterize('input_shape', [(1, 4, 4, 16), (4, 4, 4, 16)])
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_handler(bias=False):
def test_linear_handler(input_shape, bias=False):
world_size = 4
run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port())
run_func_module = partial(check_linear_module_handler,
bias=bias,
input_shape=input_shape,
world_size=world_size,
port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
run_func_function = partial(check_linear_function_handler, bias=bias, world_size=world_size, port=free_port())
run_func_function = partial(check_linear_function_handler,
bias=bias,
input_shape=input_shape,
world_size=world_size,
port=free_port())
mp.spawn(run_func_function, nprocs=world_size)

View File

@ -1,11 +1,14 @@
from typing import Optional, Tuple, Union
import torch
# from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
import torch.nn as nn
import transformers
from torch.fx import GraphModule
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
from transformers.models.gpt2.modeling_gpt2 import (
GPT2MLP,
BaseModelOutputWithPastAndCrossAttentions,
GPT2PreTrainedModel,
)
from transformers.pytorch_utils import Conv1D
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
@ -173,8 +176,91 @@ class GPT2Block(nn.Module):
return outputs # hidden_states, present, (attentions, cross_attentions)
class GPT2Model(GPT2PreTrainedModel):
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
def __init__(self, config):
super().__init__(config)
self.embed_dim = config.hidden_size
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
device = input_ids.device
token_type_ids = token_type_ids.view(-1, input_shape[-1])
past_length = 0
past_key_values = tuple([None] * len(self.h))
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# GPT2Attention mask.
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
encoder_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
# add_2
hidden_states = inputs_embeds + position_embeds
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
# transformer_drop
hidden_states = self.drop(hidden_states)
# comment to run pipeline
# add_3
output_shape = input_shape + (hidden_states.size(-1),)
presents = None
all_self_attentions = None
all_cross_attentions = None
all_hidden_states = None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i])
hidden_states = outputs[0]
hidden_states = self.ln_f(hidden_states)
# comment to run pipeline
hidden_states = hidden_states.view(output_shape)
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None)
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP])
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
def test_self_attention_block(model_cls):
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
if model_cls == GPT2MLP:
@ -193,11 +279,17 @@ def test_self_attention_block(model_cls):
input_sample = {
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
}
else:
elif model_cls in (GPT2Attention, GPT2Block):
input_sample = {
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'),
}
else:
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
input_sample = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=input_sample)