import time from functools import partial from typing import Any, Iterable, Optional, Union import torch import torch.cuda from torch.nn import Module from torch.utils._pytree import tree_map from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils.cuda import get_current_device from ._utils import get_batch_size, get_micro_batch, model_forward, to_device from .base import PipelineSchedule class ActionIntervalBuffer: """ The buffer to save the interval hidden states and new token for stage to use. """ def __int__(self): self.hidden_states = None self.new_token = None def clear(self): self.hidden_states = None self.new_token = None class GenerateSchedule(PipelineSchedule): """ GenerateSchedule is a class that handles the pipeline parallel inference. In our schedule, we place tie weight layer, embedding and lm_head in the same device to save space, so in this schedule, the out for each encoding progress is on rank0. Args: stage_manager (`PipelineStageManager`): Pipeline stage manager. mb_manager (`MicroBatchManager`): Micro batch manager. verbose (bool): Whether to verbose the information of the pipeline. """ def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchManager, verbose: bool) -> None: super().__init__(stage_manager) self.comm = PipelineP2PCommunication(stage_manager) self.mb_manager = mb_manager self.microbatch_size = mb_manager.micro_batch_size self.batch: Optional[Any] = None self.batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None self.num_microbatches: Optional[int] = None self.action_interval_buffer = ActionIntervalBuffer() self.verbose = verbose self.timestamps = None self.comm_dtype = None def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. Args: data_iter (Iterable): Data iterator. device (Optional[torch.device], optional): Target device. Defaults to None. """ batch = next(data_iter) if device is not None: batch = tree_map(partial(to_device, device=device), batch) self.batch = batch self.batch_size = get_batch_size(batch) self.microbatch_offset = 0 assert ( self.batch_size % self.microbatch_size == 0 ), f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}" self.num_microbatches = self.batch_size // self.microbatch_size self.round = self.num_microbatches // self.stage_manager.num_stages def load_micro_batch(self) -> Any: """Load a micro batch from the current batch. Returns: Any: Micro batch. """ micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) self.microbatch_offset += self.microbatch_size return tree_map(partial(to_device, device=get_current_device()), micro_batch) def _prepare_inputs_for_interval_stage(self): """ Prepare inputs for interval stage, for all the interval stage, the inputs is just the past_key_values Returns: dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None` """ model_inputs = {"infer_state": self.mb_manager.cur_descrption.infer_state} return model_inputs def _prepare_inputs_for_new_token(self, new_token: torch.Tensor): """ Prepare inputs for new token, the inputs is a dict with `input_ids`, `attention_mask` and `past_key_values` `input_ids` is the new token, `attention_mask` is the previous mask add `1` in the end, `past_key_values` is the past_key_values save in the micro batch manager Returns: dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}` """ new_mask = self.mb_manager.cur_descrption.attn_mask return dict(input_ids=new_token, attention_mask=new_mask) def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor: last_hidden_state = hidden_state[:, -1] input_ids = torch.argmax(last_hidden_state, dim=-1).unsqueeze(1) return input_ids def _recv_pre_stage(self) -> Any: """ Receive the output from previous stage Returns: Any: The output from previous stage """ if self.stage_manager.num_stages == 2: return self.comm.p2p_recv() return self.comm.recv_forward() def _init_infer_state_action(self) -> None: """ This action is only for no first stage, to load batch and init infer_state. 1.Load micro_batch 2.Use the current micro_batch to init the current infer_state """ inputs_dict = self.load_micro_batch() self.mb_manager.add_descrption(inputs_dict) def _load_stage_action(self, model: Module) -> None: """ This action is only for first stage, load, init and do forward. 1.load micro_batch 2.do the forward 3.step to update """ inputs_dict = self.load_micro_batch() self.mb_manager.add_descrption(inputs_dict) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) interval_inputs = {"infer_state": self.mb_manager.cur_infer_state} output_dict = model_forward(model, inputs_dict, interval_inputs) self.action_interval_buffer.hidden_states = output_dict["hidden_states"] def _gen_token_action(self, model: Module): """ This action is only for first stage 1.do the forward with hidden_states to generate new tokens 2.step to update """ hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state} logits = model_forward(model, None, interval_inputs) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) assert ( "logits" in logits ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" new_token = self._get_token_id(logits["logits"]) self.mb_manager.step(new_token) self.action_interval_buffer.new_token = new_token self.action_interval_buffer.hidden_states = None def _head_encoding_action(self, model: Module): """ In this action, 1.prepare inputs for encoding for first stage. 2.do the forward to get hidden states 3.step to update """ new_token = self.action_interval_buffer.new_token assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None" inputs_dict = self._prepare_inputs_for_new_token(new_token) interval_inputs = {"infer_state": self.mb_manager.cur_infer_state} output_dict = model_forward(model, inputs_dict, interval_inputs) self.action_interval_buffer.hidden_states = output_dict["hidden_states"] def _body_encoding_action(self, model: Module): hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When not first stage, the hidden states should not be None" interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state} output_dict = model_forward(model, None, interval_inputs) self.action_interval_buffer.hidden_states = output_dict["hidden_states"] def _comm_action(self, recv_pre: bool) -> torch.Tensor: """ In this action, 1.receive the hidden_states from previous stage 2.send the hidden_states to next stage """ hidden_states = self.action_interval_buffer.hidden_states ret = self.comm.p2p_communicate(hidden_states, recv_pre, comm_dtype=self.comm_dtype) self.action_interval_buffer.hidden_states = ret def _gen_action(self, model: Module): """ In p2p step method, we use `P2POp` asynchronous communication method, so the communication need to be done at the begin of each microbatch, it's a more clear way to use an action list to do so. In this function, it will generate a sequence action for current state, and do the action one by one. Args: model (Module): Model to be run. Returns: List[Callable]: A list of action, each action is a callable function, and it will be called in order. """ actions = [] if self.stage_manager.is_first_stage(): if self.mb_manager.cur_state is Status.PREFILL: actions.append(partial(self._comm_action, False)) actions.append(partial(self._load_stage_action, model)) elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.GENERATE: actions.append(partial(self._comm_action, True)) actions.append(partial(self._gen_token_action, model)) actions.append(partial(self._head_encoding_action, model)) elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.COOLDOWN: actions.append(partial(self._comm_action, True)) actions.append(partial(self._gen_token_action, model)) # other stage else: if self.mb_manager.cur_state is Status.PREFILL: actions.append(partial(self._init_infer_state_action)) actions.append(partial(self._comm_action, True)) actions.append(partial(self._body_encoding_action, model)) return actions def _gen_one_stage_action(self, model: Module): """ In this function, it will generate a sequence action for current state, and do the action one by one. Args: model (Module): Model to be run. Returns: List[Callable]: A list of action, each action is a callable function, and it will be called in order. """ actions = [] if self.mb_manager.cur_state is Status.PREFILL: actions.append(partial(self._load_stage_action, model)) elif self.mb_manager.cur_state is Status.GENERATE: actions.append(partial(self._gen_token_action, model)) actions.append(partial(self._head_encoding_action, model)) elif self.mb_manager.cur_state is Status.COOLDOWN: actions.append(partial(self._gen_token_action, model)) return actions def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: if self.stage_manager.num_stages == 1: return self.generate_step_one_stage(model, data_iter) elif self.stage_manager.num_stages == 2: return self.generate_step_p2p(model, data_iter) else: return self.generate_step_broadcast(model, data_iter) @torch.no_grad() def generate_step_one_stage(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: """ Forward one step of the pipeline, when pipeline size is 1. Args: model (Module): Model to be run. data_iter (Iterable): Data iterator. Returns: Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). """ output_sequence = [] self.load_batch(data_iter) model.eval() self.comm_dtype = model.dtype whole_timestamp = [] # run by round for _ in range(self.round): self.timestamps = [[] for _ in range(self.stage_manager.num_stages)] if self.verbose else None self.action_interval_buffer.clear() while self.mb_manager.is_micro_batch_done() is False: actions = self._gen_one_stage_action(model) for action in actions: action() self.mb_manager.next() # All microbatch in current round is DONE output_sequence.extend(self.mb_manager.export_new_tokens()) self.mb_manager.clear() if self.verbose: whole_timestamp.extend(self.timestamps) return output_sequence, whole_timestamp @torch.no_grad() def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: """ Forward one step of the pipeline, when pipeline size is 2, the schedule is a circle, broadcast communication will be blocked, so we use `P2POp` asynchronous communication method. Args: model (Module): Model to be run. data_iter (Iterable): Data iterator. Returns: Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). """ output_sequence = [] self.load_batch(data_iter) model.eval() self.comm_dtype = model.dtype whole_timestamp = [] # run by round for _ in range(self.round): self.timestamps = ( [[] for _ in range(self.stage_manager.num_stages)] if self.verbose and self.stage_manager.is_first_stage() else None ) self.action_interval_buffer.clear() while self.mb_manager.is_micro_batch_done() is False: actions = self._gen_action(model) for action in actions: action() self.mb_manager.next() # All microbatch in current round is DONE if self.stage_manager.is_first_stage(): output_sequence.extend(self.mb_manager.export_new_tokens()) else: self._comm_action(False) self.mb_manager.clear() if self.verbose and self.stage_manager.is_first_stage(): whole_timestamp.extend(self.timestamps) return output_sequence, whole_timestamp @torch.no_grad() def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: """ Forward one step of the pipeline Args: model (Module): Model to be run. data_iter (Iterable): Data iterator. Returns: Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). """ output_sequence = [] self.load_batch(data_iter) model.eval() whole_timestamp = [] # run by round for _ in range(self.round): self.timestamps = ( [[] for _ in range(self.stage_manager.num_stages)] if self.verbose and self.stage_manager.is_first_stage() else None ) while self.mb_manager.is_micro_batch_done() is False: inputs_dict = None new_token = None output_dict = None # First stage and in PREFILL phase, just load the inputs if self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.PREFILL: inputs_dict = self.load_micro_batch() if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) self.mb_manager.add_descrption(inputs_dict) interval_inputs = {"infer_state": self.mb_manager.cur_infer_state} output_dict = model_forward(model, inputs_dict, interval_inputs) # In GENERATE phase else: # Get hidden_states from previous stage hidden_states = self.comm.recv_forward() if self.stage_manager.is_first_stage(): # First just generate a new token assert ( hidden_states is not None ), "When first stage in GENERATE phase, the hidden states should not be None" interval_inputs = { "hidden_states": hidden_states["hidden_states"], "infer_state": self.mb_manager.cur_infer_state, } logits = model_forward(model, None, interval_inputs) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) assert ( "logits" in logits ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" new_token = self._get_token_id(logits["logits"]) self.mb_manager.step(new_token) # If the current micro batch is not DONE, go through blocks if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN): inputs_dict = self._prepare_inputs_for_new_token(new_token) interval_inputs = {"infer_state": self.mb_manager.cur_infer_state} output_dict = model_forward(model, inputs_dict, interval_inputs) else: assert hidden_states is not None, "When not first stage, the hidden states should not be None" # inputs_dict = self._prepare_inputs_for_interval_stage() inputs_dict = None if self.mb_manager.cur_state is Status.PREFILL: inputs_dict = self.load_micro_batch() self.mb_manager.add_descrption(inputs_dict) interval_inputs = { "hidden_states": hidden_states["hidden_states"], "infer_state": self.mb_manager.cur_infer_state, } output_dict = model_forward(model, inputs_dict, interval_inputs) # Current microbatch is not DONE, send hidden_state to next stage if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in ( Status.GENERATE, Status.COOLDOWN, ): self.comm.send_forward({"hidden_states": output_dict["hidden_states"]}) self.mb_manager.next() # All microbatch in current round is DONE if self.stage_manager.is_first_stage(): output_sequence.extend(self.mb_manager.export_new_tokens()) self.mb_manager.clear() if self.verbose and self.stage_manager.is_first_stage(): whole_timestamp.extend(self.timestamps) return output_sequence, whole_timestamp