diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index ac9a46e85..4d7e66d78 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -7,7 +7,7 @@ import json import logging from typing import AsyncGenerator, List, Optional, Union -from openai import OpenAI +from openai import AsyncOpenAI from openai.types.chat.chat_completion_chunk import ( ChatCompletionChunk as OpenAIChatCompletionChunk, ) @@ -229,7 +229,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async def initialize(self) -> None: log.info(f"Initializing VLLM client with base_url={self.config.url}") - self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) + self.client = AsyncOpenAI(base_url=self.config.url, api_key=self.config.api_token) async def shutdown(self) -> None: pass @@ -300,10 +300,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): return await self._nonstream_chat_completion(request, self.client) async def _nonstream_chat_completion( - self, request: ChatCompletionRequest, client: OpenAI + self, request: ChatCompletionRequest, client: AsyncOpenAI ) -> ChatCompletionResponse: params = await self._get_params(request) - r = client.chat.completions.create(**params) + r = await client.chat.completions.create(**params) choice = r.choices[0] result = ChatCompletionResponse( completion_message=CompletionMessage( @@ -315,17 +315,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): ) return result - async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: + async def _stream_chat_completion(self, request: ChatCompletionRequest, client: AsyncOpenAI) -> AsyncGenerator: params = await self._get_params(request) - # TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async - # generator so this wrapper is not necessary? - async def _to_async_generator(): - s = client.chat.completions.create(**params) - for chunk in s: - yield chunk - - stream = _to_async_generator() + stream = await client.chat.completions.create(**params) if len(request.tools) > 0: res = _process_vllm_chat_completion_stream_response(stream) else: @@ -335,26 +328,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: params = await self._get_params(request) - r = self.client.completions.create(**params) + r = await self.client.completions.create(**params) return process_completion_response(r) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) - # Wrapper for async generator similar - async def _to_async_generator(): - stream = self.client.completions.create(**params) - for chunk in stream: - yield chunk - - stream = _to_async_generator() + stream = await self.client.completions.create(**params) async for chunk in process_completion_stream_response(stream): yield chunk async def register_model(self, model: Model) -> Model: model = await self.register_helper.register_model(model) - res = self.client.models.list() - available_models = [m.id for m in res] + res = await self.client.models.list() + available_models = [m.id async for m in res] if model.provider_resource_id not in available_models: raise ValueError( f"Model {model.provider_resource_id} is not being served by vLLM. " @@ -410,7 +397,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): assert model.metadata.get("embedding_dimension") kwargs["dimensions"] = model.metadata.get("embedding_dimension") assert all(not content_has_media(content) for content in contents), "VLLM does not support media for embeddings" - response = self.client.embeddings.create( + response = await self.client.embeddings.create( model=model.provider_resource_id, input=[interleaved_content_as_str(content) for content in contents], **kwargs, diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 11b1ba123..3afe1389e 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -4,6 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio +import json +import logging +import threading +import time +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Any, Dict from unittest.mock import AsyncMock, patch import pytest @@ -39,9 +46,41 @@ from llama_stack.providers.remote.inference.vllm.vllm import ( # -v -s --tb=short --disable-warnings +class MockInferenceAdapterWithSleep: + def __init__(self, sleep_time: int, response: Dict[str, Any]): + self.httpd = None + + class DelayedRequestHandler(BaseHTTPRequestHandler): + # ruff: noqa: N802 + def do_POST(self): + time.sleep(sleep_time) + self.send_response(code=200) + self.end_headers() + self.wfile.write(json.dumps(response).encode("utf-8")) + + self.request_handler = DelayedRequestHandler + + def __enter__(self): + httpd = HTTPServer(("", 0), self.request_handler) + self.httpd = httpd + host, port = httpd.server_address + httpd_thread = threading.Thread(target=httpd.serve_forever) + httpd_thread.daemon = True # stop server if this thread terminates + httpd_thread.start() + + config = VLLMInferenceAdapterConfig(url=f"http://{host}:{port}") + inference_adapter = VLLMInferenceAdapter(config) + return inference_adapter + + def __exit__(self, _exc_type, _exc_value, _traceback): + if self.httpd: + self.httpd.shutdown() + self.httpd.server_close() + + @pytest.fixture(scope="module") def mock_openai_models_list(): - with patch("openai.resources.models.Models.list") as mock_list: + with patch("openai.resources.models.AsyncModels.list", new_callable=AsyncMock) as mock_list: yield mock_list @@ -56,10 +95,10 @@ async def vllm_inference_adapter(): @pytest.mark.asyncio async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inference_adapter): - mock_openai_models = [ - OpenAIModel(id="foo", created=1, object="model", owned_by="test"), - ] - mock_openai_models_list.return_value = mock_openai_models + async def mock_openai_models(): + yield OpenAIModel(id="foo", created=1, object="model", owned_by="test") + + mock_openai_models_list.return_value = mock_openai_models() foo_model = Model(identifier="foo", provider_resource_id="foo", provider_id="vllm-inference") @@ -141,3 +180,55 @@ async def test_process_vllm_chat_completion_stream_response_no_choices(): chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] assert len(chunks) == 0 + + +def test_chat_completion_doesnt_block_event_loop(caplog): + loop = asyncio.new_event_loop() + loop.set_debug(True) + caplog.set_level(logging.WARNING) + + # Log when event loop is blocked for more than 100ms + loop.slow_callback_duration = 0.1 + # Sleep for 500ms in our delayed http response + sleep_time = 0.5 + + mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference") + mock_response = { + "id": "chatcmpl-abc123", + "object": "chat.completion", + "created": 1, + "modle": "mock-model", + "choices": [ + { + "message": {"content": ""}, + "logprobs": None, + "finish_reason": "stop", + "index": 0, + } + ], + } + + async def do_chat_completion(): + await inference_adapter.chat_completion( + "mock-model", + [], + stream=False, + tools=None, + tool_config=ToolConfig(tool_choice=ToolChoice.auto), + ) + + with MockInferenceAdapterWithSleep(sleep_time, mock_response) as inference_adapter: + inference_adapter.model_store = AsyncMock() + inference_adapter.model_store.get_model.return_value = mock_model + loop.run_until_complete(inference_adapter.initialize()) + + # Clear the logs so far and run the actual chat completion we care about + caplog.clear() + loop.run_until_complete(do_chat_completion()) + + # Ensure we don't have any asyncio warnings in the captured log + # records from our chat completion call. A message gets logged + # here any time we exceed the slow_callback_duration configured + # above. + asyncio_warnings = [record.message for record in caplog.records if record.name == "asyncio"] + assert not asyncio_warnings