[autoparallel] integrate_gpt_related_tests (#2134)

* [autoparallel] integrate_gpt_related_tests

* polish code

* polish code

* add GPT2Model into runtime test
pull/2184/head
YuliangLiu0306 2022-12-23 12:36:59 +08:00 committed by GitHub
parent 59e343328d
commit 550f8f8905
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 217 additions and 207 deletions

View File

@ -230,7 +230,12 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
new_slice_items = []
for slice_item in getitem_index:
if slice_item is None:
new_slice_items.append(None)
continue
new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step
if slice_item.start in node_pairs:
new_start = node_pairs[slice_item.start]
elif slice_item.stop in node_pairs:
@ -355,7 +360,10 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
for node in nodes:
if node.op == 'call_module':
target_module = node.graph.owning_module.get_submodule(node.target)
# TODO: we need to do more actions to take care of the shared parameters.
if hasattr(target_module, 'processed') and target_module.processed:
continue
setattr(target_module, 'processed', True)
for name, param in target_module.named_parameters():
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
# apply the sharding spec of parameters
@ -404,7 +412,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
target_module = root
target = getattr(root, atoms[0])
else:
target_module = root.get_submodule(atoms[-2])
target_module = root
for atom in atoms[:-1]:
target_module = getattr(target_module, atom)
target = getattr(target_module, atoms[-1])
target_sharding_spec = node.sharding_spec

View File

@ -2,32 +2,30 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import transformers
from torch.fx import GraphModule
from transformers.models.gpt2.modeling_gpt2 import (
GPT2MLP,
BaseModelOutputWithPastAndCrossAttentions,
GPT2PreTrainedModel,
)
from transformers.activations import ACT2FN
from transformers.models.gpt2.modeling_gpt2 import BaseModelOutputWithPastAndCrossAttentions, GPT2PreTrainedModel
from transformers.pytorch_utils import Conv1D
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing import parameterize
from colossalai.testing.pytest_wrapper import run_on_environment_flag
BATCH_SIZE = 1
SEQ_LENGTH = 32
HIDDEN_DIM = 768
class GPT2MLP(nn.Module):
def __init__(self, intermediate_size, config):
super().__init__()
embed_dim = config.hidden_size
self.c_fc = Conv1D(intermediate_size, embed_dim)
self.c_proj = Conv1D(embed_dim, intermediate_size)
self.act = ACT2FN[config.activation_function]
# We temporarily banned the Dropout layer because the rng state need
# to process to get the correct result.
# self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
# TODO: the rng state need to be fixed for distributed runtime
# hidden_states = self.dropout(hidden_states)
return hidden_states
# The reason Why we don't import GPT2Attention from transformers directly is that:
@ -89,7 +87,7 @@ class GPT2Attention(nn.Module):
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
# attn_weights = self.attn_dropout(attn_weights)
# Mask heads if we want to
if head_mask is not None:
@ -125,15 +123,10 @@ class GPT2Attention(nn.Module):
present = (key, value)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
outputs += (attn_weights,)
return outputs # a, present, (attentions)
# attn_output = self.resid_dropout(attn_output)
return attn_output
class GPT2Block(nn.Module):
@ -161,19 +154,15 @@ class GPT2Block(nn.Module):
attention_mask=attention_mask,
head_mask=head_mask,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
hidden_states = attn_output + residual
hidden_states = attn_outputs + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
outputs = (hidden_states,) + outputs[1:]
return outputs # hidden_states, present, (attentions, cross_attentions)
return hidden_states
class GPT2Model(GPT2PreTrainedModel):
@ -228,103 +217,25 @@ class GPT2Model(GPT2PreTrainedModel):
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
# add_2
hidden_states = inputs_embeds + position_embeds
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
# transformer_drop
hidden_states = self.drop(hidden_states)
# comment to run pipeline
# add_3
output_shape = input_shape + (hidden_states.size(-1),)
presents = None
all_self_attentions = None
all_cross_attentions = None
all_hidden_states = None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i])
hidden_states = outputs[0]
hidden_states = outputs
hidden_states = self.ln_f(hidden_states)
# comment to run pipeline
hidden_states = hidden_states.view(output_shape)
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None)
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
def test_self_attention_block(model_cls):
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
if model_cls == GPT2MLP:
model = model_cls(intermediate_size=4 * config.hidden_size, config=config)
else:
model = model_cls(config=config)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
if model_cls == GPT2MLP:
input_sample = {
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
}
elif model_cls in (GPT2Attention, GPT2Block):
input_sample = {
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'),
}
else:
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
input_sample = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
print(gm.graph)
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
ret = solver.call_solver_serialized_args()
strategies_list = solver.last_s_val
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
computation_cost = 0
communication_cost = 0
memory_cost = 0
for index, node in enumerate(nodes):
print(node.name, node.strategies_vector[strategies_list[index]].name)
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total
if isinstance(node_memory_cost, tuple):
node_memory_cost = node_memory_cost[0]
memory_cost += node_memory_cost.activation + node_memory_cost.parameter
print(f'computation cost is {computation_cost}')
print(f'communication cost is {communication_cost}')
print(f'memory cost is {memory_cost}')
if __name__ == '__main__':
test_self_attention_block()
return hidden_states

View File

@ -1,7 +1,7 @@
import copy
import random
from functools import partial
from typing import Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union
import numpy as np
import pytest
@ -10,13 +10,11 @@ import torch.multiprocessing as mp
import torch.nn as nn
import transformers
from torch.fx import GraphModule
from transformers.activations import ACT2FN
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
from transformers.pytorch_utils import Conv1D
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
@ -32,6 +30,7 @@ from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_glob
from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
BATCH_SIZE = 1
SEQ_LENGTH = 32
@ -46,36 +45,73 @@ torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class GPT2MLP(nn.Module):
def _check_module_grad(module: torch.nn.Module, origin_param_dict: Dict[str, torch.Tensor],
best_sharding_spec_dict: Dict[str, ShardingSpec]):
for name, param in module.named_parameters():
param_grad = param.grad
origin_param_grad = origin_param_dict[name].grad
atoms = name.split('.')
new_name = '_'.join(atoms)
if new_name in best_sharding_spec_dict:
param_sharding_spec = best_sharding_spec_dict[new_name]
grad_to_compare = copy.deepcopy(param_grad)
param_grad_global = to_global(grad_to_compare, param_sharding_spec)
def __init__(self, intermediate_size, config):
super().__init__()
embed_dim = config.hidden_size
self.c_fc = Conv1D(intermediate_size, embed_dim)
self.c_proj = Conv1D(embed_dim, intermediate_size)
self.act = ACT2FN[config.activation_function]
# We temporarily banned the Dropout layer because the rng state need
# to process to get the correct result.
# self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
# TODO: the rng state need to be fixed for distributed runtime
# hidden_states = self.dropout(hidden_states)
return hidden_states
try:
assert_close_loose(param_grad_global, origin_param_grad, rtol=1e-03, atol=1e-03)
except:
difference = param_grad_global - origin_param_grad
avg_diff = difference.abs().sum() / difference.numel()
assert avg_diff < 0.001
print(f'{name} param has {avg_diff} average difference')
def check_mlp_layer(rank, model_cls, world_size, port):
def check_attention_layer(rank, model_cls, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda')
input = torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('cuda')
config = transformers.GPT2Config(n_position=64, n_layer=1, n_head=16, n_embd=HIDDEN_DIM)
if model_cls == GPT2MLP:
model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda')
else:
model = model_cls(config=config).to('cuda')
test_model = copy.deepcopy(model)
test_input = copy.deepcopy(input)
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
hidden_states = torch.rand((BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM), dtype=torch.float32)
if model_cls == GPT2MLP:
input_sample = (hidden_states.to('cuda'),)
test_input_sample = copy.deepcopy(input_sample)
meta_input_sample = {
'hidden_states': hidden_states.to('meta'),
}
elif model_cls in (GPT2Attention, GPT2Block):
input_sample = (
hidden_states.to('cuda'),
attention_mask.to('cuda'),
)
test_input_sample = copy.deepcopy(input_sample)
meta_input_sample = {
'hidden_states': hidden_states.to('meta'),
'attention_mask': attention_mask.to('meta'),
}
else:
input_sample = (
input_ids.to('cuda'),
token_type_ids.to('cuda'),
attention_mask.to('cuda'),
)
test_input_sample = copy.deepcopy(input_sample)
meta_input_sample = {
'input_ids': input_ids.to('meta'),
'token_type_ids': token_type_ids.to('meta'),
'attention_mask': attention_mask.to('meta'),
}
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
@ -85,15 +121,10 @@ def check_mlp_layer(rank, model_cls, world_size, port):
tracer = ColoTracer()
input_sample = {
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
}
graph = tracer.trace(root=model, meta_args=input_sample)
print(graph)
graph = tracer.trace(root=model, meta_args=meta_input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
print(gm)
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions()
@ -110,71 +141,35 @@ def check_mlp_layer(rank, model_cls, world_size, port):
gm, solution, device_mesh, strategies_constructor)
gm = runtime_apply_pass(gm)
gm.recompile()
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
best_sharding_spec_dict = {}
for index, node in enumerate(nodes):
best_sharding_spec_dict[node.name] = node.sharding_spec
cuda_rng_state = torch.cuda.get_rng_state()
cpu_rng_state = torch.get_rng_state()
origin_output = test_model(test_input)
origin_output = test_model(*test_input_sample)
torch.cuda.set_rng_state(cuda_rng_state)
torch.set_rng_state(cpu_rng_state)
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
assert_close(output, origin_output, rtol=1e-03, atol=1e-04)
output = gm(*input_sample, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
assert_close(output, origin_output, rtol=1e-03, atol=1e-03)
#*******************backward starting*******************
cuda_rng_state = torch.cuda.get_rng_state()
cpu_rng_state = torch.get_rng_state()
output.sum().backward()
torch.set_rng_state(cpu_rng_state)
torch.cuda.set_rng_state(cuda_rng_state)
origin_output.sum().backward()
origin_param_dict = dict(test_model.named_parameters())
if rank == 0:
print("*******************backward starting*******************")
for name, param in model.named_parameters():
param_grad = param.grad
origin_param_grad = origin_param_dict[name].grad
origin_param_size = origin_param_grad.shape[-1]
print(name, param_grad, origin_param_grad)
if name == 'c_fc.bias':
assert_close_loose(param_grad,
origin_param_grad.narrow(0, 0, origin_param_size // 2),
rtol=1e-03,
atol=1e-03)
else:
assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03)
_check_module_grad(gm, origin_param_dict, best_sharding_spec_dict)
if rank == 0:
print("*******************backward finished*******************")
if rank == 1:
for name, param in model.named_parameters():
param_grad = param.grad
origin_param_grad = origin_param_dict[name].grad
origin_param_size = origin_param_grad.shape[-1]
if name == 'c_fc.bias':
assert_close_loose(param_grad,
origin_param_grad.narrow(0, origin_param_size // 2, origin_param_size // 2),
rtol=1e-03,
atol=1e-03)
else:
assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03)
if rank == 2:
for name, param in model.named_parameters():
param_grad = param.grad
origin_param_grad = origin_param_dict[name].grad
origin_param_size = origin_param_grad.shape[-1]
if name == 'c_fc.bias':
assert_close_loose(param_grad,
origin_param_grad.narrow(0, 0, origin_param_size // 2),
rtol=1e-03,
atol=1e-03)
else:
assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03)
if rank == 3:
for name, param in model.named_parameters():
param_grad = param.grad
origin_param_grad = origin_param_dict[name].grad
origin_param_size = origin_param_grad.shape[-1]
if name == 'c_fc.bias':
assert_close_loose(param_grad,
origin_param_grad.narrow(0, origin_param_size // 2, origin_param_size // 2),
rtol=1e-03,
atol=1e-03)
else:
assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03)
#*******************backward finished*******************
@ -202,11 +197,11 @@ def check_mlp_layer(rank, model_cls, world_size, port):
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@parameterize('model_cls', [GPT2MLP])
@parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model])
@rerun_if_address_is_in_use()
def test_mlp_layer(model_cls):
world_size = 4
run_func = partial(check_mlp_layer, model_cls=model_cls, world_size=world_size, port=free_port())
run_func = partial(check_attention_layer, model_cls=model_cls, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@ -0,0 +1,94 @@
import torch
import torch.nn as nn
import transformers
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing import parameterize
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
BATCH_SIZE = 1
SEQ_LENGTH = 32
HIDDEN_DIM = 768
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
def test_self_attention_block(model_cls):
config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
if model_cls == GPT2MLP:
model = model_cls(intermediate_size=4 * config.hidden_size, config=config)
else:
model = model_cls(config=config)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
if model_cls == GPT2MLP:
input_sample = {
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
}
elif model_cls in (GPT2Attention, GPT2Block):
input_sample = {
'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'),
}
else:
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
input_sample = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
print(gm.graph)
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
ret = solver.call_solver_serialized_args()
strategies_list = solver.last_s_val
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
computation_cost = 0
communication_cost = 0
memory_cost = 0
for index, node in enumerate(nodes):
print(node.name, node.strategies_vector[strategies_list[index]].name)
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total
if isinstance(node_memory_cost, tuple):
node_memory_cost = node_memory_cost[0]
memory_cost += node_memory_cost.activation + node_memory_cost.parameter
print(f'computation cost is {computation_cost}')
print(f'communication cost is {communication_cost}')
print(f'memory cost is {memory_cost}')
if __name__ == '__main__':
test_self_attention_block()