2022-12-11 13:41:13 +00:00
|
|
|
from abc import ABC
|
|
|
|
|
2022-12-09 08:13:03 +00:00
|
|
|
import torch
|
|
|
|
|
|
|
|
|
2022-12-11 13:41:13 +00:00
|
|
|
class ParamGenerator(ABC):
|
|
|
|
|
|
|
|
def append(self, param: torch.nn.Parameter):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def generate(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def clear(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class OrderedParamGenerator(ParamGenerator):
|
|
|
|
"""OrderedParamGenerator
|
2022-12-09 08:13:03 +00:00
|
|
|
|
|
|
|
Contain the order of parameters visited during runtime.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
self.param_visited_order = []
|
|
|
|
|
|
|
|
def append(self, param: torch.nn.Parameter):
|
|
|
|
self.param_visited_order.append(param)
|
|
|
|
|
|
|
|
def generate(self):
|
|
|
|
visited_set = set()
|
|
|
|
for p in self.param_visited_order:
|
|
|
|
if p not in visited_set:
|
|
|
|
yield p
|
|
|
|
visited_set.add(p)
|
|
|
|
del visited_set
|
|
|
|
|
|
|
|
def clear(self):
|
|
|
|
self.param_visited_order = []
|