update qwen model

pull/5844/head
wangbluo 2024-06-20 09:04:57 +00:00
parent dba59354d7
commit d4ff644ef3
2 changed files with 3 additions and 3 deletions

View File

@ -51,7 +51,7 @@ class Qwen2PipelineForwards:
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = False#use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@ -592,4 +592,4 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
attentions=outputs.attentions, attentions=outputs.attentions,
) )
return forward return forward

View File

@ -333,4 +333,4 @@ class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
def get_shared_params(self) -> List[Dict[int, Tensor]]: def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in Qwen2 for sequence classification model""" """No shared params in Qwen2 for sequence classification model"""
return [] return []