align evoformer

pull/2364/head
oahzxl 2022-11-02 15:49:25 +08:00
parent 86f2a31474
commit 820ea4d056
6 changed files with 66 additions and 191 deletions

View File

@ -1,5 +1,6 @@
import colossalai import colossalai
import torch import torch
import copy
from typing import List, Callable, Any, Tuple, Dict, Iterable from typing import List, Callable, Any, Tuple, Dict, Iterable
try: try:
@ -17,74 +18,18 @@ else:
__all__ = ['python_code_with_activation_checkpoint'] __all__ = ['python_code_with_activation_checkpoint']
def _gen_saved_tensors_hooks(): def _gen_loop_start(to_keep, chunk_size=2):
""" context = "chunk_result = []; chunk_size = %d\nfor gen_loop_idx in range(0, %s.shape[0], chunk_size):\n" % (chunk_size, to_keep[0])
Generate saved tensors hooks context += " chunk_tensor = " + to_keep + "[gen_loop_idx:gen_loop_idx + chunk_size, :]\n"
"""
pack_hook = """def pack_hook_input(self, x):
if getattr(x, "offload", False):
return (x.device, x.cpu())
else:
return x
def pack_hook_no_input(self, x):
if getattr(x, "offload", True):
return (x.device, x.cpu())
else:
return x
"""
unpack_hook = """def unpack_hook(self, packed):
if isinstance(packed, tuple):
device, tensor = packed
return tensor.to(device)
else:
return packed
"""
return pack_hook, unpack_hook
def _gen_loop_5(to_keep):
context = "chunk_result = []\nfor gen_loop_idx in range(4):\n"
context += " chunk_tensor = " + to_keep + "[gen_loop_idx, :]\n"
return context return context
def _gen_loop_5_final(final_name, to_keep): def _gen_loop_end(final_name, to_keep):
context = " chunk_result.append(" + final_name + ")\n" context = " chunk_result.append(" + final_name + ")\n"
context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n" context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n"
context += final_name + " = chunk_result; chunk_result = None\n" context += final_name + " = chunk_result; chunk_result = None\n"
return context return context
def _gen_save_tensors_hooks_context(offload_input=True) -> str:
"""Generate customized saved_tensors_hooks
Args:
offload_input (bool, optional): whether we need offload input, if offload_input=False,
we will use self.pack_hook_no_input instead. Defaults to True.
Returns:
str: generated context
"""
if offload_input:
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):\n"
else:
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):\n"
return context
def _gen_save_on_cpu_context():
"""
Generate save on cpu context
"""
context = "with torch.autograd.graph.save_on_cpu(pin_memory=True):\n"
return context
def _find_input_and_output_nodes(nodes: List[Node]): def _find_input_and_output_nodes(nodes: List[Node]):
""" """
@ -112,49 +57,6 @@ def _find_input_and_output_nodes(nodes: List[Node]):
return input_nodes, output_nodes return input_nodes, output_nodes
def _find_ckpt_regions(nodes: List[Node]):
"""
Find the checkpoint regions given a list of consecutive nodes. The outputs will be list
of tuples, each tuple is in the form of (start_index, end_index).
"""
ckpt_nodes = []
ckpt_regions = []
start = -1
end = -1
current_region = None
for idx, node in enumerate(nodes):
if hasattr(node, 'activation_checkpoint'):
act_ckpt_label = node.activation_checkpoint
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
if current_region is None:
current_region = act_ckpt_label
start = idx
# if activation checkpoint has changed
# we restart the tracking
# e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2]
if act_ckpt_label != current_region:
assert start != -1
ckpt_regions.append((start, idx - 1))
current_region = act_ckpt_label
start = idx
end = -1
elif current_region is not None and not hasattr(node, 'activation_checkpoint'):
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
end = idx - 1
assert start != -1 and end != -1
ckpt_regions.append((start, end))
start = end = -1
current_region = None
else:
pass
return ckpt_regions
def _find_offload_regions(nodes: List[Node]): def _find_offload_regions(nodes: List[Node]):
"""This function is to find the offload regions """This function is to find the offload regions
In pofo algorithm, during annotation, we will annotate the offload region with the In pofo algorithm, during annotation, we will annotate the offload region with the
@ -400,12 +302,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
emit_node_func: function to emit node emit_node_func: function to emit node
delete_unused_value_func: function to remove the unused value delete_unused_value_func: function to remove the unused value
""" """
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
# find the offload regions # find the offload regions
chunk_regions, chunk_labels = _find_offload_regions(nodes) chunk_regions = [(1, 4)]
chunk_starts = [item[0] for item in chunk_regions] chunk_starts = [item[0] for item in chunk_regions]
chunk_ends = [item[1] for item in chunk_regions] chunk_ends = [item[1] for item in chunk_regions]
chunk_inputs = [] chunk_inputs = []
@ -424,7 +323,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
# this flag is to prevent repeated insert of save tensors # this flag is to prevent repeated insert of save tensors
# hooks definition in ckpt_func # hooks definition in ckpt_func
node_idx = 0 node_idx = 0
to_keep = [] chunk_var = []
while node_idx < len(node_list): while node_idx < len(node_list):
# break if we finish the processing all the nodes # break if we finish the processing all the nodes
if node_idx >= len(node_list): if node_idx >= len(node_list):
@ -435,28 +334,30 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
node = node_list[node_idx] node = node_list[node_idx]
if node_idx in chunk_starts: if node_idx in chunk_starts:
# save chunk input var, dont delete it
to_keep.extend(node.args[0].name)
within_chunk_region = True within_chunk_region = True
# add for loop
body.append(_gen_loop_5(to_keep[0]))
# change first node's input to new chunked var
node_args = list(node.args)
node_args[0] = 'chunk_tensor'
# save chunk input var, dont delete it
chunk_var.append(node.args[0].name)
# add for loop
body.append(_gen_loop_start(chunk_var[0]))
if within_chunk_region: if within_chunk_region:
emit_node_func(node, body) emit_node_func(node, body)
# replace input var with chunk var
if node_idx in chunk_starts:
body[-1] = body[-1].replace("("+ chunk_var[0] +")", '(chunk_tensor)')
body[-1] = ' ' + body[-1] body[-1] = ' ' + body[-1]
delete_unused_value_func(node, body, to_keep) delete_unused_value_func(node, body, chunk_var)
else: else:
emit_node_func(node, body) emit_node_func(node, body)
if node_idx not in chunk_inputs: if node_idx not in chunk_inputs:
delete_unused_value_func(node, body, to_keep) delete_unused_value_func(node, body, chunk_var)
if node_idx in chunk_ends: if node_idx in chunk_ends:
body.append(_gen_loop_5_final(node.name, to_keep)) body.append(_gen_loop_end(node.name, chunk_var))
to_keep = [] chunk_var = []
within_chunk_region = False within_chunk_region = False
node_idx += 1 node_idx += 1
@ -580,9 +481,7 @@ if CODEGEN_AVAILABLE:
body.append('\n') body.append('\n')
return return
nodes_to_delete = user_to_last_uses.get(user, []) nodes_to_delete = user_to_last_uses.get(user, [])
for n in nodes_to_delete: nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep]
if n.name in to_keep:
nodes_to_delete.remove(n)
if len(nodes_to_delete): if len(nodes_to_delete):
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
body.append(f'; {to_delete_str}\n') body.append(f'; {to_delete_str}\n')

