2022-12-11 13:41:13 +00:00
|
|
|
from abc import ABC
|
|
|
|
|
2022-12-09 08:13:03 +00:00
|
|
|
import torch
|
|
|
|
|
|
|
|
|
2022-12-11 13:41:13 +00:00
|
|
|
class ParamGenerator(ABC):
|
|
|
|
|
|
|
|
def append(self, param: torch.nn.Parameter):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def generate(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def clear(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class OrderedParamGenerator(ParamGenerator):
|
|
|
|
"""OrderedParamGenerator
|
2022-12-09 08:13:03 +00:00
|
|
|
|
|
|
|
Contain the order of parameters visited during runtime.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
self.param_visited_order = []
|
|
|
|
|
|
|
|
def append(self, param: torch.nn.Parameter):
|
|
|
|
self.param_visited_order.append(param)
|
|
|
|
|
|
|
|
def generate(self):
|
|
|
|
visited_set = set()
|
|
|
|
for p in self.param_visited_order:
|
|
|
|
if p not in visited_set:
|
|
|
|
yield p
|
|
|
|
visited_set.add(p)
|
|
|
|
del visited_set
|
|
|
|
|
2022-12-12 10:06:16 +00:00
|
|
|
def is_empty(self):
|
2022-12-13 06:14:55 +00:00
|
|
|
return len(self.param_visited_order) == 0
|
2022-12-12 10:06:16 +00:00
|
|
|
|
2022-12-09 08:13:03 +00:00
|
|
|
def clear(self):
|
|
|
|
self.param_visited_order = []
|