mirror of https://github.com/hpcaitech/ColossalAI
feat: add `sub_dp_group`
parent
1aaa453706
commit
9291f07964
|
@ -5,6 +5,7 @@ from contextlib import contextmanager
|
|||
from functools import partial
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
@ -80,6 +81,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||
sub_dp_size: int = 1, # further divide zero into sub-dp groups and zero groups
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
master_weights: bool = True, # master weights
|
||||
):
|
||||
|
@ -102,10 +104,37 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
self.require_grad_sync = True
|
||||
|
||||
# if process_group is none, will use the default one
|
||||
self.dp_pg = dp_process_group
|
||||
if dp_process_group is None:
|
||||
dp_process_group = dist.group.WORLD
|
||||
assert dist.get_world_size(group=dp_process_group) % sub_dp_size == 0
|
||||
dp_ranks = dist.get_process_group_ranks(group=dp_process_group)
|
||||
dp_ranks = np.array(dp_ranks).reshape(sub_dp_size, -1)
|
||||
sub_dp_rank = dist.get_rank(group=dp_process_group) % dp_ranks.shape[1]
|
||||
zero_rank = dist.get_rank(group=dp_process_group) // dp_ranks.shape[1]
|
||||
|
||||
if sub_dp_size == 1:
|
||||
self.dp_pg = dp_process_group
|
||||
else:
|
||||
self.dp_pg = None
|
||||
for i in range(dp_ranks.shape[0]):
|
||||
group = dist.new_group(dp_ranks[i])
|
||||
if i == zero_rank:
|
||||
assert self.dp_pg is None
|
||||
self.dp_pg = group
|
||||
self._local_rank = dist.get_rank(group=self.dp_pg)
|
||||
self._world_size = dist.get_world_size(group=self.dp_pg)
|
||||
|
||||
self.sub_dp_pg = None
|
||||
if sub_dp_size > 1:
|
||||
for i in range(dp_ranks.shape[1]):
|
||||
group = dist.new_group(dp_ranks[:, i])
|
||||
if i == sub_dp_rank:
|
||||
assert self.sub_dp_pg is None
|
||||
self.sub_dp_pg = group
|
||||
if self.sub_dp_pg is not None:
|
||||
self._sub_dp_rank = dist.get_rank(group=self.sub_dp_pg)
|
||||
self._sub_dp_world_size = dist.get_world_size(group=self.sub_dp_pg)
|
||||
|
||||
# working and master params for mixed precision training
|
||||
self._working_param_groups = dict()
|
||||
self._master_param_groups_of_current_rank = dict()
|
||||
|
@ -285,6 +314,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
|
||||
if not self._partition_grads:
|
||||
dist.all_reduce(flat_grads, group=self.dp_pg)
|
||||
if self.sub_dp_pg is not None:
|
||||
dist.all_reduce(flat_grads, op=dist.ReduceOp.AVG, group=self.sub_dp_pg)
|
||||
|
||||
if flat_grads.dtype != grad_dtype:
|
||||
flat_grads = flat_grads.to(grad_dtype)
|
||||
|
||||
|
@ -296,6 +328,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
|
||||
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
||||
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
|
||||
if self.sub_dp_pg is not None:
|
||||
dist.all_reduce(recieved_grad, op=dist.ReduceOp.AVG, group=self.sub_dp_pg)
|
||||
|
||||
if recieved_grad.dtype != grad_dtype:
|
||||
recieved_grad = recieved_grad.to(grad_dtype)
|
||||
|
@ -498,6 +532,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
# HACK: torch optim would skip tensor whose grad is None
|
||||
self.optim.step()
|
||||
real_master_params[group_id][idx].grad = None
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
if not is_first_step:
|
||||
# update working partition updated by the current rank
|
||||
|
@ -516,6 +551,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
working_param.data.copy_(
|
||||
flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)
|
||||
)
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
# release the grad
|
||||
release_param_grad(self._master_param_groups_of_current_rank[group_id])
|
||||
|
@ -544,7 +580,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
total_norm_cuda = torch.tensor(
|
||||
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float
|
||||
)
|
||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
|
||||
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_pg)
|
||||
if self.sub_dp_pg is not None:
|
||||
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.sub_dp_pg)
|
||||
total_norm = total_norm_cuda.item()
|
||||
|
||||
else:
|
||||
|
@ -557,9 +595,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
total_norm_exponentiated_cuda = torch.tensor(
|
||||
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
|
||||
)
|
||||
dist.all_reduce(total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
|
||||
if self.sub_dp_pg is not None:
|
||||
dist.all_reduce(total_norm_exponentiated_cuda, op=dist.ReduceOp.AVG, group=self.sub_dp_pg)
|
||||
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
||||
|
||||
return total_norm
|
||||
|
|
Loading…
Reference in New Issue