diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py index 9f95009d9..8a55829ea 100644 --- a/colossalai/auto_parallel/passes/runtime_apply_pass.py +++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py @@ -81,6 +81,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule): continue for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes): + if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0: + continue with mod_graph.inserting_before(user_node): shape_consistency_node = mod_graph.create_node('call_function', runtime_apply, diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index 614fb66f4..30b7be267 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -47,6 +47,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]): target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name)) target_sharding_specs.append(target_sharding_spec) sharding_spec_convert_dict[index] = target_sharding_specs + setattr(node, 'target_sharding_specs', target_sharding_specs) # the get_attr node strategy is kind of pending strategy, which means we will change it # to the same strategy of the user node. if node.op == 'get_attr': @@ -95,7 +96,8 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh): """ mod_graph = gm.graph nodes = tuple(mod_graph.nodes) - + # This stream is created for overlaping the communication and computation. + reduction_stream = torch.cuda.Stream() for node in nodes: if node.op == 'call_module': target_module = node.graph.owning_module.get_submodule(node.target) @@ -122,7 +124,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh): def wrapper(param, comm_spec): def hook_fn(grad): - _all_reduce(grad, comm_spec) + _all_reduce(grad, comm_spec, async_op=False) param.register_hook(hook_fn) @@ -172,7 +174,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh): def wrapper(param, comm_spec): def hook_fn(grad): - _all_reduce(grad, comm_spec) + _all_reduce(grad, comm_spec, async_op=False) param.register_hook(hook_fn) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index f576b4e4b..2d882fc09 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -74,11 +74,13 @@ class NodeHandler(ABC): if op_data.type == OperationDataType.PARAM: resharding_cost = TrainCycleItem(fwd=0, bwd=0, total=0) else: + dtype = op_data.data.dtype + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() _, _, resharding_cost = shape_consistency_manager.shape_consistency( prev_sharding_spec, current_sharding_spec) - resharding_cost = TrainCycleItem(fwd=resharding_cost["forward"], - bwd=resharding_cost["backward"], - total=resharding_cost["total"]) + resharding_cost = TrainCycleItem(fwd=resharding_cost["forward"] * size_per_elem_bytes, + bwd=resharding_cost["backward"] * size_per_elem_bytes, + total=resharding_cost["total"] * size_per_elem_bytes) resharding_costs[node].append(resharding_cost) strategy.resharding_costs = resharding_costs return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py index 6a81a7eaa..86f332d84 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py @@ -218,7 +218,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.AFTER) + comm_type=CommType.IMPLICIT) communication_action_mapping = {"output": output_comm_action} @@ -254,7 +254,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.AFTER) + comm_type=CommType.IMPLICIT) communication_action_mapping = {"output": output_comm_action} @@ -300,7 +300,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=[mesh_dim_0], - comm_type=CommType.AFTER) + comm_type=CommType.IMPLICIT) communication_action_mapping = {"output": output_comm_action} @@ -331,14 +331,14 @@ class BatchNormStrategyGenerator(StrategyGenerator): # TODO: The strategies below should be uncommented after runtime # passes ready. # SR = SR x R WITH SYNC_BN - # strategy_list.append(self.split_input_batch(0)) - # strategy_list.append(self.split_input_batch(1)) + strategy_list.append(self.split_input_batch(0)) + strategy_list.append(self.split_input_batch(1)) # SS = SS x S WITH SYNC_BN - # strategy_list.append(self.split_input_both_dim(0, 1)) - # strategy_list.append(self.split_input_both_dim(1, 0)) + strategy_list.append(self.split_input_both_dim(0, 1)) + strategy_list.append(self.split_input_both_dim(1, 0)) # S01R = S01R x R WITH SYNC_BN - # strategy_list.append(self.split_input_batch_1d(0, 1)) + strategy_list.append(self.split_input_batch_1d(0, 1)) return strategy_list diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index a0775d0bc..2910ea843 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -23,9 +23,7 @@ def _all_gather(tensor, comm_spec): torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]) ] - tensor = tensor - group = process_group - dist.all_gather(tensor_list, tensor, group=group) + dist.all_gather(tensor_list, tensor, group=process_group) output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() return output @@ -37,7 +35,6 @@ def _split(tensor, comm_spec): process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] for rank_list, _ in process_groups_list: if dist.get_rank() in rank_list: - tensor = tensor dim = comm_spec.shard_dim length = tensor.shape[comm_spec.shard_dim] // len(rank_list) start = length * rank_list.index(dist.get_rank()) @@ -69,7 +66,7 @@ def _all_to_all(tensor, comm_spec): return output -def _all_reduce(tensor, comm_spec): +def _all_reduce(tensor, comm_spec, async_op=False): ''' Implement all reduce operation on device mesh based on information provided by comm_spec. ''' @@ -78,7 +75,7 @@ def _all_reduce(tensor, comm_spec): if dist.get_rank() in rank_list: if not tensor.is_contiguous(): tensor = tensor.contiguous() - dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group) + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) return tensor