Merge branch 'main' into hide-non-openai-inference-apis

This commit is contained in:
Matthew Farrellee 2025-09-26 17:48:30 -04:00
commit 0e78cd5383
33 changed files with 2394 additions and 1723 deletions

View file

@ -139,18 +139,7 @@ Methods:
- <code title="post /v1/agents/{agent_id}/session/{session_id}/turn">client.agents.turn.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/turn.py">create</a>(session_id, \*, agent_id, \*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/turn_create_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/turn_create_response.py">TurnCreateResponse</a></code> - <code title="post /v1/agents/{agent_id}/session/{session_id}/turn">client.agents.turn.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/turn.py">create</a>(session_id, \*, agent_id, \*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/turn_create_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/turn_create_response.py">TurnCreateResponse</a></code>
- <code title="get /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}">client.agents.turn.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/turn.py">retrieve</a>(turn_id, \*, agent_id, session_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/turn.py">Turn</a></code> - <code title="get /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}">client.agents.turn.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/turn.py">retrieve</a>(turn_id, \*, agent_id, session_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/turn.py">Turn</a></code>
## BatchInference
Types:
```python
from llama_stack_client.types import BatchInferenceChatCompletionResponse
```
Methods:
- <code title="post /v1/batch-inference/chat-completion">client.batch_inference.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/batch_inference.py">chat_completion</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/batch_inference_chat_completion_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/batch_inference_chat_completion_response.py">BatchInferenceChatCompletionResponse</a></code>
- <code title="post /v1/batch-inference/completion">client.batch_inference.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/batch_inference.py">completion</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/batch_inference_completion_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/shared/batch_completion.py">BatchCompletion</a></code>
## Datasets ## Datasets

View file

@ -548,7 +548,6 @@ class Generator:
if op.defining_class.__name__ in [ if op.defining_class.__name__ in [
"SyntheticDataGeneration", "SyntheticDataGeneration",
"PostTraining", "PostTraining",
"BatchInference",
]: ]:
op.defining_class.__name__ = f"{op.defining_class.__name__} (Coming Soon)" op.defining_class.__name__ = f"{op.defining_class.__name__} (Coming Soon)"
print(op.defining_class.__name__) print(op.defining_class.__name__)

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .batch_inference import *

View file

