Merge pull request #5885 from BurkeHulk/feature/fp8_comm

Feature/fp8 comm
pull/5932/head
Hanks 2024-07-16 11:37:05 +08:00 committed by GitHub
commit 9470701110
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 236 additions and 0 deletions

View File

@ -992,6 +992,7 @@ class HybridParallelPlugin(PipelinePluginBase):
make_vocab_size_divisible_by: int = 64, make_vocab_size_divisible_by: int = 64,
dp_outside: bool = True, dp_outside: bool = True,
overlap_p2p: bool = True, overlap_p2p: bool = True,
fp8_communication: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
assert ( assert (
@ -1082,6 +1083,7 @@ class HybridParallelPlugin(PipelinePluginBase):
microbatch_size=microbatch_size, microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache, enable_metadata_cache=enable_metadata_cache,
overlap_p2p=overlap_p2p, overlap_p2p=overlap_p2p,
fp8_communication=fp8_communication,
) )
elif pp_style == "1f1b": elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule( self.schedule = OneForwardOneBackwardSchedule(
@ -1089,6 +1091,7 @@ class HybridParallelPlugin(PipelinePluginBase):
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
microbatch_size=microbatch_size, microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache, enable_metadata_cache=enable_metadata_cache,
fp8_communication=fp8_communication,
) )
else: else:
raise NotImplementedError() raise NotImplementedError()

View File

