support fp8 communication in pipeline parallelism

pull/5885/head
BurkeHulk 2024-07-12 15:25:25 +08:00
parent 1e1959467e
commit e88190184a
4 changed files with 126 additions and 1 deletions

View File

@ -992,6 +992,7 @@ class HybridParallelPlugin(PipelinePluginBase):
make_vocab_size_divisible_by: int = 64, make_vocab_size_divisible_by: int = 64,
dp_outside: bool = True, dp_outside: bool = True,
overlap_p2p: bool = True, overlap_p2p: bool = True,
fp8_communication: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
assert ( assert (
@ -1082,6 +1083,7 @@ class HybridParallelPlugin(PipelinePluginBase):
microbatch_size=microbatch_size, microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache, enable_metadata_cache=enable_metadata_cache,
overlap_p2p=overlap_p2p, overlap_p2p=overlap_p2p,
fp8_communication=fp8_communication,
) )
elif pp_style == "1f1b": elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule( self.schedule = OneForwardOneBackwardSchedule(
@ -1089,6 +1091,7 @@ class HybridParallelPlugin(PipelinePluginBase):
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
microbatch_size=microbatch_size, microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache, enable_metadata_cache=enable_metadata_cache,
fp8_communication=fp8_communication,
) )
else: else:
raise NotImplementedError() raise NotImplementedError()

View File

