feat(moe):support zero for expert local dp (#404)

* support zero for expert local dp

* fix above codes:
    *treat optim.zero_world_size and optim.zero_local_rank as list in model_checkpoint.py and test_model_checkpoint.py
    *add overlap and zero check for moe in args_sanity_check(.)
pull/407/head
Wenwen Qu 2023-10-09 17:45:26 +08:00 committed by GitHub
parent 916647c0a1
commit 582ee000bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 149 additions and 155 deletions

View File

@ -319,6 +319,13 @@ def args_sanity_check():
if "moe_loss_coeff" not in gpc.config.loss:
gpc.config.loss._add_item("moe_loss_coeff", 1.0)
# moe not support overlap and zero1.5 for now
if hasattr(gpc.config.model, "num_experts"):
assert (
not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param
), "not support overlap and moe at the same time"
assert gpc.config.parallel.zero1 == -1, "moe only support zero1, set zero1=-1 can fix this"
def launch(
config: Union[str, Path, Config, Dict],

View File

@ -3,7 +3,6 @@
import math
from functools import partial
from itertools import product
import torch
import torch.distributed as dist
@ -11,7 +10,6 @@ from torch.optim import Optimizer
from internlm.core.context import Config, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import is_moe_param
from internlm.monitor import send_alert_message
from internlm.solver.optimizer.store import (
BucketStore,
@ -116,16 +114,15 @@ class HybridZeroOptimizer(BaseOptimizer):
super().__init__(optim=optimizer)
self._cpu_offload = cpu_offload
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
self._zero_world_size = gpc.get_world_size(ParallelMode.ZERO1)
self._broadcast_parallel_mode = ParallelMode.ZERO1
self._zero_local_rank = []
self._zero_world_size = []
self._broadcast_parallel_mode = []
# ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training
self._param_store = ParameterStore(ParallelMode.ZERO1)
self._grad_store = GradientStore(ParallelMode.DATA)
self._non_moe_bucket_store = BucketStore(ParallelMode.DATA)
self._moe_bucket_store = BucketStore(ParallelMode.EXPERT_DATA)
self._bucket_store = []
self._bucket_in_progress = []
# fp16 and fp32 params for mixed precision training
@ -163,7 +160,7 @@ class HybridZeroOptimizer(BaseOptimizer):
f"gpus-{gpc.get_world_size(ParallelMode.GLOBAL)}_"
+ f"pp-{gpc.get_local_rank(ParallelMode.PIPELINE)}_"
+ f"tp-{gpc.get_local_rank(ParallelMode.TENSOR)}_"
+ f"zo-{self._zero_local_rank}.pt"
+ f"zo-{gpc.get_local_rank(ParallelMode.ZERO1)}.pt"
)
self.params_per_rank_id_dict = []
self._param_bcast_sync_handler = param_bcast_sync_handler
@ -182,10 +179,23 @@ class HybridZeroOptimizer(BaseOptimizer):
# add the fp16 params to fp16_param_groups for bookkeeping
self._fp16_param_groups[group_id] = group_params
# to find real zero mode. if zero is not used, set all param group as ParallelMode.ZERO1
# if zero is used, expert dp group will use ParallelMode.EXPERT_DATA as the real zero mode
zero_mode = (
ParallelMode.ZERO1
if param_group["dp_mode"] == gpc.get_world_size(ParallelMode.ZERO1) == 1 or ParallelMode.DATA
else ParallelMode.EXPERT_DATA
)
self._zero_local_rank.append(gpc.get_local_rank(zero_mode))
self._zero_world_size.append(gpc.get_world_size(zero_mode))
# TODO _broadcast_parallel_mode is not only used in broadcast, maybe can change its name
self._broadcast_parallel_mode.append(zero_mode)
self._bucket_store.append(BucketStore(group_id, param_group["dp_mode"]))
# assign parameters to ranks the params in the list are sorted
params_per_rank, no_params_ranks = self._partition_param_list(param_group)
params_per_rank, no_params_ranks = self._partition_param_list(group_id, param_group)
self.param_group_no_params_ranks.append(no_params_ranks)
self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks)
self.param_group_has_params.append(self._zero_local_rank[group_id] not in no_params_ranks)
# store the mapping between param to rank each param should belong to only one rank.
# we can skip the moe param and do not keep them in _param_store to save memory
@ -204,7 +214,7 @@ class HybridZeroOptimizer(BaseOptimizer):
param.data = param.data.cpu()
# flatten the reordered tensors
for rank in range(self._zero_world_size):
for rank in range(self._zero_world_size[group_id]):
# No flat fp16 buffer is allocated if the process has no parameters.
if rank not in self.param_group_no_params_ranks[group_id]:
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
@ -218,7 +228,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# No flat fp32 buffer is allocated if the process has no parameters.
if self.param_group_has_params[group_id]:
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(
self._zero_local_rank, group_id
self._zero_local_rank[group_id], group_id
)
fp32_flat_current_rank = fp16_flat_current_rank.float()
device = "cpu" if self._cpu_offload else get_current_device()
@ -263,44 +273,35 @@ class HybridZeroOptimizer(BaseOptimizer):
def num_param_groups(self):
return len(self._fp16_param_groups)
def _partition_param_list(self, param_group):
def _partition_param_list(self, group_id, param_group):
no_params_ranks = []
params_per_rank = [[] for _ in range(self._zero_world_size)]
numel_per_rank = [0 for _ in range(self._zero_world_size)]
self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)])
params_per_rank = [[] for _ in range(self._zero_world_size[group_id])]
numel_per_rank = [0 for _ in range(self._zero_world_size[group_id])]
self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size[group_id])])
param_list = param_group["params"]
if self._is_moe_group(param_group):
# for moe group, we do not need to partition the params, just add current
# params to params_per_rank[_zero_local_rank]
params_per_rank[self._zero_local_rank] = list(param_list)
self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None)
no_params_ranks = list(range(self._zero_world_size))
no_params_ranks.pop(self._zero_local_rank)
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
for i, param in enumerate(sorted_params):
global_id = str(i)
for j in range(len(param.size())):
global_id = "_".join([global_id, str(param.size()[j])])
if self._overlap_sync_param:
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
else:
rank_to_go = numel_per_rank.index(min(numel_per_rank))
params_per_rank[rank_to_go].append(param)
self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
numel_per_rank[rank_to_go] += param.numel()
else:
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
for i, param in enumerate(sorted_params):
global_id = str(i)
for j in range(len(param.size())):
global_id = "_".join([global_id, str(param.size()[j])])
if self._overlap_sync_param:
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
else:
rank_to_go = numel_per_rank.index(min(numel_per_rank))
params_per_rank[rank_to_go].append(param)
self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
numel_per_rank[rank_to_go] += param.numel()
# check whether any rank is not assigned to parameters.
for rank, params in enumerate(params_per_rank):
if len(params) == 0:
no_params_ranks.append(rank)
# check whether any rank is not assigned to parameters.
for rank, params in enumerate(params_per_rank):
if len(params) == 0:
no_params_ranks.append(rank)
if gpc.is_rank_for_log():
logger.info( # pylint: disable=W1203
f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}"
)
if gpc.is_rank_for_log():
logger.info( # pylint: disable=W1203
f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}"
)
return params_per_rank, set(no_params_ranks)
@ -313,6 +314,7 @@ class HybridZeroOptimizer(BaseOptimizer):
def _is_gate_group(self, param_group):
return "gate" in param_group.keys() and param_group["gate"]
# TODO check expert dp is correct when enable moe and overlap both
def _attach_reduction_hook(self):
# we iterate over the fp16 params
# on each param, we register a hook to its AccumulateGrad object
@ -346,16 +348,28 @@ class HybridZeroOptimizer(BaseOptimizer):
_define_and_attach(param, reduce_rank)
def belongs_to_current_rank(self, param) -> bool:
"""
Check whether a parameter is supposed to be updated by the process of the current rank
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
:return: True if the parameter should be updated by the current rank. Otherwise false.
:rtype: bool
"""
tensor_rank = self._param_store.get_param_rank(param)
group_id = getattr(param, "group_id")
return tensor_rank == gpc.get_local_rank(self._broadcast_parallel_mode[group_id])
def _store_and_try_reduce_grads_by_bucket(self, param, reduce_rank=None):
param_size = param.numel()
# check if the bucket is full
# if full, will reduce the grads already in the bucket
# after reduction, the bucket will be empty
if is_moe_param(param):
current_bucket = self._moe_bucket_store
else:
current_bucket = self._non_moe_bucket_store
group_id = getattr(param, "group_id")
current_bucket = self._bucket_store[group_id]
if current_bucket.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
self._reduce_grads_stored_in_bucket(current_bucket, reduce_rank, last_bucket=False)
@ -382,6 +396,7 @@ class HybridZeroOptimizer(BaseOptimizer):
reduce_rank=reduce_rank,
grads=current_bucket.get_grad(reduce_rank=reduce_rank),
bucket_size=current_bucket.num_elements_in_bucket(reduce_rank),
group_id=current_bucket.get_param_group_id(),
dp_parallel_mode=current_bucket.get_dp_parallel_mode(),
)
@ -402,14 +417,14 @@ class HybridZeroOptimizer(BaseOptimizer):
# update the flag
self._param_store.set_param_reduction_state(param, True)
if self._param_store.belongs_to_current_rank(param):
if self.belongs_to_current_rank(param):
self._param_store.add_reduced_param_for_compute_norm(param, last_bucket)
else:
self._param_store.add_previous_reduced_param(param)
current_bucket.reset_by_rank(reduce_rank)
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size, dp_parallel_mode):
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size, group_id, dp_parallel_mode):
grad_buckets_by_dtype = split_half_float_double(grads)
next_bucket_list = []
# add parameters into bucket for reduction
@ -418,7 +433,9 @@ class HybridZeroOptimizer(BaseOptimizer):
for tensor in tensor_list:
param_bucket.add_to_bucket(tensor, allow_oversize=True)
if not param_bucket.is_empty():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank, dp_parallel_mode=dp_parallel_mode)
self._reduce_and_copy(
bucket=param_bucket, reduce_rank=reduce_rank, group_id=group_id, dp_parallel_mode=dp_parallel_mode
)
next_bucket_list.append(param_bucket)
# wait for the completion of previouce bucket list reduction, and do unflatten_and_copy()
@ -433,7 +450,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# after the completion of bucket list reduction, add new buckets into _bucket_in_progress
self._bucket_in_progress = next_bucket_list.copy()
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank, dp_parallel_mode):
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank, group_id, dp_parallel_mode):
# flatten the tensors and do allreduce
bucket.flatten()
bucket.commu_handle = reduce_tensor(
@ -444,7 +461,7 @@ class HybridZeroOptimizer(BaseOptimizer):
)
# update the reduced tensor
if reduce_rank is None or reduce_rank == self._zero_local_rank:
if reduce_rank is None or reduce_rank == self._zero_local_rank[group_id]:
bucket.set_unflatten_and_copy_flag(flag=True)
def _has_inf_or_nan(self, tensor):
@ -473,8 +490,8 @@ class HybridZeroOptimizer(BaseOptimizer):
avg_gradients = self._grad_store._averaged_gradients
for group_id in range(self.num_param_groups):
# the following operations are performed only on the rank to which parameters are assigned.
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
param_group = self._param_store.get_fp16_params_by_rank_group(self._zero_local_rank, group_id)
if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]:
param_group = self._param_store.get_fp16_params_by_rank_group(self._zero_local_rank[group_id], group_id)
if group_id not in avg_gradients:
avg_gradients[group_id] = []
@ -538,37 +555,11 @@ class HybridZeroOptimizer(BaseOptimizer):
parameters=params,
last_stage=last_stage,
previous_norm=previous_norm,
zero_mode=self._broadcast_parallel_mode[group_id],
)
return norm
def _compute_norm_with_moe_group(self, group_id):
params = self._param_store.get_fp16_params_by_rank_group(group_id=group_id, rank=self._zero_local_rank)
# we do not get the average grad for moe parameters, so we have to constuct the gradients list here.
grads = [p.grad for p in params]
if len(params) == 0:
grads = [self.padding_grad]
params = [self.padding_tensor]
norm = 0
if self._clip_grad_norm > 0:
norm = compute_norm(
gradients=grads,
parameters=params,
last_stage=True,
is_moe_group=True,
)
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
# model and zero have been reduced!!!
pg = gpc.get_group(ParallelMode.DATA)
scaled_norm = norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
dist.all_reduce(scaled_norm_tensor, group=pg)
all_groups_norm = scaled_norm_tensor.item()
return all_groups_norm
@llm_timeout(func_name="optim_step")
def step(self, closure=None):
"""Performs a single optimization step.
@ -591,16 +582,13 @@ class HybridZeroOptimizer(BaseOptimizer):
self._store_and_try_reduce_grads_by_bucket(param)
# we need to reduce the gradients left in the communication bucket
self._reduce_grads_stored_in_bucket(self._non_moe_bucket_store, reduce_rank=None, last_bucket=True)
self._reduce_grads_stored_in_bucket(self._moe_bucket_store, reduce_rank=None, last_bucket=True)
for group_id in range(self.num_param_groups):
self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None, last_bucket=True)
# compute norm for gradients in the before bucket
groups_norms = []
for group_id in range(self.num_param_groups):
if self._is_moe_group(self.optim.param_groups[group_id]):
groups_norms.append(None)
else:
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
# clear reduced grads
# grads in the last bucket is reduced
@ -616,15 +604,22 @@ class HybridZeroOptimizer(BaseOptimizer):
for group_id in range(self.num_param_groups):
group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default"
group_name = f"{group_id}_{group_name}"
total_norms[group_name] = self._compute_norm_with_stage(
group_id=group_id,
last_bucket=True,
last_stage=True,
previous_norm=groups_norms[group_id],
)
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced
# during allreduce
if self._is_moe_group(self.optim.param_groups[group_id]):
total_norms[group_name] = self._compute_norm_with_moe_group(group_id=group_id)
else:
total_norms[group_name] = self._compute_norm_with_stage(
group_id=group_id,
last_bucket=True,
last_stage=True,
previous_norm=groups_norms[group_id],
)
# model and zero have been reduced!!!
pg = gpc.get_group(ParallelMode.EXPERT)
scaled_norm = total_norms[group_name] * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT))
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
dist.all_reduce(scaled_norm_tensor, group=pg)
total_norms[group_name] = scaled_norm_tensor.item()
timer("sync_grad").start()
self._sync_grad()
@ -746,7 +741,7 @@ class HybridZeroOptimizer(BaseOptimizer):
for group_id in range(len(self._fp16_param_groups)):
if self.param_group_has_params[group_id]:
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
rank=self._zero_local_rank, group_id=group_id
rank=self._zero_local_rank[group_id], group_id=group_id
)
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp16_param.data.copy_(fp32_param)
@ -766,27 +761,26 @@ class HybridZeroOptimizer(BaseOptimizer):
def broadcast_params(self):
handles = []
for rank, group_id in product(range(self._zero_world_size), range(self.num_param_groups)):
if self._is_moe_group(self.optim.param_groups[group_id]):
continue
# The following operations are performed only on the rank to which parameters are assigned.
if rank in self.param_group_no_params_ranks[group_id]:
continue
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
# assert grank == rank, f"{grank} == {rank}"
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
handle = dist.broadcast(
fp16_param,
src=g_rank,
group=gpc.get_group(ParallelMode.ZERO1),
async_op=True,
)
for group_id in range(self.num_param_groups):
for rank in range(self._zero_world_size[group_id]):
# The following operations are performed only on the rank to which parameters are assigned.
if rank in self.param_group_no_params_ranks[group_id]:
continue
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
# assert grank == rank, f"{grank} == {rank}"
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode[group_id])[rank]
handle = dist.broadcast(
fp16_param,
src=g_rank,
group=gpc.get_group(self._broadcast_parallel_mode[group_id]),
async_op=True,
)
if self._overlap_sync_param:
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
else:
handles.append(handle)
if self._overlap_sync_param:
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
else:
handles.append(handle)
for handle in handles:
handle.wait()
@ -802,7 +796,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# check for overflow
for group_id in range(len(self._fp16_param_groups)):
# The following operations are performed only on the rank to which parameters are assigned.
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]:
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
if avg_grad is not None and has_inf_or_nan(avg_grad):
self._found_overflow.fill_(1.0)
@ -843,7 +837,7 @@ class HybridZeroOptimizer(BaseOptimizer):
flat_fp32_weights = {}
for group_id, param in self._fp32_flat_param_groups_of_current_rank.items():
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]:
assert param.grad is None
flat_fp32_weights[group_id] = param
states["flat_fp32_weights"] = flat_fp32_weights
@ -863,7 +857,7 @@ class HybridZeroOptimizer(BaseOptimizer):
flat_fp32_weights = states["flat_fp32_weights"]
assert set(flat_fp32_weights.keys()) == set(self._fp32_flat_param_groups_of_current_rank)
for group_id, param in flat_fp32_weights.items():
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]:
self_param = self._fp32_flat_param_groups_of_current_rank[group_id]
assert (
self_param.shape == param.shape
@ -872,9 +866,9 @@ class HybridZeroOptimizer(BaseOptimizer):
# Load the fp16 model weights.
for group_id in range(len(self._fp16_param_groups)):
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]:
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
rank=self._zero_local_rank, group_id=group_id
rank=self._zero_local_rank[group_id], group_id=group_id
)
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp16_param.data.copy_(fp32_param)
@ -891,7 +885,7 @@ def reload_zero_fp32_buff(optimizer):
if optimizer.param_group_has_params[group_id]:
# flatten fp16 params have already been updated by 'load_model_checkpoint'
fp16_flat_current_rank = optimizer._param_store.get_flat_fp16_param_by_rank_group(
optimizer._zero_local_rank, group_id
optimizer._zero_local_rank[group_id], group_id
)
# param_group["params"] is fp32 flatten optimizer states of this zero rank.
param_group["params"][0].data.copy_(fp16_flat_current_rank.float())

View File

@ -33,18 +33,22 @@ class BucketStore(BaseStore):
Bucket Store
"""
def __init__(self, dp_parallel_mode):
def __init__(self, group_id, dp_parallel_mode):
super().__init__(dp_parallel_mode)
self._grads = dict()
self._params = dict()
self._num_elements_in_bucket = dict()
self._dp_parallel_mode = dp_parallel_mode
self._group_id = group_id
self.reset()
def num_elements_in_bucket(self, reduce_rank: int = None):
return self._num_elements_in_bucket[reduce_rank]
def get_param_group_id(self):
return self._group_id
def get_dp_parallel_mode(self):
return self._dp_parallel_mode
@ -182,20 +186,6 @@ class ParameterStore(BaseStore):
"""
return self._fp16_param_to_rank[tensor]
def belongs_to_current_rank(self, tensor) -> bool:
"""
Check whether a parameter is supposed to be updated by the process of the current rank
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
:return: True if the parameter should be updated by the current rank. Otherwise false.
:rtype: bool
"""
tensor_rank = self._fp16_param_to_rank[tensor]
return tensor_rank == self._local_rank
def add_fp16_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None:
if rank not in self._rank_groupid_to_fp16_param_list:
self._rank_groupid_to_fp16_param_list[rank] = dict()

View File

@ -209,7 +209,9 @@ def calc_lp(grads, norm_type):
return norm
def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, norm_type=2, is_moe_group=False):
def compute_norm(
gradients, parameters, last_stage=False, previous_norm=None, norm_type=2, zero_mode=ParallelMode.ZERO1
):
"""Get the norm
Arguments:
gradients (Iterable[Tensor]): The gradient value.
@ -302,8 +304,7 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no
# This is because we use zero1, so we need to use this reduction.
# TODO: Check zero group to be a subset of dp group.
if not is_moe_group:
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.ZERO1))
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(zero_mode))
if torch.is_tensor(total_norm):
total_norm = total_norm.item()

