chore(package): migrate to src/ layout (#3920)

Migrates package structure to src/ layout following Python packaging
best practices.

All code moved from `llama_stack/` to `src/llama_stack/`. Public API
unchanged - imports remain `import llama_stack.*`.

Updated build configs, pre-commit hooks, scripts, and GitHub workflows
accordingly. All hooks pass, package builds cleanly.

**Developer note**: Reinstall after pulling: `pip install -e .`
This commit is contained in:
Ashwin Bharambe 2025-10-27 12:02:21 -07:00 committed by GitHub
parent 98a5047f9d
commit 471b1b248b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
791 changed files with 2983 additions and 456 deletions

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,74 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import boto3
from botocore.client import BaseClient
from botocore.config import Config
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
from llama_stack.providers.utils.bedrock.refreshable_boto_session import (
RefreshableBotoSession,
)
def create_bedrock_client(config: BedrockBaseConfig, service_name: str = "bedrock-runtime") -> BaseClient:
"""Creates a boto3 client for Bedrock services with the given configuration.
Args:
config: The Bedrock configuration containing AWS credentials and settings
service_name: The AWS service name to create client for (default: "bedrock-runtime")
Returns:
A configured boto3 client
"""
if config.aws_access_key_id and config.aws_secret_access_key:
retries_config = {
k: v
for k, v in dict(
total_max_attempts=config.total_max_attempts,
mode=config.retry_mode,
).items()
if v is not None
}
config_args = {
k: v
for k, v in dict(
region_name=config.region_name,
retries=retries_config if retries_config else None,
connect_timeout=config.connect_timeout,
read_timeout=config.read_timeout,
).items()
if v is not None
}
boto3_config = Config(**config_args)
session_args = {
"aws_access_key_id": config.aws_access_key_id,
"aws_secret_access_key": config.aws_secret_access_key,
"aws_session_token": config.aws_session_token,
"region_name": config.region_name,
"profile_name": config.profile_name,
"session_ttl": config.session_ttl,
}
# Remove None values
session_args = {k: v for k, v in session_args.items() if v is not None}
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client(service_name, config=boto3_config)
else:
return (
RefreshableBotoSession(
region_name=config.region_name,
profile_name=config.profile_name,
session_ttl=config.session_ttl,
)
.refreshable_session()
.client(service_name)
)

View file

@ -1,64 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
class BedrockBaseConfig(RemoteInferenceProviderConfig):
auth_credential: None = Field(default=None, exclude=True)
aws_access_key_id: str | None = Field(
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
)
aws_secret_access_key: str | None = Field(
default_factory=lambda: os.getenv("AWS_SECRET_ACCESS_KEY"),
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
)
aws_session_token: str | None = Field(
default_factory=lambda: os.getenv("AWS_SESSION_TOKEN"),
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
)
region_name: str | None = Field(
default_factory=lambda: os.getenv("AWS_DEFAULT_REGION"),
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
"Default use environment variable: AWS_DEFAULT_REGION",
)
profile_name: str | None = Field(
default_factory=lambda: os.getenv("AWS_PROFILE"),
description="The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE",
)
total_max_attempts: int | None = Field(
default_factory=lambda: int(val) if (val := os.getenv("AWS_MAX_ATTEMPTS")) else None,
description="An integer representing the maximum number of attempts that will be made for a single request, "
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
)
retry_mode: str | None = Field(
default_factory=lambda: os.getenv("AWS_RETRY_MODE"),
description="A string representing the type of retries Boto3 will perform."
"Default use environment variable: AWS_RETRY_MODE",
)
connect_timeout: float | None = Field(
default_factory=lambda: float(os.getenv("AWS_CONNECT_TIMEOUT", "60")),
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
"The default is 60 seconds.",
)
read_timeout: float | None = Field(
default_factory=lambda: float(os.getenv("AWS_READ_TIMEOUT", "60")),
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
"The default is 60 seconds.",
)
session_ttl: int | None = Field(
default_factory=lambda: int(os.getenv("AWS_SESSION_TTL", "3600")),
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
)
@classmethod
def sample_run_config(cls, **kwargs):
return {}

View file

@ -1,112 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import datetime
from time import time
from uuid import uuid4
from boto3 import Session
from botocore.credentials import RefreshableCredentials
from botocore.session import get_session
class RefreshableBotoSession:
"""
Boto Helper class which lets us create a refreshable session so that we can cache the client or resource.
Usage
-----
session = RefreshableBotoSession().refreshable_session()
client = session.client("s3") # we now can cache this client object without worrying about expiring credentials
"""
def __init__(
self,
region_name: str = None,
profile_name: str = None,
sts_arn: str = None,
session_name: str = None,
session_ttl: int = 30000,
):
"""
Initialize `RefreshableBotoSession`
Parameters
----------
region_name : str (optional)
Default region when creating a new connection.
profile_name : str (optional)
The name of a profile to use.
sts_arn : str (optional)
The role arn to sts before creating a session.
session_name : str (optional)
An identifier for the assumed role session. (required when `sts_arn` is given)
session_ttl : int (optional)
An integer number to set the TTL for each session. Beyond this session, it will renew the token.
50 minutes by default which is before the default role expiration of 1 hour
"""
self.region_name = region_name
self.profile_name = profile_name
self.sts_arn = sts_arn
self.session_name = session_name or uuid4().hex
self.session_ttl = session_ttl
def __get_session_credentials(self):
"""
Get session credentials
"""
session = Session(region_name=self.region_name, profile_name=self.profile_name)
# if sts_arn is given, get credential by assuming the given role
if self.sts_arn:
sts_client = session.client(service_name="sts", region_name=self.region_name)
response = sts_client.assume_role(
RoleArn=self.sts_arn,
RoleSessionName=self.session_name,
DurationSeconds=self.session_ttl,
).get("Credentials")
credentials = {
"access_key": response.get("AccessKeyId"),
"secret_key": response.get("SecretAccessKey"),
"token": response.get("SessionToken"),
"expiry_time": response.get("Expiration").isoformat(),
}
else:
session_credentials = session.get_credentials().get_frozen_credentials()
credentials = {
"access_key": session_credentials.access_key,
"secret_key": session_credentials.secret_key,
"token": session_credentials.token,
"expiry_time": datetime.datetime.fromtimestamp(time() + self.session_ttl, datetime.UTC).isoformat(),
}
return credentials
def refreshable_session(self) -> Session:
"""
Get refreshable boto3 session.
"""
# Get refreshable credentials
refreshable_credentials = RefreshableCredentials.create_from_metadata(
metadata=self.__get_session_credentials(),
refresh_using=self.__get_session_credentials,
method="sts-assume-role",
)
# attach refreshable credentials current session
session = get_session()
session._credentials = refreshable_credentials
session.set_config_variable("region", self.region_name)
autorefresh_session = Session(botocore_session=session)
return autorefresh_session

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,103 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
CompletionInputType,
StringType,
)
from llama_stack.core.datatypes import Api
class ColumnName(Enum):
input_query = "input_query"
expected_answer = "expected_answer"
chat_completion_input = "chat_completion_input"
completion_input = "completion_input"
generated_answer = "generated_answer"
context = "context"
dialog = "dialog"
function = "function"
language = "language"
id = "id"
ground_truth = "ground_truth"
VALID_SCHEMAS_FOR_SCORING = [
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.generated_answer.value: StringType(),
},
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.generated_answer.value: StringType(),
ColumnName.context.value: StringType(),
},
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.generated_answer.value: StringType(),
ColumnName.function.value: StringType(),
ColumnName.language.value: StringType(),
ColumnName.id.value: StringType(),
ColumnName.ground_truth.value: StringType(),
},
]
VALID_SCHEMAS_FOR_EVAL = [
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
},
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.completion_input.value: CompletionInputType(),
},
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.generated_answer.value: StringType(),
ColumnName.function.value: StringType(),
ColumnName.language.value: StringType(),
ColumnName.id.value: StringType(),
ColumnName.ground_truth.value: StringType(),
},
]
def get_valid_schemas(api_str: str):
if api_str == Api.scoring.value:
return VALID_SCHEMAS_FOR_SCORING
elif api_str == Api.eval.value:
return VALID_SCHEMAS_FOR_EVAL
else:
raise ValueError(f"Invalid API string: {api_str}")
def validate_dataset_schema(
dataset_schema: dict[str, Any],
expected_schemas: list[dict[str, Any]],
):
if dataset_schema not in expected_schemas:
raise ValueError(f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}")
def validate_row_schema(
input_row: dict[str, Any],
expected_schemas: list[dict[str, Any]],
):
for schema in expected_schemas:
if all(key in input_row for key in schema):
return
raise ValueError(f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}")

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,47 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import base64
import io
from urllib.parse import unquote
from llama_stack.providers.utils.memory.vector_store import parse_data_url
async def get_dataframe_from_uri(uri: str):
import pandas
df = None
if uri.endswith(".csv"):
# Moving to its own thread to avoid io from blocking the eventloop
# This isn't ideal as it moves more then just the IO to a new thread
# but it is as close as we can easly get
df = await asyncio.to_thread(pandas.read_csv, uri)
elif uri.endswith(".xlsx"):
df = await asyncio.to_thread(pandas.read_excel, uri)
elif uri.startswith("data:"):
parts = parse_data_url(uri)
data = parts["data"]
if parts["is_base64"]:
data = base64.b64decode(data)
else:
data = unquote(data)
encoding = parts["encoding"] or "utf-8"
data = data.encode(encoding)
mime_type = parts["mimetype"]
mime_category = mime_type.split("/")[0]
data_bytes = io.BytesIO(data)
if mime_category == "text":
df = pandas.read_csv(data_bytes)
else:
df = pandas.read_excel(data_bytes)
else:
raise ValueError(f"Unsupported file type: {uri}")
return df

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,69 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from fastapi import Request
from pydantic import BaseModel, ValidationError
from llama_stack.apis.files import ExpiresAfter
async def parse_pydantic_from_form[T: BaseModel](request: Request, field_name: str, model_class: type[T]) -> T | None:
"""
Generic parser to extract a Pydantic model from multipart form data.
Handles both bracket notation (field[attr1], field[attr2]) and JSON string format.
Args:
request: The FastAPI request object
field_name: The name of the field in the form data (e.g., "expires_after")
model_class: The Pydantic model class to parse into
Returns:
An instance of model_class if parsing succeeds, None otherwise
Example:
expires_after = await parse_pydantic_from_form(
request, "expires_after", ExpiresAfter
)
"""
form = await request.form()
# Check for bracket notation first (e.g., expires_after[anchor], expires_after[seconds])
bracket_data = {}
prefix = f"{field_name}["
for key in form.keys():
if key.startswith(prefix) and key.endswith("]"):
# Extract the attribute name from field_name[attr]
attr = key[len(prefix) : -1]
bracket_data[attr] = form[key]
if bracket_data:
try:
return model_class(**bracket_data)
except (ValidationError, TypeError):
pass
# Check for JSON string format
if field_name in form:
value = form[field_name]
if isinstance(value, str):
try:
data = json.loads(value)
return model_class(**data)
except (json.JSONDecodeError, TypeError, ValidationError):
pass
return None
async def parse_expires_after(request: Request) -> ExpiresAfter | None:
"""
Dependency to parse expires_after from multipart form data.
Handles both bracket notation (expires_after[anchor], expires_after[seconds])
and JSON string format.
"""
return await parse_pydantic_from_form(request, "expires_after", ExpiresAfter)

View file

@ -1,34 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.models.llama.sku_types import * # noqa: F403
def is_supported_safety_model(model: Model) -> bool:
if model.quantization_format != CheckpointQuantizationFormat.bf16:
return False
model_id = model.core_model_id
return model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_3_1b,
CoreModelId.llama_guard_3_11b_vision,
]
def supported_inference_models() -> list[Model]:
return [
m
for m in all_registered_models()
if (
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3, ModelFamily.llama4}
or is_supported_safety_model(m)
)
]
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR = {m.huggingface_repo: m.descriptor() for m in all_registered_models()}

View file

@ -1,102 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import base64
import platform
import struct
from typing import TYPE_CHECKING
import torch
from llama_stack.log import get_logger
if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer
from llama_stack.apis.inference import (
ModelStore,
OpenAIEmbeddingData,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
)
EMBEDDING_MODELS = {}
DARWIN = "Darwin"
log = get_logger(name=__name__, category="providers::utils")
class SentenceTransformerEmbeddingMixin:
model_store: ModelStore
async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
# Convert input to list format if it's a single string
input_list = [params.input] if isinstance(params.input, str) else params.input
if not input_list:
raise ValueError("Empty list not supported")
# Get the model and generate embeddings
model_obj = await self.model_store.get_model(params.model)
embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id)
embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False)
# Convert embeddings to the requested format
data = []
for i, embedding in enumerate(embeddings):
if params.encoding_format == "base64":
# Convert float array to base64 string
float_bytes = struct.pack(f"{len(embedding)}f", *embedding)
embedding_value = base64.b64encode(float_bytes).decode("ascii")
else:
# Default to float format
embedding_value = embedding.tolist()
data.append(
OpenAIEmbeddingData(
embedding=embedding_value,
index=i,
)
)
# Not returning actual token usage
usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
return OpenAIEmbeddingsResponse(
data=data,
model=params.model,
usage=usage,
)
async def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
global EMBEDDING_MODELS
loaded_model = EMBEDDING_MODELS.get(model)
if loaded_model is not None:
return loaded_model
log.info(f"Loading sentence transformer for {model}...")
def _load_model():
from sentence_transformers import SentenceTransformer
platform_name = platform.system()
if platform_name == DARWIN:
# PyTorch's OpenMP kernels can segfault on macOS when spawned from background
# threads with the default parallel settings, so force a single-threaded CPU run.
log.debug(f"Constraining torch threads on {platform_name} to a single worker")
torch.set_num_threads(1)
return SentenceTransformer(model, trust_remote_code=True)
loaded_model = await asyncio.to_thread(_load_model)
EMBEDDING_MODELS[model] = loaded_model
return loaded_model

View file

