feat: add `sub_dp_group`

pull/5817/head
Wenhao Chen 2024-04-01 14:51:06 +08:00 committed by アマデウス
parent 1aaa453706
commit 9291f07964
1 changed files with 43 additions and 5 deletions

View File

@ -5,6 +5,7 @@ from contextlib import contextmanager
from functools import partial
from typing import Dict, Iterator, List, Optional, Tuple
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
@ -80,6 +81,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
sub_dp_size: int = 1, # further divide zero into sub-dp groups and zero groups
forced_dtype: Optional[torch.dtype] = None,
master_weights: bool = True, # master weights
):
@ -102,10 +104,37 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self.require_grad_sync = True
# if process_group is none, will use the default one
self.dp_pg = dp_process_group
if dp_process_group is None:
dp_process_group = dist.group.WORLD
assert dist.get_world_size(group=dp_process_group) % sub_dp_size == 0
dp_ranks = dist.get_process_group_ranks(group=dp_process_group)
dp_ranks = np.array(dp_ranks).reshape(sub_dp_size, -1)
sub_dp_rank = dist.get_rank(group=dp_process_group) % dp_ranks.shape[1]
zero_rank = dist.get_rank(group=dp_process_group) // dp_ranks.shape[1]
if sub_dp_size == 1:
self.dp_pg = dp_process_group
else:
self.dp_pg = None
for i in range(dp_ranks.shape[0]):
group = dist.new_group(dp_ranks[i])
if i == zero_rank:
assert self.dp_pg is None
self.dp_pg = group
self._local_rank = dist.get_rank(group=self.dp_pg)
self._world_size = dist.get_world_size(group=self.dp_pg)
self.sub_dp_pg = None
if sub_dp_size > 1:
for i in range(dp_ranks.shape[1]):
group = dist.new_group(dp_ranks[:, i])
if i == sub_dp_rank:
assert self.sub_dp_pg is None
self.sub_dp_pg = group
if self.sub_dp_pg is not None:
self._sub_dp_rank = dist.get_rank(group=self.sub_dp_pg)
self._sub_dp_world_size = dist.get_world_size(group=self.sub_dp_pg)
# working and master params for mixed precision training
self._working_param_groups = dict()
self._master_param_groups_of_current_rank = dict()
@ -285,6 +314,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if not self._partition_grads:
dist.all_reduce(flat_grads, group=self.dp_pg)
if self.sub_dp_pg is not None:
dist.all_reduce(flat_grads, op=dist.ReduceOp.AVG, group=self.sub_dp_pg)
if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype)
@ -296,6 +328,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
if self.sub_dp_pg is not None:
dist.all_reduce(recieved_grad, op=dist.ReduceOp.AVG, group=self.sub_dp_pg)
if recieved_grad.dtype != grad_dtype:
recieved_grad = recieved_grad.to(grad_dtype)
@ -498,6 +532,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# HACK: torch optim would skip tensor whose grad is None
self.optim.step()
real_master_params[group_id][idx].grad = None
torch.cuda.current_stream().synchronize()
if not is_first_step:
# update working partition updated by the current rank
@ -516,6 +551,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param.data.copy_(
flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)
)
torch.cuda.current_stream().synchronize()
# release the grad
release_param_grad(self._master_param_groups_of_current_rank[group_id])
@ -544,7 +580,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float
)
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_pg)
if self.sub_dp_pg is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.sub_dp_pg)
total_norm = total_norm_cuda.item()
else:
@ -557,9 +595,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float
)
torch.distributed.all_reduce(
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
)
dist.all_reduce(total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
if self.sub_dp_pg is not None:
dist.all_reduce(total_norm_exponentiated_cuda, op=dist.ReduceOp.AVG, group=self.sub_dp_pg)
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
return total_norm