mirror of https://github.com/hpcaitech/ColossalAI
[pipeline]support List of Dict data (#1125)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [pipeline]support List of Dict data
* polish
pull/1128/head
parent
91a5999825
commit
3175bcb4d8
|
@ -48,9 +48,11 @@ class BaseSchedule(ABC):
|
|||
if isinstance(data, torch.Tensor):
|
||||
return data.size(0)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
if isinstance(data[0], dict):
|
||||
return data[0][list(data[0].keys())[0]].size(0)
|
||||
return data[0].size(0)
|
||||
elif isinstance(data, dict):
|
||||
return data[next(data.keys())].size(0)
|
||||
return data[list(data.keys())[0]].size(0)
|
||||
|
||||
def load_batch(self, data_iter, to_gpu=True):
|
||||
"""Loads a batch from data iterator. It returns the data and labels which are
|
||||
|
|
|
@ -137,6 +137,12 @@ class PipelineSchedule(BaseSchedule):
|
|||
if isinstance(data, torch.Tensor):
|
||||
return data[offset:offset + self.microbatch_size]
|
||||
elif isinstance(data, (list, tuple)):
|
||||
data_dict = {}
|
||||
for element in data:
|
||||
if isinstance(element, dict):
|
||||
data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()})
|
||||
if data_dict:
|
||||
return data_dict
|
||||
return [val[offset:offset + self.microbatch_size] for val in data]
|
||||
elif isinstance(data, dict):
|
||||
return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()}
|
||||
|
@ -216,7 +222,7 @@ class PipelineSchedule(BaseSchedule):
|
|||
|
||||
# get all parameter names for the forward function of the model
|
||||
fwd_sig = self._get_actual_forward_func(model)
|
||||
fwd_sig_param_name = [p.name for p in fwd_sig.values()]
|
||||
fwd_sig_param_name = [p.name for p in fwd_sig.parameters.values()]
|
||||
|
||||
# build the kwargs for the forward function
|
||||
for idx, param_name in enumerate(fwd_sig_param_name):
|
||||
|
@ -228,7 +234,7 @@ class PipelineSchedule(BaseSchedule):
|
|||
|
||||
# get the tensors for loss
|
||||
loss_sig = inspect.signature(criterion)
|
||||
loss_sig_param_name = [p.name for p in loss_sig.values()]
|
||||
loss_sig_param_name = [p.name for p in loss_sig.parameters.values()]
|
||||
|
||||
for param_name in loss_sig_param_name:
|
||||
if param_name in micro_batch_data:
|
||||
|
|
|
@ -12,7 +12,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
|||
A context manager to split the model into pipeline stages.
|
||||
"""
|
||||
|
||||
def __init__(self, policy: str="balanced"):
|
||||
def __init__(self, policy: str = "balanced"):
|
||||
super().__init__()
|
||||
self._layer_spec_dict = {}
|
||||
self._root_children = None
|
||||
|
@ -61,11 +61,12 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
|||
"""
|
||||
# iterate over the positional arguments
|
||||
# to check if an argument is a torch Module
|
||||
# if found any torch Module, replace it with its layer spec
|
||||
# if found any torch Module, replace it with its layer spec
|
||||
# for storage purpose
|
||||
modified_args = []
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.nn.Module):
|
||||
# (lyl)TODO: if nn.Module is an argument of the root module, then we should just record the module instance itself.
|
||||
arg = self._layer_spec_dict[id(arg)]
|
||||
modified_args.append(arg)
|
||||
|
||||
|
@ -255,6 +256,3 @@ class PipelinableModel(torch.nn.Module):
|
|||
input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue