mirror of https://github.com/hpcaitech/ColossalAI
support fp8 communication in pipeline parallelism
parent
1e1959467e
commit
e88190184a
|
@ -992,6 +992,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
make_vocab_size_divisible_by: int = 64,
|
||||
dp_outside: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert (
|
||||
|
@ -1082,6 +1083,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
overlap_p2p=overlap_p2p,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
elif pp_style == "1f1b":
|
||||
self.schedule = OneForwardOneBackwardSchedule(
|
||||
|
@ -1089,6 +1091,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
num_microbatches=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -12,6 +12,7 @@ from colossalai.interface import OptimizerWrapper
|
|||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.quantization.fp8 import cast_to_fp8_pipeline, cast_from_fp8_pipeline
|
||||
|
||||
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
|
||||
from .base import PipelineSchedule
|
||||
|
@ -32,6 +33,7 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
microbatch_size: Optional[int] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
super().__init__(stage_manager)
|
||||
assert (
|
||||
|
@ -56,6 +58,7 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
self.fp8_communication = fp8_communication
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
|
||||
|
@ -191,8 +194,12 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_last_stage():
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(output_tensor)
|
||||
send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor)
|
||||
return send_handles
|
||||
return []
|
||||
|
||||
|
@ -210,10 +217,14 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
"""
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
if not self.stage_manager.is_first_stage():
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(input_tensor_grad)
|
||||
send_handles = self.comm.send_backward(
|
||||
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
|
||||
)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(input_tensor_grad)
|
||||
return send_handles
|
||||
return []
|
||||
|
||||
|
@ -224,6 +235,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
is_send = not self.stage_manager.is_last_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
is_recv = not self.stage_manager.is_first_stage()
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(output_tensor)
|
||||
input_tensor, wait_handles = self.comm.send_forward_recv_forward(
|
||||
output_tensor,
|
||||
is_send,
|
||||
|
@ -237,6 +250,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor)
|
||||
return input_tensor, wait_handles
|
||||
|
||||
def send_backward_recv_backward(
|
||||
|
@ -246,6 +261,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
is_send = not self.stage_manager.is_first_stage()
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
|
||||
is_recv = not self.stage_manager.is_last_stage()
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(input_tensor_grad)
|
||||
output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward(
|
||||
input_tensor_grad,
|
||||
is_send,
|
||||
|
@ -258,6 +275,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
self.send_grad_metadata = not self.enable_metadata_cache and is_send
|
||||
if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(input_tensor_grad)
|
||||
return output_tensor_grad, wait_handles
|
||||
|
||||
def forward_step(
|
||||
|
@ -379,6 +398,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
# Wait until current input is received
|
||||
_wait_p2p(fwd_wait_handles)
|
||||
if self.fp8_communication and input_obj is not None:
|
||||
cast_from_fp8_pipeline(input_obj)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
|
||||
if not last_batch:
|
||||
|
@ -441,6 +462,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
# Wait for input
|
||||
_wait_p2p(fwd_wait_handles)
|
||||
if self.fp8_communication and input_obj is not None:
|
||||
cast_from_fp8_pipeline(input_obj)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
|
@ -467,6 +490,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
# Wait for input.
|
||||
_wait_p2p(fwd_wait_handles)
|
||||
if self.fp8_communication and input_obj is not None:
|
||||
cast_from_fp8_pipeline(input_obj)
|
||||
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs[model_chunk_id].append(input_obj)
|
||||
|
@ -511,6 +536,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
input_obj, fwd_wait_handles = send_forward_recv_forward()
|
||||
# Wait for upstream grad
|
||||
_wait_p2p(bwd_wait_handles)
|
||||
if self.fp8_communication and output_obj_grad is not None:
|
||||
cast_from_fp8_pipeline(output_obj_grad)
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
# NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv)
|
||||
# risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html)
|
||||
|
@ -532,6 +559,8 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
|
||||
# Wait for upstream grad
|
||||
_wait_p2p(bwd_wait_handles)
|
||||
if self.fp8_communication and output_obj_grad is not None:
|
||||
cast_from_fp8_pipeline(output_obj_grad)
|
||||
# backward local grads
|
||||
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
|
||||
if not last_batch:
|
||||
|
|
|
@ -11,6 +11,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
|
|||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.quantization.fp8 import cast_to_fp8_pipeline, cast_from_fp8_pipeline
|
||||
|
||||
from ._utils import (
|
||||
detach,
|
||||
|
@ -32,6 +33,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
"""1F1B pipeline schedule.
|
||||
|
||||
|
@ -61,6 +63,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
self.tensor_metadata_recv = None
|
||||
self.grad_metadata_recv = None
|
||||
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
|
||||
|
@ -129,6 +133,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(input_tensor)
|
||||
return input_tensor
|
||||
|
||||
def recv_backward(self, next_rank: int = None) -> Any:
|
||||
|
@ -143,6 +149,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor_grad)
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
|
||||
|
@ -157,9 +165,13 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||
"""
|
||||
if not self.stage_manager.is_last_stage():
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(output_tensor)
|
||||
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
|
||||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
|
||||
def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For 1F1B.
|
||||
|
@ -169,8 +181,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||
"""
|
||||
if not self.stage_manager.is_first_stage():
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(input_tensor_grad)
|
||||
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
|
||||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False)
|
||||
|
||||
def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any:
|
||||
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
|
||||
|
@ -183,6 +199,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
if not self.stage_manager.is_last_stage():
|
||||
if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
|
||||
send_first = None
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(output_tensor)
|
||||
output_tensor_grad, _ = self.comm.send_forward_recv_backward(
|
||||
output_tensor,
|
||||
send_metadata=self.send_tensor_metadata,
|
||||
|
@ -192,6 +210,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
|
||||
cast_from_fp8_pipeline(output_tensor_grad)
|
||||
|
||||
return output_tensor_grad
|
||||
|
||||
|
@ -206,6 +227,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
if not self.stage_manager.is_first_stage():
|
||||
if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
|
||||
send_first = None # must not fallback
|
||||
if self.fp8_communication:
|
||||
cast_to_fp8_pipeline(input_tensor_grad)
|
||||
input_tensor, _ = self.comm.send_backward_recv_forward(
|
||||
input_tensor_grad,
|
||||
send_metadata=self.send_grad_metadata,
|
||||
|
@ -215,6 +238,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
self.send_grad_metadata = not self.enable_metadata_cache
|
||||
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
if self.fp8_communication:
|
||||
cast_from_fp8_pipeline(input_tensor)
|
||||
cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
|
|
@ -105,3 +105,70 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
|
|||
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
|
||||
tensor_out = torch.cat(tensor_list, dim=0)
|
||||
tensor.data = tensor_out.view(input_shape).to(input_type)
|
||||
|
||||
|
||||
|
||||
def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||
"""
|
||||
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
|
||||
The activations tensor is indexed by 'hidden_states' in the inp dict.
|
||||
After FP8 casting, the resulting tensor is saved as float16 or bfloat16 format but the size becomes halved.
|
||||
Metadata such as fp8_scale is saved into inp dict for communication.
|
||||
"""
|
||||
if inp is None:
|
||||
return
|
||||
# In pipeline parallelism, when inp is torch.Tensor, it only contains one element, thus can be omitted.
|
||||
if type(inp) == torch.Tensor:
|
||||
return
|
||||
|
||||
assert 'hidden_states' in inp, 'required by pipeline parallelism.'
|
||||
inp_tensor = inp["hidden_states"]
|
||||
|
||||
min_val, max_val = inp_tensor.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs())
|
||||
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
if amax > finfo.max:
|
||||
fp8_type = torch.float8_e5m2
|
||||
fp8_view_type = torch.float16
|
||||
else:
|
||||
fp8_type = torch.float8_e4m3fn
|
||||
fp8_view_type = torch.bfloat16
|
||||
|
||||
finfo = torch.finfo(fp8_type)
|
||||
scale = torch.tensor(1.0).to(inp_tensor.device) if amax == 0.0 else finfo.max / amax.float()
|
||||
q_tensor = (inp_tensor.data.float() * scale)
|
||||
# Todo: Currently we use fp8_view_type <float16, bfloat16> to indicate which fp8 format is used. This is a temporary workaround due to 'Only support tensor for fast send'.
|
||||
# inp_tensor needs to be a float datatype to avoid error during gradient placement.
|
||||
inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type)
|
||||
|
||||
inp["fp8_scale"] = scale.float().reciprocal()
|
||||
|
||||
|
||||
|
||||
def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
||||
"""
|
||||
Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline.
|
||||
del_metadata = False is useful when this function is called before p2p communication.
|
||||
"""
|
||||
if inp is None:
|
||||
return
|
||||
if type(inp) == torch.Tensor:
|
||||
return
|
||||
|
||||
assert 'hidden_states' in inp, 'required by pipeline parallelism.'
|
||||
inp_tensor = inp["hidden_states"]
|
||||
scale = inp["fp8_scale"]
|
||||
|
||||
fp8_view_type = inp_tensor.dtype
|
||||
if fp8_view_type == torch.float16:
|
||||
fp8_type = torch.float8_e5m2
|
||||
elif fp8_view_type == torch.bfloat16:
|
||||
fp8_type = torch.float8_e4m3fn
|
||||
else:
|
||||
raise TypeError("Only float16, bfloat16 are implemented.")
|
||||
|
||||
inp_tensor.data = inp_tensor.data.view(fp8_type).to(torch.float16) * scale
|
||||
|
||||
if del_metadata:
|
||||
del inp["fp8_scale"]
|
Loading…
Reference in New Issue