@ -1,244 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Any
from sqlalchemy.exc import IntegrityError
from llama_stack.apis.inference import (
ListOpenAIChatCompletionResponse,
OpenAIChatCompletion,
OpenAICompletionWithInputMessages,
OpenAIMessageParam,
Order,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.storage.datatypes import InferenceStoreReference, StorageBackendType
from llama_stack.log import get_logger
from ..sqlstore.api import ColumnDefinition, ColumnType
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
from ..sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl
logger = get_logger(name=__name__, category="inference")
class InferenceStore:
def __init__(
self,
reference: InferenceStoreReference,
policy: list[AccessRule],
):
self.reference = reference
self.sql_store = None
self.policy = policy
# Async write queue and worker control
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
self._worker_tasks: list[asyncio.Task[Any]] = []
self._max_write_queue_size: int = reference.max_write_queue_size
self._num_writers: int = max(1, reference.num_writers)
async def initialize(self):
"""Create the necessary tables if they don't exist."""
base_store = sqlstore_impl(self.reference)
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
# Disable write queue for SQLite to avoid concurrency issues
backend_name = self.reference.backend
backend_config = _SQLSTORE_BACKENDS.get(backend_name)
if backend_config is None:
raise ValueError(
f"Unregistered SQL backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
)
self.enable_write_queue = backend_config.type != StorageBackendType.SQL_SQLITE
await self.sql_store.create_table(
"chat_completions",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"created": ColumnType.INTEGER,
"model": ColumnType.STRING,
"choices": ColumnType.JSON,
"input_messages": ColumnType.JSON,
},
)
if self.enable_write_queue:
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
for _ in range(self._num_writers):
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
else:
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
async def shutdown(self) -> None:
if not self._worker_tasks:
return
if self._queue is not None:
await self._queue.join()
for t in self._worker_tasks:
if not t.done():
t.cancel()
for t in self._worker_tasks:
try:
await t
except asyncio.CancelledError:
pass
self._worker_tasks.clear()
async def flush(self) -> None:
"""Wait for all queued writes to complete. Useful for testing."""
if self.enable_write_queue and self._queue is not None:
await self._queue.join()
async def store_chat_completion(
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
) -> None:
if self.enable_write_queue:
if self._queue is None:
raise ValueError("Inference store is not initialized")
try:
self._queue.put_nowait((chat_completion, input_messages))
except asyncio.QueueFull:
logger.warning(
f"Write queue full; adding chat completion id={getattr(chat_completion, 'id', '<unknown>')}"
)
await self._queue.put((chat_completion, input_messages))
else:
await self._write_chat_completion(chat_completion, input_messages)
async def _worker_loop(self) -> None:
assert self._queue is not None
while True:
try:
item = await self._queue.get()
except asyncio.CancelledError:
break
chat_completion, input_messages = item
try:
await self._write_chat_completion(chat_completion, input_messages)
except Exception as e: # noqa: BLE001
logger.error(f"Error writing chat completion: {e}")
finally:
self._queue.task_done()
async def _write_chat_completion(
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
) -> None:
if self.sql_store is None:
raise ValueError("Inference store is not initialized")
data = chat_completion.model_dump()
record_data = {
"id": data["id"],
"created": data["created"],
"model": data["model"],
"choices": data["choices"],
"input_messages": [message.model_dump() for message in input_messages],
}
try:
await self.sql_store.insert(
table="chat_completions",
data=record_data,
)
except IntegrityError as e:
# Duplicate chat completion IDs can be generated during tests especially if they are replaying
# recorded responses across different tests. No need to warn or error under those circumstances.
# In the wild, this is not likely to happen at all (no evidence) so we aren't really hiding any problem.
# Check if it's a unique constraint violation
error_message = str(e.orig) if e.orig else str(e)
if self._is_unique_constraint_error(error_message):
# Update the existing record instead
await self.sql_store.update(table="chat_completions", data=record_data, where={"id": data["id"]})
else:
# Re-raise if it's not a unique constraint error
raise
def _is_unique_constraint_error(self, error_message: str) -> bool:
"""Check if the error is specifically a unique constraint violation."""
error_lower = error_message.lower()
return any(
indicator in error_lower
for indicator in [
"unique constraint failed", # SQLite
"duplicate key", # PostgreSQL
"unique violation", # PostgreSQL alternative
"duplicate entry", # MySQL
]
)
async def list_chat_completions(
self,
after: str | None = None,
limit: int | None = 50,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIChatCompletionResponse:
"""
List chat completions from the database.
:param after: The ID of the last chat completion to return.
:param limit: The maximum number of chat completions to return.
:param model: The model to filter by.
:param order: The order to sort the chat completions by.
"""
if not self.sql_store:
raise ValueError("Inference store is not initialized")
if not order:
order = Order.desc
where_conditions = {}
if model:
where_conditions["model"] = model
paginated_result = await self.sql_store.fetch_all(
table="chat_completions",
where=where_conditions if where_conditions else None,
order_by=[("created", order.value)],
cursor=("id", after) if after else None,
limit=limit,
)
data = [
OpenAICompletionWithInputMessages(
id=row["id"],
created=row["created"],
model=row["model"],
choices=row["choices"],
input_messages=row["input_messages"],
)
for row in paginated_result.data
]
return ListOpenAIChatCompletionResponse(
data=data,
has_more=paginated_result.has_more,
first_id=data[0].id if data else "",
last_id=data[-1].id if data else "",
)
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
if not self.sql_store:
raise ValueError("Inference store is not initialized")
row = await self.sql_store.fetch_one(
table="chat_completions",
where={"id": completion_id},
)
if not row:
# SecureSqlStore will return None if record doesn't exist OR access is denied
# This provides security by not revealing whether the record exists
raise ValueError(f"Chat completion with id {completion_id} not found") from None
return OpenAICompletionWithInputMessages(
id=row["id"],
created=row["created"],
model=row["model"],
choices=row["choices"],
input_messages=row["input_messages"],
)

View file

@ -1,336 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import struct
from collections.abc import AsyncIterator
import litellm
from llama_stack.apis.inference import (
ChatCompletionRequest,
InferenceProvider,
JsonSchemaResponseFormat,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingData,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
ToolChoice,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict_new,
convert_tooldef_to_openai_tool,
get_sampling_options,
prepare_openai_completion_params,
)
logger = get_logger(name=__name__, category="providers::utils")
class LiteLLMOpenAIMixin(
ModelRegistryHelper,
InferenceProvider,
NeedsRequestProviderData,
):
# TODO: avoid exposing the litellm specific model names to the user.
# potential change: add a prefix param that gets added to the model name
# when calling litellm.
def __init__(
self,
litellm_provider_name: str,
api_key_from_config: str | None,
provider_data_api_key_field: str | None = None,
model_entries: list[ProviderModelEntry] | None = None,
openai_compat_api_base: str | None = None,
download_images: bool = False,
json_schema_strict: bool = True,
):
"""
Initialize the LiteLLMOpenAIMixin.
:param model_entries: The model entries to register.
:param api_key_from_config: The API key to use from the config.
:param provider_data_api_key_field: The field in the provider data that contains the API key (optional).
:param litellm_provider_name: The name of the provider, used for model lookups.
:param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility.
:param download_images: Whether to download images and convert to base64 for message conversion.
:param json_schema_strict: Whether to use strict mode for JSON schema validation.
"""
ModelRegistryHelper.__init__(self, model_entries=model_entries)
self.litellm_provider_name = litellm_provider_name
self.api_key_from_config = api_key_from_config
self.provider_data_api_key_field = provider_data_api_key_field
self.api_base = openai_compat_api_base
self.download_images = download_images
self.json_schema_strict = json_schema_strict
if openai_compat_api_base:
self.is_openai_compat = True
else:
self.is_openai_compat = False
async def initialize(self):
pass
async def shutdown(self):
pass
def get_litellm_model_name(self, model_id: str) -> str:
# users may be using openai/ prefix in their model names. the openai/models.py did this by default.
# model_id.startswith("openai/") is for backwards compatibility.
return (
f"{self.litellm_provider_name}/{model_id}"
if self.is_openai_compat and not model_id.startswith(self.litellm_provider_name)
else model_id
)
def _add_additional_properties_recursive(self, schema):
"""
Recursively add additionalProperties: False to all object schemas
"""
if isinstance(schema, dict):
if schema.get("type") == "object":
schema["additionalProperties"] = False
# Add required field with all property keys if properties exist
if "properties" in schema and schema["properties"]:
schema["required"] = list(schema["properties"].keys())
if "properties" in schema:
for prop_schema in schema["properties"].values():
self._add_additional_properties_recursive(prop_schema)
for key in ["anyOf", "allOf", "oneOf"]:
if key in schema:
for sub_schema in schema[key]:
self._add_additional_properties_recursive(sub_schema)
if "not" in schema:
self._add_additional_properties_recursive(schema["not"])
# Handle $defs/$ref
if "$defs" in schema:
for def_schema in schema["$defs"].values():
self._add_additional_properties_recursive(def_schema)
return schema
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {}
input_dict["messages"] = [
await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages
]
if fmt := request.response_format:
if not isinstance(fmt, JsonSchemaResponseFormat):
raise ValueError(
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
)
fmt = fmt.json_schema
name = fmt["title"]
del fmt["title"]
fmt["additionalProperties"] = False
# Apply additionalProperties: False recursively to all objects
fmt = self._add_additional_properties_recursive(fmt)
input_dict["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": name,
"schema": fmt,
"strict": self.json_schema_strict,
},
}
if request.tools:
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
if request.tool_config.tool_choice:
input_dict["tool_choice"] = (
request.tool_config.tool_choice.value
if isinstance(request.tool_config.tool_choice, ToolChoice)
else request.tool_config.tool_choice
)
return {
"model": request.model,
"api_key": self.get_api_key(),
"api_base": self.api_base,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
def get_api_key(self) -> str:
provider_data = self.get_request_provider_data()
key_field = self.provider_data_api_key_field
if provider_data and getattr(provider_data, key_field, None):
api_key = getattr(provider_data, key_field)
else:
api_key = self.api_key_from_config
if not api_key:
raise ValueError(
"API key is not set. Please provide a valid API key in the "
"provider data header, e.g. x-llamastack-provider-data: "
f'{{"{key_field}": "<API_KEY>"}}, or in the provider config.'
)
return api_key
async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
model_obj = await self.model_store.get_model(params.model)
# Convert input to list if it's a string
input_list = [params.input] if isinstance(params.input, str) else params.input
# Call litellm embedding function
# litellm.drop_params = True
response = litellm.embedding(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
input=input_list,
api_key=self.get_api_key(),
api_base=self.api_base,
dimensions=params.dimensions,
)
# Convert response to OpenAI format
data = b64_encode_openai_embeddings_response(response.data, params.encoding_format)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response["usage"]["prompt_tokens"],
total_tokens=response["usage"]["total_tokens"],
)
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.provider_resource_id,
usage=usage,
)
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(params.model)
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
prompt=params.prompt,
best_of=params.best_of,
echo=params.echo,
frequency_penalty=params.frequency_penalty,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_tokens=params.max_tokens,
n=params.n,
presence_penalty=params.presence_penalty,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=params.stream_options,
temperature=params.temperature,
top_p=params.top_p,
user=params.user,
suffix=params.suffix,
api_key=self.get_api_key(),
api_base=self.api_base,
)
return await litellm.atext_completion(**request_params)
async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
# Add usage tracking for streaming when telemetry is active
from llama_stack.core.telemetry.tracing import get_current_span
stream_options = params.stream_options
if params.stream and get_current_span() is not None:
if stream_options is None:
stream_options = {"include_usage": True}
elif "include_usage" not in stream_options:
stream_options = {**stream_options, "include_usage": True}
model_obj = await self.model_store.get_model(params.model)
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
messages=params.messages,
frequency_penalty=params.frequency_penalty,
function_call=params.function_call,
functions=params.functions,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_completion_tokens=params.max_completion_tokens,
max_tokens=params.max_tokens,
n=params.n,
parallel_tool_calls=params.parallel_tool_calls,
presence_penalty=params.presence_penalty,
response_format=params.response_format,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=stream_options,
temperature=params.temperature,
tool_choice=params.tool_choice,
tools=params.tools,
top_logprobs=params.top_logprobs,
top_p=params.top_p,
user=params.user,
api_key=self.get_api_key(),
api_base=self.api_base,
)
return await litellm.acompletion(**request_params)
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available via LiteLLM for the current
provider (self.litellm_provider_name).
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
if self.litellm_provider_name not in litellm.models_by_provider:
logger.error(f"Provider {self.litellm_provider_name} is not registered in litellm.")
return False
return model in litellm.models_by_provider[self.litellm_provider_name]
def b64_encode_openai_embeddings_response(
response_data: list[dict], encoding_format: str | None = "float"
) -> list[OpenAIEmbeddingData]:
"""
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
"""
data = []
for i, embedding_data in enumerate(response_data):
if encoding_format == "base64":
byte_array = bytearray()
for embedding_value in embedding_data["embedding"]:
byte_array.extend(struct.pack("f", float(embedding_value)))
response_embedding = base64.b64encode(byte_array).decode("utf-8")
else:
response_embedding = embedding_data["embedding"]
data.append(
OpenAIEmbeddingData(
embedding=response_embedding,
index=i,
)
)
return data

View file

@ -1,191 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)
logger = get_logger(name=__name__, category="providers::utils")
class RemoteInferenceProviderConfig(BaseModel):
allowed_models: list[str] | None = Field( # TODO: make this non-optional and give a list() default
default=None,
description="List of models that should be registered with the model registry. If None, all models are allowed.",
)
refresh_models: bool = Field(
default=False,
description="Whether to refresh models periodically from the provider",
)
auth_credential: SecretStr | None = Field(
default=None,
description="Authentication credential for the provider",
alias="api_key",
)
# TODO: this class is more confusing than useful right now. We need to make it
# more closer to the Model class.
class ProviderModelEntry(BaseModel):
provider_model_id: str
aliases: list[str] = Field(default_factory=list)
llama_model: str | None = None
model_type: ModelType = ModelType.llm
metadata: dict[str, Any] = Field(default_factory=dict)
def build_hf_repo_model_entry(
provider_model_id: str,
model_descriptor: str,
additional_aliases: list[str] | None = None,
) -> ProviderModelEntry:
aliases = [
# NOTE: avoid HF aliases because they _cannot_ be unique across providers
# get_huggingface_repo(model_descriptor),
]
if additional_aliases:
aliases.extend(additional_aliases)
aliases = [alias for alias in aliases if alias is not None]
return ProviderModelEntry(
provider_model_id=provider_model_id,
aliases=aliases,
llama_model=model_descriptor,
)
class ModelRegistryHelper(ModelsProtocolPrivate):
__provider_id__: str
def __init__(
self,
model_entries: list[ProviderModelEntry] | None = None,
allowed_models: list[str] | None = None,
):
self.allowed_models = allowed_models if allowed_models else []
self.alias_to_provider_id_map = {}
self.provider_id_to_llama_model_map = {}
self.model_entries = model_entries or []
for entry in self.model_entries:
for alias in entry.aliases:
self.alias_to_provider_id_map[alias] = entry.provider_model_id
# also add a mapping from provider model id to itself for easy lookup
self.alias_to_provider_id_map[entry.provider_model_id] = entry.provider_model_id
if entry.llama_model:
self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
async def list_models(self) -> list[Model] | None:
models = []
for entry in self.model_entries:
ids = [entry.provider_model_id] + entry.aliases
for id in ids:
if self.allowed_models and id not in self.allowed_models:
continue
models.append(
Model(
identifier=id,
provider_resource_id=entry.provider_model_id,
model_type=entry.model_type,
metadata=entry.metadata,
provider_id=self.__provider_id__,
)
)
return models
async def should_refresh_models(self) -> bool:
return False
def get_provider_model_id(self, identifier: str) -> str | None:
return self.alias_to_provider_id_map.get(identifier, None)
# TODO: why keep a separate llama model mapping?
def get_llama_model(self, provider_model_id: str) -> str | None:
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from the provider (non-static check).
This is for subclassing purposes, so providers can check if a specific
model is currently available for use through dynamic means (e.g., API calls).
This method should NOT check statically configured model entries in
`self.alias_to_provider_id_map` - that is handled separately in register_model.
Default implementation returns False (no dynamic models available).
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
logger.info(
f"check_model_availability is not implemented for {self.__class__.__name__}. Returning False by default."
)
return False
async def register_model(self, model: Model) -> Model:
# Check if model is supported in static configuration
supported_model_id = self.get_provider_model_id(model.provider_resource_id)
# If not found in static config, check if it's available dynamically from provider
if not supported_model_id:
if await self.check_model_availability(model.provider_resource_id):
supported_model_id = model.provider_resource_id
else:
# note: we cannot provide a complete list of supported models without
# getting a complete list from the provider, so we return "..."
all_supported_models = [*self.alias_to_provider_id_map.keys(), "..."]
raise UnsupportedModelError(model.provider_resource_id, all_supported_models)
provider_resource_id = self.get_provider_model_id(model.model_id)
if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
if provider_resource_id:
if provider_resource_id != supported_model_id: # be idempotent, only reject differences
raise ValueError(
f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first."
)
else:
llama_model = model.metadata.get("llama_model")
if llama_model:
existing_llama_model = self.get_llama_model(model.provider_resource_id)
if existing_llama_model:
if existing_llama_model != llama_model:
raise ValueError(
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
)
else:
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
raise ValueError(
f"Invalid llama_model '{llama_model}' specified in metadata. "
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
)
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)
# Register the model alias, ensuring it maps to the correct provider model id
self.alias_to_provider_id_map[model.model_id] = supported_model_id
return model
async def unregister_model(self, model_id: str) -> None:
# model_id is the identifier, not the provider_resource_id
# unfortunately, this ID can be of the form provider_id/model_id which
# we never registered. TODO: fix this by significantly rewriting
# registration and registry helper
pass

