mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
chore: update the vLLM inference impl to use OpenAIMixin for openai-compat functions (#3404)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 1s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 7s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 3s
Python Package Build Test / build (3.12) (push) Failing after 2s
Python Package Build Test / build (3.13) (push) Failing after 1s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Test Llama Stack Build / build-single-provider (push) Failing after 5s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 4s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Test Llama Stack Build / build (push) Failing after 3s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Update ReadTheDocs / update-readthedocs (push) Failing after 3s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
UI Tests / ui-tests (22) (push) Successful in 31s
Pre-commit / pre-commit (push) Successful in 1m18s
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 1s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 7s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 3s
Python Package Build Test / build (3.12) (push) Failing after 2s
Python Package Build Test / build (3.13) (push) Failing after 1s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Test Llama Stack Build / build-single-provider (push) Failing after 5s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 4s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Test Llama Stack Build / build (push) Failing after 3s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Update ReadTheDocs / update-readthedocs (push) Failing after 3s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
UI Tests / ui-tests (22) (push) Successful in 31s
Pre-commit / pre-commit (push) Successful in 1m18s
# What does this PR do? update vLLM inference provider to use OpenAIMixin for openai-compat functions inference recordings from Qwen3-0.6B and vLLM 0.8.3 - ``` docker run --gpus all -v ~/.cache/huggingface:/root/.cache/huggingface -p 8000:8000 --ipc=host \ vllm/vllm-openai:latest \ --model Qwen/Qwen3-0.6B --enable-auto-tool-choice --tool-call-parser hermes ``` ## Test Plan ``` ./scripts/integration-tests.sh --stack-config server:ci-tests --setup vllm --subdirs inference ```
This commit is contained in:
parent
d15368a302
commit
8ef1189be7
3 changed files with 44 additions and 202 deletions
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import json
|
import json
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -38,13 +38,6 @@ from llama_stack.apis.inference import (
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ModelStore,
|
ModelStore,
|
||||||
OpenAIChatCompletion,
|
|
||||||
OpenAICompletion,
|
|
||||||
OpenAIEmbeddingData,
|
|
||||||
OpenAIEmbeddingsResponse,
|
|
||||||
OpenAIEmbeddingUsage,
|
|
||||||
OpenAIMessageParam,
|
|
||||||
OpenAIResponseFormatParam,
|
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
|
@ -71,11 +64,11 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
convert_message_to_openai_dict,
|
convert_message_to_openai_dict,
|
||||||
convert_tool_call,
|
convert_tool_call,
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
prepare_openai_completion_params,
|
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
process_completion_response,
|
process_completion_response,
|
||||||
process_completion_stream_response,
|
process_completion_stream_response,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
content_has_media,
|
content_has_media,
|
||||||
|
@ -288,7 +281,7 @@ async def _process_vllm_chat_completion_stream_response(
|
||||||
yield c
|
yield c
|
||||||
|
|
||||||
|
|
||||||
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
|
||||||
# automatically set by the resolver when instantiating the provider
|
# automatically set by the resolver when instantiating the provider
|
||||||
__provider_id__: str
|
__provider_id__: str
|
||||||
model_store: ModelStore | None = None
|
model_store: ModelStore | None = None
|
||||||
|
@ -296,7 +289,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||||
self.config = config
|
self.config = config
|
||||||
self.client = None
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
if not self.config.url:
|
if not self.config.url:
|
||||||
|
@ -308,8 +300,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
return self.config.refresh_models
|
return self.config.refresh_models
|
||||||
|
|
||||||
async def list_models(self) -> list[Model] | None:
|
async def list_models(self) -> list[Model] | None:
|
||||||
self._lazy_initialize_client()
|
|
||||||
assert self.client is not None # mypy
|
|
||||||
models = []
|
models = []
|
||||||
async for m in self.client.models.list():
|
async for m in self.client.models.list():
|
||||||
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
|
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
|
||||||
|
@ -340,8 +330,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
HealthResponse: A dictionary containing the health status.
|
HealthResponse: A dictionary containing the health status.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
client = self._create_client() if self.client is None else self.client
|
_ = [m async for m in self.client.models.list()] # Ensure the client is initialized
|
||||||
_ = [m async for m in client.models.list()] # Ensure the client is initialized
|
|
||||||
return HealthResponse(status=HealthStatus.OK)
|
return HealthResponse(status=HealthStatus.OK)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||||
|
@ -351,19 +340,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
raise ValueError("Model store not set")
|
raise ValueError("Model store not set")
|
||||||
return await self.model_store.get_model(model_id)
|
return await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
def _lazy_initialize_client(self):
|
def get_api_key(self):
|
||||||
if self.client is not None:
|
return self.config.api_token
|
||||||
return
|
|
||||||
|
|
||||||
log.info(f"Initializing vLLM client with base_url={self.config.url}")
|
def get_base_url(self):
|
||||||
self.client = self._create_client()
|
return self.config.url
|
||||||
|
|
||||||
def _create_client(self):
|
def get_extra_client_params(self):
|
||||||
return AsyncOpenAI(
|
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
|
||||||
base_url=self.config.url,
|
|
||||||
api_key=self.config.api_token,
|
|
||||||
http_client=httpx.AsyncClient(verify=self.config.tls_verify),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
|
@ -374,7 +358,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
stream: bool | None = False,
|
stream: bool | None = False,
|
||||||
logprobs: LogProbConfig | None = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||||
self._lazy_initialize_client()
|
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self._get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
|
@ -406,7 +389,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
logprobs: LogProbConfig | None = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
tool_config: ToolConfig | None = None,
|
tool_config: ToolConfig | None = None,
|
||||||
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||||
self._lazy_initialize_client()
|
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self._get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
|
@ -479,16 +461,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
# register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet.
|
|
||||||
# self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
|
|
||||||
# Changing this may lead to unpredictable behavior.
|
|
||||||
client = self._create_client() if self.client is None else self.client
|
|
||||||
try:
|
try:
|
||||||
model = await self.register_helper.register_model(model)
|
model = await self.register_helper.register_model(model)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass # Ignore statically unknown model, will check live listing
|
pass # Ignore statically unknown model, will check live listing
|
||||||
try:
|
try:
|
||||||
res = await client.models.list()
|
res = await self.client.models.list()
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL."
|
f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL."
|
||||||
|
@ -543,8 +521,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
output_dimension: int | None = None,
|
output_dimension: int | None = None,
|
||||||
task_type: EmbeddingTaskType | None = None,
|
task_type: EmbeddingTaskType | None = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
self._lazy_initialize_client()
|
|
||||||
assert self.client is not None
|
|
||||||
model = await self._get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
@ -560,154 +536,3 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
embeddings = [data.embedding for data in response.data]
|
embeddings = [data.embedding for data in response.data]
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
async def openai_embeddings(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
input: str | list[str],
|
|
||||||
encoding_format: str | None = "float",
|
|
||||||
dimensions: int | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIEmbeddingsResponse:
|
|
||||||
self._lazy_initialize_client()
|
|
||||||
assert self.client is not None
|
|
||||||
model_obj = await self._get_model(model)
|
|
||||||
assert model_obj.model_type == ModelType.embedding
|
|
||||||
|
|
||||||
# Convert input to list if it's a string
|
|
||||||
input_list = [input] if isinstance(input, str) else input
|
|
||||||
|
|
||||||
# Call vLLM embeddings endpoint with encoding_format
|
|
||||||
response = await self.client.embeddings.create(
|
|
||||||
model=model_obj.provider_resource_id,
|
|
||||||
input=input_list,
|
|
||||||
dimensions=dimensions,
|
|
||||||
encoding_format=encoding_format,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert response to OpenAI format
|
|
||||||
data = [
|
|
||||||
OpenAIEmbeddingData(
|
|
||||||
embedding=embedding_data.embedding,
|
|
||||||
index=i,
|
|
||||||
)
|
|
||||||
for i, embedding_data in enumerate(response.data)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Not returning actual token usage since vLLM doesn't provide it
|
|
||||||
usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
|
|
||||||
|
|
||||||
return OpenAIEmbeddingsResponse(
|
|
||||||
data=data,
|
|
||||||
model=model_obj.provider_resource_id,
|
|
||||||
usage=usage,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
prompt: str | list[str] | list[int] | list[list[int]],
|
|
||||||
best_of: int | None = None,
|
|
||||||
echo: bool | None = None,
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
guided_choice: list[str] | None = None,
|
|
||||||
prompt_logprobs: int | None = None,
|
|
||||||
suffix: str | None = None,
|
|
||||||
) -> OpenAICompletion:
|
|
||||||
self._lazy_initialize_client()
|
|
||||||
model_obj = await self._get_model(model)
|
|
||||||
|
|
||||||
extra_body: dict[str, Any] = {}
|
|
||||||
if prompt_logprobs is not None and prompt_logprobs >= 0:
|
|
||||||
extra_body["prompt_logprobs"] = prompt_logprobs
|
|
||||||
if guided_choice:
|
|
||||||
extra_body["guided_choice"] = guided_choice
|
|
||||||
|
|
||||||
params = await prepare_openai_completion_params(
|
|
||||||
model=model_obj.provider_resource_id,
|
|
||||||
prompt=prompt,
|
|
||||||
best_of=best_of,
|
|
||||||
echo=echo,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
extra_body=extra_body,
|
|
||||||
)
|
|
||||||
return await self.client.completions.create(**params) # type: ignore
|
|
||||||
|
|
||||||
async def openai_chat_completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: list[OpenAIMessageParam],
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
function_call: str | dict[str, Any] | None = None,
|
|
||||||
functions: list[dict[str, Any]] | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
|
||||||
self._lazy_initialize_client()
|
|
||||||
model_obj = await self._get_model(model)
|
|
||||||
params = await prepare_openai_completion_params(
|
|
||||||
model=model_obj.provider_resource_id,
|
|
||||||
messages=messages,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
function_call=function_call,
|
|
||||||
functions=functions,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_completion_tokens=max_completion_tokens,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
response_format=response_format,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
tools=tools,
|
|
||||||
top_logprobs=top_logprobs,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
return await self.client.chat.completions.create(**params) # type: ignore
|
|
||||||
|
|
|
@ -67,6 +67,17 @@ class OpenAIMixin(ABC):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_extra_client_params(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get any extra parameters to pass to the AsyncOpenAI client.
|
||||||
|
|
||||||
|
Child classes can override this method to provide additional parameters
|
||||||
|
such as timeout settings, proxies, etc.
|
||||||
|
|
||||||
|
:return: A dictionary of extra parameters
|
||||||
|
"""
|
||||||
|
return {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> AsyncOpenAI:
|
def client(self) -> AsyncOpenAI:
|
||||||
"""
|
"""
|
||||||
|
@ -78,6 +89,7 @@ class OpenAIMixin(ABC):
|
||||||
return AsyncOpenAI(
|
return AsyncOpenAI(
|
||||||
api_key=self.get_api_key(),
|
api_key=self.get_api_key(),
|
||||||
base_url=self.get_base_url(),
|
base_url=self.get_base_url(),
|
||||||
|
**self.get_extra_client_params(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get_provider_model_id(self, model: str) -> str:
|
async def _get_provider_model_id(self, model: str) -> str:
|
||||||
|
@ -124,10 +136,15 @@ class OpenAIMixin(ABC):
|
||||||
"""
|
"""
|
||||||
Direct OpenAI completion API call.
|
Direct OpenAI completion API call.
|
||||||
"""
|
"""
|
||||||
if guided_choice is not None:
|
# Handle parameters that are not supported by OpenAI API, but may be by the provider
|
||||||
logger.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
|
# prompt_logprobs is supported by vLLM
|
||||||
if prompt_logprobs is not None:
|
# guided_choice is supported by vLLM
|
||||||
logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
|
# TODO: test coverage
|
||||||
|
extra_body: dict[str, Any] = {}
|
||||||
|
if prompt_logprobs is not None and prompt_logprobs >= 0:
|
||||||
|
extra_body["prompt_logprobs"] = prompt_logprobs
|
||||||
|
if guided_choice:
|
||||||
|
extra_body["guided_choice"] = guided_choice
|
||||||
|
|
||||||
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
||||||
return await self.client.completions.create( # type: ignore[no-any-return]
|
return await self.client.completions.create( # type: ignore[no-any-return]
|
||||||
|
@ -150,7 +167,8 @@ class OpenAIMixin(ABC):
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
user=user,
|
user=user,
|
||||||
suffix=suffix,
|
suffix=suffix,
|
||||||
)
|
),
|
||||||
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
|
|
|
@ -11,7 +11,7 @@ import threading
|
||||||
import time
|
import time
|
||||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from openai.types.chat.chat_completion_chunk import (
|
from openai.types.chat.chat_completion_chunk import (
|
||||||
|
@ -150,10 +150,12 @@ async def test_tool_call_response(vllm_inference_adapter):
|
||||||
"""Verify that tool call arguments from a CompletionMessage are correctly converted
|
"""Verify that tool call arguments from a CompletionMessage are correctly converted
|
||||||
into the expected JSON format."""
|
into the expected JSON format."""
|
||||||
|
|
||||||
# Patch the call to vllm so we can inspect the arguments sent were correct
|
# Patch the client property to avoid instantiating a real AsyncOpenAI client
|
||||||
with patch.object(
|
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
|
||||||
vllm_inference_adapter.client.chat.completions, "create", new_callable=AsyncMock
|
mock_client = MagicMock()
|
||||||
) as mock_nonstream_completion:
|
mock_client.chat.completions.create = AsyncMock()
|
||||||
|
mock_create_client.return_value = mock_client
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
SystemMessage(content="You are a helpful assistant"),
|
SystemMessage(content="You are a helpful assistant"),
|
||||||
UserMessage(content="How many?"),
|
UserMessage(content="How many?"),
|
||||||
|
@ -179,7 +181,7 @@ async def test_tool_call_response(vllm_inference_adapter):
|
||||||
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
|
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert mock_nonstream_completion.call_args.kwargs["messages"][2]["tool_calls"] == [
|
assert mock_client.chat.completions.create.call_args.kwargs["messages"][2]["tool_calls"] == [
|
||||||
{
|
{
|
||||||
"id": "foo",
|
"id": "foo",
|
||||||
"type": "function",
|
"type": "function",
|
||||||
|
@ -641,9 +643,7 @@ async def test_health_status_success(vllm_inference_adapter):
|
||||||
This test verifies that the health method returns a HealthResponse with status OK, only
|
This test verifies that the health method returns a HealthResponse with status OK, only
|
||||||
when the connection to the vLLM server is successful.
|
when the connection to the vLLM server is successful.
|
||||||
"""
|
"""
|
||||||
# Set vllm_inference_adapter.client to None to ensure _create_client is called
|
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
|
||||||
vllm_inference_adapter.client = None
|
|
||||||
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
|
|
||||||
# Create mock client and models
|
# Create mock client and models
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
mock_models = MagicMock()
|
mock_models = MagicMock()
|
||||||
|
@ -674,8 +674,7 @@ async def test_health_status_failure(vllm_inference_adapter):
|
||||||
This test verifies that the health method returns a HealthResponse with status ERROR
|
This test verifies that the health method returns a HealthResponse with status ERROR
|
||||||
and an appropriate error message when the connection to the vLLM server fails.
|
and an appropriate error message when the connection to the vLLM server fails.
|
||||||
"""
|
"""
|
||||||
vllm_inference_adapter.client = None
|
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
|
||||||
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
|
|
||||||
# Create mock client and models
|
# Create mock client and models
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
mock_models = MagicMock()
|
mock_models = MagicMock()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue