Merge branch 'main' into chroma

This commit is contained in:
Bwook (Byoungwook) Kim 2025-09-11 22:59:26 +09:00 committed by GitHub
commit 729e0f3fcb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 44 additions and 202 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json import json
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator
from typing import Any from typing import Any
import httpx import httpx
@ -38,13 +38,6 @@ from llama_stack.apis.inference import (
LogProbConfig, LogProbConfig,
Message, Message,
ModelStore, ModelStore,
OpenAIChatCompletion,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation, TextTruncation,
@ -71,11 +64,11 @@ from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict, convert_message_to_openai_dict,
convert_tool_call, convert_tool_call,
get_sampling_options, get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_stream_response, process_chat_completion_stream_response,
process_completion_response, process_completion_response,
process_completion_stream_response, process_completion_stream_response,
) )
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
completion_request_to_prompt, completion_request_to_prompt,
content_has_media, content_has_media,
@ -288,7 +281,7 @@ async def _process_vllm_chat_completion_stream_response(
yield c yield c
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider # automatically set by the resolver when instantiating the provider
__provider_id__: str __provider_id__: str
model_store: ModelStore | None = None model_store: ModelStore | None = None
@ -296,7 +289,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMInferenceAdapterConfig) -> None: def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.config = config self.config = config
self.client = None
async def initialize(self) -> None: async def initialize(self) -> None:
if not self.config.url: if not self.config.url:
@ -308,8 +300,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
return self.config.refresh_models return self.config.refresh_models
async def list_models(self) -> list[Model] | None: async def list_models(self) -> list[Model] | None:
self._lazy_initialize_client()
assert self.client is not None # mypy
models = [] models = []
async for m in self.client.models.list(): async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models model_type = ModelType.llm # unclear how to determine embedding vs. llm models
@ -340,8 +330,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
HealthResponse: A dictionary containing the health status. HealthResponse: A dictionary containing the health status.
""" """
try: try:
client = self._create_client() if self.client is None else self.client _ = [m async for m in self.client.models.list()] # Ensure the client is initialized
_ = [m async for m in client.models.list()] # Ensure the client is initialized
return HealthResponse(status=HealthStatus.OK) return HealthResponse(status=HealthStatus.OK)
except Exception as e: except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
@ -351,19 +340,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
raise ValueError("Model store not set") raise ValueError("Model store not set")
return await self.model_store.get_model(model_id) return await self.model_store.get_model(model_id)
def _lazy_initialize_client(self): def get_api_key(self):
if self.client is not None: return self.config.api_token
return
log.info(f"Initializing vLLM client with base_url={self.config.url}") def get_base_url(self):
self.client = self._create_client() return self.config.url
def _create_client(self): def get_extra_client_params(self):
return AsyncOpenAI( return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
base_url=self.config.url,
api_key=self.config.api_token,
http_client=httpx.AsyncClient(verify=self.config.tls_verify),
)
async def completion( async def completion(
self, self,
@ -374,7 +358,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: bool | None = False, stream: bool | None = False,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]: ) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
self._lazy_initialize_client()
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
model = await self._get_model(model_id) model = await self._get_model(model_id)
@ -406,7 +389,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None, tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]: ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
self._lazy_initialize_client()
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
model = await self._get_model(model_id) model = await self._get_model(model_id)
@ -479,16 +461,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk yield chunk
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
# register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet.
# self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
# Changing this may lead to unpredictable behavior.
client = self._create_client() if self.client is None else self.client
try: try:
model = await self.register_helper.register_model(model) model = await self.register_helper.register_model(model)
except ValueError: except ValueError:
pass # Ignore statically unknown model, will check live listing pass # Ignore statically unknown model, will check live listing
try: try:
res = await client.models.list() res = await self.client.models.list()
except APIConnectionError as e: except APIConnectionError as e:
raise ValueError( raise ValueError(
f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL." f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL."
@ -543,8 +521,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
output_dimension: int | None = None, output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None, task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
self._lazy_initialize_client()
assert self.client is not None
model = await self._get_model(model_id) model = await self._get_model(model_id)
kwargs = {} kwargs = {}
@ -560,154 +536,3 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
embeddings = [data.embedding for data in response.data] embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings) return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
self._lazy_initialize_client()
assert self.client is not None
model_obj = await self._get_model(model)
assert model_obj.model_type == ModelType.embedding
# Convert input to list if it's a string
input_list = [input] if isinstance(input, str) else input
# Call vLLM embeddings endpoint with encoding_format
response = await self.client.embeddings.create(
model=model_obj.provider_resource_id,
input=input_list,
dimensions=dimensions,
encoding_format=encoding_format,
)
# Convert response to OpenAI format
data = [
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
for i, embedding_data in enumerate(response.data)
]
# Not returning actual token usage since vLLM doesn't provide it
usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.provider_resource_id,
usage=usage,
)
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
self._lazy_initialize_client()
model_obj = await self._get_model(model)
extra_body: dict[str, Any] = {}
if prompt_logprobs is not None and prompt_logprobs >= 0:
extra_body["prompt_logprobs"] = prompt_logprobs
if guided_choice:
extra_body["guided_choice"] = guided_choice
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
extra_body=extra_body,
)
return await self.client.completions.create(**params) # type: ignore
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
self._lazy_initialize_client()
model_obj = await self._get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await self.client.chat.completions.create(**params) # type: ignore

