mirror of https://github.com/InternLM/InternLM
pp test
parent
69ff9f2f5c
commit
5ab0dc8dc2
99
new_test.py
99
new_test.py
|
@ -45,35 +45,46 @@ import torch.distributed as dist
|
||||||
|
|
||||||
class MlpModel(nn.Module):
|
class MlpModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self, start, end):
|
def __init__(self, start, end, type=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.part = [start , end]
|
self.part = [start , end]
|
||||||
self.blocks = nn.ModuleList([nn.Linear(8, 8, bias=False) for lid in range(end -start)])
|
self.blocks = nn.ModuleList([nn.Linear(8, 8, bias=False) for lid in range(end -start)])
|
||||||
|
self.type = type
|
||||||
|
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||||
|
print(f'{gpc.get_global_rank()}: self.part={self.part}', flush=True)
|
||||||
|
|
||||||
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
||||||
print(gpc.get_global_rank(), 'hidden_states:', hidden_states, flush=True)
|
# print(gpc.get_global_rank(), 'hidden_states:', hidden_states, flush=True)
|
||||||
if self.part[0] != 0:
|
if self.type != 'torch' and not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||||
input_ids = hidden_states
|
input_ids = hidden_states
|
||||||
|
|
||||||
print(f'pp stage: {gpc.get_local_rank(ParallelMode.PIPELINE)} MLP {self.part} fwd:', input_ids.shape, flush=True)
|
# print(f'pp stage: {gpc.get_local_rank(ParallelMode.PIPELINE)} MLP {self.part} fwd:', input_ids.shape, flush=True)
|
||||||
print(gpc.get_global_rank(), 'len_blocsk:', len(self.blocks), flush=True)
|
# print(gpc.get_global_rank(), 'len_blocsk:', len(self.blocks), flush=True)
|
||||||
current_device = torch.cuda.current_device()
|
# current_device = torch.cuda.current_device()
|
||||||
print(gpc.get_global_rank(), 'current_device:', current_device, flush=True)
|
# print(gpc.get_global_rank(), 'current_device:', current_device, flush=True)
|
||||||
input_ids = input_ids.to(current_device)
|
# input_ids = input_ids.to(current_device)
|
||||||
print(gpc.get_global_rank(), 'mlp_input_data:', input_ids, input_ids.shape, type(input_ids), flush=True)
|
# print(gpc.get_global_rank(), 'mlp_input_data:', input_ids, input_ids.shape, type(input_ids), flush=True)
|
||||||
x = self.blocks[0](input_ids) + self.blocks[1](input_ids)
|
for i in range(self.part[1] - self.part[0]):
|
||||||
print(gpc.get_global_rank(), 'mlp_output_data:', x, x.shape, flush=True)
|
input_ids = self.blocks[i](input_ids)
|
||||||
return x
|
return input_ids
|
||||||
|
# x = self.blocks[0](input_ids)
|
||||||
|
# x = self.blocks[0](x)
|
||||||
|
# print(gpc.get_global_rank(), 'mlp_output_data:', x, x.shape, flush=True)
|
||||||
|
# return x
|
||||||
|
|
||||||
config = Config(
|
config = Config(
|
||||||
dict(
|
dict(
|
||||||
|
HIDDEN_SIZE=8,
|
||||||
|
SEQ_LEN=8,
|
||||||
gradient_handler=[dict(type="PipelineSharedModuleGradientHandler")],
|
gradient_handler=[dict(type="PipelineSharedModuleGradientHandler")],
|
||||||
HIDDEN_SIZE=4,
|
parallel=dict(zero1=1, pipeline=dict(size=8, interleaved_overlap=True), sequence_parallel=False, tensor=1),
|
||||||
parallel=dict(zero1=1, pipeline=dict(size=8, interleaved_overlap=False), sequence_parallel=False, tensor=1),
|
|
||||||
model_type="INTERNLM",
|
model_type="INTERNLM",
|
||||||
data=dict(seq_len=8, micro_num=16, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999),
|
data=dict(seq_len=8, micro_num=16, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999),
|
||||||
model=dict(
|
model=dict(
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
|
num_chunks=2,
|
||||||
|
hidden_size=8,
|
||||||
|
use_flash_attn=True,
|
||||||
),
|
),
|
||||||
resume_tb_folder="",
|
resume_tb_folder="",
|
||||||
tensorboard_folder="",
|
tensorboard_folder="",
|
||||||
|
@ -220,11 +231,12 @@ def exam_pipeline_parallel(args):
|
||||||
seed_all(1024)
|
seed_all(1024)
|
||||||
dtype=gpc.config.model["dtype"]
|
dtype=gpc.config.model["dtype"]
|
||||||
|
|
||||||
# torch_model = MlpModel().to(device)
|
|
||||||
# pp_model = copy.deepcopy(torch_model).to(dtype)
|
# pp_model = copy.deepcopy(torch_model).to(dtype)
|
||||||
pp_model = _build_generic_model_1d(num_layers=16, num_chunks=1)
|
pp_model = _build_generic_model_1d(num_layers=16, num_chunks=gpc.config.model.num_chunks)
|
||||||
pp_model = pp_model.to(dtype)
|
pp_model = pp_model.to(dtype)
|
||||||
print(pp_model, flush=True)
|
print(gpc.get_global_rank(), 'pp_model', pp_model)
|
||||||
|
|
||||||
|
|
||||||
scheduler_hooks = [
|
scheduler_hooks = [
|
||||||
SchedulerMetricHook(
|
SchedulerMetricHook(
|
||||||
|
@ -235,14 +247,26 @@ def exam_pipeline_parallel(args):
|
||||||
micro_num = gpc.config.data.micro_num
|
micro_num = gpc.config.data.micro_num
|
||||||
seq_len = gpc.config.data.seq_len
|
seq_len = gpc.config.data.seq_len
|
||||||
gpc.config.NUM_MICRO_BATCHES = micro_num
|
gpc.config.NUM_MICRO_BATCHES = micro_num
|
||||||
scheduler = PipelineScheduler(
|
|
||||||
data_process_func=None,
|
communication_overlap = gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
|
||||||
|
print(f'communication_overlap={communication_overlap}')
|
||||||
|
scheduler = InterleavedPipelineScheduler(
|
||||||
num_microbatches=micro_num,
|
num_microbatches=micro_num,
|
||||||
|
num_chunks=gpc.config.model.num_chunks,
|
||||||
dtype=gpc.config.model["dtype"],
|
dtype=gpc.config.model["dtype"],
|
||||||
tensor_shape=None,
|
tensor_shape=get_tensor_shape(),
|
||||||
scatter_gather_tensors=False,
|
scatter_gather_tensors=False,
|
||||||
scheduler_hooks=scheduler_hooks,
|
scheduler_hooks=scheduler_hooks,
|
||||||
|
communication_overlap=communication_overlap,
|
||||||
)
|
)
|
||||||
|
# scheduler = PipelineScheduler(
|
||||||
|
# data_process_func=None,
|
||||||
|
# num_microbatches=micro_num,
|
||||||
|
# dtype=dtype,
|
||||||
|
# tensor_shape=None,
|
||||||
|
# scatter_gather_tensors=False,
|
||||||
|
# scheduler_hooks=scheduler_hooks,
|
||||||
|
# )
|
||||||
|
|
||||||
print(f"gpc.config.hybrid_zero_optimizer: {gpc.config.hybrid_zero_optimizer}", flush=True)
|
print(f"gpc.config.hybrid_zero_optimizer: {gpc.config.hybrid_zero_optimizer}", flush=True)
|
||||||
# optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=pp_model)
|
# optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=pp_model)
|
||||||
|
@ -256,7 +280,6 @@ def exam_pipeline_parallel(args):
|
||||||
# eps=1e-8,
|
# eps=1e-8,
|
||||||
# ))
|
# ))
|
||||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=pp_model)
|
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=pp_model)
|
||||||
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=0)
|
|
||||||
|
|
||||||
engine = Engine(
|
engine = Engine(
|
||||||
model=pp_model,
|
model=pp_model,
|
||||||
|
@ -277,10 +300,12 @@ def exam_pipeline_parallel(args):
|
||||||
for _ in range(micro_num):
|
for _ in range(micro_num):
|
||||||
x_list.append([i for i in range(seq_len)])
|
x_list.append([i for i in range(seq_len)])
|
||||||
y_list.append([i for i in range(seq_len)])
|
y_list.append([i for i in range(seq_len)])
|
||||||
|
torch_xs = torch.tensor(x_list).to(device).to(torch.float32)
|
||||||
|
torch_ys = torch.tensor(y_list).to(device).to(torch.float32)
|
||||||
xs = torch.tensor(x_list).to(device).to(dtype)
|
xs = torch.tensor(x_list).to(device).to(dtype)
|
||||||
yx = torch.tensor(y_list).to(device).to(dtype)
|
yx = torch.tensor(y_list).to(device).to(dtype)
|
||||||
xs.requires_grad_()
|
# xs.requires_grad_()
|
||||||
yx.requires_grad_()
|
# yx.requires_grad_()
|
||||||
print(xs.shape, yx.shape, flush=True)
|
print(xs.shape, yx.shape, flush=True)
|
||||||
input_list = [{'input_ids':xs}, yx]
|
input_list = [{'input_ids':xs}, yx]
|
||||||
|
|
||||||
|
@ -293,15 +318,35 @@ def exam_pipeline_parallel(args):
|
||||||
# output = torch_model(input)
|
# output = torch_model(input)
|
||||||
# print(output)
|
# print(output)
|
||||||
print('local_rank:', gpc.get_local_rank(ParallelMode.PIPELINE), 'start schedule', flush=True)
|
print('local_rank:', gpc.get_local_rank(ParallelMode.PIPELINE), 'start schedule', flush=True)
|
||||||
_, _, loss = scheduler.forward_backward_step(engine, input_list, forward_only=False, return_loss=True, return_output_label=False)
|
output, label, loss = scheduler.forward_backward_step(engine, input_list, forward_only=False, return_loss=True, return_output_label=True)
|
||||||
print('local_rank:', gpc.get_local_rank(ParallelMode.PIPELINE), 'end schedule', flush=True)
|
print('local_rank:', gpc.get_local_rank(ParallelMode.PIPELINE), 'end schedule', flush=True)
|
||||||
|
|
||||||
dist.barrier()
|
#dist.barrier()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
engine.step()
|
engine.step()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
# torch_output = torch_model(input_ids=torch_input)
|
|
||||||
# torch_loss = criterion(torch_output, torch_label).unsqueeze(0)
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
|
print('torch begin')
|
||||||
|
torch_model = MlpModel(0, 16, 'torch').to(device)
|
||||||
|
# torch_model = DDP(torch_model, static_graph=True)
|
||||||
|
print(gpc.get_global_rank(), 'torch_model', torch_model)
|
||||||
|
torch_optimizer = torch.optim.AdamW(
|
||||||
|
params=[{"params": torch_model.parameters(), "weight_decay": config.adam.weight_decay}],
|
||||||
|
lr=config.adam.lr,
|
||||||
|
betas=(config.adam.adam_beta1, config.adam.adam_beta2),
|
||||||
|
eps=config.adam.adam_eps,
|
||||||
|
)
|
||||||
|
torch_output = torch_model(input_ids=torch_xs)
|
||||||
|
criterion = MyLoss().to(torch.float32)
|
||||||
|
torch_loss = criterion(torch_output, torch_ys) / micro_num
|
||||||
|
torch_loss.backward()
|
||||||
|
torch_optimizer.step()
|
||||||
|
print(gpc.get_global_rank(), 'test_torch:', 'torch_output:', torch_output, 'torch_loss:', torch_loss)
|
||||||
|
print(gpc.get_global_rank(), 'test_pp:', 'output:', output, 'label:', label, 'loss:', loss)
|
||||||
|
loose_close(torch_output, output, dtype=dtype)
|
||||||
|
loose_close(torch_loss, loss[0], dtype=dtype)
|
||||||
|
print(gpc.get_global_rank(), 'assert_ok')
|
||||||
|
|
||||||
# if rank == 0:
|
# if rank == 0:
|
||||||
# print('loss:', loss)
|
# print('loss:', loss)
|
||||||
|
|
Loading…
Reference in New Issue