[hotfix] fix a running error in test_colo_checkpoint.py (#1387)

pull/1386/head
HELSON 2022-07-29 15:58:06 +08:00 committed by GitHub
parent f792507ff3
commit 527758b2ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 13 additions and 4 deletions

View File

@ -208,7 +208,7 @@ class Chunk:
tensor (torch.Tensor): a torch Tensor object.
tensor_state (TensorState): the target state for transition.
"""
assert tensor != TensorState.FREE, 'Can only set a chunk of tensors to FREE'
# As the gradient hook can be triggered either before or after post-backward
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
# or compute -> ready_for_reduce -> hold_after_bwd

View File

@ -89,6 +89,12 @@ def load_checkpoint(path: str,
torch_load_kwargs: (dict, optional): The kwargs of torch.load inside the function
load_state_dict_kwargs (dict, optional): The kwargs of load_state_dict inside the function
"""
# initialize the default paramters
if not torch_load_kwargs:
torch_load_kwargs = dict()
if not load_state_dict_kwargs:
load_state_dict_kwargs = dict()
rank = dist.get_rank()
mapping = dict()
for n, p in model.named_parameters():

View File

@ -24,7 +24,7 @@ def gather_tensor(colo_tensor: ColoTensor) -> None:
if not colo_tensor.is_replicate():
pg = colo_tensor.get_process_group()
# for the group which contains rank 0
if pg.tp_rank_list()[0] == 0:
if pg.dp_local_rank() == 0:
old_dist_spec = colo_tensor.dist_spec
colo_tensor.to_replicate_()
if dist.get_rank() != 0:

View File

@ -146,6 +146,9 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
data = data.to(get_current_device())
label = label.to(get_current_device())
dist.broadcast(data, pg.tp_rank_list()[0], pg.tp_process_group())
dist.broadcast(label, pg.tp_rank_list()[0], pg.tp_process_group())
# Bcast rank0 data to all processes
if criterion:
output = model(data)
@ -183,9 +186,9 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size)
# TODO(haichen) add BERT in the test
# the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context
for model_name in ['simple_net']:
for model_name in ['bert']:
_run_checkpoint(model_name,
init_1d_row_for_linear_weight_spec,
use_ddp,