[hotfix] prevent nested ZeRO (#1140)

pull/1143/head v0.1.7
ver217 2022-06-21 11:33:53 +08:00 committed by GitHub
parent 15aab1476e
commit 6690a61b4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 0 deletions

View File

@ -77,6 +77,7 @@ class ShardedModelV2(nn.Module):
tensor_placement_policy: str = 'cuda',
gradient_predivide_factor: Optional[float] = 1.0,
reuse_fp16_shard: bool = False):
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
super().__init__()
self.logger = get_dist_logger()

View File

@ -87,6 +87,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
mp_process_group: Optional[ProcessGroup] = None,
verbose: bool = False) -> None:
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
assert not isinstance(optimizer, ShardedOptimizerV2), 'Nested ShardedOptimizerV2 is not supported.'
super().__init__(optimizer)
self.shard_strategy = sharded_model.shard_strategy