from typing import Optional class TensorParallelEnv(object): _instance = None def __new__(cls, *args, **kwargs): if cls._instance is None: cls._instance = object.__new__(cls, *args, **kwargs) return cls._instance def __init__(self, *args, **kwargs): self.load(*args, **kwargs) def load(self, mode: Optional[str] = None, vocab_parallel: bool = False, parallel_input_1d: bool = False, summa_dim: int = None, tesseract_dim: int = None, tesseract_dep: int = None, depth_3d: int = None, input_group_3d=None, weight_group_3d=None, output_group_3d=None, input_x_weight_group_3d=None, output_x_weight_group_3d=None): self.mode = mode self.vocab_parallel = vocab_parallel self.parallel_input_1d = parallel_input_1d self.summa_dim = summa_dim self.tesseract_dim = tesseract_dim self.tesseract_dep = tesseract_dep self.depth_3d = depth_3d self.input_group_3d = input_group_3d self.weight_group_3d = weight_group_3d self.output_group_3d = output_group_3d self.input_x_weight_group_3d = input_x_weight_group_3d self.output_x_weight_group_3d = output_x_weight_group_3d def save(self): return dict(mode=self.mode, vocab_parallel=self.vocab_parallel, parallel_input_1d=self.parallel_input_1d, summa_dim=self.summa_dim, tesseract_dim=self.tesseract_dim, tesseract_dep=self.tesseract_dep, depth_3d=self.depth_3d, input_group_3d=self.input_group_3d, weight_group_3d=self.weight_group_3d, output_group_3d=self.output_group_3d, input_x_weight_group_3d=self.input_x_weight_group_3d, output_x_weight_group_3d=self.output_x_weight_group_3d) tensor_parallel_env = TensorParallelEnv()