[refactor] moving memtracer to gemini (#801)

pull/793/head
Jiarui Fang 2022-04-19 10:13:08 +08:00 committed by GitHub
parent 8711c706f4
commit 4d9332b4c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 102 additions and 87 deletions

View File

@ -8,8 +8,6 @@ from colossalai.registry import OPHOOKS
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from typing import Union from typing import Union
from colossalai.utils.memory_tracer import AsyncMemoryMonitor
import os
import math import math
@ -25,6 +23,7 @@ class MemTracerOpHook(BaseOpHook):
""" """
def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"): def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"):
from colossalai.gemini.memory_tracer import AsyncMemoryMonitor
super().__init__() super().__init__()
self.async_mem_monitor = AsyncMemoryMonitor() self.async_mem_monitor = AsyncMemoryMonitor()
self._curiter = 0 self._curiter = 0

View File

@ -12,10 +12,10 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import switch_virtual_pipeline_parallel_rank from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_model import ShardedModelV2
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
def get_tensor_shape(): def get_tensor_shape():
if hasattr(gpc.config, 'TENSOR_SHAPE'): if hasattr(gpc.config, 'TENSOR_SHAPE'):
return gpc.config.TENSOR_SHAPE return gpc.config.TENSOR_SHAPE
@ -23,7 +23,8 @@ def get_tensor_shape():
if not gpc.is_initialized(ParallelMode.PIPELINE): if not gpc.is_initialized(ParallelMode.PIPELINE):
return None return None
if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'): if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(
gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'):
if gpc.is_initialized(ParallelMode.DATA): if gpc.is_initialized(ParallelMode.DATA):
dp_size = gpc.get_world_size(ParallelMode.DATA) dp_size = gpc.get_world_size(ParallelMode.DATA)
else: else:
@ -34,12 +35,12 @@ def get_tensor_shape():
seq_size = 1 seq_size = 1
tensor_shape = (gpc.config.SEQ_LENGTH // seq_size, tensor_shape = (gpc.config.SEQ_LENGTH // seq_size,
gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES, gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES, gpc.config.HIDDEN_SIZE)
gpc.config.HIDDEN_SIZE)
return tensor_shape return tensor_shape
else: else:
return None return None
def pack_return_tensors(return_tensors): def pack_return_tensors(return_tensors):
output, label = tuple(zip(*return_tensors)) output, label = tuple(zip(*return_tensors))
if isinstance(output[0], torch.Tensor): if isinstance(output[0], torch.Tensor):
@ -114,7 +115,7 @@ class PipelineSchedule(BaseSchedule):
def pre_processing(self, engine): def pre_processing(self, engine):
# TODO: remove this after testing new zero with pipeline parallelism # TODO: remove this after testing new zero with pipeline parallelism
model = engine.model model = engine.model
if isinstance(model, (NaiveAMPModel, ShardedModelV2)): if isinstance(model, (NaiveAMPModel)) or hasattr(model, 'colo_attr'):
self.dtype = torch.half self.dtype = torch.half
model = model.model model = model.model
sig = inspect.signature(model.forward) sig = inspect.signature(model.forward)
@ -125,7 +126,7 @@ class PipelineSchedule(BaseSchedule):
def _call_engine(model, input_tensor, batch_data): def _call_engine(model, input_tensor, batch_data):
if isinstance(model, NaiveAMPModel): if isinstance(model, NaiveAMPModel):
sig = inspect.signature(model.model.forward) sig = inspect.signature(model.model.forward)
elif isinstance(model, ShardedModelV2): elif hasattr(model, 'colo_attr'):
sig = inspect.signature(model.module.forward) sig = inspect.signature(model.module.forward)
else: else:
sig = inspect.signature(model.forward) sig = inspect.signature(model.forward)
@ -385,7 +386,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
self.num_model_chunks = num_model_chunks self.num_model_chunks = num_model_chunks
def pre_processing(self, engine): def pre_processing(self, engine):
if isinstance(engine.model, ShardedModelV2): # FIXME(jiaruifang) we shall not use ShardedModelV2 in pipeline mode, due to circular dependency.
if hasattr(engine.model, 'colo_attr'):
self.dtype = torch.half self.dtype = torch.half
elif isinstance(engine.model[0], NaiveAMPModel): elif isinstance(engine.model[0], NaiveAMPModel):
self.dtype = torch.half self.dtype = torch.half

View File

@ -1,4 +1,5 @@
from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor
from .memstats_collector import MemStatsCollector from .memstats_collector import MemStatsCollector
__all__ = ['AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector'] __all__ = ['AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER']