@ -1,79 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Protocol, runtime_checkable
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import (
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import webmethod
@runtime_checkable
class BatchInference(Protocol):
"""Batch inference API for generating completions and chat completions.
This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion.
NOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs
including (post-training, evals, etc).
"""
@webmethod(route="/batch-inference/completion", method="POST", level=LLAMA_STACK_API_V1)
async def completion(
self,
model: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> Job:
"""Generate completions for a batch of content.
:param model: The model to use for the completion.
:param content_batch: The content to complete.
:param sampling_params: The sampling parameters to use for the completion.
:param response_format: The response format to use for the completion.
:param logprobs: The logprobs to use for the completion.
:returns: A job for the completion.
"""
...
@webmethod(route="/batch-inference/chat-completion", method="POST", level=LLAMA_STACK_API_V1)
async def chat_completion(
self,
model: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
# zero-shot tool definitions as input to the model
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> Job:
"""Generate chat completions for a batch of messages.
:param model: The model to use for the chat completion.
:param messages_batch: The messages to complete.
:param sampling_params: The sampling parameters to use for the completion.
:param tools: The tools to use for the chat completion.
:param tool_choice: The tool choice to use for the chat completion.
:param tool_prompt_format: The tool prompt format to use for the chat completion.
:param response_format: The response format to use for the chat completion.
:param logprobs: The logprobs to use for the chat completion.
:returns: A job for the chat completion.
"""
...

View file

@ -914,6 +914,7 @@ class OpenAIEmbeddingData(BaseModel):
""" """
object: Literal["embedding"] = "embedding" object: Literal["embedding"] = "embedding"
# TODO: consider dropping str and using openai.types.embeddings.Embedding instead of OpenAIEmbeddingData
embedding: list[float] | str embedding: list[float] | str
index: int index: int
@ -974,26 +975,6 @@ class EmbeddingTaskType(Enum):
document = "document" document = "document"
@json_schema_type
class BatchCompletionResponse(BaseModel):
"""Response from a batch completion request.
:param batch: List of completion responses, one for each input in the batch
"""
batch: list[CompletionResponse]
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
"""Response from a batch chat completion request.
:param batch: List of chat completion responses, one for each conversation in the batch
"""
batch: list[ChatCompletionResponse]
class OpenAICompletionWithInputMessages(OpenAIChatCompletion): class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
input_messages: list[OpenAIMessageParam] input_messages: list[OpenAIMessageParam]
@ -1049,26 +1030,7 @@ class InferenceProvider(Protocol):
""" """
... ...
async def batch_completion( @webmethod(route="/inference/chat-completion", method="POST", level=LLAMA_STACK_API_V1)
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
"""Generate completions for a batch of content using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param content_batch: The content to generate completions for.
:param sampling_params: (Optional) Parameters to control the sampling strategy.
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: A BatchCompletionResponse with the full completions.
"""
raise NotImplementedError("Batch completion is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
async def chat_completion( async def chat_completion(
self, self,
model_id: str, model_id: str,
@ -1108,30 +1070,7 @@ class InferenceProvider(Protocol):
""" """
... ...
async def batch_chat_completion( @webmethod(route="/inference/embeddings", method="POST", level=LLAMA_STACK_API_V1)
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse:
"""Generate chat completions for a batch of messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages_batch: The messages to generate completions for.
:param sampling_params: (Optional) Parameters to control the sampling strategy.
:param tools: (Optional) List of tool definitions available to the model.
:param tool_config: (Optional) Configuration for tool use.
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: A BatchChatCompletionResponse with the full completions.
"""
raise NotImplementedError("Batch chat completion is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,

View file

@ -20,8 +20,6 @@ from llama_stack.apis.common.content_types import (
) )
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
@ -273,30 +271,6 @@ class InferenceRouter(Inference):
) )
return response return response
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse:
logger.debug(
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.batch_chat_completion(
model_id=model_id,
messages_batch=messages_batch,
tools=tools,
tool_config=tool_config,
sampling_params=sampling_params,
response_format=response_format,
logprobs=logprobs,
)
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
@ -338,20 +312,6 @@ class InferenceRouter(Inference):
return response return response
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
logger.debug(
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,

View file

@ -9,7 +9,7 @@ from typing import Any
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.errors import ToolGroupNotFoundError from llama_stack.apis.common.errors import ToolGroupNotFoundError
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
from llama_stack.core.datatypes import ToolGroupWithOwner from llama_stack.core.datatypes import AuthenticationRequiredError, ToolGroupWithOwner
from llama_stack.log import get_logger from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
@ -54,7 +54,18 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
all_tools = [] all_tools = []
for toolgroup in toolgroups: for toolgroup in toolgroups:
if toolgroup.identifier not in self.toolgroups_to_tools: if toolgroup.identifier not in self.toolgroups_to_tools:
await self._index_tools(toolgroup) try:
await self._index_tools(toolgroup)
except AuthenticationRequiredError:
# Send authentication errors back to the client so it knows
# that it needs to supply credentials for remote MCP servers.
raise
except Exception as e:
# Other errors that the client cannot fix are logged and
# those specific toolgroups are skipped.
logger.warning(f"Error listing tools for toolgroup {toolgroup.identifier}: {e}")
logger.debug(e, exc_info=True)
continue
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier]) all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
return ListToolsResponse(data=all_tools) return ListToolsResponse(data=all_tools)

View file

@ -14,7 +14,6 @@ from typing import Any
import yaml import yaml
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
@ -54,7 +53,6 @@ class LlamaStack(
Providers, Providers,
VectorDBs, VectorDBs,
Inference, Inference,
BatchInference,
Agents, Agents,
Safety, Safety,
SyntheticDataGeneration, SyntheticDataGeneration,

View file

@ -96,11 +96,9 @@ class DiskDistributionRegistry(DistributionRegistry):
async def register(self, obj: RoutableObjectWithProvider) -> bool: async def register(self, obj: RoutableObjectWithProvider) -> bool:
existing_obj = await self.get(obj.type, obj.identifier) existing_obj = await self.get(obj.type, obj.identifier)
# warn if the object's providerid is different but proceed with registration # dont register if the object's providerid already exists
if existing_obj and existing_obj.provider_id != obj.provider_id: if existing_obj and existing_obj.provider_id == obj.provider_id:
logger.warning( return False
f"Object {existing_obj.type}:{existing_obj.identifier}'s {existing_obj.provider_id} provider is being replaced with {obj.provider_id}"
)
await self.kvstore.set( await self.kvstore.set(
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier), KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),

View file

@ -18,8 +18,6 @@ from llama_stack.apis.common.content_types import (
ToolCallParseStatus, ToolCallParseStatus,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEvent, ChatCompletionResponseEvent,
@ -219,41 +217,6 @@ class MetaReferenceInferenceImpl(
results = await self._nonstream_completion([request]) results = await self._nonstream_completion([request])
return results[0] return results[0]
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
content_batch = [
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
]
request_batch = []
for content in content_batch:
request = CompletionRequest(
model=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
request = await convert_request_to_raw(request)
request_batch.append(request)
results = await self._nonstream_completion(request_batch)
return BatchCompletionResponse(batch=results)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer tokenizer = self.generator.formatter.tokenizer
@ -399,49 +362,6 @@ class MetaReferenceInferenceImpl(
results = await self._nonstream_chat_completion([request]) results = await self._nonstream_chat_completion([request])
return results[0] return results[0]
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> BatchChatCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request_batch = []
for messages in messages_batch:
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
logprobs=logprobs,
tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
request_batch.append(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
results = await self._nonstream_chat_completion(request_batch)
return BatchChatCompletionResponse(batch=results)
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request_batch: list[ChatCompletionRequest] self, request_batch: list[ChatCompletionRequest]
) -> list[ChatCompletionResponse]: ) -> list[ChatCompletionResponse]:

View file

@ -24,7 +24,6 @@ from llama_stack.apis.inference import (
LogProbConfig, LogProbConfig,
Message, Message,
Model, Model,
ModelType,
OpenAICompletion, OpenAICompletion,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
@ -34,6 +33,7 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin

View file

@ -64,6 +64,7 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
} }
def __init__(self, config: FireworksImplConfig) -> None: def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self)
self.config = config self.config = config
self.allowed_models = config.allowed_models self.allowed_models = config.allowed_models

View file

@ -4,12 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.inference import Inference
from .config import GroqConfig from .config import GroqConfig
async def get_adapter_impl(config: GroqConfig, _deps) -> Inference: async def get_adapter_impl(config: GroqConfig, _deps):
# import dynamically so the import is used only when it is needed # import dynamically so the import is used only when it is needed
from .groq import GroqInferenceAdapter from .groq import GroqInferenceAdapter

View file

@ -6,8 +6,7 @@
import asyncio import asyncio
import base64 from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any from typing import Any
from ollama import AsyncClient as AsyncOllamaClient from ollama import AsyncClient as AsyncOllamaClient
@ -33,10 +32,6 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
LogProbConfig, LogProbConfig,
Message, Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation, TextTruncation,
@ -62,7 +57,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
get_sampling_options, get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
process_completion_response, process_completion_response,
@ -75,7 +69,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
content_has_media, content_has_media,
convert_image_content_to_url, convert_image_content_to_url,
interleaved_content_as_str, interleaved_content_as_str,
localize_image_content,
request_has_media, request_has_media,
) )
@ -84,6 +77,7 @@ logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter( class OllamaInferenceAdapter(
OpenAIMixin, OpenAIMixin,
ModelRegistryHelper,
InferenceProvider, InferenceProvider,
ModelsProtocolPrivate, ModelsProtocolPrivate,
): ):
@ -129,6 +123,8 @@ class OllamaInferenceAdapter(
], ],
) )
self.config = config self.config = config
# Ollama does not support image urls, so we need to download the image and convert it to base64
self.download_images = True
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {} self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
@property @property
@ -173,9 +169,6 @@ class OllamaInferenceAdapter(
async def shutdown(self) -> None: async def shutdown(self) -> None:
self._clients.clear() self._clients.clear()
async def unregister_model(self, model_id: str) -> None:
pass
async def _get_model(self, model_id: str) -> Model: async def _get_model(self, model_id: str) -> Model:
if not self.model_store: if not self.model_store:
raise ValueError("Model store not set") raise ValueError("Model store not set")
@ -403,75 +396,6 @@ class OllamaInferenceAdapter(
raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys())) raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys()))
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self._get_model(model)
# Ollama does not support image urls, so we need to download the image and convert it to base64
async def _convert_message(m: OpenAIMessageParam) -> OpenAIMessageParam:
if isinstance(m.content, list):
for c in m.content:
if c.type == "image_url" and c.image_url and c.image_url.url:
localize_result = await localize_image_content(c.image_url.url)
if localize_result is None:
raise ValueError(f"Failed to localize image content from {c.image_url.url}")
content, format = localize_result
c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}"
return m
messages = [await _convert_message(m) for m in messages]
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await OpenAIMixin.openai_chat_completion(self, **params)
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]: async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
async def _convert_content(content) -> dict: async def _convert_content(content) -> dict:

View file

@ -21,8 +21,6 @@ logger = get_logger(name=__name__, category="inference::openai")
# | completion | LiteLLMOpenAIMixin | # | completion | LiteLLMOpenAIMixin |
# | chat_completion | LiteLLMOpenAIMixin | # | chat_completion | LiteLLMOpenAIMixin |
# | embedding | LiteLLMOpenAIMixin | # | embedding | LiteLLMOpenAIMixin |
# | batch_completion | LiteLLMOpenAIMixin |
# | batch_chat_completion | LiteLLMOpenAIMixin |
# | openai_completion | OpenAIMixin | # | openai_completion | OpenAIMixin |
# | openai_chat_completion | OpenAIMixin | # | openai_chat_completion | OpenAIMixin |
# | openai_embeddings | OpenAIMixin | # | openai_embeddings | OpenAIMixin |

View file

@ -4,12 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.inference import Inference
from .config import SambaNovaImplConfig from .config import SambaNovaImplConfig
async def get_adapter_impl(config: SambaNovaImplConfig, _deps) -> Inference: async def get_adapter_impl(config: SambaNovaImplConfig, _deps):
from .sambanova import SambaNovaInferenceAdapter from .sambanova import SambaNovaInferenceAdapter
assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}"

View file

@ -25,7 +25,7 @@ class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: SambaNovaImplConfig): def __init__(self, config: SambaNovaImplConfig):
self.config = config self.config = config
self.environment_available_models = [] self.environment_available_models: list[str] = []
LiteLLMOpenAIMixin.__init__( LiteLLMOpenAIMixin.__init__(
self, self,
litellm_provider_name="sambanova", litellm_provider_name="sambanova",

View file

@ -70,6 +70,7 @@ class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Need
} }
def __init__(self, config: TogetherImplConfig) -> None: def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self)
self.config = config self.config = config
self.allowed_models = config.allowed_models self.allowed_models = config.allowed_models
self._model_cache: dict[str, Model] = {} self._model_cache: dict[str, Model] = {}