View File

@ -9,60 +9,39 @@ import colossalai
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from evoformer.evoformer import evoformer_base
try: from chunk_codegen import ChunkCodeGen
from chunk_codegen import ChunkCodeGen with_codegen = True
with_codegen = True
except:
# fall back to older pytorch version
from chunk_codegen import python_code_with_activation_checkpoint
with_codegen = False
class MyNet(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear0 = torch.nn.Linear(4, 4)
self.linear1 = torch.nn.Linear(4, 4)
self.linear2 = torch.nn.Linear(4, 4)
self.linear3 = torch.nn.Linear(4, 4)
self.linear4 = torch.nn.Linear(4, 4)
self.linear5 = torch.nn.Linear(4, 4)
self.linear6 = torch.nn.Linear(4, 4)
def forward(self, x):
x = self.linear0(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = self.linear4(x)
x = self.linear5(x)
x = self.linear6(x)
return x
def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool: def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool:
for m_p, gm_p in zip(m.parameters(), gm.parameters()): for m_p, gm_p in zip(m.parameters(), gm.parameters()):
if not torch.allclose(m_p.grad, gm_p.grad): if m_p.grad is not None and not torch.allclose(m_p.grad, gm_p.grad):
return False return False
return True return True
def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.Tensor): def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool:
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
if m_p.grad is not None and not torch.allclose(m_p.data, gm_p.data):
return False
return True
def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
# test forward # test forward
non_fx_out = model(data) non_fx_out = model(node.clone(), pair.clone())
fx_out = gm(data) fx_out = gm(node.clone(), pair.clone())
print(non_fx_out.shape, fx_out.shape) assert torch.equal(non_fx_out[0], fx_out[0]), "fx_out doesn't comply with original output"
assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output" assert torch.equal(non_fx_out[1], fx_out[1]), "fx_out doesn't comply with original output"
# test barckward # test barckward
loss0 = non_fx_out.sum() # loss0 = non_fx_out[0].sum() + non_fx_out[1].sum()
loss0.backward() # loss0.backward()
loss1 = fx_out.sum() # loss1 = fx_out[0].sum() + fx_out[1].sum()
loss1.backward() # loss1.backward()
assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one" # assert _is_all_param_close(model, gm)
# assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one"
def _run_offload_codegen(rank): def _run_offload_codegen(rank):
@ -70,30 +49,22 @@ def _run_offload_codegen(rank):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
# build model and input # build model and input
model = MyNet().cuda() model = evoformer_base().cuda()
data = torch.rand(4, 4).cuda() node = torch.randn(1, 16, 32, 256).cuda()
pair = torch.randn(1, 32, 32, 128).cuda()
# trace the module and replace codegen # trace the module and replace codegen
tracer = ColoTracer(trace_act_ckpt=True) tracer = ColoTracer(trace_act_ckpt=True)
graph = tracer.trace(model) graph = tracer.trace(model)
codegen = ChunkCodeGen() # codegen = ChunkCodeGen()
graph.set_codegen(codegen) # graph.set_codegen(codegen)
# annotate the activation offload part # annotate the chunk part
# also annotate the activation_checkpoint so we could test both types # for node in graph.nodes:
# of input offload # if node.name == "linear0":
for node in graph.nodes: # setattr(node, "activation_offload", [0, True, False])
if node.name == "linear0": # if node.name == "linear1":
setattr(node, "activation_offload", [0, True, False]) # setattr(node, "activation_offload", [0, True, False])
if node.name == "linear1":
setattr(node, "activation_offload", [0, True, False])
# if node.name == "linear2":
# setattr(node, "activation_offload", [1, True, True])
# if node.name == "linear4":
# setattr(node, "activation_offload", [2, False, True])
# if node.name == "linear5":
# setattr(node, "activation_checkpoint", [0])
# setattr(node, "activation_offload", True)
gm = ColoGraphModule(copy.deepcopy(model), graph) gm = ColoGraphModule(copy.deepcopy(model), graph)
gm.recompile() gm.recompile()
@ -102,7 +73,7 @@ def _run_offload_codegen(rank):
code = graph.python_code("self").src code = graph.python_code("self").src
print(code) print(code)
_test_fwd_and_bwd(model, gm, data) _test_fwd_and_bwd(model, gm, node, pair)
gpc.destroy() gpc.destroy()

View File

@ -28,7 +28,7 @@ class Evoformer(nn.Module):
super(Evoformer, self).__init__() super(Evoformer, self).__init__()
self.blocks = nn.ModuleList() self.blocks = nn.ModuleList()
for _ in range(3): for _ in range(1):
self.blocks.append(EvoformerBlock(d_node, d_pair)) self.blocks.append(EvoformerBlock(d_node, d_pair))
def forward(self, node, pair): def forward(self, node, pair):
@ -36,6 +36,11 @@ class Evoformer(nn.Module):
node, pair = b(node, pair) node, pair = b(node, pair)
return node, pair return node, pair
def evoformer_tiny():
return Evoformer(d_node=64, d_pair=32)
def evoformer_base(): def evoformer_base():
return Evoformer(d_node=256, d_pair=128) return Evoformer(d_node=256, d_pair=128)

View File

@ -8,7 +8,7 @@ def bias_sigmod_ele(y, bias, z):
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor, def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
residual: torch.Tensor, prob: float) -> torch.Tensor: residual: torch.Tensor, prob: float) -> torch.Tensor:
out = (x + bias) * F.dropout(dropmask, p=prob, training=True) out = (x + bias) * F.dropout(dropmask, p=prob, training=False)
out = residual + out out = residual + out
return out return out

