[Tensor] rename parallel_action (#1174)

* rename parallel_action

* polish
pull/1172/head^2
Ziyue Jiang 2 years ago committed by GitHub
parent e27645376d
commit dd0420909f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -79,8 +79,8 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
check_colo_module(submodule, recursive=True)
def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recursive=True, mode='default'):
compute_pattern = parallel_action.compute_pattern
def init_colo_module(module: torch.nn.Module, compute_spec: ComputeSpec, recursive=True, mode='default'):
compute_pattern = compute_spec.compute_pattern
if is_colo_module(module):
# for each param
# set DistSpec and ComputeSpec
@ -96,7 +96,7 @@ def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recu
continue
param = module.get_parameter(param_name)
if isinstance(param, ColoParameter):
spec = TensorSpec(dist_spec, parallel_action)
spec = TensorSpec(dist_spec, compute_spec)
param.set_tensor_spec(spec)
for mod in param.shared_param_modules:
modules_update_param.add(mod)
@ -104,4 +104,4 @@ def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recu
check_colo_module(mod, recursive=False)
if recursive == True:
for submodule in module.children():
init_colo_module(submodule, parallel_action, recursive=True, mode=mode)
init_colo_module(submodule, compute_spec, recursive=True, mode=mode)

@ -9,7 +9,7 @@ class TensorSpec(object):
The specification of the ColoTensor.
Args:
dist_spec (_DistSpec): descriping the layout among processes.
parallel_action (Optional[ComputeSpec], optional): actions conducted on the tensor after initialization if it's a model data tensor.
compute_spec (Optional[ComputeSpec], optional): actions conducted on the tensor after initialization if it's a model data tensor.
Defaults to None.
"""

Loading…
Cancel
Save