View file

@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="providers::utils")
class RemoteInferenceProviderConfig(BaseModel): class RemoteInferenceProviderConfig(BaseModel):
allowed_models: list[str] | None = Field( allowed_models: list[str] | None = Field( # TODO: make this non-optional and give a list() default
default=None, default=None,
description="List of models that should be registered with the model registry. If None, all models are allowed.", description="List of models that should be registered with the model registry. If None, all models are allowed.",
) )

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import base64
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
@ -26,6 +27,7 @@ from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
logger = get_logger(name=__name__, category="providers::utils") logger = get_logger(name=__name__, category="providers::utils")
@ -51,6 +53,10 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
# This is useful for providers that do not return a unique id in the response. # This is useful for providers that do not return a unique id in the response.
overwrite_completion_id: bool = False overwrite_completion_id: bool = False
# Allow subclasses to control whether to download images and convert to base64
# for providers that require base64 encoded images instead of URLs.
download_images: bool = False
# Embedding model metadata for this provider # Embedding model metadata for this provider
# Can be set by subclasses or instances to provide embedding models # Can be set by subclasses or instances to provide embedding models
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}} # Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
@ -239,6 +245,24 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
""" """
Direct OpenAI chat completion API call. Direct OpenAI chat completion API call.
""" """
if self.download_images:
async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam:
if isinstance(m.content, list):
for c in m.content:
if c.type == "image_url" and c.image_url and c.image_url.url and "http" in c.image_url.url:
localize_result = await localize_image_content(c.image_url.url)
if localize_result is None:
raise ValueError(
f"Failed to localize image content from {c.image_url.url[:42]}{'...' if len(c.image_url.url) > 42 else ''}"
)
content, format = localize_result
c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}"
# else it's a string and we don't need to modify it
return m
messages = [await _localize_image_url(m) for m in messages]
resp = await self.client.chat.completions.create( resp = await self.client.chat.completions.create(
**await prepare_openai_completion_params( **await prepare_openai_completion_params(
model=await self._get_provider_model_id(model), model=await self._get_provider_model_id(model),

View file

@ -28,7 +28,7 @@ class CommonConfig(BaseModel):
class RedisKVStoreConfig(CommonConfig): class RedisKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.redis.value] = KVStoreType.redis.value type: Literal["redis"] = KVStoreType.redis.value
host: str = "localhost" host: str = "localhost"
port: int = 6379 port: int = 6379
@ -50,7 +50,7 @@ class RedisKVStoreConfig(CommonConfig):
class SqliteKVStoreConfig(CommonConfig): class SqliteKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.sqlite.value] = KVStoreType.sqlite.value type: Literal["sqlite"] = KVStoreType.sqlite.value
db_path: str = Field( db_path: str = Field(
default=(RUNTIME_BASE_DIR / "kvstore.db").as_posix(), default=(RUNTIME_BASE_DIR / "kvstore.db").as_posix(),
description="File path for the sqlite database", description="File path for the sqlite database",
@ -69,7 +69,7 @@ class SqliteKVStoreConfig(CommonConfig):
class PostgresKVStoreConfig(CommonConfig): class PostgresKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value type: Literal["postgres"] = KVStoreType.postgres.value
host: str = "localhost" host: str = "localhost"
port: int = 5432 port: int = 5432
db: str = "llamastack" db: str = "llamastack"
@ -113,11 +113,11 @@ class PostgresKVStoreConfig(CommonConfig):
class MongoDBKVStoreConfig(CommonConfig): class MongoDBKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.mongodb.value] = KVStoreType.mongodb.value type: Literal["mongodb"] = KVStoreType.mongodb.value
host: str = "localhost" host: str = "localhost"
port: int = 27017 port: int = 27017
db: str = "llamastack" db: str = "llamastack"
user: str = None user: str | None = None
password: str | None = None password: str | None = None
collection_name: str = "llamastack_kvstore" collection_name: str = "llamastack_kvstore"

View file

@ -7,6 +7,7 @@
from datetime import datetime from datetime import datetime
from pymongo import AsyncMongoClient from pymongo import AsyncMongoClient
from pymongo.asynchronous.collection import AsyncCollection
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
@ -19,8 +20,13 @@ log = get_logger(name=__name__, category="providers::utils")
class MongoDBKVStoreImpl(KVStore): class MongoDBKVStoreImpl(KVStore):
def __init__(self, config: MongoDBKVStoreConfig): def __init__(self, config: MongoDBKVStoreConfig):
self.config = config self.config = config
self.conn = None self.conn: AsyncMongoClient | None = None
self.collection = None
@property
def collection(self) -> AsyncCollection:
if self.conn is None:
raise RuntimeError("MongoDB connection is not initialized")
return self.conn[self.config.db][self.config.collection_name]
async def initialize(self) -> None: async def initialize(self) -> None:
try: try:
@ -32,7 +38,6 @@ class MongoDBKVStoreImpl(KVStore):
} }
conn_creds = {k: v for k, v in conn_creds.items() if v is not None} conn_creds = {k: v for k, v in conn_creds.items() if v is not None}
self.conn = AsyncMongoClient(**conn_creds) self.conn = AsyncMongoClient(**conn_creds)
self.collection = self.conn[self.config.db][self.config.collection_name]
except Exception as e: except Exception as e:
log.exception("Could not connect to MongoDB database server") log.exception("Could not connect to MongoDB database server")
raise RuntimeError("Could not connect to MongoDB database server") from e raise RuntimeError("Could not connect to MongoDB database server") from e

View file

@ -9,9 +9,13 @@ from datetime import datetime
import aiosqlite import aiosqlite
from llama_stack.log import get_logger
from ..api import KVStore from ..api import KVStore
from ..config import SqliteKVStoreConfig from ..config import SqliteKVStoreConfig
logger = get_logger(name=__name__, category="providers::utils")
class SqliteKVStoreImpl(KVStore): class SqliteKVStoreImpl(KVStore):
def __init__(self, config: SqliteKVStoreConfig): def __init__(self, config: SqliteKVStoreConfig):
@ -50,6 +54,9 @@ class SqliteKVStoreImpl(KVStore):
if row is None: if row is None:
return None return None
value, expiration = row value, expiration = row
if not isinstance(value, str):
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
return None
return value return value
async def delete(self, key: str) -> None: async def delete(self, key: str) -> None:

View file

@ -18,7 +18,7 @@
"class-variance-authority": "^0.7.1", "class-variance-authority": "^0.7.1",
"clsx": "^2.1.1", "clsx": "^2.1.1",
"framer-motion": "^12.23.12", "framer-motion": "^12.23.12",
"llama-stack-client": "^0.2.22", "llama-stack-client": "^0.2.23",
"lucide-react": "^0.542.0", "lucide-react": "^0.542.0",
"next": "15.5.3", "next": "15.5.3",
"next-auth": "^4.24.11", "next-auth": "^4.24.11",
@ -10172,9 +10172,9 @@
"license": "MIT" "license": "MIT"
}, },
"node_modules/llama-stack-client": { "node_modules/llama-stack-client": {
"version": "0.2.22", "version": "0.2.23",
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.22.tgz", "resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.23.tgz",
"integrity": "sha512-7aW3UQj5MwjV73Brd+yQ1e4W1W33nhozyeHM5tzOgbsVZ88tL78JNiNvyFqDR5w6V9XO4/uSGGiQVG6v83yR4w==", "integrity": "sha512-J3YFH1HW2K70capejQxGlCyTgKdfx+sQf8Ab+HFi1j2Q00KtpHXB79RxejvBxjWC3X2E++P9iU57KdU2Tp/rIQ==",
"license": "MIT", "license": "MIT",
"dependencies": { "dependencies": {
"@types/node": "^18.11.18", "@types/node": "^18.11.18",

View file

@ -23,7 +23,7 @@
"class-variance-authority": "^0.7.1", "class-variance-authority": "^0.7.1",
"clsx": "^2.1.1", "clsx": "^2.1.1",
"framer-motion": "^12.23.12", "framer-motion": "^12.23.12",
"llama-stack-client": "^0.2.22", "llama-stack-client": "^0.2.23",
"lucide-react": "^0.542.0", "lucide-react": "^0.542.0",
"next": "15.5.3", "next": "15.5.3",
"next-auth": "^4.24.11", "next-auth": "^4.24.11",

View file

@ -7,7 +7,7 @@ required-version = ">=0.7.0"
[project] [project]
name = "llama_stack" name = "llama_stack"
version = "0.2.22" version = "0.2.23"
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }] authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
description = "Llama Stack" description = "Llama Stack"
readme = "README.md" readme = "README.md"
@ -31,7 +31,7 @@ dependencies = [
"huggingface-hub>=0.34.0,<1.0", "huggingface-hub>=0.34.0,<1.0",
"jinja2>=3.1.6", "jinja2>=3.1.6",
"jsonschema", "jsonschema",
"llama-stack-client>=0.2.22", "llama-stack-client>=0.2.23",
"openai>=1.100.0", # for expires_after support "openai>=1.100.0", # for expires_after support
"prompt-toolkit", "prompt-toolkit",
"python-dotenv", "python-dotenv",
@ -55,7 +55,7 @@ dependencies = [
ui = [ ui = [
"streamlit", "streamlit",
"pandas", "pandas",
"llama-stack-client>=0.2.22", "llama-stack-client>=0.2.23",
"streamlit-option-menu", "streamlit-option-menu",
] ]
@ -259,15 +259,12 @@ exclude = [
"^llama_stack/models/llama/llama3/tokenizer\\.py$", "^llama_stack/models/llama/llama3/tokenizer\\.py$",
"^llama_stack/models/llama/llama3/tool_utils\\.py$", "^llama_stack/models/llama/llama3/tool_utils\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/", "^llama_stack/providers/inline/agents/meta_reference/",
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
"^llama_stack/providers/inline/datasetio/localfs/", "^llama_stack/providers/inline/datasetio/localfs/",
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$", "^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$", "^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
"^llama_stack/models/llama/llama3/generation\\.py$", "^llama_stack/models/llama/llama3/generation\\.py$",
"^llama_stack/models/llama/llama3/multimodal/model\\.py$", "^llama_stack/models/llama/llama3/multimodal/model\\.py$",
"^llama_stack/models/llama/llama4/", "^llama_stack/models/llama/llama4/",
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$", "^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
"^llama_stack/providers/inline/post_training/common/validator\\.py$", "^llama_stack/providers/inline/post_training/common/validator\\.py$",
"^llama_stack/providers/inline/safety/code_scanner/", "^llama_stack/providers/inline/safety/code_scanner/",
@ -278,19 +275,13 @@ exclude = [
"^llama_stack/providers/remote/agents/sample/", "^llama_stack/providers/remote/agents/sample/",
"^llama_stack/providers/remote/datasetio/huggingface/", "^llama_stack/providers/remote/datasetio/huggingface/",
"^llama_stack/providers/remote/datasetio/nvidia/", "^llama_stack/providers/remote/datasetio/nvidia/",
"^llama_stack/providers/remote/inference/anthropic/",
"^llama_stack/providers/remote/inference/bedrock/", "^llama_stack/providers/remote/inference/bedrock/",
"^llama_stack/providers/remote/inference/cerebras/", "^llama_stack/providers/remote/inference/cerebras/",
"^llama_stack/providers/remote/inference/databricks/", "^llama_stack/providers/remote/inference/databricks/",
"^llama_stack/providers/remote/inference/fireworks/", "^llama_stack/providers/remote/inference/fireworks/",
"^llama_stack/providers/remote/inference/gemini/",
"^llama_stack/providers/remote/inference/groq/",
"^llama_stack/providers/remote/inference/nvidia/", "^llama_stack/providers/remote/inference/nvidia/",
"^llama_stack/providers/remote/inference/openai/",
"^llama_stack/providers/remote/inference/passthrough/", "^llama_stack/providers/remote/inference/passthrough/",
"^llama_stack/providers/remote/inference/runpod/", "^llama_stack/providers/remote/inference/runpod/",
"^llama_stack/providers/remote/inference/sambanova/",
"^llama_stack/providers/remote/inference/sample/",
"^llama_stack/providers/remote/inference/tgi/", "^llama_stack/providers/remote/inference/tgi/",
"^llama_stack/providers/remote/inference/together/", "^llama_stack/providers/remote/inference/together/",
"^llama_stack/providers/remote/inference/watsonx/", "^llama_stack/providers/remote/inference/watsonx/",
@ -310,7 +301,6 @@ exclude = [
"^llama_stack/providers/remote/vector_io/qdrant/", "^llama_stack/providers/remote/vector_io/qdrant/",
"^llama_stack/providers/remote/vector_io/sample/", "^llama_stack/providers/remote/vector_io/sample/",
"^llama_stack/providers/remote/vector_io/weaviate/", "^llama_stack/providers/remote/vector_io/weaviate/",
"^llama_stack/providers/tests/conftest\\.py$",
"^llama_stack/providers/utils/bedrock/client\\.py$", "^llama_stack/providers/utils/bedrock/client\\.py$",
"^llama_stack/providers/utils/bedrock/refreshable_boto_session\\.py$", "^llama_stack/providers/utils/bedrock/refreshable_boto_session\\.py$",
"^llama_stack/providers/utils/inference/embedding_mixin\\.py$", "^llama_stack/providers/utils/inference/embedding_mixin\\.py$",
@ -318,12 +308,9 @@ exclude = [
"^llama_stack/providers/utils/inference/model_registry\\.py$", "^llama_stack/providers/utils/inference/model_registry\\.py$",
"^llama_stack/providers/utils/inference/openai_compat\\.py$", "^llama_stack/providers/utils/inference/openai_compat\\.py$",
"^llama_stack/providers/utils/inference/prompt_adapter\\.py$", "^llama_stack/providers/utils/inference/prompt_adapter\\.py$",
"^llama_stack/providers/utils/kvstore/config\\.py$",
"^llama_stack/providers/utils/kvstore/kvstore\\.py$", "^llama_stack/providers/utils/kvstore/kvstore\\.py$",
"^llama_stack/providers/utils/kvstore/mongodb/mongodb\\.py$",
"^llama_stack/providers/utils/kvstore/postgres/postgres\\.py$", "^llama_stack/providers/utils/kvstore/postgres/postgres\\.py$",
"^llama_stack/providers/utils/kvstore/redis/redis\\.py$", "^llama_stack/providers/utils/kvstore/redis/redis\\.py$",
"^llama_stack/providers/utils/kvstore/sqlite/sqlite\\.py$",
"^llama_stack/providers/utils/memory/vector_store\\.py$", "^llama_stack/providers/utils/memory/vector_store\\.py$",
"^llama_stack/providers/utils/scoring/aggregation_utils\\.py$", "^llama_stack/providers/utils/scoring/aggregation_utils\\.py$",
"^llama_stack/providers/utils/scoring/base_scoring_fn\\.py$", "^llama_stack/providers/utils/scoring/base_scoring_fn\\.py$",
@ -331,13 +318,6 @@ exclude = [
"^llama_stack/providers/utils/telemetry/trace_protocol\\.py$", "^llama_stack/providers/utils/telemetry/trace_protocol\\.py$",
"^llama_stack/providers/utils/telemetry/tracing\\.py$", "^llama_stack/providers/utils/telemetry/tracing\\.py$",
"^llama_stack/strong_typing/auxiliary\\.py$", "^llama_stack/strong_typing/auxiliary\\.py$",
"^llama_stack/strong_typing/deserializer\\.py$",
"^llama_stack/strong_typing/inspection\\.py$",
"^llama_stack/strong_typing/schema\\.py$",
"^llama_stack/strong_typing/serializer\\.py$",
"^llama_stack/distributions/groq/groq\\.py$",
"^llama_stack/distributions/llama_api/llama_api\\.py$",
"^llama_stack/distributions/sambanova/sambanova\\.py$",
"^llama_stack/distributions/template\\.py$", "^llama_stack/distributions/template\\.py$",
] ]

View file

@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import pathlib
import pytest
@pytest.fixture
def image_path():
return pathlib.Path(__file__).parent / "dog.png"
@pytest.fixture
def base64_image_data(image_path):
return base64.b64encode(image_path.read_bytes()).decode("utf-8")
async def test_openai_chat_completion_image_url(openai_client, vision_model_id):
message = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/integration/inference/dog.png"
},
},
{
"type": "text",
"text": "Describe what is in this image.",
},
],
}
response = openai_client.chat.completions.create(
model=vision_model_id,
messages=[message],
stream=False,
)
message_content = response.choices[0].message.content.lower().strip()
assert len(message_content) > 0
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
async def test_openai_chat_completion_image_data(openai_client, vision_model_id, base64_image_data):
message = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image_data}",
},
},
{
"type": "text",
"text": "Describe what is in this image.",
},
],
}
response = openai_client.chat.completions.create(
model=vision_model_id,
messages=[message],
stream=False,
)
message_content = response.choices[0].message.content.lower().strip()
assert len(message_content) > 0
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})

