mirror of https://github.com/hpcaitech/ColossalAI
[pipeline]support List of Dict data (#1125)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [pipeline]support List of Dict data
* polish
pull/1128/head
parent
91a5999825
commit
3175bcb4d8
|
@ -48,9 +48,11 @@ class BaseSchedule(ABC):
|
||||||
if isinstance(data, torch.Tensor):
|
if isinstance(data, torch.Tensor):
|
||||||
return data.size(0)
|
return data.size(0)
|
||||||
elif isinstance(data, (list, tuple)):
|
elif isinstance(data, (list, tuple)):
|
||||||
|
if isinstance(data[0], dict):
|
||||||
|
return data[0][list(data[0].keys())[0]].size(0)
|
||||||
return data[0].size(0)
|
return data[0].size(0)
|
||||||
elif isinstance(data, dict):
|
elif isinstance(data, dict):
|
||||||
return data[next(data.keys())].size(0)
|
return data[list(data.keys())[0]].size(0)
|
||||||
|
|
||||||
def load_batch(self, data_iter, to_gpu=True):
|
def load_batch(self, data_iter, to_gpu=True):
|
||||||
"""Loads a batch from data iterator. It returns the data and labels which are
|
"""Loads a batch from data iterator. It returns the data and labels which are
|
||||||
|
|
|
@ -137,6 +137,12 @@ class PipelineSchedule(BaseSchedule):
|
||||||
if isinstance(data, torch.Tensor):
|
if isinstance(data, torch.Tensor):
|
||||||
return data[offset:offset + self.microbatch_size]
|
return data[offset:offset + self.microbatch_size]
|
||||||
elif isinstance(data, (list, tuple)):
|
elif isinstance(data, (list, tuple)):
|
||||||
|
data_dict = {}
|
||||||
|
for element in data:
|
||||||
|
if isinstance(element, dict):
|
||||||
|
data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()})
|
||||||
|
if data_dict:
|
||||||
|
return data_dict
|
||||||
return [val[offset:offset + self.microbatch_size] for val in data]
|
return [val[offset:offset + self.microbatch_size] for val in data]
|
||||||
elif isinstance(data, dict):
|
elif isinstance(data, dict):
|
||||||
return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()}
|
return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()}
|
||||||
|
@ -216,7 +222,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
|
|
||||||
# get all parameter names for the forward function of the model
|
# get all parameter names for the forward function of the model
|
||||||
fwd_sig = self._get_actual_forward_func(model)
|
fwd_sig = self._get_actual_forward_func(model)
|
||||||
fwd_sig_param_name = [p.name for p in fwd_sig.values()]
|
fwd_sig_param_name = [p.name for p in fwd_sig.parameters.values()]
|
||||||
|
|
||||||
# build the kwargs for the forward function
|
# build the kwargs for the forward function
|
||||||
for idx, param_name in enumerate(fwd_sig_param_name):
|
for idx, param_name in enumerate(fwd_sig_param_name):
|
||||||
|
@ -228,7 +234,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
|
|
||||||
# get the tensors for loss
|
# get the tensors for loss
|
||||||
loss_sig = inspect.signature(criterion)
|
loss_sig = inspect.signature(criterion)
|
||||||
loss_sig_param_name = [p.name for p in loss_sig.values()]
|
loss_sig_param_name = [p.name for p in loss_sig.parameters.values()]
|
||||||
|
|
||||||
for param_name in loss_sig_param_name:
|
for param_name in loss_sig_param_name:
|
||||||
if param_name in micro_batch_data:
|
if param_name in micro_batch_data:
|
||||||
|
|
|
@ -12,7 +12,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
A context manager to split the model into pipeline stages.
|
A context manager to split the model into pipeline stages.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, policy: str="balanced"):
|
def __init__(self, policy: str = "balanced"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._layer_spec_dict = {}
|
self._layer_spec_dict = {}
|
||||||
self._root_children = None
|
self._root_children = None
|
||||||
|
@ -61,11 +61,12 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
"""
|
"""
|
||||||
# iterate over the positional arguments
|
# iterate over the positional arguments
|
||||||
# to check if an argument is a torch Module
|
# to check if an argument is a torch Module
|
||||||
# if found any torch Module, replace it with its layer spec
|
# if found any torch Module, replace it with its layer spec
|
||||||
# for storage purpose
|
# for storage purpose
|
||||||
modified_args = []
|
modified_args = []
|
||||||
for arg in args:
|
for arg in args:
|
||||||
if isinstance(arg, torch.nn.Module):
|
if isinstance(arg, torch.nn.Module):
|
||||||
|
# (lyl)TODO: if nn.Module is an argument of the root module, then we should just record the module instance itself.
|
||||||
arg = self._layer_spec_dict[id(arg)]
|
arg = self._layer_spec_dict[id(arg)]
|
||||||
modified_args.append(arg)
|
modified_args.append(arg)
|
||||||
|
|
||||||
|
@ -255,6 +256,3 @@ class PipelinableModel(torch.nn.Module):
|
||||||
input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)
|
input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)
|
||||||
|
|
||||||
return input_tensor
|
return input_tensor
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue