support qwen model

pull/5842/head
Wang Binluo 2024-04-09 11:50:35 +08:00 committed by アマデウス
parent 32e642bf40
commit 4c69e2dc91
3 changed files with 23 additions and 38 deletions

View File

@ -44,7 +44,7 @@ class Qwen2PipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
)-> Union[Tuple, BaseModelOutputWithPast]:
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@ -82,14 +82,18 @@ class Qwen2PipelineForwards:
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
assert past_key_values is None, "past_key_values is not supported for Qwen2 models at the moment."
past_key_values_length = 0
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
@ -123,18 +127,11 @@ class Qwen2PipelineForwards:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
hidden_states,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
@ -148,20 +145,14 @@ class Qwen2PipelineForwards:
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
None,
past_key_values,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
@ -315,7 +306,7 @@ class Qwen2PipelineForwards:
else:
hidden_states = outputs.get("hidden_states")
return {"hidden_states": hidden_states}
@staticmethod
def qwen2_for_sequence_classification_forward(
self: Qwen2ForSequenceClassification,

View File

@ -151,10 +151,10 @@ class Qwen2Policy(Policy):
module = self.model.model
if stage_manager.is_interleave:
layers_per_stage = self.distribute_layers(
layers_per_stage = stage_manager.distribute_layers(
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_manager.stage_indices = Policy.get_stage_index(
stage_manager.stage_indices = stage_manager.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
@ -165,8 +165,8 @@ class Qwen2Policy(Policy):
}
else:
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = stage_manager.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = stage_manager.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
@ -192,10 +192,10 @@ class Qwen2Policy(Policy):
held_layers = []
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = self.distribute_layers(
layers_per_stage = stage_manager.distribute_layers(
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_indices = Policy.get_stage_index(
stage_indices = stage_manager.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
@ -209,10 +209,10 @@ class Qwen2Policy(Policy):
held_layers.append(module.norm)
else:
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.layers), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)

View File

@ -182,12 +182,6 @@ def run_forward_backward_with_hybrid_plugin(
data_iter = iter([data])
sharded_output = booster.execute_pipeline(
data_iter, sharded_model, _criterion, sharded_optimizer, return_loss=True, return_outputs=True
data_iter,
sharded_model,
_criterion,
sharded_optimizer,
return_loss=True,
return_outputs=True,
)
sharded_loss = sharded_output["loss"]
else: