diff --git a/colossalai/communication/__init__.py b/colossalai/communication/__init__.py index e7bb323e4..25e817f1f 100644 --- a/colossalai/communication/__init__.py +++ b/colossalai/communication/__init__.py @@ -13,5 +13,5 @@ __all__ = [ 'send_forward_backward_recv_forward_backward', 'send_backward', 'send_backward_recv_backward', 'send_backward_recv_forward', 'send_forward_recv_backward', 'recv_backward', 'recv_forward', - 'ring_forward', 'send_tensor_meta', 'recv_tensor_meta' -] \ No newline at end of file + 'ring_forward', 'send_tensor_meta', 'recv_tensor_meta', +] diff --git a/colossalai/communication/p2p.py b/colossalai/communication/p2p.py index 99ccdf6eb..4aefe342d 100644 --- a/colossalai/communication/p2p.py +++ b/colossalai/communication/p2p.py @@ -1,12 +1,42 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from typing import List, Tuple, Union import torch import torch.distributed as dist from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.utils import get_current_device +from functools import reduce +import operator +from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor + + +TensorShape = Union[torch.Size, List[int], Tuple[int]] + + +def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]: + """get the exact tensor shape when communicating and return whether the tensor is a chunk + + :param tensor_shape: shape of tensor + :type tensor_shape: TensorShape + :param chunk_tensor: whether to chunk tensor, defaults to False + :type chunk_tensor: bool, optional + :return: exact tensor shape, whether to chunk tensor + :rtype: Tuple[Union[torch.Size, List[int], Tuple[int]], bool] + """ + if chunk_tensor: + tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) + tensor_parallel_world_size = gpc.get_world_size(ParallelMode.TENSOR) + if tensor_chunk_shape % tensor_parallel_world_size == 0: + tensor_chunk_shape = tensor_chunk_shape // tensor_parallel_world_size + else: + tensor_chunk_shape = tensor_shape + chunk_tensor = False + else: + tensor_chunk_shape = tensor_shape + return tensor_chunk_shape, chunk_tensor def _communicate(tensor_send_next=None, @@ -17,7 +47,8 @@ def _communicate(tensor_send_next=None, recv_next_shape=None, prev_rank=None, next_rank=None, - dtype=None): + dtype=None, + scatter_gather_tensors=False): """ Adapted from megatron.p2p_communication. Communicate tensors between stages. Used as helper method in other @@ -42,13 +73,15 @@ def _communicate(tensor_send_next=None, if recv_prev: assert recv_prev_shape is not None - tensor_recv_prev = torch.empty(recv_prev_shape, + recv_prev_chunk_shape, recv_prev_split = _get_tensor_shape(recv_prev_shape, scatter_gather_tensors) + tensor_recv_prev = torch.empty(recv_prev_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) if recv_next: assert recv_next_shape is not None - tensor_recv_next = torch.empty(recv_next_shape, + recv_next_chunk_shape, recv_next_split = _get_tensor_shape(recv_next_shape, scatter_gather_tensors) + tensor_recv_next = torch.empty(recv_next_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) @@ -63,6 +96,16 @@ def _communicate(tensor_send_next=None, next_rank = gpc.get_next_global_rank( ParallelMode.PIPELINE) + if tensor_send_prev is not None: + send_prev_split = _get_tensor_shape(tensor_send_prev.shape, scatter_gather_tensors)[1] + if send_prev_split: + tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev) + + if tensor_send_next is not None: + send_next_split = _get_tensor_shape(tensor_send_next.shape, scatter_gather_tensors)[1] + if send_next_split: + tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next) + ops = [] if tensor_send_prev is not None: send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank) @@ -82,10 +125,15 @@ def _communicate(tensor_send_next=None, req.wait() # To protect against race condition when using batch_isend_irecv(). torch.cuda.synchronize() + + if recv_prev and recv_prev_split: + tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_() + if recv_next and recv_next_split: + tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_() return tensor_recv_prev, tensor_recv_next -def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float): +def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False): """Receives the input tensor from the previous member in pipeline. :param input_tensor_shape: The shape of the tensor to be recieved @@ -101,11 +149,12 @@ def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float): input_tensor, _ = _communicate(recv_prev=True, recv_prev_shape=input_tensor_shape, prev_rank=prev_rank, - dtype=dtype) + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) return input_tensor -def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float): +def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False): """Receives the grad tensor from the next member in pipeline. :param output_grad_shape: The shape of the tensor to be recieved @@ -121,11 +170,12 @@ def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float): _, output_tensor_grad = _communicate(recv_next=True, recv_next_shape=output_grad_shape, next_rank=next_rank, - dtype=dtype) + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) return output_tensor_grad -def send_forward(output_tensor, next_rank=None): +def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False): """Sends the input tensor to the next member in pipeline. :param output_tensor: Tensor to be sent @@ -135,10 +185,11 @@ def send_forward(output_tensor, next_rank=None): """ if not gpc.is_pipeline_last_stage(): _communicate(tensor_send_next=output_tensor, - next_rank=next_rank) + next_rank=next_rank, + scatter_gather_tensors=scatter_gather_tensors) -def send_backward(input_tensor_grad, prev_rank=None): +def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False): """Sends the grad tensor to the previous member in pipeline. :param input_tensor_grad: Tensor to be sent @@ -148,14 +199,16 @@ def send_backward(input_tensor_grad, prev_rank=None): """ if not gpc.is_pipeline_first_stage(): _communicate(tensor_send_prev=input_tensor_grad, - prev_rank=prev_rank) + prev_rank=prev_rank, + scatter_gather_tensors=scatter_gather_tensors) def send_forward_recv_backward(output_tensor, output_grad_shape, recv_next=True, next_rank=None, - dtype=torch.float): + dtype=torch.float, + scatter_gather_tensors=False): """Batched communication operation. Sends the input tensor to the next member in pipeline, while recieves the grad tensor from the next member in pipeline. @@ -174,7 +227,8 @@ def send_forward_recv_backward(output_tensor, recv_next=recv_next, recv_next_shape=output_grad_shape, next_rank=next_rank, - dtype=dtype) + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) return output_tensor_grad @@ -182,7 +236,8 @@ def send_backward_recv_forward(input_tensor_grad, input_tensor_shape, recv_prev=True, prev_rank=None, - dtype=torch.float): + dtype=torch.float, + scatter_gather_tensors=False): """Batched communication operation. Sends the grad tensor to the previous member in pipeline, while recieves the input tensor from the previous member in pipeline. @@ -201,7 +256,8 @@ def send_backward_recv_forward(input_tensor_grad, recv_prev=recv_prev, recv_prev_shape=input_tensor_shape, prev_rank=prev_rank, - dtype=dtype) + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) return input_tensor @@ -210,7 +266,8 @@ def send_forward_recv_forward(output_tensor, recv_prev=True, prev_rank=None, next_rank=None, - dtype=torch.float): + dtype=torch.float, + scatter_gather_tensors=False): """Batched communication operation. Sends the input tensor to the next member in pipeline, while recieves the input tensor from the previous member in pipeline. @@ -227,7 +284,8 @@ def send_forward_recv_forward(output_tensor, recv_prev_shape=input_tensor_shape, prev_rank=prev_rank, next_rank=next_rank, - dtype=dtype) + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) return input_tensor @@ -236,7 +294,8 @@ def send_backward_recv_backward(input_tensor_grad, recv_next=True, prev_rank=None, next_rank=None, - dtype=torch.float): + dtype=torch.float, + scatter_gather_tensors=False): """Batched communication operation. Sends the grad tensor to the previous member in pipeline, while recieves the grad tensor from the next member in pipeline. @@ -253,7 +312,8 @@ def send_backward_recv_backward(input_tensor_grad, recv_next_shape=output_grad_shape, prev_rank=prev_rank, next_rank=next_rank, - dtype=dtype) + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) return output_tensor_grad @@ -265,7 +325,8 @@ def send_forward_backward_recv_forward_backward(output_tensor, recv_next=True, prev_rank=None, next_rank=None, - dtype=torch.float): + dtype=torch.float, + scatter_gather_tensors=False): """Batched communication operation. Sends the input tensor to the next and the grad tensor to the previous, while recieves the grad tensor from the next and the input tensor from the previous. @@ -290,5 +351,6 @@ def send_forward_backward_recv_forward_backward(output_tensor, recv_next_shape=output_grad_shape, prev_rank=prev_rank, next_rank=next_rank, - dtype=dtype) + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors) return input_tensor, output_tensor_grad diff --git a/colossalai/communication/utils.py b/colossalai/communication/utils.py index 1eeba7bda..908161587 100644 --- a/colossalai/communication/utils.py +++ b/colossalai/communication/utils.py @@ -62,3 +62,31 @@ def recv_tensor_meta(tensor_shape, prev_rank=None): tensor_shape = torch.Size(recv_shape) return tensor_shape + + +def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): + """Break a tensor into equal 1D chunks.""" + partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D) + start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D) + end_index = start_index + partition_size + if new_buffer: + data = torch.empty(partition_size, dtype=tensor.dtype, + device=torch.cuda.current_device(), + requires_grad=False) + data.copy_(tensor.view(-1)[start_index:end_index]) + else: + data = tensor.view(-1)[start_index:end_index] + return data + + +def gather_split_1d_tensor(tensor): + """Opposite of above function, gather values from model parallel ranks.""" + world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + numel = torch.numel(tensor) + numel_gathered = world_size * numel + gathered = torch.empty(numel_gathered, dtype=tensor.dtype, + device=torch.cuda.current_device(), + requires_grad=False) + chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)] + dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D)) + return gathered diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index 13a815592..71e39848f 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -6,7 +6,7 @@ import inspect import torch.cuda from torch import Tensor -from colossalai.communication import * +import colossalai.communication as comm from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.amp.naive_amp import NaiveAMPModel @@ -33,16 +33,22 @@ class PipelineSchedule(BaseSchedule): :type num_microbatches: int :param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch` :type batch_data_process_func: Callable + :param scatter_gather_tensors: If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization + :type scatter_gather_tensors: bool """ def __init__(self, num_microbatches, batch_data_process_func: Callable = None, - tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None): + tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, + scatter_gather_tensors: bool = False): super().__init__(batch_data_process_func=batch_data_process_func) self.num_microbatches = num_microbatches self.dtype = torch.float self.tensor_shape = tensor_shape + self.scatter_gather_tensors = False + if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1: + self.scatter_gather_tensors = scatter_gather_tensors def load_batch(self, data_iter): # Pipeline schedule just puts data in memory @@ -227,8 +233,9 @@ class PipelineSchedule(BaseSchedule): # Run warmup forward passes. for i in range(num_warmup_microbatches): if not gpc.is_first_rank(ParallelMode.PIPELINE): - ft_shape = recv_tensor_meta(ft_shape) - input_tensor = recv_forward(ft_shape, dtype=self.dtype) + ft_shape = comm.recv_tensor_meta(ft_shape) + input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) output_tensor = self.forward_step( engine, input_tensor, return_tensors, return_output_label=return_output_label, @@ -236,8 +243,8 @@ class PipelineSchedule(BaseSchedule): ) if not gpc.is_last_rank(ParallelMode.PIPELINE): bt_shape = output_tensor.shape - fs_checker = send_tensor_meta(output_tensor, fs_checker) - send_forward(output_tensor) + fs_checker = comm.send_tensor_meta(output_tensor, fs_checker) + comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors) if not forward_only: input_tensors.append(input_tensor) @@ -248,8 +255,9 @@ class PipelineSchedule(BaseSchedule): # receive this tensor here. if num_microbatches_remaining > 0: if not gpc.is_first_rank(ParallelMode.PIPELINE): - ft_shape = recv_tensor_meta(ft_shape) - input_tensor = recv_forward(ft_shape, dtype=self.dtype) + ft_shape = comm.recv_tensor_meta(ft_shape) + input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) # Run 1F1B in steady state. for i in range(num_microbatches_remaining): @@ -261,14 +269,15 @@ class PipelineSchedule(BaseSchedule): accum_loss=accum_loss ) if forward_only: - send_forward(output_tensor) + comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors) if not last_iteration: - input_tensor = recv_forward(ft_shape, dtype=self.dtype) + input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) else: - output_tensor_grad = send_forward_recv_backward( - output_tensor, bt_shape, dtype=self.dtype) + output_tensor_grad = comm.send_forward_recv_backward( + output_tensor, bt_shape, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors) # Add input_tensor and output_tensor to end of list. input_tensors.append(input_tensor) @@ -287,10 +296,10 @@ class PipelineSchedule(BaseSchedule): if last_iteration: input_tensor = None - send_backward(input_tensor_grad) + comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors) else: - input_tensor = send_backward_recv_forward( - input_tensor_grad, ft_shape, dtype=self.dtype) + input_tensor = comm.send_backward_recv_forward( + input_tensor_grad, ft_shape, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors) # Run cooldown backward passes. if not forward_only: @@ -298,7 +307,8 @@ class PipelineSchedule(BaseSchedule): input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) - output_tensor_grad = recv_backward(bt_shape, dtype=self.dtype) + output_tensor_grad = comm.recv_backward(bt_shape, dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) input_tensor_grad = self.backward_step( engine, @@ -306,7 +316,7 @@ class PipelineSchedule(BaseSchedule): output_tensor_grad ) - send_backward(input_tensor_grad) + comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors) if len(return_tensors) > 0: output, label = tuple(map(list, zip(*return_tensors))) @@ -322,7 +332,8 @@ class InterleavedPipelineSchedule(PipelineSchedule): num_microbatches, num_model_chunks, batch_data_process_func: Callable = None, - tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None): + tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, + scatter_gather_tensors: bool = False): """A helper schedule class for pipeline parallelism running environment. It uses interleaved 1F1B strategy. Other properties are similar as :class:`NonPipelineSchedule`. @@ -333,10 +344,13 @@ class InterleavedPipelineSchedule(PipelineSchedule): :type num_model_chunks: int :param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch` :type batch_data_process_func: Callable + :param scatter_gather_tensors: If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization + :type scatter_gather_tensors: bool """ assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \ 'num_microbatches must be an integer multiple of pipeline parallel world size' - super().__init__(num_microbatches, batch_data_process_func=batch_data_process_func, tensor_shape=tensor_shape) + super().__init__(num_microbatches, batch_data_process_func=batch_data_process_func, + tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather_tensors) gpc.set_virtual_pipeline_parallel_size(num_model_chunks) gpc.set_virtual_pipeline_parallel_rank(0) self.num_model_chunks = num_model_chunks @@ -494,15 +508,16 @@ class InterleavedPipelineSchedule(PipelineSchedule): # Run warmup forward passes. gpc.set_virtual_pipeline_parallel_rank(0) if not gpc.is_pipeline_first_stage(): - input_tensor_shapes[0] = recv_tensor_meta(input_tensor_shapes[0]) - input_tensors[0].append(recv_forward(input_tensor_shapes[0], dtype=self.dtype)) + input_tensor_shapes[0] = comm.recv_tensor_meta(input_tensor_shapes[0]) + input_tensors[0].append(comm.recv_forward(input_tensor_shapes[0], dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors)) for k in range(num_warmup_microbatches): model_chunk_id = get_model_chunk_id(k, forward=True) output_tensor = forward_step_helper(k) if not gpc.is_pipeline_last_stage(): output_tensor_shapes[model_chunk_id] = output_tensor.shape - send_tensor_shape_flags[model_chunk_id] = send_tensor_meta( + send_tensor_shape_flags[model_chunk_id] = comm.send_tensor_meta( output_tensor, send_tensor_shape_flags[model_chunk_id]) # Determine if tensor should be received from previous stage. next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True) @@ -519,7 +534,7 @@ class InterleavedPipelineSchedule(PipelineSchedule): with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id): if not gpc.is_pipeline_first_stage(): - input_tensor_shapes[next_forward_model_chunk_id] = recv_tensor_meta( + input_tensor_shapes[next_forward_model_chunk_id] = comm.recv_tensor_meta( input_tensor_shapes[next_forward_model_chunk_id]) # Send and receive tensors as appropriate (send tensors computed # in this iteration; receive tensors for next iteration). @@ -532,20 +547,22 @@ class InterleavedPipelineSchedule(PipelineSchedule): recv_next = False output_shape = output_tensor_shapes[num_model_chunks-1] if recv_next else None input_tensor, output_tensor_grad = \ - send_forward_backward_recv_forward_backward( + comm.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, input_shape, output_shape, recv_prev=recv_prev, recv_next=recv_next, - dtype=self.dtype) + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) else: input_tensor = \ - send_forward_recv_forward( + comm.send_forward_recv_forward( output_tensor, input_shape, recv_prev=recv_prev, - dtype=self.dtype) + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) input_tensors[next_forward_model_chunk_id].append(input_tensor) # Run 1F1B in steady state. @@ -608,12 +625,13 @@ class InterleavedPipelineSchedule(PipelineSchedule): output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None # Communicate tensors. input_tensor, output_tensor_grad = \ - send_forward_backward_recv_forward_backward( + comm.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, input_shape, output_shape, recv_prev=recv_prev, recv_next=recv_next, - dtype=self.dtype) + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) # Put input_tensor and output_tensor_grad in data structures in the # right location. @@ -627,7 +645,7 @@ class InterleavedPipelineSchedule(PipelineSchedule): if not forward_only: if all_warmup_microbatches: output_tensor_grads[num_model_chunks-1].append( - recv_backward(output_tensor_shapes[num_model_chunks-1])) + comm.recv_backward(output_tensor_shapes[num_model_chunks-1], scatter_gather_tensors=self.scatter_gather_tensors)) for k in range(num_microbatches_remaining, num_microbatches): input_tensor_grad = backward_step_helper(k) next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) @@ -639,11 +657,12 @@ class InterleavedPipelineSchedule(PipelineSchedule): recv_next = False output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None output_tensor_grads[next_backward_model_chunk_id].append( - send_backward_recv_backward( + comm.send_backward_recv_backward( input_tensor_grad, output_shape, recv_next=recv_next, - dtype=self.dtype)) + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors)) if len(return_tensors) > 0: output, label = tuple(map(list, zip(*return_tensors))) diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 6a767338b..39837e464 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -290,9 +290,10 @@ def initialize(model: Union[nn.Module, List[nn.Module]], # initialize amp amp_mode = None if fp16_cfg is not None and fp16_cfg.mode is not None: - # TODO: pipeline only support NAIVE AMP cfg_ = fp16_cfg.copy() amp_mode = cfg_.pop('mode') + if is_using_pp(): + assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently' if amp_mode == AMP_TYPE.NAIVE: cfg_['clip_grad'] = clip_grad_norm model, optimizer, criterion = convert_to_amp(model=model,