[autoparallel] gpt2lp runtimee test (#2113)

pull/2123/head
YuliangLiu0306 2022-12-12 18:06:40 +08:00 committed by GitHub
parent 9214d1fe28
commit cd0af9f7f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 261 additions and 25 deletions

View File

@ -11,6 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationDataType,
ShardingStrategy,
)
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.comm_spec import _all_reduce
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
@ -19,13 +20,23 @@ from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager = ShapeConsistencyManager()
def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
def _solution_annotatation(gm: torch.fx.GraphModule,
solution: List[int],
strategies_constructor: StrategiesConstructor = None):
"""
This method is used to stick the solution strategy to the nodes and add the information
required in runtime into graph as placeholder nodes.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
# TODO: In future PR, strategies_constructor should be a required argument,
# instead of optional argument. This is because we don't need to consider nodes with
# no strategy in runtime preparation pass.
if strategies_constructor is not None:
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
no_strategy_nodes = strategies_constructor.no_strategy_nodes
else:
nodes = tuple(mod_graph.nodes)
no_strategy_nodes = []
# the dict to get origin sharding spec of node
origin_node_sharding_spec_dict = {}
@ -44,7 +55,10 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
for index, node in enumerate(nodes):
target_sharding_specs = []
for user_node in node.strategies_vector.successor_nodes:
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
if user_node in no_strategy_nodes:
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(str(node.name))
else:
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
target_sharding_specs.append(target_sharding_spec)
sharding_spec_convert_dict[index] = target_sharding_specs
setattr(node, 'target_sharding_specs', target_sharding_specs)
@ -136,13 +150,17 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
new_args.append(arg)
for dim, shard_dims in output_dim_partition_dict.items():
# we will skip the dim with -1 value
if new_args[dim + 1] == -1:
continue
total_shard_size = 1
for shard_dim in shard_dims:
total_shard_size *= device_mesh.shape[shard_dim]
new_args[dim + 1] //= total_shard_size
# There are two ways to use torch.view:
# 1. torch.view(input, *shape)
# 2. torch.view(input, shape)
if isinstance(new_args[1], int):
new_args[dim + 1] //= total_shard_size
else:
new_args[1] = list(new_args[1])
new_args[1][dim] //= total_shard_size
node.args = tuple(new_args)
elif node.op == 'call_function':
@ -193,12 +211,12 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
setattr(param, 'sharding_spec', origin_sharding_spec)
# TODO: build a ColoParamter class to manager the distributed parameters
param_sharded = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
target_sharding_spec).detach().clone())
else:
param_sharded = param
setattr(target_module, name, param_sharded)
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
param.data = shape_consistency_manager.apply_for_autoparallel_runtime(
param.data, param.sharding_spec, target_sharding_spec).detach().clone()
setattr(target_module, name, param)
comm_actions = node.best_strategy.communication_actions
for operation_data, comm_action in comm_actions.items():
comm_spec_to_use = comm_action.comm_spec
@ -212,7 +230,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
param.register_hook(hook_fn)
wrapper(param_sharded, comm_spec_to_use)
wrapper(param, comm_spec_to_use)
sharded_buffer_dict = {}
# apply the sharding spec of buffers
@ -242,12 +260,13 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {})
setattr(target, 'sharding_spec', origin_sharding_spec)
# TODO: build a ColoParamter class to manager the distributed parameters
target_sharded = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(target.data, target.sharding_spec,
target_sharding_spec).detach().clone())
else:
target_sharded = target
setattr(target_module, atoms[-1], target_sharded)
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
target.data = shape_consistency_manager.apply_for_autoparallel_runtime(
target.data, target.sharding_spec, target_sharding_spec).detach().clone()
assert hasattr(target_module, atoms[-1])
setattr(target_module, atoms[-1], target)
comm_actions = node.best_strategy.communication_actions
for operation_data, comm_action in comm_actions.items():
@ -262,7 +281,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
param.register_hook(hook_fn)
wrapper(target_sharded, comm_spec_to_use)
wrapper(target, comm_spec_to_use)
return gm
@ -273,9 +292,12 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
pass
def runtime_preparation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh: DeviceMesh):
def runtime_preparation_pass(gm: torch.fx.GraphModule,
solution: List[int],
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor = None):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
gm, solution)
gm, solution, strategies_constructor)
gm = _node_args_converting(gm, device_mesh)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# gm = implicit_comm_action_apply(gm)

View File

@ -41,6 +41,7 @@ class StrategiesConstructor:
self.leaf_strategies = []
self.strategy_map = {}
self.solver_options = solver_options
self.no_strategy_nodes = []
def remove_duplicated_strategy(self, strategies_vector):
'''
@ -78,12 +79,11 @@ class StrategiesConstructor:
return _check_no_strategy_for_data(node._meta_data)
no_strategy_node = []
for node in self.nodes:
strategies_vector = StrategiesVector(node)
if _check_no_strategy_for_node(node):
no_strategy_node.append(node)
self.no_strategy_nodes.append(node)
pass
# placeholder node

View File

@ -0,0 +1,214 @@
import copy
import random
from functools import partial
from typing import Optional, Tuple, Union
import numpy as np
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
from torch.fx import GraphModule
from transformers.activations import ACT2FN
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
from transformers.pytorch_utils import Conv1D
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_global
from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
BATCH_SIZE = 1
SEQ_LENGTH = 32
HIDDEN_DIM = 768
seed = 128
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class GPT2MLP(nn.Module):
def __init__(self, intermediate_size, config):
super().__init__()
embed_dim = config.hidden_size
self.c_fc = Conv1D(intermediate_size, embed_dim)
self.c_proj = Conv1D(embed_dim, intermediate_size)
self.act = ACT2FN[config.activation_function]
# We temporarily banned the Dropout layer because the rng state need
# to process to get the correct result.
# self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
# TODO: the rng state need to be fixed for distributed runtime
# hidden_states = self.dropout(hidden_states)
return hidden_states
def check_mlp_layer(rank, model_cls, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda')
input = torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('cuda')
test_model = copy.deepcopy(model)
test_input = copy.deepcopy(input)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
input_sample = {
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
}
graph = tracer.trace(root=model, meta_args=input_sample)
print(graph)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
print(gm)
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
gm, solution, device_mesh, strategies_constructor)
gm = runtime_apply_pass(gm)
gm.recompile()
cuda_rng_state = torch.cuda.get_rng_state()
cpu_rng_state = torch.get_rng_state()
origin_output = test_model(test_input)
torch.cuda.set_rng_state(cuda_rng_state)
torch.set_rng_state(cpu_rng_state)
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
assert_close(output, origin_output, rtol=1e-03, atol=1e-04)
#*******************backward starting*******************
cuda_rng_state = torch.cuda.get_rng_state()
output.sum().backward()
torch.cuda.set_rng_state(cuda_rng_state)
origin_output.sum().backward()
origin_param_dict = dict(test_model.named_parameters())
if rank == 0:
print("*******************backward starting*******************")
for name, param in model.named_parameters():
param_grad = param.grad
origin_param_grad = origin_param_dict[name].grad
origin_param_size = origin_param_grad.shape[-1]
print(name, param_grad, origin_param_grad)
if name == 'c_fc.bias':
assert_close_loose(param_grad,
origin_param_grad.narrow(0, 0, origin_param_size // 2),
rtol=1e-03,
atol=1e-03)
else:
assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03)
print("*******************backward finished*******************")
if rank == 1:
for name, param in model.named_parameters():
param_grad = param.grad
origin_param_grad = origin_param_dict[name].grad
origin_param_size = origin_param_grad.shape[-1]
if name == 'c_fc.bias':
assert_close_loose(param_grad,
origin_param_grad.narrow(0, origin_param_size // 2, origin_param_size // 2),
rtol=1e-03,
atol=1e-03)
else:
assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03)
if rank == 2:
for name, param in model.named_parameters():
param_grad = param.grad
origin_param_grad = origin_param_dict[name].grad
origin_param_size = origin_param_grad.shape[-1]
if name == 'c_fc.bias':
assert_close_loose(param_grad,
origin_param_grad.narrow(0, 0, origin_param_size // 2),
rtol=1e-03,
atol=1e-03)
else:
assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03)
if rank == 3:
for name, param in model.named_parameters():
param_grad = param.grad
origin_param_grad = origin_param_dict[name].grad
origin_param_size = origin_param_grad.shape[-1]
if name == 'c_fc.bias':
assert_close_loose(param_grad,
origin_param_grad.narrow(0, origin_param_size // 2, origin_param_size // 2),
rtol=1e-03,
atol=1e-03)
else:
assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03)
#*******************backward finished*******************
#*******************strategy selected*******************
if rank == 0:
print("*******************strategy selected*******************")
strategies_list = solver.last_s_val
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
computation_cost = 0
communication_cost = 0
memory_cost = 0
for index, node in enumerate(nodes):
print(node.name, node.strategies_vector[strategies_list[index]].name)
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total
if isinstance(node_memory_cost, tuple):
node_memory_cost = node_memory_cost[0]
memory_cost += node_memory_cost.activation + node_memory_cost.parameter
print(f'computation cost is {computation_cost}')
print(f'communication cost is {communication_cost}')
print(f'memory cost is {memory_cost}')
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@parameterize('model_cls', [GPT2MLP])
@rerun_if_address_is_in_use()
def test_mlp_layer(model_cls):
world_size = 4
run_func = partial(check_mlp_layer, model_cls=model_cls, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_mlp_layer()