diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py
index 6f2fc104f..67b0bef50 100644
--- a/tests/test_booster/test_plugin/test_3d_plugin.py
+++ b/tests/test_booster/test_plugin/test_3d_plugin.py
@@ -8,13 +8,14 @@ from torch.testing import assert_close
 from torch.utils.data import Dataset
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.booster import Booster
 from colossalai.booster.plugin import HybridParallelPlugin
 from colossalai.fx import is_compatible_with_meta
 from colossalai.lazy.lazy_init import LazyInitContext
 from colossalai.nn.optimizer import HybridAdam
 from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_current_device, set_seed
+from colossalai.utils import set_seed
 from tests.kit.model_zoo import model_zoo
 
 
@@ -23,7 +24,9 @@ class RandomDataset(Dataset):
         self.num_samples = num_samples
         self.max_length = max_length
         set_seed(42)
-        self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
+        self.input_ids = torch.randint(
+            0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
+        )
         self.attention_mask = torch.ones_like(self.input_ids)
 
     def __len__(self):