fix: Swap to AsyncOpenAI client in remote vllm provider (#1459)

# What does this PR do?

This switches from an OpenAI client to the AsyncOpenAI client in the
remote vllm provider. The main benefit of this is that instead of each
client call being a blocking operation that was blocking our server
event loop, the client calls are now async operations that do not block
the event loop.

The actual fix is quite simple and straightforward. Creating a reliable
reproducer of this with a unit test that verifies we were blocking the
event loop before and are not blocking it any longer was a bit harder.
Some other inference providers have this same issue, so we may want to
make that simple delayed http server a bit more generic and pull it into
a common place as other inference providers get fixed.

(Closes #1457)

## Test Plan

I verified the unit tests and test_text_inference tests pass with this
change like below:

```
python -m pytest -v tests/unit
```

```
VLLM_URL="http://localhost:8000/v1" \
INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" \
LLAMA_STACK_CONFIG=remote-vllm \
python -m pytest -v -s \
tests/integration/inference/test_text_inference.py \
--text-model "meta-llama/Llama-3.2-3B-Instruct"
```

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
Ben Browning 2025-03-07 14:48:00 -05:00 committed by GitHub
parent 256448c14e
commit d86a893ead
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 107 additions and 29 deletions

View file

@ -7,7 +7,7 @@ import json
import logging
from typing import AsyncGenerator, List, Optional, Union
from openai import OpenAI
from openai import AsyncOpenAI
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
@ -229,7 +229,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def initialize(self) -> None:
log.info(f"Initializing VLLM client with base_url={self.config.url}")
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
self.client = AsyncOpenAI(base_url=self.config.url, api_key=self.config.api_token)
async def shutdown(self) -> None:
pass
@ -300,10 +300,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
return await self._nonstream_chat_completion(request, self.client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> ChatCompletionResponse:
params = await self._get_params(request)
r = client.chat.completions.create(**params)
r = await client.chat.completions.create(**params)
choice = r.choices[0]
result = ChatCompletionResponse(
completion_message=CompletionMessage(
@ -315,17 +315,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
)
return result
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: AsyncOpenAI) -> AsyncGenerator:
params = await self._get_params(request)
# TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async
# generator so this wrapper is not necessary?
async def _to_async_generator():
s = client.chat.completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
stream = await client.chat.completions.create(**params)
if len(request.tools) > 0:
res = _process_vllm_chat_completion_stream_response(stream)
else:
@ -335,26 +328,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = self.client.completions.create(**params)
r = await self.client.completions.create(**params)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
# Wrapper for async generator similar
async def _to_async_generator():
stream = self.client.completions.create(**params)
for chunk in stream:
yield chunk
stream = _to_async_generator()
stream = await self.client.completions.create(**params)
async for chunk in process_completion_stream_response(stream):
yield chunk
async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
res = self.client.models.list()
available_models = [m.id for m in res]
res = await self.client.models.list()
available_models = [m.id async for m in res]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model {model.provider_resource_id} is not being served by vLLM. "
@ -410,7 +397,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
assert model.metadata.get("embedding_dimension")
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
assert all(not content_has_media(content) for content in contents), "VLLM does not support media for embeddings"
response = self.client.embeddings.create(
response = await self.client.embeddings.create(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
**kwargs,