From 2cfe685b9ff4a30e23a28ac0ad04150ca3082e52 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 20 Dec 2022 15:03:26 +0800 Subject: [PATCH] [exmaple] add vit missing functions (#2154) --- examples/images/vit/test_vit.py | 36 +++++++++++++++++++++++++++++++-- examples/images/vit/vit.py | 29 ++++++++++++++++++++++++-- 2 files changed, 61 insertions(+), 4 deletions(-) diff --git a/examples/images/vit/test_vit.py b/examples/images/vit/test_vit.py index 7dbbe607e..90f2475b8 100644 --- a/examples/images/vit/test_vit.py +++ b/examples/images/vit/test_vit.py @@ -1,23 +1,55 @@ +import os +import random from functools import partial +import numpy as np import pytest import torch import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP -from utils.util import set_seed, tensor_equal, tensor_shard_equal from vit import get_training_components import colossalai +from colossalai.context import ParallelMode from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.nn.parallel.data_parallel import ColoDDP -from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec +from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext +def set_seed(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + +def tensor_equal(A, B): + return torch.allclose(A, B, rtol=1e-3, atol=1e-1) + + +def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor): + assert tensor.ndim == shard.ndim + if tensor.shape == shard.shape: + return tensor_equal(tensor, shard) + else: + dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape)) + if dims_not_eq.numel() == 1: + # 1D shard + dim = dims_not_eq.item() + world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + return tensor_equal(tensor.chunk(world_size, dim)[rank], shard) + else: + raise + + # Only for all Linear, it's 1d_row split because Linear will be transposed when calculating. # But for other layers, it's 1d_col split. # Layernorm is not supported for now. diff --git a/examples/images/vit/vit.py b/examples/images/vit/vit.py index 1116c7416..14c870b39 100644 --- a/examples/images/vit/vit.py +++ b/examples/images/vit/vit.py @@ -1,9 +1,34 @@ +from abc import ABC, abstractmethod + import torch import torch.nn as nn -from utils.dummy_data_generator import DummyDataGenerator +from transformers import ViTConfig, ViTForImageClassification from colossalai.utils.cuda import get_current_device -from transformers import ViTConfig, ViTForImageClassification + + +class DummyDataGenerator(ABC): + + def __init__(self, length=10): + self.length = length + + @abstractmethod + def generate(self): + pass + + def __iter__(self): + self.step = 0 + return self + + def __next__(self): + if self.step < self.length: + self.step += 1 + return self.generate() + else: + raise StopIteration + + def __len__(self): + return self.length class DummyDataLoader(DummyDataGenerator):