@ -11,6 +11,7 @@ from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
@ -32,6 +33,7 @@ class InterleavedSchedule(PipelineSchedule):
microbatch_size: Optional[int] = None, microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True, enable_metadata_cache: bool = True,
overlap_p2p: bool = True, overlap_p2p: bool = True,
fp8_communication: bool = False,
) -> None: ) -> None:
super().__init__(stage_manager) super().__init__(stage_manager)
assert ( assert (
@ -56,6 +58,8 @@ class InterleavedSchedule(PipelineSchedule):
self.tensor_metadata_recv = None self.tensor_metadata_recv = None
self.grad_metadata_recv = None self.grad_metadata_recv = None
self.fp8_communication = fp8_communication
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator. """Load a batch from data iterator.
@ -191,8 +195,12 @@ class InterleavedSchedule(PipelineSchedule):
""" """
with self.stage_manager.switch_model_chunk_id(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
if self.fp8_communication:
cast_to_fp8_pipeline(output_tensor)
send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata) send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_tensor_metadata = not self.enable_metadata_cache self.send_tensor_metadata = not self.enable_metadata_cache
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor)
return send_handles return send_handles
return [] return []
@ -210,10 +218,14 @@ class InterleavedSchedule(PipelineSchedule):
""" """
with self.stage_manager.switch_model_chunk_id(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
if self.fp8_communication:
cast_to_fp8_pipeline(input_tensor_grad)
send_handles = self.comm.send_backward( send_handles = self.comm.send_backward(
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
) )
self.send_grad_metadata = not self.enable_metadata_cache self.send_grad_metadata = not self.enable_metadata_cache
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor_grad)
return send_handles return send_handles
return [] return []
@ -224,6 +236,8 @@ class InterleavedSchedule(PipelineSchedule):
is_send = not self.stage_manager.is_last_stage() is_send = not self.stage_manager.is_last_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv): with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
is_recv = not self.stage_manager.is_first_stage() is_recv = not self.stage_manager.is_first_stage()
if self.fp8_communication:
cast_to_fp8_pipeline(output_tensor)
input_tensor, wait_handles = self.comm.send_forward_recv_forward( input_tensor, wait_handles = self.comm.send_forward_recv_forward(
output_tensor, output_tensor,
is_send, is_send,
@ -237,6 +251,8 @@ class InterleavedSchedule(PipelineSchedule):
if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None: if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor) self.tensor_metadata_recv = create_send_metadata(input_tensor)
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor)
return input_tensor, wait_handles return input_tensor, wait_handles
def send_backward_recv_backward( def send_backward_recv_backward(
@ -246,6 +262,8 @@ class InterleavedSchedule(PipelineSchedule):
is_send = not self.stage_manager.is_first_stage() is_send = not self.stage_manager.is_first_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv): with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
is_recv = not self.stage_manager.is_last_stage() is_recv = not self.stage_manager.is_last_stage()
if self.fp8_communication:
cast_to_fp8_pipeline(input_tensor_grad)
output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward( output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward(
input_tensor_grad, input_tensor_grad,
is_send, is_send,
@ -258,6 +276,8 @@ class InterleavedSchedule(PipelineSchedule):
self.send_grad_metadata = not self.enable_metadata_cache and is_send self.send_grad_metadata = not self.enable_metadata_cache and is_send
if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None: if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad) self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor_grad)
return output_tensor_grad, wait_handles return output_tensor_grad, wait_handles
def forward_step( def forward_step(
@ -379,6 +399,8 @@ class InterleavedSchedule(PipelineSchedule):
# Wait until current input is received # Wait until current input is received
_wait_p2p(fwd_wait_handles) _wait_p2p(fwd_wait_handles)
if self.fp8_communication and input_obj is not None:
cast_from_fp8_pipeline(input_obj)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not last_batch: if not last_batch:
@ -441,6 +463,8 @@ class InterleavedSchedule(PipelineSchedule):
# Wait for input # Wait for input
_wait_p2p(fwd_wait_handles) _wait_p2p(fwd_wait_handles)
if self.fp8_communication and input_obj is not None:
cast_from_fp8_pipeline(input_obj)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
input_objs[model_chunk_id].append(input_obj) input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj) output_objs[model_chunk_id].append(output_obj)
@ -467,6 +491,8 @@ class InterleavedSchedule(PipelineSchedule):
# Wait for input. # Wait for input.
_wait_p2p(fwd_wait_handles) _wait_p2p(fwd_wait_handles)
if self.fp8_communication and input_obj is not None:
cast_from_fp8_pipeline(input_obj)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
# Add input_obj and output_obj to end of list. # Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj) input_objs[model_chunk_id].append(input_obj)
@ -511,6 +537,8 @@ class InterleavedSchedule(PipelineSchedule):
input_obj, fwd_wait_handles = send_forward_recv_forward() input_obj, fwd_wait_handles = send_forward_recv_forward()
# Wait for upstream grad # Wait for upstream grad
_wait_p2p(bwd_wait_handles) _wait_p2p(bwd_wait_handles)
if self.fp8_communication and output_obj_grad is not None:
cast_from_fp8_pipeline(output_obj_grad)
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
# NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv) # NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv)
# risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html) # risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html)
@ -532,6 +560,8 @@ class InterleavedSchedule(PipelineSchedule):
# Wait for upstream grad # Wait for upstream grad
_wait_p2p(bwd_wait_handles) _wait_p2p(bwd_wait_handles)
if self.fp8_communication and output_obj_grad is not None:
cast_from_fp8_pipeline(output_obj_grad)
# backward local grads # backward local grads
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
if not last_batch: if not last_batch:

View File

@ -10,6 +10,7 @@ from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from ._utils import ( from ._utils import (
@ -32,6 +33,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
num_microbatches: Optional[int] = None, num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None, microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True, enable_metadata_cache: bool = True,
fp8_communication: bool = False,
) -> None: ) -> None:
"""1F1B pipeline schedule. """1F1B pipeline schedule.
@ -61,6 +63,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.tensor_metadata_recv = None self.tensor_metadata_recv = None
self.grad_metadata_recv = None self.grad_metadata_recv = None
self.fp8_communication = fp8_communication
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator. """Load a batch from data iterator.
@ -129,6 +133,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if self.enable_metadata_cache and self.tensor_metadata_recv is None: if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor) self.tensor_metadata_recv = create_send_metadata(input_tensor)
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor)
return input_tensor return input_tensor
def recv_backward(self, next_rank: int = None) -> Any: def recv_backward(self, next_rank: int = None) -> Any:
@ -143,6 +149,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
""" """
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv) output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor_grad)
if self.enable_metadata_cache and self.grad_metadata_recv is None: if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad) self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
@ -157,9 +165,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
next_rank (int, optional): The rank of the recipient of the tensor. next_rank (int, optional): The rank of the recipient of the tensor.
""" """
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
if self.fp8_communication:
cast_to_fp8_pipeline(output_tensor)
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata) self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_tensor_metadata = not self.enable_metadata_cache self.send_tensor_metadata = not self.enable_metadata_cache
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None: def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline. """Sends the gradient tensor to the previous stage in pipeline.
For 1F1B. For 1F1B.
@ -169,8 +182,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
prev_rank (int, optional): The rank of the recipient of the tensor prev_rank (int, optional): The rank of the recipient of the tensor
""" """
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
if self.fp8_communication:
cast_to_fp8_pipeline(input_tensor_grad)
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata) self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
self.send_grad_metadata = not self.enable_metadata_cache self.send_grad_metadata = not self.enable_metadata_cache
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False)
def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any: def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline. """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
@ -183,6 +200,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
if not self.send_tensor_metadata and self.grad_metadata_recv is not None: if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
send_first = None send_first = None
if self.fp8_communication:
cast_to_fp8_pipeline(output_tensor)
output_tensor_grad, _ = self.comm.send_forward_recv_backward( output_tensor_grad, _ = self.comm.send_forward_recv_backward(
output_tensor, output_tensor,
send_metadata=self.send_tensor_metadata, send_metadata=self.send_tensor_metadata,
@ -192,6 +211,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.send_tensor_metadata = not self.enable_metadata_cache self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None: if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad) self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
cast_from_fp8_pipeline(output_tensor_grad)
return output_tensor_grad return output_tensor_grad
@ -206,6 +228,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
if not self.send_grad_metadata and self.tensor_metadata_recv is not None: if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
send_first = None # must not fallback send_first = None # must not fallback
if self.fp8_communication:
cast_to_fp8_pipeline(input_tensor_grad)
input_tensor, _ = self.comm.send_backward_recv_forward( input_tensor, _ = self.comm.send_backward_recv_forward(
input_tensor_grad, input_tensor_grad,
send_metadata=self.send_grad_metadata, send_metadata=self.send_grad_metadata,
@ -215,6 +239,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.send_grad_metadata = not self.enable_metadata_cache self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None: if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor) self.tensor_metadata_recv = create_send_metadata(input_tensor)
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor)
cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False)
return input_tensor return input_tensor

