mirror of https://github.com/hpcaitech/ColossalAI
[exmaple] add vit missing functions (#2154)
parent
a7d95b7024
commit
2cfe685b9f
|
@ -1,23 +1,55 @@
|
|||
import os
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from utils.util import set_seed, tensor_equal, tensor_shard_equal
|
||||
from vit import get_training_components
|
||||
|
||||
import colossalai
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
|
||||
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def tensor_equal(A, B):
|
||||
return torch.allclose(A, B, rtol=1e-3, atol=1e-1)
|
||||
|
||||
|
||||
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
|
||||
assert tensor.ndim == shard.ndim
|
||||
if tensor.shape == shard.shape:
|
||||
return tensor_equal(tensor, shard)
|
||||
else:
|
||||
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
|
||||
if dims_not_eq.numel() == 1:
|
||||
# 1D shard
|
||||
dim = dims_not_eq.item()
|
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating.
|
||||
# But for other layers, it's 1d_col split.
|
||||
# Layernorm is not supported for now.
|
||||
|
|
|
@ -1,9 +1,34 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from utils.dummy_data_generator import DummyDataGenerator
|
||||
from transformers import ViTConfig, ViTForImageClassification
|
||||
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from transformers import ViTConfig, ViTForImageClassification
|
||||
|
||||
|
||||
class DummyDataGenerator(ABC):
|
||||
|
||||
def __init__(self, length=10):
|
||||
self.length = length
|
||||
|
||||
@abstractmethod
|
||||
def generate(self):
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
self.step = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.step < self.length:
|
||||
self.step += 1
|
||||
return self.generate()
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
class DummyDataLoader(DummyDataGenerator):
|
||||
|
|
Loading…
Reference in New Issue