mirror of https://github.com/hpcaitech/ColossalAI
[pipeline/rank_recorder] fix bug when process data before backward | add a tool for multiple ranks debug (#1681)
* [pipeline/tuning] improve dispatch performance both time and space cost * [pipeline/converge] add interface for testing convergence * [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style * Update PipelineBase.py * [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera * [pipeline/chimera] test chimera | fix bug of initializing * [pipeline/pytree] add pytree to process args and kwargs | provide to process args and kwargs after forwardpull/1684/head
parent
517b63939a
commit
3b2a59b0ba
|
@ -5,6 +5,7 @@ from functools import partial
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -12,12 +13,13 @@ from torch import nn
|
||||||
import torch.distributed.rpc as rpc
|
import torch.distributed.rpc as rpc
|
||||||
from torch.futures import Future
|
from torch.futures import Future
|
||||||
from torch._C._distributed_rpc import PyRRef
|
from torch._C._distributed_rpc import PyRRef
|
||||||
|
|
||||||
from torch import autograd
|
from torch import autograd
|
||||||
from torch import optim
|
from torch import optim
|
||||||
|
|
||||||
from colossalai.pipeline.pipeline_process_group import ppg
|
from colossalai.pipeline.pipeline_process_group import ppg
|
||||||
from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail,
|
from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail,
|
||||||
pytree_map, get_real_args_kwargs, use_color_debug)
|
pytree_map, pytree_filter, get_real_args_kwargs, use_color_debug)
|
||||||
|
|
||||||
|
|
||||||
class Phase(Enum):
|
class Phase(Enum):
|
||||||
|
@ -469,6 +471,7 @@ class WorkerBase(ABC):
|
||||||
|
|
||||||
else:
|
else:
|
||||||
consume_result = self.module_partition(*args, **kwargs)
|
consume_result = self.module_partition(*args, **kwargs)
|
||||||
|
|
||||||
# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )
|
# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )
|
||||||
|
|
||||||
if is_last_stage and self.criterion:
|
if is_last_stage and self.criterion:
|
||||||
|
@ -495,7 +498,6 @@ class WorkerBase(ABC):
|
||||||
stage_input_kwargs,
|
stage_input_kwargs,
|
||||||
stage_outputs,
|
stage_outputs,
|
||||||
checkpoint=use_checkpoint)
|
checkpoint=use_checkpoint)
|
||||||
|
|
||||||
# if not forward_only, do the backward
|
# if not forward_only, do the backward
|
||||||
if not forward_only:
|
if not forward_only:
|
||||||
if is_last_stage: # if it is the last stage, trigger backward automatic
|
if is_last_stage: # if it is the last stage, trigger backward automatic
|
||||||
|
@ -521,19 +523,19 @@ class WorkerBase(ABC):
|
||||||
if use_checkpoint:
|
if use_checkpoint:
|
||||||
stage_outputs = [self.module_partition(*stage_input_args, **stage_input_kwargs)]
|
stage_outputs = [self.module_partition(*stage_input_args, **stage_input_kwargs)]
|
||||||
|
|
||||||
# take tensor only (for only tensor can do backward)
|
|
||||||
stage_outputs_tensors = []
|
|
||||||
pytree_map(stage_outputs, stage_outputs_tensors.append, process_types=torch.Tensor)
|
|
||||||
|
|
||||||
# overlap recompute and future.wait
|
# overlap recompute and future.wait
|
||||||
|
if not is_last_stage:
|
||||||
grad_tensors = get_real_args_kwargs(args)
|
grad_tensors = get_real_args_kwargs(args)
|
||||||
|
else:
|
||||||
|
grad_tensors = None
|
||||||
|
|
||||||
# print('rank', self.pp_rank, tensor_shape_list(stage_outputs_tensors), tensor_shape_list(grad_tensors))
|
# take tensor only (for only tensor can do backward)
|
||||||
autograd.backward(stage_outputs_tensors, grad_tensors=grad_tensors)
|
stage_outputs = pytree_filter(lambda x: x.requires_grad, stage_outputs, process_types=torch.Tensor)
|
||||||
|
grad_tensors = pytree_filter(lambda x: x is not None, grad_tensors, process_types=torch.Tensor)
|
||||||
|
|
||||||
|
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
|
||||||
|
|
||||||
# collect grad of input tensor
|
# collect grad of input tensor
|
||||||
# there is a hypothesis that node in kwargs cann't be an non-leaf node in graph
|
|
||||||
# so we don't need to save the grad of node in kwargs.
|
|
||||||
consume_result = []
|
consume_result = []
|
||||||
if not is_first_stage:
|
if not is_first_stage:
|
||||||
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
|
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
|
||||||
|
|
|
@ -110,7 +110,7 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
|
||||||
if chunk > 1:
|
if chunk > 1:
|
||||||
assert num_microbatches % stage_num == 0, \
|
assert num_microbatches % stage_num == 0, \
|
||||||
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
||||||
assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
|
# assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
|
||||||
use_1F1B = True
|
use_1F1B = True
|
||||||
|
|
||||||
super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
||||||
|
|
|
@ -20,7 +20,8 @@ def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] =
|
||||||
Args:
|
Args:
|
||||||
obj (:class:`Any`): object to process
|
obj (:class:`Any`): object to process
|
||||||
fn (:class:`Callable`): a function to process subobject in obj
|
fn (:class:`Callable`): a function to process subobject in obj
|
||||||
process_types(:class: `type | tuple[type]`): types to determine the type to process
|
process_types (:class: `type | tuple[type]`): types to determine the type to process
|
||||||
|
map_all (:class: `bool`): if map_all is True, then any type of element will use fn
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
|
:class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
|
||||||
|
@ -59,6 +60,20 @@ def type_detail(obj):
|
||||||
return pytree_map(obj, lambda x: type(x), map_all=True)
|
return pytree_map(obj, lambda x: type(x), map_all=True)
|
||||||
|
|
||||||
|
|
||||||
|
def pytree_filter(fn, obj, process_types):
|
||||||
|
if obj is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
filters = []
|
||||||
|
|
||||||
|
def condition_append(obj):
|
||||||
|
if fn(obj):
|
||||||
|
filters.append(obj)
|
||||||
|
|
||||||
|
pytree_map(obj, fn=condition_append, process_types=process_types)
|
||||||
|
return filters
|
||||||
|
|
||||||
|
|
||||||
def get_real_args_kwargs(args_or_kwargs):
|
def get_real_args_kwargs(args_or_kwargs):
|
||||||
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
||||||
# TODO : combine producer and consumer
|
# TODO : combine producer and consumer
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
# Rank Recorder
|
||||||
|
This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualise the json file easily.
|
||||||
|
|
||||||
|
Before using the tool, you should ensure dist.is_initialized() return true before exit of program.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
Is very simple:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from colossalai.utils.rank_recorder import recorder
|
||||||
|
|
||||||
|
...
|
||||||
|
...
|
||||||
|
|
||||||
|
with recorder(record_name, current_rank) as r:
|
||||||
|
"""procedure to record
|
||||||
|
"""
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example
|
||||||
|
This is a demo to display kernel select in cuda and visualise the cost of several procedures in each rank.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
logging.disable(logging.INFO)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
from colossalai.utils.rank_recorder import recorder
|
||||||
|
|
||||||
|
|
||||||
|
WORLD_SIZE = 4
|
||||||
|
|
||||||
|
# config the export image here
|
||||||
|
# If you want to dive into the detail, format 'svg' is recommended
|
||||||
|
recorder.export_format = 'png'
|
||||||
|
recorder.export_name = 'kernel_select'
|
||||||
|
recorder.dpi = 500
|
||||||
|
|
||||||
|
def calc(x, y):
|
||||||
|
a = torch.randn(x, y).cuda()
|
||||||
|
b = torch.randn(x, y).cuda()
|
||||||
|
c = sum(a * b)
|
||||||
|
return c
|
||||||
|
|
||||||
|
def worker(rank):
|
||||||
|
os.environ['MASTER_ADDR'] = 'localhost'
|
||||||
|
os.environ['MASTER_PORT'] = '29020'
|
||||||
|
dist.init_process_group(backend='nccl', world_size=WORLD_SIZE, rank=rank)
|
||||||
|
print(dist.get_rank(), "enter")
|
||||||
|
time.sleep(0.1 * rank)
|
||||||
|
|
||||||
|
with recorder("calc_1(x100)", rank) as r:
|
||||||
|
calc(100, 100)
|
||||||
|
|
||||||
|
with recorder("calc_2(x400)", rank) as r:
|
||||||
|
calc(400, 400)
|
||||||
|
|
||||||
|
with recorder("calc_2(x200)", rank) as r:
|
||||||
|
calc(200, 200)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mp.spawn(worker, nprocs=WORLD_SIZE)
|
||||||
|
```
|
||||||
|
|
||||||
|
run the script directly and you will get `kernel_select.json` and `kernel_select.png` in your current folder.
|
|
@ -0,0 +1,3 @@
|
||||||
|
from colossalai.utils.rank_recorder.rank_recorder import recorder
|
||||||
|
|
||||||
|
__all__ = ["recorder"]
|
|
@ -0,0 +1,178 @@
|
||||||
|
import time
|
||||||
|
from typing import List, Dict
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import shutil
|
||||||
|
import atexit
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
import json
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.colors as mcolors
|
||||||
|
|
||||||
|
cmap = list(mcolors.TABLEAU_COLORS.values())
|
||||||
|
|
||||||
|
LOG_FOLDER = "record.log"
|
||||||
|
MAX_WAIT_TIME = 20
|
||||||
|
|
||||||
|
|
||||||
|
class Event:
|
||||||
|
|
||||||
|
def __init__(self, start: int, end: int, name: str, rank: int) -> None:
|
||||||
|
self.start = start
|
||||||
|
self.end = end
|
||||||
|
self.name = name
|
||||||
|
self.rank = rank
|
||||||
|
|
||||||
|
|
||||||
|
class Recorder:
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.rank_to_history: Dict[int, List[Event]] = {}
|
||||||
|
self.base_time = time.time()
|
||||||
|
self.temp_event = None
|
||||||
|
|
||||||
|
self.export_format = 'png'
|
||||||
|
self.export_name = 'test'
|
||||||
|
self.dpi = 500
|
||||||
|
self.theme = 'dark_background'
|
||||||
|
self.figure_width = 30
|
||||||
|
self.figure_height = 10
|
||||||
|
self.legend_fontsize = 16
|
||||||
|
self.device_fontsize = 20
|
||||||
|
self.bar_height = 0.2
|
||||||
|
|
||||||
|
if not os.path.exists(LOG_FOLDER):
|
||||||
|
os.makedirs(LOG_FOLDER)
|
||||||
|
|
||||||
|
def start(self, name: str, rank: int):
|
||||||
|
# TODO : add lock to prevent conflict
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start_time = time.time()
|
||||||
|
self.temp_event = Event(start_time, None, name, rank)
|
||||||
|
|
||||||
|
def end(self):
|
||||||
|
assert self.temp_event is not None, "`start` before `end`"
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end_time = time.time()
|
||||||
|
self.temp_event.end = end_time
|
||||||
|
rank = self.temp_event.rank
|
||||||
|
if rank not in self.rank_to_history:
|
||||||
|
self.rank_to_history[rank] = []
|
||||||
|
self.rank_to_history[rank].append(self.temp_event)
|
||||||
|
self.temp_event = None
|
||||||
|
|
||||||
|
def get_history(self):
|
||||||
|
return self.history
|
||||||
|
|
||||||
|
def __call__(self, name: str, rank: str):
|
||||||
|
self.temp_name = name
|
||||||
|
self.temp_rank = rank
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
name = self.temp_name
|
||||||
|
rank = self.temp_rank
|
||||||
|
self.start(name, rank)
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
self.end()
|
||||||
|
|
||||||
|
def dump_record(self):
|
||||||
|
rank = dist.get_rank()
|
||||||
|
rank_to_history = self.rank_to_history
|
||||||
|
records = {'base_time': self.base_time, 'content': {}}
|
||||||
|
for record_rank in rank_to_history:
|
||||||
|
history = rank_to_history[record_rank]
|
||||||
|
recs = []
|
||||||
|
for event in history:
|
||||||
|
rec = {'start': event.start, 'end': event.end, 'name': event.name}
|
||||||
|
recs.append(rec)
|
||||||
|
records['content'][record_rank] = recs
|
||||||
|
|
||||||
|
dump_name = f'{rank}.json'
|
||||||
|
dump_path = os.path.join(LOG_FOLDER, dump_name)
|
||||||
|
with open(dump_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(records, f, ensure_ascii=False)
|
||||||
|
|
||||||
|
def merge_recode(self):
|
||||||
|
base_time = self.base_time
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
|
wait_time = 0
|
||||||
|
while True:
|
||||||
|
time.sleep(0.1)
|
||||||
|
log_num = len(os.listdir(LOG_FOLDER))
|
||||||
|
if log_num == world_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
wait_time += 1
|
||||||
|
if wait_time >= MAX_WAIT_TIME:
|
||||||
|
break
|
||||||
|
|
||||||
|
# merge
|
||||||
|
logs_path = [os.path.join(LOG_FOLDER, file) for file in os.listdir(LOG_FOLDER)]
|
||||||
|
recoders = {}
|
||||||
|
for path in logs_path:
|
||||||
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
|
recs = json.load(f)
|
||||||
|
for record_rank in recs['content']:
|
||||||
|
history = recs['content'][record_rank]
|
||||||
|
recoders[record_rank] = []
|
||||||
|
for rec in history:
|
||||||
|
recoders[record_rank].append({
|
||||||
|
'start': rec['start'] - base_time,
|
||||||
|
'end': rec['end'] - base_time,
|
||||||
|
'name': rec['name']
|
||||||
|
})
|
||||||
|
|
||||||
|
shutil.rmtree(LOG_FOLDER)
|
||||||
|
with open(self.export_name + '.json', 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(recoders, f, ensure_ascii=False)
|
||||||
|
|
||||||
|
def visualise_record(self):
|
||||||
|
with open(self.export_name + '.json', 'r', encoding='utf-8') as f:
|
||||||
|
records = json.load(f)
|
||||||
|
records = dict(records)
|
||||||
|
ranks = list(sorted(records.keys()))
|
||||||
|
|
||||||
|
name_list = {}
|
||||||
|
plots = {}
|
||||||
|
plt.figure(dpi=self.dpi, figsize=[self.figure_width, self.figure_height])
|
||||||
|
plt.style.use(self.theme)
|
||||||
|
|
||||||
|
for rank in ranks:
|
||||||
|
rank_records = records[rank]
|
||||||
|
for rec in rank_records:
|
||||||
|
s = rec['start']
|
||||||
|
e = rec['end']
|
||||||
|
name = rec['name']
|
||||||
|
if name not in name_list:
|
||||||
|
name_list[name] = len(name_list)
|
||||||
|
bar = plt.barh(rank, width=e - s, height=self.bar_height, left=s, color=cmap[name_list[name]])
|
||||||
|
if name not in plots:
|
||||||
|
plots[name] = bar
|
||||||
|
|
||||||
|
plt.legend(list(plots.values()), list(plots.keys()), loc="upper left", fontsize=self.legend_fontsize)
|
||||||
|
plt.yticks(ticks=ranks, labels=[f'Device:{rank}' for rank in ranks], fontsize=self.device_fontsize)
|
||||||
|
plt.grid(axis='x')
|
||||||
|
plt.savefig("{}.{}".format(self.export_name, self.export_format))
|
||||||
|
|
||||||
|
def exit_worker(self):
|
||||||
|
if len(self.rank_to_history) == 0:
|
||||||
|
return
|
||||||
|
self.dump_record()
|
||||||
|
# if this is rank 0, wait for merge
|
||||||
|
rank = dist.get_rank()
|
||||||
|
|
||||||
|
if rank == 1:
|
||||||
|
# take the base time of rank 0 as standard
|
||||||
|
self.merge_recode()
|
||||||
|
self.visualise_record()
|
||||||
|
|
||||||
|
|
||||||
|
recorder = Recorder()
|
||||||
|
atexit.register(recorder.exit_worker)
|
Loading…
Reference in New Issue