ColossalAI/colossalai/inference/pipeline/utils.py

36 lines
1.1 KiB
Python

from typing import Set
import torch.nn as nn
from colossalai.shardformer._utils import getattr_, setattr_
def set_tensors_to_none(model: nn.Module, include: Set[str] = set()) -> None:
"""
Set all parameters and buffers of model to None
Args:
model (nn.Module): The model to set
"""
for module_suffix in include:
set_module = getattr_(model, module_suffix)
for n, p in set_module.named_parameters():
setattr_(set_module, n, None)
for n, buf in set_module.named_buffers():
setattr_(set_module, n, None)
setattr_(model, module_suffix, None)
def get_suffix_name(suffix: str, name: str):
"""
Get the suffix name of the module, as `suffix.name` when name is string or `suffix[name]` when name is a digit,
and 'name' when `suffix` is empty.
Args:
suffix (str): The suffix of the suffix module
name (str): The name of the current module
"""
point = "" if suffix is "" else "."
suffix_name = suffix + f"[{name}]" if name.isdigit() else suffix + f"{point}{name}"
return suffix_name