mirror of https://github.com/hpcaitech/ColossalAI
[checkpoint] use gather_tensor in checkpoint and update its unit test (#1339)
parent
f3ce7b8336
commit
f92c100ddd
|
@ -262,7 +262,7 @@ class ColoTensor(torch.Tensor):
|
||||||
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
|
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
|
||||||
return replicated_t.view(*args)
|
return replicated_t.view(*args)
|
||||||
|
|
||||||
def size_global(self, args: Optional[int] = None):
|
def size_global(self, args: Optional[int] = None) -> torch.Size:
|
||||||
"""override the torch buildin size()
|
"""override the torch buildin size()
|
||||||
the shape passed in must be in a replicate placement.
|
the shape passed in must be in a replicate placement.
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -141,9 +141,18 @@ class ProcessGroup:
|
||||||
def rank(self):
|
def rank(self):
|
||||||
return self._rank
|
return self._rank
|
||||||
|
|
||||||
|
def ranks_in_group(self):
|
||||||
|
return self._rank_list
|
||||||
|
|
||||||
def world_size(self):
|
def world_size(self):
|
||||||
return self._world_size
|
return self._world_size
|
||||||
|
|
||||||
|
def tp_rank_list(self):
|
||||||
|
return self._tp_rank_list
|
||||||
|
|
||||||
|
def dp_rank_list(self):
|
||||||
|
return self._dp_rank_list
|
||||||
|
|
||||||
def tp_local_rank(self):
|
def tp_local_rank(self):
|
||||||
return self._rank % self._tp_degree
|
return self._rank % self._tp_degree
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossalai.tensor import ColoTensor, DistSpecManager
|
from colossalai.tensor import ColoTensor
|
||||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||||
from copy import copy
|
from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,37 +22,52 @@ def save_checkpoint(dire: str,
|
||||||
optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None.
|
optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None.
|
||||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
|
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
rank = dist.get_rank()
|
||||||
|
model_state = model.state_dict()
|
||||||
|
# save the dist context about the tensors in a new dict, while still maintain the original dict.
|
||||||
|
for k, v in model_state.items():
|
||||||
|
if isinstance(v, ColoTensor):
|
||||||
|
gather_tensor(v) # gather shared tensors to rank0
|
||||||
|
# don't recover tensors in rank0, since the dict is only a copy of model
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
# sanity check
|
||||||
|
for k, v in model_state.items():
|
||||||
|
if isinstance(v, ColoTensor):
|
||||||
|
assert v.save_ready
|
||||||
|
assert v.is_replicate()
|
||||||
|
delattr(v, 'save_ready')
|
||||||
|
# model saving
|
||||||
|
save_state = {'epoch': epoch, 'model': model_state}
|
||||||
|
torch.save(save_state, dire + '/epoch_{}_model.pth'.format(epoch))
|
||||||
|
|
||||||
|
# delete old dicts
|
||||||
|
del model_state
|
||||||
|
# synchronize all the processes
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
mapping = dict()
|
mapping = dict()
|
||||||
new_dict = dict()
|
optim_state = optimizer.state_dict()
|
||||||
|
for k, v in optim_state['state'].items():
|
||||||
# save the dist context about the tensors in a new dict, while still maintain the original dict.
|
|
||||||
for k, v in model.state_dict().items():
|
|
||||||
if isinstance(v, ColoTensor):
|
|
||||||
mapping[k] = (v.dist_spec, v.compute_spec)
|
|
||||||
new_dict[k] = v.to_replicate().detach()
|
|
||||||
else:
|
|
||||||
new_dict[k] = v
|
|
||||||
if dist.get_rank() == 0:
|
|
||||||
for k, v in new_dict.items():
|
|
||||||
if isinstance(v, ColoTensor):
|
|
||||||
assert v.is_replicate()
|
|
||||||
|
|
||||||
model_state = {'epoch': epoch, 'model': new_dict}
|
|
||||||
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))
|
|
||||||
|
|
||||||
# delete the new dict
|
|
||||||
del new_dict
|
|
||||||
|
|
||||||
optim_state_copy = copy(optimizer.state_dict())
|
|
||||||
for k, v in optim_state_copy['state'].items():
|
|
||||||
for n, t in v.items():
|
for n, t in v.items():
|
||||||
if isinstance(t, ColoTensor):
|
if isinstance(t, ColoTensor):
|
||||||
t.to_replicate_()
|
mapping[(k, n)] = t.dist_spec
|
||||||
if dist.get_rank() == 0:
|
gather_tensor(t)
|
||||||
model_state = {'epoch': epoch, 'optim': optim_state_copy}
|
|
||||||
torch.save(model_state, dire + '/epoch_{}_optim.pth'.format(epoch))
|
if rank == 0:
|
||||||
del optim_state_copy
|
save_state = {'epoch': epoch, 'optim': optim_state}
|
||||||
|
torch.save(save_state, dire + '/epoch_{}_optim.pth'.format(epoch))
|
||||||
|
# recover colo tensors in rank0
|
||||||
|
for k, v in optimizer.state_dict()['state'].items():
|
||||||
|
for n, t in v.items():
|
||||||
|
if isinstance(t, ColoTensor):
|
||||||
|
assert hasattr(t, 'save_ready')
|
||||||
|
t.set_dist_spec(mapping[(k, n)])
|
||||||
|
delattr(t, 'save_ready')
|
||||||
|
|
||||||
|
del optim_state
|
||||||
|
del mapping
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(dire,
|
def load_checkpoint(dire,
|
||||||
|
@ -72,39 +87,42 @@ def load_checkpoint(dire,
|
||||||
optimizer (ColossalaiOptimizer, optional): _description_. Defaults to None.
|
optimizer (ColossalaiOptimizer, optional): _description_. Defaults to None.
|
||||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
|
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
rank = dist.get_rank()
|
||||||
mapping = dict()
|
mapping = dict()
|
||||||
for k, v in model.state_dict().items():
|
for n, p in model.named_parameters():
|
||||||
if isinstance(v, ColoTensor):
|
if isinstance(p, ColoTensor):
|
||||||
mapping[k] = (v.dist_spec, v.compute_spec)
|
mapping[n] = p.dist_spec
|
||||||
v.to_replicate_()
|
gather_tensor(p)
|
||||||
|
|
||||||
model_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch))
|
if rank == 0:
|
||||||
model.load_state_dict(model_state['model'])
|
load_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch))
|
||||||
|
model.load_state_dict(load_state['model'])
|
||||||
# reset tensors to original dist spec.
|
dist.barrier()
|
||||||
with DistSpecManager.no_grad():
|
|
||||||
for k, v in model.state_dict().items():
|
|
||||||
if isinstance(v, ColoTensor):
|
|
||||||
v.set_tensor_spec(*mapping[k])
|
|
||||||
|
|
||||||
|
# scatter loaded parameters
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
if isinstance(p, ColoTensor):
|
||||||
|
scatter_tensor(p, mapping[n])
|
||||||
|
if rank == 0:
|
||||||
|
assert hasattr(p, 'save_ready')
|
||||||
|
delattr(p, 'save_ready')
|
||||||
del mapping
|
del mapping
|
||||||
mapping = dict()
|
|
||||||
|
|
||||||
|
mapping = dict()
|
||||||
for k, v in optimizer.state_dict()['state'].items():
|
for k, v in optimizer.state_dict()['state'].items():
|
||||||
for n, t in v.items():
|
for n, t in v.items():
|
||||||
if isinstance(t, ColoTensor):
|
if isinstance(t, ColoTensor):
|
||||||
mapping[(k, n)] = (t.dist_spec, t.compute_spec)
|
mapping[(k, n)] = t.dist_spec
|
||||||
t.to_replicate_()
|
gather_tensor(t)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
|
colo_checkpoint = torch.load(dire + '/epoch_{}_optim.pth'.format(epoch))
|
||||||
optimizer.load_state_dict(colo_checkpoint['optim'])
|
optimizer.load_state_dict(colo_checkpoint['optim'])
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
for k, v in optimizer.state_dict()['state'].items():
|
for k, v in optimizer.state_dict()['state'].items():
|
||||||
for n, t in v.items():
|
for n, t in v.items():
|
||||||
if isinstance(t, ColoTensor):
|
if isinstance(t, ColoTensor):
|
||||||
# skip key not in mapping.
|
scatter_tensor(t, mapping[(k, n)])
|
||||||
# For Adam, if it dose not execute step() once, there will be not exp_avg and exp_avg_sq in optimizer
|
|
||||||
if (k, n) not in mapping:
|
del mapping
|
||||||
continue
|
|
||||||
t.set_tensor_spec(*mapping[(k, n)])
|
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from colossalai.tensor import ColoTensor, ColoTensorSpec
|
||||||
|
from colossalai.tensor.distspec import _DistSpec
|
||||||
|
|
||||||
|
|
||||||
|
def gather_tensor(colo_tensor: ColoTensor) -> None:
|
||||||
|
"""Make colo_tensor replicated when the rank is 0
|
||||||
|
"""
|
||||||
|
if not colo_tensor.is_replicate():
|
||||||
|
pg = colo_tensor.get_process_group()
|
||||||
|
# for the group which contains rank 0
|
||||||
|
if pg.tp_rank_list()[0] == 0:
|
||||||
|
old_dist_spec = colo_tensor.dist_spec
|
||||||
|
colo_tensor.to_replicate_()
|
||||||
|
if dist.get_rank() != 0:
|
||||||
|
colo_tensor.set_dist_spec(old_dist_spec)
|
||||||
|
|
||||||
|
# synchronize all processes for unexpected problems
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
setattr(colo_tensor, 'save_ready', True) # set saving signitrue
|
||||||
|
|
||||||
|
|
||||||
|
def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
|
||||||
|
"""Reversal operation of `gather_tensor`.
|
||||||
|
"""
|
||||||
|
if dist_spec.placement == 'r':
|
||||||
|
dist.broadcast(colo_tensor.data, 0)
|
||||||
|
else:
|
||||||
|
global_size = colo_tensor.size_global()
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
entire_data = colo_tensor.data
|
||||||
|
else:
|
||||||
|
entire_data = torch.empty(global_size, device=colo_tensor.device)
|
||||||
|
dist.broadcast(entire_data, 0)
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
colo_tensor.set_dist_spec(dist_spec)
|
||||||
|
else:
|
||||||
|
rep_tensor = ColoTensor(entire_data, ColoTensorSpec(
|
||||||
|
pg=colo_tensor.get_process_group(),
|
||||||
|
compute_attr=colo_tensor.compute_spec))
|
||||||
|
rep_tensor.set_dist_spec(dist_spec)
|
||||||
|
with torch.no_grad():
|
||||||
|
colo_tensor.data.copy_(rep_tensor.data)
|
||||||
|
# synchronize all processes for unexpected problems
|
||||||
|
dist.barrier()
|
|
@ -1,6 +1,7 @@
|
||||||
import os, shutil
|
import os, shutil
|
||||||
import torch
|
import torch
|
||||||
import pytest
|
import pytest
|
||||||
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
@ -15,8 +16,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, DistSpecManager, ReplicaSpec
|
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup
|
||||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
|
||||||
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
|
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
|
||||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||||
|
|
||||||
|
@ -63,8 +63,8 @@ def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
|
||||||
|
|
||||||
|
|
||||||
def check_param_equal(model, torch_model):
|
def check_param_equal(model, torch_model):
|
||||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
|
||||||
assert torch.allclose(torch_p, p, rtol=1e-3, atol=1e-1)
|
assert torch.all(p.data == tp.data), "{} went wrong.\n {} vs {}\n{}".format(n, p, tp, p.shape)
|
||||||
|
|
||||||
|
|
||||||
def remove(path):
|
def remove(path):
|
||||||
|
@ -84,9 +84,13 @@ def compare_optims(optim1, optim2):
|
||||||
if k not in state2:
|
if k not in state2:
|
||||||
continue
|
continue
|
||||||
p2 = state2[k]
|
p2 = state2[k]
|
||||||
if isinstance(p1, ColoTensor):
|
for n, t1 in p1.items():
|
||||||
assert isinstance(p2, ColoTensor)
|
if n not in p2:
|
||||||
assert torch.allclose(p1.to_replicate_(), p2.to_replicate_(), rtol=1e-3, atol=1e-1)
|
continue
|
||||||
|
t2 = p2[n]
|
||||||
|
if isinstance(t1, ColoTensor):
|
||||||
|
assert isinstance(t2, ColoTensor)
|
||||||
|
assert torch.allclose(t1, t2, rtol=0, atol=0)
|
||||||
|
|
||||||
|
|
||||||
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
|
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
|
||||||
|
@ -99,7 +103,6 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
||||||
# set_seed(1)
|
# set_seed(1)
|
||||||
with ColoInitContext(device=get_current_device()):
|
with ColoInitContext(device=get_current_device()):
|
||||||
model = model_builder(checkpoint=True)
|
model = model_builder(checkpoint=True)
|
||||||
model_reload = model_builder(checkpoint=True)
|
|
||||||
|
|
||||||
if use_mp_reload:
|
if use_mp_reload:
|
||||||
if 'bert' == model_name:
|
if 'bert' == model_name:
|
||||||
|
@ -119,25 +122,26 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
||||||
elif 'token_type_embeddings' in name and 'weight' in name:
|
elif 'token_type_embeddings' in name and 'weight' in name:
|
||||||
init_1d_col_embedding(p, pg)
|
init_1d_col_embedding(p, pg)
|
||||||
elif p.process_group.tp_world_size() == 1:
|
elif p.process_group.tp_world_size() == 1:
|
||||||
p.redistribute(ReplicaSpec(), pg)
|
p.set_process_group(pg)
|
||||||
elif "simple_net" == model_name:
|
elif "simple_net" == model_name:
|
||||||
init_spec_func(model, pg)
|
init_spec_func(model, pg)
|
||||||
|
|
||||||
|
model_reload = deepcopy(model)
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
model.train()
|
model.eval()
|
||||||
|
|
||||||
model_reload = model_reload.cuda()
|
model_reload = model_reload.cuda()
|
||||||
model_reload.train()
|
model_reload.eval()
|
||||||
|
|
||||||
opt_class = torch.optim.Adam
|
opt_class = torch.optim.Adam
|
||||||
colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
|
colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
|
||||||
colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))
|
colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))
|
||||||
run_reload = False
|
|
||||||
|
|
||||||
for i, (data, label) in enumerate(train_dataloader):
|
for i, (data, label) in enumerate(train_dataloader):
|
||||||
|
|
||||||
# Zero grad
|
# Zero grad
|
||||||
colo_optimizer.zero_grad()
|
colo_optimizer.zero_grad()
|
||||||
|
colo_optimizer_reload.zero_grad()
|
||||||
|
|
||||||
data = data.to(get_current_device())
|
data = data.to(get_current_device())
|
||||||
label = label.to(get_current_device())
|
label = label.to(get_current_device())
|
||||||
|
@ -155,14 +159,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
||||||
loss.backward()
|
loss.backward()
|
||||||
loss_reload.backward()
|
loss_reload.backward()
|
||||||
|
|
||||||
if run_reload:
|
colo_optimizer.step()
|
||||||
colo_optimizer_reload.zero_grad()
|
|
||||||
if criterion:
|
|
||||||
output_reload = model_reload(data)
|
|
||||||
loss_reload = criterion(output_reload, label)
|
|
||||||
else:
|
|
||||||
loss_reload = model_reload(data, label)
|
|
||||||
loss_reload.backward()
|
|
||||||
colo_optimizer_reload.step()
|
colo_optimizer_reload.step()
|
||||||
|
|
||||||
if i > 2:
|
if i > 2:
|
||||||
|
@ -170,28 +167,25 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
||||||
|
|
||||||
if not os.path.isdir('./checkpoint') and rank == 0:
|
if not os.path.isdir('./checkpoint') and rank == 0:
|
||||||
os.mkdir('./checkpoint')
|
os.mkdir('./checkpoint')
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
|
save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
|
||||||
dist.barrier()
|
|
||||||
load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)
|
load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
# Since model is sharded, we merge them before param checking.
|
|
||||||
for p in model.parameters():
|
|
||||||
p.to_replicate_()
|
|
||||||
|
|
||||||
for p in model_reload.parameters():
|
|
||||||
p.to_replicate_()
|
|
||||||
|
|
||||||
check_param_equal(model, model_reload)
|
check_param_equal(model, model_reload)
|
||||||
compare_optims(colo_optimizer, colo_optimizer_reload)
|
compare_optims(colo_optimizer, colo_optimizer_reload)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
remove('./checkpoint')
|
remove('./checkpoint')
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
|
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
pg = ProcessGroup(tp_degree=world_size)
|
pg = ProcessGroup(tp_degree=world_size)
|
||||||
for model_name in ['simple_net', 'bert']:
|
# TODO(haichen) add BERT in the test
|
||||||
|
# the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context
|
||||||
|
for model_name in ['simple_net']:
|
||||||
_run_checkpoint(model_name,
|
_run_checkpoint(model_name,
|
||||||
init_1d_row_for_linear_weight_spec,
|
init_1d_row_for_linear_weight_spec,
|
||||||
use_ddp,
|
use_ddp,
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, ColoTensorSpec
|
||||||
|
from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor
|
||||||
|
from tests.test_tensor._utils import tensor_shard_equal
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port, dp_degree, tp_degree):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree)
|
||||||
|
x = torch.randn(4, 4, device=get_current_device())
|
||||||
|
param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg))
|
||||||
|
spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)
|
||||||
|
param.set_tensor_spec(*spec)
|
||||||
|
|
||||||
|
gather_tensor(param)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
assert torch.allclose(x, param.data, rtol=0, atol=0)
|
||||||
|
else:
|
||||||
|
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
scatter_tensor(param, spec[0])
|
||||||
|
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
|
||||||
|
assert param.requires_grad is True
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize('world_size', [4])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_checkpoint(world_size):
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port(), dp_degree=2, tp_degree=world_size // 2)
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_checkpoint(world_size=4)
|
Loading…
Reference in New Issue