sync before creating empty grad

pull/433/head
ver217 2022-03-16 13:40:19 +08:00
parent ea6905a898
commit fce9432f08
2 changed files with 7 additions and 4 deletions

View File

@ -218,6 +218,7 @@ class ShardedModelV2(nn.Module):
else:
self._reduce_scatter_callback(param, new_grad)
orig_grad_data.record_stream(self.comm_stream)
torch.cuda.current_stream().wait_stream(self.comm_stream)
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
return empty_grad

View File

@ -2,12 +2,14 @@
# -*- encoding: utf-8 -*-
import copy
from asyncio.log import logger
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.logging import get_dist_logger
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
@ -18,12 +20,12 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_grads_padding, run_fwd_bwd
from colossalai.zero.sharded_model.utils import col_model_deepcopy
def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_strategy):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
logger = get_dist_logger()
logger.set_level('DEBUG')
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy()
for model_name in test_models:
@ -60,8 +62,8 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s
check_grads_padding(model, zero_model, loose=True)
print('overall cuda ', zero_model._memstats_collector._overall_cuda)
print('model cuda ', zero_model._memstats_collector._model_data_cuda)
# logger.debug('overall cuda ', zero_model._memstats_collector._overall_cuda)
# logger.debug('model cuda ', zero_model._memstats_collector._model_data_cuda)
@pytest.mark.dist