mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] accelerate gpt2 training (#2495)
parent
a360b9bc44
commit
aa0f6686f9
|
@ -387,14 +387,15 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||||
# register hook to the parameters
|
# register hook to the parameters
|
||||||
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
|
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
|
||||||
|
|
||||||
def wrapper(param, comm_spec):
|
def wrapper(param, comm_spec, stream):
|
||||||
|
|
||||||
def hook_fn(grad):
|
def hook_fn(grad):
|
||||||
_all_reduce(grad, comm_spec, async_op=False)
|
with torch.cuda.stream(stream):
|
||||||
|
_all_reduce(grad, comm_spec, async_op=True)
|
||||||
|
|
||||||
param.register_hook(hook_fn)
|
param.register_hook(hook_fn)
|
||||||
|
|
||||||
wrapper(param, comm_spec_to_use)
|
wrapper(param, comm_spec_to_use, reduction_stream)
|
||||||
|
|
||||||
sharded_buffer_dict = {}
|
sharded_buffer_dict = {}
|
||||||
# apply the sharding spec of buffers
|
# apply the sharding spec of buffers
|
||||||
|
@ -440,14 +441,15 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||||
# register hook to the parameters
|
# register hook to the parameters
|
||||||
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
|
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
|
||||||
|
|
||||||
def wrapper(param, comm_spec):
|
def wrapper(param, comm_spec, stream):
|
||||||
|
|
||||||
def hook_fn(grad):
|
def hook_fn(grad):
|
||||||
_all_reduce(grad, comm_spec, async_op=False)
|
with torch.cuda.stream(stream):
|
||||||
|
_all_reduce(grad, comm_spec, async_op=True)
|
||||||
|
|
||||||
param.register_hook(hook_fn)
|
param.register_hook(hook_fn)
|
||||||
|
|
||||||
wrapper(target, comm_spec_to_use)
|
wrapper(target, comm_spec_to_use, reduction_stream)
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -483,4 +483,6 @@ class MatMulHandler(NodeHandler):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Found unexpected output type {type(output)} from the recover method of BmmTransform")
|
f"Found unexpected output type {type(output)} from the recover method of BmmTransform")
|
||||||
strategies = recovered_stragies
|
strategies = recovered_stragies
|
||||||
|
for index, strategies in enumerate(strategies):
|
||||||
|
strategies.name = f"{strategies.name}_{index}"
|
||||||
return strategies
|
return strategies
|
||||||
|
|
|
@ -247,12 +247,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
||||||
strategies.append(self.split_rhs_space_both_contract(1, 0))
|
strategies.append(self.split_rhs_space_both_contract(1, 0))
|
||||||
|
|
||||||
# RR= RS x SR
|
# RR= RS x SR
|
||||||
strategies.append(self.recompute_split_both_contract(0))
|
# strategies.append(self.recompute_split_both_contract(0))
|
||||||
strategies.append(self.recompute_split_both_contract(1))
|
# strategies.append(self.recompute_split_both_contract(1))
|
||||||
|
|
||||||
# RS = RR x RS
|
# # RS = RR x RS
|
||||||
strategies.append(self.split_rhs_space_only(0))
|
# strategies.append(self.split_rhs_space_only(0))
|
||||||
strategies.append(self.split_rhs_space_only(1))
|
# strategies.append(self.split_rhs_space_only(1))
|
||||||
|
|
||||||
# S01R = S01R x RR
|
# S01R = S01R x RR
|
||||||
strategies.append(self.split_lhs_1st_dim_1d(0, 1))
|
strategies.append(self.split_lhs_1st_dim_1d(0, 1))
|
||||||
|
@ -263,8 +263,8 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
||||||
# RS01 = RR x RS01
|
# RS01 = RR x RS01
|
||||||
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
|
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
|
||||||
|
|
||||||
# RR = RR x RR
|
# # RR = RR x RR
|
||||||
strategies.append(self.non_split())
|
# strategies.append(self.non_split())
|
||||||
|
|
||||||
return strategies
|
return strategies
|
||||||
|
|
||||||
|
|
|
@ -98,7 +98,7 @@ class DeviceMesh:
|
||||||
return DeviceMesh(self.physical_mesh_id,
|
return DeviceMesh(self.physical_mesh_id,
|
||||||
tuple(flatten_mesh_shape),
|
tuple(flatten_mesh_shape),
|
||||||
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
|
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
|
||||||
mesh_beta=[min(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
|
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
|
||||||
init_process_group=self.init_process_group,
|
init_process_group=self.init_process_group,
|
||||||
need_flatten=False)
|
need_flatten=False)
|
||||||
|
|
||||||
|
|
|
@ -463,7 +463,7 @@ class CommSpec:
|
||||||
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
|
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
|
||||||
forward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
|
forward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
|
||||||
# give a tiny cost to shard
|
# give a tiny cost to shard
|
||||||
backward_communication_cost = 10
|
backward_communication_cost = 100
|
||||||
|
|
||||||
if self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
|
if self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
|
||||||
forward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
|
forward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
|
||||||
|
@ -481,13 +481,13 @@ class CommSpec:
|
||||||
|
|
||||||
if self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
|
if self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
|
||||||
# give a tiny cost to shard
|
# give a tiny cost to shard
|
||||||
forward_communication_cost = 10
|
forward_communication_cost = 100
|
||||||
backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
|
backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
|
||||||
|
|
||||||
if self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD:
|
if self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD:
|
||||||
# no need for axis because all devices are used in mix_gather
|
# no need for axis because all devices are used in mix_gather
|
||||||
forward_communication_cost = self.device_mesh.mix_gather_cost(comm_size)
|
forward_communication_cost = self.device_mesh.mix_gather_cost(comm_size)
|
||||||
backward_communication_cost = 10
|
backward_communication_cost = 100
|
||||||
|
|
||||||
if self.forward_only:
|
if self.forward_only:
|
||||||
cost_dict["forward"] = forward_communication_cost
|
cost_dict["forward"] = forward_communication_cost
|
||||||
|
|
Loading…
Reference in New Issue