[hotfix] fix typo change _descrption to _description (#5331)

pull/5335/head^2
digger yu 2024-03-05 21:47:48 +08:00 committed by GitHub
parent 70cce5cbed
commit 16c96d4d8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 60 additions and 60 deletions

View File

@ -17,8 +17,8 @@ class Status(Enum):
class MicroBatchDescription: class MicroBatchDescription:
""" """
This is the class to record the infomation of each microbatch, and also do some update operation. This is the class to record the information of each microbatch, and also do some update operation.
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more This class is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
details, please refer to the doc of these two classes blow. details, please refer to the doc of these two classes blow.
Args: Args:
@ -61,15 +61,15 @@ class MicroBatchDescription:
@property @property
def cur_length(self): def cur_length(self):
""" """
Return the current sequnence length of micro batch Return the current sequence length of micro batch
""" """
class HeadMicroBatchDescription(MicroBatchDescription): class HeadMicroBatchDescription(MicroBatchDescription):
""" """
This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask` This class is used to record the information of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask`
and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schedule of pipeline, the operation to update the
information and the condition to determine the state is different from other stages. information and the condition to determine the state is different from other stages.
Args: Args:
@ -123,7 +123,7 @@ class HeadMicroBatchDescription(MicroBatchDescription):
class BodyMicroBatchDescription(MicroBatchDescription): class BodyMicroBatchDescription(MicroBatchDescription):
""" """
This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`, This class is used to record the information of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`,
Args: Args:
inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage. inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage.
@ -173,76 +173,76 @@ class MicroBatchManager:
self.max_input_len = max_input_len self.max_input_len = max_input_len
self.max_output_len = max_output_len self.max_output_len = max_output_len
self.cache_manager_list = cache_manager_list self.cache_manager_list = cache_manager_list
self.mb_descrption_buffer = {} self.mb_description_buffer = {}
self.new_tokens_buffer = {} self.new_tokens_buffer = {}
self.idx = 0 self.idx = 0
def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]): def add_description(self, inputs_dict: Dict[str, torch.Tensor]):
if self.stage == 0: if self.stage == 0:
self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription( self.mb_description_buffer[self.idx] = HeadMicroBatchDescription(
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
) )
else: else:
self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription( self.mb_description_buffer[self.idx] = BodyMicroBatchDescription(
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
) )
def step(self, new_token: torch.Tensor = None): def step(self, new_token: torch.Tensor = None):
""" """
Update the state if microbatch manager, 2 conditions. Update the state if microbatch manager, 2 conditions.
1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs. 1. For first stage in PREFILL, receive inputs and outputs, `_add_description` will save its inputs.
2. For other conditon, only receive the output of previous stage, and update the descrption. 2. For other condition, only receive the output of previous stage, and update the description.
Args: Args:
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
new_token (torch.Tensor): the new token generated by current stage. new_token (torch.Tensor): the new token generated by current stage.
""" """
# Add descrption first if the descrption is None # Add description first if the description is None
self.cur_descrption.update(new_token) self.cur_description.update(new_token)
return self.cur_state return self.cur_state
def export_new_tokens(self): def export_new_tokens(self):
new_tokens_list = [] new_tokens_list = []
for i in self.mb_descrption_buffer.values(): for i in self.mb_description_buffer.values():
new_tokens_list.extend(i.new_tokens.tolist()) new_tokens_list.extend(i.new_tokens.tolist())
return new_tokens_list return new_tokens_list
def is_micro_batch_done(self): def is_micro_batch_done(self):
if len(self.mb_descrption_buffer) == 0: if len(self.mb_description_buffer) == 0:
return False return False
for mb in self.mb_descrption_buffer.values(): for mb in self.mb_description_buffer.values():
if mb.state != Status.DONE: if mb.state != Status.DONE:
return False return False
return True return True
def clear(self): def clear(self):
self.mb_descrption_buffer.clear() self.mb_description_buffer.clear()
for cache in self.cache_manager_list: for cache in self.cache_manager_list:
cache.free_all() cache.free_all()
def next(self): def next(self):
self.idx = (self.idx + 1) % self.buffer_size self.idx = (self.idx + 1) % self.buffer_size
def _remove_descrption(self): def _remove_description(self):
self.mb_descrption_buffer.pop(self.idx) self.mb_description_buffer.pop(self.idx)
@property @property
def cur_descrption(self) -> MicroBatchDescription: def cur_description(self) -> MicroBatchDescription:
return self.mb_descrption_buffer.get(self.idx) return self.mb_description_buffer.get(self.idx)
@property @property
def cur_infer_state(self): def cur_infer_state(self):
if self.cur_descrption is None: if self.cur_description is None:
return None return None
return self.cur_descrption.infer_state return self.cur_description.infer_state
@property @property
def cur_state(self): def cur_state(self):
""" """
Return the state of current micro batch, when current descrption is None, the state is PREFILL Return the state of current micro batch, when current description is None, the state is PREFILL
""" """
if self.cur_descrption is None: if self.cur_description is None:
return Status.PREFILL return Status.PREFILL
return self.cur_descrption.state return self.cur_description.state

