mirror of https://github.com/hpcaitech/ColossalAI
[zero] memtracer to record cuda memory usage of model data and overall system (#395)
parent
a37bf1bc42
commit
21dc54e019
|
@ -4,6 +4,9 @@ from colossalai.utils import get_current_device
|
|||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
|
||||
from ._base_ophook import BaseOpHook
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
|
@ -12,14 +15,17 @@ class ZeroHook(BaseOpHook):
|
|||
A hook to process sharded param for ZeRO method.
|
||||
"""
|
||||
|
||||
def __init__(self, shard_strategy: BaseShardStrategy):
|
||||
def __init__(self, shard_strategy: BaseShardStrategy, memstarts_collector: Optional[MemStatsCollector]):
|
||||
super().__init__()
|
||||
self.shard_strategy = shard_strategy
|
||||
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
|
||||
self.computing_device = torch.device(f'cuda:{get_current_device()}')
|
||||
|
||||
self._memstarts_collector = memstarts_collector
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
tensor_list = []
|
||||
global_model_data_tracer = ModelDataTracer()
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.data)
|
||||
|
@ -27,8 +33,12 @@ class ZeroHook(BaseOpHook):
|
|||
for param in module.parameters():
|
||||
if param.col_attr.data.device != self.computing_device:
|
||||
param.col_attr.data.to(self.computing_device)
|
||||
global_model_data_tracer.add_tensor(param.col_attr.data.payload)
|
||||
param.data = param.col_attr.data.payload
|
||||
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
tensor_list = []
|
||||
for param in module.parameters():
|
||||
|
@ -40,6 +50,7 @@ class ZeroHook(BaseOpHook):
|
|||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
tensor_list = []
|
||||
global_model_data_tracer = ModelDataTracer()
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.data)
|
||||
|
@ -47,6 +58,7 @@ class ZeroHook(BaseOpHook):
|
|||
for param in module.parameters():
|
||||
if param.col_attr.data.device != self.computing_device:
|
||||
param.col_attr.data.to(self.computing_device)
|
||||
global_model_data_tracer.add_tensor(param.col_attr.data.payload)
|
||||
param.data = param.col_attr.data.payload
|
||||
# Store local accumulated grad shard
|
||||
if param.grad is not None:
|
||||
|
@ -60,6 +72,8 @@ class ZeroHook(BaseOpHook):
|
|||
# The grad here must be locally computed full grad in this backward pass
|
||||
assert param.grad.shape == param.col_attr.data.origin_shape
|
||||
param.col_attr.bwd_count += 1
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
tensor_list = []
|
||||
|
|
|
@ -1,60 +1,19 @@
|
|||
import torch
|
||||
from colossalai.utils.commons.singleton_meta import SingletonMeta
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
|
||||
from typing import Union
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||
|
||||
|
||||
def col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
|
||||
if isinstance(t, ShardedTensor):
|
||||
target = t.payload
|
||||
else:
|
||||
target = t
|
||||
return target.numel() * target.element_size()
|
||||
def col_move_to_cpu(t: torch.Tensor):
|
||||
assert isinstance(t, torch.Tensor)
|
||||
if t.device.type == 'cpu':
|
||||
return
|
||||
|
||||
ModelDataTracer().delete_tensor(t)
|
||||
t.data = t.data.cpu()
|
||||
|
||||
|
||||
class ModelDataTracer(metaclass=SingletonMeta):
|
||||
"""
|
||||
A singleton to trace model data usage during runtime.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._cpu_usage = 0
|
||||
self._cuda_usage = 0
|
||||
|
||||
def trace_tensor(self, t: torch.Tensor):
|
||||
mem_use = col_tensor_mem_usage(t)
|
||||
if t.device.type == 'cpu':
|
||||
self._cpu_usage += mem_use
|
||||
elif t.device.type == 'cuda':
|
||||
self._cuda_usage += mem_use
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
def detach_tensor(self, t: torch.Tensor):
|
||||
mem_use = col_tensor_mem_usage(t)
|
||||
if t.device.type == 'cpu':
|
||||
self._cpu_usage -= mem_use
|
||||
elif t.device.type == 'cuda':
|
||||
self._cuda_usage -= mem_use
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
@property
|
||||
def cpu_usage(self):
|
||||
return self._cpu_usage
|
||||
|
||||
@property
|
||||
def cuda_usage(self):
|
||||
return self._cuda_usage
|
||||
|
||||
|
||||
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
|
||||
|
||||
|
||||
def col_allocate_payload(device: torch.device) -> torch.Tensor:
|
||||
def col_modeldata_allocate(device: torch.device) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
|
||||
def col_release_payload(t: torch.Tensor):
|
||||
def col_modeldata_release(t: torch.Tensor):
|
||||
pass
|
||||
|
|
|
@ -6,7 +6,7 @@ from colossalai.utils import get_current_device
|
|||
import torch
|
||||
|
||||
|
||||
def _get_cuda_memory_used(device: torch.device) -> int:
|
||||
def get_cuda_memory_used(device: torch.device) -> int:
|
||||
"""
|
||||
Get the free memory info of device.
|
||||
:param device: device id
|
||||
|
@ -87,7 +87,7 @@ class AsyncMemoryMonitor:
|
|||
while self.keep_measuring:
|
||||
max_usage = max(
|
||||
max_usage,
|
||||
_get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')),
|
||||
get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')),
|
||||
)
|
||||
sleep(self.interval)
|
||||
return max_usage
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
from typing import Union
|
||||
import torch
|
||||
|
||||
|
||||
def col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
|
||||
if isinstance(t, ShardedTensor):
|
||||
target = t.payload
|
||||
else:
|
||||
target = t
|
||||
return target.numel() * target.element_size()
|
|
@ -0,0 +1,81 @@
|
|||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||
from .async_memtracer import get_cuda_memory_used
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class SamplingCounter:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._samplint_cnt = 0
|
||||
|
||||
def advance(self):
|
||||
self._samplint_cnt += 1
|
||||
|
||||
@property
|
||||
def sampling_cnt(self):
|
||||
return self._samplint_cnt
|
||||
|
||||
def reset(self):
|
||||
self._samplint_cnt = 0
|
||||
|
||||
|
||||
class MemStatsCollector:
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Collecting Memory Statistics.
|
||||
It has two phases.
|
||||
1. Collection Phase: collect memory usage statistics
|
||||
2. Runtime Phase: do not collect statistics.
|
||||
"""
|
||||
self._sampling_cnter = SamplingCounter()
|
||||
self._model_data_cuda = []
|
||||
self._overall_cuda = []
|
||||
|
||||
# TODO(jiaruifang) Now no cpu mem stats collecting
|
||||
self._model_data_cpu = []
|
||||
self._overall_cpu = []
|
||||
|
||||
self._start_flag = False
|
||||
|
||||
def start_collection(self):
|
||||
self._start_flag = True
|
||||
|
||||
def finish_collection(self):
|
||||
self._start_flag = False
|
||||
|
||||
def sample_memstats(self) -> None:
|
||||
"""
|
||||
Sampling memory statistics.
|
||||
Record the current model data CUDA memory usage as well as system CUDA memory usage.
|
||||
"""
|
||||
if self._start_flag:
|
||||
sampling_cnt = self._sampling_cnter.sampling_cnt
|
||||
assert sampling_cnt == len(self._overall_cuda)
|
||||
self._model_data_cuda.append(ModelDataTracer().cuda_usage)
|
||||
self._overall_cuda.append(get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')))
|
||||
self._sampling_cnter.advance()
|
||||
|
||||
def fetch_memstats(self) -> (int, int):
|
||||
"""
|
||||
returns cuda usage of model data and overall cuda usage.
|
||||
"""
|
||||
sampling_cnt = self._sampling_cnter.sampling_cnt
|
||||
if len(self._model_data_cuda) < sampling_cnt:
|
||||
raise RuntimeError
|
||||
return (self._model_data_cuda[sampling_cnt], self._overall_cuda[sampling_cnt])
|
||||
|
||||
def reset_sampling_cnter(self) -> None:
|
||||
self._sampling_cnter.reset()
|
||||
|
||||
def clear(self) -> None:
|
||||
self._model_data_cuda = []
|
||||
self._overall_cuda = []
|
||||
|
||||
self._model_data_cpu = []
|
||||
self._overall_cpu = []
|
||||
|
||||
self._start_flag = False
|
||||
self._sampling_cnter.reset()
|
|
@ -0,0 +1,34 @@
|
|||
from colossalai.utils.commons.singleton_meta import SingletonMeta
|
||||
from colossalai.utils.memory_tracer.commons import col_tensor_mem_usage
|
||||
import torch
|
||||
|
||||
|
||||
class ModelDataTracer(metaclass=SingletonMeta):
|
||||
"""
|
||||
A singleton to trace model data usage during runtime.
|
||||
We have to trigger our API (trace_tensor, detach_tensor) when do model-data memory operation,
|
||||
including allocation, releasing and moving.
|
||||
|
||||
NOTE() now the class only trace cuda memory usage
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._cuda_usage = 0
|
||||
|
||||
def add_tensor(self, t: torch.Tensor):
|
||||
assert isinstance(t, torch.Tensor), f"ModelDataTracer add_tensor() should accept a torch.Tensor"
|
||||
mem_use = col_tensor_mem_usage(t)
|
||||
self._cuda_usage += mem_use
|
||||
|
||||
def delete_tensor(self, t: torch.Tensor):
|
||||
assert isinstance(t, torch.Tensor), f"ModelDataTracer delete_tensor() should accept a torch.Tensor"
|
||||
mem_use = col_tensor_mem_usage(t)
|
||||
self._cuda_usage -= mem_use
|
||||
|
||||
@property
|
||||
def cpu_usage(self):
|
||||
return self._cpu_usage
|
||||
|
||||
@property
|
||||
def cuda_usage(self):
|
||||
return self._cuda_usage
|
|
@ -0,0 +1,43 @@
|
|||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||
import torch
|
||||
|
||||
|
||||
def test_mem_collector():
|
||||
collector = MemStatsCollector()
|
||||
|
||||
collector.start_collection()
|
||||
|
||||
a = torch.randn(10).cuda()
|
||||
|
||||
# sampling at time 0
|
||||
collector.sample_memstats()
|
||||
|
||||
m_a = torch.randn(10).cuda()
|
||||
ModelDataTracer().add_tensor(m_a)
|
||||
b = torch.randn(10).cuda()
|
||||
|
||||
# sampling at time 1
|
||||
collector.sample_memstats()
|
||||
|
||||
a = b
|
||||
|
||||
# sampling at time 2
|
||||
collector.sample_memstats()
|
||||
|
||||
collector.finish_collection()
|
||||
collector.reset()
|
||||
|
||||
# do nothing after collection, just advance sampling cnter
|
||||
collector.sample_memstats()
|
||||
collector.sample_memstats()
|
||||
|
||||
cuda_use, overall_use = collector.fetch_memstats()
|
||||
print(cuda_use, overall_use)
|
||||
|
||||
print(collector._model_data_cuda)
|
||||
print(collector._overall_cuda)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_mem_collector()
|
|
@ -3,10 +3,11 @@ import functools
|
|||
import torch
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
|
||||
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||
|
||||
# Inserts _post_init_method at the end of init method
|
||||
|
||||
|
||||
# for all sub classes of torch.nn.Module
|
||||
class InsertPostInitMethodToModuleSubClasses(object):
|
||||
|
||||
|
@ -152,7 +153,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
|
||||
GLOBAL_MODEL_DATA_TRACER.trace_tensor(param.col_attr._data_sharded_tensor.payload)
|
||||
ModelDataTracer().add_tensor(param.col_attr._data_sharded_tensor.payload)
|
||||
if param.col_attr.grad and self.shard_grad:
|
||||
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
|
||||
GLOBAL_MODEL_DATA_TRACER.trace_tensor(param.col_attr._grad_sharded_tensor.payload)
|
||||
ModelDataTracer().add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
||||
|
|
|
@ -17,7 +17,8 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
|||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.allocator import col_move_to_cpu
|
||||
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
|
||||
get_gradient_predivide_factor)
|
||||
|
||||
|
@ -33,7 +34,8 @@ class ShardedModelV2(nn.Module):
|
|||
fp32_reduce_scatter: bool = False,
|
||||
offload_config: Optional[dict] = None,
|
||||
gradient_predivide_factor: Optional[float] = 1.0,
|
||||
shard_param: bool = True):
|
||||
shard_param: bool = True,
|
||||
use_memory_tracer: bool = False):
|
||||
r"""
|
||||
A demo to reconfigure zero1 shared_model.
|
||||
Currently do not consider the Optimizer States.
|
||||
|
@ -59,8 +61,16 @@ class ShardedModelV2(nn.Module):
|
|||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.col_attr.data])
|
||||
|
||||
# Init Memory Statistics Collector
|
||||
self._use_memory_tracer = use_memory_tracer
|
||||
if self._use_memory_tracer:
|
||||
self._memstats_collector = MemStatsCollector()
|
||||
else:
|
||||
self._memstats_collector = None
|
||||
self._iter_cnter = 0
|
||||
|
||||
# Register hooks
|
||||
register_ophooks_recursively(self.module, [ZeroHook(self.shard_strategy)])
|
||||
register_ophooks_recursively(self.module, [ZeroHook(self.shard_strategy, self._memstats_collector)])
|
||||
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
|
||||
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
||||
|
||||
|
@ -84,6 +94,9 @@ class ShardedModelV2(nn.Module):
|
|||
return self._cpu_offload
|
||||
|
||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
if self._iter_cnter == 0 and self._memstats_collector:
|
||||
# the opeartion will affect the flag in ZeroHook
|
||||
self._memstats_collector.start_collection()
|
||||
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
|
||||
outputs = self.module(*args, **kwargs)
|
||||
return outputs
|
||||
|
@ -98,6 +111,12 @@ class ShardedModelV2(nn.Module):
|
|||
|
||||
@torch.no_grad()
|
||||
def _final_backward_hook(self) -> None:
|
||||
if self._iter_cnter == 0 and self._memstats_collector:
|
||||
self._memstats_collector.finish_collection()
|
||||
if self._memstats_collector:
|
||||
self._memstats_collector.reset_sampling_cnter()
|
||||
self._iter_cnter += 1
|
||||
|
||||
if self._require_backward_grad_sync:
|
||||
# Flush any unreduced buckets in the post_backward stream.
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
|
@ -185,8 +204,10 @@ class ShardedModelV2(nn.Module):
|
|||
reduced_grad.data = cast_tensor_to_fp32(reduced_grad.data)
|
||||
|
||||
# Maybe offload
|
||||
# TODO() optimize GPU->CPU bandwidth utilization
|
||||
if self._cpu_offload:
|
||||
reduced_grad.data = reduced_grad.data.cpu()
|
||||
col_move_to_cpu(reduced_grad)
|
||||
# reduced_grad.data = reduced_grad.data.cpu()
|
||||
|
||||
if param.col_attr.grad is None:
|
||||
param.col_attr.grad = reduced_grad.data
|
||||
|
|
|
@ -143,7 +143,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
# We have to use `copy_payload` instead of `reset_payload`
|
||||
# Since p.data is fp32 and p.col_attr.data is fp16
|
||||
|
||||
# TODO() optimize this line
|
||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||
p.col_attr.data.copy_payload(p.data)
|
||||
|
||||
if not is_param_sharded:
|
||||
|
|
|
@ -56,6 +56,7 @@ def test_activation_checkpointing(cpu_offload):
|
|||
|
||||
assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match'
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# as seed manager is singleton
|
||||
# if we don't reset seeds here,
|
||||
# other tests will fail if running together with this test
|
||||
|
|
|
@ -9,12 +9,12 @@ import torch
|
|||
import torch.multiprocessing as mp
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
from common import CONFIG
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, init_device, shard_strategy):
|
||||
|
@ -37,13 +37,10 @@ def run_dist(rank, world_size, port, init_device, shard_strategy):
|
|||
assert param.col_attr.data.payload.device.type == init_device.type, \
|
||||
f'{param.col_attr.data.payload.device.type} vs. {init_device.type}'
|
||||
|
||||
print(f'cpu usgae {GLOBAL_MODEL_DATA_TRACER.cpu_usage}')
|
||||
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
|
||||
print(f'cuda usgae {ModelDataTracer().cuda_usage}')
|
||||
print(f'numel {model_numel_tensor}')
|
||||
if init_device.type == 'cuda':
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
|
||||
elif init_device.type == 'cpu':
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cpu_usage > 0)
|
||||
assert (ModelDataTracer().cuda_usage > 0)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@ -60,5 +57,5 @@ def test_zero_init_context(world_size, init_device, shard_strategy):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy)
|
||||
test_zero_init_context(2, torch.device(f'cuda:{get_current_device()}'), TensorShardStrategy)
|
||||
# test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy)
|
||||
test_zero_init_context(4, torch.device('cpu'), BucketTensorShardStrategy)
|
||||
|
|
|
@ -18,6 +18,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
|||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from common import CONFIG, check_grads_padding, run_fwd_bwd
|
||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_strategy):
|
||||
|
@ -33,12 +34,12 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s
|
|||
|
||||
if use_zero_init_ctx:
|
||||
with ZeroInitContext(convert_fp16=True,
|
||||
target_device=torch.device('cpu'),
|
||||
target_device=torch.device(f'cpu:0'),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True,
|
||||
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly):
|
||||
zero_model = model_builder(checkpoint=True)
|
||||
zero_model = ShardedModelV2(zero_model, shard_strategy)
|
||||
zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True)
|
||||
|
||||
model = model_builder(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
|
@ -59,6 +60,9 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s
|
|||
|
||||
check_grads_padding(model, zero_model, loose=True)
|
||||
|
||||
print('overall cuda ', zero_model._memstats_collector._overall_cuda)
|
||||
print('model cuda ', zero_model._memstats_collector._model_data_cuda)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
|
|
|
@ -1,6 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
|
@ -82,4 +79,4 @@ def test_sharded_optim_v2(world_size, cpu_offload, shard_strategy):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sharded_optim_v2(world_size=2, cpu_offload=True, shard_strategy=TensorShardStrategy)
|
||||
test_sharded_optim_v2(world_size=2, cpu_offload=True, shard_strategy=TensorShardStrategy)
|
Loading…
Reference in New Issue