mirror of https://github.com/InternLM/InternLM
pp test
parent
fbcd509ff9
commit
69ff9f2f5c
29
new_test.py
29
new_test.py
|
@ -51,18 +51,23 @@ class MlpModel(nn.Module):
|
||||||
self.blocks = nn.ModuleList([nn.Linear(8, 8, bias=False) for lid in range(end -start)])
|
self.blocks = nn.ModuleList([nn.Linear(8, 8, bias=False) for lid in range(end -start)])
|
||||||
|
|
||||||
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
||||||
|
print(gpc.get_global_rank(), 'hidden_states:', hidden_states, flush=True)
|
||||||
|
if self.part[0] != 0:
|
||||||
|
input_ids = hidden_states
|
||||||
|
|
||||||
print(f'pp stage: {gpc.get_local_rank(ParallelMode.PIPELINE)} MLP {self.part} fwd:', input_ids.shape, flush=True)
|
print(f'pp stage: {gpc.get_local_rank(ParallelMode.PIPELINE)} MLP {self.part} fwd:', input_ids.shape, flush=True)
|
||||||
print(gpc.get_global_rank(), 'len_blocsk:', len(self.blocks), flush=True)
|
print(gpc.get_global_rank(), 'len_blocsk:', len(self.blocks), flush=True)
|
||||||
current_device = torch.cuda.current_device()
|
current_device = torch.cuda.current_device()
|
||||||
print(gpc.get_global_rank(), 'current_device:', current_device, flush=True)
|
print(gpc.get_global_rank(), 'current_device:', current_device, flush=True)
|
||||||
input_ids = input_ids.to(current_device)
|
input_ids = input_ids.to(current_device)
|
||||||
print(gpc.get_global_rank(), 'mlp_input_data:', input_ids, input_ids.shape, type(input_ids), flush=True)
|
print(gpc.get_global_rank(), 'mlp_input_data:', input_ids, input_ids.shape, type(input_ids), flush=True)
|
||||||
x = self.blocks[0](input_ids) + self.blocks[1](input_ids)
|
x = self.blocks[0](input_ids) + self.blocks[1](input_ids)
|
||||||
print(gpc.get_global_rank(), 'mlp_output_data:', x, x.shape, flush=True)
|
print(gpc.get_global_rank(), 'mlp_output_data:', x, x.shape, flush=True)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
config = Config(
|
config = Config(
|
||||||
dict(
|
dict(
|
||||||
|
gradient_handler=[dict(type="PipelineSharedModuleGradientHandler")],
|
||||||
HIDDEN_SIZE=4,
|
HIDDEN_SIZE=4,
|
||||||
parallel=dict(zero1=1, pipeline=dict(size=8, interleaved_overlap=False), sequence_parallel=False, tensor=1),
|
parallel=dict(zero1=1, pipeline=dict(size=8, interleaved_overlap=False), sequence_parallel=False, tensor=1),
|
||||||
model_type="INTERNLM",
|
model_type="INTERNLM",
|
||||||
|
@ -243,21 +248,23 @@ def exam_pipeline_parallel(args):
|
||||||
# optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=pp_model)
|
# optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=pp_model)
|
||||||
# criterion = FlashGPTLMLoss(parallel_output=False, label_smoothing=0)
|
# criterion = FlashGPTLMLoss(parallel_output=False, label_smoothing=0)
|
||||||
|
|
||||||
from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
|
# from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
|
||||||
optimizer = BaseOptimizer(torch.optim.AdamW(
|
# optimizer = BaseOptimizer(torch.optim.AdamW(
|
||||||
params=[{"params": pp_model.parameters()}],
|
# params=[{"params": pp_model.parameters()}],
|
||||||
lr=1e-4,
|
# lr=1e-4,
|
||||||
betas=(0.9, 0.95),
|
# betas=(0.9, 0.95),
|
||||||
eps=1e-8,
|
# eps=1e-8,
|
||||||
))
|
# ))
|
||||||
|
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=pp_model)
|
||||||
|
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=0)
|
||||||
|
|
||||||
engine = Engine(
|
engine = Engine(
|
||||||
model=pp_model,
|
model=pp_model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
lr_scheduler=None,
|
lr_scheduler=lr_scheduler,
|
||||||
beta2_scheduler=None,
|
beta2_scheduler=beta2_scheduler,
|
||||||
criterion=MyLoss().to(dtype),
|
criterion=MyLoss().to(dtype),
|
||||||
gradient_handlers= [dict(type="PipelineSharedModuleGradientHandler")],
|
gradient_handlers= [PipelineSharedModuleGradientHandler(model=pp_model, optimizer=optimizer)],
|
||||||
clip_grad_norm=gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0),
|
clip_grad_norm=gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue