mirror of https://github.com/hpcaitech/ColossalAI
[hotfix]fix bugs caused by refactored pipeline (#1133)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [hotfix]fix bugs caused by refactored pipeline
pull/1130/head^2
parent
789cad301b
commit
946dbd629d
|
@ -67,8 +67,8 @@ class NonPipelineSchedule(BaseSchedule):
|
|||
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
batch_data = self.load_batch(data_iter)
|
||||
|
||||
if self.batch_data_process_func:
|
||||
data, label = self.batch_data_process_func(batch_data)
|
||||
if self.data_process_func:
|
||||
data, label = self.data_process_func(batch_data)
|
||||
else:
|
||||
# if not batch data process func is given,
|
||||
# then we regard the batch data as a simple tuple of (data, label)
|
||||
|
|
|
@ -141,6 +141,8 @@ class PipelineSchedule(BaseSchedule):
|
|||
for element in data:
|
||||
if isinstance(element, dict):
|
||||
data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()})
|
||||
elif data_dict:
|
||||
data_dict['label'] = element[offset:offset + self.microbatch_size]
|
||||
if data_dict:
|
||||
return data_dict
|
||||
return [val[offset:offset + self.microbatch_size] for val in data]
|
||||
|
@ -175,7 +177,10 @@ class PipelineSchedule(BaseSchedule):
|
|||
elif isinstance(data, (list, tuple)):
|
||||
return model(*data)
|
||||
elif isinstance(data, dict):
|
||||
return model(**data)
|
||||
stage_output = None
|
||||
if 'stage_output' in data:
|
||||
stage_output = data.pop('stage_output')
|
||||
return model(stage_output, **data)
|
||||
else:
|
||||
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
|
||||
|
||||
|
@ -204,41 +209,14 @@ class PipelineSchedule(BaseSchedule):
|
|||
data = stage_output
|
||||
_, label = micro_batch_data
|
||||
elif isinstance(micro_batch_data, dict):
|
||||
args = []
|
||||
data = {}
|
||||
label = {}
|
||||
|
||||
# we feed the stage output to args first
|
||||
# then map each arg in args to its param name
|
||||
if stage_output is not None:
|
||||
if isinstance(stage_output, torch.Tensor):
|
||||
args.append(stage_output)
|
||||
elif isinstance(stage_output, (list, tuple)):
|
||||
args.extend(stage_output)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected the values passed from previous pipeline stage to be torch.Tensor, list or tuple, but got {type(input_obj)}"
|
||||
)
|
||||
|
||||
# get all parameter names for the forward function of the model
|
||||
fwd_sig = self._get_actual_forward_func(model)
|
||||
fwd_sig_param_name = [p.name for p in fwd_sig.parameters.values()]
|
||||
|
||||
# build the kwargs for the forward function
|
||||
for idx, param_name in enumerate(fwd_sig_param_name):
|
||||
if idx < len(args):
|
||||
data[param_name] = args[idx]
|
||||
else:
|
||||
if param_name in micro_batch_data:
|
||||
data[param_name] = micro_batch_data[param_name]
|
||||
|
||||
# get the tensors for loss
|
||||
loss_sig = inspect.signature(criterion)
|
||||
loss_sig_param_name = [p.name for p in loss_sig.parameters.values()]
|
||||
|
||||
for param_name in loss_sig_param_name:
|
||||
if param_name in micro_batch_data:
|
||||
label[param_name] = micro_batch_data[param_name]
|
||||
data['stage_output'] = stage_output
|
||||
if 'label' in micro_batch_data:
|
||||
label = micro_batch_data.pop('label')
|
||||
else:
|
||||
label = None
|
||||
load_data = micro_batch_data
|
||||
data.update(load_data)
|
||||
return data, label
|
||||
|
||||
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
|
||||
|
|
|
@ -66,8 +66,11 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
|||
modified_args = []
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.nn.Module):
|
||||
# (lyl)TODO: if nn.Module is an argument of the root module, then we should just record the module instance itself.
|
||||
arg = self._layer_spec_dict[id(arg)]
|
||||
# if nn.Module is an argument of a non-root module, then we should convert it to layer spec, which make sure the correct init method used in the real build.
|
||||
# if nn.Module is an argument of the root module, then we should just record the module instance itself, because those instance has been built outside of the context.
|
||||
if id(arg) in self._layer_spec_dict:
|
||||
arg = self._layer_spec_dict[id(arg)]
|
||||
|
||||
modified_args.append(arg)
|
||||
|
||||
# to the same for the keyword arguments
|
||||
|
|
Loading…
Reference in New Issue