Merge branch 'main' into chroma

This commit is contained in:
Bwook (Byoungwook) Kim 2025-09-11 22:59:26 +09:00 committed by GitHub
commit 729e0f3fcb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 44 additions and 202 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import AsyncGenerator
from typing import Any
import httpx
@ -38,13 +38,6 @@ from llama_stack.apis.inference import (
LogProbConfig,
Message,
ModelStore,
OpenAIChatCompletion,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -71,11 +64,11 @@ from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
convert_tool_call,
get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
completion_request_to_prompt,
content_has_media,
@ -288,7 +281,7 @@ async def _process_vllm_chat_completion_stream_response(
yield c
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None
@ -296,7 +289,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.config = config
self.client = None
async def initialize(self) -> None:
if not self.config.url:
@ -308,8 +300,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
return self.config.refresh_models
async def list_models(self) -> list[Model] | None:
self._lazy_initialize_client()
assert self.client is not None # mypy
models = []
async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
@ -340,8 +330,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
HealthResponse: A dictionary containing the health status.
"""
try:
client = self._create_client() if self.client is None else self.client
_ = [m async for m in client.models.list()] # Ensure the client is initialized
_ = [m async for m in self.client.models.list()] # Ensure the client is initialized
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
@ -351,19 +340,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
def _lazy_initialize_client(self):
if self.client is not None:
return
def get_api_key(self):
return self.config.api_token
log.info(f"Initializing vLLM client with base_url={self.config.url}")
self.client = self._create_client()
def get_base_url(self):
return self.config.url
def _create_client(self):
return AsyncOpenAI(
base_url=self.config.url,
api_key=self.config.api_token,
http_client=httpx.AsyncClient(verify=self.config.tls_verify),
)
def get_extra_client_params(self):
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
async def completion(
self,
@ -374,7 +358,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
self._lazy_initialize_client()
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
@ -406,7 +389,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
self._lazy_initialize_client()
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
@ -479,16 +461,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk
async def register_model(self, model: Model) -> Model:
# register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet.
# self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
# Changing this may lead to unpredictable behavior.
client = self._create_client() if self.client is None else self.client
try:
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
try:
res = await client.models.list()
res = await self.client.models.list()
except APIConnectionError as e:
raise ValueError(
f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL."
@ -543,8 +521,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
self._lazy_initialize_client()
assert self.client is not None
model = await self._get_model(model_id)
kwargs = {}
@ -560,154 +536,3 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
self._lazy_initialize_client()
assert self.client is not None
model_obj = await self._get_model(model)
assert model_obj.model_type == ModelType.embedding
# Convert input to list if it's a string
input_list = [input] if isinstance(input, str) else input
# Call vLLM embeddings endpoint with encoding_format
response = await self.client.embeddings.create(
model=model_obj.provider_resource_id,
input=input_list,
dimensions=dimensions,
encoding_format=encoding_format,
)
# Convert response to OpenAI format
data = [
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
for i, embedding_data in enumerate(response.data)
]
# Not returning actual token usage since vLLM doesn't provide it
usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.provider_resource_id,
usage=usage,
)
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
self._lazy_initialize_client()
model_obj = await self._get_model(model)
extra_body: dict[str, Any] = {}
if prompt_logprobs is not None and prompt_logprobs >= 0:
extra_body["prompt_logprobs"] = prompt_logprobs
if guided_choice:
extra_body["guided_choice"] = guided_choice
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
extra_body=extra_body,
)
return await self.client.completions.create(**params) # type: ignore
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
self._lazy_initialize_client()
model_obj = await self._get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await self.client.chat.completions.create(**params) # type: ignore

View file

@ -67,6 +67,17 @@ class OpenAIMixin(ABC):
"""
pass
def get_extra_client_params(self) -> dict[str, Any]:
"""
Get any extra parameters to pass to the AsyncOpenAI client.
Child classes can override this method to provide additional parameters
such as timeout settings, proxies, etc.
:return: A dictionary of extra parameters
"""
return {}
@property
def client(self) -> AsyncOpenAI:
"""
@ -78,6 +89,7 @@ class OpenAIMixin(ABC):
return AsyncOpenAI(
api_key=self.get_api_key(),
base_url=self.get_base_url(),
**self.get_extra_client_params(),
)
async def _get_provider_model_id(self, model: str) -> str:
@ -124,10 +136,15 @@ class OpenAIMixin(ABC):
"""
Direct OpenAI completion API call.
"""
if guided_choice is not None:
logger.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
if prompt_logprobs is not None:
logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
# Handle parameters that are not supported by OpenAI API, but may be by the provider
# prompt_logprobs is supported by vLLM
# guided_choice is supported by vLLM
# TODO: test coverage
extra_body: dict[str, Any] = {}
if prompt_logprobs is not None and prompt_logprobs >= 0:
extra_body["prompt_logprobs"] = prompt_logprobs
if guided_choice:
extra_body["guided_choice"] = guided_choice
# TODO: fix openai_completion to return type compatible with OpenAI's API response
return await self.client.completions.create( # type: ignore[no-any-return]
@ -150,7 +167,8 @@ class OpenAIMixin(ABC):
top_p=top_p,
user=user,
suffix=suffix,
)
),
extra_body=extra_body,
)
async def openai_chat_completion(

View file

@ -11,7 +11,7 @@ import threading
import time
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import pytest
from openai.types.chat.chat_completion_chunk import (
@ -150,10 +150,12 @@ async def test_tool_call_response(vllm_inference_adapter):
"""Verify that tool call arguments from a CompletionMessage are correctly converted
into the expected JSON format."""
# Patch the call to vllm so we can inspect the arguments sent were correct
with patch.object(
vllm_inference_adapter.client.chat.completions, "create", new_callable=AsyncMock
) as mock_nonstream_completion:
# Patch the client property to avoid instantiating a real AsyncOpenAI client
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock()
mock_create_client.return_value = mock_client
messages = [
SystemMessage(content="You are a helpful assistant"),
UserMessage(content="How many?"),
@ -179,7 +181,7 @@ async def test_tool_call_response(vllm_inference_adapter):
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
assert mock_nonstream_completion.call_args.kwargs["messages"][2]["tool_calls"] == [
assert mock_client.chat.completions.create.call_args.kwargs["messages"][2]["tool_calls"] == [
{
"id": "foo",
"type": "function",
@ -641,9 +643,7 @@ async def test_health_status_success(vllm_inference_adapter):
This test verifies that the health method returns a HealthResponse with status OK, only
when the connection to the vLLM server is successful.
"""
# Set vllm_inference_adapter.client to None to ensure _create_client is called
vllm_inference_adapter.client = None
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
# Create mock client and models
mock_client = MagicMock()
mock_models = MagicMock()
@ -674,8 +674,7 @@ async def test_health_status_failure(vllm_inference_adapter):
This test verifies that the health method returns a HealthResponse with status ERROR
and an appropriate error message when the connection to the vLLM server fails.
"""
vllm_inference_adapter.client = None
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
# Create mock client and models
mock_client = MagicMock()
mock_models = MagicMock()