[exmaple] add vit missing functions (#2154)

pull/2155/head^2
Jiarui Fang 2022-12-20 15:03:26 +08:00 committed by GitHub
parent a7d95b7024
commit 2cfe685b9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 4 deletions

View File

@ -1,23 +1,55 @@
import os
import random
from functools import partial from functools import partial
import numpy as np
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from utils.util import set_seed, tensor_equal, tensor_shard_equal
from vit import get_training_components from vit import get_training_components
import colossalai import colossalai
from colossalai.context import ParallelMode
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
def set_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def tensor_equal(A, B):
return torch.allclose(A, B, rtol=1e-3, atol=1e-1)
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
assert tensor.ndim == shard.ndim
if tensor.shape == shard.shape:
return tensor_equal(tensor, shard)
else:
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
if dims_not_eq.numel() == 1:
# 1D shard
dim = dims_not_eq.item()
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
else:
raise
# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating. # Only for all Linear, it's 1d_row split because Linear will be transposed when calculating.
# But for other layers, it's 1d_col split. # But for other layers, it's 1d_col split.
# Layernorm is not supported for now. # Layernorm is not supported for now.

View File

@ -1,9 +1,34 @@
from abc import ABC, abstractmethod
import torch import torch
import torch.nn as nn import torch.nn as nn
from utils.dummy_data_generator import DummyDataGenerator from transformers import ViTConfig, ViTForImageClassification
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from transformers import ViTConfig, ViTForImageClassification
class DummyDataGenerator(ABC):
def __init__(self, length=10):
self.length = length
@abstractmethod
def generate(self):
pass
def __iter__(self):
self.step = 0
return self
def __next__(self):
if self.step < self.length:
self.step += 1
return self.generate()
else:
raise StopIteration
def __len__(self):
return self.length
class DummyDataLoader(DummyDataGenerator): class DummyDataLoader(DummyDataGenerator):