File diff suppressed because it is too large Load diff

View file

@ -1,494 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import uuid
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterable
from typing import Any
from openai import NOT_GIVEN, AsyncOpenAI
from pydantic import BaseModel, ConfigDict
from llama_stack.apis.inference import (
Model,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingData,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
)
from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
logger = get_logger(name=__name__, category="providers::utils")
class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
"""
Mixin class that provides OpenAI-specific functionality for inference providers.
This class handles direct OpenAI API calls using the AsyncOpenAI client.
This is an abstract base class that requires child classes to implement:
- get_base_url(): Method to retrieve the OpenAI-compatible API base URL
The behavior of this class can be customized by child classes in the following ways:
- overwrite_completion_id: If True, overwrites the 'id' field in OpenAI responses
- download_images: If True, downloads images and converts to base64 for providers that require it
- embedding_model_metadata: A dictionary mapping model IDs to their embedding metadata
- construct_model_from_identifier: Method to construct a Model instance corresponding to the given identifier
- provider_data_api_key_field: Optional field name in provider data to look for API key
- list_provider_model_ids: Method to list available models from the provider
- get_extra_client_params: Method to provide extra parameters to the AsyncOpenAI client
Expected Dependencies:
- self.model_store: Injected by the Llama Stack distribution system at runtime.
This provides model registry functionality for looking up registered models.
The model_store is set in routing_tables/common.py during provider initialization.
"""
# Allow extra fields so the routing infra can inject model_store, __provider_id__, etc.
model_config = ConfigDict(extra="allow")
config: RemoteInferenceProviderConfig
# Allow subclasses to control whether to overwrite the 'id' field in OpenAI responses
# is overwritten with a client-side generated id.
#
# This is useful for providers that do not return a unique id in the response.
overwrite_completion_id: bool = False
# Allow subclasses to control whether to download images and convert to base64
# for providers that require base64 encoded images instead of URLs.
download_images: bool = False
# Embedding model metadata for this provider
# Can be set by subclasses or instances to provide embedding models
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
embedding_model_metadata: dict[str, dict[str, int]] = {}
# Cache of available models keyed by model ID
# This is set in list_models() and used in check_model_availability()
_model_cache: dict[str, Model] = {}
# List of allowed models for this provider, if empty all models allowed
allowed_models: list[str] = []
# Optional field name in provider data to look for API key, which takes precedence
provider_data_api_key_field: str | None = None
def get_api_key(self) -> str | None:
"""
Get the API key.
:return: The API key as a string, or None if not set
"""
if self.config.auth_credential is None:
return None
return self.config.auth_credential.get_secret_value()
@abstractmethod
def get_base_url(self) -> str:
"""
Get the OpenAI-compatible API base URL.
This method must be implemented by child classes to provide the base URL
for the OpenAI API or compatible endpoints (e.g., "https://api.openai.com/v1").
:return: The base URL as a string
"""
pass
def get_extra_client_params(self) -> dict[str, Any]:
"""
Get any extra parameters to pass to the AsyncOpenAI client.
Child classes can override this method to provide additional parameters
such as timeout settings, proxies, etc.
:return: A dictionary of extra parameters
"""
return {}
def construct_model_from_identifier(self, identifier: str) -> Model:
"""
Construct a Model instance corresponding to the given identifier
Child classes can override this to customize model typing/metadata.
:param identifier: The provider's model identifier
:return: A Model instance
"""
if metadata := self.embedding_model_metadata.get(identifier):
return Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=identifier,
identifier=identifier,
model_type=ModelType.embedding,
metadata=metadata,
)
return Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=identifier,
identifier=identifier,
model_type=ModelType.llm,
)
async def list_provider_model_ids(self) -> Iterable[str]:
"""
List available models from the provider.
Child classes can override this method to provide a custom implementation
for listing models. The default implementation uses the AsyncOpenAI client
to list models from the OpenAI-compatible endpoint.
:return: An iterable of model IDs or None if not implemented
"""
client = self.client
async with client:
model_ids = [m.id async for m in client.models.list()]
return model_ids
async def initialize(self) -> None:
"""
Initialize the OpenAI mixin.
This method provides a default implementation that does nothing.
Subclasses can override this method to perform initialization tasks
such as setting up clients, validating configurations, etc.
"""
pass
async def shutdown(self) -> None:
"""
Shutdown the OpenAI mixin.
This method provides a default implementation that does nothing.
Subclasses can override this method to perform cleanup tasks
such as closing connections, releasing resources, etc.
"""
pass
@property
def client(self) -> AsyncOpenAI:
"""
Get an AsyncOpenAI client instance.
Uses the abstract methods get_api_key() and get_base_url() which must be
implemented by child classes.
Users can also provide the API key via the provider data header, which
is used instead of any config API key.
"""
api_key = self._get_api_key_from_config_or_provider_data()
if not api_key:
message = "API key not provided."
if self.provider_data_api_key_field:
message += f' Please provide a valid API key in the provider data header, e.g. x-llamastack-provider-data: {{"{self.provider_data_api_key_field}": "<API_KEY>"}}.'
raise ValueError(message)
return AsyncOpenAI(
api_key=api_key,
base_url=self.get_base_url(),
**self.get_extra_client_params(),
)
def _get_api_key_from_config_or_provider_data(self) -> str | None:
api_key = self.get_api_key()
if self.provider_data_api_key_field:
provider_data = self.get_request_provider_data()
if provider_data and getattr(provider_data, self.provider_data_api_key_field, None):
api_key = getattr(provider_data, self.provider_data_api_key_field)
return api_key
async def _get_provider_model_id(self, model: str) -> str:
"""
Get the provider-specific model ID from the model store.
This is a utility method that looks up the registered model and returns
the provider_resource_id that should be used for actual API calls.
:param model: The registered model name/identifier
:return: The provider-specific model ID (e.g., "gpt-4")
"""
# Look up the registered model to get the provider-specific model ID
# self.model_store is injected by the distribution system at runtime
model_obj: Model = await self.model_store.get_model(model) # type: ignore[attr-defined]
# provider_resource_id is str | None, but we expect it to be str for OpenAI calls
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {model} has no provider_resource_id")
return model_obj.provider_resource_id
async def _maybe_overwrite_id(self, resp: Any, stream: bool | None) -> Any:
if not self.overwrite_completion_id:
return resp
new_id = f"cltsd-{uuid.uuid4()}"
if stream:
async def _gen():
async for chunk in resp:
chunk.id = new_id
yield chunk
return _gen()
else:
resp.id = new_id
return resp
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
"""
Direct OpenAI completion API call.
"""
# TODO: fix openai_completion to return type compatible with OpenAI's API response
completion_kwargs = await prepare_openai_completion_params(
model=await self._get_provider_model_id(params.model),
prompt=params.prompt,
best_of=params.best_of,
echo=params.echo,
frequency_penalty=params.frequency_penalty,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_tokens=params.max_tokens,
n=params.n,
presence_penalty=params.presence_penalty,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=params.stream_options,
temperature=params.temperature,
top_p=params.top_p,
user=params.user,
suffix=params.suffix,
)
if extra_body := params.model_extra:
completion_kwargs["extra_body"] = extra_body
resp = await self.client.completions.create(**completion_kwargs)
return await self._maybe_overwrite_id(resp, params.stream) # type: ignore[no-any-return]
async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""
Direct OpenAI chat completion API call.
"""
messages = params.messages
if self.download_images:
async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam:
if isinstance(m.content, list):
for c in m.content:
if c.type == "image_url" and c.image_url and c.image_url.url and "http" in c.image_url.url:
localize_result = await localize_image_content(c.image_url.url)
if localize_result is None:
raise ValueError(
f"Failed to localize image content from {c.image_url.url[:42]}{'...' if len(c.image_url.url) > 42 else ''}"
)
content, format = localize_result
c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}"
# else it's a string and we don't need to modify it
return m
messages = [await _localize_image_url(m) for m in messages]
request_params = await prepare_openai_completion_params(
model=await self._get_provider_model_id(params.model),
messages=messages,
frequency_penalty=params.frequency_penalty,
function_call=params.function_call,
functions=params.functions,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_completion_tokens=params.max_completion_tokens,
max_tokens=params.max_tokens,
n=params.n,
parallel_tool_calls=params.parallel_tool_calls,
presence_penalty=params.presence_penalty,
response_format=params.response_format,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=params.stream_options,
temperature=params.temperature,
tool_choice=params.tool_choice,
tools=params.tools,
top_logprobs=params.top_logprobs,
top_p=params.top_p,
user=params.user,
)
if extra_body := params.model_extra:
request_params["extra_body"] = extra_body
resp = await self.client.chat.completions.create(**request_params)
return await self._maybe_overwrite_id(resp, params.stream) # type: ignore[no-any-return]
async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
"""
Direct OpenAI embeddings API call.
"""
# Prepare request parameters
request_params = {
"model": await self._get_provider_model_id(params.model),
"input": params.input,
"encoding_format": params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
"dimensions": params.dimensions if params.dimensions is not None else NOT_GIVEN,
"user": params.user if params.user is not None else NOT_GIVEN,
}
# Add extra_body if present
extra_body = params.model_extra
if extra_body:
request_params["extra_body"] = extra_body
# Call OpenAI embeddings API with properly typed parameters
response = await self.client.embeddings.create(**request_params)
data = []
for i, embedding_data in enumerate(response.data):
data.append(
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
return OpenAIEmbeddingsResponse(
data=data,
model=params.model,
usage=usage,
)
###
# ModelsProtocolPrivate implementation - provide model management functionality
#
# async def register_model(self, model: Model) -> Model: ...
# async def unregister_model(self, model_id: str) -> None: ...
#
# async def list_models(self) -> list[Model] | None: ...
# async def should_refresh_models(self) -> bool: ...
##
async def register_model(self, model: Model) -> Model:
if not await self.check_model_availability(model.provider_model_id):
raise ValueError(f"Model {model.provider_model_id} is not available from provider {self.__provider_id__}") # type: ignore[attr-defined]
return model
async def unregister_model(self, model_id: str) -> None:
return None
async def list_models(self) -> list[Model] | None:
"""
List available models from the provider's /v1/models endpoint augmented with static embedding model metadata.
Also, caches the models in self._model_cache for use in check_model_availability().
:return: A list of Model instances representing available models.
"""
self._model_cache = {}
api_key = self._get_api_key_from_config_or_provider_data()
if not api_key:
logger.debug(f"{self.__class__.__name__}.list_provider_model_ids() disabled because API key not provided")
return None
try:
iterable = await self.list_provider_model_ids()
except Exception as e:
logger.error(f"{self.__class__.__name__}.list_provider_model_ids() failed with: {e}")
raise
if not hasattr(iterable, "__iter__"):
raise TypeError(
f"Failed to list models: {self.__class__.__name__}.list_provider_model_ids() must return an iterable of "
f"strings, but returned {type(iterable).__name__}"
)
provider_models_ids = list(iterable)
logger.info(f"{self.__class__.__name__}.list_provider_model_ids() returned {len(provider_models_ids)} models")
for provider_model_id in provider_models_ids:
if not isinstance(provider_model_id, str):
raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
if self.allowed_models and provider_model_id not in self.allowed_models:
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
continue
model = self.construct_model_from_identifier(provider_model_id)
self._model_cache[provider_model_id] = model
return list(self._model_cache.values())
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from the provider's /v1/models or pre-registered.
:param model: The model identifier to check.
:return: True if the model is available dynamically or pre-registered, False otherwise.
"""
# First check if the model is pre-registered in the model store
if hasattr(self, "model_store") and self.model_store:
qualified_model = f"{self.__provider_id__}/{model}" # type: ignore[attr-defined]
if await self.model_store.has_model(qualified_model):
return True
# Then check the provider's dynamic model cache
if not self._model_cache:
await self.list_models()
return model in self._model_cache
async def should_refresh_models(self) -> bool:
return self.config.refresh_models
#
# The model_dump implementations are to avoid serializing the extra fields,
# e.g. model_store, which are not pydantic.
#
def _filter_fields(self, **kwargs):
"""Helper to exclude extra fields from serialization."""
# Exclude any extra fields stored in __pydantic_extra__
if hasattr(self, "__pydantic_extra__") and self.__pydantic_extra__:
exclude = kwargs.get("exclude", set())
if not isinstance(exclude, set):
exclude = set(exclude) if exclude else set()
exclude.update(self.__pydantic_extra__.keys())
kwargs["exclude"] = exclude
return kwargs
def model_dump(self, **kwargs):
"""Override to exclude extra fields from serialization."""
kwargs = self._filter_fields(**kwargs)
return super().model_dump(**kwargs)
def model_dump_json(self, **kwargs):
"""Override to exclude extra fields from JSON serialization."""
kwargs = self._filter_fields(**kwargs)
return super().model_dump_json(**kwargs)

View file

