[shardformer] rewrite tests for opt/bloom/llama/vit/chatglm (#4395)

* rewrite opt tests

* rewrite llama tests

* rewrite bloom & vit tests

* rewrite chatglm tests

* fix LinearCol for classfiers

* add judge for other tp layers, fix lazy init in util
pull/4445/head
Baizhou Zhang 2023-08-11 15:43:23 +08:00 committed by Hongxin Liu
parent 21e0a42fd1
commit 7711bd524a
19 changed files with 1064 additions and 1273 deletions

View File

@ -143,6 +143,14 @@ class Linear1D_Col(ParallelModule):
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
if out_features < tp_size:
return module
if out_features % tp_size != 0:
raise ValueError(
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!")
linear_1d = Linear1D_Col(in_features=in_features,
out_features=out_features,
bias=bias,
@ -293,6 +301,14 @@ class Linear1D_Row(ParallelModule):
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
if in_features < tp_size:
return module
if in_features % tp_size != 0:
raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
linear_1d = Linear1D_Row(in_features=in_features,
out_features=out_features,
bias=bias,

View File

@ -265,6 +265,14 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
if out_features < tp_size:
return module
if out_features % tp_size != 0:
raise ValueError(
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!")
linear_1d = GPT2FusedLinearConv1D_Col(in_features=in_features,
out_features=out_features,
bias=bias,
@ -420,6 +428,14 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
if in_features < tp_size:
return module
if in_features % tp_size != 0:
raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
linear_1d = GPT2FusedLinearConv1D_Row(in_features=in_features,
out_features=out_features,
bias=bias,

View File

@ -1,7 +1,500 @@
from typing import Optional, Tuple
import random
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
)
from transformers.models.opt.modeling_opt import (
OPTForCausalLM,
OPTForQuestionAnswering,
OPTForSequenceClassification,
OPTModel,
)
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
class OPTPipelineForwards:
'''
This class serves as a micro library for forward function substitution of OPT models
under pipeline setting.
'''
@staticmethod
def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
from transformers.models.opt.modeling_opt import _make_causal_mask
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
_dtype,
device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype,
tgt_len=input_shape[-1]).to(device)
combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask +
combined_attention_mask)
return combined_attention_mask
@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
@staticmethod
def opt_model_forward(
self: OPTModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
'''
This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
'''
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import logging
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
decoder = self.decoder
if stage_manager.is_first_stage():
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
batch_size, seq_length = input_shape
if inputs_embeds is None:
inputs_embeds = decoder.embed_tokens(input_ids)
if decoder.project_in is not None:
inputs_embeds = decoder.project_in(inputs_embeds)
device = input_ids.device if input_ids is not None else inputs_embeds.device
_dtype = inputs_embeds.dtype
else:
if hidden_states is None:
raise ValueError("hidden_states shouln't be None for intermediate stages.")
input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape[0], input_shape[1]
device = hidden_states.device
_dtype = hidden_states.dtype
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values_length + seq_length
# embed positions
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)")
causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask(attention_mask, input_shape, _dtype,
device, past_key_values_length)
if stage_manager.is_first_stage():
pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length)
hidden_states = inputs_embeds + pos_embeds
if decoder.gradient_checkpointing and decoder.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
past_key_values = None
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if use_cache:
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != (len(decoder.layers)):
raise ValueError(
f"The `{mask_name}` should be specified for {len(decoder.layers)} layers, but it is for"
f" {head_mask.size()[0]}.")
start_idx, end_idx = stage_index[0], stage_index[1]
torch.cuda.set_device(device)
for idx in range(start_idx, end_idx):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
decoder_layer = decoder.layers[idx]
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if decoder.training and (dropout_probability < decoder.layerdrop):
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if decoder.gradient_checkpointing and decoder.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
causal_attention_mask,
head_mask[idx] if head_mask is not None else None,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if stage_manager.is_last_stage():
if decoder.final_layer_norm is not None:
hidden_states = decoder.final_layer_norm(hidden_states)
if decoder.project_out is not None:
hidden_states = decoder.project_out(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if stage_manager.is_last_stage():
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
else:
return {'hidden_states': hidden_states}
@staticmethod
def opt_for_causal_lm_forward(
self: OPTForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForCausalLM.forward.
Please refer to original code of transformers for more details.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = OPTPipelineForwards.opt_model_forward(
self.model,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
)
if stage_manager.is_last_stage():
logits = self.lm_head(outputs[0]).contiguous()
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
hidden_states = outputs.get('hidden_states')
return {'hidden_states': hidden_states}
@staticmethod
def opt_for_sequence_classification_forward(
self: OPTForSequenceClassification,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForSequenceClassification.forward.
Please refer to original code of transformers for more details.
"""
logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
batch_size = input_ids.shape[0] if input_ids is not None else hidden_states.shape[0]
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`")
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
else:
hidden_states = transformer_outputs.get('hidden_states')
return {'hidden_states': hidden_states}
@staticmethod
def opt_for_question_answering_forward(
self: OPTForQuestionAnswering,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForQuestionAnswering.forward.
Please refer to original code of transformers for more details.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + transformer_outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
else:
hidden_states = transformer_outputs.get('hidden_states')
return {'hidden_states': hidden_states}
def get_opt_flash_attention_forward():

View File

@ -122,6 +122,12 @@ _POLICY_LIST = {
PolicyLocation(file_name="blip2", class_name="Blip2ModelPolicy"),
"transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration":
PolicyLocation(file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"),
# ChatGLM
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel":
PolicyLocation(file_name="chatglm", class_name="ChatGLMModelPolicy"),
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration":
PolicyLocation(file_name="chatglm", class_name="ChatGLMForConditionalGenerationPolicy"),
}

View File

@ -1,32 +1,14 @@
import logging
import random
from functools import partial
from types import MethodType
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List
import torch
import torch.nn as nn
from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
)
from transformers.models.opt.modeling_opt import (
OPTForCausalLM,
OPTForQuestionAnswering,
OPTForSequenceClassification,
OPTModel,
)
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .._utils import getattr_, setattr_
from .._utils import getattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.opt import get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward
from ..modeling.opt import OPTPipelineForwards, get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@ -228,6 +210,7 @@ class OPTForCausalLMPolicy(OPTPolicy):
num_stages = self.pipeline_stage_manager.num_stages
if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight):
return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}]
return []
def postprocess(self):
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
@ -295,594 +278,3 @@ class OPTForQuestionAnsweringPolicy(OPTPolicy):
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"no shared params in OPTForSequenceClassification"
return []
class OPTPipelineForwards:
'''
This class serves as a micro library for forward function substitution of OPT models
under pipeline setting.
'''
@staticmethod
def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
from transformers.models.opt.modeling_opt import _make_causal_mask
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
_dtype,
device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype,
tgt_len=input_shape[-1]).to(device)
combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask +
combined_attention_mask)
return combined_attention_mask
@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
@staticmethod
def opt_model_forward(
self: OPTModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
'''
This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
'''
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import logging
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
decoder = self.decoder
if stage_manager.is_first_stage():
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
batch_size, seq_length = input_shape
if inputs_embeds is None:
inputs_embeds = decoder.embed_tokens(input_ids)
if decoder.project_in is not None:
inputs_embeds = decoder.project_in(inputs_embeds)
device = input_ids.device if input_ids is not None else inputs_embeds.device
_dtype = inputs_embeds.dtype
else:
if hidden_states is None:
raise ValueError("hidden_states shouln't be None for intermediate stages.")
input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape[0], input_shape[1]
device = hidden_states.device
_dtype = hidden_states.dtype
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values_length + seq_length
# embed positions
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)")
causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask(attention_mask, input_shape, _dtype,
device, past_key_values_length)
if stage_manager.is_first_stage():
pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length)
hidden_states = inputs_embeds + pos_embeds
if decoder.gradient_checkpointing and decoder.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
past_key_values = None
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if use_cache:
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != (len(decoder.layers)):
raise ValueError(
f"The `{mask_name}` should be specified for {len(decoder.layers)} layers, but it is for"
f" {head_mask.size()[0]}.")
start_idx, end_idx = stage_index[0], stage_index[1]
torch.cuda.set_device(device)
for idx in range(start_idx, end_idx):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
decoder_layer = decoder.layers[idx]
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if decoder.training and (dropout_probability < decoder.layerdrop):
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if decoder.gradient_checkpointing and decoder.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
causal_attention_mask,
head_mask[idx] if head_mask is not None else None,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if stage_manager.is_last_stage():
if decoder.final_layer_norm is not None:
hidden_states = decoder.final_layer_norm(hidden_states)
if decoder.project_out is not None:
hidden_states = decoder.project_out(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if stage_manager.is_last_stage():
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
else:
return {'hidden_states': hidden_states}
@staticmethod
def opt_for_causal_lm_forward(
self: OPTForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, OPTForCausalLM
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
from transformers.modeling_outputs import CausalLMOutputWithPast
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = OPTPipelineForwards.opt_model_forward(
self.model,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
)
if stage_manager.is_last_stage():
logits = self.lm_head(outputs[0]).contiguous()
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
hidden_states = outputs.get('hidden_states')
return {'hidden_states': hidden_states}
@staticmethod
def opt_for_sequence_classification_forward(
self: OPTForSequenceClassification,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
from transformers.utils import logging
logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
batch_size = input_ids.shape[0] if input_ids is not None else hidden_states.shape[0]
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`")
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
else:
hidden_states = transformer_outputs.get('hidden_states')
return {'hidden_states': hidden_states}
@staticmethod
def opt_for_question_answering_forward(
self: OPTForQuestionAnswering,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, OPTForQuestionAnswering
>>> import torch
>>> torch.manual_seed(4) # doctest: +IGNORE_RESULT
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
>>> # note: we are loading a OPTForQuestionAnswering from the hub here,
>>> # so the head will be randomly initialized, hence the predictions will be random
>>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m")
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
>>> inputs = tokenizer(question, text, return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> answer_start_index = outputs.start_logits.argmax()
>>> answer_end_index = outputs.end_logits.argmax()
>>> answer_offset = len(tokenizer(question)[0])
>>> predict_answer_tokens = inputs.input_ids[
... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1
... ]
>>> predicted = tokenizer.decode(predict_answer_tokens)
>>> predicted
' a nice puppet'
```"""
from transformers.modeling_outputs import QuestionAnsweringModelOutput
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + transformer_outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
else:
hidden_states = transformer_outputs.get('hidden_states')
return {'hidden_states': hidden_states}

View File

@ -53,7 +53,8 @@ def data_gen_for_question_answering():
# inputs = tokenizer(question, text, return_tensors="pt")
input_ids = torch.tensor(
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], dtype=torch.int64)
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]],
dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
start_positions = torch.tensor([1], dtype=torch.int64)
end_positions = torch.tensor([10], dtype=torch.int64)
@ -73,12 +74,13 @@ loss_fn_for_causal_lm = lambda x: x.loss
loss_fn_for_classification = lambda x: x.loss
loss_fn_for_question_answering = lambda x: x.loss
config = transformers.BloomConfig(n_layer=1,
config = transformers.BloomConfig(n_layer=2,
n_head=4,
vocab_size=250880,
hidden_dropout=0,
attention_dropout=0,
hidden_size=64)
hidden_size=64,
pad_token_id=50256)
# register the following models
model_zoo.register(name='transformers_bloom',

View File

@ -17,14 +17,24 @@ def data_gen():
return dict(input_ids=input_ids, attention_mask=attention_mask)
def data_gen_for_conditional_generation():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
labels = data['input_ids'].clone()
data['labels'] = labels
return data
# define output transform function
output_transform_fn = lambda x: x
# define loss function
loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.sum()
loss_fn = lambda x: x.logits.sum()
loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state,
torch.ones_like(x.last_hidden_state))
loss_fn = lambda x: x.loss
config = ChatGLMConfig(num_layers=1,
config = ChatGLMConfig(num_layers=2,
padded_vocab_size=65024,
hidden_size=64,
num_attention_heads=8,
@ -33,7 +43,6 @@ config = ChatGLMConfig(num_layers=1,
use_cache=True,
torch_dtype=torch.float32)
model_zoo.register(name='transformers_chatglm',
model_fn=lambda: ChatGLMModel(config, empty_init=False),
data_gen_fn=data_gen,
@ -43,7 +52,7 @@ model_zoo.register(name='transformers_chatglm',
model_zoo.register(name="transformers_chatglm_for_conditional_generation",
model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
data_gen_fn=data_gen,
data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))

View File

@ -7,11 +7,7 @@ from ..registry import ModelAttribute, model_zoo
# Register single-sentence VIT
# ===============================
config = transformers.ViTConfig(
num_hidden_layers=4,
# hidden_size=128,
# intermediate_size=256,
num_attention_heads=4)
config = transformers.ViTConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4)
# define data gen function

View File

@ -104,27 +104,22 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
if 'use_lazy_init' in test_config:
use_lazy_init = test_config.pop('use_lazy_init')
if use_lazy_init:
ctx = LazyInitContext()
else:
ctx = nullcontext()
plugin = HybridParallelPlugin(**test_config)
booster = Booster(plugin=plugin)
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
org_model = model_fn().cuda()
org_model = model_fn()
sharded_model = copy.deepcopy(org_model)
if use_lazy_init:
org_model = ctx.materialize(org_model)
ctx.materialize(org_model)
org_model = org_model.cuda()
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
plugin = HybridParallelPlugin(**test_config)
booster = Booster(plugin=plugin)
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster
@ -142,11 +137,12 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
data = data_gen_fn()
sharded_model.train()
if booster.plugin.stage_manager is not None:
data = {
k: v.to('cuda').repeat(*([4] + [1] *
(v.dim() - 1))) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
for k, v in data.items()
}
for k, v in data.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
data[k] = v.to('cuda').repeat(*new_shape)
data_iter = iter([data])
sharded_output = booster.execute_pipeline(data_iter,
sharded_model,
@ -176,7 +172,8 @@ def check_output_hidden_state(org_output: Tensor,
sharded_output: Tensor,
stage_manager: Optional[PipelineStageManager] = None,
atol: float = 1e-5,
rtol: float = 1e-3):
rtol: float = 1e-3,
dim: int = 0):
org_hidden_state = org_output.last_hidden_state
@ -184,7 +181,7 @@ def check_output_hidden_state(org_output: Tensor,
sharded_hidden_state = sharded_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage():
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0)
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=dim)
assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"

View File

@ -3,57 +3,101 @@ import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
)
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# do backward
org_loss.backward()
shard_loss.backward()
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
assert torch.allclose(org_loss, shard_loss,
atol=1e-6), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if org_model.__class__.__name__ == 'BloomModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
# unwrap model
if org_model.__class__.__name__ == 'BloomModel':
bloom = org_model
sharded_bloom = sharded_model
sharded_bloom = sharded_model.unwrap()
else:
bloom = org_model.transformer
sharded_bloom = sharded_model.transformer
sharded_bloom = sharded_model.unwrap().transformer
# check grad
col_layer_for_check = ['h[0].self_attention.query_key_value']
row_layer_for_check = ['h[0].self_attention.dense']
check_grad(bloom, sharded_bloom, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
check_grad(bloom, sharded_bloom, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings']
col_layer_for_check = ['h[0].self_attention.dense']
if stage_manager is None or stage_manager.is_first_stage():
check_grad(bloom, sharded_bloom, row_layer_for_check, tp_group, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
check_grad(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False)
torch.cuda.empty_cache()
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
@parameterize('use_lazy_init', [False, True])
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused,
use_lazy_init):
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': True,
'use_lazy_init': True
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': False,
'use_lazy_init': False
}, {
'tp_size': 4,
'pp_size': 1,
'enable_fused_normalization': True,
'use_lazy_init': False
}])
def run_bloom_test(test_config):
# TODO: add test_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
# TODO: add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention, enable_jit_fused, use_lazy_init)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()
@ -67,7 +111,7 @@ def check_bloom(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bloom():
spawn(check_bloom, 2)
spawn(check_bloom, 4)
if __name__ == "__main__":

View File

@ -1,90 +0,0 @@
import pytest
import torch
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.auto_policy import get_autopolicy
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.shard import ShardConfig
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
def check_bloom_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager):
policy = get_autopolicy(model)
policy.set_model(model)
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
policy.set_shard_config(model_config)
layers = policy.get_held_layers()
if stage_manager.is_first_stage():
assert len(layers) == 0 + 2
else:
if name == 'transformers_bloom':
assert len(layers) == 1 + 1
elif name == 'transformers_bloom_for_token_classification':
assert len(layers) == 1 + 3
else:
assert len(layers) == 1 + 2
def check_bloom_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager):
if stage_manager.stage == 0:
x = torch.randint(0, 1000, (1, 3)).cuda()
attention_mask = torch.ones_like(x).cuda()
output = sharded_model(input_ids=x, attention_mask=attention_mask)
assert output['hidden_states'].shape == (1, 3, 64)
else:
attention_mask = torch.ones((1, 3)).cuda()
hidden_states = torch.randint(0, 1000, (1, 3, 64)).to(torch.float32).cuda()
output = sharded_model(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
assert output[0].shape[0] == 1
@parameterize('enable_fused_normalization', [False])
@parameterize('enable_tensor_parallelism', [False])
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_bloom
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
PP_DIM = 0
PP_SIZE = 2
pg_mesh = ProcessGroupMesh(PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
check_bloom_model_policy(name, org_model, stage_manager)
check_bloom_model_pipeline_forward(name, sharded_model, stage_manager)
torch.cuda.empty_cache()
def check_bloom(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_bloom_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bloom():
spawn(check_bloom, 2)
if __name__ == "__main__":
test_bloom()

View File

@ -1,99 +1,126 @@
import copy
import os
import pytest
import torch
from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
)
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
# do backward
org_loss.backward()
shard_loss.backward()
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if org_model.__class__.__name__ == 'ChatGLMModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3, dim=1)
check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
# unwrap model
if org_model.__class__.__name__ == 'ChatGLMModel':
chatglm_model = org_model
shard_chatglm_model = sharded_model
shard_chatglm_model = sharded_model.unwrap()
else:
chatglm_model = org_model.transformer
shard_chatglm_model = sharded_model.transformer
shard_chatglm_model = sharded_model.unwrap().transformer
# check attention grad
org_grad = chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad
shard_grad = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad
shard_weight = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight
# check grad
row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings']
col_layer_for_check = ['encoder.layers[0].self_attention.dense']
if stage_manager is None or stage_manager.is_first_stage():
check_grad(chatglm_model,
shard_chatglm_model,
row_layer_for_check,
tp_group,
atol=1e-6,
rtol=1e-3,
dim=0,
verbose=False)
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
check_grad(chatglm_model,
shard_chatglm_model,
col_layer_for_check,
tp_group,
atol=1e-6,
rtol=1e-3,
dim=1,
verbose=False)
# check embedding weights
org_grad = chatglm_model.embedding.word_embeddings.weight.grad
shard_grad = shard_chatglm_model.embedding.word_embeddings.weight.grad
shard_weight = shard_chatglm_model.embedding.word_embeddings.weight
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
check_weight(chatglm_model,
shard_chatglm_model,
col_layer_for_check,
tp_group,
atol=1e-4,
rtol=1e-3,
dim=1,
verbose=False)
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros_like(shard_grad) for _ in range(2)]
torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
torch.cuda.empty_cache()
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': True,
'use_lazy_init': True
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': False,
'use_lazy_init': False
}, {
'tp_size': 4,
'pp_size': 1,
'enable_fused_normalization': True,
'use_lazy_init': False
}])
def run_chatglm_test(test_config):
# TODO: add test_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
# TODO: add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
# create new model
org_model = model_fn().cuda()
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
# shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
if name == "transformers_chatglm":
sharded_model, _ = shard_former.optimize(model_copy, ChatGLMModelPolicy())
else:
sharded_model, _ = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy())
sharded_model = sharded_model.cuda()
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
clear_layout_converter()
torch.cuda.empty_cache()
@ -107,7 +134,7 @@ def check_chatglm(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_chatglm():
spawn(check_chatglm, 2)
spawn(check_chatglm, 4)
if __name__ == "__main__":

View File

@ -1,86 +0,0 @@
import copy
import os
import pytest
import torch
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
@parameterize('enable_fused_normalization', [False])
@parameterize('enable_tensor_parallelism', [False])
@parameterize('use_lazy_init', [False])
def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
# create new model for test
inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
input_ids = inputs['input_ids']
hidden_size = 64
batch_size, seq_len = input_ids.shape
hidden_state_shape = (seq_len, batch_size, hidden_size)
if name == "transformers_chatglm":
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init, ChatGLMModelPolicy())
if stage_manager.is_last_stage():
hidden_states = torch.randn(*hidden_state_shape).cuda()
inputs['input_ids'] = None
inputs['hidden_states'] = hidden_states
outputs = sharded_model(**inputs)
if stage_manager.is_last_stage():
assert outputs[0].shape == hidden_state_shape
else:
assert outputs['hidden_states'].shape == hidden_state_shape
if name == "transformers_chatglm_for_conditional_generation":
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init,
ChatGLMForConditionalGenerationPolicy())
if stage_manager.is_last_stage():
hidden_states = torch.randn(*hidden_state_shape).cuda()
inputs['input_ids'] = None
inputs['hidden_states'] = hidden_states
outputs = sharded_model(**inputs)
if stage_manager.is_last_stage():
assert outputs[0].shape == (batch_size, seq_len, 65024)
else:
assert outputs['hidden_states'].shape == hidden_state_shape
torch.cuda.empty_cache()
def check_chatglm(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_chatglm_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_chatglm():
spawn(check_chatglm, 4)
if __name__ == "__main__":
test_chatglm()

View File

@ -2,69 +2,139 @@ import os
import pytest
import torch
from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
)
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# forward check
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-5)
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
# run backward
org_loss.backward()
shard_loss.backward()
org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster)
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if org_model.__class__.__name__ == 'LlamaModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
# unwrap model
if hasattr(org_model, 'model'):
llama_model = org_model.model
shard_llama_model = sharded_model.model
else:
if org_model.__class__.__name__ == 'LlamaModel':
llama_model = org_model
shard_llama_model = sharded_model
shard_llama_model = sharded_model.unwrap()
else:
llama_model = org_model.model
shard_llama_model = sharded_model.unwrap().model
# check grad
col_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
row_layer_for_check = ['layers[0].self_attn.o_proj']
check_grad(llama_model, shard_llama_model, col_layer_for_check, atol=1e-6, rtol=1e-4, dim=0, verbose=False)
check_grad(llama_model, shard_llama_model, row_layer_for_check, atol=1e-6, rtol=1e-4, dim=1, verbose=False)
row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
col_layer_for_check = ['layers[0].self_attn.o_proj']
if stage_manager is None or stage_manager.is_first_stage():
check_grad(llama_model,
shard_llama_model,
row_layer_for_check,
tp_group,
atol=1e-6,
rtol=1e-4,
dim=0,
verbose=False)
check_grad(llama_model,
shard_llama_model,
col_layer_for_check,
tp_group,
atol=1e-6,
rtol=1e-4,
dim=1,
verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
check_weight(llama_model,
shard_llama_model,
col_layer_for_check,
tp_group,
atol=1e-4,
rtol=1e-3,
dim=1,
verbose=False)
torch.cuda.empty_cache()
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('use_lazy_init', [False, True])
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, use_lazy_init):
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_fused_normalization': True,
'use_lazy_init': True
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'use_lazy_init': False
}, {
'tp_size': 4,
'pp_size': 1,
'enable_fused_normalization': True,
'use_lazy_init': False
}, {
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
'use_lazy_init': False
}])
def run_llama_test(test_config):
# TODO: add test_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
# TODO: add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention, use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()
def check_llama(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_gpt2_llama()
run_llama_test()
@pytest.mark.dist

View File

@ -1,89 +0,0 @@
import pytest
import torch
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.auto_policy import get_autopolicy
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.shard import ShardConfig
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
def check_llama_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager):
policy = get_autopolicy(model)
policy.set_model(model)
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
policy.set_shard_config(model_config)
layers = policy.get_held_layers()
if stage_manager.is_first_stage():
assert len(layers) == 2 + 1
else:
if name == "transformers_llama":
assert len(layers) == 2 + 1
else:
assert len(layers) == 2 + 2
def check_llama_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager):
x = torch.randint(0, 1000, (2, 3)).cuda()
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x).cuda()
output = sharded_model(input_ids=x, attention_mask=attention_mask)
assert output['hidden_states'].shape == (2, 3, 128)
else:
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
attention_mask = torch.ones((2, 3)).cuda()
output = sharded_model(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
assert output[0] is not None
@parameterize('enable_fused_normalization', [False])
@parameterize('enable_tensor_parallelism', [False])
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_llama
def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
PP_DIM = 0
PP_SIZE = 2
pg_mesh = ProcessGroupMesh(PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
check_llama_model_policy(name, org_model, stage_manager)
check_llama_model_pipeline_forward(name, sharded_model, stage_manager)
torch.cuda.empty_cache()
def check_llama(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_llama_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama():
spawn(check_llama, 2)
if __name__ == "__main__":
test_llama()

View File

@ -1,64 +1,129 @@
import copy
import os
import pytest
import torch
from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
)
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-5)
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# run backward
org_loss.backward()
shard_loss.backward()
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if org_model.__class__.__name__ == 'OPTModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
# unwrap model
if hasattr(org_model, 'model'):
opt_model = org_model.model
shard_opt_model = sharded_model.model
else:
if org_model.__class__.__name__ == 'OPTModel':
opt_model = org_model
shard_opt_model = sharded_model
shard_opt_model = sharded_model.unwrap()
else:
opt_model = org_model.model
shard_opt_model = sharded_model.unwrap().model
# check grad
col_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens']
row_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
check_grad(opt_model, shard_opt_model, col_layer_for_check, atol=1e-6, rtol=1e-3, dim=0, verbose=False)
check_grad(opt_model, shard_opt_model, row_layer_for_check, atol=1e-6, rtol=1e-3, dim=1, verbose=False)
row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens']
col_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
if stage_manager is None or stage_manager.is_first_stage():
check_grad(opt_model,
shard_opt_model,
row_layer_for_check,
tp_group,
atol=1e-6,
rtol=1e-3,
dim=0,
verbose=False)
check_grad(opt_model,
shard_opt_model,
col_layer_for_check,
tp_group,
atol=1e-6,
rtol=1e-3,
dim=1,
verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
check_weight(opt_model,
shard_opt_model,
col_layer_for_check,
tp_group,
atol=1e-3,
rtol=1e-3,
dim=1,
verbose=False)
torch.cuda.empty_cache()
@parameterize('use_lazy_init', [False, True])
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
def run_opt_test(use_lazy_init, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention,
enable_jit_fused):
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': True,
'use_lazy_init': True
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': False,
'use_lazy_init': False
}, {
'tp_size': 4,
'pp_size': 1,
'enable_fused_normalization': True,
'use_lazy_init': False
}])
def run_opt_test(test_config):
# TODO: add test_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
# TODO: add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention, enable_jit_fused, use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()

View File

@ -1,70 +0,0 @@
import pytest
import torch
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_pipeline_model
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# TODO: add tests for forward/backward later
pass
@parameterize('enable_tensor_parallelism', [False])
@parameterize('enable_fused_normalization', [False])
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_opt
def run_opt_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
input_ids, _ = inputs['input_ids'], inputs['attention_mask']
batch_size, seq_len = input_ids.shape
hidden_size = 128
hidden_state_shape = (batch_size, seq_len, hidden_size)
if not stage_manager.is_first_stage():
# change inputs if not the first stage
hidden_states = torch.zeros(*hidden_state_shape).cuda()
inputs['input_ids'] = None
inputs['hidden_states'] = hidden_states
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
sharded_model.train()
output = sharded_model(**inputs)
if stage_manager.is_last_stage():
assert output[0] is not None
else:
assert output['hidden_states'].shape == hidden_state_shape
torch.cuda.empty_cache()
def check_opt(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_opt_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_opt():
spawn(check_opt, 4)
if __name__ == "__main__":
test_opt()

View File

@ -1,60 +1,127 @@
import os
import pytest
import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
)
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
assert_hf_output_close(org_output, shard_output, atol=1e-3, rtol=1e-3)
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
# do backward
org_loss.backward()
shard_loss.backward()
org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster)
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if org_model.__class__.__name__ == 'ViTModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
check_loss(org_loss, sharded_loss, atol=1e-6, rtol=1e-3)
# unwrap model
if org_model.__class__.__name__ == 'ViTModel':
vit_model = org_model
shard_vit_model = sharded_model
shard_vit_model = sharded_model.unwrap()
else:
vit_model = org_model.vit
shard_vit_model = sharded_model.vit
shard_vit_model = sharded_model.unwrap().vit
# check grad
col_layer_for_check = ['encoder.layer[0].attention.attention.query']
row_layer_for_check = ['encoder.layer[0].attention.output.dense']
check_grad(vit_model, shard_vit_model, col_layer_for_check, atol=1e-5, rtol=1e-3, dim=0, verbose=False)
check_grad(vit_model, shard_vit_model, row_layer_for_check, atol=1e-5, rtol=1e-3, dim=1, verbose=False)
row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection']
col_layer_for_check = ['encoder.layer[0].attention.output.dense']
if stage_manager is None or stage_manager.is_first_stage():
check_grad(vit_model,
shard_vit_model,
row_layer_for_check,
tp_group,
atol=1e-5,
rtol=1e-3,
dim=0,
verbose=False)
check_grad(vit_model,
shard_vit_model,
col_layer_for_check,
tp_group,
atol=1e-5,
rtol=1e-3,
dim=1,
verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
check_weight(vit_model,
shard_vit_model,
col_layer_for_check,
tp_group,
atol=5e-3,
rtol=1e-3,
dim=1,
verbose=False)
torch.cuda.empty_cache()
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': True,
'use_lazy_init': False
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': False,
'use_lazy_init': False
}, {
'tp_size': 4,
'pp_size': 1,
'enable_fused_normalization': True,
'use_lazy_init': False
}])
def run_vit_test(test_config):
# TODO: add test_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
# TODO: add test_config for flash attention & jit operator after supporting
# TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention, enable_jit_fused)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@ -68,7 +135,7 @@ def check_vit(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_vit():
spawn(check_vit, 2)
spawn(check_vit, 4)
if __name__ == "__main__":

View File

@ -1,74 +0,0 @@
import pytest
import torch
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_pipeline_model
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# TODO: add tests for forward/backward later
pass
@parameterize('enable_tensor_parallelism', [False])
@parameterize('enable_fused_normalization', [False])
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_vit
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
pixel_values = inputs['pixel_values']
batch_size = len(pixel_values)
hidden_size = 768
hidden_state_shape = (batch_size, 197, hidden_size)
if not stage_manager.is_first_stage():
# change inputs if not the first stage
hidden_states = torch.randn(*hidden_state_shape).cuda()
# inputs['pixel_values'] = None
inputs['hidden_states'] = hidden_states
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
sharded_model.train()
output = sharded_model(**inputs)
if stage_manager.is_last_stage():
if name != 'transformers_vit':
assert output.loss is not None
else:
assert output['hidden_states'].shape == hidden_state_shape, \
f'hidden_states shape is not correct, output:{output["hidden_states"].shape} is not equal to hidden_state:{hidden_state_shape}'
torch.cuda.empty_cache()
def check_vit(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_vit_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_vit():
spawn(check_vit, 4)
if __name__ == "__main__":
test_vit()