mirror of https://github.com/hpcaitech/ColossalAI
sync before creating empty grad
parent
ea6905a898
commit
fce9432f08
|
@ -218,6 +218,7 @@ class ShardedModelV2(nn.Module):
|
|||
else:
|
||||
self._reduce_scatter_callback(param, new_grad)
|
||||
orig_grad_data.record_stream(self.comm_stream)
|
||||
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
||||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
return empty_grad
|
||||
|
|
|
@ -2,12 +2,14 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
from asyncio.log import logger
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
|
@ -18,12 +20,12 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
|||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from common import CONFIG, check_grads_padding, run_fwd_bwd
|
||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_strategy):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
logger = get_dist_logger()
|
||||
logger.set_level('DEBUG')
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
||||
shard_strategy = shard_strategy()
|
||||
for model_name in test_models:
|
||||
|
@ -60,8 +62,8 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s
|
|||
|
||||
check_grads_padding(model, zero_model, loose=True)
|
||||
|
||||
print('overall cuda ', zero_model._memstats_collector._overall_cuda)
|
||||
print('model cuda ', zero_model._memstats_collector._model_data_cuda)
|
||||
# logger.debug('overall cuda ', zero_model._memstats_collector._overall_cuda)
|
||||
# logger.debug('model cuda ', zero_model._memstats_collector._model_data_cuda)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
Loading…
Reference in New Issue