View file

@ -67,6 +67,17 @@ class OpenAIMixin(ABC):
""" """
pass pass
def get_extra_client_params(self) -> dict[str, Any]:
"""
Get any extra parameters to pass to the AsyncOpenAI client.
Child classes can override this method to provide additional parameters
such as timeout settings, proxies, etc.
:return: A dictionary of extra parameters
"""
return {}
@property @property
def client(self) -> AsyncOpenAI: def client(self) -> AsyncOpenAI:
""" """
@ -78,6 +89,7 @@ class OpenAIMixin(ABC):
return AsyncOpenAI( return AsyncOpenAI(
api_key=self.get_api_key(), api_key=self.get_api_key(),
base_url=self.get_base_url(), base_url=self.get_base_url(),
**self.get_extra_client_params(),
) )
async def _get_provider_model_id(self, model: str) -> str: async def _get_provider_model_id(self, model: str) -> str:
@ -124,10 +136,15 @@ class OpenAIMixin(ABC):
""" """
Direct OpenAI completion API call. Direct OpenAI completion API call.
""" """
if guided_choice is not None: # Handle parameters that are not supported by OpenAI API, but may be by the provider
logger.warning("guided_choice is not supported by the OpenAI API. Ignoring.") # prompt_logprobs is supported by vLLM
if prompt_logprobs is not None: # guided_choice is supported by vLLM
logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.") # TODO: test coverage
extra_body: dict[str, Any] = {}
if prompt_logprobs is not None and prompt_logprobs >= 0:
extra_body["prompt_logprobs"] = prompt_logprobs
if guided_choice:
extra_body["guided_choice"] = guided_choice
# TODO: fix openai_completion to return type compatible with OpenAI's API response # TODO: fix openai_completion to return type compatible with OpenAI's API response
return await self.client.completions.create( # type: ignore[no-any-return] return await self.client.completions.create( # type: ignore[no-any-return]
@ -150,7 +167,8 @@ class OpenAIMixin(ABC):
top_p=top_p, top_p=top_p,
user=user, user=user,
suffix=suffix, suffix=suffix,
) ),
extra_body=extra_body,
) )
async def openai_chat_completion( async def openai_chat_completion(

View file

@ -11,7 +11,7 @@ import threading
import time import time
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import pytest import pytest
from openai.types.chat.chat_completion_chunk import ( from openai.types.chat.chat_completion_chunk import (
@ -150,10 +150,12 @@ async def test_tool_call_response(vllm_inference_adapter):
"""Verify that tool call arguments from a CompletionMessage are correctly converted """Verify that tool call arguments from a CompletionMessage are correctly converted
into the expected JSON format.""" into the expected JSON format."""
# Patch the call to vllm so we can inspect the arguments sent were correct # Patch the client property to avoid instantiating a real AsyncOpenAI client
with patch.object( with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
vllm_inference_adapter.client.chat.completions, "create", new_callable=AsyncMock mock_client = MagicMock()
) as mock_nonstream_completion: mock_client.chat.completions.create = AsyncMock()
mock_create_client.return_value = mock_client
messages = [ messages = [
SystemMessage(content="You are a helpful assistant"), SystemMessage(content="You are a helpful assistant"),
UserMessage(content="How many?"), UserMessage(content="How many?"),
@ -179,7 +181,7 @@ async def test_tool_call_response(vllm_inference_adapter):
tool_config=ToolConfig(tool_choice=ToolChoice.auto), tool_config=ToolConfig(tool_choice=ToolChoice.auto),
) )
assert mock_nonstream_completion.call_args.kwargs["messages"][2]["tool_calls"] == [ assert mock_client.chat.completions.create.call_args.kwargs["messages"][2]["tool_calls"] == [
{ {
"id": "foo", "id": "foo",
"type": "function", "type": "function",
@ -641,9 +643,7 @@ async def test_health_status_success(vllm_inference_adapter):
This test verifies that the health method returns a HealthResponse with status OK, only This test verifies that the health method returns a HealthResponse with status OK, only
when the connection to the vLLM server is successful. when the connection to the vLLM server is successful.
""" """
# Set vllm_inference_adapter.client to None to ensure _create_client is called with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
vllm_inference_adapter.client = None
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
# Create mock client and models # Create mock client and models
mock_client = MagicMock() mock_client = MagicMock()
mock_models = MagicMock() mock_models = MagicMock()
@ -674,8 +674,7 @@ async def test_health_status_failure(vllm_inference_adapter):
This test verifies that the health method returns a HealthResponse with status ERROR This test verifies that the health method returns a HealthResponse with status ERROR
and an appropriate error message when the connection to the vLLM server fails. and an appropriate error message when the connection to the vLLM server fails.
""" """
vllm_inference_adapter.client = None with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
# Create mock client and models # Create mock client and models
mock_client = MagicMock() mock_client = MagicMock()
mock_models = MagicMock() mock_models = MagicMock()