diff --git a/colossalai/gemini/chunk.py b/colossalai/gemini/chunk.py index c39a06502..b454fc988 100644 --- a/colossalai/gemini/chunk.py +++ b/colossalai/gemini/chunk.py @@ -208,7 +208,7 @@ class Chunk: tensor (torch.Tensor): a torch Tensor object. tensor_state (TensorState): the target state for transition. """ - assert tensor != TensorState.FREE, 'Can only set a chunk of tensors to FREE' + # As the gradient hook can be triggered either before or after post-backward # tensor's state can be compute -> hold_after_bwd -> ready_for_reduce # or compute -> ready_for_reduce -> hold_after_bwd diff --git a/colossalai/utils/checkpoint/module_checkpoint.py b/colossalai/utils/checkpoint/module_checkpoint.py index 78dcfb681..a109b3702 100644 --- a/colossalai/utils/checkpoint/module_checkpoint.py +++ b/colossalai/utils/checkpoint/module_checkpoint.py @@ -89,6 +89,12 @@ def load_checkpoint(path: str, torch_load_kwargs: (dict, optional): The kwargs of torch.load inside the function load_state_dict_kwargs (dict, optional): The kwargs of load_state_dict inside the function """ + # initialize the default paramters + if not torch_load_kwargs: + torch_load_kwargs = dict() + if not load_state_dict_kwargs: + load_state_dict_kwargs = dict() + rank = dist.get_rank() mapping = dict() for n, p in model.named_parameters(): diff --git a/colossalai/utils/checkpoint/utils.py b/colossalai/utils/checkpoint/utils.py index cd6f85175..5652600ff 100644 --- a/colossalai/utils/checkpoint/utils.py +++ b/colossalai/utils/checkpoint/utils.py @@ -24,7 +24,7 @@ def gather_tensor(colo_tensor: ColoTensor) -> None: if not colo_tensor.is_replicate(): pg = colo_tensor.get_process_group() # for the group which contains rank 0 - if pg.tp_rank_list()[0] == 0: + if pg.dp_local_rank() == 0: old_dist_spec = colo_tensor.dist_spec colo_tensor.to_replicate_() if dist.get_rank() != 0: diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py index d25b17f10..a5ea75fff 100644 --- a/tests/test_utils/test_colo_checkpoint.py +++ b/tests/test_utils/test_colo_checkpoint.py @@ -146,6 +146,9 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch data = data.to(get_current_device()) label = label.to(get_current_device()) + dist.broadcast(data, pg.tp_rank_list()[0], pg.tp_process_group()) + dist.broadcast(label, pg.tp_rank_list()[0], pg.tp_process_group()) + # Bcast rank0 data to all processes if criterion: output = model(data) @@ -183,9 +186,9 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') pg = ProcessGroup(tp_degree=world_size) - # TODO(haichen) add BERT in the test + # the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context - for model_name in ['simple_net']: + for model_name in ['bert']: _run_checkpoint(model_name, init_1d_row_for_linear_weight_spec, use_ddp,