mirror of https://github.com/hpcaitech/ColossalAI
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.cipull/5885/head
parent
1f1b856354
commit
51f916b11d
|
@ -11,8 +11,8 @@ from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.interface import OptimizerWrapper
|
from colossalai.interface import OptimizerWrapper
|
||||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.quantization.fp8 import cast_to_fp8_pipeline, cast_from_fp8_pipeline
|
|
||||||
|
|
||||||
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
|
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
|
||||||
from .base import PipelineSchedule
|
from .base import PipelineSchedule
|
||||||
|
@ -59,6 +59,7 @@ class InterleavedSchedule(PipelineSchedule):
|
||||||
self.grad_metadata_recv = None
|
self.grad_metadata_recv = None
|
||||||
|
|
||||||
self.fp8_communication = fp8_communication
|
self.fp8_communication = fp8_communication
|
||||||
|
|
||||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||||
"""Load a batch from data iterator.
|
"""Load a batch from data iterator.
|
||||||
|
|
||||||
|
|
|
@ -10,8 +10,8 @@ from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.quantization.fp8 import cast_to_fp8_pipeline, cast_from_fp8_pipeline
|
|
||||||
|
|
||||||
from ._utils import (
|
from ._utils import (
|
||||||
detach,
|
detach,
|
||||||
|
@ -172,6 +172,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
|
|
||||||
if self.fp8_communication:
|
if self.fp8_communication:
|
||||||
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
|
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
|
||||||
|
|
||||||
def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
|
def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
|
||||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||||
For 1F1B.
|
For 1F1B.
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -107,7 +107,6 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
|
||||||
tensor.data = tensor_out.view(input_shape).to(input_type)
|
tensor.data = tensor_out.view(input_shape).to(input_type)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def cast_to_fp8_pipeline(inp: Any) -> None:
|
def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||||
"""
|
"""
|
||||||
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
|
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
|
||||||
|
@ -121,7 +120,7 @@ def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||||
if type(inp) == torch.Tensor:
|
if type(inp) == torch.Tensor:
|
||||||
return
|
return
|
||||||
|
|
||||||
assert 'hidden_states' in inp, 'required by pipeline parallelism.'
|
assert "hidden_states" in inp, "required by pipeline parallelism."
|
||||||
inp_tensor = inp["hidden_states"]
|
inp_tensor = inp["hidden_states"]
|
||||||
|
|
||||||
min_val, max_val = inp_tensor.aminmax()
|
min_val, max_val = inp_tensor.aminmax()
|
||||||
|
@ -137,7 +136,7 @@ def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||||
|
|
||||||
finfo = torch.finfo(fp8_type)
|
finfo = torch.finfo(fp8_type)
|
||||||
scale = torch.tensor(1.0).to(inp_tensor.device) if amax == 0.0 else finfo.max / amax.float()
|
scale = torch.tensor(1.0).to(inp_tensor.device) if amax == 0.0 else finfo.max / amax.float()
|
||||||
q_tensor = (inp_tensor.data.float() * scale)
|
q_tensor = inp_tensor.data.float() * scale
|
||||||
# Todo: Currently we use fp8_view_type <float16, bfloat16> to indicate which fp8 format is used. This is a temporary workaround due to 'Only support tensor for fast send'.
|
# Todo: Currently we use fp8_view_type <float16, bfloat16> to indicate which fp8 format is used. This is a temporary workaround due to 'Only support tensor for fast send'.
|
||||||
# inp_tensor needs to be a float datatype to avoid error during gradient placement.
|
# inp_tensor needs to be a float datatype to avoid error during gradient placement.
|
||||||
inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type)
|
inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type)
|
||||||
|
@ -145,7 +144,6 @@ def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||||
inp["fp8_scale"] = scale.float().reciprocal()
|
inp["fp8_scale"] = scale.float().reciprocal()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
||||||
"""
|
"""
|
||||||
Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline.
|
Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline.
|
||||||
|
@ -156,7 +154,7 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
||||||
if type(inp) == torch.Tensor:
|
if type(inp) == torch.Tensor:
|
||||||
return
|
return
|
||||||
|
|
||||||
assert 'hidden_states' in inp, 'required by pipeline parallelism.'
|
assert "hidden_states" in inp, "required by pipeline parallelism."
|
||||||
inp_tensor = inp["hidden_states"]
|
inp_tensor = inp["hidden_states"]
|
||||||
scale = inp["fp8_scale"]
|
scale = inp["fp8_scale"]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue