mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style (#1559)
parent
b0f4c0bddf
commit
318fbf1145
|
@ -778,4 +778,4 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
|
|||
criterion: Callable = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
use_1F1B = True
|
||||
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, checkpoint)
|
||||
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, checkpoint)
|
||||
|
|
|
@ -26,13 +26,9 @@ class MultiTensorApply(object):
|
|||
raise RuntimeError(
|
||||
"Attempted to call MultiTensorApply method, but MultiTensorApply "
|
||||
"is not available, possibly because Apex was installed without "
|
||||
"--cpp_ext --cuda_ext. Original import error message:",
|
||||
MultiTensorApply.import_err)
|
||||
"--cpp_ext --cuda_ext. Original import error message:", MultiTensorApply.import_err)
|
||||
|
||||
def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
|
||||
self.check_avail()
|
||||
|
||||
return op(self.chunk_size,
|
||||
noop_flag_buffer,
|
||||
tensor_lists,
|
||||
*args)
|
||||
return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)
|
||||
|
|
Loading…
Reference in New Issue