View File

@ -18,8 +18,8 @@ class Status(Enum):
class MicroBatchDescription: class MicroBatchDescription:
""" """
This is the class to record the infomation of each microbatch, and also do some update operation. This is the class to record the information of each microbatch, and also do some update operation.
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more This class is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
details, please refer to the doc of these two classes blow. details, please refer to the doc of these two classes blow.
Args: Args:
@ -62,15 +62,15 @@ class MicroBatchDescription:
@property @property
def cur_length(self): def cur_length(self):
""" """
Return the current sequnence length of micro batch Return the current sequence length of micro batch
""" """
class HeadMicroBatchDescription(MicroBatchDescription): class HeadMicroBatchDescription(MicroBatchDescription):
""" """
This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask` This class is used to record the information of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask`
and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schedule of pipeline, the operation to update the
information and the condition to determine the state is different from other stages. information and the condition to determine the state is different from other stages.
Args: Args:
@ -124,7 +124,7 @@ class HeadMicroBatchDescription(MicroBatchDescription):
class BodyMicroBatchDescription(MicroBatchDescription): class BodyMicroBatchDescription(MicroBatchDescription):
""" """
This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`, This class is used to record the information of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`,
Args: Args:
inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage. inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage.
@ -174,76 +174,76 @@ class MicroBatchManager:
self.max_input_len = max_input_len self.max_input_len = max_input_len
self.max_output_len = max_output_len self.max_output_len = max_output_len
self.cache_manager_list = cache_manager_list self.cache_manager_list = cache_manager_list
self.mb_descrption_buffer = {} self.mb_description_buffer = {}
self.new_tokens_buffer = {} self.new_tokens_buffer = {}
self.idx = 0 self.idx = 0
def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]): def add_description(self, inputs_dict: Dict[str, torch.Tensor]):
if self.stage == 0: if self.stage == 0:
self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription( self.mb_description_buffer[self.idx] = HeadMicroBatchDescription(
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
) )
else: else:
self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription( self.mb_description_buffer[self.idx] = BodyMicroBatchDescription(
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
) )
def step(self, new_token: torch.Tensor = None): def step(self, new_token: torch.Tensor = None):
""" """
Update the state if microbatch manager, 2 conditions. Update the state if microbatch manager, 2 conditions.
1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs. 1. For first stage in PREFILL, receive inputs and outputs, `_add_description` will save its inputs.
2. For other conditon, only receive the output of previous stage, and update the descrption. 2. For other condition, only receive the output of previous stage, and update the description.
Args: Args:
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
new_token (torch.Tensor): the new token generated by current stage. new_token (torch.Tensor): the new token generated by current stage.
""" """
# Add descrption first if the descrption is None # Add description first if the description is None
self.cur_descrption.update(new_token) self.cur_description.update(new_token)
return self.cur_state return self.cur_state
def export_new_tokens(self): def export_new_tokens(self):
new_tokens_list = [] new_tokens_list = []
for i in self.mb_descrption_buffer.values(): for i in self.mb_description_buffer.values():
new_tokens_list.extend(i.new_tokens.tolist()) new_tokens_list.extend(i.new_tokens.tolist())
return new_tokens_list return new_tokens_list
def is_micro_batch_done(self): def is_micro_batch_done(self):
if len(self.mb_descrption_buffer) == 0: if len(self.mb_description_buffer) == 0:
return False return False
for mb in self.mb_descrption_buffer.values(): for mb in self.mb_description_buffer.values():
if mb.state != Status.DONE: if mb.state != Status.DONE:
return False return False
return True return True
def clear(self): def clear(self):
self.mb_descrption_buffer.clear() self.mb_description_buffer.clear()
for cache in self.cache_manager_list: for cache in self.cache_manager_list:
cache.free_all() cache.free_all()
def next(self): def next(self):
self.idx = (self.idx + 1) % self.buffer_size self.idx = (self.idx + 1) % self.buffer_size
def _remove_descrption(self): def _remove_description(self):
self.mb_descrption_buffer.pop(self.idx) self.mb_description_buffer.pop(self.idx)
@property @property
def cur_descrption(self) -> MicroBatchDescription: def cur_description(self) -> MicroBatchDescription:
return self.mb_descrption_buffer.get(self.idx) return self.mb_description_buffer.get(self.idx)
@property @property
def cur_infer_state(self): def cur_infer_state(self):
if self.cur_descrption is None: if self.cur_description is None:
return None return None
return self.cur_descrption.infer_state return self.cur_description.infer_state
@property @property
def cur_state(self): def cur_state(self):
""" """
Return the state of current micro batch, when current descrption is None, the state is PREFILL Return the state of current micro batch, when current description is None, the state is PREFILL
""" """
if self.cur_descrption is None: if self.cur_description is None:
return Status.PREFILL return Status.PREFILL
return self.cur_descrption.state return self.cur_description.state

