mirror of https://github.com/hpcaitech/ColossalAI
fix module utils bug (#1066)
parent
a00644079e
commit
6754f1b77f
|
@ -14,7 +14,7 @@ def register_colo_module(module_type: type, colo_module: ColoModule):
|
||||||
def is_colo_module(module: torch.nn.Module):
|
def is_colo_module(module: torch.nn.Module):
|
||||||
global _COLOSSAL_MODULES
|
global _COLOSSAL_MODULES
|
||||||
for module_type in _COLOSSAL_MODULES.keys():
|
for module_type in _COLOSSAL_MODULES.keys():
|
||||||
if isinstance(type(module), module_type):
|
if isinstance(module, module_type):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ def get_colo_module(module: torch.nn.Module):
|
||||||
global _COLOSSAL_MODULES
|
global _COLOSSAL_MODULES
|
||||||
if is_colo_module(module):
|
if is_colo_module(module):
|
||||||
for module_type, colo_module in _COLOSSAL_MODULES.items():
|
for module_type, colo_module in _COLOSSAL_MODULES.items():
|
||||||
if isinstance(type(module), module_type):
|
if isinstance(module, module_type):
|
||||||
return colo_module
|
return colo_module
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
Loading…
Reference in New Issue