|
|
@ -387,6 +387,7 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
p: nn.Parameter,
|
|
|
|
p: nn.Parameter,
|
|
|
|
async_reduce_stream: Optional[torch.cuda.Stream] = None,
|
|
|
|
async_reduce_stream: Optional[torch.cuda.Stream] = None,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
|
|
|
|
async_reduce_scatter = async_reduce_stream is not None
|
|
|
|
setattr(p, "_gemini_reduced", True)
|
|
|
|
setattr(p, "_gemini_reduced", True)
|
|
|
|
empty_grad = torch.empty_like(grad)
|
|
|
|
empty_grad = torch.empty_like(grad)
|
|
|
|
free_storage(empty_grad)
|
|
|
|
free_storage(empty_grad)
|
|
|
@ -426,7 +427,7 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
async_reduce_stream.wait_stream(torch.cuda.current_stream())
|
|
|
|
async_reduce_stream.wait_stream(torch.cuda.current_stream())
|
|
|
|
|
|
|
|
|
|
|
|
with torch.cuda.stream(async_reduce_stream):
|
|
|
|
with torch.cuda.stream(async_reduce_stream):
|
|
|
|
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=(async_reduce_stream is not None))
|
|
|
|
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce_scatter)
|
|
|
|
if reduced:
|
|
|
|
if reduced:
|
|
|
|
grad_chunk.wait_async_reduce()
|
|
|
|
grad_chunk.wait_async_reduce()
|
|
|
|
if not chunk_manager.reuse_fp16_chunk:
|
|
|
|
if not chunk_manager.reuse_fp16_chunk:
|
|
|
@ -447,9 +448,13 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
|
|
|
|
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
|
|
|
|
if chunk.l2_norm_flag:
|
|
|
|
if chunk.l2_norm_flag:
|
|
|
|
grad_chunk.set_l2_norm()
|
|
|
|
grad_chunk.set_l2_norm()
|
|
|
|
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
|
|
|
|
chunk_manager.move_chunk(
|
|
|
|
|
|
|
|
grad_chunk, grads_device[p], force_copy=True, async_move=async_reduce_scatter
|
|
|
|
|
|
|
|
)
|
|
|
|
if not (master_weights) or (enable_gradient_accumulation):
|
|
|
|
if not (master_weights) or (enable_gradient_accumulation):
|
|
|
|
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
|
|
|
|
chunk_manager.move_chunk(
|
|
|
|
|
|
|
|
chunk, grads_device[p], force_copy=True, async_move=async_reduce_scatter
|
|
|
|
|
|
|
|
)
|
|
|
|
return empty_grad
|
|
|
|
return empty_grad
|
|
|
|
|
|
|
|
|
|
|
|
def zero_grad(self, set_to_none: bool = False) -> None:
|
|
|
|
def zero_grad(self, set_to_none: bool = False) -> None:
|
|
|
|