|
|
|
@ -193,8 +193,9 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|
|
|
|
if isinstance(val, torch.Tensor): |
|
|
|
|
self.chunk_manager.add_extern_static_tensor(val) |
|
|
|
|
|
|
|
|
|
def state_dict(self): |
|
|
|
|
r"""Returns the state of the optimizer as a :class:`dict`. For DP rank != 0, this function returns None. |
|
|
|
|
def state_dict(self, only_rank_0: bool = True): |
|
|
|
|
r"""Returns the state of the optimizer as a :class:`dict`. If only_rank_0 is True, for DP rank != 0, this function returns None. |
|
|
|
|
This saves memory usage. |
|
|
|
|
|
|
|
|
|
It contains two entries: |
|
|
|
|
|
|
|
|
@ -204,7 +205,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|
|
|
|
parameter group is a dict |
|
|
|
|
""" |
|
|
|
|
is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0 |
|
|
|
|
if not self.chunk_manager.enable_distributed_storage and not is_rank_0: |
|
|
|
|
if not self.chunk_manager.enable_distributed_storage and only_rank_0 and not is_rank_0: |
|
|
|
|
return |
|
|
|
|
optim_state_dict = super().state_dict() |
|
|
|
|
scaler_state_dict = self.grad_scaler.state_dict() |
|
|
|
@ -214,14 +215,17 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|
|
|
|
local_state = {k: convert_state_dict_to_cpu(v) for k, v in optim_state_dict['state'].items() if len(v) > 0} |
|
|
|
|
if not self.chunk_manager.process_group.has_cpu_groups: |
|
|
|
|
self.chunk_manager.process_group.set_cpu_groups() |
|
|
|
|
dst_rank = self.chunk_manager.process_group.dp_rank_list()[0] |
|
|
|
|
output = [None for _ in range(self.chunk_manager.process_group.dp_world_size())] |
|
|
|
|
dist.gather_object(local_state, |
|
|
|
|
output if self.chunk_manager.process_group.dp_local_rank() == 0 else None, |
|
|
|
|
dst=dst_rank, |
|
|
|
|
group=self.chunk_manager.process_group.cpu_dp_process_group()) |
|
|
|
|
if not is_rank_0: |
|
|
|
|
return |
|
|
|
|
if only_rank_0: |
|
|
|
|
dst_rank = self.chunk_manager.process_group.dp_rank_list()[0] |
|
|
|
|
dist.gather_object(local_state, |
|
|
|
|
output if self.chunk_manager.process_group.dp_local_rank() == 0 else None, |
|
|
|
|
dst=dst_rank, |
|
|
|
|
group=self.chunk_manager.process_group.cpu_dp_process_group()) |
|
|
|
|
if not is_rank_0: |
|
|
|
|
return |
|
|
|
|
else: |
|
|
|
|
dist.all_gather_object(output, local_state, group=self.chunk_manager.process_group.cpu_dp_process_group()) |
|
|
|
|
for state in output: |
|
|
|
|
optim_state_dict['state'].update(state) |
|
|
|
|
return optim_state_dict |
|
|
|
|