|
|
|
@ -42,14 +42,15 @@ class ColoDDP(torch.nn.Module):
|
|
|
|
|
loss.backward() |
|
|
|
|
torch.cuda.current_stream().wait_stream(self.comm_stream) |
|
|
|
|
for p in self.module.parameters(): |
|
|
|
|
p.grad = p._saved_grad |
|
|
|
|
if p.grad.device.type != "cpu": |
|
|
|
|
p.grad = p._saved_grad |
|
|
|
|
|
|
|
|
|
def grad_handle(self, p, grad): |
|
|
|
|
empty_grad = torch.empty_like(grad) |
|
|
|
|
free_storage(empty_grad) |
|
|
|
|
if self.dp_world_size > 1: |
|
|
|
|
grad = grad / self.dp_world_size |
|
|
|
|
if grad.device.type != "cpu": |
|
|
|
|
if grad.device.type != "cpu": |
|
|
|
|
empty_grad = torch.empty_like(grad) |
|
|
|
|
free_storage(empty_grad) |
|
|
|
|
if self.dp_world_size > 1: |
|
|
|
|
grad = grad / self.dp_world_size |
|
|
|
|
self.comm_stream.wait_stream(torch.cuda.current_stream()) |
|
|
|
|
with torch.cuda.stream(self.comm_stream): |
|
|
|
|
group = gpc.get_group(ParallelMode.DATA) |
|
|
|
@ -57,12 +58,13 @@ class ColoDDP(torch.nn.Module):
|
|
|
|
|
ColoDDP._save_grad(p, grad) |
|
|
|
|
grad.record_stream(self.comm_stream) |
|
|
|
|
else: |
|
|
|
|
group = gpc.get_cpu_group(ParallelMode.DATA) |
|
|
|
|
dist.all_reduce(grad, group=group) |
|
|
|
|
ColoDDP._save_grad(p, grad) |
|
|
|
|
return empty_grad |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
ColoDDP._save_grad(p, grad) |
|
|
|
|
return empty_grad |
|
|
|
|
group = gpc.get_cpu_group(ParallelMode.DATA) |
|
|
|
|
dist.all_reduce(grad, group=group) |
|
|
|
|
return grad |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def _save_grad(p, grad): |
|
|
|
|