mirror of https://github.com/hpcaitech/ColossalAI
support qwen model
parent
32e642bf40
commit
4c69e2dc91
|
@ -44,7 +44,7 @@ class Qwen2PipelineForwards:
|
|||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
)-> Union[Tuple, BaseModelOutputWithPast]:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
@ -82,14 +82,18 @@ class Qwen2PipelineForwards:
|
|||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
|
||||
assert past_key_values is None, "past_key_values is not supported for Qwen2 models at the moment."
|
||||
|
||||
past_key_values_length = 0
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
|
@ -123,18 +127,11 @@ class Qwen2PipelineForwards:
|
|||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
@ -148,20 +145,14 @@ class Qwen2PipelineForwards:
|
|||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
|
@ -315,7 +306,7 @@ class Qwen2PipelineForwards:
|
|||
else:
|
||||
hidden_states = outputs.get("hidden_states")
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def qwen2_for_sequence_classification_forward(
|
||||
self: Qwen2ForSequenceClassification,
|
||||
|
|
|
@ -151,10 +151,10 @@ class Qwen2Policy(Policy):
|
|||
module = self.model.model
|
||||
|
||||
if stage_manager.is_interleave:
|
||||
layers_per_stage = self.distribute_layers(
|
||||
layers_per_stage = stage_manager.distribute_layers(
|
||||
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
|
||||
)
|
||||
stage_manager.stage_indices = Policy.get_stage_index(
|
||||
stage_manager.stage_indices = stage_manager.get_stage_index(
|
||||
layers_per_stage,
|
||||
stage_manager.stage,
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
|
@ -165,8 +165,8 @@ class Qwen2Policy(Policy):
|
|||
}
|
||||
|
||||
else:
|
||||
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||
|
@ -192,10 +192,10 @@ class Qwen2Policy(Policy):
|
|||
held_layers = []
|
||||
if stage_manager.is_interleave:
|
||||
assert stage_manager.num_model_chunks is not None
|
||||
layers_per_stage = self.distribute_layers(
|
||||
layers_per_stage = stage_manager.distribute_layers(
|
||||
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
|
||||
)
|
||||
stage_indices = Policy.get_stage_index(
|
||||
stage_indices = stage_manager.get_stage_index(
|
||||
layers_per_stage,
|
||||
stage_manager.stage,
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
|
@ -209,10 +209,10 @@ class Qwen2Policy(Policy):
|
|||
held_layers.append(module.norm)
|
||||
|
||||
else:
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
|
|
|
@ -182,12 +182,6 @@ def run_forward_backward_with_hybrid_plugin(
|
|||
data_iter = iter([data])
|
||||
sharded_output = booster.execute_pipeline(
|
||||
data_iter, sharded_model, _criterion, sharded_optimizer, return_loss=True, return_outputs=True
|
||||
data_iter,
|
||||
sharded_model,
|
||||
_criterion,
|
||||
sharded_optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True,
|
||||
)
|
||||
sharded_loss = sharded_output["loss"]
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue