From 7e4de520e16af2f555fee760f158ef9e55d80b12 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Fri, 14 Jul 2023 09:51:53 +0800 Subject: [PATCH] [shardformer] fix base policy (#4229) --- colossalai/shardformer/policies/base_policy.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 68fde0115..69493bfb6 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -156,7 +156,10 @@ class Policy(ABC): # append or create a new description if target_key in policy: - policy[target_key].sub_module_replacement.extend(description) + if policy[target_key].sub_module_replacement is None: + policy[target_key].sub_module_replacement = description + else: + policy[target_key].sub_module_replacement.extend(description) else: policy[target_key] = ModulePolicyDescription(sub_module_replacement=description) @@ -174,7 +177,10 @@ class Policy(ABC): target_key (Union[str, nn.Module]): the key of the policy to be updated """ if target_key in policy: - policy[target_key].method_replacement.update(description) + if policy[target_key].method_replacement is None: + policy[target_key].method_replacement = description + else: + policy[target_key].method_replacement.update(description) else: policy[target_key] = ModulePolicyDescription(method_replacement=description)