Merge remote-tracking branch 'upstream/main'

This commit is contained in:
Raghotham Murthy 2025-06-23 14:03:50 -07:00
commit ee96c4891b
181 changed files with 18069 additions and 1469 deletions

View file

@ -4,10 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import sys
from collections.abc import AsyncIterator
from datetime import datetime
from enum import Enum
from enum import StrEnum
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field
@ -40,14 +39,6 @@ from .openai_responses import (
OpenAIResponseText,
)
# TODO: use enum.StrEnum when we drop support for python 3.10
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
class Attachment(BaseModel):
"""An attachment to an agent turn.

View file

@ -9,6 +9,7 @@ from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from llama_stack.apis.vector_io import SearchRankingOptions as FileSearchRankingOptions
from llama_stack.schema_utils import json_schema_type, register_schema
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
@ -81,6 +82,15 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
type: Literal["web_search_call"] = "web_search_call"
@json_schema_type
class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
id: str
queries: list[str]
status: str
type: Literal["file_search_call"] = "file_search_call"
results: list[dict[str, Any]] | None = None
@json_schema_type
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
call_id: str
@ -119,6 +129,7 @@ class OpenAIResponseOutputMessageMCPListTools(BaseModel):
OpenAIResponseOutput = Annotated[
OpenAIResponseMessage
| OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageFileSearchToolCall
| OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseOutputMessageMCPCall
| OpenAIResponseOutputMessageMCPListTools,
@ -362,6 +373,7 @@ class OpenAIResponseInputFunctionToolCallOutput(BaseModel):
OpenAIResponseInput = Annotated[
# Responses API allows output messages to be passed in as input
OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageFileSearchToolCall
| OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseInputFunctionToolCallOutput
|
@ -389,17 +401,13 @@ class OpenAIResponseInputToolFunction(BaseModel):
strict: bool | None = None
class FileSearchRankingOptions(BaseModel):
ranker: str | None = None
score_threshold: float | None = Field(default=0.0, ge=0.0, le=1.0)
@json_schema_type
class OpenAIResponseInputToolFileSearch(BaseModel):
type: Literal["file_search"] = "file_search"
vector_store_id: list[str]
vector_store_ids: list[str]
filters: dict[str, Any] | None = None
max_num_results: int | None = Field(default=10, ge=1, le=50)
ranking_options: FileSearchRankingOptions | None = None
# TODO: add filters
class ApprovalFilter(BaseModel):

View file

@ -23,7 +23,9 @@ class PaginatedResponse(BaseModel):
:param data: The list of items for the current page
:param has_more: Whether there are more items available after this set
:param url: The URL for accessing this list
"""
data: list[dict[str, Any]]
has_more: bool
url: str | None = None

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import sys
from collections.abc import AsyncIterator
from enum import Enum
from typing import (
@ -37,15 +36,7 @@ register_schema(ToolCall)
register_schema(ToolParamDefinition)
register_schema(ToolDefinition)
# TODO: use enum.StrEnum when we drop support for python 3.10
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
pass
from enum import StrEnum
@json_schema_type
@ -1038,6 +1029,8 @@ class InferenceProvider(Protocol):
# vLLM-specific parameters
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
# for fill-in-the-middle type completion
suffix: str | None = None,
) -> OpenAICompletion:
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
@ -1058,6 +1051,7 @@ class InferenceProvider(Protocol):
:param temperature: (Optional) The temperature to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:param suffix: (Optional) The suffix that should be appended to the completion.
:returns: An OpenAICompletion.
"""
...

View file

@ -4,21 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import sys
from enum import Enum
from enum import StrEnum
from pydantic import BaseModel, Field
# TODO: use enum.StrEnum when we drop support for python 3.10
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
pass
class ResourceType(StrEnum):
model = "model"

View file

@ -5,8 +5,7 @@
# the root directory of this source tree.
# TODO: use enum.StrEnum when we drop support for python 3.10
import sys
from enum import Enum
from enum import StrEnum
from typing import (
Annotated,
Any,
@ -21,15 +20,6 @@ from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
pass
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up?

View file

@ -15,6 +15,48 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class RRFRanker(BaseModel):
"""
Reciprocal Rank Fusion (RRF) ranker configuration.
:param type: The type of ranker, always "rrf"
:param impact_factor: The impact factor for RRF scoring. Higher values give more weight to higher-ranked results.
Must be greater than 0. Default of 60 is from the original RRF paper (Cormack et al., 2009).
"""
type: Literal["rrf"] = "rrf"
impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
@json_schema_type
class WeightedRanker(BaseModel):
"""
Weighted ranker configuration that combines vector and keyword scores.
:param type: The type of ranker, always "weighted"
:param alpha: Weight factor between 0 and 1.
0 means only use keyword scores,
1 means only use vector scores,
values in between blend both scores.
"""
type: Literal["weighted"] = "weighted"
alpha: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Weight factor between 0 and 1. 0 means only keyword scores, 1 means only vector scores.",
)
Ranker = Annotated[
RRFRanker | WeightedRanker,
Field(discriminator="type"),
]
register_schema(Ranker, name="Ranker")
@json_schema_type
class RAGDocument(BaseModel):
"""
@ -76,7 +118,8 @@ class RAGQueryConfig(BaseModel):
:param chunk_template: Template for formatting each retrieved chunk in the context.
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
:param mode: Search mode for retrievaleither "vector" or "keyword". Default "vector".
:param mode: Search mode for retrievaleither "vector", "keyword", or "hybrid". Default "vector".
:param ranker: Configuration for the ranker to use in hybrid search. Defaults to RRF ranker.
"""
# This config defines how a query is generated using the messages
@ -86,6 +129,7 @@ class RAGQueryConfig(BaseModel):
max_chunks: int = 5
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
mode: str | None = None
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
@field_validator("chunk_template")
def validate_chunk_template(cls, v: str) -> str:

View file

@ -8,7 +8,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Protocol, runtime_checkable
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, Field
@ -16,6 +16,7 @@ from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
from llama_stack.strong_typing.schema import register_schema
class Chunk(BaseModel):
@ -37,6 +38,197 @@ class QueryChunksResponse(BaseModel):
scores: list[float]
@json_schema_type
class VectorStoreFileCounts(BaseModel):
completed: int
cancelled: int
failed: int
in_progress: int
total: int
@json_schema_type
class VectorStoreObject(BaseModel):
"""OpenAI Vector Store object."""
id: str
object: str = "vector_store"
created_at: int
name: str | None = None
usage_bytes: int = 0
file_counts: VectorStoreFileCounts
status: str = "completed"
expires_after: dict[str, Any] | None = None
expires_at: int | None = None
last_active_at: int | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class VectorStoreCreateRequest(BaseModel):
"""Request to create a vector store."""
name: str | None = None
file_ids: list[str] = Field(default_factory=list)
expires_after: dict[str, Any] | None = None
chunking_strategy: dict[str, Any] | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class VectorStoreModifyRequest(BaseModel):
"""Request to modify a vector store."""
name: str | None = None
expires_after: dict[str, Any] | None = None
metadata: dict[str, Any] | None = None
@json_schema_type
class VectorStoreListResponse(BaseModel):
"""Response from listing vector stores."""
object: str = "list"
data: list[VectorStoreObject]
first_id: str | None = None
last_id: str | None = None
has_more: bool = False
@json_schema_type
class VectorStoreSearchRequest(BaseModel):
"""Request to search a vector store."""
query: str | list[str]
filters: dict[str, Any] | None = None
max_num_results: int = 10
ranking_options: dict[str, Any] | None = None
rewrite_query: bool = False
@json_schema_type
class VectorStoreContent(BaseModel):
type: Literal["text"]
text: str
@json_schema_type
class VectorStoreSearchResponse(BaseModel):
"""Response from searching a vector store."""
file_id: str
filename: str
score: float
attributes: dict[str, str | float | bool] | None = None
content: list[VectorStoreContent]
@json_schema_type
class VectorStoreSearchResponsePage(BaseModel):
"""Response from searching a vector store."""
object: str = "vector_store.search_results.page"
search_query: str
data: list[VectorStoreSearchResponse]
has_more: bool = False
next_page: str | None = None
@json_schema_type
class VectorStoreDeleteResponse(BaseModel):
"""Response from deleting a vector store."""
id: str
object: str = "vector_store.deleted"
deleted: bool = True
@json_schema_type
class VectorStoreChunkingStrategyAuto(BaseModel):
type: Literal["auto"] = "auto"
@json_schema_type
class VectorStoreChunkingStrategyStaticConfig(BaseModel):
chunk_overlap_tokens: int = 400
max_chunk_size_tokens: int = Field(800, ge=100, le=4096)
@json_schema_type
class VectorStoreChunkingStrategyStatic(BaseModel):
type: Literal["static"] = "static"
static: VectorStoreChunkingStrategyStaticConfig
VectorStoreChunkingStrategy = Annotated[
VectorStoreChunkingStrategyAuto | VectorStoreChunkingStrategyStatic, Field(discriminator="type")
]
register_schema(VectorStoreChunkingStrategy, name="VectorStoreChunkingStrategy")
class SearchRankingOptions(BaseModel):
ranker: str | None = None
# NOTE: OpenAI File Search Tool requires threshold to be between 0 and 1, however
# we don't guarantee that the score is between 0 and 1, so will leave this unconstrained
# and let the provider handle it
score_threshold: float | None = Field(default=0.0)
@json_schema_type
class VectorStoreFileLastError(BaseModel):
code: Literal["server_error"] | Literal["rate_limit_exceeded"]
message: str
VectorStoreFileStatus = Literal["completed"] | Literal["in_progress"] | Literal["cancelled"] | Literal["failed"]
register_schema(VectorStoreFileStatus, name="VectorStoreFileStatus")
@json_schema_type
class VectorStoreFileObject(BaseModel):
"""OpenAI Vector Store File object."""
id: str
object: str = "vector_store.file"
attributes: dict[str, Any] = Field(default_factory=dict)
chunking_strategy: VectorStoreChunkingStrategy
created_at: int
last_error: VectorStoreFileLastError | None = None
status: VectorStoreFileStatus
usage_bytes: int = 0
vector_store_id: str
@json_schema_type
class VectorStoreListFilesResponse(BaseModel):
"""Response from listing vector stores."""
object: str = "list"
data: list[VectorStoreFileObject]
first_id: str | None = None
last_id: str | None = None
has_more: bool = False
@json_schema_type
class VectorStoreFileContentsResponse(BaseModel):
"""Response from retrieving the contents of a vector store file."""
file_id: str
filename: str
attributes: dict[str, Any]
content: list[VectorStoreContent]
@json_schema_type
class VectorStoreFileDeleteResponse(BaseModel):
"""Response from deleting a vector store file."""
id: str
object: str = "vector_store.file.deleted"
deleted: bool = True
class VectorDBStore(Protocol):
def get_vector_db(self, vector_db_id: str) -> VectorDB | None: ...
@ -81,3 +273,209 @@ class VectorIO(Protocol):
:returns: A QueryChunksResponse.
"""
...
# OpenAI Vector Stores API endpoints
@webmethod(route="/openai/v1/vector_stores", method="POST")
async def openai_create_vector_store(
self,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
"""Creates a vector store.
:param name: A name for the vector store.
:param file_ids: A list of File IDs that the vector store should use. Useful for tools like `file_search` that can access files.
:param expires_after: The expiration policy for a vector store.
:param chunking_strategy: The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.
:param metadata: Set of 16 key-value pairs that can be attached to an object.
:param embedding_model: The embedding model to use for this vector store.
:param embedding_dimension: The dimension of the embedding vectors (default: 384).
:param provider_id: The ID of the provider to use for this vector store.
:param provider_vector_db_id: The provider-specific vector database ID.
:returns: A VectorStoreObject representing the created vector store.
"""
...
@webmethod(route="/openai/v1/vector_stores", method="GET")
async def openai_list_vector_stores(
self,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
"""Returns a list of vector stores.
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
:param order: Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list.
:param before: A cursor for use in pagination. `before` is an object ID that defines your place in the list.
:returns: A VectorStoreListResponse containing the list of vector stores.
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}", method="GET")
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
"""Retrieves a vector store.
:param vector_store_id: The ID of the vector store to retrieve.
:returns: A VectorStoreObject representing the vector store.
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}", method="POST")
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
"""Updates a vector store.
:param vector_store_id: The ID of the vector store to update.
:param name: The name of the vector store.
:param expires_after: The expiration policy for a vector store.
:param metadata: Set of 16 key-value pairs that can be attached to an object.
:returns: A VectorStoreObject representing the updated vector store.
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}", method="DELETE")
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
"""Delete a vector store.
:param vector_store_id: The ID of the vector store to delete.
:returns: A VectorStoreDeleteResponse indicating the deletion status.
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/search", method="POST")
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
"""Search for chunks in a vector store.
Searches a vector store for relevant chunks based on a query and optional file attribute filters.
:param vector_store_id: The ID of the vector store to search.
:param query: The query string or array for performing the search.
:param filters: Filters based on file attributes to narrow the search results.
:param max_num_results: Maximum number of results to return (1 to 50 inclusive, default 10).
:param ranking_options: Ranking options for fine-tuning the search results.
:param rewrite_query: Whether to rewrite the natural language query for vector search (default false)
:returns: A VectorStoreSearchResponse containing the search results.
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files", method="POST")
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
"""Attach a file to a vector store.
:param vector_store_id: The ID of the vector store to attach the file to.
:param file_id: The ID of the file to attach to the vector store.
:param attributes: The key-value attributes stored with the file, which can be used for filtering.
:param chunking_strategy: The chunking strategy to use for the file.
:returns: A VectorStoreFileObject representing the attached file.
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files", method="GET")
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> VectorStoreListFilesResponse:
"""List files in a vector store.
:param vector_store_id: The ID of the vector store to list files from.
:returns: A VectorStoreListFilesResponse containing the list of files.
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}", method="GET")
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
"""Retrieves a vector store file.
:param vector_store_id: The ID of the vector store containing the file to retrieve.
:param file_id: The ID of the file to retrieve.
:returns: A VectorStoreFileObject representing the file.
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}/content", method="GET")
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
"""Retrieves the contents of a vector store file.
:param vector_store_id: The ID of the vector store containing the file to retrieve.
:param file_id: The ID of the file to retrieve.
:returns: A list of InterleavedContent representing the file contents.
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}", method="POST")
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any],
) -> VectorStoreFileObject:
"""Updates a vector store file.
:param vector_store_id: The ID of the vector store containing the file to update.
:param file_id: The ID of the file to update.
:param attributes: The updated key-value attributes to store with the file.
:returns: A VectorStoreFileObject representing the updated file.
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}", method="DELETE")
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileDeleteResponse:
"""Delete a vector store file.
:param vector_store_id: The ID of the vector store containing the file to delete.
:param file_id: The ID of the file to delete.
:returns: A VectorStoreFileDeleteResponse indicating the deletion status.
"""
...

View file

@ -11,7 +11,7 @@ import os
import shutil
import sys
from dataclasses import dataclass
from datetime import datetime, timezone
from datetime import UTC, datetime
from functools import partial
from pathlib import Path
@ -409,7 +409,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
d = json.load(f)
manifest = Manifest(**d)
if datetime.now(timezone.utc) > manifest.expires_on.astimezone(timezone.utc):
if datetime.now(UTC) > manifest.expires_on.astimezone(UTC):
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
console = Console()

View file

@ -408,10 +408,10 @@ def _run_stack_build_command_from_build_config(
shutil.copy(path, run_config_file)
cprint("Build Successful!", color="green", file=sys.stderr)
cprint(f"You can find the newly-built template here: {template_path}", color="light_blue", file=sys.stderr)
cprint(f"You can find the newly-built template here: {template_path}", color="blue", file=sys.stderr)
cprint(
"You can run the new Llama Stack distro via: "
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue"),
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "blue"),
color="green",
file=sys.stderr,
)

View file

@ -5,9 +5,9 @@
# the root directory of this source tree.
from enum import Enum
from typing import Self
from pydantic import BaseModel, model_validator
from typing_extensions import Self
from .conditions import parse_conditions

View file

