diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py index 16c2c95dc..96357c476 100644 --- a/colossalai/pipeline/rpc/_pipeline_base.py +++ b/colossalai/pipeline/rpc/_pipeline_base.py @@ -5,6 +5,7 @@ from functools import partial from abc import ABC, abstractmethod import sys import os +import time import inspect import torch @@ -12,12 +13,13 @@ from torch import nn import torch.distributed.rpc as rpc from torch.futures import Future from torch._C._distributed_rpc import PyRRef + from torch import autograd from torch import optim from colossalai.pipeline.pipeline_process_group import ppg from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail, - pytree_map, get_real_args_kwargs, use_color_debug) + pytree_map, pytree_filter, get_real_args_kwargs, use_color_debug) class Phase(Enum): @@ -469,6 +471,7 @@ class WorkerBase(ABC): else: consume_result = self.module_partition(*args, **kwargs) + # print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', ) if is_last_stage and self.criterion: @@ -495,7 +498,6 @@ class WorkerBase(ABC): stage_input_kwargs, stage_outputs, checkpoint=use_checkpoint) - # if not forward_only, do the backward if not forward_only: if is_last_stage: # if it is the last stage, trigger backward automatic @@ -521,19 +523,19 @@ class WorkerBase(ABC): if use_checkpoint: stage_outputs = [self.module_partition(*stage_input_args, **stage_input_kwargs)] - # take tensor only (for only tensor can do backward) - stage_outputs_tensors = [] - pytree_map(stage_outputs, stage_outputs_tensors.append, process_types=torch.Tensor) - # overlap recompute and future.wait - grad_tensors = get_real_args_kwargs(args) + if not is_last_stage: + grad_tensors = get_real_args_kwargs(args) + else: + grad_tensors = None + + # take tensor only (for only tensor can do backward) + stage_outputs = pytree_filter(lambda x: x.requires_grad, stage_outputs, process_types=torch.Tensor) + grad_tensors = pytree_filter(lambda x: x is not None, grad_tensors, process_types=torch.Tensor) - # print('rank', self.pp_rank, tensor_shape_list(stage_outputs_tensors), tensor_shape_list(grad_tensors)) - autograd.backward(stage_outputs_tensors, grad_tensors=grad_tensors) + autograd.backward(stage_outputs, grad_tensors=grad_tensors) # collect grad of input tensor - # there is a hypothesis that node in kwargs cann't be an non-leaf node in graph - # so we don't need to save the grad of node in kwargs. consume_result = [] if not is_first_stage: pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor) diff --git a/colossalai/pipeline/rpc/_pipeline_schedule.py b/colossalai/pipeline/rpc/_pipeline_schedule.py index 6c4c39a73..e534943e0 100644 --- a/colossalai/pipeline/rpc/_pipeline_schedule.py +++ b/colossalai/pipeline/rpc/_pipeline_schedule.py @@ -110,7 +110,7 @@ class OneFOneBPipelineEngine(PipelineEngineBase): if chunk > 1: assert num_microbatches % stage_num == 0, \ "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" - assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk" + # assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk" use_1F1B = True super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, diff --git a/colossalai/pipeline/rpc/utils.py b/colossalai/pipeline/rpc/utils.py index 5badecedb..c4d6897f6 100644 --- a/colossalai/pipeline/rpc/utils.py +++ b/colossalai/pipeline/rpc/utils.py @@ -20,7 +20,8 @@ def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = Args: obj (:class:`Any`): object to process fn (:class:`Callable`): a function to process subobject in obj - process_types(:class: `type | tuple[type]`): types to determine the type to process + process_types (:class: `type | tuple[type]`): types to determine the type to process + map_all (:class: `bool`): if map_all is True, then any type of element will use fn Returns: :class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn` @@ -59,6 +60,20 @@ def type_detail(obj): return pytree_map(obj, lambda x: type(x), map_all=True) +def pytree_filter(fn, obj, process_types): + if obj is None: + return None + + filters = [] + + def condition_append(obj): + if fn(obj): + filters.append(obj) + + pytree_map(obj, fn=condition_append, process_types=process_types) + return filters + + def get_real_args_kwargs(args_or_kwargs): args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) # TODO : combine producer and consumer diff --git a/colossalai/utils/rank_recorder/README.md b/colossalai/utils/rank_recorder/README.md new file mode 100644 index 000000000..e30a925d2 --- /dev/null +++ b/colossalai/utils/rank_recorder/README.md @@ -0,0 +1,72 @@ +# Rank Recorder +This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualise the json file easily. + +Before using the tool, you should ensure dist.is_initialized() return true before exit of program. + +## Usage + +Is very simple: + +```python +from colossalai.utils.rank_recorder import recorder + +... +... + +with recorder(record_name, current_rank) as r: + """procedure to record + """ + +``` + +## Example +This is a demo to display kernel select in cuda and visualise the cost of several procedures in each rank. + +```python +import time +import os +import logging +logging.disable(logging.INFO) + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from colossalai.utils.rank_recorder import recorder + + +WORLD_SIZE = 4 + +# config the export image here +# If you want to dive into the detail, format 'svg' is recommended +recorder.export_format = 'png' +recorder.export_name = 'kernel_select' +recorder.dpi = 500 + +def calc(x, y): + a = torch.randn(x, y).cuda() + b = torch.randn(x, y).cuda() + c = sum(a * b) + return c + +def worker(rank): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '29020' + dist.init_process_group(backend='nccl', world_size=WORLD_SIZE, rank=rank) + print(dist.get_rank(), "enter") + time.sleep(0.1 * rank) + + with recorder("calc_1(x100)", rank) as r: + calc(100, 100) + + with recorder("calc_2(x400)", rank) as r: + calc(400, 400) + + with recorder("calc_2(x200)", rank) as r: + calc(200, 200) + +if __name__ == "__main__": + mp.spawn(worker, nprocs=WORLD_SIZE) +``` + +run the script directly and you will get `kernel_select.json` and `kernel_select.png` in your current folder. \ No newline at end of file diff --git a/colossalai/utils/rank_recorder/__init__.py b/colossalai/utils/rank_recorder/__init__.py new file mode 100644 index 000000000..1274d0e7d --- /dev/null +++ b/colossalai/utils/rank_recorder/__init__.py @@ -0,0 +1,3 @@ +from colossalai.utils.rank_recorder.rank_recorder import recorder + +__all__ = ["recorder"] \ No newline at end of file diff --git a/colossalai/utils/rank_recorder/rank_recorder.py b/colossalai/utils/rank_recorder/rank_recorder.py new file mode 100644 index 000000000..c088ceeb2 --- /dev/null +++ b/colossalai/utils/rank_recorder/rank_recorder.py @@ -0,0 +1,178 @@ +import time +from typing import List, Dict +import json +import os +import time +import shutil +import atexit + +import torch +import torch.distributed as dist + +import json +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors + +cmap = list(mcolors.TABLEAU_COLORS.values()) + +LOG_FOLDER = "record.log" +MAX_WAIT_TIME = 20 + + +class Event: + + def __init__(self, start: int, end: int, name: str, rank: int) -> None: + self.start = start + self.end = end + self.name = name + self.rank = rank + + +class Recorder: + + def __init__(self) -> None: + self.rank_to_history: Dict[int, List[Event]] = {} + self.base_time = time.time() + self.temp_event = None + + self.export_format = 'png' + self.export_name = 'test' + self.dpi = 500 + self.theme = 'dark_background' + self.figure_width = 30 + self.figure_height = 10 + self.legend_fontsize = 16 + self.device_fontsize = 20 + self.bar_height = 0.2 + + if not os.path.exists(LOG_FOLDER): + os.makedirs(LOG_FOLDER) + + def start(self, name: str, rank: int): + # TODO : add lock to prevent conflict + torch.cuda.synchronize() + start_time = time.time() + self.temp_event = Event(start_time, None, name, rank) + + def end(self): + assert self.temp_event is not None, "`start` before `end`" + torch.cuda.synchronize() + end_time = time.time() + self.temp_event.end = end_time + rank = self.temp_event.rank + if rank not in self.rank_to_history: + self.rank_to_history[rank] = [] + self.rank_to_history[rank].append(self.temp_event) + self.temp_event = None + + def get_history(self): + return self.history + + def __call__(self, name: str, rank: str): + self.temp_name = name + self.temp_rank = rank + return self + + def __enter__(self): + name = self.temp_name + rank = self.temp_rank + self.start(name, rank) + + def __exit__(self, *args): + self.end() + + def dump_record(self): + rank = dist.get_rank() + rank_to_history = self.rank_to_history + records = {'base_time': self.base_time, 'content': {}} + for record_rank in rank_to_history: + history = rank_to_history[record_rank] + recs = [] + for event in history: + rec = {'start': event.start, 'end': event.end, 'name': event.name} + recs.append(rec) + records['content'][record_rank] = recs + + dump_name = f'{rank}.json' + dump_path = os.path.join(LOG_FOLDER, dump_name) + with open(dump_path, 'w', encoding='utf-8') as f: + json.dump(records, f, ensure_ascii=False) + + def merge_recode(self): + base_time = self.base_time + world_size = dist.get_world_size() + + wait_time = 0 + while True: + time.sleep(0.1) + log_num = len(os.listdir(LOG_FOLDER)) + if log_num == world_size: + break + + wait_time += 1 + if wait_time >= MAX_WAIT_TIME: + break + + # merge + logs_path = [os.path.join(LOG_FOLDER, file) for file in os.listdir(LOG_FOLDER)] + recoders = {} + for path in logs_path: + with open(path, 'r', encoding='utf-8') as f: + recs = json.load(f) + for record_rank in recs['content']: + history = recs['content'][record_rank] + recoders[record_rank] = [] + for rec in history: + recoders[record_rank].append({ + 'start': rec['start'] - base_time, + 'end': rec['end'] - base_time, + 'name': rec['name'] + }) + + shutil.rmtree(LOG_FOLDER) + with open(self.export_name + '.json', 'w', encoding='utf-8') as f: + json.dump(recoders, f, ensure_ascii=False) + + def visualise_record(self): + with open(self.export_name + '.json', 'r', encoding='utf-8') as f: + records = json.load(f) + records = dict(records) + ranks = list(sorted(records.keys())) + + name_list = {} + plots = {} + plt.figure(dpi=self.dpi, figsize=[self.figure_width, self.figure_height]) + plt.style.use(self.theme) + + for rank in ranks: + rank_records = records[rank] + for rec in rank_records: + s = rec['start'] + e = rec['end'] + name = rec['name'] + if name not in name_list: + name_list[name] = len(name_list) + bar = plt.barh(rank, width=e - s, height=self.bar_height, left=s, color=cmap[name_list[name]]) + if name not in plots: + plots[name] = bar + + plt.legend(list(plots.values()), list(plots.keys()), loc="upper left", fontsize=self.legend_fontsize) + plt.yticks(ticks=ranks, labels=[f'Device:{rank}' for rank in ranks], fontsize=self.device_fontsize) + plt.grid(axis='x') + plt.savefig("{}.{}".format(self.export_name, self.export_format)) + + def exit_worker(self): + if len(self.rank_to_history) == 0: + return + self.dump_record() + # if this is rank 0, wait for merge + rank = dist.get_rank() + + if rank == 1: + # take the base time of rank 0 as standard + self.merge_recode() + self.visualise_record() + + +recorder = Recorder() +atexit.register(recorder.exit_worker)