Browse Source

update qwen model

pull/5844/head
wangbluo 5 months ago
parent
commit
d4ff644ef3
  1. 4
      colossalai/shardformer/modeling/qwen2.py
  2. 2
      colossalai/shardformer/policies/qwen2.py

4
colossalai/shardformer/modeling/qwen2.py

@ -51,7 +51,7 @@ class Qwen2PipelineForwards:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_cache = False#use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@ -592,4 +592,4 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
attentions=outputs.attentions,
)
return forward
return forward

2
colossalai/shardformer/policies/qwen2.py

@ -333,4 +333,4 @@ class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in Qwen2 for sequence classification model"""
return []
return []
Loading…
Cancel
Save