mirror of https://github.com/hpcaitech/ColossalAI
align evoformer
parent
86f2a31474
commit
820ea4d056
143
chunk_codegen.py
143
chunk_codegen.py
|
@ -1,5 +1,6 @@
|
|||
import colossalai
|
||||
import torch
|
||||
import copy
|
||||
from typing import List, Callable, Any, Tuple, Dict, Iterable
|
||||
|
||||
try:
|
||||
|
@ -17,74 +18,18 @@ else:
|
|||
__all__ = ['python_code_with_activation_checkpoint']
|
||||
|
||||
|
||||
def _gen_saved_tensors_hooks():
|
||||
"""
|
||||
Generate saved tensors hooks
|
||||
"""
|
||||
|
||||
pack_hook = """def pack_hook_input(self, x):
|
||||
if getattr(x, "offload", False):
|
||||
return (x.device, x.cpu())
|
||||
else:
|
||||
return x
|
||||
|
||||
def pack_hook_no_input(self, x):
|
||||
if getattr(x, "offload", True):
|
||||
return (x.device, x.cpu())
|
||||
else:
|
||||
return x
|
||||
"""
|
||||
|
||||
unpack_hook = """def unpack_hook(self, packed):
|
||||
if isinstance(packed, tuple):
|
||||
device, tensor = packed
|
||||
return tensor.to(device)
|
||||
else:
|
||||
return packed
|
||||
"""
|
||||
|
||||
return pack_hook, unpack_hook
|
||||
|
||||
|
||||
def _gen_loop_5(to_keep):
|
||||
context = "chunk_result = []\nfor gen_loop_idx in range(4):\n"
|
||||
context += " chunk_tensor = " + to_keep + "[gen_loop_idx, :]\n"
|
||||
def _gen_loop_start(to_keep, chunk_size=2):
|
||||
context = "chunk_result = []; chunk_size = %d\nfor gen_loop_idx in range(0, %s.shape[0], chunk_size):\n" % (chunk_size, to_keep[0])
|
||||
context += " chunk_tensor = " + to_keep + "[gen_loop_idx:gen_loop_idx + chunk_size, :]\n"
|
||||
return context
|
||||
|
||||
|
||||
def _gen_loop_5_final(final_name, to_keep):
|
||||
def _gen_loop_end(final_name, to_keep):
|
||||
context = " chunk_result.append(" + final_name + ")\n"
|
||||
context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n"
|
||||
context += final_name + " = chunk_result; chunk_result = None\n"
|
||||
return context
|
||||
|
||||
|
||||
def _gen_save_tensors_hooks_context(offload_input=True) -> str:
|
||||
"""Generate customized saved_tensors_hooks
|
||||
|
||||
Args:
|
||||
offload_input (bool, optional): whether we need offload input, if offload_input=False,
|
||||
we will use self.pack_hook_no_input instead. Defaults to True.
|
||||
|
||||
Returns:
|
||||
str: generated context
|
||||
"""
|
||||
|
||||
if offload_input:
|
||||
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):\n"
|
||||
else:
|
||||
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):\n"
|
||||
return context
|
||||
|
||||
|
||||
def _gen_save_on_cpu_context():
|
||||
"""
|
||||
Generate save on cpu context
|
||||
"""
|
||||
|
||||
context = "with torch.autograd.graph.save_on_cpu(pin_memory=True):\n"
|
||||
return context
|
||||
|
||||
|
||||
def _find_input_and_output_nodes(nodes: List[Node]):
|
||||
"""
|
||||
|
@ -112,49 +57,6 @@ def _find_input_and_output_nodes(nodes: List[Node]):
|
|||
return input_nodes, output_nodes
|
||||
|
||||
|
||||
def _find_ckpt_regions(nodes: List[Node]):
|
||||
"""
|
||||
Find the checkpoint regions given a list of consecutive nodes. The outputs will be list
|
||||
of tuples, each tuple is in the form of (start_index, end_index).
|
||||
"""
|
||||
ckpt_nodes = []
|
||||
ckpt_regions = []
|
||||
start = -1
|
||||
end = -1
|
||||
current_region = None
|
||||
|
||||
for idx, node in enumerate(nodes):
|
||||
if hasattr(node, 'activation_checkpoint'):
|
||||
act_ckpt_label = node.activation_checkpoint
|
||||
|
||||
# this activation checkpoint label is not set yet
|
||||
# meaning this is the first node of the activation ckpt region
|
||||
if current_region is None:
|
||||
current_region = act_ckpt_label
|
||||
start = idx
|
||||
|
||||
# if activation checkpoint has changed
|
||||
# we restart the tracking
|
||||
# e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2]
|
||||
if act_ckpt_label != current_region:
|
||||
assert start != -1
|
||||
ckpt_regions.append((start, idx - 1))
|
||||
current_region = act_ckpt_label
|
||||
start = idx
|
||||
end = -1
|
||||
elif current_region is not None and not hasattr(node, 'activation_checkpoint'):
|
||||
# used to check the case below
|
||||
# node ckpt states = [ckpt, ckpt, non-ckpt]
|
||||
end = idx - 1
|
||||
assert start != -1 and end != -1
|
||||
ckpt_regions.append((start, end))
|
||||
start = end = -1
|
||||
current_region = None
|
||||
else:
|
||||
pass
|
||||
return ckpt_regions
|
||||
|
||||
|
||||
def _find_offload_regions(nodes: List[Node]):
|
||||
"""This function is to find the offload regions
|
||||
In pofo algorithm, during annotation, we will annotate the offload region with the
|
||||
|
@ -400,12 +302,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
|||
emit_node_func: function to emit node
|
||||
delete_unused_value_func: function to remove the unused value
|
||||
"""
|
||||
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
|
||||
start_idx = [item[0] for item in ckpt_regions]
|
||||
end_idx = [item[1] for item in ckpt_regions]
|
||||
|
||||
# find the offload regions
|
||||
chunk_regions, chunk_labels = _find_offload_regions(nodes)
|
||||
chunk_regions = [(1, 4)]
|
||||
chunk_starts = [item[0] for item in chunk_regions]
|
||||
chunk_ends = [item[1] for item in chunk_regions]
|
||||
chunk_inputs = []
|
||||
|
@ -424,7 +323,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
|||
# this flag is to prevent repeated insert of save tensors
|
||||
# hooks definition in ckpt_func
|
||||
node_idx = 0
|
||||
to_keep = []
|
||||
chunk_var = []
|
||||
while node_idx < len(node_list):
|
||||
# break if we finish the processing all the nodes
|
||||
if node_idx >= len(node_list):
|
||||
|
@ -435,28 +334,30 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
|||
node = node_list[node_idx]
|
||||
|
||||
if node_idx in chunk_starts:
|
||||
# save chunk input var, dont delete it
|
||||
to_keep.extend(node.args[0].name)
|
||||
within_chunk_region = True
|
||||
# add for loop
|
||||
body.append(_gen_loop_5(to_keep[0]))
|
||||
# change first node's input to new chunked var
|
||||
node_args = list(node.args)
|
||||
node_args[0] = 'chunk_tensor'
|
||||
|
||||
# save chunk input var, dont delete it
|
||||
chunk_var.append(node.args[0].name)
|
||||
|
||||
# add for loop
|
||||
body.append(_gen_loop_start(chunk_var[0]))
|
||||
|
||||
if within_chunk_region:
|
||||
emit_node_func(node, body)
|
||||
# replace input var with chunk var
|
||||
if node_idx in chunk_starts:
|
||||
body[-1] = body[-1].replace("("+ chunk_var[0] +")", '(chunk_tensor)')
|
||||
body[-1] = ' ' + body[-1]
|
||||
delete_unused_value_func(node, body, to_keep)
|
||||
delete_unused_value_func(node, body, chunk_var)
|
||||
|
||||
else:
|
||||
emit_node_func(node, body)
|
||||
if node_idx not in chunk_inputs:
|
||||
delete_unused_value_func(node, body, to_keep)
|
||||
delete_unused_value_func(node, body, chunk_var)
|
||||
|
||||
if node_idx in chunk_ends:
|
||||
body.append(_gen_loop_5_final(node.name, to_keep))
|
||||
to_keep = []
|
||||
body.append(_gen_loop_end(node.name, chunk_var))
|
||||
chunk_var = []
|
||||
within_chunk_region = False
|
||||
|
||||
node_idx += 1
|
||||
|
@ -580,9 +481,7 @@ if CODEGEN_AVAILABLE:
|
|||
body.append('\n')
|
||||
return
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
for n in nodes_to_delete:
|
||||
if n.name in to_keep:
|
||||
nodes_to_delete.remove(n)
|
||||
nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep]
|
||||
if len(nodes_to_delete):
|
||||
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
|
||||
body.append(f'; {to_delete_str}\n')
|
||||
|
|
|
@ -9,60 +9,39 @@ import colossalai
|
|||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
|
||||
try:
|
||||
from chunk_codegen import ChunkCodeGen
|
||||
with_codegen = True
|
||||
except:
|
||||
# fall back to older pytorch version
|
||||
from chunk_codegen import python_code_with_activation_checkpoint
|
||||
with_codegen = False
|
||||
|
||||
|
||||
class MyNet(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear0 = torch.nn.Linear(4, 4)
|
||||
self.linear1 = torch.nn.Linear(4, 4)
|
||||
self.linear2 = torch.nn.Linear(4, 4)
|
||||
self.linear3 = torch.nn.Linear(4, 4)
|
||||
self.linear4 = torch.nn.Linear(4, 4)
|
||||
self.linear5 = torch.nn.Linear(4, 4)
|
||||
self.linear6 = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear0(x)
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
x = self.linear3(x)
|
||||
x = self.linear4(x)
|
||||
x = self.linear5(x)
|
||||
x = self.linear6(x)
|
||||
return x
|
||||
from evoformer.evoformer import evoformer_base
|
||||
from chunk_codegen import ChunkCodeGen
|
||||
with_codegen = True
|
||||
|
||||
|
||||
def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool:
|
||||
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
|
||||
if not torch.allclose(m_p.grad, gm_p.grad):
|
||||
if m_p.grad is not None and not torch.allclose(m_p.grad, gm_p.grad):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.Tensor):
|
||||
def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool:
|
||||
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
|
||||
if m_p.grad is not None and not torch.allclose(m_p.data, gm_p.data):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
|
||||
# test forward
|
||||
non_fx_out = model(data)
|
||||
fx_out = gm(data)
|
||||
print(non_fx_out.shape, fx_out.shape)
|
||||
assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output"
|
||||
non_fx_out = model(node.clone(), pair.clone())
|
||||
fx_out = gm(node.clone(), pair.clone())
|
||||
assert torch.equal(non_fx_out[0], fx_out[0]), "fx_out doesn't comply with original output"
|
||||
assert torch.equal(non_fx_out[1], fx_out[1]), "fx_out doesn't comply with original output"
|
||||
|
||||
# test barckward
|
||||
loss0 = non_fx_out.sum()
|
||||
loss0.backward()
|
||||
loss1 = fx_out.sum()
|
||||
loss1.backward()
|
||||
assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one"
|
||||
# loss0 = non_fx_out[0].sum() + non_fx_out[1].sum()
|
||||
# loss0.backward()
|
||||
# loss1 = fx_out[0].sum() + fx_out[1].sum()
|
||||
# loss1.backward()
|
||||
# assert _is_all_param_close(model, gm)
|
||||
# assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one"
|
||||
|
||||
|
||||
def _run_offload_codegen(rank):
|
||||
|
@ -70,30 +49,22 @@ def _run_offload_codegen(rank):
|
|||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
|
||||
# build model and input
|
||||
model = MyNet().cuda()
|
||||
data = torch.rand(4, 4).cuda()
|
||||
model = evoformer_base().cuda()
|
||||
node = torch.randn(1, 16, 32, 256).cuda()
|
||||
pair = torch.randn(1, 32, 32, 128).cuda()
|
||||
|
||||
# trace the module and replace codegen
|
||||
tracer = ColoTracer(trace_act_ckpt=True)
|
||||
graph = tracer.trace(model)
|
||||
codegen = ChunkCodeGen()
|
||||
graph.set_codegen(codegen)
|
||||
# codegen = ChunkCodeGen()
|
||||
# graph.set_codegen(codegen)
|
||||
|
||||
# annotate the activation offload part
|
||||
# also annotate the activation_checkpoint so we could test both types
|
||||
# of input offload
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear0":
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
# if node.name == "linear2":
|
||||
# setattr(node, "activation_offload", [1, True, True])
|
||||
# if node.name == "linear4":
|
||||
# setattr(node, "activation_offload", [2, False, True])
|
||||
# if node.name == "linear5":
|
||||
# setattr(node, "activation_checkpoint", [0])
|
||||
# setattr(node, "activation_offload", True)
|
||||
# annotate the chunk part
|
||||
# for node in graph.nodes:
|
||||
# if node.name == "linear0":
|
||||
# setattr(node, "activation_offload", [0, True, False])
|
||||
# if node.name == "linear1":
|
||||
# setattr(node, "activation_offload", [0, True, False])
|
||||
|
||||
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||
gm.recompile()
|
||||
|
@ -102,7 +73,7 @@ def _run_offload_codegen(rank):
|
|||
code = graph.python_code("self").src
|
||||
print(code)
|
||||
|
||||
_test_fwd_and_bwd(model, gm, data)
|
||||
_test_fwd_and_bwd(model, gm, node, pair)
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ class Evoformer(nn.Module):
|
|||
super(Evoformer, self).__init__()
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
for _ in range(3):
|
||||
for _ in range(1):
|
||||
self.blocks.append(EvoformerBlock(d_node, d_pair))
|
||||
|
||||
def forward(self, node, pair):
|
||||
|
@ -36,6 +36,11 @@ class Evoformer(nn.Module):
|
|||
node, pair = b(node, pair)
|
||||
return node, pair
|
||||
|
||||
|
||||
def evoformer_tiny():
|
||||
return Evoformer(d_node=64, d_pair=32)
|
||||
|
||||
|
||||
def evoformer_base():
|
||||
return Evoformer(d_node=256, d_pair=128)
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ def bias_sigmod_ele(y, bias, z):
|
|||
|
||||
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
|
||||
residual: torch.Tensor, prob: float) -> torch.Tensor:
|
||||
out = (x + bias) * F.dropout(dropmask, p=prob, training=True)
|
||||
out = (x + bias) * F.dropout(dropmask, p=prob, training=False)
|
||||
out = residual + out
|
||||
return out
|
||||
|
||||
|
|
|
@ -45,7 +45,7 @@ class MSARowAttentionWithPairBias(nn.Module):
|
|||
# b = rearrange(b, 'b q k h -> b h q k')
|
||||
|
||||
M = self.attention(M, b)
|
||||
dropout_mask = torch.ones_like(M[:, 0:1, :, :], device=M.device, dtype=M.dtype)
|
||||
dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype)
|
||||
|
||||
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop)
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ class TriangleMultiplicationOutgoing(nn.Module):
|
|||
|
||||
ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act)
|
||||
ab = self.output_projection(self.layernorm2(ab))
|
||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
|
||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||
return bias_ele_dropout_residual(ab,
|
||||
self.output_bias,
|
||||
g,
|
||||
|
@ -97,7 +97,7 @@ class TriangleMultiplicationIncoming(nn.Module):
|
|||
|
||||
ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act)
|
||||
ab = self.output_projection(self.layernorm2(ab))
|
||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
|
||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||
return bias_ele_dropout_residual(ab,
|
||||
self.output_bias,
|
||||
g,
|
||||
|
@ -134,7 +134,7 @@ class TriangleAttentionStartingNode(nn.Module):
|
|||
|
||||
Z = self.attention(Z, b)
|
||||
|
||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
|
||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
|
||||
|
||||
|
||||
|
@ -168,7 +168,7 @@ class TriangleAttentionEndingNode(nn.Module):
|
|||
Z = self.attention(Z, b)
|
||||
|
||||
Z = Z.transpose(-2, -3)
|
||||
dropout_mask = torch.ones_like(Z[:, :, 0:1, :], device=Z.device, dtype=Z.dtype)
|
||||
dropout_mask = torch.ones_like(Z[:, :, 0:1, :]).to(Z.device).to(Z.dtype)
|
||||
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue