mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 02:03:44 +00:00
chore(package): migrate to src/ layout (#3920)
Migrates package structure to src/ layout following Python packaging best practices. All code moved from `llama_stack/` to `src/llama_stack/`. Public API unchanged - imports remain `import llama_stack.*`. Updated build configs, pre-commit hooks, scripts, and GitHub workflows accordingly. All hooks pass, package builds cleanly. **Developer note**: Reinstall after pulling: `pip install -e .`
This commit is contained in:
parent
98a5047f9d
commit
471b1b248b
791 changed files with 2983 additions and 456 deletions
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,74 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import boto3
|
||||
from botocore.client import BaseClient
|
||||
from botocore.config import Config
|
||||
|
||||
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||
from llama_stack.providers.utils.bedrock.refreshable_boto_session import (
|
||||
RefreshableBotoSession,
|
||||
)
|
||||
|
||||
|
||||
def create_bedrock_client(config: BedrockBaseConfig, service_name: str = "bedrock-runtime") -> BaseClient:
|
||||
"""Creates a boto3 client for Bedrock services with the given configuration.
|
||||
|
||||
Args:
|
||||
config: The Bedrock configuration containing AWS credentials and settings
|
||||
service_name: The AWS service name to create client for (default: "bedrock-runtime")
|
||||
|
||||
Returns:
|
||||
A configured boto3 client
|
||||
"""
|
||||
if config.aws_access_key_id and config.aws_secret_access_key:
|
||||
retries_config = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
total_max_attempts=config.total_max_attempts,
|
||||
mode=config.retry_mode,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
config_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
region_name=config.region_name,
|
||||
retries=retries_config if retries_config else None,
|
||||
connect_timeout=config.connect_timeout,
|
||||
read_timeout=config.read_timeout,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_config = Config(**config_args)
|
||||
|
||||
session_args = {
|
||||
"aws_access_key_id": config.aws_access_key_id,
|
||||
"aws_secret_access_key": config.aws_secret_access_key,
|
||||
"aws_session_token": config.aws_session_token,
|
||||
"region_name": config.region_name,
|
||||
"profile_name": config.profile_name,
|
||||
"session_ttl": config.session_ttl,
|
||||
}
|
||||
|
||||
# Remove None values
|
||||
session_args = {k: v for k, v in session_args.items() if v is not None}
|
||||
|
||||
boto3_session = boto3.session.Session(**session_args)
|
||||
return boto3_session.client(service_name, config=boto3_config)
|
||||
else:
|
||||
return (
|
||||
RefreshableBotoSession(
|
||||
region_name=config.region_name,
|
||||
profile_name=config.profile_name,
|
||||
session_ttl=config.session_ttl,
|
||||
)
|
||||
.refreshable_session()
|
||||
.client(service_name)
|
||||
)
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
|
||||
|
||||
class BedrockBaseConfig(RemoteInferenceProviderConfig):
|
||||
auth_credential: None = Field(default=None, exclude=True)
|
||||
aws_access_key_id: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
||||
)
|
||||
aws_secret_access_key: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
|
||||
)
|
||||
aws_session_token: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AWS_SESSION_TOKEN"),
|
||||
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
|
||||
)
|
||||
region_name: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AWS_DEFAULT_REGION"),
|
||||
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
|
||||
"Default use environment variable: AWS_DEFAULT_REGION",
|
||||
)
|
||||
profile_name: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AWS_PROFILE"),
|
||||
description="The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE",
|
||||
)
|
||||
total_max_attempts: int | None = Field(
|
||||
default_factory=lambda: int(val) if (val := os.getenv("AWS_MAX_ATTEMPTS")) else None,
|
||||
description="An integer representing the maximum number of attempts that will be made for a single request, "
|
||||
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
|
||||
)
|
||||
retry_mode: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AWS_RETRY_MODE"),
|
||||
description="A string representing the type of retries Boto3 will perform."
|
||||
"Default use environment variable: AWS_RETRY_MODE",
|
||||
)
|
||||
connect_timeout: float | None = Field(
|
||||
default_factory=lambda: float(os.getenv("AWS_CONNECT_TIMEOUT", "60")),
|
||||
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
|
||||
"The default is 60 seconds.",
|
||||
)
|
||||
read_timeout: float | None = Field(
|
||||
default_factory=lambda: float(os.getenv("AWS_READ_TIMEOUT", "60")),
|
||||
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
|
||||
"The default is 60 seconds.",
|
||||
)
|
||||
session_ttl: int | None = Field(
|
||||
default_factory=lambda: int(os.getenv("AWS_SESSION_TTL", "3600")),
|
||||
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs):
|
||||
return {}
|
||||
|
|
@ -1,112 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import datetime
|
||||
from time import time
|
||||
from uuid import uuid4
|
||||
|
||||
from boto3 import Session
|
||||
from botocore.credentials import RefreshableCredentials
|
||||
from botocore.session import get_session
|
||||
|
||||
|
||||
class RefreshableBotoSession:
|
||||
"""
|
||||
Boto Helper class which lets us create a refreshable session so that we can cache the client or resource.
|
||||
|
||||
Usage
|
||||
-----
|
||||
session = RefreshableBotoSession().refreshable_session()
|
||||
|
||||
client = session.client("s3") # we now can cache this client object without worrying about expiring credentials
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
region_name: str = None,
|
||||
profile_name: str = None,
|
||||
sts_arn: str = None,
|
||||
session_name: str = None,
|
||||
session_ttl: int = 30000,
|
||||
):
|
||||
"""
|
||||
Initialize `RefreshableBotoSession`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
region_name : str (optional)
|
||||
Default region when creating a new connection.
|
||||
|
||||
profile_name : str (optional)
|
||||
The name of a profile to use.
|
||||
|
||||
sts_arn : str (optional)
|
||||
The role arn to sts before creating a session.
|
||||
|
||||
session_name : str (optional)
|
||||
An identifier for the assumed role session. (required when `sts_arn` is given)
|
||||
|
||||
session_ttl : int (optional)
|
||||
An integer number to set the TTL for each session. Beyond this session, it will renew the token.
|
||||
50 minutes by default which is before the default role expiration of 1 hour
|
||||
"""
|
||||
|
||||
self.region_name = region_name
|
||||
self.profile_name = profile_name
|
||||
self.sts_arn = sts_arn
|
||||
self.session_name = session_name or uuid4().hex
|
||||
self.session_ttl = session_ttl
|
||||
|
||||
def __get_session_credentials(self):
|
||||
"""
|
||||
Get session credentials
|
||||
"""
|
||||
session = Session(region_name=self.region_name, profile_name=self.profile_name)
|
||||
|
||||
# if sts_arn is given, get credential by assuming the given role
|
||||
if self.sts_arn:
|
||||
sts_client = session.client(service_name="sts", region_name=self.region_name)
|
||||
response = sts_client.assume_role(
|
||||
RoleArn=self.sts_arn,
|
||||
RoleSessionName=self.session_name,
|
||||
DurationSeconds=self.session_ttl,
|
||||
).get("Credentials")
|
||||
|
||||
credentials = {
|
||||
"access_key": response.get("AccessKeyId"),
|
||||
"secret_key": response.get("SecretAccessKey"),
|
||||
"token": response.get("SessionToken"),
|
||||
"expiry_time": response.get("Expiration").isoformat(),
|
||||
}
|
||||
else:
|
||||
session_credentials = session.get_credentials().get_frozen_credentials()
|
||||
credentials = {
|
||||
"access_key": session_credentials.access_key,
|
||||
"secret_key": session_credentials.secret_key,
|
||||
"token": session_credentials.token,
|
||||
"expiry_time": datetime.datetime.fromtimestamp(time() + self.session_ttl, datetime.UTC).isoformat(),
|
||||
}
|
||||
|
||||
return credentials
|
||||
|
||||
def refreshable_session(self) -> Session:
|
||||
"""
|
||||
Get refreshable boto3 session.
|
||||
"""
|
||||
# Get refreshable credentials
|
||||
refreshable_credentials = RefreshableCredentials.create_from_metadata(
|
||||
metadata=self.__get_session_credentials(),
|
||||
refresh_using=self.__get_session_credentials,
|
||||
method="sts-assume-role",
|
||||
)
|
||||
|
||||
# attach refreshable credentials current session
|
||||
session = get_session()
|
||||
session._credentials = refreshable_credentials
|
||||
session.set_config_variable("region", self.region_name)
|
||||
autorefresh_session = Session(botocore_session=session)
|
||||
|
||||
return autorefresh_session
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,103 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.type_system import (
|
||||
ChatCompletionInputType,
|
||||
CompletionInputType,
|
||||
StringType,
|
||||
)
|
||||
from llama_stack.core.datatypes import Api
|
||||
|
||||
|
||||
class ColumnName(Enum):
|
||||
input_query = "input_query"
|
||||
expected_answer = "expected_answer"
|
||||
chat_completion_input = "chat_completion_input"
|
||||
completion_input = "completion_input"
|
||||
generated_answer = "generated_answer"
|
||||
context = "context"
|
||||
dialog = "dialog"
|
||||
function = "function"
|
||||
language = "language"
|
||||
id = "id"
|
||||
ground_truth = "ground_truth"
|
||||
|
||||
|
||||
VALID_SCHEMAS_FOR_SCORING = [
|
||||
{
|
||||
ColumnName.input_query.value: StringType(),
|
||||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.generated_answer.value: StringType(),
|
||||
},
|
||||
{
|
||||
ColumnName.input_query.value: StringType(),
|
||||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.generated_answer.value: StringType(),
|
||||
ColumnName.context.value: StringType(),
|
||||
},
|
||||
{
|
||||
ColumnName.input_query.value: StringType(),
|
||||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.generated_answer.value: StringType(),
|
||||
ColumnName.function.value: StringType(),
|
||||
ColumnName.language.value: StringType(),
|
||||
ColumnName.id.value: StringType(),
|
||||
ColumnName.ground_truth.value: StringType(),
|
||||
},
|
||||
]
|
||||
|
||||
VALID_SCHEMAS_FOR_EVAL = [
|
||||
{
|
||||
ColumnName.input_query.value: StringType(),
|
||||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
|
||||
},
|
||||
{
|
||||
ColumnName.input_query.value: StringType(),
|
||||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.completion_input.value: CompletionInputType(),
|
||||
},
|
||||
{
|
||||
ColumnName.input_query.value: StringType(),
|
||||
ColumnName.expected_answer.value: StringType(),
|
||||
ColumnName.generated_answer.value: StringType(),
|
||||
ColumnName.function.value: StringType(),
|
||||
ColumnName.language.value: StringType(),
|
||||
ColumnName.id.value: StringType(),
|
||||
ColumnName.ground_truth.value: StringType(),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_valid_schemas(api_str: str):
|
||||
if api_str == Api.scoring.value:
|
||||
return VALID_SCHEMAS_FOR_SCORING
|
||||
elif api_str == Api.eval.value:
|
||||
return VALID_SCHEMAS_FOR_EVAL
|
||||
else:
|
||||
raise ValueError(f"Invalid API string: {api_str}")
|
||||
|
||||
|
||||
def validate_dataset_schema(
|
||||
dataset_schema: dict[str, Any],
|
||||
expected_schemas: list[dict[str, Any]],
|
||||
):
|
||||
if dataset_schema not in expected_schemas:
|
||||
raise ValueError(f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}")
|
||||
|
||||
|
||||
def validate_row_schema(
|
||||
input_row: dict[str, Any],
|
||||
expected_schemas: list[dict[str, Any]],
|
||||
):
|
||||
for schema in expected_schemas:
|
||||
if all(key in input_row for key in schema):
|
||||
return
|
||||
|
||||
raise ValueError(f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}")
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from urllib.parse import unquote
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||
|
||||
|
||||
async def get_dataframe_from_uri(uri: str):
|
||||
import pandas
|
||||
|
||||
df = None
|
||||
if uri.endswith(".csv"):
|
||||
# Moving to its own thread to avoid io from blocking the eventloop
|
||||
# This isn't ideal as it moves more then just the IO to a new thread
|
||||
# but it is as close as we can easly get
|
||||
df = await asyncio.to_thread(pandas.read_csv, uri)
|
||||
elif uri.endswith(".xlsx"):
|
||||
df = await asyncio.to_thread(pandas.read_excel, uri)
|
||||
elif uri.startswith("data:"):
|
||||
parts = parse_data_url(uri)
|
||||
data = parts["data"]
|
||||
if parts["is_base64"]:
|
||||
data = base64.b64decode(data)
|
||||
else:
|
||||
data = unquote(data)
|
||||
encoding = parts["encoding"] or "utf-8"
|
||||
data = data.encode(encoding)
|
||||
|
||||
mime_type = parts["mimetype"]
|
||||
mime_category = mime_type.split("/")[0]
|
||||
data_bytes = io.BytesIO(data)
|
||||
|
||||
if mime_category == "text":
|
||||
df = pandas.read_csv(data_bytes)
|
||||
else:
|
||||
df = pandas.read_excel(data_bytes)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {uri}")
|
||||
|
||||
return df
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,69 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.apis.files import ExpiresAfter
|
||||
|
||||
|
||||
async def parse_pydantic_from_form[T: BaseModel](request: Request, field_name: str, model_class: type[T]) -> T | None:
|
||||
"""
|
||||
Generic parser to extract a Pydantic model from multipart form data.
|
||||
Handles both bracket notation (field[attr1], field[attr2]) and JSON string format.
|
||||
|
||||
Args:
|
||||
request: The FastAPI request object
|
||||
field_name: The name of the field in the form data (e.g., "expires_after")
|
||||
model_class: The Pydantic model class to parse into
|
||||
|
||||
Returns:
|
||||
An instance of model_class if parsing succeeds, None otherwise
|
||||
|
||||
Example:
|
||||
expires_after = await parse_pydantic_from_form(
|
||||
request, "expires_after", ExpiresAfter
|
||||
)
|
||||
"""
|
||||
form = await request.form()
|
||||
|
||||
# Check for bracket notation first (e.g., expires_after[anchor], expires_after[seconds])
|
||||
bracket_data = {}
|
||||
prefix = f"{field_name}["
|
||||
for key in form.keys():
|
||||
if key.startswith(prefix) and key.endswith("]"):
|
||||
# Extract the attribute name from field_name[attr]
|
||||
attr = key[len(prefix) : -1]
|
||||
bracket_data[attr] = form[key]
|
||||
|
||||
if bracket_data:
|
||||
try:
|
||||
return model_class(**bracket_data)
|
||||
except (ValidationError, TypeError):
|
||||
pass
|
||||
|
||||
# Check for JSON string format
|
||||
if field_name in form:
|
||||
value = form[field_name]
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
data = json.loads(value)
|
||||
return model_class(**data)
|
||||
except (json.JSONDecodeError, TypeError, ValidationError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def parse_expires_after(request: Request) -> ExpiresAfter | None:
|
||||
"""
|
||||
Dependency to parse expires_after from multipart form data.
|
||||
Handles both bracket notation (expires_after[anchor], expires_after[seconds])
|
||||
and JSON string format.
|
||||
"""
|
||||
return await parse_pydantic_from_form(request, "expires_after", ExpiresAfter)
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.models.llama.sku_types import * # noqa: F403
|
||||
|
||||
|
||||
def is_supported_safety_model(model: Model) -> bool:
|
||||
if model.quantization_format != CheckpointQuantizationFormat.bf16:
|
||||
return False
|
||||
|
||||
model_id = model.core_model_id
|
||||
return model_id in [
|
||||
CoreModelId.llama_guard_3_8b,
|
||||
CoreModelId.llama_guard_3_1b,
|
||||
CoreModelId.llama_guard_3_11b_vision,
|
||||
]
|
||||
|
||||
|
||||
def supported_inference_models() -> list[Model]:
|
||||
return [
|
||||
m
|
||||
for m in all_registered_models()
|
||||
if (
|
||||
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3, ModelFamily.llama4}
|
||||
or is_supported_safety_model(m)
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR = {m.huggingface_repo: m.descriptor() for m in all_registered_models()}
|
||||
|
|
@ -1,102 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import platform
|
||||
import struct
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ModelStore,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
)
|
||||
|
||||
EMBEDDING_MODELS = {}
|
||||
|
||||
DARWIN = "Darwin"
|
||||
|
||||
|
||||
log = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class SentenceTransformerEmbeddingMixin:
|
||||
model_store: ModelStore
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
# Convert input to list format if it's a single string
|
||||
input_list = [params.input] if isinstance(params.input, str) else params.input
|
||||
if not input_list:
|
||||
raise ValueError("Empty list not supported")
|
||||
|
||||
# Get the model and generate embeddings
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id)
|
||||
embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False)
|
||||
|
||||
# Convert embeddings to the requested format
|
||||
data = []
|
||||
for i, embedding in enumerate(embeddings):
|
||||
if params.encoding_format == "base64":
|
||||
# Convert float array to base64 string
|
||||
float_bytes = struct.pack(f"{len(embedding)}f", *embedding)
|
||||
embedding_value = base64.b64encode(float_bytes).decode("ascii")
|
||||
else:
|
||||
# Default to float format
|
||||
embedding_value = embedding.tolist()
|
||||
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=embedding_value,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
|
||||
# Not returning actual token usage
|
||||
usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=params.model,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||
global EMBEDDING_MODELS
|
||||
|
||||
loaded_model = EMBEDDING_MODELS.get(model)
|
||||
if loaded_model is not None:
|
||||
return loaded_model
|
||||
|
||||
log.info(f"Loading sentence transformer for {model}...")
|
||||
|
||||
def _load_model():
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
platform_name = platform.system()
|
||||
if platform_name == DARWIN:
|
||||
# PyTorch's OpenMP kernels can segfault on macOS when spawned from background
|
||||
# threads with the default parallel settings, so force a single-threaded CPU run.
|
||||
log.debug(f"Constraining torch threads on {platform_name} to a single worker")
|
||||
torch.set_num_threads(1)
|
||||
|
||||
return SentenceTransformer(model, trust_remote_code=True)
|
||||
|
||||
loaded_model = await asyncio.to_thread(_load_model)
|
||||
EMBEDDING_MODELS[model] = loaded_model
|
||||
return loaded_model
|
||||
|
|
@ -1,244 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ListOpenAIChatCompletionResponse,
|
||||
OpenAIChatCompletion,
|
||||
OpenAICompletionWithInputMessages,
|
||||
OpenAIMessageParam,
|
||||
Order,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.storage.datatypes import InferenceStoreReference, StorageBackendType
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
||||
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from ..sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class InferenceStore:
|
||||
def __init__(
|
||||
self,
|
||||
reference: InferenceStoreReference,
|
||||
policy: list[AccessRule],
|
||||
):
|
||||
self.reference = reference
|
||||
self.sql_store = None
|
||||
self.policy = policy
|
||||
|
||||
# Async write queue and worker control
|
||||
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
|
||||
self._worker_tasks: list[asyncio.Task[Any]] = []
|
||||
self._max_write_queue_size: int = reference.max_write_queue_size
|
||||
self._num_writers: int = max(1, reference.num_writers)
|
||||
|
||||
async def initialize(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
base_store = sqlstore_impl(self.reference)
|
||||
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
||||
|
||||
# Disable write queue for SQLite to avoid concurrency issues
|
||||
backend_name = self.reference.backend
|
||||
backend_config = _SQLSTORE_BACKENDS.get(backend_name)
|
||||
if backend_config is None:
|
||||
raise ValueError(
|
||||
f"Unregistered SQL backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
||||
)
|
||||
self.enable_write_queue = backend_config.type != StorageBackendType.SQL_SQLITE
|
||||
await self.sql_store.create_table(
|
||||
"chat_completions",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"created": ColumnType.INTEGER,
|
||||
"model": ColumnType.STRING,
|
||||
"choices": ColumnType.JSON,
|
||||
"input_messages": ColumnType.JSON,
|
||||
},
|
||||
)
|
||||
|
||||
if self.enable_write_queue:
|
||||
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||
for _ in range(self._num_writers):
|
||||
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||
else:
|
||||
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._worker_tasks:
|
||||
return
|
||||
if self._queue is not None:
|
||||
await self._queue.join()
|
||||
for t in self._worker_tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
for t in self._worker_tasks:
|
||||
try:
|
||||
await t
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._worker_tasks.clear()
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Wait for all queued writes to complete. Useful for testing."""
|
||||
if self.enable_write_queue and self._queue is not None:
|
||||
await self._queue.join()
|
||||
|
||||
async def store_chat_completion(
|
||||
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
|
||||
) -> None:
|
||||
if self.enable_write_queue:
|
||||
if self._queue is None:
|
||||
raise ValueError("Inference store is not initialized")
|
||||
try:
|
||||
self._queue.put_nowait((chat_completion, input_messages))
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
f"Write queue full; adding chat completion id={getattr(chat_completion, 'id', '<unknown>')}"
|
||||
)
|
||||
await self._queue.put((chat_completion, input_messages))
|
||||
else:
|
||||
await self._write_chat_completion(chat_completion, input_messages)
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
assert self._queue is not None
|
||||
while True:
|
||||
try:
|
||||
item = await self._queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
chat_completion, input_messages = item
|
||||
try:
|
||||
await self._write_chat_completion(chat_completion, input_messages)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Error writing chat completion: {e}")
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _write_chat_completion(
|
||||
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
|
||||
) -> None:
|
||||
if self.sql_store is None:
|
||||
raise ValueError("Inference store is not initialized")
|
||||
|
||||
data = chat_completion.model_dump()
|
||||
record_data = {
|
||||
"id": data["id"],
|
||||
"created": data["created"],
|
||||
"model": data["model"],
|
||||
"choices": data["choices"],
|
||||
"input_messages": [message.model_dump() for message in input_messages],
|
||||
}
|
||||
|
||||
try:
|
||||
await self.sql_store.insert(
|
||||
table="chat_completions",
|
||||
data=record_data,
|
||||
)
|
||||
except IntegrityError as e:
|
||||
# Duplicate chat completion IDs can be generated during tests especially if they are replaying
|
||||
# recorded responses across different tests. No need to warn or error under those circumstances.
|
||||
# In the wild, this is not likely to happen at all (no evidence) so we aren't really hiding any problem.
|
||||
|
||||
# Check if it's a unique constraint violation
|
||||
error_message = str(e.orig) if e.orig else str(e)
|
||||
if self._is_unique_constraint_error(error_message):
|
||||
# Update the existing record instead
|
||||
await self.sql_store.update(table="chat_completions", data=record_data, where={"id": data["id"]})
|
||||
else:
|
||||
# Re-raise if it's not a unique constraint error
|
||||
raise
|
||||
|
||||
def _is_unique_constraint_error(self, error_message: str) -> bool:
|
||||
"""Check if the error is specifically a unique constraint violation."""
|
||||
error_lower = error_message.lower()
|
||||
return any(
|
||||
indicator in error_lower
|
||||
for indicator in [
|
||||
"unique constraint failed", # SQLite
|
||||
"duplicate key", # PostgreSQL
|
||||
"unique violation", # PostgreSQL alternative
|
||||
"duplicate entry", # MySQL
|
||||
]
|
||||
)
|
||||
|
||||
async def list_chat_completions(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 50,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIChatCompletionResponse:
|
||||
"""
|
||||
List chat completions from the database.
|
||||
|
||||
:param after: The ID of the last chat completion to return.
|
||||
:param limit: The maximum number of chat completions to return.
|
||||
:param model: The model to filter by.
|
||||
:param order: The order to sort the chat completions by.
|
||||
"""
|
||||
if not self.sql_store:
|
||||
raise ValueError("Inference store is not initialized")
|
||||
|
||||
if not order:
|
||||
order = Order.desc
|
||||
|
||||
where_conditions = {}
|
||||
if model:
|
||||
where_conditions["model"] = model
|
||||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="chat_completions",
|
||||
where=where_conditions if where_conditions else None,
|
||||
order_by=[("created", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
data = [
|
||||
OpenAICompletionWithInputMessages(
|
||||
id=row["id"],
|
||||
created=row["created"],
|
||||
model=row["model"],
|
||||
choices=row["choices"],
|
||||
input_messages=row["input_messages"],
|
||||
)
|
||||
for row in paginated_result.data
|
||||
]
|
||||
return ListOpenAIChatCompletionResponse(
|
||||
data=data,
|
||||
has_more=paginated_result.has_more,
|
||||
first_id=data[0].id if data else "",
|
||||
last_id=data[-1].id if data else "",
|
||||
)
|
||||
|
||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||
if not self.sql_store:
|
||||
raise ValueError("Inference store is not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one(
|
||||
table="chat_completions",
|
||||
where={"id": completion_id},
|
||||
)
|
||||
|
||||
if not row:
|
||||
# SecureSqlStore will return None if record doesn't exist OR access is denied
|
||||
# This provides security by not revealing whether the record exists
|
||||
raise ValueError(f"Chat completion with id {completion_id} not found") from None
|
||||
|
||||
return OpenAICompletionWithInputMessages(
|
||||
id=row["id"],
|
||||
created=row["created"],
|
||||
model=row["model"],
|
||||
choices=row["choices"],
|
||||
input_messages=row["input_messages"],
|
||||
)
|
||||
|
|
@ -1,336 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import struct
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import litellm
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
InferenceProvider,
|
||||
JsonSchemaResponseFormat,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
ToolChoice,
|
||||
)
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict_new,
|
||||
convert_tooldef_to_openai_tool,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class LiteLLMOpenAIMixin(
|
||||
ModelRegistryHelper,
|
||||
InferenceProvider,
|
||||
NeedsRequestProviderData,
|
||||
):
|
||||
# TODO: avoid exposing the litellm specific model names to the user.
|
||||
# potential change: add a prefix param that gets added to the model name
|
||||
# when calling litellm.
|
||||
def __init__(
|
||||
self,
|
||||
litellm_provider_name: str,
|
||||
api_key_from_config: str | None,
|
||||
provider_data_api_key_field: str | None = None,
|
||||
model_entries: list[ProviderModelEntry] | None = None,
|
||||
openai_compat_api_base: str | None = None,
|
||||
download_images: bool = False,
|
||||
json_schema_strict: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the LiteLLMOpenAIMixin.
|
||||
|
||||
:param model_entries: The model entries to register.
|
||||
:param api_key_from_config: The API key to use from the config.
|
||||
:param provider_data_api_key_field: The field in the provider data that contains the API key (optional).
|
||||
:param litellm_provider_name: The name of the provider, used for model lookups.
|
||||
:param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility.
|
||||
:param download_images: Whether to download images and convert to base64 for message conversion.
|
||||
:param json_schema_strict: Whether to use strict mode for JSON schema validation.
|
||||
"""
|
||||
ModelRegistryHelper.__init__(self, model_entries=model_entries)
|
||||
|
||||
self.litellm_provider_name = litellm_provider_name
|
||||
self.api_key_from_config = api_key_from_config
|
||||
self.provider_data_api_key_field = provider_data_api_key_field
|
||||
self.api_base = openai_compat_api_base
|
||||
self.download_images = download_images
|
||||
self.json_schema_strict = json_schema_strict
|
||||
|
||||
if openai_compat_api_base:
|
||||
self.is_openai_compat = True
|
||||
else:
|
||||
self.is_openai_compat = False
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
def get_litellm_model_name(self, model_id: str) -> str:
|
||||
# users may be using openai/ prefix in their model names. the openai/models.py did this by default.
|
||||
# model_id.startswith("openai/") is for backwards compatibility.
|
||||
return (
|
||||
f"{self.litellm_provider_name}/{model_id}"
|
||||
if self.is_openai_compat and not model_id.startswith(self.litellm_provider_name)
|
||||
else model_id
|
||||
)
|
||||
|
||||
def _add_additional_properties_recursive(self, schema):
|
||||
"""
|
||||
Recursively add additionalProperties: False to all object schemas
|
||||
"""
|
||||
if isinstance(schema, dict):
|
||||
if schema.get("type") == "object":
|
||||
schema["additionalProperties"] = False
|
||||
|
||||
# Add required field with all property keys if properties exist
|
||||
if "properties" in schema and schema["properties"]:
|
||||
schema["required"] = list(schema["properties"].keys())
|
||||
|
||||
if "properties" in schema:
|
||||
for prop_schema in schema["properties"].values():
|
||||
self._add_additional_properties_recursive(prop_schema)
|
||||
|
||||
for key in ["anyOf", "allOf", "oneOf"]:
|
||||
if key in schema:
|
||||
for sub_schema in schema[key]:
|
||||
self._add_additional_properties_recursive(sub_schema)
|
||||
|
||||
if "not" in schema:
|
||||
self._add_additional_properties_recursive(schema["not"])
|
||||
|
||||
# Handle $defs/$ref
|
||||
if "$defs" in schema:
|
||||
for def_schema in schema["$defs"].values():
|
||||
self._add_additional_properties_recursive(def_schema)
|
||||
|
||||
return schema
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
input_dict = {}
|
||||
|
||||
input_dict["messages"] = [
|
||||
await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages
|
||||
]
|
||||
if fmt := request.response_format:
|
||||
if not isinstance(fmt, JsonSchemaResponseFormat):
|
||||
raise ValueError(
|
||||
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
|
||||
)
|
||||
|
||||
fmt = fmt.json_schema
|
||||
name = fmt["title"]
|
||||
del fmt["title"]
|
||||
fmt["additionalProperties"] = False
|
||||
|
||||
# Apply additionalProperties: False recursively to all objects
|
||||
fmt = self._add_additional_properties_recursive(fmt)
|
||||
|
||||
input_dict["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": name,
|
||||
"schema": fmt,
|
||||
"strict": self.json_schema_strict,
|
||||
},
|
||||
}
|
||||
if request.tools:
|
||||
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
||||
if request.tool_config.tool_choice:
|
||||
input_dict["tool_choice"] = (
|
||||
request.tool_config.tool_choice.value
|
||||
if isinstance(request.tool_config.tool_choice, ToolChoice)
|
||||
else request.tool_config.tool_choice
|
||||
)
|
||||
|
||||
return {
|
||||
"model": request.model,
|
||||
"api_key": self.get_api_key(),
|
||||
"api_base": self.api_base,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
provider_data = self.get_request_provider_data()
|
||||
key_field = self.provider_data_api_key_field
|
||||
if provider_data and getattr(provider_data, key_field, None):
|
||||
api_key = getattr(provider_data, key_field)
|
||||
else:
|
||||
api_key = self.api_key_from_config
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"API key is not set. Please provide a valid API key in the "
|
||||
"provider data header, e.g. x-llamastack-provider-data: "
|
||||
f'{{"{key_field}": "<API_KEY>"}}, or in the provider config.'
|
||||
)
|
||||
return api_key
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
# Convert input to list if it's a string
|
||||
input_list = [params.input] if isinstance(params.input, str) else params.input
|
||||
|
||||
# Call litellm embedding function
|
||||
# litellm.drop_params = True
|
||||
response = litellm.embedding(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
input=input_list,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
dimensions=params.dimensions,
|
||||
)
|
||||
|
||||
# Convert response to OpenAI format
|
||||
data = b64_encode_openai_embeddings_response(response.data, params.encoding_format)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response["usage"]["prompt_tokens"],
|
||||
total_tokens=response["usage"]["total_tokens"],
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=model_obj.provider_resource_id,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
prompt=params.prompt,
|
||||
best_of=params.best_of,
|
||||
echo=params.echo,
|
||||
frequency_penalty=params.frequency_penalty,
|
||||
logit_bias=params.logit_bias,
|
||||
logprobs=params.logprobs,
|
||||
max_tokens=params.max_tokens,
|
||||
n=params.n,
|
||||
presence_penalty=params.presence_penalty,
|
||||
seed=params.seed,
|
||||
stop=params.stop,
|
||||
stream=params.stream,
|
||||
stream_options=params.stream_options,
|
||||
temperature=params.temperature,
|
||||
top_p=params.top_p,
|
||||
user=params.user,
|
||||
suffix=params.suffix,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
)
|
||||
return await litellm.atext_completion(**request_params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
# Add usage tracking for streaming when telemetry is active
|
||||
from llama_stack.core.telemetry.tracing import get_current_span
|
||||
|
||||
stream_options = params.stream_options
|
||||
if params.stream and get_current_span() is not None:
|
||||
if stream_options is None:
|
||||
stream_options = {"include_usage": True}
|
||||
elif "include_usage" not in stream_options:
|
||||
stream_options = {**stream_options, "include_usage": True}
|
||||
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
messages=params.messages,
|
||||
frequency_penalty=params.frequency_penalty,
|
||||
function_call=params.function_call,
|
||||
functions=params.functions,
|
||||
logit_bias=params.logit_bias,
|
||||
logprobs=params.logprobs,
|
||||
max_completion_tokens=params.max_completion_tokens,
|
||||
max_tokens=params.max_tokens,
|
||||
n=params.n,
|
||||
parallel_tool_calls=params.parallel_tool_calls,
|
||||
presence_penalty=params.presence_penalty,
|
||||
response_format=params.response_format,
|
||||
seed=params.seed,
|
||||
stop=params.stop,
|
||||
stream=params.stream,
|
||||
stream_options=stream_options,
|
||||
temperature=params.temperature,
|
||||
tool_choice=params.tool_choice,
|
||||
tools=params.tools,
|
||||
top_logprobs=params.top_logprobs,
|
||||
top_p=params.top_p,
|
||||
user=params.user,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
)
|
||||
return await litellm.acompletion(**request_params)
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available via LiteLLM for the current
|
||||
provider (self.litellm_provider_name).
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
if self.litellm_provider_name not in litellm.models_by_provider:
|
||||
logger.error(f"Provider {self.litellm_provider_name} is not registered in litellm.")
|
||||
return False
|
||||
|
||||
return model in litellm.models_by_provider[self.litellm_provider_name]
|
||||
|
||||
|
||||
def b64_encode_openai_embeddings_response(
|
||||
response_data: list[dict], encoding_format: str | None = "float"
|
||||
) -> list[OpenAIEmbeddingData]:
|
||||
"""
|
||||
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
|
||||
"""
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response_data):
|
||||
if encoding_format == "base64":
|
||||
byte_array = bytearray()
|
||||
for embedding_value in embedding_data["embedding"]:
|
||||
byte_array.extend(struct.pack("f", float(embedding_value)))
|
||||
|
||||
response_embedding = base64.b64encode(byte_array).decode("utf-8")
|
||||
else:
|
||||
response_embedding = embedding_data["embedding"]
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=response_embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
|
@ -1,191 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.apis.common.errors import UnsupportedModelError
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference import (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class RemoteInferenceProviderConfig(BaseModel):
|
||||
allowed_models: list[str] | None = Field( # TODO: make this non-optional and give a list() default
|
||||
default=None,
|
||||
description="List of models that should be registered with the model registry. If None, all models are allowed.",
|
||||
)
|
||||
refresh_models: bool = Field(
|
||||
default=False,
|
||||
description="Whether to refresh models periodically from the provider",
|
||||
)
|
||||
auth_credential: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="Authentication credential for the provider",
|
||||
alias="api_key",
|
||||
)
|
||||
|
||||
|
||||
# TODO: this class is more confusing than useful right now. We need to make it
|
||||
# more closer to the Model class.
|
||||
class ProviderModelEntry(BaseModel):
|
||||
provider_model_id: str
|
||||
aliases: list[str] = Field(default_factory=list)
|
||||
llama_model: str | None = None
|
||||
model_type: ModelType = ModelType.llm
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
def build_hf_repo_model_entry(
|
||||
provider_model_id: str,
|
||||
model_descriptor: str,
|
||||
additional_aliases: list[str] | None = None,
|
||||
) -> ProviderModelEntry:
|
||||
aliases = [
|
||||
# NOTE: avoid HF aliases because they _cannot_ be unique across providers
|
||||
# get_huggingface_repo(model_descriptor),
|
||||
]
|
||||
if additional_aliases:
|
||||
aliases.extend(additional_aliases)
|
||||
aliases = [alias for alias in aliases if alias is not None]
|
||||
return ProviderModelEntry(
|
||||
provider_model_id=provider_model_id,
|
||||
aliases=aliases,
|
||||
llama_model=model_descriptor,
|
||||
)
|
||||
|
||||
|
||||
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||
__provider_id__: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_entries: list[ProviderModelEntry] | None = None,
|
||||
allowed_models: list[str] | None = None,
|
||||
):
|
||||
self.allowed_models = allowed_models if allowed_models else []
|
||||
|
||||
self.alias_to_provider_id_map = {}
|
||||
self.provider_id_to_llama_model_map = {}
|
||||
self.model_entries = model_entries or []
|
||||
for entry in self.model_entries:
|
||||
for alias in entry.aliases:
|
||||
self.alias_to_provider_id_map[alias] = entry.provider_model_id
|
||||
|
||||
# also add a mapping from provider model id to itself for easy lookup
|
||||
self.alias_to_provider_id_map[entry.provider_model_id] = entry.provider_model_id
|
||||
|
||||
if entry.llama_model:
|
||||
self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
|
||||
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
models = []
|
||||
for entry in self.model_entries:
|
||||
ids = [entry.provider_model_id] + entry.aliases
|
||||
for id in ids:
|
||||
if self.allowed_models and id not in self.allowed_models:
|
||||
continue
|
||||
models.append(
|
||||
Model(
|
||||
identifier=id,
|
||||
provider_resource_id=entry.provider_model_id,
|
||||
model_type=entry.model_type,
|
||||
metadata=entry.metadata,
|
||||
provider_id=self.__provider_id__,
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_provider_model_id(self, identifier: str) -> str | None:
|
||||
return self.alias_to_provider_id_map.get(identifier, None)
|
||||
|
||||
# TODO: why keep a separate llama model mapping?
|
||||
def get_llama_model(self, provider_model_id: str) -> str | None:
|
||||
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available from the provider (non-static check).
|
||||
|
||||
This is for subclassing purposes, so providers can check if a specific
|
||||
model is currently available for use through dynamic means (e.g., API calls).
|
||||
|
||||
This method should NOT check statically configured model entries in
|
||||
`self.alias_to_provider_id_map` - that is handled separately in register_model.
|
||||
|
||||
Default implementation returns False (no dynamic models available).
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
logger.info(
|
||||
f"check_model_availability is not implemented for {self.__class__.__name__}. Returning False by default."
|
||||
)
|
||||
return False
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
# Check if model is supported in static configuration
|
||||
supported_model_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
|
||||
# If not found in static config, check if it's available dynamically from provider
|
||||
if not supported_model_id:
|
||||
if await self.check_model_availability(model.provider_resource_id):
|
||||
supported_model_id = model.provider_resource_id
|
||||
else:
|
||||
# note: we cannot provide a complete list of supported models without
|
||||
# getting a complete list from the provider, so we return "..."
|
||||
all_supported_models = [*self.alias_to_provider_id_map.keys(), "..."]
|
||||
raise UnsupportedModelError(model.provider_resource_id, all_supported_models)
|
||||
|
||||
provider_resource_id = self.get_provider_model_id(model.model_id)
|
||||
if model.model_type == ModelType.embedding:
|
||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||
provider_resource_id = model.provider_resource_id
|
||||
if provider_resource_id:
|
||||
if provider_resource_id != supported_model_id: # be idempotent, only reject differences
|
||||
raise ValueError(
|
||||
f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first."
|
||||
)
|
||||
else:
|
||||
llama_model = model.metadata.get("llama_model")
|
||||
if llama_model:
|
||||
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
||||
if existing_llama_model:
|
||||
if existing_llama_model != llama_model:
|
||||
raise ValueError(
|
||||
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||
)
|
||||
else:
|
||||
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||
raise ValueError(
|
||||
f"Invalid llama_model '{llama_model}' specified in metadata. "
|
||||
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
|
||||
)
|
||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||
)
|
||||
|
||||
# Register the model alias, ensuring it maps to the correct provider model id
|
||||
self.alias_to_provider_id_map[model.model_id] = supported_model_id
|
||||
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
# model_id is the identifier, not the provider_resource_id
|
||||
# unfortunately, this ID can be of the form provider_id/model_id which
|
||||
# we never registered. TODO: fix this by significantly rewriting
|
||||
# registration and registry helper
|
||||
pass
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,494 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator, Iterable
|
||||
from typing import Any
|
||||
|
||||
from openai import NOT_GIVEN, AsyncOpenAI
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Model,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
OpenAIMessageParam,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import localize_image_content
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
||||
"""
|
||||
Mixin class that provides OpenAI-specific functionality for inference providers.
|
||||
This class handles direct OpenAI API calls using the AsyncOpenAI client.
|
||||
|
||||
This is an abstract base class that requires child classes to implement:
|
||||
- get_base_url(): Method to retrieve the OpenAI-compatible API base URL
|
||||
|
||||
The behavior of this class can be customized by child classes in the following ways:
|
||||
- overwrite_completion_id: If True, overwrites the 'id' field in OpenAI responses
|
||||
- download_images: If True, downloads images and converts to base64 for providers that require it
|
||||
- embedding_model_metadata: A dictionary mapping model IDs to their embedding metadata
|
||||
- construct_model_from_identifier: Method to construct a Model instance corresponding to the given identifier
|
||||
- provider_data_api_key_field: Optional field name in provider data to look for API key
|
||||
- list_provider_model_ids: Method to list available models from the provider
|
||||
- get_extra_client_params: Method to provide extra parameters to the AsyncOpenAI client
|
||||
|
||||
Expected Dependencies:
|
||||
- self.model_store: Injected by the Llama Stack distribution system at runtime.
|
||||
This provides model registry functionality for looking up registered models.
|
||||
The model_store is set in routing_tables/common.py during provider initialization.
|
||||
"""
|
||||
|
||||
# Allow extra fields so the routing infra can inject model_store, __provider_id__, etc.
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
config: RemoteInferenceProviderConfig
|
||||
|
||||
# Allow subclasses to control whether to overwrite the 'id' field in OpenAI responses
|
||||
# is overwritten with a client-side generated id.
|
||||
#
|
||||
# This is useful for providers that do not return a unique id in the response.
|
||||
overwrite_completion_id: bool = False
|
||||
|
||||
# Allow subclasses to control whether to download images and convert to base64
|
||||
# for providers that require base64 encoded images instead of URLs.
|
||||
download_images: bool = False
|
||||
|
||||
# Embedding model metadata for this provider
|
||||
# Can be set by subclasses or instances to provide embedding models
|
||||
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {}
|
||||
|
||||
# Cache of available models keyed by model ID
|
||||
# This is set in list_models() and used in check_model_availability()
|
||||
_model_cache: dict[str, Model] = {}
|
||||
|
||||
# List of allowed models for this provider, if empty all models allowed
|
||||
allowed_models: list[str] = []
|
||||
|
||||
# Optional field name in provider data to look for API key, which takes precedence
|
||||
provider_data_api_key_field: str | None = None
|
||||
|
||||
def get_api_key(self) -> str | None:
|
||||
"""
|
||||
Get the API key.
|
||||
|
||||
:return: The API key as a string, or None if not set
|
||||
"""
|
||||
if self.config.auth_credential is None:
|
||||
return None
|
||||
return self.config.auth_credential.get_secret_value()
|
||||
|
||||
@abstractmethod
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the OpenAI-compatible API base URL.
|
||||
|
||||
This method must be implemented by child classes to provide the base URL
|
||||
for the OpenAI API or compatible endpoints (e.g., "https://api.openai.com/v1").
|
||||
|
||||
:return: The base URL as a string
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_extra_client_params(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get any extra parameters to pass to the AsyncOpenAI client.
|
||||
|
||||
Child classes can override this method to provide additional parameters
|
||||
such as timeout settings, proxies, etc.
|
||||
|
||||
:return: A dictionary of extra parameters
|
||||
"""
|
||||
return {}
|
||||
|
||||
def construct_model_from_identifier(self, identifier: str) -> Model:
|
||||
"""
|
||||
Construct a Model instance corresponding to the given identifier
|
||||
|
||||
Child classes can override this to customize model typing/metadata.
|
||||
|
||||
:param identifier: The provider's model identifier
|
||||
:return: A Model instance
|
||||
"""
|
||||
if metadata := self.embedding_model_metadata.get(identifier):
|
||||
return Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=identifier,
|
||||
identifier=identifier,
|
||||
model_type=ModelType.embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
return Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=identifier,
|
||||
identifier=identifier,
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
"""
|
||||
List available models from the provider.
|
||||
|
||||
Child classes can override this method to provide a custom implementation
|
||||
for listing models. The default implementation uses the AsyncOpenAI client
|
||||
to list models from the OpenAI-compatible endpoint.
|
||||
|
||||
:return: An iterable of model IDs or None if not implemented
|
||||
"""
|
||||
client = self.client
|
||||
async with client:
|
||||
model_ids = [m.id async for m in client.models.list()]
|
||||
return model_ids
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
Initialize the OpenAI mixin.
|
||||
|
||||
This method provides a default implementation that does nothing.
|
||||
Subclasses can override this method to perform initialization tasks
|
||||
such as setting up clients, validating configurations, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""
|
||||
Shutdown the OpenAI mixin.
|
||||
|
||||
This method provides a default implementation that does nothing.
|
||||
Subclasses can override this method to perform cleanup tasks
|
||||
such as closing connections, releasing resources, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncOpenAI:
|
||||
"""
|
||||
Get an AsyncOpenAI client instance.
|
||||
|
||||
Uses the abstract methods get_api_key() and get_base_url() which must be
|
||||
implemented by child classes.
|
||||
|
||||
Users can also provide the API key via the provider data header, which
|
||||
is used instead of any config API key.
|
||||
"""
|
||||
|
||||
api_key = self._get_api_key_from_config_or_provider_data()
|
||||
if not api_key:
|
||||
message = "API key not provided."
|
||||
if self.provider_data_api_key_field:
|
||||
message += f' Please provide a valid API key in the provider data header, e.g. x-llamastack-provider-data: {{"{self.provider_data_api_key_field}": "<API_KEY>"}}.'
|
||||
raise ValueError(message)
|
||||
|
||||
return AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=self.get_base_url(),
|
||||
**self.get_extra_client_params(),
|
||||
)
|
||||
|
||||
def _get_api_key_from_config_or_provider_data(self) -> str | None:
|
||||
api_key = self.get_api_key()
|
||||
|
||||
if self.provider_data_api_key_field:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data and getattr(provider_data, self.provider_data_api_key_field, None):
|
||||
api_key = getattr(provider_data, self.provider_data_api_key_field)
|
||||
|
||||
return api_key
|
||||
|
||||
async def _get_provider_model_id(self, model: str) -> str:
|
||||
"""
|
||||
Get the provider-specific model ID from the model store.
|
||||
|
||||
This is a utility method that looks up the registered model and returns
|
||||
the provider_resource_id that should be used for actual API calls.
|
||||
|
||||
:param model: The registered model name/identifier
|
||||
:return: The provider-specific model ID (e.g., "gpt-4")
|
||||
"""
|
||||
# Look up the registered model to get the provider-specific model ID
|
||||
# self.model_store is injected by the distribution system at runtime
|
||||
model_obj: Model = await self.model_store.get_model(model) # type: ignore[attr-defined]
|
||||
# provider_resource_id is str | None, but we expect it to be str for OpenAI calls
|
||||
if model_obj.provider_resource_id is None:
|
||||
raise ValueError(f"Model {model} has no provider_resource_id")
|
||||
return model_obj.provider_resource_id
|
||||
|
||||
async def _maybe_overwrite_id(self, resp: Any, stream: bool | None) -> Any:
|
||||
if not self.overwrite_completion_id:
|
||||
return resp
|
||||
|
||||
new_id = f"cltsd-{uuid.uuid4()}"
|
||||
if stream:
|
||||
|
||||
async def _gen():
|
||||
async for chunk in resp:
|
||||
chunk.id = new_id
|
||||
yield chunk
|
||||
|
||||
return _gen()
|
||||
else:
|
||||
resp.id = new_id
|
||||
return resp
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
"""
|
||||
Direct OpenAI completion API call.
|
||||
"""
|
||||
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
||||
completion_kwargs = await prepare_openai_completion_params(
|
||||
model=await self._get_provider_model_id(params.model),
|
||||
prompt=params.prompt,
|
||||
best_of=params.best_of,
|
||||
echo=params.echo,
|
||||
frequency_penalty=params.frequency_penalty,
|
||||
logit_bias=params.logit_bias,
|
||||
logprobs=params.logprobs,
|
||||
max_tokens=params.max_tokens,
|
||||
n=params.n,
|
||||
presence_penalty=params.presence_penalty,
|
||||
seed=params.seed,
|
||||
stop=params.stop,
|
||||
stream=params.stream,
|
||||
stream_options=params.stream_options,
|
||||
temperature=params.temperature,
|
||||
top_p=params.top_p,
|
||||
user=params.user,
|
||||
suffix=params.suffix,
|
||||
)
|
||||
if extra_body := params.model_extra:
|
||||
completion_kwargs["extra_body"] = extra_body
|
||||
resp = await self.client.completions.create(**completion_kwargs)
|
||||
|
||||
return await self._maybe_overwrite_id(resp, params.stream) # type: ignore[no-any-return]
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""
|
||||
Direct OpenAI chat completion API call.
|
||||
"""
|
||||
messages = params.messages
|
||||
|
||||
if self.download_images:
|
||||
|
||||
async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam:
|
||||
if isinstance(m.content, list):
|
||||
for c in m.content:
|
||||
if c.type == "image_url" and c.image_url and c.image_url.url and "http" in c.image_url.url:
|
||||
localize_result = await localize_image_content(c.image_url.url)
|
||||
if localize_result is None:
|
||||
raise ValueError(
|
||||
f"Failed to localize image content from {c.image_url.url[:42]}{'...' if len(c.image_url.url) > 42 else ''}"
|
||||
)
|
||||
content, format = localize_result
|
||||
c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}"
|
||||
# else it's a string and we don't need to modify it
|
||||
return m
|
||||
|
||||
messages = [await _localize_image_url(m) for m in messages]
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=await self._get_provider_model_id(params.model),
|
||||
messages=messages,
|
||||
frequency_penalty=params.frequency_penalty,
|
||||
function_call=params.function_call,
|
||||
functions=params.functions,
|
||||
logit_bias=params.logit_bias,
|
||||
logprobs=params.logprobs,
|
||||
max_completion_tokens=params.max_completion_tokens,
|
||||
max_tokens=params.max_tokens,
|
||||
n=params.n,
|
||||
parallel_tool_calls=params.parallel_tool_calls,
|
||||
presence_penalty=params.presence_penalty,
|
||||
response_format=params.response_format,
|
||||
seed=params.seed,
|
||||
stop=params.stop,
|
||||
stream=params.stream,
|
||||
stream_options=params.stream_options,
|
||||
temperature=params.temperature,
|
||||
tool_choice=params.tool_choice,
|
||||
tools=params.tools,
|
||||
top_logprobs=params.top_logprobs,
|
||||
top_p=params.top_p,
|
||||
user=params.user,
|
||||
)
|
||||
|
||||
if extra_body := params.model_extra:
|
||||
request_params["extra_body"] = extra_body
|
||||
resp = await self.client.chat.completions.create(**request_params)
|
||||
|
||||
return await self._maybe_overwrite_id(resp, params.stream) # type: ignore[no-any-return]
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""
|
||||
Direct OpenAI embeddings API call.
|
||||
"""
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"model": await self._get_provider_model_id(params.model),
|
||||
"input": params.input,
|
||||
"encoding_format": params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
|
||||
"dimensions": params.dimensions if params.dimensions is not None else NOT_GIVEN,
|
||||
"user": params.user if params.user is not None else NOT_GIVEN,
|
||||
}
|
||||
|
||||
# Add extra_body if present
|
||||
extra_body = params.model_extra
|
||||
if extra_body:
|
||||
request_params["extra_body"] = extra_body
|
||||
|
||||
# Call OpenAI embeddings API with properly typed parameters
|
||||
response = await self.client.embeddings.create(**request_params)
|
||||
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=embedding_data.embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=params.model,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
###
|
||||
# ModelsProtocolPrivate implementation - provide model management functionality
|
||||
#
|
||||
# async def register_model(self, model: Model) -> Model: ...
|
||||
# async def unregister_model(self, model_id: str) -> None: ...
|
||||
#
|
||||
# async def list_models(self) -> list[Model] | None: ...
|
||||
# async def should_refresh_models(self) -> bool: ...
|
||||
##
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if not await self.check_model_availability(model.provider_model_id):
|
||||
raise ValueError(f"Model {model.provider_model_id} is not available from provider {self.__provider_id__}") # type: ignore[attr-defined]
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
return None
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
"""
|
||||
List available models from the provider's /v1/models endpoint augmented with static embedding model metadata.
|
||||
|
||||
Also, caches the models in self._model_cache for use in check_model_availability().
|
||||
|
||||
:return: A list of Model instances representing available models.
|
||||
"""
|
||||
self._model_cache = {}
|
||||
|
||||
api_key = self._get_api_key_from_config_or_provider_data()
|
||||
if not api_key:
|
||||
logger.debug(f"{self.__class__.__name__}.list_provider_model_ids() disabled because API key not provided")
|
||||
return None
|
||||
|
||||
try:
|
||||
iterable = await self.list_provider_model_ids()
|
||||
except Exception as e:
|
||||
logger.error(f"{self.__class__.__name__}.list_provider_model_ids() failed with: {e}")
|
||||
raise
|
||||
if not hasattr(iterable, "__iter__"):
|
||||
raise TypeError(
|
||||
f"Failed to list models: {self.__class__.__name__}.list_provider_model_ids() must return an iterable of "
|
||||
f"strings, but returned {type(iterable).__name__}"
|
||||
)
|
||||
|
||||
provider_models_ids = list(iterable)
|
||||
logger.info(f"{self.__class__.__name__}.list_provider_model_ids() returned {len(provider_models_ids)} models")
|
||||
|
||||
for provider_model_id in provider_models_ids:
|
||||
if not isinstance(provider_model_id, str):
|
||||
raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
|
||||
if self.allowed_models and provider_model_id not in self.allowed_models:
|
||||
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
|
||||
continue
|
||||
model = self.construct_model_from_identifier(provider_model_id)
|
||||
self._model_cache[provider_model_id] = model
|
||||
|
||||
return list(self._model_cache.values())
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available from the provider's /v1/models or pre-registered.
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically or pre-registered, False otherwise.
|
||||
"""
|
||||
# First check if the model is pre-registered in the model store
|
||||
if hasattr(self, "model_store") and self.model_store:
|
||||
qualified_model = f"{self.__provider_id__}/{model}" # type: ignore[attr-defined]
|
||||
if await self.model_store.has_model(qualified_model):
|
||||
return True
|
||||
|
||||
# Then check the provider's dynamic model cache
|
||||
if not self._model_cache:
|
||||
await self.list_models()
|
||||
return model in self._model_cache
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return self.config.refresh_models
|
||||
|
||||
#
|
||||
# The model_dump implementations are to avoid serializing the extra fields,
|
||||
# e.g. model_store, which are not pydantic.
|
||||
#
|
||||
|
||||
def _filter_fields(self, **kwargs):
|
||||
"""Helper to exclude extra fields from serialization."""
|
||||
# Exclude any extra fields stored in __pydantic_extra__
|
||||
if hasattr(self, "__pydantic_extra__") and self.__pydantic_extra__:
|
||||
exclude = kwargs.get("exclude", set())
|
||||
if not isinstance(exclude, set):
|
||||
exclude = set(exclude) if exclude else set()
|
||||
exclude.update(self.__pydantic_extra__.keys())
|
||||
kwargs["exclude"] = exclude
|
||||
return kwargs
|
||||
|
||||
def model_dump(self, **kwargs):
|
||||
"""Override to exclude extra fields from serialization."""
|
||||
kwargs = self._filter_fields(**kwargs)
|
||||
return super().model_dump(**kwargs)
|
||||
|
||||
def model_dump_json(self, **kwargs):
|
||||
"""Override to exclude extra fields from JSON serialization."""
|
||||
kwargs = self._filter_fields(**kwargs)
|
||||
return super().model_dump_json(**kwargs)
|
||||
|
|
@ -1,495 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
Message,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIFile,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SystemMessage,
|
||||
SystemMessageBehavior,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
RawContent,
|
||||
RawContentItem,
|
||||
RawMediaItem,
|
||||
RawMessage,
|
||||
RawTextItem,
|
||||
Role,
|
||||
StopReason,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
FunctionTagCustomToolGenerator,
|
||||
JsonCustomToolGenerator,
|
||||
PythonListCustomToolGenerator,
|
||||
SystemDefaultGenerator,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
|
||||
)
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
log = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
|
||||
messages: list[RawMessage]
|
||||
|
||||
|
||||
class CompletionRequestWithRawContent(CompletionRequest):
|
||||
content: RawContent
|
||||
|
||||
|
||||
def decode_assistant_message(content: str, stop_reason: StopReason) -> RawMessage:
|
||||
formatter = ChatFormat(Tokenizer.get_instance())
|
||||
return formatter.decode_assistant_message_from_content(content, stop_reason)
|
||||
|
||||
|
||||
def interleaved_content_as_str(
|
||||
content: Any,
|
||||
sep: str = " ",
|
||||
) -> str:
|
||||
if content is None:
|
||||
return ""
|
||||
|
||||
def _process(c) -> str:
|
||||
if isinstance(c, str):
|
||||
return c
|
||||
elif isinstance(c, TextContentItem) or isinstance(c, OpenAIChatCompletionContentPartTextParam):
|
||||
return c.text
|
||||
elif isinstance(c, ImageContentItem) or isinstance(c, OpenAIChatCompletionContentPartImageParam):
|
||||
return "<image>"
|
||||
elif isinstance(c, OpenAIFile):
|
||||
return "<file>"
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(c)}")
|
||||
|
||||
if isinstance(content, list):
|
||||
return sep.join(_process(c) for c in content)
|
||||
else:
|
||||
return _process(content)
|
||||
|
||||
|
||||
async def convert_request_to_raw(
|
||||
request: ChatCompletionRequest | CompletionRequest,
|
||||
) -> ChatCompletionRequestWithRawContent | CompletionRequestWithRawContent:
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
messages = []
|
||||
for m in request.messages:
|
||||
content = await interleaved_content_convert_to_raw(m.content)
|
||||
d = m.model_dump()
|
||||
d["content"] = content
|
||||
messages.append(RawMessage(**d))
|
||||
|
||||
d = request.model_dump()
|
||||
d["messages"] = messages
|
||||
request = ChatCompletionRequestWithRawContent(**d)
|
||||
else:
|
||||
d = request.model_dump()
|
||||
d["content"] = await interleaved_content_convert_to_raw(request.content)
|
||||
request = CompletionRequestWithRawContent(**d)
|
||||
|
||||
return request
|
||||
|
||||
|
||||
async def interleaved_content_convert_to_raw(
|
||||
content: InterleavedContent,
|
||||
) -> RawContent:
|
||||
"""Download content from URLs / files etc. so plain bytes can be sent to the model"""
|
||||
|
||||
async def _localize_single(c: str | InterleavedContentItem) -> str | RawContentItem:
|
||||
if isinstance(c, str):
|
||||
return RawTextItem(text=c)
|
||||
elif isinstance(c, TextContentItem):
|
||||
return RawTextItem(text=c.text)
|
||||
elif isinstance(c, ImageContentItem):
|
||||
image = c.image
|
||||
if image.url:
|
||||
# Load image bytes from URL
|
||||
if image.url.uri.startswith("data"):
|
||||
match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid data URL format, {image.url.uri[:40]}...")
|
||||
_, image_data = match.groups()
|
||||
data = base64.b64decode(image_data)
|
||||
elif image.url.uri.startswith("file://"):
|
||||
path = image.url.uri[len("file://") :]
|
||||
with open(path, "rb") as f:
|
||||
data = f.read() # type: ignore
|
||||
elif image.url.uri.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(image.url.uri)
|
||||
data = response.content
|
||||
else:
|
||||
raise ValueError("Unsupported URL type")
|
||||
elif image.data:
|
||||
# data is a base64 encoded string, decode it to bytes for RawMediaItem
|
||||
data = base64.b64decode(image.data)
|
||||
else:
|
||||
raise ValueError("No data or URL provided")
|
||||
|
||||
return RawMediaItem(data=data)
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(c)}")
|
||||
|
||||
if isinstance(content, list):
|
||||
return await asyncio.gather(*(_localize_single(c) for c in content))
|
||||
else:
|
||||
return await _localize_single(content)
|
||||
|
||||
|
||||
def content_has_media(content: InterleavedContent):
|
||||
def _has_media_content(c):
|
||||
return isinstance(c, ImageContentItem)
|
||||
|
||||
if isinstance(content, list):
|
||||
return any(_has_media_content(c) for c in content)
|
||||
else:
|
||||
return _has_media_content(content)
|
||||
|
||||
|
||||
def messages_have_media(messages: list[Message]):
|
||||
return any(content_has_media(m.content) for m in messages)
|
||||
|
||||
|
||||
def request_has_media(request: ChatCompletionRequest | CompletionRequest):
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
return messages_have_media(request.messages)
|
||||
else:
|
||||
return content_has_media(request.content)
|
||||
|
||||
|
||||
async def localize_image_content(uri: str) -> tuple[bytes, str] | None:
|
||||
if uri.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(uri)
|
||||
content = r.content
|
||||
content_type = r.headers.get("content-type")
|
||||
if content_type:
|
||||
format = content_type.split("/")[-1]
|
||||
else:
|
||||
format = "png"
|
||||
|
||||
return content, format
|
||||
elif uri.startswith("data"):
|
||||
# data:image/{format};base64,{data}
|
||||
match = re.match(r"data:image/(\w+);base64,(.+)", uri)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid data URL format, {uri[:40]}...")
|
||||
fmt, image_data = match.groups()
|
||||
content = base64.b64decode(image_data)
|
||||
return content, fmt
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
async def convert_image_content_to_url(
|
||||
media: ImageContentItem, download: bool = False, include_format: bool = True
|
||||
) -> str:
|
||||
image = media.image
|
||||
if image.url and (not download or image.url.uri.startswith("data")):
|
||||
return image.url.uri
|
||||
|
||||
if image.data:
|
||||
# data is a base64 encoded string, decode it to bytes first
|
||||
# TODO(mf): do this more efficiently, decode less
|
||||
content = base64.b64decode(image.data)
|
||||
pil_image = PIL_Image.open(io.BytesIO(content))
|
||||
format = pil_image.format
|
||||
else:
|
||||
localize_result = await localize_image_content(image.url.uri)
|
||||
if localize_result is None:
|
||||
raise ValueError(f"Failed to localize image content from {image.url.uri}")
|
||||
content, format = localize_result
|
||||
|
||||
if include_format:
|
||||
return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8")
|
||||
else:
|
||||
return base64.b64encode(content).decode("utf-8")
|
||||
|
||||
|
||||
def augment_content_with_response_format_prompt(response_format, content):
|
||||
if fmt_prompt := response_format_prompt(response_format):
|
||||
if isinstance(content, list):
|
||||
return content + [TextContentItem(text=fmt_prompt)]
|
||||
elif isinstance(content, str):
|
||||
return [TextContentItem(text=content), TextContentItem(text=fmt_prompt)]
|
||||
else:
|
||||
return [content, TextContentItem(text=fmt_prompt)]
|
||||
|
||||
return content
|
||||
|
||||
|
||||
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
|
||||
messages = chat_completion_request_to_messages(request, llama_model)
|
||||
request.messages = messages
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_dialog_prompt(
|
||||
request.messages,
|
||||
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
|
||||
)
|
||||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
||||
async def chat_completion_request_to_model_input_info(
|
||||
request: ChatCompletionRequest, llama_model: str
|
||||
) -> tuple[str, int]:
|
||||
messages = chat_completion_request_to_messages(request, llama_model)
|
||||
request.messages = messages
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||
model_input = formatter.encode_dialog_prompt(
|
||||
request.messages,
|
||||
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
|
||||
)
|
||||
return (
|
||||
formatter.tokenizer.decode(model_input.tokens),
|
||||
len(model_input.tokens),
|
||||
)
|
||||
|
||||
|
||||
def chat_completion_request_to_messages(
|
||||
request: ChatCompletionRequest,
|
||||
llama_model: str,
|
||||
) -> list[Message]:
|
||||
"""Reads chat completion request and augments the messages to handle tools.
|
||||
For eg. for llama_3_1, add system message with the appropriate tools or
|
||||
add user messsage for custom tools, etc.
|
||||
"""
|
||||
assert llama_model is not None, "llama_model is required"
|
||||
model = resolve_model(llama_model)
|
||||
if model is None:
|
||||
log.error(f"Could not resolve model {llama_model}")
|
||||
return request.messages
|
||||
|
||||
allowed_models = supported_inference_models()
|
||||
descriptors = [m.descriptor() for m in allowed_models]
|
||||
if model.descriptor() not in descriptors:
|
||||
log.error(f"Unsupported inference model? {model.descriptor()}")
|
||||
return request.messages
|
||||
|
||||
if model.model_family == ModelFamily.llama3_1 or (
|
||||
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
|
||||
):
|
||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama_3_1(request)
|
||||
elif model.model_family in (
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
):
|
||||
# llama3.2, llama3.3 follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator)
|
||||
elif model.model_family == ModelFamily.llama4:
|
||||
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4)
|
||||
else:
|
||||
messages = request.messages
|
||||
|
||||
if fmt_prompt := response_format_prompt(request.response_format):
|
||||
messages.append(UserMessage(content=fmt_prompt))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def response_format_prompt(fmt: ResponseFormat | None):
|
||||
if not fmt:
|
||||
return None
|
||||
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
return f"Please respond in JSON format with the schema: {json.dumps(fmt.json_schema)}"
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
raise NotImplementedError("Grammar response format not supported yet")
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
|
||||
def augment_messages_for_tools_llama_3_1(
|
||||
request: ChatCompletionRequest,
|
||||
) -> list[Message]:
|
||||
existing_messages = request.messages
|
||||
existing_system_message = None
|
||||
if existing_messages[0].role == Role.system.value:
|
||||
existing_system_message = existing_messages.pop(0)
|
||||
|
||||
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
|
||||
|
||||
messages = []
|
||||
|
||||
default_gen = SystemDefaultGenerator()
|
||||
default_template = default_gen.gen()
|
||||
|
||||
sys_content = ""
|
||||
|
||||
tool_template = None
|
||||
if request.tools:
|
||||
tool_gen = BuiltinToolGenerator()
|
||||
tool_template = tool_gen.gen(request.tools)
|
||||
|
||||
sys_content += tool_template.render()
|
||||
sys_content += "\n"
|
||||
|
||||
sys_content += default_template.render()
|
||||
|
||||
if existing_system_message:
|
||||
# TODO: this fn is needed in many places
|
||||
def _process(c):
|
||||
if isinstance(c, str):
|
||||
return c
|
||||
else:
|
||||
return "<media>"
|
||||
|
||||
sys_content += "\n"
|
||||
|
||||
if isinstance(existing_system_message.content, str):
|
||||
sys_content += _process(existing_system_message.content)
|
||||
elif isinstance(existing_system_message.content, list):
|
||||
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
|
||||
|
||||
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
|
||||
if tool_choice_prompt:
|
||||
sys_content += "\n" + tool_choice_prompt
|
||||
|
||||
messages.append(SystemMessage(content=sys_content))
|
||||
|
||||
has_custom_tools = request.tools is not None and any(isinstance(dfn.tool_name, str) for dfn in request.tools)
|
||||
if has_custom_tools:
|
||||
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
|
||||
if fmt == ToolPromptFormat.json:
|
||||
tool_gen = JsonCustomToolGenerator()
|
||||
elif fmt == ToolPromptFormat.function_tag:
|
||||
tool_gen = FunctionTagCustomToolGenerator()
|
||||
else:
|
||||
raise ValueError(f"Non supported ToolPromptFormat {fmt}")
|
||||
|
||||
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
|
||||
custom_template = tool_gen.gen(custom_tools)
|
||||
messages.append(UserMessage(content=custom_template.render()))
|
||||
|
||||
# Add back existing messages from the request
|
||||
messages += existing_messages
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def augment_messages_for_tools_llama(
|
||||
request: ChatCompletionRequest,
|
||||
custom_tool_prompt_generator,
|
||||
) -> list[Message]:
|
||||
existing_messages = request.messages
|
||||
existing_system_message = None
|
||||
if existing_messages[0].role == Role.system.value:
|
||||
existing_system_message = existing_messages.pop(0)
|
||||
|
||||
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
|
||||
|
||||
sys_content = ""
|
||||
custom_tools, builtin_tools = [], []
|
||||
for t in request.tools:
|
||||
if isinstance(t.tool_name, str):
|
||||
custom_tools.append(t)
|
||||
else:
|
||||
builtin_tools.append(t)
|
||||
|
||||
if builtin_tools:
|
||||
tool_gen = BuiltinToolGenerator()
|
||||
tool_template = tool_gen.gen(builtin_tools)
|
||||
|
||||
sys_content += tool_template.render()
|
||||
sys_content += "\n"
|
||||
|
||||
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
|
||||
if custom_tools:
|
||||
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
|
||||
if fmt != ToolPromptFormat.python_list:
|
||||
raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}")
|
||||
|
||||
system_prompt = None
|
||||
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
||||
system_prompt = existing_system_message.content
|
||||
|
||||
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
|
||||
|
||||
sys_content += tool_template.render()
|
||||
sys_content += "\n"
|
||||
|
||||
if existing_system_message and (
|
||||
request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools
|
||||
):
|
||||
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
|
||||
|
||||
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
|
||||
if tool_choice_prompt:
|
||||
sys_content += "\n" + tool_choice_prompt
|
||||
|
||||
messages = [SystemMessage(content=sys_content.strip("\n")), *existing_messages]
|
||||
return messages
|
||||
|
||||
|
||||
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str:
|
||||
if tool_choice == ToolChoice.auto:
|
||||
return ""
|
||||
elif tool_choice == ToolChoice.required:
|
||||
return "You MUST use one of the provided functions/tools to answer the user query."
|
||||
elif tool_choice == ToolChoice.none:
|
||||
# tools are already not passed in
|
||||
return ""
|
||||
else:
|
||||
# specific tool
|
||||
return f"You MUST use the tool `{tool_choice}` to answer the user query."
|
||||
|
||||
|
||||
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
|
||||
llama_model = resolve_model(model)
|
||||
if llama_model is None:
|
||||
log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format")
|
||||
return ToolPromptFormat.json
|
||||
|
||||
if llama_model.model_family == ModelFamily.llama3_1 or (
|
||||
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
|
||||
):
|
||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||
return ToolPromptFormat.json
|
||||
elif llama_model.model_family in (
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
ModelFamily.llama4,
|
||||
):
|
||||
# llama3.2 and llama3.3 models follow the same tool prompt format
|
||||
return ToolPromptFormat.python_list
|
||||
else:
|
||||
return ToolPromptFormat.json
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .kvstore import * # noqa: F401, F403
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class KVStore(Protocol):
|
||||
# TODO: make the value type bytes instead of str
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: ...
|
||||
|
||||
async def get(self, key: str) -> str | None: ...
|
||||
|
||||
async def delete(self, key: str) -> None: ...
|
||||
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ...
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ...
|
||||
|
|
@ -1,39 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
MongoDBKVStoreConfig,
|
||||
PostgresKVStoreConfig,
|
||||
RedisKVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
StorageBackendType,
|
||||
)
|
||||
|
||||
KVStoreConfig = Annotated[
|
||||
RedisKVStoreConfig | SqliteKVStoreConfig | PostgresKVStoreConfig | MongoDBKVStoreConfig, Field(discriminator="type")
|
||||
]
|
||||
|
||||
|
||||
def get_pip_packages(store_config: dict | KVStoreConfig) -> list[str]:
|
||||
"""Get pip packages for KV store config, handling both dict and object cases."""
|
||||
if isinstance(store_config, dict):
|
||||
store_type = store_config.get("type")
|
||||
if store_type == StorageBackendType.KV_SQLITE.value:
|
||||
return SqliteKVStoreConfig.pip_packages()
|
||||
elif store_type == StorageBackendType.KV_POSTGRES.value:
|
||||
return PostgresKVStoreConfig.pip_packages()
|
||||
elif store_type == StorageBackendType.KV_REDIS.value:
|
||||
return RedisKVStoreConfig.pip_packages()
|
||||
elif store_type == StorageBackendType.KV_MONGODB.value:
|
||||
return MongoDBKVStoreConfig.pip_packages()
|
||||
else:
|
||||
raise ValueError(f"Unknown KV store type: {store_type}")
|
||||
else:
|
||||
return store_config.pip_packages()
|
||||
|
|
@ -1,97 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType
|
||||
|
||||
from .api import KVStore
|
||||
from .config import KVStoreConfig
|
||||
|
||||
|
||||
def kvstore_dependencies():
|
||||
"""
|
||||
Returns all possible kvstore dependencies for registry/provider specifications.
|
||||
|
||||
NOTE: For specific kvstore implementations, use config.pip_packages instead.
|
||||
This function returns the union of all dependencies for cases where the specific
|
||||
kvstore type is not known at declaration time (e.g., provider registries).
|
||||
"""
|
||||
return ["aiosqlite", "psycopg2-binary", "redis", "pymongo"]
|
||||
|
||||
|
||||
class InmemoryKVStoreImpl(KVStore):
|
||||
def __init__(self):
|
||||
self._store = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
return self._store.get(key)
|
||||
|
||||
async def set(self, key: str, value: str) -> None:
|
||||
self._store[key] = value
|
||||
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key]
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
"""Get all keys in the given range."""
|
||||
return [key for key in self._store.keys() if key >= start_key and key < end_key]
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
del self._store[key]
|
||||
|
||||
|
||||
_KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {}
|
||||
|
||||
|
||||
def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
|
||||
"""Register the set of available KV store backends for reference resolution."""
|
||||
global _KVSTORE_BACKENDS
|
||||
|
||||
_KVSTORE_BACKENDS.clear()
|
||||
for name, cfg in backends.items():
|
||||
_KVSTORE_BACKENDS[name] = cfg
|
||||
|
||||
|
||||
async def kvstore_impl(reference: KVStoreReference) -> KVStore:
|
||||
backend_name = reference.backend
|
||||
|
||||
backend_config = _KVSTORE_BACKENDS.get(backend_name)
|
||||
if backend_config is None:
|
||||
raise ValueError(f"Unknown KVStore backend '{backend_name}'. Registered backends: {sorted(_KVSTORE_BACKENDS)}")
|
||||
|
||||
config = backend_config.model_copy()
|
||||
config.namespace = reference.namespace
|
||||
|
||||
if config.type == StorageBackendType.KV_REDIS.value:
|
||||
from .redis import RedisKVStoreImpl
|
||||
|
||||
impl = RedisKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_SQLITE.value:
|
||||
from .sqlite import SqliteKVStoreImpl
|
||||
|
||||
impl = SqliteKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_POSTGRES.value:
|
||||
from .postgres import PostgresKVStoreImpl
|
||||
|
||||
impl = PostgresKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_MONGODB.value:
|
||||
from .mongodb import MongoDBKVStoreImpl
|
||||
|
||||
impl = MongoDBKVStoreImpl(config)
|
||||
else:
|
||||
raise ValueError(f"Unknown kvstore type {config.type}")
|
||||
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .mongodb import MongoDBKVStoreImpl
|
||||
|
||||
__all__ = ["MongoDBKVStoreImpl"]
|
||||
|
|
@ -1,82 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pymongo import AsyncMongoClient
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
from ..config import MongoDBKVStoreConfig
|
||||
|
||||
log = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class MongoDBKVStoreImpl(KVStore):
|
||||
def __init__(self, config: MongoDBKVStoreConfig):
|
||||
self.config = config
|
||||
self.conn: AsyncMongoClient | None = None
|
||||
|
||||
@property
|
||||
def collection(self) -> AsyncCollection:
|
||||
if self.conn is None:
|
||||
raise RuntimeError("MongoDB connection is not initialized")
|
||||
return self.conn[self.config.db][self.config.collection_name]
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
conn_creds = {
|
||||
"host": self.config.host,
|
||||
"port": self.config.port,
|
||||
"username": self.config.user,
|
||||
"password": self.config.password,
|
||||
}
|
||||
conn_creds = {k: v for k, v in conn_creds.items() if v is not None}
|
||||
self.conn = AsyncMongoClient(**conn_creds)
|
||||
except Exception as e:
|
||||
log.exception("Could not connect to MongoDB database server")
|
||||
raise RuntimeError("Could not connect to MongoDB database server") from e
|
||||
|
||||
def _namespaced_key(self, key: str) -> str:
|
||||
if not self.config.namespace:
|
||||
return key
|
||||
return f"{self.config.namespace}:{key}"
|
||||
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
update_query = {"$set": {"value": value, "expiration": expiration}}
|
||||
await self.collection.update_one({"key": key}, update_query, upsert=True)
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
key = self._namespaced_key(key)
|
||||
query = {"key": key}
|
||||
result = await self.collection.find_one(query, {"value": 1, "_id": 0})
|
||||
return result["value"] if result else None
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
await self.collection.delete_one({"key": key})
|
||||
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
query = {
|
||||
"key": {"$gte": start_key, "$lt": end_key},
|
||||
}
|
||||
cursor = self.collection.find(query, {"value": 1, "_id": 0}).sort("key", 1)
|
||||
result = []
|
||||
async for doc in cursor:
|
||||
result.append(doc["value"])
|
||||
return result
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
query = {"key": {"$gte": start_key, "$lt": end_key}}
|
||||
cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1)
|
||||
return [doc["key"] for doc in cursor]
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .postgres import PostgresKVStoreImpl # noqa: F401 F403
|
||||
|
|
@ -1,114 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import DictCursor
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..api import KVStore
|
||||
from ..config import PostgresKVStoreConfig
|
||||
|
||||
log = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class PostgresKVStoreImpl(KVStore):
|
||||
def __init__(self, config: PostgresKVStoreConfig):
|
||||
self.config = config
|
||||
self.conn = None
|
||||
self.cursor = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
self.conn = psycopg2.connect(
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
database=self.config.db,
|
||||
user=self.config.user,
|
||||
password=self.config.password,
|
||||
sslmode=self.config.ssl_mode,
|
||||
sslrootcert=self.config.ca_cert_path,
|
||||
)
|
||||
self.conn.autocommit = True
|
||||
self.cursor = self.conn.cursor(cursor_factory=DictCursor)
|
||||
|
||||
# Create table if it doesn't exist
|
||||
self.cursor.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.config.table_name} (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT,
|
||||
expiration TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception("Could not connect to PostgreSQL database server")
|
||||
raise RuntimeError("Could not connect to PostgreSQL database server") from e
|
||||
|
||||
def _namespaced_key(self, key: str) -> str:
|
||||
if not self.config.namespace:
|
||||
return key
|
||||
return f"{self.config.namespace}:{key}"
|
||||
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
self.cursor.execute(
|
||||
f"""
|
||||
INSERT INTO {self.config.table_name} (key, value, expiration)
|
||||
VALUES (%s, %s, %s)
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET value = EXCLUDED.value, expiration = EXCLUDED.expiration
|
||||
""",
|
||||
(key, value, expiration),
|
||||
)
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
key = self._namespaced_key(key)
|
||||
self.cursor.execute(
|
||||
f"""
|
||||
SELECT value FROM {self.config.table_name}
|
||||
WHERE key = %s
|
||||
AND (expiration IS NULL OR expiration > NOW())
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
result = self.cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
self.cursor.execute(
|
||||
f"DELETE FROM {self.config.table_name} WHERE key = %s",
|
||||
(key,),
|
||||
)
|
||||
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
|
||||
self.cursor.execute(
|
||||
f"""
|
||||
SELECT value FROM {self.config.table_name}
|
||||
WHERE key >= %s AND key < %s
|
||||
AND (expiration IS NULL OR expiration > NOW())
|
||||
ORDER BY key
|
||||
""",
|
||||
(start_key, end_key),
|
||||
)
|
||||
return [row[0] for row in self.cursor.fetchall()]
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
|
||||
self.cursor.execute(
|
||||
f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s",
|
||||
(start_key, end_key),
|
||||
)
|
||||
return [row[0] for row in self.cursor.fetchall()]
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .redis import RedisKVStoreImpl # noqa: F401
|
||||
|
|
@ -1,76 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from ..api import KVStore
|
||||
from ..config import RedisKVStoreConfig
|
||||
|
||||
|
||||
class RedisKVStoreImpl(KVStore):
|
||||
def __init__(self, config: RedisKVStoreConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.redis = Redis.from_url(self.config.url)
|
||||
|
||||
def _namespaced_key(self, key: str) -> str:
|
||||
if not self.config.namespace:
|
||||
return key
|
||||
return f"{self.config.namespace}:{key}"
|
||||
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
await self.redis.set(key, value)
|
||||
if expiration:
|
||||
await self.redis.expireat(key, expiration)
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
key = self._namespaced_key(key)
|
||||
value = await self.redis.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
await self.redis.ttl(key)
|
||||
return value
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
await self.redis.delete(key)
|
||||
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
cursor = 0
|
||||
pattern = start_key + "*" # Match all keys starting with start_key prefix
|
||||
matching_keys = []
|
||||
while True:
|
||||
cursor, keys = await self.redis.scan(cursor, match=pattern, count=1000)
|
||||
|
||||
for key in keys:
|
||||
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
|
||||
if start_key <= key_str <= end_key:
|
||||
matching_keys.append(key)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
# Then fetch all values in a single MGET call
|
||||
if matching_keys:
|
||||
values = await self.redis.mget(matching_keys)
|
||||
return [
|
||||
value.decode("utf-8") if isinstance(value, bytes) else value for value in values if value is not None
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
"""Get all keys in the given range."""
|
||||
matching_keys = await self.redis.zrangebylex(self.namespace, f"[{start_key}", f"[{end_key}")
|
||||
if not matching_keys:
|
||||
return []
|
||||
return [k.decode("utf-8") for k in matching_keys]
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .sqlite import SqliteKVStoreImpl # noqa: F401
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SqliteControlPlaneConfig(BaseModel):
|
||||
db_path: str = Field(
|
||||
description="File path for the sqlite database",
|
||||
)
|
||||
table_name: str = Field(
|
||||
default="llamastack_control_plane",
|
||||
description="Table into which all the keys will be placed",
|
||||
)
|
||||
|
|
@ -1,174 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..api import KVStore
|
||||
from ..config import SqliteKVStoreConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class SqliteKVStoreImpl(KVStore):
|
||||
def __init__(self, config: SqliteKVStoreConfig):
|
||||
self.db_path = config.db_path
|
||||
self.table_name = "kvstore"
|
||||
self._conn: aiosqlite.Connection | None = None
|
||||
|
||||
def __str__(self):
|
||||
return f"SqliteKVStoreImpl(db_path={self.db_path}, table_name={self.table_name})"
|
||||
|
||||
def _is_memory_db(self) -> bool:
|
||||
"""Check if this is an in-memory database."""
|
||||
return self.db_path == ":memory:" or "mode=memory" in self.db_path
|
||||
|
||||
async def initialize(self):
|
||||
# Skip directory creation for in-memory databases and file: URIs
|
||||
if not self._is_memory_db() and not self.db_path.startswith("file:"):
|
||||
db_dir = os.path.dirname(self.db_path)
|
||||
if db_dir: # Only create if there's a directory component
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
|
||||
# Only use persistent connection for in-memory databases
|
||||
# File-based databases use connection-per-operation to avoid hangs
|
||||
if self._is_memory_db():
|
||||
self._conn = await aiosqlite.connect(self.db_path)
|
||||
await self._conn.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT,
|
||||
expiration TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
await self._conn.commit()
|
||||
else:
|
||||
# For file-based databases, just create the table
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT,
|
||||
expiration TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async def shutdown(self):
|
||||
"""Close the persistent connection (only for in-memory databases)."""
|
||||
if self._conn:
|
||||
await self._conn.close()
|
||||
self._conn = None
|
||||
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||
if self._conn:
|
||||
# In-memory database with persistent connection
|
||||
await self._conn.execute(
|
||||
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
|
||||
(key, value, expiration),
|
||||
)
|
||||
await self._conn.commit()
|
||||
else:
|
||||
# File-based database with connection per operation
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(
|
||||
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
|
||||
(key, value, expiration),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
if self._conn:
|
||||
# In-memory database with persistent connection
|
||||
async with self._conn.execute(
|
||||
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
value, expiration = row
|
||||
if not isinstance(value, str):
|
||||
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
|
||||
return None
|
||||
return value
|
||||
else:
|
||||
# File-based database with connection per operation
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute(
|
||||
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
value, expiration = row
|
||||
if not isinstance(value, str):
|
||||
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
|
||||
return None
|
||||
return value
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
if self._conn:
|
||||
# In-memory database with persistent connection
|
||||
await self._conn.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
||||
await self._conn.commit()
|
||||
else:
|
||||
# File-based database with connection per operation
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
||||
await db.commit()
|
||||
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
if self._conn:
|
||||
# In-memory database with persistent connection
|
||||
async with self._conn.execute(
|
||||
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
(start_key, end_key),
|
||||
) as cursor:
|
||||
result = []
|
||||
async for row in cursor:
|
||||
_, value, _ = row
|
||||
result.append(value)
|
||||
return result
|
||||
else:
|
||||
# File-based database with connection per operation
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute(
|
||||
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
(start_key, end_key),
|
||||
) as cursor:
|
||||
result = []
|
||||
async for row in cursor:
|
||||
_, value, _ = row
|
||||
result.append(value)
|
||||
return result
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
"""Get all keys in the given range."""
|
||||
if self._conn:
|
||||
# In-memory database with persistent connection
|
||||
cursor = await self._conn.execute(
|
||||
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
(start_key, end_key),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return [row[0] for row in rows]
|
||||
else:
|
||||
# File-based database with connection per operation
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
cursor = await db.execute(
|
||||
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
(start_key, end_key),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return [row[0] for row in rows]
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
|
||||
|
||||
def data_url_from_file(file_path: str) -> URL:
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as file:
|
||||
file_content = file.read()
|
||||
|
||||
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
|
||||
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||
|
||||
return URL(uri=data_url)
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,332 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIEmbeddingsRequestWithExtraBody
|
||||
from llama_stack.apis.tools import RAGDocument
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
|
||||
log = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class ChunkForDeletion(BaseModel):
|
||||
"""Information needed to delete a chunk from a vector store.
|
||||
|
||||
:param chunk_id: The ID of the chunk to delete
|
||||
:param document_id: The ID of the document this chunk belongs to
|
||||
"""
|
||||
|
||||
chunk_id: str
|
||||
document_id: str
|
||||
|
||||
|
||||
# Constants for reranker types
|
||||
RERANKER_TYPE_RRF = "rrf"
|
||||
RERANKER_TYPE_WEIGHTED = "weighted"
|
||||
RERANKER_TYPE_NORMALIZED = "normalized"
|
||||
|
||||
|
||||
def parse_pdf(data: bytes) -> str:
|
||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||
pdf_bytes = io.BytesIO(data)
|
||||
from pypdf import PdfReader
|
||||
|
||||
pdf_reader = PdfReader(pdf_bytes)
|
||||
return "\n".join([page.extract_text() for page in pdf_reader.pages])
|
||||
|
||||
|
||||
def parse_data_url(data_url: str):
|
||||
data_url_pattern = re.compile(
|
||||
r"^"
|
||||
r"data:"
|
||||
r"(?P<mimetype>[\w/\-+.]+)"
|
||||
r"(?P<charset>;charset=(?P<encoding>[\w-]+))?"
|
||||
r"(?P<base64>;base64)?"
|
||||
r",(?P<data>.*)"
|
||||
r"$",
|
||||
re.DOTALL,
|
||||
)
|
||||
match = data_url_pattern.match(data_url)
|
||||
if not match:
|
||||
raise ValueError("Invalid Data URL format")
|
||||
|
||||
parts = match.groupdict()
|
||||
parts["is_base64"] = bool(parts["base64"])
|
||||
return parts
|
||||
|
||||
|
||||
def content_from_data(data_url: str) -> str:
|
||||
parts = parse_data_url(data_url)
|
||||
data = parts["data"]
|
||||
|
||||
if parts["is_base64"]:
|
||||
data = base64.b64decode(data)
|
||||
else:
|
||||
data = unquote(data)
|
||||
encoding = parts["encoding"] or "utf-8"
|
||||
data = data.encode(encoding)
|
||||
return content_from_data_and_mime_type(data, parts["mimetype"], parts.get("encoding", None))
|
||||
|
||||
|
||||
def content_from_data_and_mime_type(data: bytes | str, mime_type: str | None, encoding: str | None = None) -> str:
|
||||
if isinstance(data, bytes):
|
||||
if not encoding:
|
||||
import chardet
|
||||
|
||||
detected = chardet.detect(data)
|
||||
encoding = detected["encoding"]
|
||||
|
||||
mime_category = mime_type.split("/")[0] if mime_type else None
|
||||
if mime_category == "text":
|
||||
# For text-based files (including CSV, MD)
|
||||
encodings_to_try = [encoding]
|
||||
if encoding != "utf-8":
|
||||
encodings_to_try.append("utf-8")
|
||||
first_exception = None
|
||||
for encoding in encodings_to_try:
|
||||
try:
|
||||
return data.decode(encoding)
|
||||
except UnicodeDecodeError as e:
|
||||
if first_exception is None:
|
||||
first_exception = e
|
||||
log.warning(f"Decoding failed with {encoding}: {e}")
|
||||
# raise the origional exception, if we got here there was at least 1 exception
|
||||
log.error(f"Could not decode data as any of {encodings_to_try}")
|
||||
raise first_exception
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
return parse_pdf(data)
|
||||
|
||||
else:
|
||||
log.error("Could not extract content from data_url properly.")
|
||||
return ""
|
||||
|
||||
|
||||
async def content_from_doc(doc: RAGDocument) -> str:
|
||||
if isinstance(doc.content, URL):
|
||||
if doc.content.uri.startswith("data:"):
|
||||
return content_from_data(doc.content.uri)
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(doc.content.uri)
|
||||
if doc.mime_type == "application/pdf":
|
||||
return parse_pdf(r.content)
|
||||
return r.text
|
||||
elif isinstance(doc.content, str):
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
if pattern.match(doc.content):
|
||||
if doc.content.startswith("data:"):
|
||||
return content_from_data(doc.content)
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(doc.content)
|
||||
if doc.mime_type == "application/pdf":
|
||||
return parse_pdf(r.content)
|
||||
return r.text
|
||||
return doc.content
|
||||
else:
|
||||
# will raise ValueError if the content is not List[InterleavedContent] or InterleavedContent
|
||||
return interleaved_content_as_str(doc.content)
|
||||
|
||||
|
||||
def make_overlapped_chunks(
|
||||
document_id: str, text: str, window_len: int, overlap_len: int, metadata: dict[str, Any]
|
||||
) -> list[Chunk]:
|
||||
default_tokenizer = "DEFAULT_TIKTOKEN_TOKENIZER"
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
tokens = tokenizer.encode(text, bos=False, eos=False)
|
||||
try:
|
||||
metadata_string = str(metadata)
|
||||
except Exception as e:
|
||||
raise ValueError("Failed to serialize metadata to string") from e
|
||||
|
||||
metadata_tokens = tokenizer.encode(metadata_string, bos=False, eos=False)
|
||||
|
||||
chunks = []
|
||||
for i in range(0, len(tokens), window_len - overlap_len):
|
||||
toks = tokens[i : i + window_len]
|
||||
chunk = tokenizer.decode(toks)
|
||||
chunk_window = f"{i}-{i + len(toks)}"
|
||||
chunk_id = generate_chunk_id(chunk, text, chunk_window)
|
||||
chunk_metadata = metadata.copy()
|
||||
chunk_metadata["chunk_id"] = chunk_id
|
||||
chunk_metadata["document_id"] = document_id
|
||||
chunk_metadata["token_count"] = len(toks)
|
||||
chunk_metadata["metadata_token_count"] = len(metadata_tokens)
|
||||
|
||||
backend_chunk_metadata = ChunkMetadata(
|
||||
chunk_id=chunk_id,
|
||||
document_id=document_id,
|
||||
source=metadata.get("source", None),
|
||||
created_timestamp=metadata.get("created_timestamp", int(time.time())),
|
||||
updated_timestamp=int(time.time()),
|
||||
chunk_window=chunk_window,
|
||||
chunk_tokenizer=default_tokenizer,
|
||||
chunk_embedding_model=None, # This will be set in `VectorStoreWithIndex.insert_chunks`
|
||||
content_token_count=len(toks),
|
||||
metadata_token_count=len(metadata_tokens),
|
||||
)
|
||||
|
||||
# chunk is a string
|
||||
chunks.append(
|
||||
Chunk(
|
||||
content=chunk,
|
||||
metadata=chunk_metadata,
|
||||
chunk_metadata=backend_chunk_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def _validate_embedding(embedding: NDArray, index: int, expected_dimension: int):
|
||||
"""Helper method to validate embedding format and dimensions"""
|
||||
if not isinstance(embedding, (list | np.ndarray)):
|
||||
raise ValueError(f"Embedding at index {index} must be a list or numpy array, got {type(embedding)}")
|
||||
|
||||
if isinstance(embedding, np.ndarray):
|
||||
if not np.issubdtype(embedding.dtype, np.number):
|
||||
raise ValueError(f"Embedding at index {index} contains non-numeric values")
|
||||
else:
|
||||
if not all(isinstance(e, (float | int | np.number)) for e in embedding):
|
||||
raise ValueError(f"Embedding at index {index} contains non-numeric values")
|
||||
|
||||
if len(embedding) != expected_dimension:
|
||||
raise ValueError(f"Embedding at index {index} has dimension {len(embedding)}, expected {expected_dimension}")
|
||||
|
||||
|
||||
class EmbeddingIndex(ABC):
|
||||
@abstractmethod
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorStoreWithIndex:
|
||||
vector_store: VectorStore
|
||||
index: EmbeddingIndex
|
||||
inference_api: Api.inference
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
chunks: list[Chunk],
|
||||
) -> None:
|
||||
chunks_to_embed = []
|
||||
for i, c in enumerate(chunks):
|
||||
if c.embedding is None:
|
||||
chunks_to_embed.append(c)
|
||||
if c.chunk_metadata:
|
||||
c.chunk_metadata.chunk_embedding_model = self.vector_store.embedding_model
|
||||
c.chunk_metadata.chunk_embedding_dimension = self.vector_store.embedding_dimension
|
||||
else:
|
||||
_validate_embedding(c.embedding, i, self.vector_store.embedding_dimension)
|
||||
|
||||
if chunks_to_embed:
|
||||
params = OpenAIEmbeddingsRequestWithExtraBody(
|
||||
model=self.vector_store.embedding_model,
|
||||
input=[c.content for c in chunks_to_embed],
|
||||
)
|
||||
resp = await self.inference_api.openai_embeddings(params)
|
||||
for c, data in zip(chunks_to_embed, resp.data, strict=False):
|
||||
c.embedding = data.embedding
|
||||
|
||||
embeddings = np.array([c.embedding for c in chunks], dtype=np.float32)
|
||||
await self.index.add_chunks(chunks, embeddings)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
if params is None:
|
||||
params = {}
|
||||
k = params.get("max_chunks", 3)
|
||||
mode = params.get("mode")
|
||||
score_threshold = params.get("score_threshold", 0.0)
|
||||
|
||||
ranker = params.get("ranker")
|
||||
if ranker is None:
|
||||
reranker_type = RERANKER_TYPE_RRF
|
||||
reranker_params = {"impact_factor": 60.0}
|
||||
else:
|
||||
strategy = ranker.get("strategy", "rrf")
|
||||
if strategy == "weighted":
|
||||
weights = ranker.get("params", {}).get("weights", [0.5, 0.5])
|
||||
reranker_type = RERANKER_TYPE_WEIGHTED
|
||||
reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5}
|
||||
elif strategy == "normalized":
|
||||
reranker_type = RERANKER_TYPE_NORMALIZED
|
||||
else:
|
||||
reranker_type = RERANKER_TYPE_RRF
|
||||
k_value = ranker.get("params", {}).get("k", 60.0)
|
||||
reranker_params = {"impact_factor": k_value}
|
||||
|
||||
query_string = interleaved_content_as_str(query)
|
||||
if mode == "keyword":
|
||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||
|
||||
params = OpenAIEmbeddingsRequestWithExtraBody(
|
||||
model=self.vector_store.embedding_model,
|
||||
input=[query_string],
|
||||
)
|
||||
embeddings_response = await self.inference_api.openai_embeddings(params)
|
||||
query_vector = np.array(embeddings_response.data[0].embedding, dtype=np.float32)
|
||||
if mode == "hybrid":
|
||||
return await self.index.query_hybrid(
|
||||
query_vector, query_string, k, score_threshold, reranker_type, reranker_params
|
||||
)
|
||||
else:
|
||||
return await self.index.query_vector(query_vector, k, score_threshold)
|
||||
|
|
@ -1,43 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
|
||||
|
||||
def paginate_records(
|
||||
records: list[dict[str, Any]],
|
||||
start_index: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""Helper function to handle pagination of records consistently across implementations.
|
||||
Inspired by stripe's pagination: https://docs.stripe.com/api/pagination
|
||||
|
||||
:param records: List of records to paginate
|
||||
:param start_index: The starting index (0-based). If None, starts from beginning.
|
||||
:param limit: Number of items to return. If None or -1, returns all items.
|
||||
:return: PaginatedResponse with the paginated data
|
||||
"""
|
||||
# Handle special case for fetching all rows
|
||||
if limit is None or limit == -1:
|
||||
return PaginatedResponse(
|
||||
data=records,
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
# Use offset-based pagination
|
||||
start_index = start_index or 0
|
||||
end_index = min(start_index + limit, len(records))
|
||||
page_data = records[start_index:end_index]
|
||||
|
||||
# Calculate if there are more records
|
||||
has_more = end_index < len(records)
|
||||
|
||||
return PaginatedResponse(
|
||||
data=page_data,
|
||||
has_more=has_more,
|
||||
)
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,354 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Order,
|
||||
)
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIDeleteResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectWithInput,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqlStoreReference, StorageBackendType
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
||||
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from ..sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
||||
|
||||
class _OpenAIResponseObjectWithInputAndMessages(OpenAIResponseObjectWithInput):
|
||||
"""Internal class for storing responses with chat completion messages.
|
||||
|
||||
This extends the public OpenAIResponseObjectWithInput with messages field
|
||||
for internal storage. The messages field is not exposed in the public API.
|
||||
|
||||
The messages field is optional for backward compatibility with responses
|
||||
stored before this feature was added.
|
||||
"""
|
||||
|
||||
messages: list[OpenAIMessageParam] | None = None
|
||||
|
||||
|
||||
class ResponsesStore:
|
||||
def __init__(
|
||||
self,
|
||||
reference: ResponsesStoreReference | SqlStoreReference,
|
||||
policy: list[AccessRule],
|
||||
):
|
||||
if isinstance(reference, ResponsesStoreReference):
|
||||
self.reference = reference
|
||||
else:
|
||||
self.reference = ResponsesStoreReference(**reference.model_dump())
|
||||
|
||||
self.policy = policy
|
||||
self.sql_store = None
|
||||
self.enable_write_queue = True
|
||||
|
||||
# Async write queue and worker control
|
||||
self._queue: (
|
||||
asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput], list[OpenAIMessageParam]]] | None
|
||||
) = None
|
||||
self._worker_tasks: list[asyncio.Task[Any]] = []
|
||||
self._max_write_queue_size: int = self.reference.max_write_queue_size
|
||||
self._num_writers: int = max(1, self.reference.num_writers)
|
||||
|
||||
async def initialize(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
base_store = sqlstore_impl(self.reference)
|
||||
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
||||
|
||||
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
|
||||
if backend_config is None:
|
||||
raise ValueError(
|
||||
f"Unregistered SQL backend '{self.reference.backend}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
||||
)
|
||||
if backend_config.type == StorageBackendType.SQL_SQLITE:
|
||||
self.enable_write_queue = False
|
||||
await self.sql_store.create_table(
|
||||
"openai_responses",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"response_object": ColumnType.JSON,
|
||||
"model": ColumnType.STRING,
|
||||
},
|
||||
)
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"conversation_messages",
|
||||
{
|
||||
"conversation_id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"messages": ColumnType.JSON,
|
||||
},
|
||||
)
|
||||
|
||||
if self.enable_write_queue:
|
||||
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||
for _ in range(self._num_writers):
|
||||
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||
else:
|
||||
logger.debug("Write queue disabled for SQLite to avoid concurrency issues")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._worker_tasks:
|
||||
return
|
||||
if self._queue is not None:
|
||||
await self._queue.join()
|
||||
for t in self._worker_tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
for t in self._worker_tasks:
|
||||
try:
|
||||
await t
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._worker_tasks.clear()
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Wait for all queued writes to complete. Useful for testing."""
|
||||
if self.enable_write_queue and self._queue is not None:
|
||||
await self._queue.join()
|
||||
|
||||
async def store_response_object(
|
||||
self,
|
||||
response_object: OpenAIResponseObject,
|
||||
input: list[OpenAIResponseInput],
|
||||
messages: list[OpenAIMessageParam],
|
||||
) -> None:
|
||||
if self.enable_write_queue:
|
||||
if self._queue is None:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
try:
|
||||
self._queue.put_nowait((response_object, input, messages))
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '<unknown>')}")
|
||||
await self._queue.put((response_object, input, messages))
|
||||
else:
|
||||
await self._write_response_object(response_object, input, messages)
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
assert self._queue is not None
|
||||
while True:
|
||||
try:
|
||||
item = await self._queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
response_object, input, messages = item
|
||||
try:
|
||||
await self._write_response_object(response_object, input, messages)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Error writing response object: {e}")
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _write_response_object(
|
||||
self,
|
||||
response_object: OpenAIResponseObject,
|
||||
input: list[OpenAIResponseInput],
|
||||
messages: list[OpenAIMessageParam],
|
||||
) -> None:
|
||||
if self.sql_store is None:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
|
||||
data = response_object.model_dump()
|
||||
data["input"] = [input_item.model_dump() for input_item in input]
|
||||
data["messages"] = [msg.model_dump() for msg in messages]
|
||||
|
||||
await self.sql_store.insert(
|
||||
"openai_responses",
|
||||
{
|
||||
"id": data["id"],
|
||||
"created_at": data["created_at"],
|
||||
"model": data["model"],
|
||||
"response_object": data,
|
||||
},
|
||||
)
|
||||
|
||||
async def list_responses(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 50,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
"""
|
||||
List responses from the database.
|
||||
|
||||
:param after: The ID of the last response to return.
|
||||
:param limit: The maximum number of responses to return.
|
||||
:param model: The model to filter by.
|
||||
:param order: The order to sort the responses by.
|
||||
"""
|
||||
if not self.sql_store:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
|
||||
if not order:
|
||||
order = Order.desc
|
||||
|
||||
where_conditions = {}
|
||||
if model:
|
||||
where_conditions["model"] = model
|
||||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_responses",
|
||||
where=where_conditions if where_conditions else None,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
|
||||
return ListOpenAIResponseObject(
|
||||
data=data,
|
||||
has_more=paginated_result.has_more,
|
||||
first_id=data[0].id if data else "",
|
||||
last_id=data[-1].id if data else "",
|
||||
)
|
||||
|
||||
async def get_response_object(self, response_id: str) -> _OpenAIResponseObjectWithInputAndMessages:
|
||||
"""
|
||||
Get a response object with automatic access control checking.
|
||||
"""
|
||||
if not self.sql_store:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one(
|
||||
"openai_responses",
|
||||
where={"id": response_id},
|
||||
)
|
||||
|
||||
if not row:
|
||||
# SecureSqlStore will return None if record doesn't exist OR access is denied
|
||||
# This provides security by not revealing whether the record exists
|
||||
raise ValueError(f"Response with id {response_id} not found") from None
|
||||
|
||||
return _OpenAIResponseObjectWithInputAndMessages(**row["response_object"])
|
||||
|
||||
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
if not self.sql_store:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
|
||||
if not row:
|
||||
raise ValueError(f"Response with id {response_id} not found")
|
||||
await self.sql_store.delete("openai_responses", where={"id": response_id})
|
||||
return OpenAIDeleteResponseObject(id=response_id)
|
||||
|
||||
async def list_response_input_items(
|
||||
self,
|
||||
response_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
include: list[str] | None = None,
|
||||
limit: int | None = 20,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseInputItem:
|
||||
"""
|
||||
List input items for a given response.
|
||||
|
||||
:param response_id: The ID of the response to retrieve input items for.
|
||||
:param after: An item ID to list items after, used for pagination.
|
||||
:param before: An item ID to list items before, used for pagination.
|
||||
:param include: Additional fields to include in the response.
|
||||
:param limit: A limit on the number of objects to be returned.
|
||||
:param order: The order to return the input items in.
|
||||
"""
|
||||
if include:
|
||||
raise NotImplementedError("Include is not supported yet")
|
||||
if before and after:
|
||||
raise ValueError("Cannot specify both 'before' and 'after' parameters")
|
||||
|
||||
response_with_input_and_messages = await self.get_response_object(response_id)
|
||||
items = response_with_input_and_messages.input
|
||||
|
||||
if order == Order.desc:
|
||||
items = list(reversed(items))
|
||||
|
||||
start_index = 0
|
||||
end_index = len(items)
|
||||
|
||||
if after or before:
|
||||
for i, item in enumerate(items):
|
||||
item_id = getattr(item, "id", None)
|
||||
if after and item_id == after:
|
||||
start_index = i + 1
|
||||
if before and item_id == before:
|
||||
end_index = i
|
||||
break
|
||||
|
||||
if after and start_index == 0:
|
||||
raise ValueError(f"Input item with id '{after}' not found for response '{response_id}'")
|
||||
if before and end_index == len(items):
|
||||
raise ValueError(f"Input item with id '{before}' not found for response '{response_id}'")
|
||||
|
||||
items = items[start_index:end_index]
|
||||
|
||||
# Apply limit
|
||||
if limit is not None:
|
||||
items = items[:limit]
|
||||
|
||||
return ListOpenAIResponseInputItem(data=items)
|
||||
|
||||
async def store_conversation_messages(self, conversation_id: str, messages: list[OpenAIMessageParam]) -> None:
|
||||
"""Store messages for a conversation.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param messages: List of OpenAI message parameters to store.
|
||||
"""
|
||||
if not self.sql_store:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
|
||||
# Serialize messages to dict format for JSON storage
|
||||
messages_data = [msg.model_dump() for msg in messages]
|
||||
|
||||
# Upsert: try insert first, update if exists
|
||||
try:
|
||||
await self.sql_store.insert(
|
||||
table="conversation_messages",
|
||||
data={"conversation_id": conversation_id, "messages": messages_data},
|
||||
)
|
||||
except Exception:
|
||||
# If insert fails due to ID conflict, update existing record
|
||||
await self.sql_store.update(
|
||||
table="conversation_messages",
|
||||
data={"messages": messages_data},
|
||||
where={"conversation_id": conversation_id},
|
||||
)
|
||||
|
||||
logger.debug(f"Stored {len(messages)} messages for conversation {conversation_id}")
|
||||
|
||||
async def get_conversation_messages(self, conversation_id: str) -> list[OpenAIMessageParam] | None:
|
||||
"""Get stored messages for a conversation.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:returns: List of OpenAI message parameters, or None if no messages stored.
|
||||
"""
|
||||
if not self.sql_store:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
|
||||
record = await self.sql_store.fetch_one(
|
||||
table="conversation_messages",
|
||||
where={"conversation_id": conversation_id},
|
||||
)
|
||||
|
||||
if record is None:
|
||||
return None
|
||||
|
||||
# Deserialize messages from JSON storage
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
adapter = TypeAdapter(list[OpenAIMessageParam])
|
||||
return adapter.validate_python(record["messages"])
|
||||
|
|
@ -1,270 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import functools
|
||||
import threading
|
||||
from collections.abc import Callable, Coroutine, Iterable
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
# TODO: revisit the list of possible statuses when defining a more coherent
|
||||
# Jobs API for all API flows; e.g. do we need new vs scheduled?
|
||||
class JobStatus(Enum):
|
||||
new = "new"
|
||||
scheduled = "scheduled"
|
||||
running = "running"
|
||||
failed = "failed"
|
||||
completed = "completed"
|
||||
|
||||
|
||||
type JobID = str
|
||||
type JobType = str
|
||||
|
||||
|
||||
class JobArtifact(BaseModel):
|
||||
type: JobType
|
||||
name: str
|
||||
# TODO: uri should be a reference to /files API; revisit when /files is implemented
|
||||
uri: str | None = None
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
JobHandler = Callable[
|
||||
[Callable[[str], None], Callable[[JobStatus], None], Callable[[JobArtifact], None]], Coroutine[Any, Any, None]
|
||||
]
|
||||
|
||||
|
||||
type LogMessage = tuple[datetime, str]
|
||||
|
||||
|
||||
_COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
|
||||
|
||||
|
||||
class Job:
|
||||
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler):
|
||||
super().__init__()
|
||||
self.id = job_id
|
||||
self._type = job_type
|
||||
self._handler = handler
|
||||
self._artifacts: list[JobArtifact] = []
|
||||
self._logs: list[LogMessage] = []
|
||||
self._state_transitions: list[tuple[datetime, JobStatus]] = [(datetime.now(UTC), JobStatus.new)]
|
||||
|
||||
@property
|
||||
def handler(self) -> JobHandler:
|
||||
return self._handler
|
||||
|
||||
@property
|
||||
def status(self) -> JobStatus:
|
||||
return self._state_transitions[-1][1]
|
||||
|
||||
@status.setter
|
||||
def status(self, status: JobStatus):
|
||||
if status in _COMPLETED_STATUSES and self.status in _COMPLETED_STATUSES:
|
||||
raise ValueError(f"Job is already in a completed state ({self.status})")
|
||||
if self.status == status:
|
||||
return
|
||||
self._state_transitions.append((datetime.now(UTC), status))
|
||||
|
||||
@property
|
||||
def artifacts(self) -> list[JobArtifact]:
|
||||
return self._artifacts
|
||||
|
||||
def register_artifact(self, artifact: JobArtifact) -> None:
|
||||
self._artifacts.append(artifact)
|
||||
|
||||
def _find_state_transition_date(self, status: Iterable[JobStatus]) -> datetime | None:
|
||||
for date, s in reversed(self._state_transitions):
|
||||
if s in status:
|
||||
return date
|
||||
return None
|
||||
|
||||
@property
|
||||
def scheduled_at(self) -> datetime | None:
|
||||
return self._find_state_transition_date([JobStatus.scheduled])
|
||||
|
||||
@property
|
||||
def started_at(self) -> datetime | None:
|
||||
return self._find_state_transition_date([JobStatus.running])
|
||||
|
||||
@property
|
||||
def completed_at(self) -> datetime | None:
|
||||
return self._find_state_transition_date(_COMPLETED_STATUSES)
|
||||
|
||||
@property
|
||||
def logs(self) -> list[LogMessage]:
|
||||
return self._logs[:]
|
||||
|
||||
def append_log(self, message: LogMessage) -> None:
|
||||
self._logs.append(message)
|
||||
|
||||
# TODO: implement
|
||||
def cancel(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _SchedulerBackend(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def shutdown(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def schedule(
|
||||
self,
|
||||
job: Job,
|
||||
on_log_message_cb: Callable[[str], None],
|
||||
on_status_change_cb: Callable[[JobStatus], None],
|
||||
on_artifact_collected_cb: Callable[[JobArtifact], None],
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _NaiveSchedulerBackend(_SchedulerBackend):
|
||||
def __init__(self, timeout: int = 5):
|
||||
self._timeout = timeout
|
||||
self._loop = asyncio.new_event_loop()
|
||||
# There may be performance implications of using threads due to Python
|
||||
# GIL; may need to measure if it's a real problem though
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._loop.run_forever()
|
||||
|
||||
# TODO: When stopping the loop, give tasks a chance to finish
|
||||
# TODO: should we explicitly inform jobs of pending stoppage?
|
||||
|
||||
# cancel all tasks
|
||||
for task in asyncio.all_tasks(self._loop):
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
self._loop.close()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
self._thread.join()
|
||||
|
||||
# TODO: decouple scheduling and running the job
|
||||
def schedule(
|
||||
self,
|
||||
job: Job,
|
||||
on_log_message_cb: Callable[[str], None],
|
||||
on_status_change_cb: Callable[[JobStatus], None],
|
||||
on_artifact_collected_cb: Callable[[JobArtifact], None],
|
||||
) -> None:
|
||||
async def do():
|
||||
try:
|
||||
job.status = JobStatus.running
|
||||
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
|
||||
except Exception as e:
|
||||
on_log_message_cb(str(e))
|
||||
job.status = JobStatus.failed
|
||||
logger.exception(f"Job {job.id} failed.")
|
||||
|
||||
asyncio.run_coroutine_threadsafe(do(), self._loop)
|
||||
|
||||
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
|
||||
pass
|
||||
|
||||
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
||||
pass
|
||||
|
||||
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
||||
pass
|
||||
|
||||
|
||||
_BACKENDS = {
|
||||
"naive": _NaiveSchedulerBackend,
|
||||
}
|
||||
|
||||
|
||||
def _get_backend_impl(backend: str) -> _SchedulerBackend:
|
||||
try:
|
||||
return _BACKENDS[backend]()
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Unknown backend {backend}") from e
|
||||
|
||||
|
||||
class Scheduler:
|
||||
def __init__(self, backend: str = "naive"):
|
||||
# TODO: if server crashes, job states are lost; we need to persist jobs on disc
|
||||
self._jobs: dict[JobID, Job] = {}
|
||||
self._backend = _get_backend_impl(backend)
|
||||
|
||||
def _on_log_message_cb(self, job: Job, message: str) -> None:
|
||||
msg = (datetime.now(UTC), message)
|
||||
# At least for the time being, until there's a better way to expose
|
||||
# logs to users, log messages on console
|
||||
logger.info(f"Job {job.id}: {message}")
|
||||
job.append_log(msg)
|
||||
self._backend.on_log_message_cb(job, msg)
|
||||
|
||||
def _on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
||||
job.status = status
|
||||
self._backend.on_status_change_cb(job, status)
|
||||
|
||||
def _on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
||||
job.register_artifact(artifact)
|
||||
self._backend.on_artifact_collected_cb(job, artifact)
|
||||
|
||||
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID:
|
||||
job = Job(type_, job_id, handler)
|
||||
if job.id in self._jobs:
|
||||
raise ValueError(f"Job {job.id} already exists")
|
||||
|
||||
self._jobs[job.id] = job
|
||||
job.status = JobStatus.scheduled
|
||||
self._backend.schedule(
|
||||
job,
|
||||
functools.partial(self._on_log_message_cb, job),
|
||||
functools.partial(self._on_status_change_cb, job),
|
||||
functools.partial(self._on_artifact_collected_cb, job),
|
||||
)
|
||||
|
||||
return job.id
|
||||
|
||||
def cancel(self, job_id: JobID) -> None:
|
||||
self.get_job(job_id).cancel()
|
||||
|
||||
def get_job(self, job_id: JobID) -> Job:
|
||||
try:
|
||||
return self._jobs[job_id]
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Job {job_id} not found") from e
|
||||
|
||||
def get_jobs(self, type_: JobType | None = None) -> list[Job]:
|
||||
jobs = list(self._jobs.values())
|
||||
if type_:
|
||||
jobs = [job for job in jobs if job._type == type_]
|
||||
return jobs
|
||||
|
||||
async def shutdown(self):
|
||||
# TODO: also cancel jobs once implemented
|
||||
await self._backend.shutdown()
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,75 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import statistics
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import AggregationFunctionType
|
||||
|
||||
|
||||
def aggregate_accuracy(scoring_results: list[ScoringResultRow]) -> dict[str, Any]:
|
||||
num_correct = sum(result["score"] for result in scoring_results)
|
||||
avg_score = num_correct / len(scoring_results)
|
||||
|
||||
return {
|
||||
"accuracy": avg_score,
|
||||
"num_correct": num_correct,
|
||||
"num_total": len(scoring_results),
|
||||
}
|
||||
|
||||
|
||||
def aggregate_average(scoring_results: list[ScoringResultRow]) -> dict[str, Any]:
|
||||
return {
|
||||
"average": sum(result["score"] for result in scoring_results if result["score"] is not None)
|
||||
/ len([_ for _ in scoring_results if _["score"] is not None]),
|
||||
}
|
||||
|
||||
|
||||
def aggregate_weighted_average(scoring_results: list[ScoringResultRow]) -> dict[str, Any]:
|
||||
return {
|
||||
"weighted_average": sum(
|
||||
result["score"] * result["weight"]
|
||||
for result in scoring_results
|
||||
if result["score"] is not None and result["weight"] is not None
|
||||
)
|
||||
/ sum(result["weight"] for result in scoring_results if result["weight"] is not None),
|
||||
}
|
||||
|
||||
|
||||
def aggregate_categorical_count(
|
||||
scoring_results: list[ScoringResultRow],
|
||||
) -> dict[str, Any]:
|
||||
scores = [str(r["score"]) for r in scoring_results]
|
||||
unique_scores = sorted(set(scores))
|
||||
return {"categorical_count": {s: scores.count(s) for s in unique_scores}}
|
||||
|
||||
|
||||
def aggregate_median(scoring_results: list[ScoringResultRow]) -> dict[str, Any]:
|
||||
scores = [r["score"] for r in scoring_results if r["score"] is not None]
|
||||
median = statistics.median(scores) if scores else None
|
||||
return {"median": median}
|
||||
|
||||
|
||||
# TODO: decide whether we want to make aggregation functions as a registerable resource
|
||||
AGGREGATION_FUNCTIONS = {
|
||||
AggregationFunctionType.accuracy: aggregate_accuracy,
|
||||
AggregationFunctionType.average: aggregate_average,
|
||||
AggregationFunctionType.weighted_average: aggregate_weighted_average,
|
||||
AggregationFunctionType.categorical_count: aggregate_categorical_count,
|
||||
AggregationFunctionType.median: aggregate_median,
|
||||
}
|
||||
|
||||
|
||||
def aggregate_metrics(
|
||||
scoring_results: list[ScoringResultRow], metrics: list[AggregationFunctionType]
|
||||
) -> dict[str, Any]:
|
||||
agg_results = {}
|
||||
for metric in metrics:
|
||||
if metric not in AGGREGATION_FUNCTIONS:
|
||||
raise ValueError(f"Aggregation function {metric} not found")
|
||||
agg_fn = AGGREGATION_FUNCTIONS[metric]
|
||||
agg_results[metric] = agg_fn(scoring_results)
|
||||
return agg_results
|
||||
|
|
@ -1,114 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringFnParams, ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
|
||||
|
||||
|
||||
class BaseScoringFn(ABC):
|
||||
"""
|
||||
Base interface class for Scoring Functions.
|
||||
Each scoring function needs to implement the following methods:
|
||||
- score_row(self, row)
|
||||
- aggregate(self, scoring_fn_results)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
@abstractmethod
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def aggregate(
|
||||
self,
|
||||
scoring_results: list[ScoringResultRow],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def score(
|
||||
self,
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> list[ScoringResultRow]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RegisteredBaseScoringFn(BaseScoringFn):
|
||||
"""
|
||||
Interface for native scoring functions that are registered in LlamaStack.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
def get_supported_scoring_fn_defs(self) -> list[ScoringFn]:
|
||||
return list(self.supported_fn_defs_registry.values())
|
||||
|
||||
def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None:
|
||||
if scoring_fn.identifier in self.supported_fn_defs_registry:
|
||||
raise ValueError(f"Scoring function def with identifier {scoring_fn.identifier} already exists.")
|
||||
self.supported_fn_defs_registry[scoring_fn.identifier] = scoring_fn
|
||||
|
||||
def unregister_scoring_fn_def(self, scoring_fn_id: str) -> None:
|
||||
if scoring_fn_id not in self.supported_fn_defs_registry:
|
||||
raise ValueError(f"Scoring function def with identifier {scoring_fn_id} does not exist.")
|
||||
del self.supported_fn_defs_registry[scoring_fn_id]
|
||||
|
||||
@abstractmethod
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def aggregate(
|
||||
self,
|
||||
scoring_results: list[ScoringResultRow],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> dict[str, Any]:
|
||||
params = self.supported_fn_defs_registry[scoring_fn_identifier].params
|
||||
if scoring_params is not None:
|
||||
if params is None:
|
||||
params = scoring_params
|
||||
else:
|
||||
params.aggregation_functions = scoring_params.aggregation_functions
|
||||
|
||||
aggregation_functions = []
|
||||
if params and hasattr(params, "aggregation_functions") and params.aggregation_functions:
|
||||
aggregation_functions.extend(params.aggregation_functions)
|
||||
return aggregate_metrics(scoring_results, aggregation_functions)
|
||||
|
||||
async def score(
|
||||
self,
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> list[ScoringResultRow]:
|
||||
return [await self.score_row(input_row, scoring_fn_identifier, scoring_params) for input_row in input_rows]
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import contextlib
|
||||
import signal
|
||||
from collections.abc import Iterator
|
||||
from types import FrameType
|
||||
|
||||
|
||||
class TimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def time_limit(seconds: float) -> Iterator[None]:
|
||||
def signal_handler(signum: int, frame: FrameType | None) -> None:
|
||||
raise TimeoutError("Timed out!")
|
||||
|
||||
signal.setitimer(signal.ITIMER_REAL, seconds)
|
||||
signal.signal(signal.SIGALRM, signal_handler)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.setitimer(signal.ITIMER_REAL, 0)
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,128 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
|
||||
|
||||
class ColumnType(Enum):
|
||||
INTEGER = "INTEGER"
|
||||
STRING = "STRING"
|
||||
TEXT = "TEXT"
|
||||
FLOAT = "FLOAT"
|
||||
BOOLEAN = "BOOLEAN"
|
||||
JSON = "JSON"
|
||||
DATETIME = "DATETIME"
|
||||
|
||||
|
||||
class ColumnDefinition(BaseModel):
|
||||
type: ColumnType
|
||||
primary_key: bool = False
|
||||
nullable: bool = True
|
||||
default: Any = None
|
||||
|
||||
|
||||
class SqlStore(Protocol):
|
||||
"""
|
||||
A protocol for a SQL store.
|
||||
"""
|
||||
|
||||
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
|
||||
"""
|
||||
Create a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
||||
"""
|
||||
Insert a row or batch of rows into a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any] | None = None,
|
||||
where_sql: str | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
cursor: tuple[str, str] | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""
|
||||
Fetch all rows from a table with optional cursor-based pagination.
|
||||
|
||||
:param table: The table name
|
||||
:param where: Simple key-value WHERE conditions
|
||||
:param where_sql: Raw SQL WHERE clause for complex queries
|
||||
:param limit: Maximum number of records to return
|
||||
:param order_by: List of (column, order) tuples for sorting
|
||||
:param cursor: Tuple of (key_column, cursor_id) for pagination (None for first page)
|
||||
Requires order_by with exactly one column when used
|
||||
:return: PaginatedResult with data and has_more flag
|
||||
|
||||
Note: Cursor pagination only supports single-column ordering for simplicity.
|
||||
Multi-column ordering is allowed without cursor but will raise an error with cursor.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def fetch_one(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any] | None = None,
|
||||
where_sql: str | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Fetch one row from a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def update(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
where: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Update a row in a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Delete a row from a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def add_column_if_not_exists(
|
||||
self,
|
||||
table: str,
|
||||
column_name: str,
|
||||
column_type: ColumnType,
|
||||
nullable: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Add a column to an existing table if the column doesn't already exist.
|
||||
|
||||
This is useful for table migrations when adding new functionality.
|
||||
If the table doesn't exist, this method should do nothing.
|
||||
If the column already exists, this method should do nothing.
|
||||
|
||||
:param table: Table name
|
||||
:param column_name: Name of the column to add
|
||||
:param column_type: Type of the column to add
|
||||
:param nullable: Whether the column should be nullable (default: True)
|
||||
"""
|
||||
pass
|
||||
|
|
@ -1,303 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
|
||||
from llama_stack.core.access_control.conditions import ProtectedResource
|
||||
from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.core.storage.datatypes import StorageBackendType
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
# Hardcoded copy of the default policy that our SQL filtering implements
|
||||
# WARNING: If default_policy() changes, this constant must be updated accordingly
|
||||
# or SQL filtering will fall back to conservative mode (safe but less performant)
|
||||
#
|
||||
# This policy represents: "Permit all actions when user is in owners list for ALL attribute categories"
|
||||
# The corresponding SQL logic is implemented in _build_default_policy_where_clause():
|
||||
# - Public records (no access_attributes) are always accessible
|
||||
# - Records with access_attributes require user to match ALL categories that exist in the resource
|
||||
# - Missing categories in the resource are treated as "no restriction" (allow)
|
||||
# - Within each category, user needs ANY matching value (OR logic)
|
||||
# - Between categories, user needs ALL categories to match (AND logic)
|
||||
SQL_OPTIMIZED_POLICY = [
|
||||
AccessRule(
|
||||
permit=Scope(actions=list(Action)),
|
||||
when=["user in owners roles", "user in owners teams", "user in owners projects", "user in owners namespaces"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: User | None) -> Mapping[str, Any]:
|
||||
"""Add access control attributes to a data item."""
|
||||
enhanced = dict(item)
|
||||
if current_user:
|
||||
enhanced["owner_principal"] = current_user.principal
|
||||
enhanced["access_attributes"] = current_user.attributes
|
||||
else:
|
||||
enhanced["owner_principal"] = None
|
||||
enhanced["access_attributes"] = None
|
||||
return enhanced
|
||||
|
||||
|
||||
class SqlRecord(ProtectedResource):
|
||||
def __init__(self, record_id: str, table_name: str, owner: User):
|
||||
self.type = f"sql_record::{table_name}"
|
||||
self.identifier = record_id
|
||||
self.owner = owner
|
||||
|
||||
|
||||
class AuthorizedSqlStore:
|
||||
"""
|
||||
Authorization layer for SqlStore that provides access control functionality.
|
||||
|
||||
This class composes a base SqlStore and adds authorization methods that handle
|
||||
access control policies, user attribute capture, and SQL filtering optimization.
|
||||
"""
|
||||
|
||||
def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
|
||||
"""
|
||||
Initialize the authorization layer.
|
||||
|
||||
:param sql_store: Base SqlStore implementation to wrap
|
||||
:param policy: Access control policy to use for authorization
|
||||
"""
|
||||
self.sql_store = sql_store
|
||||
self.policy = policy
|
||||
self._detect_database_type()
|
||||
self._validate_sql_optimized_policy()
|
||||
|
||||
def _detect_database_type(self) -> None:
|
||||
"""Detect the database type from the underlying SQL store."""
|
||||
if not hasattr(self.sql_store, "config"):
|
||||
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
|
||||
|
||||
self.database_type = self.sql_store.config.type.value
|
||||
if self.database_type not in [StorageBackendType.SQL_POSTGRES.value, StorageBackendType.SQL_SQLITE.value]:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
|
||||
def _validate_sql_optimized_policy(self) -> None:
|
||||
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
|
||||
|
||||
This ensures that if default_policy() changes, we detect the mismatch and
|
||||
can update our SQL filtering logic accordingly.
|
||||
"""
|
||||
actual_default = default_policy()
|
||||
|
||||
if SQL_OPTIMIZED_POLICY != actual_default:
|
||||
logger.warning(
|
||||
f"SQL_OPTIMIZED_POLICY does not match default_policy(). "
|
||||
f"SQL filtering will use conservative mode. "
|
||||
f"Expected: {SQL_OPTIMIZED_POLICY}, Got: {actual_default}",
|
||||
)
|
||||
|
||||
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
|
||||
"""Create a table with built-in access control support."""
|
||||
|
||||
enhanced_schema = dict(schema)
|
||||
if "access_attributes" not in enhanced_schema:
|
||||
enhanced_schema["access_attributes"] = ColumnType.JSON
|
||||
if "owner_principal" not in enhanced_schema:
|
||||
enhanced_schema["owner_principal"] = ColumnType.STRING
|
||||
|
||||
await self.sql_store.create_table(table, enhanced_schema)
|
||||
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
|
||||
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
||||
"""Insert a row or batch of rows with automatic access control attribute capture."""
|
||||
current_user = get_authenticated_user()
|
||||
enhanced_data: Mapping[str, Any] | Sequence[Mapping[str, Any]]
|
||||
if isinstance(data, Mapping):
|
||||
enhanced_data = _enhance_item_with_access_control(data, current_user)
|
||||
else:
|
||||
enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
|
||||
await self.sql_store.insert(table, enhanced_data)
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
cursor: tuple[str, str] | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""Fetch all rows with automatic access control filtering."""
|
||||
access_where = self._build_access_control_where_clause(self.policy)
|
||||
rows = await self.sql_store.fetch_all(
|
||||
table=table,
|
||||
where=where,
|
||||
where_sql=access_where,
|
||||
limit=limit,
|
||||
order_by=order_by,
|
||||
cursor=cursor,
|
||||
)
|
||||
|
||||
current_user = get_authenticated_user()
|
||||
filtered_rows = []
|
||||
|
||||
for row in rows.data:
|
||||
stored_access_attrs = row.get("access_attributes")
|
||||
stored_owner_principal = row.get("owner_principal") or ""
|
||||
|
||||
record_id = row.get("id", "unknown")
|
||||
sql_record = SqlRecord(
|
||||
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
|
||||
)
|
||||
|
||||
if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
|
||||
filtered_rows.append(row)
|
||||
|
||||
return PaginatedResponse(
|
||||
data=filtered_rows,
|
||||
has_more=rows.has_more,
|
||||
)
|
||||
|
||||
async def fetch_one(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any] | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch one row with automatic access control checking."""
|
||||
results = await self.fetch_all(
|
||||
table=table,
|
||||
where=where,
|
||||
limit=1,
|
||||
order_by=order_by,
|
||||
)
|
||||
|
||||
return results.data[0] if results.data else None
|
||||
|
||||
async def update(self, table: str, data: Mapping[str, Any], where: Mapping[str, Any]) -> None:
|
||||
"""Update rows with automatic access control attribute capture."""
|
||||
enhanced_data = dict(data)
|
||||
|
||||
current_user = get_authenticated_user()
|
||||
if current_user:
|
||||
enhanced_data["owner_principal"] = current_user.principal
|
||||
enhanced_data["access_attributes"] = current_user.attributes
|
||||
else:
|
||||
enhanced_data["owner_principal"] = None
|
||||
enhanced_data["access_attributes"] = None
|
||||
|
||||
await self.sql_store.update(table, enhanced_data, where)
|
||||
|
||||
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
|
||||
"""Delete rows with automatic access control filtering."""
|
||||
await self.sql_store.delete(table, where)
|
||||
|
||||
def _build_access_control_where_clause(self, policy: list[AccessRule]) -> str:
|
||||
"""Build SQL WHERE clause for access control filtering.
|
||||
|
||||
Only applies SQL filtering for the default policy to ensure correctness.
|
||||
For custom policies, uses conservative filtering to avoid blocking legitimate access.
|
||||
"""
|
||||
current_user = get_authenticated_user()
|
||||
|
||||
if not policy or policy == SQL_OPTIMIZED_POLICY:
|
||||
return self._build_default_policy_where_clause(current_user)
|
||||
else:
|
||||
return self._build_conservative_where_clause()
|
||||
|
||||
def _json_extract(self, column: str, path: str) -> str:
|
||||
"""Extract JSON value (keeping JSON type).
|
||||
|
||||
Args:
|
||||
column: The JSON column name
|
||||
path: The JSON path (e.g., 'roles', 'teams')
|
||||
|
||||
Returns:
|
||||
SQL expression to extract JSON value
|
||||
"""
|
||||
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
|
||||
return f"{column}->'{path}'"
|
||||
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
|
||||
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
|
||||
def _json_extract_text(self, column: str, path: str) -> str:
|
||||
"""Extract JSON value as text.
|
||||
|
||||
Args:
|
||||
column: The JSON column name
|
||||
path: The JSON path (e.g., 'roles', 'teams')
|
||||
|
||||
Returns:
|
||||
SQL expression to extract JSON value as text
|
||||
"""
|
||||
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
|
||||
return f"{column}->>'{path}'"
|
||||
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
|
||||
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
|
||||
def _get_public_access_conditions(self) -> list[str]:
|
||||
"""Get the SQL conditions for public access."""
|
||||
# Public records are records that have no owner_principal or access_attributes
|
||||
conditions = ["owner_principal = ''"]
|
||||
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
|
||||
# Postgres stores JSON null as 'null'
|
||||
conditions.append("access_attributes::text = 'null'")
|
||||
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
|
||||
conditions.append("access_attributes = 'null'")
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
return conditions
|
||||
|
||||
def _build_default_policy_where_clause(self, current_user: User | None) -> str:
|
||||
"""Build SQL WHERE clause for the default policy.
|
||||
|
||||
Default policy: permit all actions when user in owners [roles, teams, projects, namespaces]
|
||||
This means user must match ALL attribute categories that exist in the resource.
|
||||
"""
|
||||
base_conditions = self._get_public_access_conditions()
|
||||
user_attr_conditions = []
|
||||
|
||||
if current_user and current_user.attributes:
|
||||
for attr_key, user_values in current_user.attributes.items():
|
||||
if user_values:
|
||||
value_conditions = []
|
||||
for value in user_values:
|
||||
# Check if JSON array contains the value
|
||||
escaped_value = value.replace("'", "''")
|
||||
json_text = self._json_extract_text("access_attributes", attr_key)
|
||||
value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')")
|
||||
|
||||
if value_conditions:
|
||||
# Check if the category is missing (NULL)
|
||||
category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL"
|
||||
user_matches_category = f"({' OR '.join(value_conditions)})"
|
||||
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
|
||||
|
||||
if user_attr_conditions:
|
||||
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
|
||||
base_conditions.append(all_requirements_met)
|
||||
|
||||
return f"({' OR '.join(base_conditions)})"
|
||||
|
||||
def _build_conservative_where_clause(self) -> str:
|
||||
"""Conservative SQL filtering for custom policies.
|
||||
|
||||
Only filters records we're 100% certain would be denied by any reasonable policy.
|
||||
"""
|
||||
current_user = get_authenticated_user()
|
||||
|
||||
if not current_user:
|
||||
# Only allow public records
|
||||
base_conditions = self._get_public_access_conditions()
|
||||
return f"({' OR '.join(base_conditions)})"
|
||||
|
||||
return "1=1"
|
||||
|
|
@ -1,313 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Float,
|
||||
Integer,
|
||||
MetaData,
|
||||
String,
|
||||
Table,
|
||||
Text,
|
||||
inspect,
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.ext.asyncio.engine import AsyncEngine
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.core.storage.datatypes import SqlAlchemySqlStoreConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .api import ColumnDefinition, ColumnType, SqlStore
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
TYPE_MAPPING: dict[ColumnType, Any] = {
|
||||
ColumnType.INTEGER: Integer,
|
||||
ColumnType.STRING: String,
|
||||
ColumnType.FLOAT: Float,
|
||||
ColumnType.BOOLEAN: Boolean,
|
||||
ColumnType.DATETIME: DateTime,
|
||||
ColumnType.TEXT: Text,
|
||||
ColumnType.JSON: JSON,
|
||||
}
|
||||
|
||||
|
||||
def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
|
||||
"""Return a SQLAlchemy expression for a where condition.
|
||||
|
||||
`value` may be a simple scalar (equality) or a mapping like {">": 123}.
|
||||
The returned expression is a SQLAlchemy ColumnElement usable in query.where(...).
|
||||
"""
|
||||
if isinstance(value, Mapping):
|
||||
if len(value) != 1:
|
||||
raise ValueError(f"Operator mapping must have a single operator, got: {value}")
|
||||
op, operand = next(iter(value.items()))
|
||||
if op == "==" or op == "=":
|
||||
return column == operand
|
||||
if op == ">":
|
||||
return column > operand
|
||||
if op == "<":
|
||||
return column < operand
|
||||
if op == ">=":
|
||||
return column >= operand
|
||||
if op == "<=":
|
||||
return column <= operand
|
||||
raise ValueError(f"Unsupported operator '{op}' in where mapping")
|
||||
return column == value
|
||||
|
||||
|
||||
class SqlAlchemySqlStoreImpl(SqlStore):
|
||||
def __init__(self, config: SqlAlchemySqlStoreConfig):
|
||||
self.config = config
|
||||
self.async_session = async_sessionmaker(self.create_engine())
|
||||
self.metadata = MetaData()
|
||||
|
||||
def create_engine(self) -> AsyncEngine:
|
||||
return create_async_engine(self.config.engine_str, pool_pre_ping=True)
|
||||
|
||||
async def create_table(
|
||||
self,
|
||||
table: str,
|
||||
schema: Mapping[str, ColumnType | ColumnDefinition],
|
||||
) -> None:
|
||||
if not schema:
|
||||
raise ValueError(f"No columns defined for table '{table}'.")
|
||||
|
||||
sqlalchemy_columns: list[Column] = []
|
||||
|
||||
for col_name, col_props in schema.items():
|
||||
col_type = None
|
||||
is_primary_key = False
|
||||
is_nullable = True
|
||||
|
||||
if isinstance(col_props, ColumnType):
|
||||
col_type = col_props
|
||||
elif isinstance(col_props, ColumnDefinition):
|
||||
col_type = col_props.type
|
||||
is_primary_key = col_props.primary_key
|
||||
is_nullable = col_props.nullable
|
||||
|
||||
sqlalchemy_type = TYPE_MAPPING.get(col_type)
|
||||
if not sqlalchemy_type:
|
||||
raise ValueError(f"Unsupported column type '{col_type}' for column '{col_name}'.")
|
||||
|
||||
sqlalchemy_columns.append(
|
||||
Column(col_name, sqlalchemy_type, primary_key=is_primary_key, nullable=is_nullable)
|
||||
)
|
||||
|
||||
if table not in self.metadata.tables:
|
||||
sqlalchemy_table = Table(table, self.metadata, *sqlalchemy_columns)
|
||||
else:
|
||||
sqlalchemy_table = self.metadata.tables[table]
|
||||
|
||||
engine = self.create_engine()
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
||||
async with self.async_session() as session:
|
||||
await session.execute(self.metadata.tables[table].insert(), data)
|
||||
await session.commit()
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any] | None = None,
|
||||
where_sql: str | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
cursor: tuple[str, str] | None = None,
|
||||
) -> PaginatedResponse:
|
||||
async with self.async_session() as session:
|
||||
table_obj = self.metadata.tables[table]
|
||||
query = select(table_obj)
|
||||
|
||||
if where:
|
||||
for key, value in where.items():
|
||||
query = query.where(_build_where_expr(table_obj.c[key], value))
|
||||
|
||||
if where_sql:
|
||||
query = query.where(text(where_sql))
|
||||
|
||||
# Handle cursor-based pagination
|
||||
if cursor:
|
||||
# Validate cursor tuple format
|
||||
if not isinstance(cursor, tuple) or len(cursor) != 2:
|
||||
raise ValueError(f"Cursor must be a tuple of (key_column, cursor_id), got: {cursor}")
|
||||
|
||||
# Require order_by for cursor pagination
|
||||
if not order_by:
|
||||
raise ValueError("order_by is required when using cursor pagination")
|
||||
|
||||
# Only support single-column ordering for cursor pagination
|
||||
if len(order_by) != 1:
|
||||
raise ValueError(
|
||||
f"Cursor pagination only supports single-column ordering, got {len(order_by)} columns"
|
||||
)
|
||||
|
||||
cursor_key_column, cursor_id = cursor
|
||||
order_column, order_direction = order_by[0]
|
||||
|
||||
# Verify cursor_key_column exists
|
||||
if cursor_key_column not in table_obj.c:
|
||||
raise ValueError(f"Cursor key column '{cursor_key_column}' not found in table '{table}'")
|
||||
|
||||
# Get cursor value for the order column
|
||||
cursor_query = select(table_obj.c[order_column]).where(table_obj.c[cursor_key_column] == cursor_id)
|
||||
cursor_result = await session.execute(cursor_query)
|
||||
cursor_row = cursor_result.fetchone()
|
||||
|
||||
if not cursor_row:
|
||||
raise ValueError(f"Record with {cursor_key_column}='{cursor_id}' not found in table '{table}'")
|
||||
|
||||
cursor_value = cursor_row[0]
|
||||
|
||||
# Apply cursor condition based on sort direction
|
||||
if order_direction == "desc":
|
||||
query = query.where(table_obj.c[order_column] < cursor_value)
|
||||
else:
|
||||
query = query.where(table_obj.c[order_column] > cursor_value)
|
||||
|
||||
# Apply ordering
|
||||
if order_by:
|
||||
if not isinstance(order_by, list):
|
||||
raise ValueError(
|
||||
f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}"
|
||||
)
|
||||
for order in order_by:
|
||||
if not isinstance(order, tuple):
|
||||
raise ValueError(
|
||||
f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}"
|
||||
)
|
||||
name, order_type = order
|
||||
if name not in table_obj.c:
|
||||
raise ValueError(f"Column '{name}' not found in table '{table}'")
|
||||
if order_type == "asc":
|
||||
query = query.order_by(table_obj.c[name].asc())
|
||||
elif order_type == "desc":
|
||||
query = query.order_by(table_obj.c[name].desc())
|
||||
else:
|
||||
raise ValueError(f"Invalid order '{order_type}' for column '{name}'")
|
||||
|
||||
# Fetch limit + 1 to determine has_more
|
||||
fetch_limit = limit
|
||||
if limit:
|
||||
fetch_limit = limit + 1
|
||||
|
||||
if fetch_limit:
|
||||
query = query.limit(fetch_limit)
|
||||
|
||||
result = await session.execute(query)
|
||||
if result.rowcount == 0:
|
||||
rows = []
|
||||
else:
|
||||
rows = [dict(row._mapping) for row in result]
|
||||
|
||||
# Always return pagination result
|
||||
has_more = False
|
||||
if limit and len(rows) > limit:
|
||||
has_more = True
|
||||
rows = rows[:limit]
|
||||
|
||||
return PaginatedResponse(data=rows, has_more=has_more)
|
||||
|
||||
async def fetch_one(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any] | None = None,
|
||||
where_sql: str | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
result = await self.fetch_all(table, where, where_sql, limit=1, order_by=order_by)
|
||||
if not result.data:
|
||||
return None
|
||||
return result.data[0]
|
||||
|
||||
async def update(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
where: Mapping[str, Any],
|
||||
) -> None:
|
||||
if not where:
|
||||
raise ValueError("where is required for update")
|
||||
|
||||
async with self.async_session() as session:
|
||||
stmt = self.metadata.tables[table].update()
|
||||
for key, value in where.items():
|
||||
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
|
||||
await session.execute(stmt, data)
|
||||
await session.commit()
|
||||
|
||||
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
|
||||
if not where:
|
||||
raise ValueError("where is required for delete")
|
||||
|
||||
async with self.async_session() as session:
|
||||
stmt = self.metadata.tables[table].delete()
|
||||
for key, value in where.items():
|
||||
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def add_column_if_not_exists(
|
||||
self,
|
||||
table: str,
|
||||
column_name: str,
|
||||
column_type: ColumnType,
|
||||
nullable: bool = True,
|
||||
) -> None:
|
||||
"""Add a column to an existing table if the column doesn't already exist."""
|
||||
engine = self.create_engine()
|
||||
|
||||
try:
|
||||
async with engine.begin() as conn:
|
||||
|
||||
def check_column_exists(sync_conn):
|
||||
inspector = inspect(sync_conn)
|
||||
|
||||
table_names = inspector.get_table_names()
|
||||
if table not in table_names:
|
||||
return False, False # table doesn't exist, column doesn't exist
|
||||
|
||||
existing_columns = inspector.get_columns(table)
|
||||
column_names = [col["name"] for col in existing_columns]
|
||||
|
||||
return True, column_name in column_names # table exists, column exists or not
|
||||
|
||||
table_exists, column_exists = await conn.run_sync(check_column_exists)
|
||||
if not table_exists or column_exists:
|
||||
return
|
||||
|
||||
sqlalchemy_type = TYPE_MAPPING.get(column_type)
|
||||
if not sqlalchemy_type:
|
||||
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
|
||||
|
||||
# Create the ALTER TABLE statement
|
||||
# Note: We need to get the dialect-specific type name
|
||||
dialect = engine.dialect
|
||||
type_impl = sqlalchemy_type()
|
||||
compiled_type = type_impl.compile(dialect=dialect)
|
||||
|
||||
nullable_clause = "" if nullable else " NOT NULL"
|
||||
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
|
||||
|
||||
await conn.execute(add_column_sql)
|
||||
|
||||
except Exception as e:
|
||||
# If any error occurs during migration, log it but don't fail
|
||||
# The table creation will handle adding the column
|
||||
logger.error(f"Error adding column {column_name} to table {table}: {e}")
|
||||
pass
|
||||
|
|
@ -1,70 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Annotated, cast
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
PostgresSqlStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageBackendConfig,
|
||||
StorageBackendType,
|
||||
)
|
||||
|
||||
from .api import SqlStore
|
||||
|
||||
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
|
||||
|
||||
_SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {}
|
||||
|
||||
|
||||
SqlStoreConfig = Annotated[
|
||||
SqliteSqlStoreConfig | PostgresSqlStoreConfig,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]:
|
||||
"""Get pip packages for SQL store config, handling both dict and object cases."""
|
||||
if isinstance(store_config, dict):
|
||||
store_type = store_config.get("type")
|
||||
if store_type == StorageBackendType.SQL_SQLITE.value:
|
||||
return SqliteSqlStoreConfig.pip_packages()
|
||||
elif store_type == StorageBackendType.SQL_POSTGRES.value:
|
||||
return PostgresSqlStoreConfig.pip_packages()
|
||||
else:
|
||||
raise ValueError(f"Unknown SQL store type: {store_type}")
|
||||
else:
|
||||
return store_config.pip_packages()
|
||||
|
||||
|
||||
def sqlstore_impl(reference: SqlStoreReference) -> SqlStore:
|
||||
backend_name = reference.backend
|
||||
|
||||
backend_config = _SQLSTORE_BACKENDS.get(backend_name)
|
||||
if backend_config is None:
|
||||
raise ValueError(
|
||||
f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
||||
)
|
||||
|
||||
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
|
||||
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
|
||||
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
|
||||
return SqlAlchemySqlStoreImpl(config)
|
||||
else:
|
||||
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
|
||||
|
||||
|
||||
def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
|
||||
"""Register the set of available SQL store backends for reference resolution."""
|
||||
global _SQLSTORE_BACKENDS
|
||||
|
||||
_SQLSTORE_BACKENDS.clear()
|
||||
for name, cfg in backends.items():
|
||||
_SQLSTORE_BACKENDS[name] = cfg
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,148 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
from mcp import ClientSession, McpError
|
||||
from mcp import types as mcp_types
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
)
|
||||
from llama_stack.core.datatypes import AuthenticationRequiredError
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.tools.ttl_dict import TTLDict
|
||||
|
||||
logger = get_logger(__name__, category="tools")
|
||||
|
||||
protocol_cache = TTLDict(ttl_seconds=3600)
|
||||
|
||||
|
||||
class MCPProtol(Enum):
|
||||
UNKNOWN = 0
|
||||
STREAMABLE_HTTP = 1
|
||||
SSE = 2
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
|
||||
# we use a ttl'd dict to cache the happy path protocol for each endpoint
|
||||
# but, we always fall back to trying the other protocol if we cannot initialize the session
|
||||
connection_strategies = [MCPProtol.STREAMABLE_HTTP, MCPProtol.SSE]
|
||||
mcp_protocol = protocol_cache.get(endpoint, default=MCPProtol.UNKNOWN)
|
||||
if mcp_protocol == MCPProtol.SSE:
|
||||
connection_strategies = [MCPProtol.SSE, MCPProtol.STREAMABLE_HTTP]
|
||||
|
||||
for i, strategy in enumerate(connection_strategies):
|
||||
try:
|
||||
client = streamablehttp_client
|
||||
if strategy == MCPProtol.SSE:
|
||||
client = sse_client
|
||||
async with client(endpoint, headers=headers) as client_streams:
|
||||
async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session:
|
||||
await session.initialize()
|
||||
protocol_cache[endpoint] = strategy
|
||||
yield session
|
||||
return
|
||||
except* httpx.HTTPStatusError as eg:
|
||||
for exc in eg.exceptions:
|
||||
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
|
||||
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
|
||||
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
|
||||
err = cast(httpx.HTTPStatusError, exc)
|
||||
if err.response.status_code == 401:
|
||||
raise AuthenticationRequiredError(exc) from exc
|
||||
if i == len(connection_strategies) - 1:
|
||||
raise
|
||||
except* httpx.ConnectError as eg:
|
||||
# Connection refused, server down, network unreachable
|
||||
if i == len(connection_strategies) - 1:
|
||||
error_msg = f"Failed to connect to MCP server at {endpoint}: Connection refused"
|
||||
logger.error(f"MCP connection error: {error_msg}")
|
||||
raise ConnectionError(error_msg) from eg
|
||||
else:
|
||||
logger.warning(
|
||||
f"failed to connect to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
||||
)
|
||||
except* httpx.TimeoutException as eg:
|
||||
# Request timeout, server too slow
|
||||
if i == len(connection_strategies) - 1:
|
||||
error_msg = f"MCP server at {endpoint} timed out"
|
||||
logger.error(f"MCP timeout error: {error_msg}")
|
||||
raise TimeoutError(error_msg) from eg
|
||||
else:
|
||||
logger.warning(
|
||||
f"MCP server at {endpoint} timed out via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
||||
)
|
||||
except* httpx.RequestError as eg:
|
||||
# DNS resolution failures, network errors, invalid URLs
|
||||
if i == len(connection_strategies) - 1:
|
||||
# Get the first exception's message for the error string
|
||||
exc_msg = str(eg.exceptions[0]) if eg.exceptions else "Unknown error"
|
||||
error_msg = f"Network error connecting to MCP server at {endpoint}: {exc_msg}"
|
||||
logger.error(f"MCP network error: {error_msg}")
|
||||
raise ConnectionError(error_msg) from eg
|
||||
else:
|
||||
logger.warning(
|
||||
f"network error connecting to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
||||
)
|
||||
except* McpError:
|
||||
if i < len(connection_strategies) - 1:
|
||||
logger.warning(
|
||||
f"failed to connect via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
|
||||
tools = []
|
||||
async with client_wrapper(endpoint, headers) as session:
|
||||
tools_result = await session.list_tools()
|
||||
for tool in tools_result.tools:
|
||||
tools.append(
|
||||
ToolDef(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
input_schema=tool.inputSchema,
|
||||
output_schema=getattr(tool, "outputSchema", None),
|
||||
metadata={
|
||||
"endpoint": endpoint,
|
||||
},
|
||||
)
|
||||
)
|
||||
return ListToolDefsResponse(data=tools)
|
||||
|
||||
|
||||
async def invoke_mcp_tool(
|
||||
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
async with client_wrapper(endpoint, headers) as session:
|
||||
result = await session.call_tool(tool_name, kwargs)
|
||||
|
||||
content: list[InterleavedContentItem] = []
|
||||
for item in result.content:
|
||||
if isinstance(item, mcp_types.TextContent):
|
||||
content.append(TextContentItem(text=item.text))
|
||||
elif isinstance(item, mcp_types.ImageContent):
|
||||
content.append(ImageContentItem(image=item.data))
|
||||
elif isinstance(item, mcp_types.EmbeddedResource):
|
||||
logger.warning(f"EmbeddedResource is not supported: {item}")
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {type(item)}")
|
||||
return ToolInvocationResult(
|
||||
content=content,
|
||||
error_code=1 if result.isError else 0,
|
||||
)
|
||||
|
|
@ -1,70 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
from threading import RLock
|
||||
from typing import Any
|
||||
|
||||
|
||||
class TTLDict(dict):
|
||||
"""
|
||||
A dictionary with a ttl for each item
|
||||
"""
|
||||
|
||||
def __init__(self, ttl_seconds: float, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ttl_seconds = ttl_seconds
|
||||
self._expires: dict[Any, Any] = {} # expires holds when an item will expire
|
||||
self._lock = RLock()
|
||||
|
||||
if args or kwargs:
|
||||
for k, v in self.items():
|
||||
self.__setitem__(k, v)
|
||||
|
||||
def __delitem__(self, key):
|
||||
with self._lock:
|
||||
del self._expires[key]
|
||||
super().__delitem__(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
with self._lock:
|
||||
self._expires[key] = time.monotonic() + self.ttl_seconds
|
||||
super().__setitem__(key, value)
|
||||
|
||||
def _is_expired(self, key):
|
||||
if key not in self._expires:
|
||||
return False
|
||||
return time.monotonic() > self._expires[key]
|
||||
|
||||
def __getitem__(self, key):
|
||||
with self._lock:
|
||||
if self._is_expired(key):
|
||||
del self._expires[key]
|
||||
super().__delitem__(key)
|
||||
raise KeyError(f"{key} has expired and was removed")
|
||||
|
||||
return super().__getitem__(key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def __contains__(self, key):
|
||||
try:
|
||||
_ = self[key]
|
||||
return True
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
with self._lock:
|
||||
for key in self.keys():
|
||||
if self._is_expired(key):
|
||||
del self._expires[key]
|
||||
super().__delitem__(key)
|
||||
return f"TTLDict({self.ttl_seconds}, {super().__repr__()})"
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,156 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
import uuid
|
||||
|
||||
|
||||
def generate_chunk_id(document_id: str, chunk_text: str, chunk_window: str | None = None) -> str:
|
||||
"""
|
||||
Generate a unique chunk ID using a hash of the document ID and chunk text.
|
||||
Then use the first 32 characters of the hash to create a UUID.
|
||||
"""
|
||||
hash_input = f"{document_id}:{chunk_text}".encode()
|
||||
if chunk_window:
|
||||
hash_input += f":{chunk_window}".encode()
|
||||
return str(uuid.UUID(hashlib.sha256(hash_input).hexdigest()[:32]))
|
||||
|
||||
|
||||
def proper_case(s: str) -> str:
|
||||
"""Convert a string to proper case (first letter uppercase, rest lowercase)."""
|
||||
return s[0].upper() + s[1:].lower() if s else s
|
||||
|
||||
|
||||
def sanitize_collection_name(name: str, weaviate_format=False) -> str:
|
||||
"""
|
||||
Sanitize collection name to ensure it only contains numbers, letters, and underscores.
|
||||
Any other characters are replaced with underscores.
|
||||
"""
|
||||
if not weaviate_format:
|
||||
s = re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
||||
else:
|
||||
s = proper_case(re.sub(r"[^a-zA-Z0-9]", "", name))
|
||||
return s
|
||||
|
||||
|
||||
class WeightedInMemoryAggregator:
|
||||
@staticmethod
|
||||
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
|
||||
"""
|
||||
Normalize scores to 0-1 range using min-max normalization.
|
||||
|
||||
Args:
|
||||
scores: dictionary of scores with document IDs as keys and scores as values
|
||||
|
||||
Returns:
|
||||
Normalized scores with document IDs as keys and normalized scores as values
|
||||
"""
|
||||
if not scores:
|
||||
return {}
|
||||
min_score, max_score = min(scores.values()), max(scores.values())
|
||||
score_range = max_score - min_score
|
||||
if score_range > 0:
|
||||
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
|
||||
return dict.fromkeys(scores, 1.0)
|
||||
|
||||
@staticmethod
|
||||
def weighted_rerank(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
alpha: float = 0.5,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Rerank via weighted average of scores.
|
||||
|
||||
Args:
|
||||
vector_scores: scores from vector search
|
||||
keyword_scores: scores from keyword search
|
||||
alpha: weight factor between 0 and 1 (default: 0.5)
|
||||
0 = keyword only, 1 = vector only, 0.5 = equal weight
|
||||
|
||||
Returns:
|
||||
All unique document IDs with weighted combined scores
|
||||
"""
|
||||
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||
normalized_vector_scores = WeightedInMemoryAggregator._normalize_scores(vector_scores)
|
||||
normalized_keyword_scores = WeightedInMemoryAggregator._normalize_scores(keyword_scores)
|
||||
|
||||
# Weighted formula: score = (1-alpha) * keyword_score + alpha * vector_score
|
||||
# alpha=0 means keyword only, alpha=1 means vector only
|
||||
return {
|
||||
doc_id: ((1 - alpha) * normalized_keyword_scores.get(doc_id, 0.0))
|
||||
+ (alpha * normalized_vector_scores.get(doc_id, 0.0))
|
||||
for doc_id in all_ids
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def rrf_rerank(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
impact_factor: float = 60.0,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Rerank via Reciprocal Rank Fusion.
|
||||
|
||||
Args:
|
||||
vector_scores: scores from vector search
|
||||
keyword_scores: scores from keyword search
|
||||
impact_factor: impact factor for RRF (default: 60.0)
|
||||
|
||||
Returns:
|
||||
All unique document IDs with RRF combined scores
|
||||
"""
|
||||
|
||||
# Convert scores to ranks
|
||||
vector_ranks = {
|
||||
doc_id: i + 1
|
||||
for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
|
||||
}
|
||||
keyword_ranks = {
|
||||
doc_id: i + 1
|
||||
for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
|
||||
}
|
||||
|
||||
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||
rrf_scores = {}
|
||||
for doc_id in all_ids:
|
||||
vector_rank = vector_ranks.get(doc_id, float("inf"))
|
||||
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
|
||||
|
||||
# RRF formula: score = 1/(k + r) where k is impact_factor (default: 60.0) and r is the rank
|
||||
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
|
||||
return rrf_scores
|
||||
|
||||
@staticmethod
|
||||
def combine_search_results(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
reranker_type: str = "rrf",
|
||||
reranker_params: dict[str, float] | None = None,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Combine vector and keyword search results using specified reranking strategy.
|
||||
|
||||
Args:
|
||||
vector_scores: scores from vector search
|
||||
keyword_scores: scores from keyword search
|
||||
reranker_type: type of reranker to use (default: RERANKER_TYPE_RRF)
|
||||
reranker_params: parameters for the reranker
|
||||
|
||||
Returns:
|
||||
All unique document IDs with combined scores
|
||||
"""
|
||||
if reranker_params is None:
|
||||
reranker_params = {}
|
||||
|
||||
if reranker_type == "weighted":
|
||||
alpha = reranker_params.get("alpha", 0.5)
|
||||
return WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha)
|
||||
else:
|
||||
# Default to RRF for None, RRF, or any unknown types
|
||||
impact_factor = reranker_params.get("impact_factor", 60.0)
|
||||
return WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor)
|
||||
Loading…
Add table
Add a link
Reference in a new issue