align evoformer

pull/2364/head
oahzxl 2022-11-02 15:49:25 +08:00
parent 86f2a31474
commit 820ea4d056
6 changed files with 66 additions and 191 deletions

View File

@ -1,5 +1,6 @@
import colossalai
import torch
import copy
from typing import List, Callable, Any, Tuple, Dict, Iterable
try:
@ -17,74 +18,18 @@ else:
__all__ = ['python_code_with_activation_checkpoint']
def _gen_saved_tensors_hooks():
"""
Generate saved tensors hooks
"""
pack_hook = """def pack_hook_input(self, x):
if getattr(x, "offload", False):
return (x.device, x.cpu())
else:
return x
def pack_hook_no_input(self, x):
if getattr(x, "offload", True):
return (x.device, x.cpu())
else:
return x
"""
unpack_hook = """def unpack_hook(self, packed):
if isinstance(packed, tuple):
device, tensor = packed
return tensor.to(device)
else:
return packed
"""
return pack_hook, unpack_hook
def _gen_loop_5(to_keep):
context = "chunk_result = []\nfor gen_loop_idx in range(4):\n"
context += " chunk_tensor = " + to_keep + "[gen_loop_idx, :]\n"
def _gen_loop_start(to_keep, chunk_size=2):
context = "chunk_result = []; chunk_size = %d\nfor gen_loop_idx in range(0, %s.shape[0], chunk_size):\n" % (chunk_size, to_keep[0])
context += " chunk_tensor = " + to_keep + "[gen_loop_idx:gen_loop_idx + chunk_size, :]\n"
return context
def _gen_loop_5_final(final_name, to_keep):
def _gen_loop_end(final_name, to_keep):
context = " chunk_result.append(" + final_name + ")\n"
context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n"
context += final_name + " = chunk_result; chunk_result = None\n"
return context
def _gen_save_tensors_hooks_context(offload_input=True) -> str:
"""Generate customized saved_tensors_hooks
Args:
offload_input (bool, optional): whether we need offload input, if offload_input=False,
we will use self.pack_hook_no_input instead. Defaults to True.
Returns:
str: generated context
"""
if offload_input:
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):\n"
else:
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):\n"
return context
def _gen_save_on_cpu_context():
"""
Generate save on cpu context
"""
context = "with torch.autograd.graph.save_on_cpu(pin_memory=True):\n"
return context
def _find_input_and_output_nodes(nodes: List[Node]):
"""
@ -112,49 +57,6 @@ def _find_input_and_output_nodes(nodes: List[Node]):
return input_nodes, output_nodes
def _find_ckpt_regions(nodes: List[Node]):
"""
Find the checkpoint regions given a list of consecutive nodes. The outputs will be list
of tuples, each tuple is in the form of (start_index, end_index).
"""
ckpt_nodes = []
ckpt_regions = []
start = -1
end = -1
current_region = None
for idx, node in enumerate(nodes):
if hasattr(node, 'activation_checkpoint'):
act_ckpt_label = node.activation_checkpoint
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
if current_region is None:
current_region = act_ckpt_label
start = idx
# if activation checkpoint has changed
# we restart the tracking
# e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2]
if act_ckpt_label != current_region:
assert start != -1
ckpt_regions.append((start, idx - 1))
current_region = act_ckpt_label
start = idx
end = -1
elif current_region is not None and not hasattr(node, 'activation_checkpoint'):
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
end = idx - 1
assert start != -1 and end != -1
ckpt_regions.append((start, end))
start = end = -1
current_region = None
else:
pass
return ckpt_regions
def _find_offload_regions(nodes: List[Node]):
"""This function is to find the offload regions
In pofo algorithm, during annotation, we will annotate the offload region with the
@ -400,12 +302,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
emit_node_func: function to emit node
delete_unused_value_func: function to remove the unused value
"""
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
# find the offload regions
chunk_regions, chunk_labels = _find_offload_regions(nodes)
chunk_regions = [(1, 4)]
chunk_starts = [item[0] for item in chunk_regions]
chunk_ends = [item[1] for item in chunk_regions]
chunk_inputs = []
@ -424,7 +323,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
# this flag is to prevent repeated insert of save tensors
# hooks definition in ckpt_func
node_idx = 0
to_keep = []
chunk_var = []
while node_idx < len(node_list):
# break if we finish the processing all the nodes
if node_idx >= len(node_list):
@ -435,28 +334,30 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
node = node_list[node_idx]
if node_idx in chunk_starts:
# save chunk input var, dont delete it
to_keep.extend(node.args[0].name)
within_chunk_region = True
# add for loop
body.append(_gen_loop_5(to_keep[0]))
# change first node's input to new chunked var
node_args = list(node.args)
node_args[0] = 'chunk_tensor'
# save chunk input var, dont delete it
chunk_var.append(node.args[0].name)
# add for loop
body.append(_gen_loop_start(chunk_var[0]))
if within_chunk_region:
emit_node_func(node, body)
# replace input var with chunk var
if node_idx in chunk_starts:
body[-1] = body[-1].replace("("+ chunk_var[0] +")", '(chunk_tensor)')
body[-1] = ' ' + body[-1]
delete_unused_value_func(node, body, to_keep)
delete_unused_value_func(node, body, chunk_var)
else:
emit_node_func(node, body)
if node_idx not in chunk_inputs:
delete_unused_value_func(node, body, to_keep)
delete_unused_value_func(node, body, chunk_var)
if node_idx in chunk_ends:
body.append(_gen_loop_5_final(node.name, to_keep))
to_keep = []
body.append(_gen_loop_end(node.name, chunk_var))
chunk_var = []
within_chunk_region = False
node_idx += 1
@ -580,9 +481,7 @@ if CODEGEN_AVAILABLE:
body.append('\n')
return
nodes_to_delete = user_to_last_uses.get(user, [])
for n in nodes_to_delete:
if n.name in to_keep:
nodes_to_delete.remove(n)
nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep]
if len(nodes_to_delete):
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
body.append(f'; {to_delete_str}\n')

View File

@ -9,60 +9,39 @@ import colossalai
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
try:
from chunk_codegen import ChunkCodeGen
with_codegen = True
except:
# fall back to older pytorch version
from chunk_codegen import python_code_with_activation_checkpoint
with_codegen = False
class MyNet(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear0 = torch.nn.Linear(4, 4)
self.linear1 = torch.nn.Linear(4, 4)
self.linear2 = torch.nn.Linear(4, 4)
self.linear3 = torch.nn.Linear(4, 4)
self.linear4 = torch.nn.Linear(4, 4)
self.linear5 = torch.nn.Linear(4, 4)
self.linear6 = torch.nn.Linear(4, 4)
def forward(self, x):
x = self.linear0(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = self.linear4(x)
x = self.linear5(x)
x = self.linear6(x)
return x
from evoformer.evoformer import evoformer_base
from chunk_codegen import ChunkCodeGen
with_codegen = True
def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool:
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
if not torch.allclose(m_p.grad, gm_p.grad):
if m_p.grad is not None and not torch.allclose(m_p.grad, gm_p.grad):
return False
return True
def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.Tensor):
def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool:
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
if m_p.grad is not None and not torch.allclose(m_p.data, gm_p.data):
return False
return True
def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
# test forward
non_fx_out = model(data)
fx_out = gm(data)
print(non_fx_out.shape, fx_out.shape)
assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output"
non_fx_out = model(node.clone(), pair.clone())
fx_out = gm(node.clone(), pair.clone())
assert torch.equal(non_fx_out[0], fx_out[0]), "fx_out doesn't comply with original output"
assert torch.equal(non_fx_out[1], fx_out[1]), "fx_out doesn't comply with original output"
# test barckward
loss0 = non_fx_out.sum()
loss0.backward()
loss1 = fx_out.sum()
loss1.backward()
assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one"
# loss0 = non_fx_out[0].sum() + non_fx_out[1].sum()
# loss0.backward()
# loss1 = fx_out[0].sum() + fx_out[1].sum()
# loss1.backward()
# assert _is_all_param_close(model, gm)
# assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one"
def _run_offload_codegen(rank):
@ -70,30 +49,22 @@ def _run_offload_codegen(rank):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
# build model and input
model = MyNet().cuda()
data = torch.rand(4, 4).cuda()
model = evoformer_base().cuda()
node = torch.randn(1, 16, 32, 256).cuda()
pair = torch.randn(1, 32, 32, 128).cuda()
# trace the module and replace codegen
tracer = ColoTracer(trace_act_ckpt=True)
graph = tracer.trace(model)
codegen = ChunkCodeGen()
graph.set_codegen(codegen)
# codegen = ChunkCodeGen()
# graph.set_codegen(codegen)
# annotate the activation offload part
# also annotate the activation_checkpoint so we could test both types
# of input offload
for node in graph.nodes:
if node.name == "linear0":
setattr(node, "activation_offload", [0, True, False])
if node.name == "linear1":
setattr(node, "activation_offload", [0, True, False])
# if node.name == "linear2":
# setattr(node, "activation_offload", [1, True, True])
# if node.name == "linear4":
# setattr(node, "activation_offload", [2, False, True])
# if node.name == "linear5":
# setattr(node, "activation_checkpoint", [0])
# setattr(node, "activation_offload", True)
# annotate the chunk part
# for node in graph.nodes:
# if node.name == "linear0":
# setattr(node, "activation_offload", [0, True, False])
# if node.name == "linear1":
# setattr(node, "activation_offload", [0, True, False])
gm = ColoGraphModule(copy.deepcopy(model), graph)
gm.recompile()
@ -102,7 +73,7 @@ def _run_offload_codegen(rank):
code = graph.python_code("self").src
print(code)
_test_fwd_and_bwd(model, gm, data)
_test_fwd_and_bwd(model, gm, node, pair)
gpc.destroy()

View File

@ -28,7 +28,7 @@ class Evoformer(nn.Module):
super(Evoformer, self).__init__()
self.blocks = nn.ModuleList()
for _ in range(3):
for _ in range(1):
self.blocks.append(EvoformerBlock(d_node, d_pair))
def forward(self, node, pair):
@ -36,6 +36,11 @@ class Evoformer(nn.Module):
node, pair = b(node, pair)
return node, pair
def evoformer_tiny():
return Evoformer(d_node=64, d_pair=32)
def evoformer_base():
return Evoformer(d_node=256, d_pair=128)

View File

@ -8,7 +8,7 @@ def bias_sigmod_ele(y, bias, z):
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
residual: torch.Tensor, prob: float) -> torch.Tensor:
out = (x + bias) * F.dropout(dropmask, p=prob, training=True)
out = (x + bias) * F.dropout(dropmask, p=prob, training=False)
out = residual + out
return out

View File

@ -45,7 +45,7 @@ class MSARowAttentionWithPairBias(nn.Module):
# b = rearrange(b, 'b q k h -> b h q k')
M = self.attention(M, b)
dropout_mask = torch.ones_like(M[:, 0:1, :, :], device=M.device, dtype=M.dtype)
dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype)
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop)

View File

@ -51,7 +51,7 @@ class TriangleMultiplicationOutgoing(nn.Module):
ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act)
ab = self.output_projection(self.layernorm2(ab))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
return bias_ele_dropout_residual(ab,
self.output_bias,
g,
@ -97,7 +97,7 @@ class TriangleMultiplicationIncoming(nn.Module):
ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act)
ab = self.output_projection(self.layernorm2(ab))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
return bias_ele_dropout_residual(ab,
self.output_bias,
g,
@ -134,7 +134,7 @@ class TriangleAttentionStartingNode(nn.Module):
Z = self.attention(Z, b)
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
@ -168,7 +168,7 @@ class TriangleAttentionEndingNode(nn.Module):
Z = self.attention(Z, b)
Z = Z.transpose(-2, -3)
dropout_mask = torch.ones_like(Z[:, :, 0:1, :], device=Z.device, dtype=Z.dtype)
dropout_mask = torch.ones_like(Z[:, :, 0:1, :]).to(Z.device).to(Z.dtype)
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)