mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix a running error in test_colo_checkpoint.py (#1387)
parent
f792507ff3
commit
527758b2ae
|
@ -208,7 +208,7 @@ class Chunk:
|
|||
tensor (torch.Tensor): a torch Tensor object.
|
||||
tensor_state (TensorState): the target state for transition.
|
||||
"""
|
||||
assert tensor != TensorState.FREE, 'Can only set a chunk of tensors to FREE'
|
||||
|
||||
# As the gradient hook can be triggered either before or after post-backward
|
||||
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
|
||||
# or compute -> ready_for_reduce -> hold_after_bwd
|
||||
|
|
|
@ -89,6 +89,12 @@ def load_checkpoint(path: str,
|
|||
torch_load_kwargs: (dict, optional): The kwargs of torch.load inside the function
|
||||
load_state_dict_kwargs (dict, optional): The kwargs of load_state_dict inside the function
|
||||
"""
|
||||
# initialize the default paramters
|
||||
if not torch_load_kwargs:
|
||||
torch_load_kwargs = dict()
|
||||
if not load_state_dict_kwargs:
|
||||
load_state_dict_kwargs = dict()
|
||||
|
||||
rank = dist.get_rank()
|
||||
mapping = dict()
|
||||
for n, p in model.named_parameters():
|
||||
|
|
|
@ -24,7 +24,7 @@ def gather_tensor(colo_tensor: ColoTensor) -> None:
|
|||
if not colo_tensor.is_replicate():
|
||||
pg = colo_tensor.get_process_group()
|
||||
# for the group which contains rank 0
|
||||
if pg.tp_rank_list()[0] == 0:
|
||||
if pg.dp_local_rank() == 0:
|
||||
old_dist_spec = colo_tensor.dist_spec
|
||||
colo_tensor.to_replicate_()
|
||||
if dist.get_rank() != 0:
|
||||
|
|
|
@ -146,6 +146,9 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
|||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
dist.broadcast(data, pg.tp_rank_list()[0], pg.tp_process_group())
|
||||
dist.broadcast(label, pg.tp_rank_list()[0], pg.tp_process_group())
|
||||
|
||||
# Bcast rank0 data to all processes
|
||||
if criterion:
|
||||
output = model(data)
|
||||
|
@ -183,9 +186,9 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
|||
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
# TODO(haichen) add BERT in the test
|
||||
|
||||
# the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context
|
||||
for model_name in ['simple_net']:
|
||||
for model_name in ['bert']:
|
||||
_run_checkpoint(model_name,
|
||||
init_1d_row_for_linear_weight_spec,
|
||||
use_ddp,
|
||||
|
|
Loading…
Reference in New Issue