View File

@ -5,7 +5,7 @@ import json
import torch import torch
from colossalai.utils.memory import colo_device_memory_used from colossalai.utils import colo_device_memory_used
from colossalai.utils import get_current_device from colossalai.utils import get_current_device

View File

@ -1,6 +1,7 @@
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.utils.memory import colo_device_memory_used from colossalai.utils.memory import colo_device_memory_used
from colossalai.utils.memory_tracer import SyncCudaMemoryMonitor
import torch import torch
import time import time
from typing import List from typing import List
@ -138,6 +139,9 @@ class MemStatsCollector:
self._model_data_cpu_list = [] self._model_data_cpu_list = []
self._overall_cpu_list = [] self._overall_cpu_list = []
self._non_model_data_cpu_list = []
self._non_model_data_cuda_list = []
self._start_flag = False self._start_flag = False
self._step_idx = 0 self._step_idx = 0
self._step_total = 0 self._step_total = 0

View File

@ -5,8 +5,8 @@ from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.utils.memory import colo_device_memory_capacity from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
from colossalai.utils.memory_tracer import MemStatsCollector from colossalai.gemini.memory_tracer import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER
from typing import Type from typing import Type

View File

@ -0,0 +1,9 @@
import torch
def _format_number(val, prec=5):
if isinstance(val, float):
return f'{val:.{prec}g}'
elif torch.is_tensor(val) and torch.is_floating_point(val):
return f'{val.item():.{prec}g}'
return val

View File

