[hotfix] fix opt pipeline (#4293)

* opt forward and test

* pause

* finish opt model pipeline

* finish opt pipeline

* opt forward and test

* pause

* finish opt model pipeline

* finish opt pipeline

* fix opt

* set transformers version

* refactor the test pipeline

* fix bug
pull/4445/head
Jianghai 2023-07-20 17:21:28 +08:00 committed by Hongxin Liu
parent d8408d185c
commit 0a8f3c851a
1 changed files with 3 additions and 2 deletions

View File

@ -12,6 +12,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .._utils import getattr_, setattr_
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@ -198,8 +199,8 @@ class OPTForCausalLMPolicy(OPTPolicy):
def get_shared_params(self) -> List[Dict[int, Tensor]]:
opt_model = self.model
num_stages = self.pipeline_stage_manager.num_stages
if self.pipeline_stage_manager and num_stages > 1:
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
num_stages = self.pipeline_stage_manager.num_stages
if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight):
return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}]