View File

@ -45,7 +45,7 @@ class MSARowAttentionWithPairBias(nn.Module):
# b = rearrange(b, 'b q k h -> b h q k') # b = rearrange(b, 'b q k h -> b h q k')
M = self.attention(M, b) M = self.attention(M, b)
dropout_mask = torch.ones_like(M[:, 0:1, :, :], device=M.device, dtype=M.dtype) dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype)
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop) return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop)

View File

@ -51,7 +51,7 @@ class TriangleMultiplicationOutgoing(nn.Module):
ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act) ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act)
ab = self.output_projection(self.layernorm2(ab)) ab = self.output_projection(self.layernorm2(ab))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype) dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
return bias_ele_dropout_residual(ab, return bias_ele_dropout_residual(ab,
self.output_bias, self.output_bias,
g, g,
@ -97,7 +97,7 @@ class TriangleMultiplicationIncoming(nn.Module):
ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act) ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act)
ab = self.output_projection(self.layernorm2(ab)) ab = self.output_projection(self.layernorm2(ab))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype) dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
return bias_ele_dropout_residual(ab, return bias_ele_dropout_residual(ab,
self.output_bias, self.output_bias,
g, g,
@ -134,7 +134,7 @@ class TriangleAttentionStartingNode(nn.Module):
Z = self.attention(Z, b) Z = self.attention(Z, b)
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype) dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
@ -168,7 +168,7 @@ class TriangleAttentionEndingNode(nn.Module):
Z = self.attention(Z, b) Z = self.attention(Z, b)
Z = Z.transpose(-2, -3) Z = Z.transpose(-2, -3)
dropout_mask = torch.ones_like(Z[:, :, 0:1, :], device=Z.device, dtype=Z.dtype) dropout_mask = torch.ones_like(Z[:, :, 0:1, :]).to(Z.device).to(Z.dtype)
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)