[Tensor] fix optimizer for CPU parallel (#1069)

pull/1072/head
Ziyue Jiang 2022-06-06 17:36:11 +08:00 committed by GitHub
parent 49832b2344
commit 4fc748f69b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 12 deletions

View File

@ -42,14 +42,15 @@ class ColoDDP(torch.nn.Module):
loss.backward() loss.backward()
torch.cuda.current_stream().wait_stream(self.comm_stream) torch.cuda.current_stream().wait_stream(self.comm_stream)
for p in self.module.parameters(): for p in self.module.parameters():
p.grad = p._saved_grad if p.grad.device.type != "cpu":
p.grad = p._saved_grad
def grad_handle(self, p, grad): def grad_handle(self, p, grad):
empty_grad = torch.empty_like(grad) if grad.device.type != "cpu":
free_storage(empty_grad) empty_grad = torch.empty_like(grad)
if self.dp_world_size > 1: free_storage(empty_grad)
grad = grad / self.dp_world_size if self.dp_world_size > 1:
if grad.device.type != "cpu": grad = grad / self.dp_world_size
self.comm_stream.wait_stream(torch.cuda.current_stream()) self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream): with torch.cuda.stream(self.comm_stream):
group = gpc.get_group(ParallelMode.DATA) group = gpc.get_group(ParallelMode.DATA)
@ -57,12 +58,13 @@ class ColoDDP(torch.nn.Module):
ColoDDP._save_grad(p, grad) ColoDDP._save_grad(p, grad)
grad.record_stream(self.comm_stream) grad.record_stream(self.comm_stream)
else: else:
group = gpc.get_cpu_group(ParallelMode.DATA)
dist.all_reduce(grad, group=group)
ColoDDP._save_grad(p, grad) ColoDDP._save_grad(p, grad)
return empty_grad
else: else:
ColoDDP._save_grad(p, grad) group = gpc.get_cpu_group(ParallelMode.DATA)
return empty_grad dist.all_reduce(grad, group=group)
return grad
@staticmethod @staticmethod
def _save_grad(p, grad): def _save_grad(p, grad):

View File

@ -9,6 +9,7 @@ from colossalai.context import ParallelMode
from colossalai.nn.parallel.layers import init_colo_module from colossalai.nn.parallel.layers import init_colo_module
from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.nn.optimizer import ColoOptimizer
import colossalai import colossalai
import torch import torch
@ -56,10 +57,11 @@ def run_hybrid_device(use_ddp):
print(f'embedding weight size: {real_model.embed.weight.size()} | new device: {real_model.embed.weight.device}') print(f'embedding weight size: {real_model.embed.weight.size()} | new device: {real_model.embed.weight.device}')
#print(f'linear weight size: {real_model.proj.weight.size()} | new device: {real_model.proj.weight.device}') #print(f'linear weight size: {real_model.proj.weight.size()} | new device: {real_model.proj.weight.device}')
optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
data = torch.randint(low=0, high=20, size=(16,), device=get_current_device()) data = torch.randint(low=0, high=20, size=(16,), device=get_current_device())
out = model(data) out = model(data)
out.sum().backward() out.sum().backward()
optimizer.step()
def run_dist(rank, world_size, port, use_ddp): def run_dist(rank, world_size, port, use_ddp):
if use_ddp and world_size == 1: if use_ddp and world_size == 1:
@ -81,4 +83,4 @@ def _test_hybrid_device(world_size, use_ddp):
if __name__ == '__main__': if __name__ == '__main__':
_test_hybrid_device(1, False) _test_hybrid_device(4, True)