[autoparallel] accelerate gpt2 training (#2495)

pull/2522/head
YuliangLiu0306 2023-01-29 11:13:15 +08:00 committed by GitHub
parent a360b9bc44
commit aa0f6686f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 21 additions and 17 deletions

View File

@ -387,14 +387,15 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# register hook to the parameters # register hook to the parameters
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK: if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
def wrapper(param, comm_spec): def wrapper(param, comm_spec, stream):
def hook_fn(grad): def hook_fn(grad):
_all_reduce(grad, comm_spec, async_op=False) with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
param.register_hook(hook_fn) param.register_hook(hook_fn)
wrapper(param, comm_spec_to_use) wrapper(param, comm_spec_to_use, reduction_stream)
sharded_buffer_dict = {} sharded_buffer_dict = {}
# apply the sharding spec of buffers # apply the sharding spec of buffers
@ -440,14 +441,15 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# register hook to the parameters # register hook to the parameters
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK: if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
def wrapper(param, comm_spec): def wrapper(param, comm_spec, stream):
def hook_fn(grad): def hook_fn(grad):
_all_reduce(grad, comm_spec, async_op=False) with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
param.register_hook(hook_fn) param.register_hook(hook_fn)
wrapper(target, comm_spec_to_use) wrapper(target, comm_spec_to_use, reduction_stream)
return gm return gm

View File

@ -483,4 +483,6 @@ class MatMulHandler(NodeHandler):
raise TypeError( raise TypeError(
f"Found unexpected output type {type(output)} from the recover method of BmmTransform") f"Found unexpected output type {type(output)} from the recover method of BmmTransform")
strategies = recovered_stragies strategies = recovered_stragies
for index, strategies in enumerate(strategies):
strategies.name = f"{strategies.name}_{index}"
return strategies return strategies

View File

@ -247,12 +247,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
strategies.append(self.split_rhs_space_both_contract(1, 0)) strategies.append(self.split_rhs_space_both_contract(1, 0))
# RR= RS x SR # RR= RS x SR
strategies.append(self.recompute_split_both_contract(0)) # strategies.append(self.recompute_split_both_contract(0))
strategies.append(self.recompute_split_both_contract(1)) # strategies.append(self.recompute_split_both_contract(1))
# RS = RR x RS # # RS = RR x RS
strategies.append(self.split_rhs_space_only(0)) # strategies.append(self.split_rhs_space_only(0))
strategies.append(self.split_rhs_space_only(1)) # strategies.append(self.split_rhs_space_only(1))
# S01R = S01R x RR # S01R = S01R x RR
strategies.append(self.split_lhs_1st_dim_1d(0, 1)) strategies.append(self.split_lhs_1st_dim_1d(0, 1))
@ -263,8 +263,8 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# RS01 = RR x RS01 # RS01 = RR x RS01
strategies.append(self.split_rhs_2nd_dim_1d(0, 1)) strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
# RR = RR x RR # # RR = RR x RR
strategies.append(self.non_split()) # strategies.append(self.non_split())
return strategies return strategies

View File

@ -98,7 +98,7 @@ class DeviceMesh:
return DeviceMesh(self.physical_mesh_id, return DeviceMesh(self.physical_mesh_id,
tuple(flatten_mesh_shape), tuple(flatten_mesh_shape),
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
mesh_beta=[min(self.mesh_beta)] * (flatten_mesh_shape_size - 1), mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
init_process_group=self.init_process_group, init_process_group=self.init_process_group,
need_flatten=False) need_flatten=False)

View File

@ -463,7 +463,7 @@ class CommSpec:
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
forward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) forward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
# give a tiny cost to shard # give a tiny cost to shard
backward_communication_cost = 10 backward_communication_cost = 100
if self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: if self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
forward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis) forward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
@ -481,13 +481,13 @@ class CommSpec:
if self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: if self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
# give a tiny cost to shard # give a tiny cost to shard
forward_communication_cost = 10 forward_communication_cost = 100
backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: if self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD:
# no need for axis because all devices are used in mix_gather # no need for axis because all devices are used in mix_gather
forward_communication_cost = self.device_mesh.mix_gather_cost(comm_size) forward_communication_cost = self.device_mesh.mix_gather_cost(comm_size)
backward_communication_cost = 10 backward_communication_cost = 100
if self.forward_only: if self.forward_only:
cost_dict["forward"] = forward_communication_cost cost_dict["forward"] = forward_communication_cost