|
|
|
@ -17,8 +17,8 @@ class Status(Enum):
|
|
|
|
|
|
|
|
|
|
class MicroBatchDescription: |
|
|
|
|
""" |
|
|
|
|
This is the class to record the infomation of each microbatch, and also do some update operation. |
|
|
|
|
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more |
|
|
|
|
This is the class to record the information of each microbatch, and also do some update operation. |
|
|
|
|
This class is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more |
|
|
|
|
details, please refer to the doc of these two classes blow. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
@ -61,15 +61,15 @@ class MicroBatchDescription:
|
|
|
|
|
@property |
|
|
|
|
def cur_length(self): |
|
|
|
|
""" |
|
|
|
|
Return the current sequnence length of micro batch |
|
|
|
|
Return the current sequence length of micro batch |
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HeadMicroBatchDescription(MicroBatchDescription): |
|
|
|
|
""" |
|
|
|
|
This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask` |
|
|
|
|
and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the |
|
|
|
|
This class is used to record the information of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask` |
|
|
|
|
and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schedule of pipeline, the operation to update the |
|
|
|
|
information and the condition to determine the state is different from other stages. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
@ -123,7 +123,7 @@ class HeadMicroBatchDescription(MicroBatchDescription):
|
|
|
|
|
|
|
|
|
|
class BodyMicroBatchDescription(MicroBatchDescription): |
|
|
|
|
""" |
|
|
|
|
This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`, |
|
|
|
|
This class is used to record the information of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`, |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage. |
|
|
|
@ -173,76 +173,76 @@ class MicroBatchManager:
|
|
|
|
|
self.max_input_len = max_input_len |
|
|
|
|
self.max_output_len = max_output_len |
|
|
|
|
self.cache_manager_list = cache_manager_list |
|
|
|
|
self.mb_descrption_buffer = {} |
|
|
|
|
self.mb_description_buffer = {} |
|
|
|
|
self.new_tokens_buffer = {} |
|
|
|
|
self.idx = 0 |
|
|
|
|
|
|
|
|
|
def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]): |
|
|
|
|
def add_description(self, inputs_dict: Dict[str, torch.Tensor]): |
|
|
|
|
if self.stage == 0: |
|
|
|
|
self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription( |
|
|
|
|
self.mb_description_buffer[self.idx] = HeadMicroBatchDescription( |
|
|
|
|
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription( |
|
|
|
|
self.mb_description_buffer[self.idx] = BodyMicroBatchDescription( |
|
|
|
|
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def step(self, new_token: torch.Tensor = None): |
|
|
|
|
""" |
|
|
|
|
Update the state if microbatch manager, 2 conditions. |
|
|
|
|
1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs. |
|
|
|
|
2. For other conditon, only receive the output of previous stage, and update the descrption. |
|
|
|
|
1. For first stage in PREFILL, receive inputs and outputs, `_add_description` will save its inputs. |
|
|
|
|
2. For other condition, only receive the output of previous stage, and update the description. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. |
|
|
|
|
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. |
|
|
|
|
new_token (torch.Tensor): the new token generated by current stage. |
|
|
|
|
""" |
|
|
|
|
# Add descrption first if the descrption is None |
|
|
|
|
self.cur_descrption.update(new_token) |
|
|
|
|
# Add description first if the description is None |
|
|
|
|
self.cur_description.update(new_token) |
|
|
|
|
return self.cur_state |
|
|
|
|
|
|
|
|
|
def export_new_tokens(self): |
|
|
|
|
new_tokens_list = [] |
|
|
|
|
for i in self.mb_descrption_buffer.values(): |
|
|
|
|
for i in self.mb_description_buffer.values(): |
|
|
|
|
new_tokens_list.extend(i.new_tokens.tolist()) |
|
|
|
|
return new_tokens_list |
|
|
|
|
|
|
|
|
|
def is_micro_batch_done(self): |
|
|
|
|
if len(self.mb_descrption_buffer) == 0: |
|
|
|
|
if len(self.mb_description_buffer) == 0: |
|
|
|
|
return False |
|
|
|
|
for mb in self.mb_descrption_buffer.values(): |
|
|
|
|
for mb in self.mb_description_buffer.values(): |
|
|
|
|
if mb.state != Status.DONE: |
|
|
|
|
return False |
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
def clear(self): |
|
|
|
|
self.mb_descrption_buffer.clear() |
|
|
|
|
self.mb_description_buffer.clear() |
|
|
|
|
for cache in self.cache_manager_list: |
|
|
|
|
cache.free_all() |
|
|
|
|
|
|
|
|
|
def next(self): |
|
|
|
|
self.idx = (self.idx + 1) % self.buffer_size |
|
|
|
|
|
|
|
|
|
def _remove_descrption(self): |
|
|
|
|
self.mb_descrption_buffer.pop(self.idx) |
|
|
|
|
def _remove_description(self): |
|
|
|
|
self.mb_description_buffer.pop(self.idx) |
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
def cur_descrption(self) -> MicroBatchDescription: |
|
|
|
|
return self.mb_descrption_buffer.get(self.idx) |
|
|
|
|
def cur_description(self) -> MicroBatchDescription: |
|
|
|
|
return self.mb_description_buffer.get(self.idx) |
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
def cur_infer_state(self): |
|
|
|
|
if self.cur_descrption is None: |
|
|
|
|
if self.cur_description is None: |
|
|
|
|
return None |
|
|
|
|
return self.cur_descrption.infer_state |
|
|
|
|
return self.cur_description.infer_state |
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
def cur_state(self): |
|
|
|
|
""" |
|
|
|
|
Return the state of current micro batch, when current descrption is None, the state is PREFILL |
|
|
|
|
Return the state of current micro batch, when current description is None, the state is PREFILL |
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
if self.cur_descrption is None: |
|
|
|
|
if self.cur_description is None: |
|
|
|
|
return Status.PREFILL |
|
|
|
|
return self.cur_descrption.state |
|
|
|
|
return self.cur_description.state |
|
|
|
|