mirror of https://github.com/InternLM/InternLM
fix interface for dense pipeline
parent
b46d1c17af
commit
462b849942
|
@ -335,7 +335,7 @@ class PipelineScheduler(BaseScheduler):
|
||||||
|
|
||||||
self._call_hooks("before_backward", output_obj, output_obj_grad)
|
self._call_hooks("before_backward", output_obj, output_obj_grad)
|
||||||
with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync):
|
with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync):
|
||||||
if moe_loss is None:
|
if moe_loss is None or moe_loss.item() == 0.0:
|
||||||
if output_obj_grad is None:
|
if output_obj_grad is None:
|
||||||
engine.backward(output_obj)
|
engine.backward(output_obj)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -437,7 +437,8 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
|
|
||||||
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
||||||
# attention_mask: compute attention on the places where the value is 1
|
# attention_mask: compute attention on the places where the value is 1
|
||||||
if hasattr(self, "embedding"):
|
# old condition may fail when use shared embedding
|
||||||
|
if gpc.is_pipeline_first_stage():
|
||||||
hidden_states = self.embedding(input_ids)
|
hidden_states = self.embedding(input_ids)
|
||||||
if self.embed_grad_scale != 1:
|
if self.embed_grad_scale != 1:
|
||||||
hidden_states = (
|
hidden_states = (
|
||||||
|
|
Loading…
Reference in New Issue