@ -12,6 +12,7 @@ from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.quantization.fp8 import cast_to_fp8_pipeline, cast_from_fp8_pipeline
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule from .base import PipelineSchedule
@ -32,6 +33,7 @@ class InterleavedSchedule(PipelineSchedule):
microbatch_size: Optional[int] = None, microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True, enable_metadata_cache: bool = True,
overlap_p2p: bool = True, overlap_p2p: bool = True,
fp8_communication: bool = False,
) -> None: ) -> None:
super().__init__(stage_manager) super().__init__(stage_manager)
assert ( assert (
@ -56,6 +58,7 @@ class InterleavedSchedule(PipelineSchedule):
self.tensor_metadata_recv = None self.tensor_metadata_recv = None
self.grad_metadata_recv = None self.grad_metadata_recv = None
self.fp8_communication = fp8_communication
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator. """Load a batch from data iterator.
@ -191,8 +194,12 @@ class InterleavedSchedule(PipelineSchedule):
""" """
with self.stage_manager.switch_model_chunk_id(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
if self.fp8_communication:
cast_to_fp8_pipeline(output_tensor)
send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata) send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_tensor_metadata = not self.enable_metadata_cache self.send_tensor_metadata = not self.enable_metadata_cache
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor)
return send_handles return send_handles
return [] return []
@ -210,10 +217,14 @@ class InterleavedSchedule(PipelineSchedule):
""" """
with self.stage_manager.switch_model_chunk_id(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
if self.fp8_communication:
cast_to_fp8_pipeline(input_tensor_grad)
send_handles = self.comm.send_backward( send_handles = self.comm.send_backward(
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
) )
self.send_grad_metadata = not self.enable_metadata_cache self.send_grad_metadata = not self.enable_metadata_cache
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor_grad)
return send_handles return send_handles
return [] return []
@ -224,6 +235,8 @@ class InterleavedSchedule(PipelineSchedule):
is_send = not self.stage_manager.is_last_stage() is_send = not self.stage_manager.is_last_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv): with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
is_recv = not self.stage_manager.is_first_stage() is_recv = not self.stage_manager.is_first_stage()
if self.fp8_communication:
cast_to_fp8_pipeline(output_tensor)
input_tensor, wait_handles = self.comm.send_forward_recv_forward( input_tensor, wait_handles = self.comm.send_forward_recv_forward(
output_tensor, output_tensor,
is_send, is_send,
@ -237,6 +250,8 @@ class InterleavedSchedule(PipelineSchedule):
if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None: if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor) self.tensor_metadata_recv = create_send_metadata(input_tensor)
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor)
return input_tensor, wait_handles return input_tensor, wait_handles
def send_backward_recv_backward( def send_backward_recv_backward(
@ -246,6 +261,8 @@ class InterleavedSchedule(PipelineSchedule):
is_send = not self.stage_manager.is_first_stage() is_send = not self.stage_manager.is_first_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv): with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
is_recv = not self.stage_manager.is_last_stage() is_recv = not self.stage_manager.is_last_stage()
if self.fp8_communication:
cast_to_fp8_pipeline(input_tensor_grad)
output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward( output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward(
input_tensor_grad, input_tensor_grad,
is_send, is_send,
@ -258,6 +275,8 @@ class InterleavedSchedule(PipelineSchedule):
self.send_grad_metadata = not self.enable_metadata_cache and is_send self.send_grad_metadata = not self.enable_metadata_cache and is_send
if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None: if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad) self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor_grad)
return output_tensor_grad, wait_handles return output_tensor_grad, wait_handles
def forward_step( def forward_step(
@ -379,6 +398,8 @@ class InterleavedSchedule(PipelineSchedule):
# Wait until current input is received # Wait until current input is received
_wait_p2p(fwd_wait_handles) _wait_p2p(fwd_wait_handles)
if self.fp8_communication and input_obj is not None:
cast_from_fp8_pipeline(input_obj)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not last_batch: if not last_batch:
@ -441,6 +462,8 @@ class InterleavedSchedule(PipelineSchedule):
# Wait for input # Wait for input
_wait_p2p(fwd_wait_handles) _wait_p2p(fwd_wait_handles)
if self.fp8_communication and input_obj is not None:
cast_from_fp8_pipeline(input_obj)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
input_objs[model_chunk_id].append(input_obj) input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj) output_objs[model_chunk_id].append(output_obj)
@ -467,6 +490,8 @@ class InterleavedSchedule(PipelineSchedule):
# Wait for input. # Wait for input.
_wait_p2p(fwd_wait_handles) _wait_p2p(fwd_wait_handles)
if self.fp8_communication and input_obj is not None:
cast_from_fp8_pipeline(input_obj)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
# Add input_obj and output_obj to end of list. # Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj) input_objs[model_chunk_id].append(input_obj)
@ -511,6 +536,8 @@ class InterleavedSchedule(PipelineSchedule):
input_obj, fwd_wait_handles = send_forward_recv_forward() input_obj, fwd_wait_handles = send_forward_recv_forward()
# Wait for upstream grad # Wait for upstream grad
_wait_p2p(bwd_wait_handles) _wait_p2p(bwd_wait_handles)
if self.fp8_communication and output_obj_grad is not None:
cast_from_fp8_pipeline(output_obj_grad)
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
# NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv) # NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv)
# risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html) # risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html)
@ -532,6 +559,8 @@ class InterleavedSchedule(PipelineSchedule):
# Wait for upstream grad # Wait for upstream grad
_wait_p2p(bwd_wait_handles) _wait_p2p(bwd_wait_handles)
if self.fp8_communication and output_obj_grad is not None:
cast_from_fp8_pipeline(output_obj_grad)
# backward local grads # backward local grads
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
if not last_batch: if not last_batch:

View File

@ -11,6 +11,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.quantization.fp8 import cast_to_fp8_pipeline, cast_from_fp8_pipeline
from ._utils import ( from ._utils import (
detach, detach,
@ -32,6 +33,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
num_microbatches: Optional[int] = None, num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None, microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True, enable_metadata_cache: bool = True,
fp8_communication: bool = False,
) -> None: ) -> None:
"""1F1B pipeline schedule. """1F1B pipeline schedule.
@ -61,6 +63,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.tensor_metadata_recv = None self.tensor_metadata_recv = None
self.grad_metadata_recv = None self.grad_metadata_recv = None
self.fp8_communication = fp8_communication
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator. """Load a batch from data iterator.
@ -129,6 +133,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if self.enable_metadata_cache and self.tensor_metadata_recv is None: if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor) self.tensor_metadata_recv = create_send_metadata(input_tensor)
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor)
return input_tensor return input_tensor
def recv_backward(self, next_rank: int = None) -> Any: def recv_backward(self, next_rank: int = None) -> Any:
@ -143,6 +149,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
""" """
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv) output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor_grad)
if self.enable_metadata_cache and self.grad_metadata_recv is None: if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad) self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
@ -157,9 +165,13 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
next_rank (int, optional): The rank of the recipient of the tensor. next_rank (int, optional): The rank of the recipient of the tensor.
""" """
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
if self.fp8_communication:
cast_to_fp8_pipeline(output_tensor)
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata) self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_tensor_metadata = not self.enable_metadata_cache self.send_tensor_metadata = not self.enable_metadata_cache
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None: def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline. """Sends the gradient tensor to the previous stage in pipeline.
For 1F1B. For 1F1B.
@ -169,8 +181,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
prev_rank (int, optional): The rank of the recipient of the tensor prev_rank (int, optional): The rank of the recipient of the tensor
""" """
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
if self.fp8_communication:
cast_to_fp8_pipeline(input_tensor_grad)
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata) self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
self.send_grad_metadata = not self.enable_metadata_cache self.send_grad_metadata = not self.enable_metadata_cache
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False)
def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any: def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline. """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
@ -183,6 +199,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if not self.stage_manager.is_last_stage(): if not self.stage_manager.is_last_stage():
if not self.send_tensor_metadata and self.grad_metadata_recv is not None: if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
send_first = None send_first = None
if self.fp8_communication:
cast_to_fp8_pipeline(output_tensor)
output_tensor_grad, _ = self.comm.send_forward_recv_backward( output_tensor_grad, _ = self.comm.send_forward_recv_backward(
output_tensor, output_tensor,
send_metadata=self.send_tensor_metadata, send_metadata=self.send_tensor_metadata,
@ -192,6 +210,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.send_tensor_metadata = not self.enable_metadata_cache self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None: if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad) self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
cast_from_fp8_pipeline(output_tensor_grad)
return output_tensor_grad return output_tensor_grad
@ -206,6 +227,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if not self.stage_manager.is_first_stage(): if not self.stage_manager.is_first_stage():
if not self.send_grad_metadata and self.tensor_metadata_recv is not None: if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
send_first = None # must not fallback send_first = None # must not fallback
if self.fp8_communication:
cast_to_fp8_pipeline(input_tensor_grad)
input_tensor, _ = self.comm.send_backward_recv_forward( input_tensor, _ = self.comm.send_backward_recv_forward(
input_tensor_grad, input_tensor_grad,
send_metadata=self.send_grad_metadata, send_metadata=self.send_grad_metadata,
@ -215,6 +238,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.send_grad_metadata = not self.enable_metadata_cache self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None: if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor) self.tensor_metadata_recv = create_send_metadata(input_tensor)
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor)
cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False)
return input_tensor return input_tensor