@ -1,495 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import base64
import io
import json
import re
from typing import Any
import httpx
from PIL import Image as PIL_Image
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
Message,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIFile,
ResponseFormat,
ResponseFormatType,
SystemMessage,
SystemMessageBehavior,
ToolChoice,
ToolDefinition,
UserMessage,
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
RawContent,
RawContentItem,
RawMediaItem,
RawMessage,
RawTextItem,
Role,
StopReason,
ToolPromptFormat,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.prompt_templates import (
BuiltinToolGenerator,
FunctionTagCustomToolGenerator,
JsonCustomToolGenerator,
PythonListCustomToolGenerator,
SystemDefaultGenerator,
)
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
)
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.utils.inference import supported_inference_models
log = get_logger(name=__name__, category="providers::utils")
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
messages: list[RawMessage]
class CompletionRequestWithRawContent(CompletionRequest):
content: RawContent
def decode_assistant_message(content: str, stop_reason: StopReason) -> RawMessage:
formatter = ChatFormat(Tokenizer.get_instance())
return formatter.decode_assistant_message_from_content(content, stop_reason)
def interleaved_content_as_str(
content: Any,
sep: str = " ",
) -> str:
if content is None:
return ""
def _process(c) -> str:
if isinstance(c, str):
return c
elif isinstance(c, TextContentItem) or isinstance(c, OpenAIChatCompletionContentPartTextParam):
return c.text
elif isinstance(c, ImageContentItem) or isinstance(c, OpenAIChatCompletionContentPartImageParam):
return "<image>"
elif isinstance(c, OpenAIFile):
return "<file>"
else:
raise ValueError(f"Unsupported content type: {type(c)}")
if isinstance(content, list):
return sep.join(_process(c) for c in content)
else:
return _process(content)
async def convert_request_to_raw(
request: ChatCompletionRequest | CompletionRequest,
) -> ChatCompletionRequestWithRawContent | CompletionRequestWithRawContent:
if isinstance(request, ChatCompletionRequest):
messages = []
for m in request.messages:
content = await interleaved_content_convert_to_raw(m.content)
d = m.model_dump()
d["content"] = content
messages.append(RawMessage(**d))
d = request.model_dump()
d["messages"] = messages
request = ChatCompletionRequestWithRawContent(**d)
else:
d = request.model_dump()
d["content"] = await interleaved_content_convert_to_raw(request.content)
request = CompletionRequestWithRawContent(**d)
return request
async def interleaved_content_convert_to_raw(
content: InterleavedContent,
) -> RawContent:
"""Download content from URLs / files etc. so plain bytes can be sent to the model"""
async def _localize_single(c: str | InterleavedContentItem) -> str | RawContentItem:
if isinstance(c, str):
return RawTextItem(text=c)
elif isinstance(c, TextContentItem):
return RawTextItem(text=c.text)
elif isinstance(c, ImageContentItem):
image = c.image
if image.url:
# Load image bytes from URL
if image.url.uri.startswith("data"):
match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
if not match:
raise ValueError(f"Invalid data URL format, {image.url.uri[:40]}...")
_, image_data = match.groups()
data = base64.b64decode(image_data)
elif image.url.uri.startswith("file://"):
path = image.url.uri[len("file://") :]
with open(path, "rb") as f:
data = f.read() # type: ignore
elif image.url.uri.startswith("http"):
async with httpx.AsyncClient() as client:
response = await client.get(image.url.uri)
data = response.content
else:
raise ValueError("Unsupported URL type")
elif image.data:
# data is a base64 encoded string, decode it to bytes for RawMediaItem
data = base64.b64decode(image.data)
else:
raise ValueError("No data or URL provided")
return RawMediaItem(data=data)
else:
raise ValueError(f"Unsupported content type: {type(c)}")
if isinstance(content, list):
return await asyncio.gather(*(_localize_single(c) for c in content))
else:
return await _localize_single(content)
def content_has_media(content: InterleavedContent):
def _has_media_content(c):
return isinstance(c, ImageContentItem)
if isinstance(content, list):
return any(_has_media_content(c) for c in content)
else:
return _has_media_content(content)
def messages_have_media(messages: list[Message]):
return any(content_has_media(m.content) for m in messages)
def request_has_media(request: ChatCompletionRequest | CompletionRequest):
if isinstance(request, ChatCompletionRequest):
return messages_have_media(request.messages)
else:
return content_has_media(request.content)
async def localize_image_content(uri: str) -> tuple[bytes, str] | None:
if uri.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(uri)
content = r.content
content_type = r.headers.get("content-type")
if content_type:
format = content_type.split("/")[-1]
else:
format = "png"
return content, format
elif uri.startswith("data"):
# data:image/{format};base64,{data}
match = re.match(r"data:image/(\w+);base64,(.+)", uri)
if not match:
raise ValueError(f"Invalid data URL format, {uri[:40]}...")
fmt, image_data = match.groups()
content = base64.b64decode(image_data)
return content, fmt
else:
return None
async def convert_image_content_to_url(
media: ImageContentItem, download: bool = False, include_format: bool = True
) -> str:
image = media.image
if image.url and (not download or image.url.uri.startswith("data")):
return image.url.uri
if image.data:
# data is a base64 encoded string, decode it to bytes first
# TODO(mf): do this more efficiently, decode less
content = base64.b64decode(image.data)
pil_image = PIL_Image.open(io.BytesIO(content))
format = pil_image.format
else:
localize_result = await localize_image_content(image.url.uri)
if localize_result is None:
raise ValueError(f"Failed to localize image content from {image.url.uri}")
content, format = localize_result
if include_format:
return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8")
else:
return base64.b64encode(content).decode("utf-8")
def augment_content_with_response_format_prompt(response_format, content):
if fmt_prompt := response_format_prompt(response_format):
if isinstance(content, list):
return content + [TextContentItem(text=fmt_prompt)]
elif isinstance(content, str):
return [TextContentItem(text=content), TextContentItem(text=fmt_prompt)]
else:
return [content, TextContentItem(text=fmt_prompt)]
return content
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(
request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
)
return formatter.tokenizer.decode(model_input.tokens)
async def chat_completion_request_to_model_input_info(
request: ChatCompletionRequest, llama_model: str
) -> tuple[str, int]:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(
request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
)
return (
formatter.tokenizer.decode(model_input.tokens),
len(model_input.tokens),
)
def chat_completion_request_to_messages(
request: ChatCompletionRequest,
llama_model: str,
) -> list[Message]:
"""Reads chat completion request and augments the messages to handle tools.
For eg. for llama_3_1, add system message with the appropriate tools or
add user messsage for custom tools, etc.
"""
assert llama_model is not None, "llama_model is required"
model = resolve_model(llama_model)
if model is None:
log.error(f"Could not resolve model {llama_model}")
return request.messages
allowed_models = supported_inference_models()
descriptors = [m.descriptor() for m in allowed_models]
if model.descriptor() not in descriptors:
log.error(f"Unsupported inference model? {model.descriptor()}")
return request.messages
if model.model_family == ModelFamily.llama3_1 or (
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_1(request)
elif model.model_family in (
ModelFamily.llama3_2,
ModelFamily.llama3_3,
):
# llama3.2, llama3.3 follow the same tool prompt format
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator)
elif model.model_family == ModelFamily.llama4:
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4)
else:
messages = request.messages
if fmt_prompt := response_format_prompt(request.response_format):
messages.append(UserMessage(content=fmt_prompt))
return messages
def response_format_prompt(fmt: ResponseFormat | None):
if not fmt:
return None
if fmt.type == ResponseFormatType.json_schema.value:
return f"Please respond in JSON format with the schema: {json.dumps(fmt.json_schema)}"
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
def augment_messages_for_tools_llama_3_1(
request: ChatCompletionRequest,
) -> list[Message]:
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
messages = []
default_gen = SystemDefaultGenerator()
default_template = default_gen.gen()
sys_content = ""
tool_template = None
if request.tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(request.tools)
sys_content += tool_template.render()
sys_content += "\n"
sys_content += default_template.render()
if existing_system_message:
# TODO: this fn is needed in many places
def _process(c):
if isinstance(c, str):
return c
else:
return "<media>"
sys_content += "\n"
if isinstance(existing_system_message.content, str):
sys_content += _process(existing_system_message.content)
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt
messages.append(SystemMessage(content=sys_content))
has_custom_tools = request.tools is not None and any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
if fmt == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
elif fmt == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator()
else:
raise ValueError(f"Non supported ToolPromptFormat {fmt}")
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
custom_template = tool_gen.gen(custom_tools)
messages.append(UserMessage(content=custom_template.render()))
# Add back existing messages from the request
messages += existing_messages
return messages
def augment_messages_for_tools_llama(
request: ChatCompletionRequest,
custom_tool_prompt_generator,
) -> list[Message]:
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
sys_content = ""
custom_tools, builtin_tools = [], []
for t in request.tools:
if isinstance(t.tool_name, str):
custom_tools.append(t)
else:
builtin_tools.append(t)
if builtin_tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(builtin_tools)
sys_content += tool_template.render()
sys_content += "\n"
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
if custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list:
raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}")
system_prompt = None
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
system_prompt = existing_system_message.content
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
sys_content += tool_template.render()
sys_content += "\n"
if existing_system_message and (
request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools
):
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt
messages = [SystemMessage(content=sys_content.strip("\n")), *existing_messages]
return messages
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str:
if tool_choice == ToolChoice.auto:
return ""
elif tool_choice == ToolChoice.required:
return "You MUST use one of the provided functions/tools to answer the user query."
elif tool_choice == ToolChoice.none:
# tools are already not passed in
return ""
else:
# specific tool
return f"You MUST use the tool `{tool_choice}` to answer the user query."
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
llama_model = resolve_model(model)
if llama_model is None:
log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format")
return ToolPromptFormat.json
if llama_model.model_family == ModelFamily.llama3_1 or (
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
return ToolPromptFormat.json
elif llama_model.model_family in (
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
):
# llama3.2 and llama3.3 models follow the same tool prompt format
return ToolPromptFormat.python_list
else:
return ToolPromptFormat.json

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .kvstore import * # noqa: F401, F403

View file

@ -1,21 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from typing import Protocol
class KVStore(Protocol):
# TODO: make the value type bytes instead of str
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: ...
async def get(self, key: str) -> str | None: ...
async def delete(self, key: str) -> None: ...
async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ...
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ...

View file

@ -1,39 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated
from pydantic import Field
from llama_stack.core.storage.datatypes import (
MongoDBKVStoreConfig,
PostgresKVStoreConfig,
RedisKVStoreConfig,
SqliteKVStoreConfig,
StorageBackendType,
)
KVStoreConfig = Annotated[
RedisKVStoreConfig | SqliteKVStoreConfig | PostgresKVStoreConfig | MongoDBKVStoreConfig, Field(discriminator="type")
]
def get_pip_packages(store_config: dict | KVStoreConfig) -> list[str]:
"""Get pip packages for KV store config, handling both dict and object cases."""
if isinstance(store_config, dict):
store_type = store_config.get("type")
if store_type == StorageBackendType.KV_SQLITE.value:
return SqliteKVStoreConfig.pip_packages()
elif store_type == StorageBackendType.KV_POSTGRES.value:
return PostgresKVStoreConfig.pip_packages()
elif store_type == StorageBackendType.KV_REDIS.value:
return RedisKVStoreConfig.pip_packages()
elif store_type == StorageBackendType.KV_MONGODB.value:
return MongoDBKVStoreConfig.pip_packages()
else:
raise ValueError(f"Unknown KV store type: {store_type}")
else:
return store_config.pip_packages()

View file

@ -1,97 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from __future__ import annotations
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType
from .api import KVStore
from .config import KVStoreConfig
def kvstore_dependencies():
"""
Returns all possible kvstore dependencies for registry/provider specifications.
NOTE: For specific kvstore implementations, use config.pip_packages instead.
This function returns the union of all dependencies for cases where the specific
kvstore type is not known at declaration time (e.g., provider registries).
"""
return ["aiosqlite", "psycopg2-binary", "redis", "pymongo"]
class InmemoryKVStoreImpl(KVStore):
def __init__(self):
self._store = {}
async def initialize(self) -> None:
pass
async def get(self, key: str) -> str | None:
return self._store.get(key)
async def set(self, key: str, value: str) -> None:
self._store[key] = value
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key]
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
"""Get all keys in the given range."""
return [key for key in self._store.keys() if key >= start_key and key < end_key]
async def delete(self, key: str) -> None:
del self._store[key]
_KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {}
def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
"""Register the set of available KV store backends for reference resolution."""
global _KVSTORE_BACKENDS
_KVSTORE_BACKENDS.clear()
for name, cfg in backends.items():
_KVSTORE_BACKENDS[name] = cfg
async def kvstore_impl(reference: KVStoreReference) -> KVStore:
backend_name = reference.backend
backend_config = _KVSTORE_BACKENDS.get(backend_name)
if backend_config is None:
raise ValueError(f"Unknown KVStore backend '{backend_name}'. Registered backends: {sorted(_KVSTORE_BACKENDS)}")
config = backend_config.model_copy()
config.namespace = reference.namespace
if config.type == StorageBackendType.KV_REDIS.value:
from .redis import RedisKVStoreImpl
impl = RedisKVStoreImpl(config)
elif config.type == StorageBackendType.KV_SQLITE.value:
from .sqlite import SqliteKVStoreImpl
impl = SqliteKVStoreImpl(config)
elif config.type == StorageBackendType.KV_POSTGRES.value:
from .postgres import PostgresKVStoreImpl
impl = PostgresKVStoreImpl(config)
elif config.type == StorageBackendType.KV_MONGODB.value:
from .mongodb import MongoDBKVStoreImpl
impl = MongoDBKVStoreImpl(config)
else:
raise ValueError(f"Unknown kvstore type {config.type}")
await impl.initialize()
return impl

View file

@ -1,9 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .mongodb import MongoDBKVStoreImpl
__all__ = ["MongoDBKVStoreImpl"]

View file

@ -1,82 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from pymongo import AsyncMongoClient
from pymongo.asynchronous.collection import AsyncCollection
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
from ..config import MongoDBKVStoreConfig
log = get_logger(name=__name__, category="providers::utils")
class MongoDBKVStoreImpl(KVStore):
def __init__(self, config: MongoDBKVStoreConfig):
self.config = config
self.conn: AsyncMongoClient | None = None
@property
def collection(self) -> AsyncCollection:
if self.conn is None:
raise RuntimeError("MongoDB connection is not initialized")
return self.conn[self.config.db][self.config.collection_name]
async def initialize(self) -> None:
try:
conn_creds = {
"host": self.config.host,
"port": self.config.port,
"username": self.config.user,
"password": self.config.password,
}
conn_creds = {k: v for k, v in conn_creds.items() if v is not None}
self.conn = AsyncMongoClient(**conn_creds)
except Exception as e:
log.exception("Could not connect to MongoDB database server")
raise RuntimeError("Could not connect to MongoDB database server") from e
def _namespaced_key(self, key: str) -> str:
if not self.config.namespace:
return key
return f"{self.config.namespace}:{key}"
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
key = self._namespaced_key(key)
update_query = {"$set": {"value": value, "expiration": expiration}}
await self.collection.update_one({"key": key}, update_query, upsert=True)
async def get(self, key: str) -> str | None:
key = self._namespaced_key(key)
query = {"key": key}
result = await self.collection.find_one(query, {"value": 1, "_id": 0})
return result["value"] if result else None
async def delete(self, key: str) -> None:
key = self._namespaced_key(key)
await self.collection.delete_one({"key": key})
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
query = {
"key": {"$gte": start_key, "$lt": end_key},
}
cursor = self.collection.find(query, {"value": 1, "_id": 0}).sort("key", 1)
result = []
async for doc in cursor:
result.append(doc["value"])
return result
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
query = {"key": {"$gte": start_key, "$lt": end_key}}
cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1)
return [doc["key"] for doc in cursor]

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .postgres import PostgresKVStoreImpl # noqa: F401 F403

View file

