[moe] add checkpoint for moe zero test (#729)

pull/731/head
HELSON 2022-04-12 12:11:54 +08:00 committed by GitHub
parent 6f7d1362c9
commit b9b469ea50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 9 deletions

View File

@ -5,6 +5,7 @@ import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.nn import CheckpointModule
from colossalai.logging import get_dist_logger
from colossalai.testing import parameterize
from colossalai.utils import free_port
@ -18,10 +19,10 @@ from colossalai.utils import get_current_device
from tests.test_zero_data_parallel.common import CONFIG
class MoeModel(nn.Module):
class MoeModel(CheckpointModule):
def __init__(self):
super().__init__()
def __init__(self, checkpoint: bool = False):
super().__init__(checkpoint)
self.proj1 = nn.Linear(4, 16)
expert_cls = nn.Linear
expert_args_dict = dict(in_features=16, out_features=16)
@ -52,7 +53,7 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
shard_strategy=shard_strategy_class(),
shard_param=True,
model_numel_tensor=model_numel_tensor):
model = MoeModel()
model = MoeModel(checkpoint=True)
for name, param in model.named_parameters():
assert hasattr(param, 'colo_attr')

View File

@ -31,7 +31,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()),
shard_strategy=shard_strategy,
shard_param=True):
zero_model = MoeModel()
zero_model = MoeModel(checkpoint=True)
zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True)
# check whether parameters are identical in ddp
@ -39,7 +39,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload)
model = MoeModel().half()
model = MoeModel(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
model = model.cuda()
grad_handler = MoeGradientHandler(model)

View File

@ -65,7 +65,7 @@ def _run_test_sharded_optim_v2(cpu_offload,
with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(),
shard_strategy=shard_strategy,
shard_param=True):
zero_model = MoeModel()
zero_model = MoeModel(checkpoint=True)
zero_model = ShardedModelV2(zero_model,
shard_strategy,
@ -78,7 +78,7 @@ def _run_test_sharded_optim_v2(cpu_offload,
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload.to(get_current_device()))
model = MoeModel().half()
model = MoeModel(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
model = model.cuda().float()
@ -129,4 +129,4 @@ def test_moe_zero_optim(world_size):
if __name__ == '__main__':
test_moe_zero_optim(world_size=2)
test_moe_zero_optim(world_size=4)