View File

@ -105,3 +105,70 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
tensor_out = torch.cat(tensor_list, dim=0) tensor_out = torch.cat(tensor_list, dim=0)
tensor.data = tensor_out.view(input_shape).to(input_type) tensor.data = tensor_out.view(input_shape).to(input_type)
def cast_to_fp8_pipeline(inp: Any) -> None:
"""
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
The activations tensor is indexed by 'hidden_states' in the inp dict.
After FP8 casting, the resulting tensor is saved as float16 or bfloat16 format but the size becomes halved.
Metadata such as fp8_scale is saved into inp dict for communication.
"""
if inp is None:
return
# In pipeline parallelism, when inp is torch.Tensor, it only contains one element, thus can be omitted.
if type(inp) == torch.Tensor:
return
assert 'hidden_states' in inp, 'required by pipeline parallelism.'
inp_tensor = inp["hidden_states"]
min_val, max_val = inp_tensor.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs())
finfo = torch.finfo(torch.float8_e4m3fn)
if amax > finfo.max:
fp8_type = torch.float8_e5m2
fp8_view_type = torch.float16
else:
fp8_type = torch.float8_e4m3fn
fp8_view_type = torch.bfloat16
finfo = torch.finfo(fp8_type)
scale = torch.tensor(1.0).to(inp_tensor.device) if amax == 0.0 else finfo.max / amax.float()
q_tensor = (inp_tensor.data.float() * scale)
# Todo: Currently we use fp8_view_type <float16, bfloat16> to indicate which fp8 format is used. This is a temporary workaround due to 'Only support tensor for fast send'.
# inp_tensor needs to be a float datatype to avoid error during gradient placement.
inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type)
inp["fp8_scale"] = scale.float().reciprocal()
def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
"""
Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline.
del_metadata = False is useful when this function is called before p2p communication.
"""
if inp is None:
return
if type(inp) == torch.Tensor:
return
assert 'hidden_states' in inp, 'required by pipeline parallelism.'
inp_tensor = inp["hidden_states"]
scale = inp["fp8_scale"]
fp8_view_type = inp_tensor.dtype
if fp8_view_type == torch.float16:
fp8_type = torch.float8_e5m2
elif fp8_view_type == torch.bfloat16:
fp8_type = torch.float8_e4m3fn
else:
raise TypeError("Only float16, bfloat16 are implemented.")
inp_tensor.data = inp_tensor.data.view(fp8_type).to(torch.float16) * scale
if del_metadata:
del inp["fp8_scale"]