mirror of https://github.com/hpcaitech/ColossalAI
[ddp] add set_params_to_ignore for ColoDDP (#1122)
* add set_params_to_ignore for ColoDDP * polish code * fix zero hook v2 * add unit test * polish docstrpull/1128/head
parent
3175bcb4d8
commit
f0a954f16d
|
@ -7,7 +7,7 @@ from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
|||
from colossalai.tensor.chunk import TensorState, Chunk
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from typing import Dict
|
||||
from typing import Dict, Iterable
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
|
@ -38,6 +38,8 @@ class ColoDDP(torch.nn.Module):
|
|||
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
||||
self.dp_world_size = gpc.get_world_size(ParallelMode.DATA)
|
||||
for p in module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
continue
|
||||
if p.requires_grad:
|
||||
p.register_hook(partial(self.grad_handle, p))
|
||||
|
||||
|
@ -55,6 +57,8 @@ class ColoDDP(torch.nn.Module):
|
|||
loss.backward()
|
||||
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
||||
for p in self.module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
continue
|
||||
if p.grad.device.type != "cpu":
|
||||
p.grad = p._saved_grad
|
||||
|
||||
|
@ -99,6 +103,25 @@ class ColoDDP(torch.nn.Module):
|
|||
p._saved_grad.requires_grad_(False)
|
||||
p._saved_grad.zero_()
|
||||
|
||||
@staticmethod
|
||||
def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None:
|
||||
"""Sets parameters to be ignored by DDP.
|
||||
This method must be called before initializing ColoDDP.
|
||||
|
||||
Example::
|
||||
>>> params_to_ignore = []
|
||||
>>> for p in module.parameters():
|
||||
>>> if should_ignore(p):
|
||||
>>> params_to_ignore.append(p)
|
||||
>>> ColoDDP.set_params_to_ignore(params_to_ignore)
|
||||
>>> module = ColoDDP(module)
|
||||
|
||||
Args:
|
||||
params_to_ignore (Iterable[torch.Tensor]): A list of parameters to be ignored.
|
||||
"""
|
||||
for p in params_to_ignore:
|
||||
p._ddp_to_ignore = True
|
||||
|
||||
|
||||
class ColoDDPV2(ColoDDP):
|
||||
|
||||
|
@ -114,6 +137,8 @@ class ColoDDPV2(ColoDDP):
|
|||
self.chunk_manager.create_group('fp32_param')
|
||||
# TODO: get param order and filter unused params
|
||||
for p in module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
continue
|
||||
assert p.dtype == torch.half
|
||||
fp32_p = p.float().detach()
|
||||
self.chunk_manager.append_tensor(p, 'fp16_param')
|
||||
|
@ -133,6 +158,8 @@ class ColoDDPV2(ColoDDP):
|
|||
|
||||
def _setup_grads_ptr(self):
|
||||
for p in self.module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
continue
|
||||
if self.chunk_manager.get_chunk(p).is_empty or not p.requires_grad:
|
||||
p.grad = None
|
||||
else:
|
||||
|
|
|
@ -22,6 +22,7 @@ class ZeROHookV2(ParamOpHook):
|
|||
self._training_phase = TrainingPhase.FORWARD
|
||||
|
||||
def pre_op(self, params):
|
||||
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
|
||||
chunks = self._chunk_manager.get_chunks(params)
|
||||
for p in params:
|
||||
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
||||
|
@ -33,6 +34,7 @@ class ZeROHookV2(ParamOpHook):
|
|||
self._gemini_manager.sample_model_data()
|
||||
|
||||
def post_op(self, params):
|
||||
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
|
||||
for p in params:
|
||||
tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD
|
||||
self._chunk_manager.trans_tensor_state(p, tensor_state)
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
import pytest
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import ChunkManager
|
||||
from functools import partial
|
||||
from colossalai.nn.parallel import ColoDDP, ColoDDPV2
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from typing import Callable
|
||||
import torch.distributed as dist
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def init_ddp(module: torch.nn.Module) -> ColoDDP:
|
||||
return ColoDDP(module)
|
||||
|
||||
|
||||
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ColoDDPV2:
|
||||
chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size)
|
||||
gemini_manager = GeminiManager('cuda', chunk_manager)
|
||||
return ColoDDPV2(module, gemini_manager)
|
||||
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(3, 3, bias=False)
|
||||
self.fc2 = torch.nn.Linear(3, 1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(self.fc1(x))
|
||||
|
||||
|
||||
def run_fwd_bwd(ddp_cls: ColoDDP, init_ddp_func: Callable[[torch.nn.Module], ColoDDP]):
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = Net().cuda()
|
||||
w1 = model.fc1.weight
|
||||
w2 = model.fc2.weight
|
||||
ddp_cls.set_params_to_ignore([w2])
|
||||
model = init_ddp_func(model)
|
||||
x = torch.rand(2, 3, device=get_current_device())
|
||||
logits = model(x)
|
||||
loss = torch.sum(logits)
|
||||
model.backward(loss)
|
||||
w1_grads = [torch.empty_like(w1) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(w1_grads, w1.grad)
|
||||
assert torch.equal(w1_grads[0], w1_grads[1])
|
||||
w2_grads = [torch.empty_like(w2) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(w2_grads, w2.grad)
|
||||
assert not torch.equal(w2_grads[0], w2_grads[1])
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
set_seed(dist.get_rank())
|
||||
run_fwd_bwd(ColoDDP, init_ddp)
|
||||
run_fwd_bwd(ColoDDPV2, partial(init_ddpv2, use_chunk=False))
|
||||
run_fwd_bwd(ColoDDPV2, partial(init_ddpv2, use_chunk=True))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_ddp_ignore_params(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_ddp_ignore_params(2)
|
Loading…
Reference in New Issue