mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix typo change _descrption to _description (#5331)
parent
70cce5cbed
commit
16c96d4d8c
|
@ -17,8 +17,8 @@ class Status(Enum):
|
|||
|
||||
class MicroBatchDescription:
|
||||
"""
|
||||
This is the class to record the infomation of each microbatch, and also do some update operation.
|
||||
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
|
||||
This is the class to record the information of each microbatch, and also do some update operation.
|
||||
This class is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
|
||||
details, please refer to the doc of these two classes blow.
|
||||
|
||||
Args:
|
||||
|
@ -61,15 +61,15 @@ class MicroBatchDescription:
|
|||
@property
|
||||
def cur_length(self):
|
||||
"""
|
||||
Return the current sequnence length of micro batch
|
||||
Return the current sequence length of micro batch
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class HeadMicroBatchDescription(MicroBatchDescription):
|
||||
"""
|
||||
This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask`
|
||||
and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the
|
||||
This class is used to record the information of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask`
|
||||
and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schedule of pipeline, the operation to update the
|
||||
information and the condition to determine the state is different from other stages.
|
||||
|
||||
Args:
|
||||
|
@ -123,7 +123,7 @@ class HeadMicroBatchDescription(MicroBatchDescription):
|
|||
|
||||
class BodyMicroBatchDescription(MicroBatchDescription):
|
||||
"""
|
||||
This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`,
|
||||
This class is used to record the information of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`,
|
||||
|
||||
Args:
|
||||
inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage.
|
||||
|
@ -173,76 +173,76 @@ class MicroBatchManager:
|
|||
self.max_input_len = max_input_len
|
||||
self.max_output_len = max_output_len
|
||||
self.cache_manager_list = cache_manager_list
|
||||
self.mb_descrption_buffer = {}
|
||||
self.mb_description_buffer = {}
|
||||
self.new_tokens_buffer = {}
|
||||
self.idx = 0
|
||||
|
||||
def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]):
|
||||
def add_description(self, inputs_dict: Dict[str, torch.Tensor]):
|
||||
if self.stage == 0:
|
||||
self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(
|
||||
self.mb_description_buffer[self.idx] = HeadMicroBatchDescription(
|
||||
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
|
||||
)
|
||||
else:
|
||||
self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(
|
||||
self.mb_description_buffer[self.idx] = BodyMicroBatchDescription(
|
||||
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
|
||||
)
|
||||
|
||||
def step(self, new_token: torch.Tensor = None):
|
||||
"""
|
||||
Update the state if microbatch manager, 2 conditions.
|
||||
1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs.
|
||||
2. For other conditon, only receive the output of previous stage, and update the descrption.
|
||||
1. For first stage in PREFILL, receive inputs and outputs, `_add_description` will save its inputs.
|
||||
2. For other condition, only receive the output of previous stage, and update the description.
|
||||
|
||||
Args:
|
||||
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
|
||||
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
|
||||
new_token (torch.Tensor): the new token generated by current stage.
|
||||
"""
|
||||
# Add descrption first if the descrption is None
|
||||
self.cur_descrption.update(new_token)
|
||||
# Add description first if the description is None
|
||||
self.cur_description.update(new_token)
|
||||
return self.cur_state
|
||||
|
||||
def export_new_tokens(self):
|
||||
new_tokens_list = []
|
||||
for i in self.mb_descrption_buffer.values():
|
||||
for i in self.mb_description_buffer.values():
|
||||
new_tokens_list.extend(i.new_tokens.tolist())
|
||||
return new_tokens_list
|
||||
|
||||
def is_micro_batch_done(self):
|
||||
if len(self.mb_descrption_buffer) == 0:
|
||||
if len(self.mb_description_buffer) == 0:
|
||||
return False
|
||||
for mb in self.mb_descrption_buffer.values():
|
||||
for mb in self.mb_description_buffer.values():
|
||||
if mb.state != Status.DONE:
|
||||
return False
|
||||
return True
|
||||
|
||||
def clear(self):
|
||||
self.mb_descrption_buffer.clear()
|
||||
self.mb_description_buffer.clear()
|
||||
for cache in self.cache_manager_list:
|
||||
cache.free_all()
|
||||
|
||||
def next(self):
|
||||
self.idx = (self.idx + 1) % self.buffer_size
|
||||
|
||||
def _remove_descrption(self):
|
||||
self.mb_descrption_buffer.pop(self.idx)
|
||||
def _remove_description(self):
|
||||
self.mb_description_buffer.pop(self.idx)
|
||||
|
||||
@property
|
||||
def cur_descrption(self) -> MicroBatchDescription:
|
||||
return self.mb_descrption_buffer.get(self.idx)
|
||||
def cur_description(self) -> MicroBatchDescription:
|
||||
return self.mb_description_buffer.get(self.idx)
|
||||
|
||||
@property
|
||||
def cur_infer_state(self):
|
||||
if self.cur_descrption is None:
|
||||
if self.cur_description is None:
|
||||
return None
|
||||
return self.cur_descrption.infer_state
|
||||
return self.cur_description.infer_state
|
||||
|
||||
@property
|
||||
def cur_state(self):
|
||||
"""
|
||||
Return the state of current micro batch, when current descrption is None, the state is PREFILL
|
||||
Return the state of current micro batch, when current description is None, the state is PREFILL
|
||||
|
||||
"""
|
||||
if self.cur_descrption is None:
|
||||
if self.cur_description is None:
|
||||
return Status.PREFILL
|
||||
return self.cur_descrption.state
|
||||
return self.cur_description.state
|
||||
|
|
|
@ -18,8 +18,8 @@ class Status(Enum):
|
|||
|
||||
class MicroBatchDescription:
|
||||
"""
|
||||
This is the class to record the infomation of each microbatch, and also do some update operation.
|
||||
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
|
||||
This is the class to record the information of each microbatch, and also do some update operation.
|
||||
This class is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
|
||||
details, please refer to the doc of these two classes blow.
|
||||
|
||||
Args:
|
||||
|
@ -62,15 +62,15 @@ class MicroBatchDescription:
|
|||
@property
|
||||
def cur_length(self):
|
||||
"""
|
||||
Return the current sequnence length of micro batch
|
||||
Return the current sequence length of micro batch
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class HeadMicroBatchDescription(MicroBatchDescription):
|
||||
"""
|
||||
This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask`
|
||||
and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the
|
||||
This class is used to record the information of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask`
|
||||
and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schedule of pipeline, the operation to update the
|
||||
information and the condition to determine the state is different from other stages.
|
||||
|
||||
Args:
|
||||
|
@ -124,7 +124,7 @@ class HeadMicroBatchDescription(MicroBatchDescription):
|
|||
|
||||
class BodyMicroBatchDescription(MicroBatchDescription):
|
||||
"""
|
||||
This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`,
|
||||
This class is used to record the information of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`,
|
||||
|
||||
Args:
|
||||
inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage.
|
||||
|
@ -174,76 +174,76 @@ class MicroBatchManager:
|
|||
self.max_input_len = max_input_len
|
||||
self.max_output_len = max_output_len
|
||||
self.cache_manager_list = cache_manager_list
|
||||
self.mb_descrption_buffer = {}
|
||||
self.mb_description_buffer = {}
|
||||
self.new_tokens_buffer = {}
|
||||
self.idx = 0
|
||||
|
||||
def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]):
|
||||
def add_description(self, inputs_dict: Dict[str, torch.Tensor]):
|
||||
if self.stage == 0:
|
||||
self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(
|
||||
self.mb_description_buffer[self.idx] = HeadMicroBatchDescription(
|
||||
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
|
||||
)
|
||||
else:
|
||||
self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(
|
||||
self.mb_description_buffer[self.idx] = BodyMicroBatchDescription(
|
||||
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
|
||||
)
|
||||
|
||||
def step(self, new_token: torch.Tensor = None):
|
||||
"""
|
||||
Update the state if microbatch manager, 2 conditions.
|
||||
1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs.
|
||||
2. For other conditon, only receive the output of previous stage, and update the descrption.
|
||||
1. For first stage in PREFILL, receive inputs and outputs, `_add_description` will save its inputs.
|
||||
2. For other condition, only receive the output of previous stage, and update the description.
|
||||
|
||||
Args:
|
||||
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
|
||||
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
|
||||
new_token (torch.Tensor): the new token generated by current stage.
|
||||
"""
|
||||
# Add descrption first if the descrption is None
|
||||
self.cur_descrption.update(new_token)
|
||||
# Add description first if the description is None
|
||||
self.cur_description.update(new_token)
|
||||
return self.cur_state
|
||||
|
||||
def export_new_tokens(self):
|
||||
new_tokens_list = []
|
||||
for i in self.mb_descrption_buffer.values():
|
||||
for i in self.mb_description_buffer.values():
|
||||
new_tokens_list.extend(i.new_tokens.tolist())
|
||||
return new_tokens_list
|
||||
|
||||
def is_micro_batch_done(self):
|
||||
if len(self.mb_descrption_buffer) == 0:
|
||||
if len(self.mb_description_buffer) == 0:
|
||||
return False
|
||||
for mb in self.mb_descrption_buffer.values():
|
||||
for mb in self.mb_description_buffer.values():
|
||||
if mb.state != Status.DONE:
|
||||
return False
|
||||
return True
|
||||
|
||||
def clear(self):
|
||||
self.mb_descrption_buffer.clear()
|
||||
self.mb_description_buffer.clear()
|
||||
for cache in self.cache_manager_list:
|
||||
cache.free_all()
|
||||
|
||||
def next(self):
|
||||
self.idx = (self.idx + 1) % self.buffer_size
|
||||
|
||||
def _remove_descrption(self):
|
||||
self.mb_descrption_buffer.pop(self.idx)
|
||||
def _remove_description(self):
|
||||
self.mb_description_buffer.pop(self.idx)
|
||||
|
||||
@property
|
||||
def cur_descrption(self) -> MicroBatchDescription:
|
||||
return self.mb_descrption_buffer.get(self.idx)
|
||||
def cur_description(self) -> MicroBatchDescription:
|
||||
return self.mb_description_buffer.get(self.idx)
|
||||
|
||||
@property
|
||||
def cur_infer_state(self):
|
||||
if self.cur_descrption is None:
|
||||
if self.cur_description is None:
|
||||
return None
|
||||
return self.cur_descrption.infer_state
|
||||
return self.cur_description.infer_state
|
||||
|
||||
@property
|
||||
def cur_state(self):
|
||||
"""
|
||||
Return the state of current micro batch, when current descrption is None, the state is PREFILL
|
||||
Return the state of current micro batch, when current description is None, the state is PREFILL
|
||||
|
||||
"""
|
||||
if self.cur_descrption is None:
|
||||
if self.cur_description is None:
|
||||
return Status.PREFILL
|
||||
return self.cur_descrption.state
|
||||
return self.cur_description.state
|
||||
|
|
|
@ -95,7 +95,7 @@ class GenerateSchedule(PipelineSchedule):
|
|||
Returns:
|
||||
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
|
||||
"""
|
||||
model_inputs = {"infer_state": self.mb_manager.cur_descrption.infer_state}
|
||||
model_inputs = {"infer_state": self.mb_manager.cur_description.infer_state}
|
||||
return model_inputs
|
||||
|
||||
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
|
||||
|
@ -107,7 +107,7 @@ class GenerateSchedule(PipelineSchedule):
|
|||
Returns:
|
||||
dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}`
|
||||
"""
|
||||
new_mask = self.mb_manager.cur_descrption.attn_mask
|
||||
new_mask = self.mb_manager.cur_description.attn_mask
|
||||
|
||||
return dict(input_ids=new_token, attention_mask=new_mask)
|
||||
|
||||
|
@ -133,7 +133,7 @@ class GenerateSchedule(PipelineSchedule):
|
|||
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
|
||||
"""
|
||||
inputs_dict = self.load_micro_batch()
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
self.mb_manager.add_description(inputs_dict)
|
||||
|
||||
def _load_stage_action(self, model: Module) -> None:
|
||||
"""
|
||||
|
@ -141,7 +141,7 @@ class GenerateSchedule(PipelineSchedule):
|
|||
1.load micro_batch 2.do the forward 3.step to update
|
||||
"""
|
||||
inputs_dict = self.load_micro_batch()
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
self.mb_manager.add_description(inputs_dict)
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
|
@ -379,7 +379,7 @@ class GenerateSchedule(PipelineSchedule):
|
|||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
self.mb_manager.add_description(inputs_dict)
|
||||
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
# In GENERATE phase
|
||||
|
@ -415,7 +415,7 @@ class GenerateSchedule(PipelineSchedule):
|
|||
inputs_dict = None
|
||||
if self.mb_manager.cur_state is Status.PREFILL:
|
||||
inputs_dict = self.load_micro_batch()
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
self.mb_manager.add_description(inputs_dict)
|
||||
interval_inputs = {
|
||||
"hidden_states": hidden_states["hidden_states"],
|
||||
"infer_state": self.mb_manager.cur_infer_state,
|
||||
|
|
Loading…
Reference in New Issue