mirror of https://github.com/hpcaitech/ColossalAI
to: remove MoE temporarily
parent
93aaa21d4a
commit
a53c8c1ade
|
@ -9,7 +9,6 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor, inf
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
@ -21,7 +20,6 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
|||
)
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
|
||||
from .bookkeeping import BucketStore, GradientStore, ParameterStore
|
||||
|
@ -76,7 +74,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
cpu_offload: bool = False, # cpu offload
|
||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
moe_extra_dp_process_group: Optional[ProcessGroup] = None,
|
||||
master_weights: bool = True, # master weights
|
||||
):
|
||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||
|
@ -102,16 +99,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
self._local_rank = dist.get_rank(group=self.dp_pg)
|
||||
self._world_size = dist.get_world_size(group=self.dp_pg)
|
||||
|
||||
# extra dp
|
||||
# This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size.
|
||||
# Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg.
|
||||
# Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step.
|
||||
# And moe working and master param are split by extra dp pg.
|
||||
self.moe_extra_dp_pg = moe_extra_dp_process_group
|
||||
if self.moe_extra_dp_pg is not None:
|
||||
self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg)
|
||||
self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg)
|
||||
|
||||
# working and master params for mixed precision training
|
||||
self._working_param_groups = dict()
|
||||
self._master_param_groups_of_current_rank = dict()
|
||||
|
@ -143,12 +130,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad)
|
||||
self._bucket_store = BucketStore(self.dp_pg)
|
||||
|
||||
# moe param should not be stored in working_groups
|
||||
# because they have different parallel strategy
|
||||
# so we need to store them separately in param_groups
|
||||
# instead of working_groups
|
||||
self.working_moe_params = list()
|
||||
|
||||
# iterate over the param group in the optimizer
|
||||
# partition these param groups for data parallel training
|
||||
# and add buffers to parameter store for future access
|
||||
|
@ -156,11 +137,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
group_params = list()
|
||||
for param in param_group["params"]:
|
||||
if param.requires_grad:
|
||||
if self.moe_extra_dp_pg is None:
|
||||
# skip moe param
|
||||
if is_moe_tensor(param):
|
||||
self.working_moe_params.append(param)
|
||||
continue
|
||||
group_params.append(param)
|
||||
|
||||
# add the working params to working_param_groups for bookkeeping
|
||||
|
@ -174,25 +150,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
# managed by this data parallel rank
|
||||
param_group["params"] = master_param_current_rank
|
||||
|
||||
# if there are moe params, store in addtional group in optim
|
||||
if len(self.working_moe_params) > 0:
|
||||
self._sync_master_param = False
|
||||
param_group = dict()
|
||||
# create fp32 master param
|
||||
for key, value in self.optim.param_groups[0].items():
|
||||
if key != "params":
|
||||
param_group[key] = value
|
||||
self.master_moe_params = []
|
||||
for param in self.working_moe_params:
|
||||
self.master_moe_params.append(param.clone().to(torch.float32).detach())
|
||||
# create mapping from master to working for optimizer io
|
||||
self.moe_master_to_working_map = {}
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param
|
||||
# add to optim
|
||||
param_group["params"] = self.master_moe_params
|
||||
self.optim.param_groups.append(param_group)
|
||||
|
||||
# initialize communication stream for
|
||||
# communication-computation overlapping
|
||||
if self._overlap_communication:
|
||||
|
@ -256,12 +213,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
else:
|
||||
padding_param = param.data.view(-1)
|
||||
|
||||
if self.moe_extra_dp_pg is not None and is_moe_tensor(param):
|
||||
splited_params = padding_param.split(padding_param.numel() // self.moe_extra_dp_pg_size)
|
||||
splited_params = splited_params[self.moe_extra_dp_pg_rank]
|
||||
else:
|
||||
splited_params = padding_param.split(padding_param.numel() // self._world_size)
|
||||
splited_params = splited_params[self._local_rank]
|
||||
splited_params = padding_param.split(padding_param.numel() // self._world_size)
|
||||
splited_params = splited_params[self._local_rank]
|
||||
|
||||
# use fp32 when master_weights is True
|
||||
if self._master_weights is True:
|
||||
|
@ -301,43 +254,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
if self._bucket_store.num_elements_in_bucket() > 0:
|
||||
self._bucket_store.build_grad_in_bucket()
|
||||
|
||||
if self.moe_extra_dp_pg is None:
|
||||
flat_grads = self._bucket_store.get_flatten_grad()
|
||||
flat_grads /= self._world_size
|
||||
else:
|
||||
# record moe and non moe param
|
||||
moe_list = []
|
||||
for param in self._bucket_store._param_list:
|
||||
moe_list.append(is_moe_tensor(param))
|
||||
|
||||
# divide them into different groups
|
||||
moe_grad_list = []
|
||||
non_moe_grad_list = []
|
||||
for grad_list in self._bucket_store._grad_in_bucket.values():
|
||||
non_moe_cur_grad = []
|
||||
moe_cur_grad = []
|
||||
for i in range(len(grad_list)):
|
||||
if moe_list[i] == True:
|
||||
moe_cur_grad.append(grad_list[i])
|
||||
else:
|
||||
non_moe_cur_grad.append(grad_list[i])
|
||||
if len(moe_cur_grad) > 0:
|
||||
moe_grad_list.append(moe_cur_grad)
|
||||
if len(non_moe_cur_grad) > 0:
|
||||
non_moe_grad_list.append(non_moe_cur_grad)
|
||||
|
||||
if len(non_moe_grad_list) > 0:
|
||||
non_moe_flat_grads = []
|
||||
for grad_list in non_moe_grad_list:
|
||||
non_moe_flat_grads.append(_flatten_dense_tensors(grad_list))
|
||||
non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads)
|
||||
non_moe_flat_grads /= self._world_size
|
||||
|
||||
if len(moe_grad_list) > 0:
|
||||
moe_flat_grads = []
|
||||
for grad_list in moe_grad_list:
|
||||
moe_flat_grads.append(_flatten_dense_tensors(grad_list))
|
||||
moe_flat_grads = _flatten_dense_tensors(moe_flat_grads)
|
||||
flat_grads = self._bucket_store.get_flatten_grad()
|
||||
flat_grads /= self._world_size
|
||||
|
||||
# ready to add other tensors to bucket
|
||||
self._bucket_store.reset_num_elements_in_bucket()
|
||||
|
@ -345,13 +263,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
if self._overlap_communication:
|
||||
stream = self._comm_stream
|
||||
# in case of the memory being reused in the default stream
|
||||
if self.moe_extra_dp_pg is None:
|
||||
flat_grads.record_stream(stream)
|
||||
else:
|
||||
if len(non_moe_grad_list) > 0:
|
||||
non_moe_flat_grads.record_stream(stream)
|
||||
if len(moe_grad_list) > 0:
|
||||
moe_flat_grads.record_stream(stream)
|
||||
flat_grads.record_stream(stream)
|
||||
# waiting for ops in the default stream finishing
|
||||
stream.wait_stream(get_accelerator().current_stream())
|
||||
else:
|
||||
|
@ -360,84 +272,29 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
with get_accelerator().stream(stream):
|
||||
group_id = self._bucket_store.current_group_id
|
||||
|
||||
if self.moe_extra_dp_pg is None:
|
||||
grad_dtype = flat_grads.dtype
|
||||
if self._communication_dtype is not None:
|
||||
flat_grads = flat_grads.to(self._communication_dtype)
|
||||
grad_dtype = flat_grads.dtype
|
||||
if self._communication_dtype is not None:
|
||||
flat_grads = flat_grads.to(self._communication_dtype)
|
||||
|
||||
if not self._partition_grads:
|
||||
if self.moe_extra_dp_pg is None:
|
||||
dist.all_reduce(flat_grads, group=self.dp_pg)
|
||||
if flat_grads.dtype != grad_dtype:
|
||||
flat_grads = flat_grads.to(grad_dtype)
|
||||
dist.all_reduce(flat_grads, group=self.dp_pg)
|
||||
if flat_grads.dtype != grad_dtype:
|
||||
flat_grads = flat_grads.to(grad_dtype)
|
||||
|
||||
flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size)
|
||||
grad_in_bucket = self._bucket_store.get_grad()
|
||||
self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id)
|
||||
|
||||
# sync extra zero group
|
||||
else:
|
||||
# sync non moe param in global dp group
|
||||
if len(non_moe_grad_list) > 0:
|
||||
dist.all_reduce(non_moe_flat_grads, group=self.dp_pg)
|
||||
flat_grads_per_rank = non_moe_flat_grads.split(
|
||||
non_moe_flat_grads.numel() // self._world_size
|
||||
)
|
||||
self._update_unpartitoned_grad(non_moe_grad_list, flat_grads_per_rank, group_id)
|
||||
|
||||
# sync moe param only in zero group
|
||||
if len(moe_grad_list) > 0:
|
||||
dist.all_reduce(moe_flat_grads, group=self.moe_extra_dp_pg)
|
||||
flat_grads_per_rank = moe_flat_grads.split(moe_flat_grads.numel() // self._world_size)
|
||||
self._update_unpartitoned_grad(moe_grad_list, flat_grads_per_rank, group_id)
|
||||
flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size)
|
||||
grad_in_bucket = self._bucket_store.get_grad()
|
||||
self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id)
|
||||
|
||||
else:
|
||||
if self.moe_extra_dp_pg is None:
|
||||
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
|
||||
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
||||
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
|
||||
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
|
||||
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
||||
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
|
||||
|
||||
if recieved_grad.dtype != grad_dtype:
|
||||
recieved_grad = recieved_grad.to(grad_dtype)
|
||||
if recieved_grad.dtype != grad_dtype:
|
||||
recieved_grad = recieved_grad.to(grad_dtype)
|
||||
|
||||
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
|
||||
self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1)
|
||||
else:
|
||||
# categorize moe and non moe param
|
||||
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
|
||||
moe_grad_in_bucket_current_rank = []
|
||||
non_moe_grad_in_bucket_current_rank = []
|
||||
for idx, grad in enumerate(grad_in_bucket_current_rank):
|
||||
if moe_list[idx] == True:
|
||||
moe_grad_in_bucket_current_rank.append(grad)
|
||||
else:
|
||||
non_moe_grad_in_bucket_current_rank.append(grad)
|
||||
|
||||
if len(non_moe_grad_list) > 0:
|
||||
flat_grads_list = list(
|
||||
non_moe_flat_grads.split(len(non_moe_flat_grads) // self._world_size)
|
||||
)
|
||||
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
||||
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
|
||||
self._update_partitoned_grad(
|
||||
non_moe_grad_in_bucket_current_rank, recieved_grad, group_id, 1
|
||||
)
|
||||
|
||||
if len(moe_grad_list) > 0:
|
||||
flat_grads_list = list(
|
||||
moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size)
|
||||
)
|
||||
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
||||
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.moe_extra_dp_pg)
|
||||
param_slice = self._world_size // self.moe_extra_dp_pg_size
|
||||
recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice))
|
||||
for split_recieved_grad in recieved_grad:
|
||||
split_recieved_grad = _unflatten_dense_tensors(
|
||||
split_recieved_grad, moe_grad_in_bucket_current_rank
|
||||
)
|
||||
for real_grad, grad in zip(split_recieved_grad, moe_grad_in_bucket_current_rank):
|
||||
param_id = self._bucket_store.get_param_id_of_grad(grad)
|
||||
self._add_grad(real_grad, param_slice, group_id, param_id)
|
||||
grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
|
||||
self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1)
|
||||
|
||||
self._bucket_store.reset()
|
||||
|
||||
|
@ -578,20 +435,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
# else the splited grad should be attached to the splited param
|
||||
grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
|
||||
if len(grads) > 0:
|
||||
# moe hybrid zero
|
||||
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
|
||||
real_working_params[group_id].append(working_param)
|
||||
if self._partition_grads:
|
||||
grad = grads
|
||||
else:
|
||||
param_slice = self._world_size // self.moe_extra_dp_pg_size
|
||||
grad = grads[
|
||||
self.moe_extra_dp_pg_rank * param_slice : (self.moe_extra_dp_pg_rank + 1) * param_slice
|
||||
]
|
||||
grad = flatten(grad)
|
||||
else:
|
||||
real_working_params[group_id].append(working_param)
|
||||
grad = grads[grad_index]
|
||||
real_working_params[group_id].append(working_param)
|
||||
grad = grads[grad_index]
|
||||
# no need to copy fp32 grad if master_weights is False
|
||||
if self._master_weights:
|
||||
grad = grad.to(splited_param.dtype).to(splited_param.device)
|
||||
|
@ -609,26 +454,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
# update the params in the optimizer
|
||||
self.optim.param_groups[group_id]["params"] = real_master_params[group_id]
|
||||
|
||||
# update param for moe ep
|
||||
# move grad to master param and compute norm
|
||||
if len(self.working_moe_params) > 0:
|
||||
moe_grads = []
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
if master_moe_param.grad is not None:
|
||||
raise RuntimeError("Moe param should not have grad here")
|
||||
grad = working_moe_param.grad
|
||||
# no need to copy fp32 grad if master_weights is False
|
||||
if self._master_weights:
|
||||
grad = grad.to(master_moe_param.dtype).to(master_moe_param.device)
|
||||
master_moe_param.grad = grad
|
||||
working_moe_param.grad = None
|
||||
moe_grads.append(grad)
|
||||
grad_partition_groups.append(grad)
|
||||
norm_group = self._compute_grad_norm(gradients=moe_grads)
|
||||
norm_groups.append(norm_group)
|
||||
self.optim.param_groups[-1]["params"] = self.master_moe_params
|
||||
del moe_grads
|
||||
|
||||
# unscale and clip grads
|
||||
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
||||
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
|
||||
|
@ -636,14 +461,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
# update the parameters
|
||||
self.optim.step()
|
||||
|
||||
# release moe grad
|
||||
if len(self.working_moe_params) > 0:
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
master_moe_param.grad = None
|
||||
working_moe_param.data = (
|
||||
master_moe_param.data.to(working_moe_param.device).to(working_moe_param.dtype).detach()
|
||||
)
|
||||
|
||||
# release the grad
|
||||
grad_partition_groups = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
|
@ -655,20 +472,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
master_working_param = self.optim.param_groups[group_id]["params"]
|
||||
for idx, splited_param in enumerate(master_working_param):
|
||||
working_param = real_working_params[group_id][idx]
|
||||
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
|
||||
all_splited_param = [
|
||||
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
|
||||
for _ in range(self.moe_extra_dp_pg_size)
|
||||
]
|
||||
dist.all_gather(
|
||||
all_splited_param, splited_param.to(device).to(self._dtype), group=self.moe_extra_dp_pg
|
||||
)
|
||||
else:
|
||||
all_splited_param = [
|
||||
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
|
||||
for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg)
|
||||
all_splited_param = [
|
||||
torch.zeros(splited_param.shape, device=device, dtype=self._dtype) for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg)
|
||||
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
|
||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||
|
||||
|
@ -802,16 +609,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
working_param = self._param_store.master_to_working_param[id(param)]
|
||||
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
|
||||
gather_tensor = [
|
||||
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
|
||||
]
|
||||
dist.all_gather(gather_tensor, v.to(device), group=self.moe_extra_dp_pg)
|
||||
else:
|
||||
gather_tensor = [
|
||||
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(gather_tensor, v.to(device), group=self.dp_pg)
|
||||
gather_tensor = [
|
||||
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(gather_tensor, v.to(device), group=self.dp_pg)
|
||||
param_state = (
|
||||
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
||||
)
|
||||
|
@ -836,12 +637,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
|
||||
v_list = v.split(v.numel() // self.moe_extra_dp_pg_size)
|
||||
zero_state_dict["state"][param_idx][k] = v_list[self.moe_extra_dp_pg_rank].detach().clone()
|
||||
else:
|
||||
v_list = v.split(v.numel() // self._world_size)
|
||||
zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone()
|
||||
v_list = v.split(v.numel() // self._world_size)
|
||||
zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone()
|
||||
|
||||
self.optim.load_state_dict(zero_state_dict)
|
||||
|
||||
|
@ -873,16 +670,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
|
||||
for k, v in states.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
|
||||
state_tensor = [
|
||||
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
|
||||
]
|
||||
dist.all_gather(state_tensor, v.to(device), group=self.moe_extra_dp_pg)
|
||||
else:
|
||||
state_tensor = [
|
||||
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(state_tensor, v.to(device), group=self.dp_pg)
|
||||
state_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)]
|
||||
dist.all_gather(state_tensor, v.to(device), group=self.dp_pg)
|
||||
state_tensor = (
|
||||
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
|
||||
)
|
||||
|
@ -913,18 +702,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
working_param = p.data.view(-1)
|
||||
if padding_size > 0:
|
||||
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
|
||||
if self.moe_extra_dp_pg is not None and is_moe_tensor(p):
|
||||
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
|
||||
else:
|
||||
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
|
||||
if hasattr(self, "master_moe_params"):
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
master_moe_param.copy_(working_moe_param)
|
||||
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
|
||||
|
||||
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
|
||||
return self._param_store.working_to_master_param
|
||||
|
||||
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
|
||||
if hasattr(self, "moe_master_to_working_map"):
|
||||
return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}
|
||||
return self._param_store.master_to_working_param
|
||||
|
|
Loading…
Reference in New Issue