diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 9decf6ffd..da860721c 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -5,6 +5,7 @@ from contextlib import contextmanager from functools import partial from typing import Dict, Iterator, List, Optional, Tuple +import numpy as np import torch import torch.distributed as dist import torch.nn as nn @@ -80,6 +81,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + sub_dp_size: int = 1, # further divide zero into sub-dp groups and zero groups forced_dtype: Optional[torch.dtype] = None, master_weights: bool = True, # master weights ): @@ -102,10 +104,37 @@ class LowLevelZeroOptimizer(OptimizerWrapper): self.require_grad_sync = True # if process_group is none, will use the default one - self.dp_pg = dp_process_group + if dp_process_group is None: + dp_process_group = dist.group.WORLD + assert dist.get_world_size(group=dp_process_group) % sub_dp_size == 0 + dp_ranks = dist.get_process_group_ranks(group=dp_process_group) + dp_ranks = np.array(dp_ranks).reshape(sub_dp_size, -1) + sub_dp_rank = dist.get_rank(group=dp_process_group) % dp_ranks.shape[1] + zero_rank = dist.get_rank(group=dp_process_group) // dp_ranks.shape[1] + + if sub_dp_size == 1: + self.dp_pg = dp_process_group + else: + self.dp_pg = None + for i in range(dp_ranks.shape[0]): + group = dist.new_group(dp_ranks[i]) + if i == zero_rank: + assert self.dp_pg is None + self.dp_pg = group self._local_rank = dist.get_rank(group=self.dp_pg) self._world_size = dist.get_world_size(group=self.dp_pg) + self.sub_dp_pg = None + if sub_dp_size > 1: + for i in range(dp_ranks.shape[1]): + group = dist.new_group(dp_ranks[:, i]) + if i == sub_dp_rank: + assert self.sub_dp_pg is None + self.sub_dp_pg = group + if self.sub_dp_pg is not None: + self._sub_dp_rank = dist.get_rank(group=self.sub_dp_pg) + self._sub_dp_world_size = dist.get_world_size(group=self.sub_dp_pg) + # working and master params for mixed precision training self._working_param_groups = dict() self._master_param_groups_of_current_rank = dict() @@ -285,6 +314,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if not self._partition_grads: dist.all_reduce(flat_grads, group=self.dp_pg) + if self.sub_dp_pg is not None: + dist.all_reduce(flat_grads, op=dist.ReduceOp.AVG, group=self.sub_dp_pg) + if flat_grads.dtype != grad_dtype: flat_grads = flat_grads.to(grad_dtype) @@ -296,6 +328,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper): flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) recieved_grad = torch.zeros_like(flat_grads_list[0]) dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) + if self.sub_dp_pg is not None: + dist.all_reduce(recieved_grad, op=dist.ReduceOp.AVG, group=self.sub_dp_pg) if recieved_grad.dtype != grad_dtype: recieved_grad = recieved_grad.to(grad_dtype) @@ -498,6 +532,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # HACK: torch optim would skip tensor whose grad is None self.optim.step() real_master_params[group_id][idx].grad = None + torch.cuda.current_stream().synchronize() if not is_first_step: # update working partition updated by the current rank @@ -516,6 +551,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): working_param.data.copy_( flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param) ) + torch.cuda.current_stream().synchronize() # release the grad release_param_grad(self._master_param_groups_of_current_rank[group_id]) @@ -544,7 +580,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper): total_norm_cuda = torch.tensor( [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float ) - dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) + dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_pg) + if self.sub_dp_pg is not None: + dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.sub_dp_pg) total_norm = total_norm_cuda.item() else: @@ -557,9 +595,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper): total_norm_exponentiated_cuda = torch.tensor( [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float ) - torch.distributed.all_reduce( - total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg - ) + dist.all_reduce(total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg) + if self.sub_dp_pg is not None: + dist.all_reduce(total_norm_exponentiated_cuda, op=dist.ReduceOp.AVG, group=self.sub_dp_pg) total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) return total_norm