mirror of https://github.com/InternLM/InternLM
208 lines
6.6 KiB
Python
208 lines
6.6 KiB
Python
import random
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
from torch.testing import assert_close
|
|
|
|
import internlm
|
|
from internlm.core.context import ParallelMode
|
|
from internlm.core.context import global_context as gpc
|
|
from internlm.core.engine import Engine
|
|
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
|
|
from internlm.core.scheduler import (
|
|
InterleavedPipelineScheduler,
|
|
NonPipelineScheduler,
|
|
PipelineScheduler,
|
|
SchedulerMetricHook,
|
|
)
|
|
from internlm.solver.pipeline_utils import partition_uniform
|
|
from internlm.train import initialize_optimizer
|
|
|
|
|
|
class MlpModel(nn.Module):
|
|
"""
|
|
Custom model
|
|
"""
|
|
|
|
def __init__(self, start, end, model_type=None, embedding=False):
|
|
super().__init__()
|
|
self.part = [start, end]
|
|
self.blocks = nn.ModuleList([nn.Linear(8, 8, bias=False) for lid in range(end - start)])
|
|
self.model_type = model_type
|
|
self.embedding = embedding
|
|
|
|
def forward(
|
|
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None
|
|
): # pylint: disable=W0613
|
|
if self.model_type != "torch" and self.part[0] != 0:
|
|
input_ids = hidden_states
|
|
|
|
# Simulate Embedding.
|
|
if self.embedding:
|
|
if len(input_ids.shape) == 2:
|
|
input_ids = input_ids.view(-1, 8)
|
|
elif len(input_ids.shape) == 3:
|
|
input_ids = input_ids.view(input_ids.shape(0), -1, 8)
|
|
|
|
for i in range(self.part[1] - self.part[0]):
|
|
input_ids = self.blocks[i](input_ids)
|
|
|
|
return input_ids
|
|
|
|
|
|
class MyLoss(nn.Module):
|
|
"""
|
|
Custom loss
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, logits, labels):
|
|
loss = torch.nn.MSELoss(reduction="sum")
|
|
return loss(logits, labels)
|
|
|
|
|
|
def init_model_and_optim(
|
|
num_layers, num_chunks, dtype, micro_num, interleaved_overlap, tensor_shape, init_optim=True, embedding=False
|
|
):
|
|
# pp model
|
|
pp_model = _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, embedding=embedding)
|
|
pp_model = pp_model.to(dtype)
|
|
|
|
# pp scheduler
|
|
scheduler_hooks = [
|
|
SchedulerMetricHook(skip=True),
|
|
]
|
|
|
|
if gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
|
if num_chunks == 1:
|
|
# noninterleaved pp
|
|
scheduler = PipelineScheduler(
|
|
data_process_func=None,
|
|
num_microbatches=micro_num,
|
|
dtype=dtype,
|
|
tensor_shape=tensor_shape,
|
|
scatter_gather_tensors=False,
|
|
scheduler_hooks=scheduler_hooks,
|
|
)
|
|
else:
|
|
# interleaved pp
|
|
if micro_num < gpc.get_world_size(ParallelMode.PIPELINE):
|
|
try:
|
|
scheduler = InterleavedPipelineScheduler(
|
|
num_microbatches=micro_num,
|
|
num_chunks=gpc.config.model.num_chunks,
|
|
dtype=dtype,
|
|
tensor_shape=tensor_shape,
|
|
scatter_gather_tensors=False,
|
|
scheduler_hooks=scheduler_hooks,
|
|
communication_overlap=interleaved_overlap,
|
|
)
|
|
except AssertionError as e:
|
|
print(f"AssertionError: {e}", flush=True)
|
|
return None, None
|
|
else:
|
|
raise RuntimeError(
|
|
"Error: AssertionError should occur when micro_num < Pipeline parrallel world size"
|
|
)
|
|
else:
|
|
scheduler = InterleavedPipelineScheduler(
|
|
num_microbatches=micro_num,
|
|
num_chunks=gpc.config.model.num_chunks,
|
|
dtype=dtype,
|
|
tensor_shape=tensor_shape,
|
|
scatter_gather_tensors=False,
|
|
scheduler_hooks=scheduler_hooks,
|
|
communication_overlap=interleaved_overlap,
|
|
)
|
|
else:
|
|
scheduler = NonPipelineScheduler(
|
|
data_process_func=None,
|
|
gradient_accumulation_size=gpc.config.data.gradient_accumulation,
|
|
scheduler_hooks=scheduler_hooks,
|
|
)
|
|
|
|
# pp optimizer and engine
|
|
if init_optim:
|
|
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=pp_model)
|
|
else:
|
|
optimizer, beta2_scheduler, lr_scheduler = None, None, None
|
|
|
|
engine = Engine(
|
|
model=pp_model,
|
|
optimizer=optimizer,
|
|
lr_scheduler=lr_scheduler,
|
|
beta2_scheduler=beta2_scheduler,
|
|
criterion=MyLoss().to(dtype),
|
|
gradient_handlers=[PipelineSharedModuleGradientHandler(model=pp_model, optimizer=optimizer)],
|
|
clip_grad_norm=0.0,
|
|
)
|
|
return engine, scheduler
|
|
|
|
|
|
def build_environment(rank, world_size, config):
|
|
import os
|
|
|
|
os.environ["RANK"] = str(rank)
|
|
os.environ["LOCAL_RANK"] = str(rank)
|
|
os.environ["WORLD_SIZE"] = str(world_size)
|
|
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
|
os.environ["MASTER_PORT"] = "33333"
|
|
torch.cuda.empty_cache()
|
|
# launcher="torch"
|
|
internlm.launch_from_torch(config=config, seed=1024)
|
|
|
|
|
|
def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
|
|
|
if dtype is torch.float32:
|
|
rtol = 1.3e-6
|
|
atol = 1e-5
|
|
elif dtype is torch.bfloat16:
|
|
rtol = 2e-2
|
|
atol = 2e-2
|
|
|
|
if isinstance(a, torch.Tensor):
|
|
a = a.detach().to(dtype)
|
|
b = b.detach().to(dtype)
|
|
|
|
assert_close(a, b, rtol=rtol, atol=atol)
|
|
|
|
|
|
def seed_all(seed, cuda_deterministic=False):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
if cuda_deterministic: # slower, more reproducible
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
else:
|
|
torch.backends.cudnn.deterministic = False
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
|
def _build_generic_model_1d(num_layers, num_chunks, embedding=False):
|
|
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
|
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
|
|
|
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
|
|
parts = all_parts[pipeline_rank]
|
|
if gpc.is_rank_for_log():
|
|
print(f"The layer sharding is {all_parts}.", flush=True)
|
|
|
|
models = []
|
|
for start, end in parts:
|
|
models.append(MlpModel(start, end, embedding=embedding).cuda())
|
|
torch.distributed.barrier()
|
|
if len(models) == 1:
|
|
model = models[0]
|
|
else:
|
|
model = nn.ModuleList(models)
|
|
|
|
return model
|