mirror of https://github.com/InternLM/InternLM
feat: add pp test
parent
3034e73c42
commit
033e646191
|
@ -47,7 +47,19 @@ class MlpModel(nn.Module):
|
||||||
self.linear1 = nn.Linear(4, 8)
|
self.linear1 = nn.Linear(4, 8)
|
||||||
self.linear2 = nn.Linear(8, 8)
|
self.linear2 = nn.Linear(8, 8)
|
||||||
self.linear3 = nn.Linear(8, 8)
|
self.linear3 = nn.Linear(8, 8)
|
||||||
self.linear4 = nn.Linear(8, 4)
|
self.linear4 = nn.Linear(8, 8)
|
||||||
|
self.linear5 = nn.Linear(8, 8)
|
||||||
|
self.linear6 = nn.Linear(8, 8)
|
||||||
|
self.linear7 = nn.Linear(8, 8)
|
||||||
|
self.linear8 = nn.Linear(8, 8)
|
||||||
|
self.linear9 = nn.Linear(8, 8)
|
||||||
|
self.linear10 = nn.Linear(8, 8)
|
||||||
|
self.linear11 = nn.Linear(8, 8)
|
||||||
|
self.linear12 = nn.Linear(8, 8)
|
||||||
|
self.linear13 = nn.Linear(8, 8)
|
||||||
|
self.linear14 = nn.Linear(8, 8)
|
||||||
|
self.linear15 = nn.Linear(8, 8)
|
||||||
|
self.linear16 = nn.Linear(8, 4)
|
||||||
|
|
||||||
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
||||||
print('MLP:', input_ids, input_ids.dtype, flush=True)
|
print('MLP:', input_ids, input_ids.dtype, flush=True)
|
||||||
|
@ -55,13 +67,26 @@ class MlpModel(nn.Module):
|
||||||
input_ids = self.linear2(input_ids)
|
input_ids = self.linear2(input_ids)
|
||||||
input_ids = self.linear3(input_ids)
|
input_ids = self.linear3(input_ids)
|
||||||
input_ids = self.linear4(input_ids)
|
input_ids = self.linear4(input_ids)
|
||||||
|
input_ids = self.linear5(input_ids)
|
||||||
|
input_ids = self.linear6(input_ids)
|
||||||
|
input_ids = self.linear7(input_ids)
|
||||||
|
input_ids = self.linear8(input_ids)
|
||||||
|
input_ids = self.linear9(input_ids)
|
||||||
|
input_ids = self.linear10(input_ids)
|
||||||
|
input_ids = self.linear11(input_ids)
|
||||||
|
input_ids = self.linear12(input_ids)
|
||||||
|
input_ids = self.linear13(input_ids)
|
||||||
|
input_ids = self.linear14(input_ids)
|
||||||
|
input_ids = self.linear15(input_ids)
|
||||||
|
input_ids = self.linear16(input_ids)
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
config = Config(
|
config = Config(
|
||||||
dict(
|
dict(
|
||||||
parallel=dict(zero1=1, pipeline=dict(size=2, interleaved_overlap=False), sequence_parallel=False, tensor=1),
|
HIDDEN_SIZE=4,
|
||||||
|
parallel=dict(zero1=1, pipeline=dict(size=8, interleaved_overlap=False), sequence_parallel=False, tensor=1),
|
||||||
model_type="INTERNLM",
|
model_type="INTERNLM",
|
||||||
data=dict(seq_len=2048, micro_num=1, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999),
|
data=dict(seq_len=4, micro_num=4, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999),
|
||||||
model=dict(
|
model=dict(
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
),
|
),
|
||||||
|
@ -117,7 +142,7 @@ def build_environment(rank, world_size):
|
||||||
os.environ["LOCAL_RANK"] = str(rank)
|
os.environ["LOCAL_RANK"] = str(rank)
|
||||||
os.environ["WORLD_SIZE"] = str(world_size)
|
os.environ["WORLD_SIZE"] = str(world_size)
|
||||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||||
os.environ["MASTER_PORT"] = "12345"
|
os.environ["MASTER_PORT"] = "44444"
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
# launcher="torch"
|
# launcher="torch"
|
||||||
internlm.launch_from_torch(config=config, seed=1024)
|
internlm.launch_from_torch(config=config, seed=1024)
|
||||||
|
@ -155,17 +180,28 @@ def seed_all(seed, cuda_deterministic=False):
|
||||||
|
|
||||||
|
|
||||||
def exam_pipeline_parallel(args):
|
def exam_pipeline_parallel(args):
|
||||||
|
import os
|
||||||
rank, world_size = args
|
rank, world_size = args
|
||||||
dtype = torch.bfloat16
|
dtype = torch.float32
|
||||||
|
|
||||||
build_environment(rank, world_size)
|
build_environment(rank, world_size)
|
||||||
|
local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
|
print('rank_com:', rank, local_rank)
|
||||||
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
|
# print('device_id:', device)
|
||||||
|
# torch.cuda.set_device(device)
|
||||||
seed_all(1024)
|
seed_all(1024)
|
||||||
|
|
||||||
torch_model = MlpModel().cuda()
|
torch_model = MlpModel().to(device)
|
||||||
pp_model = copy.deepcopy(torch_model).to(dtype)
|
pp_model = copy.deepcopy(torch_model).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
tensor_shape = get_tensor_shape()
|
tensor_shape = get_tensor_shape()
|
||||||
|
tensor_shape = (
|
||||||
|
4,
|
||||||
|
4,
|
||||||
|
)
|
||||||
|
# print('tensor_shape:', tensor_shape)
|
||||||
|
|
||||||
scatter_gather = gpc.is_initialized(ParallelMode.TENSOR)
|
scatter_gather = gpc.is_initialized(ParallelMode.TENSOR)
|
||||||
|
|
||||||
|
@ -179,7 +215,8 @@ def exam_pipeline_parallel(args):
|
||||||
skip=False
|
skip=False
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
|
||||||
scheduler = PipelineScheduler(
|
scheduler = PipelineScheduler(
|
||||||
data_process_func=None,
|
data_process_func=None,
|
||||||
num_microbatches=gpc.config.data.micro_num,
|
num_microbatches=gpc.config.data.micro_num,
|
||||||
|
@ -198,7 +235,7 @@ def exam_pipeline_parallel(args):
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
beta2_scheduler=beta2_scheduler,
|
beta2_scheduler=beta2_scheduler,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
gradient_handlers=[],
|
gradient_handlers= [dict(type="PipelineSharedModuleGradientHandler")],
|
||||||
clip_grad_norm=gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0),
|
clip_grad_norm=gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -206,10 +243,10 @@ def exam_pipeline_parallel(args):
|
||||||
engine.train()
|
engine.train()
|
||||||
engine.zero_grad()
|
engine.zero_grad()
|
||||||
|
|
||||||
input_list = [{'input_ids':torch.tensor([[0,1,2,3]]).cuda().to(dtype)},
|
input_list = [{'input_ids':torch.tensor([[0,1,2,3],[0,1,2,3],[0,1,2,3],[0,1,2,3]]).to(device).to(dtype)},
|
||||||
torch.tensor([[1]]).cuda().to(torch.int64)]
|
torch.tensor([[1],[1],[1],[1]]).to(device).to(torch.int64)]
|
||||||
torch_input = torch.tensor([[0,1,2,3]]).cuda().to(torch.float32)
|
torch_input = torch.tensor([[0,1,2,3]]).to(device).to(torch.float32)
|
||||||
torch_label = torch.tensor([[1]]).cuda().to(torch.int64)
|
torch_label = torch.tensor([[1]]).to(device).to(torch.int64)
|
||||||
# print('label_shape:', input_list[1].shape)
|
# print('label_shape:', input_list[1].shape)
|
||||||
# input_list = [{'input_ids':torch.rand(1, 4).cuda()}, torch.rand(1, 4).cuda()]
|
# input_list = [{'input_ids':torch.rand(1, 4).cuda()}, torch.rand(1, 4).cuda()]
|
||||||
# input = input_list[0]
|
# input = input_list[0]
|
||||||
|
|
Loading…
Reference in New Issue