mirror of https://github.com/hpcaitech/ColossalAI
[fix] multi-node backward slowdown (#6134)
* remove redundant memcpy during backward * get back record_streampull/6142/head
parent
c2fe3137e2
commit
cc40fe0e6f
|
@ -78,13 +78,13 @@ class BucketStore(BaseStore):
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
for param, padding_size in zip(self._param_list, self._padding_size):
|
for param, padding_size in zip(self._param_list, self._padding_size):
|
||||||
grad = param.grad.clone().detach().flatten()
|
grad = param.grad.detach().flatten()
|
||||||
if padding_size > 0:
|
if padding_size > 0:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
grad = torch.nn.functional.pad(grad.view(-1), [0, padding_size])
|
grad = torch.nn.functional.pad(grad.view(-1), [0, padding_size])
|
||||||
grad_list = grad.split(grad.numel() // self._world_size)
|
grad_list = grad.split(grad.numel() // self._world_size)
|
||||||
for rank in range(self._world_size):
|
for rank in range(self._world_size):
|
||||||
grad_current_rank = grad_list[rank].clone().detach()
|
grad_current_rank = grad_list[rank].detach()
|
||||||
self.grad_to_param_mapping[id(grad_current_rank)] = id(param)
|
self.grad_to_param_mapping[id(grad_current_rank)] = id(param)
|
||||||
self._grad_in_bucket[rank].append(grad_current_rank)
|
self._grad_in_bucket[rank].append(grad_current_rank)
|
||||||
param.grad = None
|
param.grad = None
|
||||||
|
@ -110,7 +110,7 @@ class BucketStore(BaseStore):
|
||||||
|
|
||||||
flat_grad = []
|
flat_grad = []
|
||||||
for grad_list in self._grad_in_bucket.values():
|
for grad_list in self._grad_in_bucket.values():
|
||||||
flat_grad.append(_flatten_dense_tensors(grad_list))
|
flat_grad.extend(grad_list)
|
||||||
flat_grad = _flatten_dense_tensors(flat_grad)
|
flat_grad = _flatten_dense_tensors(flat_grad)
|
||||||
return flat_grad
|
return flat_grad
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue