[refactor] refactor the memory utils (#715)

pull/725/head
Jiarui Fang 2022-04-11 16:47:57 +08:00 committed by GitHub
parent dbd96fe90a
commit 193dc8dacb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 218 additions and 308 deletions

View File

@ -30,7 +30,7 @@ class ZeroHook(BaseOpHook):
self.process_group = process_group
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
self.computing_device = torch.device(f'cuda:{get_current_device()}')
self.computing_device = get_current_device()
self._memstarts_collector = memstarts_collector
self._stateful_tensor_mgr = stateful_tensor_mgr

View File

@ -8,7 +8,7 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral
sync_model_param, disposable)
from .data_sampler import DataParallelSampler, get_dataloader
from .gradient_accumulation import accumulate_gradient
from .memory_utils.memory_monitor import report_memory_usage
from .memory import report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction, colo_device_memory_capacity
from .timer import MultiTimer, Timer
from .tensor_detector import TensorDetector
@ -17,7 +17,8 @@ __all__ = [
'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'conditional_context',
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint',
'report_memory_usage', 'colo_device_memory_capacity', 'colo_device_memory_used', 'colo_set_process_memory_fraction',
'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler', 'get_dataloader',
'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint',
'ensure_path_exists', 'disposable'
]

View File

@ -20,13 +20,15 @@ def set_to_cuda(models):
return models.to(get_current_device())
def get_current_device():
"""Returns the index of a currently selected device (gpu/cpu).
def get_current_device() -> torch.device:
"""
Returns currently selected device (gpu/cpu).
If cuda available, return gpu, otherwise return cpu.
"""
if torch.cuda.is_available():
return torch.cuda.current_device()
return torch.device(f'cuda:{torch.cuda.current_device()}')
else:
return 'cpu'
return torch.device('cpu')
def synchronize():

147
colossalai/utils/memory.py Normal file
View File

@ -0,0 +1,147 @@
import torch
import gc
import psutil
from collections import namedtuple
from colossalai.context.parallel_mode import ParallelMode
from colossalai.utils import get_current_device
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
from colossalai.logging import get_dist_logger
_GLOBAL_CUDA_MEM_FRACTION = 1.0
def _bytes_to_MB(val, decimal=2):
"""A byte-to-Megabyte converter, default using binary notation.
:param val: X bytes to convert
:return: X' MB
"""
return round(val / (1024 * 1024), decimal)
# copy from PatrickStar
def _get_cpu_memory_info():
ps_mem_info = namedtuple("ps_mem_info", ["total", "free", "cached", "buffers", "used"])
try:
# psutil reads the memory info from /proc/memory_info,
# which results in returning the host memory instead of
# that of container.
# Here we try to read the container memory with method in:
# https://stackoverflow.com/a/46213331/5163915
mems = {}
with open("/sys/fs/cgroup/memory/memory.meminfo", "rb") as f:
for line in f:
fields = line.split()
mems[fields[0]] = int(fields[1]) * 1024
total = mems[b"MemTotal:"]
free = mems[b"MemFree:"]
cached = mems[b"Cached:"]
buffers = mems[b"Buffers:"]
used = total - free - cached - buffers
if used < 0:
used = total - free
mem_info = ps_mem_info(total=total, free=free, cached=cached, buffers=buffers, used=used)
except FileNotFoundError:
mems = psutil.virtual_memory()
mem_info = ps_mem_info(
total=mems.total,
free=mems.free,
cached=mems.cached,
buffers=mems.buffers,
used=mems.used,
)
return mem_info
def report_memory_usage(message, logger=None, report_cpu=False):
"""Calculate and print RAM usage (in GB)
Args:
message (str): A prefix message to add in the log.
logger (:class:`colossalai.logging.DistributedLogger`): The logger used to record memory information.
report_cpu (bool, optional): Whether to report CPU memory.
Raises:
EnvironmentError: Raise error if no distributed environment has been initialized.
"""
if not gpc.is_initialized(ParallelMode.GLOBAL):
raise EnvironmentError("No distributed environment is initialized")
gpu_allocated = _bytes_to_MB(torch.cuda.memory_allocated())
gpu_max_allocated = _bytes_to_MB(torch.cuda.max_memory_allocated())
gpu_cached = _bytes_to_MB(torch.cuda.memory_reserved())
gpu_max_cached = _bytes_to_MB(torch.cuda.max_memory_reserved())
full_log = f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " \
+ f"cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB"
if report_cpu:
# python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
gc.collect()
vm_stats = psutil.virtual_memory()
vm_used = _bytes_to_MB(vm_stats.total - vm_stats.available)
full_log += f", CPU Virtual Memory: used = {vm_used} MB, percent = {vm_stats.percent}%"
if logger is None:
logger = get_dist_logger()
logger.info(full_log)
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
torch.cuda.reset_peak_memory_stats()
def colo_device_memory_capacity(device: torch.device) -> int:
"""
Get the capacity of the memory of the device
Args:
device (torch.device): a device
Returns:
int: size in byte
"""
assert isinstance(device, torch.device)
if device.type == 'cpu':
mem_info = _get_cpu_memory_info()
return mem_info.info.total / gpc.get_world_size(ParallelMode.DATA)
if device.type == 'cuda':
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION
def colo_device_memory_used(device: torch.device) -> int:
"""
Get the device memory on device belonging to the current process.
Args:
device (torch.device): a device
Returns:
int: memory size in bytes
"""
if device.type == 'cpu':
mem_info = _get_cpu_memory_info()
# FIXME(jiaruifang) we need get how many processes are using the CPU memory.
ret = mem_info.used / gpc.get_world_size(ParallelMode.DATA)
return ret
elif device.type == 'cuda':
ret: int = torch.cuda.memory_allocated(device)
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
torch.cuda.reset_peak_memory_stats(device)
return ret
def colo_set_process_memory_fraction(ratio: float) -> None:
"""colo_set_process_memory_fraction
set how much cuda memory used on the gpu belonging to the current process.
Args:
ratio (float): a ratio between 0. ~ 1.
"""
global _GLOBAL_CUDA_MEM_FRACTION
_GLOBAL_CUDA_MEM_FRACTION = ratio
torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device())

View File

@ -4,7 +4,7 @@ import pickle
import torch
from colossalai.utils.memory_utils.utils import colo_device_memory_used
from colossalai.utils.memory import colo_device_memory_used
from colossalai.utils import get_current_device

View File

@ -1,5 +1,5 @@
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_device_memory_used
from colossalai.utils.memory import colo_device_memory_used
from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor
import torch
import time

View File

@ -1,61 +0,0 @@
import torch
from colossalai.zero.sharded_param import ShardedParamV2
from colossalai.utils import get_current_device
from typing import List
class BucketizedTensorCopy(object):
def __init__(
self,
chunk_size: int,
):
r"""
torch.nn.Parameter CPU (fp32) -> ShardedParam GPU (fp16)
TODO(jiaruifang) The class is a little bit hardcoded
I will make it more general later.
"""
self.chunk_size = chunk_size
self._offset = 0
self._cpu_buffer = torch.empty(chunk_size, dtype=torch.float, device=torch.device("cpu:0"), pin_memory=True)
self._cuda_buffer = torch.empty(chunk_size,
dtype=torch.half,
device=torch.device(f"cuda:{get_current_device()}"))
self._buffered_param_list: List[ShardedParamV2] = []
self._numel_list = []
def copy(self, src_param: torch.nn.Parameter, target_param: ShardedParamV2):
assert isinstance(target_param, ShardedParamV2)
assert isinstance(src_param, torch.nn.Parameter)
numel = src_param.numel()
if self._offset + numel > self.chunk_size:
self.flush()
assert src_param.data.device.type == 'cpu'
self._cpu_buffer.narrow(0, self._offset, numel).copy_(src_param.data.view(-1))
self._buffered_param_list.append(target_param)
self._numel_list.append(numel)
self._offset += numel
def flush(self):
"""
flush to cuda memory
"""
self._cuda_buffer.copy_(self._cpu_buffer)
flush_offset = 0
for sparam, numel in zip(self._buffered_param_list, self._numel_list):
sparam.sharded_data_tensor.copy_payload(self._cpu_buffer.narrow(0, flush_offset, numel))
flush_offset += numel
self.reset()
def reset(self):
self._buffered_param_list = []
self._numel_list = []
self._offset = 0

View File

@ -1,67 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import gc
import psutil
import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
def bytes_to_GB(val, decimal=2):
"""A byte-to-Gigabyte converter, default using binary notation.
:param val: X bytes to convert
:return: X' GB
"""
return round(val / (1024 * 1024 * 1024), decimal)
def bytes_to_MB(val, decimal=2):
"""A byte-to-Megabyte converter, default using binary notation.
:param val: X bytes to convert
:return: X' MB
"""
return round(val / (1024 * 1024), decimal)
def report_memory_usage(message, logger=None, report_cpu=False):
"""Calculate and print RAM usage (in GB)
Args:
message (str): A prefix message to add in the log.
logger (:class:`colossalai.logging.DistributedLogger`): The logger used to record memory information.
report_cpu (bool, optional): Whether to report CPU memory.
Raises:
EnvironmentError: Raise error if no distributed environment has been initialized.
"""
if not gpc.is_initialized(ParallelMode.GLOBAL):
raise EnvironmentError("No distributed environment is initialized")
gpu_allocated = bytes_to_MB(torch.cuda.memory_allocated())
gpu_max_allocated = bytes_to_MB(torch.cuda.max_memory_allocated())
gpu_cached = bytes_to_MB(torch.cuda.memory_reserved())
gpu_max_cached = bytes_to_MB(torch.cuda.max_memory_reserved())
full_log = f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " \
+ f"cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB"
if report_cpu:
# python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
gc.collect()
vm_stats = psutil.virtual_memory()
vm_used = bytes_to_MB(vm_stats.total - vm_stats.available)
full_log += f", CPU Virtual Memory: used = {vm_used} MB, percent = {vm_stats.percent}%"
if logger is None:
logger = get_dist_logger()
logger.info(full_log)
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
torch.cuda.reset_peak_memory_stats()

View File

@ -1,82 +0,0 @@
import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.utils import get_current_device
from collections import namedtuple
import psutil
from colossalai.core import global_context as gpc
_GLOBAL_CUDA_MEM_FRACTION = 1.0
# copy from PatrickStar
def _get_cpu_memory_info():
ps_mem_info = namedtuple("ps_mem_info", ["total", "free", "cached", "buffers", "used"])
try:
# psutil reads the memory info from /proc/memory_info,
# which results in returning the host memory instead of
# that of container.
# Here we try to read the container memory with method in:
# https://stackoverflow.com/a/46213331/5163915
mems = {}
with open("/sys/fs/cgroup/memory/memory.meminfo", "rb") as f:
for line in f:
fields = line.split()
mems[fields[0]] = int(fields[1]) * 1024
total = mems[b"MemTotal:"]
free = mems[b"MemFree:"]
cached = mems[b"Cached:"]
buffers = mems[b"Buffers:"]
used = total - free - cached - buffers
if used < 0:
used = total - free
mem_info = ps_mem_info(total=total, free=free, cached=cached, buffers=buffers, used=used)
except FileNotFoundError:
mems = psutil.virtual_memory()
mem_info = ps_mem_info(
total=mems.total,
free=mems.free,
cached=mems.cached,
buffers=mems.buffers,
used=mems.used,
)
return mem_info
def colo_device_memory_used(device) -> int:
if not isinstance(device, torch.device):
device = torch.device(f"cuda:{device}")
if device.type == 'cpu':
mem_info = _get_cpu_memory_info()
# FIXME(jiaruifang) only work for 1-CPU multi-GPU
# CPU memory is sharded with all processes
# Not support multi-GPU multi-CPU
# We need a local_world_size here
ret = mem_info.used / gpc.get_world_size(ParallelMode.DATA)
return ret
elif device.type == 'cuda':
ret: int = torch.cuda.memory_allocated(device)
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
torch.cuda.reset_peak_memory_stats(device)
return ret
def colo_set_process_memory_fraction(ratio: float) -> None:
"""colo_set_process_memory_fraction
set how much cuda memory used on the gpu belonging to the current process.
Args:
ratio (float): a ratio between 0. ~ 1.
"""
global _GLOBAL_CUDA_MEM_FRACTION
_GLOBAL_CUDA_MEM_FRACTION = ratio
torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device())
def colo_cuda_memory_capacity() -> float:
"""
Get cuda memory capacity of the current cuda.
"""
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION

View File

@ -5,7 +5,7 @@ from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from typing import Dict, List
from colossalai.utils.memory_tracer import MemStatsCollector
@ -64,7 +64,7 @@ class StatefulTensorMgr(object):
cuda_demand += colo_tensor_mem_usage(tensor.payload)[1]
else:
raise RuntimeError
cuda_capacity = colo_cuda_memory_capacity()
cuda_capacity = colo_device_memory_capacity(get_current_device())
if self._warmup:
# We designate a part of CUDA memory for model data in warmup iterations.

View File

@ -33,7 +33,7 @@ class TensorShardStrategy(BaseShardStrategy):
if t.is_sharded:
return
if t.payload.device.type == 'cuda':
assert t.payload.device.index == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
f" but current cuda device is {get_current_device()}"
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
t.reset_payload(sharded_payload)

View File

@ -16,7 +16,7 @@ from colossalai.utils import get_current_device, disposable
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_move_to_cpu
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
@ -231,7 +231,7 @@ class ShardedModelV2(nn.Module):
# the way to calculate margin space is based on the assumption that
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self._cuda_margin_space = colo_cuda_memory_capacity() - max(
self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max(
self._memstats_collector.overall_mem_stats('cuda'))
@torch.no_grad()

View File

@ -41,7 +41,7 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
logger = get_dist_logger("test_moe_zero_init")
if init_device_type == 'cuda':
init_device = torch.device(f"cuda:{get_current_device()}")
init_device = get_current_device()
elif init_device_type == 'cpu':
init_device = torch.device("cpu")
else:

View File

@ -62,10 +62,9 @@ def _run_test_sharded_optim_v2(cpu_offload,
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
_, train_dataloader, _, optimizer_class, criterion = get_components_func()
with ZeroInitContext(
target_device=torch.device('cpu') if cpu_offload else torch.device(f'cuda:{get_current_device()}'),
shard_strategy=shard_strategy,
shard_param=True):
with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(),
shard_strategy=shard_strategy,
shard_param=True):
zero_model = MoeModel()
zero_model = ShardedModelV2(zero_model,

View File

@ -1,39 +0,0 @@
from colossalai.utils.memory_utils.bucket_tensor_copy import BucketizedTensorCopy
from colossalai.zero.sharded_param import ShardedParamV2
from colossalai.utils import free_port
import torch
import colossalai
def test_bucket_copy():
# init dist env
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
copyer = BucketizedTensorCopy(20)
shape_list = [(2, 3), (5), (8), (12)]
src_param_list = []
tgt_param_list = []
for shape in shape_list:
# on CPU
src_param = torch.nn.Parameter(torch.randn(shape, dtype=torch.float, device=torch.device('cpu')))
# on GPU
tgt_param = ShardedParamV2(torch.nn.Parameter(torch.ones(shape, dtype=torch.half, device=torch.device('cuda'))))
src_param_list.append(src_param)
tgt_param_list.append(tgt_param)
copyer.copy(src_param, tgt_param)
copyer.flush()
for src_param, tgt_param in zip(src_param_list, tgt_param_list):
diff = src_param.cpu().float() - tgt_param.sharded_data_tensor.payload.cpu().float()
assert torch.allclose(src_param.cpu().float(),
tgt_param.sharded_data_tensor.payload.cpu().float(),
rtol=1e-03,
atol=1e-03), f"diff {diff}"
if __name__ == '__main__':
test_bucket_copy()

View File

@ -2,7 +2,7 @@ import pytest
from colossalai.utils.cuda import get_current_device
from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage, colo_model_data_tensor_move, colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu, colo_model_tensor_clone
from colossalai.utils.memory_utils.utils import colo_set_process_memory_fraction, colo_cuda_memory_capacity
from colossalai.utils.memory import colo_set_process_memory_fraction, colo_device_memory_capacity
from colossalai.utils import free_port
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
import colossalai
@ -12,54 +12,63 @@ import torch
from functools import partial
import torch.multiprocessing as mp
def _run_colo_tensor_mem_usage():
for i in range(1):
if i == 1:
t1 = StatefulTensor(torch.randn(2,2))
t2 = StatefulTensor(torch.randn(4,4))
c1 , g1 = colo_tensor_mem_usage(t1)
c2 , g2 = colo_tensor_mem_usage(t2)
assert c1*4 == c2
assert g1*4 == g2
t1 = StatefulTensor(torch.randn(2, 2))
t2 = StatefulTensor(torch.randn(4, 4))
c1, g1 = colo_tensor_mem_usage(t1)
c2, g2 = colo_tensor_mem_usage(t2)
assert c1 * 4 == c2
assert g1 * 4 == g2
else:
t1 = torch.randn(2,2)
t2 = torch.randn(4,4)
c1 , g1 = colo_tensor_mem_usage(t1)
c2 , g2 = colo_tensor_mem_usage(t2)
assert c1*4 == c2
assert g1*4 == g2
t1 = torch.randn(2, 2)
t2 = torch.randn(4, 4)
c1, g1 = colo_tensor_mem_usage(t1)
c2, g2 = colo_tensor_mem_usage(t2)
assert c1 * 4 == c2
assert g1 * 4 == g2
def _run_colo_set_process_memory_fraction_and_colo_cuda_memory_capacity():
frac1 = colo_cuda_memory_capacity()
def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity():
frac1 = colo_device_memory_capacity(get_current_device())
colo_set_process_memory_fraction(0.5)
frac2 = colo_cuda_memory_capacity()
assert frac2*2 == frac1
frac2 = colo_device_memory_capacity(get_current_device())
assert frac2 * 2 == frac1
def _run_colo_model_data_tensor_move_inline():
for t in [StatefulTensor(torch.randn(2,3)), torch.randn(2,3)]:
colo_model_data_tensor_move_inline(t, torch.device(f"cuda:{get_current_device()}"))
assert t.device == torch.device(f"cuda:{get_current_device()}")
for t in [StatefulTensor(torch.randn(2, 3)), torch.randn(2, 3)]:
colo_model_data_tensor_move_inline(t, get_current_device())
assert t.device == get_current_device()
def _run_colo_model_data_tensor_move():
for t in [(StatefulTensor(torch.ones(2, 3)), StatefulTensor(torch.zeros(2, 3).cuda(get_current_device()))),
(torch.ones(2, 3), torch.zeros(2, 3).cuda(get_current_device()))]:
for t in [(StatefulTensor(torch.ones(2, 3)), StatefulTensor(torch.zeros(2, 3).to(get_current_device()))),
(torch.ones(2, 3), torch.zeros(2, 3).to(get_current_device()))]:
cpu_t, cuda_t = t
colo_model_data_tensor_move(cpu_t, cuda_t)
assert cuda_t.device == torch.device(f"cuda:{get_current_device()}")
assert cuda_t.device == get_current_device()
def _run_colo_model_data_move_to_cpu():
for t in [StatefulTensor(torch.randn(2,2)), torch.randn(4,4)]:
for t in [StatefulTensor(torch.randn(2, 2)), torch.randn(4, 4)]:
colo_model_data_move_to_cpu(t)
assert t.device == torch.device("cpu")
def _run_colo_model_tensor_clone():
for t in [StatefulTensor(torch.randn(2,2).cuda(torch.cuda.current_device())), torch.randn(4,4).cuda(torch.cuda.current_device())]:
for t in [
StatefulTensor(torch.randn(2, 2).cuda(torch.cuda.current_device())),
torch.randn(4, 4).cuda(torch.cuda.current_device())
]:
if issubclass(type(t), StatefulTensor):
assert t.payload.device == torch.device(f"cuda:{get_current_device()}")
assert t.payload.device == get_current_device()
else:
assert t.device == torch.device(f"cuda:{get_current_device()}")
p = colo_model_tensor_clone(t, torch.device(f"cuda:{get_current_device()}"))
assert p.device == torch.device(f"cuda:{get_current_device()}")
assert t.device == get_current_device()
p = colo_model_tensor_clone(t, get_current_device())
assert p.device == get_current_device()
for i in range(2):
for j in range(2):
if issubclass(type(t), StatefulTensor):
@ -70,21 +79,22 @@ def _run_colo_model_tensor_clone():
assert t[i][j] == p[i][j]
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
_run_colo_set_process_memory_fraction_and_colo_cuda_memory_capacity()
_run_colo_set_process_memory_fraction_and_colo_device_memory_capacity()
_run_colo_model_data_tensor_move_inline()
_run_colo_model_data_tensor_move()
_run_colo_tensor_mem_usage()
_run_colo_model_data_move_to_cpu()
_run_colo_model_tensor_clone()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4, 5])
def test_tensor_move(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_tensor_move(4)

View File

@ -13,7 +13,7 @@ from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer.model_data_memtracer import \
colo_model_mem_usage
from colossalai.utils.memory_utils.utils import colo_device_memory_used
from colossalai.utils.memory import colo_device_memory_used
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from tests.components_to_test.registry import non_distributed_component_funcs
@ -29,7 +29,7 @@ def run_model_test(init_device_type, shard_strategy_class):
for get_components_func in non_distributed_component_funcs:
model_builder, _, _, _, _ = get_components_func()
if init_device_type == 'cuda':
init_device = torch.device(f"cuda:{get_current_device()}")
init_device = get_current_device()
elif init_device_type == 'cpu':
init_device = torch.device("cpu")
else:

View File

@ -57,10 +57,9 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
with ZeroInitContext(
target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(f'cuda:{get_current_device()}'),
shard_strategy=shard_strategy,
shard_param=True):
with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(),
shard_strategy=shard_strategy,
shard_param=True):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(
zero_model,

View File

@ -2,9 +2,10 @@ import torch
import colossalai
import pytest
import torch.multiprocessing as mp
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity, colo_set_process_memory_fraction
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from colossalai.zero.shard_utils import StatefulTensorMgr
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import TensorState
@ -26,7 +27,7 @@ class Net(torch.nn.Module):
def run_stm():
cuda_capacity = colo_cuda_memory_capacity()
cuda_capacity = colo_device_memory_capacity(get_current_device())
fraction = (1.4 * 1024**3) / cuda_capacity
# limit max memory to 1.4GB
# which means only 2 parameters can be on CUDA