|
|
|
@ -679,52 +679,201 @@ class MixtralPipelineForwards:
|
|
|
|
|
) |
|
|
|
|
past_key_values = None |
|
|
|
|
|
|
|
|
|
if stage_manager.is_last_stage(): |
|
|
|
|
hidden_states = outputs[0] |
|
|
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
logits = logits.float() |
|
|
|
|
|
|
|
|
|
loss = None |
|
|
|
|
if labels is not None: |
|
|
|
|
# Shift so that tokens < n predict n |
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
# Flatten the tokens |
|
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
|
|
|
shift_labels = shift_labels.view(-1) |
|
|
|
|
# Enable model parallelism |
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
|
|
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
|
|
|
|
|
|
aux_loss = None |
|
|
|
|
if output_router_logits: |
|
|
|
|
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) |
|
|
|
|
####### |
|
|
|
|
# Attention, we support consider 1f1b, interleaved, zbv |
|
|
|
|
####### |
|
|
|
|
if stage_manager.is_interleave: |
|
|
|
|
if stage_manager.use_zbv: |
|
|
|
|
# zbv |
|
|
|
|
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1: |
|
|
|
|
hidden_states = outputs[0] |
|
|
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
logits = logits.float() |
|
|
|
|
|
|
|
|
|
loss = None |
|
|
|
|
if labels is not None: |
|
|
|
|
# Shift so that tokens < n predict n |
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
# Flatten the tokens |
|
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
|
|
|
shift_labels = shift_labels.view(-1) |
|
|
|
|
# Enable model parallelism |
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
|
|
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
|
|
|
|
|
|
aux_loss = None |
|
|
|
|
if output_router_logits: |
|
|
|
|
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) |
|
|
|
|
if labels is not None: |
|
|
|
|
loss += self.router_aux_loss_coef * aux_loss |
|
|
|
|
|
|
|
|
|
if not return_dict: |
|
|
|
|
output = (logits,) + outputs[1:] |
|
|
|
|
if output_router_logits: |
|
|
|
|
output = (aux_loss,) + output |
|
|
|
|
return (loss,) + output if loss is not None else output |
|
|
|
|
|
|
|
|
|
return MoeCausalLMOutputWithPast( |
|
|
|
|
loss=loss, |
|
|
|
|
aux_loss=aux_loss, |
|
|
|
|
logits=logits, |
|
|
|
|
past_key_values=None, |
|
|
|
|
hidden_states=outputs[0], |
|
|
|
|
attentions=None, |
|
|
|
|
router_logits=outputs[-1], |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
out = {} |
|
|
|
|
hidden_states = outputs.get("hidden_states") |
|
|
|
|
out["hidden_states"] = hidden_states |
|
|
|
|
if output_router_logits: |
|
|
|
|
out["past_router_logits"] = outputs["past_router_logits"] |
|
|
|
|
return out |
|
|
|
|
else: |
|
|
|
|
# interleaved |
|
|
|
|
if stage_manager.is_last_stage(ignore_chunk=True): |
|
|
|
|
hidden_states = outputs[0] |
|
|
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
logits = logits.float() |
|
|
|
|
|
|
|
|
|
loss = None |
|
|
|
|
if labels is not None: |
|
|
|
|
# Shift so that tokens < n predict n |
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
# Flatten the tokens |
|
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
|
|
|
shift_labels = shift_labels.view(-1) |
|
|
|
|
# Enable model parallelism |
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
|
|
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
|
|
|
|
|
|
aux_loss = None |
|
|
|
|
if output_router_logits: |
|
|
|
|
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) |
|
|
|
|
if labels is not None: |
|
|
|
|
loss += self.router_aux_loss_coef * aux_loss |
|
|
|
|
|
|
|
|
|
if not return_dict: |
|
|
|
|
output = (logits,) + outputs[1:] |
|
|
|
|
if output_router_logits: |
|
|
|
|
output = (aux_loss,) + output |
|
|
|
|
return (loss,) + output if loss is not None else output |
|
|
|
|
|
|
|
|
|
return MoeCausalLMOutputWithPast( |
|
|
|
|
loss=loss, |
|
|
|
|
aux_loss=aux_loss, |
|
|
|
|
logits=logits, |
|
|
|
|
past_key_values=None, |
|
|
|
|
hidden_states=outputs[0], |
|
|
|
|
attentions=None, |
|
|
|
|
router_logits=outputs[-1], |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
out = {} |
|
|
|
|
hidden_states = outputs.get("hidden_states") |
|
|
|
|
out["hidden_states"] = hidden_states |
|
|
|
|
if output_router_logits: |
|
|
|
|
out["past_router_logits"] = outputs["past_router_logits"] |
|
|
|
|
return out |
|
|
|
|
else: |
|
|
|
|
# 1f1b or otherwise |
|
|
|
|
if stage_manager.is_last_stage(): |
|
|
|
|
hidden_states = outputs[0] |
|
|
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
logits = logits.float() |
|
|
|
|
|
|
|
|
|
loss = None |
|
|
|
|
if labels is not None: |
|
|
|
|
loss += self.router_aux_loss_coef * aux_loss |
|
|
|
|
# Shift so that tokens < n predict n |
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
# Flatten the tokens |
|
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
|
|
|
shift_labels = shift_labels.view(-1) |
|
|
|
|
# Enable model parallelism |
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
|
|
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
|
|
|
|
|
|
aux_loss = None |
|
|
|
|
if output_router_logits: |
|
|
|
|
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) |
|
|
|
|
if labels is not None: |
|
|
|
|
loss += self.router_aux_loss_coef * aux_loss |
|
|
|
|
|
|
|
|
|
if not return_dict: |
|
|
|
|
output = (logits,) + outputs[1:] |
|
|
|
|
if not return_dict: |
|
|
|
|
output = (logits,) + outputs[1:] |
|
|
|
|
if output_router_logits: |
|
|
|
|
output = (aux_loss,) + output |
|
|
|
|
return (loss,) + output if loss is not None else output |
|
|
|
|
|
|
|
|
|
return MoeCausalLMOutputWithPast( |
|
|
|
|
loss=loss, |
|
|
|
|
aux_loss=aux_loss, |
|
|
|
|
logits=logits, |
|
|
|
|
past_key_values=None, |
|
|
|
|
hidden_states=outputs[0], |
|
|
|
|
attentions=None, |
|
|
|
|
router_logits=outputs[-1], |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
out = {} |
|
|
|
|
hidden_states = outputs.get("hidden_states") |
|
|
|
|
out["hidden_states"] = hidden_states |
|
|
|
|
if output_router_logits: |
|
|
|
|
output = (aux_loss,) + output |
|
|
|
|
return (loss,) + output if loss is not None else output |
|
|
|
|
|
|
|
|
|
return MoeCausalLMOutputWithPast( |
|
|
|
|
loss=loss, |
|
|
|
|
aux_loss=aux_loss, |
|
|
|
|
logits=logits, |
|
|
|
|
past_key_values=None, |
|
|
|
|
hidden_states=outputs[0], |
|
|
|
|
attentions=None, |
|
|
|
|
router_logits=outputs[-1], |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
out = {} |
|
|
|
|
hidden_states = outputs.get("hidden_states") |
|
|
|
|
out["hidden_states"] = hidden_states |
|
|
|
|
if output_router_logits: |
|
|
|
|
out["past_router_logits"] = outputs["past_router_logits"] |
|
|
|
|
return out |
|
|
|
|
out["past_router_logits"] = outputs["past_router_logits"] |
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
# if stage_manager.is_last_stage(): |
|
|
|
|
# hidden_states = outputs[0] |
|
|
|
|
# logits = self.lm_head(hidden_states) |
|
|
|
|
# logits = logits.float() |
|
|
|
|
|
|
|
|
|
# loss = None |
|
|
|
|
# if labels is not None: |
|
|
|
|
# # Shift so that tokens < n predict n |
|
|
|
|
# shift_logits = logits[..., :-1, :].contiguous() |
|
|
|
|
# shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
# # Flatten the tokens |
|
|
|
|
# loss_fct = CrossEntropyLoss() |
|
|
|
|
# shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
|
|
|
# shift_labels = shift_labels.view(-1) |
|
|
|
|
# # Enable model parallelism |
|
|
|
|
# shift_labels = shift_labels.to(shift_logits.device) |
|
|
|
|
# loss = loss_fct(shift_logits, shift_labels) |
|
|
|
|
|
|
|
|
|
# aux_loss = None |
|
|
|
|
# if output_router_logits: |
|
|
|
|
# aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) |
|
|
|
|
# if labels is not None: |
|
|
|
|
# loss += self.router_aux_loss_coef * aux_loss |
|
|
|
|
|
|
|
|
|
# if not return_dict: |
|
|
|
|
# output = (logits,) + outputs[1:] |
|
|
|
|
# if output_router_logits: |
|
|
|
|
# output = (aux_loss,) + output |
|
|
|
|
# return (loss,) + output if loss is not None else output |
|
|
|
|
|
|
|
|
|
# return MoeCausalLMOutputWithPast( |
|
|
|
|
# loss=loss, |
|
|
|
|
# aux_loss=aux_loss, |
|
|
|
|
# logits=logits, |
|
|
|
|
# past_key_values=None, |
|
|
|
|
# hidden_states=outputs[0], |
|
|
|
|
# attentions=None, |
|
|
|
|
# router_logits=outputs[-1], |
|
|
|
|
# ) |
|
|
|
|
# else: |
|
|
|
|
# out = {} |
|
|
|
|
# hidden_states = outputs.get("hidden_states") |
|
|
|
|
# out["hidden_states"] = hidden_states |
|
|
|
|
# if output_router_logits: |
|
|
|
|
# out["past_router_logits"] = outputs["past_router_logits"] |
|
|
|
|
# return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): |
|
|
|
|