|
|
|
@ -79,8 +79,8 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
|
|
|
|
|
check_colo_module(submodule, recursive=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recursive=True, mode='default'):
|
|
|
|
|
compute_pattern = parallel_action.compute_pattern
|
|
|
|
|
def init_colo_module(module: torch.nn.Module, compute_spec: ComputeSpec, recursive=True, mode='default'):
|
|
|
|
|
compute_pattern = compute_spec.compute_pattern
|
|
|
|
|
if is_colo_module(module):
|
|
|
|
|
# for each param
|
|
|
|
|
# set DistSpec and ComputeSpec
|
|
|
|
@ -96,7 +96,7 @@ def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recu
|
|
|
|
|
continue
|
|
|
|
|
param = module.get_parameter(param_name)
|
|
|
|
|
if isinstance(param, ColoParameter):
|
|
|
|
|
spec = TensorSpec(dist_spec, parallel_action)
|
|
|
|
|
spec = TensorSpec(dist_spec, compute_spec)
|
|
|
|
|
param.set_tensor_spec(spec)
|
|
|
|
|
for mod in param.shared_param_modules:
|
|
|
|
|
modules_update_param.add(mod)
|
|
|
|
@ -104,4 +104,4 @@ def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recu
|
|
|
|
|
check_colo_module(mod, recursive=False)
|
|
|
|
|
if recursive == True:
|
|
|
|
|
for submodule in module.children():
|
|
|
|
|
init_colo_module(submodule, parallel_action, recursive=True, mode=mode)
|
|
|
|
|
init_colo_module(submodule, compute_spec, recursive=True, mode=mode)
|
|
|
|
|