From 550f8f89056e47ff3328faf3a3eec761b7da8b76 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Fri, 23 Dec 2022 12:36:59 +0800 Subject: [PATCH] [autoparallel] integrate_gpt_related_tests (#2134) * [autoparallel] integrate_gpt_related_tests * polish code * polish code * add GPT2Model into runtime test --- .../passes/runtime_preparation_pass.py | 14 +- .../test_tensor_shard/test_gpt/__init__.py | 0 .../gpt_modules.py} | 147 +++------------ .../test_runtime_with_gpt_modules.py} | 169 +++++++++--------- .../test_gpt/test_solver_with_gpt_module.py | 94 ++++++++++ 5 files changed, 217 insertions(+), 207 deletions(-) create mode 100644 tests/test_auto_parallel/test_tensor_shard/test_gpt/__init__.py rename tests/test_auto_parallel/test_tensor_shard/{test_solver_with_gpt_related_module.py => test_gpt/gpt_modules.py} (64%) rename tests/test_auto_parallel/test_tensor_shard/{test_gptmlp_runtime.py => test_gpt/test_runtime_with_gpt_modules.py} (51%) create mode 100644 tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index 92916118b..0b898a43e 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -230,7 +230,12 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): new_slice_items = [] for slice_item in getitem_index: + if slice_item is None: + new_slice_items.append(None) + continue + new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step + if slice_item.start in node_pairs: new_start = node_pairs[slice_item.start] elif slice_item.stop in node_pairs: @@ -355,7 +360,10 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): for node in nodes: if node.op == 'call_module': target_module = node.graph.owning_module.get_submodule(node.target) - + # TODO: we need to do more actions to take care of the shared parameters. + if hasattr(target_module, 'processed') and target_module.processed: + continue + setattr(target_module, 'processed', True) for name, param in target_module.named_parameters(): target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) # apply the sharding spec of parameters @@ -404,7 +412,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): target_module = root target = getattr(root, atoms[0]) else: - target_module = root.get_submodule(atoms[-2]) + target_module = root + for atom in atoms[:-1]: + target_module = getattr(target_module, atom) target = getattr(target_module, atoms[-1]) target_sharding_spec = node.sharding_spec diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/__init__.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_related_module.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py similarity index 64% rename from tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_related_module.py rename to tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py index 82accebdb..b66ad1949 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_related_module.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py @@ -2,32 +2,30 @@ from typing import Optional, Tuple, Union import torch import torch.nn as nn -import transformers -from torch.fx import GraphModule -from transformers.models.gpt2.modeling_gpt2 import ( - GPT2MLP, - BaseModelOutputWithPastAndCrossAttentions, - GPT2PreTrainedModel, -) +from transformers.activations import ACT2FN +from transformers.models.gpt2.modeling_gpt2 import BaseModelOutputWithPastAndCrossAttentions, GPT2PreTrainedModel from transformers.pytorch_utils import Conv1D -from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP -from colossalai.auto_parallel.tensor_shard.solver import ( - CostGraph, - GraphAnalyser, - Solver, - SolverOptions, - StrategiesConstructor, -) -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.testing import parameterize -from colossalai.testing.pytest_wrapper import run_on_environment_flag -BATCH_SIZE = 1 -SEQ_LENGTH = 32 -HIDDEN_DIM = 768 +class GPT2MLP(nn.Module): + + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + # We temporarily banned the Dropout layer because the rng state need + # to process to get the correct result. + # self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + # TODO: the rng state need to be fixed for distributed runtime + # hidden_states = self.dropout(hidden_states) + return hidden_states # The reason Why we don't import GPT2Attention from transformers directly is that: @@ -89,7 +87,7 @@ class GPT2Attention(nn.Module): # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise attn_weights = attn_weights.type(value.dtype) - attn_weights = self.attn_dropout(attn_weights) + # attn_weights = self.attn_dropout(attn_weights) # Mask heads if we want to if head_mask is not None: @@ -125,15 +123,10 @@ class GPT2Attention(nn.Module): present = (key, value) attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) - attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(attn_output) - attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, present) - outputs += (attn_weights,) - - return outputs # a, present, (attentions) + # attn_output = self.resid_dropout(attn_output) + return attn_output class GPT2Block(nn.Module): @@ -161,19 +154,15 @@ class GPT2Block(nn.Module): attention_mask=attention_mask, head_mask=head_mask, ) - attn_output = attn_outputs[0] # output_attn: a, present, (attentions) - outputs = attn_outputs[1:] # residual connection - hidden_states = attn_output + residual + hidden_states = attn_outputs + residual residual = hidden_states hidden_states = self.ln_2(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states) # residual connection hidden_states = residual + feed_forward_hidden_states - outputs = (hidden_states,) + outputs[1:] - - return outputs # hidden_states, present, (attentions, cross_attentions) + return hidden_states class GPT2Model(GPT2PreTrainedModel): @@ -228,103 +217,25 @@ class GPT2Model(GPT2PreTrainedModel): # attention_probs has shape bsz x n_heads x N x N # head_mask has shape n_layer x batch x n_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) + # add_2 hidden_states = inputs_embeds + position_embeds token_type_embeds = self.wte(token_type_ids) hidden_states = hidden_states + token_type_embeds - # transformer_drop - hidden_states = self.drop(hidden_states) # comment to run pipeline # add_3 output_shape = input_shape + (hidden_states.size(-1),) - presents = None - all_self_attentions = None - all_cross_attentions = None - all_hidden_states = None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i]) - hidden_states = outputs[0] + hidden_states = outputs hidden_states = self.ln_f(hidden_states) # comment to run pipeline hidden_states = hidden_states.view(output_shape) - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] - if v is not None) - - -@run_on_environment_flag(name='AUTO_PARALLEL') -@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) -def test_self_attention_block(model_cls): - config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM) - if model_cls == GPT2MLP: - model = model_cls(intermediate_size=4 * config.hidden_size, config=config) - else: - model = model_cls(config=config) - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - shape_consistency_manager = ShapeConsistencyManager() - - tracer = ColoTracer() - if model_cls == GPT2MLP: - input_sample = { - 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), - } - elif model_cls in (GPT2Attention, GPT2Block): - input_sample = { - 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), - 'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'), - } - else: - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) - input_sample = {k: v.to('meta') for k, v in kwargs.items()} - - graph = tracer.trace(root=model, meta_args=input_sample) - - gm = GraphModule(model, graph, model.__class__.__name__) - print(gm.graph) - gm.recompile() - graph_analyser = GraphAnalyser(gm) - liveness_list = graph_analyser.liveness_analysis() - solver_options = SolverOptions() - strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - strategies_constructor.build_strategies_and_cost() - - cost_graph = CostGraph(strategies_constructor.leaf_strategies) - cost_graph.simplify_graph() - solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1) - ret = solver.call_solver_serialized_args() - strategies_list = solver.last_s_val - nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] - - computation_cost = 0 - communication_cost = 0 - memory_cost = 0 - for index, node in enumerate(nodes): - print(node.name, node.strategies_vector[strategies_list[index]].name) - computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total - communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total - node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total - if isinstance(node_memory_cost, tuple): - node_memory_cost = node_memory_cost[0] - memory_cost += node_memory_cost.activation + node_memory_cost.parameter - - print(f'computation cost is {computation_cost}') - print(f'communication cost is {communication_cost}') - print(f'memory cost is {memory_cost}') - - -if __name__ == '__main__': - test_self_attention_block() + return hidden_states diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gptmlp_runtime.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py similarity index 51% rename from tests/test_auto_parallel/test_tensor_shard/test_gptmlp_runtime.py rename to tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py index d573c6590..361c22d26 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gptmlp_runtime.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py @@ -1,7 +1,7 @@ import copy import random from functools import partial -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import numpy as np import pytest @@ -10,13 +10,11 @@ import torch.multiprocessing as mp import torch.nn as nn import transformers from torch.fx import GraphModule -from transformers.activations import ACT2FN -from transformers.models.gpt2.modeling_gpt2 import GPT2MLP -from transformers.pytorch_utils import Conv1D from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec from colossalai.auto_parallel.tensor_shard.solver import ( CostGraph, GraphAnalyser, @@ -32,6 +30,7 @@ from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_glob from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model BATCH_SIZE = 1 SEQ_LENGTH = 32 @@ -46,36 +45,73 @@ torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False -class GPT2MLP(nn.Module): +def _check_module_grad(module: torch.nn.Module, origin_param_dict: Dict[str, torch.Tensor], + best_sharding_spec_dict: Dict[str, ShardingSpec]): + for name, param in module.named_parameters(): + param_grad = param.grad + origin_param_grad = origin_param_dict[name].grad + atoms = name.split('.') + new_name = '_'.join(atoms) + if new_name in best_sharding_spec_dict: + param_sharding_spec = best_sharding_spec_dict[new_name] + grad_to_compare = copy.deepcopy(param_grad) + param_grad_global = to_global(grad_to_compare, param_sharding_spec) - def __init__(self, intermediate_size, config): - super().__init__() - embed_dim = config.hidden_size - self.c_fc = Conv1D(intermediate_size, embed_dim) - self.c_proj = Conv1D(embed_dim, intermediate_size) - self.act = ACT2FN[config.activation_function] - # We temporarily banned the Dropout layer because the rng state need - # to process to get the correct result. - # self.dropout = nn.Dropout(config.resid_pdrop) - - def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: - hidden_states = self.c_fc(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.c_proj(hidden_states) - # TODO: the rng state need to be fixed for distributed runtime - # hidden_states = self.dropout(hidden_states) - return hidden_states + try: + assert_close_loose(param_grad_global, origin_param_grad, rtol=1e-03, atol=1e-03) + except: + difference = param_grad_global - origin_param_grad + avg_diff = difference.abs().sum() / difference.numel() + assert avg_diff < 0.001 + print(f'{name} param has {avg_diff} average difference') -def check_mlp_layer(rank, model_cls, world_size, port): +def check_attention_layer(rank, model_cls, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM) - model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda') - input = torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('cuda') + config = transformers.GPT2Config(n_position=64, n_layer=1, n_head=16, n_embd=HIDDEN_DIM) + + if model_cls == GPT2MLP: + model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda') + else: + model = model_cls(config=config).to('cuda') test_model = copy.deepcopy(model) - test_input = copy.deepcopy(input) + + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + hidden_states = torch.rand((BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM), dtype=torch.float32) + + if model_cls == GPT2MLP: + input_sample = (hidden_states.to('cuda'),) + test_input_sample = copy.deepcopy(input_sample) + meta_input_sample = { + 'hidden_states': hidden_states.to('meta'), + } + elif model_cls in (GPT2Attention, GPT2Block): + input_sample = ( + hidden_states.to('cuda'), + attention_mask.to('cuda'), + ) + test_input_sample = copy.deepcopy(input_sample) + meta_input_sample = { + 'hidden_states': hidden_states.to('meta'), + 'attention_mask': attention_mask.to('meta'), + } + else: + input_sample = ( + input_ids.to('cuda'), + token_type_ids.to('cuda'), + attention_mask.to('cuda'), + ) + test_input_sample = copy.deepcopy(input_sample) + meta_input_sample = { + 'input_ids': input_ids.to('meta'), + 'token_type_ids': token_type_ids.to('meta'), + 'attention_mask': attention_mask.to('meta'), + } + physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) # [[0, 1] @@ -85,15 +121,10 @@ def check_mlp_layer(rank, model_cls, world_size, port): tracer = ColoTracer() - input_sample = { - 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), - } - - graph = tracer.trace(root=model, meta_args=input_sample) - print(graph) + graph = tracer.trace(root=model, meta_args=meta_input_sample) gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() - print(gm) + graph_analyser = GraphAnalyser(gm) liveness_list = graph_analyser.liveness_analysis() solver_options = SolverOptions() @@ -110,71 +141,35 @@ def check_mlp_layer(rank, model_cls, world_size, port): gm, solution, device_mesh, strategies_constructor) gm = runtime_apply_pass(gm) gm.recompile() + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + best_sharding_spec_dict = {} + for index, node in enumerate(nodes): + best_sharding_spec_dict[node.name] = node.sharding_spec + cuda_rng_state = torch.cuda.get_rng_state() cpu_rng_state = torch.get_rng_state() - origin_output = test_model(test_input) + origin_output = test_model(*test_input_sample) torch.cuda.set_rng_state(cuda_rng_state) torch.set_rng_state(cpu_rng_state) - output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict) - assert_close(output, origin_output, rtol=1e-03, atol=1e-04) + output = gm(*input_sample, sharding_spec_dict, origin_spec_dict, comm_actions_dict) + assert_close(output, origin_output, rtol=1e-03, atol=1e-03) #*******************backward starting******************* cuda_rng_state = torch.cuda.get_rng_state() + cpu_rng_state = torch.get_rng_state() output.sum().backward() + torch.set_rng_state(cpu_rng_state) torch.cuda.set_rng_state(cuda_rng_state) origin_output.sum().backward() origin_param_dict = dict(test_model.named_parameters()) + if rank == 0: print("*******************backward starting*******************") - for name, param in model.named_parameters(): - param_grad = param.grad - origin_param_grad = origin_param_dict[name].grad - origin_param_size = origin_param_grad.shape[-1] - print(name, param_grad, origin_param_grad) - if name == 'c_fc.bias': - assert_close_loose(param_grad, - origin_param_grad.narrow(0, 0, origin_param_size // 2), - rtol=1e-03, - atol=1e-03) - else: - assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03) + + _check_module_grad(gm, origin_param_dict, best_sharding_spec_dict) + + if rank == 0: print("*******************backward finished*******************") - if rank == 1: - for name, param in model.named_parameters(): - param_grad = param.grad - origin_param_grad = origin_param_dict[name].grad - origin_param_size = origin_param_grad.shape[-1] - if name == 'c_fc.bias': - assert_close_loose(param_grad, - origin_param_grad.narrow(0, origin_param_size // 2, origin_param_size // 2), - rtol=1e-03, - atol=1e-03) - else: - assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03) - if rank == 2: - for name, param in model.named_parameters(): - param_grad = param.grad - origin_param_grad = origin_param_dict[name].grad - origin_param_size = origin_param_grad.shape[-1] - if name == 'c_fc.bias': - assert_close_loose(param_grad, - origin_param_grad.narrow(0, 0, origin_param_size // 2), - rtol=1e-03, - atol=1e-03) - else: - assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03) - if rank == 3: - for name, param in model.named_parameters(): - param_grad = param.grad - origin_param_grad = origin_param_dict[name].grad - origin_param_size = origin_param_grad.shape[-1] - if name == 'c_fc.bias': - assert_close_loose(param_grad, - origin_param_grad.narrow(0, origin_param_size // 2, origin_param_size // 2), - rtol=1e-03, - atol=1e-03) - else: - assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03) #*******************backward finished******************* @@ -202,11 +197,11 @@ def check_mlp_layer(rank, model_cls, world_size, port): @run_on_environment_flag(name='AUTO_PARALLEL') @pytest.mark.dist -@parameterize('model_cls', [GPT2MLP]) +@parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model]) @rerun_if_address_is_in_use() def test_mlp_layer(model_cls): world_size = 4 - run_func = partial(check_mlp_layer, model_cls=model_cls, world_size=world_size, port=free_port()) + run_func = partial(check_attention_layer, model_cls=model_cls, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py new file mode 100644 index 000000000..478b77e76 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py @@ -0,0 +1,94 @@ +import torch +import torch.nn as nn +import transformers +from torch.fx import GraphModule + +from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.testing import parameterize +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model + +BATCH_SIZE = 1 +SEQ_LENGTH = 32 +HIDDEN_DIM = 768 + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) +def test_self_attention_block(model_cls): + config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM) + if model_cls == GPT2MLP: + model = model_cls(intermediate_size=4 * config.hidden_size, config=config) + else: + model = model_cls(config=config) + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + if model_cls == GPT2MLP: + input_sample = { + 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), + } + elif model_cls in (GPT2Attention, GPT2Block): + input_sample = { + 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), + 'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'), + } + else: + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + input_sample = {k: v.to('meta') for k, v in kwargs.items()} + + graph = tracer.trace(root=model, meta_args=input_sample) + + gm = GraphModule(model, graph, model.__class__.__name__) + print(gm.graph) + gm.recompile() + graph_analyser = GraphAnalyser(gm) + liveness_list = graph_analyser.liveness_analysis() + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + cost_graph.simplify_graph() + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1) + ret = solver.call_solver_serialized_args() + strategies_list = solver.last_s_val + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + + computation_cost = 0 + communication_cost = 0 + memory_cost = 0 + for index, node in enumerate(nodes): + print(node.name, node.strategies_vector[strategies_list[index]].name) + computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total + communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total + node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total + if isinstance(node_memory_cost, tuple): + node_memory_cost = node_memory_cost[0] + memory_cost += node_memory_cost.activation + node_memory_cost.parameter + + print(f'computation cost is {computation_cost}') + print(f'communication cost is {communication_cost}') + print(f'memory cost is {memory_cost}') + + +if __name__ == '__main__': + test_self_attention_block()