[fix\ fix fail case test_shard_llama

pull/6083/head
duanjunwen 2024-10-25 02:28:55 +00:00
parent 2eca112c90
commit d0ec221b38
5 changed files with 10 additions and 12 deletions

View File

@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
import torch.cuda import torch.cuda
import torch.distributed
from torch.nn import Module, ModuleList from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_flatten, tree_map from torch.utils._pytree import tree_flatten, tree_map
@ -544,7 +545,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
ctx = optimizer.no_sync() ctx = optimizer.no_sync()
except AttributeError: except AttributeError:
ctx = model_chunk.no_sync() ctx = model_chunk.no_sync()
with ctx: with ctx:
optimizer.backward_by_grad( optimizer.backward_by_grad(
tensor=output_obj_, tensor=output_obj_,

View File

@ -228,5 +228,4 @@ class PipelineStageManager:
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2 start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
for i in range(start_position, start_position + remainder): for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1 layers_per_stage[i] += 1
# print(f"layers_per_stage {layers_per_stage}")
return layers_per_stage return layers_per_stage

View File

@ -32,7 +32,6 @@ from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, RingAttention, dist_cross_entropy from ..layer import ColoAttention, RingAttention, dist_cross_entropy
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
_GLOBAL_ORDER_ = 0
class LlamaPipelineForwards: class LlamaPipelineForwards:
@ -194,10 +193,6 @@ class LlamaPipelineForwards:
assert num_ckpt_layers <= end_idx - start_idx assert num_ckpt_layers <= end_idx - start_idx
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
# global _GLOBAL_ORDER_
# if torch.distributed.get_rank() == 0:
# print(f"rank {torch.distributed.get_rank()} {stage_manager.stage}; start:{start_idx}, end:{end_idx} hidden_states require grad{hidden_states.requires_grad}")
# # _GLOBAL_ORDER_ += 1
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if idx - start_idx < num_ckpt_layers: if idx - start_idx < num_ckpt_layers:
@ -221,8 +216,6 @@ class LlamaPipelineForwards:
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
) )
# if torch.distributed.get_rank() == 0:
# print(f"rank {torch.distributed.get_rank()} {stage_manager.stage}; start:{start_idx}, end:{end_idx} layer_outputs require grad {layer_outputs[0].requires_grad}")
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

View File

@ -287,6 +287,11 @@ def main():
# ============================== # ==============================
dp_size = getattr(plugin, "dp_size", coordinator.world_size) dp_size = getattr(plugin, "dp_size", coordinator.world_size)
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42)
dataset = RandomDataset( dataset = RandomDataset(
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size

View File

@ -923,10 +923,11 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
@parameterize( @parameterize(
"config", "config",
[ [
# (0, 4, 1, 1), # (1, 2, 2, 1), # Pass
(1, 2, 2, 1), # TODO: only support pp + tp accleration; Will support fully pp and None tp Hybrid in furture;
(0, 4, 1, 1),
# (1, 2, 1, 2), # (1, 2, 1, 2),
# (1, 1, 2, 2), # TODO: no pp show gather result err # (1, 1, 2, 2),
], ],
) )
def run_with_booster_hybridplugin(config: Tuple[int, ...]): def run_with_booster_hybridplugin(config: Tuple[int, ...]):