@ -43,23 +43,12 @@ def get_provider_dependencies(
config: BuildConfig | DistributionTemplate,
) -> tuple[list[str], list[str]]:
"""Get normal and special dependencies from provider configuration."""
# Extract providers based on config type
if isinstance(config, DistributionTemplate):
providers = config.providers
config = config.build_config()
providers = config.distribution_spec.providers
additional_pip_packages = config.additional_pip_packages
# TODO: This is a hack to get the dependencies for internal APIs into build
# We should have a better way to do this by formalizing the concept of "internal" APIs
# and providers, with a way to specify dependencies for them.
run_configs = config.run_configs
additional_pip_packages: list[str] = []
if run_configs:
for run_config in run_configs.values():
run_config_ = run_config.run_config(name="", providers={}, container_image=None)
if run_config_.inference_store:
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
elif isinstance(config, BuildConfig):
providers = config.distribution_spec.providers
additional_pip_packages = config.additional_pip_packages
deps = []
registry = get_provider_registry(config)
for api_str, provider_or_providers in providers.items():
@ -87,8 +76,7 @@ def get_provider_dependencies(
else:
normal_deps.append(package)
if additional_pip_packages:
normal_deps.extend(additional_pip_packages)
normal_deps.extend(additional_pip_packages or [])
return list(set(normal_deps)), list(set(special_deps))
@ -113,7 +101,7 @@ def build_image(
template_or_config: str,
run_config: str | None = None,
):
container_base = build_config.distribution_spec.container_image or "python:3.10-slim"
container_base = build_config.distribution_spec.container_image or "python:3.11-slim"
normal_deps, special_deps = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES

View file

@ -49,7 +49,7 @@ ensure_conda_env_python310() {
local env_name="$1"
local pip_dependencies="$2"
local special_pip_deps="$3"
local python_version="3.10"
local python_version="3.11"
# Check if conda command is available
if ! is_command_available conda; then

View file

@ -180,6 +180,7 @@ def get_provider_registry(
if provider_type_key in ret[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
ret[api][provider_type_key] = spec
logger.info(f"Successfully loaded external provider {provider_type_key}")
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
raise yaml_err

View file

@ -99,7 +99,7 @@ class ProviderImpl(Providers):
try:
health = await asyncio.wait_for(impl.health(), timeout=timeout)
return api_name, health
except (asyncio.TimeoutError, TimeoutError):
except TimeoutError:
return (
api_name,
HealthResponse(

View file

@ -335,7 +335,7 @@ async def instantiate_provider(
method = "get_auto_router_impl"
config = None
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config]
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config, policy]
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"
@ -394,9 +394,13 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method is actually implemented in the class
method_owner = next((cls for cls in mro if name in cls.__dict__), None)
if method_owner is None or method_owner.__name__ == protocol.__name__:
# Check if the method has a concrete implementation (not just a protocol stub)
# Find all classes in MRO that define this method
method_owners = [cls for cls in mro if name in cls.__dict__]
# Allow methods from mixins/parents, only reject if ONLY the protocol defines it
if len(method_owners) == 1 and method_owners[0].__name__ == protocol.__name__:
# Only reject if the method is ONLY defined in the protocol itself (abstract stub)
missing_methods.append((name, "not_actually_implemented"))
if missing_methods:

View file

@ -47,7 +47,7 @@ async def get_routing_table_impl(
async def get_auto_router_impl(
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig, policy: list[AccessRule]
) -> Any:
from .datasets import DatasetIORouter
from .eval_scoring import EvalRouter, ScoringRouter
@ -78,7 +78,7 @@ async def get_auto_router_impl(
# TODO: move pass configs to routers instead
if api == Api.inference and run_config.inference_store:
inference_store = InferenceStore(run_config.inference_store)
inference_store = InferenceStore(run_config.inference_store, policy)
await inference_store.initialize()
api_to_dep_impl["store"] = inference_store

View file

@ -163,6 +163,9 @@ class InferenceRouter(Inference):
messages: list[Message] | InterleavedContent,
tool_prompt_format: ToolPromptFormat | None = None,
) -> int | None:
if not hasattr(self, "formatter") or self.formatter is None:
return None
if isinstance(messages, list):
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
else:
@ -423,6 +426,7 @@ class InferenceRouter(Inference):
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
logger.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
@ -453,6 +457,7 @@ class InferenceRouter(Inference):
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
suffix=suffix,
)
provider = self.routing_table.get_provider_impl(model_obj.identifier)
@ -602,7 +607,7 @@ class InferenceRouter(Inference):
async def health(self) -> dict[str, HealthResponse]:
health_statuses = {}
timeout = 0.5
timeout = 1 # increasing the timeout to 1 second for health checks
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
try:
# check if the provider has a health method
@ -610,7 +615,7 @@ class InferenceRouter(Inference):
continue
health = await asyncio.wait_for(impl.health(), timeout=timeout)
health_statuses[provider_id] = health
except (asyncio.TimeoutError, TimeoutError):
except TimeoutError:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR,
message=f"Health check timed out after {timeout} seconds",

View file

@ -4,14 +4,32 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Any
from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
SearchRankingOptions,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.apis.vector_io.vector_io import (
VectorStoreChunkingStrategy,
VectorStoreFileContentsResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
logger = get_logger(name=__name__, category="core")
@ -34,6 +52,31 @@ class VectorIORouter(VectorIO):
logger.debug("VectorIORouter.shutdown")
pass
async def _get_first_embedding_model(self) -> tuple[str, int] | None:
"""Get the first available embedding model identifier."""
try:
# Get all models from the routing table
all_models = await self.routing_table.get_all_with_type("model")
# Filter for embedding models
embedding_models = [
model
for model in all_models
if hasattr(model, "model_type") and model.model_type == ModelType.embedding
]
if embedding_models:
dimension = embedding_models[0].metadata.get("embedding_dimension", None)
if dimension is None:
raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension")
return embedding_models[0].identifier, dimension
else:
logger.warning("No embedding models found in the routing table")
return None
except Exception as e:
logger.error(f"Error getting embedding models: {e}")
return None
async def register_vector_db(
self,
vector_db_id: str,
@ -70,3 +113,272 @@ class VectorIORouter(VectorIO):
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
# OpenAI Vector Stores API endpoints
async def openai_create_vector_store(
self,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = None,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
# If no embedding model is provided, use the first available one
if embedding_model is None:
embedding_model_info = await self._get_first_embedding_model()
if embedding_model_info is None:
raise ValueError("No embedding model provided and no embedding models available in the system")
embedding_model, embedding_dimension = embedding_model_info
logger.info(f"No embedding model specified, using first available: {embedding_model}")
vector_db_id = name
registered_vector_db = await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
embedding_dimension,
provider_id,
provider_vector_db_id,
)
return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store(
vector_db_id,
file_ids=file_ids,
expires_after=expires_after,
chunking_strategy=chunking_strategy,
metadata=metadata,
embedding_model=embedding_model,
embedding_dimension=embedding_dimension,
provider_id=registered_vector_db.provider_id,
provider_vector_db_id=registered_vector_db.provider_resource_id,
)
async def openai_list_vector_stores(
self,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
# Route to default provider for now - could aggregate from all providers in the future
# call retrieve on each vector dbs to get list of vector stores
vector_dbs = await self.routing_table.get_all_with_type("vector_db")
all_stores = []
for vector_db in vector_dbs:
try:
vector_store = await self.routing_table.get_provider_impl(
vector_db.identifier
).openai_retrieve_vector_store(vector_db.identifier)
all_stores.append(vector_store)
except Exception as e:
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
continue
# Sort by created_at
reverse_order = order == "desc"
all_stores.sort(key=lambda x: x.created_at, reverse=reverse_order)
# Apply cursor-based pagination
if after:
after_index = next((i for i, store in enumerate(all_stores) if store.id == after), -1)
if after_index >= 0:
all_stores = all_stores[after_index + 1 :]
if before:
before_index = next((i for i, store in enumerate(all_stores) if store.id == before), len(all_stores))
all_stores = all_stores[:before_index]
# Apply limit
limited_stores = all_stores[:limit]
# Determine pagination info
has_more = len(all_stores) > limit
first_id = limited_stores[0].id if limited_stores else None
last_id = limited_stores[-1].id if limited_stores else None
return VectorStoreListResponse(
data=limited_stores,
has_more=has_more,
first_id=first_id,
last_id=last_id,
)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
vector_store_id=vector_store_id,
name=name,
expires_after=expires_after,
metadata=metadata,
)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
result = await provider.openai_delete_vector_store(vector_store_id)
# drop from registry
await self.routing_table.unregister_vector_db(vector_store_id)
return result
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
max_num_results=max_num_results,
ranking_options=ranking_options,
rewrite_query=rewrite_query,
)
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store(
vector_store_id=vector_store_id,
limit=limit,
order=order,
after=after,
before=before,
filter=filter,
)
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any],
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
)
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def health(self) -> dict[str, HealthResponse]:
health_statuses = {}
timeout = 1 # increasing the timeout to 1 second for health checks
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
try:
# check if the provider has a health method
if not hasattr(impl, "health"):
continue
health = await asyncio.wait_for(impl.health(), timeout=timeout)
health_statuses[provider_id] = health
except TimeoutError:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR,
message=f"Health check timed out after {timeout} seconds",
)
except NotImplementedError:
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
except Exception as e:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
)
return health_statuses

View file

@ -9,12 +9,12 @@ import time
from abc import ABC, abstractmethod
from asyncio import Lock
from pathlib import Path
from typing import Self
from urllib.parse import parse_qs
import httpx
from jose import jwt
from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
from llama_stack.log import get_logger
@ -84,6 +84,7 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
class OAuth2JWKSConfig(BaseModel):
# The JWKS URI for collecting public keys
uri: str
token: str | None = Field(default=None, description="token to authorise access to jwks")
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
@ -246,9 +247,12 @@ class OAuth2TokenAuthProvider(AuthProvider):
if self.config.jwks is None:
raise ValueError("JWKS is not configured")
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
headers = {}
if self.config.jwks.token:
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
async with httpx.AsyncClient(verify=verify) as client:
res = await client.get(self.config.jwks.uri, timeout=5)
res = await client.get(self.config.jwks.uri, timeout=5, headers=headers)
res.raise_for_status()
jwks_data = res.json()["keys"]
updated = {}

View file

@ -6,7 +6,7 @@
import json
import time
from datetime import datetime, timedelta, timezone
from datetime import UTC, datetime, timedelta
from starlette.types import ASGIApp, Receive, Scope, Send
@ -79,7 +79,7 @@ class QuotaMiddleware:
if int(prev) == 0:
# Set with expiration datetime when it is the first request in the window.
expiration = datetime.now(timezone.utc) + timedelta(seconds=self.window_seconds)
expiration = datetime.now(UTC) + timedelta(seconds=self.window_seconds)
await kv.set(key, str(count), expiration=expiration)
else:
await kv.set(key, str(count))

View file

@ -30,6 +30,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
@ -144,7 +145,7 @@ async def shutdown(app):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning("No shutdown method for %s", impl_name)
except (asyncio.TimeoutError, TimeoutError):
except TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except (Exception, asyncio.CancelledError) as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
@ -230,7 +231,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
return StreamingResponse(gen, media_type="text/event-stream")
else:
value = func(**kwargs)
return await maybe_await(value)
result = await maybe_await(value)
if isinstance(result, PaginatedResponse) and result.url is None:
result.url = route
return result
except Exception as e:
logger.exception(f"Error executing endpoint {route=} {method=}")
raise translate_exception(e) from e

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -115,7 +115,7 @@ def parse_environment_config(env_config: str) -> dict[str, int]:
class CustomRichHandler(RichHandler):
def __init__(self, *args, **kwargs):
kwargs["console"] = Console(width=120)
kwargs["console"] = Console(width=150)
super().__init__(*args, **kwargs)
def emit(self, record):

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -11,7 +11,7 @@ import secrets
import string
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from datetime import UTC, datetime
import httpx
@ -242,7 +242,7 @@ class ChatAgent(ShieldRunnerMixin):
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id
)
now = datetime.now(timezone.utc).isoformat()
now = datetime.now(UTC).isoformat()
tool_execution_step = ToolExecutionStep(
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id,
@ -267,7 +267,7 @@ class ChatAgent(ShieldRunnerMixin):
start_time = last_turn.started_at
else:
messages.extend(request.messages)
start_time = datetime.now(timezone.utc).isoformat()
start_time = datetime.now(UTC).isoformat()
input_messages = request.messages
output_message = None
@ -298,7 +298,7 @@ class ChatAgent(ShieldRunnerMixin):
input_messages=input_messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now(timezone.utc).isoformat(),
completed_at=datetime.now(UTC).isoformat(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
@ -389,7 +389,7 @@ class ChatAgent(ShieldRunnerMixin):
return
step_id = str(uuid.uuid4())
shield_call_start_time = datetime.now(timezone.utc).isoformat()
shield_call_start_time = datetime.now(UTC).isoformat()
try:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -413,7 +413,7 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id,
violation=e.violation,
started_at=shield_call_start_time,
completed_at=datetime.now(timezone.utc).isoformat(),
completed_at=datetime.now(UTC).isoformat(),
),
)
)
@ -436,7 +436,7 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id,
violation=None,
started_at=shield_call_start_time,
completed_at=datetime.now(timezone.utc).isoformat(),
completed_at=datetime.now(UTC).isoformat(),
),
)
)
@ -491,7 +491,7 @@ class ChatAgent(ShieldRunnerMixin):
client_tools[tool.name] = tool
while True:
step_id = str(uuid.uuid4())
inference_start_time = datetime.now(timezone.utc).isoformat()
inference_start_time = datetime.now(UTC).isoformat()
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
@ -603,7 +603,7 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id,
model_response=copy.deepcopy(message),
started_at=inference_start_time,
completed_at=datetime.now(timezone.utc).isoformat(),
completed_at=datetime.now(UTC).isoformat(),
),
)
)
@ -681,7 +681,7 @@ class ChatAgent(ShieldRunnerMixin):
"input": message.model_dump_json(),
},
) as span:
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
tool_execution_start_time = datetime.now(UTC).isoformat()
tool_result = await self.execute_tool_call_maybe(
session_id,
tool_call,
@ -710,7 +710,7 @@ class ChatAgent(ShieldRunnerMixin):
)
],
started_at=tool_execution_start_time,
completed_at=datetime.now(timezone.utc).isoformat(),
completed_at=datetime.now(UTC).isoformat(),
)
# Yield the step completion event
@ -747,7 +747,7 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id,
tool_calls=client_tool_calls,
tool_responses=[],
started_at=datetime.now(timezone.utc).isoformat(),
started_at=datetime.now(UTC).isoformat(),
),
)

View file