@ -1,114 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
import psycopg2
from psycopg2.extras import DictCursor
from llama_stack.log import get_logger
from ..api import KVStore
from ..config import PostgresKVStoreConfig
log = get_logger(name=__name__, category="providers::utils")
class PostgresKVStoreImpl(KVStore):
def __init__(self, config: PostgresKVStoreConfig):
self.config = config
self.conn = None
self.cursor = None
async def initialize(self) -> None:
try:
self.conn = psycopg2.connect(
host=self.config.host,
port=self.config.port,
database=self.config.db,
user=self.config.user,
password=self.config.password,
sslmode=self.config.ssl_mode,
sslrootcert=self.config.ca_cert_path,
)
self.conn.autocommit = True
self.cursor = self.conn.cursor(cursor_factory=DictCursor)
# Create table if it doesn't exist
self.cursor.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.config.table_name} (
key TEXT PRIMARY KEY,
value TEXT,
expiration TIMESTAMP
)
"""
)
except Exception as e:
log.exception("Could not connect to PostgreSQL database server")
raise RuntimeError("Could not connect to PostgreSQL database server") from e
def _namespaced_key(self, key: str) -> str:
if not self.config.namespace:
return key
return f"{self.config.namespace}:{key}"
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
key = self._namespaced_key(key)
self.cursor.execute(
f"""
INSERT INTO {self.config.table_name} (key, value, expiration)
VALUES (%s, %s, %s)
ON CONFLICT (key) DO UPDATE
SET value = EXCLUDED.value, expiration = EXCLUDED.expiration
""",
(key, value, expiration),
)
async def get(self, key: str) -> str | None:
key = self._namespaced_key(key)
self.cursor.execute(
f"""
SELECT value FROM {self.config.table_name}
WHERE key = %s
AND (expiration IS NULL OR expiration > NOW())
""",
(key,),
)
result = self.cursor.fetchone()
return result[0] if result else None
async def delete(self, key: str) -> None:
key = self._namespaced_key(key)
self.cursor.execute(
f"DELETE FROM {self.config.table_name} WHERE key = %s",
(key,),
)
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
self.cursor.execute(
f"""
SELECT value FROM {self.config.table_name}
WHERE key >= %s AND key < %s
AND (expiration IS NULL OR expiration > NOW())
ORDER BY key
""",
(start_key, end_key),
)
return [row[0] for row in self.cursor.fetchall()]
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
self.cursor.execute(
f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s",
(start_key, end_key),
)
return [row[0] for row in self.cursor.fetchall()]

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .redis import RedisKVStoreImpl # noqa: F401

View file

@ -1,76 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from redis.asyncio import Redis
from ..api import KVStore
from ..config import RedisKVStoreConfig
class RedisKVStoreImpl(KVStore):
def __init__(self, config: RedisKVStoreConfig):
self.config = config
async def initialize(self) -> None:
self.redis = Redis.from_url(self.config.url)
def _namespaced_key(self, key: str) -> str:
if not self.config.namespace:
return key
return f"{self.config.namespace}:{key}"
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
key = self._namespaced_key(key)
await self.redis.set(key, value)
if expiration:
await self.redis.expireat(key, expiration)
async def get(self, key: str) -> str | None:
key = self._namespaced_key(key)
value = await self.redis.get(key)
if value is None:
return None
await self.redis.ttl(key)
return value
async def delete(self, key: str) -> None:
key = self._namespaced_key(key)
await self.redis.delete(key)
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
cursor = 0
pattern = start_key + "*" # Match all keys starting with start_key prefix
matching_keys = []
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=1000)
for key in keys:
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
if start_key <= key_str <= end_key:
matching_keys.append(key)
if cursor == 0:
break
# Then fetch all values in a single MGET call
if matching_keys:
values = await self.redis.mget(matching_keys)
return [
value.decode("utf-8") if isinstance(value, bytes) else value for value in values if value is not None
]
return []
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
"""Get all keys in the given range."""
matching_keys = await self.redis.zrangebylex(self.namespace, f"[{start_key}", f"[{end_key}")
if not matching_keys:
return []
return [k.decode("utf-8") for k in matching_keys]

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .sqlite import SqliteKVStoreImpl # noqa: F401

View file

@ -1,20 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class SqliteControlPlaneConfig(BaseModel):
db_path: str = Field(
description="File path for the sqlite database",
)
table_name: str = Field(
default="llamastack_control_plane",
description="Table into which all the keys will be placed",
)

View file

@ -1,174 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from datetime import datetime
import aiosqlite
from llama_stack.log import get_logger
from ..api import KVStore
from ..config import SqliteKVStoreConfig
logger = get_logger(name=__name__, category="providers::utils")
class SqliteKVStoreImpl(KVStore):
def __init__(self, config: SqliteKVStoreConfig):
self.db_path = config.db_path
self.table_name = "kvstore"
self._conn: aiosqlite.Connection | None = None
def __str__(self):
return f"SqliteKVStoreImpl(db_path={self.db_path}, table_name={self.table_name})"
def _is_memory_db(self) -> bool:
"""Check if this is an in-memory database."""
return self.db_path == ":memory:" or "mode=memory" in self.db_path
async def initialize(self):
# Skip directory creation for in-memory databases and file: URIs
if not self._is_memory_db() and not self.db_path.startswith("file:"):
db_dir = os.path.dirname(self.db_path)
if db_dir: # Only create if there's a directory component
os.makedirs(db_dir, exist_ok=True)
# Only use persistent connection for in-memory databases
# File-based databases use connection-per-operation to avoid hangs
if self._is_memory_db():
self._conn = await aiosqlite.connect(self.db_path)
await self._conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
key TEXT PRIMARY KEY,
value TEXT,
expiration TIMESTAMP
)
"""
)
await self._conn.commit()
else:
# For file-based databases, just create the table
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
key TEXT PRIMARY KEY,
value TEXT,
expiration TIMESTAMP
)
"""
)
await db.commit()
async def shutdown(self):
"""Close the persistent connection (only for in-memory databases)."""
if self._conn:
await self._conn.close()
self._conn = None
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
if self._conn:
# In-memory database with persistent connection
await self._conn.execute(
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
(key, value, expiration),
)
await self._conn.commit()
else:
# File-based database with connection per operation
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
(key, value, expiration),
)
await db.commit()
async def get(self, key: str) -> str | None:
if self._conn:
# In-memory database with persistent connection
async with self._conn.execute(
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
) as cursor:
row = await cursor.fetchone()
if row is None:
return None
value, expiration = row
if not isinstance(value, str):
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
return None
return value
else:
# File-based database with connection per operation
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
) as cursor:
row = await cursor.fetchone()
if row is None:
return None
value, expiration = row
if not isinstance(value, str):
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
return None
return value
async def delete(self, key: str) -> None:
if self._conn:
# In-memory database with persistent connection
await self._conn.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
await self._conn.commit()
else:
# File-based database with connection per operation
async with aiosqlite.connect(self.db_path) as db:
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
await db.commit()
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
if self._conn:
# In-memory database with persistent connection
async with self._conn.execute(
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
(start_key, end_key),
) as cursor:
result = []
async for row in cursor:
_, value, _ = row
result.append(value)
return result
else:
# File-based database with connection per operation
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
(start_key, end_key),
) as cursor:
result = []
async for row in cursor:
_, value, _ = row
result.append(value)
return result
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
"""Get all keys in the given range."""
if self._conn:
# In-memory database with persistent connection
cursor = await self._conn.execute(
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
(start_key, end_key),
)
rows = await cursor.fetchall()
return [row[0] for row in rows]
else:
# File-based database with connection per operation
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute(
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
(start_key, end_key),
)
rows = await cursor.fetchall()
return [row[0] for row in rows]

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,26 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import mimetypes
import os
from llama_stack.apis.common.content_types import URL
def data_url_from_file(file_path: str) -> URL:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return URL(uri=data_url)

File diff suppressed because it is too large Load diff

View file

@ -1,332 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import io
import re
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any
from urllib.parse import unquote
import httpx
import numpy as np
from numpy.typing import NDArray
from pydantic import BaseModel
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
)
from llama_stack.apis.inference import OpenAIEmbeddingsRequestWithExtraBody
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
log = get_logger(name=__name__, category="providers::utils")
class ChunkForDeletion(BaseModel):
"""Information needed to delete a chunk from a vector store.
:param chunk_id: The ID of the chunk to delete
:param document_id: The ID of the document this chunk belongs to
"""
chunk_id: str
document_id: str
# Constants for reranker types
RERANKER_TYPE_RRF = "rrf"
RERANKER_TYPE_WEIGHTED = "weighted"
RERANKER_TYPE_NORMALIZED = "normalized"
def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string
pdf_bytes = io.BytesIO(data)
from pypdf import PdfReader
pdf_reader = PdfReader(pdf_bytes)
return "\n".join([page.extract_text() for page in pdf_reader.pages])
def parse_data_url(data_url: str):
data_url_pattern = re.compile(
r"^"
r"data:"
r"(?P<mimetype>[\w/\-+.]+)"
r"(?P<charset>;charset=(?P<encoding>[\w-]+))?"
r"(?P<base64>;base64)?"
r",(?P<data>.*)"
r"$",
re.DOTALL,
)
match = data_url_pattern.match(data_url)
if not match:
raise ValueError("Invalid Data URL format")
parts = match.groupdict()
parts["is_base64"] = bool(parts["base64"])
return parts
def content_from_data(data_url: str) -> str:
parts = parse_data_url(data_url)
data = parts["data"]
if parts["is_base64"]:
data = base64.b64decode(data)
else:
data = unquote(data)
encoding = parts["encoding"] or "utf-8"
data = data.encode(encoding)
return content_from_data_and_mime_type(data, parts["mimetype"], parts.get("encoding", None))
def content_from_data_and_mime_type(data: bytes | str, mime_type: str | None, encoding: str | None = None) -> str:
if isinstance(data, bytes):
if not encoding:
import chardet
detected = chardet.detect(data)
encoding = detected["encoding"]
mime_category = mime_type.split("/")[0] if mime_type else None
if mime_category == "text":
# For text-based files (including CSV, MD)
encodings_to_try = [encoding]
if encoding != "utf-8":
encodings_to_try.append("utf-8")
first_exception = None
for encoding in encodings_to_try:
try:
return data.decode(encoding)
except UnicodeDecodeError as e:
if first_exception is None:
first_exception = e
log.warning(f"Decoding failed with {encoding}: {e}")
# raise the origional exception, if we got here there was at least 1 exception
log.error(f"Could not decode data as any of {encodings_to_try}")
raise first_exception
elif mime_type == "application/pdf":
return parse_pdf(data)
else:
log.error("Could not extract content from data_url properly.")
return ""
async def content_from_doc(doc: RAGDocument) -> str:
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
return content_from_data(doc.content.uri)
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
return r.text
elif isinstance(doc.content, str):
pattern = re.compile("^(https?://|file://|data:)")
if pattern.match(doc.content):
if doc.content.startswith("data:"):
return content_from_data(doc.content)
async with httpx.AsyncClient() as client:
r = await client.get(doc.content)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
return r.text
return doc.content
else:
# will raise ValueError if the content is not List[InterleavedContent] or InterleavedContent
return interleaved_content_as_str(doc.content)
def make_overlapped_chunks(
document_id: str, text: str, window_len: int, overlap_len: int, metadata: dict[str, Any]
) -> list[Chunk]:
default_tokenizer = "DEFAULT_TIKTOKEN_TOKENIZER"
tokenizer = Tokenizer.get_instance()
tokens = tokenizer.encode(text, bos=False, eos=False)
try:
metadata_string = str(metadata)
except Exception as e:
raise ValueError("Failed to serialize metadata to string") from e
metadata_tokens = tokenizer.encode(metadata_string, bos=False, eos=False)
chunks = []
for i in range(0, len(tokens), window_len - overlap_len):
toks = tokens[i : i + window_len]
chunk = tokenizer.decode(toks)
chunk_window = f"{i}-{i + len(toks)}"
chunk_id = generate_chunk_id(chunk, text, chunk_window)
chunk_metadata = metadata.copy()
chunk_metadata["chunk_id"] = chunk_id
chunk_metadata["document_id"] = document_id
chunk_metadata["token_count"] = len(toks)
chunk_metadata["metadata_token_count"] = len(metadata_tokens)
backend_chunk_metadata = ChunkMetadata(
chunk_id=chunk_id,
document_id=document_id,
source=metadata.get("source", None),
created_timestamp=metadata.get("created_timestamp", int(time.time())),
updated_timestamp=int(time.time()),
chunk_window=chunk_window,
chunk_tokenizer=default_tokenizer,
chunk_embedding_model=None, # This will be set in `VectorStoreWithIndex.insert_chunks`
content_token_count=len(toks),
metadata_token_count=len(metadata_tokens),
)
# chunk is a string
chunks.append(
Chunk(
content=chunk,
metadata=chunk_metadata,
chunk_metadata=backend_chunk_metadata,
)
)
return chunks
def _validate_embedding(embedding: NDArray, index: int, expected_dimension: int):
"""Helper method to validate embedding format and dimensions"""
if not isinstance(embedding, (list | np.ndarray)):
raise ValueError(f"Embedding at index {index} must be a list or numpy array, got {type(embedding)}")
if isinstance(embedding, np.ndarray):
if not np.issubdtype(embedding.dtype, np.number):
raise ValueError(f"Embedding at index {index} contains non-numeric values")
else:
if not all(isinstance(e, (float | int | np.number)) for e in embedding):
raise ValueError(f"Embedding at index {index} contains non-numeric values")
if len(embedding) != expected_dimension:
raise ValueError(f"Embedding at index {index} has dimension {len(embedding)}, expected {expected_dimension}")
class EmbeddingIndex(ABC):
@abstractmethod
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
raise NotImplementedError()
@abstractmethod
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]):
raise NotImplementedError()
@abstractmethod
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError()
@abstractmethod
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError()
@abstractmethod
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError()
@abstractmethod
async def delete(self):
raise NotImplementedError()
@dataclass
class VectorStoreWithIndex:
vector_store: VectorStore
index: EmbeddingIndex
inference_api: Api.inference
async def insert_chunks(
self,
chunks: list[Chunk],
) -> None:
chunks_to_embed = []
for i, c in enumerate(chunks):
if c.embedding is None:
chunks_to_embed.append(c)
if c.chunk_metadata:
c.chunk_metadata.chunk_embedding_model = self.vector_store.embedding_model
c.chunk_metadata.chunk_embedding_dimension = self.vector_store.embedding_dimension
else:
_validate_embedding(c.embedding, i, self.vector_store.embedding_dimension)
if chunks_to_embed:
params = OpenAIEmbeddingsRequestWithExtraBody(
model=self.vector_store.embedding_model,
input=[c.content for c in chunks_to_embed],
)
resp = await self.inference_api.openai_embeddings(params)
for c, data in zip(chunks_to_embed, resp.data, strict=False):
c.embedding = data.embedding
embeddings = np.array([c.embedding for c in chunks], dtype=np.float32)
await self.index.add_chunks(chunks, embeddings)
async def query_chunks(
self,
query: InterleavedContent,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
if params is None:
params = {}
k = params.get("max_chunks", 3)
mode = params.get("mode")
score_threshold = params.get("score_threshold", 0.0)
ranker = params.get("ranker")
if ranker is None:
reranker_type = RERANKER_TYPE_RRF
reranker_params = {"impact_factor": 60.0}
else:
strategy = ranker.get("strategy", "rrf")
if strategy == "weighted":
weights = ranker.get("params", {}).get("weights", [0.5, 0.5])
reranker_type = RERANKER_TYPE_WEIGHTED
reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5}
elif strategy == "normalized":
reranker_type = RERANKER_TYPE_NORMALIZED
else:
reranker_type = RERANKER_TYPE_RRF
k_value = ranker.get("params", {}).get("k", 60.0)
reranker_params = {"impact_factor": k_value}
query_string = interleaved_content_as_str(query)
if mode == "keyword":
return await self.index.query_keyword(query_string, k, score_threshold)
params = OpenAIEmbeddingsRequestWithExtraBody(
model=self.vector_store.embedding_model,
input=[query_string],
)
embeddings_response = await self.inference_api.openai_embeddings(params)
query_vector = np.array(embeddings_response.data[0].embedding, dtype=np.float32)
if mode == "hybrid":
return await self.index.query_hybrid(
query_vector, query_string, k, score_threshold, reranker_type, reranker_params
)
else:
return await self.index.query_vector(query_vector, k, score_threshold)

View file

@ -1,43 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.responses import PaginatedResponse
def paginate_records(
records: list[dict[str, Any]],
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
"""Helper function to handle pagination of records consistently across implementations.
Inspired by stripe's pagination: https://docs.stripe.com/api/pagination
:param records: List of records to paginate
:param start_index: The starting index (0-based). If None, starts from beginning.
:param limit: Number of items to return. If None or -1, returns all items.
:return: PaginatedResponse with the paginated data
"""
# Handle special case for fetching all rows
if limit is None or limit == -1:
return PaginatedResponse(
data=records,
has_more=False,
)
# Use offset-based pagination
start_index = start_index or 0
end_index = min(start_index + limit, len(records))
page_data = records[start_index:end_index]
# Calculate if there are more records
has_more = end_index < len(records)
return PaginatedResponse(
data=page_data,
has_more=has_more,
)

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,354 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Any
from llama_stack.apis.agents import (
Order,
)
from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseInput,
OpenAIResponseObject,
OpenAIResponseObjectWithInput,
)
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqlStoreReference, StorageBackendType
from llama_stack.log import get_logger
from ..sqlstore.api import ColumnDefinition, ColumnType
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
from ..sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl
logger = get_logger(name=__name__, category="openai_responses")
class _OpenAIResponseObjectWithInputAndMessages(OpenAIResponseObjectWithInput):
"""Internal class for storing responses with chat completion messages.
This extends the public OpenAIResponseObjectWithInput with messages field
for internal storage. The messages field is not exposed in the public API.
The messages field is optional for backward compatibility with responses
stored before this feature was added.
"""
messages: list[OpenAIMessageParam] | None = None
class ResponsesStore:
def __init__(
self,
reference: ResponsesStoreReference | SqlStoreReference,
policy: list[AccessRule],
):
if isinstance(reference, ResponsesStoreReference):
self.reference = reference
else:
self.reference = ResponsesStoreReference(**reference.model_dump())
self.policy = policy
self.sql_store = None
self.enable_write_queue = True
# Async write queue and worker control
self._queue: (
asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput], list[OpenAIMessageParam]]] | None
) = None
self._worker_tasks: list[asyncio.Task[Any]] = []
self._max_write_queue_size: int = self.reference.max_write_queue_size
self._num_writers: int = max(1, self.reference.num_writers)
async def initialize(self):
"""Create the necessary tables if they don't exist."""
base_store = sqlstore_impl(self.reference)
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
if backend_config is None:
raise ValueError(
f"Unregistered SQL backend '{self.reference.backend}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
)
if backend_config.type == StorageBackendType.SQL_SQLITE:
self.enable_write_queue = False
await self.sql_store.create_table(
"openai_responses",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"created_at": ColumnType.INTEGER,
"response_object": ColumnType.JSON,
"model": ColumnType.STRING,
},
)
await self.sql_store.create_table(
"conversation_messages",
{
"conversation_id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"messages": ColumnType.JSON,
},
)
if self.enable_write_queue:
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
for _ in range(self._num_writers):
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
else:
logger.debug("Write queue disabled for SQLite to avoid concurrency issues")
async def shutdown(self) -> None:
if not self._worker_tasks:
return
if self._queue is not None:
await self._queue.join()
for t in self._worker_tasks:
if not t.done():
t.cancel()
for t in self._worker_tasks:
try:
await t
except asyncio.CancelledError:
pass
self._worker_tasks.clear()
async def flush(self) -> None:
"""Wait for all queued writes to complete. Useful for testing."""
if self.enable_write_queue and self._queue is not None:
await self._queue.join()
async def store_response_object(
self,
response_object: OpenAIResponseObject,
input: list[OpenAIResponseInput],
messages: list[OpenAIMessageParam],
) -> None:
if self.enable_write_queue:
if self._queue is None:
raise ValueError("Responses store is not initialized")
try:
self._queue.put_nowait((response_object, input, messages))
except asyncio.QueueFull:
logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '<unknown>')}")
await self._queue.put((response_object, input, messages))
else:
await self._write_response_object(response_object, input, messages)
async def _worker_loop(self) -> None:
assert self._queue is not None
while True:
try:
item = await self._queue.get()
except asyncio.CancelledError:
break
response_object, input, messages = item
try:
await self._write_response_object(response_object, input, messages)
except Exception as e: # noqa: BLE001
logger.error(f"Error writing response object: {e}")
finally:
self._queue.task_done()
async def _write_response_object(
self,
response_object: OpenAIResponseObject,
input: list[OpenAIResponseInput],
messages: list[OpenAIMessageParam],
) -> None:
if self.sql_store is None:
raise ValueError("Responses store is not initialized")
data = response_object.model_dump()
data["input"] = [input_item.model_dump() for input_item in input]
data["messages"] = [msg.model_dump() for msg in messages]
await self.sql_store.insert(
"openai_responses",
{
"id": data["id"],
"created_at": data["created_at"],
"model": data["model"],
"response_object": data,
},
)
async def list_responses(
self,
after: str | None = None,
limit: int | None = 50,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIResponseObject:
"""
List responses from the database.
:param after: The ID of the last response to return.
:param limit: The maximum number of responses to return.
:param model: The model to filter by.
:param order: The order to sort the responses by.
"""
if not self.sql_store:
raise ValueError("Responses store is not initialized")
if not order:
order = Order.desc
where_conditions = {}
if model:
where_conditions["model"] = model
paginated_result = await self.sql_store.fetch_all(
table="openai_responses",
where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,
limit=limit,
)
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
return ListOpenAIResponseObject(
data=data,
has_more=paginated_result.has_more,
first_id=data[0].id if data else "",
last_id=data[-1].id if data else "",
)
async def get_response_object(self, response_id: str) -> _OpenAIResponseObjectWithInputAndMessages:
"""
Get a response object with automatic access control checking.
"""
if not self.sql_store:
raise ValueError("Responses store is not initialized")
row = await self.sql_store.fetch_one(
"openai_responses",
where={"id": response_id},
)
if not row:
# SecureSqlStore will return None if record doesn't exist OR access is denied
# This provides security by not revealing whether the record exists
raise ValueError(f"Response with id {response_id} not found") from None
return _OpenAIResponseObjectWithInputAndMessages(**row["response_object"])
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
if not self.sql_store:
raise ValueError("Responses store is not initialized")
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
if not row:
raise ValueError(f"Response with id {response_id} not found")
await self.sql_store.delete("openai_responses", where={"id": response_id})
return OpenAIDeleteResponseObject(id=response_id)
async def list_response_input_items(
self,
response_id: str,
after: str | None = None,
before: str | None = None,
include: list[str] | None = None,
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
"""
List input items for a given response.
:param response_id: The ID of the response to retrieve input items for.
:param after: An item ID to list items after, used for pagination.
:param before: An item ID to list items before, used for pagination.
:param include: Additional fields to include in the response.
:param limit: A limit on the number of objects to be returned.
:param order: The order to return the input items in.
"""
if include:
raise NotImplementedError("Include is not supported yet")
if before and after:
raise ValueError("Cannot specify both 'before' and 'after' parameters")
response_with_input_and_messages = await self.get_response_object(response_id)
items = response_with_input_and_messages.input
if order == Order.desc:
items = list(reversed(items))
start_index = 0
end_index = len(items)
if after or before:
for i, item in enumerate(items):
item_id = getattr(item, "id", None)
if after and item_id == after:
start_index = i + 1
if before and item_id == before:
end_index = i
break
if after and start_index == 0:
raise ValueError(f"Input item with id '{after}' not found for response '{response_id}'")
if before and end_index == len(items):
raise ValueError(f"Input item with id '{before}' not found for response '{response_id}'")
items = items[start_index:end_index]
# Apply limit
if limit is not None:
items = items[:limit]
return ListOpenAIResponseInputItem(data=items)
async def store_conversation_messages(self, conversation_id: str, messages: list[OpenAIMessageParam]) -> None:
"""Store messages for a conversation.
:param conversation_id: The conversation identifier.
:param messages: List of OpenAI message parameters to store.
"""
if not self.sql_store:
raise ValueError("Responses store is not initialized")
# Serialize messages to dict format for JSON storage
messages_data = [msg.model_dump() for msg in messages]
# Upsert: try insert first, update if exists
try:
await self.sql_store.insert(
table="conversation_messages",
data={"conversation_id": conversation_id, "messages": messages_data},
)
except Exception:
# If insert fails due to ID conflict, update existing record
await self.sql_store.update(
table="conversation_messages",
data={"messages": messages_data},
where={"conversation_id": conversation_id},
)
logger.debug(f"Stored {len(messages)} messages for conversation {conversation_id}")
async def get_conversation_messages(self, conversation_id: str) -> list[OpenAIMessageParam] | None:
"""Get stored messages for a conversation.
:param conversation_id: The conversation identifier.
:returns: List of OpenAI message parameters, or None if no messages stored.
"""
if not self.sql_store:
raise ValueError("Responses store is not initialized")
record = await self.sql_store.fetch_one(
table="conversation_messages",
where={"conversation_id": conversation_id},
)
if record is None:
return None
# Deserialize messages from JSON storage
from pydantic import TypeAdapter
adapter = TypeAdapter(list[OpenAIMessageParam])
return adapter.validate_python(record["messages"])

View file

@ -1,270 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import abc
import asyncio
import functools
import threading
from collections.abc import Callable, Coroutine, Iterable
from datetime import UTC, datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="providers::utils")
# TODO: revisit the list of possible statuses when defining a more coherent
# Jobs API for all API flows; e.g. do we need new vs scheduled?
class JobStatus(Enum):
new = "new"
scheduled = "scheduled"
running = "running"
failed = "failed"
completed = "completed"
type JobID = str
type JobType = str
class JobArtifact(BaseModel):
type: JobType
name: str
# TODO: uri should be a reference to /files API; revisit when /files is implemented
uri: str | None = None
metadata: dict[str, Any]
JobHandler = Callable[
[Callable[[str], None], Callable[[JobStatus], None], Callable[[JobArtifact], None]], Coroutine[Any, Any, None]
]
type LogMessage = tuple[datetime, str]
_COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
class Job:
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler):
super().__init__()
self.id = job_id
self._type = job_type
self._handler = handler
self._artifacts: list[JobArtifact] = []
self._logs: list[LogMessage] = []
self._state_transitions: list[tuple[datetime, JobStatus]] = [(datetime.now(UTC), JobStatus.new)]
@property
def handler(self) -> JobHandler:
return self._handler
@property
def status(self) -> JobStatus:
return self._state_transitions[-1][1]
@status.setter
def status(self, status: JobStatus):
if status in _COMPLETED_STATUSES and self.status in _COMPLETED_STATUSES:
raise ValueError(f"Job is already in a completed state ({self.status})")
if self.status == status:
return
self._state_transitions.append((datetime.now(UTC), status))
@property
def artifacts(self) -> list[JobArtifact]:
return self._artifacts
def register_artifact(self, artifact: JobArtifact) -> None:
self._artifacts.append(artifact)
def _find_state_transition_date(self, status: Iterable[JobStatus]) -> datetime | None:
for date, s in reversed(self._state_transitions):
if s in status:
return date
return None
@property
def scheduled_at(self) -> datetime | None:
return self._find_state_transition_date([JobStatus.scheduled])
@property
def started_at(self) -> datetime | None:
return self._find_state_transition_date([JobStatus.running])
@property
def completed_at(self) -> datetime | None:
return self._find_state_transition_date(_COMPLETED_STATUSES)
@property
def logs(self) -> list[LogMessage]:
return self._logs[:]
def append_log(self, message: LogMessage) -> None:
self._logs.append(message)
# TODO: implement
def cancel(self) -> None:
raise NotImplementedError
class _SchedulerBackend(abc.ABC):
@abc.abstractmethod
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
raise NotImplementedError
@abc.abstractmethod
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
raise NotImplementedError
@abc.abstractmethod
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
raise NotImplementedError
@abc.abstractmethod
async def shutdown(self) -> None:
raise NotImplementedError
@abc.abstractmethod
def schedule(
self,
job: Job,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
on_artifact_collected_cb: Callable[[JobArtifact], None],
) -> None:
raise NotImplementedError
class _NaiveSchedulerBackend(_SchedulerBackend):
def __init__(self, timeout: int = 5):
self._timeout = timeout
self._loop = asyncio.new_event_loop()
# There may be performance implications of using threads due to Python
# GIL; may need to measure if it's a real problem though
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
def _run_loop(self) -> None:
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
# TODO: When stopping the loop, give tasks a chance to finish
# TODO: should we explicitly inform jobs of pending stoppage?
# cancel all tasks
for task in asyncio.all_tasks(self._loop):
if not task.done():
task.cancel()
self._loop.close()
async def shutdown(self) -> None:
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
# TODO: decouple scheduling and running the job
def schedule(
self,
job: Job,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
on_artifact_collected_cb: Callable[[JobArtifact], None],
) -> None:
async def do():
try:
job.status = JobStatus.running
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
except Exception as e:
on_log_message_cb(str(e))
job.status = JobStatus.failed
logger.exception(f"Job {job.id} failed.")
asyncio.run_coroutine_threadsafe(do(), self._loop)
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
pass
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
pass
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
pass
_BACKENDS = {
"naive": _NaiveSchedulerBackend,
}
def _get_backend_impl(backend: str) -> _SchedulerBackend:
try:
return _BACKENDS[backend]()
except KeyError as e:
raise ValueError(f"Unknown backend {backend}") from e
class Scheduler:
def __init__(self, backend: str = "naive"):
# TODO: if server crashes, job states are lost; we need to persist jobs on disc
self._jobs: dict[JobID, Job] = {}
self._backend = _get_backend_impl(backend)
def _on_log_message_cb(self, job: Job, message: str) -> None:
msg = (datetime.now(UTC), message)
# At least for the time being, until there's a better way to expose
# logs to users, log messages on console
logger.info(f"Job {job.id}: {message}")
job.append_log(msg)
self._backend.on_log_message_cb(job, msg)
def _on_status_change_cb(self, job: Job, status: JobStatus) -> None:
job.status = status
self._backend.on_status_change_cb(job, status)
def _on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
job.register_artifact(artifact)
self._backend.on_artifact_collected_cb(job, artifact)
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID:
job = Job(type_, job_id, handler)
if job.id in self._jobs:
raise ValueError(f"Job {job.id} already exists")
self._jobs[job.id] = job
job.status = JobStatus.scheduled
self._backend.schedule(
job,
functools.partial(self._on_log_message_cb, job),
functools.partial(self._on_status_change_cb, job),
functools.partial(self._on_artifact_collected_cb, job),
)
return job.id
def cancel(self, job_id: JobID) -> None:
self.get_job(job_id).cancel()
def get_job(self, job_id: JobID) -> Job:
try:
return self._jobs[job_id]
except KeyError as e:
raise ValueError(f"Job {job_id} not found") from e
def get_jobs(self, type_: JobType | None = None) -> list[Job]:
jobs = list(self._jobs.values())
if type_:
jobs = [job for job in jobs if job._type == type_]
return jobs
async def shutdown(self):
# TODO: also cancel jobs once implemented
await self._backend.shutdown()

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,75 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import statistics
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import AggregationFunctionType
def aggregate_accuracy(scoring_results: list[ScoringResultRow]) -> dict[str, Any]:
num_correct = sum(result["score"] for result in scoring_results)
avg_score = num_correct / len(scoring_results)
return {
"accuracy": avg_score,
"num_correct": num_correct,
"num_total": len(scoring_results),
}
def aggregate_average(scoring_results: list[ScoringResultRow]) -> dict[str, Any]:
return {
"average": sum(result["score"] for result in scoring_results if result["score"] is not None)
/ len([_ for _ in scoring_results if _["score"] is not None]),
}
def aggregate_weighted_average(scoring_results: list[ScoringResultRow]) -> dict[str, Any]:
return {
"weighted_average": sum(
result["score"] * result["weight"]
for result in scoring_results
if result["score"] is not None and result["weight"] is not None
)
/ sum(result["weight"] for result in scoring_results if result["weight"] is not None),
}
def aggregate_categorical_count(
scoring_results: list[ScoringResultRow],
) -> dict[str, Any]:
scores = [str(r["score"]) for r in scoring_results]
unique_scores = sorted(set(scores))
return {"categorical_count": {s: scores.count(s) for s in unique_scores}}
def aggregate_median(scoring_results: list[ScoringResultRow]) -> dict[str, Any]:
scores = [r["score"] for r in scoring_results if r["score"] is not None]
median = statistics.median(scores) if scores else None
return {"median": median}
# TODO: decide whether we want to make aggregation functions as a registerable resource
AGGREGATION_FUNCTIONS = {
AggregationFunctionType.accuracy: aggregate_accuracy,
AggregationFunctionType.average: aggregate_average,
AggregationFunctionType.weighted_average: aggregate_weighted_average,
AggregationFunctionType.categorical_count: aggregate_categorical_count,
AggregationFunctionType.median: aggregate_median,
}
def aggregate_metrics(
scoring_results: list[ScoringResultRow], metrics: list[AggregationFunctionType]
) -> dict[str, Any]:
agg_results = {}
for metric in metrics:
if metric not in AGGREGATION_FUNCTIONS:
raise ValueError(f"Aggregation function {metric} not found")
agg_fn = AGGREGATION_FUNCTIONS[metric]
agg_results[metric] = agg_fn(scoring_results)
return agg_results

