pull/413/head
li126com 2023-09-26 16:11:56 +08:00
parent 033e646191
commit fbcd509ff9
1 changed files with 324 additions and 0 deletions

324
new_test.py Normal file
View File

@ -0,0 +1,324 @@
import copy
import multiprocessing as mp
import random
import numpy as np
import pytest
import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import internlm
from internlm.core.context.parallel_context import Config
from internlm.core.trainer import Trainer
from internlm.core.scheduler import (
InterleavedPipelineScheduler,
NonPipelineScheduler,
PipelineScheduler,
SchedulerHook,
)
from internlm.data.utils import unpack_data
from internlm.core.scheduler.pipeline_scheduler import get_tensor_shape
from internlm.core.context import global_context as gpc
from internlm.core.context import ParallelMode
from internlm.core.scheduler import SchedulerMetricHook
from internlm.model.metrics import AccPerplex
from internlm.train import (
get_train_data_loader,
get_validation_data_loader,
initialize_llm_profile,
initialize_model,
initialize_optimizer,
load_new_batch,
record_current_batch_training_metrics,
)
from internlm.core.engine import Engine
from internlm.model.loss import FlashGPTLMLoss
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
from internlm.core.trainer import TrainState
from internlm.solver.pipeline_utils import partition_uniform
import torch.distributed as dist
class MlpModel(nn.Module):
def __init__(self, start, end):
super().__init__()
self.part = [start , end]
self.blocks = nn.ModuleList([nn.Linear(8, 8, bias=False) for lid in range(end -start)])
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
print(f'pp stage: {gpc.get_local_rank(ParallelMode.PIPELINE)} MLP {self.part} fwd:', input_ids.shape, flush=True)
print(gpc.get_global_rank(), 'len_blocsk:', len(self.blocks), flush=True)
current_device = torch.cuda.current_device()
print(gpc.get_global_rank(), 'current_device:', current_device, flush=True)
input_ids = input_ids.to(current_device)
print(gpc.get_global_rank(), 'mlp_input_data:', input_ids, input_ids.shape, type(input_ids), flush=True)
x = self.blocks[0](input_ids) + self.blocks[1](input_ids)
print(gpc.get_global_rank(), 'mlp_output_data:', x, x.shape, flush=True)
return x
config = Config(
dict(
HIDDEN_SIZE=4,
parallel=dict(zero1=1, pipeline=dict(size=8, interleaved_overlap=False), sequence_parallel=False, tensor=1),
model_type="INTERNLM",
data=dict(seq_len=8, micro_num=16, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999),
model=dict(
dtype=torch.bfloat16,
),
resume_tb_folder="",
tensorboard_folder="",
alert_address=None,
monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)),
grad_scaler=dict(
fp16=dict(
initial_scale=1,
min_scale=1,
growth_interval=1,
),
growth_factor=1.1,
backoff_factor=0.9,
max_scale=1,
hysteresis=1,
),
adam=dict(
lr=1e-4,
adam_beta1=0.9,
adam_beta2=0.95,
adam_beta2_c=0,
adam_eps=1e-8,
weight_decay=0.01,
),
hybrid_zero_optimizer=dict(
overlap_sync_grad=False,
overlap_sync_param=False,
reduce_bucket_size=512 * 1024 * 1024,
clip_grad_norm=1.0,
),
beta2_scheduler = dict(
init_beta2=0.95,
c=0,
cur_iter=-1,
),
lr_scheduler = dict(
total_steps=100,
init_steps=0, # optimizer_warmup_step
warmup_ratio=0.01,
eta_min=1e-5,
last_epoch=-1,
)
)
)
def build_environment(rank, world_size):
import os
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "33333"
torch.cuda.empty_cache()
# launcher="torch"
internlm.launch_from_torch(config=config, seed=1024)
def loose_close(a, b, dtype: torch.dtype = torch.float32):
if dtype is torch.float32:
rtol = 1.3e-6
atol = 1e-5
elif dtype is torch.bfloat16:
rtol = 2e-2
atol = 2e-2
if isinstance(a, torch.Tensor):
a = a.detach().to(dtype)
b = b.detach().to(dtype)
assert_close(a, b, rtol=rtol, atol=atol)
def seed_all(seed, cuda_deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if cuda_deterministic: # slower, more reproducible
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
"""
build generic model 1d
Args:
num_layers (int): The number of layer.
num_chunks (int): The number of partitions in pipeline parallel.
device (Optional[Union[str, torch.device]]): The device will be used. torch.device("cuda") by default.
"""
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
parts = all_parts[pipeline_rank]
if gpc.is_rank_for_log():
print(f"The layer sharding is {all_parts}.", flush=True)
models = []
for start, end in parts:
models.append(MlpModel(start, end).cuda())
torch.distributed.barrier()
if len(models) == 1:
model = models[0]
else:
model = nn.ModuleList(models)
return model
class MyLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits, labels):
loss = torch.nn.MSELoss(reduction='sum')
print(logits, flush=True)
print(labels, flush=True)
return loss(logits, labels)
def exam_pipeline_parallel(args):
import os
# rank, world_size = args
rank = os.environ["RANK"]
world_size = os.environ["WORLD_SIZE"]
build_environment(rank, world_size)
local_rank = int(os.environ["LOCAL_RANK"])
print('rank_com:', rank, local_rank)
device = torch.device(f"cuda:{local_rank}")
# print('device_id:', device)
# torch.cuda.set_device(device)
seed_all(1024)
dtype=gpc.config.model["dtype"]
# torch_model = MlpModel().to(device)
# pp_model = copy.deepcopy(torch_model).to(dtype)
pp_model = _build_generic_model_1d(num_layers=16, num_chunks=1)
pp_model = pp_model.to(dtype)
print(pp_model, flush=True)
scheduler_hooks = [
SchedulerMetricHook(
skip=True
),
]
micro_num = gpc.config.data.micro_num
seq_len = gpc.config.data.seq_len
gpc.config.NUM_MICRO_BATCHES = micro_num
scheduler = PipelineScheduler(
data_process_func=None,
num_microbatches=micro_num,
dtype=gpc.config.model["dtype"],
tensor_shape=None,
scatter_gather_tensors=False,
scheduler_hooks=scheduler_hooks,
)
print(f"gpc.config.hybrid_zero_optimizer: {gpc.config.hybrid_zero_optimizer}", flush=True)
# optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=pp_model)
# criterion = FlashGPTLMLoss(parallel_output=False, label_smoothing=0)
from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
optimizer = BaseOptimizer(torch.optim.AdamW(
params=[{"params": pp_model.parameters()}],
lr=1e-4,
betas=(0.9, 0.95),
eps=1e-8,
))
engine = Engine(
model=pp_model,
optimizer=optimizer,
lr_scheduler=None,
beta2_scheduler=None,
criterion=MyLoss().to(dtype),
gradient_handlers= [dict(type="PipelineSharedModuleGradientHandler")],
clip_grad_norm=gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0),
)
scheduler.pre_processing(engine)
engine.train()
# engine.zero_grad()
x_list = []
y_list = []
for _ in range(micro_num):
x_list.append([i for i in range(seq_len)])
y_list.append([i for i in range(seq_len)])
xs = torch.tensor(x_list).to(device).to(dtype)
yx = torch.tensor(y_list).to(device).to(dtype)
xs.requires_grad_()
yx.requires_grad_()
print(xs.shape, yx.shape, flush=True)
input_list = [{'input_ids':xs}, yx]
# torch_input = torch.tensor([[0,1,2,3]]).to(device).to(torch.float32)
# torch_label = torch.tensor([[1]]).to(device).to(torch.int64)
# print('label_shape:', input_list[1].shape)
# input_list = [{'input_ids':torch.rand(1, 4).cuda()}, torch.rand(1, 4).cuda()]
# input = input_list[0]
# print(input)
# output = torch_model(input)
# print(output)
print('local_rank:', gpc.get_local_rank(ParallelMode.PIPELINE), 'start schedule', flush=True)
_, _, loss = scheduler.forward_backward_step(engine, input_list, forward_only=False, return_loss=True, return_output_label=False)
print('local_rank:', gpc.get_local_rank(ParallelMode.PIPELINE), 'end schedule', flush=True)
dist.barrier()
torch.cuda.synchronize()
engine.step()
torch.cuda.synchronize()
# torch_output = torch_model(input_ids=torch_input)
# torch_loss = criterion(torch_output, torch_label).unsqueeze(0)
# if rank == 0:
# print('loss:', loss)
# print('torch_loss:', torch_loss)
#loose_close(loss, torch_loss, dtype=dtype)
# torch_loss.backward()
print('local_rank:', gpc.get_local_rank(ParallelMode.PIPELINE), 'everything3')
# def test_pipeline_parallel():
# ctx = mp.get_context("spawn")
# with ctx.Pool(processes=8) as pool:
# pool.map(
# exam_pipeline_parallel,
# [[rank, 8] for rank in range(8)],
# )
# pool.close()
# pool.join()
if __name__ == "__main__":
# pytest.main(["-s", "-q", "test_pipeline.py"])
exam_pipeline_parallel(None)