mirror of https://github.com/hpcaitech/ColossalAI
[pipeline/rank_recorder] fix bug when process data before backward | add a tool for multiple ranks debug (#1681)
* [pipeline/tuning] improve dispatch performance both time and space cost * [pipeline/converge] add interface for testing convergence * [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style * Update PipelineBase.py * [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera * [pipeline/chimera] test chimera | fix bug of initializing * [pipeline/pytree] add pytree to process args and kwargs | provide to process args and kwargs after forwardpull/1684/head
parent
517b63939a
commit
3b2a59b0ba
|
@ -5,6 +5,7 @@ from functools import partial
|
|||
from abc import ABC, abstractmethod
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
|
@ -12,12 +13,13 @@ from torch import nn
|
|||
import torch.distributed.rpc as rpc
|
||||
from torch.futures import Future
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
|
||||
from torch import autograd
|
||||
from torch import optim
|
||||
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail,
|
||||
pytree_map, get_real_args_kwargs, use_color_debug)
|
||||
pytree_map, pytree_filter, get_real_args_kwargs, use_color_debug)
|
||||
|
||||
|
||||
class Phase(Enum):
|
||||
|
@ -469,6 +471,7 @@ class WorkerBase(ABC):
|
|||
|
||||
else:
|
||||
consume_result = self.module_partition(*args, **kwargs)
|
||||
|
||||
# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )
|
||||
|
||||
if is_last_stage and self.criterion:
|
||||
|
@ -495,7 +498,6 @@ class WorkerBase(ABC):
|
|||
stage_input_kwargs,
|
||||
stage_outputs,
|
||||
checkpoint=use_checkpoint)
|
||||
|
||||
# if not forward_only, do the backward
|
||||
if not forward_only:
|
||||
if is_last_stage: # if it is the last stage, trigger backward automatic
|
||||
|
@ -521,19 +523,19 @@ class WorkerBase(ABC):
|
|||
if use_checkpoint:
|
||||
stage_outputs = [self.module_partition(*stage_input_args, **stage_input_kwargs)]
|
||||
|
||||
# take tensor only (for only tensor can do backward)
|
||||
stage_outputs_tensors = []
|
||||
pytree_map(stage_outputs, stage_outputs_tensors.append, process_types=torch.Tensor)
|
||||
|
||||
# overlap recompute and future.wait
|
||||
grad_tensors = get_real_args_kwargs(args)
|
||||
if not is_last_stage:
|
||||
grad_tensors = get_real_args_kwargs(args)
|
||||
else:
|
||||
grad_tensors = None
|
||||
|
||||
# print('rank', self.pp_rank, tensor_shape_list(stage_outputs_tensors), tensor_shape_list(grad_tensors))
|
||||
autograd.backward(stage_outputs_tensors, grad_tensors=grad_tensors)
|
||||
# take tensor only (for only tensor can do backward)
|
||||
stage_outputs = pytree_filter(lambda x: x.requires_grad, stage_outputs, process_types=torch.Tensor)
|
||||
grad_tensors = pytree_filter(lambda x: x is not None, grad_tensors, process_types=torch.Tensor)
|
||||
|
||||
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
|
||||
|
||||
# collect grad of input tensor
|
||||
# there is a hypothesis that node in kwargs cann't be an non-leaf node in graph
|
||||
# so we don't need to save the grad of node in kwargs.
|
||||
consume_result = []
|
||||
if not is_first_stage:
|
||||
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
|
||||
|
|
|
@ -110,7 +110,7 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
|
|||
if chunk > 1:
|
||||
assert num_microbatches % stage_num == 0, \
|
||||
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
||||
assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
|
||||
# assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
|
||||
use_1F1B = True
|
||||
|
||||
super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
||||
|
|
|
@ -20,7 +20,8 @@ def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] =
|
|||
Args:
|
||||
obj (:class:`Any`): object to process
|
||||
fn (:class:`Callable`): a function to process subobject in obj
|
||||
process_types(:class: `type | tuple[type]`): types to determine the type to process
|
||||
process_types (:class: `type | tuple[type]`): types to determine the type to process
|
||||
map_all (:class: `bool`): if map_all is True, then any type of element will use fn
|
||||
|
||||
Returns:
|
||||
:class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
|
||||
|
@ -59,6 +60,20 @@ def type_detail(obj):
|
|||
return pytree_map(obj, lambda x: type(x), map_all=True)
|
||||
|
||||
|
||||
def pytree_filter(fn, obj, process_types):
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
filters = []
|
||||
|
||||
def condition_append(obj):
|
||||
if fn(obj):
|
||||
filters.append(obj)
|
||||
|
||||
pytree_map(obj, fn=condition_append, process_types=process_types)
|
||||
return filters
|
||||
|
||||
|
||||
def get_real_args_kwargs(args_or_kwargs):
|
||||
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
||||
# TODO : combine producer and consumer
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
# Rank Recorder
|
||||
This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualise the json file easily.
|
||||
|
||||
Before using the tool, you should ensure dist.is_initialized() return true before exit of program.
|
||||
|
||||
## Usage
|
||||
|
||||
Is very simple:
|
||||
|
||||
```python
|
||||
from colossalai.utils.rank_recorder import recorder
|
||||
|
||||
...
|
||||
...
|
||||
|
||||
with recorder(record_name, current_rank) as r:
|
||||
"""procedure to record
|
||||
"""
|
||||
|
||||
```
|
||||
|
||||
## Example
|
||||
This is a demo to display kernel select in cuda and visualise the cost of several procedures in each rank.
|
||||
|
||||
```python
|
||||
import time
|
||||
import os
|
||||
import logging
|
||||
logging.disable(logging.INFO)
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.utils.rank_recorder import recorder
|
||||
|
||||
|
||||
WORLD_SIZE = 4
|
||||
|
||||
# config the export image here
|
||||
# If you want to dive into the detail, format 'svg' is recommended
|
||||
recorder.export_format = 'png'
|
||||
recorder.export_name = 'kernel_select'
|
||||
recorder.dpi = 500
|
||||
|
||||
def calc(x, y):
|
||||
a = torch.randn(x, y).cuda()
|
||||
b = torch.randn(x, y).cuda()
|
||||
c = sum(a * b)
|
||||
return c
|
||||
|
||||
def worker(rank):
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '29020'
|
||||
dist.init_process_group(backend='nccl', world_size=WORLD_SIZE, rank=rank)
|
||||
print(dist.get_rank(), "enter")
|
||||
time.sleep(0.1 * rank)
|
||||
|
||||
with recorder("calc_1(x100)", rank) as r:
|
||||
calc(100, 100)
|
||||
|
||||
with recorder("calc_2(x400)", rank) as r:
|
||||
calc(400, 400)
|
||||
|
||||
with recorder("calc_2(x200)", rank) as r:
|
||||
calc(200, 200)
|
||||
|
||||
if __name__ == "__main__":
|
||||
mp.spawn(worker, nprocs=WORLD_SIZE)
|
||||
```
|
||||
|
||||
run the script directly and you will get `kernel_select.json` and `kernel_select.png` in your current folder.
|
|
@ -0,0 +1,3 @@
|
|||
from colossalai.utils.rank_recorder.rank_recorder import recorder
|
||||
|
||||
__all__ = ["recorder"]
|
|
@ -0,0 +1,178 @@
|
|||
import time
|
||||
from typing import List, Dict
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import shutil
|
||||
import atexit
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.colors as mcolors
|
||||
|
||||
cmap = list(mcolors.TABLEAU_COLORS.values())
|
||||
|
||||
LOG_FOLDER = "record.log"
|
||||
MAX_WAIT_TIME = 20
|
||||
|
||||
|
||||
class Event:
|
||||
|
||||
def __init__(self, start: int, end: int, name: str, rank: int) -> None:
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.name = name
|
||||
self.rank = rank
|
||||
|
||||
|
||||
class Recorder:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.rank_to_history: Dict[int, List[Event]] = {}
|
||||
self.base_time = time.time()
|
||||
self.temp_event = None
|
||||
|
||||
self.export_format = 'png'
|
||||
self.export_name = 'test'
|
||||
self.dpi = 500
|
||||
self.theme = 'dark_background'
|
||||
self.figure_width = 30
|
||||
self.figure_height = 10
|
||||
self.legend_fontsize = 16
|
||||
self.device_fontsize = 20
|
||||
self.bar_height = 0.2
|
||||
|
||||
if not os.path.exists(LOG_FOLDER):
|
||||
os.makedirs(LOG_FOLDER)
|
||||
|
||||
def start(self, name: str, rank: int):
|
||||
# TODO : add lock to prevent conflict
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
self.temp_event = Event(start_time, None, name, rank)
|
||||
|
||||
def end(self):
|
||||
assert self.temp_event is not None, "`start` before `end`"
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
self.temp_event.end = end_time
|
||||
rank = self.temp_event.rank
|
||||
if rank not in self.rank_to_history:
|
||||
self.rank_to_history[rank] = []
|
||||
self.rank_to_history[rank].append(self.temp_event)
|
||||
self.temp_event = None
|
||||
|
||||
def get_history(self):
|
||||
return self.history
|
||||
|
||||
def __call__(self, name: str, rank: str):
|
||||
self.temp_name = name
|
||||
self.temp_rank = rank
|
||||
return self
|
||||
|
||||
def __enter__(self):
|
||||
name = self.temp_name
|
||||
rank = self.temp_rank
|
||||
self.start(name, rank)
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.end()
|
||||
|
||||
def dump_record(self):
|
||||
rank = dist.get_rank()
|
||||
rank_to_history = self.rank_to_history
|
||||
records = {'base_time': self.base_time, 'content': {}}
|
||||
for record_rank in rank_to_history:
|
||||
history = rank_to_history[record_rank]
|
||||
recs = []
|
||||
for event in history:
|
||||
rec = {'start': event.start, 'end': event.end, 'name': event.name}
|
||||
recs.append(rec)
|
||||
records['content'][record_rank] = recs
|
||||
|
||||
dump_name = f'{rank}.json'
|
||||
dump_path = os.path.join(LOG_FOLDER, dump_name)
|
||||
with open(dump_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(records, f, ensure_ascii=False)
|
||||
|
||||
def merge_recode(self):
|
||||
base_time = self.base_time
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
wait_time = 0
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
log_num = len(os.listdir(LOG_FOLDER))
|
||||
if log_num == world_size:
|
||||
break
|
||||
|
||||
wait_time += 1
|
||||
if wait_time >= MAX_WAIT_TIME:
|
||||
break
|
||||
|
||||
# merge
|
||||
logs_path = [os.path.join(LOG_FOLDER, file) for file in os.listdir(LOG_FOLDER)]
|
||||
recoders = {}
|
||||
for path in logs_path:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
recs = json.load(f)
|
||||
for record_rank in recs['content']:
|
||||
history = recs['content'][record_rank]
|
||||
recoders[record_rank] = []
|
||||
for rec in history:
|
||||
recoders[record_rank].append({
|
||||
'start': rec['start'] - base_time,
|
||||
'end': rec['end'] - base_time,
|
||||
'name': rec['name']
|
||||
})
|
||||
|
||||
shutil.rmtree(LOG_FOLDER)
|
||||
with open(self.export_name + '.json', 'w', encoding='utf-8') as f:
|
||||
json.dump(recoders, f, ensure_ascii=False)
|
||||
|
||||
def visualise_record(self):
|
||||
with open(self.export_name + '.json', 'r', encoding='utf-8') as f:
|
||||
records = json.load(f)
|
||||
records = dict(records)
|
||||
ranks = list(sorted(records.keys()))
|
||||
|
||||
name_list = {}
|
||||
plots = {}
|
||||
plt.figure(dpi=self.dpi, figsize=[self.figure_width, self.figure_height])
|
||||
plt.style.use(self.theme)
|
||||
|
||||
for rank in ranks:
|
||||
rank_records = records[rank]
|
||||
for rec in rank_records:
|
||||
s = rec['start']
|
||||
e = rec['end']
|
||||
name = rec['name']
|
||||
if name not in name_list:
|
||||
name_list[name] = len(name_list)
|
||||
bar = plt.barh(rank, width=e - s, height=self.bar_height, left=s, color=cmap[name_list[name]])
|
||||
if name not in plots:
|
||||
plots[name] = bar
|
||||
|
||||
plt.legend(list(plots.values()), list(plots.keys()), loc="upper left", fontsize=self.legend_fontsize)
|
||||
plt.yticks(ticks=ranks, labels=[f'Device:{rank}' for rank in ranks], fontsize=self.device_fontsize)
|
||||
plt.grid(axis='x')
|
||||
plt.savefig("{}.{}".format(self.export_name, self.export_format))
|
||||
|
||||
def exit_worker(self):
|
||||
if len(self.rank_to_history) == 0:
|
||||
return
|
||||
self.dump_record()
|
||||
# if this is rank 0, wait for merge
|
||||
rank = dist.get_rank()
|
||||
|
||||
if rank == 1:
|
||||
# take the base time of rank 0 as standard
|
||||
self.merge_recode()
|
||||
self.visualise_record()
|
||||
|
||||
|
||||
recorder = Recorder()
|
||||
atexit.register(recorder.exit_worker)
|
Loading…
Reference in New Issue