mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Merge branch 'main' into hide-non-openai-inference-apis
This commit is contained in:
commit
0e78cd5383
33 changed files with 2394 additions and 1723 deletions
|
@ -139,18 +139,7 @@ Methods:
|
||||||
- <code title="post /v1/agents/{agent_id}/session/{session_id}/turn">client.agents.turn.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/turn.py">create</a>(session_id, \*, agent_id, \*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/turn_create_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/turn_create_response.py">TurnCreateResponse</a></code>
|
- <code title="post /v1/agents/{agent_id}/session/{session_id}/turn">client.agents.turn.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/turn.py">create</a>(session_id, \*, agent_id, \*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/turn_create_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/turn_create_response.py">TurnCreateResponse</a></code>
|
||||||
- <code title="get /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}">client.agents.turn.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/turn.py">retrieve</a>(turn_id, \*, agent_id, session_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/turn.py">Turn</a></code>
|
- <code title="get /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}">client.agents.turn.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/agents/turn.py">retrieve</a>(turn_id, \*, agent_id, session_id) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/agents/turn.py">Turn</a></code>
|
||||||
|
|
||||||
## BatchInference
|
|
||||||
|
|
||||||
Types:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from llama_stack_client.types import BatchInferenceChatCompletionResponse
|
|
||||||
```
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
|
|
||||||
- <code title="post /v1/batch-inference/chat-completion">client.batch_inference.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/batch_inference.py">chat_completion</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/batch_inference_chat_completion_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/batch_inference_chat_completion_response.py">BatchInferenceChatCompletionResponse</a></code>
|
|
||||||
- <code title="post /v1/batch-inference/completion">client.batch_inference.<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/resources/batch_inference.py">completion</a>(\*\*<a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/batch_inference_completion_params.py">params</a>) -> <a href="https://github.com/meta-llama/llama-stack-client-python/tree/main/src/llama_stack_client/types/shared/batch_completion.py">BatchCompletion</a></code>
|
|
||||||
|
|
||||||
## Datasets
|
## Datasets
|
||||||
|
|
||||||
|
|
|
@ -548,7 +548,6 @@ class Generator:
|
||||||
if op.defining_class.__name__ in [
|
if op.defining_class.__name__ in [
|
||||||
"SyntheticDataGeneration",
|
"SyntheticDataGeneration",
|
||||||
"PostTraining",
|
"PostTraining",
|
||||||
"BatchInference",
|
|
||||||
]:
|
]:
|
||||||
op.defining_class.__name__ = f"{op.defining_class.__name__} (Coming Soon)"
|
op.defining_class.__name__ = f"{op.defining_class.__name__} (Coming Soon)"
|
||||||
print(op.defining_class.__name__)
|
print(op.defining_class.__name__)
|
||||||
|
|
1927
docs/static/llama-stack-spec.html
vendored
1927
docs/static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
1480
docs/static/llama-stack-spec.yaml
vendored
1480
docs/static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
|
@ -1,7 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from .batch_inference import *
|
|
|
@ -1,79 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Protocol, runtime_checkable
|
|
||||||
|
|
||||||
from llama_stack.apis.common.job_types import Job
|
|
||||||
from llama_stack.apis.inference import (
|
|
||||||
InterleavedContent,
|
|
||||||
LogProbConfig,
|
|
||||||
Message,
|
|
||||||
ResponseFormat,
|
|
||||||
SamplingParams,
|
|
||||||
ToolChoice,
|
|
||||||
ToolDefinition,
|
|
||||||
ToolPromptFormat,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
|
||||||
from llama_stack.schema_utils import webmethod
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class BatchInference(Protocol):
|
|
||||||
"""Batch inference API for generating completions and chat completions.
|
|
||||||
|
|
||||||
This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion.
|
|
||||||
|
|
||||||
NOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs
|
|
||||||
including (post-training, evals, etc).
|
|
||||||
"""
|
|
||||||
|
|
||||||
@webmethod(route="/batch-inference/completion", method="POST", level=LLAMA_STACK_API_V1)
|
|
||||||
async def completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
content_batch: list[InterleavedContent],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
) -> Job:
|
|
||||||
"""Generate completions for a batch of content.
|
|
||||||
|
|
||||||
:param model: The model to use for the completion.
|
|
||||||
:param content_batch: The content to complete.
|
|
||||||
:param sampling_params: The sampling parameters to use for the completion.
|
|
||||||
:param response_format: The response format to use for the completion.
|
|
||||||
:param logprobs: The logprobs to use for the completion.
|
|
||||||
:returns: A job for the completion.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/batch-inference/chat-completion", method="POST", level=LLAMA_STACK_API_V1)
|
|
||||||
async def chat_completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages_batch: list[list[Message]],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
# zero-shot tool definitions as input to the model
|
|
||||||
tools: list[ToolDefinition] | None = None,
|
|
||||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
|
||||||
tool_prompt_format: ToolPromptFormat | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
) -> Job:
|
|
||||||
"""Generate chat completions for a batch of messages.
|
|
||||||
|
|
||||||
:param model: The model to use for the chat completion.
|
|
||||||
:param messages_batch: The messages to complete.
|
|
||||||
:param sampling_params: The sampling parameters to use for the completion.
|
|
||||||
:param tools: The tools to use for the chat completion.
|
|
||||||
:param tool_choice: The tool choice to use for the chat completion.
|
|
||||||
:param tool_prompt_format: The tool prompt format to use for the chat completion.
|
|
||||||
:param response_format: The response format to use for the chat completion.
|
|
||||||
:param logprobs: The logprobs to use for the chat completion.
|
|
||||||
:returns: A job for the chat completion.
|
|
||||||
"""
|
|
||||||
...
|
|
|
@ -914,6 +914,7 @@ class OpenAIEmbeddingData(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
object: Literal["embedding"] = "embedding"
|
object: Literal["embedding"] = "embedding"
|
||||||
|
# TODO: consider dropping str and using openai.types.embeddings.Embedding instead of OpenAIEmbeddingData
|
||||||
embedding: list[float] | str
|
embedding: list[float] | str
|
||||||
index: int
|
index: int
|
||||||
|
|
||||||
|
@ -974,26 +975,6 @@ class EmbeddingTaskType(Enum):
|
||||||
document = "document"
|
document = "document"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class BatchCompletionResponse(BaseModel):
|
|
||||||
"""Response from a batch completion request.
|
|
||||||
|
|
||||||
:param batch: List of completion responses, one for each input in the batch
|
|
||||||
"""
|
|
||||||
|
|
||||||
batch: list[CompletionResponse]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class BatchChatCompletionResponse(BaseModel):
|
|
||||||
"""Response from a batch chat completion request.
|
|
||||||
|
|
||||||
:param batch: List of chat completion responses, one for each conversation in the batch
|
|
||||||
"""
|
|
||||||
|
|
||||||
batch: list[ChatCompletionResponse]
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
|
class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
|
||||||
input_messages: list[OpenAIMessageParam]
|
input_messages: list[OpenAIMessageParam]
|
||||||
|
|
||||||
|
@ -1049,26 +1030,7 @@ class InferenceProvider(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
async def batch_completion(
|
@webmethod(route="/inference/chat-completion", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
content_batch: list[InterleavedContent],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
) -> BatchCompletionResponse:
|
|
||||||
"""Generate completions for a batch of content using the specified model.
|
|
||||||
|
|
||||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
|
||||||
:param content_batch: The content to generate completions for.
|
|
||||||
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
|
||||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
|
||||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
|
||||||
:returns: A BatchCompletionResponse with the full completions.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Batch completion is not implemented")
|
|
||||||
return # this is so mypy's safe-super rule will consider the method concrete
|
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -1108,30 +1070,7 @@ class InferenceProvider(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
async def batch_chat_completion(
|
@webmethod(route="/inference/embeddings", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
messages_batch: list[list[Message]],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
tools: list[ToolDefinition] | None = None,
|
|
||||||
tool_config: ToolConfig | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
) -> BatchChatCompletionResponse:
|
|
||||||
"""Generate chat completions for a batch of messages using the specified model.
|
|
||||||
|
|
||||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
|
||||||
:param messages_batch: The messages to generate completions for.
|
|
||||||
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
|
||||||
:param tools: (Optional) List of tool definitions available to the model.
|
|
||||||
:param tool_config: (Optional) Configuration for tool use.
|
|
||||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
|
||||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
|
||||||
:returns: A BatchChatCompletionResponse with the full completions.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Batch chat completion is not implemented")
|
|
||||||
return # this is so mypy's safe-super rule will consider the method concrete
|
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
|
|
@ -20,8 +20,6 @@ from llama_stack.apis.common.content_types import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
BatchChatCompletionResponse,
|
|
||||||
BatchCompletionResponse,
|
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
|
@ -273,30 +271,6 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def batch_chat_completion(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
messages_batch: list[list[Message]],
|
|
||||||
tools: list[ToolDefinition] | None = None,
|
|
||||||
tool_config: ToolConfig | None = None,
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
) -> BatchChatCompletionResponse:
|
|
||||||
logger.debug(
|
|
||||||
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
|
||||||
)
|
|
||||||
provider = await self.routing_table.get_provider_impl(model_id)
|
|
||||||
return await provider.batch_chat_completion(
|
|
||||||
model_id=model_id,
|
|
||||||
messages_batch=messages_batch,
|
|
||||||
tools=tools,
|
|
||||||
tool_config=tool_config,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
response_format=response_format,
|
|
||||||
logprobs=logprobs,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -338,20 +312,6 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def batch_completion(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
content_batch: list[InterleavedContent],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
) -> BatchCompletionResponse:
|
|
||||||
logger.debug(
|
|
||||||
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
|
||||||
)
|
|
||||||
provider = await self.routing_table.get_provider_impl(model_id)
|
|
||||||
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
|
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
|
|
@ -9,7 +9,7 @@ from typing import Any
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.errors import ToolGroupNotFoundError
|
from llama_stack.apis.common.errors import ToolGroupNotFoundError
|
||||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
||||||
from llama_stack.core.datatypes import ToolGroupWithOwner
|
from llama_stack.core.datatypes import AuthenticationRequiredError, ToolGroupWithOwner
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
@ -54,7 +54,18 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
all_tools = []
|
all_tools = []
|
||||||
for toolgroup in toolgroups:
|
for toolgroup in toolgroups:
|
||||||
if toolgroup.identifier not in self.toolgroups_to_tools:
|
if toolgroup.identifier not in self.toolgroups_to_tools:
|
||||||
await self._index_tools(toolgroup)
|
try:
|
||||||
|
await self._index_tools(toolgroup)
|
||||||
|
except AuthenticationRequiredError:
|
||||||
|
# Send authentication errors back to the client so it knows
|
||||||
|
# that it needs to supply credentials for remote MCP servers.
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# Other errors that the client cannot fix are logged and
|
||||||
|
# those specific toolgroups are skipped.
|
||||||
|
logger.warning(f"Error listing tools for toolgroup {toolgroup.identifier}: {e}")
|
||||||
|
logger.debug(e, exc_info=True)
|
||||||
|
continue
|
||||||
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
|
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
|
||||||
|
|
||||||
return ListToolsResponse(data=all_tools)
|
return ListToolsResponse(data=all_tools)
|
||||||
|
|
|
@ -14,7 +14,6 @@ from typing import Any
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.batch_inference import BatchInference
|
|
||||||
from llama_stack.apis.benchmarks import Benchmarks
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
@ -54,7 +53,6 @@ class LlamaStack(
|
||||||
Providers,
|
Providers,
|
||||||
VectorDBs,
|
VectorDBs,
|
||||||
Inference,
|
Inference,
|
||||||
BatchInference,
|
|
||||||
Agents,
|
Agents,
|
||||||
Safety,
|
Safety,
|
||||||
SyntheticDataGeneration,
|
SyntheticDataGeneration,
|
||||||
|
|
|
@ -96,11 +96,9 @@ class DiskDistributionRegistry(DistributionRegistry):
|
||||||
|
|
||||||
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
||||||
existing_obj = await self.get(obj.type, obj.identifier)
|
existing_obj = await self.get(obj.type, obj.identifier)
|
||||||
# warn if the object's providerid is different but proceed with registration
|
# dont register if the object's providerid already exists
|
||||||
if existing_obj and existing_obj.provider_id != obj.provider_id:
|
if existing_obj and existing_obj.provider_id == obj.provider_id:
|
||||||
logger.warning(
|
return False
|
||||||
f"Object {existing_obj.type}:{existing_obj.identifier}'s {existing_obj.provider_id} provider is being replaced with {obj.provider_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
|
||||||
|
|
|
@ -18,8 +18,6 @@ from llama_stack.apis.common.content_types import (
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
BatchChatCompletionResponse,
|
|
||||||
BatchCompletionResponse,
|
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEvent,
|
ChatCompletionResponseEvent,
|
||||||
|
@ -219,41 +217,6 @@ class MetaReferenceInferenceImpl(
|
||||||
results = await self._nonstream_completion([request])
|
results = await self._nonstream_completion([request])
|
||||||
return results[0]
|
return results[0]
|
||||||
|
|
||||||
async def batch_completion(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
content_batch: list[InterleavedContent],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
stream: bool | None = False,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
) -> BatchCompletionResponse:
|
|
||||||
if sampling_params is None:
|
|
||||||
sampling_params = SamplingParams()
|
|
||||||
if logprobs:
|
|
||||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
|
||||||
|
|
||||||
content_batch = [
|
|
||||||
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
|
|
||||||
]
|
|
||||||
|
|
||||||
request_batch = []
|
|
||||||
for content in content_batch:
|
|
||||||
request = CompletionRequest(
|
|
||||||
model=model_id,
|
|
||||||
content=content,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
response_format=response_format,
|
|
||||||
stream=stream,
|
|
||||||
logprobs=logprobs,
|
|
||||||
)
|
|
||||||
self.check_model(request)
|
|
||||||
request = await convert_request_to_raw(request)
|
|
||||||
request_batch.append(request)
|
|
||||||
|
|
||||||
results = await self._nonstream_completion(request_batch)
|
|
||||||
return BatchCompletionResponse(batch=results)
|
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
tokenizer = self.generator.formatter.tokenizer
|
tokenizer = self.generator.formatter.tokenizer
|
||||||
|
|
||||||
|
@ -399,49 +362,6 @@ class MetaReferenceInferenceImpl(
|
||||||
results = await self._nonstream_chat_completion([request])
|
results = await self._nonstream_chat_completion([request])
|
||||||
return results[0]
|
return results[0]
|
||||||
|
|
||||||
async def batch_chat_completion(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
messages_batch: list[list[Message]],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
tools: list[ToolDefinition] | None = None,
|
|
||||||
stream: bool | None = False,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
tool_config: ToolConfig | None = None,
|
|
||||||
) -> BatchChatCompletionResponse:
|
|
||||||
if sampling_params is None:
|
|
||||||
sampling_params = SamplingParams()
|
|
||||||
if logprobs:
|
|
||||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
|
||||||
|
|
||||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
|
||||||
request_batch = []
|
|
||||||
for messages in messages_batch:
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model=model_id,
|
|
||||||
messages=messages,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
tools=tools or [],
|
|
||||||
response_format=response_format,
|
|
||||||
logprobs=logprobs,
|
|
||||||
tool_config=tool_config or ToolConfig(),
|
|
||||||
)
|
|
||||||
self.check_model(request)
|
|
||||||
|
|
||||||
# augment and rewrite messages depending on the model
|
|
||||||
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
|
||||||
# download media and convert to raw content so we can send it to the model
|
|
||||||
request = await convert_request_to_raw(request)
|
|
||||||
request_batch.append(request)
|
|
||||||
|
|
||||||
if self.config.create_distributed_process_group:
|
|
||||||
if SEMAPHORE.locked():
|
|
||||||
raise RuntimeError("Only one concurrent request is supported")
|
|
||||||
|
|
||||||
results = await self._nonstream_chat_completion(request_batch)
|
|
||||||
return BatchChatCompletionResponse(batch=results)
|
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request_batch: list[ChatCompletionRequest]
|
self, request_batch: list[ChatCompletionRequest]
|
||||||
) -> list[ChatCompletionResponse]:
|
) -> list[ChatCompletionResponse]:
|
||||||
|
|
|
@ -24,7 +24,6 @@ from llama_stack.apis.inference import (
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
Model,
|
Model,
|
||||||
ModelType,
|
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
@ -34,6 +33,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
|
|
|
@ -64,6 +64,7 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, config: FireworksImplConfig) -> None:
|
def __init__(self, config: FireworksImplConfig) -> None:
|
||||||
|
ModelRegistryHelper.__init__(self)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.allowed_models = config.allowed_models
|
self.allowed_models = config.allowed_models
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.apis.inference import Inference
|
|
||||||
|
|
||||||
from .config import GroqConfig
|
from .config import GroqConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: GroqConfig, _deps) -> Inference:
|
async def get_adapter_impl(config: GroqConfig, _deps):
|
||||||
# import dynamically so the import is used only when it is needed
|
# import dynamically so the import is used only when it is needed
|
||||||
from .groq import GroqInferenceAdapter
|
from .groq import GroqInferenceAdapter
|
||||||
|
|
||||||
|
|
|
@ -6,8 +6,7 @@
|
||||||
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
from collections.abc import AsyncGenerator
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from ollama import AsyncClient as AsyncOllamaClient
|
from ollama import AsyncClient as AsyncOllamaClient
|
||||||
|
@ -33,10 +32,6 @@ from llama_stack.apis.inference import (
|
||||||
JsonSchemaResponseFormat,
|
JsonSchemaResponseFormat,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
OpenAIChatCompletion,
|
|
||||||
OpenAIChatCompletionChunk,
|
|
||||||
OpenAIMessageParam,
|
|
||||||
OpenAIResponseFormatParam,
|
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
|
@ -62,7 +57,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
OpenAICompatCompletionChoice,
|
OpenAICompatCompletionChoice,
|
||||||
OpenAICompatCompletionResponse,
|
OpenAICompatCompletionResponse,
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
prepare_openai_completion_params,
|
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
process_completion_response,
|
process_completion_response,
|
||||||
|
@ -75,7 +69,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
content_has_media,
|
content_has_media,
|
||||||
convert_image_content_to_url,
|
convert_image_content_to_url,
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
localize_image_content,
|
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -84,6 +77,7 @@ logger = get_logger(name=__name__, category="inference::ollama")
|
||||||
|
|
||||||
class OllamaInferenceAdapter(
|
class OllamaInferenceAdapter(
|
||||||
OpenAIMixin,
|
OpenAIMixin,
|
||||||
|
ModelRegistryHelper,
|
||||||
InferenceProvider,
|
InferenceProvider,
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
):
|
):
|
||||||
|
@ -129,6 +123,8 @@ class OllamaInferenceAdapter(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
# Ollama does not support image urls, so we need to download the image and convert it to base64
|
||||||
|
self.download_images = True
|
||||||
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
|
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -173,9 +169,6 @@ class OllamaInferenceAdapter(
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self._clients.clear()
|
self._clients.clear()
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def _get_model(self, model_id: str) -> Model:
|
async def _get_model(self, model_id: str) -> Model:
|
||||||
if not self.model_store:
|
if not self.model_store:
|
||||||
raise ValueError("Model store not set")
|
raise ValueError("Model store not set")
|
||||||
|
@ -403,75 +396,6 @@ class OllamaInferenceAdapter(
|
||||||
|
|
||||||
raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys()))
|
raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys()))
|
||||||
|
|
||||||
async def openai_chat_completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: list[OpenAIMessageParam],
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
function_call: str | dict[str, Any] | None = None,
|
|
||||||
functions: list[dict[str, Any]] | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
|
||||||
model_obj = await self._get_model(model)
|
|
||||||
|
|
||||||
# Ollama does not support image urls, so we need to download the image and convert it to base64
|
|
||||||
async def _convert_message(m: OpenAIMessageParam) -> OpenAIMessageParam:
|
|
||||||
if isinstance(m.content, list):
|
|
||||||
for c in m.content:
|
|
||||||
if c.type == "image_url" and c.image_url and c.image_url.url:
|
|
||||||
localize_result = await localize_image_content(c.image_url.url)
|
|
||||||
if localize_result is None:
|
|
||||||
raise ValueError(f"Failed to localize image content from {c.image_url.url}")
|
|
||||||
|
|
||||||
content, format = localize_result
|
|
||||||
c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}"
|
|
||||||
return m
|
|
||||||
|
|
||||||
messages = [await _convert_message(m) for m in messages]
|
|
||||||
params = await prepare_openai_completion_params(
|
|
||||||
model=model_obj.provider_resource_id,
|
|
||||||
messages=messages,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
function_call=function_call,
|
|
||||||
functions=functions,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_completion_tokens=max_completion_tokens,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
response_format=response_format,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
tools=tools,
|
|
||||||
top_logprobs=top_logprobs,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
return await OpenAIMixin.openai_chat_completion(self, **params)
|
|
||||||
|
|
||||||
|
|
||||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
||||||
async def _convert_content(content) -> dict:
|
async def _convert_content(content) -> dict:
|
||||||
|
|
|
@ -21,8 +21,6 @@ logger = get_logger(name=__name__, category="inference::openai")
|
||||||
# | completion | LiteLLMOpenAIMixin |
|
# | completion | LiteLLMOpenAIMixin |
|
||||||
# | chat_completion | LiteLLMOpenAIMixin |
|
# | chat_completion | LiteLLMOpenAIMixin |
|
||||||
# | embedding | LiteLLMOpenAIMixin |
|
# | embedding | LiteLLMOpenAIMixin |
|
||||||
# | batch_completion | LiteLLMOpenAIMixin |
|
|
||||||
# | batch_chat_completion | LiteLLMOpenAIMixin |
|
|
||||||
# | openai_completion | OpenAIMixin |
|
# | openai_completion | OpenAIMixin |
|
||||||
# | openai_chat_completion | OpenAIMixin |
|
# | openai_chat_completion | OpenAIMixin |
|
||||||
# | openai_embeddings | OpenAIMixin |
|
# | openai_embeddings | OpenAIMixin |
|
||||||
|
|
|
@ -4,12 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.apis.inference import Inference
|
|
||||||
|
|
||||||
from .config import SambaNovaImplConfig
|
from .config import SambaNovaImplConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: SambaNovaImplConfig, _deps) -> Inference:
|
async def get_adapter_impl(config: SambaNovaImplConfig, _deps):
|
||||||
from .sambanova import SambaNovaInferenceAdapter
|
from .sambanova import SambaNovaInferenceAdapter
|
||||||
|
|
||||||
assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
|
|
@ -25,7 +25,7 @@ class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
|
|
||||||
def __init__(self, config: SambaNovaImplConfig):
|
def __init__(self, config: SambaNovaImplConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.environment_available_models = []
|
self.environment_available_models: list[str] = []
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
litellm_provider_name="sambanova",
|
litellm_provider_name="sambanova",
|
||||||
|
|
|
@ -70,6 +70,7 @@ class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Need
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, config: TogetherImplConfig) -> None:
|
def __init__(self, config: TogetherImplConfig) -> None:
|
||||||
|
ModelRegistryHelper.__init__(self)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.allowed_models = config.allowed_models
|
self.allowed_models = config.allowed_models
|
||||||
self._model_cache: dict[str, Model] = {}
|
self._model_cache: dict[str, Model] = {}
|
||||||
|
|
|
@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class RemoteInferenceProviderConfig(BaseModel):
|
class RemoteInferenceProviderConfig(BaseModel):
|
||||||
allowed_models: list[str] | None = Field(
|
allowed_models: list[str] | None = Field( # TODO: make this non-optional and give a list() default
|
||||||
default=None,
|
default=None,
|
||||||
description="List of models that should be registered with the model registry. If None, all models are allowed.",
|
description="List of models that should be registered with the model registry. If None, all models are allowed.",
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
@ -26,6 +27,7 @@ from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="providers::utils")
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
@ -51,6 +53,10 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
|
||||||
# This is useful for providers that do not return a unique id in the response.
|
# This is useful for providers that do not return a unique id in the response.
|
||||||
overwrite_completion_id: bool = False
|
overwrite_completion_id: bool = False
|
||||||
|
|
||||||
|
# Allow subclasses to control whether to download images and convert to base64
|
||||||
|
# for providers that require base64 encoded images instead of URLs.
|
||||||
|
download_images: bool = False
|
||||||
|
|
||||||
# Embedding model metadata for this provider
|
# Embedding model metadata for this provider
|
||||||
# Can be set by subclasses or instances to provide embedding models
|
# Can be set by subclasses or instances to provide embedding models
|
||||||
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
|
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
|
||||||
|
@ -239,6 +245,24 @@ class OpenAIMixin(ModelRegistryHelper, ABC):
|
||||||
"""
|
"""
|
||||||
Direct OpenAI chat completion API call.
|
Direct OpenAI chat completion API call.
|
||||||
"""
|
"""
|
||||||
|
if self.download_images:
|
||||||
|
|
||||||
|
async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam:
|
||||||
|
if isinstance(m.content, list):
|
||||||
|
for c in m.content:
|
||||||
|
if c.type == "image_url" and c.image_url and c.image_url.url and "http" in c.image_url.url:
|
||||||
|
localize_result = await localize_image_content(c.image_url.url)
|
||||||
|
if localize_result is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to localize image content from {c.image_url.url[:42]}{'...' if len(c.image_url.url) > 42 else ''}"
|
||||||
|
)
|
||||||
|
content, format = localize_result
|
||||||
|
c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}"
|
||||||
|
# else it's a string and we don't need to modify it
|
||||||
|
return m
|
||||||
|
|
||||||
|
messages = [await _localize_image_url(m) for m in messages]
|
||||||
|
|
||||||
resp = await self.client.chat.completions.create(
|
resp = await self.client.chat.completions.create(
|
||||||
**await prepare_openai_completion_params(
|
**await prepare_openai_completion_params(
|
||||||
model=await self._get_provider_model_id(model),
|
model=await self._get_provider_model_id(model),
|
||||||
|
|
|
@ -28,7 +28,7 @@ class CommonConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class RedisKVStoreConfig(CommonConfig):
|
class RedisKVStoreConfig(CommonConfig):
|
||||||
type: Literal[KVStoreType.redis.value] = KVStoreType.redis.value
|
type: Literal["redis"] = KVStoreType.redis.value
|
||||||
host: str = "localhost"
|
host: str = "localhost"
|
||||||
port: int = 6379
|
port: int = 6379
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ class RedisKVStoreConfig(CommonConfig):
|
||||||
|
|
||||||
|
|
||||||
class SqliteKVStoreConfig(CommonConfig):
|
class SqliteKVStoreConfig(CommonConfig):
|
||||||
type: Literal[KVStoreType.sqlite.value] = KVStoreType.sqlite.value
|
type: Literal["sqlite"] = KVStoreType.sqlite.value
|
||||||
db_path: str = Field(
|
db_path: str = Field(
|
||||||
default=(RUNTIME_BASE_DIR / "kvstore.db").as_posix(),
|
default=(RUNTIME_BASE_DIR / "kvstore.db").as_posix(),
|
||||||
description="File path for the sqlite database",
|
description="File path for the sqlite database",
|
||||||
|
@ -69,7 +69,7 @@ class SqliteKVStoreConfig(CommonConfig):
|
||||||
|
|
||||||
|
|
||||||
class PostgresKVStoreConfig(CommonConfig):
|
class PostgresKVStoreConfig(CommonConfig):
|
||||||
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
|
type: Literal["postgres"] = KVStoreType.postgres.value
|
||||||
host: str = "localhost"
|
host: str = "localhost"
|
||||||
port: int = 5432
|
port: int = 5432
|
||||||
db: str = "llamastack"
|
db: str = "llamastack"
|
||||||
|
@ -113,11 +113,11 @@ class PostgresKVStoreConfig(CommonConfig):
|
||||||
|
|
||||||
|
|
||||||
class MongoDBKVStoreConfig(CommonConfig):
|
class MongoDBKVStoreConfig(CommonConfig):
|
||||||
type: Literal[KVStoreType.mongodb.value] = KVStoreType.mongodb.value
|
type: Literal["mongodb"] = KVStoreType.mongodb.value
|
||||||
host: str = "localhost"
|
host: str = "localhost"
|
||||||
port: int = 27017
|
port: int = 27017
|
||||||
db: str = "llamastack"
|
db: str = "llamastack"
|
||||||
user: str = None
|
user: str | None = None
|
||||||
password: str | None = None
|
password: str | None = None
|
||||||
collection_name: str = "llamastack_kvstore"
|
collection_name: str = "llamastack_kvstore"
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from pymongo import AsyncMongoClient
|
from pymongo import AsyncMongoClient
|
||||||
|
from pymongo.asynchronous.collection import AsyncCollection
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
@ -19,8 +20,13 @@ log = get_logger(name=__name__, category="providers::utils")
|
||||||
class MongoDBKVStoreImpl(KVStore):
|
class MongoDBKVStoreImpl(KVStore):
|
||||||
def __init__(self, config: MongoDBKVStoreConfig):
|
def __init__(self, config: MongoDBKVStoreConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.conn = None
|
self.conn: AsyncMongoClient | None = None
|
||||||
self.collection = None
|
|
||||||
|
@property
|
||||||
|
def collection(self) -> AsyncCollection:
|
||||||
|
if self.conn is None:
|
||||||
|
raise RuntimeError("MongoDB connection is not initialized")
|
||||||
|
return self.conn[self.config.db][self.config.collection_name]
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
try:
|
try:
|
||||||
|
@ -32,7 +38,6 @@ class MongoDBKVStoreImpl(KVStore):
|
||||||
}
|
}
|
||||||
conn_creds = {k: v for k, v in conn_creds.items() if v is not None}
|
conn_creds = {k: v for k, v in conn_creds.items() if v is not None}
|
||||||
self.conn = AsyncMongoClient(**conn_creds)
|
self.conn = AsyncMongoClient(**conn_creds)
|
||||||
self.collection = self.conn[self.config.db][self.config.collection_name]
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception("Could not connect to MongoDB database server")
|
log.exception("Could not connect to MongoDB database server")
|
||||||
raise RuntimeError("Could not connect to MongoDB database server") from e
|
raise RuntimeError("Could not connect to MongoDB database server") from e
|
||||||
|
|
|
@ -9,9 +9,13 @@ from datetime import datetime
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ..api import KVStore
|
from ..api import KVStore
|
||||||
from ..config import SqliteKVStoreConfig
|
from ..config import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class SqliteKVStoreImpl(KVStore):
|
class SqliteKVStoreImpl(KVStore):
|
||||||
def __init__(self, config: SqliteKVStoreConfig):
|
def __init__(self, config: SqliteKVStoreConfig):
|
||||||
|
@ -50,6 +54,9 @@ class SqliteKVStoreImpl(KVStore):
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
value, expiration = row
|
value, expiration = row
|
||||||
|
if not isinstance(value, str):
|
||||||
|
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
|
||||||
|
return None
|
||||||
return value
|
return value
|
||||||
|
|
||||||
async def delete(self, key: str) -> None:
|
async def delete(self, key: str) -> None:
|
||||||
|
|
8
llama_stack/ui/package-lock.json
generated
8
llama_stack/ui/package-lock.json
generated
|
@ -18,7 +18,7 @@
|
||||||
"class-variance-authority": "^0.7.1",
|
"class-variance-authority": "^0.7.1",
|
||||||
"clsx": "^2.1.1",
|
"clsx": "^2.1.1",
|
||||||
"framer-motion": "^12.23.12",
|
"framer-motion": "^12.23.12",
|
||||||
"llama-stack-client": "^0.2.22",
|
"llama-stack-client": "^0.2.23",
|
||||||
"lucide-react": "^0.542.0",
|
"lucide-react": "^0.542.0",
|
||||||
"next": "15.5.3",
|
"next": "15.5.3",
|
||||||
"next-auth": "^4.24.11",
|
"next-auth": "^4.24.11",
|
||||||
|
@ -10172,9 +10172,9 @@
|
||||||
"license": "MIT"
|
"license": "MIT"
|
||||||
},
|
},
|
||||||
"node_modules/llama-stack-client": {
|
"node_modules/llama-stack-client": {
|
||||||
"version": "0.2.22",
|
"version": "0.2.23",
|
||||||
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.22.tgz",
|
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.23.tgz",
|
||||||
"integrity": "sha512-7aW3UQj5MwjV73Brd+yQ1e4W1W33nhozyeHM5tzOgbsVZ88tL78JNiNvyFqDR5w6V9XO4/uSGGiQVG6v83yR4w==",
|
"integrity": "sha512-J3YFH1HW2K70capejQxGlCyTgKdfx+sQf8Ab+HFi1j2Q00KtpHXB79RxejvBxjWC3X2E++P9iU57KdU2Tp/rIQ==",
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@types/node": "^18.11.18",
|
"@types/node": "^18.11.18",
|
||||||
|
|
|
@ -23,7 +23,7 @@
|
||||||
"class-variance-authority": "^0.7.1",
|
"class-variance-authority": "^0.7.1",
|
||||||
"clsx": "^2.1.1",
|
"clsx": "^2.1.1",
|
||||||
"framer-motion": "^12.23.12",
|
"framer-motion": "^12.23.12",
|
||||||
"llama-stack-client": "^0.2.22",
|
"llama-stack-client": "^0.2.23",
|
||||||
"lucide-react": "^0.542.0",
|
"lucide-react": "^0.542.0",
|
||||||
"next": "15.5.3",
|
"next": "15.5.3",
|
||||||
"next-auth": "^4.24.11",
|
"next-auth": "^4.24.11",
|
||||||
|
|
|
@ -7,7 +7,7 @@ required-version = ">=0.7.0"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "llama_stack"
|
name = "llama_stack"
|
||||||
version = "0.2.22"
|
version = "0.2.23"
|
||||||
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
|
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
|
||||||
description = "Llama Stack"
|
description = "Llama Stack"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
@ -31,7 +31,7 @@ dependencies = [
|
||||||
"huggingface-hub>=0.34.0,<1.0",
|
"huggingface-hub>=0.34.0,<1.0",
|
||||||
"jinja2>=3.1.6",
|
"jinja2>=3.1.6",
|
||||||
"jsonschema",
|
"jsonschema",
|
||||||
"llama-stack-client>=0.2.22",
|
"llama-stack-client>=0.2.23",
|
||||||
"openai>=1.100.0", # for expires_after support
|
"openai>=1.100.0", # for expires_after support
|
||||||
"prompt-toolkit",
|
"prompt-toolkit",
|
||||||
"python-dotenv",
|
"python-dotenv",
|
||||||
|
@ -55,7 +55,7 @@ dependencies = [
|
||||||
ui = [
|
ui = [
|
||||||
"streamlit",
|
"streamlit",
|
||||||
"pandas",
|
"pandas",
|
||||||
"llama-stack-client>=0.2.22",
|
"llama-stack-client>=0.2.23",
|
||||||
"streamlit-option-menu",
|
"streamlit-option-menu",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -259,15 +259,12 @@ exclude = [
|
||||||
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
|
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
|
||||||
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
|
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
|
||||||
"^llama_stack/providers/inline/agents/meta_reference/",
|
"^llama_stack/providers/inline/agents/meta_reference/",
|
||||||
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
|
|
||||||
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
|
|
||||||
"^llama_stack/providers/inline/datasetio/localfs/",
|
"^llama_stack/providers/inline/datasetio/localfs/",
|
||||||
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
||||||
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
|
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
|
||||||
"^llama_stack/models/llama/llama3/generation\\.py$",
|
"^llama_stack/models/llama/llama3/generation\\.py$",
|
||||||
"^llama_stack/models/llama/llama3/multimodal/model\\.py$",
|
"^llama_stack/models/llama/llama3/multimodal/model\\.py$",
|
||||||
"^llama_stack/models/llama/llama4/",
|
"^llama_stack/models/llama/llama4/",
|
||||||
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
|
|
||||||
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
|
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
|
||||||
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
|
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
|
||||||
"^llama_stack/providers/inline/safety/code_scanner/",
|
"^llama_stack/providers/inline/safety/code_scanner/",
|
||||||
|
@ -278,19 +275,13 @@ exclude = [
|
||||||
"^llama_stack/providers/remote/agents/sample/",
|
"^llama_stack/providers/remote/agents/sample/",
|
||||||
"^llama_stack/providers/remote/datasetio/huggingface/",
|
"^llama_stack/providers/remote/datasetio/huggingface/",
|
||||||
"^llama_stack/providers/remote/datasetio/nvidia/",
|
"^llama_stack/providers/remote/datasetio/nvidia/",
|
||||||
"^llama_stack/providers/remote/inference/anthropic/",
|
|
||||||
"^llama_stack/providers/remote/inference/bedrock/",
|
"^llama_stack/providers/remote/inference/bedrock/",
|
||||||
"^llama_stack/providers/remote/inference/cerebras/",
|
"^llama_stack/providers/remote/inference/cerebras/",
|
||||||
"^llama_stack/providers/remote/inference/databricks/",
|
"^llama_stack/providers/remote/inference/databricks/",
|
||||||
"^llama_stack/providers/remote/inference/fireworks/",
|
"^llama_stack/providers/remote/inference/fireworks/",
|
||||||
"^llama_stack/providers/remote/inference/gemini/",
|
|
||||||
"^llama_stack/providers/remote/inference/groq/",
|
|
||||||
"^llama_stack/providers/remote/inference/nvidia/",
|
"^llama_stack/providers/remote/inference/nvidia/",
|
||||||
"^llama_stack/providers/remote/inference/openai/",
|
|
||||||
"^llama_stack/providers/remote/inference/passthrough/",
|
"^llama_stack/providers/remote/inference/passthrough/",
|
||||||
"^llama_stack/providers/remote/inference/runpod/",
|
"^llama_stack/providers/remote/inference/runpod/",
|
||||||
"^llama_stack/providers/remote/inference/sambanova/",
|
|
||||||
"^llama_stack/providers/remote/inference/sample/",
|
|
||||||
"^llama_stack/providers/remote/inference/tgi/",
|
"^llama_stack/providers/remote/inference/tgi/",
|
||||||
"^llama_stack/providers/remote/inference/together/",
|
"^llama_stack/providers/remote/inference/together/",
|
||||||
"^llama_stack/providers/remote/inference/watsonx/",
|
"^llama_stack/providers/remote/inference/watsonx/",
|
||||||
|
@ -310,7 +301,6 @@ exclude = [
|
||||||
"^llama_stack/providers/remote/vector_io/qdrant/",
|
"^llama_stack/providers/remote/vector_io/qdrant/",
|
||||||
"^llama_stack/providers/remote/vector_io/sample/",
|
"^llama_stack/providers/remote/vector_io/sample/",
|
||||||
"^llama_stack/providers/remote/vector_io/weaviate/",
|
"^llama_stack/providers/remote/vector_io/weaviate/",
|
||||||
"^llama_stack/providers/tests/conftest\\.py$",
|
|
||||||
"^llama_stack/providers/utils/bedrock/client\\.py$",
|
"^llama_stack/providers/utils/bedrock/client\\.py$",
|
||||||
"^llama_stack/providers/utils/bedrock/refreshable_boto_session\\.py$",
|
"^llama_stack/providers/utils/bedrock/refreshable_boto_session\\.py$",
|
||||||
"^llama_stack/providers/utils/inference/embedding_mixin\\.py$",
|
"^llama_stack/providers/utils/inference/embedding_mixin\\.py$",
|
||||||
|
@ -318,12 +308,9 @@ exclude = [
|
||||||
"^llama_stack/providers/utils/inference/model_registry\\.py$",
|
"^llama_stack/providers/utils/inference/model_registry\\.py$",
|
||||||
"^llama_stack/providers/utils/inference/openai_compat\\.py$",
|
"^llama_stack/providers/utils/inference/openai_compat\\.py$",
|
||||||
"^llama_stack/providers/utils/inference/prompt_adapter\\.py$",
|
"^llama_stack/providers/utils/inference/prompt_adapter\\.py$",
|
||||||
"^llama_stack/providers/utils/kvstore/config\\.py$",
|
|
||||||
"^llama_stack/providers/utils/kvstore/kvstore\\.py$",
|
"^llama_stack/providers/utils/kvstore/kvstore\\.py$",
|
||||||
"^llama_stack/providers/utils/kvstore/mongodb/mongodb\\.py$",
|
|
||||||
"^llama_stack/providers/utils/kvstore/postgres/postgres\\.py$",
|
"^llama_stack/providers/utils/kvstore/postgres/postgres\\.py$",
|
||||||
"^llama_stack/providers/utils/kvstore/redis/redis\\.py$",
|
"^llama_stack/providers/utils/kvstore/redis/redis\\.py$",
|
||||||
"^llama_stack/providers/utils/kvstore/sqlite/sqlite\\.py$",
|
|
||||||
"^llama_stack/providers/utils/memory/vector_store\\.py$",
|
"^llama_stack/providers/utils/memory/vector_store\\.py$",
|
||||||
"^llama_stack/providers/utils/scoring/aggregation_utils\\.py$",
|
"^llama_stack/providers/utils/scoring/aggregation_utils\\.py$",
|
||||||
"^llama_stack/providers/utils/scoring/base_scoring_fn\\.py$",
|
"^llama_stack/providers/utils/scoring/base_scoring_fn\\.py$",
|
||||||
|
@ -331,13 +318,6 @@ exclude = [
|
||||||
"^llama_stack/providers/utils/telemetry/trace_protocol\\.py$",
|
"^llama_stack/providers/utils/telemetry/trace_protocol\\.py$",
|
||||||
"^llama_stack/providers/utils/telemetry/tracing\\.py$",
|
"^llama_stack/providers/utils/telemetry/tracing\\.py$",
|
||||||
"^llama_stack/strong_typing/auxiliary\\.py$",
|
"^llama_stack/strong_typing/auxiliary\\.py$",
|
||||||
"^llama_stack/strong_typing/deserializer\\.py$",
|
|
||||||
"^llama_stack/strong_typing/inspection\\.py$",
|
|
||||||
"^llama_stack/strong_typing/schema\\.py$",
|
|
||||||
"^llama_stack/strong_typing/serializer\\.py$",
|
|
||||||
"^llama_stack/distributions/groq/groq\\.py$",
|
|
||||||
"^llama_stack/distributions/llama_api/llama_api\\.py$",
|
|
||||||
"^llama_stack/distributions/sambanova/sambanova\\.py$",
|
|
||||||
"^llama_stack/distributions/template\\.py$",
|
"^llama_stack/distributions/template\\.py$",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
77
tests/integration/inference/test_openai_vision_inference.py
Normal file
77
tests/integration/inference/test_openai_vision_inference.py
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def image_path():
|
||||||
|
return pathlib.Path(__file__).parent / "dog.png"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def base64_image_data(image_path):
|
||||||
|
return base64.b64encode(image_path.read_bytes()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_openai_chat_completion_image_url(openai_client, vision_model_id):
|
||||||
|
message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/integration/inference/dog.png"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Describe what is in this image.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
response = openai_client.chat.completions.create(
|
||||||
|
model=vision_model_id,
|
||||||
|
messages=[message],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
message_content = response.choices[0].message.content.lower().strip()
|
||||||
|
assert len(message_content) > 0
|
||||||
|
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
|
||||||
|
|
||||||
|
|
||||||
|
async def test_openai_chat_completion_image_data(openai_client, vision_model_id, base64_image_data):
|
||||||
|
message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/png;base64,{base64_image_data}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Describe what is in this image.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
response = openai_client.chat.completions.create(
|
||||||
|
model=vision_model_id,
|
||||||
|
messages=[message],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
message_content = response.choices[0].message.content.lower().strip()
|
||||||
|
assert len(message_content) > 0
|
||||||
|
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
|
|
@ -10,6 +10,7 @@ from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.type_system import NumberType
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
|
@ -645,3 +646,25 @@ async def test_models_source_interaction_cleanup_provider_models(cached_disk_dis
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
await table.shutdown()
|
await table.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tool_groups_routing_table_exception_handling(cached_disk_dist_registry):
|
||||||
|
"""Test that the tool group routing table handles exceptions when listing tools, like if an MCP server is unreachable."""
|
||||||
|
|
||||||
|
exception_throwing_tool_groups_impl = ToolGroupsImpl()
|
||||||
|
exception_throwing_tool_groups_impl.list_runtime_tools = AsyncMock(side_effect=Exception("Test exception"))
|
||||||
|
|
||||||
|
table = ToolGroupsRoutingTable(
|
||||||
|
{"test_provider": exception_throwing_tool_groups_impl}, cached_disk_dist_registry, {}
|
||||||
|
)
|
||||||
|
await table.initialize()
|
||||||
|
|
||||||
|
await table.register_tool_group(
|
||||||
|
toolgroup_id="test-toolgroup-exceptions",
|
||||||
|
provider_id="test_provider",
|
||||||
|
mcp_endpoint=URL(uri="http://localhost:8479/foo/bar"),
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = await table.list_tools(toolgroup_id="test-toolgroup-exceptions")
|
||||||
|
|
||||||
|
assert len(tools.data) == 0
|
||||||
|
|
|
@ -4,11 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from unittest.mock import MagicMock, PropertyMock, patch
|
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.inference import Model
|
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
|
@ -43,8 +43,17 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin):
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mixin():
|
def mixin():
|
||||||
"""Create a test instance of OpenAIMixin"""
|
"""Create a test instance of OpenAIMixin with mocked model_store"""
|
||||||
return OpenAIMixinImpl()
|
mixin_instance = OpenAIMixinImpl()
|
||||||
|
|
||||||
|
# just enough to satisfy _get_provider_model_id calls
|
||||||
|
mock_model_store = MagicMock()
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_model.provider_resource_id = "test-provider-resource-id"
|
||||||
|
mock_model_store.get_model = AsyncMock(return_value=mock_model)
|
||||||
|
mixin_instance.model_store = mock_model_store
|
||||||
|
|
||||||
|
return mixin_instance
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -205,6 +214,74 @@ class TestOpenAIMixinCacheBehavior:
|
||||||
assert "final-mock-model-id" in mixin._model_cache
|
assert "final-mock-model-id" in mixin._model_cache
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIMixinImagePreprocessing:
|
||||||
|
"""Test cases for image preprocessing functionality"""
|
||||||
|
|
||||||
|
async def test_openai_chat_completion_with_image_preprocessing_enabled(self, mixin):
|
||||||
|
"""Test that image URLs are converted to base64 when download_images is True"""
|
||||||
|
mixin.download_images = True
|
||||||
|
|
||||||
|
message = OpenAIUserMessageParam(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
|
||||||
|
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
|
||||||
|
mock_localize.return_value = (b"fake_image_data", "jpeg")
|
||||||
|
|
||||||
|
await mixin.openai_chat_completion(model="test-model", messages=[message])
|
||||||
|
|
||||||
|
mock_localize.assert_called_once_with("http://example.com/image.jpg")
|
||||||
|
|
||||||
|
mock_client.chat.completions.create.assert_called_once()
|
||||||
|
call_args = mock_client.chat.completions.create.call_args
|
||||||
|
processed_messages = call_args[1]["messages"]
|
||||||
|
assert len(processed_messages) == 1
|
||||||
|
content = processed_messages[0]["content"]
|
||||||
|
assert len(content) == 2
|
||||||
|
assert content[0]["type"] == "text"
|
||||||
|
assert content[1]["type"] == "image_url"
|
||||||
|
assert content[1]["image_url"]["url"] == ""
|
||||||
|
|
||||||
|
async def test_openai_chat_completion_with_image_preprocessing_disabled(self, mixin):
|
||||||
|
"""Test that image URLs are not modified when download_images is False"""
|
||||||
|
mixin.download_images = False # explicitly set to False
|
||||||
|
|
||||||
|
message = OpenAIUserMessageParam(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
|
||||||
|
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
|
||||||
|
await mixin.openai_chat_completion(model="test-model", messages=[message])
|
||||||
|
|
||||||
|
mock_localize.assert_not_called()
|
||||||
|
|
||||||
|
mock_client.chat.completions.create.assert_called_once()
|
||||||
|
call_args = mock_client.chat.completions.create.call_args
|
||||||
|
processed_messages = call_args[1]["messages"]
|
||||||
|
assert len(processed_messages) == 1
|
||||||
|
content = processed_messages[0]["content"]
|
||||||
|
assert len(content) == 2
|
||||||
|
assert content[1]["image_url"]["url"] == "http://example.com/image.jpg"
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIMixinEmbeddingModelMetadata:
|
class TestOpenAIMixinEmbeddingModelMetadata:
|
||||||
"""Test cases for embedding_model_metadata attribute functionality"""
|
"""Test cases for embedding_model_metadata attribute functionality"""
|
||||||
|
|
||||||
|
|
|
@ -129,7 +129,7 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
||||||
|
|
||||||
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.embedding_model == duplicate_vector_db.embedding_model # Original values preserved
|
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
|
||||||
|
|
||||||
|
|
||||||
async def test_get_all_objects(cached_disk_dist_registry):
|
async def test_get_all_objects(cached_disk_dist_registry):
|
||||||
|
@ -174,14 +174,10 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore):
|
||||||
)
|
)
|
||||||
|
|
||||||
await sqlite_kvstore.set(
|
await sqlite_kvstore.set(
|
||||||
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"),
|
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()
|
||||||
valid_db.model_dump_json(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await sqlite_kvstore.set(
|
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
|
||||||
KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"),
|
|
||||||
"{not valid json",
|
|
||||||
)
|
|
||||||
|
|
||||||
await sqlite_kvstore.set(
|
await sqlite_kvstore.set(
|
||||||
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
|
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
|
||||||
|
@ -216,8 +212,7 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
|
||||||
)
|
)
|
||||||
|
|
||||||
await sqlite_kvstore.set(
|
await sqlite_kvstore.set(
|
||||||
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"),
|
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()
|
||||||
valid_db.model_dump_json(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await sqlite_kvstore.set(
|
await sqlite_kvstore.set(
|
||||||
|
|
12
uv.lock
generated
12
uv.lock
generated
|
@ -1749,7 +1749,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "llama-stack"
|
name = "llama-stack"
|
||||||
version = "0.2.22"
|
version = "0.2.23"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "aiohttp" },
|
{ name = "aiohttp" },
|
||||||
|
@ -1885,8 +1885,8 @@ requires-dist = [
|
||||||
{ name = "huggingface-hub", specifier = ">=0.34.0,<1.0" },
|
{ name = "huggingface-hub", specifier = ">=0.34.0,<1.0" },
|
||||||
{ name = "jinja2", specifier = ">=3.1.6" },
|
{ name = "jinja2", specifier = ">=3.1.6" },
|
||||||
{ name = "jsonschema" },
|
{ name = "jsonschema" },
|
||||||
{ name = "llama-stack-client", specifier = ">=0.2.22" },
|
{ name = "llama-stack-client", specifier = ">=0.2.23" },
|
||||||
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.22" },
|
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.23" },
|
||||||
{ name = "openai", specifier = ">=1.100.0" },
|
{ name = "openai", specifier = ">=1.100.0" },
|
||||||
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" },
|
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" },
|
||||||
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" },
|
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" },
|
||||||
|
@ -1993,7 +1993,7 @@ unit = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "llama-stack-client"
|
name = "llama-stack-client"
|
||||||
version = "0.2.22"
|
version = "0.2.23"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "anyio" },
|
{ name = "anyio" },
|
||||||
|
@ -2012,9 +2012,9 @@ dependencies = [
|
||||||
{ name = "tqdm" },
|
{ name = "tqdm" },
|
||||||
{ name = "typing-extensions" },
|
{ name = "typing-extensions" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/60/80/4260816bfaaa889d515206c9df4906d08d405bf94c9b4d1be399b1923e46/llama_stack_client-0.2.22.tar.gz", hash = "sha256:9a0bc756b91ebd539858eeaf1f231c5e5c6900e1ea4fcced726c6717f3d27ca7", size = 318309, upload-time = "2025-09-16T19:43:33.212Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/9f/8f/306d5fcf2f97b3a6251219b03c194836a2ff4e0fcc8146c9970e50a72cd3/llama_stack_client-0.2.23.tar.gz", hash = "sha256:68f34e8ac8eea6a73ed9d4977d849992b2d8bd835804d770a11843431cd5bf74", size = 322288, upload-time = "2025-09-26T21:11:08.342Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/d1/8e/1ebf6ac0dbb62b81038e856ed00768e283d927b14fcd614e3018a227092b/llama_stack_client-0.2.22-py3-none-any.whl", hash = "sha256:b260d73aec56fcfd8fa601b3b34c2f83c4fbcfb7261a246b02bbdf6c2da184fe", size = 369901, upload-time = "2025-09-16T19:43:32.089Z" },
|
{ url = "https://files.pythonhosted.org/packages/fa/75/3eb58e092a681804013dbec7b7f549d18f55acf6fd6e6b27de7e249766d8/llama_stack_client-0.2.23-py3-none-any.whl", hash = "sha256:eee42c74eee8f218f9455e5a06d5d4be43f8a8c82a7937ef51ce367f916df847", size = 379809, upload-time = "2025-09-26T21:11:06.856Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue