fix interface for dense pipeline

pull/182/head
Wenwen Qu 2023-09-15 12:12:45 +08:00
parent b46d1c17af
commit 462b849942
2 changed files with 3 additions and 2 deletions

View File

@ -335,7 +335,7 @@ class PipelineScheduler(BaseScheduler):
self._call_hooks("before_backward", output_obj, output_obj_grad)
with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync):
if moe_loss is None:
if moe_loss is None or moe_loss.item() == 0.0:
if output_obj_grad is None:
engine.backward(output_obj)
else:

View File

@ -437,7 +437,8 @@ class PackedFlashInternLm1D(nn.Module):
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
# attention_mask: compute attention on the places where the value is 1
if hasattr(self, "embedding"):
# old condition may fail when use shared embedding
if gpc.is_pipeline_first_stage():
hidden_states = self.embedding(input_ids)
if self.embed_grad_scale != 1:
hidden_states = (