polish utils docstring (#620)

pull/621/head^2
ver217 2022-04-01 16:36:47 +08:00 committed by GitHub
parent e619a651fb
commit 369a288bf3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 42 additions and 48 deletions

View File

@ -11,19 +11,13 @@ from colossalai.utils import get_current_device
class AsyncMemoryMonitor: class AsyncMemoryMonitor:
""" """
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
at interval of 1/(10**power) sec. at interval of `1/(10**power)` sec.
The idea comes from Runtime Memory Tracer of PatrickStar The idea comes from Runtime Memory Tracer of PatrickStar
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management `PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
https://arxiv.org/abs/2108.05818
:param power: the power of time interval, defaults to 10 Usage::
:type power: int
Usage:
::
```python
async_mem_monitor = AsyncMemoryMonitor() async_mem_monitor = AsyncMemoryMonitor()
input = torch.randn(2, 20).cuda() input = torch.randn(2, 20).cuda()
OP1 = torch.nn.Linear(20, 30).cuda() OP1 = torch.nn.Linear(20, 30).cuda()
@ -36,7 +30,12 @@ class AsyncMemoryMonitor:
output = OP2(output) output = OP2(output)
async_mem_monitor.finish() async_mem_monitor.finish()
async_mem_monitor.save('log.pkl') async_mem_monitor.save('log.pkl')
```
Args:
power (int, optional): the power of time interva. Defaults to 10.
.. _PatrickStar\: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
https://arxiv.org/abs/2108.05818
""" """
def __init__(self, power: int = 10): def __init__(self, power: int = 10):

View File

@ -12,6 +12,8 @@ class MemProfiler(BaseProfiler):
To use this profiler, you need to pass an `engine` instance. And the usage is same like To use this profiler, you need to pass an `engine` instance. And the usage is same like
CommProfiler. CommProfiler.
Usage::
mm_prof = MemProfiler(engine) mm_prof = MemProfiler(engine)
with ProfilerContext([mm_prof]) as prof: with ProfilerContext([mm_prof]) as prof:
writer = SummaryWriter("mem") writer = SummaryWriter("mem")
@ -36,11 +38,7 @@ class MemProfiler(BaseProfiler):
def to_tensorboard(self, writer: SummaryWriter) -> None: def to_tensorboard(self, writer: SummaryWriter) -> None:
stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats'] stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats']
for info, i in enumerate(stats): for info, i in enumerate(stats):
writer.add_scalar( writer.add_scalar("memory_usage/GPU", info, i)
"memory_usage/GPU",
info,
i
)
def to_file(self, data_file: Path) -> None: def to_file(self, data_file: Path) -> None:
self._mem_tracer.save_results(data_file) self._mem_tracer.save_results(data_file)

View File

@ -70,12 +70,10 @@ class BaseProfiler(ABC):
class ProfilerContext(object): class ProfilerContext(object):
""" """Profiler context manager
Profiler context manager
Usage: Usage::
::
```python
world_size = 4 world_size = 4
inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device()) inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device())
outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device()) outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device())
@ -92,7 +90,6 @@ class ProfilerContext(object):
dist.reduce(inputs, 0) dist.reduce(inputs, 0)
prof.show() prof.show()
```
""" """
def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True): def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True):