mirror of https://github.com/hpcaitech/ColossalAI
restruct dir
parent
27ab524096
commit
efb1c64c30
|
@ -3,13 +3,13 @@ import time
|
|||
import torch
|
||||
import torch.fx
|
||||
|
||||
from autochunk.chunk_codegen import ChunkCodeGen
|
||||
from colossalai.autochunk.chunk_codegen import ChunkCodeGen
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from autochunk.evoformer.evoformer import evoformer_base
|
||||
from autochunk.openfold.evoformer import EvoformerBlock
|
||||
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
||||
from tests.test_autochunk.openfold.evoformer import EvoformerBlock
|
||||
|
||||
|
||||
def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None):
|
||||
|
@ -94,7 +94,7 @@ def _build_openfold():
|
|||
def benchmark_evoformer():
|
||||
# init data and model
|
||||
msa_len = 256
|
||||
pair_len = 1024
|
||||
pair_len = 256
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
model = evoformer_base().cuda()
|
||||
|
@ -106,11 +106,11 @@ def benchmark_evoformer():
|
|||
|
||||
# build openfold
|
||||
chunk_size = 64
|
||||
# openfold = _build_openfold()
|
||||
openfold = _build_openfold()
|
||||
|
||||
# benchmark
|
||||
# _benchmark_evoformer(model, node, pair, "base")
|
||||
# _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
|
||||
_benchmark_evoformer(model, node, pair, "base")
|
||||
_benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
|
||||
_benchmark_evoformer(autochunk, node, pair, "autochunk")
|
||||
|
||||
|
|
@ -12,8 +12,8 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from autochunk.evoformer.evoformer import evoformer_base
|
||||
from autochunk.chunk_codegen import ChunkCodeGen
|
||||
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
||||
from ...colossalai.autochunk.chunk_codegen import ChunkCodeGen
|
||||
with_codegen = True
|
||||
|
||||
|
|
@ -19,25 +19,25 @@ import torch.nn as nn
|
|||
from typing import Tuple, Optional
|
||||
from functools import partial
|
||||
|
||||
from openfold.primitives import Linear, LayerNorm
|
||||
from openfold.dropout import DropoutRowwise, DropoutColumnwise
|
||||
from openfold.msa import (
|
||||
from .primitives import Linear, LayerNorm
|
||||
from .dropout import DropoutRowwise, DropoutColumnwise
|
||||
from .msa import (
|
||||
MSARowAttentionWithPairBias,
|
||||
MSAColumnAttention,
|
||||
MSAColumnGlobalAttention,
|
||||
)
|
||||
from openfold.outer_product_mean import OuterProductMean
|
||||
from openfold.pair_transition import PairTransition
|
||||
from openfold.triangular_attention import (
|
||||
from .outer_product_mean import OuterProductMean
|
||||
from .pair_transition import PairTransition
|
||||
from .triangular_attention import (
|
||||
TriangleAttentionStartingNode,
|
||||
TriangleAttentionEndingNode,
|
||||
)
|
||||
from openfold.triangular_multiplicative_update import (
|
||||
from .triangular_multiplicative_update import (
|
||||
TriangleMultiplicationOutgoing,
|
||||
TriangleMultiplicationIncoming,
|
||||
)
|
||||
from openfold.checkpointing import checkpoint_blocks, get_checkpoint_fn
|
||||
from openfold.tensor_utils import chunk_layer
|
||||
from .checkpointing import checkpoint_blocks, get_checkpoint_fn
|
||||
from .tensor_utils import chunk_layer
|
||||
|
||||
|
||||
class MSATransition(nn.Module):
|
|
@ -18,15 +18,15 @@ import torch
|
|||
import torch.nn as nn
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from openfold.primitives import (
|
||||
from .primitives import (
|
||||
Linear,
|
||||
LayerNorm,
|
||||
Attention,
|
||||
GlobalAttention,
|
||||
_attention_chunked_trainable,
|
||||
)
|
||||
from openfold.checkpointing import get_checkpoint_fn
|
||||
from openfold.tensor_utils import (
|
||||
from .checkpointing import get_checkpoint_fn
|
||||
from .tensor_utils import (
|
||||
chunk_layer,
|
||||
permute_final_dims,
|
||||
flatten_final_dims,
|
|
@ -19,8 +19,8 @@ from typing import Optional
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from openfold.primitives import Linear
|
||||
from openfold.tensor_utils import chunk_layer
|
||||
from .primitives import Linear
|
||||
from .tensor_utils import chunk_layer
|
||||
|
||||
|
||||
class OuterProductMean(nn.Module):
|
|
@ -17,8 +17,8 @@ from typing import Optional
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from openfold.primitives import Linear, LayerNorm
|
||||
from openfold.tensor_utils import chunk_layer
|
||||
from .primitives import Linear, LayerNorm
|
||||
from .tensor_utils import chunk_layer
|
||||
|
||||
|
||||
class PairTransition(nn.Module):
|
|
@ -21,8 +21,8 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from openfold.checkpointing import get_checkpoint_fn
|
||||
from openfold.tensor_utils import (
|
||||
from .checkpointing import get_checkpoint_fn
|
||||
from .tensor_utils import (
|
||||
permute_final_dims,
|
||||
flatten_final_dims,
|
||||
_chunk_slice,
|
|
@ -20,8 +20,8 @@ from typing import Optional, List
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from openfold.primitives import Linear, LayerNorm, Attention
|
||||
from openfold.tensor_utils import (
|
||||
from .primitives import Linear, LayerNorm, Attention
|
||||
from .tensor_utils import (
|
||||
chunk_layer,
|
||||
permute_final_dims,
|
||||
flatten_final_dims,
|
|
@ -19,8 +19,8 @@ from typing import Optional
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from openfold.primitives import Linear, LayerNorm
|
||||
from openfold.tensor_utils import permute_final_dims
|
||||
from .primitives import Linear, LayerNorm
|
||||
from .tensor_utils import permute_final_dims
|
||||
|
||||
|
||||
class TriangleMultiplicativeUpdate(nn.Module):
|
Loading…
Reference in New Issue