mirror of https://github.com/hpcaitech/ColossalAI
[Tensor] fix optimizer for CPU parallel (#1069)
parent
49832b2344
commit
4fc748f69b
|
@ -42,14 +42,15 @@ class ColoDDP(torch.nn.Module):
|
||||||
loss.backward()
|
loss.backward()
|
||||||
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
p.grad = p._saved_grad
|
if p.grad.device.type != "cpu":
|
||||||
|
p.grad = p._saved_grad
|
||||||
|
|
||||||
def grad_handle(self, p, grad):
|
def grad_handle(self, p, grad):
|
||||||
empty_grad = torch.empty_like(grad)
|
if grad.device.type != "cpu":
|
||||||
free_storage(empty_grad)
|
empty_grad = torch.empty_like(grad)
|
||||||
if self.dp_world_size > 1:
|
free_storage(empty_grad)
|
||||||
grad = grad / self.dp_world_size
|
if self.dp_world_size > 1:
|
||||||
if grad.device.type != "cpu":
|
grad = grad / self.dp_world_size
|
||||||
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
||||||
with torch.cuda.stream(self.comm_stream):
|
with torch.cuda.stream(self.comm_stream):
|
||||||
group = gpc.get_group(ParallelMode.DATA)
|
group = gpc.get_group(ParallelMode.DATA)
|
||||||
|
@ -57,12 +58,13 @@ class ColoDDP(torch.nn.Module):
|
||||||
ColoDDP._save_grad(p, grad)
|
ColoDDP._save_grad(p, grad)
|
||||||
grad.record_stream(self.comm_stream)
|
grad.record_stream(self.comm_stream)
|
||||||
else:
|
else:
|
||||||
group = gpc.get_cpu_group(ParallelMode.DATA)
|
|
||||||
dist.all_reduce(grad, group=group)
|
|
||||||
ColoDDP._save_grad(p, grad)
|
ColoDDP._save_grad(p, grad)
|
||||||
|
return empty_grad
|
||||||
|
|
||||||
else:
|
else:
|
||||||
ColoDDP._save_grad(p, grad)
|
group = gpc.get_cpu_group(ParallelMode.DATA)
|
||||||
return empty_grad
|
dist.all_reduce(grad, group=group)
|
||||||
|
return grad
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _save_grad(p, grad):
|
def _save_grad(p, grad):
|
||||||
|
|
|
@ -9,6 +9,7 @@ from colossalai.context import ParallelMode
|
||||||
|
|
||||||
from colossalai.nn.parallel.layers import init_colo_module
|
from colossalai.nn.parallel.layers import init_colo_module
|
||||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||||
|
from colossalai.nn.optimizer import ColoOptimizer
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
import torch
|
import torch
|
||||||
|
@ -56,10 +57,11 @@ def run_hybrid_device(use_ddp):
|
||||||
print(f'embedding weight size: {real_model.embed.weight.size()} | new device: {real_model.embed.weight.device}')
|
print(f'embedding weight size: {real_model.embed.weight.size()} | new device: {real_model.embed.weight.device}')
|
||||||
#print(f'linear weight size: {real_model.proj.weight.size()} | new device: {real_model.proj.weight.device}')
|
#print(f'linear weight size: {real_model.proj.weight.size()} | new device: {real_model.proj.weight.device}')
|
||||||
|
|
||||||
|
optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
|
||||||
data = torch.randint(low=0, high=20, size=(16,), device=get_current_device())
|
data = torch.randint(low=0, high=20, size=(16,), device=get_current_device())
|
||||||
out = model(data)
|
out = model(data)
|
||||||
out.sum().backward()
|
out.sum().backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
def run_dist(rank, world_size, port, use_ddp):
|
def run_dist(rank, world_size, port, use_ddp):
|
||||||
if use_ddp and world_size == 1:
|
if use_ddp and world_size == 1:
|
||||||
|
@ -81,4 +83,4 @@ def _test_hybrid_device(world_size, use_ddp):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
_test_hybrid_device(1, False)
|
_test_hybrid_device(4, True)
|
||||||
|
|
Loading…
Reference in New Issue