[pipeline/rank_recorder] fix bug when process data before backward | add a tool for multiple ranks debug (#1681)

* [pipeline/tuning] improve dispatch performance both time and space cost

* [pipeline/converge] add interface for testing convergence

* [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style

* Update PipelineBase.py

* [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera

* [pipeline/chimera] test chimera | fix bug of initializing

* [pipeline/pytree] add pytree to process args and kwargs | provide  to process args and kwargs after forward
pull/1684/head
Kirigaya Kazuto 2 years ago committed by GitHub
parent 517b63939a
commit 3b2a59b0ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,6 +5,7 @@ from functools import partial
from abc import ABC, abstractmethod
import sys
import os
import time
import inspect
import torch
@ -12,12 +13,13 @@ from torch import nn
import torch.distributed.rpc as rpc
from torch.futures import Future
from torch._C._distributed_rpc import PyRRef
from torch import autograd
from torch import optim
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail,
pytree_map, get_real_args_kwargs, use_color_debug)
pytree_map, pytree_filter, get_real_args_kwargs, use_color_debug)
class Phase(Enum):
@ -469,6 +471,7 @@ class WorkerBase(ABC):
else:
consume_result = self.module_partition(*args, **kwargs)
# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )
if is_last_stage and self.criterion:
@ -495,7 +498,6 @@ class WorkerBase(ABC):
stage_input_kwargs,
stage_outputs,
checkpoint=use_checkpoint)
# if not forward_only, do the backward
if not forward_only:
if is_last_stage: # if it is the last stage, trigger backward automatic
@ -521,19 +523,19 @@ class WorkerBase(ABC):
if use_checkpoint:
stage_outputs = [self.module_partition(*stage_input_args, **stage_input_kwargs)]
# take tensor only (for only tensor can do backward)
stage_outputs_tensors = []
pytree_map(stage_outputs, stage_outputs_tensors.append, process_types=torch.Tensor)
# overlap recompute and future.wait
grad_tensors = get_real_args_kwargs(args)
if not is_last_stage:
grad_tensors = get_real_args_kwargs(args)
else:
grad_tensors = None
# take tensor only (for only tensor can do backward)
stage_outputs = pytree_filter(lambda x: x.requires_grad, stage_outputs, process_types=torch.Tensor)
grad_tensors = pytree_filter(lambda x: x is not None, grad_tensors, process_types=torch.Tensor)
# print('rank', self.pp_rank, tensor_shape_list(stage_outputs_tensors), tensor_shape_list(grad_tensors))
autograd.backward(stage_outputs_tensors, grad_tensors=grad_tensors)
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
# collect grad of input tensor
# there is a hypothesis that node in kwargs cann't be an non-leaf node in graph
# so we don't need to save the grad of node in kwargs.
consume_result = []
if not is_first_stage:
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)

@ -110,7 +110,7 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
if chunk > 1:
assert num_microbatches % stage_num == 0, \
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
# assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
use_1F1B = True
super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,

@ -20,7 +20,8 @@ def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] =
Args:
obj (:class:`Any`): object to process
fn (:class:`Callable`): a function to process subobject in obj
process_types(:class: `type | tuple[type]`): types to determine the type to process
process_types (:class: `type | tuple[type]`): types to determine the type to process
map_all (:class: `bool`): if map_all is True, then any type of element will use fn
Returns:
:class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
@ -59,6 +60,20 @@ def type_detail(obj):
return pytree_map(obj, lambda x: type(x), map_all=True)
def pytree_filter(fn, obj, process_types):
if obj is None:
return None
filters = []
def condition_append(obj):
if fn(obj):
filters.append(obj)
pytree_map(obj, fn=condition_append, process_types=process_types)
return filters
def get_real_args_kwargs(args_or_kwargs):
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
# TODO : combine producer and consumer

@ -0,0 +1,72 @@
# Rank Recorder
This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualise the json file easily.
Before using the tool, you should ensure dist.is_initialized() return true before exit of program.
## Usage
Is very simple:
```python
from colossalai.utils.rank_recorder import recorder
...
...
with recorder(record_name, current_rank) as r:
"""procedure to record
"""
```
## Example
This is a demo to display kernel select in cuda and visualise the cost of several procedures in each rank.
```python
import time
import os
import logging
logging.disable(logging.INFO)
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.utils.rank_recorder import recorder
WORLD_SIZE = 4
# config the export image here
# If you want to dive into the detail, format 'svg' is recommended
recorder.export_format = 'png'
recorder.export_name = 'kernel_select'
recorder.dpi = 500
def calc(x, y):
a = torch.randn(x, y).cuda()
b = torch.randn(x, y).cuda()
c = sum(a * b)
return c
def worker(rank):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29020'
dist.init_process_group(backend='nccl', world_size=WORLD_SIZE, rank=rank)
print(dist.get_rank(), "enter")
time.sleep(0.1 * rank)
with recorder("calc_1(x100)", rank) as r:
calc(100, 100)
with recorder("calc_2(x400)", rank) as r:
calc(400, 400)
with recorder("calc_2(x200)", rank) as r:
calc(200, 200)
if __name__ == "__main__":
mp.spawn(worker, nprocs=WORLD_SIZE)
```
run the script directly and you will get `kernel_select.json` and `kernel_select.png` in your current folder.