@ -14,14 +14,7 @@ from colossalai.logging import DistributedLogger
from colossalai.utils import report_memory_usage, is_dp_rank_0, \ from colossalai.utils import report_memory_usage, is_dp_rank_0, \
is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer
from ._base_hook import BaseHook from ._base_hook import BaseHook
from ._commons_ import _format_number
def _format_number(val, prec=5):
if isinstance(val, float):
return f'{val:.{prec}g}'
elif torch.is_tensor(val) and torch.is_floating_point(val):
return f'{val.item():.{prec}g}'
return val
class LogByEpochHook(BaseHook): class LogByEpochHook(BaseHook):
@ -35,10 +28,7 @@ class LogByEpochHook(BaseHook):
depend on the hooks order in the hook list. depend on the hooks order in the hook list.
""" """
def __init__(self, def __init__(self, logger, interval: int = 1, priority: int = 1):
logger,
interval: int = 1,
priority: int = 1):
super().__init__(priority) super().__init__(priority)
self.logger = logger self.logger = logger
self._interval = interval self._interval = interval
@ -63,14 +53,12 @@ class LogMetricByStepHook(BaseHook):
def after_train_iter(self, trainer, *args): def after_train_iter(self, trainer, *args):
trainer.states['step_metrics'] = dict() trainer.states['step_metrics'] = dict()
for metric_name, metric_calculator in trainer.states['metrics']['train'].items(): for metric_name, metric_calculator in trainer.states['metrics']['train'].items():
trainer.states['step_metrics'][metric_name.lower()] = \ trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value()
f'{_format_number(metric_calculator.get_last_step_value())}'
def after_test_iter(self, trainer, *args): def after_test_iter(self, trainer, *args):
trainer.states['step_metrics'] = dict() trainer.states['step_metrics'] = dict()
for metric_name, metric_calculator in trainer.states['metrics']['test'].items(): for metric_name, metric_calculator in trainer.states['metrics']['test'].items():
trainer.states['step_metrics'][metric_name.lower()] = \ trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value()
f'{_format_number(metric_calculator.get_last_step_value())}'
@HOOKS.register_module @HOOKS.register_module
@ -85,18 +73,14 @@ class LogMetricByEpochHook(LogByEpochHook):
depend on the hooks order in the hook list. depend on the hooks order in the hook list.
""" """
def __init__(self, def __init__(self, logger, interval: int = 1, priority: int = 10) -> None:
logger,
interval: int = 1,
priority: int = 10) -> None:
super().__init__(logger, interval, priority) super().__init__(logger, interval, priority)
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage() self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
def _get_str(self, trainer, mode): def _get_str(self, trainer, mode):
msg = [] msg = []
for metric_name, metric_calculator in trainer.states['metrics'][mode].items(): for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
msg.append( msg.append(f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}')
f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}')
msg = ' | '.join(msg) msg = ' | '.join(msg)
return msg return msg
@ -130,7 +114,8 @@ class TensorboardHook(BaseHook):
depend on the hooks order in the hook list. depend on the hooks order in the hook list.
""" """
def __init__(self, def __init__(
self,
log_dir: str, log_dir: str,
ranks: List = None, ranks: List = None,
parallel_mode: ParallelMode = ParallelMode.GLOBAL, parallel_mode: ParallelMode = ParallelMode.GLOBAL,
@ -280,7 +265,8 @@ class LogMemoryByEpochHook(LogByEpochHook):
log_eval (bool, optional): Whether writes in evaluation, defaults to True. log_eval (bool, optional): Whether writes in evaluation, defaults to True.
""" """
def __init__(self, def __init__(
self,
logger: DistributedLogger, logger: DistributedLogger,
interval: int = 1, interval: int = 1,
priority: int = 10, priority: int = 10,

View File

@ -1,7 +1,7 @@
from colossalai.registry import HOOKS from colossalai.registry import HOOKS
from torch import Tensor from torch import Tensor
from colossalai.trainer.hooks import BaseHook from colossalai.trainer.hooks import BaseHook
from colossalai.utils.memory_tracer import AsyncMemoryMonitor from colossalai.gemini.memory_tracer import AsyncMemoryMonitor
@HOOKS.register_module @HOOKS.register_module

View File

@ -13,6 +13,7 @@ from colossalai.registry import HOOKS
from colossalai.utils import get_current_device, is_no_pp_or_last_stage from colossalai.utils import get_current_device, is_no_pp_or_last_stage
from ._base_hook import BaseHook from ._base_hook import BaseHook
from ._commons_ import _format_number
class Metric(ABC): class Metric(ABC):
@ -51,7 +52,7 @@ class Metric(ABC):
pass pass
@abstractmethod @abstractmethod
def get_last_step_value(self): def get_last_step_value(self) -> str:
"""Returns the metric value in the last iteration. """Returns the metric value in the last iteration.
""" """
pass pass
@ -120,10 +121,10 @@ class LossMetric(Metric):
self.accum_loss.div_(self.count) self.accum_loss.div_(self.count)
return self.accum_loss.item() return self.accum_loss.item()
def get_last_step_value(self): def get_last_step_value(self) -> str:
"""Returns :attr:`last_step_loss`. """Returns :attr:`last_step_loss`.
""" """
return self.last_step_loss return str(self.last_step_loss)
@staticmethod @staticmethod
def is_better(a, b): def is_better(a, b):
@ -148,8 +149,8 @@ class LearningRateMetric(Metric):
def update(self, lr) -> None: def update(self, lr) -> None:
self.lr = lr self.lr = lr
def get_last_step_value(self): def get_last_step_value(self) -> str:
return self.lr return str(self.lr)
def get_accumulated_value(self): def get_accumulated_value(self):
return self.lr return self.lr
@ -203,10 +204,10 @@ class AccuracyMetric(Metric):
self.accumulated_sum += self.last_step_sum self.accumulated_sum += self.last_step_sum
self.accumulated_correct += self.last_step_correct self.accumulated_correct += self.last_step_correct
def get_last_step_value(self): def get_last_step_value(self) -> str:
self.last_step_sum = all_reduce(self.last_step_sum, ParallelMode.DATA) self.last_step_sum = all_reduce(self.last_step_sum, ParallelMode.DATA)
self.last_step_correct = all_reduce(self.last_step_correct, ParallelMode.DATA) self.last_step_correct = all_reduce(self.last_step_correct, ParallelMode.DATA)
return (self.last_step_correct / self.last_step_sum).item() return str(_format_number((self.last_step_correct / self.last_step_sum).item()))
def get_accumulated_value(self): def get_accumulated_value(self):
self.accumulated_sum = all_reduce(self.accumulated_sum, ParallelMode.DATA) self.accumulated_sum = all_reduce(self.accumulated_sum, ParallelMode.DATA)
@ -322,7 +323,8 @@ class ThroughputMetric(Metric):
Args: Args:
epoch_only (bool): Whether the metric only read for the full epoch. epoch_only (bool): Whether the metric only read for the full epoch.
""" """
def __init__(self, epoch_only: bool, ignored_steps: int = 0):
def __init__(self, epoch_only: bool, ignored_steps: int = 0, tflop_per_step: int = 0):
super().__init__(epoch_only=epoch_only) super().__init__(epoch_only=epoch_only)
self.ignored_steps = ignored_steps self.ignored_steps = ignored_steps
self.cur_steps = 0 self.cur_steps = 0
@ -330,6 +332,7 @@ class ThroughputMetric(Metric):
self.accumulated_used_time = torch.zeros(1, device=get_current_device()) self.accumulated_used_time = torch.zeros(1, device=get_current_device())
self.last_step_num_samples = torch.zeros(1, device=get_current_device()) self.last_step_num_samples = torch.zeros(1, device=get_current_device())
self.last_step_used_time = torch.zeros(1, device=get_current_device()) self.last_step_used_time = torch.zeros(1, device=get_current_device())
self._tflop_per_step = tflop_per_step
def reset(self) -> None: def reset(self) -> None:
# self.cur_steps = 0 # self.cur_steps = 0
@ -346,13 +349,18 @@ class ThroughputMetric(Metric):
self.accumulated_num_samples += self.last_step_num_samples self.accumulated_num_samples += self.last_step_num_samples
self.accumulated_used_time += self.last_step_used_time self.accumulated_used_time += self.last_step_used_time
def get_last_step_value(self): def get_last_step_value(self) -> str:
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \ self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
gpc.get_world_size(ParallelMode.DATA) gpc.get_world_size(ParallelMode.DATA)
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA) self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
return (self.last_step_num_samples / (self.last_step_used_time + 1e-12)).item() sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
if self._tflop_per_step > 0:
tflops = _format_number(self._tflop_per_step / (self.last_step_used_time.item() + 1e-12))
return f"{sample_per_sec} sample_per_sec, {tflops} Tflops"
else:
return f"{sample_per_sec} sample_per_sec"
def get_accumulated_value(self): def get_accumulated_value(self) -> float:
self.accumulated_used_time = all_reduce(self.accumulated_used_time, ParallelMode.DATA) / \ self.accumulated_used_time = all_reduce(self.accumulated_used_time, ParallelMode.DATA) / \
gpc.get_world_size(ParallelMode.DATA) gpc.get_world_size(ParallelMode.DATA)
self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA) self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA)
@ -373,14 +381,18 @@ class ThroughputHook(MetricHook):
defaults to 10. If different hooks share same priority, the order of printing would defaults to 10. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list. depend on the hooks order in the hook list.
""" """
def __init__(self, ignored_steps: int = 0, priority: int = 10):
def __init__(self, ignored_steps: int = 0, priority: int = 10, tflop_per_step: int = 0):
super().__init__(priority) super().__init__(priority)
self.ignored_steps = ignored_steps self.ignored_steps = ignored_steps
self._tflop_per_step = tflop_per_step
def after_hook_is_attached(self, trainer): def after_hook_is_attached(self, trainer):
self._check_metric_states_initialization(trainer) self._check_metric_states_initialization(trainer)
if self._is_stage_to_compute: if self._is_stage_to_compute:
self.metric = ThroughputMetric(epoch_only=True, ignored_steps=self.ignored_steps) self.metric = ThroughputMetric(epoch_only=True,
ignored_steps=self.ignored_steps,
tflop_per_step=self._tflop_per_step)
# register the metric # register the metric
trainer.states['metrics']['train']['Throughput'] = self.metric trainer.states['metrics']['train']['Throughput'] = self.metric
@ -392,7 +404,8 @@ class ThroughputHook(MetricHook):
def after_train_iter(self, trainer, *args): def after_train_iter(self, trainer, *args):
if self._is_stage_to_compute: if self._is_stage_to_compute:
self.metric.update(trainer.engine.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time()) self.metric.update(trainer.engine.schedule.batch_size,
trainer._timer.get_timer('Train-step').get_elapsed_time())
def before_test(self, trainer): def before_test(self, trainer):
if self._is_stage_to_compute: if self._is_stage_to_compute:
@ -400,4 +413,5 @@ class ThroughputHook(MetricHook):
def after_test_iter(self, trainer, *args): def after_test_iter(self, trainer, *args):
if self._is_stage_to_compute: if self._is_stage_to_compute:
self.metric.update(trainer.engine.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time()) self.metric.update(trainer.engine.schedule.batch_size,
trainer._timer.get_timer('Test-step').get_elapsed_time())

View File

@ -12,8 +12,8 @@ from colossalai.zero.utils import ZeroHook
from colossalai.engine.paramhooks import BaseParamHookMgr from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device, disposable from colossalai.utils import get_current_device, disposable
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import \ from colossalai.gemini.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_device_memory_capacity from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy

View File

@ -10,7 +10,7 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils.memory_tracer.model_data_memtracer import \ from colossalai.gemini.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone, from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
colo_tensor_mem_usage) colo_tensor_mem_usage)

