diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 668964592..e5b687248 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -42,14 +42,15 @@ class ColoDDP(torch.nn.Module): loss.backward() torch.cuda.current_stream().wait_stream(self.comm_stream) for p in self.module.parameters(): - p.grad = p._saved_grad + if p.grad.device.type != "cpu": + p.grad = p._saved_grad def grad_handle(self, p, grad): - empty_grad = torch.empty_like(grad) - free_storage(empty_grad) - if self.dp_world_size > 1: - grad = grad / self.dp_world_size - if grad.device.type != "cpu": + if grad.device.type != "cpu": + empty_grad = torch.empty_like(grad) + free_storage(empty_grad) + if self.dp_world_size > 1: + grad = grad / self.dp_world_size self.comm_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.comm_stream): group = gpc.get_group(ParallelMode.DATA) @@ -57,12 +58,13 @@ class ColoDDP(torch.nn.Module): ColoDDP._save_grad(p, grad) grad.record_stream(self.comm_stream) else: - group = gpc.get_cpu_group(ParallelMode.DATA) - dist.all_reduce(grad, group=group) ColoDDP._save_grad(p, grad) + return empty_grad + else: - ColoDDP._save_grad(p, grad) - return empty_grad + group = gpc.get_cpu_group(ParallelMode.DATA) + dist.all_reduce(grad, group=group) + return grad @staticmethod def _save_grad(p, grad): diff --git a/tests/test_tensor/test_hybrid_device.py b/tests/test_tensor/test_hybrid_device.py index a6d5cf14f..cb63b2152 100644 --- a/tests/test_tensor/test_hybrid_device.py +++ b/tests/test_tensor/test_hybrid_device.py @@ -9,6 +9,7 @@ from colossalai.context import ParallelMode from colossalai.nn.parallel.layers import init_colo_module from colossalai.nn.parallel.data_parallel import ColoDDP +from colossalai.nn.optimizer import ColoOptimizer import colossalai import torch @@ -56,10 +57,11 @@ def run_hybrid_device(use_ddp): print(f'embedding weight size: {real_model.embed.weight.size()} | new device: {real_model.embed.weight.device}') #print(f'linear weight size: {real_model.proj.weight.size()} | new device: {real_model.proj.weight.device}') + optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1) data = torch.randint(low=0, high=20, size=(16,), device=get_current_device()) out = model(data) out.sum().backward() - + optimizer.step() def run_dist(rank, world_size, port, use_ddp): if use_ddp and world_size == 1: @@ -81,4 +83,4 @@ def _test_hybrid_device(world_size, use_ddp): if __name__ == '__main__': - _test_hybrid_device(1, False) + _test_hybrid_device(4, True)