View File

@ -2,6 +2,7 @@ from typing import Dict, Tuple
import torch
from internlm.core.context.parallel_context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.model.utils import is_gate_param, is_moe_param, is_norm_param
@ -37,13 +38,13 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
# create new groups for fp32, norm, moe gate and moe expert
new_groups = {}
new_groups["fp32"] = {"name": "fp32", "params": []}
new_groups["fp32"] = {"name": "fp32", "params": [], "dp_mode": ParallelMode.DATA}
if gpc.config.model.get("num_experts", 0) > 1:
# norm and gate are special group to force sync (when enable MoE).
for key in ["gate", "norm"]:
new_groups[key] = {"name": key, key: True, "params": []}
new_groups[key] = {"name": key, key: True, "params": [], "dp_mode": ParallelMode.DATA}
for key in gpc.expert_parallel_group_names:
new_groups[key] = {"name": key, "moe": True, "params": []}
new_groups[key] = {"name": key, "moe": True, "params": [], "dp_mode": ParallelMode.EXPERT_DATA}
for pgroup in param_groups:
# copy attribute from origin group, we assume the input param_groups only
@ -72,6 +73,7 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
# bf16 param group, which is the first group in the param groups
pgroup["params"] = origin_params
pgroup["dp_mode"] = ParallelMode.DATA
# param groups may contain empty groups, such as fp32
param_groups.extend(new_groups.values())