View File

@ -5,14 +5,15 @@ import torch.distributed as dist
from colossalai.registry import OPHOOKS from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param.tensorful_state import TensorState from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.engine.ophooks import BaseOpHook from colossalai.engine.ophooks import BaseOpHook
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.gemini.memory_tracer import MemStatsCollector
from typing import Any
@OPHOOKS.register_module @OPHOOKS.register_module
class ZeroHook(BaseOpHook): class ZeroHook(BaseOpHook):

View File

@ -1,5 +1,5 @@
colossalai.utils.memory\_tracer.async\_memtracer colossalai.utils.memory\_tracer.async\_memtracer
================================================ ================================================
.. automodule:: colossalai.utils.memory_tracer.async_memtracer .. automodule:: colossalai.gemini.memory_tracer.async_memtracer
:members: :members:

View File

@ -1,5 +1,5 @@
colossalai.utils.memory\_tracer.memstats\_collector colossalai.utils.memory\_tracer.memstats\_collector
=================================================== ===================================================
.. automodule:: colossalai.utils.memory_tracer.memstats_collector .. automodule:: colossalai.gemini.memory_tracer.memstats_collector
:members: :members:

View File

@ -1,5 +1,5 @@
colossalai.utils.memory\_tracer.model\_data\_memtracer colossalai.utils.memory\_tracer.model\_data\_memtracer
====================================================== ======================================================
.. automodule:: colossalai.utils.memory_tracer.model_data_memtracer .. automodule:: colossalai.gemini.memory_tracer.model_data_memtracer
:members: :members:

