|
|
|
@ -15,6 +15,7 @@ from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_
|
|
|
|
|
from colossalai.interface import ModelWrapper |
|
|
|
|
from colossalai.lazy import LazyTensor |
|
|
|
|
from colossalai.logging import get_dist_logger |
|
|
|
|
from colossalai.quantization.fp8_hook import FP8Hook |
|
|
|
|
from colossalai.tensor.colo_parameter import ColoParameter |
|
|
|
|
from colossalai.tensor.d_tensor import ( |
|
|
|
|
distribute_tensor, |
|
|
|
@ -99,6 +100,7 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
|
verbose: bool = False, |
|
|
|
|
enable_async_reduce: bool = True, |
|
|
|
|
fp8_communication: bool = False, |
|
|
|
|
use_fp8: bool = False, |
|
|
|
|
) -> None: |
|
|
|
|
assert mixed_precision in (torch.float16, torch.bfloat16) |
|
|
|
|
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False |
|
|
|
@ -138,6 +140,9 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
|
) |
|
|
|
|
self.force_outputs_fp32 = force_outputs_fp32 |
|
|
|
|
self.param_op_hook = GeminiZeROHook(self.gemini_manager) |
|
|
|
|
self.hooks = [self.param_op_hook] |
|
|
|
|
if use_fp8: |
|
|
|
|
self.hooks.append(FP8Hook()) |
|
|
|
|
self.fp32_params: List[torch.Tensor] = list() |
|
|
|
|
self.fp16_params: List[ColoParameter] = list() |
|
|
|
|
self.grads_device: Dict[torch.Tensor, torch.device] = dict() |
|
|
|
@ -310,7 +315,7 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
|
outputs = self._inference_forward(*args, **kwargs) |
|
|
|
|
else: |
|
|
|
|
self.gemini_manager.pre_iter(*args) |
|
|
|
|
with ColoParamOpHookManager.use_hooks(self.param_op_hook): |
|
|
|
|
with ColoParamOpHookManager.use_hooks(*self.hooks): |
|
|
|
|
outputs = self.module(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
if self.force_outputs_fp32: |
|
|
|
@ -319,7 +324,7 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
|
|
|
|
|
|
def _inference_forward(self, *args, **kwargs): |
|
|
|
|
"""This function is only triggered for inference.""" |
|
|
|
|
fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook) |
|
|
|
|
fwd_ctx = ColoParamOpHookManager.use_hooks(*self.hooks) |
|
|
|
|
if not self.scatter_after_inference: |
|
|
|
|
# gather all chunks |
|
|
|
|
for chunk in self.chunk_manager.get_chunks(self.fp16_params): |
|
|
|
@ -372,7 +377,7 @@ class GeminiDDP(ModelWrapper):
|
|
|
|
|
|
|
|
|
|
def backward(self, loss: torch.Tensor): |
|
|
|
|
self._pre_backward() |
|
|
|
|
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): |
|
|
|
|
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(*self.hooks): |
|
|
|
|
loss.backward() |
|
|
|
|
self._post_backward() |
|
|
|
|
|
|
|
|
|