[hotfix] add bert test for gemini fwd bwd (#2035)

pull/2036/head
Jiarui Fang 2022-11-29 11:19:52 +08:00 committed by GitHub
parent 0dbcd4a6f5
commit 96134e7be3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 13 deletions

View File

@ -34,17 +34,15 @@ def register_elementwise_op(op):
dist_attr=input_tensor.dist_spec))
@colo_op_impl(torch.relu_)
def elementwise_op(input_tensor):
torch.relu_(input_tensor.data)
return input_tensor
@colo_op_impl(Tensor.add_)
def elementwise_op(input_tensor: ColoTensor, *args, **kwargs):
input_tensor = input_tensor.data.add_(*args, **kwargs)
return input_tensor
# @colo_op_impl(torch.relu_)
# def elementwise_op(input_tensor):
# torch.relu_(input_tensor.data)
# return input_tensor
# @colo_op_impl(Tensor.add_)
# def elementwise_op(input_tensor: ColoTensor, *args, **kwargs):
# input_tensor = input_tensor.data.add_(*args, **kwargs)
# return input_tensor
# Tensor op
register_elementwise_op(Tensor.abs)

View File

@ -272,7 +272,7 @@ class ZeroDDP(ColoDDP):
p.grad = None
def _post_backward(self):
# assert self.chunk_manager.accessed_mem == 0
assert self.chunk_manager.accessed_mem == 0
self._setup_grads_ptr()
self._logger.debug(
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'

View File

@ -33,7 +33,7 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('keep_gather', [False, True])
@parameterize('model_name', ['gpt2', 'bert', 'resnet18'])
@parameterize('model_name', ['gpt2', 'bert'])
@parameterize('use_grad_checkpoint', [False, True])
def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False):
set_seed(42)
@ -78,7 +78,7 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_ch
torch.max(torch.abs(loss - torch_loss)).item(), loss, torch_loss)
# FIXME(1SAA) bert and resnet18 can not pass the check_grad
# check_grad(model, torch_model)
check_grad(model, torch_model)
def run_dist(rank, world_size, port):