From 5b24987fa75adee654aac0b02c8805fb8042cc05 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 15 Feb 2023 12:25:50 +0800 Subject: [PATCH] [autoparallel] fix parameters sharding bug (#2716) --- .../auto_parallel/passes/runtime_preparation_pass.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index bb419be35..e63bfdfe7 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -426,8 +426,9 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes # we could use .data here, because all the operations just happen before the real training # loop, so we don't need to track these operations in the autograd graph. param = torch.nn.Parameter( - shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec, - target_sharding_spec).detach().clone()) + shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec, + target_sharding_spec).detach().clone()) + return param for node in nodes: if node.op == 'call_module': @@ -438,7 +439,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes setattr(target_module, 'processed', True) for name, param in target_module.named_parameters(): target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) - _shard_param(param, target_sharding_spec) + param = _shard_param(param, target_sharding_spec) setattr(target_module, name, param) _add_hook_for_grad_communication(node, param) @@ -469,7 +470,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes target = getattr(target_module, atoms[-1]) target_sharding_spec = node.sharding_spec - _shard_param(target, target_sharding_spec) + target = _shard_param(target, target_sharding_spec) assert hasattr(target_module, atoms[-1]) setattr(target_module, atoms[-1], target)