mirror of https://github.com/hpcaitech/ColossalAI
[moe] add checkpoint for moe zero test (#729)
parent
6f7d1362c9
commit
b9b469ea50
|
@ -5,6 +5,7 @@ import pytest
|
|||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.nn import CheckpointModule
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.utils import free_port
|
||||
|
@ -18,10 +19,10 @@ from colossalai.utils import get_current_device
|
|||
from tests.test_zero_data_parallel.common import CONFIG
|
||||
|
||||
|
||||
class MoeModel(nn.Module):
|
||||
class MoeModel(CheckpointModule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, checkpoint: bool = False):
|
||||
super().__init__(checkpoint)
|
||||
self.proj1 = nn.Linear(4, 16)
|
||||
expert_cls = nn.Linear
|
||||
expert_args_dict = dict(in_features=16, out_features=16)
|
||||
|
@ -52,7 +53,7 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
|
|||
shard_strategy=shard_strategy_class(),
|
||||
shard_param=True,
|
||||
model_numel_tensor=model_numel_tensor):
|
||||
model = MoeModel()
|
||||
model = MoeModel(checkpoint=True)
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
assert hasattr(param, 'colo_attr')
|
||||
|
|
|
@ -31,7 +31,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
|
|||
with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
zero_model = MoeModel()
|
||||
zero_model = MoeModel(checkpoint=True)
|
||||
zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True)
|
||||
|
||||
# check whether parameters are identical in ddp
|
||||
|
@ -39,7 +39,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
|
|||
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
|
||||
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload)
|
||||
|
||||
model = MoeModel().half()
|
||||
model = MoeModel(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
model = model.cuda()
|
||||
grad_handler = MoeGradientHandler(model)
|
||||
|
|
|
@ -65,7 +65,7 @@ def _run_test_sharded_optim_v2(cpu_offload,
|
|||
with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
zero_model = MoeModel()
|
||||
zero_model = MoeModel(checkpoint=True)
|
||||
|
||||
zero_model = ShardedModelV2(zero_model,
|
||||
shard_strategy,
|
||||
|
@ -78,7 +78,7 @@ def _run_test_sharded_optim_v2(cpu_offload,
|
|||
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
|
||||
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload.to(get_current_device()))
|
||||
|
||||
model = MoeModel().half()
|
||||
model = MoeModel(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
model = model.cuda().float()
|
||||
|
||||
|
@ -129,4 +129,4 @@ def test_moe_zero_optim(world_size):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_moe_zero_optim(world_size=2)
|
||||
test_moe_zero_optim(world_size=4)
|
||||
|
|
Loading…
Reference in New Issue