[pipeline]support List of Dict data (#1125)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [pipeline]support List of Dict data

* polish
pull/1128/head
YuliangLiu0306 2022-06-16 11:19:48 +08:00 committed by GitHub
parent 91a5999825
commit 3175bcb4d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 8 deletions

View File

@ -48,9 +48,11 @@ class BaseSchedule(ABC):
if isinstance(data, torch.Tensor):
return data.size(0)
elif isinstance(data, (list, tuple)):
if isinstance(data[0], dict):
return data[0][list(data[0].keys())[0]].size(0)
return data[0].size(0)
elif isinstance(data, dict):
return data[next(data.keys())].size(0)
return data[list(data.keys())[0]].size(0)
def load_batch(self, data_iter, to_gpu=True):
"""Loads a batch from data iterator. It returns the data and labels which are

View File

@ -137,6 +137,12 @@ class PipelineSchedule(BaseSchedule):
if isinstance(data, torch.Tensor):
return data[offset:offset + self.microbatch_size]
elif isinstance(data, (list, tuple)):
data_dict = {}
for element in data:
if isinstance(element, dict):
data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()})
if data_dict:
return data_dict
return [val[offset:offset + self.microbatch_size] for val in data]
elif isinstance(data, dict):
return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()}
@ -216,7 +222,7 @@ class PipelineSchedule(BaseSchedule):
# get all parameter names for the forward function of the model
fwd_sig = self._get_actual_forward_func(model)
fwd_sig_param_name = [p.name for p in fwd_sig.values()]
fwd_sig_param_name = [p.name for p in fwd_sig.parameters.values()]
# build the kwargs for the forward function
for idx, param_name in enumerate(fwd_sig_param_name):
@ -228,7 +234,7 @@ class PipelineSchedule(BaseSchedule):
# get the tensors for loss
loss_sig = inspect.signature(criterion)
loss_sig_param_name = [p.name for p in loss_sig.values()]
loss_sig_param_name = [p.name for p in loss_sig.parameters.values()]
for param_name in loss_sig_param_name:
if param_name in micro_batch_data:

View File

@ -12,7 +12,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
A context manager to split the model into pipeline stages.
"""
def __init__(self, policy: str="balanced"):
def __init__(self, policy: str = "balanced"):
super().__init__()
self._layer_spec_dict = {}
self._root_children = None
@ -61,11 +61,12 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
"""
# iterate over the positional arguments
# to check if an argument is a torch Module
# if found any torch Module, replace it with its layer spec
# if found any torch Module, replace it with its layer spec
# for storage purpose
modified_args = []
for arg in args:
if isinstance(arg, torch.nn.Module):
# (lyl)TODO: if nn.Module is an argument of the root module, then we should just record the module instance itself.
arg = self._layer_spec_dict[id(arg)]
modified_args.append(arg)
@ -255,6 +256,3 @@ class PipelinableModel(torch.nn.Module):
input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)
return input_tensor