mirror of https://github.com/hpcaitech/ColossalAI
[feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forward for zbv
parent
72b507a7be
commit
e234dfa236
|
@ -679,6 +679,108 @@ class MixtralPipelineForwards:
|
|||
)
|
||||
past_key_values = None
|
||||
|
||||
#######
|
||||
# Attention, we support consider 1f1b, interleaved, zbv
|
||||
#######
|
||||
if stage_manager.is_interleave:
|
||||
if stage_manager.use_zbv:
|
||||
# zbv
|
||||
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
|
||||
if labels is not None:
|
||||
loss += self.router_aux_loss_coef * aux_loss
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
if output_router_logits:
|
||||
output = (aux_loss,) + output
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss,
|
||||
logits=logits,
|
||||
past_key_values=None,
|
||||
hidden_states=outputs[0],
|
||||
attentions=None,
|
||||
router_logits=outputs[-1],
|
||||
)
|
||||
else:
|
||||
out = {}
|
||||
hidden_states = outputs.get("hidden_states")
|
||||
out["hidden_states"] = hidden_states
|
||||
if output_router_logits:
|
||||
out["past_router_logits"] = outputs["past_router_logits"]
|
||||
return out
|
||||
else:
|
||||
# interleaved
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
|
||||
if labels is not None:
|
||||
loss += self.router_aux_loss_coef * aux_loss
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
if output_router_logits:
|
||||
output = (aux_loss,) + output
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss,
|
||||
logits=logits,
|
||||
past_key_values=None,
|
||||
hidden_states=outputs[0],
|
||||
attentions=None,
|
||||
router_logits=outputs[-1],
|
||||
)
|
||||
else:
|
||||
out = {}
|
||||
hidden_states = outputs.get("hidden_states")
|
||||
out["hidden_states"] = hidden_states
|
||||
if output_router_logits:
|
||||
out["past_router_logits"] = outputs["past_router_logits"]
|
||||
return out
|
||||
else:
|
||||
# 1f1b or otherwise
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
@ -726,6 +828,53 @@ class MixtralPipelineForwards:
|
|||
out["past_router_logits"] = outputs["past_router_logits"]
|
||||
return out
|
||||
|
||||
# if stage_manager.is_last_stage():
|
||||
# hidden_states = outputs[0]
|
||||
# logits = self.lm_head(hidden_states)
|
||||
# logits = logits.float()
|
||||
|
||||
# loss = None
|
||||
# if labels is not None:
|
||||
# # Shift so that tokens < n predict n
|
||||
# shift_logits = logits[..., :-1, :].contiguous()
|
||||
# shift_labels = labels[..., 1:].contiguous()
|
||||
# # Flatten the tokens
|
||||
# loss_fct = CrossEntropyLoss()
|
||||
# shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
# shift_labels = shift_labels.view(-1)
|
||||
# # Enable model parallelism
|
||||
# shift_labels = shift_labels.to(shift_logits.device)
|
||||
# loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
# aux_loss = None
|
||||
# if output_router_logits:
|
||||
# aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
|
||||
# if labels is not None:
|
||||
# loss += self.router_aux_loss_coef * aux_loss
|
||||
|
||||
# if not return_dict:
|
||||
# output = (logits,) + outputs[1:]
|
||||
# if output_router_logits:
|
||||
# output = (aux_loss,) + output
|
||||
# return (loss,) + output if loss is not None else output
|
||||
|
||||
# return MoeCausalLMOutputWithPast(
|
||||
# loss=loss,
|
||||
# aux_loss=aux_loss,
|
||||
# logits=logits,
|
||||
# past_key_values=None,
|
||||
# hidden_states=outputs[0],
|
||||
# attentions=None,
|
||||
# router_logits=outputs[-1],
|
||||
# )
|
||||
# else:
|
||||
# out = {}
|
||||
# hidden_states = outputs.get("hidden_states")
|
||||
# out["hidden_states"] = hidden_states
|
||||
# if output_router_logits:
|
||||
# out["past_router_logits"] = outputs["past_router_logits"]
|
||||
# return out
|
||||
|
||||
|
||||
def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
|
|
@ -786,6 +786,8 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|||
seed_all(10086)
|
||||
|
||||
torch_model = MixtralModel(config).to(dtype).cuda()
|
||||
# TODO: Support MixtralForCausalLM
|
||||
# torch_model = MixtralForCausalLM(config).to(dtype).cuda()
|
||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||
# init schedule
|
||||
h, a, s = config.hidden_size, config.num_attention_heads, 1024
|
||||
|
|
Loading…
Reference in New Issue