View File

@ -0,0 +1,172 @@
from typing import Any
import torch
import torch.distributed as dist
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor):
r"""
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
Args:
inp: input torch Tensor, should be in torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor.
scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling
is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied.
fp8_format: e4m3 or e5m2
Returns:
Tuples: A tuple (fp8_tensor, scale)
"""
if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
raise TypeError("Only float16, bfloat16, and float32 are allowed.")
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
fp8_max = torch.finfo(fp8_type).max
if inp.dim() == 2:
per_channel_max = inp.abs().max(dim=-1).values.float()
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
scale = fp8_max / per_channel_max[:, None]
else:
per_tensor_max = inp.abs().max().float()
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
scale = fp8_max / per_tensor_max
scale_inv = 1.0 / scale
ret = (scale * inp.float()).to(fp8_type)
return ret, scale_inv
def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor:
r"""
Args:
inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2].
scale: scaling factor returned by cast_to_fp8 function.
ret_type: the datatype of the returned tensor.
Returns:
torch.Tensor
"""
if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.")
if inp.dim() == 2:
ret = scale_inv[:, None] * inp.float()
else:
ret = scale_inv * inp.float()
return ret.to(ret_type)
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
r"""
This is an in-place operation for compressed all_reduce using fp8.
It works like dist.all_reduce but during communication the data is cast to fp8 format.
Args:
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
fp8_format: e4m3 or e5m2
Returns:
None
"""
world_size = dist.get_world_size()
input_type = tensor.dtype
input_shape = tensor.shape
input_device = tensor.device
input_size = tensor.numel()
tensor = tensor.flatten()
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
ret, scale = cast_to_fp8(tensor, fp8_format=fp8_format)
inp = ret.view(torch.uint8)
input_chunks = list(torch.chunk(inp, world_size, dim=0))
if dist.get_rank() == world_size - 1:
output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)]
else:
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
dist.all_to_all(output_chunks, input_chunks)
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
dist.all_gather(scale_list, scale)
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
for scale, out in zip(scale_list, output_chunks):
out = out.view(fp8_type)
summed_out += cast_from_fp8(out, scale, input_type)
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
dist.all_gather(scale_list, scale)
tensor_list = list(torch.chunk(torch.empty(input_size, device=input_device, dtype=torch.uint8), world_size, dim=0))
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8))
for i in range(world_size):
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
tensor_out = torch.cat(tensor_list, dim=0)
tensor.data = tensor_out.view(input_shape).to(input_type)
def cast_to_fp8_pipeline(inp: Any) -> None:
"""
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
The activations tensor is indexed by 'hidden_states' in the inp dict.
After FP8 casting, the resulting tensor is saved as float16 or bfloat16 format but the size becomes halved.
Metadata such as fp8_scale is saved into inp dict for communication.
"""
if inp is None:
return
# In pipeline parallelism, when inp is torch.Tensor, it only contains one element, thus can be omitted.
if type(inp) == torch.Tensor:
return
assert "hidden_states" in inp, "required by pipeline parallelism."
inp_tensor = inp["hidden_states"]
min_val, max_val = inp_tensor.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs())
finfo = torch.finfo(torch.float8_e4m3fn)
if amax > finfo.max:
fp8_type = torch.float8_e5m2
fp8_view_type = torch.float16
else:
fp8_type = torch.float8_e4m3fn
fp8_view_type = torch.bfloat16
finfo = torch.finfo(fp8_type)
scale = torch.tensor(1.0).to(inp_tensor.device) if amax == 0.0 else finfo.max / amax.float()
q_tensor = inp_tensor.data.float() * scale
# Todo: Currently we use fp8_view_type <float16, bfloat16> to indicate which fp8 format is used. This is a temporary workaround due to 'Only support tensor for fast send'.
# inp_tensor needs to be a float datatype to avoid error during gradient placement.
inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type)
inp["fp8_scale"] = scale.float().reciprocal()
def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
"""
Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline.
del_metadata = False is useful when this function is called before p2p communication.
"""
if inp is None:
return
if type(inp) == torch.Tensor:
return
assert "hidden_states" in inp, "required by pipeline parallelism."
inp_tensor = inp["hidden_states"]
scale = inp["fp8_scale"]
fp8_view_type = inp_tensor.dtype
if fp8_view_type == torch.float16:
fp8_type = torch.float8_e5m2
elif fp8_view_type == torch.bfloat16:
fp8_type = torch.float8_e4m3fn
else:
raise TypeError("Only float16, bfloat16 are implemented.")
inp_tensor.data = inp_tensor.data.view(fp8_type).to(torch.float16) * scale
if del_metadata:
del inp["fp8_scale"]

View File

@ -190,6 +190,7 @@ def main():
) )
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context") parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication")
args = parser.parse_args() args = parser.parse_args()
if args.model_type == "bert": if args.model_type == "bert":
@ -232,6 +233,7 @@ def main():
zero_stage=1, zero_stage=1,
precision="fp16", precision="fp16",
initial_scale=1, initial_scale=1,
fp8_communication=args.use_fp8_comm,
) )
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin, **booster_kwargs)

View File

@ -187,6 +187,7 @@ def main():
) )
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context") parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication")
args = parser.parse_args() args = parser.parse_args()
if args.model_type == "gpt2": if args.model_type == "gpt2":
@ -225,6 +226,7 @@ def main():
zero_stage=1, zero_stage=1,
precision="fp16", precision="fp16",
initial_scale=1, initial_scale=1,
fp8_communication=args.use_fp8_comm,
) )
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin, **booster_kwargs)