mirror of https://github.com/hpcaitech/ColossalAI
[zero] allow passing process group to zero12 (#4153)
* allow passing process group to zero12 * union tp-zero and normal-zero * polish codepull/4359/head
parent
79cf1b5f33
commit
c668801d36
|
@ -3,8 +3,9 @@ from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch import inf
|
from torch import Tensor, inf
|
||||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
from colossalai.tensor import ColoParameter
|
from colossalai.tensor import ColoParameter
|
||||||
from colossalai.utils import is_model_parallel_parameter
|
from colossalai.utils import is_model_parallel_parameter
|
||||||
|
@ -194,25 +195,20 @@ def calculate_global_norm_from_list(norm_list):
|
||||||
return math.sqrt(total_norm)
|
return math.sqrt(total_norm)
|
||||||
|
|
||||||
|
|
||||||
def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
|
def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGroup, norm_type: int = 2) -> int:
|
||||||
"""Clips gradient norm of an iterable of parameters.
|
"""Clips gradient norm of an iterable of parameters.
|
||||||
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
|
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
|
||||||
added functionality to handle model parallel parameters. Note that
|
added functionality to handle model parallel parameters.
|
||||||
the gradients are modified in place.
|
|
||||||
Arguments:
|
|
||||||
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
|
||||||
single Tensor that will have gradients normalized
|
|
||||||
max_norm (float or int): max norm of the gradients
|
|
||||||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
|
||||||
infinity norm.
|
|
||||||
Returns:
|
|
||||||
Total norm of the parameters (viewed as a single vector).
|
|
||||||
"""
|
|
||||||
|
|
||||||
if mp_group is None:
|
Args:
|
||||||
mp_rank = 0
|
gradients (Tensor): The gradients to compute norm
|
||||||
else:
|
dp_group (ProcessGroup): The process group of ZeRO Data Parallelism
|
||||||
mp_rank = dist.get_rank(mp_group)
|
tp_group (ProcessGroup): The process group of Tensor Parallelism
|
||||||
|
norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The total norm of given gradients
|
||||||
|
"""
|
||||||
|
|
||||||
norm_type = float(norm_type)
|
norm_type = float(norm_type)
|
||||||
if norm_type == inf:
|
if norm_type == inf:
|
||||||
|
@ -221,20 +217,12 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
|
||||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group)
|
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group)
|
||||||
|
|
||||||
# Take max across all GPUs.
|
# Take max across all GPUs.
|
||||||
if mp_group is not None:
|
if tp_group is not None:
|
||||||
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX)
|
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX)
|
||||||
total_norm = total_norm_cuda[0].item()
|
total_norm = total_norm_cuda[0].item()
|
||||||
else:
|
else:
|
||||||
total_norm = 0.0
|
total_norm = 0.0
|
||||||
# if dist.get_rank() == 0:
|
for g in gradients:
|
||||||
# logger.info(f"Total Norm beginning {total_norm}")
|
|
||||||
|
|
||||||
for g, p in zip(gradients, params):
|
|
||||||
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
|
|
||||||
tp_param_flag = False
|
|
||||||
if is_model_parallel_parameter(p) or (isinstance(p, ColoParameter) and not p.is_replicate()):
|
|
||||||
tp_param_flag = True
|
|
||||||
if tp_param_flag or mp_rank == 0:
|
|
||||||
param_norm = g.data.double().norm(2)
|
param_norm = g.data.double().norm(2)
|
||||||
total_norm += param_norm.item()**2
|
total_norm += param_norm.item()**2
|
||||||
|
|
||||||
|
@ -242,8 +230,8 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
|
||||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||||
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
|
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
|
||||||
|
|
||||||
if mp_group is not None:
|
if tp_group is not None:
|
||||||
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=mp_group)
|
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group)
|
||||||
|
|
||||||
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
|
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
||||||
|
@ -12,12 +13,9 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
||||||
FP16MixedPrecisionMixin,
|
FP16MixedPrecisionMixin,
|
||||||
MixedPrecisionMixin,
|
MixedPrecisionMixin,
|
||||||
)
|
)
|
||||||
from colossalai.context import ParallelMode
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.interface import OptimizerWrapper
|
from colossalai.interface import OptimizerWrapper
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
# from colossalai.tensor import ColoParameter, ProcessGroup
|
||||||
from colossalai.tensor import ColoParameter, ProcessGroup
|
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
|
||||||
from ._utils import (
|
from ._utils import (
|
||||||
|
@ -77,11 +75,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
overlap_communication: bool = False,
|
overlap_communication: bool = False,
|
||||||
partition_grad: bool = False, # stage 2 flag
|
partition_grad: bool = False, # stage 2 flag
|
||||||
cpu_offload: bool = False, # cpu offload
|
cpu_offload: bool = False, # cpu offload
|
||||||
|
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||||
|
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||||
forced_dtype: Optional[torch.dtype] = None):
|
forced_dtype: Optional[torch.dtype] = None):
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
# 1. process group api
|
# 1. state_dict for checkpoint IO
|
||||||
# 2. checkpoint IO
|
|
||||||
|
|
||||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||||
self._dtype = self.optim.param_groups[0]['params'][0].dtype
|
self._dtype = self.optim.param_groups[0]['params'][0].dtype
|
||||||
|
@ -96,30 +95,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
# grad accumulation
|
# grad accumulation
|
||||||
self.require_grad_sync = True
|
self.require_grad_sync = True
|
||||||
|
|
||||||
colo_pg = self._search_colo_process_group()
|
# if process_group is none, will use the default one
|
||||||
if isinstance(colo_pg, ProcessGroup):
|
self.dp_pg = dp_process_group
|
||||||
self._local_rank = colo_pg.dp_local_rank()
|
self._local_rank = dist.get_rank(group=self.dp_pg)
|
||||||
self._world_size = colo_pg.dp_world_size()
|
self._world_size = dist.get_world_size(group=self.dp_pg)
|
||||||
self._dp_global_ranks = colo_pg.get_ranks_in_dp()
|
|
||||||
self._dp_torch_group = colo_pg.dp_process_group()
|
|
||||||
self._mp_torch_group = None
|
|
||||||
if colo_pg.tp_world_size() > 1:
|
|
||||||
self._mp_torch_group = colo_pg.tp_process_group()
|
|
||||||
elif colo_pg is None:
|
|
||||||
dp_parallel_mode = ParallelMode.DATA
|
|
||||||
mp_parallel_mode = ParallelMode.MODEL
|
|
||||||
|
|
||||||
self._dp_parallel_mode = dp_parallel_mode
|
self.tp_pg = tp_process_group
|
||||||
self._mp_parallel_mode = mp_parallel_mode
|
|
||||||
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
|
|
||||||
self._world_size = gpc.get_world_size(dp_parallel_mode)
|
|
||||||
self._dp_global_ranks = gpc.get_ranks_in_group(dp_parallel_mode)
|
|
||||||
self._dp_torch_group = gpc.get_group(dp_parallel_mode)
|
|
||||||
self._mp_torch_group = None
|
|
||||||
if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1:
|
|
||||||
self._mp_torch_group = gpc.get_group(mp_parallel_mode)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
# working and master params for mixed precision training
|
# working and master params for mixed precision training
|
||||||
self._working_param_groups = dict()
|
self._working_param_groups = dict()
|
||||||
|
@ -145,9 +126,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
|
|
||||||
# ParameterStore will manage the tensor buffers used for zero
|
# ParameterStore will manage the tensor buffers used for zero
|
||||||
# it will not manage the tensors used by mixed precision training
|
# it will not manage the tensors used by mixed precision training
|
||||||
self._param_store = ParameterStore(self._dp_torch_group)
|
self._param_store = ParameterStore(self.dp_pg)
|
||||||
self._grad_store = GradientStore(self._dp_torch_group, partition_grad=partition_grad)
|
self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad)
|
||||||
self._bucket_store = BucketStore(self._dp_torch_group)
|
self._bucket_store = BucketStore(self.dp_pg)
|
||||||
|
|
||||||
# iterate over the param group in the optimizer
|
# iterate over the param group in the optimizer
|
||||||
# partition these param groups for data parallel training
|
# partition these param groups for data parallel training
|
||||||
|
@ -212,22 +193,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
assert param.dtype == self._dtype, \
|
assert param.dtype == self._dtype, \
|
||||||
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
|
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
|
||||||
|
|
||||||
def _search_colo_process_group(self):
|
|
||||||
colo_flag = False
|
|
||||||
colo_pg = None
|
|
||||||
for param_group in self.optim.param_groups:
|
|
||||||
group_params = param_group['params']
|
|
||||||
for param in group_params:
|
|
||||||
if isinstance(param, ColoParameter):
|
|
||||||
colo_flag = True
|
|
||||||
if colo_pg is None:
|
|
||||||
colo_pg = param.get_process_group()
|
|
||||||
else:
|
|
||||||
assert colo_pg == param.get_process_group(), "All parameters should be in a same process group"
|
|
||||||
elif colo_flag:
|
|
||||||
raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.")
|
|
||||||
return colo_pg
|
|
||||||
|
|
||||||
def _create_master_param_current_rank(self, param_list):
|
def _create_master_param_current_rank(self, param_list):
|
||||||
# split each param evenly by world size
|
# split each param evenly by world size
|
||||||
params_current_rank = []
|
params_current_rank = []
|
||||||
|
@ -291,7 +256,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
flat_grads = flat_grads.to(self._communication_dtype)
|
flat_grads = flat_grads.to(self._communication_dtype)
|
||||||
|
|
||||||
if not self._partition_grads:
|
if not self._partition_grads:
|
||||||
dist.all_reduce(flat_grads, group=self._dp_torch_group)
|
dist.all_reduce(flat_grads, group=self.dp_pg)
|
||||||
if flat_grads.dtype != grad_dtype:
|
if flat_grads.dtype != grad_dtype:
|
||||||
flat_grads = flat_grads.to(grad_dtype)
|
flat_grads = flat_grads.to(grad_dtype)
|
||||||
|
|
||||||
|
@ -307,7 +272,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
else:
|
else:
|
||||||
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
|
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
|
||||||
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
||||||
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self._dp_torch_group)
|
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
|
||||||
|
|
||||||
if recieved_grad.dtype != grad_dtype:
|
if recieved_grad.dtype != grad_dtype:
|
||||||
recieved_grad = recieved_grad.to(grad_dtype)
|
recieved_grad = recieved_grad.to(grad_dtype)
|
||||||
|
@ -425,10 +390,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
|
|
||||||
# compute norm
|
# compute norm
|
||||||
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
|
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
|
||||||
norm_group = compute_norm(gradients=working_grads,
|
norm_group = compute_norm(gradients=working_grads, dp_group=self.dp_pg, tp_group=self.tp_pg)
|
||||||
params=real_working_params[group_id],
|
|
||||||
dp_group=self._dp_torch_group,
|
|
||||||
mp_group=self._mp_torch_group)
|
|
||||||
norm_groups.append(norm_group)
|
norm_groups.append(norm_group)
|
||||||
|
|
||||||
self._grad_store.reset_grads_by_group_id(group_id)
|
self._grad_store.reset_grads_by_group_id(group_id)
|
||||||
|
@ -454,7 +416,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
|
|
||||||
for idx, splited_param in enumerate(master_working_param):
|
for idx, splited_param in enumerate(master_working_param):
|
||||||
full_master_param = [torch.zeros_like(splited_param).cuda() for _ in range(self._world_size)]
|
full_master_param = [torch.zeros_like(splited_param).cuda() for _ in range(self._world_size)]
|
||||||
dist.all_gather(full_master_param, splited_param.cuda(), group=self._dp_torch_group)
|
dist.all_gather(full_master_param, splited_param.cuda(), group=self.dp_pg)
|
||||||
working_param = real_working_params[group_id][idx]
|
working_param = real_working_params[group_id][idx]
|
||||||
full_master_param = flatten(full_master_param)[:working_param.numel()].reshape_as(working_param)
|
full_master_param = flatten(full_master_param)[:working_param.numel()].reshape_as(working_param)
|
||||||
working_param.data.copy_(full_master_param)
|
working_param.data.copy_(full_master_param)
|
||||||
|
|
|
@ -33,10 +33,9 @@ def exam_zero_init():
|
||||||
|
|
||||||
assert optimizer1._local_rank == optimizer2._local_rank
|
assert optimizer1._local_rank == optimizer2._local_rank
|
||||||
assert optimizer1._world_size == optimizer2._world_size
|
assert optimizer1._world_size == optimizer2._world_size
|
||||||
assert optimizer1._dp_global_ranks == optimizer2._dp_global_ranks
|
|
||||||
|
|
||||||
mp_group1 = optimizer1._mp_torch_group
|
mp_group1 = optimizer1.tp_pg
|
||||||
mp_group2 = optimizer2._mp_torch_group
|
mp_group2 = optimizer2.tp_pg
|
||||||
assert dist.get_world_size(mp_group1) == dist.get_world_size(mp_group2)
|
assert dist.get_world_size(mp_group1) == dist.get_world_size(mp_group2)
|
||||||
assert dist.get_rank(mp_group1) == dist.get_rank(mp_group2)
|
assert dist.get_rank(mp_group1) == dist.get_rank(mp_group2)
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,9 @@ def exam_zero_with_tp(overlap_flag, partition_flag):
|
||||||
initial_scale=2,
|
initial_scale=2,
|
||||||
clip_grad_norm=1.0,
|
clip_grad_norm=1.0,
|
||||||
overlap_communication=overlap_flag,
|
overlap_communication=overlap_flag,
|
||||||
partition_grad=partition_flag)
|
partition_grad=partition_flag,
|
||||||
|
dp_process_group=tp_pg.dp_process_group(),
|
||||||
|
tp_process_group=tp_pg.tp_process_group())
|
||||||
|
|
||||||
dp_local_rank = tp_pg.dp_local_rank()
|
dp_local_rank = tp_pg.dp_local_rank()
|
||||||
set_seed(255 + dp_local_rank)
|
set_seed(255 + dp_local_rank)
|
||||||
|
|
Loading…
Reference in New Issue