mirror of https://github.com/hpcaitech/ColossalAI
[profile] added example for ProfilerContext (#349)
parent
532ae79cb0
commit
c57e089824
|
@ -35,15 +35,25 @@ class ProfilerContext(object):
|
|||
"""
|
||||
Profiler context manager
|
||||
Usage:
|
||||
from colossalai.utils.profiler import CommProf, ProfilerContext
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
cc_prof = CommProf()
|
||||
|
||||
```python
|
||||
world_size = 4
|
||||
inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device())
|
||||
outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device())
|
||||
outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0))
|
||||
|
||||
cc_prof = CommProfiler()
|
||||
|
||||
with ProfilerContext([cc_prof]) as prof:
|
||||
train()
|
||||
writer = SummaryWriter('tb/path')
|
||||
prof.to_tensorboard(writer)
|
||||
prof.to_file('./prof_logs/')
|
||||
op = dist.all_reduce(inputs, async_op=True)
|
||||
dist.all_gather(outputs_list, inputs)
|
||||
op.wait()
|
||||
dist.reduce_scatter(inputs, outputs_list)
|
||||
dist.broadcast(inputs, 0)
|
||||
dist.reduce(inputs, 0)
|
||||
|
||||
prof.show()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True):
|
||||
|
|
Loading…
Reference in New Issue