View File

@ -392,13 +392,14 @@ def save_optimizer_checkpoint(optim, state_path):
zero_rank = gpc.get_local_rank(ParallelMode.ZERO1)
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
zero_size = gpc.get_world_size(ParallelMode.ZERO1)
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
states = optim.state_dict()
if isinstance(optim, HybridZeroOptimizer):
if gpc.get_global_rank() < optim.zero_world_size * tp_size * pp_size:
if gpc.get_global_rank() < zero_size * tp_size * pp_size:
llm_save(os.path.join(state_path, fp), states)
if "zero_devide_optim_plan" in states:
params_per_rank_id_dict = states.pop("zero_devide_optim_plan")

View File

@ -5,7 +5,6 @@ import torch.distributed as dist
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import is_moe_param
def is_model_parallel_parameter(p):
@ -23,7 +22,7 @@ def sync_model_param(model):
gpc.is_initialized(ParallelMode.EXPERT_DATA) and gpc.get_world_size(ParallelMode.EXPERT_DATA) > 1
)
for param in model.parameters():
if sync_moe_param and is_moe_param(param):
if sync_moe_param and getattr(param, "is_expert", False):
ranks = gpc.get_ranks_in_group(ParallelMode.EXPERT_DATA)
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.EXPERT_DATA))
else:

View File

@ -86,13 +86,13 @@ ckpt_config_list = [
def overwrite_optim_state(optim, set_value):
if isinstance(optim, HybridZeroOptimizer):
for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items():
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
if optim._zero_local_rank[group_id] not in optim.param_group_no_params_ranks[group_id]:
# p.copy_(torch.full_like(p, set_value, dtype=p.dtype))
p.data.fill_(set_value)
for group_id in range(len(optim._fp16_param_groups)):
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
if optim._zero_local_rank[group_id] not in optim.param_group_no_params_ranks[group_id]:
fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group(
rank=optim._zero_local_rank, group_id=group_id
rank=optim._zero_local_rank[group_id], group_id=group_id
)
fp16_p.fill_(set_value)
else:
@ -109,7 +109,7 @@ def compare_optim_state(optim1, optim2):
fp32_buff2 = optim2._fp32_flat_param_groups_of_current_rank
for group_id_1, group_id_2 in zip(fp32_buff1, fp32_buff2):
re &= group_id_1 == group_id_2
if optim1.zero_local_rank not in optim1.param_group_no_params_ranks[group_id_1]:
if optim1.zero_local_rank[group_id_1] not in optim1.param_group_no_params_ranks[group_id_1]:
re &= torch.equal(fp32_buff1[group_id_1], fp32_buff1[group_id_2])
else:
for group1, group2 in zip(optim1.param_groups, optim2.param_groups):
@ -122,12 +122,12 @@ def compare_optim_value(optim, value):
re = True
if isinstance(optim, HybridZeroOptimizer):
for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items():
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
if optim._zero_local_rank[group_id] not in optim.param_group_no_params_ranks[group_id]:
re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype))
for group_id in range(len(optim._fp16_param_groups)):
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
if optim._zero_local_rank[group_id] not in optim.param_group_no_params_ranks[group_id]:
fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group(
rank=optim._zero_local_rank, group_id=group_id
rank=optim._zero_local_rank[group_id], group_id=group_id
)
re &= torch.equal(fp16_p, torch.full_like(fp16_p, value, dtype=fp16_p.dtype))
else: