mirror of https://github.com/hpcaitech/ColossalAI
parent
e27645376d
commit
dd0420909f
|
@ -79,8 +79,8 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
|
|||
check_colo_module(submodule, recursive=True)
|
||||
|
||||
|
||||
def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recursive=True, mode='default'):
|
||||
compute_pattern = parallel_action.compute_pattern
|
||||
def init_colo_module(module: torch.nn.Module, compute_spec: ComputeSpec, recursive=True, mode='default'):
|
||||
compute_pattern = compute_spec.compute_pattern
|
||||
if is_colo_module(module):
|
||||
# for each param
|
||||
# set DistSpec and ComputeSpec
|
||||
|
@ -96,7 +96,7 @@ def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recu
|
|||
continue
|
||||
param = module.get_parameter(param_name)
|
||||
if isinstance(param, ColoParameter):
|
||||
spec = TensorSpec(dist_spec, parallel_action)
|
||||
spec = TensorSpec(dist_spec, compute_spec)
|
||||
param.set_tensor_spec(spec)
|
||||
for mod in param.shared_param_modules:
|
||||
modules_update_param.add(mod)
|
||||
|
@ -104,4 +104,4 @@ def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recu
|
|||
check_colo_module(mod, recursive=False)
|
||||
if recursive == True:
|
||||
for submodule in module.children():
|
||||
init_colo_module(submodule, parallel_action, recursive=True, mode=mode)
|
||||
init_colo_module(submodule, compute_spec, recursive=True, mode=mode)
|
||||
|
|
|
@ -9,7 +9,7 @@ class TensorSpec(object):
|
|||
The specification of the ColoTensor.
|
||||
Args:
|
||||
dist_spec (_DistSpec): descriping the layout among processes.
|
||||
parallel_action (Optional[ComputeSpec], optional): actions conducted on the tensor after initialization if it's a model data tensor.
|
||||
compute_spec (Optional[ComputeSpec], optional): actions conducted on the tensor after initialization if it's a model data tensor.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
|
|
Loading…
Reference in New Issue