@ -7,7 +7,7 @@
import logging
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from datetime import UTC, datetime
from llama_stack.apis.agents import (
Agent,
@ -78,13 +78,14 @@ class MetaReferenceAgentsImpl(Agents):
async def initialize(self) -> None:
self.persistence_store = await kvstore_impl(self.config.persistence_store)
self.responses_store = ResponsesStore(self.config.responses_store)
self.responses_store = ResponsesStore(self.config.responses_store, self.policy)
await self.responses_store.initialize()
self.openai_responses_impl = OpenAIResponsesImpl(
inference_api=self.inference_api,
tool_groups_api=self.tool_groups_api,
tool_runtime_api=self.tool_runtime_api,
responses_store=self.responses_store,
vector_io_api=self.vector_io_api,
)
async def create_agent(
@ -92,7 +93,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_config: AgentConfig,
) -> AgentCreateResponse:
agent_id = str(uuid.uuid4())
created_at = datetime.now(timezone.utc)
created_at = datetime.now(UTC)
agent_info = AgentInfo(
**agent_config.model_dump(),

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import time
import uuid
@ -24,6 +25,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolMCP,
OpenAIResponseMessage,
OpenAIResponseObject,
@ -34,12 +36,14 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutput,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall,
OpenAIResponseText,
OpenAIResponseTextFormat,
)
from llama_stack.apis.common.content_types import TextContentItem
from llama_stack.apis.inference.inference import (
Inference,
OpenAIAssistantMessageParam,
@ -62,7 +66,8 @@ from llama_stack.apis.inference.inference import (
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.tools.tools import ToolGroups, ToolRuntime
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
@ -198,7 +203,8 @@ class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
class ChatCompletionContext(BaseModel):
model: str
messages: list[OpenAIMessageParam]
tools: list[ChatCompletionToolParam] | None = None
response_tools: list[OpenAIResponseInputTool] | None = None
chat_tools: list[ChatCompletionToolParam] | None = None
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
temperature: float | None
response_format: OpenAIResponseFormatParam
@ -211,11 +217,13 @@ class OpenAIResponsesImpl:
tool_groups_api: ToolGroups,
tool_runtime_api: ToolRuntime,
responses_store: ResponsesStore,
vector_io_api: VectorIO, # VectorIO
):
self.inference_api = inference_api
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
self.responses_store = responses_store
self.vector_io_api = vector_io_api
async def _prepend_previous_response(
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
@ -388,7 +396,8 @@ class OpenAIResponsesImpl:
ctx = ChatCompletionContext(
model=model,
messages=messages,
tools=chat_tools,
response_tools=tools,
chat_tools=chat_tools,
mcp_tool_to_server=mcp_tool_to_server,
temperature=temperature,
response_format=response_format,
@ -417,7 +426,7 @@ class OpenAIResponsesImpl:
completion_result = await self.inference_api.openai_chat_completion(
model=ctx.model,
messages=messages,
tools=ctx.tools,
tools=ctx.chat_tools,
stream=True,
temperature=ctx.temperature,
response_format=ctx.response_format,
@ -606,6 +615,12 @@ class OpenAIResponsesImpl:
if not tool:
raise ValueError(f"Tool {tool_name} not found")
chat_tools.append(make_openai_tool(tool_name, tool))
elif input_tool.type == "file_search":
tool_name = "knowledge_search"
tool = await self.tool_groups_api.get_tool(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")
chat_tools.append(make_openai_tool(tool_name, tool))
elif input_tool.type == "mcp":
always_allowed = None
never_allowed = None
@ -656,6 +671,71 @@ class OpenAIResponsesImpl:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
return chat_tools, mcp_tool_to_server, mcp_list_message
async def _execute_knowledge_search_via_vector_store(
self,
query: str,
response_file_search_tool: OpenAIResponseInputToolFileSearch,
) -> ToolInvocationResult:
"""Execute knowledge search using vector_stores.search API with filters support."""
search_results = []
# Create search tasks for all vector stores
async def search_single_store(vector_store_id):
try:
search_response = await self.vector_io_api.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=response_file_search_tool.filters,
max_num_results=response_file_search_tool.max_num_results,
ranking_options=response_file_search_tool.ranking_options,
rewrite_query=False,
)
return search_response.data
except Exception as e:
logger.warning(f"Failed to search vector store {vector_store_id}: {e}")
return []
# Run all searches in parallel using gather
search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids]
all_results = await asyncio.gather(*search_tasks)
# Flatten results
for results in all_results:
search_results.extend(results)
# Convert search results to tool result format matching memory.py
# Format the results as interleaved content similar to memory.py
content_items = []
content_items.append(
TextContentItem(
text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
)
for i, result_item in enumerate(search_results):
chunk_text = result_item.content[0].text if result_item.content else ""
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
if result_item.attributes:
metadata_text += f", attributes: {result_item.attributes}"
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
content_items.append(TextContentItem(text=text_content))
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
content_items.append(
TextContentItem(
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
)
)
return ToolInvocationResult(
content=content_items,
metadata={
"document_ids": [r.file_id for r in search_results],
"chunks": [r.content[0].text if r.content else "" for r in search_results],
"scores": [r.score for r in search_results],
},
)
async def _execute_tool_call(
self,
tool_call: OpenAIChatCompletionToolCall,
@ -667,6 +747,7 @@ class OpenAIResponsesImpl:
tool_call_id = tool_call.id
function = tool_call.function
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
if not function or not tool_call_id or not function.name:
return None, None
@ -680,12 +761,24 @@ class OpenAIResponsesImpl:
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
tool_name=function.name,
kwargs=json.loads(function.arguments) if function.arguments else {},
kwargs=tool_kwargs,
)
elif function.name == "knowledge_search":
response_file_search_tool = next(
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), None
)
if response_file_search_tool:
# Use vector_stores.search API instead of knowledge_search tool
# to support filters and ranking_options
query = tool_kwargs.get("query", "")
result = await self._execute_knowledge_search_via_vector_store(
query=query,
response_file_search_tool=response_file_search_tool,
)
else:
result = await self.tool_runtime_api.invoke_tool(
tool_name=function.name,
kwargs=json.loads(function.arguments) if function.arguments else {},
kwargs=tool_kwargs,
)
except Exception as e:
error_exc = e
@ -713,6 +806,27 @@ class OpenAIResponsesImpl:
)
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
message.status = "failed"
elif function.name == "knowledge_search":
message = OpenAIResponseOutputMessageFileSearchToolCall(
id=tool_call_id,
queries=[tool_kwargs.get("query", "")],
status="completed",
)
if "document_ids" in result.metadata:
message.results = []
for i, doc_id in enumerate(result.metadata["document_ids"]):
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
score = result.metadata["scores"][i] if "scores" in result.metadata else None
message.results.append(
{
"file_id": doc_id,
"filename": doc_id,
"text": text,
"score": score,
}
)
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
message.status = "failed"
else:
raise ValueError(f"Unknown tool {function.name} called")

View file

@ -7,7 +7,7 @@
import json
import logging
import uuid
from datetime import datetime, timezone
from datetime import UTC, datetime
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
@ -47,7 +47,7 @@ class AgentPersistence:
session_info = AgentSessionInfo(
session_id=session_id,
session_name=name,
started_at=datetime.now(timezone.utc),
started_at=datetime.now(UTC),
owner=user,
turns=[],
identifier=name, # should this be qualified in any way?

View file

@ -114,18 +114,18 @@ class LocalfsFilesImpl(Files):
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
# TODO: Implement 'after' pagination properly
if after:
raise NotImplementedError("After pagination not yet implemented")
if not order:
order = Order.desc
where = None
where_conditions = {}
if purpose:
where = {"purpose": purpose.value}
where_conditions["purpose"] = purpose.value
rows = await self.sql_store.fetch_all(
"openai_files",
where=where,
order_by=[("created_at", order.value if order else Order.desc.value)],
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,
limit=limit,
)
@ -138,12 +138,12 @@ class LocalfsFilesImpl(Files):
created_at=row["created_at"],
expires_at=row["expires_at"],
)
for row in rows
for row in paginated_result.data
]
return ListOpenAIFileResponse(
data=files,
has_more=False, # TODO: Implement proper pagination
has_more=paginated_result.has_more,
first_id=files[0].id if files else "",
last_id=files[-1].id if files else "",
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -11,7 +11,7 @@ import multiprocessing
import os
import signal
import sys
from datetime import datetime, timezone
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
@ -670,7 +670,7 @@ class HFFinetuningSingleDevice:
# Create checkpoint
checkpoint = Checkpoint(
identifier=f"{model}-sft-{config.n_epochs}",
created_at=datetime.now(timezone.utc),
created_at=datetime.now(UTC),
epoch=config.n_epochs,
post_training_job_id=job_uuid,
path=str(output_dir_path / "merged_model"),

View file

@ -7,7 +7,7 @@
import logging
import os
import time
from datetime import datetime, timezone
from datetime import UTC, datetime
from functools import partial
from pathlib import Path
from typing import Any
@ -537,7 +537,7 @@ class LoraFinetuningSingleDevice:
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
checkpoint = Checkpoint(
identifier=f"{self.model_id}-sft-{curr_epoch}",
created_at=datetime.now(timezone.utc),
created_at=datetime.now(UTC),
epoch=curr_epoch,
post_training_job_id=self.job_uuid,
path=checkpoint_path,

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import json
from datetime import datetime, timezone
from datetime import UTC, datetime
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanProcessor
@ -34,7 +34,7 @@ class ConsoleSpanProcessor(SpanProcessor):
if span.attributes and span.attributes.get("__autotraced__"):
return
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
print(
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
@ -46,7 +46,7 @@ class ConsoleSpanProcessor(SpanProcessor):
if span.attributes and span.attributes.get("__autotraced__"):
return
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
span_context = (
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
@ -74,7 +74,7 @@ class ConsoleSpanProcessor(SpanProcessor):
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
for event in span.events:
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=timezone.utc).strftime("%H:%M:%S.%f")[:-3]
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
severity = event.attributes.get("severity", "info")
message = event.attributes.get("message", event.name)

View file

@ -8,7 +8,7 @@ import json
import os
import sqlite3
import threading
from datetime import datetime, timezone
from datetime import UTC, datetime
from opentelemetry.sdk.trace import SpanProcessor
from opentelemetry.trace import Span
@ -125,8 +125,8 @@ class SQLiteSpanProcessor(SpanProcessor):
trace_id,
service_name,
(span_id if span.attributes.get("__root_span__") == "true" else None),
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
datetime.fromtimestamp(span.start_time / 1e9, UTC).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9, UTC).isoformat(),
),
)
@ -144,8 +144,8 @@ class SQLiteSpanProcessor(SpanProcessor):
trace_id,
parent_span_id,
span.name,
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
datetime.fromtimestamp(span.start_time / 1e9, UTC).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9, UTC).isoformat(),
json.dumps(dict(span.attributes)),
span.status.status_code.name,
span.kind.name,
@ -162,7 +162,7 @@ class SQLiteSpanProcessor(SpanProcessor):
(
span_id,
event.name,
datetime.fromtimestamp(event.timestamp / 1e9, timezone.utc).isoformat(),
datetime.fromtimestamp(event.timestamp / 1e9, UTC).isoformat(),
json.dumps(dict(event.attributes)),
),
)

View file

@ -121,8 +121,10 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
vector_db_id=vector_db_id,
query=query,
params={
"max_chunks": query_config.max_chunks,
"mode": query_config.mode,
"max_chunks": query_config.max_chunks,
"score_threshold": 0.0,
"ranker": query_config.ranker,
},
)
for vector_db_id in vector_db_ids
@ -170,6 +172,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
content=picked,
metadata={
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
"chunks": [c.content for c in chunks[: len(picked)]],
"scores": scores[: len(picked)],
},
)

View file

@ -16,6 +16,6 @@ async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = FaissVectorIOAdapter(config, deps[Api.inference])
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
await impl.initialize()
return impl

View file

@ -15,13 +15,23 @@ import faiss
import numpy as np
from numpy.typing import NDArray
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.files import Files
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
)
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
VectorDBsProtocolPrivate,
)
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
@ -34,6 +44,9 @@ logger = logging.getLogger(__name__)
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::"
class FaissIndex(EmbeddingIndex):
@ -112,7 +125,7 @@ class FaissIndex(EmbeddingIndex):
if i < 0:
continue
chunks.append(self.chunk_by_index[int(i)])
scores.append(1.0 / float(d))
scores.append(1.0 / float(d) if d != 0 else float("inf"))
return QueryChunksResponse(chunks=chunks, scores=scores)
@ -124,13 +137,26 @@ class FaissIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in FAISS")
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in FAISS")
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
self.config = config
self.inference_api = inference_api
self.files_api = files_api
self.cache: dict[str, VectorDBWithIndex] = {}
self.kvstore: KVStore | None = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
@ -148,10 +174,29 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
)
self.cache[vector_db.identifier] = index
# Load existing OpenAI vector stores using the mixin method
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None:
# Cleanup if needed
pass
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the inline faiss DB.
This method is used by the Provider API to verify
that the service is running correctly.
Returns:
HealthResponse: A dictionary containing the health status.
"""
try:
vector_dimension = 128 # sample dimension
faiss.IndexFlatL2(vector_dimension)
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def register_vector_db(
self,
vector_db: VectorDB,
@ -208,3 +253,71 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from kvstore."""
assert self.kvstore is not None
start_key = OPENAI_VECTOR_STORES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
stores = {}
for store_data in stored_openai_stores:
store_info = json.loads(store_data)
stores[store_info["id"]] = store_info
return stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.delete(key)
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file metadata to kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.set(key=key, value=json.dumps(file_info))
content_key = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}"
await self.kvstore.set(key=content_key, value=json.dumps(file_contents))
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
"""Load vector store file metadata from kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
stored_data = await self.kvstore.get(key)
return json.loads(stored_data) if stored_data else {}
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
"""Load vector store file contents from kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}"
stored_data = await self.kvstore.get(key)
return json.loads(stored_data) if stored_data else []
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
"""Update vector store file metadata in kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.set(key=key, value=json.dumps(file_info))
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store file metadata from kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.delete(key)

View file

@ -15,6 +15,6 @@ async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
from .sqlite_vec import SQLiteVecVectorIOAdapter
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference])
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
await impl.initialize()
return impl

View file

@ -6,6 +6,7 @@
import asyncio
import hashlib
import json
import logging
import sqlite3
import struct
@ -16,18 +17,30 @@ import numpy as np
import sqlite_vec
from numpy.typing import NDArray
from llama_stack.apis.files.files import Files
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
)
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF,
RERANKER_TYPE_WEIGHTED,
EmbeddingIndex,
VectorDBWithIndex,
)
logger = logging.getLogger(__name__)
# Specifying search mode is dependent on the VectorIO provider.
VECTOR_SEARCH = "vector"
KEYWORD_SEARCH = "keyword"
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
HYBRID_SEARCH = "hybrid"
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH}
def serialize_vector(vector: list[float]) -> bytes:
@ -44,6 +57,59 @@ def _create_sqlite_connection(db_path):
return connection
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
"""Normalize scores to [0,1] range using min-max normalization."""
if not scores:
return {}
min_score = min(scores.values())
max_score = max(scores.values())
score_range = max_score - min_score
if score_range > 0:
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
return {doc_id: 1.0 for doc_id in scores}
def _weighted_rerank(
vector_scores: dict[str, float],
keyword_scores: dict[str, float],
alpha: float = 0.5,
) -> dict[str, float]:
"""ReRanker that uses weighted average of scores."""
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
normalized_vector_scores = _normalize_scores(vector_scores)
normalized_keyword_scores = _normalize_scores(keyword_scores)
return {
doc_id: (alpha * normalized_keyword_scores.get(doc_id, 0.0))
+ ((1 - alpha) * normalized_vector_scores.get(doc_id, 0.0))
for doc_id in all_ids
}
def _rrf_rerank(
vector_scores: dict[str, float],
keyword_scores: dict[str, float],
impact_factor: float = 60.0,
) -> dict[str, float]:
"""ReRanker that uses Reciprocal Rank Fusion."""
# Convert scores to ranks
vector_ranks = {
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
}
keyword_ranks = {
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
}
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
rrf_scores = {}
for doc_id in all_ids:
vector_rank = vector_ranks.get(doc_id, float("inf"))
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
# RRF formula: score = 1/(k + r) where k is impact_factor and r is the rank
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
return rrf_scores
class SQLiteVecIndex(EmbeddingIndex):
"""
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
@ -248,8 +314,6 @@ class SQLiteVecIndex(EmbeddingIndex):
"""
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
"""
if query_string is None:
raise ValueError("query_string is required for keyword search.")
def _execute_query():
connection = _create_sqlite_connection(self.db_path)
@ -287,18 +351,95 @@ class SQLiteVecIndex(EmbeddingIndex):
scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str = RERANKER_TYPE_RRF,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
"""
Hybrid search using a configurable re-ranking strategy.
class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
Args:
embedding: The query embedding vector
query_string: The text query for keyword search
k: Number of results to return
score_threshold: Minimum similarity score threshold
reranker_type: Type of reranker to use ("rrf" or "weighted")
reranker_params: Parameters for the reranker
Returns:
QueryChunksResponse with combined results
"""
if reranker_params is None:
reranker_params = {}
# Get results from both search methods
vector_response = await self.query_vector(embedding, k, score_threshold)
keyword_response = await self.query_keyword(query_string, k, score_threshold)
# Convert responses to score dictionaries using generate_chunk_id
vector_scores = {
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
}
keyword_scores = {
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
}
# Combine scores using the specified reranker
if reranker_type == RERANKER_TYPE_WEIGHTED:
alpha = reranker_params.get("alpha", 0.5)
combined_scores = _weighted_rerank(vector_scores, keyword_scores, alpha)
else:
# Default to RRF for None, RRF, or any unknown types
impact_factor = reranker_params.get("impact_factor", 60.0)
combined_scores = _rrf_rerank(vector_scores, keyword_scores, impact_factor)
# Sort by combined score and get top k results
sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
top_k_items = sorted_items[:k]
# Filter by score threshold
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
# Create a map of chunk_id to chunk for both responses
chunk_map = {}
for c in vector_response.chunks:
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
chunk_map[chunk_id] = c
for c in keyword_response.chunks:
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
chunk_map[chunk_id] = c
# Use the map to look up chunks by their IDs
chunks = []
scores = []
for doc_id, score in filtered_items:
if doc_id in chunk_map:
chunks.append(chunk_map[doc_id])
scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
"""
A VectorIO implementation using SQLite + sqlite_vec.
This class handles vector database registration (with metadata stored in a table named `vector_dbs`)
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
"""
def __init__(self, config, inference_api: Inference) -> None:
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
self.config = config
self.inference_api = inference_api
self.files_api = files_api
self.cache: dict[str, VectorDBWithIndex] = {}
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
async def initialize(self) -> None:
def _setup_connection():
@ -313,24 +454,55 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
metadata TEXT
);
""")
# Create a table to persist OpenAI vector stores.
cur.execute("""
CREATE TABLE IF NOT EXISTS openai_vector_stores (
id TEXT PRIMARY KEY,
metadata TEXT
);
""")
# Create a table to persist OpenAI vector store files.
cur.execute("""
CREATE TABLE IF NOT EXISTS openai_vector_store_files (
store_id TEXT,
file_id TEXT,
metadata TEXT,
PRIMARY KEY (store_id, file_id)
);
""")
cur.execute("""
CREATE TABLE IF NOT EXISTS openai_vector_store_files_contents (
store_id TEXT,
file_id TEXT,
contents TEXT,
PRIMARY KEY (store_id, file_id)
);
""")
connection.commit()
# Load any existing vector DB registrations.
cur.execute("SELECT metadata FROM vector_dbs")
rows = cur.fetchall()
return rows
vector_db_rows = cur.fetchall()
return vector_db_rows
finally:
cur.close()
connection.close()
rows = await asyncio.to_thread(_setup_connection)
for row in rows:
vector_db_rows = await asyncio.to_thread(_setup_connection)
# Load existing vector DBs
for row in vector_db_rows:
vector_db_data = row[0]
vector_db = VectorDB.model_validate_json(vector_db_data)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension, self.config.db_path, vector_db.identifier
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
# Load existing OpenAI vector stores using the mixin method
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None:
# nothing to do since we don't maintain a persistent connection
pass
@ -350,7 +522,11 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
connection.close()
await asyncio.to_thread(_register_db)
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def list_vector_dbs(self) -> list[VectorDB]:
@ -375,6 +551,199 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
await asyncio.to_thread(_delete_vector_db_from_registry)
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to SQLite database."""
def _store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)",
(store_id, json.dumps(store_info)),
)
connection.commit()
except Exception as e:
logger.error(f"Error saving openai vector store {store_id}: {e}")
raise
finally:
cur.close()
connection.close()
try:
await asyncio.to_thread(_store)
except Exception as e:
logger.error(f"Error saving openai vector store {store_id}: {e}")
raise
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from SQLite database."""
def _load():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("SELECT metadata FROM openai_vector_stores")
rows = cur.fetchall()
return rows
finally:
cur.close()
connection.close()
rows = await asyncio.to_thread(_load)
stores = {}
for row in rows:
store_data = row[0]
store_info = json.loads(store_data)
stores[store_info["id"]] = store_info
return stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in SQLite database."""
def _update():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"UPDATE openai_vector_stores SET metadata = ? WHERE id = ?",
(json.dumps(store_info), store_id),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_update)
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from SQLite database."""
def _delete():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (store_id,))
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete)
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file metadata to SQLite database."""
def _store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"INSERT OR REPLACE INTO openai_vector_store_files (store_id, file_id, metadata) VALUES (?, ?, ?)",
(store_id, file_id, json.dumps(file_info)),
)
cur.execute(
"INSERT OR REPLACE INTO openai_vector_store_files_contents (store_id, file_id, contents) VALUES (?, ?, ?)",
(store_id, file_id, json.dumps(file_contents)),
)
connection.commit()
except Exception as e:
logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}")
raise
finally:
cur.close()
connection.close()
try:
await asyncio.to_thread(_store)
except Exception as e:
logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}")
raise
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
"""Load vector store file metadata from SQLite database."""
def _load():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"SELECT metadata FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?",
(store_id, file_id),
)
row = cur.fetchone()
if row is None:
return None
(metadata,) = row
return metadata
finally:
cur.close()
connection.close()
stored_data = await asyncio.to_thread(_load)
return json.loads(stored_data) if stored_data else {}
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
"""Load vector store file contents from SQLite database."""
def _load():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"SELECT contents FROM openai_vector_store_files_contents WHERE store_id = ? AND file_id = ?",
(store_id, file_id),
)
row = cur.fetchone()
if row is None:
return None
(contents,) = row
return contents
finally:
cur.close()
connection.close()
stored_contents = await asyncio.to_thread(_load)
return json.loads(stored_contents) if stored_contents else []
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
"""Update vector store file metadata in SQLite database."""
def _update():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"UPDATE openai_vector_store_files SET metadata = ? WHERE store_id = ? AND file_id = ?",
(json.dumps(file_info), store_id, file_id),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_update)
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store file metadata from SQLite database."""
def _delete():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"DELETE FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?", (store_id, file_id)
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete)
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")

View file

@ -24,7 +24,7 @@ def available_providers() -> list[ProviderSpec]:
"pandas",
"scikit-learn",
]
+ kvstore_dependencies(),
+ kvstore_dependencies(), # TODO make this dynamic based on the kvstore config
module="llama_stack.providers.inline.agents.meta_reference",
config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig",
api_dependencies=[

View file

@ -24,6 +24,7 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
deprecation_warning="Please use the `inline::faiss` provider instead.",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec(
api=Api.vector_io,
@ -32,6 +33,7 @@ def available_providers() -> list[ProviderSpec]:
module="llama_stack.providers.inline.vector_io.faiss",
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
# NOTE: sqlite-vec cannot be bundled into the container image because it does not have a
# source distribution and the wheels are not available for all platforms.
@ -42,6 +44,7 @@ def available_providers() -> list[ProviderSpec]:
module="llama_stack.providers.inline.vector_io.sqlite_vec",
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec(
api=Api.vector_io,
@ -51,6 +54,7 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
remote_provider_spec(
Api.vector_io,

View file

@ -32,7 +32,6 @@ import os
os.environ["NVIDIA_API_KEY"] = "your-api-key"
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
os.environ["NVIDIA_USER_ID"] = "llama-stack-user"
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient

View file

@ -36,6 +36,10 @@ class NvidiaDatasetIOAdapter:
url = f"{self.config.datasets_url}{path}"
request_headers = self.headers.copy()
# Set default Content-Type for JSON requests
if json is not None:
request_headers["Content-Type"] = "application/json"
if headers:
request_headers.update(headers)

View file

@ -318,6 +318,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)

View file

@ -12,6 +12,9 @@ from llama_stack.providers.utils.inference.model_registry import (
LLM_MODEL_IDS = [
"gemini/gemini-1.5-flash",
"gemini/gemini-1.5-pro",
"gemini/gemini-2.0-flash",
"gemini/gemini-2.5-flash",
"gemini/gemini-2.5-pro",
]

View file

@ -316,6 +316,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
provider_model_id = await self._get_provider_model_id(model)

View file

@ -33,7 +33,6 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -46,6 +45,8 @@ from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
@ -62,8 +63,10 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
b64_encode_openai_embeddings_response,
get_sampling_options,
prepare_openai_completion_params,
prepare_openai_embeddings_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
@ -386,7 +389,35 @@ class OllamaInferenceAdapter(
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
model_obj = await self._get_model(model)
if model_obj.model_type != ModelType.embedding:
raise ValueError(f"Model {model} is not an embedding model")
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {model} has no provider_resource_id set")
# Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters
params = prepare_openai_embeddings_params(
model=model_obj.provider_resource_id,
input=input,
encoding_format=encoding_format,
dimensions=dimensions,
user=user,
)
response = await self.openai_client.embeddings.create(**params)
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
# TODO: Investigate why model_obj.identifier is used instead of response.model
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.identifier,
usage=usage,
)
async def openai_completion(
self,
@ -409,6 +440,7 @@ class OllamaInferenceAdapter(
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
if not isinstance(prompt, str):
raise ValueError("Ollama does not support non-string prompts for completion")
@ -432,6 +464,7 @@ class OllamaInferenceAdapter(
temperature=temperature,
top_p=top_p,
user=user,
suffix=suffix,
)
return await self.openai_client.completions.create(**params) # type: ignore

View file

@ -90,6 +90,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
if guided_choice is not None:
logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
@ -117,6 +118,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
temperature=temperature,
top_p=top_p,
user=user,
suffix=suffix,
)
return await self._openai_client.completions.create(**params)

View file

@ -242,6 +242,7 @@ class PassthroughInferenceAdapter(Inference):
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
client = self._get_client()
model_obj = await self.model_store.get_model(model)

View file

@ -299,6 +299,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(

View file

@ -38,7 +38,9 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
LogProbConfig,
Message,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -56,7 +58,11 @@ from llama_stack.apis.inference.inference import (
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
@ -298,6 +304,22 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def unregister_model(self, model_id: str) -> None:
pass
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the remote vLLM server.
This method is used by the Provider API to verify
that the service is running correctly.
Returns:
HealthResponse: A dictionary containing the health status.
"""
try:
client = self._create_client() if self.client is None else self.client
_ = [m async for m in client.models.list()] # Ensure the client is initialized
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
@ -516,7 +538,39 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
self._lazy_initialize_client()
assert self.client is not None
model_obj = await self._get_model(model)
assert model_obj.model_type == ModelType.embedding
# Convert input to list if it's a string
input_list = [input] if isinstance(input, str) else input
# Call vLLM embeddings endpoint with encoding_format
response = await self.client.embeddings.create(
model=model_obj.provider_resource_id,
input=input_list,
dimensions=dimensions,
encoding_format=encoding_format,
)
# Convert response to OpenAI format
data = [
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
for i, embedding_data in enumerate(response.data)
]
# Not returning actual token usage since vLLM doesn't provide it
usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.provider_resource_id,
usage=usage,
)
async def openai_completion(
self,
@ -539,6 +593,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
self._lazy_initialize_client()
model_obj = await self._get_model(model)

View file

@ -292,6 +292,7 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(

View file

@ -25,13 +25,16 @@ class NVIDIASafetyConfig(BaseModel):
guardrails_service_url: str = Field(
default_factory=lambda: os.getenv("GUARDRAILS_SERVICE_URL", "http://0.0.0.0:7331"),
description="The url for accessing the guardrails service",
description="The url for accessing the Guardrails service",
)
config_id: str | None = Field(
default_factory=lambda: os.getenv("NVIDIA_GUARDRAILS_CONFIG_ID", "self-check"),
description="Guardrails configuration ID to use from the Guardrails configuration store",
)
config_id: str | None = Field(default="self-check", description="Config ID to use from the config store")
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}",
"config_id": "self-check",
"config_id": "${env.NVIDIA_GUARDRAILS_CONFIG_ID:self-check}",
}

