mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] rewrite t5 tests & support multi-tensor transmitting in pipeline (#4388)
* fix remaining t5 bugs/rewrite t5 tests * fix multi-tensor communication in pipeline * rearrange test_config * fix keyerror in sync_shared_params * fix get_held_layers & Randomnizer, complete t5 tests * erase printing * fix get_held_layers through modifying _release_unheld_layers * fix _get_recursive_held_layers bugpull/4445/head
parent
906426cb44
commit
ed4c448488
|
@ -50,8 +50,10 @@ class HybridParallelModule(ModelWrapper):
|
||||||
|
|
||||||
def sync_shared_params(self):
|
def sync_shared_params(self):
|
||||||
for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
|
for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
|
||||||
|
if self.stage_manager.stage in shared_param:
|
||||||
param = shared_param[self.stage_manager.stage]
|
param = shared_param[self.stage_manager.stage]
|
||||||
dist.all_reduce(param.grad, group=group)
|
dist.all_reduce(param.grad, group=group)
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
def no_sync(self) -> Iterator[None]:
|
def no_sync(self) -> Iterator[None]:
|
||||||
# no sync grads across data parallel
|
# no sync grads across data parallel
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import pickle
|
import pickle
|
||||||
|
import re
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -31,7 +32,10 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
|
||||||
if b'cuda' in buf:
|
if b'cuda' in buf:
|
||||||
buf_array = bytearray(buf)
|
buf_array = bytearray(buf)
|
||||||
device_index = torch.cuda.current_device()
|
device_index = torch.cuda.current_device()
|
||||||
buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index
|
# There might be more than one output tensors during forward
|
||||||
|
for cuda_str in re.finditer(b'cuda', buf_array):
|
||||||
|
pos = cuda_str.start()
|
||||||
|
buf_array[pos + 5] = 48 + device_index
|
||||||
buf = bytes(buf_array)
|
buf = bytes(buf_array)
|
||||||
|
|
||||||
io_bytes = io.BytesIO(buf)
|
io_bytes = io.BytesIO(buf)
|
||||||
|
|
|
@ -86,7 +86,7 @@ def retain_grad(x: Any) -> None:
|
||||||
Args:
|
Args:
|
||||||
x (Any): Object to be called.
|
x (Any): Object to be called.
|
||||||
"""
|
"""
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor) and x.requires_grad:
|
||||||
x.retain_grad()
|
x.retain_grad()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -107,8 +107,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
if output_obj_grad is None:
|
if output_obj_grad is None:
|
||||||
optimizer.backward(output_obj)
|
optimizer.backward(output_obj)
|
||||||
else:
|
else:
|
||||||
|
if "backward_tensor_keys" not in output_obj:
|
||||||
for k, grad in output_obj_grad.items():
|
for k, grad in output_obj_grad.items():
|
||||||
optimizer.backward_by_grad(output_obj[k], grad)
|
optimizer.backward_by_grad(output_obj[k], grad)
|
||||||
|
else:
|
||||||
|
for k, grad in output_obj_grad.items():
|
||||||
|
output_obj[k].grad = grad
|
||||||
|
for k in output_obj["backward_tensor_keys"]:
|
||||||
|
tensor_to_backward = output_obj[k]
|
||||||
|
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
|
||||||
|
|
||||||
# Collect the grad of the input_obj.
|
# Collect the grad of the input_obj.
|
||||||
input_obj_grad = None
|
input_obj_grad = None
|
||||||
|
|
|
@ -122,6 +122,13 @@ class Randomizer:
|
||||||
"""
|
"""
|
||||||
Randomizer._INDEX += 1
|
Randomizer._INDEX += 1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reset_index():
|
||||||
|
"""
|
||||||
|
Reset the index to zero.
|
||||||
|
"""
|
||||||
|
Randomizer._INDEX = 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_randomizer_index_synchronized(process_group: ProcessGroup = None):
|
def is_randomizer_index_synchronized(process_group: ProcessGroup = None):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -238,7 +238,8 @@ class T5PipelineForwards:
|
||||||
return {
|
return {
|
||||||
'hidden_states': hidden_states,
|
'hidden_states': hidden_states,
|
||||||
'position_bias': position_bias,
|
'position_bias': position_bias,
|
||||||
'encoder_decoder_position_bias': encoder_decoder_position_bias
|
'encoder_decoder_position_bias': encoder_decoder_position_bias,
|
||||||
|
'backward_tensor_keys': ['hidden_states']
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -261,8 +262,10 @@ class T5PipelineForwards:
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
position_bias: Optional[torch.Tensor] = None,
|
position_bias: Optional[torch.Tensor] = None,
|
||||||
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
|
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
|
||||||
|
backward_tensor_keys: Optional[List[str]] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
decoder_starting_stage: Optional[int] = None,
|
decoder_starting_stage: Optional[int] = None,
|
||||||
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
|
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
|
||||||
|
@ -303,7 +306,6 @@ class T5PipelineForwards:
|
||||||
decoder_head_mask = head_mask
|
decoder_head_mask = head_mask
|
||||||
|
|
||||||
in_decoder = stage_manager.stage >= decoder_starting_stage
|
in_decoder = stage_manager.stage >= decoder_starting_stage
|
||||||
|
|
||||||
# Stage is in encoder, directly return the output of t5_stack_forward
|
# Stage is in encoder, directly return the output of t5_stack_forward
|
||||||
if not in_decoder:
|
if not in_decoder:
|
||||||
encoder_outputs = T5PipelineForwards.t5_stack_forward(
|
encoder_outputs = T5PipelineForwards.t5_stack_forward(
|
||||||
|
@ -323,25 +325,18 @@ class T5PipelineForwards:
|
||||||
decoder_starting_stage=decoder_starting_stage)
|
decoder_starting_stage=decoder_starting_stage)
|
||||||
if stage_manager.stage == decoder_starting_stage - 1:
|
if stage_manager.stage == decoder_starting_stage - 1:
|
||||||
# last stage of encoder
|
# last stage of encoder
|
||||||
return {'encoder_outputs': encoder_outputs}
|
return {'encoder_hidden_states': encoder_outputs[0]}
|
||||||
else:
|
else:
|
||||||
return encoder_outputs
|
return encoder_outputs
|
||||||
|
|
||||||
at_last_decoder_stage = stage_manager.is_last_stage()
|
at_last_decoder_stage = stage_manager.is_last_stage()
|
||||||
at_first_decoder_stage = stage_manager.stage == decoder_starting_stage
|
at_first_decoder_stage = stage_manager.stage == decoder_starting_stage
|
||||||
|
|
||||||
if encoder_outputs is None:
|
if encoder_outputs is not None:
|
||||||
raise ValueError("Non-empty encoder_outputs should be passed in at decoder stages.")
|
|
||||||
|
|
||||||
encoder_hidden_states = encoder_outputs[0]
|
encoder_hidden_states = encoder_outputs[0]
|
||||||
if return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
elif encoder_hidden_states is None:
|
||||||
encoder_outputs = BaseModelOutput(
|
raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.")
|
||||||
last_hidden_state=encoder_outputs[0],
|
|
||||||
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
|
||||||
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stage is in decoder, we assume that the outputs of last stage of encoder will be passed in.
|
|
||||||
if not at_first_decoder_stage and hidden_states is None:
|
if not at_first_decoder_stage and hidden_states is None:
|
||||||
raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")
|
raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")
|
||||||
|
|
||||||
|
@ -360,6 +355,7 @@ class T5PipelineForwards:
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
stage_manager=stage_manager,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
position_bias=position_bias,
|
position_bias=position_bias,
|
||||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||||
|
@ -368,22 +364,19 @@ class T5PipelineForwards:
|
||||||
|
|
||||||
# Directly return outputs of overloaded T5Stack forward if not at last stage.
|
# Directly return outputs of overloaded T5Stack forward if not at last stage.
|
||||||
if not at_last_decoder_stage:
|
if not at_last_decoder_stage:
|
||||||
decoder_outputs['encoder_outputs'] = encoder_outputs # encoder_outputs should be passed to the next stage
|
# encoder_hidden_states should be passed to the next stage
|
||||||
|
decoder_outputs['encoder_hidden_states'] = encoder_hidden_states
|
||||||
return decoder_outputs
|
return decoder_outputs
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return decoder_outputs + encoder_outputs
|
return decoder_outputs + encoder_hidden_states
|
||||||
|
else:
|
||||||
return Seq2SeqModelOutput(
|
return Seq2SeqModelOutput(last_hidden_state=decoder_outputs.last_hidden_state,
|
||||||
last_hidden_state=decoder_outputs.last_hidden_state,
|
|
||||||
past_key_values=decoder_outputs.past_key_values,
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
cross_attentions=decoder_outputs.cross_attentions,
|
cross_attentions=decoder_outputs.cross_attentions,
|
||||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
encoder_last_hidden_state=encoder_hidden_states)
|
||||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
|
||||||
encoder_attentions=encoder_outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def t5_for_conditional_generation_forward(
|
def t5_for_conditional_generation_forward(
|
||||||
|
@ -406,8 +399,10 @@ class T5PipelineForwards:
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
position_bias: Optional[torch.Tensor] = None,
|
position_bias: Optional[torch.Tensor] = None,
|
||||||
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
|
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
|
||||||
|
backward_tensor_keys: Optional[List[str]] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
decoder_starting_stage: Optional[int] = None,
|
decoder_starting_stage: Optional[int] = None,
|
||||||
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
||||||
|
@ -468,28 +463,25 @@ class T5PipelineForwards:
|
||||||
decoder_starting_stage=decoder_starting_stage)
|
decoder_starting_stage=decoder_starting_stage)
|
||||||
if stage_manager.stage == decoder_starting_stage - 1:
|
if stage_manager.stage == decoder_starting_stage - 1:
|
||||||
# last stage of encoder
|
# last stage of encoder
|
||||||
return {'encoder_outputs': encoder_outputs}
|
return {'encoder_hidden_states': encoder_outputs[0]}
|
||||||
else:
|
else:
|
||||||
return encoder_outputs
|
return encoder_outputs
|
||||||
|
|
||||||
at_last_decoder_stage = stage_manager.is_last_stage()
|
at_last_decoder_stage = stage_manager.is_last_stage()
|
||||||
at_first_decoder_stage = stage_manager.stage == decoder_starting_stage
|
at_first_decoder_stage = stage_manager.stage == decoder_starting_stage
|
||||||
|
|
||||||
if encoder_outputs is None:
|
if encoder_outputs is not None:
|
||||||
raise ValueError("Non-empty encoder_outputs should be passed in at decoder stages.")
|
|
||||||
|
|
||||||
encoder_hidden_states = encoder_outputs[0]
|
encoder_hidden_states = encoder_outputs[0]
|
||||||
if return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
elif encoder_hidden_states is None:
|
||||||
encoder_outputs = BaseModelOutput(
|
raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.")
|
||||||
last_hidden_state=encoder_outputs[0],
|
|
||||||
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
|
||||||
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stage is in decoder, we assume that the outputs of last stage of encoder will be passed in.
|
|
||||||
if not at_first_decoder_stage and hidden_states is None:
|
if not at_first_decoder_stage and hidden_states is None:
|
||||||
raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")
|
raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")
|
||||||
|
|
||||||
|
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||||
|
# get decoder inputs from shifting lm labels to the right
|
||||||
|
decoder_input_ids = self._shift_right(labels)
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
decoder_outputs = T5PipelineForwards.t5_stack_forward(
|
decoder_outputs = T5PipelineForwards.t5_stack_forward(
|
||||||
self.decoder,
|
self.decoder,
|
||||||
|
@ -505,6 +497,7 @@ class T5PipelineForwards:
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
stage_manager=stage_manager,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
position_bias=position_bias,
|
position_bias=position_bias,
|
||||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||||
|
@ -513,7 +506,8 @@ class T5PipelineForwards:
|
||||||
|
|
||||||
# Directly return outputs of overloaded T5Stack forward if not at last stage.
|
# Directly return outputs of overloaded T5Stack forward if not at last stage.
|
||||||
if not at_last_decoder_stage:
|
if not at_last_decoder_stage:
|
||||||
decoder_outputs['encoder_outputs'] = encoder_outputs # encoder_outputs should be passed to the next stage
|
# encoder_hidden_states should be passed to the next stage
|
||||||
|
decoder_outputs['encoder_hidden_states'] = encoder_hidden_states
|
||||||
return decoder_outputs
|
return decoder_outputs
|
||||||
|
|
||||||
sequence_output = decoder_outputs[0]
|
sequence_output = decoder_outputs[0]
|
||||||
|
@ -533,20 +527,16 @@ class T5PipelineForwards:
|
||||||
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
output = (lm_logits,) + decoder_outputs[1:] + encoder_hidden_states
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return Seq2SeqLMOutput(
|
return Seq2SeqLMOutput(loss=loss,
|
||||||
loss=loss,
|
|
||||||
logits=lm_logits,
|
logits=lm_logits,
|
||||||
past_key_values=decoder_outputs.past_key_values,
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
cross_attentions=decoder_outputs.cross_attentions,
|
cross_attentions=decoder_outputs.cross_attentions,
|
||||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
encoder_last_hidden_state=encoder_hidden_states)
|
||||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
|
||||||
encoder_attentions=encoder_outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def t5_encoder_model_forward(
|
def t5_encoder_model_forward(
|
||||||
|
@ -562,6 +552,7 @@ class T5PipelineForwards:
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
position_bias: Optional[torch.Tensor] = None,
|
position_bias: Optional[torch.Tensor] = None,
|
||||||
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
|
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
|
||||||
|
backward_tensor_keys: Optional[List[str]] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
decoder_starting_stage: Optional[int] = None,
|
decoder_starting_stage: Optional[int] = None,
|
||||||
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
||||||
|
|
|
@ -260,7 +260,7 @@ class T5BasePolicy(Policy):
|
||||||
|
|
||||||
model = self.model
|
model = self.model
|
||||||
encoder = self.model.encoder
|
encoder = self.model.encoder
|
||||||
decoder = self.model.__dict__.get('decoder', None)
|
decoder = getattr(self.model, 'decoder', None)
|
||||||
|
|
||||||
num_encoder_layers = len(encoder.block)
|
num_encoder_layers = len(encoder.block)
|
||||||
num_decoder_layers = len(decoder.block) if decoder else 0
|
num_decoder_layers = len(decoder.block) if decoder else 0
|
||||||
|
@ -300,7 +300,7 @@ class T5BasePolicy(Policy):
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
encoder = self.model.encoder
|
encoder = self.model.encoder
|
||||||
decoder = self.model.__dict__.get('decoder', None)
|
decoder = getattr(self.model, 'decoder', None)
|
||||||
|
|
||||||
num_encoder_layers = len(encoder.block)
|
num_encoder_layers = len(encoder.block)
|
||||||
num_decoder_layers = len(decoder.block) if decoder else 0
|
num_decoder_layers = len(decoder.block) if decoder else 0
|
||||||
|
@ -355,15 +355,6 @@ class T5ModelPolicy(T5BasePolicy):
|
||||||
return [{0: module.shared.weight, decoder_starting_stage: module.decoder.embed_tokens.weight}]
|
return [{0: module.shared.weight, decoder_starting_stage: module.decoder.embed_tokens.weight}]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def postprocess(self):
|
|
||||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
|
||||||
binding_map = {"shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]}
|
|
||||||
for k, v in binding_map.items():
|
|
||||||
src = getattr_(self.model, k)
|
|
||||||
for dst in v:
|
|
||||||
setattr_(self.model, dst, src)
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
|
|
||||||
class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||||
|
|
||||||
|
@ -409,29 +400,22 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||||
stage_manager.num_stages)
|
stage_manager.num_stages)
|
||||||
|
|
||||||
shared_params = []
|
shared_params = []
|
||||||
|
shared_embedding = {}
|
||||||
if id(module.decoder.embed_tokens.weight) == id(module.shared.weight):
|
if id(module.decoder.embed_tokens.weight) == id(module.shared.weight):
|
||||||
shared_params.append({
|
shared_embedding[0] = module.shared.weight
|
||||||
0: module.shared.weight,
|
shared_embedding[decoder_starting_stage] = module.decoder.embed_tokens.weight
|
||||||
decoder_starting_stage: module.decoder.embed_tokens.weight
|
|
||||||
})
|
|
||||||
if id(module.lm_head.weight) == id(module.shared.weight):
|
if id(module.lm_head.weight) == id(module.shared.weight):
|
||||||
shared_params.append({0: module.shared.weight, stage_manager.num_stages - 1: module.lm_head.weight})
|
shared_embedding[0] = module.shared.weight
|
||||||
|
shared_embedding[stage_manager.num_stages - 1] = module.lm_head.weight
|
||||||
|
|
||||||
|
if len(shared_embedding) > 0:
|
||||||
|
shared_params.append(shared_embedding)
|
||||||
|
|
||||||
return shared_params
|
return shared_params
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def postprocess(self):
|
|
||||||
super().postprocess()
|
|
||||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
|
||||||
binding_map = {
|
|
||||||
"shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
|
||||||
}
|
|
||||||
for k, v in binding_map.items():
|
|
||||||
src = getattr_(self.model, k)
|
|
||||||
for dst in v:
|
|
||||||
setattr_(self.model, dst, src)
|
|
||||||
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
|
|
||||||
class T5EncoderPolicy(T5BasePolicy):
|
class T5EncoderPolicy(T5BasePolicy):
|
||||||
|
|
||||||
|
@ -462,12 +446,3 @@ class T5EncoderPolicy(T5BasePolicy):
|
||||||
|
|
||||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def postprocess(self):
|
|
||||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
|
||||||
binding_map = {"shared.weight": ["encoder.embed_tokens.weight"]}
|
|
||||||
for k, v in binding_map.items():
|
|
||||||
src = getattr_(self.model, k)
|
|
||||||
for dst in v:
|
|
||||||
setattr_(self.model, dst, src)
|
|
||||||
return self.model
|
|
||||||
|
|
|
@ -198,6 +198,20 @@ class ModelSharder(object):
|
||||||
|
|
||||||
setattr_(org_layer, suffix, replace_layer)
|
setattr_(org_layer, suffix, replace_layer)
|
||||||
|
|
||||||
|
def _get_recursive_held_layers(self, held_layers: Optional[List[nn.Module]]) -> Optional[List[nn.Module]]:
|
||||||
|
|
||||||
|
def collect_sub_modules(module: nn.Module):
|
||||||
|
if module is None:
|
||||||
|
return
|
||||||
|
recursive_held_layers.append(module)
|
||||||
|
for name, child in module.named_children():
|
||||||
|
collect_sub_modules(child)
|
||||||
|
|
||||||
|
recursive_held_layers = []
|
||||||
|
for module in held_layers:
|
||||||
|
collect_sub_modules(module)
|
||||||
|
return recursive_held_layers
|
||||||
|
|
||||||
def _release_unheld_layers(self) -> Optional[Set[nn.Module]]:
|
def _release_unheld_layers(self) -> Optional[Set[nn.Module]]:
|
||||||
r"""
|
r"""
|
||||||
Release the unheld layers in the model
|
Release the unheld layers in the model
|
||||||
|
@ -205,7 +219,7 @@ class ModelSharder(object):
|
||||||
if self.shard_config and self.shard_config.pipeline_stage_manager:
|
if self.shard_config and self.shard_config.pipeline_stage_manager:
|
||||||
held_layers = self.policy.get_held_layers()
|
held_layers = self.policy.get_held_layers()
|
||||||
set_tensors_to_none(self.model, exclude=set(held_layers))
|
set_tensors_to_none(self.model, exclude=set(held_layers))
|
||||||
return set(held_layers)
|
return set(self._get_recursive_held_layers(held_layers))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _materialize(self) -> None:
|
def _materialize(self) -> None:
|
||||||
|
|
|
@ -68,16 +68,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
@parameterize('test_config', [{
|
@parameterize('test_config', [{
|
||||||
'tp_size': 1,
|
|
||||||
'pp_size': 2,
|
|
||||||
'num_microbatches': 4,
|
|
||||||
'use_lazy_init': True
|
|
||||||
}, {
|
|
||||||
'tp_size': 2,
|
'tp_size': 2,
|
||||||
'pp_size': 2,
|
'pp_size': 2,
|
||||||
'num_microbatches': 4,
|
'num_microbatches': 4,
|
||||||
'enable_fused_normalization': False,
|
'enable_fused_normalization': True,
|
||||||
|
'use_lazy_init': True
|
||||||
|
}, {
|
||||||
|
'tp_size': 1,
|
||||||
|
'pp_size': 2,
|
||||||
|
'num_microbatches': 4,
|
||||||
'use_lazy_init': False
|
'use_lazy_init': False
|
||||||
}, {
|
}, {
|
||||||
'tp_size': 4,
|
'tp_size': 4,
|
||||||
|
|
|
@ -1,60 +1,110 @@
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.testing import (
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
assert_hf_output_close,
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
clear_cache_before_run,
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
parameterize,
|
|
||||||
rerun_if_address_is_in_use,
|
|
||||||
spawn,
|
|
||||||
)
|
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
|
from tests.test_shardformer.test_model._utils import (
|
||||||
|
build_model_from_hybrid_plugin,
|
||||||
|
check_grad,
|
||||||
|
check_loss,
|
||||||
|
check_output_hidden_state,
|
||||||
|
check_weight,
|
||||||
|
run_forward_backward_with_hybrid_plugin,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||||
# check forward
|
|
||||||
# the value "past_key_values" is sharded, so we ignore
|
|
||||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
|
||||||
output_transform_fn, loss_fn)
|
|
||||||
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], atol=1e-5)
|
|
||||||
|
|
||||||
# do backward
|
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
|
||||||
org_loss.backward()
|
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
|
||||||
shard_loss.backward()
|
|
||||||
|
|
||||||
assert torch.allclose(org_loss, shard_loss,
|
org_loss, org_output, sharded_loss, sharded_output = \
|
||||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
run_forward_backward_with_hybrid_plugin(
|
||||||
|
org_model,
|
||||||
|
sharded_model,
|
||||||
|
sharded_optimizer,
|
||||||
|
data_gen_fn,
|
||||||
|
output_transform_fn,
|
||||||
|
criterion,
|
||||||
|
booster)
|
||||||
|
|
||||||
# check grad
|
stage_manager = booster.plugin.stage_manager
|
||||||
col_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.q', 'shared']
|
tp_group = booster.plugin.tp_group
|
||||||
row_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.relative_attention_bias']
|
|
||||||
check_grad(org_model, sharded_model, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
|
|
||||||
check_grad(org_model, sharded_model, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
|
|
||||||
|
|
||||||
# check weights are tied
|
# check last hidden state & loss
|
||||||
if hasattr(org_model, 'lm_head'):
|
if stage_manager is None or stage_manager.is_last_stage():
|
||||||
assert org_model.shared.weight.data.data_ptr() == org_model.lm_head.weight.data.data_ptr()
|
|
||||||
assert sharded_model.shared.weight.data.data_ptr() == sharded_model.lm_head.weight.data.data_ptr()
|
if org_model.__class__.__name__ != 'T5ForConditionalGeneration':
|
||||||
|
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
|
||||||
|
|
||||||
|
check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3)
|
||||||
|
|
||||||
|
# unwrap model
|
||||||
|
t5 = org_model
|
||||||
|
sharded_t5 = sharded_model.unwrap()
|
||||||
|
|
||||||
|
row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q']
|
||||||
|
|
||||||
|
# check weights and gradients
|
||||||
|
if stage_manager is None or stage_manager.is_first_stage():
|
||||||
|
check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-5, rtol=1e-3, dim=0)
|
||||||
|
|
||||||
|
# check weights after optimizer.step()
|
||||||
|
org_optimizer.step()
|
||||||
|
sharded_optimizer.step()
|
||||||
|
if stage_manager is None or stage_manager.is_first_stage():
|
||||||
|
check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
@parameterize('enable_fused_normalization', [True, False])
|
@parameterize('test_config', [{
|
||||||
@parameterize('enable_tensor_parallelism', [True, False])
|
'tp_size': 2,
|
||||||
@parameterize('use_lazy_init', [False, True])
|
'pp_size': 2,
|
||||||
@parameterize('enable_flash_attention', [True, False])
|
'num_microbatches': 2,
|
||||||
@parameterize('enable_jit_fused', [True, False])
|
'enable_fused_normalization': True,
|
||||||
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init, enable_flash_attention,
|
'use_lazy_init': True
|
||||||
enable_jit_fused):
|
}, {
|
||||||
|
'tp_size': 1,
|
||||||
|
'pp_size': 2,
|
||||||
|
'num_microbatches': 4,
|
||||||
|
'use_lazy_init': False
|
||||||
|
}, {
|
||||||
|
'tp_size': 4,
|
||||||
|
'pp_size': 1,
|
||||||
|
'enable_fused_normalization': True,
|
||||||
|
'use_lazy_init': False
|
||||||
|
}, {
|
||||||
|
'tp_size': 1,
|
||||||
|
'pp_size': 4,
|
||||||
|
'num_microbatches': 4,
|
||||||
|
'use_lazy_init': False
|
||||||
|
}])
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def run_t5_test(test_config):
|
||||||
|
|
||||||
|
# TODO: add plugin_config for TP+DP after supporting & debugging it
|
||||||
|
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
|
||||||
|
|
||||||
|
# TODO: add test_config for flash attention & jit operator after supporting
|
||||||
|
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
||||||
|
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
|
||||||
enable_flash_attention, enable_jit_fused, use_lazy_init)
|
# skip 4-stage pp test for t5_encoder
|
||||||
check_state_dict(org_model, sharded_model, name=name)
|
if test_config['pp_size'] > 2 and name == 'transformers_t5_encoder_model':
|
||||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
continue
|
||||||
|
|
||||||
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
|
||||||
|
clear_layout_converter()
|
||||||
|
Randomizer.reset_index()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,7 +118,7 @@ def check_t5(rank, world_size, port):
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
def test_t5():
|
def test_t5():
|
||||||
spawn(check_t5, 2)
|
spawn(check_t5, 4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,101 +0,0 @@
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
|
||||||
from colossalai.logging import disable_existing_loggers
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
||||||
from colossalai.shardformer.policies.t5 import T5BasePolicy
|
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
|
||||||
from tests.kit.model_zoo import model_zoo
|
|
||||||
from tests.test_shardformer.test_model._utils import build_pipeline_model
|
|
||||||
|
|
||||||
|
|
||||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
|
||||||
# TODO: add tests for forward/backward later
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@parameterize('enable_tensor_parallelism', [False])
|
|
||||||
@parameterize('enable_fused_normalization', [False])
|
|
||||||
@parameterize('use_lazy_init', [False])
|
|
||||||
#TODO: merge this into test_shard_t5.py
|
|
||||||
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
|
||||||
DP_DIM, PP_DIM = 0, 1
|
|
||||||
DP_SIZE, PP_SIZE = 2, 2
|
|
||||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
|
||||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
|
||||||
|
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
|
||||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
|
||||||
|
|
||||||
inputs = data_gen_fn()
|
|
||||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
|
||||||
input_ids = inputs['input_ids']
|
|
||||||
|
|
||||||
_, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
|
||||||
enable_tensor_parallelism, use_lazy_init)
|
|
||||||
|
|
||||||
batch_size, seq_len = input_ids.shape
|
|
||||||
hidden_size = sharded_model.config.d_model
|
|
||||||
num_heads = sharded_model.config.num_heads
|
|
||||||
hidden_state_shape = (batch_size, seq_len, hidden_size)
|
|
||||||
position_bias_shape = (batch_size, num_heads, seq_len, seq_len)
|
|
||||||
|
|
||||||
num_encoder_layers = len(sharded_model.encoder.block)
|
|
||||||
decoder = sharded_model.__dict__.get('decoder', None)
|
|
||||||
num_decoder_layers = len(decoder.block) if decoder else 0
|
|
||||||
|
|
||||||
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(num_encoder_layers, num_decoder_layers, PP_SIZE)
|
|
||||||
stage = stage_manager.stage
|
|
||||||
at_first_stage = (stage == 0) or (stage == decoder_starting_stage)
|
|
||||||
at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1)
|
|
||||||
in_decoder = stage >= decoder_starting_stage
|
|
||||||
|
|
||||||
if not at_first_stage:
|
|
||||||
# change inputs if not the first stage
|
|
||||||
hidden_states = torch.zeros(*hidden_state_shape).cuda()
|
|
||||||
position_bias = torch.zeros(*position_bias_shape).cuda()
|
|
||||||
encoder_decoder_position_bias = torch.zeros(*position_bias_shape).cuda()
|
|
||||||
inputs['input_ids'] = None
|
|
||||||
inputs['hidden_states'] = hidden_states
|
|
||||||
inputs['position_bias'] = position_bias
|
|
||||||
inputs['encoder_decoder_position_bias'] = encoder_decoder_position_bias
|
|
||||||
if in_decoder:
|
|
||||||
encoder_output_states = torch.zeros(*hidden_state_shape).cuda()
|
|
||||||
inputs['encoder_outputs'] = (encoder_output_states,)
|
|
||||||
|
|
||||||
sharded_model.train()
|
|
||||||
output = sharded_model(**inputs)
|
|
||||||
if at_last_stage:
|
|
||||||
if name == 'transformers_t5_for_conditional_generation' and in_decoder:
|
|
||||||
assert output.loss is not None
|
|
||||||
else:
|
|
||||||
if name != 'transformers_t5_encoder_model' and not in_decoder:
|
|
||||||
output = output['encoder_outputs']
|
|
||||||
assert output[0].shape == hidden_state_shape
|
|
||||||
else:
|
|
||||||
assert output['hidden_states'].shape == hidden_state_shape
|
|
||||||
# position_bias information should be passed in T5
|
|
||||||
assert output['position_bias'].shape == position_bias_shape
|
|
||||||
if in_decoder:
|
|
||||||
assert output['encoder_decoder_position_bias'].shape == position_bias_shape
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
def check_t5(rank, world_size, port):
|
|
||||||
disable_existing_loggers()
|
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
||||||
run_t5_test()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
|
||||||
@rerun_if_address_is_in_use()
|
|
||||||
@clear_cache_before_run()
|
|
||||||
def test_t5():
|
|
||||||
spawn(check_t5, 4)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_t5()
|
|
Loading…
Reference in New Issue