Browse Source

[feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forward for zbv

pull/6083/head
duanjunwen 1 month ago
parent
commit
e234dfa236
  1. 235
      colossalai/shardformer/modeling/mixtral.py
  2. 2
      tests/test_pipeline/test_schedule/test_zerobubble_pp.py

235
colossalai/shardformer/modeling/mixtral.py

@ -679,52 +679,201 @@ class MixtralPipelineForwards:
)
past_key_values = None
if stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
#######
# Attention, we support consider 1f1b, interleaved, zbv
#######
if stage_manager.is_interleave:
if stage_manager.use_zbv:
# zbv
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss
if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=None,
hidden_states=outputs[0],
attentions=None,
router_logits=outputs[-1],
)
else:
out = {}
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
if output_router_logits:
out["past_router_logits"] = outputs["past_router_logits"]
return out
else:
# interleaved
if stage_manager.is_last_stage(ignore_chunk=True):
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss
if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=None,
hidden_states=outputs[0],
attentions=None,
router_logits=outputs[-1],
)
else:
out = {}
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
if output_router_logits:
out["past_router_logits"] = outputs["past_router_logits"]
return out
else:
# 1f1b or otherwise
if stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss
if not return_dict:
output = (logits,) + outputs[1:]
if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=None,
hidden_states=outputs[0],
attentions=None,
router_logits=outputs[-1],
)
else:
out = {}
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=None,
hidden_states=outputs[0],
attentions=None,
router_logits=outputs[-1],
)
else:
out = {}
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
if output_router_logits:
out["past_router_logits"] = outputs["past_router_logits"]
return out
out["past_router_logits"] = outputs["past_router_logits"]
return out
# if stage_manager.is_last_stage():
# hidden_states = outputs[0]
# logits = self.lm_head(hidden_states)
# logits = logits.float()
# loss = None
# if labels is not None:
# # Shift so that tokens < n predict n
# shift_logits = logits[..., :-1, :].contiguous()
# shift_labels = labels[..., 1:].contiguous()
# # Flatten the tokens
# loss_fct = CrossEntropyLoss()
# shift_logits = shift_logits.view(-1, self.config.vocab_size)
# shift_labels = shift_labels.view(-1)
# # Enable model parallelism
# shift_labels = shift_labels.to(shift_logits.device)
# loss = loss_fct(shift_logits, shift_labels)
# aux_loss = None
# if output_router_logits:
# aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
# if labels is not None:
# loss += self.router_aux_loss_coef * aux_loss
# if not return_dict:
# output = (logits,) + outputs[1:]
# if output_router_logits:
# output = (aux_loss,) + output
# return (loss,) + output if loss is not None else output
# return MoeCausalLMOutputWithPast(
# loss=loss,
# aux_loss=aux_loss,
# logits=logits,
# past_key_values=None,
# hidden_states=outputs[0],
# attentions=None,
# router_logits=outputs[-1],
# )
# else:
# out = {}
# hidden_states = outputs.get("hidden_states")
# out["hidden_states"] = hidden_states
# if output_router_logits:
# out["past_router_logits"] = outputs["past_router_logits"]
# return out
def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):

2
tests/test_pipeline/test_schedule/test_zerobubble_pp.py

@ -786,6 +786,8 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
seed_all(10086)
torch_model = MixtralModel(config).to(dtype).cuda()
# TODO: Support MixtralForCausalLM
# torch_model = MixtralForCausalLM(config).to(dtype).cuda()
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
# init schedule
h, a, s = config.hidden_size, config.num_attention_heads, 1024

Loading…
Cancel
Save