mirror of https://github.com/InternLM/InternLM
feat(moe):support zero for expert local dp (#404)
* support zero for expert local dp * fix above codes: *treat optim.zero_world_size and optim.zero_local_rank as list in model_checkpoint.py and test_model_checkpoint.py *add overlap and zero check for moe in args_sanity_check(.)pull/407/head
parent
916647c0a1
commit
582ee000bd
|
@ -319,6 +319,13 @@ def args_sanity_check():
|
|||
if "moe_loss_coeff" not in gpc.config.loss:
|
||||
gpc.config.loss._add_item("moe_loss_coeff", 1.0)
|
||||
|
||||
# moe not support overlap and zero1.5 for now
|
||||
if hasattr(gpc.config.model, "num_experts"):
|
||||
assert (
|
||||
not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param
|
||||
), "not support overlap and moe at the same time"
|
||||
assert gpc.config.parallel.zero1 == -1, "moe only support zero1, set zero1=-1 can fix this"
|
||||
|
||||
|
||||
def launch(
|
||||
config: Union[str, Path, Config, Dict],
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
import math
|
||||
from functools import partial
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -11,7 +10,6 @@ from torch.optim import Optimizer
|
|||
|
||||
from internlm.core.context import Config, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.utils import is_moe_param
|
||||
from internlm.monitor import send_alert_message
|
||||
from internlm.solver.optimizer.store import (
|
||||
BucketStore,
|
||||
|
@ -116,16 +114,15 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
super().__init__(optim=optimizer)
|
||||
|
||||
self._cpu_offload = cpu_offload
|
||||
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
||||
self._zero_world_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||
self._broadcast_parallel_mode = ParallelMode.ZERO1
|
||||
self._zero_local_rank = []
|
||||
self._zero_world_size = []
|
||||
self._broadcast_parallel_mode = []
|
||||
|
||||
# ParameterStore will manage the tensor buffers used for zero
|
||||
# it will not manage the tensors used by mixed precision training
|
||||
self._param_store = ParameterStore(ParallelMode.ZERO1)
|
||||
self._grad_store = GradientStore(ParallelMode.DATA)
|
||||
self._non_moe_bucket_store = BucketStore(ParallelMode.DATA)
|
||||
self._moe_bucket_store = BucketStore(ParallelMode.EXPERT_DATA)
|
||||
self._bucket_store = []
|
||||
self._bucket_in_progress = []
|
||||
|
||||
# fp16 and fp32 params for mixed precision training
|
||||
|
@ -163,7 +160,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
f"gpus-{gpc.get_world_size(ParallelMode.GLOBAL)}_"
|
||||
+ f"pp-{gpc.get_local_rank(ParallelMode.PIPELINE)}_"
|
||||
+ f"tp-{gpc.get_local_rank(ParallelMode.TENSOR)}_"
|
||||
+ f"zo-{self._zero_local_rank}.pt"
|
||||
+ f"zo-{gpc.get_local_rank(ParallelMode.ZERO1)}.pt"
|
||||
)
|
||||
self.params_per_rank_id_dict = []
|
||||
self._param_bcast_sync_handler = param_bcast_sync_handler
|
||||
|
@ -182,10 +179,23 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# add the fp16 params to fp16_param_groups for bookkeeping
|
||||
self._fp16_param_groups[group_id] = group_params
|
||||
|
||||
# to find real zero mode. if zero is not used, set all param group as ParallelMode.ZERO1
|
||||
# if zero is used, expert dp group will use ParallelMode.EXPERT_DATA as the real zero mode
|
||||
zero_mode = (
|
||||
ParallelMode.ZERO1
|
||||
if param_group["dp_mode"] == gpc.get_world_size(ParallelMode.ZERO1) == 1 or ParallelMode.DATA
|
||||
else ParallelMode.EXPERT_DATA
|
||||
)
|
||||
self._zero_local_rank.append(gpc.get_local_rank(zero_mode))
|
||||
self._zero_world_size.append(gpc.get_world_size(zero_mode))
|
||||
# TODO _broadcast_parallel_mode is not only used in broadcast, maybe can change its name
|
||||
self._broadcast_parallel_mode.append(zero_mode)
|
||||
self._bucket_store.append(BucketStore(group_id, param_group["dp_mode"]))
|
||||
|
||||
# assign parameters to ranks the params in the list are sorted
|
||||
params_per_rank, no_params_ranks = self._partition_param_list(param_group)
|
||||
params_per_rank, no_params_ranks = self._partition_param_list(group_id, param_group)
|
||||
self.param_group_no_params_ranks.append(no_params_ranks)
|
||||
self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks)
|
||||
self.param_group_has_params.append(self._zero_local_rank[group_id] not in no_params_ranks)
|
||||
|
||||
# store the mapping between param to rank each param should belong to only one rank.
|
||||
# we can skip the moe param and do not keep them in _param_store to save memory
|
||||
|
@ -204,7 +214,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
param.data = param.data.cpu()
|
||||
|
||||
# flatten the reordered tensors
|
||||
for rank in range(self._zero_world_size):
|
||||
for rank in range(self._zero_world_size[group_id]):
|
||||
# No flat fp16 buffer is allocated if the process has no parameters.
|
||||
if rank not in self.param_group_no_params_ranks[group_id]:
|
||||
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
|
||||
|
@ -218,7 +228,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# No flat fp32 buffer is allocated if the process has no parameters.
|
||||
if self.param_group_has_params[group_id]:
|
||||
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(
|
||||
self._zero_local_rank, group_id
|
||||
self._zero_local_rank[group_id], group_id
|
||||
)
|
||||
fp32_flat_current_rank = fp16_flat_current_rank.float()
|
||||
device = "cpu" if self._cpu_offload else get_current_device()
|
||||
|
@ -263,44 +273,35 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
def num_param_groups(self):
|
||||
return len(self._fp16_param_groups)
|
||||
|
||||
def _partition_param_list(self, param_group):
|
||||
def _partition_param_list(self, group_id, param_group):
|
||||
no_params_ranks = []
|
||||
params_per_rank = [[] for _ in range(self._zero_world_size)]
|
||||
numel_per_rank = [0 for _ in range(self._zero_world_size)]
|
||||
self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)])
|
||||
params_per_rank = [[] for _ in range(self._zero_world_size[group_id])]
|
||||
numel_per_rank = [0 for _ in range(self._zero_world_size[group_id])]
|
||||
self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size[group_id])])
|
||||
param_list = param_group["params"]
|
||||
|
||||
if self._is_moe_group(param_group):
|
||||
# for moe group, we do not need to partition the params, just add current
|
||||
# params to params_per_rank[_zero_local_rank]
|
||||
params_per_rank[self._zero_local_rank] = list(param_list)
|
||||
self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None)
|
||||
no_params_ranks = list(range(self._zero_world_size))
|
||||
no_params_ranks.pop(self._zero_local_rank)
|
||||
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
|
||||
for i, param in enumerate(sorted_params):
|
||||
global_id = str(i)
|
||||
for j in range(len(param.size())):
|
||||
global_id = "_".join([global_id, str(param.size()[j])])
|
||||
if self._overlap_sync_param:
|
||||
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
|
||||
else:
|
||||
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
||||
params_per_rank[rank_to_go].append(param)
|
||||
self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
|
||||
numel_per_rank[rank_to_go] += param.numel()
|
||||
|
||||
else:
|
||||
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
|
||||
for i, param in enumerate(sorted_params):
|
||||
global_id = str(i)
|
||||
for j in range(len(param.size())):
|
||||
global_id = "_".join([global_id, str(param.size()[j])])
|
||||
if self._overlap_sync_param:
|
||||
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
|
||||
else:
|
||||
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
||||
params_per_rank[rank_to_go].append(param)
|
||||
self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
|
||||
numel_per_rank[rank_to_go] += param.numel()
|
||||
# check whether any rank is not assigned to parameters.
|
||||
for rank, params in enumerate(params_per_rank):
|
||||
if len(params) == 0:
|
||||
no_params_ranks.append(rank)
|
||||
|
||||
# check whether any rank is not assigned to parameters.
|
||||
for rank, params in enumerate(params_per_rank):
|
||||
if len(params) == 0:
|
||||
no_params_ranks.append(rank)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info( # pylint: disable=W1203
|
||||
f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}"
|
||||
)
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info( # pylint: disable=W1203
|
||||
f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}"
|
||||
)
|
||||
|
||||
return params_per_rank, set(no_params_ranks)
|
||||
|
||||
|
@ -313,6 +314,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
def _is_gate_group(self, param_group):
|
||||
return "gate" in param_group.keys() and param_group["gate"]
|
||||
|
||||
# TODO check expert dp is correct when enable moe and overlap both
|
||||
def _attach_reduction_hook(self):
|
||||
# we iterate over the fp16 params
|
||||
# on each param, we register a hook to its AccumulateGrad object
|
||||
|
@ -346,16 +348,28 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
_define_and_attach(param, reduce_rank)
|
||||
|
||||
def belongs_to_current_rank(self, param) -> bool:
|
||||
"""
|
||||
Check whether a parameter is supposed to be updated by the process of the current rank
|
||||
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type tensor: torch.Tensor
|
||||
|
||||
:return: True if the parameter should be updated by the current rank. Otherwise false.
|
||||
:rtype: bool
|
||||
"""
|
||||
tensor_rank = self._param_store.get_param_rank(param)
|
||||
group_id = getattr(param, "group_id")
|
||||
return tensor_rank == gpc.get_local_rank(self._broadcast_parallel_mode[group_id])
|
||||
|
||||
def _store_and_try_reduce_grads_by_bucket(self, param, reduce_rank=None):
|
||||
param_size = param.numel()
|
||||
|
||||
# check if the bucket is full
|
||||
# if full, will reduce the grads already in the bucket
|
||||
# after reduction, the bucket will be empty
|
||||
if is_moe_param(param):
|
||||
current_bucket = self._moe_bucket_store
|
||||
else:
|
||||
current_bucket = self._non_moe_bucket_store
|
||||
group_id = getattr(param, "group_id")
|
||||
current_bucket = self._bucket_store[group_id]
|
||||
|
||||
if current_bucket.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
||||
self._reduce_grads_stored_in_bucket(current_bucket, reduce_rank, last_bucket=False)
|
||||
|
@ -382,6 +396,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
reduce_rank=reduce_rank,
|
||||
grads=current_bucket.get_grad(reduce_rank=reduce_rank),
|
||||
bucket_size=current_bucket.num_elements_in_bucket(reduce_rank),
|
||||
group_id=current_bucket.get_param_group_id(),
|
||||
dp_parallel_mode=current_bucket.get_dp_parallel_mode(),
|
||||
)
|
||||
|
||||
|
@ -402,14 +417,14 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# update the flag
|
||||
self._param_store.set_param_reduction_state(param, True)
|
||||
|
||||
if self._param_store.belongs_to_current_rank(param):
|
||||
if self.belongs_to_current_rank(param):
|
||||
self._param_store.add_reduced_param_for_compute_norm(param, last_bucket)
|
||||
else:
|
||||
self._param_store.add_previous_reduced_param(param)
|
||||
|
||||
current_bucket.reset_by_rank(reduce_rank)
|
||||
|
||||
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size, dp_parallel_mode):
|
||||
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size, group_id, dp_parallel_mode):
|
||||
grad_buckets_by_dtype = split_half_float_double(grads)
|
||||
next_bucket_list = []
|
||||
# add parameters into bucket for reduction
|
||||
|
@ -418,7 +433,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
for tensor in tensor_list:
|
||||
param_bucket.add_to_bucket(tensor, allow_oversize=True)
|
||||
if not param_bucket.is_empty():
|
||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank, dp_parallel_mode=dp_parallel_mode)
|
||||
self._reduce_and_copy(
|
||||
bucket=param_bucket, reduce_rank=reduce_rank, group_id=group_id, dp_parallel_mode=dp_parallel_mode
|
||||
)
|
||||
next_bucket_list.append(param_bucket)
|
||||
|
||||
# wait for the completion of previouce bucket list reduction, and do unflatten_and_copy()
|
||||
|
@ -433,7 +450,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# after the completion of bucket list reduction, add new buckets into _bucket_in_progress
|
||||
self._bucket_in_progress = next_bucket_list.copy()
|
||||
|
||||
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank, dp_parallel_mode):
|
||||
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank, group_id, dp_parallel_mode):
|
||||
# flatten the tensors and do allreduce
|
||||
bucket.flatten()
|
||||
bucket.commu_handle = reduce_tensor(
|
||||
|
@ -444,7 +461,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
)
|
||||
|
||||
# update the reduced tensor
|
||||
if reduce_rank is None or reduce_rank == self._zero_local_rank:
|
||||
if reduce_rank is None or reduce_rank == self._zero_local_rank[group_id]:
|
||||
bucket.set_unflatten_and_copy_flag(flag=True)
|
||||
|
||||
def _has_inf_or_nan(self, tensor):
|
||||
|
@ -473,8 +490,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
avg_gradients = self._grad_store._averaged_gradients
|
||||
for group_id in range(self.num_param_groups):
|
||||
# the following operations are performed only on the rank to which parameters are assigned.
|
||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
||||
param_group = self._param_store.get_fp16_params_by_rank_group(self._zero_local_rank, group_id)
|
||||
if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]:
|
||||
param_group = self._param_store.get_fp16_params_by_rank_group(self._zero_local_rank[group_id], group_id)
|
||||
|
||||
if group_id not in avg_gradients:
|
||||
avg_gradients[group_id] = []
|
||||
|
@ -538,37 +555,11 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
parameters=params,
|
||||
last_stage=last_stage,
|
||||
previous_norm=previous_norm,
|
||||
zero_mode=self._broadcast_parallel_mode[group_id],
|
||||
)
|
||||
|
||||
return norm
|
||||
|
||||
def _compute_norm_with_moe_group(self, group_id):
|
||||
params = self._param_store.get_fp16_params_by_rank_group(group_id=group_id, rank=self._zero_local_rank)
|
||||
# we do not get the average grad for moe parameters, so we have to constuct the gradients list here.
|
||||
grads = [p.grad for p in params]
|
||||
|
||||
if len(params) == 0:
|
||||
grads = [self.padding_grad]
|
||||
params = [self.padding_tensor]
|
||||
|
||||
norm = 0
|
||||
if self._clip_grad_norm > 0:
|
||||
norm = compute_norm(
|
||||
gradients=grads,
|
||||
parameters=params,
|
||||
last_stage=True,
|
||||
is_moe_group=True,
|
||||
)
|
||||
|
||||
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
|
||||
# model and zero have been reduced!!!
|
||||
pg = gpc.get_group(ParallelMode.DATA)
|
||||
scaled_norm = norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
|
||||
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
||||
dist.all_reduce(scaled_norm_tensor, group=pg)
|
||||
all_groups_norm = scaled_norm_tensor.item()
|
||||
return all_groups_norm
|
||||
|
||||
@llm_timeout(func_name="optim_step")
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
@ -591,16 +582,13 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._store_and_try_reduce_grads_by_bucket(param)
|
||||
|
||||
# we need to reduce the gradients left in the communication bucket
|
||||
self._reduce_grads_stored_in_bucket(self._non_moe_bucket_store, reduce_rank=None, last_bucket=True)
|
||||
self._reduce_grads_stored_in_bucket(self._moe_bucket_store, reduce_rank=None, last_bucket=True)
|
||||
for group_id in range(self.num_param_groups):
|
||||
self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True)
|
||||
|
||||
# compute norm for gradients in the before bucket
|
||||
groups_norms = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
if self._is_moe_group(self.optim.param_groups[group_id]):
|
||||
groups_norms.append(None)
|
||||
else:
|
||||
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
||||
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
||||
|
||||
# clear reduced grads
|
||||
# grads in the last bucket is reduced
|
||||
|
@ -616,15 +604,22 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
for group_id in range(self.num_param_groups):
|
||||
group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default"
|
||||
group_name = f"{group_id}_{group_name}"
|
||||
total_norms[group_name] = self._compute_norm_with_stage(
|
||||
group_id=group_id,
|
||||
last_bucket=True,
|
||||
last_stage=True,
|
||||
previous_norm=groups_norms[group_id],
|
||||
)
|
||||
|
||||
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced
|
||||
# during allreduce
|
||||
if self._is_moe_group(self.optim.param_groups[group_id]):
|
||||
total_norms[group_name] = self._compute_norm_with_moe_group(group_id=group_id)
|
||||
else:
|
||||
total_norms[group_name] = self._compute_norm_with_stage(
|
||||
group_id=group_id,
|
||||
last_bucket=True,
|
||||
last_stage=True,
|
||||
previous_norm=groups_norms[group_id],
|
||||
)
|
||||
# model and zero have been reduced!!!
|
||||
pg = gpc.get_group(ParallelMode.EXPERT)
|
||||
scaled_norm = total_norms[group_name] * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT))
|
||||
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
||||
dist.all_reduce(scaled_norm_tensor, group=pg)
|
||||
total_norms[group_name] = scaled_norm_tensor.item()
|
||||
|
||||
timer("sync_grad").start()
|
||||
self._sync_grad()
|
||||
|
@ -746,7 +741,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
for group_id in range(len(self._fp16_param_groups)):
|
||||
if self.param_group_has_params[group_id]:
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
|
||||
rank=self._zero_local_rank, group_id=group_id
|
||||
rank=self._zero_local_rank[group_id], group_id=group_id
|
||||
)
|
||||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||
fp16_param.data.copy_(fp32_param)
|
||||
|
@ -766,27 +761,26 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
def broadcast_params(self):
|
||||
handles = []
|
||||
|
||||
for rank, group_id in product(range(self._zero_world_size), range(self.num_param_groups)):
|
||||
if self._is_moe_group(self.optim.param_groups[group_id]):
|
||||
continue
|
||||
# The following operations are performed only on the rank to which parameters are assigned.
|
||||
if rank in self.param_group_no_params_ranks[group_id]:
|
||||
continue
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
||||
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
|
||||
# assert grank == rank, f"{grank} == {rank}"
|
||||
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
|
||||
handle = dist.broadcast(
|
||||
fp16_param,
|
||||
src=g_rank,
|
||||
group=gpc.get_group(ParallelMode.ZERO1),
|
||||
async_op=True,
|
||||
)
|
||||
for group_id in range(self.num_param_groups):
|
||||
for rank in range(self._zero_world_size[group_id]):
|
||||
# The following operations are performed only on the rank to which parameters are assigned.
|
||||
if rank in self.param_group_no_params_ranks[group_id]:
|
||||
continue
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
||||
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
|
||||
# assert grank == rank, f"{grank} == {rank}"
|
||||
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode[group_id])[rank]
|
||||
handle = dist.broadcast(
|
||||
fp16_param,
|
||||
src=g_rank,
|
||||
group=gpc.get_group(self._broadcast_parallel_mode[group_id]),
|
||||
async_op=True,
|
||||
)
|
||||
|
||||
if self._overlap_sync_param:
|
||||
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
|
||||
else:
|
||||
handles.append(handle)
|
||||
if self._overlap_sync_param:
|
||||
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
|
||||
else:
|
||||
handles.append(handle)
|
||||
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
|
@ -802,7 +796,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# check for overflow
|
||||
for group_id in range(len(self._fp16_param_groups)):
|
||||
# The following operations are performed only on the rank to which parameters are assigned.
|
||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
||||
if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]:
|
||||
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
|
||||
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
||||
self._found_overflow.fill_(1.0)
|
||||
|
@ -843,7 +837,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
flat_fp32_weights = {}
|
||||
for group_id, param in self._fp32_flat_param_groups_of_current_rank.items():
|
||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
||||
if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]:
|
||||
assert param.grad is None
|
||||
flat_fp32_weights[group_id] = param
|
||||
states["flat_fp32_weights"] = flat_fp32_weights
|
||||
|
@ -863,7 +857,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
flat_fp32_weights = states["flat_fp32_weights"]
|
||||
assert set(flat_fp32_weights.keys()) == set(self._fp32_flat_param_groups_of_current_rank)
|
||||
for group_id, param in flat_fp32_weights.items():
|
||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
||||
if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]:
|
||||
self_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||
assert (
|
||||
self_param.shape == param.shape
|
||||
|
@ -872,9 +866,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
# Load the fp16 model weights.
|
||||
for group_id in range(len(self._fp16_param_groups)):
|
||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
||||
if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]:
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
|
||||
rank=self._zero_local_rank, group_id=group_id
|
||||
rank=self._zero_local_rank[group_id], group_id=group_id
|
||||
)
|
||||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||
fp16_param.data.copy_(fp32_param)
|
||||
|
@ -891,7 +885,7 @@ def reload_zero_fp32_buff(optimizer):
|
|||
if optimizer.param_group_has_params[group_id]:
|
||||
# flatten fp16 params have already been updated by 'load_model_checkpoint'
|
||||
fp16_flat_current_rank = optimizer._param_store.get_flat_fp16_param_by_rank_group(
|
||||
optimizer._zero_local_rank, group_id
|
||||
optimizer._zero_local_rank[group_id], group_id
|
||||
)
|
||||
# param_group["params"] is fp32 flatten optimizer states of this zero rank.
|
||||
param_group["params"][0].data.copy_(fp16_flat_current_rank.float())
|
||||
|
|
|
@ -33,18 +33,22 @@ class BucketStore(BaseStore):
|
|||
Bucket Store
|
||||
"""
|
||||
|
||||
def __init__(self, dp_parallel_mode):
|
||||
def __init__(self, group_id, dp_parallel_mode):
|
||||
super().__init__(dp_parallel_mode)
|
||||
self._grads = dict()
|
||||
self._params = dict()
|
||||
self._num_elements_in_bucket = dict()
|
||||
self._dp_parallel_mode = dp_parallel_mode
|
||||
self._group_id = group_id
|
||||
|
||||
self.reset()
|
||||
|
||||
def num_elements_in_bucket(self, reduce_rank: int = None):
|
||||
return self._num_elements_in_bucket[reduce_rank]
|
||||
|
||||
def get_param_group_id(self):
|
||||
return self._group_id
|
||||
|
||||
def get_dp_parallel_mode(self):
|
||||
return self._dp_parallel_mode
|
||||
|
||||
|
@ -182,20 +186,6 @@ class ParameterStore(BaseStore):
|
|||
"""
|
||||
return self._fp16_param_to_rank[tensor]
|
||||
|
||||
def belongs_to_current_rank(self, tensor) -> bool:
|
||||
"""
|
||||
Check whether a parameter is supposed to be updated by the process of the current rank
|
||||
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type tensor: torch.Tensor
|
||||
|
||||
:return: True if the parameter should be updated by the current rank. Otherwise false.
|
||||
:rtype: bool
|
||||
"""
|
||||
|
||||
tensor_rank = self._fp16_param_to_rank[tensor]
|
||||
return tensor_rank == self._local_rank
|
||||
|
||||
def add_fp16_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None:
|
||||
if rank not in self._rank_groupid_to_fp16_param_list:
|
||||
self._rank_groupid_to_fp16_param_list[rank] = dict()
|
||||
|
|
|
@ -209,7 +209,9 @@ def calc_lp(grads, norm_type):
|
|||
return norm
|
||||
|
||||
|
||||
def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, norm_type=2, is_moe_group=False):
|
||||
def compute_norm(
|
||||
gradients, parameters, last_stage=False, previous_norm=None, norm_type=2, zero_mode=ParallelMode.ZERO1
|
||||
):
|
||||
"""Get the norm
|
||||
Arguments:
|
||||
gradients (Iterable[Tensor]): The gradient value.
|
||||
|
@ -302,8 +304,7 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no
|
|||
|
||||
# This is because we use zero1, so we need to use this reduction.
|
||||
# TODO: Check zero group to be a subset of dp group.
|
||||
if not is_moe_group:
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.ZERO1))
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(zero_mode))
|
||||
|
||||
if torch.is_tensor(total_norm):
|
||||
total_norm = total_norm.item()
|
||||
|
|
|
@ -2,6 +2,7 @@ from typing import Dict, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
from internlm.core.context.parallel_context import ParallelMode
|
||||
from internlm.core.context.parallel_context import global_context as gpc
|
||||
from internlm.model.utils import is_gate_param, is_moe_param, is_norm_param
|
||||
|
||||
|
@ -37,13 +38,13 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
|||
|
||||
# create new groups for fp32, norm, moe gate and moe expert
|
||||
new_groups = {}
|
||||
new_groups["fp32"] = {"name": "fp32", "params": []}
|
||||
new_groups["fp32"] = {"name": "fp32", "params": [], "dp_mode": ParallelMode.DATA}
|
||||
if gpc.config.model.get("num_experts", 0) > 1:
|
||||
# norm and gate are special group to force sync (when enable MoE).
|
||||
for key in ["gate", "norm"]:
|
||||
new_groups[key] = {"name": key, key: True, "params": []}
|
||||
new_groups[key] = {"name": key, key: True, "params": [], "dp_mode": ParallelMode.DATA}
|
||||
for key in gpc.expert_parallel_group_names:
|
||||
new_groups[key] = {"name": key, "moe": True, "params": []}
|
||||
new_groups[key] = {"name": key, "moe": True, "params": [], "dp_mode": ParallelMode.EXPERT_DATA}
|
||||
|
||||
for pgroup in param_groups:
|
||||
# copy attribute from origin group, we assume the input param_groups only
|
||||
|
@ -72,6 +73,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
|||
|
||||
# bf16 param group, which is the first group in the param groups
|
||||
pgroup["params"] = origin_params
|
||||
pgroup["dp_mode"] = ParallelMode.DATA
|
||||
|
||||
# param groups may contain empty groups, such as fp32
|
||||
param_groups.extend(new_groups.values())
|
||||
|
|
|
@ -392,13 +392,14 @@ def save_optimizer_checkpoint(optim, state_path):
|
|||
zero_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
zero_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
|
||||
|
||||
states = optim.state_dict()
|
||||
if isinstance(optim, HybridZeroOptimizer):
|
||||
if gpc.get_global_rank() < optim.zero_world_size * tp_size * pp_size:
|
||||
if gpc.get_global_rank() < zero_size * tp_size * pp_size:
|
||||
llm_save(os.path.join(state_path, fp), states)
|
||||
if "zero_devide_optim_plan" in states:
|
||||
params_per_rank_id_dict = states.pop("zero_devide_optim_plan")
|
||||
|
|
|
@ -5,7 +5,6 @@ import torch.distributed as dist
|
|||
|
||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.utils import is_moe_param
|
||||
|
||||
|
||||
def is_model_parallel_parameter(p):
|
||||
|
@ -23,7 +22,7 @@ def sync_model_param(model):
|
|||
gpc.is_initialized(ParallelMode.EXPERT_DATA) and gpc.get_world_size(ParallelMode.EXPERT_DATA) > 1
|
||||
)
|
||||
for param in model.parameters():
|
||||
if sync_moe_param and is_moe_param(param):
|
||||
if sync_moe_param and getattr(param, "is_expert", False):
|
||||
ranks = gpc.get_ranks_in_group(ParallelMode.EXPERT_DATA)
|
||||
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.EXPERT_DATA))
|
||||
else:
|
||||
|
|
|
@ -86,13 +86,13 @@ ckpt_config_list = [
|
|||
def overwrite_optim_state(optim, set_value):
|
||||
if isinstance(optim, HybridZeroOptimizer):
|
||||
for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items():
|
||||
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
|
||||
if optim._zero_local_rank[group_id] not in optim.param_group_no_params_ranks[group_id]:
|
||||
# p.copy_(torch.full_like(p, set_value, dtype=p.dtype))
|
||||
p.data.fill_(set_value)
|
||||
for group_id in range(len(optim._fp16_param_groups)):
|
||||
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
|
||||
if optim._zero_local_rank[group_id] not in optim.param_group_no_params_ranks[group_id]:
|
||||
fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group(
|
||||
rank=optim._zero_local_rank, group_id=group_id
|
||||
rank=optim._zero_local_rank[group_id], group_id=group_id
|
||||
)
|
||||
fp16_p.fill_(set_value)
|
||||
else:
|
||||
|
@ -109,7 +109,7 @@ def compare_optim_state(optim1, optim2):
|
|||
fp32_buff2 = optim2._fp32_flat_param_groups_of_current_rank
|
||||
for group_id_1, group_id_2 in zip(fp32_buff1, fp32_buff2):
|
||||
re &= group_id_1 == group_id_2
|
||||
if optim1.zero_local_rank not in optim1.param_group_no_params_ranks[group_id_1]:
|
||||
if optim1.zero_local_rank[group_id_1] not in optim1.param_group_no_params_ranks[group_id_1]:
|
||||
re &= torch.equal(fp32_buff1[group_id_1], fp32_buff1[group_id_2])
|
||||
else:
|
||||
for group1, group2 in zip(optim1.param_groups, optim2.param_groups):
|
||||
|
@ -122,12 +122,12 @@ def compare_optim_value(optim, value):
|
|||
re = True
|
||||
if isinstance(optim, HybridZeroOptimizer):
|
||||
for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items():
|
||||
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
|
||||
if optim._zero_local_rank[group_id] not in optim.param_group_no_params_ranks[group_id]:
|
||||
re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype))
|
||||
for group_id in range(len(optim._fp16_param_groups)):
|
||||
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
|
||||
if optim._zero_local_rank[group_id] not in optim.param_group_no_params_ranks[group_id]:
|
||||
fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group(
|
||||
rank=optim._zero_local_rank, group_id=group_id
|
||||
rank=optim._zero_local_rank[group_id], group_id=group_id
|
||||
)
|
||||
re &= torch.equal(fp16_p, torch.full_like(fp16_p, value, dtype=fp16_p.dtype))
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue