mirror of https://github.com/hpcaitech/ColossalAI
add overlap option (#2613)
parent
cb3d1bef62
commit
28398f1c70
|
@ -352,7 +352,7 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||
return gm
|
||||
|
||||
|
||||
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, overlap=False):
|
||||
"""
|
||||
Apply the sharding action to the module parameters and buffers following the
|
||||
instructions of solver solution.
|
||||
|
@ -387,15 +387,18 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||
# register hook to the parameters
|
||||
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
|
||||
|
||||
def wrapper(param, comm_spec, stream):
|
||||
def wrapper(param, comm_spec, stream, overlap):
|
||||
|
||||
def hook_fn(grad):
|
||||
with torch.cuda.stream(stream):
|
||||
_all_reduce(grad, comm_spec, async_op=True)
|
||||
if overlap:
|
||||
with torch.cuda.stream(stream):
|
||||
_all_reduce(grad, comm_spec, async_op=True)
|
||||
else:
|
||||
_all_reduce(grad, comm_spec, async_op=False)
|
||||
|
||||
param.register_hook(hook_fn)
|
||||
|
||||
wrapper(param, comm_spec_to_use, reduction_stream)
|
||||
wrapper(param, comm_spec_to_use, reduction_stream, overlap=overlap)
|
||||
|
||||
sharded_buffer_dict = {}
|
||||
# apply the sharding spec of buffers
|
||||
|
@ -441,15 +444,18 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||
# register hook to the parameters
|
||||
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
|
||||
|
||||
def wrapper(param, comm_spec, stream):
|
||||
def wrapper(param, comm_spec, stream, overlap):
|
||||
|
||||
def hook_fn(grad):
|
||||
with torch.cuda.stream(stream):
|
||||
_all_reduce(grad, comm_spec, async_op=True)
|
||||
if overlap:
|
||||
with torch.cuda.stream(stream):
|
||||
_all_reduce(grad, comm_spec, async_op=True)
|
||||
else:
|
||||
_all_reduce(grad, comm_spec, async_op=False)
|
||||
|
||||
param.register_hook(hook_fn)
|
||||
|
||||
wrapper(target, comm_spec_to_use, reduction_stream)
|
||||
wrapper(target, comm_spec_to_use, reduction_stream, overlap=overlap)
|
||||
return gm
|
||||
|
||||
|
||||
|
@ -463,13 +469,14 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
|
|||
def runtime_preparation_pass(gm: torch.fx.GraphModule,
|
||||
solution: List[int],
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_constructor: StrategiesConstructor = None):
|
||||
strategies_constructor: StrategiesConstructor = None,
|
||||
overlap=False):
|
||||
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
|
||||
gm, solution, strategies_constructor)
|
||||
gm = _size_value_converting(gm, device_mesh)
|
||||
gm = _node_args_converting(gm, device_mesh)
|
||||
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
|
||||
# gm = implicit_comm_action_apply(gm)
|
||||
gm = _module_params_sharding(gm, device_mesh)
|
||||
gm = _module_params_sharding(gm, device_mesh, overlap=overlap)
|
||||
|
||||
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
|
||||
|
|
|
@ -98,16 +98,22 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
|
|||
return solution
|
||||
|
||||
|
||||
def transform_to_sharded_model(gm: ColoGraphModule, solution: List[int], device_mesh: DeviceMesh,
|
||||
strategies_constructor: StrategiesConstructor):
|
||||
def transform_to_sharded_model(gm: ColoGraphModule,
|
||||
solution: List[int],
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
overlap: bool = False):
|
||||
'''
|
||||
This method is used to transform the original graph to the sharded graph.
|
||||
The model parameters will be sharded according to the solution and the grad hooks
|
||||
will be added to the sharded graph using the runtime_preparation_pass.
|
||||
The communication node will be added into the graph using the runtime_apply_pass.
|
||||
'''
|
||||
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
|
||||
gm, solution, device_mesh, strategies_constructor)
|
||||
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm,
|
||||
solution,
|
||||
device_mesh,
|
||||
strategies_constructor,
|
||||
overlap=overlap)
|
||||
gm = runtime_apply_pass(gm)
|
||||
gm.recompile()
|
||||
sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||
|
@ -176,6 +182,7 @@ def initialize_model(model: nn.Module,
|
|||
meta_args: Dict[str, torch.Tensor],
|
||||
device_mesh: DeviceMesh,
|
||||
memory_budget: float = -1.0,
|
||||
overlap: bool = False,
|
||||
save_solver_solution: bool = False,
|
||||
load_solver_solution: bool = False,
|
||||
solution_path: str = None,
|
||||
|
@ -189,6 +196,8 @@ def initialize_model(model: nn.Module,
|
|||
device_mesh: the device mesh to execute the model.
|
||||
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
|
||||
the memory budget will be infinity.
|
||||
overlap(optional): the overlap is used to specify whether to overlap gradient communication and
|
||||
backward computing.
|
||||
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
|
||||
to the solution_path.
|
||||
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
|
||||
|
@ -211,7 +220,7 @@ def initialize_model(model: nn.Module,
|
|||
if save_solver_solution:
|
||||
torch.save(solution, solution_path)
|
||||
|
||||
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor)
|
||||
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor, overlap)
|
||||
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
|
||||
|
||||
if return_solution:
|
||||
|
|
Loading…
Reference in New Issue