mirror of https://github.com/InternLM/InternLM
Test(pp): test pipeline parallel (#413)
* test: pp * feat: add pp test * test pp * pp test * pp test * test pppull/424/head
parent
aa5e34d815
commit
30f610b1fa
|
@ -0,0 +1,318 @@
|
|||
import multiprocessing as mp
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import internlm
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.context.parallel_context import Config
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
|
||||
from internlm.core.scheduler import (
|
||||
InterleavedPipelineScheduler,
|
||||
PipelineScheduler,
|
||||
SchedulerMetricHook,
|
||||
)
|
||||
from internlm.solver.pipeline_utils import partition_uniform
|
||||
from internlm.train import initialize_optimizer
|
||||
|
||||
|
||||
class MlpModel(nn.Module):
|
||||
"""
|
||||
Custom model
|
||||
"""
|
||||
|
||||
def __init__(self, start, end, model_type=None):
|
||||
super().__init__()
|
||||
self.part = [start, end]
|
||||
self.blocks = nn.ModuleList([nn.Linear(8, 8, bias=False) for lid in range(end - start)])
|
||||
self.model_type = model_type
|
||||
|
||||
def forward(self, hidden_states=None, input_ids=None):
|
||||
if self.model_type != "torch" and self.part[0] != 0:
|
||||
input_ids = hidden_states
|
||||
|
||||
for i in range(self.part[1] - self.part[0]):
|
||||
input_ids = self.blocks[i](input_ids)
|
||||
return input_ids
|
||||
|
||||
|
||||
class MyLoss(nn.Module):
|
||||
"""
|
||||
Custom loss
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
loss = torch.nn.MSELoss(reduction="sum")
|
||||
return loss(logits, labels)
|
||||
|
||||
|
||||
config = Config(
|
||||
dict(
|
||||
gradient_handler=[dict(type="PipelineSharedModuleGradientHandler")],
|
||||
parallel=dict(zero1=1, pipeline=dict(size=8, interleaved_overlap=False), sequence_parallel=False, tensor=1),
|
||||
model_type="INTERNLM",
|
||||
data=dict(seq_len=8, micro_num=16, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999),
|
||||
model=dict(
|
||||
dtype=torch.bfloat16,
|
||||
num_chunks=2,
|
||||
use_flash_attn=True,
|
||||
),
|
||||
resume_tb_folder="",
|
||||
tensorboard_folder="",
|
||||
alert_address=None,
|
||||
monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)),
|
||||
grad_scaler=dict(
|
||||
fp16=dict(
|
||||
initial_scale=1,
|
||||
min_scale=1,
|
||||
growth_interval=1,
|
||||
),
|
||||
growth_factor=1.1,
|
||||
backoff_factor=0.9,
|
||||
max_scale=1,
|
||||
hysteresis=1,
|
||||
),
|
||||
adam=dict(
|
||||
lr=1e-4,
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.95,
|
||||
adam_beta2_c=0,
|
||||
adam_eps=1e-8,
|
||||
weight_decay=0.01,
|
||||
),
|
||||
hybrid_zero_optimizer=dict(
|
||||
overlap_sync_grad=False,
|
||||
overlap_sync_param=False,
|
||||
reduce_bucket_size=512 * 1024 * 1024,
|
||||
clip_grad_norm=1.0,
|
||||
),
|
||||
beta2_scheduler=dict(
|
||||
init_beta2=0.95,
|
||||
c=0,
|
||||
cur_iter=-1,
|
||||
),
|
||||
lr_scheduler=dict(
|
||||
total_steps=100,
|
||||
init_steps=0,
|
||||
warmup_ratio=0.01,
|
||||
eta_min=1e-5,
|
||||
last_epoch=-1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_environment(rank, world_size):
|
||||
import os
|
||||
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["LOCAL_RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "33333"
|
||||
torch.cuda.empty_cache()
|
||||
# launcher="torch"
|
||||
internlm.launch_from_torch(config=config, seed=1024)
|
||||
|
||||
|
||||
def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
||||
|
||||
if dtype is torch.float32:
|
||||
rtol = 1.3e-6
|
||||
atol = 1e-5
|
||||
elif dtype is torch.bfloat16:
|
||||
rtol = 2e-2
|
||||
atol = 2e-2
|
||||
|
||||
if isinstance(a, torch.Tensor):
|
||||
a = a.detach().to(dtype)
|
||||
b = b.detach().to(dtype)
|
||||
|
||||
assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def seed_all(seed, cuda_deterministic=False):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
if cuda_deterministic: # slower, more reproducible
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
else:
|
||||
torch.backends.cudnn.deterministic = False
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def _build_generic_model_1d(num_layers, num_chunks):
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
|
||||
parts = all_parts[pipeline_rank]
|
||||
if gpc.is_rank_for_log():
|
||||
print(f"The layer sharding is {all_parts}.", flush=True)
|
||||
|
||||
models = []
|
||||
for start, end in parts:
|
||||
models.append(MlpModel(start, end).cuda())
|
||||
torch.distributed.barrier()
|
||||
if len(models) == 1:
|
||||
model = models[0]
|
||||
else:
|
||||
model = nn.ModuleList(models)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def exam_pipeline_parallel(args):
|
||||
# init
|
||||
rank, world_size, micro_num, num_chunks, interleaved_overlap = args
|
||||
config.data.micro_num = micro_num
|
||||
config.model.num_chunks = num_chunks
|
||||
config.parallel.pipeline.interleaved_overlap = interleaved_overlap
|
||||
|
||||
build_environment(rank, world_size)
|
||||
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = config.model["dtype"]
|
||||
|
||||
# set seed
|
||||
seed_all(1024)
|
||||
|
||||
# pp model
|
||||
pp_model = _build_generic_model_1d(num_layers=32, num_chunks=num_chunks)
|
||||
pp_model = pp_model.to(dtype)
|
||||
|
||||
# pp scheduler
|
||||
scheduler_hooks = [
|
||||
SchedulerMetricHook(skip=True),
|
||||
]
|
||||
|
||||
seq_len = gpc.config.data.seq_len
|
||||
gpc.config.NUM_MICRO_BATCHES = micro_num
|
||||
communication_overlap = interleaved_overlap
|
||||
|
||||
if num_chunks == 1:
|
||||
# noninterleaved pp
|
||||
scheduler = PipelineScheduler(
|
||||
data_process_func=None,
|
||||
num_microbatches=micro_num,
|
||||
dtype=dtype,
|
||||
tensor_shape=[1, 8],
|
||||
scatter_gather_tensors=False,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
)
|
||||
else:
|
||||
# interleaved pp
|
||||
if micro_num < gpc.get_world_size(ParallelMode.PIPELINE):
|
||||
try:
|
||||
scheduler = InterleavedPipelineScheduler(
|
||||
num_microbatches=micro_num,
|
||||
num_chunks=gpc.config.model.num_chunks,
|
||||
dtype=dtype,
|
||||
tensor_shape=[1, 8],
|
||||
scatter_gather_tensors=False,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
communication_overlap=communication_overlap,
|
||||
)
|
||||
except AssertionError:
|
||||
return
|
||||
else:
|
||||
raise RuntimeError("Error: AssertionError should occur when micro_num < Pipeline parrallel world size")
|
||||
else:
|
||||
scheduler = InterleavedPipelineScheduler(
|
||||
num_microbatches=micro_num,
|
||||
num_chunks=gpc.config.model.num_chunks,
|
||||
dtype=dtype,
|
||||
tensor_shape=[1, 8],
|
||||
scatter_gather_tensors=False,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
communication_overlap=communication_overlap,
|
||||
)
|
||||
|
||||
# pp optimizer and engine
|
||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=pp_model)
|
||||
engine = Engine(
|
||||
model=pp_model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
beta2_scheduler=beta2_scheduler,
|
||||
criterion=MyLoss().to(dtype),
|
||||
gradient_handlers=[PipelineSharedModuleGradientHandler(model=pp_model, optimizer=optimizer)],
|
||||
clip_grad_norm=gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0),
|
||||
)
|
||||
|
||||
scheduler.pre_processing(engine)
|
||||
engine.train()
|
||||
|
||||
# create input
|
||||
x_list = []
|
||||
y_list = []
|
||||
for _ in range(micro_num):
|
||||
x_list.append(list(range(seq_len)))
|
||||
y_list.append(list(range(seq_len)))
|
||||
xs = torch.tensor(x_list).to(device).to(dtype)
|
||||
yx = torch.tensor(y_list).to(device).to(dtype)
|
||||
|
||||
input_list = [{"input_ids": xs}, yx]
|
||||
|
||||
# pp forward and backward
|
||||
output, _, loss = scheduler.forward_backward_step(
|
||||
engine, input_list, forward_only=False, return_loss=True, return_output_label=True
|
||||
)
|
||||
|
||||
engine.step()
|
||||
|
||||
# torch related
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
torch_xs = torch.tensor(x_list).to(device).to(torch.float32)
|
||||
torch_ys = torch.tensor(y_list).to(device).to(torch.float32)
|
||||
torch_model = MlpModel(0, 32, "torch").to(device)
|
||||
torch_optimizer = torch.optim.AdamW(
|
||||
params=[{"params": torch_model.parameters(), "weight_decay": config.adam.weight_decay}],
|
||||
lr=config.adam.lr,
|
||||
betas=(config.adam.adam_beta1, config.adam.adam_beta2),
|
||||
eps=config.adam.adam_eps,
|
||||
)
|
||||
|
||||
# check output
|
||||
torch_output = torch_model(input_ids=torch_xs) # pylint: disable=E1102
|
||||
loose_close(torch_output, output, dtype=dtype)
|
||||
|
||||
torch_criterion = MyLoss().to(torch.float32)
|
||||
torch_loss = torch_criterion(torch_output, torch_ys) / micro_num # pylint: disable=E1102
|
||||
torch_loss.backward()
|
||||
torch_optimizer.step()
|
||||
|
||||
# check loss
|
||||
loose_close(torch_loss, loss[0], dtype=dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("micro_num", [4, 8, 16])
|
||||
@pytest.mark.parametrize("num_chunks", [1, 2, 4])
|
||||
@pytest.mark.parametrize("interleaved_overlap", [True, False])
|
||||
def test_pipeline_parallel(micro_num, num_chunks, interleaved_overlap):
|
||||
ctx = mp.get_context("spawn")
|
||||
with ctx.Pool(processes=8) as pool:
|
||||
pool.map(
|
||||
exam_pipeline_parallel,
|
||||
[[rank, 8, micro_num, num_chunks, interleaved_overlap] for rank in range(8)],
|
||||
)
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-s", "-q", "test_pipeline.py"])
|
Loading…
Reference in New Issue