mirror of https://github.com/hpcaitech/ColossalAI
support qwen model
parent
32e642bf40
commit
4c69e2dc91
|
@ -44,7 +44,7 @@ class Qwen2PipelineForwards:
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
shard_config: ShardConfig = None,
|
shard_config: ShardConfig = None,
|
||||||
):
|
)-> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
@ -82,14 +82,18 @@ class Qwen2PipelineForwards:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||||
output_hidden_states = False
|
output_hidden_states = False
|
||||||
if use_cache:
|
|
||||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
assert past_key_values is None, "past_key_values is not supported for Qwen2 models at the moment."
|
assert past_key_values is None, "past_key_values is not supported for Qwen2 models at the moment."
|
||||||
|
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
position_ids = torch.arange(
|
position_ids = torch.arange(
|
||||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||||
)
|
)
|
||||||
|
@ -123,18 +127,11 @@ class Qwen2PipelineForwards:
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
(batch_size, seq_length),
|
(batch_size, seq_length),
|
||||||
inputs_embeds,
|
hidden_states,
|
||||||
past_key_values_length,
|
past_key_values_length,
|
||||||
sliding_window=self.config.sliding_window,
|
sliding_window=self.config.sliding_window,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
|
@ -148,20 +145,14 @@ class Qwen2PipelineForwards:
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
def create_custom_forward(module):
|
decoder_layer.__call__,
|
||||||
def custom_forward(*inputs):
|
|
||||||
# None for past_key_value
|
|
||||||
return module(*inputs, output_attentions, None)
|
|
||||||
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(decoder_layer),
|
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
None,
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
|
@ -315,7 +306,7 @@ class Qwen2PipelineForwards:
|
||||||
else:
|
else:
|
||||||
hidden_states = outputs.get("hidden_states")
|
hidden_states = outputs.get("hidden_states")
|
||||||
return {"hidden_states": hidden_states}
|
return {"hidden_states": hidden_states}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def qwen2_for_sequence_classification_forward(
|
def qwen2_for_sequence_classification_forward(
|
||||||
self: Qwen2ForSequenceClassification,
|
self: Qwen2ForSequenceClassification,
|
||||||
|
|
|
@ -151,10 +151,10 @@ class Qwen2Policy(Policy):
|
||||||
module = self.model.model
|
module = self.model.model
|
||||||
|
|
||||||
if stage_manager.is_interleave:
|
if stage_manager.is_interleave:
|
||||||
layers_per_stage = self.distribute_layers(
|
layers_per_stage = stage_manager.distribute_layers(
|
||||||
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
|
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
|
||||||
)
|
)
|
||||||
stage_manager.stage_indices = Policy.get_stage_index(
|
stage_manager.stage_indices = stage_manager.get_stage_index(
|
||||||
layers_per_stage,
|
layers_per_stage,
|
||||||
stage_manager.stage,
|
stage_manager.stage,
|
||||||
num_model_chunks=stage_manager.num_model_chunks,
|
num_model_chunks=stage_manager.num_model_chunks,
|
||||||
|
@ -165,8 +165,8 @@ class Qwen2Policy(Policy):
|
||||||
}
|
}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
stage_index = stage_manager.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
method_replacement = {
|
method_replacement = {
|
||||||
"forward": partial(
|
"forward": partial(
|
||||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||||
|
@ -192,10 +192,10 @@ class Qwen2Policy(Policy):
|
||||||
held_layers = []
|
held_layers = []
|
||||||
if stage_manager.is_interleave:
|
if stage_manager.is_interleave:
|
||||||
assert stage_manager.num_model_chunks is not None
|
assert stage_manager.num_model_chunks is not None
|
||||||
layers_per_stage = self.distribute_layers(
|
layers_per_stage = stage_manager.distribute_layers(
|
||||||
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
|
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
|
||||||
)
|
)
|
||||||
stage_indices = Policy.get_stage_index(
|
stage_indices = stage_manager.get_stage_index(
|
||||||
layers_per_stage,
|
layers_per_stage,
|
||||||
stage_manager.stage,
|
stage_manager.stage,
|
||||||
num_model_chunks=stage_manager.num_model_chunks,
|
num_model_chunks=stage_manager.num_model_chunks,
|
||||||
|
@ -209,10 +209,10 @@ class Qwen2Policy(Policy):
|
||||||
held_layers.append(module.norm)
|
held_layers.append(module.norm)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
held_layers.append(module.embed_tokens)
|
held_layers.append(module.embed_tokens)
|
||||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
held_layers.extend(module.layers[start_idx:end_idx])
|
held_layers.extend(module.layers[start_idx:end_idx])
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
held_layers.append(module.norm)
|
held_layers.append(module.norm)
|
||||||
|
|
|
@ -182,12 +182,6 @@ def run_forward_backward_with_hybrid_plugin(
|
||||||
data_iter = iter([data])
|
data_iter = iter([data])
|
||||||
sharded_output = booster.execute_pipeline(
|
sharded_output = booster.execute_pipeline(
|
||||||
data_iter, sharded_model, _criterion, sharded_optimizer, return_loss=True, return_outputs=True
|
data_iter, sharded_model, _criterion, sharded_optimizer, return_loss=True, return_outputs=True
|
||||||
data_iter,
|
|
||||||
sharded_model,
|
|
||||||
_criterion,
|
|
||||||
sharded_optimizer,
|
|
||||||
return_loss=True,
|
|
||||||
return_outputs=True,
|
|
||||||
)
|
)
|
||||||
sharded_loss = sharded_output["loss"]
|
sharded_loss = sharded_output["loss"]
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue