mirror of https://github.com/hpcaitech/ColossalAI
14 lines
391 B
Python
14 lines
391 B
Python
import torch
|
|
from ..registry import meta_profiler_module
|
|
from typing import Optional, Tuple
|
|
|
|
|
|
# TODO: calculate rnn FLOPs
|
|
@meta_profiler_module.register(torch.nn.GRU)
|
|
@meta_profiler_module.register(torch.nn.RNN)
|
|
def torch_nn_rnn(self: torch.nn.Module, input: torch.Tensor, hx: torch.Tensor) -> Tuple[int, int]:
|
|
raise NotImplementedError
|
|
flops = 0
|
|
macs = 0
|
|
return flops, macs
|