mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] fix parameters sharding bug (#2716)
parent
2045d45ab7
commit
5b24987fa7
|
@ -426,8 +426,9 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
|
||||||
# we could use .data here, because all the operations just happen before the real training
|
# we could use .data here, because all the operations just happen before the real training
|
||||||
# loop, so we don't need to track these operations in the autograd graph.
|
# loop, so we don't need to track these operations in the autograd graph.
|
||||||
param = torch.nn.Parameter(
|
param = torch.nn.Parameter(
|
||||||
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
|
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
|
||||||
target_sharding_spec).detach().clone())
|
target_sharding_spec).detach().clone())
|
||||||
|
return param
|
||||||
|
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
if node.op == 'call_module':
|
if node.op == 'call_module':
|
||||||
|
@ -438,7 +439,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
|
||||||
setattr(target_module, 'processed', True)
|
setattr(target_module, 'processed', True)
|
||||||
for name, param in target_module.named_parameters():
|
for name, param in target_module.named_parameters():
|
||||||
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
||||||
_shard_param(param, target_sharding_spec)
|
param = _shard_param(param, target_sharding_spec)
|
||||||
|
|
||||||
setattr(target_module, name, param)
|
setattr(target_module, name, param)
|
||||||
_add_hook_for_grad_communication(node, param)
|
_add_hook_for_grad_communication(node, param)
|
||||||
|
@ -469,7 +470,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
|
||||||
target = getattr(target_module, atoms[-1])
|
target = getattr(target_module, atoms[-1])
|
||||||
|
|
||||||
target_sharding_spec = node.sharding_spec
|
target_sharding_spec = node.sharding_spec
|
||||||
_shard_param(target, target_sharding_spec)
|
target = _shard_param(target, target_sharding_spec)
|
||||||
|
|
||||||
assert hasattr(target_module, atoms[-1])
|
assert hasattr(target_module, atoms[-1])
|
||||||
setattr(target_module, atoms[-1], target)
|
setattr(target_module, atoms[-1], target)
|
||||||
|
|
Loading…
Reference in New Issue