mirror of https://github.com/hpcaitech/ColossalAI
[fix\ fix fail case test_shard_llama
parent
2eca112c90
commit
d0ec221b38
|
@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
import torch.distributed
|
||||||
from torch.nn import Module, ModuleList
|
from torch.nn import Module, ModuleList
|
||||||
from torch.utils._pytree import tree_flatten, tree_map
|
from torch.utils._pytree import tree_flatten, tree_map
|
||||||
|
|
||||||
|
@ -544,7 +545,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
ctx = optimizer.no_sync()
|
ctx = optimizer.no_sync()
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
ctx = model_chunk.no_sync()
|
ctx = model_chunk.no_sync()
|
||||||
|
|
||||||
with ctx:
|
with ctx:
|
||||||
optimizer.backward_by_grad(
|
optimizer.backward_by_grad(
|
||||||
tensor=output_obj_,
|
tensor=output_obj_,
|
||||||
|
|
|
@ -228,5 +228,4 @@ class PipelineStageManager:
|
||||||
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
|
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
|
||||||
for i in range(start_position, start_position + remainder):
|
for i in range(start_position, start_position + remainder):
|
||||||
layers_per_stage[i] += 1
|
layers_per_stage[i] += 1
|
||||||
# print(f"layers_per_stage {layers_per_stage}")
|
|
||||||
return layers_per_stage
|
return layers_per_stage
|
||||||
|
|
|
@ -32,7 +32,6 @@ from colossalai.shardformer.shard import ShardConfig
|
||||||
from ..layer import ColoAttention, RingAttention, dist_cross_entropy
|
from ..layer import ColoAttention, RingAttention, dist_cross_entropy
|
||||||
|
|
||||||
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
|
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
|
||||||
_GLOBAL_ORDER_ = 0
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaPipelineForwards:
|
class LlamaPipelineForwards:
|
||||||
|
@ -194,10 +193,6 @@ class LlamaPipelineForwards:
|
||||||
assert num_ckpt_layers <= end_idx - start_idx
|
assert num_ckpt_layers <= end_idx - start_idx
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||||
# global _GLOBAL_ORDER_
|
|
||||||
# if torch.distributed.get_rank() == 0:
|
|
||||||
# print(f"rank {torch.distributed.get_rank()} {stage_manager.stage}; start:{start_idx}, end:{end_idx} hidden_states require grad{hidden_states.requires_grad}")
|
|
||||||
# # _GLOBAL_ORDER_ += 1
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
if idx - start_idx < num_ckpt_layers:
|
if idx - start_idx < num_ckpt_layers:
|
||||||
|
@ -221,8 +216,6 @@ class LlamaPipelineForwards:
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
# if torch.distributed.get_rank() == 0:
|
|
||||||
# print(f"rank {torch.distributed.get_rank()} {stage_manager.stage}; start:{start_idx}, end:{end_idx} layer_outputs require grad {layer_outputs[0].requires_grad}")
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
|
|
@ -287,6 +287,11 @@ def main():
|
||||||
# ==============================
|
# ==============================
|
||||||
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
|
dp_size = getattr(plugin, "dp_size", coordinator.world_size)
|
||||||
|
|
||||||
|
if args.config in MODEL_CONFIGS:
|
||||||
|
config = MODEL_CONFIGS[args.config]
|
||||||
|
else:
|
||||||
|
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||||
|
|
||||||
torch.cuda.manual_seed(42)
|
torch.cuda.manual_seed(42)
|
||||||
dataset = RandomDataset(
|
dataset = RandomDataset(
|
||||||
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
||||||
|
|
|
@ -923,10 +923,11 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"config",
|
"config",
|
||||||
[
|
[
|
||||||
# (0, 4, 1, 1),
|
# (1, 2, 2, 1), # Pass
|
||||||
(1, 2, 2, 1),
|
# TODO: only support pp + tp accleration; Will support fully pp and None tp Hybrid in furture;
|
||||||
|
(0, 4, 1, 1),
|
||||||
# (1, 2, 1, 2),
|
# (1, 2, 1, 2),
|
||||||
# (1, 1, 2, 2), # TODO: no pp show gather result err
|
# (1, 1, 2, 2),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
||||||
|
|
Loading…
Reference in New Issue