diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 779ff42d7..b7900bc0f 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -1,4 +1,5 @@ import copy +from functools import reduce import logging import os from pathlib import Path @@ -313,9 +314,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): # Keep a record of loaded files so that file will not be repeatedly loaded. loaded_file = set() + missing_keys = [] + missing_file_keys = [] + def _load(name: str): if name not in weight_map: - raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!") + missing_file_keys.append(name) + return filename = weight_map[name] # If this param/buffer has been loaded before, directly return. @@ -324,7 +329,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): file_path = os.path.join(ckpt_root_path, filename) state_dict = load_shard_state_dict(Path(file_path), use_safetensors) - missing_keys = [] load_state_dict_into_model( model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True @@ -357,6 +361,27 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): if self.verbose and self.coordinator.is_master(): logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + if len(missing_keys) == 0: + raise RuntimeError( + "No weigth is loaded into the model. Please check the checkpoint files and the model structure." + ) + + remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) + remain_keys = remain_keys.union(set(missing_file_keys)) + if len(remain_keys) > 0: + if strict: + error_msgs = "Missing key(s) in state_dict: {}. ".format( + ", ".join('"{}"'.format(k) for k in missing_keys) + ) + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + self.__class__.__name__, "\n\t".join(error_msgs) + ) + ) + else: + if self.coordinator.is_master(): + logging.info(f"The following keys are not loaded from checkpoint: {remain_keys}") + def save_sharded_optimizer( self, optimizer: OptimizerWrapper, diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index f822c1819..6e49fa36b 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -5,9 +5,12 @@ import io import pickle import re from typing import Any, List, Optional, Union +from collections import namedtuple import torch import torch.distributed as dist +from dataclasses import dataclass +from enum import Enum from packaging.version import Version from torch.distributed import ProcessGroup from torch.distributed import distributed_c10d as c10d @@ -45,6 +48,21 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - return unpickle +def check_for_nccl_backend(group): + pg = group or c10d._get_default_group() + # Gate PG wrapper check on Gloo availability. + if c10d._GLOO_AVAILABLE: + # It is not expected for PG to be wrapped many times, but support it just + # in case + while isinstance(pg, c10d._ProcessGroupWrapper): + pg = pg.wrapped_pg + + return ( + c10d.is_nccl_available() and + pg.name() == c10d.Backend.NCCL + ) + + def _broadcast_object_list( object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None ): @@ -65,7 +83,7 @@ def _broadcast_object_list( c10d._warn_not_in_group("broadcast_object_list") return - is_nccl_backend = c10d._check_for_nccl_backend(group) + is_nccl_backend = check_for_nccl_backend(group) current_device = None if device is not None: @@ -113,7 +131,7 @@ def _broadcast_object_list( if my_rank != src: for i, obj_size in enumerate(object_sizes_tensor): - obj_view = object_tensor[offset : offset + obj_size] + obj_view = object_tensor[offset: offset + obj_size] obj_view = obj_view.type(torch.uint8) if obj_view.device != torch.device("cpu"): obj_view = obj_view.cpu() @@ -131,6 +149,258 @@ def _broadcast_object_list( object_list[i] = unpickle_object +def check_device(group): + is_nccl_backend = check_for_nccl_backend(group) + current_device = None + + current_device = torch.device("cpu") + if is_nccl_backend: + current_device = torch.device("cuda", torch.cuda.current_device()) + return current_device, is_nccl_backend + + +TensorMetadata = namedtuple('TensorMetadata', ['key', 'shape', 'dtype', 'requires_grad']) + + +class P2PDataType(Enum): + serialization = 0 + tensor = 1 + list = 2 + dict = 3 + + +@dataclass +class P2PMetadata: + data_type: P2PDataType + content: Union[List[TensorMetadata], TensorMetadata, Any] + + +def filling_ops_queue(obj, comm_op, comm_rank, ops_queue, group): + if isinstance(obj, torch.Tensor): + obj = obj.contiguous() + op_to_add = dist.P2POp(comm_op, obj, comm_rank, group) + ops_queue.append(op_to_add) + else: + for tensor_to_comm in obj: + tensor_to_comm = tensor_to_comm.contiguous() + op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank, group) + ops_queue.append(op_to_add) + + +def create_recv_buffer(p2p_metadata: P2PMetadata, current_device): + if p2p_metadata.data_type == P2PDataType.tensor: + metadata = p2p_metadata.content + tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype) + return tensor_recv + elif p2p_metadata.data_type in (P2PDataType.list, P2PDataType.dict): + buffer_recv = [] + for metadata in p2p_metadata.content: + tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype) + buffer_recv.append(tensor_recv) + return buffer_recv + else: + raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}") + + +def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device): + buffer_recv = None + if recv_tensor_metadata is not None: + buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device) + + ops = [] + + if send_dst is not None: + filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group) + + if recv_src is not None: + assert buffer_recv is not None + filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group) + + if len(ops) > 0: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + torch.cuda.synchronize() + + # Remove synchronization according to Pytorch's documentation + # However, the Megatron-LM does synchronization here + # https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112 + # In case there is potential error, uncomment the following `torch.cuda.synchronize()` + # torch.cuda.synchronize() + + return buffer_recv + + +def _send_recv_serialization_object( + object: Any, + send_dst: Optional[int], recv_src: Optional[int], + send_group: Optional[ProcessGroup], recv_group: Optional[ProcessGroup], + current_device, + is_nccl_backend): + ops = [] + send_object_tensor = None + if object is not None and send_dst is not None: + if Version(torch.__version__) >= Version("1.13.0"): + send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device) + else: + send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object) + + if is_nccl_backend: + send_object_size_tensor = send_object_size_tensor.to(current_device) + send_object_tensor = send_object_tensor.to(current_device) + + filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group) + + recv_object_size_tensor = None + if recv_src is not None: + recv_object_size_tensor = torch.empty(1, dtype=torch.long) + if is_nccl_backend: + recv_object_size_tensor = recv_object_size_tensor.to(current_device) + filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group) + + if len(ops) > 0: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + torch.cuda.synchronize() + + # See the comment in `_batch_send_recv_tensor` + # torch.cuda.synchronize() + + ops = [] + + if send_dst is not None and send_object_tensor is not None: + filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group) + + recv_object_tensor = None + if recv_src is not None and recv_object_size_tensor is not None: + recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8) + if is_nccl_backend: + recv_object_tensor = recv_object_tensor.to(current_device) + filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group) + + if len(ops) > 0: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + torch.cuda.synchronize() + + # See the comment in `_batch_send_recv_tensor` + # torch.cuda.synchronize() + + if recv_object_tensor is not None and recv_object_size_tensor is not None: + recv_object_tensor = recv_object_tensor.type(torch.uint8) + if recv_object_tensor.device != torch.device("cpu"): + recv_object_tensor = recv_object_tensor.cpu() + + unpickle_object = _cuda_safe_tensor_to_object( + recv_object_tensor, recv_object_size_tensor.item()) + + if ( + isinstance(unpickle_object, torch.Tensor) + and unpickle_object.device.index != torch.cuda.current_device() + ): + unpickle_object = unpickle_object.cuda() + + return unpickle_object + + +def _check_if_fast_send_available(object): + if type(object) is torch.Tensor: + return True + elif type(object) is list: + is_list_of_tensor = all([type(v) is torch.Tensor for v in object]) + return is_list_of_tensor + elif type(object) is dict: + is_dict_of_tensor = all([type(k) is str and type( + v) is torch.Tensor for k, v in object.items()]) + + return is_dict_of_tensor + return False + + +def _communicate( + object, + send_dst: Optional[int], + recv_src: Optional[int], + send_group: Optional[ProcessGroup] = None, + recv_group: Optional[ProcessGroup] = None, +) -> Any: + if c10d._rank_not_in_group(send_group) or c10d._rank_not_in_group(recv_group): + c10d._warn_not_in_group("_communicate") + return + + current_send_device, is_send_nccl_backend = check_device(send_group) + current_recv_device, is_recv_nccl_backend = check_device(recv_group) + + is_nccl_backend = is_send_nccl_backend and is_recv_nccl_backend + + assert current_send_device == current_recv_device + current_device = current_send_device + + assert (send_dst is not None) or (recv_src is not None) + + can_fast_send = False + send_metadata = None + if send_dst is not None: + can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend + if not can_fast_send: + send_metadata = P2PMetadata(P2PDataType.serialization, object) + else: + if type(object) is torch.Tensor: + data_type = P2PDataType.tensor + content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad) + elif type(object) is list: + data_type = P2PDataType.list + content = [] + for v in object: + content.append(TensorMetadata(None, v.shape, v.dtype, v.requires_grad)) + elif type(object) is dict: + data_type = P2PDataType.dict + content = [] + for k, v in object.items(): + content.append(TensorMetadata(k, v.shape, v.dtype, v.requires_grad)) + else: + raise ValueError('Cannot send object of type {}'.format(type(object))) + send_metadata = P2PMetadata(data_type, content) + + recv_metadata = _send_recv_serialization_object(send_metadata, send_dst, recv_src, send_group, recv_group, current_device, is_nccl_backend) + if recv_metadata is not None: + assert type(recv_metadata) is P2PMetadata + if recv_metadata.data_type == P2PDataType.serialization: + return recv_metadata.content + if not can_fast_send and send_dst is not None: + return + + send_tensor_list = None + if type(object) is torch.Tensor: + send_tensor_list = object + elif type(object) is list: + send_tensor_list = object + elif type(object) is dict: + send_tensor_list = list(object.values()) + + recv_buffer = _batch_send_recv_tensor(send_tensor_list, recv_metadata, send_dst, recv_src, send_group, recv_group, current_device) + + if recv_metadata is not None: + assert recv_buffer is not None + if recv_metadata.data_type in [P2PDataType.tensor, P2PDataType.list]: + return recv_buffer + elif recv_metadata.data_type == P2PDataType.dict: + return { + k: v + for k, v in zip( + [m.key for m in recv_metadata.content], + recv_buffer, + ) + } + else: + raise ValueError('Unknown data type {}'.format(recv_metadata.data_type)) + + def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None: """send anything to dst rank @@ -141,8 +411,7 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None: Returns: None """ - # then broadcast safely - _broadcast_object_list([object], src, group) + _communicate(object, send_dst=dst, recv_src=None, send_group=group) def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: @@ -154,10 +423,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: Returns: Any: Object received from src. """ - object_list = [None] - _broadcast_object_list(object_list, src, group) - - return object_list[0] + return _communicate(None, send_dst=None, recv_src=src, recv_group=group) def _p2p_comm( @@ -302,6 +568,64 @@ class PipelineP2PCommunication: cur_rank = self.stage_manager.get_rank() _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) + def send_forward_recv_backward(self, input_object: Any, next_rank: int = None) -> Any: + """Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline + + Args: + input_object (Any): Object to be sent. + next_rank (int, optional): The rank of the sender and recipient of the tensor + """ + if next_rank is None: + next_rank = self.stage_manager.get_next_rank() + + cur_rank = self.stage_manager.get_rank() + group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank) + return _communicate( + input_object, next_rank, next_rank, + send_group=group, recv_group=group, + ) + + def send_backward_recv_forward(self, input_object: Any, prev_rank: int = None) -> Any: + """Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline + + Args: + input_object (Any): Object to be sent. + prev_rank (int, optional): The rank of the sender and recipient of the tensor + """ + if prev_rank is None: + prev_rank = self.stage_manager.get_prev_rank() + + cur_rank = self.stage_manager.get_rank() + group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank) + return _communicate( + input_object, prev_rank, prev_rank, + send_group=group, recv_group=group, + ) + + def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any: + """Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline. + + Args: + input_object (Any): Object to be sent. + prev_rank (int, optional): The rank of the sender of the tensor + next_rank (int, optional): The rank of the recipient of the tensor + """ + if prev_rank is None: + prev_rank = self.stage_manager.get_prev_rank() + if next_rank is None: + next_rank = self.stage_manager.get_next_rank() + + cur_rank = self.stage_manager.get_rank() + recv_group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank) + send_group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank) + return _communicate( + input_object, + send_dst=next_rank, + recv_src=prev_rank, + send_group=send_group, + recv_group=recv_group, + ) + def p2p_communicate( self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16 ) -> None: diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 4eaf135fd..1f3b80857 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -127,6 +127,17 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): if not self.stage_manager.is_last_stage(): self.comm.send_forward(output_object, next_rank) + def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any: + """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline. + For 1F1B. + + Args: + output_object (Any): Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if not self.stage_manager.is_last_stage(): + return self.comm.send_forward_recv_backward(output_object, next_rank) + def send_backward(self, input_object: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. For 1F1B. @@ -138,6 +149,33 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): if not self.stage_manager.is_first_stage(): self.comm.send_backward(input_object, prev_rank) + def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any: + """Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline. + For 1F1B. + + Args: + output_object (Any): Object to be sent. + prev_rank (int, optional): The rank of the recipient of the tensor. + """ + if not self.stage_manager.is_first_stage(): + return self.comm.send_backward_recv_forward(output_object, prev_rank) + + def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any: + """Sends the input tensor to the next stage and copy the input tensor from the previous stage in pipeline. + For 1F1B. + + Args: + input_object (Any): Object to be sent. + prev_rank (int, optional): The previous rank of the recipient of the tensor. + next_rank (int, optional): The next rank of the recipient of the tensor. + """ + if self.stage_manager.is_first_stage(): + return self.comm.send_forward(input_object, next_rank) + elif self.stage_manager.is_last_stage(): + return self.comm.recv_forward(prev_rank) + else: + return self.comm.send_forward_recv_forward(input_object, prev_rank, next_rank) + def forward_step( self, model: Module, @@ -291,7 +329,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): if not last_iteration: input_obj = self.recv_forward() - else: # TODO adjust here self.send_forward(output_obj) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 4b6c83425..0f911be48 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -413,6 +413,7 @@ def get_llama_flash_attention_forward(): past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index ce13ebbf6..b38ddbb4a 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -183,7 +183,11 @@ def main(): model_numel = get_model_numel(model) coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") performance_evaluator = PerformanceEvaluator( - model_numel, args.grad_checkpoint, args.ignore_steps, dp_world_size=dp_size + model_numel, + model.config.num_hidden_layers, + model.config.hidden_size, + model.config.vocab_size, + args.grad_checkpoint, args.ignore_steps, dp_world_size=dp_size ) optimizer = HybridAdam(model.parameters()) diff --git a/examples/language/llama2/performance_evaluator.py b/examples/language/llama2/performance_evaluator.py index a57c1e0e9..05e71edf1 100644 --- a/examples/language/llama2/performance_evaluator.py +++ b/examples/language/llama2/performance_evaluator.py @@ -58,6 +58,9 @@ class PerformanceEvaluator: def __init__( self, model_numel: int, + num_layers: int, + hidden_size: int, + vocab_size: int, enable_grad_checkpoint: bool = False, ignore_steps: int = 0, dp_world_size: Optional[int] = None, @@ -65,12 +68,16 @@ class PerformanceEvaluator: self.model_numel = model_numel self.enable_grad_checkpoint = enable_grad_checkpoint self.ignore_steps = ignore_steps + self.num_layers = num_layers + self.hidden_size = hidden_size + self.vocab_size = vocab_size self.coordinator = DistCoordinator() self.dp_world_size = dp_world_size or self.coordinator.world_size self.disable: bool = False self.timer = Timer() self.num_samples: int = 0 + self.flop_megatron = 0 self.flop: int = 0 def on_step_start(self, step: int) -> None: @@ -89,17 +96,20 @@ class PerformanceEvaluator: batch_size, seq_len = input_ids.shape self.num_samples += batch_size + checkpoint_activations_factor = (3 + int(self.enable_grad_checkpoint)) + self.flop_megatron += (24 * checkpoint_activations_factor * batch_size * seq_len * self.num_layers * (self.hidden_size**2)) * (1. + (seq_len / (6. * self.hidden_size)) + (self.vocab_size / (16. * self.num_layers * self.hidden_size))) self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint)) def on_fit_end(self) -> None: avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size) avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12) mp_world_size = self.coordinator.world_size // self.dp_world_size + avg_tflops_per_gpu_megatron = self.flop_megatron / 1e12 / (avg_duration + 1e-12) / mp_world_size avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size self.coordinator.print_on_master( - f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, " + f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop_megatron: {self.flop_megatron}, flop: {self.flop}, avg_duration: {avg_duration}, " f"avg_throughput: {avg_throughput}" ) self.coordinator.print_on_master( - f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}" + f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU by Megatron: {avg_tflops_per_gpu_megatron:.2f}, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}" )