mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] remove redundancy comm node (#1893)
parent
9183e0dec5
commit
36c0f3ea5b
|
@ -81,6 +81,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
|
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
|
||||||
|
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
|
||||||
|
continue
|
||||||
with mod_graph.inserting_before(user_node):
|
with mod_graph.inserting_before(user_node):
|
||||||
shape_consistency_node = mod_graph.create_node('call_function',
|
shape_consistency_node = mod_graph.create_node('call_function',
|
||||||
runtime_apply,
|
runtime_apply,
|
||||||
|
|
|
@ -47,6 +47,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
|
||||||
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
|
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
|
||||||
target_sharding_specs.append(target_sharding_spec)
|
target_sharding_specs.append(target_sharding_spec)
|
||||||
sharding_spec_convert_dict[index] = target_sharding_specs
|
sharding_spec_convert_dict[index] = target_sharding_specs
|
||||||
|
setattr(node, 'target_sharding_specs', target_sharding_specs)
|
||||||
# the get_attr node strategy is kind of pending strategy, which means we will change it
|
# the get_attr node strategy is kind of pending strategy, which means we will change it
|
||||||
# to the same strategy of the user node.
|
# to the same strategy of the user node.
|
||||||
if node.op == 'get_attr':
|
if node.op == 'get_attr':
|
||||||
|
@ -95,7 +96,8 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
|
||||||
"""
|
"""
|
||||||
mod_graph = gm.graph
|
mod_graph = gm.graph
|
||||||
nodes = tuple(mod_graph.nodes)
|
nodes = tuple(mod_graph.nodes)
|
||||||
|
# This stream is created for overlaping the communication and computation.
|
||||||
|
reduction_stream = torch.cuda.Stream()
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
if node.op == 'call_module':
|
if node.op == 'call_module':
|
||||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||||
|
@ -122,7 +124,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
|
||||||
def wrapper(param, comm_spec):
|
def wrapper(param, comm_spec):
|
||||||
|
|
||||||
def hook_fn(grad):
|
def hook_fn(grad):
|
||||||
_all_reduce(grad, comm_spec)
|
_all_reduce(grad, comm_spec, async_op=False)
|
||||||
|
|
||||||
param.register_hook(hook_fn)
|
param.register_hook(hook_fn)
|
||||||
|
|
||||||
|
@ -172,7 +174,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
|
||||||
def wrapper(param, comm_spec):
|
def wrapper(param, comm_spec):
|
||||||
|
|
||||||
def hook_fn(grad):
|
def hook_fn(grad):
|
||||||
_all_reduce(grad, comm_spec)
|
_all_reduce(grad, comm_spec, async_op=False)
|
||||||
|
|
||||||
param.register_hook(hook_fn)
|
param.register_hook(hook_fn)
|
||||||
|
|
||||||
|
|
|
@ -74,11 +74,13 @@ class NodeHandler(ABC):
|
||||||
if op_data.type == OperationDataType.PARAM:
|
if op_data.type == OperationDataType.PARAM:
|
||||||
resharding_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
|
resharding_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
|
||||||
else:
|
else:
|
||||||
|
dtype = op_data.data.dtype
|
||||||
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
_, _, resharding_cost = shape_consistency_manager.shape_consistency(
|
_, _, resharding_cost = shape_consistency_manager.shape_consistency(
|
||||||
prev_sharding_spec, current_sharding_spec)
|
prev_sharding_spec, current_sharding_spec)
|
||||||
resharding_cost = TrainCycleItem(fwd=resharding_cost["forward"],
|
resharding_cost = TrainCycleItem(fwd=resharding_cost["forward"] * size_per_elem_bytes,
|
||||||
bwd=resharding_cost["backward"],
|
bwd=resharding_cost["backward"] * size_per_elem_bytes,
|
||||||
total=resharding_cost["total"])
|
total=resharding_cost["total"] * size_per_elem_bytes)
|
||||||
resharding_costs[node].append(resharding_cost)
|
resharding_costs[node].append(resharding_cost)
|
||||||
strategy.resharding_costs = resharding_costs
|
strategy.resharding_costs = resharding_costs
|
||||||
return strategy
|
return strategy
|
||||||
|
|
|
@ -218,7 +218,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
||||||
sharding_spec=sharding_spec_mapping["output"],
|
sharding_spec=sharding_spec_mapping["output"],
|
||||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||||
logical_process_axis=mesh_dim_0,
|
logical_process_axis=mesh_dim_0,
|
||||||
comm_type=CommType.AFTER)
|
comm_type=CommType.IMPLICIT)
|
||||||
|
|
||||||
communication_action_mapping = {"output": output_comm_action}
|
communication_action_mapping = {"output": output_comm_action}
|
||||||
|
|
||||||
|
@ -254,7 +254,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
||||||
sharding_spec=sharding_spec_mapping["output"],
|
sharding_spec=sharding_spec_mapping["output"],
|
||||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||||
comm_type=CommType.AFTER)
|
comm_type=CommType.IMPLICIT)
|
||||||
|
|
||||||
communication_action_mapping = {"output": output_comm_action}
|
communication_action_mapping = {"output": output_comm_action}
|
||||||
|
|
||||||
|
@ -300,7 +300,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
||||||
sharding_spec=sharding_spec_mapping["output"],
|
sharding_spec=sharding_spec_mapping["output"],
|
||||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||||
logical_process_axis=[mesh_dim_0],
|
logical_process_axis=[mesh_dim_0],
|
||||||
comm_type=CommType.AFTER)
|
comm_type=CommType.IMPLICIT)
|
||||||
|
|
||||||
communication_action_mapping = {"output": output_comm_action}
|
communication_action_mapping = {"output": output_comm_action}
|
||||||
|
|
||||||
|
@ -331,14 +331,14 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
||||||
# TODO: The strategies below should be uncommented after runtime
|
# TODO: The strategies below should be uncommented after runtime
|
||||||
# passes ready.
|
# passes ready.
|
||||||
# SR = SR x R WITH SYNC_BN
|
# SR = SR x R WITH SYNC_BN
|
||||||
# strategy_list.append(self.split_input_batch(0))
|
strategy_list.append(self.split_input_batch(0))
|
||||||
# strategy_list.append(self.split_input_batch(1))
|
strategy_list.append(self.split_input_batch(1))
|
||||||
|
|
||||||
# SS = SS x S WITH SYNC_BN
|
# SS = SS x S WITH SYNC_BN
|
||||||
# strategy_list.append(self.split_input_both_dim(0, 1))
|
strategy_list.append(self.split_input_both_dim(0, 1))
|
||||||
# strategy_list.append(self.split_input_both_dim(1, 0))
|
strategy_list.append(self.split_input_both_dim(1, 0))
|
||||||
|
|
||||||
# S01R = S01R x R WITH SYNC_BN
|
# S01R = S01R x R WITH SYNC_BN
|
||||||
# strategy_list.append(self.split_input_batch_1d(0, 1))
|
strategy_list.append(self.split_input_batch_1d(0, 1))
|
||||||
|
|
||||||
return strategy_list
|
return strategy_list
|
||||||
|
|
|
@ -23,9 +23,7 @@ def _all_gather(tensor, comm_spec):
|
||||||
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
|
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
|
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
|
||||||
]
|
]
|
||||||
tensor = tensor
|
dist.all_gather(tensor_list, tensor, group=process_group)
|
||||||
group = process_group
|
|
||||||
dist.all_gather(tensor_list, tensor, group=group)
|
|
||||||
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
|
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -37,7 +35,6 @@ def _split(tensor, comm_spec):
|
||||||
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
|
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
|
||||||
for rank_list, _ in process_groups_list:
|
for rank_list, _ in process_groups_list:
|
||||||
if dist.get_rank() in rank_list:
|
if dist.get_rank() in rank_list:
|
||||||
tensor = tensor
|
|
||||||
dim = comm_spec.shard_dim
|
dim = comm_spec.shard_dim
|
||||||
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
|
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
|
||||||
start = length * rank_list.index(dist.get_rank())
|
start = length * rank_list.index(dist.get_rank())
|
||||||
|
@ -69,7 +66,7 @@ def _all_to_all(tensor, comm_spec):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _all_reduce(tensor, comm_spec):
|
def _all_reduce(tensor, comm_spec, async_op=False):
|
||||||
'''
|
'''
|
||||||
Implement all reduce operation on device mesh based on information provided by comm_spec.
|
Implement all reduce operation on device mesh based on information provided by comm_spec.
|
||||||
'''
|
'''
|
||||||
|
@ -78,7 +75,7 @@ def _all_reduce(tensor, comm_spec):
|
||||||
if dist.get_rank() in rank_list:
|
if dist.get_rank() in rank_list:
|
||||||
if not tensor.is_contiguous():
|
if not tensor.is_contiguous():
|
||||||
tensor = tensor.contiguous()
|
tensor = tensor.contiguous()
|
||||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group)
|
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue