mirror of https://github.com/hpcaitech/ColossalAI
align evoformer
parent
86f2a31474
commit
820ea4d056
143
chunk_codegen.py
143
chunk_codegen.py
|
@ -1,5 +1,6 @@
|
||||||
import colossalai
|
import colossalai
|
||||||
import torch
|
import torch
|
||||||
|
import copy
|
||||||
from typing import List, Callable, Any, Tuple, Dict, Iterable
|
from typing import List, Callable, Any, Tuple, Dict, Iterable
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -17,74 +18,18 @@ else:
|
||||||
__all__ = ['python_code_with_activation_checkpoint']
|
__all__ = ['python_code_with_activation_checkpoint']
|
||||||
|
|
||||||
|
|
||||||
def _gen_saved_tensors_hooks():
|
def _gen_loop_start(to_keep, chunk_size=2):
|
||||||
"""
|
context = "chunk_result = []; chunk_size = %d\nfor gen_loop_idx in range(0, %s.shape[0], chunk_size):\n" % (chunk_size, to_keep[0])
|
||||||
Generate saved tensors hooks
|
context += " chunk_tensor = " + to_keep + "[gen_loop_idx:gen_loop_idx + chunk_size, :]\n"
|
||||||
"""
|
|
||||||
|
|
||||||
pack_hook = """def pack_hook_input(self, x):
|
|
||||||
if getattr(x, "offload", False):
|
|
||||||
return (x.device, x.cpu())
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
|
|
||||||
def pack_hook_no_input(self, x):
|
|
||||||
if getattr(x, "offload", True):
|
|
||||||
return (x.device, x.cpu())
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
"""
|
|
||||||
|
|
||||||
unpack_hook = """def unpack_hook(self, packed):
|
|
||||||
if isinstance(packed, tuple):
|
|
||||||
device, tensor = packed
|
|
||||||
return tensor.to(device)
|
|
||||||
else:
|
|
||||||
return packed
|
|
||||||
"""
|
|
||||||
|
|
||||||
return pack_hook, unpack_hook
|
|
||||||
|
|
||||||
|
|
||||||
def _gen_loop_5(to_keep):
|
|
||||||
context = "chunk_result = []\nfor gen_loop_idx in range(4):\n"
|
|
||||||
context += " chunk_tensor = " + to_keep + "[gen_loop_idx, :]\n"
|
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
def _gen_loop_5_final(final_name, to_keep):
|
def _gen_loop_end(final_name, to_keep):
|
||||||
context = " chunk_result.append(" + final_name + ")\n"
|
context = " chunk_result.append(" + final_name + ")\n"
|
||||||
context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n"
|
context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n"
|
||||||
context += final_name + " = chunk_result; chunk_result = None\n"
|
context += final_name + " = chunk_result; chunk_result = None\n"
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
def _gen_save_tensors_hooks_context(offload_input=True) -> str:
|
|
||||||
"""Generate customized saved_tensors_hooks
|
|
||||||
|
|
||||||
Args:
|
|
||||||
offload_input (bool, optional): whether we need offload input, if offload_input=False,
|
|
||||||
we will use self.pack_hook_no_input instead. Defaults to True.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: generated context
|
|
||||||
"""
|
|
||||||
|
|
||||||
if offload_input:
|
|
||||||
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):\n"
|
|
||||||
else:
|
|
||||||
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):\n"
|
|
||||||
return context
|
|
||||||
|
|
||||||
|
|
||||||
def _gen_save_on_cpu_context():
|
|
||||||
"""
|
|
||||||
Generate save on cpu context
|
|
||||||
"""
|
|
||||||
|
|
||||||
context = "with torch.autograd.graph.save_on_cpu(pin_memory=True):\n"
|
|
||||||
return context
|
|
||||||
|
|
||||||
|
|
||||||
def _find_input_and_output_nodes(nodes: List[Node]):
|
def _find_input_and_output_nodes(nodes: List[Node]):
|
||||||
"""
|
"""
|
||||||
|
@ -112,49 +57,6 @@ def _find_input_and_output_nodes(nodes: List[Node]):
|
||||||
return input_nodes, output_nodes
|
return input_nodes, output_nodes
|
||||||
|
|
||||||
|
|
||||||
def _find_ckpt_regions(nodes: List[Node]):
|
|
||||||
"""
|
|
||||||
Find the checkpoint regions given a list of consecutive nodes. The outputs will be list
|
|
||||||
of tuples, each tuple is in the form of (start_index, end_index).
|
|
||||||
"""
|
|
||||||
ckpt_nodes = []
|
|
||||||
ckpt_regions = []
|
|
||||||
start = -1
|
|
||||||
end = -1
|
|
||||||
current_region = None
|
|
||||||
|
|
||||||
for idx, node in enumerate(nodes):
|
|
||||||
if hasattr(node, 'activation_checkpoint'):
|
|
||||||
act_ckpt_label = node.activation_checkpoint
|
|
||||||
|
|
||||||
# this activation checkpoint label is not set yet
|
|
||||||
# meaning this is the first node of the activation ckpt region
|
|
||||||
if current_region is None:
|
|
||||||
current_region = act_ckpt_label
|
|
||||||
start = idx
|
|
||||||
|
|
||||||
# if activation checkpoint has changed
|
|
||||||
# we restart the tracking
|
|
||||||
# e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2]
|
|
||||||
if act_ckpt_label != current_region:
|
|
||||||
assert start != -1
|
|
||||||
ckpt_regions.append((start, idx - 1))
|
|
||||||
current_region = act_ckpt_label
|
|
||||||
start = idx
|
|
||||||
end = -1
|
|
||||||
elif current_region is not None and not hasattr(node, 'activation_checkpoint'):
|
|
||||||
# used to check the case below
|
|
||||||
# node ckpt states = [ckpt, ckpt, non-ckpt]
|
|
||||||
end = idx - 1
|
|
||||||
assert start != -1 and end != -1
|
|
||||||
ckpt_regions.append((start, end))
|
|
||||||
start = end = -1
|
|
||||||
current_region = None
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
return ckpt_regions
|
|
||||||
|
|
||||||
|
|
||||||
def _find_offload_regions(nodes: List[Node]):
|
def _find_offload_regions(nodes: List[Node]):
|
||||||
"""This function is to find the offload regions
|
"""This function is to find the offload regions
|
||||||
In pofo algorithm, during annotation, we will annotate the offload region with the
|
In pofo algorithm, during annotation, we will annotate the offload region with the
|
||||||
|
@ -400,12 +302,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
||||||
emit_node_func: function to emit node
|
emit_node_func: function to emit node
|
||||||
delete_unused_value_func: function to remove the unused value
|
delete_unused_value_func: function to remove the unused value
|
||||||
"""
|
"""
|
||||||
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
|
|
||||||
start_idx = [item[0] for item in ckpt_regions]
|
|
||||||
end_idx = [item[1] for item in ckpt_regions]
|
|
||||||
|
|
||||||
# find the offload regions
|
# find the offload regions
|
||||||
chunk_regions, chunk_labels = _find_offload_regions(nodes)
|
chunk_regions = [(1, 4)]
|
||||||
chunk_starts = [item[0] for item in chunk_regions]
|
chunk_starts = [item[0] for item in chunk_regions]
|
||||||
chunk_ends = [item[1] for item in chunk_regions]
|
chunk_ends = [item[1] for item in chunk_regions]
|
||||||
chunk_inputs = []
|
chunk_inputs = []
|
||||||
|
@ -424,7 +323,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
||||||
# this flag is to prevent repeated insert of save tensors
|
# this flag is to prevent repeated insert of save tensors
|
||||||
# hooks definition in ckpt_func
|
# hooks definition in ckpt_func
|
||||||
node_idx = 0
|
node_idx = 0
|
||||||
to_keep = []
|
chunk_var = []
|
||||||
while node_idx < len(node_list):
|
while node_idx < len(node_list):
|
||||||
# break if we finish the processing all the nodes
|
# break if we finish the processing all the nodes
|
||||||
if node_idx >= len(node_list):
|
if node_idx >= len(node_list):
|
||||||
|
@ -435,28 +334,30 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
||||||
node = node_list[node_idx]
|
node = node_list[node_idx]
|
||||||
|
|
||||||
if node_idx in chunk_starts:
|
if node_idx in chunk_starts:
|
||||||
# save chunk input var, dont delete it
|
|
||||||
to_keep.extend(node.args[0].name)
|
|
||||||
within_chunk_region = True
|
within_chunk_region = True
|
||||||
# add for loop
|
|
||||||
body.append(_gen_loop_5(to_keep[0]))
|
|
||||||
# change first node's input to new chunked var
|
|
||||||
node_args = list(node.args)
|
|
||||||
node_args[0] = 'chunk_tensor'
|
|
||||||
|
|
||||||
|
# save chunk input var, dont delete it
|
||||||
|
chunk_var.append(node.args[0].name)
|
||||||
|
|
||||||
|
# add for loop
|
||||||
|
body.append(_gen_loop_start(chunk_var[0]))
|
||||||
|
|
||||||
if within_chunk_region:
|
if within_chunk_region:
|
||||||
emit_node_func(node, body)
|
emit_node_func(node, body)
|
||||||
|
# replace input var with chunk var
|
||||||
|
if node_idx in chunk_starts:
|
||||||
|
body[-1] = body[-1].replace("("+ chunk_var[0] +")", '(chunk_tensor)')
|
||||||
body[-1] = ' ' + body[-1]
|
body[-1] = ' ' + body[-1]
|
||||||
delete_unused_value_func(node, body, to_keep)
|
delete_unused_value_func(node, body, chunk_var)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
emit_node_func(node, body)
|
emit_node_func(node, body)
|
||||||
if node_idx not in chunk_inputs:
|
if node_idx not in chunk_inputs:
|
||||||
delete_unused_value_func(node, body, to_keep)
|
delete_unused_value_func(node, body, chunk_var)
|
||||||
|
|
||||||
if node_idx in chunk_ends:
|
if node_idx in chunk_ends:
|
||||||
body.append(_gen_loop_5_final(node.name, to_keep))
|
body.append(_gen_loop_end(node.name, chunk_var))
|
||||||
to_keep = []
|
chunk_var = []
|
||||||
within_chunk_region = False
|
within_chunk_region = False
|
||||||
|
|
||||||
node_idx += 1
|
node_idx += 1
|
||||||
|
@ -580,9 +481,7 @@ if CODEGEN_AVAILABLE:
|
||||||
body.append('\n')
|
body.append('\n')
|
||||||
return
|
return
|
||||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||||
for n in nodes_to_delete:
|
nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep]
|
||||||
if n.name in to_keep:
|
|
||||||
nodes_to_delete.remove(n)
|
|
||||||
if len(nodes_to_delete):
|
if len(nodes_to_delete):
|
||||||
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
|
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
|
||||||
body.append(f'; {to_delete_str}\n')
|
body.append(f'; {to_delete_str}\n')
|
||||||
|
|
|
@ -9,60 +9,39 @@ import colossalai
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.fx.graph_module import ColoGraphModule
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
|
from evoformer.evoformer import evoformer_base
|
||||||
try:
|
from chunk_codegen import ChunkCodeGen
|
||||||
from chunk_codegen import ChunkCodeGen
|
with_codegen = True
|
||||||
with_codegen = True
|
|
||||||
except:
|
|
||||||
# fall back to older pytorch version
|
|
||||||
from chunk_codegen import python_code_with_activation_checkpoint
|
|
||||||
with_codegen = False
|
|
||||||
|
|
||||||
|
|
||||||
class MyNet(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.linear0 = torch.nn.Linear(4, 4)
|
|
||||||
self.linear1 = torch.nn.Linear(4, 4)
|
|
||||||
self.linear2 = torch.nn.Linear(4, 4)
|
|
||||||
self.linear3 = torch.nn.Linear(4, 4)
|
|
||||||
self.linear4 = torch.nn.Linear(4, 4)
|
|
||||||
self.linear5 = torch.nn.Linear(4, 4)
|
|
||||||
self.linear6 = torch.nn.Linear(4, 4)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.linear0(x)
|
|
||||||
x = self.linear1(x)
|
|
||||||
x = self.linear2(x)
|
|
||||||
x = self.linear3(x)
|
|
||||||
x = self.linear4(x)
|
|
||||||
x = self.linear5(x)
|
|
||||||
x = self.linear6(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool:
|
def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool:
|
||||||
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
|
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
|
||||||
if not torch.allclose(m_p.grad, gm_p.grad):
|
if m_p.grad is not None and not torch.allclose(m_p.grad, gm_p.grad):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.Tensor):
|
def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool:
|
||||||
|
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
|
||||||
|
if m_p.grad is not None and not torch.allclose(m_p.data, gm_p.data):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
|
||||||
# test forward
|
# test forward
|
||||||
non_fx_out = model(data)
|
non_fx_out = model(node.clone(), pair.clone())
|
||||||
fx_out = gm(data)
|
fx_out = gm(node.clone(), pair.clone())
|
||||||
print(non_fx_out.shape, fx_out.shape)
|
assert torch.equal(non_fx_out[0], fx_out[0]), "fx_out doesn't comply with original output"
|
||||||
assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output"
|
assert torch.equal(non_fx_out[1], fx_out[1]), "fx_out doesn't comply with original output"
|
||||||
|
|
||||||
# test barckward
|
# test barckward
|
||||||
loss0 = non_fx_out.sum()
|
# loss0 = non_fx_out[0].sum() + non_fx_out[1].sum()
|
||||||
loss0.backward()
|
# loss0.backward()
|
||||||
loss1 = fx_out.sum()
|
# loss1 = fx_out[0].sum() + fx_out[1].sum()
|
||||||
loss1.backward()
|
# loss1.backward()
|
||||||
assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one"
|
# assert _is_all_param_close(model, gm)
|
||||||
|
# assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one"
|
||||||
|
|
||||||
|
|
||||||
def _run_offload_codegen(rank):
|
def _run_offload_codegen(rank):
|
||||||
|
@ -70,30 +49,22 @@ def _run_offload_codegen(rank):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||||
|
|
||||||
# build model and input
|
# build model and input
|
||||||
model = MyNet().cuda()
|
model = evoformer_base().cuda()
|
||||||
data = torch.rand(4, 4).cuda()
|
node = torch.randn(1, 16, 32, 256).cuda()
|
||||||
|
pair = torch.randn(1, 32, 32, 128).cuda()
|
||||||
|
|
||||||
# trace the module and replace codegen
|
# trace the module and replace codegen
|
||||||
tracer = ColoTracer(trace_act_ckpt=True)
|
tracer = ColoTracer(trace_act_ckpt=True)
|
||||||
graph = tracer.trace(model)
|
graph = tracer.trace(model)
|
||||||
codegen = ChunkCodeGen()
|
# codegen = ChunkCodeGen()
|
||||||
graph.set_codegen(codegen)
|
# graph.set_codegen(codegen)
|
||||||
|
|
||||||
# annotate the activation offload part
|
# annotate the chunk part
|
||||||
# also annotate the activation_checkpoint so we could test both types
|
# for node in graph.nodes:
|
||||||
# of input offload
|
# if node.name == "linear0":
|
||||||
for node in graph.nodes:
|
# setattr(node, "activation_offload", [0, True, False])
|
||||||
if node.name == "linear0":
|
# if node.name == "linear1":
|
||||||
setattr(node, "activation_offload", [0, True, False])
|
# setattr(node, "activation_offload", [0, True, False])
|
||||||
if node.name == "linear1":
|
|
||||||
setattr(node, "activation_offload", [0, True, False])
|
|
||||||
# if node.name == "linear2":
|
|
||||||
# setattr(node, "activation_offload", [1, True, True])
|
|
||||||
# if node.name == "linear4":
|
|
||||||
# setattr(node, "activation_offload", [2, False, True])
|
|
||||||
# if node.name == "linear5":
|
|
||||||
# setattr(node, "activation_checkpoint", [0])
|
|
||||||
# setattr(node, "activation_offload", True)
|
|
||||||
|
|
||||||
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
@ -102,7 +73,7 @@ def _run_offload_codegen(rank):
|
||||||
code = graph.python_code("self").src
|
code = graph.python_code("self").src
|
||||||
print(code)
|
print(code)
|
||||||
|
|
||||||
_test_fwd_and_bwd(model, gm, data)
|
_test_fwd_and_bwd(model, gm, node, pair)
|
||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ class Evoformer(nn.Module):
|
||||||
super(Evoformer, self).__init__()
|
super(Evoformer, self).__init__()
|
||||||
|
|
||||||
self.blocks = nn.ModuleList()
|
self.blocks = nn.ModuleList()
|
||||||
for _ in range(3):
|
for _ in range(1):
|
||||||
self.blocks.append(EvoformerBlock(d_node, d_pair))
|
self.blocks.append(EvoformerBlock(d_node, d_pair))
|
||||||
|
|
||||||
def forward(self, node, pair):
|
def forward(self, node, pair):
|
||||||
|
@ -36,6 +36,11 @@ class Evoformer(nn.Module):
|
||||||
node, pair = b(node, pair)
|
node, pair = b(node, pair)
|
||||||
return node, pair
|
return node, pair
|
||||||
|
|
||||||
|
|
||||||
|
def evoformer_tiny():
|
||||||
|
return Evoformer(d_node=64, d_pair=32)
|
||||||
|
|
||||||
|
|
||||||
def evoformer_base():
|
def evoformer_base():
|
||||||
return Evoformer(d_node=256, d_pair=128)
|
return Evoformer(d_node=256, d_pair=128)
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ def bias_sigmod_ele(y, bias, z):
|
||||||
|
|
||||||
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
|
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
|
||||||
residual: torch.Tensor, prob: float) -> torch.Tensor:
|
residual: torch.Tensor, prob: float) -> torch.Tensor:
|
||||||
out = (x + bias) * F.dropout(dropmask, p=prob, training=True)
|
out = (x + bias) * F.dropout(dropmask, p=prob, training=False)
|
||||||
out = residual + out
|
out = residual + out
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,7 @@ class MSARowAttentionWithPairBias(nn.Module):
|
||||||
# b = rearrange(b, 'b q k h -> b h q k')
|
# b = rearrange(b, 'b q k h -> b h q k')
|
||||||
|
|
||||||
M = self.attention(M, b)
|
M = self.attention(M, b)
|
||||||
dropout_mask = torch.ones_like(M[:, 0:1, :, :], device=M.device, dtype=M.dtype)
|
dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype)
|
||||||
|
|
||||||
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop)
|
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop)
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ class TriangleMultiplicationOutgoing(nn.Module):
|
||||||
|
|
||||||
ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act)
|
ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act)
|
||||||
ab = self.output_projection(self.layernorm2(ab))
|
ab = self.output_projection(self.layernorm2(ab))
|
||||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
|
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||||
return bias_ele_dropout_residual(ab,
|
return bias_ele_dropout_residual(ab,
|
||||||
self.output_bias,
|
self.output_bias,
|
||||||
g,
|
g,
|
||||||
|
@ -97,7 +97,7 @@ class TriangleMultiplicationIncoming(nn.Module):
|
||||||
|
|
||||||
ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act)
|
ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act)
|
||||||
ab = self.output_projection(self.layernorm2(ab))
|
ab = self.output_projection(self.layernorm2(ab))
|
||||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
|
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||||
return bias_ele_dropout_residual(ab,
|
return bias_ele_dropout_residual(ab,
|
||||||
self.output_bias,
|
self.output_bias,
|
||||||
g,
|
g,
|
||||||
|
@ -134,7 +134,7 @@ class TriangleAttentionStartingNode(nn.Module):
|
||||||
|
|
||||||
Z = self.attention(Z, b)
|
Z = self.attention(Z, b)
|
||||||
|
|
||||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
|
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||||
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
|
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
|
||||||
|
|
||||||
|
|
||||||
|
@ -168,7 +168,7 @@ class TriangleAttentionEndingNode(nn.Module):
|
||||||
Z = self.attention(Z, b)
|
Z = self.attention(Z, b)
|
||||||
|
|
||||||
Z = Z.transpose(-2, -3)
|
Z = Z.transpose(-2, -3)
|
||||||
dropout_mask = torch.ones_like(Z[:, :, 0:1, :], device=Z.device, dtype=Z.dtype)
|
dropout_mask = torch.ones_like(Z[:, :, 0:1, :]).to(Z.device).to(Z.dtype)
|
||||||
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
|
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue