mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix a running error in test_colo_checkpoint.py (#1387)
parent
f792507ff3
commit
527758b2ae
|
@ -208,7 +208,7 @@ class Chunk:
|
||||||
tensor (torch.Tensor): a torch Tensor object.
|
tensor (torch.Tensor): a torch Tensor object.
|
||||||
tensor_state (TensorState): the target state for transition.
|
tensor_state (TensorState): the target state for transition.
|
||||||
"""
|
"""
|
||||||
assert tensor != TensorState.FREE, 'Can only set a chunk of tensors to FREE'
|
|
||||||
# As the gradient hook can be triggered either before or after post-backward
|
# As the gradient hook can be triggered either before or after post-backward
|
||||||
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
|
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
|
||||||
# or compute -> ready_for_reduce -> hold_after_bwd
|
# or compute -> ready_for_reduce -> hold_after_bwd
|
||||||
|
|
|
@ -89,6 +89,12 @@ def load_checkpoint(path: str,
|
||||||
torch_load_kwargs: (dict, optional): The kwargs of torch.load inside the function
|
torch_load_kwargs: (dict, optional): The kwargs of torch.load inside the function
|
||||||
load_state_dict_kwargs (dict, optional): The kwargs of load_state_dict inside the function
|
load_state_dict_kwargs (dict, optional): The kwargs of load_state_dict inside the function
|
||||||
"""
|
"""
|
||||||
|
# initialize the default paramters
|
||||||
|
if not torch_load_kwargs:
|
||||||
|
torch_load_kwargs = dict()
|
||||||
|
if not load_state_dict_kwargs:
|
||||||
|
load_state_dict_kwargs = dict()
|
||||||
|
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
mapping = dict()
|
mapping = dict()
|
||||||
for n, p in model.named_parameters():
|
for n, p in model.named_parameters():
|
||||||
|
|
|
@ -24,7 +24,7 @@ def gather_tensor(colo_tensor: ColoTensor) -> None:
|
||||||
if not colo_tensor.is_replicate():
|
if not colo_tensor.is_replicate():
|
||||||
pg = colo_tensor.get_process_group()
|
pg = colo_tensor.get_process_group()
|
||||||
# for the group which contains rank 0
|
# for the group which contains rank 0
|
||||||
if pg.tp_rank_list()[0] == 0:
|
if pg.dp_local_rank() == 0:
|
||||||
old_dist_spec = colo_tensor.dist_spec
|
old_dist_spec = colo_tensor.dist_spec
|
||||||
colo_tensor.to_replicate_()
|
colo_tensor.to_replicate_()
|
||||||
if dist.get_rank() != 0:
|
if dist.get_rank() != 0:
|
||||||
|
|
|
@ -146,6 +146,9 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
||||||
data = data.to(get_current_device())
|
data = data.to(get_current_device())
|
||||||
label = label.to(get_current_device())
|
label = label.to(get_current_device())
|
||||||
|
|
||||||
|
dist.broadcast(data, pg.tp_rank_list()[0], pg.tp_process_group())
|
||||||
|
dist.broadcast(label, pg.tp_rank_list()[0], pg.tp_process_group())
|
||||||
|
|
||||||
# Bcast rank0 data to all processes
|
# Bcast rank0 data to all processes
|
||||||
if criterion:
|
if criterion:
|
||||||
output = model(data)
|
output = model(data)
|
||||||
|
@ -183,9 +186,9 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
||||||
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
|
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
pg = ProcessGroup(tp_degree=world_size)
|
pg = ProcessGroup(tp_degree=world_size)
|
||||||
# TODO(haichen) add BERT in the test
|
|
||||||
# the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context
|
# the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context
|
||||||
for model_name in ['simple_net']:
|
for model_name in ['bert']:
|
||||||
_run_checkpoint(model_name,
|
_run_checkpoint(model_name,
|
||||||
init_1d_row_for_linear_weight_spec,
|
init_1d_row_for_linear_weight_spec,
|
||||||
use_ddp,
|
use_ddp,
|
||||||
|
|
Loading…
Reference in New Issue