restruct dir

pull/2364/head
oahzxl 2023-01-06 11:39:26 +08:00
parent 27ab524096
commit efb1c64c30
19 changed files with 31 additions and 31 deletions

View File

@ -3,13 +3,13 @@ import time
import torch
import torch.fx
from autochunk.chunk_codegen import ChunkCodeGen
from colossalai.autochunk.chunk_codegen import ChunkCodeGen
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import MetaTensor
from autochunk.evoformer.evoformer import evoformer_base
from autochunk.openfold.evoformer import EvoformerBlock
from tests.test_autochunk.evoformer.evoformer import evoformer_base
from tests.test_autochunk.openfold.evoformer import EvoformerBlock
def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None):
@ -94,7 +94,7 @@ def _build_openfold():
def benchmark_evoformer():
# init data and model
msa_len = 256
pair_len = 1024
pair_len = 256
node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
model = evoformer_base().cuda()
@ -106,11 +106,11 @@ def benchmark_evoformer():
# build openfold
chunk_size = 64
# openfold = _build_openfold()
openfold = _build_openfold()
# benchmark
# _benchmark_evoformer(model, node, pair, "base")
# _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
_benchmark_evoformer(model, node, pair, "base")
_benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
_benchmark_evoformer(autochunk, node, pair, "autochunk")

View File

@ -12,8 +12,8 @@ from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
from colossalai.fx.profiler import MetaTensor
from autochunk.evoformer.evoformer import evoformer_base
from autochunk.chunk_codegen import ChunkCodeGen
from tests.test_autochunk.evoformer.evoformer import evoformer_base
from ...colossalai.autochunk.chunk_codegen import ChunkCodeGen
with_codegen = True

View File

@ -19,25 +19,25 @@ import torch.nn as nn
from typing import Tuple, Optional
from functools import partial
from openfold.primitives import Linear, LayerNorm
from openfold.dropout import DropoutRowwise, DropoutColumnwise
from openfold.msa import (
from .primitives import Linear, LayerNorm
from .dropout import DropoutRowwise, DropoutColumnwise
from .msa import (
MSARowAttentionWithPairBias,
MSAColumnAttention,
MSAColumnGlobalAttention,
)
from openfold.outer_product_mean import OuterProductMean
from openfold.pair_transition import PairTransition
from openfold.triangular_attention import (
from .outer_product_mean import OuterProductMean
from .pair_transition import PairTransition
from .triangular_attention import (
TriangleAttentionStartingNode,
TriangleAttentionEndingNode,
)
from openfold.triangular_multiplicative_update import (
from .triangular_multiplicative_update import (
TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming,
)
from openfold.checkpointing import checkpoint_blocks, get_checkpoint_fn
from openfold.tensor_utils import chunk_layer
from .checkpointing import checkpoint_blocks, get_checkpoint_fn
from .tensor_utils import chunk_layer
class MSATransition(nn.Module):

View File

@ -18,15 +18,15 @@ import torch
import torch.nn as nn
from typing import Optional, List, Tuple
from openfold.primitives import (
from .primitives import (
Linear,
LayerNorm,
Attention,
GlobalAttention,
_attention_chunked_trainable,
)
from openfold.checkpointing import get_checkpoint_fn
from openfold.tensor_utils import (
from .checkpointing import get_checkpoint_fn
from .tensor_utils import (
chunk_layer,
permute_final_dims,
flatten_final_dims,

View File

@ -19,8 +19,8 @@ from typing import Optional
import torch
import torch.nn as nn
from openfold.primitives import Linear
from openfold.tensor_utils import chunk_layer
from .primitives import Linear
from .tensor_utils import chunk_layer
class OuterProductMean(nn.Module):

View File

@ -17,8 +17,8 @@ from typing import Optional
import torch
import torch.nn as nn
from openfold.primitives import Linear, LayerNorm
from openfold.tensor_utils import chunk_layer
from .primitives import Linear, LayerNorm
from .tensor_utils import chunk_layer
class PairTransition(nn.Module):

View File

@ -21,8 +21,8 @@ import numpy as np
import torch
import torch.nn as nn
from openfold.checkpointing import get_checkpoint_fn
from openfold.tensor_utils import (
from .checkpointing import get_checkpoint_fn
from .tensor_utils import (
permute_final_dims,
flatten_final_dims,
_chunk_slice,

View File

@ -20,8 +20,8 @@ from typing import Optional, List
import torch
import torch.nn as nn
from openfold.primitives import Linear, LayerNorm, Attention
from openfold.tensor_utils import (
from .primitives import Linear, LayerNorm, Attention
from .tensor_utils import (
chunk_layer,
permute_final_dims,
flatten_final_dims,

View File

@ -19,8 +19,8 @@ from typing import Optional
import torch
import torch.nn as nn
from openfold.primitives import Linear, LayerNorm
from openfold.tensor_utils import permute_final_dims
from .primitives import Linear, LayerNorm
from .tensor_utils import permute_final_dims
class TriangleMultiplicativeUpdate(nn.Module):