|
|
|
@ -6,6 +6,7 @@ from colossalai.logging import get_dist_logger
|
|
|
|
|
from colossalai.nn.layer.utils import CheckpointModule |
|
|
|
|
from typing import List |
|
|
|
|
|
|
|
|
|
from collections import OrderedDict |
|
|
|
|
|
|
|
|
|
def _binary_partition(weights: List, start: int, end: int): |
|
|
|
|
"""Returns the binary partition position of `weights`, given the start |
|
|
|
@ -159,8 +160,10 @@ def build_kwargs_for_module(function, input_tensor, kw_dict):
|
|
|
|
|
kwargs_offset = 0 |
|
|
|
|
elif isinstance(input_tensor, torch.Tensor): |
|
|
|
|
kwargs_offset = 1 |
|
|
|
|
else: |
|
|
|
|
assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.' |
|
|
|
|
elif isinstance(input_tensor, (tuple, OrderedDict)): |
|
|
|
|
#assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.' |
|
|
|
|
# Huggingface will take their own structures based on OrderedDict as the output |
|
|
|
|
# between layers so we've to close this check. |
|
|
|
|
kwargs_offset = len(input_tensor) |
|
|
|
|
args_name_list = list(sig.parameters.keys()) |
|
|
|
|
kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[kwargs_offset:]} |
|
|
|
|