mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] add bert test for gemini fwd bwd (#2035)
parent
0dbcd4a6f5
commit
96134e7be3
|
@ -34,17 +34,15 @@ def register_elementwise_op(op):
|
|||
dist_attr=input_tensor.dist_spec))
|
||||
|
||||
|
||||
@colo_op_impl(torch.relu_)
|
||||
def elementwise_op(input_tensor):
|
||||
torch.relu_(input_tensor.data)
|
||||
return input_tensor
|
||||
|
||||
|
||||
@colo_op_impl(Tensor.add_)
|
||||
def elementwise_op(input_tensor: ColoTensor, *args, **kwargs):
|
||||
input_tensor = input_tensor.data.add_(*args, **kwargs)
|
||||
return input_tensor
|
||||
# @colo_op_impl(torch.relu_)
|
||||
# def elementwise_op(input_tensor):
|
||||
# torch.relu_(input_tensor.data)
|
||||
# return input_tensor
|
||||
|
||||
# @colo_op_impl(Tensor.add_)
|
||||
# def elementwise_op(input_tensor: ColoTensor, *args, **kwargs):
|
||||
# input_tensor = input_tensor.data.add_(*args, **kwargs)
|
||||
# return input_tensor
|
||||
|
||||
# Tensor op
|
||||
register_elementwise_op(Tensor.abs)
|
||||
|
|
|
@ -272,7 +272,7 @@ class ZeroDDP(ColoDDP):
|
|||
p.grad = None
|
||||
|
||||
def _post_backward(self):
|
||||
# assert self.chunk_manager.accessed_mem == 0
|
||||
assert self.chunk_manager.accessed_mem == 0
|
||||
self._setup_grads_ptr()
|
||||
self._logger.debug(
|
||||
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
|
||||
|
|
|
@ -33,7 +33,7 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
|
|||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||
@parameterize('keep_gather', [False, True])
|
||||
@parameterize('model_name', ['gpt2', 'bert', 'resnet18'])
|
||||
@parameterize('model_name', ['gpt2', 'bert'])
|
||||
@parameterize('use_grad_checkpoint', [False, True])
|
||||
def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False):
|
||||
set_seed(42)
|
||||
|
@ -78,7 +78,7 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_ch
|
|||
torch.max(torch.abs(loss - torch_loss)).item(), loss, torch_loss)
|
||||
|
||||
# FIXME(1SAA) bert and resnet18 can not pass the check_grad
|
||||
# check_grad(model, torch_model)
|
||||
check_grad(model, torch_model)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
|
Loading…
Reference in New Issue