mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral & llama policy and modeling;
parent
e234dfa236
commit
0ca16d5cbe
|
@ -432,7 +432,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
internal_inputs = {} if input_obj is None else input_obj
|
||||
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
|
||||
|
||||
# last layer in model
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
||||
|
@ -500,12 +499,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
|
||||
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
|
||||
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj_,
|
||||
grad=output_obj_grad_,
|
||||
inputs=input_obj_,
|
||||
retain_graph=True,
|
||||
)
|
||||
try:
|
||||
ctx = optimizer.no_sync()
|
||||
except AttributeError:
|
||||
ctx = model_chunk.no_sync()
|
||||
|
||||
with ctx:
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj_,
|
||||
grad=output_obj_grad_,
|
||||
inputs=input_obj_,
|
||||
retain_graph=True,
|
||||
)
|
||||
|
||||
# Format output_obj_grad
|
||||
input_obj_grad = {}
|
||||
|
|
|
@ -267,98 +267,25 @@ class MixtralPipelineForwards:
|
|||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if stage_manager.is_interleave:
|
||||
if stage_manager.use_zbv:
|
||||
# zbv
|
||||
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 0:
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
# interleaved
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
# 1f1b or None
|
||||
if stage_manager.is_first_stage(): # No ignore_chunk=True for 1f1b
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
#######
|
||||
# Attention, we support consider 1f1b, interleaved, zbv
|
||||
#######
|
||||
|
||||
# # retrieve input_ids and inputs_embeds
|
||||
# print(f"model_chunk_id {stage_manager.model_chunk_id} stage_manager {stage_manager.stage}")
|
||||
# if stage_manager.is_first_stage():
|
||||
# # retrieve input_ids and inputs_embeds
|
||||
# if input_ids is not None and inputs_embeds is not None:
|
||||
# raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
# elif input_ids is not None:
|
||||
# batch_size, seq_length = input_ids.shape
|
||||
# elif inputs_embeds is not None:
|
||||
# batch_size, seq_length, _ = inputs_embeds.shape
|
||||
# else:
|
||||
# raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
# device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
# if inputs_embeds is None:
|
||||
# inputs_embeds = self.embed_tokens(input_ids)
|
||||
# hidden_states = inputs_embeds
|
||||
# else:
|
||||
# input_shape = hidden_states.shape[:-1]
|
||||
# batch_size, seq_length = input_shape
|
||||
# device = hidden_states.device
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
@ -462,22 +389,8 @@ class MixtralPipelineForwards:
|
|||
if output_router_logits:
|
||||
all_router_logits += (layer_outputs[-1],)
|
||||
|
||||
#######
|
||||
# Attention, we support consider 1f1b, interleaved, zbv
|
||||
#######
|
||||
if stage_manager.is_interleave:
|
||||
if stage_manager.use_zbv:
|
||||
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
else:
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
hidden_states = self.norm(hidden_states)
|
||||
else:
|
||||
if stage_manager.is_last_stage(): # No ignore_chunk=True for 1f1b
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# if stage_manager.is_last_stage():
|
||||
# hidden_states = self.norm(hidden_states)
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
@ -487,113 +400,30 @@ class MixtralPipelineForwards:
|
|||
if output_router_logits and past_router_logits is not None:
|
||||
all_router_logits = past_router_logits + all_router_logits
|
||||
|
||||
#######
|
||||
# Attention, we support consider 1f1b, interleaved, zbv
|
||||
#######
|
||||
if stage_manager.is_interleave:
|
||||
if stage_manager.use_zbv:
|
||||
# zbv
|
||||
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||
if v is not None
|
||||
)
|
||||
return MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
else:
|
||||
if output_router_logits:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"past_router_logits": all_router_logits,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
}
|
||||
else:
|
||||
# interlearved
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||
if v is not None
|
||||
)
|
||||
return MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
else:
|
||||
if output_router_logits:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"past_router_logits": all_router_logits,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
}
|
||||
else:
|
||||
# 1f1b or other
|
||||
if stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||
if v is not None
|
||||
)
|
||||
return MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
router_logits=all_router_logits,
|
||||
if stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||
if v is not None
|
||||
)
|
||||
return MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
else:
|
||||
if output_router_logits:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"past_router_logits": all_router_logits,
|
||||
}
|
||||
else:
|
||||
if output_router_logits:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"past_router_logits": all_router_logits,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
}
|
||||
|
||||
# if stage_manager.is_last_stage():
|
||||
# if not return_dict:
|
||||
# return tuple(
|
||||
# v
|
||||
# for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||
# if v is not None
|
||||
# )
|
||||
# return MoeModelOutputWithPast(
|
||||
# last_hidden_state=hidden_states,
|
||||
# past_key_values=next_cache,
|
||||
# hidden_states=all_hidden_states,
|
||||
# attentions=all_self_attns,
|
||||
# router_logits=all_router_logits,
|
||||
# )
|
||||
# else:
|
||||
# if output_router_logits:
|
||||
# return {
|
||||
# "hidden_states": hidden_states,
|
||||
# "past_router_logits": all_router_logits,
|
||||
# }
|
||||
# else:
|
||||
# return {
|
||||
# "hidden_states": hidden_states,
|
||||
# }
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def mixtral_for_causal_lm_forward(
|
||||
|
@ -679,201 +509,51 @@ class MixtralPipelineForwards:
|
|||
)
|
||||
past_key_values = None
|
||||
|
||||
#######
|
||||
# Attention, we support consider 1f1b, interleaved, zbv
|
||||
#######
|
||||
if stage_manager.is_interleave:
|
||||
if stage_manager.use_zbv:
|
||||
# zbv
|
||||
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
|
||||
if labels is not None:
|
||||
loss += self.router_aux_loss_coef * aux_loss
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
if output_router_logits:
|
||||
output = (aux_loss,) + output
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss,
|
||||
logits=logits,
|
||||
past_key_values=None,
|
||||
hidden_states=outputs[0],
|
||||
attentions=None,
|
||||
router_logits=outputs[-1],
|
||||
)
|
||||
else:
|
||||
out = {}
|
||||
hidden_states = outputs.get("hidden_states")
|
||||
out["hidden_states"] = hidden_states
|
||||
if output_router_logits:
|
||||
out["past_router_logits"] = outputs["past_router_logits"]
|
||||
return out
|
||||
else:
|
||||
# interleaved
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
|
||||
if labels is not None:
|
||||
loss += self.router_aux_loss_coef * aux_loss
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
if output_router_logits:
|
||||
output = (aux_loss,) + output
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss,
|
||||
logits=logits,
|
||||
past_key_values=None,
|
||||
hidden_states=outputs[0],
|
||||
attentions=None,
|
||||
router_logits=outputs[-1],
|
||||
)
|
||||
else:
|
||||
out = {}
|
||||
hidden_states = outputs.get("hidden_states")
|
||||
out["hidden_states"] = hidden_states
|
||||
if output_router_logits:
|
||||
out["past_router_logits"] = outputs["past_router_logits"]
|
||||
return out
|
||||
else:
|
||||
# 1f1b or otherwise
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
loss += self.router_aux_loss_coef * aux_loss
|
||||
|
||||
aux_loss = None
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
|
||||
if labels is not None:
|
||||
loss += self.router_aux_loss_coef * aux_loss
|
||||
output = (aux_loss,) + output
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
if output_router_logits:
|
||||
output = (aux_loss,) + output
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss,
|
||||
logits=logits,
|
||||
past_key_values=None,
|
||||
hidden_states=outputs[0],
|
||||
attentions=None,
|
||||
router_logits=outputs[-1],
|
||||
)
|
||||
else:
|
||||
out = {}
|
||||
hidden_states = outputs.get("hidden_states")
|
||||
out["hidden_states"] = hidden_states
|
||||
if output_router_logits:
|
||||
out["past_router_logits"] = outputs["past_router_logits"]
|
||||
return out
|
||||
|
||||
# if stage_manager.is_last_stage():
|
||||
# hidden_states = outputs[0]
|
||||
# logits = self.lm_head(hidden_states)
|
||||
# logits = logits.float()
|
||||
|
||||
# loss = None
|
||||
# if labels is not None:
|
||||
# # Shift so that tokens < n predict n
|
||||
# shift_logits = logits[..., :-1, :].contiguous()
|
||||
# shift_labels = labels[..., 1:].contiguous()
|
||||
# # Flatten the tokens
|
||||
# loss_fct = CrossEntropyLoss()
|
||||
# shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
# shift_labels = shift_labels.view(-1)
|
||||
# # Enable model parallelism
|
||||
# shift_labels = shift_labels.to(shift_logits.device)
|
||||
# loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
# aux_loss = None
|
||||
# if output_router_logits:
|
||||
# aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
|
||||
# if labels is not None:
|
||||
# loss += self.router_aux_loss_coef * aux_loss
|
||||
|
||||
# if not return_dict:
|
||||
# output = (logits,) + outputs[1:]
|
||||
# if output_router_logits:
|
||||
# output = (aux_loss,) + output
|
||||
# return (loss,) + output if loss is not None else output
|
||||
|
||||
# return MoeCausalLMOutputWithPast(
|
||||
# loss=loss,
|
||||
# aux_loss=aux_loss,
|
||||
# logits=logits,
|
||||
# past_key_values=None,
|
||||
# hidden_states=outputs[0],
|
||||
# attentions=None,
|
||||
# router_logits=outputs[-1],
|
||||
# )
|
||||
# else:
|
||||
# out = {}
|
||||
# hidden_states = outputs.get("hidden_states")
|
||||
# out["hidden_states"] = hidden_states
|
||||
# if output_router_logits:
|
||||
# out["past_router_logits"] = outputs["past_router_logits"]
|
||||
# return out
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss,
|
||||
logits=logits,
|
||||
past_key_values=None,
|
||||
hidden_states=outputs[0],
|
||||
attentions=None,
|
||||
router_logits=outputs[-1],
|
||||
)
|
||||
else:
|
||||
out = {}
|
||||
hidden_states = outputs.get("hidden_states")
|
||||
out["hidden_states"] = hidden_states
|
||||
if output_router_logits:
|
||||
out["past_router_logits"] = outputs["past_router_logits"]
|
||||
return out
|
||||
|
||||
|
||||
def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||
|
|
|
@ -343,18 +343,10 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
|
|||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_interleave:
|
||||
if stage_manager.use_zbv:
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
else:
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
# if stage_manager.is_last_stage():
|
||||
# held_layers.append(self.model.lm_head)
|
||||
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
|
|
@ -21,6 +21,7 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchF
|
|||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
|
||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
@ -91,7 +92,7 @@ def main():
|
|||
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
|
||||
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
||||
|
||||
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
|
||||
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
|
||||
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
||||
parser.add_argument("--profile", action="store_true", help="Profile the code")
|
||||
parser.add_argument(
|
||||
|
@ -137,6 +138,11 @@ def main():
|
|||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.config in MODEL_CONFIGS:
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||
|
||||
use_empty_init = True
|
||||
if args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
|
@ -210,6 +216,23 @@ def main():
|
|||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
if args.pp_style == "zbv":
|
||||
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
|
||||
mem_w = -32 * config.hidden_size
|
||||
mem_b = -mem_w - mem_f
|
||||
scheduler_nodes = PipelineGraph(
|
||||
n_stage=args.pp,
|
||||
n_micro=args.batch_size // args.mbs,
|
||||
f_cost=1000,
|
||||
b_cost=1000,
|
||||
w_cost=1000,
|
||||
c_cost=1,
|
||||
f_mem=mem_f,
|
||||
b_mem=mem_b,
|
||||
w_mem=mem_w,
|
||||
).get_v_schedule()
|
||||
else:
|
||||
scheduler_nodes = None
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=args.pp,
|
||||
|
@ -227,6 +250,7 @@ def main():
|
|||
overlap_allgather=args.overlap_allgather,
|
||||
use_fp8=args.use_fp8,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
scheduler_nodes=scheduler_nodes,
|
||||
**hybrid_kwargs,
|
||||
)
|
||||
elif args.plugin == "3d_cpu":
|
||||
|
@ -256,10 +280,6 @@ def main():
|
|||
# ==============================
|
||||
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
|
||||
|
||||
if args.config in MODEL_CONFIGS:
|
||||
config = MODEL_CONFIGS[args.config]
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||
torch.cuda.manual_seed(42)
|
||||
dataset = RandomDataset(
|
||||
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
||||
|
@ -334,8 +354,12 @@ def main():
|
|||
return_loss=True,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
if args.pp_style == "zbv":
|
||||
if dist.get_rank() == 0:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
else:
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
|
|
@ -227,7 +227,6 @@ def main():
|
|||
)
|
||||
|
||||
optimizer = HybridAdam(model.parameters())
|
||||
# optimizer = torch.optim.SGD(model.parameters(), lr=1)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||
|
||||
|
@ -258,8 +257,12 @@ def main():
|
|||
return_loss=True,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
if args.pp_style == "zbv":
|
||||
if dist.get_rank() == 0:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
else:
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
print(f"Step {step} loss: {loss}")
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
|
Loading…
Reference in New Issue