Browse Source

[hotfix] fix typo change _descrption to _description (#5331)

pull/5335/head^2
digger yu 9 months ago committed by GitHub
parent
commit
16c96d4d8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 54
      colossalai/inference/engine/microbatch_manager.py
  2. 54
      colossalai/legacy/inference/pipeline/microbatch_manager.py
  3. 12
      colossalai/pipeline/schedule/generate.py

54
colossalai/inference/engine/microbatch_manager.py

@ -17,8 +17,8 @@ class Status(Enum):
class MicroBatchDescription:
"""
This is the class to record the infomation of each microbatch, and also do some update operation.
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
This is the class to record the information of each microbatch, and also do some update operation.
This class is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
details, please refer to the doc of these two classes blow.
Args:
@ -61,15 +61,15 @@ class MicroBatchDescription:
@property
def cur_length(self):
"""
Return the current sequnence length of micro batch
Return the current sequence length of micro batch
"""
class HeadMicroBatchDescription(MicroBatchDescription):
"""
This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask`
and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the
This class is used to record the information of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask`
and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schedule of pipeline, the operation to update the
information and the condition to determine the state is different from other stages.
Args:
@ -123,7 +123,7 @@ class HeadMicroBatchDescription(MicroBatchDescription):
class BodyMicroBatchDescription(MicroBatchDescription):
"""
This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`,
This class is used to record the information of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`,
Args:
inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage.
@ -173,76 +173,76 @@ class MicroBatchManager:
self.max_input_len = max_input_len
self.max_output_len = max_output_len
self.cache_manager_list = cache_manager_list
self.mb_descrption_buffer = {}
self.mb_description_buffer = {}
self.new_tokens_buffer = {}
self.idx = 0
def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]):
def add_description(self, inputs_dict: Dict[str, torch.Tensor]):
if self.stage == 0:
self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(
self.mb_description_buffer[self.idx] = HeadMicroBatchDescription(
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
)
else:
self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(
self.mb_description_buffer[self.idx] = BodyMicroBatchDescription(
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
)
def step(self, new_token: torch.Tensor = None):
"""
Update the state if microbatch manager, 2 conditions.
1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs.
2. For other conditon, only receive the output of previous stage, and update the descrption.
1. For first stage in PREFILL, receive inputs and outputs, `_add_description` will save its inputs.
2. For other condition, only receive the output of previous stage, and update the description.
Args:
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
new_token (torch.Tensor): the new token generated by current stage.
"""
# Add descrption first if the descrption is None
self.cur_descrption.update(new_token)
# Add description first if the description is None
self.cur_description.update(new_token)
return self.cur_state
def export_new_tokens(self):
new_tokens_list = []
for i in self.mb_descrption_buffer.values():
for i in self.mb_description_buffer.values():
new_tokens_list.extend(i.new_tokens.tolist())
return new_tokens_list
def is_micro_batch_done(self):
if len(self.mb_descrption_buffer) == 0:
if len(self.mb_description_buffer) == 0:
return False
for mb in self.mb_descrption_buffer.values():
for mb in self.mb_description_buffer.values():
if mb.state != Status.DONE:
return False
return True
def clear(self):
self.mb_descrption_buffer.clear()
self.mb_description_buffer.clear()
for cache in self.cache_manager_list:
cache.free_all()
def next(self):
self.idx = (self.idx + 1) % self.buffer_size
def _remove_descrption(self):
self.mb_descrption_buffer.pop(self.idx)
def _remove_description(self):
self.mb_description_buffer.pop(self.idx)
@property
def cur_descrption(self) -> MicroBatchDescription:
return self.mb_descrption_buffer.get(self.idx)
def cur_description(self) -> MicroBatchDescription:
return self.mb_description_buffer.get(self.idx)
@property
def cur_infer_state(self):
if self.cur_descrption is None:
if self.cur_description is None:
return None
return self.cur_descrption.infer_state
return self.cur_description.infer_state
@property
def cur_state(self):
"""
Return the state of current micro batch, when current descrption is None, the state is PREFILL
Return the state of current micro batch, when current description is None, the state is PREFILL
"""
if self.cur_descrption is None:
if self.cur_description is None:
return Status.PREFILL
return self.cur_descrption.state
return self.cur_description.state

54
colossalai/legacy/inference/pipeline/microbatch_manager.py

@ -18,8 +18,8 @@ class Status(Enum):
class MicroBatchDescription:
"""
This is the class to record the infomation of each microbatch, and also do some update operation.
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
This is the class to record the information of each microbatch, and also do some update operation.
This class is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
details, please refer to the doc of these two classes blow.
Args:
@ -62,15 +62,15 @@ class MicroBatchDescription:
@property
def cur_length(self):
"""
Return the current sequnence length of micro batch
Return the current sequence length of micro batch
"""
class HeadMicroBatchDescription(MicroBatchDescription):
"""
This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask`
and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the
This class is used to record the information of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask`
and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schedule of pipeline, the operation to update the
information and the condition to determine the state is different from other stages.
Args:
@ -124,7 +124,7 @@ class HeadMicroBatchDescription(MicroBatchDescription):
class BodyMicroBatchDescription(MicroBatchDescription):
"""
This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`,
This class is used to record the information of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`,
Args:
inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage.
@ -174,76 +174,76 @@ class MicroBatchManager:
self.max_input_len = max_input_len
self.max_output_len = max_output_len
self.cache_manager_list = cache_manager_list
self.mb_descrption_buffer = {}
self.mb_description_buffer = {}
self.new_tokens_buffer = {}
self.idx = 0
def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]):
def add_description(self, inputs_dict: Dict[str, torch.Tensor]):
if self.stage == 0:
self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(
self.mb_description_buffer[self.idx] = HeadMicroBatchDescription(
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
)
else:
self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(
self.mb_description_buffer[self.idx] = BodyMicroBatchDescription(
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
)
def step(self, new_token: torch.Tensor = None):
"""
Update the state if microbatch manager, 2 conditions.
1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs.
2. For other conditon, only receive the output of previous stage, and update the descrption.
1. For first stage in PREFILL, receive inputs and outputs, `_add_description` will save its inputs.
2. For other condition, only receive the output of previous stage, and update the description.
Args:
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
new_token (torch.Tensor): the new token generated by current stage.
"""
# Add descrption first if the descrption is None
self.cur_descrption.update(new_token)
# Add description first if the description is None
self.cur_description.update(new_token)
return self.cur_state
def export_new_tokens(self):
new_tokens_list = []
for i in self.mb_descrption_buffer.values():
for i in self.mb_description_buffer.values():
new_tokens_list.extend(i.new_tokens.tolist())
return new_tokens_list
def is_micro_batch_done(self):
if len(self.mb_descrption_buffer) == 0:
if len(self.mb_description_buffer) == 0:
return False
for mb in self.mb_descrption_buffer.values():
for mb in self.mb_description_buffer.values():
if mb.state != Status.DONE:
return False
return True
def clear(self):
self.mb_descrption_buffer.clear()
self.mb_description_buffer.clear()
for cache in self.cache_manager_list:
cache.free_all()
def next(self):
self.idx = (self.idx + 1) % self.buffer_size
def _remove_descrption(self):
self.mb_descrption_buffer.pop(self.idx)
def _remove_description(self):
self.mb_description_buffer.pop(self.idx)
@property
def cur_descrption(self) -> MicroBatchDescription:
return self.mb_descrption_buffer.get(self.idx)
def cur_description(self) -> MicroBatchDescription:
return self.mb_description_buffer.get(self.idx)
@property
def cur_infer_state(self):
if self.cur_descrption is None:
if self.cur_description is None:
return None
return self.cur_descrption.infer_state
return self.cur_description.infer_state
@property
def cur_state(self):
"""
Return the state of current micro batch, when current descrption is None, the state is PREFILL
Return the state of current micro batch, when current description is None, the state is PREFILL
"""
if self.cur_descrption is None:
if self.cur_description is None:
return Status.PREFILL
return self.cur_descrption.state
return self.cur_description.state

12
colossalai/pipeline/schedule/generate.py

@ -95,7 +95,7 @@ class GenerateSchedule(PipelineSchedule):
Returns:
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
"""
model_inputs = {"infer_state": self.mb_manager.cur_descrption.infer_state}
model_inputs = {"infer_state": self.mb_manager.cur_description.infer_state}
return model_inputs
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
@ -107,7 +107,7 @@ class GenerateSchedule(PipelineSchedule):
Returns:
dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}`
"""
new_mask = self.mb_manager.cur_descrption.attn_mask
new_mask = self.mb_manager.cur_description.attn_mask
return dict(input_ids=new_token, attention_mask=new_mask)
@ -133,7 +133,7 @@ class GenerateSchedule(PipelineSchedule):
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
"""
inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict)
self.mb_manager.add_description(inputs_dict)
def _load_stage_action(self, model: Module) -> None:
"""
@ -141,7 +141,7 @@ class GenerateSchedule(PipelineSchedule):
1.load micro_batch 2.do the forward 3.step to update
"""
inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict)
self.mb_manager.add_description(inputs_dict)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
@ -379,7 +379,7 @@ class GenerateSchedule(PipelineSchedule):
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
self.mb_manager.add_descrption(inputs_dict)
self.mb_manager.add_description(inputs_dict)
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
# In GENERATE phase
@ -415,7 +415,7 @@ class GenerateSchedule(PipelineSchedule):
inputs_dict = None
if self.mb_manager.cur_state is Status.PREFILL:
inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict)
self.mb_manager.add_description(inputs_dict)
interval_inputs = {
"hidden_states": hidden_states["hidden_states"],
"infer_state": self.mb_manager.cur_infer_state,

Loading…
Cancel
Save