mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] accelerate gpt2 training (#2495)
parent
a360b9bc44
commit
aa0f6686f9
|
@ -387,14 +387,15 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||
# register hook to the parameters
|
||||
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
|
||||
|
||||
def wrapper(param, comm_spec):
|
||||
def wrapper(param, comm_spec, stream):
|
||||
|
||||
def hook_fn(grad):
|
||||
_all_reduce(grad, comm_spec, async_op=False)
|
||||
with torch.cuda.stream(stream):
|
||||
_all_reduce(grad, comm_spec, async_op=True)
|
||||
|
||||
param.register_hook(hook_fn)
|
||||
|
||||
wrapper(param, comm_spec_to_use)
|
||||
wrapper(param, comm_spec_to_use, reduction_stream)
|
||||
|
||||
sharded_buffer_dict = {}
|
||||
# apply the sharding spec of buffers
|
||||
|
@ -440,14 +441,15 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||
# register hook to the parameters
|
||||
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
|
||||
|
||||
def wrapper(param, comm_spec):
|
||||
def wrapper(param, comm_spec, stream):
|
||||
|
||||
def hook_fn(grad):
|
||||
_all_reduce(grad, comm_spec, async_op=False)
|
||||
with torch.cuda.stream(stream):
|
||||
_all_reduce(grad, comm_spec, async_op=True)
|
||||
|
||||
param.register_hook(hook_fn)
|
||||
|
||||
wrapper(target, comm_spec_to_use)
|
||||
wrapper(target, comm_spec_to_use, reduction_stream)
|
||||
return gm
|
||||
|
||||
|
||||
|
|
|
@ -483,4 +483,6 @@ class MatMulHandler(NodeHandler):
|
|||
raise TypeError(
|
||||
f"Found unexpected output type {type(output)} from the recover method of BmmTransform")
|
||||
strategies = recovered_stragies
|
||||
for index, strategies in enumerate(strategies):
|
||||
strategies.name = f"{strategies.name}_{index}"
|
||||
return strategies
|
||||
|
|
|
@ -247,12 +247,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
strategies.append(self.split_rhs_space_both_contract(1, 0))
|
||||
|
||||
# RR= RS x SR
|
||||
strategies.append(self.recompute_split_both_contract(0))
|
||||
strategies.append(self.recompute_split_both_contract(1))
|
||||
# strategies.append(self.recompute_split_both_contract(0))
|
||||
# strategies.append(self.recompute_split_both_contract(1))
|
||||
|
||||
# RS = RR x RS
|
||||
strategies.append(self.split_rhs_space_only(0))
|
||||
strategies.append(self.split_rhs_space_only(1))
|
||||
# # RS = RR x RS
|
||||
# strategies.append(self.split_rhs_space_only(0))
|
||||
# strategies.append(self.split_rhs_space_only(1))
|
||||
|
||||
# S01R = S01R x RR
|
||||
strategies.append(self.split_lhs_1st_dim_1d(0, 1))
|
||||
|
@ -263,8 +263,8 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
# RS01 = RR x RS01
|
||||
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
|
||||
|
||||
# RR = RR x RR
|
||||
strategies.append(self.non_split())
|
||||
# # RR = RR x RR
|
||||
# strategies.append(self.non_split())
|
||||
|
||||
return strategies
|
||||
|
||||
|
|
|
@ -98,7 +98,7 @@ class DeviceMesh:
|
|||
return DeviceMesh(self.physical_mesh_id,
|
||||
tuple(flatten_mesh_shape),
|
||||
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
|
||||
mesh_beta=[min(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
|
||||
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
|
||||
init_process_group=self.init_process_group,
|
||||
need_flatten=False)
|
||||
|
||||
|
|
|
@ -463,7 +463,7 @@ class CommSpec:
|
|||
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
|
||||
forward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
|
||||
# give a tiny cost to shard
|
||||
backward_communication_cost = 10
|
||||
backward_communication_cost = 100
|
||||
|
||||
if self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
|
||||
forward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
|
||||
|
@ -481,13 +481,13 @@ class CommSpec:
|
|||
|
||||
if self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
|
||||
# give a tiny cost to shard
|
||||
forward_communication_cost = 10
|
||||
forward_communication_cost = 100
|
||||
backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
|
||||
|
||||
if self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD:
|
||||
# no need for axis because all devices are used in mix_gather
|
||||
forward_communication_cost = self.device_mesh.mix_gather_cost(comm_size)
|
||||
backward_communication_cost = 10
|
||||
backward_communication_cost = 100
|
||||
|
||||
if self.forward_only:
|
||||
cost_dict["forward"] = forward_communication_cost
|
||||
|
|
Loading…
Reference in New Issue