@ -0,0 +1,3 @@
from colossalai.utils.rank_recorder.rank_recorder import recorder
__all__ = ["recorder"]

@ -0,0 +1,178 @@
import time
from typing import List, Dict
import json
import os
import time
import shutil
import atexit
import torch
import torch.distributed as dist
import json
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
cmap = list(mcolors.TABLEAU_COLORS.values())
LOG_FOLDER = "record.log"
MAX_WAIT_TIME = 20
class Event:
def __init__(self, start: int, end: int, name: str, rank: int) -> None:
self.start = start
self.end = end
self.name = name
self.rank = rank
class Recorder:
def __init__(self) -> None:
self.rank_to_history: Dict[int, List[Event]] = {}
self.base_time = time.time()
self.temp_event = None
self.export_format = 'png'
self.export_name = 'test'
self.dpi = 500
self.theme = 'dark_background'
self.figure_width = 30
self.figure_height = 10
self.legend_fontsize = 16
self.device_fontsize = 20
self.bar_height = 0.2
if not os.path.exists(LOG_FOLDER):
os.makedirs(LOG_FOLDER)
def start(self, name: str, rank: int):
# TODO : add lock to prevent conflict
torch.cuda.synchronize()
start_time = time.time()
self.temp_event = Event(start_time, None, name, rank)
def end(self):
assert self.temp_event is not None, "`start` before `end`"
torch.cuda.synchronize()
end_time = time.time()
self.temp_event.end = end_time
rank = self.temp_event.rank
if rank not in self.rank_to_history:
self.rank_to_history[rank] = []
self.rank_to_history[rank].append(self.temp_event)
self.temp_event = None
def get_history(self):
return self.history
def __call__(self, name: str, rank: str):
self.temp_name = name
self.temp_rank = rank
return self
def __enter__(self):
name = self.temp_name
rank = self.temp_rank
self.start(name, rank)
def __exit__(self, *args):
self.end()
def dump_record(self):
rank = dist.get_rank()
rank_to_history = self.rank_to_history
records = {'base_time': self.base_time, 'content': {}}
for record_rank in rank_to_history:
history = rank_to_history[record_rank]
recs = []
for event in history:
rec = {'start': event.start, 'end': event.end, 'name': event.name}
recs.append(rec)
records['content'][record_rank] = recs
dump_name = f'{rank}.json'
dump_path = os.path.join(LOG_FOLDER, dump_name)
with open(dump_path, 'w', encoding='utf-8') as f:
json.dump(records, f, ensure_ascii=False)
def merge_recode(self):
base_time = self.base_time
world_size = dist.get_world_size()
wait_time = 0
while True:
time.sleep(0.1)
log_num = len(os.listdir(LOG_FOLDER))
if log_num == world_size:
break
wait_time += 1
if wait_time >= MAX_WAIT_TIME:
break
# merge
logs_path = [os.path.join(LOG_FOLDER, file) for file in os.listdir(LOG_FOLDER)]
recoders = {}
for path in logs_path:
with open(path, 'r', encoding='utf-8') as f:
recs = json.load(f)
for record_rank in recs['content']:
history = recs['content'][record_rank]
recoders[record_rank] = []
for rec in history:
recoders[record_rank].append({
'start': rec['start'] - base_time,
'end': rec['end'] - base_time,
'name': rec['name']
})
shutil.rmtree(LOG_FOLDER)
with open(self.export_name + '.json', 'w', encoding='utf-8') as f:
json.dump(recoders, f, ensure_ascii=False)
def visualise_record(self):
with open(self.export_name + '.json', 'r', encoding='utf-8') as f:
records = json.load(f)
records = dict(records)
ranks = list(sorted(records.keys()))
name_list = {}
plots = {}
plt.figure(dpi=self.dpi, figsize=[self.figure_width, self.figure_height])
plt.style.use(self.theme)
for rank in ranks:
rank_records = records[rank]
for rec in rank_records:
s = rec['start']
e = rec['end']
name = rec['name']
if name not in name_list:
name_list[name] = len(name_list)
bar = plt.barh(rank, width=e - s, height=self.bar_height, left=s, color=cmap[name_list[name]])
if name not in plots:
plots[name] = bar
plt.legend(list(plots.values()), list(plots.keys()), loc="upper left", fontsize=self.legend_fontsize)
plt.yticks(ticks=ranks, labels=[f'Device:{rank}' for rank in ranks], fontsize=self.device_fontsize)
plt.grid(axis='x')
plt.savefig("{}.{}".format(self.export_name, self.export_format))
def exit_worker(self):
if len(self.rank_to_history) == 0:
return
self.dump_record()
# if this is rank 0, wait for merge
rank = dist.get_rank()
if rank == 1:
# take the base time of rank 0 as standard
self.merge_recode()
self.visualise_record()
recorder = Recorder()
atexit.register(recorder.exit_worker)
Loading…
Cancel
Save