|
|
|
@ -14,7 +14,7 @@ from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
|
|
|
|
from colossalai.tensor import ProcessGroup as ColoProcessGroup |
|
|
|
|
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec |
|
|
|
|
from colossalai.tensor.param_op_hook import ColoParamOpHookManager |
|
|
|
|
from colossalai.utils import get_current_device |
|
|
|
|
from colossalai.utils import get_current_device, is_ddp_ignored |
|
|
|
|
from colossalai.zero.utils.gemini_hook import GeminiZeROHook |
|
|
|
|
|
|
|
|
|
from .reducer import Reducer |
|
|
|
@ -81,7 +81,7 @@ class ColoDDP(torch.nn.Module):
|
|
|
|
|
self.reducer = Reducer(bucket_cap_mb) |
|
|
|
|
self.rebuild_bucket = rebuild_bucket |
|
|
|
|
for p in module.parameters(): |
|
|
|
|
if getattr(p, '_ddp_to_ignore', False): |
|
|
|
|
if is_ddp_ignored(p): |
|
|
|
|
continue |
|
|
|
|
if p.requires_grad: |
|
|
|
|
p.register_hook(partial(self.grad_handle, p)) |
|
|
|
@ -116,7 +116,7 @@ class ColoDDP(torch.nn.Module):
|
|
|
|
|
if self.rebuild_bucket: |
|
|
|
|
self.reducer.free() |
|
|
|
|
for p in self.module.parameters(): |
|
|
|
|
if getattr(p, '_ddp_to_ignore', False): |
|
|
|
|
if is_ddp_ignored(p): |
|
|
|
|
continue |
|
|
|
|
if p.grad.device.type != "cpu": |
|
|
|
|
p.grad = p._saved_grad |
|
|
|
@ -232,7 +232,7 @@ class ZeroDDP(ColoDDP):
|
|
|
|
|
for p in param_order.generate(): |
|
|
|
|
assert isinstance(p, ColoParameter) |
|
|
|
|
|
|
|
|
|
if getattr(p, '_ddp_to_ignore', False): |
|
|
|
|
if is_ddp_ignored(p): |
|
|
|
|
p.data = p.data.half() |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
@ -256,7 +256,7 @@ class ZeroDDP(ColoDDP):
|
|
|
|
|
self.chunk_manager.close_all_groups() |
|
|
|
|
self._cast_buffers() |
|
|
|
|
|
|
|
|
|
params_list = [p for p in param_order.generate() if not getattr(p, '_ddp_to_ignore', False)] |
|
|
|
|
params_list = [p for p in param_order.generate() if not is_ddp_ignored(p)] |
|
|
|
|
for p, fp32_p in zip(params_list, self.fp32_params): |
|
|
|
|
chunk_16 = self.chunk_manager.get_chunk(p) |
|
|
|
|
chunk_32 = self.chunk_manager.get_chunk(fp32_p) |
|
|
|
@ -303,7 +303,7 @@ class ZeroDDP(ColoDDP):
|
|
|
|
|
|
|
|
|
|
def _setup_grads_ptr(self): |
|
|
|
|
for p in self.module.parameters(): |
|
|
|
|
if getattr(p, '_ddp_to_ignore', False): |
|
|
|
|
if is_ddp_ignored(p): |
|
|
|
|
continue |
|
|
|
|
p.grad = None |
|
|
|
|
|
|
|
|
|