From fcda5e976c50eec8dc37e649f1868a913f79445a Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Thu, 11 Sep 2025 10:32:04 -0400 Subject: [PATCH] chore(unit tests): remove network use, update async test --- .../providers/inference/test_remote_vllm.py | 160 +++++++----------- 1 file changed, 60 insertions(+), 100 deletions(-) diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index a48af2a1d..61b16b5d1 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -6,11 +6,7 @@ import asyncio import json -import logging # allow-direct-logging -import threading import time -from http.server import BaseHTTPRequestHandler, HTTPServer -from typing import Any from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch import pytest @@ -18,7 +14,7 @@ from openai.types.chat.chat_completion_chunk import ( ChatCompletionChunk as OpenAIChatCompletionChunk, ) from openai.types.chat.chat_completion_chunk import ( - Choice as OpenAIChoice, + Choice as OpenAIChoiceChunk, ) from openai.types.chat.chat_completion_chunk import ( ChoiceDelta as OpenAIChoiceDelta, @@ -35,6 +31,9 @@ from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponseEventType, CompletionMessage, + OpenAIAssistantMessageParam, + OpenAIChatCompletion, + OpenAIChoice, SystemMessage, ToolChoice, ToolConfig, @@ -61,41 +60,6 @@ from llama_stack.providers.remote.inference.vllm.vllm import ( # -v -s --tb=short --disable-warnings -class MockInferenceAdapterWithSleep: - def __init__(self, sleep_time: int, response: dict[str, Any]): - self.httpd = None - - class DelayedRequestHandler(BaseHTTPRequestHandler): - # ruff: noqa: N802 - def do_POST(self): - time.sleep(sleep_time) - response_body = json.dumps(response).encode("utf-8") - self.send_response(code=200) - self.send_header("Content-Type", "application/json") - self.send_header("Content-Length", len(response_body)) - self.end_headers() - self.wfile.write(response_body) - - self.request_handler = DelayedRequestHandler - - def __enter__(self): - httpd = HTTPServer(("", 0), self.request_handler) - self.httpd = httpd - host, port = httpd.server_address - httpd_thread = threading.Thread(target=httpd.serve_forever) - httpd_thread.daemon = True # stop server if this thread terminates - httpd_thread.start() - - config = VLLMInferenceAdapterConfig(url=f"http://{host}:{port}") - inference_adapter = VLLMInferenceAdapter(config) - return inference_adapter - - def __exit__(self, _exc_type, _exc_value, _traceback): - if self.httpd: - self.httpd.shutdown() - self.httpd.server_close() - - @pytest.fixture(scope="module") def mock_openai_models_list(): with patch("openai.resources.models.AsyncModels.list", new_callable=AsyncMock) as mock_list: @@ -201,7 +165,7 @@ async def test_tool_call_delta_empty_tool_call_buf(): async def mock_stream(): delta = OpenAIChoiceDelta(content="", tool_calls=None) - choices = [OpenAIChoice(delta=delta, finish_reason="stop", index=0)] + choices = [OpenAIChoiceChunk(delta=delta, finish_reason="stop", index=0)] mock_chunk = OpenAIChatCompletionChunk( id="chunk-1", created=1, @@ -227,7 +191,7 @@ async def test_tool_call_delta_streaming_arguments_dict(): model="foo", object="chat.completion.chunk", choices=[ - OpenAIChoice( + OpenAIChoiceChunk( delta=OpenAIChoiceDelta( content="", tool_calls=[ @@ -252,7 +216,7 @@ async def test_tool_call_delta_streaming_arguments_dict(): model="foo", object="chat.completion.chunk", choices=[ - OpenAIChoice( + OpenAIChoiceChunk( delta=OpenAIChoiceDelta( content="", tool_calls=[ @@ -277,7 +241,9 @@ async def test_tool_call_delta_streaming_arguments_dict(): model="foo", object="chat.completion.chunk", choices=[ - OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0) + OpenAIChoiceChunk( + delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0 + ) ], ) for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]: @@ -301,7 +267,7 @@ async def test_multiple_tool_calls(): model="foo", object="chat.completion.chunk", choices=[ - OpenAIChoice( + OpenAIChoiceChunk( delta=OpenAIChoiceDelta( content="", tool_calls=[ @@ -326,7 +292,7 @@ async def test_multiple_tool_calls(): model="foo", object="chat.completion.chunk", choices=[ - OpenAIChoice( + OpenAIChoiceChunk( delta=OpenAIChoiceDelta( content="", tool_calls=[ @@ -351,7 +317,9 @@ async def test_multiple_tool_calls(): model="foo", object="chat.completion.chunk", choices=[ - OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0) + OpenAIChoiceChunk( + delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0 + ) ], ) for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]: @@ -395,59 +363,6 @@ async def test_process_vllm_chat_completion_stream_response_no_choices(): assert chunks[0].event.event_type.value == "start" -@pytest.mark.allow_network -def test_chat_completion_doesnt_block_event_loop(caplog): - loop = asyncio.new_event_loop() - loop.set_debug(True) - caplog.set_level(logging.WARNING) - - # Log when event loop is blocked for more than 200ms - loop.slow_callback_duration = 0.5 - # Sleep for 500ms in our delayed http response - sleep_time = 0.5 - - mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference") - mock_response = { - "id": "chatcmpl-abc123", - "object": "chat.completion", - "created": 1, - "modle": "mock-model", - "choices": [ - { - "message": {"content": ""}, - "logprobs": None, - "finish_reason": "stop", - "index": 0, - } - ], - } - - async def do_chat_completion(): - await inference_adapter.chat_completion( - "mock-model", - [], - stream=False, - tools=None, - tool_config=ToolConfig(tool_choice=ToolChoice.auto), - ) - - with MockInferenceAdapterWithSleep(sleep_time, mock_response) as inference_adapter: - inference_adapter.model_store = AsyncMock() - inference_adapter.model_store.get_model.return_value = mock_model - loop.run_until_complete(inference_adapter.initialize()) - - # Clear the logs so far and run the actual chat completion we care about - caplog.clear() - loop.run_until_complete(do_chat_completion()) - - # Ensure we don't have any asyncio warnings in the captured log - # records from our chat completion call. A message gets logged - # here any time we exceed the slow_callback_duration configured - # above. - asyncio_warnings = [record.message for record in caplog.records if record.name == "asyncio"] - assert not asyncio_warnings - - async def test_get_params_empty_tools(vllm_inference_adapter): request = ChatCompletionRequest( tools=[], @@ -696,3 +611,48 @@ async def test_health_status_failure(vllm_inference_adapter): assert "Health check failed: Connection failed" in health_response["message"] mock_models.list.assert_called_once() + + +async def test_openai_chat_completion_is_async(vllm_inference_adapter): + """ + Verify that openai_chat_completion is async and doesn't block the event loop. + + To do this we mock the underlying inference with a sleep, start multiple + inference calls in parallel, and ensure the total time taken is less + than the sum of the individual sleep times. + """ + sleep_time = 0.5 + + async def mock_create(*args, **kwargs): + await asyncio.sleep(sleep_time) + return OpenAIChatCompletion( + id="chatcmpl-abc123", + created=1, + model="mock-model", + choices=[ + OpenAIChoice( + message=OpenAIAssistantMessageParam( + content="nothing interesting", + ), + finish_reason="stop", + index=0, + ) + ], + ) + + async def do_inference(): + await vllm_inference_adapter.openai_chat_completion( + "mock-model", messages=["one fish", "two fish"], stream=False + ) + + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client: + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(side_effect=mock_create) + mock_create_client.return_value = mock_client + + start_time = time.time() + await asyncio.gather(do_inference(), do_inference(), do_inference(), do_inference()) + total_time = time.time() - start_time + + assert mock_create_client.call_count == 4 # no cheating + assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"