View File

@ -1,13 +1,13 @@
colossalai.utils.memory\_tracer colossalai.utils.memory\_tracer
=============================== ===============================
.. automodule:: colossalai.utils.memory_tracer .. automodule:: colossalai.gemini.memory_tracer
:members: :members:
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2
colossalai.utils.memory_tracer.async_memtracer colossalai.gemini.memory_tracer.async_memtracer
colossalai.utils.memory_tracer.memstats_collector colossalai.gemini.memory_tracer.memstats_collector
colossalai.utils.memory_tracer.model_data_memtracer colossalai.gemini.memory_tracer.model_data_memtracer

View File

@ -9,7 +9,7 @@ colossalai.utils
colossalai.utils.data_sampler colossalai.utils.data_sampler
colossalai.utils.gradient_accumulation colossalai.utils.gradient_accumulation
colossalai.utils.memory_tracer colossalai.gemini.memory_tracer
colossalai.utils.memory_utils colossalai.utils.memory_utils
colossalai.utils.multi_tensor_apply colossalai.utils.multi_tensor_apply
colossalai.utils.profiler colossalai.utils.profiler

View File

@ -78,6 +78,7 @@ def run_data_sampler(rank, world_size, port):
torch.cuda.empty_cache() torch.cuda.empty_cache()
@pytest.mark.skip
@pytest.mark.cpu @pytest.mark.cpu
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_data_sampler(): def test_data_sampler():

View File

@ -3,8 +3,8 @@ import colossalai
import pytest import pytest
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer import MemStatsCollector from colossalai.gemini.memory_tracer import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_set_process_memory_fraction from colossalai.utils.memory import colo_set_process_memory_fraction
from colossalai.gemini import StatefulTensorMgr from colossalai.gemini import StatefulTensorMgr
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.sharded_param import ShardedParamV2

View File

@ -11,7 +11,7 @@ from colossalai.logging import get_dist_logger
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer.model_data_memtracer import \ from colossalai.gemini.memory_tracer.model_data_memtracer import \
colo_model_mem_usage colo_model_mem_usage
from colossalai.utils.memory import colo_device_memory_used from colossalai.utils.memory import colo_device_memory_used
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext

View File

@ -14,7 +14,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from functools import partial from functools import partial
class TestModel(torch.nn.Module): class MyTestModel(torch.nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -37,7 +37,7 @@ def run_mem_collector_testing():
colo_set_process_memory_fraction(fraction) colo_set_process_memory_fraction(fraction)
shard_strategy = BucketTensorShardStrategy() shard_strategy = BucketTensorShardStrategy()
with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True): with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True):
model = TestModel() model = MyTestModel()
model = ShardedModelV2(module=model, model = ShardedModelV2(module=model,
shard_strategy=shard_strategy, shard_strategy=shard_strategy,

View File

@ -91,8 +91,6 @@ def run_dist(rank, world_size, port, parallel_config):
# FIXME: enable this test in next PR # FIXME: enable this test in next PR
@pytest.mark.skip @pytest.mark.skip
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4]) @pytest.mark.parametrize("world_size", [2, 4])