diff --git a/colossalai/inference/engine/microbatch_manager.py b/colossalai/inference/engine/microbatch_manager.py index d698c89f9..7264b81e0 100644 --- a/colossalai/inference/engine/microbatch_manager.py +++ b/colossalai/inference/engine/microbatch_manager.py @@ -17,8 +17,8 @@ class Status(Enum): class MicroBatchDescription: """ - This is the class to record the infomation of each microbatch, and also do some update operation. - This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more + This is the class to record the information of each microbatch, and also do some update operation. + This class is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more details, please refer to the doc of these two classes blow. Args: @@ -61,15 +61,15 @@ class MicroBatchDescription: @property def cur_length(self): """ - Return the current sequnence length of micro batch + Return the current sequence length of micro batch """ class HeadMicroBatchDescription(MicroBatchDescription): """ - This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask` - and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the + This class is used to record the information of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask` + and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schedule of pipeline, the operation to update the information and the condition to determine the state is different from other stages. Args: @@ -123,7 +123,7 @@ class HeadMicroBatchDescription(MicroBatchDescription): class BodyMicroBatchDescription(MicroBatchDescription): """ - This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`, + This class is used to record the information of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`, Args: inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage. @@ -173,76 +173,76 @@ class MicroBatchManager: self.max_input_len = max_input_len self.max_output_len = max_output_len self.cache_manager_list = cache_manager_list - self.mb_descrption_buffer = {} + self.mb_description_buffer = {} self.new_tokens_buffer = {} self.idx = 0 - def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]): + def add_description(self, inputs_dict: Dict[str, torch.Tensor]): if self.stage == 0: - self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription( + self.mb_description_buffer[self.idx] = HeadMicroBatchDescription( inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] ) else: - self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription( + self.mb_description_buffer[self.idx] = BodyMicroBatchDescription( inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] ) def step(self, new_token: torch.Tensor = None): """ Update the state if microbatch manager, 2 conditions. - 1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs. - 2. For other conditon, only receive the output of previous stage, and update the descrption. + 1. For first stage in PREFILL, receive inputs and outputs, `_add_description` will save its inputs. + 2. For other condition, only receive the output of previous stage, and update the description. Args: inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. new_token (torch.Tensor): the new token generated by current stage. """ - # Add descrption first if the descrption is None - self.cur_descrption.update(new_token) + # Add description first if the description is None + self.cur_description.update(new_token) return self.cur_state def export_new_tokens(self): new_tokens_list = [] - for i in self.mb_descrption_buffer.values(): + for i in self.mb_description_buffer.values(): new_tokens_list.extend(i.new_tokens.tolist()) return new_tokens_list def is_micro_batch_done(self): - if len(self.mb_descrption_buffer) == 0: + if len(self.mb_description_buffer) == 0: return False - for mb in self.mb_descrption_buffer.values(): + for mb in self.mb_description_buffer.values(): if mb.state != Status.DONE: return False return True def clear(self): - self.mb_descrption_buffer.clear() + self.mb_description_buffer.clear() for cache in self.cache_manager_list: cache.free_all() def next(self): self.idx = (self.idx + 1) % self.buffer_size - def _remove_descrption(self): - self.mb_descrption_buffer.pop(self.idx) + def _remove_description(self): + self.mb_description_buffer.pop(self.idx) @property - def cur_descrption(self) -> MicroBatchDescription: - return self.mb_descrption_buffer.get(self.idx) + def cur_description(self) -> MicroBatchDescription: + return self.mb_description_buffer.get(self.idx) @property def cur_infer_state(self): - if self.cur_descrption is None: + if self.cur_description is None: return None - return self.cur_descrption.infer_state + return self.cur_description.infer_state @property def cur_state(self): """ - Return the state of current micro batch, when current descrption is None, the state is PREFILL + Return the state of current micro batch, when current description is None, the state is PREFILL """ - if self.cur_descrption is None: + if self.cur_description is None: return Status.PREFILL - return self.cur_descrption.state + return self.cur_description.state diff --git a/colossalai/legacy/inference/pipeline/microbatch_manager.py b/colossalai/legacy/inference/pipeline/microbatch_manager.py index 441cf6039..cb0a8c1a9 100644 --- a/colossalai/legacy/inference/pipeline/microbatch_manager.py +++ b/colossalai/legacy/inference/pipeline/microbatch_manager.py @@ -18,8 +18,8 @@ class Status(Enum): class MicroBatchDescription: """ - This is the class to record the infomation of each microbatch, and also do some update operation. - This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more + This is the class to record the information of each microbatch, and also do some update operation. + This class is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more details, please refer to the doc of these two classes blow. Args: @@ -62,15 +62,15 @@ class MicroBatchDescription: @property def cur_length(self): """ - Return the current sequnence length of micro batch + Return the current sequence length of micro batch """ class HeadMicroBatchDescription(MicroBatchDescription): """ - This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask` - and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the + This class is used to record the information of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask` + and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schedule of pipeline, the operation to update the information and the condition to determine the state is different from other stages. Args: @@ -124,7 +124,7 @@ class HeadMicroBatchDescription(MicroBatchDescription): class BodyMicroBatchDescription(MicroBatchDescription): """ - This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`, + This class is used to record the information of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`, Args: inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage. @@ -174,76 +174,76 @@ class MicroBatchManager: self.max_input_len = max_input_len self.max_output_len = max_output_len self.cache_manager_list = cache_manager_list - self.mb_descrption_buffer = {} + self.mb_description_buffer = {} self.new_tokens_buffer = {} self.idx = 0 - def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]): + def add_description(self, inputs_dict: Dict[str, torch.Tensor]): if self.stage == 0: - self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription( + self.mb_description_buffer[self.idx] = HeadMicroBatchDescription( inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] ) else: - self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription( + self.mb_description_buffer[self.idx] = BodyMicroBatchDescription( inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] ) def step(self, new_token: torch.Tensor = None): """ Update the state if microbatch manager, 2 conditions. - 1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs. - 2. For other conditon, only receive the output of previous stage, and update the descrption. + 1. For first stage in PREFILL, receive inputs and outputs, `_add_description` will save its inputs. + 2. For other condition, only receive the output of previous stage, and update the description. Args: inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. new_token (torch.Tensor): the new token generated by current stage. """ - # Add descrption first if the descrption is None - self.cur_descrption.update(new_token) + # Add description first if the description is None + self.cur_description.update(new_token) return self.cur_state def export_new_tokens(self): new_tokens_list = [] - for i in self.mb_descrption_buffer.values(): + for i in self.mb_description_buffer.values(): new_tokens_list.extend(i.new_tokens.tolist()) return new_tokens_list def is_micro_batch_done(self): - if len(self.mb_descrption_buffer) == 0: + if len(self.mb_description_buffer) == 0: return False - for mb in self.mb_descrption_buffer.values(): + for mb in self.mb_description_buffer.values(): if mb.state != Status.DONE: return False return True def clear(self): - self.mb_descrption_buffer.clear() + self.mb_description_buffer.clear() for cache in self.cache_manager_list: cache.free_all() def next(self): self.idx = (self.idx + 1) % self.buffer_size - def _remove_descrption(self): - self.mb_descrption_buffer.pop(self.idx) + def _remove_description(self): + self.mb_description_buffer.pop(self.idx) @property - def cur_descrption(self) -> MicroBatchDescription: - return self.mb_descrption_buffer.get(self.idx) + def cur_description(self) -> MicroBatchDescription: + return self.mb_description_buffer.get(self.idx) @property def cur_infer_state(self): - if self.cur_descrption is None: + if self.cur_description is None: return None - return self.cur_descrption.infer_state + return self.cur_description.infer_state @property def cur_state(self): """ - Return the state of current micro batch, when current descrption is None, the state is PREFILL + Return the state of current micro batch, when current description is None, the state is PREFILL """ - if self.cur_descrption is None: + if self.cur_description is None: return Status.PREFILL - return self.cur_descrption.state + return self.cur_description.state diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 20f316c2a..d6a6aec63 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -95,7 +95,7 @@ class GenerateSchedule(PipelineSchedule): Returns: dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None` """ - model_inputs = {"infer_state": self.mb_manager.cur_descrption.infer_state} + model_inputs = {"infer_state": self.mb_manager.cur_description.infer_state} return model_inputs def _prepare_inputs_for_new_token(self, new_token: torch.Tensor): @@ -107,7 +107,7 @@ class GenerateSchedule(PipelineSchedule): Returns: dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}` """ - new_mask = self.mb_manager.cur_descrption.attn_mask + new_mask = self.mb_manager.cur_description.attn_mask return dict(input_ids=new_token, attention_mask=new_mask) @@ -133,7 +133,7 @@ class GenerateSchedule(PipelineSchedule): 1.Load micro_batch 2.Use the current micro_batch to init the current infer_state """ inputs_dict = self.load_micro_batch() - self.mb_manager.add_descrption(inputs_dict) + self.mb_manager.add_description(inputs_dict) def _load_stage_action(self, model: Module) -> None: """ @@ -141,7 +141,7 @@ class GenerateSchedule(PipelineSchedule): 1.load micro_batch 2.do the forward 3.step to update """ inputs_dict = self.load_micro_batch() - self.mb_manager.add_descrption(inputs_dict) + self.mb_manager.add_description(inputs_dict) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) @@ -379,7 +379,7 @@ class GenerateSchedule(PipelineSchedule): if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - self.mb_manager.add_descrption(inputs_dict) + self.mb_manager.add_description(inputs_dict) interval_inputs = {"infer_state": self.mb_manager.cur_infer_state} output_dict = model_forward(model, inputs_dict, interval_inputs) # In GENERATE phase @@ -415,7 +415,7 @@ class GenerateSchedule(PipelineSchedule): inputs_dict = None if self.mb_manager.cur_state is Status.PREFILL: inputs_dict = self.load_micro_batch() - self.mb_manager.add_descrption(inputs_dict) + self.mb_manager.add_description(inputs_dict) interval_inputs = { "hidden_states": hidden_states["hidden_states"], "infer_state": self.mb_manager.cur_infer_state,