|
|
|
@ -15,6 +15,7 @@ from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_
|
|
|
|
|
from colossalai.interface import ModelWrapper
|
|
|
|
|
from colossalai.lazy import LazyTensor
|
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
|
from colossalai.quantization.fp8_hook import FP8Hook
|
|
|
|
|
from colossalai.tensor.colo_parameter import ColoParameter
|
|
|
|
|
from colossalai.tensor.d_tensor import (
|
|
|
|
|
distribute_tensor,
|
|
|
|
@ -99,6 +100,7 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
|
verbose: bool = False,
|
|
|
|
|
enable_async_reduce: bool = True,
|
|
|
|
|
fp8_communication: bool = False,
|
|
|
|
|
use_fp8: bool = False,
|
|
|
|
|
) -> None:
|
|
|
|
|
assert mixed_precision in (torch.float16, torch.bfloat16)
|
|
|
|
|
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
|
|
|
|
@ -138,6 +140,9 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
|
)
|
|
|
|
|
self.force_outputs_fp32 = force_outputs_fp32
|
|
|
|
|
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
|
|
|
|
self.hooks = [self.param_op_hook]
|
|
|
|
|
if use_fp8:
|
|
|
|
|
self.hooks.append(FP8Hook())
|
|
|
|
|
self.fp32_params: List[torch.Tensor] = list()
|
|
|
|
|
self.fp16_params: List[ColoParameter] = list()
|
|
|
|
|
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
|
|
|
@ -310,7 +315,7 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
|
outputs = self._inference_forward(*args, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
self.gemini_manager.pre_iter(*args)
|
|
|
|
|
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
|
|
|
|
with ColoParamOpHookManager.use_hooks(*self.hooks):
|
|
|
|
|
outputs = self.module(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
if self.force_outputs_fp32:
|
|
|
|
@ -319,7 +324,7 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
|
|
|
|
|
|
def _inference_forward(self, *args, **kwargs):
|
|
|
|
|
"""This function is only triggered for inference."""
|
|
|
|
|
fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook)
|
|
|
|
|
fwd_ctx = ColoParamOpHookManager.use_hooks(*self.hooks)
|
|
|
|
|
if not self.scatter_after_inference:
|
|
|
|
|
# gather all chunks
|
|
|
|
|
for chunk in self.chunk_manager.get_chunks(self.fp16_params):
|
|
|
|
@ -372,7 +377,7 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
|
|
|
|
|
|
def backward(self, loss: torch.Tensor):
|
|
|
|
|
self._pre_backward()
|
|
|
|
|
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
|
|
|
|
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(*self.hooks):
|
|
|
|
|
loss.backward()
|
|
|
|
|
self._post_backward()
|
|
|
|
|
|
|
|
|
|