|
|
|
@ -145,6 +145,12 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
|
self.extra_dp_group = extra_dp_group
|
|
|
|
|
|
|
|
|
|
self.master_weights = master_weights
|
|
|
|
|
self.enable_async_reduce = enable_async_reduce
|
|
|
|
|
|
|
|
|
|
if enable_async_reduce:
|
|
|
|
|
self.async_reduce_stream = torch.cuda.Stream()
|
|
|
|
|
else:
|
|
|
|
|
self.async_reduce_stream = None
|
|
|
|
|
|
|
|
|
|
self._logger = get_dist_logger()
|
|
|
|
|
|
|
|
|
@ -174,6 +180,7 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
|
super().__init__(module)
|
|
|
|
|
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
|
|
|
|
|
self._cast_buffers()
|
|
|
|
|
|
|
|
|
|
# register grad hook
|
|
|
|
|
for p in module.parameters():
|
|
|
|
|
if is_ddp_ignored(p):
|
|
|
|
@ -189,7 +196,7 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
|
master_weights=self.master_weights,
|
|
|
|
|
enable_gradient_accumulation=self.enable_gradient_accumulation,
|
|
|
|
|
p=p,
|
|
|
|
|
async_reduce=enable_async_reduce,
|
|
|
|
|
async_reduce_stream=self.async_reduce_stream,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -337,10 +344,8 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
|
setattr(param, "_gemini_reduced", False)
|
|
|
|
|
|
|
|
|
|
def _post_backward(self):
|
|
|
|
|
for param in self.param2name:
|
|
|
|
|
if hasattr(param, "_release_grad_chunk_cb"):
|
|
|
|
|
param._release_grad_chunk_cb()
|
|
|
|
|
delattr(param, "_release_grad_chunk_cb")
|
|
|
|
|
if self.enable_async_reduce:
|
|
|
|
|
self.async_reduce_stream.synchronize()
|
|
|
|
|
|
|
|
|
|
if self.chunk_manager.accessed_mem != 0:
|
|
|
|
|
error_params = ["Reduction failed at followed parameters:"]
|
|
|
|
@ -379,7 +384,7 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
|
master_weights: bool,
|
|
|
|
|
enable_gradient_accumulation: bool,
|
|
|
|
|
p: nn.Parameter,
|
|
|
|
|
async_reduce: bool,
|
|
|
|
|
async_reduce_stream: Optional[torch.cuda.Stream] = None,
|
|
|
|
|
):
|
|
|
|
|
setattr(p, "_gemini_reduced", True)
|
|
|
|
|
empty_grad = torch.empty_like(grad)
|
|
|
|
@ -415,56 +420,31 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
|
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk)
|
|
|
|
|
else:
|
|
|
|
|
grad_chunk.add_tensor_to_chunk_slice(p, grad)
|
|
|
|
|
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce)
|
|
|
|
|
if reduced: # if not async, can release immediately, else release in when work finished
|
|
|
|
|
if async_reduce:
|
|
|
|
|
# dirty fix by installing callback
|
|
|
|
|
assert not hasattr(p, "_release_grad_chunk_cb")
|
|
|
|
|
|
|
|
|
|
def _release_grad_chunk_cb():
|
|
|
|
|
grad_chunk.wait_async_reduce()
|
|
|
|
|
GeminiDDP.release_grad_chunk_handle(
|
|
|
|
|
chunk_manager,
|
|
|
|
|
grads_device,
|
|
|
|
|
master_weights,
|
|
|
|
|
enable_gradient_accumulation,
|
|
|
|
|
p,
|
|
|
|
|
chunk,
|
|
|
|
|
grad_chunk,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
p._release_grad_chunk_cb = _release_grad_chunk_cb
|
|
|
|
|
else:
|
|
|
|
|
GeminiDDP.release_grad_chunk_handle(
|
|
|
|
|
chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk
|
|
|
|
|
)
|
|
|
|
|
return empty_grad
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def release_grad_chunk_handle(
|
|
|
|
|
chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk
|
|
|
|
|
):
|
|
|
|
|
if not chunk_manager.reuse_fp16_chunk:
|
|
|
|
|
if chunk.keep_gathered:
|
|
|
|
|
chunk_manager.fake_release_chunk(chunk)
|
|
|
|
|
else:
|
|
|
|
|
chunk_manager.release_chunk(chunk)
|
|
|
|
|
if grad_chunk.is_gathered:
|
|
|
|
|
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
|
|
|
|
|
if chunk.extra_dp_group is not None:
|
|
|
|
|
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
|
|
|
|
|
else:
|
|
|
|
|
grad_chunk.cuda_shard.div_(chunk.pg_size)
|
|
|
|
|
if chunk.extra_dp_group is not None:
|
|
|
|
|
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
|
|
|
|
|
# check overflow elements
|
|
|
|
|
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
|
|
|
|
|
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
|
|
|
|
|
if chunk.l2_norm_flag:
|
|
|
|
|
grad_chunk.set_l2_norm()
|
|
|
|
|
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
|
|
|
|
|
if not (master_weights) or (enable_gradient_accumulation):
|
|
|
|
|
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
|
|
|
|
|
with torch.cuda.stream(async_reduce_stream):
|
|
|
|
|
chunk_manager.reduce_chunk(grad_chunk)
|
|
|
|
|
|
|
|
|
|
if not chunk_manager.reuse_fp16_chunk:
|
|
|
|
|
if chunk.keep_gathered:
|
|
|
|
|
chunk_manager.fake_release_chunk(chunk)
|
|
|
|
|
else:
|
|
|
|
|
chunk_manager.release_chunk(chunk)
|
|
|
|
|
if grad_chunk.is_gathered:
|
|
|
|
|
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
|
|
|
|
|
if chunk.extra_dp_group is not None:
|
|
|
|
|
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
|
|
|
|
|
else:
|
|
|
|
|
grad_chunk.cuda_shard.div_(chunk.pg_size)
|
|
|
|
|
if chunk.extra_dp_group is not None:
|
|
|
|
|
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
|
|
|
|
|
# check overflow elements
|
|
|
|
|
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
|
|
|
|
|
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
|
|
|
|
|
if chunk.l2_norm_flag:
|
|
|
|
|
grad_chunk.set_l2_norm()
|
|
|
|
|
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
|
|
|
|
|
if not (master_weights) or (enable_gradient_accumulation):
|
|
|
|
|
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
|
|
|
|
|
|
|
|
|
|
def zero_grad(self, set_to_none: bool = False) -> None:
|
|
|
|
|
self.module.zero_grad(set_to_none=True)
|
|
|
|
|