mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] fix base policy (#4229)
parent
208ac8f2ba
commit
7e4de520e1
|
@ -156,7 +156,10 @@ class Policy(ABC):
|
||||||
|
|
||||||
# append or create a new description
|
# append or create a new description
|
||||||
if target_key in policy:
|
if target_key in policy:
|
||||||
policy[target_key].sub_module_replacement.extend(description)
|
if policy[target_key].sub_module_replacement is None:
|
||||||
|
policy[target_key].sub_module_replacement = description
|
||||||
|
else:
|
||||||
|
policy[target_key].sub_module_replacement.extend(description)
|
||||||
else:
|
else:
|
||||||
policy[target_key] = ModulePolicyDescription(sub_module_replacement=description)
|
policy[target_key] = ModulePolicyDescription(sub_module_replacement=description)
|
||||||
|
|
||||||
|
@ -174,7 +177,10 @@ class Policy(ABC):
|
||||||
target_key (Union[str, nn.Module]): the key of the policy to be updated
|
target_key (Union[str, nn.Module]): the key of the policy to be updated
|
||||||
"""
|
"""
|
||||||
if target_key in policy:
|
if target_key in policy:
|
||||||
policy[target_key].method_replacement.update(description)
|
if policy[target_key].method_replacement is None:
|
||||||
|
policy[target_key].method_replacement = description
|
||||||
|
else:
|
||||||
|
policy[target_key].method_replacement.update(description)
|
||||||
else:
|
else:
|
||||||
policy[target_key] = ModulePolicyDescription(method_replacement=description)
|
policy[target_key] = ModulePolicyDescription(method_replacement=description)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue