mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix 3d plugin test (#5292)
parent
d66e6988bc
commit
d7f8db8e21
|
@ -8,13 +8,14 @@ from torch.testing import assert_close
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import HybridParallelPlugin
|
from colossalai.booster.plugin import HybridParallelPlugin
|
||||||
from colossalai.fx import is_compatible_with_meta
|
from colossalai.fx import is_compatible_with_meta
|
||||||
from colossalai.lazy.lazy_init import LazyInitContext
|
from colossalai.lazy.lazy_init import LazyInitContext
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.utils import get_current_device, set_seed
|
from colossalai.utils import set_seed
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,7 +24,9 @@ class RandomDataset(Dataset):
|
||||||
self.num_samples = num_samples
|
self.num_samples = num_samples
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
|
self.input_ids = torch.randint(
|
||||||
|
0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
|
||||||
|
)
|
||||||
self.attention_mask = torch.ones_like(self.input_ids)
|
self.attention_mask = torch.ones_like(self.input_ids)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|
Loading…
Reference in New Issue