Browse Source

[shardformer] fix base policy (#4229)

pull/4445/head
Hongxin Liu 1 year ago
parent
commit
7e4de520e1
  1. 10
      colossalai/shardformer/policies/base_policy.py

10
colossalai/shardformer/policies/base_policy.py

@ -156,7 +156,10 @@ class Policy(ABC):
# append or create a new description
if target_key in policy:
policy[target_key].sub_module_replacement.extend(description)
if policy[target_key].sub_module_replacement is None:
policy[target_key].sub_module_replacement = description
else:
policy[target_key].sub_module_replacement.extend(description)
else:
policy[target_key] = ModulePolicyDescription(sub_module_replacement=description)
@ -174,7 +177,10 @@ class Policy(ABC):
target_key (Union[str, nn.Module]): the key of the policy to be updated
"""
if target_key in policy:
policy[target_key].method_replacement.update(description)
if policy[target_key].method_replacement is None:
policy[target_key].method_replacement = description
else:
policy[target_key].method_replacement.update(description)
else:
policy[target_key] = ModulePolicyDescription(method_replacement=description)

Loading…
Cancel
Save