mirror of https://github.com/hpcaitech/ColossalAI
[exmaple] add vit missing functions (#2154)
parent
a7d95b7024
commit
2cfe685b9f
|
@ -1,23 +1,55 @@
|
||||||
|
import os
|
||||||
|
import random
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from utils.util import set_seed, tensor_equal, tensor_shard_equal
|
|
||||||
from vit import get_training_components
|
from vit import get_training_components
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
|
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed):
|
||||||
|
random.seed(seed)
|
||||||
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_equal(A, B):
|
||||||
|
return torch.allclose(A, B, rtol=1e-3, atol=1e-1)
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
|
||||||
|
assert tensor.ndim == shard.ndim
|
||||||
|
if tensor.shape == shard.shape:
|
||||||
|
return tensor_equal(tensor, shard)
|
||||||
|
else:
|
||||||
|
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
|
||||||
|
if dims_not_eq.numel() == 1:
|
||||||
|
# 1D shard
|
||||||
|
dim = dims_not_eq.item()
|
||||||
|
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||||
|
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||||
|
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating.
|
# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating.
|
||||||
# But for other layers, it's 1d_col split.
|
# But for other layers, it's 1d_col split.
|
||||||
# Layernorm is not supported for now.
|
# Layernorm is not supported for now.
|
||||||
|
|
|
@ -1,9 +1,34 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from utils.dummy_data_generator import DummyDataGenerator
|
from transformers import ViTConfig, ViTForImageClassification
|
||||||
|
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from transformers import ViTConfig, ViTForImageClassification
|
|
||||||
|
|
||||||
|
class DummyDataGenerator(ABC):
|
||||||
|
|
||||||
|
def __init__(self, length=10):
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self.step = 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.step < self.length:
|
||||||
|
self.step += 1
|
||||||
|
return self.generate()
|
||||||
|
else:
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
|
||||||
class DummyDataLoader(DummyDataGenerator):
|
class DummyDataLoader(DummyDataGenerator):
|
||||||
|
|
Loading…
Reference in New Issue