mirror of https://github.com/hpcaitech/ColossalAI
* [Hot Fix] CI,import,requirements-test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5894/head^2
parent
cba20525a8
commit
66abf1c6e8
|
@ -57,11 +57,11 @@ class LLMEngine(BaseEngine):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model_or_path: nn.Module | str,
|
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = None,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
|
||||
inference_config: InferenceConfig = None,
|
||||
verbose: bool = False,
|
||||
model_policy: Policy | type[Policy] = None,
|
||||
model_policy: Union[Policy, type[Policy]] = None,
|
||||
) -> None:
|
||||
self.inference_config = inference_config
|
||||
self.dtype = inference_config.dtype
|
||||
|
|
|
@ -186,8 +186,6 @@ def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]):
|
|||
"""
|
||||
|
||||
try:
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
DiffusionPipeline.load_config(model_or_path)
|
||||
return ModelType.DIFFUSION_MODEL
|
||||
except:
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
#!/bin/bash
|
||||
echo "Skip the test (this test is slow)"
|
|
@ -1,4 +1,3 @@
|
|||
diffusers
|
||||
pytest
|
||||
coverage==7.2.3
|
||||
git+https://github.com/hpcaitech/pytest-testmon
|
||||
|
|
Loading…
Reference in New Issue