|
|
|
@ -291,13 +291,17 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|
|
|
|
GPT2LMHeadModel: ModulePolicyDescription( |
|
|
|
|
sub_module_replacement=[ |
|
|
|
|
SubModuleReplacementDescription( |
|
|
|
|
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": not self.shard_config.parallel_output} |
|
|
|
|
suffix="lm_head", |
|
|
|
|
target_module=col_nn.Linear1D_Col, |
|
|
|
|
kwargs={"gather_output": not self.shard_config.parallel_output}, |
|
|
|
|
) |
|
|
|
|
], |
|
|
|
|
) |
|
|
|
|
} |
|
|
|
|
if self.shard_config.parallel_output: |
|
|
|
|
addon_module[GPT2LMHeadModel].method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)} |
|
|
|
|
addon_module[GPT2LMHeadModel].method_replacement = { |
|
|
|
|
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) |
|
|
|
|
} |
|
|
|
|
module_policy.update(addon_module) |
|
|
|
|
|
|
|
|
|
if self.pipeline_stage_manager is not None: |
|
|
|
|