Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

441 lines
20 KiB

import time
from functools import partial
from typing import Any, Iterable, Optional, Union
import torch
import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
from .base import PipelineSchedule
class ActionIntervalBuffer:
"""
The buffer to save the interval hidden states and new token for stage to use.
"""
def __int__(self):
self.hidden_states = None
self.new_token = None
def clear(self):
self.hidden_states = None
self.new_token = None
class GenerateSchedule(PipelineSchedule):
"""
GenerateSchedule is a class that handles the pipeline parallel inference.
In our schedule, we place tie weight layer, embedding and lm_head in the same device to save space, so in
this schedule, the out for each encoding progress is on rank0.
Args:
stage_manager (`PipelineStageManager`): Pipeline stage manager.
mb_manager (`MicroBatchManager`): Micro batch manager.
verbose (bool): Whether to verbose the information of the pipeline.
"""
def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchManager, verbose: bool) -> None:
super().__init__(stage_manager)
self.comm = PipelineP2PCommunication(stage_manager)
self.mb_manager = mb_manager
self.microbatch_size = mb_manager.micro_batch_size
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
self.num_microbatches: Optional[int] = None
self.action_interval_buffer = ActionIntervalBuffer()
self.verbose = verbose
self.timestamps = None
self.comm_dtype = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
Args:
data_iter (Iterable): Data iterator.
device (Optional[torch.device], optional): Target device. Defaults to None.
"""
batch = next(data_iter)
if device is not None:
batch = tree_map(partial(to_device, device=device), batch)
self.batch = batch
self.batch_size = get_batch_size(batch)
if self.stage_manager.num_stages == 1:
self.microbatch_size = self.batch_size
self.microbatch_offset = 0
assert (
self.batch_size % self.microbatch_size == 0
), f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}"
self.num_microbatches = self.batch_size // self.microbatch_size
self.round = self.num_microbatches // self.stage_manager.num_stages
def load_micro_batch(self) -> Any:
"""Load a micro batch from the current batch.
Returns:
Any: Micro batch.
"""
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
def _prepare_inputs_for_interval_stage(self):
"""
Prepare inputs for interval stage, for all the interval stage, the inputs is just the past_key_values
Returns:
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
"""
model_inputs = {"infer_state": self.mb_manager.cur_description.infer_state}
return model_inputs
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
"""
Prepare inputs for new token, the inputs is a dict with `input_ids`, `attention_mask` and `past_key_values`
`input_ids` is the new token, `attention_mask` is the previous mask add `1` in the end,
`past_key_values` is the past_key_values save in the micro batch manager
Returns:
dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}`
"""
new_mask = self.mb_manager.cur_description.attn_mask
return dict(input_ids=new_token, attention_mask=new_mask)
def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor:
last_hidden_state = hidden_state[:, -1]
input_ids = torch.argmax(last_hidden_state, dim=-1).unsqueeze(1)
return input_ids
def _recv_pre_stage(self) -> Any:
"""
Receive the output from previous stage
Returns:
Any: The output from previous stage
"""
if self.stage_manager.num_stages == 2:
return self.comm.p2p_recv()
return self.comm.recv_forward()
def _init_infer_state_action(self) -> None:
"""
This action is only for no first stage, to load batch and init infer_state.
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
"""
inputs_dict = self.load_micro_batch()
self.mb_manager.add_description(inputs_dict)
def _load_stage_action(self, model: Module) -> None:
"""
This action is only for first stage, load, init and do forward.
1.load micro_batch 2.do the forward 3.step to update
"""
inputs_dict = self.load_micro_batch()
self.mb_manager.add_description(inputs_dict)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _gen_token_action(self, model: Module):
"""
This action is only for first stage
1.do the forward with hidden_states to generate new tokens 2.step to update
"""
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
logits = model_forward(model, None, interval_inputs)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
assert (
"logits" in logits
), f"When first stage in GENERATE phase, the output should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits["logits"])
self.mb_manager.step(new_token)
self.action_interval_buffer.new_token = new_token
self.action_interval_buffer.hidden_states = None
def _head_encoding_action(self, model: Module):
"""
In this action, 1.prepare inputs for encoding for first stage. 2.do the forward to get hidden states 3.step to update
"""
new_token = self.action_interval_buffer.new_token
assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None"
inputs_dict = self._prepare_inputs_for_new_token(new_token)
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _body_encoding_action(self, model: Module):
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, None, interval_inputs)
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
"""
In this action, 1.receive the hidden_states from previous stage 2.send the hidden_states to next stage
"""
hidden_states = self.action_interval_buffer.hidden_states
ret = self.comm.p2p_communicate(hidden_states, recv_pre, comm_dtype=self.comm_dtype)
self.action_interval_buffer.hidden_states = ret
def _gen_action(self, model: Module):
"""
In p2p step method, we use `P2POp` asynchronous communication method, so the communication need to be done
at the begin of each microbatch, it's a more clear way to use an action list to do so. In this function, it will
generate a sequence action for current state, and do the action one by one.
Args:
model (Module): Model to be run.
Returns:
List[Callable]: A list of action, each action is a callable function, and it will be called in order.
"""
actions = []
if self.stage_manager.is_first_stage():
if self.mb_manager.cur_state is Status.PREFILL:
actions.append(partial(self._comm_action, False))
actions.append(partial(self._load_stage_action, model))
elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.GENERATE:
actions.append(partial(self._comm_action, True))
actions.append(partial(self._gen_token_action, model))
actions.append(partial(self._head_encoding_action, model))
elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.COOLDOWN:
actions.append(partial(self._comm_action, True))
actions.append(partial(self._gen_token_action, model))
# other stage
else:
if self.mb_manager.cur_state is Status.PREFILL:
actions.append(partial(self._init_infer_state_action))
actions.append(partial(self._comm_action, True))
actions.append(partial(self._body_encoding_action, model))
return actions
def _gen_one_stage_action(self, model: Module):
"""
In this function, it will generate a sequence action for current state, and do the action one by one.
Args:
model (Module): Model to be run.
Returns:
List[Callable]: A list of action, each action is a callable function, and it will be called in order.
"""
actions = []
if self.mb_manager.cur_state is Status.PREFILL:
actions.append(partial(self._load_stage_action, model))
elif self.mb_manager.cur_state is Status.GENERATE:
actions.append(partial(self._gen_token_action, model))
actions.append(partial(self._head_encoding_action, model))
elif self.mb_manager.cur_state is Status.COOLDOWN:
actions.append(partial(self._gen_token_action, model))
return actions
def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
if self.stage_manager.num_stages == 1:
return self.generate_step_one_stage(model, data_iter)
elif self.stage_manager.num_stages == 2:
return self.generate_step_p2p(model, data_iter)
else:
return self.generate_step_broadcast(model, data_iter)
@torch.no_grad()
def generate_step_one_stage(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
"""
Forward one step of the pipeline, when pipeline size is 1.
Args:
model (Module): Model to be run.
data_iter (Iterable): Data iterator.
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
output_sequence = []
self.load_batch(data_iter)
model.eval()
self.comm_dtype = model.dtype
whole_timestamp = []
# run by round
for _ in range(self.round):
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)] if self.verbose else None
self.action_interval_buffer.clear()
while self.mb_manager.is_micro_batch_done() is False:
actions = self._gen_one_stage_action(model)
for action in actions:
action()
self.mb_manager.next()
# All microbatch in current round is DONE
output_sequence.extend(self.mb_manager.export_new_tokens())
self.mb_manager.clear()
if self.verbose:
whole_timestamp.extend(self.timestamps)
return output_sequence, whole_timestamp
@torch.no_grad()
def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
"""
Forward one step of the pipeline, when pipeline size is 2, the schedule is a circle, broadcast communication will be
blocked, so we use `P2POp` asynchronous communication method.
Args:
model (Module): Model to be run.
data_iter (Iterable): Data iterator.
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
output_sequence = []
self.load_batch(data_iter)
model.eval()
self.comm_dtype = model.dtype
whole_timestamp = []
# run by round
for _ in range(self.round):
self.timestamps = (
[[] for _ in range(self.stage_manager.num_stages)]
if self.verbose and self.stage_manager.is_first_stage()
else None
)
self.action_interval_buffer.clear()
while self.mb_manager.is_micro_batch_done() is False:
actions = self._gen_action(model)
for action in actions:
action()
self.mb_manager.next()
# All microbatch in current round is DONE
if self.stage_manager.is_first_stage():
output_sequence.extend(self.mb_manager.export_new_tokens())
else:
self._comm_action(False)
self.mb_manager.clear()
if self.verbose and self.stage_manager.is_first_stage():
whole_timestamp.extend(self.timestamps)
return output_sequence, whole_timestamp
@torch.no_grad()
def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
"""
Forward one step of the pipeline
Args:
model (Module): Model to be run.
data_iter (Iterable): Data iterator.
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
output_sequence = []
self.load_batch(data_iter)
model.eval()
whole_timestamp = []
# run by round
for _ in range(self.round):
self.timestamps = (
[[] for _ in range(self.stage_manager.num_stages)]
if self.verbose and self.stage_manager.is_first_stage()
else None
)
while self.mb_manager.is_micro_batch_done() is False:
inputs_dict = None
new_token = None
output_dict = None
# First stage and in PREFILL phase, just load the inputs
if self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.PREFILL:
inputs_dict = self.load_micro_batch()
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
self.mb_manager.add_description(inputs_dict)
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
# In GENERATE phase
else:
# Get hidden_states from previous stage
hidden_states = self.comm.recv_forward()
if self.stage_manager.is_first_stage():
# First just generate a new token
assert (
hidden_states is not None
), "When first stage in GENERATE phase, the hidden states should not be None"
interval_inputs = {
"hidden_states": hidden_states["hidden_states"],
"infer_state": self.mb_manager.cur_infer_state,
}
logits = model_forward(model, None, interval_inputs)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
assert (
"logits" in logits
), f"When first stage in GENERATE phase, the output should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits["logits"])
self.mb_manager.step(new_token)
# If the current micro batch is not DONE, go through blocks
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
inputs_dict = self._prepare_inputs_for_new_token(new_token)
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
else:
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
# inputs_dict = self._prepare_inputs_for_interval_stage()
inputs_dict = None
if self.mb_manager.cur_state is Status.PREFILL:
inputs_dict = self.load_micro_batch()
self.mb_manager.add_description(inputs_dict)
interval_inputs = {
"hidden_states": hidden_states["hidden_states"],
"infer_state": self.mb_manager.cur_infer_state,
}
output_dict = model_forward(model, inputs_dict, interval_inputs)
# Current microbatch is not DONE, send hidden_state to next stage
if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (
Status.GENERATE,
Status.COOLDOWN,
):
self.comm.send_forward({"hidden_states": output_dict["hidden_states"]})
self.mb_manager.next()
# All microbatch in current round is DONE
if self.stage_manager.is_first_stage():
output_sequence.extend(self.mb_manager.export_new_tokens())
self.mb_manager.clear()
if self.verbose and self.stage_manager.is_first_stage():
whole_timestamp.extend(self.timestamps)
return output_sequence, whole_timestamp