diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index a3d6f1e74..b818209a6 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -992,6 +992,7 @@ class HybridParallelPlugin(PipelinePluginBase): make_vocab_size_divisible_by: int = 64, dp_outside: bool = True, overlap_p2p: bool = True, + fp8_communication: bool = False, ) -> None: super().__init__() assert ( @@ -1082,6 +1083,7 @@ class HybridParallelPlugin(PipelinePluginBase): microbatch_size=microbatch_size, enable_metadata_cache=enable_metadata_cache, overlap_p2p=overlap_p2p, + fp8_communication=fp8_communication, ) elif pp_style == "1f1b": self.schedule = OneForwardOneBackwardSchedule( @@ -1089,6 +1091,7 @@ class HybridParallelPlugin(PipelinePluginBase): num_microbatches=num_microbatches, microbatch_size=microbatch_size, enable_metadata_cache=enable_metadata_cache, + fp8_communication=fp8_communication, ) else: raise NotImplementedError() diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index a21b45c44..a7571c731 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -11,6 +11,7 @@ from colossalai.accelerator import get_accelerator from colossalai.interface import OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline from colossalai.utils import get_current_device from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device @@ -32,6 +33,7 @@ class InterleavedSchedule(PipelineSchedule): microbatch_size: Optional[int] = None, enable_metadata_cache: bool = True, overlap_p2p: bool = True, + fp8_communication: bool = False, ) -> None: super().__init__(stage_manager) assert ( @@ -56,6 +58,8 @@ class InterleavedSchedule(PipelineSchedule): self.tensor_metadata_recv = None self.grad_metadata_recv = None + self.fp8_communication = fp8_communication + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -191,8 +195,12 @@ class InterleavedSchedule(PipelineSchedule): """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_last_stage(): + if self.fp8_communication: + cast_to_fp8_pipeline(output_tensor) send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata) self.send_tensor_metadata = not self.enable_metadata_cache + if self.fp8_communication: + cast_from_fp8_pipeline(output_tensor) return send_handles return [] @@ -210,10 +218,14 @@ class InterleavedSchedule(PipelineSchedule): """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_first_stage(): + if self.fp8_communication: + cast_to_fp8_pipeline(input_tensor_grad) send_handles = self.comm.send_backward( input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata ) self.send_grad_metadata = not self.enable_metadata_cache + if self.fp8_communication: + cast_from_fp8_pipeline(input_tensor_grad) return send_handles return [] @@ -224,6 +236,8 @@ class InterleavedSchedule(PipelineSchedule): is_send = not self.stage_manager.is_last_stage() with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv): is_recv = not self.stage_manager.is_first_stage() + if self.fp8_communication: + cast_to_fp8_pipeline(output_tensor) input_tensor, wait_handles = self.comm.send_forward_recv_forward( output_tensor, is_send, @@ -237,6 +251,8 @@ class InterleavedSchedule(PipelineSchedule): if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None: self.tensor_metadata_recv = create_send_metadata(input_tensor) + if self.fp8_communication: + cast_from_fp8_pipeline(output_tensor) return input_tensor, wait_handles def send_backward_recv_backward( @@ -246,6 +262,8 @@ class InterleavedSchedule(PipelineSchedule): is_send = not self.stage_manager.is_first_stage() with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv): is_recv = not self.stage_manager.is_last_stage() + if self.fp8_communication: + cast_to_fp8_pipeline(input_tensor_grad) output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward( input_tensor_grad, is_send, @@ -258,6 +276,8 @@ class InterleavedSchedule(PipelineSchedule): self.send_grad_metadata = not self.enable_metadata_cache and is_send if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None: self.grad_metadata_recv = create_send_metadata(output_tensor_grad) + if self.fp8_communication: + cast_from_fp8_pipeline(input_tensor_grad) return output_tensor_grad, wait_handles def forward_step( @@ -379,6 +399,8 @@ class InterleavedSchedule(PipelineSchedule): # Wait until current input is received _wait_p2p(fwd_wait_handles) + if self.fp8_communication and input_obj is not None: + cast_from_fp8_pipeline(input_obj) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) if not last_batch: @@ -441,6 +463,8 @@ class InterleavedSchedule(PipelineSchedule): # Wait for input _wait_p2p(fwd_wait_handles) + if self.fp8_communication and input_obj is not None: + cast_from_fp8_pipeline(input_obj) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) input_objs[model_chunk_id].append(input_obj) output_objs[model_chunk_id].append(output_obj) @@ -467,6 +491,8 @@ class InterleavedSchedule(PipelineSchedule): # Wait for input. _wait_p2p(fwd_wait_handles) + if self.fp8_communication and input_obj is not None: + cast_from_fp8_pipeline(input_obj) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) # Add input_obj and output_obj to end of list. input_objs[model_chunk_id].append(input_obj) @@ -511,6 +537,8 @@ class InterleavedSchedule(PipelineSchedule): input_obj, fwd_wait_handles = send_forward_recv_forward() # Wait for upstream grad _wait_p2p(bwd_wait_handles) + if self.fp8_communication and output_obj_grad is not None: + cast_from_fp8_pipeline(output_obj_grad) input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) # NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv) # risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html) @@ -532,6 +560,8 @@ class InterleavedSchedule(PipelineSchedule): # Wait for upstream grad _wait_p2p(bwd_wait_handles) + if self.fp8_communication and output_obj_grad is not None: + cast_from_fp8_pipeline(output_obj_grad) # backward local grads input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) if not last_batch: diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 7f0d0e349..3269d67ba 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -10,6 +10,7 @@ from colossalai.accelerator import get_accelerator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline from colossalai.utils import get_current_device from ._utils import ( @@ -32,6 +33,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): num_microbatches: Optional[int] = None, microbatch_size: Optional[int] = None, enable_metadata_cache: bool = True, + fp8_communication: bool = False, ) -> None: """1F1B pipeline schedule. @@ -61,6 +63,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.tensor_metadata_recv = None self.grad_metadata_recv = None + self.fp8_communication = fp8_communication + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -129,6 +133,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): if self.enable_metadata_cache and self.tensor_metadata_recv is None: self.tensor_metadata_recv = create_send_metadata(input_tensor) + if self.fp8_communication: + cast_from_fp8_pipeline(input_tensor) return input_tensor def recv_backward(self, next_rank: int = None) -> Any: @@ -143,6 +149,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): """ if not self.stage_manager.is_last_stage(): output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv) + if self.fp8_communication: + cast_from_fp8_pipeline(output_tensor_grad) if self.enable_metadata_cache and self.grad_metadata_recv is None: self.grad_metadata_recv = create_send_metadata(output_tensor_grad) @@ -157,9 +165,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): next_rank (int, optional): The rank of the recipient of the tensor. """ if not self.stage_manager.is_last_stage(): + if self.fp8_communication: + cast_to_fp8_pipeline(output_tensor) self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata) self.send_tensor_metadata = not self.enable_metadata_cache + if self.fp8_communication: + cast_from_fp8_pipeline(output_tensor, del_metadata=False) + def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. For 1F1B. @@ -169,8 +182,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): prev_rank (int, optional): The rank of the recipient of the tensor """ if not self.stage_manager.is_first_stage(): + if self.fp8_communication: + cast_to_fp8_pipeline(input_tensor_grad) self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata) self.send_grad_metadata = not self.enable_metadata_cache + if self.fp8_communication: + cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False) def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any: """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline. @@ -183,6 +200,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): if not self.stage_manager.is_last_stage(): if not self.send_tensor_metadata and self.grad_metadata_recv is not None: send_first = None + if self.fp8_communication: + cast_to_fp8_pipeline(output_tensor) output_tensor_grad, _ = self.comm.send_forward_recv_backward( output_tensor, send_metadata=self.send_tensor_metadata, @@ -192,6 +211,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.send_tensor_metadata = not self.enable_metadata_cache if self.enable_metadata_cache and self.grad_metadata_recv is None: self.grad_metadata_recv = create_send_metadata(output_tensor_grad) + if self.fp8_communication: + cast_from_fp8_pipeline(output_tensor, del_metadata=False) + cast_from_fp8_pipeline(output_tensor_grad) return output_tensor_grad @@ -206,6 +228,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): if not self.stage_manager.is_first_stage(): if not self.send_grad_metadata and self.tensor_metadata_recv is not None: send_first = None # must not fallback + if self.fp8_communication: + cast_to_fp8_pipeline(input_tensor_grad) input_tensor, _ = self.comm.send_backward_recv_forward( input_tensor_grad, send_metadata=self.send_grad_metadata, @@ -215,6 +239,9 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.send_grad_metadata = not self.enable_metadata_cache if self.enable_metadata_cache and self.tensor_metadata_recv is None: self.tensor_metadata_recv = create_send_metadata(input_tensor) + if self.fp8_communication: + cast_from_fp8_pipeline(input_tensor) + cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False) return input_tensor diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py new file mode 100644 index 000000000..e514f435e --- /dev/null +++ b/colossalai/quantization/fp8.py @@ -0,0 +1,172 @@ +from typing import Any + +import torch +import torch.distributed as dist + + +def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor): + r""" + casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. + Args: + inp: input torch Tensor, should be in torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor. + scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling + is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied. + fp8_format: e4m3 or e5m2 + + Returns: + Tuples: A tuple (fp8_tensor, scale) + """ + + if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise TypeError("Only float16, bfloat16, and float32 are allowed.") + + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + fp8_max = torch.finfo(fp8_type).max + + if inp.dim() == 2: + per_channel_max = inp.abs().max(dim=-1).values.float() + per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0) + scale = fp8_max / per_channel_max[:, None] + else: + per_tensor_max = inp.abs().max().float() + per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0) + scale = fp8_max / per_tensor_max + + scale_inv = 1.0 / scale + ret = (scale * inp.float()).to(fp8_type) + return ret, scale_inv + + +def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor: + r""" + + Args: + inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2]. + scale: scaling factor returned by cast_to_fp8 function. + ret_type: the datatype of the returned tensor. + + Returns: + torch.Tensor + """ + if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: + raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.") + + if inp.dim() == 2: + ret = scale_inv[:, None] * inp.float() + else: + ret = scale_inv * inp.float() + return ret.to(ret_type) + + +def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: + r""" + This is an in-place operation for compressed all_reduce using fp8. + It works like dist.all_reduce but during communication the data is cast to fp8 format. + + Args: + tensor: torch.Tensor in fp32, fp16, bf16 datatype. + fp8_format: e4m3 or e5m2 + + Returns: + None + """ + + world_size = dist.get_world_size() + input_type = tensor.dtype + input_shape = tensor.shape + input_device = tensor.device + input_size = tensor.numel() + tensor = tensor.flatten() + + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + + ret, scale = cast_to_fp8(tensor, fp8_format=fp8_format) + + inp = ret.view(torch.uint8) + input_chunks = list(torch.chunk(inp, world_size, dim=0)) + if dist.get_rank() == world_size - 1: + output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)] + else: + output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)] + dist.all_to_all(output_chunks, input_chunks) + scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] + dist.all_gather(scale_list, scale) + summed_out = torch.zeros_like(output_chunks[0]).to(input_type) + for scale, out in zip(scale_list, output_chunks): + out = out.view(fp8_type) + summed_out += cast_from_fp8(out, scale, input_type) + + summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format) + dist.all_gather(scale_list, scale) + + tensor_list = list(torch.chunk(torch.empty(input_size, device=input_device, dtype=torch.uint8), world_size, dim=0)) + dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8)) + for i in range(world_size): + tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] + tensor_out = torch.cat(tensor_list, dim=0) + tensor.data = tensor_out.view(input_shape).to(input_type) + + +def cast_to_fp8_pipeline(inp: Any) -> None: + """ + Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline. + The activations tensor is indexed by 'hidden_states' in the inp dict. + After FP8 casting, the resulting tensor is saved as float16 or bfloat16 format but the size becomes halved. + Metadata such as fp8_scale is saved into inp dict for communication. + """ + if inp is None: + return + # In pipeline parallelism, when inp is torch.Tensor, it only contains one element, thus can be omitted. + if type(inp) == torch.Tensor: + return + + assert "hidden_states" in inp, "required by pipeline parallelism." + inp_tensor = inp["hidden_states"] + + min_val, max_val = inp_tensor.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()) + + finfo = torch.finfo(torch.float8_e4m3fn) + if amax > finfo.max: + fp8_type = torch.float8_e5m2 + fp8_view_type = torch.float16 + else: + fp8_type = torch.float8_e4m3fn + fp8_view_type = torch.bfloat16 + + finfo = torch.finfo(fp8_type) + scale = torch.tensor(1.0).to(inp_tensor.device) if amax == 0.0 else finfo.max / amax.float() + q_tensor = inp_tensor.data.float() * scale + # Todo: Currently we use fp8_view_type to indicate which fp8 format is used. This is a temporary workaround due to 'Only support tensor for fast send'. + # inp_tensor needs to be a float datatype to avoid error during gradient placement. + inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type) + + inp["fp8_scale"] = scale.float().reciprocal() + + +def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: + """ + Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline. + del_metadata = False is useful when this function is called before p2p communication. + """ + if inp is None: + return + if type(inp) == torch.Tensor: + return + + assert "hidden_states" in inp, "required by pipeline parallelism." + inp_tensor = inp["hidden_states"] + scale = inp["fp8_scale"] + + fp8_view_type = inp_tensor.dtype + if fp8_view_type == torch.float16: + fp8_type = torch.float8_e5m2 + elif fp8_view_type == torch.bfloat16: + fp8_type = torch.float8_e4m3fn + else: + raise TypeError("Only float16, bfloat16 are implemented.") + + inp_tensor.data = inp_tensor.data.view(fp8_type).to(torch.float16) * scale + + if del_metadata: + del inp["fp8_scale"] diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 7e8c07fdc..8a59ab683 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -190,6 +190,7 @@ def main(): ) parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context") + parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication") args = parser.parse_args() if args.model_type == "bert": @@ -232,6 +233,7 @@ def main(): zero_stage=1, precision="fp16", initial_scale=1, + fp8_communication=args.use_fp8_comm, ) booster = Booster(plugin=plugin, **booster_kwargs) diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py index 777d16cb9..9b3a10160 100644 --- a/examples/language/gpt/hybridparallelism/finetune.py +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -187,6 +187,7 @@ def main(): ) parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context") + parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication") args = parser.parse_args() if args.model_type == "gpt2": @@ -225,6 +226,7 @@ def main(): zero_stage=1, precision="fp16", initial_scale=1, + fp8_communication=args.use_fp8_comm, ) booster = Booster(plugin=plugin, **booster_kwargs)