View file

@ -10,6 +10,7 @@ from unittest.mock import AsyncMock
import pytest import pytest
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api from llama_stack.apis.datatypes import Api
@ -645,3 +646,25 @@ async def test_models_source_interaction_cleanup_provider_models(cached_disk_dis
# Cleanup # Cleanup
await table.shutdown() await table.shutdown()
async def test_tool_groups_routing_table_exception_handling(cached_disk_dist_registry):
"""Test that the tool group routing table handles exceptions when listing tools, like if an MCP server is unreachable."""
exception_throwing_tool_groups_impl = ToolGroupsImpl()
exception_throwing_tool_groups_impl.list_runtime_tools = AsyncMock(side_effect=Exception("Test exception"))
table = ToolGroupsRoutingTable(
{"test_provider": exception_throwing_tool_groups_impl}, cached_disk_dist_registry, {}
)
await table.initialize()
await table.register_tool_group(
toolgroup_id="test-toolgroup-exceptions",
provider_id="test_provider",
mcp_endpoint=URL(uri="http://localhost:8479/foo/bar"),
)
tools = await table.list_tools(toolgroup_id="test-toolgroup-exceptions")
assert len(tools.data) == 0

View file

@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from unittest.mock import MagicMock, PropertyMock, patch from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import pytest import pytest
from llama_stack.apis.inference import Model from llama_stack.apis.inference import Model, OpenAIUserMessageParam
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -43,8 +43,17 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin):
@pytest.fixture @pytest.fixture
def mixin(): def mixin():
"""Create a test instance of OpenAIMixin""" """Create a test instance of OpenAIMixin with mocked model_store"""
return OpenAIMixinImpl() mixin_instance = OpenAIMixinImpl()
# just enough to satisfy _get_provider_model_id calls
mock_model_store = MagicMock()
mock_model = MagicMock()
mock_model.provider_resource_id = "test-provider-resource-id"
mock_model_store.get_model = AsyncMock(return_value=mock_model)
mixin_instance.model_store = mock_model_store
return mixin_instance
@pytest.fixture @pytest.fixture
@ -205,6 +214,74 @@ class TestOpenAIMixinCacheBehavior:
assert "final-mock-model-id" in mixin._model_cache assert "final-mock-model-id" in mixin._model_cache
class TestOpenAIMixinImagePreprocessing:
"""Test cases for image preprocessing functionality"""
async def test_openai_chat_completion_with_image_preprocessing_enabled(self, mixin):
"""Test that image URLs are converted to base64 when download_images is True"""
mixin.download_images = True
message = OpenAIUserMessageParam(
role="user",
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
],
)
mock_client = MagicMock()
mock_response = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
mock_localize.return_value = (b"fake_image_data", "jpeg")
await mixin.openai_chat_completion(model="test-model", messages=[message])
mock_localize.assert_called_once_with("http://example.com/image.jpg")
mock_client.chat.completions.create.assert_called_once()
call_args = mock_client.chat.completions.create.call_args
processed_messages = call_args[1]["messages"]
assert len(processed_messages) == 1
content = processed_messages[0]["content"]
assert len(content) == 2
assert content[0]["type"] == "text"
assert content[1]["type"] == "image_url"
assert content[1]["image_url"]["url"] == "data:image/jpeg;base64,ZmFrZV9pbWFnZV9kYXRh"
async def test_openai_chat_completion_with_image_preprocessing_disabled(self, mixin):
"""Test that image URLs are not modified when download_images is False"""
mixin.download_images = False # explicitly set to False
message = OpenAIUserMessageParam(
role="user",
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
],
)
mock_client = MagicMock()
mock_response = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
await mixin.openai_chat_completion(model="test-model", messages=[message])
mock_localize.assert_not_called()
mock_client.chat.completions.create.assert_called_once()
call_args = mock_client.chat.completions.create.call_args
processed_messages = call_args[1]["messages"]
assert len(processed_messages) == 1
content = processed_messages[0]["content"]
assert len(content) == 2
assert content[1]["image_url"]["url"] == "http://example.com/image.jpg"
class TestOpenAIMixinEmbeddingModelMetadata: class TestOpenAIMixinEmbeddingModelMetadata:
"""Test cases for embedding_model_metadata attribute functionality""" """Test cases for embedding_model_metadata attribute functionality"""

