mirror of https://github.com/hpcaitech/ColossalAI
[feat] update MixtralPipelineForwards --> mixtral_model_forward; support zbv;
parent
9ee80fc828
commit
72b507a7be
|
@ -267,26 +267,98 @@ class MixtralPipelineForwards:
|
|||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
print(f"model_chunk_id {stage_manager.model_chunk_id} stage_manager {stage_manager.stage}")
|
||||
if stage_manager.is_first_stage():
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
if stage_manager.is_interleave:
|
||||
if stage_manager.use_zbv:
|
||||
# zbv
|
||||
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 0:
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
# interleaved
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
# 1f1b or None
|
||||
if stage_manager.is_first_stage(): # No ignore_chunk=True for 1f1b
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
#######
|
||||
# Attention, we support consider 1f1b, interleaved, zbv
|
||||
#######
|
||||
|
||||
# # retrieve input_ids and inputs_embeds
|
||||
# print(f"model_chunk_id {stage_manager.model_chunk_id} stage_manager {stage_manager.stage}")
|
||||
# if stage_manager.is_first_stage():
|
||||
# # retrieve input_ids and inputs_embeds
|
||||
# if input_ids is not None and inputs_embeds is not None:
|
||||
# raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
# elif input_ids is not None:
|
||||
# batch_size, seq_length = input_ids.shape
|
||||
# elif inputs_embeds is not None:
|
||||
# batch_size, seq_length, _ = inputs_embeds.shape
|
||||
# else:
|
||||
# raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
# device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
# if inputs_embeds is None:
|
||||
# inputs_embeds = self.embed_tokens(input_ids)
|
||||
# hidden_states = inputs_embeds
|
||||
# else:
|
||||
# input_shape = hidden_states.shape[:-1]
|
||||
# batch_size, seq_length = input_shape
|
||||
# device = hidden_states.device
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
@ -390,8 +462,22 @@ class MixtralPipelineForwards:
|
|||
if output_router_logits:
|
||||
all_router_logits += (layer_outputs[-1],)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
#######
|
||||
# Attention, we support consider 1f1b, interleaved, zbv
|
||||
#######
|
||||
if stage_manager.is_interleave:
|
||||
if stage_manager.use_zbv:
|
||||
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
else:
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
hidden_states = self.norm(hidden_states)
|
||||
else:
|
||||
if stage_manager.is_last_stage(): # No ignore_chunk=True for 1f1b
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# if stage_manager.is_last_stage():
|
||||
# hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
@ -400,30 +486,114 @@ class MixtralPipelineForwards:
|
|||
|
||||
if output_router_logits and past_router_logits is not None:
|
||||
all_router_logits = past_router_logits + all_router_logits
|
||||
if stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||
if v is not None
|
||||
)
|
||||
return MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
else:
|
||||
if output_router_logits:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"past_router_logits": all_router_logits,
|
||||
}
|
||||
|
||||
#######
|
||||
# Attention, we support consider 1f1b, interleaved, zbv
|
||||
#######
|
||||
if stage_manager.is_interleave:
|
||||
if stage_manager.use_zbv:
|
||||
# zbv
|
||||
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||
if v is not None
|
||||
)
|
||||
return MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
else:
|
||||
if output_router_logits:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"past_router_logits": all_router_logits,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
}
|
||||
# interlearved
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||
if v is not None
|
||||
)
|
||||
return MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
else:
|
||||
if output_router_logits:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"past_router_logits": all_router_logits,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
}
|
||||
else:
|
||||
# 1f1b or other
|
||||
if stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||
if v is not None
|
||||
)
|
||||
return MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
else:
|
||||
if output_router_logits:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"past_router_logits": all_router_logits,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
}
|
||||
|
||||
# if stage_manager.is_last_stage():
|
||||
# if not return_dict:
|
||||
# return tuple(
|
||||
# v
|
||||
# for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||
# if v is not None
|
||||
# )
|
||||
# return MoeModelOutputWithPast(
|
||||
# last_hidden_state=hidden_states,
|
||||
# past_key_values=next_cache,
|
||||
# hidden_states=all_hidden_states,
|
||||
# attentions=all_self_attns,
|
||||
# router_logits=all_router_logits,
|
||||
# )
|
||||
# else:
|
||||
# if output_router_logits:
|
||||
# return {
|
||||
# "hidden_states": hidden_states,
|
||||
# "past_router_logits": all_router_logits,
|
||||
# }
|
||||
# else:
|
||||
# return {
|
||||
# "hidden_states": hidden_states,
|
||||
# }
|
||||
|
||||
@staticmethod
|
||||
def mixtral_for_causal_lm_forward(
|
||||
|
|
Loading…
Reference in New Issue