View file

@ -14,7 +14,22 @@ from numpy.typing import NDArray
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
SearchRankingOptions,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.apis.vector_io.vector_io import (
VectorStoreChunkingStrategy,
VectorStoreFileContentsResponse,
VectorStoreFileObject,
VectorStoreListFilesResponse,
)
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
from llama_stack.providers.utils.memory.vector_store import (
@ -55,7 +70,7 @@ class ChromaIndex(EmbeddingIndex):
)
)
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
results = await maybe_await(
self.collection.query(
query_embeddings=[embedding.tolist()],
@ -76,8 +91,12 @@ class ChromaIndex(EmbeddingIndex):
log.exception(f"Failed to parse document: {doc}")
continue
score = 1.0 / float(dist) if dist != 0 else float("inf")
if score < score_threshold:
continue
chunks.append(chunk)
scores.append(1.0 / float(dist))
scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
@ -92,6 +111,17 @@ class ChromaIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Chroma")
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Chroma")
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(
@ -174,3 +204,102 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api)
self.cache[vector_db_id] = index
return index
async def openai_create_vector_store(
self,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_list_vector_stores(
self,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
) -> VectorStoreListFilesResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")

View file

@ -16,7 +16,22 @@ from pymilvus import MilvusClient
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
SearchRankingOptions,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.apis.vector_io.vector_io import (
VectorStoreChunkingStrategy,
VectorStoreFileContentsResponse,
VectorStoreFileObject,
VectorStoreListFilesResponse,
)
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
from llama_stack.providers.utils.memory.vector_store import (
@ -94,6 +109,17 @@ class MilvusIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Milvus")
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Milvus")
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(
@ -177,6 +203,105 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
return await index.query_chunks(query, params)
async def openai_create_vector_store(
self,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_list_vector_stores(
self,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Milvus")
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
) -> VectorStoreListFilesResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Milvus")
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Milvus")
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Milvus")
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Milvus")
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Milvus")
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""

View file

@ -116,7 +116,7 @@ class PGVectorIndex(EmbeddingIndex):
scores = []
for doc, dist in results:
chunks.append(Chunk(**doc))
scores.append(1.0 / float(dist))
scores.append(1.0 / float(dist) if dist != 0 else float("inf"))
return QueryChunksResponse(chunks=chunks, scores=scores)
@ -128,6 +128,17 @@ class PGVectorIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in PGVector")
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in PGVector")
async def delete(self):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

View file

@ -14,7 +14,22 @@ from qdrant_client.models import PointStruct
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
SearchRankingOptions,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.apis.vector_io.vector_io import (
VectorStoreChunkingStrategy,
VectorStoreFileContentsResponse,
VectorStoreFileObject,
VectorStoreListFilesResponse,
)
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.memory.vector_store import (
@ -103,6 +118,17 @@ class QdrantIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Qdrant")
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Qdrant")
async def delete(self):
await self.client.delete_collection(collection_name=self.collection_name)
@ -178,3 +204,102 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)
async def openai_create_vector_store(
self,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_list_vector_stores(
self,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
) -> VectorStoreListFilesResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")

View file

@ -76,7 +76,7 @@ class WeaviateIndex(EmbeddingIndex):
continue
chunks.append(chunk)
scores.append(1.0 / doc.metadata.distance)
scores.append(1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf"))
return QueryChunksResponse(chunks=chunks, scores=scores)
@ -92,6 +92,17 @@ class WeaviateIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Weaviate")
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Weaviate")
class WeaviateVectorIOAdapter(
VectorIO,

View file

@ -87,9 +87,7 @@ class RefreshableBotoSession:
"access_key": session_credentials.access_key,
"secret_key": session_credentials.secret_key,
"token": session_credentials.token,
"expiry_time": datetime.datetime.fromtimestamp(
time() + self.session_ttl, datetime.timezone.utc
).isoformat(),
"expiry_time": datetime.datetime.fromtimestamp(time() + self.session_ttl, datetime.UTC).isoformat(),
}
return credentials

View file

@ -10,24 +10,27 @@ from llama_stack.apis.inference import (
OpenAIMessageParam,
Order,
)
from llama_stack.distribution.datatypes import AccessRule
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from ..sqlstore.api import ColumnDefinition, ColumnType
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
class InferenceStore:
def __init__(self, sql_store_config: SqlStoreConfig):
def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]):
if not sql_store_config:
sql_store_config = SqliteSqlStoreConfig(
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
)
self.sql_store_config = sql_store_config
self.sql_store = None
self.policy = policy
async def initialize(self):
"""Create the necessary tables if they don't exist."""
self.sql_store = sqlstore_impl(self.sql_store_config)
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
await self.sql_store.create_table(
"chat_completions",
{
@ -48,8 +51,8 @@ class InferenceStore:
data = chat_completion.model_dump()
await self.sql_store.insert(
"chat_completions",
{
table="chat_completions",
data={
"id": data["id"],
"created": data["created"],
"model": data["model"],
@ -76,17 +79,20 @@ class InferenceStore:
if not self.sql_store:
raise ValueError("Inference store is not initialized")
# TODO: support after
if after:
raise NotImplementedError("After is not supported for SQLite")
if not order:
order = Order.desc
rows = await self.sql_store.fetch_all(
"chat_completions",
where={"model": model} if model else None,
where_conditions = {}
if model:
where_conditions["model"] = model
paginated_result = await self.sql_store.fetch_all(
table="chat_completions",
where=where_conditions if where_conditions else None,
order_by=[("created", order.value)],
cursor=("id", after) if after else None,
limit=limit,
policy=self.policy,
)
data = [
@ -97,12 +103,11 @@ class InferenceStore:
choices=row["choices"],
input_messages=row["input_messages"],
)
for row in rows
for row in paginated_result.data
]
return ListOpenAIChatCompletionResponse(
data=data,
# TODO: implement has_more
has_more=False,
has_more=paginated_result.has_more,
first_id=data[0].id if data else "",
last_id=data[-1].id if data else "",
)
@ -111,9 +116,17 @@ class InferenceStore:
if not self.sql_store:
raise ValueError("Inference store is not initialized")
row = await self.sql_store.fetch_one("chat_completions", where={"id": completion_id})
row = await self.sql_store.fetch_one(
table="chat_completions",
where={"id": completion_id},
policy=self.policy,
)
if not row:
# SecureSqlStore will return None if record doesn't exist OR access is denied
# This provides security by not revealing whether the record exists
raise ValueError(f"Chat completion with id {completion_id} not found") from None
return OpenAICompletionWithInputMessages(
id=row["id"],
created=row["created"],

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import struct
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
@ -37,7 +35,6 @@ from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
@ -48,6 +45,7 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
b64_encode_openai_embeddings_response,
convert_message_to_openai_dict_new,
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
@ -293,16 +291,7 @@ class LiteLLMOpenAIMixin(
)
# Convert response to OpenAI format
data = []
for i, embedding_data in enumerate(response["data"]):
# we encode to base64 if the encoding format is base64 in the request
if encoding_format == "base64":
byte_data = b"".join(struct.pack("f", f) for f in embedding_data["embedding"])
embedding = base64.b64encode(byte_data).decode("utf-8")
else:
embedding = embedding_data["embedding"]
data.append(OpenAIEmbeddingData(embedding=embedding, index=i))
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response["usage"]["prompt_tokens"],
@ -336,6 +325,7 @@ class LiteLLMOpenAIMixin(
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(

View file

@ -3,8 +3,10 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import json
import logging
import struct
import time
import uuid
import warnings
@ -108,6 +110,7 @@ from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAICompletion,
OpenAICompletionChoice,
OpenAIEmbeddingData,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ToolConfig,
@ -1287,6 +1290,7 @@ class OpenAICompletionToLlamaStackMixin:
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
if stream:
raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions")
@ -1483,3 +1487,55 @@ class OpenAIChatCompletionToLlamaStackMixin:
model=model,
object="chat.completion",
)
def prepare_openai_embeddings_params(
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
):
if model is None:
raise ValueError("Model must be provided for embeddings")
input_list = [input] if isinstance(input, str) else input
params: dict[str, Any] = {
"model": model,
"input": input_list,
}
if encoding_format is not None:
params["encoding_format"] = encoding_format
if dimensions is not None:
params["dimensions"] = dimensions
if user is not None:
params["user"] = user
return params
def b64_encode_openai_embeddings_response(
response_data: dict, encoding_format: str | None = "float"
) -> list[OpenAIEmbeddingData]:
"""
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
"""
data = []
for i, embedding_data in enumerate(response_data):
if encoding_format == "base64":
byte_array = bytearray()
for embedding_value in embedding_data.embedding:
byte_array.extend(struct.pack("f", float(embedding_value)))
response_embedding = base64.b64encode(byte_array).decode("utf-8")
else:
response_embedding = embedding_data.embedding
data.append(
OpenAIEmbeddingData(
embedding=response_embedding,
index=i,
)
)
return data

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from datetime import UTC, datetime
from typing import Any
from llama_stack.apis.inference import (
@ -122,7 +122,7 @@ async def stream_and_store_openai_completion(
final_response = OpenAIChatCompletion(
id=id,
choices=assembled_choices,
created=created or int(datetime.now(timezone.utc).timestamp()),
created=created or int(datetime.now(UTC).timestamp()),
model=model,
object="chat.completion",
)

View file

@ -36,6 +36,10 @@ class RedisKVStoreConfig(CommonConfig):
def url(self) -> str:
return f"redis://{self.host}:{self.port}"
@property
def pip_packages(self) -> list[str]:
return ["redis"]
@classmethod
def sample_run_config(cls):
return {
@ -53,6 +57,10 @@ class SqliteKVStoreConfig(CommonConfig):
description="File path for the sqlite database",
)
@property
def pip_packages(self) -> list[str]:
return ["aiosqlite"]
@classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "kvstore.db"):
return {
@ -100,6 +108,10 @@ class PostgresKVStoreConfig(CommonConfig):
raise ValueError("Table name must be less than 63 characters")
return v
@property
def pip_packages(self) -> list[str]:
return ["psycopg2-binary"]
class MongoDBKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.mongodb.value] = KVStoreType.mongodb.value
@ -110,6 +122,10 @@ class MongoDBKVStoreConfig(CommonConfig):
password: str | None = None
collection_name: str = "llamastack_kvstore"
@property
def pip_packages(self) -> list[str]:
return ["pymongo"]
@classmethod
def sample_run_config(cls, collection_name: str = "llamastack_kvstore"):
return {

View file

@ -10,6 +10,13 @@ from .config import KVStoreConfig, KVStoreType
def kvstore_dependencies():
"""
Returns all possible kvstore dependencies for registry/provider specifications.
NOTE: For specific kvstore implementations, use config.pip_packages instead.
This function returns the union of all dependencies for cases where the specific
kvstore type is not known at declaration time (e.g., provider registries).
"""
return ["aiosqlite", "psycopg2-binary", "redis", "pymongo"]

View file

@ -0,0 +1,737 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import logging
import mimetypes
import time
import uuid
from abc import ABC, abstractmethod
from typing import Any
from llama_stack.apis.files import Files
from llama_stack.apis.files.files import OpenAIFileObject
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
QueryChunksResponse,
SearchRankingOptions,
VectorStoreContent,
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponse,
VectorStoreSearchResponsePage,
)
from llama_stack.apis.vector_io.vector_io import (
Chunk,
VectorStoreChunkingStrategy,
VectorStoreChunkingStrategyAuto,
VectorStoreChunkingStrategyStatic,
VectorStoreFileContentsResponse,
VectorStoreFileCounts,
VectorStoreFileDeleteResponse,
VectorStoreFileLastError,
VectorStoreFileObject,
VectorStoreFileStatus,
VectorStoreListFilesResponse,
)
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
logger = logging.getLogger(__name__)
# Constants for OpenAI vector stores
CHUNK_MULTIPLIER = 5
class OpenAIVectorStoreMixin(ABC):
"""
Mixin class that provides common OpenAI Vector Store API implementation.
Providers need to implement the abstract storage methods and maintain
an openai_vector_stores in-memory cache.
"""
# These should be provided by the implementing class
openai_vector_stores: dict[str, dict[str, Any]]
files_api: Files | None
@abstractmethod
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to persistent storage."""
pass
@abstractmethod
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from persistent storage."""
pass
@abstractmethod
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in persistent storage."""
pass
@abstractmethod
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from persistent storage."""
pass
@abstractmethod
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file metadata to persistent storage."""
pass
@abstractmethod
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
"""Load vector store file metadata from persistent storage."""
pass
@abstractmethod
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
"""Load vector store file contents from persistent storage."""
pass
@abstractmethod
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
"""Update vector store file metadata in persistent storage."""
pass
@abstractmethod
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store file metadata from persistent storage."""
pass
@abstractmethod
async def register_vector_db(self, vector_db: VectorDB) -> None:
"""Register a vector database (provider-specific implementation)."""
pass
@abstractmethod
async def unregister_vector_db(self, vector_db_id: str) -> None:
"""Unregister a vector database (provider-specific implementation)."""
pass
@abstractmethod
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
"""Insert chunks into a vector database (provider-specific implementation)."""
pass
@abstractmethod
async def query_chunks(
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
"""Query chunks from a vector database (provider-specific implementation)."""
pass
async def openai_create_vector_store(
self,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
"""Creates a vector store."""
# store and vector_db have the same id
store_id = name or str(uuid.uuid4())
created_at = int(time.time())
if provider_id is None:
raise ValueError("Provider ID is required")
if embedding_model is None:
raise ValueError("Embedding model is required")
# Use provided embedding dimension or default to 384
if embedding_dimension is None:
raise ValueError("Embedding dimension is required")
provider_vector_db_id = provider_vector_db_id or store_id
vector_db = VectorDB(
identifier=store_id,
embedding_dimension=embedding_dimension,
embedding_model=embedding_model,
provider_id=provider_id,
provider_resource_id=provider_vector_db_id,
)
# Register the vector DB
await self.register_vector_db(vector_db)
# Create OpenAI vector store metadata
status = "completed"
# Start with no files attached and update later
file_counts = VectorStoreFileCounts(
cancelled=0,
completed=0,
failed=0,
in_progress=0,
total=0,
)
store_info = {
"id": store_id,
"object": "vector_store",
"created_at": created_at,
"name": store_id,
"usage_bytes": 0,
"file_counts": file_counts.model_dump(),
"status": status,
"expires_after": expires_after,
"expires_at": None,
"last_active_at": created_at,
"file_ids": [],
"chunking_strategy": chunking_strategy,
}
# Add provider information to metadata if provided
metadata = metadata or {}
if provider_id:
metadata["provider_id"] = provider_id
if provider_vector_db_id:
metadata["provider_vector_db_id"] = provider_vector_db_id
store_info["metadata"] = metadata
# Save to persistent storage (provider-specific)
await self._save_openai_vector_store(store_id, store_info)
# Store in memory cache
self.openai_vector_stores[store_id] = store_info
# Now that our vector store is created, attach any files that were provided
file_ids = file_ids or []
tasks = [self.openai_attach_file_to_vector_store(store_id, file_id) for file_id in file_ids]
await asyncio.gather(*tasks)
# Get the updated store info and return it
store_info = self.openai_vector_stores[store_id]
return VectorStoreObject.model_validate(store_info)
async def openai_list_vector_stores(
self,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
"""Returns a list of vector stores."""
limit = limit or 20
order = order or "desc"
# Get all vector stores
all_stores = list(self.openai_vector_stores.values())
# Sort by created_at
reverse_order = order == "desc"
all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order)
# Apply cursor-based pagination
if after:
after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1)
if after_index >= 0:
all_stores = all_stores[after_index + 1 :]
if before:
before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
all_stores = all_stores[:before_index]
# Apply limit
limited_stores = all_stores[:limit]
# Convert to VectorStoreObject instances
data = [VectorStoreObject(**store) for store in limited_stores]
# Determine pagination info
has_more = len(all_stores) > limit
first_id = data[0].id if data else None
last_id = data[-1].id if data else None
return VectorStoreListResponse(
data=data,
has_more=has_more,
first_id=first_id,
last_id=last_id,
)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
"""Retrieves a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
store_info = self.openai_vector_stores[vector_store_id]
return VectorStoreObject(**store_info)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
"""Modifies a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
store_info = self.openai_vector_stores[vector_store_id].copy()
# Update fields if provided
if name is not None:
store_info["name"] = name
if expires_after is not None:
store_info["expires_after"] = expires_after
if metadata is not None:
store_info["metadata"] = metadata
# Update last_active_at
store_info["last_active_at"] = int(time.time())
# Save to persistent storage (provider-specific)
await self._update_openai_vector_store(vector_store_id, store_info)
# Update in-memory cache
self.openai_vector_stores[vector_store_id] = store_info
return VectorStoreObject(**store_info)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
"""Delete a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
# Delete from persistent storage (provider-specific)
await self._delete_openai_vector_store_from_storage(vector_store_id)
# Delete from in-memory cache
del self.openai_vector_stores[vector_store_id]
# Also delete the underlying vector DB
try:
await self.unregister_vector_db(vector_store_id)
except Exception as e:
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
return VectorStoreDeleteResponse(
id=vector_store_id,
deleted=True,
)
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
# search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
) -> VectorStoreSearchResponsePage:
"""Search for chunks in a vector store."""
# TODO: Add support in the API for this
search_mode = "vector"
max_num_results = max_num_results or 10
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
if isinstance(query, list):
search_query = " ".join(query)
else:
search_query = query
try:
score_threshold = (
ranking_options.score_threshold
if ranking_options and ranking_options.score_threshold is not None
else 0.0
)
params = {
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
"score_threshold": score_threshold,
"mode": search_mode,
}
# TODO: Add support for ranking_options.ranker
response = await self.query_chunks(
vector_db_id=vector_store_id,
query=search_query,
params=params,
)
# Convert response to OpenAI format
data = []
for chunk, score in zip(response.chunks, response.scores, strict=False):
# Apply score based filtering
if score < score_threshold:
continue
# Apply filters if provided
if filters:
# Simple metadata filtering
if not self._matches_filters(chunk.metadata, filters):
continue
content = self._chunk_to_vector_store_content(chunk)
response_data_item = VectorStoreSearchResponse(
file_id=chunk.metadata.get("file_id", ""),
filename=chunk.metadata.get("filename", ""),
score=score,
attributes=chunk.metadata,
content=content,
)
data.append(response_data_item)
if len(data) >= max_num_results:
break
return VectorStoreSearchResponsePage(
search_query=search_query,
data=data,
has_more=False, # For simplicity, we don't implement pagination here
next_page=None,
)
except Exception as e:
logger.error(f"Error searching vector store {vector_store_id}: {e}")
# Return empty results on error
return VectorStoreSearchResponsePage(
search_query=search_query,
data=[],
has_more=False,
next_page=None,
)
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
"""Check if metadata matches the provided filters."""
if not filters:
return True
filter_type = filters.get("type")
if filter_type in ["eq", "ne", "gt", "gte", "lt", "lte"]:
# Comparison filter
key = filters.get("key")
value = filters.get("value")
if key not in metadata:
return False
metadata_value = metadata[key]
if filter_type == "eq":
return bool(metadata_value == value)
elif filter_type == "ne":
return bool(metadata_value != value)
elif filter_type == "gt":
return bool(metadata_value > value)
elif filter_type == "gte":
return bool(metadata_value >= value)
elif filter_type == "lt":
return bool(metadata_value < value)
elif filter_type == "lte":
return bool(metadata_value <= value)
else:
raise ValueError(f"Unsupported filter type: {filter_type}")
elif filter_type == "and":
# All filters must match
sub_filters = filters.get("filters", [])
return all(self._matches_filters(metadata, f) for f in sub_filters)
elif filter_type == "or":
# At least one filter must match
sub_filters = filters.get("filters", [])
return any(self._matches_filters(metadata, f) for f in sub_filters)
else:
# Unknown filter type, default to no match
raise ValueError(f"Unsupported filter type: {filter_type}")
def _chunk_to_vector_store_content(self, chunk: Chunk) -> list[VectorStoreContent]:
# content is InterleavedContent
if isinstance(chunk.content, str):
content = [
VectorStoreContent(
type="text",
text=chunk.content,
)
]
elif isinstance(chunk.content, list):
# TODO: Add support for other types of content
content = [
VectorStoreContent(
type="text",
text=item.text,
)
for item in chunk.content
if item.type == "text"
]
else:
if chunk.content.type != "text":
raise ValueError(f"Unsupported content type: {chunk.content.type}")
content = [
VectorStoreContent(
type="text",
text=chunk.content.text,
)
]
return content
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
attributes = attributes or {}
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
created_at = int(time.time())
chunks: list[Chunk] = []
file_response: OpenAIFileObject | None = None
vector_store_file_object = VectorStoreFileObject(
id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy,
created_at=created_at,
status="in_progress",
vector_store_id=vector_store_id,
)
if not hasattr(self, "files_api") or not self.files_api:
vector_store_file_object.status = "failed"
vector_store_file_object.last_error = VectorStoreFileLastError(
code="server_error",
message="Files API is not available",
)
return vector_store_file_object
if isinstance(chunking_strategy, VectorStoreChunkingStrategyStatic):
max_chunk_size_tokens = chunking_strategy.static.max_chunk_size_tokens
chunk_overlap_tokens = chunking_strategy.static.chunk_overlap_tokens
else:
# Default values from OpenAI API spec
max_chunk_size_tokens = 800
chunk_overlap_tokens = 400
try:
file_response = await self.files_api.openai_retrieve_file(file_id)
mime_type, _ = mimetypes.guess_type(file_response.filename)
content_response = await self.files_api.openai_retrieve_file_content(file_id)
content = content_from_data_and_mime_type(content_response.body, mime_type)
chunks = make_overlapped_chunks(
file_id,
content,
max_chunk_size_tokens,
chunk_overlap_tokens,
attributes,
)
if not chunks:
vector_store_file_object.status = "failed"
vector_store_file_object.last_error = VectorStoreFileLastError(
code="server_error",
message="No chunks were generated from the file",
)
else:
await self.insert_chunks(
vector_db_id=vector_store_id,
chunks=chunks,
)
vector_store_file_object.status = "completed"
except Exception as e:
logger.error(f"Error attaching file to vector store: {e}")
vector_store_file_object.status = "failed"
vector_store_file_object.last_error = VectorStoreFileLastError(
code="server_error",
message=str(e),
)
# Create OpenAI vector store file metadata
file_info = vector_store_file_object.model_dump(exclude={"last_error"})
file_info["filename"] = file_response.filename if file_response else ""
# Save vector store file to persistent storage (provider-specific)
dict_chunks = [c.model_dump() for c in chunks]
await self._save_openai_vector_store_file(vector_store_id, file_id, file_info, dict_chunks)
# Update file_ids and file_counts in vector store metadata
store_info = self.openai_vector_stores[vector_store_id].copy()
store_info["file_ids"].append(file_id)
store_info["file_counts"]["total"] += 1
store_info["file_counts"][vector_store_file_object.status] += 1
# Save updated vector store to persistent storage
await self._save_openai_vector_store(vector_store_id, store_info)
# Update vector store in-memory cache
self.openai_vector_stores[vector_store_id] = store_info
return vector_store_file_object
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> VectorStoreListFilesResponse:
"""List files in a vector store."""
limit = limit or 20
order = order or "desc"
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
store_info = self.openai_vector_stores[vector_store_id]
file_objects: list[VectorStoreFileObject] = []
for file_id in store_info["file_ids"]:
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
file_object = VectorStoreFileObject(**file_info)
if filter and file_object.status != filter:
continue
file_objects.append(file_object)
# Sort by created_at
reverse_order = order == "desc"
file_objects.sort(key=lambda x: x.created_at, reverse=reverse_order)
# Apply cursor-based pagination
if after:
after_index = next((i for i, file in enumerate(file_objects) if file.id == after), -1)
if after_index >= 0:
file_objects = file_objects[after_index + 1 :]
if before:
before_index = next((i for i, file in enumerate(file_objects) if file.id == before), len(file_objects))
file_objects = file_objects[:before_index]
# Apply limit
limited_files = file_objects[:limit]
# Determine pagination info
has_more = len(file_objects) > limit
first_id = file_objects[0].id if file_objects else None
last_id = file_objects[-1].id if file_objects else None
return VectorStoreListFilesResponse(
data=limited_files,
has_more=has_more,
first_id=first_id,
last_id=last_id,
)
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
"""Retrieves a vector store file."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
store_info = self.openai_vector_stores[vector_store_id]
if file_id not in store_info["file_ids"]:
raise ValueError(f"File {file_id} not found in vector store {vector_store_id}")
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
return VectorStoreFileObject(**file_info)
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
"""Retrieves the contents of a vector store file."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
chunks = [Chunk.model_validate(c) for c in dict_chunks]
content = []
for chunk in chunks:
content.extend(self._chunk_to_vector_store_content(chunk))
return VectorStoreFileContentsResponse(
file_id=file_id,
filename=file_info.get("filename", ""),
attributes=file_info.get("attributes", {}),
content=content,
)
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any],
) -> VectorStoreFileObject:
"""Updates a vector store file."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
store_info = self.openai_vector_stores[vector_store_id]
if file_id not in store_info["file_ids"]:
raise ValueError(f"File {file_id} not found in vector store {vector_store_id}")
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
file_info["attributes"] = attributes
await self._update_openai_vector_store_file(vector_store_id, file_id, file_info)
return VectorStoreFileObject(**file_info)
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileDeleteResponse:
"""Deletes a vector store file."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
store_info = self.openai_vector_stores[vector_store_id].copy()
file = await self.openai_retrieve_vector_store_file(vector_store_id, file_id)
await self._delete_openai_vector_store_file_from_storage(vector_store_id, file_id)
# TODO: We need to actually delete the embeddings from the underlying vector store...
# Also uncomment the corresponding integration test marked as xfail
#
# test_openai_vector_store_delete_file_removes_from_vector_store in
# tests/integration/vector_io/test_openai_vector_stores.py
# Update in-memory cache
store_info["file_ids"].remove(file_id)
store_info["file_counts"][file.status] -= 1
store_info["file_counts"]["total"] -= 1
self.openai_vector_stores[vector_store_id] = store_info
# Save updated vector store to persistent storage
await self._save_openai_vector_store(vector_store_id, store_info)
return VectorStoreFileDeleteResponse(
id=file_id,
deleted=True,
)

View file

@ -32,6 +32,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
log = logging.getLogger(__name__)
# Constants for reranker types
RERANKER_TYPE_RRF = "rrf"
RERANKER_TYPE_WEIGHTED = "weighted"
def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string
@ -72,16 +76,18 @@ def content_from_data(data_url: str) -> str:
data = unquote(data)
encoding = parts["encoding"] or "utf-8"
data = data.encode(encoding)
return content_from_data_and_mime_type(data, parts["mimetype"], parts.get("encoding", None))
encoding = parts["encoding"]
if not encoding:
import chardet
detected = chardet.detect(data)
encoding = detected["encoding"]
def content_from_data_and_mime_type(data: bytes | str, mime_type: str | None, encoding: str | None = None) -> str:
if isinstance(data, bytes):
if not encoding:
import chardet
mime_type = parts["mimetype"]
mime_category = mime_type.split("/")[0]
detected = chardet.detect(data)
encoding = detected["encoding"]
mime_category = mime_type.split("/")[0] if mime_type else None
if mime_category == "text":
# For text-based files (including CSV, MD)
return data.decode(encoding)
@ -200,6 +206,18 @@ class EmbeddingIndex(ABC):
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError()
@abstractmethod
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError()
@abstractmethod
async def delete(self):
raise NotImplementedError()
@ -243,10 +261,29 @@ class VectorDBWithIndex:
k = params.get("max_chunks", 3)
mode = params.get("mode")
score_threshold = params.get("score_threshold", 0.0)
# Get ranker configuration
ranker = params.get("ranker")
if ranker is None:
# Default to RRF with impact_factor=60.0
reranker_type = RERANKER_TYPE_RRF
reranker_params = {"impact_factor": 60.0}
else:
reranker_type = ranker.type
reranker_params = (
{"impact_factor": ranker.impact_factor} if ranker.type == RERANKER_TYPE_RRF else {"alpha": ranker.alpha}
)
query_string = interleaved_content_as_str(query)
if mode == "keyword":
return await self.index.query_keyword(query_string, k, score_threshold)
# Calculate embeddings for both vector and hybrid modes
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
if mode == "hybrid":
return await self.index.query_hybrid(
query_vector, query_string, k, score_threshold, reranker_type, reranker_params
)
else:
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
return await self.index.query_vector(query_vector, k, score_threshold)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -13,19 +13,22 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObject,
OpenAIResponseObjectWithInput,
)
from llama_stack.distribution.datatypes import AccessRule
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from ..sqlstore.api import ColumnDefinition, ColumnType
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
class ResponsesStore:
def __init__(self, sql_store_config: SqlStoreConfig):
def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]):
if not sql_store_config:
sql_store_config = SqliteSqlStoreConfig(
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
)
self.sql_store = sqlstore_impl(sql_store_config)
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config))
self.policy = policy
async def initialize(self):
"""Create the necessary tables if they don't exist."""
@ -70,32 +73,45 @@ class ResponsesStore:
:param model: The model to filter by.
:param order: The order to sort the responses by.
"""
# TODO: support after
if after:
raise NotImplementedError("After is not supported for SQLite")
if not order:
order = Order.desc
rows = await self.sql_store.fetch_all(
"openai_responses",
where={"model": model} if model else None,
where_conditions = {}
if model:
where_conditions["model"] = model
paginated_result = await self.sql_store.fetch_all(
table="openai_responses",
where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,
limit=limit,
policy=self.policy,
)
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in rows]
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
return ListOpenAIResponseObject(
data=data,
# TODO: implement has_more
has_more=False,
has_more=paginated_result.has_more,
first_id=data[0].id if data else "",
last_id=data[-1].id if data else "",
)
async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWithInput:
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
"""
Get a response object with automatic access control checking.
"""
row = await self.sql_store.fetch_one(
"openai_responses",
where={"id": response_id},
policy=self.policy,
)
if not row:
# SecureSqlStore will return None if record doesn't exist OR access is denied
# This provides security by not revealing whether the record exists
raise ValueError(f"Response with id {response_id} not found") from None
return OpenAIResponseObjectWithInput(**row["response_object"])
async def list_response_input_items(
@ -117,19 +133,38 @@ class ResponsesStore:
:param limit: A limit on the number of objects to be returned.
:param order: The order to return the input items in.
"""
# TODO: support after/before pagination
if after or before:
raise NotImplementedError("After/before pagination is not supported yet")
if include:
raise NotImplementedError("Include is not supported yet")
if before and after:
raise ValueError("Cannot specify both 'before' and 'after' parameters")
response_with_input = await self.get_response_object(response_id)
input_items = response_with_input.input
items = response_with_input.input
if order == Order.desc:
input_items = list(reversed(input_items))
items = list(reversed(items))
if limit is not None and len(input_items) > limit:
input_items = input_items[:limit]
start_index = 0
end_index = len(items)
return ListOpenAIResponseInputItem(data=input_items)
if after or before:
for i, item in enumerate(items):
item_id = getattr(item, "id", None)
if after and item_id == after:
start_index = i + 1
if before and item_id == before:
end_index = i
break
if after and start_index == 0:
raise ValueError(f"Input item with id '{after}' not found for response '{response_id}'")
if before and end_index == len(items):
raise ValueError(f"Input item with id '{before}' not found for response '{response_id}'")
items = items[start_index:end_index]
# Apply limit
if limit is not None:
items = items[:limit]
return ListOpenAIResponseInputItem(data=items)

View file

@ -9,7 +9,7 @@ import asyncio
import functools
import threading
from collections.abc import Callable, Coroutine, Iterable
from datetime import datetime, timezone
from datetime import UTC, datetime
from enum import Enum
from typing import Any, TypeAlias
@ -61,7 +61,7 @@ class Job:
self._handler = handler
self._artifacts: list[JobArtifact] = []
self._logs: list[LogMessage] = []
self._state_transitions: list[tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
self._state_transitions: list[tuple[datetime, JobStatus]] = [(datetime.now(UTC), JobStatus.new)]
@property
def handler(self) -> JobHandler:
@ -77,7 +77,7 @@ class Job:
raise ValueError(f"Job is already in a completed state ({self.status})")
if self.status == status:
return
self._state_transitions.append((datetime.now(timezone.utc), status))
self._state_transitions.append((datetime.now(UTC), status))
@property
def artifacts(self) -> list[JobArtifact]:
@ -157,10 +157,14 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
# When stopping the loop, give tasks a chance to finish
# TODO: When stopping the loop, give tasks a chance to finish
# TODO: should we explicitly inform jobs of pending stoppage?
# cancel all tasks
for task in asyncio.all_tasks(self._loop):
self._loop.run_until_complete(task)
if not task.done():
task.cancel()
self._loop.close()
async def shutdown(self) -> None:
@ -215,7 +219,7 @@ class Scheduler:
self._backend = _get_backend_impl(backend)
def _on_log_message_cb(self, job: Job, message: str) -> None:
msg = (datetime.now(timezone.utc), message)
msg = (datetime.now(UTC), message)
# At least for the time being, until there's a better way to expose
# logs to users, log messages on console
logger.info(f"Job {job.id}: {message}")

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -10,6 +10,8 @@ from typing import Any, Literal, Protocol
from pydantic import BaseModel
from llama_stack.apis.common.responses import PaginatedResponse
class ColumnType(Enum):
INTEGER = "INTEGER"
@ -49,11 +51,25 @@ class SqlStore(Protocol):
self,
table: str,
where: Mapping[str, Any] | None = None,
where_sql: str | None = None,
limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> list[dict[str, Any]]:
cursor: tuple[str, str] | None = None,
) -> PaginatedResponse:
"""
Fetch all rows from a table.
Fetch all rows from a table with optional cursor-based pagination.
:param table: The table name
:param where: Simple key-value WHERE conditions
:param where_sql: Raw SQL WHERE clause for complex queries
:param limit: Maximum number of records to return
:param order_by: List of (column, order) tuples for sorting
:param cursor: Tuple of (key_column, cursor_id) for pagination (None for first page)
Requires order_by with exactly one column when used
:return: PaginatedResult with data and has_more flag
Note: Cursor pagination only supports single-column ordering for simplicity.
Multi-column ordering is allowed without cursor but will raise an error with cursor.
"""
pass
@ -61,6 +77,7 @@ class SqlStore(Protocol):
self,
table: str,
where: Mapping[str, Any] | None = None,
where_sql: str | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None:
"""
@ -88,3 +105,24 @@ class SqlStore(Protocol):
Delete a row from a table.
"""
pass
async def add_column_if_not_exists(
self,
table: str,
column_name: str,
column_type: ColumnType,
nullable: bool = True,
) -> None:
"""
Add a column to an existing table if the column doesn't already exist.
This is useful for table migrations when adding new functionality.
If the table doesn't exist, this method should do nothing.
If the column already exists, this method should do nothing.
:param table: Table name
:param column_name: Name of the column to add
:param column_type: Type of the column to add
:param nullable: Whether the column should be nullable (default: True)
"""
pass

View file

@ -0,0 +1,222 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Mapping
from typing import Any, Literal
from llama_stack.distribution.access_control.access_control import default_policy, is_action_allowed
from llama_stack.distribution.access_control.conditions import ProtectedResource
from llama_stack.distribution.access_control.datatypes import AccessRule, Action, Scope
from llama_stack.distribution.datatypes import User
from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
logger = get_logger(name=__name__, category="authorized_sqlstore")
# Hardcoded copy of the default policy that our SQL filtering implements
# WARNING: If default_policy() changes, this constant must be updated accordingly
# or SQL filtering will fall back to conservative mode (safe but less performant)
#
# This policy represents: "Permit all actions when user is in owners list for ALL attribute categories"
# The corresponding SQL logic is implemented in _build_default_policy_where_clause():
# - Public records (no access_attributes) are always accessible
# - Records with access_attributes require user to match ALL categories that exist in the resource
# - Missing categories in the resource are treated as "no restriction" (allow)
# - Within each category, user needs ANY matching value (OR logic)
# - Between categories, user needs ALL categories to match (AND logic)
SQL_OPTIMIZED_POLICY = [
AccessRule(
permit=Scope(actions=list(Action)),
when=["user in owners roles", "user in owners teams", "user in owners projects", "user in owners namespaces"],
),
]
class SqlRecord(ProtectedResource):
"""Simple ProtectedResource implementation for SQL records."""
def __init__(self, record_id: str, table_name: str, access_attributes: dict[str, list[str]] | None = None):
self.type = f"sql_record::{table_name}"
self.identifier = record_id
if access_attributes:
self.owner = User(
principal="system",
attributes=access_attributes,
)
else:
self.owner = User(
principal="system_public",
attributes=None,
)
class AuthorizedSqlStore:
"""
Authorization layer for SqlStore that provides access control functionality.
This class composes a base SqlStore and adds authorization methods that handle
access control policies, user attribute capture, and SQL filtering optimization.
"""
def __init__(self, sql_store: SqlStore):
"""
Initialize the authorization layer.
:param sql_store: Base SqlStore implementation to wrap
"""
self.sql_store = sql_store
self._validate_sql_optimized_policy()
def _validate_sql_optimized_policy(self) -> None:
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
This ensures that if default_policy() changes, we detect the mismatch and
can update our SQL filtering logic accordingly.
"""
actual_default = default_policy()
if SQL_OPTIMIZED_POLICY != actual_default:
logger.warning(
f"SQL_OPTIMIZED_POLICY does not match default_policy(). "
f"SQL filtering will use conservative mode. "
f"Expected: {SQL_OPTIMIZED_POLICY}, Got: {actual_default}",
)
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
"""Create a table with built-in access control support."""
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
enhanced_schema = dict(schema)
if "access_attributes" not in enhanced_schema:
enhanced_schema["access_attributes"] = ColumnType.JSON
await self.sql_store.create_table(table, enhanced_schema)
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
"""Insert a row with automatic access control attribute capture."""
enhanced_data = dict(data)
current_user = get_authenticated_user()
if current_user and current_user.attributes:
enhanced_data["access_attributes"] = current_user.attributes
else:
enhanced_data["access_attributes"] = None
await self.sql_store.insert(table, enhanced_data)
async def fetch_all(
self,
table: str,
policy: list[AccessRule],
where: Mapping[str, Any] | None = None,
limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
cursor: tuple[str, str] | None = None,
) -> PaginatedResponse:
"""Fetch all rows with automatic access control filtering."""
access_where = self._build_access_control_where_clause(policy)
rows = await self.sql_store.fetch_all(
table=table,
where=where,
where_sql=access_where,
limit=limit,
order_by=order_by,
cursor=cursor,
)
current_user = get_authenticated_user()
filtered_rows = []
for row in rows.data:
stored_access_attrs = row.get("access_attributes")
record_id = row.get("id", "unknown")
sql_record = SqlRecord(str(record_id), table, stored_access_attrs)
if is_action_allowed(policy, Action.READ, sql_record, current_user):
filtered_rows.append(row)
return PaginatedResponse(
data=filtered_rows,
has_more=rows.has_more,
)
async def fetch_one(
self,
table: str,
policy: list[AccessRule],
where: Mapping[str, Any] | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None:
"""Fetch one row with automatic access control checking."""
results = await self.fetch_all(
table=table,
policy=policy,
where=where,
limit=1,
order_by=order_by,
)
return results.data[0] if results.data else None
def _build_access_control_where_clause(self, policy: list[AccessRule]) -> str:
"""Build SQL WHERE clause for access control filtering.
Only applies SQL filtering for the default policy to ensure correctness.
For custom policies, uses conservative filtering to avoid blocking legitimate access.
"""
if not policy or policy == SQL_OPTIMIZED_POLICY:
return self._build_default_policy_where_clause()
else:
return self._build_conservative_where_clause()
def _build_default_policy_where_clause(self) -> str:
"""Build SQL WHERE clause for the default policy.
Default policy: permit all actions when user in owners [roles, teams, projects, namespaces]
This means user must match ALL attribute categories that exist in the resource.
"""
current_user = get_authenticated_user()
if not current_user or not current_user.attributes:
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
else:
base_conditions = ["access_attributes IS NULL", "access_attributes = 'null'", "access_attributes = '{}'"]
user_attr_conditions = []
for attr_key, user_values in current_user.attributes.items():
if user_values:
value_conditions = []
for value in user_values:
value_conditions.append(f"JSON_EXTRACT(access_attributes, '$.{attr_key}') LIKE '%\"{value}\"%'")
if value_conditions:
category_missing = f"JSON_EXTRACT(access_attributes, '$.{attr_key}') IS NULL"
user_matches_category = f"({' OR '.join(value_conditions)})"
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
if user_attr_conditions:
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
base_conditions.append(all_requirements_met)
return f"({' OR '.join(base_conditions)})"
else:
return f"({' OR '.join(base_conditions)})"
def _build_conservative_where_clause(self) -> str:
"""Conservative SQL filtering for custom policies.
Only filters records we're 100% certain would be denied by any reasonable policy.
"""
current_user = get_authenticated_user()
if not current_user:
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
return "1=1"

View file

@ -17,13 +17,20 @@ from sqlalchemy import (
String,
Table,
Text,
inspect,
select,
text,
)
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, SqlStore
from .sqlstore import SqlAlchemySqlStoreConfig
logger = get_logger(name=__name__, category="sqlstore")
TYPE_MAPPING: dict[ColumnType, Any] = {
ColumnType.INTEGER: Integer,
ColumnType.STRING: String,
@ -54,7 +61,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
for col_name, col_props in schema.items():
col_type = None
is_primary_key = False
is_nullable = True # Default to nullable
is_nullable = True
if isinstance(col_props, ColumnType):
col_type = col_props
@ -71,14 +78,11 @@ class SqlAlchemySqlStoreImpl(SqlStore):
Column(col_name, sqlalchemy_type, primary_key=is_primary_key, nullable=is_nullable)
)
# Check if table already exists in metadata, otherwise define it
if table not in self.metadata.tables:
sqlalchemy_table = Table(table, self.metadata, *sqlalchemy_columns)
else:
sqlalchemy_table = self.metadata.tables[table]
# Create the table in the database if it doesn't exist
# checkfirst=True ensures it doesn't try to recreate if it's already there
engine = create_async_engine(self.config.engine_str)
async with engine.begin() as conn:
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
@ -92,16 +96,62 @@ class SqlAlchemySqlStoreImpl(SqlStore):
self,
table: str,
where: Mapping[str, Any] | None = None,
where_sql: str | None = None,
limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> list[dict[str, Any]]:
cursor: tuple[str, str] | None = None,
) -> PaginatedResponse:
async with self.async_session() as session:
query = select(self.metadata.tables[table])
table_obj = self.metadata.tables[table]
query = select(table_obj)
if where:
for key, value in where.items():
query = query.where(self.metadata.tables[table].c[key] == value)
if limit:
query = query.limit(limit)
query = query.where(table_obj.c[key] == value)
if where_sql:
query = query.where(text(where_sql))
# Handle cursor-based pagination
if cursor:
# Validate cursor tuple format
if not isinstance(cursor, tuple) or len(cursor) != 2:
raise ValueError(f"Cursor must be a tuple of (key_column, cursor_id), got: {cursor}")
# Require order_by for cursor pagination
if not order_by:
raise ValueError("order_by is required when using cursor pagination")
# Only support single-column ordering for cursor pagination
if len(order_by) != 1:
raise ValueError(
f"Cursor pagination only supports single-column ordering, got {len(order_by)} columns"
)
cursor_key_column, cursor_id = cursor
order_column, order_direction = order_by[0]
# Verify cursor_key_column exists
if cursor_key_column not in table_obj.c:
raise ValueError(f"Cursor key column '{cursor_key_column}' not found in table '{table}'")
# Get cursor value for the order column
cursor_query = select(table_obj.c[order_column]).where(table_obj.c[cursor_key_column] == cursor_id)
cursor_result = await session.execute(cursor_query)
cursor_row = cursor_result.fetchone()
if not cursor_row:
raise ValueError(f"Record with {cursor_key_column}='{cursor_id}' not found in table '{table}'")
cursor_value = cursor_row[0]
# Apply cursor condition based on sort direction
if order_direction == "desc":
query = query.where(table_obj.c[order_column] < cursor_value)
else:
query = query.where(table_obj.c[order_column] > cursor_value)
# Apply ordering
if order_by:
if not isinstance(order_by, list):
raise ValueError(
@ -113,27 +163,48 @@ class SqlAlchemySqlStoreImpl(SqlStore):
f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}"
)
name, order_type = order
if name not in table_obj.c:
raise ValueError(f"Column '{name}' not found in table '{table}'")
if order_type == "asc":
query = query.order_by(self.metadata.tables[table].c[name].asc())
query = query.order_by(table_obj.c[name].asc())
elif order_type == "desc":
query = query.order_by(self.metadata.tables[table].c[name].desc())
query = query.order_by(table_obj.c[name].desc())
else:
raise ValueError(f"Invalid order '{order_type}' for column '{name}'")
# Fetch limit + 1 to determine has_more
fetch_limit = limit
if limit:
fetch_limit = limit + 1
if fetch_limit:
query = query.limit(fetch_limit)
result = await session.execute(query)
if result.rowcount == 0:
return []
return [dict(row._mapping) for row in result]
rows = []
else:
rows = [dict(row._mapping) for row in result]
# Always return pagination result
has_more = False
if limit and len(rows) > limit:
has_more = True
rows = rows[:limit]
return PaginatedResponse(data=rows, has_more=has_more)
async def fetch_one(
self,
table: str,
where: Mapping[str, Any] | None = None,
where_sql: str | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None:
rows = await self.fetch_all(table, where, limit=1, order_by=order_by)
if not rows:
result = await self.fetch_all(table, where, where_sql, limit=1, order_by=order_by)
if not result.data:
return None
return rows[0]
return result.data[0]
async def update(
self,
@ -161,3 +232,47 @@ class SqlAlchemySqlStoreImpl(SqlStore):
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
await session.execute(stmt)
await session.commit()
async def add_column_if_not_exists(
self,
table: str,
column_name: str,
column_type: ColumnType,
nullable: bool = True,
) -> None:
"""Add a column to an existing table if the column doesn't already exist."""
engine = create_async_engine(self.config.engine_str)
try:
inspector = inspect(engine)
table_names = inspector.get_table_names()
if table not in table_names:
return
existing_columns = inspector.get_columns(table)
column_names = [col["name"] for col in existing_columns]
if column_name in column_names:
return
sqlalchemy_type = TYPE_MAPPING.get(column_type)
if not sqlalchemy_type:
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
# Create the ALTER TABLE statement
# Note: We need to get the dialect-specific type name
dialect = engine.dialect
type_impl = sqlalchemy_type()
compiled_type = type_impl.compile(dialect=dialect)
nullable_clause = "" if nullable else " NOT NULL"
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
async with engine.begin() as conn:
await conn.execute(add_column_sql)
except Exception:
# If any error occurs during migration, log it but don't fail
# The table creation will handle adding the column
pass

View file

@ -11,7 +11,7 @@ import queue
import random
import threading
from collections.abc import Callable
from datetime import datetime, timezone
from datetime import UTC, datetime
from functools import wraps
from typing import Any
@ -121,7 +121,7 @@ class TraceContext:
span_id=generate_span_id(),
trace_id=self.trace_id,
name=name,
start_time=datetime.now(timezone.utc),
start_time=datetime.now(UTC),
parent_span_id=current_span.span_id if current_span else None,
attributes=attributes,
)
@ -239,7 +239,7 @@ class TelemetryHandler(logging.Handler):
UnstructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
timestamp=datetime.now(timezone.utc),
timestamp=datetime.now(UTC),
message=self.format(record),
severity=severity(record.levelname),
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -46,22 +46,22 @@ The deployed platform includes the NIM Proxy microservice, which is the service
### Datasetio API: NeMo Data Store
The NeMo Data Store microservice serves as the default file storage solution for the NeMo microservices platform. It exposts APIs compatible with the Hugging Face Hub client (`HfApi`), so you can use the client to interact with Data Store. The `NVIDIA_DATASETS_URL` environment variable should point to your NeMo Data Store endpoint.
See the [NVIDIA Datasetio docs](/llama_stack/providers/remote/datasetio/nvidia/README.md) for supported features and example usage.
See the {repopath}`NVIDIA Datasetio docs::llama_stack/providers/remote/datasetio/nvidia/README.md` for supported features and example usage.
### Eval API: NeMo Evaluator
The NeMo Evaluator microservice supports evaluation of LLMs. Launching an Evaluation job with NeMo Evaluator requires an Evaluation Config (an object that contains metadata needed by the job). A Llama Stack Benchmark maps to an Evaluation Config, so registering a Benchmark creates an Evaluation Config in NeMo Evaluator. The `NVIDIA_EVALUATOR_URL` environment variable should point to your NeMo Microservices endpoint.
See the [NVIDIA Eval docs](/llama_stack/providers/remote/eval/nvidia/README.md) for supported features and example usage.
See the {repopath}`NVIDIA Eval docs::llama_stack/providers/remote/eval/nvidia/README.md` for supported features and example usage.
### Post-Training API: NeMo Customizer
The NeMo Customizer microservice supports fine-tuning models. You can reference [this list of supported models](/llama_stack/providers/remote/post_training/nvidia/models.py) that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
The NeMo Customizer microservice supports fine-tuning models. You can reference {repopath}`this list of supported models::llama_stack/providers/remote/post_training/nvidia/models.py` that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
See the [NVIDIA Post-Training docs](/llama_stack/providers/remote/post_training/nvidia/README.md) for supported features and example usage.
See the {repopath}`NVIDIA Post-Training docs::llama_stack/providers/remote/post_training/nvidia/README.md` for supported features and example usage.
### Safety API: NeMo Guardrails
The NeMo Guardrails microservice sits between your application and the LLM, and adds checks and content moderation to a model. The `GUARDRAILS_SERVICE_URL` environment variable should point to your NeMo Microservices endpoint.
See the NVIDIA Safety docs for supported features and example usage.
See the {repopath}`NVIDIA Safety docs::llama_stack/providers/remote/safety/nvidia/README.md` for supported features and example usage.
## Deploying models
In order to use a registered model with the Llama Stack APIs, ensure the corresponding NIM is deployed to your environment. For example, you can use the NIM Proxy microservice to deploy `meta/llama-3.2-1b-instruct`.
@ -144,3 +144,6 @@ llama stack run ./run.yaml \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
--env INFERENCE_MODEL=$INFERENCE_MODEL
```
## Example Notebooks
For examples of how to use the NVIDIA Distribution to run inference, fine-tune, evaluate, and run safety checks on your LLMs, you can reference the example notebooks in {repopath}`docs/notebooks/nvidia`.

View file

@ -130,6 +130,10 @@ def get_distribution_template() -> DistributionTemplate:
"http://0.0.0.0:7331",
"URL for the NeMo Guardrails Service",
),
"NVIDIA_GUARDRAILS_CONFIG_ID": (
"self-check",
"NVIDIA Guardrail Configuration ID",
),
"NVIDIA_EVALUATOR_URL": (
"http://0.0.0.0:7331",
"URL for the NeMo Evaluator Service",

View file

@ -23,7 +23,7 @@ providers:
provider_type: remote::nvidia
config:
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id: self-check
config_id: ${env.NVIDIA_GUARDRAILS_CONFIG_ID:self-check}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
@ -37,7 +37,7 @@ providers:
provider_type: remote::nvidia
config:
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id: self-check
config_id: ${env.NVIDIA_GUARDRAILS_CONFIG_ID:self-check}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -32,7 +32,7 @@ providers:
provider_type: remote::nvidia
config:
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id: self-check
config_id: ${env.NVIDIA_GUARDRAILS_CONFIG_ID:self-check}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -23,6 +23,8 @@ distribution_spec:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
files:
- inline::localfs
post_training:
- inline::huggingface
tool_runtime:

View file

@ -13,6 +13,7 @@ from llama_stack.distribution.datatypes import (
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
@ -29,6 +30,7 @@ def get_distribution_template() -> DistributionTemplate:
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"files": ["inline::localfs"],
"post_training": ["inline::huggingface"],
"tool_runtime": [
"remote::brave-search",
@ -49,6 +51,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
files_provider = Provider(
provider_id="meta-reference-files",
provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
posttraining_provider = Provider(
provider_id="huggingface",
provider_type="inline::huggingface",
@ -98,6 +105,7 @@ def get_distribution_template() -> DistributionTemplate:
provider_overrides={
"inference": [inference_provider],
"vector_io": [vector_io_provider_faiss],
"files": [files_provider],
"post_training": [posttraining_provider],
},
default_models=[inference_model, embedding_model],
@ -107,6 +115,7 @@ def get_distribution_template() -> DistributionTemplate:
provider_overrides={
"inference": [inference_provider],
"vector_io": [vector_io_provider_faiss],
"files": [files_provider],
"post_training": [posttraining_provider],
"safety": [
Provider(

View file

@ -4,6 +4,7 @@ apis:
- agents
- datasetio
- eval
- files
- inference
- post_training
- safety
@ -84,6 +85,14 @@ providers:
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:~/.llama/distributions/ollama/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/files_metadata.db
post_training:
- provider_id: huggingface
provider_type: inline::huggingface

View file

@ -4,6 +4,7 @@ apis:
- agents
- datasetio
- eval
- files
- inference
- post_training
- safety
@ -82,6 +83,14 @@ providers:
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:~/.llama/distributions/ollama/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/files_metadata.db
post_training:
- provider_id: huggingface
provider_type: inline::huggingface

View file

@ -21,4 +21,5 @@ distribution_spec:
image_type: conda
additional_pip_packages:
- asyncpg
- psycopg2-binary
- sqlalchemy[asyncio]

View file

@ -17,6 +17,8 @@ distribution_spec:
- inline::sqlite-vec
- remote::chromadb
- remote::pgvector
files:
- inline::localfs
safety:
- inline::llama-guard
agents:

View file

@ -4,6 +4,7 @@ apis:
- agents
- datasetio
- eval
- files
- inference
- safety
- scoring
@ -75,6 +76,14 @@ providers:
db: ${env.PGVECTOR_DB:}
user: ${env.PGVECTOR_USER:}
password: ${env.PGVECTOR_PASSWORD:}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/files_metadata.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
@ -722,6 +731,21 @@ models:
provider_id: gemini
provider_model_id: gemini/gemini-1.5-pro
model_type: llm
- metadata: {}
model_id: gemini/gemini-2.0-flash
provider_id: gemini
provider_model_id: gemini/gemini-2.0-flash
model_type: llm
- metadata: {}
model_id: gemini/gemini-2.5-flash
provider_id: gemini
provider_model_id: gemini/gemini-2.5-flash
model_type: llm
- metadata: {}
model_id: gemini/gemini-2.5-pro
provider_id: gemini
provider_model_id: gemini/gemini-2.5-pro
model_type: llm
- metadata:
embedding_dimension: 768
context_length: 2048

View file

@ -12,6 +12,7 @@ from llama_stack.distribution.datatypes import (
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
@ -134,6 +135,7 @@ def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ([p.provider_type for p in inference_providers] + ["inline::sentence-transformers"]),
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
"files": ["inline::localfs"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
@ -170,6 +172,11 @@ def get_distribution_template() -> DistributionTemplate:
),
),
]
files_provider = Provider(
provider_id="meta-reference-files",
provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
@ -212,6 +219,7 @@ def get_distribution_template() -> DistributionTemplate:
provider_overrides={
"inference": inference_providers + [embedding_provider],
"vector_io": vector_io_providers,
"files": [files_provider],
},
default_models=default_models + [embedding_model],
default_tool_groups=default_tool_groups,

View file

@ -186,8 +186,14 @@ class DistributionTemplate(BaseModel):
additional_pip_packages: list[str] = []
for run_config in self.run_configs.values():
run_config_ = run_config.run_config(self.name, self.providers, self.container_image)
# TODO: This is a hack to get the dependencies for internal APIs into build
# We should have a better way to do this by formalizing the concept of "internal" APIs
# and providers, with a way to specify dependencies for them.
if run_config_.inference_store:
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
if run_config_.metadata_store:
additional_pip_packages.extend(run_config_.metadata_store.pip_packages)
if self.additional_pip_packages:
additional_pip_packages.extend(self.additional_pip_packages)

View file

@ -39,3 +39,6 @@ yarn-error.log*
# typescript
*.tsbuildinfo
next-env.d.ts
# playwright
.last-run.json

View file

@ -1,51 +1,7 @@
"use client";
import { useEffect, useState } from "react";
import { ChatCompletion } from "@/lib/types";
import { ChatCompletionsTable } from "@/components/chat-completions/chat-completions-table";
import { client } from "@/lib/client";
export default function ChatCompletionsPage() {
const [completions, setCompletions] = useState<ChatCompletion[]>([]);
const [isLoading, setIsLoading] = useState<boolean>(true);
const [error, setError] = useState<Error | null>(null);
useEffect(() => {
const fetchCompletions = async () => {
setIsLoading(true);
setError(null);
try {
const response = await client.chat.completions.list();
const data = Array.isArray(response)
? response
: (response as { data: ChatCompletion[] }).data;
if (Array.isArray(data)) {
setCompletions(data);
} else {
console.error("Unexpected response structure:", response);
setError(new Error("Unexpected response structure"));
setCompletions([]);
}
} catch (err) {
console.error("Error fetching chat completions:", err);
setError(
err instanceof Error ? err : new Error("Failed to fetch completions"),
);
setCompletions([]);
} finally {
setIsLoading(false);
}
};
fetchCompletions();
}, []);
return (
<ChatCompletionsTable
data={completions}
isLoading={isLoading}
error={error}
/>
);
return <ChatCompletionsTable paginationOptions={{ limit: 20 }} />;
}

View file

@ -1,66 +1,7 @@
"use client";
import { useEffect, useState } from "react";
import type { ResponseListResponse } from "llama-stack-client/resources/responses/responses";
import { OpenAIResponse } from "@/lib/types";
import { ResponsesTable } from "@/components/responses/responses-table";
import { client } from "@/lib/client";
export default function ResponsesPage() {
const [responses, setResponses] = useState<OpenAIResponse[]>([]);
const [isLoading, setIsLoading] = useState<boolean>(true);
const [error, setError] = useState<Error | null>(null);
// Helper function to convert ResponseListResponse.Data to OpenAIResponse
const convertResponseListData = (
responseData: ResponseListResponse.Data,
): OpenAIResponse => {
return {
id: responseData.id,
created_at: responseData.created_at,
model: responseData.model,
object: responseData.object,
status: responseData.status,
output: responseData.output as OpenAIResponse["output"],
input: responseData.input as OpenAIResponse["input"],
error: responseData.error,
parallel_tool_calls: responseData.parallel_tool_calls,
previous_response_id: responseData.previous_response_id,
temperature: responseData.temperature,
top_p: responseData.top_p,
truncation: responseData.truncation,
user: responseData.user,
};
};
useEffect(() => {
const fetchResponses = async () => {
setIsLoading(true);
setError(null);
try {
const response = await client.responses.list();
const responseListData = response as ResponseListResponse;
const convertedResponses: OpenAIResponse[] = responseListData.data.map(
convertResponseListData,
);
setResponses(convertedResponses);
} catch (err) {
console.error("Error fetching responses:", err);
setError(
err instanceof Error ? err : new Error("Failed to fetch responses"),
);
setResponses([]);
} finally {
setIsLoading(false);
}
};
fetchResponses();
}, []);
return (
<ResponsesTable data={responses} isLoading={isLoading} error={error} />
);
return <ResponsesTable paginationOptions={{ limit: 20 }} />;
}

View file

@ -16,6 +16,29 @@ jest.mock("next/navigation", () => ({
jest.mock("@/lib/truncate-text");
jest.mock("@/lib/format-message-content");
// Mock the client
jest.mock("@/lib/client", () => ({
client: {
chat: {
completions: {
list: jest.fn(),
},
},
},
}));
// Mock the usePagination hook
const mockLoadMore = jest.fn();
jest.mock("@/hooks/usePagination", () => ({
usePagination: jest.fn(() => ({
data: [],
status: "idle",
hasMore: false,
error: null,
loadMore: mockLoadMore,
})),
}));
// Import the mocked functions to set up default or specific implementations
import { truncateText as originalTruncateText } from "@/lib/truncate-text";
import {
@ -23,6 +46,12 @@ import {
extractDisplayableText as originalExtractDisplayableText,
} from "@/lib/format-message-content";
// Import the mocked hook
import { usePagination } from "@/hooks/usePagination";
const mockedUsePagination = usePagination as jest.MockedFunction<
typeof usePagination
>;
// Cast to jest.Mock for typings
const truncateText = originalTruncateText as jest.Mock;
const extractTextFromContentPart =
@ -30,11 +59,7 @@ const extractTextFromContentPart =
const extractDisplayableText = originalExtractDisplayableText as jest.Mock;
describe("ChatCompletionsTable", () => {
const defaultProps = {
data: [] as ChatCompletion[],
isLoading: false,
error: null,
};
const defaultProps = {};
beforeEach(() => {
// Reset all mocks before each test
@ -42,16 +67,27 @@ describe("ChatCompletionsTable", () => {
truncateText.mockClear();
extractTextFromContentPart.mockClear();
extractDisplayableText.mockClear();
mockLoadMore.mockClear();
jest.clearAllMocks();
// Default pass-through implementations
truncateText.mockImplementation((text: string | undefined) => text);
extractTextFromContentPart.mockImplementation((content: unknown) =>
typeof content === "string" ? content : "extracted text",
);
extractDisplayableText.mockImplementation(
(message: unknown) =>
(message as { content?: string })?.content || "extracted output",
);
extractDisplayableText.mockImplementation((message: unknown) => {
const msg = message as { content?: string };
return msg?.content || "extracted output";
});
// Default hook return value
mockedUsePagination.mockReturnValue({
data: [],
status: "idle",
hasMore: false,
error: null,
loadMore: mockLoadMore,
});
});
test("renders without crashing with default props", () => {
@ -60,41 +96,56 @@ describe("ChatCompletionsTable", () => {
});
test("click on a row navigates to the correct URL", () => {
const mockCompletion: ChatCompletion = {
id: "comp_123",
object: "chat.completion",
created: Math.floor(Date.now() / 1000),
model: "llama-test-model",
choices: [
{
index: 0,
message: { role: "assistant", content: "Test output" },
finish_reason: "stop",
},
],
input_messages: [{ role: "user", content: "Test input" }],
};
const mockData: ChatCompletion[] = [
{
id: "completion_123",
choices: [
{
message: { role: "assistant", content: "Test response" },
finish_reason: "stop",
index: 0,
},
],
object: "chat.completion",
created: 1234567890,
model: "test-model",
input_messages: [{ role: "user", content: "Test prompt" }],
},
];
// Set up mocks to return expected values
extractTextFromContentPart.mockReturnValue("Test input");
extractDisplayableText.mockReturnValue("Test output");
// Configure the mock to return our test data
mockedUsePagination.mockReturnValue({
data: mockData,
status: "idle",
hasMore: false,
error: null,
loadMore: mockLoadMore,
});
render(<ChatCompletionsTable {...defaultProps} data={[mockCompletion]} />);
render(<ChatCompletionsTable {...defaultProps} />);
const row = screen.getByText("Test input").closest("tr");
const row = screen.getByText("Test prompt").closest("tr");
if (row) {
fireEvent.click(row);
expect(mockPush).toHaveBeenCalledWith("/logs/chat-completions/comp_123");
expect(mockPush).toHaveBeenCalledWith(
"/logs/chat-completions/completion_123",
);
} else {
throw new Error('Row with "Test input" not found for router mock test.');
throw new Error('Row with "Test prompt" not found for router mock test.');
}
});
describe("Loading State", () => {
test("renders skeleton UI when isLoading is true", () => {
const { container } = render(
<ChatCompletionsTable {...defaultProps} isLoading={true} />,
);
mockedUsePagination.mockReturnValue({
data: [],
status: "loading",
hasMore: false,
error: null,
loadMore: mockLoadMore,
});
const { container } = render(<ChatCompletionsTable {...defaultProps} />);
// Check for skeleton in the table caption
const tableCaption = container.querySelector("caption");
@ -121,40 +172,48 @@ describe("ChatCompletionsTable", () => {
describe("Error State", () => {
test("renders error message when error prop is provided", () => {
const errorMessage = "Network Error";
render(
<ChatCompletionsTable
{...defaultProps}
error={{ name: "Error", message: errorMessage }}
/>,
);
mockedUsePagination.mockReturnValue({
data: [],
status: "error",
hasMore: false,
error: { name: "Error", message: errorMessage } as Error,
loadMore: mockLoadMore,
});
render(<ChatCompletionsTable {...defaultProps} />);
expect(
screen.getByText(`Error fetching data: ${errorMessage}`),
screen.getByText("Unable to load chat completions"),
).toBeInTheDocument();
expect(screen.getByText(errorMessage)).toBeInTheDocument();
});
test("renders default error message when error.message is not available", () => {
render(
<ChatCompletionsTable
{...defaultProps}
error={{ name: "Error", message: "" }}
/>,
);
expect(
screen.getByText("Error fetching data: An unknown error occurred"),
).toBeInTheDocument();
});
test.each([{ name: "Error", message: "" }, {}])(
"renders default error message when error has no message",
(errorObject) => {
mockedUsePagination.mockReturnValue({
data: [],
status: "error",
hasMore: false,
error: errorObject as Error,
loadMore: mockLoadMore,
});
test("renders default error message when error prop is an object without message", () => {
render(<ChatCompletionsTable {...defaultProps} error={{} as Error} />);
expect(
screen.getByText("Error fetching data: An unknown error occurred"),
).toBeInTheDocument();
});
render(<ChatCompletionsTable {...defaultProps} />);
expect(
screen.getByText("Unable to load chat completions"),
).toBeInTheDocument();
expect(
screen.getByText(
"An unexpected error occurred while loading the data.",
),
).toBeInTheDocument();
},
);
});
describe("Empty State", () => {
test('renders "No chat completions found." and no table when data array is empty', () => {
render(<ChatCompletionsTable data={[]} isLoading={false} error={null} />);
render(<ChatCompletionsTable {...defaultProps} />);
expect(
screen.getByText("No chat completions found."),
).toBeInTheDocument();
@ -167,7 +226,7 @@ describe("ChatCompletionsTable", () => {
describe("Data Rendering", () => {
test("renders table caption, headers, and completion data correctly", () => {
const mockCompletions = [
const mockCompletions: ChatCompletion[] = [
{
id: "comp_1",
object: "chat.completion",
@ -211,13 +270,15 @@ describe("ChatCompletionsTable", () => {
return "extracted output";
});
render(
<ChatCompletionsTable
data={mockCompletions}
isLoading={false}
error={null}
/>,
);
mockedUsePagination.mockReturnValue({
data: mockCompletions,
status: "idle",
hasMore: false,
error: null,
loadMore: mockLoadMore,
});
render(<ChatCompletionsTable {...defaultProps} />);
// Table caption
expect(
@ -268,7 +329,7 @@ describe("ChatCompletionsTable", () => {
extractTextFromContentPart.mockReturnValue(longInput);
extractDisplayableText.mockReturnValue(longOutput);
const mockCompletions = [
const mockCompletions: ChatCompletion[] = [
{
id: "comp_trunc",
object: "chat.completion",
@ -285,63 +346,72 @@ describe("ChatCompletionsTable", () => {
},
];
render(
<ChatCompletionsTable
data={mockCompletions}
isLoading={false}
error={null}
/>,
);
mockedUsePagination.mockReturnValue({
data: mockCompletions,
status: "idle",
hasMore: false,
error: null,
loadMore: mockLoadMore,
});
render(<ChatCompletionsTable {...defaultProps} />);
// The truncated text should be present for both input and output
const truncatedTexts = screen.getAllByText(
longInput.slice(0, 10) + "...",
);
expect(truncatedTexts.length).toBe(2); // one for input, one for output
truncatedTexts.forEach((textElement) =>
expect(textElement).toBeInTheDocument(),
);
});
test("uses content extraction functions correctly", () => {
const mockCompletion = {
id: "comp_extract",
object: "chat.completion",
created: 1710003000,
model: "llama-extract-model",
choices: [
{
index: 0,
message: { role: "assistant", content: "Extracted output" },
finish_reason: "stop",
},
],
input_messages: [{ role: "user", content: "Extracted input" }],
const complexMessage = [
{ type: "text", text: "Extracted input" },
{ type: "image", url: "http://example.com/image.png" },
];
const assistantMessage = {
role: "assistant",
content: "Extracted output from assistant",
};
const mockCompletions: ChatCompletion[] = [
{
id: "comp_extract",
object: "chat.completion",
created: 1710003000,
model: "llama-extract-model",
choices: [
{
index: 0,
message: assistantMessage,
finish_reason: "stop",
},
],
input_messages: [{ role: "user", content: complexMessage }],
},
];
extractTextFromContentPart.mockReturnValue("Extracted input");
extractDisplayableText.mockReturnValue("Extracted output");
extractDisplayableText.mockReturnValue("Extracted output from assistant");
render(
<ChatCompletionsTable
data={[mockCompletion]}
isLoading={false}
error={null}
/>,
);
// Verify the extraction functions were called
expect(extractTextFromContentPart).toHaveBeenCalledWith(
"Extracted input",
);
expect(extractDisplayableText).toHaveBeenCalledWith({
role: "assistant",
content: "Extracted output",
mockedUsePagination.mockReturnValue({
data: mockCompletions,
status: "idle",
hasMore: false,
error: null,
loadMore: mockLoadMore,
});
// Verify the extracted content is displayed
render(<ChatCompletionsTable {...defaultProps} />);
// Verify the extraction functions were called
expect(extractTextFromContentPart).toHaveBeenCalledWith(complexMessage);
expect(extractDisplayableText).toHaveBeenCalledWith(assistantMessage);
// Verify the extracted text appears in the table
expect(screen.getByText("Extracted input")).toBeInTheDocument();
expect(screen.getByText("Extracted output")).toBeInTheDocument();
expect(
screen.getByText("Extracted output from assistant"),
).toBeInTheDocument();
});
});
});

View file

@ -1,16 +1,21 @@
"use client";
import { ChatCompletion } from "@/lib/types";
import {
ChatCompletion,
UsePaginationOptions,
ListChatCompletionsResponse,
} from "@/lib/types";
import { LogsTable, LogTableRow } from "@/components/logs/logs-table";
import {
extractTextFromContentPart,
extractDisplayableText,
} from "@/lib/format-message-content";
import { usePagination } from "@/hooks/usePagination";
import { client } from "@/lib/client";
interface ChatCompletionsTableProps {
data: ChatCompletion[];
isLoading: boolean;
error: Error | null;
/** Optional pagination configuration */
paginationOptions?: UsePaginationOptions;
}
function formatChatCompletionToRow(completion: ChatCompletion): LogTableRow {
@ -25,17 +30,39 @@ function formatChatCompletionToRow(completion: ChatCompletion): LogTableRow {
}
export function ChatCompletionsTable({
data,
isLoading,
error,
paginationOptions,
}: ChatCompletionsTableProps) {
const fetchFunction = async (params: {
after?: string;
limit: number;
model?: string;
order?: string;
}) => {
const response = await client.chat.completions.list({
after: params.after,
limit: params.limit,
...(params.model && { model: params.model }),
...(params.order && { order: params.order }),
} as any);
return response as ListChatCompletionsResponse;
};
const { data, status, hasMore, error, loadMore } = usePagination({
...paginationOptions,
fetchFunction,
errorMessagePrefix: "chat completions",
});
const formattedData = data.map(formatChatCompletionToRow);
return (
<LogsTable
data={formattedData}
isLoading={isLoading}
status={status}
hasMore={hasMore}
error={error}
onLoadMore={loadMore}
caption="A list of your recent chat completions."
emptyMessage="No chat completions found."
/>

View file

@ -37,13 +37,11 @@ export default function LogsLayout({
}
return (
<div className="container mx-auto p-4">
<>
{segments.length > 0 && (
<PageBreadcrumb segments={segments} className="mb-4" />
)}
{children}
</>
<div className="container mx-auto p-4 h-[calc(100vh-64px)] flex flex-col">
{segments.length > 0 && (
<PageBreadcrumb segments={segments} className="mb-4" />
)}
<div className="flex-1 min-h-0 flex flex-col">{children}</div>
</div>
);
}

View file

@ -0,0 +1,142 @@
import React from "react";
import { render, waitFor } from "@testing-library/react";
import "@testing-library/jest-dom";
import { LogsTable, LogTableRow } from "./logs-table";
import { PaginationStatus } from "@/lib/types";
// Mock next/navigation
jest.mock("next/navigation", () => ({
useRouter: () => ({
push: jest.fn(),
}),
}));
// Mock the useInfiniteScroll hook
jest.mock("@/hooks/useInfiniteScroll", () => ({
useInfiniteScroll: jest.fn((onLoadMore, options) => {
const ref = React.useRef(null);
React.useEffect(() => {
// Simulate the observer behavior
if (options?.enabled && onLoadMore) {
// Trigger load after a delay to simulate intersection
const timeout = setTimeout(() => {
onLoadMore();
}, 100);
return () => clearTimeout(timeout);
}
}, [options?.enabled, onLoadMore]);
return ref;
}),
}));
// IntersectionObserver mock is already in jest.setup.ts
describe("LogsTable Viewport Loading", () => {
const mockData: LogTableRow[] = Array.from({ length: 10 }, (_, i) => ({
id: `row_${i}`,
input: `Input ${i}`,
output: `Output ${i}`,
model: "test-model",
createdTime: new Date().toISOString(),
detailPath: `/logs/test/${i}`,
}));
const defaultProps = {
data: mockData,
status: "idle" as PaginationStatus,
hasMore: true,
error: null,
caption: "Test table",
emptyMessage: "No data",
};
beforeEach(() => {
jest.clearAllMocks();
});
test("should trigger loadMore when sentinel is visible", async () => {
const mockLoadMore = jest.fn();
render(<LogsTable {...defaultProps} onLoadMore={mockLoadMore} />);
// Wait for the intersection observer to trigger
await waitFor(
() => {
expect(mockLoadMore).toHaveBeenCalled();
},
{ timeout: 300 },
);
expect(mockLoadMore).toHaveBeenCalledTimes(1);
});
test("should not trigger loadMore when already loading", async () => {
const mockLoadMore = jest.fn();
render(
<LogsTable
{...defaultProps}
status="loading-more"
onLoadMore={mockLoadMore}
/>,
);
// Wait for possible triggers
await new Promise((resolve) => setTimeout(resolve, 300));
expect(mockLoadMore).not.toHaveBeenCalled();
});
test("should not trigger loadMore when status is loading", async () => {
const mockLoadMore = jest.fn();
render(
<LogsTable
{...defaultProps}
status="loading"
onLoadMore={mockLoadMore}
/>,
);
// Wait for possible triggers
await new Promise((resolve) => setTimeout(resolve, 300));
expect(mockLoadMore).not.toHaveBeenCalled();
});
test("should not trigger loadMore when hasMore is false", async () => {
const mockLoadMore = jest.fn();
render(
<LogsTable {...defaultProps} hasMore={false} onLoadMore={mockLoadMore} />,
);
// Wait for possible triggers
await new Promise((resolve) => setTimeout(resolve, 300));
expect(mockLoadMore).not.toHaveBeenCalled();
});
test("sentinel element should not be rendered when loading", () => {
const { container } = render(
<LogsTable {...defaultProps} status="loading-more" />,
);
// Check that no sentinel row with height: 1 exists
const sentinelRow = container.querySelector('tr[style*="height: 1"]');
expect(sentinelRow).not.toBeInTheDocument();
});
test("sentinel element should be rendered when not loading and hasMore", () => {
const { container } = render(
<LogsTable {...defaultProps} hasMore={true} status="idle" />,
);
// Check that sentinel row exists
const sentinelRow = container.querySelector('tr[style*="height: 1"]');
expect(sentinelRow).toBeInTheDocument();
});
});

Some files were not shown because too many files have changed in this diff Show more