View file

@ -129,7 +129,7 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2") result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
assert result is not None assert result is not None
assert result.embedding_model == duplicate_vector_db.embedding_model # Original values preserved assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
async def test_get_all_objects(cached_disk_dist_registry): async def test_get_all_objects(cached_disk_dist_registry):
@ -174,14 +174,10 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore):
) )
await sqlite_kvstore.set( await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()
valid_db.model_dump_json(),
) )
await sqlite_kvstore.set( await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"),
"{not valid json",
)
await sqlite_kvstore.set( await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"), KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
@ -216,8 +212,7 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
) )
await sqlite_kvstore.set( await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()
valid_db.model_dump_json(),
) )
await sqlite_kvstore.set( await sqlite_kvstore.set(

12
uv.lock generated
View file

@ -1749,7 +1749,7 @@ wheels = [
[[package]] [[package]]
name = "llama-stack" name = "llama-stack"
version = "0.2.22" version = "0.2.23"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "aiohttp" }, { name = "aiohttp" },
@ -1885,8 +1885,8 @@ requires-dist = [
{ name = "huggingface-hub", specifier = ">=0.34.0,<1.0" }, { name = "huggingface-hub", specifier = ">=0.34.0,<1.0" },
{ name = "jinja2", specifier = ">=3.1.6" }, { name = "jinja2", specifier = ">=3.1.6" },
{ name = "jsonschema" }, { name = "jsonschema" },
{ name = "llama-stack-client", specifier = ">=0.2.22" }, { name = "llama-stack-client", specifier = ">=0.2.23" },
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.22" }, { name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.23" },
{ name = "openai", specifier = ">=1.100.0" }, { name = "openai", specifier = ">=1.100.0" },
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" }, { name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" },
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" }, { name = "opentelemetry-sdk", specifier = ">=1.30.0" },
@ -1993,7 +1993,7 @@ unit = [
[[package]] [[package]]
name = "llama-stack-client" name = "llama-stack-client"
version = "0.2.22" version = "0.2.23"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "anyio" }, { name = "anyio" },
@ -2012,9 +2012,9 @@ dependencies = [
{ name = "tqdm" }, { name = "tqdm" },
{ name = "typing-extensions" }, { name = "typing-extensions" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/60/80/4260816bfaaa889d515206c9df4906d08d405bf94c9b4d1be399b1923e46/llama_stack_client-0.2.22.tar.gz", hash = "sha256:9a0bc756b91ebd539858eeaf1f231c5e5c6900e1ea4fcced726c6717f3d27ca7", size = 318309, upload-time = "2025-09-16T19:43:33.212Z" } sdist = { url = "https://files.pythonhosted.org/packages/9f/8f/306d5fcf2f97b3a6251219b03c194836a2ff4e0fcc8146c9970e50a72cd3/llama_stack_client-0.2.23.tar.gz", hash = "sha256:68f34e8ac8eea6a73ed9d4977d849992b2d8bd835804d770a11843431cd5bf74", size = 322288, upload-time = "2025-09-26T21:11:08.342Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/d1/8e/1ebf6ac0dbb62b81038e856ed00768e283d927b14fcd614e3018a227092b/llama_stack_client-0.2.22-py3-none-any.whl", hash = "sha256:b260d73aec56fcfd8fa601b3b34c2f83c4fbcfb7261a246b02bbdf6c2da184fe", size = 369901, upload-time = "2025-09-16T19:43:32.089Z" }, { url = "https://files.pythonhosted.org/packages/fa/75/3eb58e092a681804013dbec7b7f549d18f55acf6fd6e6b27de7e249766d8/llama_stack_client-0.2.23-py3-none-any.whl", hash = "sha256:eee42c74eee8f218f9455e5a06d5d4be43f8a8c82a7937ef51ce367f916df847", size = 379809, upload-time = "2025-09-26T21:11:06.856Z" },
] ]
[[package]] [[package]]