mirror of https://github.com/hpcaitech/ColossalAI
[feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forward for zbv
parent
72b507a7be
commit
e234dfa236
|
@ -679,6 +679,108 @@ class MixtralPipelineForwards:
|
||||||
)
|
)
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
|
|
||||||
|
#######
|
||||||
|
# Attention, we support consider 1f1b, interleaved, zbv
|
||||||
|
#######
|
||||||
|
if stage_manager.is_interleave:
|
||||||
|
if stage_manager.use_zbv:
|
||||||
|
# zbv
|
||||||
|
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
aux_loss = None
|
||||||
|
if output_router_logits:
|
||||||
|
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
|
||||||
|
if labels is not None:
|
||||||
|
loss += self.router_aux_loss_coef * aux_loss
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
if output_router_logits:
|
||||||
|
output = (aux_loss,) + output
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return MoeCausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
aux_loss=aux_loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=None,
|
||||||
|
hidden_states=outputs[0],
|
||||||
|
attentions=None,
|
||||||
|
router_logits=outputs[-1],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = {}
|
||||||
|
hidden_states = outputs.get("hidden_states")
|
||||||
|
out["hidden_states"] = hidden_states
|
||||||
|
if output_router_logits:
|
||||||
|
out["past_router_logits"] = outputs["past_router_logits"]
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
# interleaved
|
||||||
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
aux_loss = None
|
||||||
|
if output_router_logits:
|
||||||
|
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
|
||||||
|
if labels is not None:
|
||||||
|
loss += self.router_aux_loss_coef * aux_loss
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
if output_router_logits:
|
||||||
|
output = (aux_loss,) + output
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return MoeCausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
aux_loss=aux_loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=None,
|
||||||
|
hidden_states=outputs[0],
|
||||||
|
attentions=None,
|
||||||
|
router_logits=outputs[-1],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = {}
|
||||||
|
hidden_states = outputs.get("hidden_states")
|
||||||
|
out["hidden_states"] = hidden_states
|
||||||
|
if output_router_logits:
|
||||||
|
out["past_router_logits"] = outputs["past_router_logits"]
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
# 1f1b or otherwise
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
|
@ -726,6 +828,53 @@ class MixtralPipelineForwards:
|
||||||
out["past_router_logits"] = outputs["past_router_logits"]
|
out["past_router_logits"] = outputs["past_router_logits"]
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
# if stage_manager.is_last_stage():
|
||||||
|
# hidden_states = outputs[0]
|
||||||
|
# logits = self.lm_head(hidden_states)
|
||||||
|
# logits = logits.float()
|
||||||
|
|
||||||
|
# loss = None
|
||||||
|
# if labels is not None:
|
||||||
|
# # Shift so that tokens < n predict n
|
||||||
|
# shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
# shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# # Flatten the tokens
|
||||||
|
# loss_fct = CrossEntropyLoss()
|
||||||
|
# shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
# shift_labels = shift_labels.view(-1)
|
||||||
|
# # Enable model parallelism
|
||||||
|
# shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
# loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
# aux_loss = None
|
||||||
|
# if output_router_logits:
|
||||||
|
# aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
|
||||||
|
# if labels is not None:
|
||||||
|
# loss += self.router_aux_loss_coef * aux_loss
|
||||||
|
|
||||||
|
# if not return_dict:
|
||||||
|
# output = (logits,) + outputs[1:]
|
||||||
|
# if output_router_logits:
|
||||||
|
# output = (aux_loss,) + output
|
||||||
|
# return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
# return MoeCausalLMOutputWithPast(
|
||||||
|
# loss=loss,
|
||||||
|
# aux_loss=aux_loss,
|
||||||
|
# logits=logits,
|
||||||
|
# past_key_values=None,
|
||||||
|
# hidden_states=outputs[0],
|
||||||
|
# attentions=None,
|
||||||
|
# router_logits=outputs[-1],
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# out = {}
|
||||||
|
# hidden_states = outputs.get("hidden_states")
|
||||||
|
# out["hidden_states"] = hidden_states
|
||||||
|
# if output_router_logits:
|
||||||
|
# out["past_router_logits"] = outputs["past_router_logits"]
|
||||||
|
# return out
|
||||||
|
|
||||||
|
|
||||||
def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
|
@ -786,6 +786,8 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||||
seed_all(10086)
|
seed_all(10086)
|
||||||
|
|
||||||
torch_model = MixtralModel(config).to(dtype).cuda()
|
torch_model = MixtralModel(config).to(dtype).cuda()
|
||||||
|
# TODO: Support MixtralForCausalLM
|
||||||
|
# torch_model = MixtralForCausalLM(config).to(dtype).cuda()
|
||||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||||
# init schedule
|
# init schedule
|
||||||
h, a, s = config.hidden_size, config.num_attention_heads, 1024
|
h, a, s = config.hidden_size, config.num_attention_heads, 1024
|
||||||
|
|
Loading…
Reference in New Issue