View File

@ -95,7 +95,7 @@ class GenerateSchedule(PipelineSchedule):
Returns: Returns:
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None` dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
""" """
model_inputs = {"infer_state": self.mb_manager.cur_descrption.infer_state} model_inputs = {"infer_state": self.mb_manager.cur_description.infer_state}
return model_inputs return model_inputs
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor): def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
@ -107,7 +107,7 @@ class GenerateSchedule(PipelineSchedule):
Returns: Returns:
dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}` dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}`
""" """
new_mask = self.mb_manager.cur_descrption.attn_mask new_mask = self.mb_manager.cur_description.attn_mask
return dict(input_ids=new_token, attention_mask=new_mask) return dict(input_ids=new_token, attention_mask=new_mask)
@ -133,7 +133,7 @@ class GenerateSchedule(PipelineSchedule):
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state 1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
""" """
inputs_dict = self.load_micro_batch() inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict) self.mb_manager.add_description(inputs_dict)
def _load_stage_action(self, model: Module) -> None: def _load_stage_action(self, model: Module) -> None:
""" """
@ -141,7 +141,7 @@ class GenerateSchedule(PipelineSchedule):
1.load micro_batch 2.do the forward 3.step to update 1.load micro_batch 2.do the forward 3.step to update
""" """
inputs_dict = self.load_micro_batch() inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict) self.mb_manager.add_description(inputs_dict)
if self.verbose and self.stage_manager.is_first_stage(): if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize() torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time()) self.timestamps[self.mb_manager.idx].append(time.time())
@ -379,7 +379,7 @@ class GenerateSchedule(PipelineSchedule):
if self.verbose and self.stage_manager.is_first_stage(): if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize() torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time()) self.timestamps[self.mb_manager.idx].append(time.time())
self.mb_manager.add_descrption(inputs_dict) self.mb_manager.add_description(inputs_dict)
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state} interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs) output_dict = model_forward(model, inputs_dict, interval_inputs)
# In GENERATE phase # In GENERATE phase
@ -415,7 +415,7 @@ class GenerateSchedule(PipelineSchedule):
inputs_dict = None inputs_dict = None
if self.mb_manager.cur_state is Status.PREFILL: if self.mb_manager.cur_state is Status.PREFILL:
inputs_dict = self.load_micro_batch() inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict) self.mb_manager.add_description(inputs_dict)
interval_inputs = { interval_inputs = {
"hidden_states": hidden_states["hidden_states"], "hidden_states": hidden_states["hidden_states"],
"infer_state": self.mb_manager.cur_infer_state, "infer_state": self.mb_manager.cur_infer_state,