View file

@ -1,114 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import Any
from llama_stack.apis.scoring import ScoringFnParams, ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
class BaseScoringFn(ABC):
"""
Base interface class for Scoring Functions.
Each scoring function needs to implement the following methods:
- score_row(self, row)
- aggregate(self, scoring_fn_results)
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def __str__(self) -> str:
return self.__class__.__name__
@abstractmethod
async def score_row(
self,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
raise NotImplementedError()
@abstractmethod
async def aggregate(
self,
scoring_results: list[ScoringResultRow],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> dict[str, Any]:
raise NotImplementedError()
@abstractmethod
async def score(
self,
input_rows: list[dict[str, Any]],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> list[ScoringResultRow]:
raise NotImplementedError()
class RegisteredBaseScoringFn(BaseScoringFn):
"""
Interface for native scoring functions that are registered in LlamaStack.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {}
def __str__(self) -> str:
return self.__class__.__name__
def get_supported_scoring_fn_defs(self) -> list[ScoringFn]:
return list(self.supported_fn_defs_registry.values())
def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None:
if scoring_fn.identifier in self.supported_fn_defs_registry:
raise ValueError(f"Scoring function def with identifier {scoring_fn.identifier} already exists.")
self.supported_fn_defs_registry[scoring_fn.identifier] = scoring_fn
def unregister_scoring_fn_def(self, scoring_fn_id: str) -> None:
if scoring_fn_id not in self.supported_fn_defs_registry:
raise ValueError(f"Scoring function def with identifier {scoring_fn_id} does not exist.")
del self.supported_fn_defs_registry[scoring_fn_id]
@abstractmethod
async def score_row(
self,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
raise NotImplementedError()
async def aggregate(
self,
scoring_results: list[ScoringResultRow],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> dict[str, Any]:
params = self.supported_fn_defs_registry[scoring_fn_identifier].params
if scoring_params is not None:
if params is None:
params = scoring_params
else:
params.aggregation_functions = scoring_params.aggregation_functions
aggregation_functions = []
if params and hasattr(params, "aggregation_functions") and params.aggregation_functions:
aggregation_functions.extend(params.aggregation_functions)
return aggregate_metrics(scoring_results, aggregation_functions)
async def score(
self,
input_rows: list[dict[str, Any]],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> list[ScoringResultRow]:
return [await self.score_row(input_row, scoring_fn_identifier, scoring_params) for input_row in input_rows]

View file

@ -1,26 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import contextlib
import signal
from collections.abc import Iterator
from types import FrameType
class TimeoutError(Exception):
pass
@contextlib.contextmanager
def time_limit(seconds: float) -> Iterator[None]:
def signal_handler(signum: int, frame: FrameType | None) -> None:
raise TimeoutError("Timed out!")
signal.setitimer(signal.ITIMER_REAL, seconds)
signal.signal(signal.SIGALRM, signal_handler)
try:
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,128 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Literal, Protocol
from pydantic import BaseModel
from llama_stack.apis.common.responses import PaginatedResponse
class ColumnType(Enum):
INTEGER = "INTEGER"
STRING = "STRING"
TEXT = "TEXT"
FLOAT = "FLOAT"
BOOLEAN = "BOOLEAN"
JSON = "JSON"
DATETIME = "DATETIME"
class ColumnDefinition(BaseModel):
type: ColumnType
primary_key: bool = False
nullable: bool = True
default: Any = None
class SqlStore(Protocol):
"""
A protocol for a SQL store.
"""
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
"""
Create a table.
"""
pass
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
"""
Insert a row or batch of rows into a table.
"""
pass
async def fetch_all(
self,
table: str,
where: Mapping[str, Any] | None = None,
where_sql: str | None = None,
limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
cursor: tuple[str, str] | None = None,
) -> PaginatedResponse:
"""
Fetch all rows from a table with optional cursor-based pagination.
:param table: The table name
:param where: Simple key-value WHERE conditions
:param where_sql: Raw SQL WHERE clause for complex queries
:param limit: Maximum number of records to return
:param order_by: List of (column, order) tuples for sorting
:param cursor: Tuple of (key_column, cursor_id) for pagination (None for first page)
Requires order_by with exactly one column when used
:return: PaginatedResult with data and has_more flag
Note: Cursor pagination only supports single-column ordering for simplicity.
Multi-column ordering is allowed without cursor but will raise an error with cursor.
"""
pass
async def fetch_one(
self,
table: str,
where: Mapping[str, Any] | None = None,
where_sql: str | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None:
"""
Fetch one row from a table.
"""
pass
async def update(
self,
table: str,
data: Mapping[str, Any],
where: Mapping[str, Any],
) -> None:
"""
Update a row in a table.
"""
pass
async def delete(
self,
table: str,
where: Mapping[str, Any],
) -> None:
"""
Delete a row from a table.
"""
pass
async def add_column_if_not_exists(
self,
table: str,
column_name: str,
column_type: ColumnType,
nullable: bool = True,
) -> None:
"""
Add a column to an existing table if the column doesn't already exist.
This is useful for table migrations when adding new functionality.
If the table doesn't exist, this method should do nothing.
If the column already exists, this method should do nothing.
:param table: Table name
:param column_name: Name of the column to add
:param column_type: Type of the column to add
:param nullable: Whether the column should be nullable (default: True)
"""
pass

View file

@ -1,303 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
from llama_stack.core.access_control.conditions import ProtectedResource
from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope
from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.core.storage.datatypes import StorageBackendType
from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
logger = get_logger(name=__name__, category="providers::utils")
# Hardcoded copy of the default policy that our SQL filtering implements
# WARNING: If default_policy() changes, this constant must be updated accordingly
# or SQL filtering will fall back to conservative mode (safe but less performant)
#
# This policy represents: "Permit all actions when user is in owners list for ALL attribute categories"
# The corresponding SQL logic is implemented in _build_default_policy_where_clause():
# - Public records (no access_attributes) are always accessible
# - Records with access_attributes require user to match ALL categories that exist in the resource
# - Missing categories in the resource are treated as "no restriction" (allow)
# - Within each category, user needs ANY matching value (OR logic)
# - Between categories, user needs ALL categories to match (AND logic)
SQL_OPTIMIZED_POLICY = [
AccessRule(
permit=Scope(actions=list(Action)),
when=["user in owners roles", "user in owners teams", "user in owners projects", "user in owners namespaces"],
),
]
def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: User | None) -> Mapping[str, Any]:
"""Add access control attributes to a data item."""
enhanced = dict(item)
if current_user:
enhanced["owner_principal"] = current_user.principal
enhanced["access_attributes"] = current_user.attributes
else:
enhanced["owner_principal"] = None
enhanced["access_attributes"] = None
return enhanced
class SqlRecord(ProtectedResource):
def __init__(self, record_id: str, table_name: str, owner: User):
self.type = f"sql_record::{table_name}"
self.identifier = record_id
self.owner = owner
class AuthorizedSqlStore:
"""
Authorization layer for SqlStore that provides access control functionality.
This class composes a base SqlStore and adds authorization methods that handle
access control policies, user attribute capture, and SQL filtering optimization.
"""
def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
"""
Initialize the authorization layer.
:param sql_store: Base SqlStore implementation to wrap
:param policy: Access control policy to use for authorization
"""
self.sql_store = sql_store
self.policy = policy
self._detect_database_type()
self._validate_sql_optimized_policy()
def _detect_database_type(self) -> None:
"""Detect the database type from the underlying SQL store."""
if not hasattr(self.sql_store, "config"):
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
self.database_type = self.sql_store.config.type.value
if self.database_type not in [StorageBackendType.SQL_POSTGRES.value, StorageBackendType.SQL_SQLITE.value]:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _validate_sql_optimized_policy(self) -> None:
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
This ensures that if default_policy() changes, we detect the mismatch and
can update our SQL filtering logic accordingly.
"""
actual_default = default_policy()
if SQL_OPTIMIZED_POLICY != actual_default:
logger.warning(
f"SQL_OPTIMIZED_POLICY does not match default_policy(). "
f"SQL filtering will use conservative mode. "
f"Expected: {SQL_OPTIMIZED_POLICY}, Got: {actual_default}",
)
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
"""Create a table with built-in access control support."""
enhanced_schema = dict(schema)
if "access_attributes" not in enhanced_schema:
enhanced_schema["access_attributes"] = ColumnType.JSON
if "owner_principal" not in enhanced_schema:
enhanced_schema["owner_principal"] = ColumnType.STRING
await self.sql_store.create_table(table, enhanced_schema)
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
"""Insert a row or batch of rows with automatic access control attribute capture."""
current_user = get_authenticated_user()
enhanced_data: Mapping[str, Any] | Sequence[Mapping[str, Any]]
if isinstance(data, Mapping):
enhanced_data = _enhance_item_with_access_control(data, current_user)
else:
enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
await self.sql_store.insert(table, enhanced_data)
async def fetch_all(
self,
table: str,
where: Mapping[str, Any] | None = None,
limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
cursor: tuple[str, str] | None = None,
) -> PaginatedResponse:
"""Fetch all rows with automatic access control filtering."""
access_where = self._build_access_control_where_clause(self.policy)
rows = await self.sql_store.fetch_all(
table=table,
where=where,
where_sql=access_where,
limit=limit,
order_by=order_by,
cursor=cursor,
)
current_user = get_authenticated_user()
filtered_rows = []
for row in rows.data:
stored_access_attrs = row.get("access_attributes")
stored_owner_principal = row.get("owner_principal") or ""
record_id = row.get("id", "unknown")
sql_record = SqlRecord(
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
)
if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
filtered_rows.append(row)
return PaginatedResponse(
data=filtered_rows,
has_more=rows.has_more,
)
async def fetch_one(
self,
table: str,
where: Mapping[str, Any] | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None:
"""Fetch one row with automatic access control checking."""
results = await self.fetch_all(
table=table,
where=where,
limit=1,
order_by=order_by,
)
return results.data[0] if results.data else None
async def update(self, table: str, data: Mapping[str, Any], where: Mapping[str, Any]) -> None:
"""Update rows with automatic access control attribute capture."""
enhanced_data = dict(data)
current_user = get_authenticated_user()
if current_user:
enhanced_data["owner_principal"] = current_user.principal
enhanced_data["access_attributes"] = current_user.attributes
else:
enhanced_data["owner_principal"] = None
enhanced_data["access_attributes"] = None
await self.sql_store.update(table, enhanced_data, where)
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
"""Delete rows with automatic access control filtering."""
await self.sql_store.delete(table, where)
def _build_access_control_where_clause(self, policy: list[AccessRule]) -> str:
"""Build SQL WHERE clause for access control filtering.
Only applies SQL filtering for the default policy to ensure correctness.
For custom policies, uses conservative filtering to avoid blocking legitimate access.
"""
current_user = get_authenticated_user()
if not policy or policy == SQL_OPTIMIZED_POLICY:
return self._build_default_policy_where_clause(current_user)
else:
return self._build_conservative_where_clause()
def _json_extract(self, column: str, path: str) -> str:
"""Extract JSON value (keeping JSON type).
Args:
column: The JSON column name
path: The JSON path (e.g., 'roles', 'teams')
Returns:
SQL expression to extract JSON value
"""
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
return f"{column}->'{path}'"
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
return f"JSON_EXTRACT({column}, '$.{path}')"
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _json_extract_text(self, column: str, path: str) -> str:
"""Extract JSON value as text.
Args:
column: The JSON column name
path: The JSON path (e.g., 'roles', 'teams')
Returns:
SQL expression to extract JSON value as text
"""
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
return f"{column}->>'{path}'"
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
return f"JSON_EXTRACT({column}, '$.{path}')"
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _get_public_access_conditions(self) -> list[str]:
"""Get the SQL conditions for public access."""
# Public records are records that have no owner_principal or access_attributes
conditions = ["owner_principal = ''"]
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
# Postgres stores JSON null as 'null'
conditions.append("access_attributes::text = 'null'")
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
conditions.append("access_attributes = 'null'")
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
return conditions
def _build_default_policy_where_clause(self, current_user: User | None) -> str:
"""Build SQL WHERE clause for the default policy.
Default policy: permit all actions when user in owners [roles, teams, projects, namespaces]
This means user must match ALL attribute categories that exist in the resource.
"""
base_conditions = self._get_public_access_conditions()
user_attr_conditions = []
if current_user and current_user.attributes:
for attr_key, user_values in current_user.attributes.items():
if user_values:
value_conditions = []
for value in user_values:
# Check if JSON array contains the value
escaped_value = value.replace("'", "''")
json_text = self._json_extract_text("access_attributes", attr_key)
value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')")
if value_conditions:
# Check if the category is missing (NULL)
category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL"
user_matches_category = f"({' OR '.join(value_conditions)})"
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
if user_attr_conditions:
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
base_conditions.append(all_requirements_met)
return f"({' OR '.join(base_conditions)})"
def _build_conservative_where_clause(self) -> str:
"""Conservative SQL filtering for custom policies.
Only filters records we're 100% certain would be denied by any reasonable policy.
"""
current_user = get_authenticated_user()
if not current_user:
# Only allow public records
base_conditions = self._get_public_access_conditions()
return f"({' OR '.join(base_conditions)})"
return "1=1"

View file

@ -1,313 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from sqlalchemy import (
JSON,
Boolean,
Column,
DateTime,
Float,
Integer,
MetaData,
String,
Table,
Text,
inspect,
select,
text,
)
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.ext.asyncio.engine import AsyncEngine
from sqlalchemy.sql.elements import ColumnElement
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.core.storage.datatypes import SqlAlchemySqlStoreConfig
from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, SqlStore
logger = get_logger(name=__name__, category="providers::utils")
TYPE_MAPPING: dict[ColumnType, Any] = {
ColumnType.INTEGER: Integer,
ColumnType.STRING: String,
ColumnType.FLOAT: Float,
ColumnType.BOOLEAN: Boolean,
ColumnType.DATETIME: DateTime,
ColumnType.TEXT: Text,
ColumnType.JSON: JSON,
}
def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
"""Return a SQLAlchemy expression for a where condition.
`value` may be a simple scalar (equality) or a mapping like {">": 123}.
The returned expression is a SQLAlchemy ColumnElement usable in query.where(...).
"""
if isinstance(value, Mapping):
if len(value) != 1:
raise ValueError(f"Operator mapping must have a single operator, got: {value}")
op, operand = next(iter(value.items()))
if op == "==" or op == "=":
return column == operand
if op == ">":
return column > operand
if op == "<":
return column < operand
if op == ">=":
return column >= operand
if op == "<=":
return column <= operand
raise ValueError(f"Unsupported operator '{op}' in where mapping")
return column == value
class SqlAlchemySqlStoreImpl(SqlStore):
def __init__(self, config: SqlAlchemySqlStoreConfig):
self.config = config
self.async_session = async_sessionmaker(self.create_engine())
self.metadata = MetaData()
def create_engine(self) -> AsyncEngine:
return create_async_engine(self.config.engine_str, pool_pre_ping=True)
async def create_table(
self,
table: str,
schema: Mapping[str, ColumnType | ColumnDefinition],
) -> None:
if not schema:
raise ValueError(f"No columns defined for table '{table}'.")
sqlalchemy_columns: list[Column] = []
for col_name, col_props in schema.items():
col_type = None
is_primary_key = False
is_nullable = True
if isinstance(col_props, ColumnType):
col_type = col_props
elif isinstance(col_props, ColumnDefinition):
col_type = col_props.type
is_primary_key = col_props.primary_key
is_nullable = col_props.nullable
sqlalchemy_type = TYPE_MAPPING.get(col_type)
if not sqlalchemy_type:
raise ValueError(f"Unsupported column type '{col_type}' for column '{col_name}'.")
sqlalchemy_columns.append(
Column(col_name, sqlalchemy_type, primary_key=is_primary_key, nullable=is_nullable)
)
if table not in self.metadata.tables:
sqlalchemy_table = Table(table, self.metadata, *sqlalchemy_columns)
else:
sqlalchemy_table = self.metadata.tables[table]
engine = self.create_engine()
async with engine.begin() as conn:
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
async with self.async_session() as session:
await session.execute(self.metadata.tables[table].insert(), data)
await session.commit()
async def fetch_all(
self,
table: str,
where: Mapping[str, Any] | None = None,
where_sql: str | None = None,
limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
cursor: tuple[str, str] | None = None,
) -> PaginatedResponse:
async with self.async_session() as session:
table_obj = self.metadata.tables[table]
query = select(table_obj)
if where:
for key, value in where.items():
query = query.where(_build_where_expr(table_obj.c[key], value))
if where_sql:
query = query.where(text(where_sql))
# Handle cursor-based pagination
if cursor:
# Validate cursor tuple format
if not isinstance(cursor, tuple) or len(cursor) != 2:
raise ValueError(f"Cursor must be a tuple of (key_column, cursor_id), got: {cursor}")
# Require order_by for cursor pagination
if not order_by:
raise ValueError("order_by is required when using cursor pagination")
# Only support single-column ordering for cursor pagination
if len(order_by) != 1:
raise ValueError(
f"Cursor pagination only supports single-column ordering, got {len(order_by)} columns"
)
cursor_key_column, cursor_id = cursor
order_column, order_direction = order_by[0]
# Verify cursor_key_column exists
if cursor_key_column not in table_obj.c:
raise ValueError(f"Cursor key column '{cursor_key_column}' not found in table '{table}'")
# Get cursor value for the order column
cursor_query = select(table_obj.c[order_column]).where(table_obj.c[cursor_key_column] == cursor_id)
cursor_result = await session.execute(cursor_query)
cursor_row = cursor_result.fetchone()
if not cursor_row:
raise ValueError(f"Record with {cursor_key_column}='{cursor_id}' not found in table '{table}'")
cursor_value = cursor_row[0]
# Apply cursor condition based on sort direction
if order_direction == "desc":
query = query.where(table_obj.c[order_column] < cursor_value)
else:
query = query.where(table_obj.c[order_column] > cursor_value)
# Apply ordering
if order_by:
if not isinstance(order_by, list):
raise ValueError(
f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}"
)
for order in order_by:
if not isinstance(order, tuple):
raise ValueError(
f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}"
)
name, order_type = order
if name not in table_obj.c:
raise ValueError(f"Column '{name}' not found in table '{table}'")
if order_type == "asc":
query = query.order_by(table_obj.c[name].asc())
elif order_type == "desc":
query = query.order_by(table_obj.c[name].desc())
else:
raise ValueError(f"Invalid order '{order_type}' for column '{name}'")
# Fetch limit + 1 to determine has_more
fetch_limit = limit
if limit:
fetch_limit = limit + 1
if fetch_limit:
query = query.limit(fetch_limit)
result = await session.execute(query)
if result.rowcount == 0:
rows = []
else:
rows = [dict(row._mapping) for row in result]
# Always return pagination result
has_more = False
if limit and len(rows) > limit:
has_more = True
rows = rows[:limit]
return PaginatedResponse(data=rows, has_more=has_more)
async def fetch_one(
self,
table: str,
where: Mapping[str, Any] | None = None,
where_sql: str | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None:
result = await self.fetch_all(table, where, where_sql, limit=1, order_by=order_by)
if not result.data:
return None
return result.data[0]
async def update(
self,
table: str,
data: Mapping[str, Any],
where: Mapping[str, Any],
) -> None:
if not where:
raise ValueError("where is required for update")
async with self.async_session() as session:
stmt = self.metadata.tables[table].update()
for key, value in where.items():
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
await session.execute(stmt, data)
await session.commit()
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
if not where:
raise ValueError("where is required for delete")
async with self.async_session() as session:
stmt = self.metadata.tables[table].delete()
for key, value in where.items():
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
await session.execute(stmt)
await session.commit()
async def add_column_if_not_exists(
self,
table: str,
column_name: str,
column_type: ColumnType,
nullable: bool = True,
) -> None:
"""Add a column to an existing table if the column doesn't already exist."""
engine = self.create_engine()
try:
async with engine.begin() as conn:
def check_column_exists(sync_conn):
inspector = inspect(sync_conn)
table_names = inspector.get_table_names()
if table not in table_names:
return False, False # table doesn't exist, column doesn't exist
existing_columns = inspector.get_columns(table)
column_names = [col["name"] for col in existing_columns]
return True, column_name in column_names # table exists, column exists or not
table_exists, column_exists = await conn.run_sync(check_column_exists)
if not table_exists or column_exists:
return
sqlalchemy_type = TYPE_MAPPING.get(column_type)
if not sqlalchemy_type:
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
# Create the ALTER TABLE statement
# Note: We need to get the dialect-specific type name
dialect = engine.dialect
type_impl = sqlalchemy_type()
compiled_type = type_impl.compile(dialect=dialect)
nullable_clause = "" if nullable else " NOT NULL"
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
await conn.execute(add_column_sql)
except Exception as e:
# If any error occurs during migration, log it but don't fail
# The table creation will handle adding the column
logger.error(f"Error adding column {column_name} to table {table}: {e}")
pass

View file

@ -1,70 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, cast
from pydantic import Field
from llama_stack.core.storage.datatypes import (
PostgresSqlStoreConfig,
SqliteSqlStoreConfig,
SqlStoreReference,
StorageBackendConfig,
StorageBackendType,
)
from .api import SqlStore
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
_SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {}
SqlStoreConfig = Annotated[
SqliteSqlStoreConfig | PostgresSqlStoreConfig,
Field(discriminator="type"),
]
def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]:
"""Get pip packages for SQL store config, handling both dict and object cases."""
if isinstance(store_config, dict):
store_type = store_config.get("type")
if store_type == StorageBackendType.SQL_SQLITE.value:
return SqliteSqlStoreConfig.pip_packages()
elif store_type == StorageBackendType.SQL_POSTGRES.value:
return PostgresSqlStoreConfig.pip_packages()
else:
raise ValueError(f"Unknown SQL store type: {store_type}")
else:
return store_config.pip_packages()
def sqlstore_impl(reference: SqlStoreReference) -> SqlStore:
backend_name = reference.backend
backend_config = _SQLSTORE_BACKENDS.get(backend_name)
if backend_config is None:
raise ValueError(
f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
)
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
return SqlAlchemySqlStoreImpl(config)
else:
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
"""Register the set of available SQL store backends for reference resolution."""
global _SQLSTORE_BACKENDS
_SQLSTORE_BACKENDS.clear()
for name, cfg in backends.items():
_SQLSTORE_BACKENDS[name] = cfg

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,148 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from enum import Enum
from typing import Any, cast
import httpx
from mcp import ClientSession, McpError
from mcp import types as mcp_types
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolDef,
ToolInvocationResult,
)
from llama_stack.core.datatypes import AuthenticationRequiredError
from llama_stack.log import get_logger
from llama_stack.providers.utils.tools.ttl_dict import TTLDict
logger = get_logger(__name__, category="tools")
protocol_cache = TTLDict(ttl_seconds=3600)
class MCPProtol(Enum):
UNKNOWN = 0
STREAMABLE_HTTP = 1
SSE = 2
@asynccontextmanager
async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
# we use a ttl'd dict to cache the happy path protocol for each endpoint
# but, we always fall back to trying the other protocol if we cannot initialize the session
connection_strategies = [MCPProtol.STREAMABLE_HTTP, MCPProtol.SSE]
mcp_protocol = protocol_cache.get(endpoint, default=MCPProtol.UNKNOWN)
if mcp_protocol == MCPProtol.SSE:
connection_strategies = [MCPProtol.SSE, MCPProtol.STREAMABLE_HTTP]
for i, strategy in enumerate(connection_strategies):
try:
client = streamablehttp_client
if strategy == MCPProtol.SSE:
client = sse_client
async with client(endpoint, headers=headers) as client_streams:
async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session:
await session.initialize()
protocol_cache[endpoint] = strategy
yield session
return
except* httpx.HTTPStatusError as eg:
for exc in eg.exceptions:
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
err = cast(httpx.HTTPStatusError, exc)
if err.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
if i == len(connection_strategies) - 1:
raise
except* httpx.ConnectError as eg:
# Connection refused, server down, network unreachable
if i == len(connection_strategies) - 1:
error_msg = f"Failed to connect to MCP server at {endpoint}: Connection refused"
logger.error(f"MCP connection error: {error_msg}")
raise ConnectionError(error_msg) from eg
else:
logger.warning(
f"failed to connect to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
)
except* httpx.TimeoutException as eg:
# Request timeout, server too slow
if i == len(connection_strategies) - 1:
error_msg = f"MCP server at {endpoint} timed out"
logger.error(f"MCP timeout error: {error_msg}")
raise TimeoutError(error_msg) from eg
else:
logger.warning(
f"MCP server at {endpoint} timed out via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
)
except* httpx.RequestError as eg:
# DNS resolution failures, network errors, invalid URLs
if i == len(connection_strategies) - 1:
# Get the first exception's message for the error string
exc_msg = str(eg.exceptions[0]) if eg.exceptions else "Unknown error"
error_msg = f"Network error connecting to MCP server at {endpoint}: {exc_msg}"
logger.error(f"MCP network error: {error_msg}")
raise ConnectionError(error_msg) from eg
else:
logger.warning(
f"network error connecting to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
)
except* McpError:
if i < len(connection_strategies) - 1:
logger.warning(
f"failed to connect via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
)
else:
raise
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
tools = []
async with client_wrapper(endpoint, headers) as session:
tools_result = await session.list_tools()
for tool in tools_result.tools:
tools.append(
ToolDef(
name=tool.name,
description=tool.description,
input_schema=tool.inputSchema,
output_schema=getattr(tool, "outputSchema", None),
metadata={
"endpoint": endpoint,
},
)
)
return ListToolDefsResponse(data=tools)
async def invoke_mcp_tool(
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
) -> ToolInvocationResult:
async with client_wrapper(endpoint, headers) as session:
result = await session.call_tool(tool_name, kwargs)
content: list[InterleavedContentItem] = []
for item in result.content:
if isinstance(item, mcp_types.TextContent):
content.append(TextContentItem(text=item.text))
elif isinstance(item, mcp_types.ImageContent):
content.append(ImageContentItem(image=item.data))
elif isinstance(item, mcp_types.EmbeddedResource):
logger.warning(f"EmbeddedResource is not supported: {item}")
else:
raise ValueError(f"Unknown content type: {type(item)}")
return ToolInvocationResult(
content=content,
error_code=1 if result.isError else 0,
)

View file

@ -1,70 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
from threading import RLock
from typing import Any
class TTLDict(dict):
"""
A dictionary with a ttl for each item
"""
def __init__(self, ttl_seconds: float, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ttl_seconds = ttl_seconds
self._expires: dict[Any, Any] = {} # expires holds when an item will expire
self._lock = RLock()
if args or kwargs:
for k, v in self.items():
self.__setitem__(k, v)
def __delitem__(self, key):
with self._lock:
del self._expires[key]
super().__delitem__(key)
def __setitem__(self, key, value):
with self._lock:
self._expires[key] = time.monotonic() + self.ttl_seconds
super().__setitem__(key, value)
def _is_expired(self, key):
if key not in self._expires:
return False
return time.monotonic() > self._expires[key]
def __getitem__(self, key):
with self._lock:
if self._is_expired(key):
del self._expires[key]
super().__delitem__(key)
raise KeyError(f"{key} has expired and was removed")
return super().__getitem__(key)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def __contains__(self, key):
try:
_ = self[key]
return True
except KeyError:
return False
def __repr__(self):
with self._lock:
for key in self.keys():
if self._is_expired(key):
del self._expires[key]
super().__delitem__(key)
return f"TTLDict({self.ttl_seconds}, {super().__repr__()})"

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,156 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import hashlib
import re
import uuid
def generate_chunk_id(document_id: str, chunk_text: str, chunk_window: str | None = None) -> str:
"""
Generate a unique chunk ID using a hash of the document ID and chunk text.
Then use the first 32 characters of the hash to create a UUID.
"""
hash_input = f"{document_id}:{chunk_text}".encode()
if chunk_window:
hash_input += f":{chunk_window}".encode()
return str(uuid.UUID(hashlib.sha256(hash_input).hexdigest()[:32]))
def proper_case(s: str) -> str:
"""Convert a string to proper case (first letter uppercase, rest lowercase)."""
return s[0].upper() + s[1:].lower() if s else s
def sanitize_collection_name(name: str, weaviate_format=False) -> str:
"""
Sanitize collection name to ensure it only contains numbers, letters, and underscores.
Any other characters are replaced with underscores.
"""
if not weaviate_format:
s = re.sub(r"[^a-zA-Z0-9_]", "_", name)
else:
s = proper_case(re.sub(r"[^a-zA-Z0-9]", "", name))
return s
class WeightedInMemoryAggregator:
@staticmethod
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
"""
Normalize scores to 0-1 range using min-max normalization.
Args:
scores: dictionary of scores with document IDs as keys and scores as values
Returns:
Normalized scores with document IDs as keys and normalized scores as values
"""
if not scores:
return {}
min_score, max_score = min(scores.values()), max(scores.values())
score_range = max_score - min_score
if score_range > 0:
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
return dict.fromkeys(scores, 1.0)
@staticmethod
def weighted_rerank(
vector_scores: dict[str, float],
keyword_scores: dict[str, float],
alpha: float = 0.5,
) -> dict[str, float]:
"""
Rerank via weighted average of scores.
Args:
vector_scores: scores from vector search
keyword_scores: scores from keyword search
alpha: weight factor between 0 and 1 (default: 0.5)
0 = keyword only, 1 = vector only, 0.5 = equal weight
Returns:
All unique document IDs with weighted combined scores
"""
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
normalized_vector_scores = WeightedInMemoryAggregator._normalize_scores(vector_scores)
normalized_keyword_scores = WeightedInMemoryAggregator._normalize_scores(keyword_scores)
# Weighted formula: score = (1-alpha) * keyword_score + alpha * vector_score
# alpha=0 means keyword only, alpha=1 means vector only
return {
doc_id: ((1 - alpha) * normalized_keyword_scores.get(doc_id, 0.0))
+ (alpha * normalized_vector_scores.get(doc_id, 0.0))
for doc_id in all_ids
}
@staticmethod
def rrf_rerank(
vector_scores: dict[str, float],
keyword_scores: dict[str, float],
impact_factor: float = 60.0,
) -> dict[str, float]:
"""
Rerank via Reciprocal Rank Fusion.
Args:
vector_scores: scores from vector search
keyword_scores: scores from keyword search
impact_factor: impact factor for RRF (default: 60.0)
Returns:
All unique document IDs with RRF combined scores
"""
# Convert scores to ranks
vector_ranks = {
doc_id: i + 1
for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
}
keyword_ranks = {
doc_id: i + 1
for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
}
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
rrf_scores = {}
for doc_id in all_ids:
vector_rank = vector_ranks.get(doc_id, float("inf"))
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
# RRF formula: score = 1/(k + r) where k is impact_factor (default: 60.0) and r is the rank
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
return rrf_scores
@staticmethod
def combine_search_results(
vector_scores: dict[str, float],
keyword_scores: dict[str, float],
reranker_type: str = "rrf",
reranker_params: dict[str, float] | None = None,
) -> dict[str, float]:
"""
Combine vector and keyword search results using specified reranking strategy.
Args:
vector_scores: scores from vector search
keyword_scores: scores from keyword search
reranker_type: type of reranker to use (default: RERANKER_TYPE_RRF)
reranker_params: parameters for the reranker
Returns:
All unique document IDs with combined scores
"""
if reranker_params is None:
reranker_params = {}
if reranker_type == "weighted":
alpha = reranker_params.get("alpha", 0.5)
return WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha)
else:
# Default to RRF for None, RRF, or any unknown types
impact_factor = reranker_params.get("impact_factor", 60.0)
return WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor)