mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
fix: Swap to AsyncOpenAI client in remote vllm provider (#1459)
# What does this PR do? This switches from an OpenAI client to the AsyncOpenAI client in the remote vllm provider. The main benefit of this is that instead of each client call being a blocking operation that was blocking our server event loop, the client calls are now async operations that do not block the event loop. The actual fix is quite simple and straightforward. Creating a reliable reproducer of this with a unit test that verifies we were blocking the event loop before and are not blocking it any longer was a bit harder. Some other inference providers have this same issue, so we may want to make that simple delayed http server a bit more generic and pull it into a common place as other inference providers get fixed. (Closes #1457) ## Test Plan I verified the unit tests and test_text_inference tests pass with this change like below: ``` python -m pytest -v tests/unit ``` ``` VLLM_URL="http://localhost:8000/v1" \ INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" \ LLAMA_STACK_CONFIG=remote-vllm \ python -m pytest -v -s \ tests/integration/inference/test_text_inference.py \ --text-model "meta-llama/Llama-3.2-3B-Instruct" ``` Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
256448c14e
commit
d86a893ead
2 changed files with 107 additions and 29 deletions
|
@ -7,7 +7,7 @@ import json
|
|||
import logging
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from openai import OpenAI
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||
)
|
||||
|
@ -229,7 +229,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"Initializing VLLM client with base_url={self.config.url}")
|
||||
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
self.client = AsyncOpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
@ -300,10 +300,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
return await self._nonstream_chat_completion(request, self.client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
self, request: ChatCompletionRequest, client: AsyncOpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = client.chat.completions.create(**params)
|
||||
r = await client.chat.completions.create(**params)
|
||||
choice = r.choices[0]
|
||||
result = ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
|
@ -315,17 +315,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
return result
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: AsyncOpenAI) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
# TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async
|
||||
# generator so this wrapper is not necessary?
|
||||
async def _to_async_generator():
|
||||
s = client.chat.completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
stream = await client.chat.completions.create(**params)
|
||||
if len(request.tools) > 0:
|
||||
res = _process_vllm_chat_completion_stream_response(stream)
|
||||
else:
|
||||
|
@ -335,26 +328,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = self.client.completions.create(**params)
|
||||
r = await self.client.completions.create(**params)
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
# Wrapper for async generator similar
|
||||
async def _to_async_generator():
|
||||
stream = self.client.completions.create(**params)
|
||||
for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
stream = await self.client.completions.create(**params)
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.register_helper.register_model(model)
|
||||
res = self.client.models.list()
|
||||
available_models = [m.id for m in res]
|
||||
res = await self.client.models.list()
|
||||
available_models = [m.id async for m in res]
|
||||
if model.provider_resource_id not in available_models:
|
||||
raise ValueError(
|
||||
f"Model {model.provider_resource_id} is not being served by vLLM. "
|
||||
|
@ -410,7 +397,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
assert model.metadata.get("embedding_dimension")
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
|
||||
assert all(not content_has_media(content) for content in contents), "VLLM does not support media for embeddings"
|
||||
response = self.client.embeddings.create(
|
||||
response = await self.client.embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
**kwargs,
|
||||
|
|
|
@ -4,6 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
@ -39,9 +46,41 @@ from llama_stack.providers.remote.inference.vllm.vllm import (
|
|||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
class MockInferenceAdapterWithSleep:
|
||||
def __init__(self, sleep_time: int, response: Dict[str, Any]):
|
||||
self.httpd = None
|
||||
|
||||
class DelayedRequestHandler(BaseHTTPRequestHandler):
|
||||
# ruff: noqa: N802
|
||||
def do_POST(self):
|
||||
time.sleep(sleep_time)
|
||||
self.send_response(code=200)
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(response).encode("utf-8"))
|
||||
|
||||
self.request_handler = DelayedRequestHandler
|
||||
|
||||
def __enter__(self):
|
||||
httpd = HTTPServer(("", 0), self.request_handler)
|
||||
self.httpd = httpd
|
||||
host, port = httpd.server_address
|
||||
httpd_thread = threading.Thread(target=httpd.serve_forever)
|
||||
httpd_thread.daemon = True # stop server if this thread terminates
|
||||
httpd_thread.start()
|
||||
|
||||
config = VLLMInferenceAdapterConfig(url=f"http://{host}:{port}")
|
||||
inference_adapter = VLLMInferenceAdapter(config)
|
||||
return inference_adapter
|
||||
|
||||
def __exit__(self, _exc_type, _exc_value, _traceback):
|
||||
if self.httpd:
|
||||
self.httpd.shutdown()
|
||||
self.httpd.server_close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mock_openai_models_list():
|
||||
with patch("openai.resources.models.Models.list") as mock_list:
|
||||
with patch("openai.resources.models.AsyncModels.list", new_callable=AsyncMock) as mock_list:
|
||||
yield mock_list
|
||||
|
||||
|
||||
|
@ -56,10 +95,10 @@ async def vllm_inference_adapter():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inference_adapter):
|
||||
mock_openai_models = [
|
||||
OpenAIModel(id="foo", created=1, object="model", owned_by="test"),
|
||||
]
|
||||
mock_openai_models_list.return_value = mock_openai_models
|
||||
async def mock_openai_models():
|
||||
yield OpenAIModel(id="foo", created=1, object="model", owned_by="test")
|
||||
|
||||
mock_openai_models_list.return_value = mock_openai_models()
|
||||
|
||||
foo_model = Model(identifier="foo", provider_resource_id="foo", provider_id="vllm-inference")
|
||||
|
||||
|
@ -141,3 +180,55 @@ async def test_process_vllm_chat_completion_stream_response_no_choices():
|
|||
|
||||
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
|
||||
assert len(chunks) == 0
|
||||
|
||||
|
||||
def test_chat_completion_doesnt_block_event_loop(caplog):
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.set_debug(True)
|
||||
caplog.set_level(logging.WARNING)
|
||||
|
||||
# Log when event loop is blocked for more than 100ms
|
||||
loop.slow_callback_duration = 0.1
|
||||
# Sleep for 500ms in our delayed http response
|
||||
sleep_time = 0.5
|
||||
|
||||
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
|
||||
mock_response = {
|
||||
"id": "chatcmpl-abc123",
|
||||
"object": "chat.completion",
|
||||
"created": 1,
|
||||
"modle": "mock-model",
|
||||
"choices": [
|
||||
{
|
||||
"message": {"content": ""},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
async def do_chat_completion():
|
||||
await inference_adapter.chat_completion(
|
||||
"mock-model",
|
||||
[],
|
||||
stream=False,
|
||||
tools=None,
|
||||
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
|
||||
)
|
||||
|
||||
with MockInferenceAdapterWithSleep(sleep_time, mock_response) as inference_adapter:
|
||||
inference_adapter.model_store = AsyncMock()
|
||||
inference_adapter.model_store.get_model.return_value = mock_model
|
||||
loop.run_until_complete(inference_adapter.initialize())
|
||||
|
||||
# Clear the logs so far and run the actual chat completion we care about
|
||||
caplog.clear()
|
||||
loop.run_until_complete(do_chat_completion())
|
||||
|
||||
# Ensure we don't have any asyncio warnings in the captured log
|
||||
# records from our chat completion call. A message gets logged
|
||||
# here any time we exceed the slow_callback_duration configured
|
||||
# above.
|
||||
asyncio_warnings = [record.message for record in caplog.records if record.name == "asyncio"]
|
||||
assert not asyncio_warnings
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue