mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] support origin activation ckpt on autoprallel system (#2468)
parent
3a21485ead
commit
67e1912b59
|
@ -128,6 +128,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
|||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
if 'activation_checkpoint' in user_node.meta:
|
||||
shape_consistency_node.meta['activation_checkpoint'] = user_node.meta['activation_checkpoint']
|
||||
|
||||
new_args = list(user_node.args)
|
||||
new_kwargs = dict(user_node.kwargs)
|
||||
|
@ -208,6 +210,37 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
|||
# substitute the origin node with comm_spec_apply_node
|
||||
new_kwargs[str(node)] = comm_spec_apply_node
|
||||
user.kwargs = new_kwargs
|
||||
|
||||
if 'activation_checkpoint' in node.meta:
|
||||
comm_spec_apply_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
def _act_annotataion_pass(gm: torch.fx.GraphModule):
|
||||
"""
|
||||
This pass is used to add the act annotation to the new inserted nodes.
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
nodes = tuple(mod_graph.nodes)
|
||||
|
||||
for node in nodes:
|
||||
if not hasattr(node.meta, 'activation_checkpoint'):
|
||||
from .runtime_preparation_pass import size_processing
|
||||
|
||||
user_act_annotation = -1
|
||||
input_act_annotation = -1
|
||||
for user_node in node.users.keys():
|
||||
if 'activation_checkpoint' in user_node.meta:
|
||||
user_act_annotation = user_node.meta['activation_checkpoint']
|
||||
break
|
||||
for input_node in node._input_nodes.keys():
|
||||
if 'activation_checkpoint' in input_node.meta:
|
||||
input_act_annotation = input_node.meta['activation_checkpoint']
|
||||
break
|
||||
if user_act_annotation == input_act_annotation and user_act_annotation != -1:
|
||||
node.meta['activation_checkpoint'] = user_act_annotation
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
|
|
|
@ -179,6 +179,8 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||
# It will be used to replace the original node with processing node in slice object
|
||||
node_pairs[node] = size_processing_node
|
||||
size_processing_node._meta_data = node._meta_data
|
||||
if 'activation_checkpoint' in node.meta:
|
||||
size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
|
||||
|
||||
user_list = list(node.users.keys())
|
||||
for user in user_list:
|
||||
|
|
|
@ -18,6 +18,7 @@ from colossalai.auto_parallel.tensor_shard.solver import (
|
|||
)
|
||||
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
@ -28,7 +29,7 @@ class ModuleWrapper(nn.Module):
|
|||
into the forward function.
|
||||
'''
|
||||
|
||||
def __init__(self, module: GraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]],
|
||||
def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]],
|
||||
origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]):
|
||||
'''
|
||||
Args:
|
||||
|
@ -81,7 +82,7 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
|
|||
return strategies_constructor
|
||||
|
||||
|
||||
def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
|
||||
def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
|
||||
'''
|
||||
This method is used to solve the best solution for the given graph.
|
||||
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
|
||||
|
@ -97,7 +98,7 @@ def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor,
|
|||
return solution
|
||||
|
||||
|
||||
def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh: DeviceMesh,
|
||||
def transform_to_sharded_model(gm: ColoGraphModule, solution: List[int], device_mesh: DeviceMesh,
|
||||
strategies_constructor: StrategiesConstructor):
|
||||
'''
|
||||
This method is used to transform the original graph to the sharded graph.
|
||||
|
@ -197,10 +198,10 @@ def initialize_model(model: nn.Module,
|
|||
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
|
||||
return a series of integers, but return the best strategies.
|
||||
'''
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(trace_act_ckpt=True)
|
||||
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
strategies_constructor = build_strategy_constructor(graph, device_mesh)
|
||||
if load_solver_solution:
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
from functools import partial
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.tracer import ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
|
||||
HIDDEN_SIZE = 16
|
||||
|
||||
|
||||
class GPT2MLPWithCkpt(nn.Module):
|
||||
|
||||
def __init__(self, intermediate_size, hidden_size):
|
||||
super().__init__()
|
||||
embed_dim = hidden_size
|
||||
self.c_fc = Conv1D(intermediate_size, embed_dim)
|
||||
self.c_proj = Conv1D(embed_dim, intermediate_size)
|
||||
self.act = torch.nn.ReLU()
|
||||
|
||||
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
|
||||
hidden_states = self.c_fc(hidden_states)
|
||||
hidden_states = checkpoint(self.c_proj, hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def check_act_ckpt(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = GPT2MLPWithCkpt(intermediate_size=4 * HIDDEN_SIZE, hidden_size=HIDDEN_SIZE)
|
||||
input_sample = {
|
||||
'hidden_states': torch.rand(1, 64, HIDDEN_SIZE).to('meta'),
|
||||
}
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
gm = initialize_model(model, input_sample, device_mesh)
|
||||
code = gm.module.graph.python_code('self').src
|
||||
assert "runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')" in code
|
||||
assert "view_3 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, view_1, comm_actions_dict, use_reentrant=True)" in code
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mlp_layer():
|
||||
world_size = 4
|
||||
run_func = partial(check_act_ckpt, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_mlp_layer()
|
Loading…
Reference in New Issue