diff --git a/llama_stack/providers/impls/vllm/vllm.py b/llama_stack/providers/impls/vllm/vllm.py index c977c738d..ad3ad8fb7 100644 --- a/llama_stack/providers/impls/vllm/vllm.py +++ b/llama_stack/providers/impls/vllm/vllm.py @@ -7,7 +7,7 @@ import logging import os import uuid -from typing import Any +from typing import Any, AsyncGenerator from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import * # noqa: F403 @@ -15,7 +15,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import SamplingParams as VLLMSamplingParams from llama_stack.apis.inference import * # noqa: F403 @@ -40,10 +40,10 @@ def _random_uuid() -> str: return str(uuid.uuid4().hex) -def _vllm_sampling_params(sampling_params: Any) -> SamplingParams: +def _vllm_sampling_params(sampling_params: Any) -> VLLMSamplingParams: """Convert sampling params to vLLM sampling params.""" if sampling_params is None: - return SamplingParams() + return VLLMSamplingParams() # TODO convert what I saw in my first test ... but surely there's more to do here kwargs = { @@ -58,7 +58,7 @@ def _vllm_sampling_params(sampling_params: Any) -> SamplingParams: if sampling_params.repetition_penalty > 0: kwargs["repetition_penalty"] = sampling_params.repetition_penalty - return SamplingParams(**kwargs) + return VLLMSamplingParams(**kwargs) class VLLMInferenceImpl(ModelRegistryHelper, Inference):