mirror of https://github.com/hpcaitech/ColossalAI
[zero] adapt zero for unsharded parameters (#561)
* support existing sharded and unsharded parameters in zero * add unitest for moe-zero model init * polish moe gradient handlerpull/584/head
parent
13ed4b6441
commit
e6d50ec107
|
@ -16,6 +16,9 @@ class MoeGradientHandler(BaseGradientHandler):
|
|||
the same type to improve the efficiency of communication.
|
||||
"""
|
||||
|
||||
def __init__(self, model, optimizer=None):
|
||||
super().__init__(model, optimizer)
|
||||
|
||||
def handle_gradient(self):
|
||||
"""A method running an all-reduce operation in a data parallel group.
|
||||
Then running an all-reduce operation for all parameters in experts
|
||||
|
@ -24,13 +27,15 @@ class MoeGradientHandler(BaseGradientHandler):
|
|||
global_data = gpc.data_parallel_size
|
||||
|
||||
if global_data > 1:
|
||||
param_dict = get_moe_epsize_param_dict(self._model)
|
||||
epsize_param_dict = get_moe_epsize_param_dict(self._model)
|
||||
|
||||
# epsize is 1, indicating the params are replicated among processes in data parallelism
|
||||
# use the ParallelMode.DATA to get data parallel group
|
||||
# reduce gradients for all parameters in data parallelism
|
||||
if 1 in param_dict:
|
||||
bucket_allreduce(param_list=param_dict[1], group=gpc.get_group(ParallelMode.DATA))
|
||||
if 1 in epsize_param_dict:
|
||||
bucket_allreduce(param_list=epsize_param_dict[1], group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
for ep_size in param_dict:
|
||||
for ep_size in epsize_param_dict:
|
||||
if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
|
||||
bucket_allreduce(param_list=param_dict[ep_size],
|
||||
bucket_allreduce(param_list=epsize_param_dict[ep_size],
|
||||
group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List
|
||||
from typing import List, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -61,18 +61,25 @@ class PostBackwardFunction(torch.autograd.Function):
|
|||
return (None, None) + args
|
||||
|
||||
|
||||
def register_ophooks_recursively(module: torch.nn.Module, ophook_list: List[BaseOpHook] = None, name: str = ""):
|
||||
def register_ophooks_recursively(module: torch.nn.Module,
|
||||
ophook_list: List[BaseOpHook] = None,
|
||||
name: str = "",
|
||||
filter_fn: Optional[Callable] = None):
|
||||
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
|
||||
# Add hooks for submodules
|
||||
for child_name, child in module.named_children():
|
||||
register_ophooks_recursively(child, ophook_list, name + child_name)
|
||||
register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn)
|
||||
|
||||
# Early return on modules with no parameters.
|
||||
if len(list(module.parameters(recurse=False))) == 0:
|
||||
return
|
||||
|
||||
# return from flitered module
|
||||
if filter_fn is not None and filter_fn(module):
|
||||
return
|
||||
|
||||
if ophook_list is not None:
|
||||
for hook in ophook_list:
|
||||
assert (isinstance(hook, BaseOpHook))
|
||||
|
|
|
@ -35,7 +35,7 @@ class Experts(MoeExperts):
|
|||
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
|
||||
"""
|
||||
|
||||
@no_shard_zero_decrator
|
||||
@no_shard_zero_decrator(is_replicated=False)
|
||||
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
|
||||
super().__init__("all_to_all", num_experts)
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from colossalai.context.moe_context import MOE_CONTEXT
|
|||
from colossalai.utils import get_current_device
|
||||
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
|
||||
from .experts import MoeExperts, Experts
|
||||
from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator
|
||||
from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator, autocast_softmax
|
||||
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
|
||||
from typing import Callable, Optional, Type
|
||||
from torch.distributed import ProcessGroup
|
||||
|
@ -66,7 +66,7 @@ class Top1Router(nn.Module):
|
|||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
logits = F.softmax(inputs, dim=-1)
|
||||
logits = autocast_softmax(inputs, dim=-1)
|
||||
num_experts = logits.size(-1)
|
||||
capacity = self.get_capacity(logits.shape)
|
||||
|
||||
|
@ -152,7 +152,7 @@ class Top2Router(nn.Module):
|
|||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
logits = F.softmax(inputs, dim=-1) # logits: [s, e]
|
||||
logits = autocast_softmax(inputs, dim=-1) # logits: [s, e]
|
||||
num_experts = logits.size(-1)
|
||||
capacity = self.get_capacity(logits.shape)
|
||||
|
||||
|
@ -241,7 +241,7 @@ class MoeLayer(nn.Module):
|
|||
experts (:class:`torch.nn.Module`): Instance of experts generated by Expert.
|
||||
"""
|
||||
|
||||
@no_shard_zero_decrator
|
||||
@no_shard_zero_decrator(is_replicated=True)
|
||||
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts):
|
||||
super().__init__()
|
||||
self.d_model = dim_model
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from .experts import FFNExperts, TPExperts
|
||||
|
@ -51,6 +52,12 @@ class UniformNoiseGenerator:
|
|||
return inputs * noisy
|
||||
|
||||
|
||||
def autocast_softmax(logit: torch.Tensor, dim: int):
|
||||
if logit.dtype != torch.float32:
|
||||
logit = logit.float()
|
||||
return F.softmax(logit, dim=dim)
|
||||
|
||||
|
||||
def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||
mep_size = MOE_CONTEXT.max_ep_size
|
||||
if num_experts % mep_size == 0 or mep_size % num_experts == 0:
|
||||
|
|
|
@ -39,7 +39,7 @@ def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
|
|||
if t.device.type == 'cpu':
|
||||
_cpu_mem_usage += t.numel() * t.element_size()
|
||||
elif t.device.type == 'cuda':
|
||||
_cuda_mem_usages += t.numel() * t.element_size()
|
||||
_cuda_mem_usage += t.numel() * t.element_size()
|
||||
return _cuda_mem_usage, _cpu_mem_usage
|
||||
|
||||
cuda_mem_usage = 0
|
||||
|
|
|
@ -88,6 +88,8 @@ class ZeroContextConfig(object):
|
|||
"""The configuration used to control zero context initialization.
|
||||
|
||||
Args:
|
||||
replicated (bool, optional): Whether the param is replicated across data parallel group.
|
||||
Some parameters are not replicated, e.g. parameters in MOE experts.
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
|
||||
This will reduce memory usage when initializing model.
|
||||
|
@ -97,8 +99,9 @@ class ZeroContextConfig(object):
|
|||
See torchvision resnet18. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, shard_param: bool = False, rm_torch_payload_on_the_fly: bool = False):
|
||||
def __init__(self, replicated: bool = True, shard_param: bool = False, rm_torch_payload_on_the_fly: bool = False):
|
||||
super().__init__()
|
||||
self.is_replicated: bool = replicated
|
||||
self.shard_param: bool = shard_param
|
||||
self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly
|
||||
|
||||
|
@ -139,10 +142,15 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
self.model_numel_tensor = model_numel_tensor
|
||||
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
||||
|
||||
self.config = ZeroContextConfig(shard_param=shard_param,
|
||||
self.config = ZeroContextConfig(replicated=True,
|
||||
shard_param=shard_param,
|
||||
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly)
|
||||
ZeroContextMgr().current_context = self
|
||||
|
||||
@property
|
||||
def is_replicated(self):
|
||||
return self.config.is_replicated
|
||||
|
||||
@property
|
||||
def shard_param(self):
|
||||
return self.config.shard_param
|
||||
|
@ -183,6 +191,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
|
||||
self.model_numel_tensor += param.numel()
|
||||
|
||||
# mark whether the param is replicated
|
||||
param.is_replicated = self.is_replicated
|
||||
|
||||
# convert parameters to half
|
||||
param_half = half_fn(param)
|
||||
param.data = param_half
|
||||
|
@ -224,14 +235,20 @@ class ZeroContextMgr(metaclass=SingletonMeta):
|
|||
self.current_context.config = old_config
|
||||
|
||||
|
||||
def no_shard_zero_context():
|
||||
return ZeroContextMgr().hijack_context_config(shard_param=False, rm_torch_payload_on_the_fly=False)
|
||||
def no_shard_zero_context(is_replicated: bool = True):
|
||||
return ZeroContextMgr().hijack_context_config(replicated=is_replicated,
|
||||
shard_param=False,
|
||||
rm_torch_payload_on_the_fly=False)
|
||||
|
||||
|
||||
def no_shard_zero_decrator(init_func):
|
||||
def no_shard_zero_decrator(is_replicated: bool = True):
|
||||
|
||||
def _no_shard(*args, **kwargs):
|
||||
with no_shard_zero_context():
|
||||
init_func(*args, **kwargs)
|
||||
def _wrapper(init_func):
|
||||
|
||||
return _no_shard
|
||||
def _no_shard(*args, **kwargs):
|
||||
with no_shard_zero_context(is_replicated):
|
||||
init_func(*args, **kwargs)
|
||||
|
||||
return _no_shard
|
||||
|
||||
return _wrapper
|
||||
|
|
|
@ -10,6 +10,7 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.engine.ophooks import register_ophooks_recursively
|
||||
from colossalai.engine.ophooks.zero_hook import ZeroHook
|
||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||
from colossalai.engine.gradient_handler.utils import bucket_allreduce
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
|
@ -67,17 +68,27 @@ class ShardedModelV2(nn.Module):
|
|||
self.logger = get_dist_logger()
|
||||
|
||||
# We force users to use ZeroInitContext
|
||||
sharded = []
|
||||
unsharded = []
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.'
|
||||
sharded.append(param.colo_attr.param_is_sharded)
|
||||
unsharded.append(not param.colo_attr.param_is_sharded)
|
||||
assert all(sharded) or all(
|
||||
unsharded), 'Parameters must be all sharded or all unsharded! Parameters are partially sharded now.'
|
||||
self.shard_param = all(sharded)
|
||||
self.module = module
|
||||
for submodule in module.modules():
|
||||
sharded_cnt = 0
|
||||
unshard_cnt = 0
|
||||
for param in submodule.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.'
|
||||
if param.colo_attr.param_is_sharded:
|
||||
sharded_cnt += 1
|
||||
else:
|
||||
unshard_cnt += 1
|
||||
assert (not sharded_cnt) or (not unshard_cnt), 'nn.Module can not both have shard param and unshard param'
|
||||
submodule.param_is_sharded = (sharded_cnt > 0)
|
||||
|
||||
self.sharded_params = []
|
||||
self.unshard_params = []
|
||||
for param in module.parameters():
|
||||
if param.colo_attr.param_is_sharded:
|
||||
self.sharded_params.append(param)
|
||||
else:
|
||||
self.unshard_params.append(param)
|
||||
|
||||
self.module = module
|
||||
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
||||
self.reduce_scatter_process_group = reduce_scatter_process_group or self.process_group
|
||||
self.world_size = dist.get_world_size(self.process_group)
|
||||
|
@ -95,8 +106,8 @@ class ShardedModelV2(nn.Module):
|
|||
|
||||
# Register hooks
|
||||
self._ophook_list = [ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)]
|
||||
register_ophooks_recursively(self.module, self._ophook_list)
|
||||
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
|
||||
register_ophooks_recursively(self.module, self._ophook_list, filter_fn=lambda m: not m.param_is_sharded)
|
||||
self.param_hook_mgr = BaseParamHookMgr(self.sharded_params)
|
||||
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
||||
|
||||
self.fp32_reduce_scatter = fp32_reduce_scatter
|
||||
|
@ -185,7 +196,6 @@ class ShardedModelV2(nn.Module):
|
|||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
|
||||
|
||||
self._post_backward_operations()
|
||||
for ophook in self._ophook_list:
|
||||
ophook.post_iter()
|
||||
|
@ -224,17 +234,21 @@ class ShardedModelV2(nn.Module):
|
|||
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.reducer.free()
|
||||
# 3. shard tensors not dealed in the zero hook
|
||||
if self.shard_param:
|
||||
tensor_list = []
|
||||
for p in self.module.parameters():
|
||||
if not p.colo_attr.param_is_sharded:
|
||||
tensor_list.append(p.colo_attr.sharded_data_tensor)
|
||||
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
p.colo_attr.remove_torch_payload()
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
|
||||
# 4. move sharded param grad payload to param.grad
|
||||
# all reduce gradients for unsharded parameters
|
||||
reduce_list = [p for p in self.unshard_params if p.is_replicated]
|
||||
bucket_allreduce(reduce_list, self.process_group)
|
||||
|
||||
# 3. shard tensors not dealed in the zero hook
|
||||
tensor_list = []
|
||||
for p in self.sharded_params:
|
||||
if not p.colo_attr.param_is_sharded:
|
||||
tensor_list.append(p.colo_attr.sharded_data_tensor)
|
||||
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
p.colo_attr.remove_torch_payload()
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
|
||||
# 4. set all parameters' grad to None
|
||||
for p in self.module.parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
|
@ -245,6 +259,16 @@ class ShardedModelV2(nn.Module):
|
|||
# We also allows to interleave no-sync pass with sync passes, if desired.
|
||||
if not self._require_backward_grad_sync:
|
||||
continue
|
||||
|
||||
# move unsharded param grad to saved_grad
|
||||
if not p.colo_attr.param_is_sharded:
|
||||
if p.colo_attr.offload_grad:
|
||||
colo_model_data_move_to_cpu(p.grad)
|
||||
if p.colo_attr.saved_grad.is_null():
|
||||
p.colo_attr.saved_grad.reset_payload(p.grad.data)
|
||||
else:
|
||||
p.colo_attr.saved_grad.payload.add_(p.grad.data)
|
||||
|
||||
p.grad = None
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -320,16 +344,14 @@ class ShardedModelV2(nn.Module):
|
|||
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.module.parameters()],
|
||||
self.process_group)
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
|
||||
prev_params = {}
|
||||
for p in self.module.parameters():
|
||||
for p in self.sharded_params:
|
||||
prev_params[p] = p.data
|
||||
p.data = p.colo_attr.sharded_data_tensor.payload
|
||||
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.module.parameters()],
|
||||
self.process_group)
|
||||
for p in self.module.parameters():
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
|
||||
for p in self.sharded_params:
|
||||
p.data = prev_params[p]
|
||||
return gathered_state_dict
|
||||
|
||||
|
|
|
@ -22,19 +22,14 @@ class MoeModel(nn.Module):
|
|||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.proj1 = nn.Linear(4, 8)
|
||||
self.proj1 = nn.Linear(4, 16)
|
||||
expert_cls = nn.Linear
|
||||
expert_args_dict = dict(in_features=8, out_features=8)
|
||||
self.moe = MoeModule(dim_model=8,
|
||||
num_experts=8,
|
||||
noisy_policy='Jitter',
|
||||
use_residual=True,
|
||||
expert_cls=expert_cls,
|
||||
**expert_args_dict)
|
||||
self.proj2 = nn.Linear(8, 4)
|
||||
expert_args_dict = dict(in_features=16, out_features=16)
|
||||
self.moe = MoeModule(dim_model=16, num_experts=8, use_residual=True, expert_cls=expert_cls, **expert_args_dict)
|
||||
self.proj2 = nn.Linear(16, 4)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.proj1(x)
|
||||
x = self.moe(x)
|
||||
x = self.proj2(x)
|
||||
return x
|
||||
|
@ -75,6 +70,12 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
|
|||
else:
|
||||
assert param.colo_attr.sharded_data_tensor.is_sharded
|
||||
|
||||
# the parameters in moe experts is not replicated
|
||||
if 'experts' in name:
|
||||
assert not param.is_replicated
|
||||
else:
|
||||
assert param.is_replicated
|
||||
|
||||
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||
f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import parameterize, rerun_on_exception
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from colossalai.engine.gradient_handler import MoeGradientHandler
|
||||
from colossalai.context import MOE_CONTEXT
|
||||
from colossalai.testing import assert_equal_in_group
|
||||
|
||||
from tests.test_zero_data_parallel.common import CONFIG, check_grads_padding, run_fwd_bwd
|
||||
from tests.test_moe.test_moe_zero_init import MoeModel
|
||||
|
||||
|
||||
@parameterize("enable_autocast", [False])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_model_test(enable_autocast, shard_strategy_class):
|
||||
shard_strategy = shard_strategy_class()
|
||||
|
||||
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
|
||||
_, train_dataloader, _, _, criterion = get_components_func()
|
||||
|
||||
rm_torch_payload_on_the_fly = False
|
||||
|
||||
with ZeroInitContext(target_device=torch.cuda.current_device(),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True,
|
||||
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly):
|
||||
zero_model = MoeModel()
|
||||
zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True)
|
||||
|
||||
# check whether parameters are identical in ddp
|
||||
for name, p in zero_model.named_parameters():
|
||||
if not p.colo_attr.param_is_sharded and p.is_replicated:
|
||||
assert_equal_in_group(p.data)
|
||||
|
||||
model = MoeModel().half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
model = model.cuda()
|
||||
grad_handler = MoeGradientHandler(model)
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
if i > 5:
|
||||
break
|
||||
|
||||
data, label = cast_tensor_to_fp16(data).cuda(), label.cuda()
|
||||
run_fwd_bwd(model, data, label, criterion, enable_autocast)
|
||||
run_fwd_bwd(zero_model, data, label, criterion, enable_autocast)
|
||||
grad_handler.handle_gradient()
|
||||
|
||||
check_grads_padding(model, zero_model, loose=True)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
MOE_CONTEXT.setup(seed=42)
|
||||
MOE_CONTEXT.reset_loss()
|
||||
run_model_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_moe_zero_model(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_moe_zero_model(world_size=2)
|
|
@ -91,15 +91,19 @@ def check_params(model, zero_model, loose=False):
|
|||
|
||||
def check_grads_padding(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
|
||||
# zero_grad = zero_p.grad.clone().to(p.device)
|
||||
zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device)
|
||||
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
grad = chunks[rank].float()
|
||||
if zero_grad.size(0) > grad.size(0):
|
||||
zero_grad = zero_grad[:grad.size(0)]
|
||||
if zero_p.colo_attr.param_is_sharded:
|
||||
zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device)
|
||||
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
grad = chunks[rank].float()
|
||||
if zero_grad.size(0) > grad.size(0):
|
||||
zero_grad = zero_grad[:grad.size(0)]
|
||||
else:
|
||||
grad = p.grad
|
||||
zero_grad = zero_p.colo_attr.saved_grad.payload
|
||||
assert grad.dtype == zero_grad.dtype
|
||||
assert allclose(grad, zero_grad, loose=loose), f'diff: {grad - zero_grad}'
|
||||
|
||||
|
|
Loading…
Reference in New Issue