diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 7420da8f4..378f186a8 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -350,7 +350,7 @@ class ZeroDDP(ColoDDP): for tensor in chunk.get_tensors(): rec_p = torch.empty([0]) if record_flag: - rec_p = tensor.cpu() # move the whole tensor to CPU mem + rec_p = tensor.cpu() # move the whole tensor to CPU mem assert tensor not in param_to_save_data param_to_save_data[tensor] = rec_p # release the actual memory of the chunk @@ -406,7 +406,7 @@ class ZeroDDP(ColoDDP): state_dict = state_dict.copy() if metadata is not None: # mypy isn't aware that "_metadata" exists in state_dict - state_dict._metadata = metadata # type: ignore[attr-defined] + state_dict._metadata = metadata # type: ignore[attr-defined] prefix = '' local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})