[fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral & llama policy and modeling;

pull/6083/head
duanjunwen 2024-10-11 07:32:43 +00:00
parent e234dfa236
commit 0ca16d5cbe
5 changed files with 134 additions and 430 deletions

View File

@ -432,7 +432,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
internal_inputs = {} if input_obj is None else input_obj
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
# last layer in model
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
loss = criterion(output_obj, micro_batch) / self.num_microbatch
@ -500,12 +499,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
optimizer.backward_by_grad(
tensor=output_obj_,
grad=output_obj_grad_,
inputs=input_obj_,
retain_graph=True,
)
try:
ctx = optimizer.no_sync()
except AttributeError:
ctx = model_chunk.no_sync()
with ctx:
optimizer.backward_by_grad(
tensor=output_obj_,
grad=output_obj_grad_,
inputs=input_obj_,
retain_graph=True,
)
# Format output_obj_grad
input_obj_grad = {}

View File

@ -267,98 +267,25 @@ class MixtralPipelineForwards:
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if stage_manager.is_interleave:
if stage_manager.use_zbv:
# zbv
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 0:
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
# retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage():
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
# interleaved
if stage_manager.is_first_stage(ignore_chunk=True):
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
# 1f1b or None
if stage_manager.is_first_stage(): # No ignore_chunk=True for 1f1b
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
#######
# Attention, we support consider 1f1b, interleaved, zbv
#######
# # retrieve input_ids and inputs_embeds
# print(f"model_chunk_id {stage_manager.model_chunk_id} stage_manager {stage_manager.stage}")
# if stage_manager.is_first_stage():
# # retrieve input_ids and inputs_embeds
# if input_ids is not None and inputs_embeds is not None:
# raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
# elif input_ids is not None:
# batch_size, seq_length = input_ids.shape
# elif inputs_embeds is not None:
# batch_size, seq_length, _ = inputs_embeds.shape
# else:
# raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
# device = input_ids.device if input_ids is not None else inputs_embeds.device
# if inputs_embeds is None:
# inputs_embeds = self.embed_tokens(input_ids)
# hidden_states = inputs_embeds
# else:
# input_shape = hidden_states.shape[:-1]
# batch_size, seq_length = input_shape
# device = hidden_states.device
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
seq_length_with_past = seq_length
past_key_values_length = 0
@ -462,22 +389,8 @@ class MixtralPipelineForwards:
if output_router_logits:
all_router_logits += (layer_outputs[-1],)
#######
# Attention, we support consider 1f1b, interleaved, zbv
#######
if stage_manager.is_interleave:
if stage_manager.use_zbv:
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
hidden_states = self.norm(hidden_states)
else:
if stage_manager.is_last_stage(ignore_chunk=True):
hidden_states = self.norm(hidden_states)
else:
if stage_manager.is_last_stage(): # No ignore_chunk=True for 1f1b
hidden_states = self.norm(hidden_states)
# if stage_manager.is_last_stage():
# hidden_states = self.norm(hidden_states)
if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
@ -487,113 +400,30 @@ class MixtralPipelineForwards:
if output_router_logits and past_router_logits is not None:
all_router_logits = past_router_logits + all_router_logits
#######
# Attention, we support consider 1f1b, interleaved, zbv
#######
if stage_manager.is_interleave:
if stage_manager.use_zbv:
# zbv
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
if v is not None
)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)
else:
if output_router_logits:
return {
"hidden_states": hidden_states,
"past_router_logits": all_router_logits,
}
else:
return {
"hidden_states": hidden_states,
}
else:
# interlearved
if stage_manager.is_last_stage(ignore_chunk=True):
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
if v is not None
)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)
else:
if output_router_logits:
return {
"hidden_states": hidden_states,
"past_router_logits": all_router_logits,
}
else:
return {
"hidden_states": hidden_states,
}
else:
# 1f1b or other
if stage_manager.is_last_stage():
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
if v is not None
)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
if stage_manager.is_last_stage():
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
if v is not None
)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)
else:
if output_router_logits:
return {
"hidden_states": hidden_states,
"past_router_logits": all_router_logits,
}
else:
if output_router_logits:
return {
"hidden_states": hidden_states,
"past_router_logits": all_router_logits,
}
else:
return {
"hidden_states": hidden_states,
}
# if stage_manager.is_last_stage():
# if not return_dict:
# return tuple(
# v
# for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
# if v is not None
# )
# return MoeModelOutputWithPast(
# last_hidden_state=hidden_states,
# past_key_values=next_cache,
# hidden_states=all_hidden_states,
# attentions=all_self_attns,
# router_logits=all_router_logits,
# )
# else:
# if output_router_logits:
# return {
# "hidden_states": hidden_states,
# "past_router_logits": all_router_logits,
# }
# else:
# return {
# "hidden_states": hidden_states,
# }
return {
"hidden_states": hidden_states,
}
@staticmethod
def mixtral_for_causal_lm_forward(
@ -679,201 +509,51 @@ class MixtralPipelineForwards:
)
past_key_values = None
#######
# Attention, we support consider 1f1b, interleaved, zbv
#######
if stage_manager.is_interleave:
if stage_manager.use_zbv:
# zbv
if stage_manager.is_first_stage(ignore_chunk=True) and stage_manager.model_chunk_id == 1:
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
if stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss
if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=None,
hidden_states=outputs[0],
attentions=None,
router_logits=outputs[-1],
)
else:
out = {}
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
if output_router_logits:
out["past_router_logits"] = outputs["past_router_logits"]
return out
else:
# interleaved
if stage_manager.is_last_stage(ignore_chunk=True):
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss
if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=None,
hidden_states=outputs[0],
attentions=None,
router_logits=outputs[-1],
)
else:
out = {}
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
if output_router_logits:
out["past_router_logits"] = outputs["past_router_logits"]
return out
else:
# 1f1b or otherwise
if stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
loss += self.router_aux_loss_coef * aux_loss
aux_loss = None
if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output
if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=None,
hidden_states=outputs[0],
attentions=None,
router_logits=outputs[-1],
)
else:
out = {}
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
if output_router_logits:
out["past_router_logits"] = outputs["past_router_logits"]
return out
# if stage_manager.is_last_stage():
# hidden_states = outputs[0]
# logits = self.lm_head(hidden_states)
# logits = logits.float()
# loss = None
# if labels is not None:
# # Shift so that tokens < n predict n
# shift_logits = logits[..., :-1, :].contiguous()
# shift_labels = labels[..., 1:].contiguous()
# # Flatten the tokens
# loss_fct = CrossEntropyLoss()
# shift_logits = shift_logits.view(-1, self.config.vocab_size)
# shift_labels = shift_labels.view(-1)
# # Enable model parallelism
# shift_labels = shift_labels.to(shift_logits.device)
# loss = loss_fct(shift_logits, shift_labels)
# aux_loss = None
# if output_router_logits:
# aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
# if labels is not None:
# loss += self.router_aux_loss_coef * aux_loss
# if not return_dict:
# output = (logits,) + outputs[1:]
# if output_router_logits:
# output = (aux_loss,) + output
# return (loss,) + output if loss is not None else output
# return MoeCausalLMOutputWithPast(
# loss=loss,
# aux_loss=aux_loss,
# logits=logits,
# past_key_values=None,
# hidden_states=outputs[0],
# attentions=None,
# router_logits=outputs[-1],
# )
# else:
# out = {}
# hidden_states = outputs.get("hidden_states")
# out["hidden_states"] = hidden_states
# if output_router_logits:
# out["past_router_logits"] = outputs["past_router_logits"]
# return out
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=None,
hidden_states=outputs[0],
attentions=None,
router_logits=outputs[-1],
)
else:
out = {}
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
if output_router_logits:
out["past_router_logits"] = outputs["past_router_logits"]
return out
def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):

