mirror of https://github.com/hpcaitech/ColossalAI
[pipeline/rank_recorder] fix bug when process data before backward | add a tool for multiple ranks debug (#1681)
* [pipeline/tuning] improve dispatch performance both time and space cost * [pipeline/converge] add interface for testing convergence * [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style * Update PipelineBase.py * [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera * [pipeline/chimera] test chimera | fix bug of initializing * [pipeline/pytree] add pytree to process args and kwargs | provide to process args and kwargs after forwardpull/1684/head
parent
517b63939a
commit
3b2a59b0ba
@ -0,0 +1,72 @@
|
||||
# Rank Recorder
|
||||
This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualise the json file easily.
|
||||
|
||||
Before using the tool, you should ensure dist.is_initialized() return true before exit of program.
|
||||
|
||||
## Usage
|
||||
|
||||
Is very simple:
|
||||
|
||||
```python
|
||||
from colossalai.utils.rank_recorder import recorder
|
||||
|
||||
...
|
||||
...
|
||||
|
||||
with recorder(record_name, current_rank) as r:
|
||||
"""procedure to record
|
||||
"""
|
||||
|
||||
```
|
||||
|
||||
## Example
|
||||
This is a demo to display kernel select in cuda and visualise the cost of several procedures in each rank.
|
||||
|
||||
```python
|
||||
import time
|
||||
import os
|
||||
import logging
|
||||
logging.disable(logging.INFO)
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.utils.rank_recorder import recorder
|
||||
|
||||
|
||||
WORLD_SIZE = 4
|
||||
|
||||
# config the export image here
|
||||
# If you want to dive into the detail, format 'svg' is recommended
|
||||
recorder.export_format = 'png'
|
||||
recorder.export_name = 'kernel_select'
|
||||
recorder.dpi = 500
|
||||
|
||||
def calc(x, y):
|
||||
a = torch.randn(x, y).cuda()
|
||||
b = torch.randn(x, y).cuda()
|
||||
c = sum(a * b)
|
||||
return c
|
||||
|
||||
def worker(rank):
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '29020'
|
||||
dist.init_process_group(backend='nccl', world_size=WORLD_SIZE, rank=rank)
|
||||
print(dist.get_rank(), "enter")
|
||||
time.sleep(0.1 * rank)
|
||||
|
||||
with recorder("calc_1(x100)", rank) as r:
|
||||
calc(100, 100)
|
||||
|
||||
with recorder("calc_2(x400)", rank) as r:
|
||||
calc(400, 400)
|
||||
|
||||
with recorder("calc_2(x200)", rank) as r:
|
||||
calc(200, 200)
|
||||
|
||||
if __name__ == "__main__":
|
||||
mp.spawn(worker, nprocs=WORLD_SIZE)
|
||||
```
|
||||
|
||||
run the script directly and you will get `kernel_select.json` and `kernel_select.png` in your current folder.
|
@ -0,0 +1,3 @@
|
||||
from colossalai.utils.rank_recorder.rank_recorder import recorder
|
||||
|
||||
__all__ = ["recorder"]
|
@ -0,0 +1,178 @@
|
||||
import time
|
||||
from typing import List, Dict
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import shutil
|
||||
import atexit
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.colors as mcolors
|
||||
|
||||
cmap = list(mcolors.TABLEAU_COLORS.values())
|
||||
|
||||
LOG_FOLDER = "record.log"
|
||||
MAX_WAIT_TIME = 20
|
||||
|
||||
|
||||
class Event:
|
||||
|
||||
def __init__(self, start: int, end: int, name: str, rank: int) -> None:
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.name = name
|
||||
self.rank = rank
|
||||
|
||||
|
||||
class Recorder:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.rank_to_history: Dict[int, List[Event]] = {}
|
||||
self.base_time = time.time()
|
||||
self.temp_event = None
|
||||
|
||||
self.export_format = 'png'
|
||||
self.export_name = 'test'
|
||||
self.dpi = 500
|
||||
self.theme = 'dark_background'
|
||||
self.figure_width = 30
|
||||
self.figure_height = 10
|
||||
self.legend_fontsize = 16
|
||||
self.device_fontsize = 20
|
||||
self.bar_height = 0.2
|
||||
|
||||
if not os.path.exists(LOG_FOLDER):
|
||||
os.makedirs(LOG_FOLDER)
|
||||
|
||||
def start(self, name: str, rank: int):
|
||||
# TODO : add lock to prevent conflict
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
self.temp_event = Event(start_time, None, name, rank)
|
||||
|
||||
def end(self):
|
||||
assert self.temp_event is not None, "`start` before `end`"
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
self.temp_event.end = end_time
|
||||
rank = self.temp_event.rank
|
||||
if rank not in self.rank_to_history:
|
||||
self.rank_to_history[rank] = []
|
||||
self.rank_to_history[rank].append(self.temp_event)
|
||||
self.temp_event = None
|
||||
|
||||
def get_history(self):
|
||||
return self.history
|
||||
|
||||
def __call__(self, name: str, rank: str):
|
||||
self.temp_name = name
|
||||
self.temp_rank = rank
|
||||
return self
|
||||
|
||||
def __enter__(self):
|
||||
name = self.temp_name
|
||||
rank = self.temp_rank
|
||||
self.start(name, rank)
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.end()
|
||||
|
||||
def dump_record(self):
|
||||
rank = dist.get_rank()
|
||||
rank_to_history = self.rank_to_history
|
||||
records = {'base_time': self.base_time, 'content': {}}
|
||||
for record_rank in rank_to_history:
|
||||
history = rank_to_history[record_rank]
|
||||
recs = []
|
||||
for event in history:
|
||||
rec = {'start': event.start, 'end': event.end, 'name': event.name}
|
||||
recs.append(rec)
|
||||
records['content'][record_rank] = recs
|
||||
|
||||
dump_name = f'{rank}.json'
|
||||
dump_path = os.path.join(LOG_FOLDER, dump_name)
|
||||
with open(dump_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(records, f, ensure_ascii=False)
|
||||
|
||||
def merge_recode(self):
|
||||
base_time = self.base_time
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
wait_time = 0
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
log_num = len(os.listdir(LOG_FOLDER))
|
||||
if log_num == world_size:
|
||||
break
|
||||
|
||||
wait_time += 1
|
||||
if wait_time >= MAX_WAIT_TIME:
|
||||
break
|
||||
|
||||
# merge
|
||||
logs_path = [os.path.join(LOG_FOLDER, file) for file in os.listdir(LOG_FOLDER)]
|
||||
recoders = {}
|
||||
for path in logs_path:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
recs = json.load(f)
|
||||
for record_rank in recs['content']:
|
||||
history = recs['content'][record_rank]
|
||||
recoders[record_rank] = []
|
||||
for rec in history:
|
||||
recoders[record_rank].append({
|
||||
'start': rec['start'] - base_time,
|
||||
'end': rec['end'] - base_time,
|
||||
'name': rec['name']
|
||||
})
|
||||
|
||||
shutil.rmtree(LOG_FOLDER)
|
||||
with open(self.export_name + '.json', 'w', encoding='utf-8') as f:
|
||||
json.dump(recoders, f, ensure_ascii=False)
|
||||
|
||||
def visualise_record(self):
|
||||
with open(self.export_name + '.json', 'r', encoding='utf-8') as f:
|
||||
records = json.load(f)
|
||||
records = dict(records)
|
||||
ranks = list(sorted(records.keys()))
|
||||
|
||||
name_list = {}
|
||||
plots = {}
|
||||
plt.figure(dpi=self.dpi, figsize=[self.figure_width, self.figure_height])
|
||||
plt.style.use(self.theme)
|
||||
|
||||
for rank in ranks:
|
||||
rank_records = records[rank]
|
||||
for rec in rank_records:
|
||||
s = rec['start']
|
||||
e = rec['end']
|
||||
name = rec['name']
|
||||
if name not in name_list:
|
||||
name_list[name] = len(name_list)
|
||||
bar = plt.barh(rank, width=e - s, height=self.bar_height, left=s, color=cmap[name_list[name]])
|
||||
if name not in plots:
|
||||
plots[name] = bar
|
||||
|
||||
plt.legend(list(plots.values()), list(plots.keys()), loc="upper left", fontsize=self.legend_fontsize)
|
||||
plt.yticks(ticks=ranks, labels=[f'Device:{rank}' for rank in ranks], fontsize=self.device_fontsize)
|
||||
plt.grid(axis='x')
|
||||
plt.savefig("{}.{}".format(self.export_name, self.export_format))
|
||||
|
||||
def exit_worker(self):
|
||||
if len(self.rank_to_history) == 0:
|
||||
return
|
||||
self.dump_record()
|
||||
# if this is rank 0, wait for merge
|
||||
rank = dist.get_rank()
|
||||
|
||||
if rank == 1:
|
||||
# take the base time of rank 0 as standard
|
||||
self.merge_recode()
|
||||
self.visualise_record()
|
||||
|
||||
|
||||
recorder = Recorder()
|
||||
atexit.register(recorder.exit_worker)
|
Loading…
Reference in new issue