diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 9e9a80ca5..77f5d82af 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import json -from collections.abc import AsyncGenerator, AsyncIterator +from collections.abc import AsyncGenerator from typing import Any import httpx @@ -38,13 +38,6 @@ from llama_stack.apis.inference import ( LogProbConfig, Message, ModelStore, - OpenAIChatCompletion, - OpenAICompletion, - OpenAIEmbeddingData, - OpenAIEmbeddingsResponse, - OpenAIEmbeddingUsage, - OpenAIMessageParam, - OpenAIResponseFormatParam, ResponseFormat, SamplingParams, TextTruncation, @@ -71,11 +64,11 @@ from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, convert_tool_call, get_sampling_options, - prepare_openai_completion_params, process_chat_completion_stream_response, process_completion_response, process_completion_stream_response, ) +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.prompt_adapter import ( completion_request_to_prompt, content_has_media, @@ -288,7 +281,7 @@ async def _process_vllm_chat_completion_stream_response( yield c -class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): +class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate): # automatically set by the resolver when instantiating the provider __provider_id__: str model_store: ModelStore | None = None @@ -296,7 +289,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, config: VLLMInferenceAdapterConfig) -> None: self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) self.config = config - self.client = None async def initialize(self) -> None: if not self.config.url: @@ -308,8 +300,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): return self.config.refresh_models async def list_models(self) -> list[Model] | None: - self._lazy_initialize_client() - assert self.client is not None # mypy models = [] async for m in self.client.models.list(): model_type = ModelType.llm # unclear how to determine embedding vs. llm models @@ -340,8 +330,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): HealthResponse: A dictionary containing the health status. """ try: - client = self._create_client() if self.client is None else self.client - _ = [m async for m in client.models.list()] # Ensure the client is initialized + _ = [m async for m in self.client.models.list()] # Ensure the client is initialized return HealthResponse(status=HealthStatus.OK) except Exception as e: return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") @@ -351,19 +340,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): raise ValueError("Model store not set") return await self.model_store.get_model(model_id) - def _lazy_initialize_client(self): - if self.client is not None: - return + def get_api_key(self): + return self.config.api_token - log.info(f"Initializing vLLM client with base_url={self.config.url}") - self.client = self._create_client() + def get_base_url(self): + return self.config.url - def _create_client(self): - return AsyncOpenAI( - base_url=self.config.url, - api_key=self.config.api_token, - http_client=httpx.AsyncClient(verify=self.config.tls_verify), - ) + def get_extra_client_params(self): + return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)} async def completion( self, @@ -374,7 +358,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): stream: bool | None = False, logprobs: LogProbConfig | None = None, ) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]: - self._lazy_initialize_client() if sampling_params is None: sampling_params = SamplingParams() model = await self._get_model(model_id) @@ -406,7 +389,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): logprobs: LogProbConfig | None = None, tool_config: ToolConfig | None = None, ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - self._lazy_initialize_client() if sampling_params is None: sampling_params = SamplingParams() model = await self._get_model(model_id) @@ -479,16 +461,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): yield chunk async def register_model(self, model: Model) -> Model: - # register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet. - # self.client should only be created after the initialization is complete to avoid asyncio cross-context errors. - # Changing this may lead to unpredictable behavior. - client = self._create_client() if self.client is None else self.client try: model = await self.register_helper.register_model(model) except ValueError: pass # Ignore statically unknown model, will check live listing try: - res = await client.models.list() + res = await self.client.models.list() except APIConnectionError as e: raise ValueError( f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL." @@ -543,8 +521,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): output_dimension: int | None = None, task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: - self._lazy_initialize_client() - assert self.client is not None model = await self._get_model(model_id) kwargs = {} @@ -560,154 +536,3 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): embeddings = [data.embedding for data in response.data] return EmbeddingsResponse(embeddings=embeddings) - - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - self._lazy_initialize_client() - assert self.client is not None - model_obj = await self._get_model(model) - assert model_obj.model_type == ModelType.embedding - - # Convert input to list if it's a string - input_list = [input] if isinstance(input, str) else input - - # Call vLLM embeddings endpoint with encoding_format - response = await self.client.embeddings.create( - model=model_obj.provider_resource_id, - input=input_list, - dimensions=dimensions, - encoding_format=encoding_format, - ) - - # Convert response to OpenAI format - data = [ - OpenAIEmbeddingData( - embedding=embedding_data.embedding, - index=i, - ) - for i, embedding_data in enumerate(response.data) - ] - - # Not returning actual token usage since vLLM doesn't provide it - usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1) - - return OpenAIEmbeddingsResponse( - data=data, - model=model_obj.provider_resource_id, - usage=usage, - ) - - async def openai_completion( - self, - model: str, - prompt: str | list[str] | list[int] | list[list[int]], - best_of: int | None = None, - echo: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_tokens: int | None = None, - n: int | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - user: str | None = None, - guided_choice: list[str] | None = None, - prompt_logprobs: int | None = None, - suffix: str | None = None, - ) -> OpenAICompletion: - self._lazy_initialize_client() - model_obj = await self._get_model(model) - - extra_body: dict[str, Any] = {} - if prompt_logprobs is not None and prompt_logprobs >= 0: - extra_body["prompt_logprobs"] = prompt_logprobs - if guided_choice: - extra_body["guided_choice"] = guided_choice - - params = await prepare_openai_completion_params( - model=model_obj.provider_resource_id, - prompt=prompt, - best_of=best_of, - echo=echo, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - logprobs=logprobs, - max_tokens=max_tokens, - n=n, - presence_penalty=presence_penalty, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - top_p=top_p, - user=user, - extra_body=extra_body, - ) - return await self.client.completions.create(**params) # type: ignore - - async def openai_chat_completion( - self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - self._lazy_initialize_client() - model_obj = await self._get_model(model) - params = await prepare_openai_completion_params( - model=model_obj.provider_resource_id, - messages=messages, - frequency_penalty=frequency_penalty, - function_call=function_call, - functions=functions, - logit_bias=logit_bias, - logprobs=logprobs, - max_completion_tokens=max_completion_tokens, - max_tokens=max_tokens, - n=n, - parallel_tool_calls=parallel_tool_calls, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - tool_choice=tool_choice, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - user=user, - ) - return await self.client.chat.completions.create(**params) # type: ignore diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index f60deee6e..a3c0ffadc 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -67,6 +67,17 @@ class OpenAIMixin(ABC): """ pass + def get_extra_client_params(self) -> dict[str, Any]: + """ + Get any extra parameters to pass to the AsyncOpenAI client. + + Child classes can override this method to provide additional parameters + such as timeout settings, proxies, etc. + + :return: A dictionary of extra parameters + """ + return {} + @property def client(self) -> AsyncOpenAI: """ @@ -78,6 +89,7 @@ class OpenAIMixin(ABC): return AsyncOpenAI( api_key=self.get_api_key(), base_url=self.get_base_url(), + **self.get_extra_client_params(), ) async def _get_provider_model_id(self, model: str) -> str: @@ -124,10 +136,15 @@ class OpenAIMixin(ABC): """ Direct OpenAI completion API call. """ - if guided_choice is not None: - logger.warning("guided_choice is not supported by the OpenAI API. Ignoring.") - if prompt_logprobs is not None: - logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.") + # Handle parameters that are not supported by OpenAI API, but may be by the provider + # prompt_logprobs is supported by vLLM + # guided_choice is supported by vLLM + # TODO: test coverage + extra_body: dict[str, Any] = {} + if prompt_logprobs is not None and prompt_logprobs >= 0: + extra_body["prompt_logprobs"] = prompt_logprobs + if guided_choice: + extra_body["guided_choice"] = guided_choice # TODO: fix openai_completion to return type compatible with OpenAI's API response return await self.client.completions.create( # type: ignore[no-any-return] @@ -150,7 +167,8 @@ class OpenAIMixin(ABC): top_p=top_p, user=user, suffix=suffix, - ) + ), + extra_body=extra_body, ) async def openai_chat_completion( diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index ce0e930b1..a48af2a1d 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -11,7 +11,7 @@ import threading import time from http.server import BaseHTTPRequestHandler, HTTPServer from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch import pytest from openai.types.chat.chat_completion_chunk import ( @@ -150,10 +150,12 @@ async def test_tool_call_response(vllm_inference_adapter): """Verify that tool call arguments from a CompletionMessage are correctly converted into the expected JSON format.""" - # Patch the call to vllm so we can inspect the arguments sent were correct - with patch.object( - vllm_inference_adapter.client.chat.completions, "create", new_callable=AsyncMock - ) as mock_nonstream_completion: + # Patch the client property to avoid instantiating a real AsyncOpenAI client + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client: + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock() + mock_create_client.return_value = mock_client + messages = [ SystemMessage(content="You are a helpful assistant"), UserMessage(content="How many?"), @@ -179,7 +181,7 @@ async def test_tool_call_response(vllm_inference_adapter): tool_config=ToolConfig(tool_choice=ToolChoice.auto), ) - assert mock_nonstream_completion.call_args.kwargs["messages"][2]["tool_calls"] == [ + assert mock_client.chat.completions.create.call_args.kwargs["messages"][2]["tool_calls"] == [ { "id": "foo", "type": "function", @@ -641,9 +643,7 @@ async def test_health_status_success(vllm_inference_adapter): This test verifies that the health method returns a HealthResponse with status OK, only when the connection to the vLLM server is successful. """ - # Set vllm_inference_adapter.client to None to ensure _create_client is called - vllm_inference_adapter.client = None - with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client: + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client: # Create mock client and models mock_client = MagicMock() mock_models = MagicMock() @@ -674,8 +674,7 @@ async def test_health_status_failure(vllm_inference_adapter): This test verifies that the health method returns a HealthResponse with status ERROR and an appropriate error message when the connection to the vLLM server fails. """ - vllm_inference_adapter.client = None - with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client: + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client: # Create mock client and models mock_client = MagicMock() mock_models = MagicMock()