View File

@ -343,18 +343,10 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_interleave:
if stage_manager.use_zbv:
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
else:
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
else:
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
# if stage_manager.is_last_stage():
# held_layers.append(self.model.lm_head)
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:

View File

@ -21,6 +21,7 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchF
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
from colossalai.shardformer import PipelineGradientCheckpointConfig
warnings.filterwarnings("ignore")
@ -91,7 +92,7 @@ def main():
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
parser.add_argument("--profile", action="store_true", help="Profile the code")
parser.add_argument(
@ -137,6 +138,11 @@ def main():
# ==============================
# Initialize Booster
# ==============================
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
use_empty_init = True
if args.plugin == "gemini":
plugin = GeminiPlugin(
@ -210,6 +216,23 @@ def main():
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "3d":
if args.pp_style == "zbv":
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
mem_w = -32 * config.hidden_size
mem_b = -mem_w - mem_f
scheduler_nodes = PipelineGraph(
n_stage=args.pp,
n_micro=args.batch_size // args.mbs,
f_cost=1000,
b_cost=1000,
w_cost=1000,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
).get_v_schedule()
else:
scheduler_nodes = None
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
@ -227,6 +250,7 @@ def main():
overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
scheduler_nodes=scheduler_nodes,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":
@ -256,10 +280,6 @@ def main():
# ==============================
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
torch.cuda.manual_seed(42)
dataset = RandomDataset(
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
@ -334,8 +354,12 @@ def main():
return_loss=True,
)
loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
if args.pp_style == "zbv":
if dist.get_rank() == 0:
print(f"Step {step} loss: {loss}")
else:
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
optimizer.step()
optimizer.zero_grad()

View File

@ -227,7 +227,6 @@ def main():
)
optimizer = HybridAdam(model.parameters())
# optimizer = torch.optim.SGD(model.parameters(), lr=1)
torch.set_default_dtype(torch.bfloat16)
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
@ -258,8 +257,12 @@ def main():
return_loss=True,
)
loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
if args.pp_style == "zbv":
if dist.get_rank() == 0:
print(f"Step {step} loss: {loss}")
else:
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
optimizer.step()
optimizer.zero_grad()