From 51f916b11d87ecdfa3763da7a6b396a030b32b13 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Jul 2024 07:33:44 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/pipeline/schedule/interleaved_pp.py | 3 ++- colossalai/pipeline/schedule/one_f_one_b.py | 3 ++- colossalai/quantization/fp8.py | 12 +++++------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 86ce536d0..a7571c731 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -11,8 +11,8 @@ from colossalai.accelerator import get_accelerator from colossalai.interface import OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline from colossalai.utils import get_current_device -from colossalai.quantization.fp8 import cast_to_fp8_pipeline, cast_from_fp8_pipeline from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from .base import PipelineSchedule @@ -59,6 +59,7 @@ class InterleavedSchedule(PipelineSchedule): self.grad_metadata_recv = None self.fp8_communication = fp8_communication + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 90ebb0534..3269d67ba 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -10,8 +10,8 @@ from colossalai.accelerator import get_accelerator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline from colossalai.utils import get_current_device -from colossalai.quantization.fp8 import cast_to_fp8_pipeline, cast_from_fp8_pipeline from ._utils import ( detach, @@ -172,6 +172,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): if self.fp8_communication: cast_from_fp8_pipeline(output_tensor, del_metadata=False) + def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. For 1F1B. diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index c02223331..e514f435e 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any import torch import torch.distributed as dist @@ -107,7 +107,6 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: tensor.data = tensor_out.view(input_shape).to(input_type) - def cast_to_fp8_pipeline(inp: Any) -> None: """ Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline. @@ -121,7 +120,7 @@ def cast_to_fp8_pipeline(inp: Any) -> None: if type(inp) == torch.Tensor: return - assert 'hidden_states' in inp, 'required by pipeline parallelism.' + assert "hidden_states" in inp, "required by pipeline parallelism." inp_tensor = inp["hidden_states"] min_val, max_val = inp_tensor.aminmax() @@ -137,7 +136,7 @@ def cast_to_fp8_pipeline(inp: Any) -> None: finfo = torch.finfo(fp8_type) scale = torch.tensor(1.0).to(inp_tensor.device) if amax == 0.0 else finfo.max / amax.float() - q_tensor = (inp_tensor.data.float() * scale) + q_tensor = inp_tensor.data.float() * scale # Todo: Currently we use fp8_view_type to indicate which fp8 format is used. This is a temporary workaround due to 'Only support tensor for fast send'. # inp_tensor needs to be a float datatype to avoid error during gradient placement. inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type) @@ -145,7 +144,6 @@ def cast_to_fp8_pipeline(inp: Any) -> None: inp["fp8_scale"] = scale.float().reciprocal() - def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: """ Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline. @@ -156,7 +154,7 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: if type(inp) == torch.Tensor: return - assert 'hidden_states' in inp, 'required by pipeline parallelism.' + assert "hidden_states" in inp, "required by pipeline parallelism." inp_tensor = inp["hidden_states"] scale = inp["fp8_scale"] @@ -171,4 +169,4 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: inp_tensor.data = inp_tensor.data.view(fp8_type).to(torch.float16) * scale if del_metadata: - del inp["fp8_scale"] \ No newline at end of file + del inp["fp8_scale"]