diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index d61ea5373..5e06eb646 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -77,6 +77,7 @@ class ShardedModelV2(nn.Module): tensor_placement_policy: str = 'cuda', gradient_predivide_factor: Optional[float] = 1.0, reuse_fp16_shard: bool = False): + assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.' super().__init__() self.logger = get_dist_logger() diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 610665422..95ab70708 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -87,6 +87,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): mp_process_group: Optional[ProcessGroup] = None, verbose: bool = False) -> None: assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' + assert not isinstance(optimizer, ShardedOptimizerV2), 'Nested ShardedOptimizerV2 is not supported.' super().__init__(optimizer) self.shard_strategy = sharded_model.shard_strategy