mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] complete gpt related module search (#2097)
parent
85efb7ac2e
commit
3af7e65dea
|
@ -64,20 +64,14 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
|
|||
last_physical_output_dims = output_op_data.data.dim() - 1
|
||||
|
||||
if last_logical_input_dims in input_sharding_spec.dim_partition_dict:
|
||||
update_partition_dim(
|
||||
sharding_spec=input_sharding_spec,
|
||||
dim_mapping={last_logical_input_dims: last_physical_input_dims},
|
||||
physical_shape=input_op_data.data.shape,
|
||||
inplace=True,
|
||||
)
|
||||
input_last_dim_mapping = {last_logical_input_dims: last_physical_input_dims}
|
||||
else:
|
||||
input_last_dim_mapping = {}
|
||||
|
||||
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
|
||||
update_partition_dim(
|
||||
sharding_spec=output_sharding_spec,
|
||||
dim_mapping={last_logical_output_dims: last_physical_output_dims},
|
||||
physical_shape=output_op_data.data.shape,
|
||||
inplace=True,
|
||||
)
|
||||
output_last_dim_mapping = {last_logical_output_dims: last_physical_output_dims}
|
||||
else:
|
||||
output_last_dim_mapping = {}
|
||||
|
||||
# get logger for debug message
|
||||
logger = get_dist_logger()
|
||||
|
@ -97,12 +91,18 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
|
|||
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
|
||||
try:
|
||||
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
|
||||
input_dim_mapping = {0: i}
|
||||
input_dim_mapping.update(input_last_dim_mapping)
|
||||
|
||||
update_partition_dim(sharding_spec=input_sharding_spec,
|
||||
dim_mapping={0: i},
|
||||
dim_mapping=input_dim_mapping,
|
||||
physical_shape=input_op_data.data.shape,
|
||||
inplace=True)
|
||||
output_dim_mapping = {0: i}
|
||||
output_dim_mapping.update(output_last_dim_mapping)
|
||||
|
||||
update_partition_dim(sharding_spec=output_sharding_spec,
|
||||
dim_mapping={0: i},
|
||||
dim_mapping=output_dim_mapping,
|
||||
physical_shape=output_op_data.data.shape,
|
||||
inplace=True)
|
||||
strategy_copy.name = f'{strategy.name}_{i}'
|
||||
|
@ -120,12 +120,17 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
|
|||
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
|
||||
|
||||
# after updating, the logical shape will be replaced by the physical shape
|
||||
input_dim_mapping = {}
|
||||
input_dim_mapping.update(input_last_dim_mapping)
|
||||
update_partition_dim(sharding_spec=input_sharding_spec,
|
||||
dim_mapping={},
|
||||
dim_mapping=input_dim_mapping,
|
||||
physical_shape=input_op_data.data.shape,
|
||||
inplace=True)
|
||||
|
||||
output_dim_mapping = {}
|
||||
output_dim_mapping.update(output_last_dim_mapping)
|
||||
update_partition_dim(sharding_spec=output_sharding_spec,
|
||||
dim_mapping={},
|
||||
dim_mapping=output_dim_mapping,
|
||||
physical_shape=output_op_data.data.shape,
|
||||
inplace=True)
|
||||
sharding_strategies.append(strategy_copy)
|
||||
|
|
|
@ -26,18 +26,21 @@ from colossalai.utils import free_port
|
|||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
def check_linear_module_handler(rank, bias, world_size, port):
|
||||
def check_linear_module_handler(rank, bias, input_shape, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
input = torch.rand(4, 4, 4, 16).cuda()
|
||||
input = torch.rand(input_shape).cuda()
|
||||
# the index of linear node in computation graph
|
||||
node_index = 1
|
||||
# strategy number of linear node
|
||||
strategy_number = 24
|
||||
if input_shape == (1, 4, 4, 16):
|
||||
strategy_number = 19
|
||||
else:
|
||||
strategy_number = 24
|
||||
# construct input args
|
||||
input_args = [input]
|
||||
# construct meta arg names
|
||||
|
@ -50,7 +53,7 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
|||
meta_arg_names=meta_arg_names)
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 4, 16).to('meta')})
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(input_shape).to('meta')})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
|
||||
linear_mod_node = list(graph.nodes)[1]
|
||||
|
@ -69,9 +72,10 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
|||
assert op_data.data is not None
|
||||
|
||||
assert mapping['input'].name == "input_1"
|
||||
assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16])
|
||||
assert mapping['input'].data.shape == torch.Size(input_shape)
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size([64, 16])
|
||||
input_logical_shape = mapping['input'].data.view(-1, 16).shape
|
||||
assert mapping['input'].logical_shape == input_logical_shape
|
||||
|
||||
assert mapping['other'].name == "weight"
|
||||
assert mapping['other'].data.shape == torch.Size([32, 16])
|
||||
|
@ -85,28 +89,32 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
|||
assert mapping['bias'].logical_shape == torch.Size([32])
|
||||
|
||||
assert mapping['output'].name == "_0"
|
||||
assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32])
|
||||
output_shape = input_shape[:-1] + (32,)
|
||||
assert mapping['output'].data.shape == torch.Size(output_shape)
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
assert mapping['output'].logical_shape == torch.Size([64, 32])
|
||||
output_logical_shape = mapping['output'].data.view(-1, 32).shape
|
||||
assert mapping['output'].logical_shape == torch.Size(output_logical_shape)
|
||||
|
||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||
strategy_name_list = [val.name for val in strategies_vector]
|
||||
# one strategy will be converted to different physical sharding spec
|
||||
assert len(strategy_name_list) > 8
|
||||
|
||||
# First dimension cannot be shard if input shape is (1, 4, 4, 16)
|
||||
if input_shape != (1, 4, 4, 16):
|
||||
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
|
||||
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
|
||||
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
|
||||
assert 'S01R = S01R x RR_0' in strategy_name_list
|
||||
|
||||
# SS = SR x RS
|
||||
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
|
||||
assert 'S0S1 = S0R x RS1_1' in strategy_name_list
|
||||
assert 'S0S1 = S0R x RS1_2' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0_1' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0_2' in strategy_name_list
|
||||
|
||||
# SR = SS x SR
|
||||
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
|
||||
assert 'S0R = S0S1 x S1R_1' in strategy_name_list
|
||||
assert 'S0R = S0S1 x S1R_2' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R_1' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R_2' in strategy_name_list
|
||||
|
||||
|
@ -123,7 +131,6 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
|||
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||
|
||||
# S01R = S01R x RR
|
||||
assert 'S01R = S01R x RR_0' in strategy_name_list
|
||||
assert 'S01R = S01R x RR_1' in strategy_name_list
|
||||
assert 'S01R = S01R x RR_2' in strategy_name_list
|
||||
|
||||
|
@ -164,7 +171,7 @@ class LinearModel(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def check_linear_function_handler(rank, bias, world_size, port):
|
||||
def check_linear_function_handler(rank, bias, input_shape, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = LinearModel().cuda()
|
||||
|
@ -172,12 +179,15 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
|||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
input = torch.rand(4, 4, 4, 16).cuda()
|
||||
input = torch.rand(input_shape).cuda()
|
||||
other = torch.rand(32, 16).cuda()
|
||||
# the index of linear node in computation graph
|
||||
node_index = 2
|
||||
# strategy number of linear node
|
||||
strategy_number = 24
|
||||
if input_shape == (1, 4, 4, 16):
|
||||
strategy_number = 19
|
||||
else:
|
||||
strategy_number = 24
|
||||
# construct input args
|
||||
input_args = [input, other]
|
||||
# construct meta arg names
|
||||
|
@ -192,7 +202,7 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
|||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(4, 4, 4, 16).to('meta'),
|
||||
"input": torch.rand(input_shape).to('meta'),
|
||||
'others': torch.rand(32, 16).to('meta')
|
||||
})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
|
@ -209,9 +219,10 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
|||
mapping = handler.get_operation_data_mapping()
|
||||
|
||||
assert mapping['input'].name == "input_1"
|
||||
assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16])
|
||||
assert mapping['input'].data.shape == torch.Size(input_shape)
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size([64, 16])
|
||||
input_logical_shape = mapping['input'].data.view(-1, 16).shape
|
||||
assert mapping['input'].logical_shape == torch.Size(input_logical_shape)
|
||||
|
||||
assert mapping['other'].name == "others"
|
||||
assert mapping['other'].data.shape == torch.Size([32, 16])
|
||||
|
@ -225,27 +236,32 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
|||
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
||||
|
||||
assert mapping['output'].name == "linear"
|
||||
assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32])
|
||||
output_shape = input_shape[:-1] + (32,)
|
||||
assert mapping['output'].data.shape == torch.Size(output_shape)
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
output_logical_shape = mapping['output'].data.view(-1, 32).shape
|
||||
assert mapping['output'].logical_shape == torch.Size(output_logical_shape)
|
||||
|
||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||
strategy_name_list = [val.name for val in strategies_vector]
|
||||
# one strategy will be converted to different physical sharding spec
|
||||
assert len(strategy_name_list) > 8
|
||||
|
||||
# First dimension cannot be shard if input shape is (1, 4, 4, 16)
|
||||
if input_shape != (1, 4, 4, 16):
|
||||
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
|
||||
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
|
||||
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
|
||||
assert 'S01R = S01R x RR_0' in strategy_name_list
|
||||
|
||||
# SS = SR x RS
|
||||
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
|
||||
assert 'S0S1 = S0R x RS1_1' in strategy_name_list
|
||||
assert 'S0S1 = S0R x RS1_2' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0_1' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0_2' in strategy_name_list
|
||||
|
||||
# SR = SS x SR
|
||||
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
|
||||
assert 'S0R = S0S1 x S1R_1' in strategy_name_list
|
||||
assert 'S0R = S0S1 x S1R_2' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R_1' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R_2' in strategy_name_list
|
||||
|
||||
|
@ -262,7 +278,6 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
|||
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||
|
||||
# S01R = S01R x RR
|
||||
assert 'S01R = S01R x RR_0' in strategy_name_list
|
||||
assert 'S01R = S01R x RR_1' in strategy_name_list
|
||||
assert 'S01R = S01R x RR_2' in strategy_name_list
|
||||
|
||||
|
@ -293,15 +308,23 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
|||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
# @parameterize('bias', [True, False])
|
||||
@parameterize('input_shape', [(1, 4, 4, 16), (4, 4, 4, 16)])
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear_handler(bias=False):
|
||||
def test_linear_handler(input_shape, bias=False):
|
||||
world_size = 4
|
||||
run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port())
|
||||
run_func_module = partial(check_linear_module_handler,
|
||||
bias=bias,
|
||||
input_shape=input_shape,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
run_func_function = partial(check_linear_function_handler, bias=bias, world_size=world_size, port=free_port())
|
||||
run_func_function = partial(check_linear_function_handler,
|
||||
bias=bias,
|
||||
input_shape=input_shape,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func_function, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
# from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from torch.fx import GraphModule
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
|
||||
from transformers.models.gpt2.modeling_gpt2 import (
|
||||
GPT2MLP,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
GPT2PreTrainedModel,
|
||||
)
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
||||
|
@ -173,8 +176,91 @@ class GPT2Block(nn.Module):
|
|||
return outputs # hidden_states, present, (attentions, cross_attentions)
|
||||
|
||||
|
||||
class GPT2Model(GPT2PreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
device = input_ids.device
|
||||
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
|
||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
# GPT2Attention mask.
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
|
||||
encoder_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
# add_2
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
hidden_states = hidden_states + token_type_embeds
|
||||
|
||||
# transformer_drop
|
||||
hidden_states = self.drop(hidden_states)
|
||||
# comment to run pipeline
|
||||
# add_3
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
presents = None
|
||||
all_self_attentions = None
|
||||
all_cross_attentions = None
|
||||
all_hidden_states = None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i])
|
||||
hidden_states = outputs[0]
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
# comment to run pipeline
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||
if v is not None)
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP])
|
||||
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
|
||||
def test_self_attention_block(model_cls):
|
||||
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
|
||||
if model_cls == GPT2MLP:
|
||||
|
@ -193,11 +279,17 @@ def test_self_attention_block(model_cls):
|
|||
input_sample = {
|
||||
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
|
||||
}
|
||||
else:
|
||||
elif model_cls in (GPT2Attention, GPT2Block):
|
||||
input_sample = {
|
||||
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
|
||||
'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'),
|
||||
}
|
||||
else:
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
input_sample = {k: v.to('meta') for k, v in kwargs.items()}
|
||||
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
|
Loading…
Reference in New Issue