Merge branch 'main' into implement-search-for-PGVector

This commit is contained in:
Francisco Arceo 2025-08-28 10:20:25 -06:00 committed by GitHub
commit 4c03cddf6f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
176 changed files with 8344 additions and 734 deletions

View file

@ -29,12 +29,16 @@ class ListBatchesResponse(BaseModel):
@runtime_checkable
class Batches(Protocol):
"""Protocol for batch processing API operations.
"""
The Batches API enables efficient processing of multiple requests in a single operation,
particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale.
The API is designed to allow use of openai client libraries for seamless integration.
This API provides the following extensions:
- idempotent batch creation
Note: This API is currently under active development and may undergo changes.
"""
@ -45,6 +49,7 @@ class Batches(Protocol):
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
) -> BatchObject:
"""Create a new batch for processing multiple API requests.
@ -52,6 +57,7 @@ class Batches(Protocol):
:param endpoint: The endpoint to be used for all requests in the batch.
:param completion_window: The time window within which the batch should be processed.
:param metadata: Optional metadata for the batch.
:param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior.
:returns: The created batch object.
"""
...

View file

@ -473,6 +473,28 @@ class EmbeddingsResponse(BaseModel):
embeddings: list[list[float]]
@json_schema_type
class RerankData(BaseModel):
"""A single rerank result from a reranking response.
:param index: The original index of the document in the input list
:param relevance_score: The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance.
"""
index: int
relevance_score: float
@json_schema_type
class RerankResponse(BaseModel):
"""Response from a reranking request.
:param data: List of rerank result objects, sorted by relevance score (descending)
"""
data: list[RerankData]
@json_schema_type
class OpenAIChatCompletionContentPartTextParam(BaseModel):
"""Text content part for OpenAI-compatible chat completion messages.
@ -1046,6 +1068,7 @@ class InferenceProvider(Protocol):
:returns: A BatchCompletionResponse with the full completions.
"""
raise NotImplementedError("Batch completion is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/inference/chat-completion", method="POST")
async def chat_completion(
@ -1110,6 +1133,7 @@ class InferenceProvider(Protocol):
:returns: A BatchChatCompletionResponse with the full completions.
"""
raise NotImplementedError("Batch chat completion is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/inference/embeddings", method="POST")
async def embeddings(
@ -1131,6 +1155,25 @@ class InferenceProvider(Protocol):
"""
...
@webmethod(route="/inference/rerank", method="POST", experimental=True)
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
"""Rerank a list of documents based on their relevance to a query.
:param model: The identifier of the reranking model to use.
:param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length.
:param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length.
:param max_num_results: (Optional) Maximum number of results to return. Default: returns all.
:returns: RerankResponse with indices sorted by relevance score (descending).
"""
raise NotImplementedError("Reranking is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/openai/v1/completions", method="POST")
async def openai_completion(
self,

View file

@ -386,6 +386,7 @@ class MetricDataPoint(BaseModel):
timestamp: int
value: float
unit: str
@json_schema_type
@ -518,7 +519,7 @@ class Telemetry(Protocol):
metric_name: str,
start_time: int,
end_time: int | None = None,
granularity: str | None = "1d",
granularity: str | None = None,
query_type: MetricQueryType = MetricQueryType.RANGE,
label_matchers: list[MetricLabelMatcher] | None = None,
) -> QueryMetricsResponse:

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server")
logger = get_logger(name=__name__, category="cli")
class StackRun(Subcommand):

View file

@ -80,7 +80,7 @@ def get_provider_dependencies(
normal_deps = []
special_deps = []
for package in deps:
if "--no-deps" in package or "--index-url" in package:
if any(f in package for f in ["--no-deps", "--index-url", "--extra-index-url"]):
special_deps.append(package)
else:
normal_deps.append(package)

View file

@ -318,6 +318,41 @@ class QuotaConfig(BaseModel):
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
class CORSConfig(BaseModel):
allow_origins: list[str] = Field(default_factory=list)
allow_origin_regex: str | None = Field(default=None)
allow_methods: list[str] = Field(default=["OPTIONS"])
allow_headers: list[str] = Field(default_factory=list)
allow_credentials: bool = Field(default=False)
expose_headers: list[str] = Field(default_factory=list)
max_age: int = Field(default=600, ge=0)
@model_validator(mode="after")
def validate_credentials_config(self) -> Self:
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
raise ValueError("Cannot use wildcard origins with credentials enabled")
return self
def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None:
if cors_config is False or cors_config is None:
return None
if cors_config is True:
# dev mode: allow localhost on any port
return CORSConfig(
allow_origins=[],
allow_origin_regex=r"https?://localhost:\d+",
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
)
if isinstance(cors_config, CORSConfig):
return cors_config
raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
class ServerConfig(BaseModel):
port: int = Field(
default=8321,
@ -349,6 +384,12 @@ class ServerConfig(BaseModel):
default=None,
description="Per client quota request configuration",
)
cors: bool | CORSConfig | None = Field(
default=None,
description="CORS configuration for cross-origin requests. Can be:\n"
"- true: Enable localhost CORS for development\n"
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
)
class StackRunConfig(BaseModel):

View file

@ -146,39 +146,26 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
):
super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_distro_name, custom_provider_registry, provider_data
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
)
self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal
self.provider_data = provider_data
self.loop = asyncio.new_event_loop()
def initialize(self):
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
if not self.skip_logger_removal:
self._remove_root_logger_handlers()
# use a new event loop to avoid interfering with the main event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(self.async_client.initialize())
loop.run_until_complete(self.async_client.initialize())
finally:
asyncio.set_event_loop(None)
def _remove_root_logger_handlers(self):
def initialize(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
Deprecated method for backward compatibility.
"""
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
pass
def request(self, *args, **kwargs):
loop = self.loop
@ -216,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
config_path_or_distro_name: str,
custom_provider_registry: ProviderRegistry | None = None,
provider_data: dict[str, Any] | None = None,
skip_logger_removal: bool = False,
):
super().__init__()
# when using the library client, we should not log to console since many
@ -223,6 +211,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
if not skip_logger_removal:
self._remove_root_logger_handlers()
if config_path_or_distro_name.endswith(".yaml"):
config_path = Path(config_path_or_distro_name)
if not config_path.exists():
@ -239,7 +234,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self.provider_data = provider_data
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
def _remove_root_logger_handlers(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
"""
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
async def initialize(self) -> bool:
"""
Initialize the async client.
Returns:
bool: True if initialization was successful
"""
try:
self.route_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry)

View file

@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class DatasetIORouter(DatasetIO):

View file

@ -16,7 +16,7 @@ from llama_stack.apis.scoring import (
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class ScoringRouter(Scoring):

View file

@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="core::routers")
class InferenceRouter(Inference):

View file

@ -13,7 +13,7 @@ from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class SafetyRouter(Safety):

View file

@ -22,7 +22,7 @@ from llama_stack.log import get_logger
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class ToolRuntimeRouter(ToolRuntime):

View file

@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class VectorIORouter(VectorIO):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):

View file

@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
def get_impl_api(p: Any) -> Api:

View file

@ -26,7 +26,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):

View file

@ -17,7 +17,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class ModelsRoutingTable(CommonRoutingTableImpl, Models):

View file

@ -19,7 +19,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:

View file

@ -30,7 +30,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):

View file

@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
logger = get_logger(name=__name__, category="core::auth")
class AuthenticationMiddleware:

View file

@ -23,7 +23,7 @@ from llama_stack.core.datatypes import (
)
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
logger = get_logger(name=__name__, category="core::auth")
class AuthResponse(BaseModel):

View file

@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
logger = get_logger(name=__name__, category="quota")
logger = get_logger(name=__name__, category="core::server")
class QuotaMiddleware:

View file

@ -28,6 +28,7 @@ from aiohttp import hdrs
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
@ -40,6 +41,7 @@ from llama_stack.core.datatypes import (
AuthenticationRequiredError,
LoggingConfig,
StackRunConfig,
process_cors_config,
)
from llama_stack.core.distribution import builtin_automatically_routed_apis
from llama_stack.core.external import ExternalApiSpec, load_external_apis
@ -82,7 +84,7 @@ from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server")
logger = get_logger(name=__name__, category="core::server")
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -413,7 +415,7 @@ def main(args: argparse.Namespace | None = None):
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="server", config=logger_config)
logger = get_logger(name=__name__, category="core::server", config=logger_config)
if args.env:
for env_pair in args.env:
try:
@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None):
window_seconds=window_seconds,
)
if config.server.cors:
logger.info("Enabling CORS")
cors_config = process_cors_config(config.server.cors)
if cors_config:
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:

View file

@ -225,7 +225,10 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
try:
result = re.sub(pattern, get_env_var, config)
return _convert_string_to_proper_type(result)
# Only apply type conversion if substitution actually happened
if result != config:
return _convert_string_to_proper_type(result)
return result
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None

View file

@ -16,7 +16,7 @@ from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
logger = get_logger(__name__, category="core")
logger = get_logger(__name__, category="core::registry")
class DistributionRegistry(Protocol):

View file

@ -10,7 +10,7 @@ from pathlib import Path
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="config_resolution")
logger = get_logger(name=__name__, category="core")
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"

View file

@ -34,7 +34,7 @@ distribution_spec:
telemetry:
- provider_type: inline::meta-reference
post_training:
- provider_type: inline::huggingface
- provider_type: inline::torchtune-cpu
eval:
- provider_type: inline::meta-reference
datasetio:

View file

@ -156,13 +156,10 @@ providers:
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training:
- provider_id: huggingface
provider_type: inline::huggingface
- provider_id: torchtune-cpu
provider_type: inline::torchtune-cpu
config:
checkpoint_format: huggingface
distributed_backend: null
device: cpu
dpo_output_dir: ~/.llama/distributions/ci-tests/dpo_output
checkpoint_format: meta
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -1,7 +1,7 @@
---
orphan: true
---
# Meta Reference Distribution
# Meta Reference GPU Distribution
```{toctree}
:maxdepth: 2
@ -29,7 +29,7 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](../../references/llama_cli_reference/download_models.md) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
```
$ llama model list --downloaded

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .starter_gpu import get_distribution_template # noqa: F401

View file

@ -0,0 +1,59 @@
version: 2
distribution_spec:
description: Quick start template for running Llama Stack with several popular providers.
This distribution is intended for GPU-enabled environments.
providers:
inference:
- provider_type: remote::cerebras
- provider_type: remote::ollama
- provider_type: remote::vllm
- provider_type: remote::tgi
- provider_type: remote::fireworks
- provider_type: remote::together
- provider_type: remote::bedrock
- provider_type: remote::nvidia
- provider_type: remote::openai
- provider_type: remote::anthropic
- provider_type: remote::gemini
- provider_type: remote::vertexai
- provider_type: remote::groq
- provider_type: remote::sambanova
- provider_type: inline::sentence-transformers
vector_io:
- provider_type: inline::faiss
- provider_type: inline::sqlite-vec
- provider_type: inline::milvus
- provider_type: remote::chromadb
- provider_type: remote::pgvector
files:
- provider_type: inline::localfs
safety:
- provider_type: inline::llama-guard
- provider_type: inline::code-scanner
agents:
- provider_type: inline::meta-reference
telemetry:
- provider_type: inline::meta-reference
post_training:
- provider_type: inline::huggingface-gpu
eval:
- provider_type: inline::meta-reference
datasetio:
- provider_type: remote::huggingface
- provider_type: inline::localfs
scoring:
- provider_type: inline::basic
- provider_type: inline::llm-as-judge
- provider_type: inline::braintrust
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
batches:
- provider_type: inline::reference
image_type: venv
additional_pip_packages:
- aiosqlite
- asyncpg
- sqlalchemy[asyncio]

View file

@ -0,0 +1,241 @@
version: 2
image_name: starter-gpu
apis:
- agents
- batches
- datasetio
- eval
- files
- inference
- post_training
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
api_key: ${env.OPENAI_API_KEY:=}
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
- provider_id: anthropic
provider_type: remote::anthropic
config:
api_key: ${env.ANTHROPIC_API_KEY:=}
- provider_id: gemini
provider_type: remote::gemini
config:
api_key: ${env.GEMINI_API_KEY:=}
- provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
provider_type: remote::vertexai
config:
project: ${env.VERTEX_AI_PROJECT:=}
location: ${env.VERTEX_AI_LOCATION:=us-central1}
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
- provider_id: ${env.MILVUS_URL:+milvus}
provider_type: inline::milvus
config:
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db
- provider_id: ${env.CHROMADB_URL:+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db
- provider_id: ${env.PGVECTOR_DB:+pgvector}
provider_type: remote::pgvector
config:
host: ${env.PGVECTOR_HOST:=localhost}
port: ${env.PGVECTOR_PORT:=5432}
db: ${env.PGVECTOR_DB:=}
user: ${env.PGVECTOR_USER:=}
password: ${env.PGVECTOR_PASSWORD:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
- provider_id: code-scanner
provider_type: inline::code-scanner
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training:
- provider_id: huggingface-gpu
provider_type: inline::huggingface-gpu
config:
checkpoint_format: huggingface
distributed_backend: null
device: cpu
dpo_output_dir: ~/.llama/distributions/starter-gpu/dpo_output
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/huggingface_datasetio.db
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:=}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
batches:
- provider_id: reference
provider_type: inline::reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/batches.db
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/inference_store.db
models: []
shields:
- shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=}
- shield_id: code-scanner
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.distributions.template import BuildProvider, DistributionTemplate
from ..starter.starter import get_distribution_template as get_starter_distribution_template
def get_distribution_template() -> DistributionTemplate:
template = get_starter_distribution_template()
name = "starter-gpu"
template.name = name
template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments."
template.providers["post_training"] = [
BuildProvider(provider_type="inline::huggingface-gpu"),
]
return template

View file

@ -1,6 +1,7 @@
version: 2
distribution_spec:
description: Quick start template for running Llama Stack with several popular providers
description: Quick start template for running Llama Stack with several popular providers.
This distribution is intended for CPU-only environments.
providers:
inference:
- provider_type: remote::cerebras
@ -34,7 +35,7 @@ distribution_spec:
telemetry:
- provider_type: inline::meta-reference
post_training:
- provider_type: inline::huggingface
- provider_type: inline::torchtune-cpu
eval:
- provider_type: inline::meta-reference
datasetio:

View file

@ -156,13 +156,10 @@ providers:
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training:
- provider_id: huggingface
provider_type: inline::huggingface
- provider_id: torchtune-cpu
provider_type: inline::torchtune-cpu
config:
checkpoint_format: huggingface
distributed_backend: null
device: cpu
dpo_output_dir: ~/.llama/distributions/starter/dpo_output
checkpoint_format: meta
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -120,7 +120,7 @@ def get_distribution_template() -> DistributionTemplate:
],
"agents": [BuildProvider(provider_type="inline::meta-reference")],
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"post_training": [BuildProvider(provider_type="inline::huggingface")],
"post_training": [BuildProvider(provider_type="inline::torchtune-cpu")],
"eval": [BuildProvider(provider_type="inline::meta-reference")],
"datasetio": [
BuildProvider(provider_type="remote::huggingface"),
@ -178,7 +178,7 @@ def get_distribution_template() -> DistributionTemplate:
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Quick start template for running Llama Stack with several popular providers",
description="Quick start template for running Llama Stack with several popular providers. This distribution is intended for CPU-only environments.",
container_image=None,
template_path=None,
providers=providers,

View file

@ -36,7 +36,7 @@ from .utils import get_negative_inf_value, to_2tuple
MP_SCALE = 8
logger = get_logger(name=__name__, category="models")
logger = get_logger(name=__name__, category="models::llama")
def reduce_from_tensor_model_parallel_region(input_):

View file

@ -11,7 +11,7 @@ from llama_stack.log import get_logger
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="models::llama")
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")

View file

@ -18,7 +18,7 @@ from ...datatypes import QuantizationMode
from ..model import Transformer, TransformerBlock
from ..moe import MoE
log = get_logger(name=__name__, category="models")
log = get_logger(name=__name__, category="models::llama")
def swiglu_wrapper_no_reduce(

View file

@ -9,7 +9,7 @@ import collections
from llama_stack.log import get_logger
log = get_logger(name=__name__, category="llama")
log = get_logger(name=__name__, category="models::llama")
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401

View file

@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag"
logger = get_logger(name=__name__, category="agents")
logger = get_logger(name=__name__, category="agents::meta_reference")
class ChatAgent(ShieldRunnerMixin):

View file

@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
from .persistence import AgentInfo
from .responses.openai_responses import OpenAIResponsesImpl
logger = get_logger(name=__name__, category="agents")
logger = get_logger(name=__name__, category="agents::meta_reference")
class MetaReferenceAgentsImpl(Agents):

View file

@ -17,7 +17,7 @@ from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
log = get_logger(name=__name__, category="agents")
log = get_logger(name=__name__, category="agents::meta_reference")
class AgentSessionInfo(Session):

View file

@ -41,7 +41,7 @@ from .utils import (
convert_response_text_to_chat_response_format,
)
logger = get_logger(name=__name__, category="responses")
logger = get_logger(name=__name__, category="openai::responses")
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):

View file

@ -47,7 +47,7 @@ from llama_stack.log import get_logger
from .types import ChatCompletionContext, ChatCompletionResult
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
logger = get_logger(name=__name__, category="responses")
logger = get_logger(name=__name__, category="agents::meta_reference")
class StreamingResponseOrchestrator:

View file

@ -38,7 +38,7 @@ from llama_stack.log import get_logger
from .types import ChatCompletionContext, ToolExecutionResult
logger = get_logger(name=__name__, category="responses")
logger = get_logger(name=__name__, category="agents::meta_reference")
class ToolExecutor:

View file

@ -101,14 +101,22 @@ async def convert_response_input_to_chat_messages(
"""
messages: list[OpenAIMessageParam] = []
if isinstance(input, list):
# extract all OpenAIResponseInputFunctionToolCallOutput items
# so their corresponding OpenAIToolMessageParam instances can
# be added immediately following the corresponding
# OpenAIAssistantMessageParam
tool_call_results = {}
for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
messages.append(
OpenAIToolMessageParam(
content=input_item.output,
tool_call_id=input_item.call_id,
)
tool_call_results[input_item.call_id] = OpenAIToolMessageParam(
content=input_item.output,
tool_call_id=input_item.call_id,
)
for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
# skip as these have been extracted and inserted in order
pass
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
tool_call = OpenAIChatCompletionToolCall(
index=0,
@ -119,6 +127,9 @@ async def convert_response_input_to_chat_messages(
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
if input_item.call_id in tool_call_results:
messages.append(tool_call_results[input_item.call_id])
del tool_call_results[input_item.call_id]
elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall):
tool_call = OpenAIChatCompletionToolCall(
index=0,
@ -146,6 +157,10 @@ async def convert_response_input_to_chat_messages(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
)
messages.append(message_type(content=content))
if len(tool_call_results):
raise ValueError(
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
)
else:
messages.append(OpenAIUserMessageParam(content=input))
return messages

View file

@ -11,7 +11,7 @@ from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing
log = get_logger(name=__name__, category="agents")
log = get_logger(name=__name__, category="agents::meta_reference")
class SafetyException(Exception): # noqa: N818

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import asyncio
import hashlib
import itertools
import json
import time
@ -136,28 +137,45 @@ class ReferenceBatchesImpl(Batches):
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
) -> BatchObject:
"""
Create a new batch for processing multiple API requests.
Error handling by levels -
0. Input param handling, results in 40x errors before processing, e.g.
- Wrong completion_window
- Invalid metadata types
- Unknown endpoint
-> no batch created
1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
- input_file_id missing
- invalid json in file
- missing custom_id, method, url, body
- invalid model
- streaming
-> batch created, validation sends to failed status
2. Processing errors, result in error_file_id entries, e.g.
- Any error returned from inference endpoint
-> batch created, goes to completed status
This implementation provides optional idempotency: when an idempotency key
(idempotency_key) is provided, a deterministic ID is generated based on the input
parameters. If a batch with the same parameters already exists, it will be
returned instead of creating a duplicate. Without an idempotency key,
each request creates a new batch with a unique ID.
Args:
input_file_id: The ID of an uploaded file containing requests for the batch.
endpoint: The endpoint to be used for all requests in the batch.
completion_window: The time window within which the batch should be processed.
metadata: Optional metadata for the batch.
idempotency_key: Optional idempotency key for enabling idempotent behavior.
Returns:
The created or existing batch object.
"""
# Error handling by levels -
# 0. Input param handling, results in 40x errors before processing, e.g.
# - Wrong completion_window
# - Invalid metadata types
# - Unknown endpoint
# -> no batch created
# 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
# - input_file_id missing
# - invalid json in file
# - missing custom_id, method, url, body
# - invalid model
# - streaming
# -> batch created, validation sends to failed status
# 2. Processing errors, result in error_file_id entries, e.g.
# - Any error returned from inference endpoint
# -> batch created, goes to completed status
# TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions"]:
@ -171,6 +189,35 @@ class ReferenceBatchesImpl(Batches):
)
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
# For idempotent requests, use the idempotency key for the batch ID
# This ensures the same key always maps to the same batch ID,
# allowing us to detect parameter conflicts
if idempotency_key is not None:
hash_input = idempotency_key.encode("utf-8")
hash_digest = hashlib.sha256(hash_input).hexdigest()[:24]
batch_id = f"batch_{hash_digest}"
try:
existing_batch = await self.retrieve_batch(batch_id)
if (
existing_batch.input_file_id != input_file_id
or existing_batch.endpoint != endpoint
or existing_batch.completion_window != completion_window
or existing_batch.metadata != metadata
):
raise ConflictError(
f"Idempotency key '{idempotency_key}' was previously used with different parameters. "
"Either use a new idempotency key or ensure all parameters match the original request."
)
logger.info(f"Returning existing batch with ID: {batch_id}")
return existing_batch
except ResourceNotFoundError:
# Batch doesn't exist, continue with creation
pass
current_time = int(time.time())
batch = BatchObject(
@ -185,6 +232,7 @@ class ReferenceBatchesImpl(Batches):
)
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
logger.info(f"Created new batch with ID: {batch_id}")
if self.process_batches:
task = asyncio.create_task(self._process_batch(batch_id))

View file

@ -9,7 +9,6 @@ from collections.abc import AsyncGenerator
from llama_stack.apis.inference import (
CompletionResponse,
InferenceProvider,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
@ -100,25 +99,3 @@ class SentenceTransformersInferenceImpl(
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import datetime
import threading
from typing import Any
@ -145,11 +146,41 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
metric_name: str,
start_time: int,
end_time: int | None = None,
granularity: str | None = "1d",
granularity: str | None = None,
query_type: MetricQueryType = MetricQueryType.RANGE,
label_matchers: list[MetricLabelMatcher] | None = None,
) -> QueryMetricsResponse:
raise NotImplementedError("Querying metrics is not implemented")
"""Query metrics from the telemetry store.
Args:
metric_name: The name of the metric to query (e.g., "prompt_tokens")
start_time: Start time as Unix timestamp
end_time: End time as Unix timestamp (defaults to now if None)
granularity: Time granularity for aggregation
query_type: Type of query (RANGE or INSTANT)
label_matchers: Label filters to apply
Returns:
QueryMetricsResponse with metric time series data
"""
# Convert timestamps to datetime objects
start_dt = datetime.datetime.fromtimestamp(start_time, datetime.UTC)
end_dt = datetime.datetime.fromtimestamp(end_time, datetime.UTC) if end_time else None
# Use SQLite trace store if available
if hasattr(self, "trace_store") and self.trace_store:
return await self.trace_store.query_metrics(
metric_name=metric_name,
start_time=start_dt,
end_time=end_dt,
granularity=granularity,
query_type=query_type,
label_matchers=label_matchers,
)
else:
raise ValueError(
f"In order to query_metrics, you must have {TelemetrySink.SQLITE} set in your telemetry sinks"
)
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:

View file

@ -5,9 +5,11 @@
# the root directory of this source tree.
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
@ -23,4 +25,14 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
description="Local filesystem-based file storage provider for managing files and documents locally.",
),
remote_provider_spec(
api=Api.files,
adapter=AdapterSpec(
adapter_type="s3",
pip_packages=["boto3"] + sql_store_pip_packages,
module="llama_stack.providers.remote.files.s3",
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
),
),
]

View file

@ -40,8 +40,9 @@ def available_providers() -> list[ProviderSpec]:
InlineProviderSpec(
api=Api.inference,
provider_type="inline::sentence-transformers",
# CrossEncoder depends on torchao.quantization
pip_packages=[
"torch torchvision --index-url https://download.pytorch.org/whl/cpu",
"torch torchvision torchao>=0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu",
"sentence-transformers --no-deps",
],
module="llama_stack.providers.inline.inference.sentence_transformers",

View file

@ -5,27 +5,50 @@
# the root directory of this source tree.
from typing import cast
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
# We provide two versions of these providers so that distributions can package the appropriate version of torch.
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
torchtune_def = dict(
api=Api.post_training,
pip_packages=["numpy"],
module="llama_stack.providers.inline.post_training.torchtune",
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
description="TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.",
)
def available_providers() -> list[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.post_training,
provider_type="inline::torchtune",
pip_packages=["torch", "torchtune==0.5.0", "torchao==0.8.0", "numpy"],
module="llama_stack.providers.inline.post_training.torchtune",
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
description="TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.",
**{ # type: ignore
**torchtune_def,
"provider_type": "inline::torchtune-cpu",
"pip_packages": (
cast(list[str], torchtune_def["pip_packages"])
+ ["torch torchtune>=0.5.0 torchao>=0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu"]
),
},
),
InlineProviderSpec(
**{ # type: ignore
**torchtune_def,
"provider_type": "inline::torchtune-gpu",
"pip_packages": (
cast(list[str], torchtune_def["pip_packages"]) + ["torch torchtune>=0.5.0 torchao>=0.12.0"]
),
},
),
InlineProviderSpec(
api=Api.post_training,
provider_type="inline::huggingface",
pip_packages=["torch", "trl", "transformers", "peft", "datasets"],
provider_type="inline::huggingface-gpu",
pip_packages=["trl", "transformers", "peft", "datasets", "torch"],
module="llama_stack.providers.inline.post_training.huggingface",
config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig",
api_dependencies=[

View file

@ -0,0 +1,237 @@
# S3 Files Provider
A remote S3-based implementation of the Llama Stack Files API that provides scalable cloud file storage with metadata persistence.
## Features
- **AWS S3 Storage**: Store files in AWS S3 buckets for scalable, durable storage
- **Metadata Management**: Uses SQL database for efficient file metadata queries
- **OpenAI API Compatibility**: Full compatibility with OpenAI Files API endpoints
- **Flexible Authentication**: Support for IAM roles and access keys
- **Custom S3 Endpoints**: Support for MinIO and other S3-compatible services
## Configuration
### Basic Configuration
```yaml
api: files
provider_type: remote::s3
config:
bucket_name: my-llama-stack-files
region: us-east-1
metadata_store:
type: sqlite
db_path: ./s3_files_metadata.db
```
### Advanced Configuration
```yaml
api: files
provider_type: remote::s3
config:
bucket_name: my-llama-stack-files
region: us-east-1
aws_access_key_id: YOUR_ACCESS_KEY
aws_secret_access_key: YOUR_SECRET_KEY
endpoint_url: https://s3.amazonaws.com # Optional for custom endpoints
metadata_store:
type: sqlite
db_path: ./s3_files_metadata.db
```
### Environment Variables
The configuration supports environment variable substitution:
```yaml
config:
bucket_name: "${env.S3_BUCKET_NAME}"
region: "${env.AWS_REGION:=us-east-1}"
aws_access_key_id: "${env.AWS_ACCESS_KEY_ID:=}"
aws_secret_access_key: "${env.AWS_SECRET_ACCESS_KEY:=}"
endpoint_url: "${env.S3_ENDPOINT_URL:=}"
```
Note: `S3_BUCKET_NAME` has no default value since S3 bucket names must be globally unique.
## Authentication
### IAM Roles (Recommended)
For production deployments, use IAM roles:
```yaml
config:
bucket_name: my-bucket
region: us-east-1
# No credentials needed - will use IAM role
```
### Access Keys
For development or specific use cases:
```yaml
config:
bucket_name: my-bucket
region: us-east-1
aws_access_key_id: AKIAIOSFODNN7EXAMPLE
aws_secret_access_key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY
```
## S3 Bucket Setup
### Required Permissions
The S3 provider requires the following permissions:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject",
"s3:ListBucket"
],
"Resource": [
"arn:aws:s3:::your-bucket-name",
"arn:aws:s3:::your-bucket-name/*"
]
}
]
}
```
### Automatic Bucket Creation
By default, the S3 provider expects the bucket to already exist. If you want the provider to automatically create the bucket when it doesn't exist, set `auto_create_bucket: true` in your configuration:
```yaml
config:
bucket_name: my-bucket
auto_create_bucket: true # Will create bucket if it doesn't exist
region: us-east-1
```
**Note**: When `auto_create_bucket` is enabled, the provider will need additional permissions:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject",
"s3:ListBucket",
"s3:CreateBucket"
],
"Resource": [
"arn:aws:s3:::your-bucket-name",
"arn:aws:s3:::your-bucket-name/*"
]
}
]
}
```
### Bucket Policy (Optional)
For additional security, you can add a bucket policy:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Sid": "LlamaStackAccess",
"Effect": "Allow",
"Principal": {
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
},
"Action": [
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject"
],
"Resource": "arn:aws:s3:::your-bucket-name/*"
},
{
"Sid": "LlamaStackBucketAccess",
"Effect": "Allow",
"Principal": {
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
},
"Action": [
"s3:ListBucket"
],
"Resource": "arn:aws:s3:::your-bucket-name"
}
]
}
```
## Features
### Metadata Persistence
File metadata is stored in a SQL database for fast queries and OpenAI API compatibility. The metadata includes:
- File ID
- Original filename
- Purpose (assistants, batch, etc.)
- File size in bytes
- Created and expiration timestamps
### TTL and Cleanup
Files currently have a fixed long expiration time (100 years).
## Development and Testing
### Using MinIO
For self-hosted S3-compatible storage:
```yaml
config:
bucket_name: test-bucket
region: us-east-1
endpoint_url: http://localhost:9000
aws_access_key_id: minioadmin
aws_secret_access_key: minioadmin
```
## Monitoring and Logging
The provider logs important operations and errors. For production deployments, consider:
- CloudWatch monitoring for S3 operations
- Custom metrics for file upload/download rates
- Error rate monitoring
- Performance metrics tracking
## Error Handling
The provider handles various error scenarios:
- S3 connectivity issues
- Bucket access permissions
- File not found errors
- Metadata consistency checks
## Known Limitations
- Fixed long TTL (100 years) instead of configurable expiration
- No server-side encryption enabled by default
- No support for AWS session tokens
- No S3 key prefix organization support
- No multipart upload support (all files uploaded as single objects)

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.core.datatypes import Api
from .config import S3FilesImplConfig
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]):
from .files import S3FilesImpl
# TODO: authorization policies and user separation
impl = S3FilesImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
class S3FilesImplConfig(BaseModel):
"""Configuration for S3-based files provider."""
bucket_name: str = Field(description="S3 bucket name to store files")
region: str = Field(default="us-east-1", description="AWS region where the bucket is located")
aws_access_key_id: str | None = Field(default=None, description="AWS access key ID (optional if using IAM roles)")
aws_secret_access_key: str | None = Field(
default=None, description="AWS secret access key (optional if using IAM roles)"
)
endpoint_url: str | None = Field(default=None, description="Custom S3 endpoint URL (for MinIO, LocalStack, etc.)")
auto_create_bucket: bool = Field(
default=False, description="Automatically create the S3 bucket if it doesn't exist"
)
metadata_store: SqlStoreConfig = Field(description="SQL store configuration for file metadata")
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"bucket_name": "${env.S3_BUCKET_NAME}", # no default, buckets must be globally unique
"region": "${env.AWS_REGION:=us-east-1}",
"aws_access_key_id": "${env.AWS_ACCESS_KEY_ID:=}",
"aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:=}",
"endpoint_url": "${env.S3_ENDPOINT_URL:=}",
"auto_create_bucket": "${env.S3_AUTO_CREATE_BUCKET:=false}",
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="s3_files_metadata.db",
),
}

View file

@ -0,0 +1,272 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
import uuid
from typing import Annotated
import boto3
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
from fastapi import File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import (
Files,
ListOpenAIFileResponse,
OpenAIFileDeleteResponse,
OpenAIFileObject,
OpenAIFilePurpose,
)
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl
from .config import S3FilesImplConfig
# TODO: provider data for S3 credentials
def _create_s3_client(config: S3FilesImplConfig) -> boto3.client:
try:
s3_config = {
"region_name": config.region,
}
# endpoint URL if specified (for MinIO, LocalStack, etc.)
if config.endpoint_url:
s3_config["endpoint_url"] = config.endpoint_url
if config.aws_access_key_id and config.aws_secret_access_key:
s3_config.update(
{
"aws_access_key_id": config.aws_access_key_id,
"aws_secret_access_key": config.aws_secret_access_key,
}
)
return boto3.client("s3", **s3_config)
except (BotoCoreError, NoCredentialsError) as e:
raise RuntimeError(f"Failed to initialize S3 client: {e}") from e
async def _create_bucket_if_not_exists(client: boto3.client, config: S3FilesImplConfig) -> None:
try:
client.head_bucket(Bucket=config.bucket_name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
if error_code == "404":
if not config.auto_create_bucket:
raise RuntimeError(
f"S3 bucket '{config.bucket_name}' does not exist. "
f"Either create the bucket manually or set 'auto_create_bucket: true' in your configuration."
) from e
try:
# For us-east-1, we can't specify LocationConstraint
if config.region == "us-east-1":
client.create_bucket(Bucket=config.bucket_name)
else:
client.create_bucket(
Bucket=config.bucket_name,
CreateBucketConfiguration={"LocationConstraint": config.region},
)
except ClientError as create_error:
raise RuntimeError(
f"Failed to create S3 bucket '{config.bucket_name}': {create_error}"
) from create_error
elif error_code == "403":
raise RuntimeError(f"Access denied to S3 bucket '{config.bucket_name}'") from e
else:
raise RuntimeError(f"Failed to access S3 bucket '{config.bucket_name}': {e}") from e
class S3FilesImpl(Files):
"""S3-based implementation of the Files API."""
# TODO: implement expiration, for now a silly offset
_SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60
def __init__(self, config: S3FilesImplConfig) -> None:
self._config = config
self._client: boto3.client | None = None
self._sql_store: SqlStore | None = None
async def initialize(self) -> None:
self._client = _create_s3_client(self._config)
await _create_bucket_if_not_exists(self._client, self._config)
self._sql_store = sqlstore_impl(self._config.metadata_store)
await self._sql_store.create_table(
"openai_files",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"filename": ColumnType.STRING,
"purpose": ColumnType.STRING,
"bytes": ColumnType.INTEGER,
"created_at": ColumnType.INTEGER,
"expires_at": ColumnType.INTEGER,
# TODO: add s3_etag field for integrity checking
},
)
async def shutdown(self) -> None:
pass
@property
def client(self) -> boto3.client:
assert self._client is not None, "Provider not initialized"
return self._client
@property
def sql_store(self) -> SqlStore:
assert self._sql_store is not None, "Provider not initialized"
return self._sql_store
async def openai_upload_file(
self,
file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()],
) -> OpenAIFileObject:
file_id = f"file-{uuid.uuid4().hex}"
filename = getattr(file, "filename", None) or "uploaded_file"
created_at = int(time.time())
expires_at = created_at + self._SILLY_EXPIRATION_OFFSET
content = await file.read()
file_size = len(content)
await self.sql_store.insert(
"openai_files",
{
"id": file_id,
"filename": filename,
"purpose": purpose.value,
"bytes": file_size,
"created_at": created_at,
"expires_at": expires_at,
},
)
try:
self.client.put_object(
Bucket=self._config.bucket_name,
Key=file_id,
Body=content,
# TODO: enable server-side encryption
)
except ClientError as e:
await self.sql_store.delete("openai_files", where={"id": file_id})
raise RuntimeError(f"Failed to upload file to S3: {e}") from e
return OpenAIFileObject(
id=file_id,
filename=filename,
purpose=purpose,
bytes=file_size,
created_at=created_at,
expires_at=expires_at,
)
async def openai_list_files(
self,
after: str | None = None,
limit: int | None = 10000,
order: Order | None = Order.desc,
purpose: OpenAIFilePurpose | None = None,
) -> ListOpenAIFileResponse:
# this purely defensive. it should not happen because the router also default to Order.desc.
if not order:
order = Order.desc
where_conditions = {}
if purpose:
where_conditions["purpose"] = purpose.value
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,
limit=limit,
)
files = [
OpenAIFileObject(
id=row["id"],
filename=row["filename"],
purpose=OpenAIFilePurpose(row["purpose"]),
bytes=row["bytes"],
created_at=row["created_at"],
expires_at=row["expires_at"],
)
for row in paginated_result.data
]
return ListOpenAIFileResponse(
data=files,
has_more=paginated_result.has_more,
# empty string or None? spec says str, ref impl returns str | None, we go with spec
first_id=files[0].id if files else "",
last_id=files[-1].id if files else "",
)
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "files.list()")
return OpenAIFileObject(
id=row["id"],
filename=row["filename"],
purpose=OpenAIFilePurpose(row["purpose"]),
bytes=row["bytes"],
created_at=row["created_at"],
expires_at=row["expires_at"],
)
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "files.list()")
try:
self.client.delete_object(
Bucket=self._config.bucket_name,
Key=row["id"],
)
except ClientError as e:
if e.response["Error"]["Code"] != "NoSuchKey":
raise RuntimeError(f"Failed to delete file from S3: {e}") from e
await self.sql_store.delete("openai_files", where={"id": file_id})
return OpenAIFileDeleteResponse(id=file_id, deleted=True)
async def openai_retrieve_file_content(self, file_id: str) -> Response:
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "files.list()")
try:
response = self.client.get_object(
Bucket=self._config.bucket_name,
Key=row["id"],
)
# TODO: can we stream this instead of loading it into memory
content = response["Body"].read()
except ClientError as e:
if e.response["Error"]["Code"] == "NoSuchKey":
await self.sql_store.delete("openai_files", where={"id": file_id})
raise ResourceNotFoundError(file_id, "File", "files.list()") from e
raise RuntimeError(f"Failed to download file from S3: {e}") from e
return Response(
content=content,
media_type="application/octet-stream",
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
)

View file

@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::fireworks")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):

View file

@ -10,7 +10,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):

View file

@ -41,6 +41,11 @@ client.initialize()
### Create Completion
> Note on Completion API
>
> The hosted NVIDIA Llama NIMs (e.g., `meta-llama/Llama-3.1-8B-Instruct`) with ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` does not support the ```completion``` method, while the locally deployed NIM does.
```python
response = client.inference.completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
@ -76,6 +81,73 @@ response = client.inference.chat_completion(
print(f"Response: {response.completion_message.content}")
```
### Tool Calling Example ###
```python
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
tool_definition = ToolDefinition(
tool_name="get_weather",
description="Get current weather information for a location",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
required=True,
),
"unit": ToolParamDefinition(
param_type="string",
description="Temperature unit (celsius or fahrenheit)",
required=False,
default="celsius",
),
},
)
tool_response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
tools=[tool_definition],
)
print(f"Tool Response: {tool_response.completion_message.content}")
if tool_response.completion_message.tool_calls:
for tool_call in tool_response.completion_message.tool_calls:
print(f"Tool Called: {tool_call.tool_name}")
print(f"Arguments: {tool_call.arguments}")
```
### Structured Output Example
```python
from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
person_schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"occupation": {"type": "string"},
},
"required": ["name", "age", "occupation"],
}
response_format = JsonSchemaResponseFormat(
type=ResponseFormatType.json_schema, json_schema=person_schema
)
structured_response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
messages=[
{
"role": "user",
"content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ",
}
],
response_format=response_format,
)
print(f"Structured Response: {structured_response.completion_message.content}")
```
### Create Embeddings
> Note on OpenAI embeddings compatibility
>

View file

@ -7,7 +7,7 @@
import warnings
from collections.abc import AsyncIterator
from openai import NOT_GIVEN, APIConnectionError, BadRequestError
from openai import NOT_GIVEN, APIConnectionError
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -57,7 +57,7 @@ from .openai_utils import (
)
from .utils import _is_nvidia_hosted
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::nvidia")
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
@ -197,15 +197,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
}
extra_body["input_type"] = task_type_options[task_type]
try:
response = await self.client.embeddings.create(
model=provider_model_id,
input=input,
extra_body=extra_body,
)
except BadRequestError as e:
raise ValueError(f"Failed to get embeddings: {e}") from e
response = await self.client.embeddings.create(
model=provider_model_id,
input=input,
extra_body=extra_body,
)
#
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
# ->

View file

@ -10,7 +10,7 @@ from llama_stack.log import get_logger
from . import NVIDIAConfig
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::nvidia")
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:

View file

@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter(
@ -619,28 +619,6 @@ class OllamaInferenceAdapter(
response.id = id
return response
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for Ollama")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
async def _convert_content(content) -> dict:

View file

@ -11,7 +11,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::openai")
#

View file

@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
log = get_logger(name=__name__, category="inference")
log = get_logger(name=__name__, category="inference::tgi")
def build_hf_repo_model_entries():

View file

@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::together")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):

View file

@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMInferenceAdapterConfig
log = get_logger(name=__name__, category="inference")
log = get_logger(name=__name__, category="inference::vllm")
def build_hf_repo_model_entries():
@ -711,25 +711,3 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
user=user,
)
return await self.client.chat.completions.create(**params) # type: ignore
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for Ollama")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")

View file

@ -15,7 +15,7 @@ from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefa
from .config import NvidiaPostTrainingConfig
logger = get_logger(name=__name__, category="integration")
logger = get_logger(name=__name__, category="post_training::nvidia")
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:

View file

@ -21,7 +21,7 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from .config import BedrockSafetyConfig
logger = get_logger(name=__name__, category="safety")
logger = get_logger(name=__name__, category="safety::bedrock")
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):

View file

@ -9,7 +9,7 @@ from typing import Any
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
from .config import NVIDIASafetyConfig
logger = get_logger(name=__name__, category="safety")
logger = get_logger(name=__name__, category="safety::nvidia")
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
@ -67,6 +67,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
self.shield = NeMoGuardrails(self.config, shield.shield_id)
return await self.shield.run(messages)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")
class NeMoGuardrails:
"""

View file

@ -25,7 +25,7 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
from .config import SambaNovaSafetyConfig
logger = get_logger(name=__name__, category="safety")
logger = get_logger(name=__name__, category="safety::sambanova")
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"

View file

@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
log = get_logger(name=__name__, category="vector_io")
log = get_logger(name=__name__, category="vector_io::chroma")
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI

View file

@ -36,7 +36,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
logger = get_logger(name=__name__, category="vector_io")
logger = get_logger(name=__name__, category="vector_io::milvus")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"

View file

@ -39,7 +39,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryA
from .config import PGVectorVectorIOConfig
log = get_logger(name=__name__, category="vector_io")
log = get_logger(name=__name__, category="vector_io::pgvector")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"

View file

@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
log = get_logger(name=__name__, category="vector_io")
log = get_logger(name=__name__, category="vector_io::qdrant")
CHUNK_ID_KEY = "_chunk_id"
# KV store prefixes for vector databases

View file

@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
from .config import WeaviateVectorIOConfig
log = get_logger(name=__name__, category="vector_io")
log = get_logger(name=__name__, category="vector_io::weaviate")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"

View file

@ -28,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
EMBEDDING_MODELS = {}
log = get_logger(name=__name__, category="inference")
log = get_logger(name=__name__, category="providers::utils")
class SentenceTransformerEmbeddingMixin:

View file

@ -54,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="providers::utils")
class LiteLLMOpenAIMixin(
@ -429,28 +429,6 @@ class LiteLLMOpenAIMixin(
)
return await litellm.acompletion(**params)
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for OpenAI Compat")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available via LiteLLM for the current

View file

@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="providers::utils")
class RemoteInferenceProviderConfig(BaseModel):

View file

@ -134,7 +134,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
decode_assistant_message,
)
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="providers::utils")
class OpenAICompatCompletionChoiceDelta(BaseModel):

View file

@ -25,7 +25,7 @@ from llama_stack.apis.inference import (
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="providers::utils")
class OpenAIMixin(ABC):

View file

@ -58,7 +58,7 @@ from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.utils.inference import supported_inference_models
log = get_logger(name=__name__, category="inference")
log = get_logger(name=__name__, category="providers::utils")
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):

View file

@ -13,7 +13,7 @@ from llama_stack.providers.utils.kvstore import KVStore
from ..config import MongoDBKVStoreConfig
log = get_logger(name=__name__, category="kvstore")
log = get_logger(name=__name__, category="providers::utils")
class MongoDBKVStoreImpl(KVStore):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from ..api import KVStore
from ..config import PostgresKVStoreConfig
log = get_logger(name=__name__, category="kvstore")
log = get_logger(name=__name__, category="providers::utils")
class PostgresKVStoreImpl(KVStore):

View file

@ -44,7 +44,7 @@ from llama_stack.providers.utils.memory.vector_store import (
make_overlapped_chunks,
)
logger = get_logger(name=__name__, category="memory")
logger = get_logger(name=__name__, category="providers::utils")
# Constants for OpenAI vector stores
CHUNK_MULTIPLIER = 5

View file

@ -33,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
log = get_logger(name=__name__, category="memory")
log = get_logger(name=__name__, category="providers::utils")
class ChunkForDeletion(BaseModel):

View file

@ -17,7 +17,7 @@ from pydantic import BaseModel
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="scheduler")
logger = get_logger(name=__name__, category="providers::utils")
# TODO: revisit the list of possible statuses when defining a more coherent

View file

@ -17,7 +17,7 @@ from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
from .sqlstore import SqlStoreType
logger = get_logger(name=__name__, category="authorized_sqlstore")
logger = get_logger(name=__name__, category="providers::utils")
# Hardcoded copy of the default policy that our SQL filtering implements
# WARNING: If default_policy() changes, this constant must be updated accordingly

View file

@ -30,7 +30,7 @@ from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, SqlStore
from .sqlstore import SqlAlchemySqlStoreConfig
logger = get_logger(name=__name__, category="sqlstore")
logger = get_logger(name=__name__, category="providers::utils")
TYPE_MAPPING: dict[ColumnType, Any] = {
ColumnType.INTEGER: Integer,

View file

@ -5,12 +5,23 @@
# the root directory of this source tree.
import json
from datetime import datetime
from datetime import UTC, datetime
from typing import Protocol
import aiosqlite
from llama_stack.apis.telemetry import QueryCondition, Span, SpanWithStatus, Trace
from llama_stack.apis.telemetry import (
MetricDataPoint,
MetricLabel,
MetricLabelMatcher,
MetricQueryType,
MetricSeries,
QueryCondition,
QueryMetricsResponse,
Span,
SpanWithStatus,
Trace,
)
class TraceStore(Protocol):
@ -29,11 +40,192 @@ class TraceStore(Protocol):
max_depth: int | None = None,
) -> dict[str, SpanWithStatus]: ...
async def query_metrics(
self,
metric_name: str,
start_time: datetime,
end_time: datetime | None = None,
granularity: str | None = "1d",
query_type: MetricQueryType = MetricQueryType.RANGE,
label_matchers: list[MetricLabelMatcher] | None = None,
) -> QueryMetricsResponse: ...
class SQLiteTraceStore(TraceStore):
def __init__(self, conn_string: str):
self.conn_string = conn_string
async def query_metrics(
self,
metric_name: str,
start_time: datetime,
end_time: datetime | None = None,
granularity: str | None = None,
query_type: MetricQueryType = MetricQueryType.RANGE,
label_matchers: list[MetricLabelMatcher] | None = None,
) -> QueryMetricsResponse:
if end_time is None:
end_time = datetime.now(UTC)
# Build base query
if query_type == MetricQueryType.INSTANT:
query = """
SELECT
se.name,
SUM(CAST(json_extract(se.attributes, '$.value') AS REAL)) as value,
json_extract(se.attributes, '$.unit') as unit,
se.attributes
FROM span_events se
WHERE se.name = ?
AND se.timestamp BETWEEN ? AND ?
"""
else:
if granularity:
time_format = self._get_time_format_for_granularity(granularity)
query = f"""
SELECT
se.name,
SUM(CAST(json_extract(se.attributes, '$.value') AS REAL)) as value,
json_extract(se.attributes, '$.unit') as unit,
se.attributes,
strftime('{time_format}', se.timestamp) as bucket_start
FROM span_events se
WHERE se.name = ?
AND se.timestamp BETWEEN ? AND ?
"""
else:
query = """
SELECT
se.name,
json_extract(se.attributes, '$.value') as value,
json_extract(se.attributes, '$.unit') as unit,
se.attributes,
se.timestamp
FROM span_events se
WHERE se.name = ?
AND se.timestamp BETWEEN ? AND ?
"""
params = [f"metric.{metric_name}", start_time.isoformat(), end_time.isoformat()]
# Labels that will be attached to the MetricSeries (preserve matcher labels)
all_labels: list[MetricLabel] = []
matcher_label_names = set()
if label_matchers:
for matcher in label_matchers:
json_path = f"$.{matcher.name}"
if matcher.operator == "=":
query += f" AND json_extract(se.attributes, '{json_path}') = ?"
params.append(matcher.value)
elif matcher.operator == "!=":
query += f" AND json_extract(se.attributes, '{json_path}') != ?"
params.append(matcher.value)
elif matcher.operator == "=~":
query += f" AND json_extract(se.attributes, '{json_path}') LIKE ?"
params.append(f"%{matcher.value}%")
elif matcher.operator == "!~":
query += f" AND json_extract(se.attributes, '{json_path}') NOT LIKE ?"
params.append(f"%{matcher.value}%")
# Preserve filter context in output
all_labels.append(MetricLabel(name=matcher.name, value=str(matcher.value)))
matcher_label_names.add(matcher.name)
# GROUP BY / ORDER BY logic
if query_type == MetricQueryType.RANGE and granularity:
group_time_format = self._get_time_format_for_granularity(granularity)
query += f" GROUP BY strftime('{group_time_format}', se.timestamp), json_extract(se.attributes, '$.unit')"
query += " ORDER BY bucket_start"
elif query_type == MetricQueryType.INSTANT:
query += " GROUP BY json_extract(se.attributes, '$.unit')"
else:
query += " ORDER BY se.timestamp"
# Execute query
async with aiosqlite.connect(self.conn_string) as conn:
conn.row_factory = aiosqlite.Row
async with conn.execute(query, params) as cursor:
rows = await cursor.fetchall()
if not rows:
return QueryMetricsResponse(data=[])
data_points = []
# We want to add attribute labels, but only those not already present as matcher labels.
attr_label_names = set()
for row in rows:
# Parse JSON attributes safely, if there are no attributes (weird), just don't add the labels to the result.
try:
attributes = json.loads(row["attributes"] or "{}")
except (TypeError, json.JSONDecodeError):
attributes = {}
value = row["value"]
unit = row["unit"] or ""
# Add labels from attributes without duplicating matcher labels, if we don't do this, there will be a lot of duplicate label in the result.
for k, v in attributes.items():
if k not in ["value", "unit"] and k not in matcher_label_names and k not in attr_label_names:
all_labels.append(MetricLabel(name=k, value=str(v)))
attr_label_names.add(k)
# Determine timestamp
if query_type == MetricQueryType.RANGE and granularity:
try:
bucket_start_raw = row["bucket_start"]
except KeyError as e:
raise ValueError(
"DB did not have a bucket_start time in row when using granularity, this indicates improper formatting"
) from e
# this value could also be there, but be NULL, I think.
if bucket_start_raw is None:
raise ValueError("bucket_start is None check time format and data")
bucket_start = datetime.fromisoformat(bucket_start_raw)
timestamp = int(bucket_start.timestamp())
elif query_type == MetricQueryType.INSTANT:
timestamp = int(datetime.now(UTC).timestamp())
else:
try:
timestamp_raw = row["timestamp"]
except KeyError as e:
raise ValueError(
"DB did not have a timestamp in row, this indicates improper formatting"
) from e
# this value could also be there, but be NULL, I think.
if timestamp_raw is None:
raise ValueError("timestamp is None check time format and data")
timestamp_iso = datetime.fromisoformat(timestamp_raw)
timestamp = int(timestamp_iso.timestamp())
data_points.append(
MetricDataPoint(
timestamp=timestamp,
value=value,
unit=unit,
)
)
metric_series = [MetricSeries(metric=metric_name, labels=all_labels, values=data_points)]
return QueryMetricsResponse(data=metric_series)
def _get_time_format_for_granularity(self, granularity: str | None) -> str:
"""Get the SQLite strftime format string for a given granularity.
Args:
granularity: Granularity string (e.g., "1m", "5m", "1h", "1d")
Returns:
SQLite strftime format string for the granularity
"""
if granularity is None:
raise ValueError("granularity cannot be None for this method - use separate logic for no aggregation")
if granularity.endswith("d"):
return "%Y-%m-%d 00:00:00"
elif granularity.endswith("h"):
return "%Y-%m-%d %H:00:00"
elif granularity.endswith("m"):
return "%Y-%m-%d %H:%M:00"
else:
return "%Y-%m-%d %H:%M:00" # Default to most granular which will give us the most timestamps.
async def query_traces(
self,
attribute_filters: list[QueryCondition] | None = None,

View file

@ -9,7 +9,6 @@ from __future__ import annotations # for forward references
import hashlib
import json
import os
import sqlite3
from collections.abc import Generator
from contextlib import contextmanager
from enum import StrEnum
@ -125,28 +124,13 @@ class ResponseStorage:
def __init__(self, test_dir: Path):
self.test_dir = test_dir
self.responses_dir = self.test_dir / "responses"
self.db_path = self.test_dir / "index.sqlite"
self._ensure_directories()
self._init_database()
def _ensure_directories(self):
self.test_dir.mkdir(parents=True, exist_ok=True)
self.responses_dir.mkdir(exist_ok=True)
def _init_database(self):
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS recordings (
request_hash TEXT PRIMARY KEY,
response_file TEXT,
endpoint TEXT,
model TEXT,
timestamp TEXT,
is_streaming BOOLEAN
)
""")
def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]):
"""Store a request/response pair."""
# Generate unique response filename
@ -169,34 +153,9 @@ class ResponseStorage:
f.write("\n")
f.flush()
# Update SQLite index
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""
INSERT OR REPLACE INTO recordings
(request_hash, response_file, endpoint, model, timestamp, is_streaming)
VALUES (?, ?, ?, ?, datetime('now'), ?)
""",
(
request_hash,
response_file,
request.get("endpoint", ""),
request.get("model", ""),
response.get("is_streaming", False),
),
)
def find_recording(self, request_hash: str) -> dict[str, Any] | None:
"""Find a recorded response by request hash."""
with sqlite3.connect(self.db_path) as conn:
result = conn.execute(
"SELECT response_file FROM recordings WHERE request_hash = ?", (request_hash,)
).fetchone()
if not result:
return None
response_file = result[0]
response_file = f"{request_hash[:12]}.json"
response_path = self.responses_dir / response_file
if not response_path.exists():

View file

@ -0,0 +1,610 @@
import { describe, test, expect } from "@jest/globals";
// Extract the exact processChunk function implementation for testing
function createProcessChunk() {
return (chunk: unknown): { text: string | null; isToolCall: boolean } => {
const chunkObj = chunk as Record<string, unknown>;
// Helper function to check if content contains function call JSON
const containsToolCall = (content: string): boolean => {
return (
content.includes('"type": "function"') ||
content.includes('"name": "knowledge_search"') ||
content.includes('"parameters":') ||
!!content.match(/\{"type":\s*"function".*?\}/)
);
};
// Check if this chunk contains a tool call (function call)
let isToolCall = false;
// Check direct chunk content if it's a string
if (typeof chunk === "string") {
isToolCall = containsToolCall(chunk);
}
// Check delta structures
if (
chunkObj?.delta &&
typeof chunkObj.delta === "object" &&
chunkObj.delta !== null
) {
const delta = chunkObj.delta as Record<string, unknown>;
if ("tool_calls" in delta) {
isToolCall = true;
}
if (typeof delta.text === "string") {
if (containsToolCall(delta.text)) {
isToolCall = true;
}
}
}
// Check event structures
if (
chunkObj?.event &&
typeof chunkObj.event === "object" &&
chunkObj.event !== null
) {
const event = chunkObj.event as Record<string, unknown>;
// Check event payload
if (
event?.payload &&
typeof event.payload === "object" &&
event.payload !== null
) {
const payload = event.payload as Record<string, unknown>;
if (typeof payload.content === "string") {
if (containsToolCall(payload.content)) {
isToolCall = true;
}
}
// Check payload delta
if (
payload?.delta &&
typeof payload.delta === "object" &&
payload.delta !== null
) {
const delta = payload.delta as Record<string, unknown>;
if (typeof delta.text === "string") {
if (containsToolCall(delta.text)) {
isToolCall = true;
}
}
}
}
// Check event delta
if (
event?.delta &&
typeof event.delta === "object" &&
event.delta !== null
) {
const delta = event.delta as Record<string, unknown>;
if (typeof delta.text === "string") {
if (containsToolCall(delta.text)) {
isToolCall = true;
}
}
if (typeof delta.content === "string") {
if (containsToolCall(delta.content)) {
isToolCall = true;
}
}
}
}
// if it's a tool call, skip it (don't display in chat)
if (isToolCall) {
return { text: null, isToolCall: true };
}
// Extract text content from various chunk formats
let text: string | null = null;
// Helper function to extract clean text content, filtering out function calls
const extractCleanText = (content: string): string | null => {
if (containsToolCall(content)) {
try {
// Try to parse and extract non-function call parts
const jsonMatch = content.match(
/\{"type":\s*"function"[^}]*\}[^}]*\}/
);
if (jsonMatch) {
const jsonPart = jsonMatch[0];
const parsedJson = JSON.parse(jsonPart);
// If it's a function call, extract text after JSON
if (parsedJson.type === "function") {
const textAfterJson = content
.substring(content.indexOf(jsonPart) + jsonPart.length)
.trim();
return textAfterJson || null;
}
}
// If we can't parse it properly, skip the whole thing
return null;
} catch {
return null;
}
}
return content;
};
// Try direct delta text
if (
chunkObj?.delta &&
typeof chunkObj.delta === "object" &&
chunkObj.delta !== null
) {
const delta = chunkObj.delta as Record<string, unknown>;
if (typeof delta.text === "string") {
text = extractCleanText(delta.text);
}
}
// Try event structures
if (
!text &&
chunkObj?.event &&
typeof chunkObj.event === "object" &&
chunkObj.event !== null
) {
const event = chunkObj.event as Record<string, unknown>;
// Try event payload content
if (
event?.payload &&
typeof event.payload === "object" &&
event.payload !== null
) {
const payload = event.payload as Record<string, unknown>;
// Try direct payload content
if (typeof payload.content === "string") {
text = extractCleanText(payload.content);
}
// Try turn_complete event structure: payload.turn.output_message.content
if (
!text &&
payload?.turn &&
typeof payload.turn === "object" &&
payload.turn !== null
) {
const turn = payload.turn as Record<string, unknown>;
if (
turn?.output_message &&
typeof turn.output_message === "object" &&
turn.output_message !== null
) {
const outputMessage = turn.output_message as Record<
string,
unknown
>;
if (typeof outputMessage.content === "string") {
text = extractCleanText(outputMessage.content);
}
}
// Fallback to model_response in steps if no output_message
if (
!text &&
turn?.steps &&
Array.isArray(turn.steps) &&
turn.steps.length > 0
) {
for (const step of turn.steps) {
if (step && typeof step === "object" && step !== null) {
const stepObj = step as Record<string, unknown>;
if (
stepObj?.model_response &&
typeof stepObj.model_response === "object" &&
stepObj.model_response !== null
) {
const modelResponse = stepObj.model_response as Record<
string,
unknown
>;
if (typeof modelResponse.content === "string") {
text = extractCleanText(modelResponse.content);
break;
}
}
}
}
}
}
// Try payload delta
if (
!text &&
payload?.delta &&
typeof payload.delta === "object" &&
payload.delta !== null
) {
const delta = payload.delta as Record<string, unknown>;
if (typeof delta.text === "string") {
text = extractCleanText(delta.text);
}
}
}
// Try event delta
if (
!text &&
event?.delta &&
typeof event.delta === "object" &&
event.delta !== null
) {
const delta = event.delta as Record<string, unknown>;
if (typeof delta.text === "string") {
text = extractCleanText(delta.text);
}
if (!text && typeof delta.content === "string") {
text = extractCleanText(delta.content);
}
}
}
// Try choices structure (ChatML format)
if (
!text &&
chunkObj?.choices &&
Array.isArray(chunkObj.choices) &&
chunkObj.choices.length > 0
) {
const choice = chunkObj.choices[0] as Record<string, unknown>;
if (
choice?.delta &&
typeof choice.delta === "object" &&
choice.delta !== null
) {
const delta = choice.delta as Record<string, unknown>;
if (typeof delta.content === "string") {
text = extractCleanText(delta.content);
}
}
}
// Try direct string content
if (!text && typeof chunk === "string") {
text = extractCleanText(chunk);
}
return { text, isToolCall: false };
};
}
describe("Chunk Processor", () => {
const processChunk = createProcessChunk();
describe("Real Event Structures", () => {
test("handles turn_complete event with cancellation policy response", () => {
const chunk = {
event: {
payload: {
event_type: "turn_complete",
turn: {
turn_id: "50a2d6b7-49ed-4d1e-b1c2-6d68b3f726db",
session_id: "e7f62b8e-518c-4450-82df-e65fe49f27a3",
input_messages: [
{
role: "user",
content: "nice, what's the cancellation policy?",
context: null,
},
],
steps: [
{
turn_id: "50a2d6b7-49ed-4d1e-b1c2-6d68b3f726db",
step_id: "54074310-af42-414c-9ffe-fba5b2ead0ad",
started_at: "2025-08-27T18:15:25.870703Z",
completed_at: "2025-08-27T18:15:51.288993Z",
step_type: "inference",
model_response: {
role: "assistant",
content:
"According to the search results, the cancellation policy for Red Hat Summit is as follows:\n\n* Cancellations must be received by 5 PM EDT on April 18, 2025 for a 50% refund of the registration fee.\n* No refunds will be given for cancellations received after 5 PM EDT on April 18, 2025.\n* Cancellation of travel reservations and hotel reservations are the responsibility of the registrant.",
stop_reason: "end_of_turn",
tool_calls: [],
},
},
],
output_message: {
role: "assistant",
content:
"According to the search results, the cancellation policy for Red Hat Summit is as follows:\n\n* Cancellations must be received by 5 PM EDT on April 18, 2025 for a 50% refund of the registration fee.\n* No refunds will be given for cancellations received after 5 PM EDT on April 18, 2025.\n* Cancellation of travel reservations and hotel reservations are the responsibility of the registrant.",
stop_reason: "end_of_turn",
tool_calls: [],
},
output_attachments: [],
started_at: "2025-08-27T18:15:25.868548Z",
completed_at: "2025-08-27T18:15:51.289262Z",
},
},
},
};
const result = processChunk(chunk);
expect(result.isToolCall).toBe(false);
expect(result.text).toContain(
"According to the search results, the cancellation policy for Red Hat Summit is as follows:"
);
expect(result.text).toContain("5 PM EDT on April 18, 2025");
});
test("handles turn_complete event with address response", () => {
const chunk = {
event: {
payload: {
event_type: "turn_complete",
turn: {
turn_id: "2f4a1520-8ecc-4cb7-bb7b-886939e042b0",
session_id: "e7f62b8e-518c-4450-82df-e65fe49f27a3",
input_messages: [
{
role: "user",
content: "what's francisco's address",
context: null,
},
],
steps: [
{
turn_id: "2f4a1520-8ecc-4cb7-bb7b-886939e042b0",
step_id: "c13dd277-1acb-4419-8fbf-d5e2f45392ea",
started_at: "2025-08-27T18:14:52.558761Z",
completed_at: "2025-08-27T18:15:11.306032Z",
step_type: "inference",
model_response: {
role: "assistant",
content:
"Francisco Arceo's address is:\n\nRed Hat\nUnited States\n17 Primrose Ln \nBasking Ridge New Jersey 07920",
stop_reason: "end_of_turn",
tool_calls: [],
},
},
],
output_message: {
role: "assistant",
content:
"Francisco Arceo's address is:\n\nRed Hat\nUnited States\n17 Primrose Ln \nBasking Ridge New Jersey 07920",
stop_reason: "end_of_turn",
tool_calls: [],
},
output_attachments: [],
started_at: "2025-08-27T18:14:52.553707Z",
completed_at: "2025-08-27T18:15:11.306729Z",
},
},
},
};
const result = processChunk(chunk);
expect(result.isToolCall).toBe(false);
expect(result.text).toContain("Francisco Arceo's address is:");
expect(result.text).toContain("17 Primrose Ln");
expect(result.text).toContain("Basking Ridge New Jersey 07920");
});
test("handles turn_complete event with ticket cost response", () => {
const chunk = {
event: {
payload: {
event_type: "turn_complete",
turn: {
turn_id: "7ef244a3-efee-42ca-a9c8-942865251002",
session_id: "e7f62b8e-518c-4450-82df-e65fe49f27a3",
input_messages: [
{
role: "user",
content: "what was the ticket cost for summit?",
context: null,
},
],
steps: [
{
turn_id: "7ef244a3-efee-42ca-a9c8-942865251002",
step_id: "7651dda0-315a-472d-b1c1-3c2725f55bc5",
started_at: "2025-08-27T18:14:21.710611Z",
completed_at: "2025-08-27T18:14:39.706452Z",
step_type: "inference",
model_response: {
role: "assistant",
content:
"The ticket cost for the Red Hat Summit was $999.00 for a conference pass.",
stop_reason: "end_of_turn",
tool_calls: [],
},
},
],
output_message: {
role: "assistant",
content:
"The ticket cost for the Red Hat Summit was $999.00 for a conference pass.",
stop_reason: "end_of_turn",
tool_calls: [],
},
output_attachments: [],
started_at: "2025-08-27T18:14:21.705289Z",
completed_at: "2025-08-27T18:14:39.706752Z",
},
},
},
};
const result = processChunk(chunk);
expect(result.isToolCall).toBe(false);
expect(result.text).toBe(
"The ticket cost for the Red Hat Summit was $999.00 for a conference pass."
);
});
});
describe("Function Call Detection", () => {
test("detects function calls in direct string chunks", () => {
const chunk =
'{"type": "function", "name": "knowledge_search", "parameters": {"query": "test"}}';
const result = processChunk(chunk);
expect(result.isToolCall).toBe(true);
expect(result.text).toBe(null);
});
test("detects function calls in event payload content", () => {
const chunk = {
event: {
payload: {
content:
'{"type": "function", "name": "knowledge_search", "parameters": {"query": "test"}}',
},
},
};
const result = processChunk(chunk);
expect(result.isToolCall).toBe(true);
expect(result.text).toBe(null);
});
test("detects tool_calls in delta structure", () => {
const chunk = {
delta: {
tool_calls: [{ function: { name: "knowledge_search" } }],
},
};
const result = processChunk(chunk);
expect(result.isToolCall).toBe(true);
expect(result.text).toBe(null);
});
test("detects function call in mixed content but skips it", () => {
const chunk =
'{"type": "function", "name": "knowledge_search", "parameters": {"query": "test"}} Based on the search results, here is your answer.';
const result = processChunk(chunk);
// This is detected as a tool call and skipped entirely - the implementation prioritizes safety
expect(result.isToolCall).toBe(true);
expect(result.text).toBe(null);
});
});
describe("Text Extraction", () => {
test("extracts text from direct string chunks", () => {
const chunk = "Hello, this is a normal response.";
const result = processChunk(chunk);
expect(result.isToolCall).toBe(false);
expect(result.text).toBe("Hello, this is a normal response.");
});
test("extracts text from delta structure", () => {
const chunk = {
delta: {
text: "Hello, this is a normal response.",
},
};
const result = processChunk(chunk);
expect(result.isToolCall).toBe(false);
expect(result.text).toBe("Hello, this is a normal response.");
});
test("extracts text from choices structure", () => {
const chunk = {
choices: [
{
delta: {
content: "Hello, this is a normal response.",
},
},
],
};
const result = processChunk(chunk);
expect(result.isToolCall).toBe(false);
expect(result.text).toBe("Hello, this is a normal response.");
});
test("prioritizes output_message over model_response in turn structure", () => {
const chunk = {
event: {
payload: {
turn: {
steps: [
{
model_response: {
content: "Model response content.",
},
},
],
output_message: {
content: "Final output message content.",
},
},
},
},
};
const result = processChunk(chunk);
expect(result.isToolCall).toBe(false);
expect(result.text).toBe("Final output message content.");
});
test("falls back to model_response when no output_message", () => {
const chunk = {
event: {
payload: {
turn: {
steps: [
{
model_response: {
content: "This is from the model response.",
},
},
],
},
},
},
};
const result = processChunk(chunk);
expect(result.isToolCall).toBe(false);
expect(result.text).toBe("This is from the model response.");
});
});
describe("Edge Cases", () => {
test("handles empty chunks", () => {
const result = processChunk("");
expect(result.isToolCall).toBe(false);
expect(result.text).toBe("");
});
test("handles null chunks", () => {
const result = processChunk(null);
expect(result.isToolCall).toBe(false);
expect(result.text).toBe(null);
});
test("handles undefined chunks", () => {
const result = processChunk(undefined);
expect(result.isToolCall).toBe(false);
expect(result.text).toBe(null);
});
test("handles chunks with no text content", () => {
const chunk = {
event: {
metadata: {
timestamp: "2024-01-01",
},
},
};
const result = processChunk(chunk);
expect(result.isToolCall).toBe(false);
expect(result.text).toBe(null);
});
test("handles malformed JSON in function calls gracefully", () => {
const chunk =
'{"type": "function", "name": "knowledge_search"} incomplete json';
const result = processChunk(chunk);
expect(result.isToolCall).toBe(true);
expect(result.text).toBe(null);
});
});
});

View file

@ -0,0 +1,790 @@
import React from "react";
import {
render,
screen,
fireEvent,
waitFor,
act,
} from "@testing-library/react";
import "@testing-library/jest-dom";
import ChatPlaygroundPage from "./page";
const mockClient = {
agents: {
list: jest.fn(),
create: jest.fn(),
retrieve: jest.fn(),
delete: jest.fn(),
session: {
list: jest.fn(),
create: jest.fn(),
delete: jest.fn(),
retrieve: jest.fn(),
},
turn: {
create: jest.fn(),
},
},
models: {
list: jest.fn(),
},
toolgroups: {
list: jest.fn(),
},
vectorDBs: {
list: jest.fn(),
},
};
jest.mock("@/hooks/use-auth-client", () => ({
useAuthClient: jest.fn(() => mockClient),
}));
jest.mock("@/components/chat-playground/chat", () => ({
Chat: jest.fn(
({
className,
messages,
handleSubmit,
input,
handleInputChange,
isGenerating,
append,
suggestions,
}) => (
<div data-testid="chat-component" className={className}>
<div data-testid="messages-count">{messages.length}</div>
<input
data-testid="chat-input"
value={input}
onChange={handleInputChange}
disabled={isGenerating}
/>
<button data-testid="submit-button" onClick={handleSubmit}>
Submit
</button>
{suggestions?.map((suggestion: string, index: number) => (
<button
key={index}
data-testid={`suggestion-${index}`}
onClick={() => append({ role: "user", content: suggestion })}
>
{suggestion}
</button>
))}
</div>
)
),
}));
jest.mock("@/components/chat-playground/conversations", () => ({
SessionManager: jest.fn(({ selectedAgentId, onNewSession }) => (
<div data-testid="session-manager">
{selectedAgentId && (
<>
<div data-testid="selected-agent">{selectedAgentId}</div>
<button data-testid="new-session-button" onClick={onNewSession}>
New Session
</button>
</>
)}
</div>
)),
SessionUtils: {
saveCurrentSessionId: jest.fn(),
loadCurrentSessionId: jest.fn(),
loadCurrentAgentId: jest.fn(),
saveCurrentAgentId: jest.fn(),
clearCurrentSession: jest.fn(),
saveSessionData: jest.fn(),
loadSessionData: jest.fn(),
saveAgentConfig: jest.fn(),
loadAgentConfig: jest.fn(),
clearAgentCache: jest.fn(),
createDefaultSession: jest.fn(() => ({
id: "test-session-123",
name: "Default Session",
messages: [],
selectedModel: "",
systemMessage: "You are a helpful assistant.",
agentId: "test-agent-123",
createdAt: Date.now(),
updatedAt: Date.now(),
})),
},
}));
const mockAgents = [
{
agent_id: "agent_123",
agent_config: {
name: "Test Agent",
instructions: "You are a test assistant.",
},
},
{
agent_id: "agent_456",
agent_config: {
agent_name: "Another Agent",
instructions: "You are another assistant.",
},
},
];
const mockModels = [
{
identifier: "test-model-1",
model_type: "llm",
},
{
identifier: "test-model-2",
model_type: "llm",
},
];
const mockToolgroups = [
{
identifier: "builtin::rag",
provider_id: "test-provider",
type: "tool_group",
provider_resource_id: "test-resource",
},
];
describe("ChatPlaygroundPage", () => {
beforeEach(() => {
jest.clearAllMocks();
Element.prototype.scrollIntoView = jest.fn();
mockClient.agents.list.mockResolvedValue({ data: mockAgents });
mockClient.models.list.mockResolvedValue(mockModels);
mockClient.toolgroups.list.mockResolvedValue(mockToolgroups);
mockClient.agents.session.create.mockResolvedValue({
session_id: "new-session-123",
});
mockClient.agents.session.list.mockResolvedValue({ data: [] });
mockClient.agents.session.retrieve.mockResolvedValue({
session_id: "test-session",
session_name: "Test Session",
started_at: new Date().toISOString(),
turns: [],
});
mockClient.agents.retrieve.mockResolvedValue({
agent_id: "test-agent",
agent_config: {
toolgroups: ["builtin::rag"],
instructions: "Test instructions",
model: "test-model",
},
});
mockClient.agents.delete.mockResolvedValue(undefined);
});
describe("Agent Selector Rendering", () => {
test("shows agent selector when agents are available", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(screen.getByText("Agent Session:")).toBeInTheDocument();
expect(screen.getAllByRole("combobox")).toHaveLength(2);
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
expect(screen.getByText("Clear Chat")).toBeInTheDocument();
});
});
test("does not show agent selector when no agents are available", async () => {
mockClient.agents.list.mockResolvedValue({ data: [] });
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument();
expect(screen.getAllByRole("combobox")).toHaveLength(1);
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument();
});
});
test("does not show agent selector while loading", async () => {
mockClient.agents.list.mockImplementation(() => new Promise(() => {}));
await act(async () => {
render(<ChatPlaygroundPage />);
});
expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument();
expect(screen.getAllByRole("combobox")).toHaveLength(1);
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument();
});
test("shows agent options in selector", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
const agentCombobox = screen.getAllByRole("combobox").find(element => {
return (
element.textContent?.includes("Test Agent") ||
element.textContent?.includes("Select Agent")
);
});
expect(agentCombobox).toBeDefined();
fireEvent.click(agentCombobox!);
});
await waitFor(() => {
expect(screen.getAllByText("Test Agent")).toHaveLength(2);
expect(screen.getByText("Another Agent")).toBeInTheDocument();
});
});
test("displays agent ID when no name is available", async () => {
const agentWithoutName = {
agent_id: "agent_789",
agent_config: {
instructions: "You are an agent without a name.",
},
};
mockClient.agents.list.mockResolvedValue({ data: [agentWithoutName] });
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
const agentCombobox = screen.getAllByRole("combobox").find(element => {
return (
element.textContent?.includes("Agent agent_78") ||
element.textContent?.includes("Select Agent")
);
});
expect(agentCombobox).toBeDefined();
fireEvent.click(agentCombobox!);
});
await waitFor(() => {
expect(screen.getAllByText("Agent agent_78...")).toHaveLength(2);
});
});
});
describe("Agent Creation Modal", () => {
test("opens agent creation modal when + New Agent is clicked", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
const newAgentButton = screen.getByText("+ New Agent");
fireEvent.click(newAgentButton);
expect(screen.getByText("Create New Agent")).toBeInTheDocument();
expect(screen.getByText("Agent Name (optional)")).toBeInTheDocument();
expect(screen.getAllByText("Model")).toHaveLength(2);
expect(screen.getByText("System Instructions")).toBeInTheDocument();
expect(screen.getByText("Tools (optional)")).toBeInTheDocument();
});
test("closes modal when Cancel is clicked", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
const newAgentButton = screen.getByText("+ New Agent");
fireEvent.click(newAgentButton);
const cancelButton = screen.getByText("Cancel");
fireEvent.click(cancelButton);
expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument();
});
test("creates agent when Create Agent is clicked", async () => {
mockClient.agents.create.mockResolvedValue({ agent_id: "new-agent-123" });
mockClient.agents.list
.mockResolvedValueOnce({ data: mockAgents })
.mockResolvedValueOnce({
data: [
...mockAgents,
{ agent_id: "new-agent-123", agent_config: { name: "New Agent" } },
],
});
await act(async () => {
render(<ChatPlaygroundPage />);
});
const newAgentButton = screen.getByText("+ New Agent");
await act(async () => {
fireEvent.click(newAgentButton);
});
await waitFor(() => {
expect(screen.getByText("Create New Agent")).toBeInTheDocument();
});
const nameInput = screen.getByPlaceholderText("My Custom Agent");
await act(async () => {
fireEvent.change(nameInput, { target: { value: "Test Agent Name" } });
});
const instructionsTextarea = screen.getByDisplayValue(
"You are a helpful assistant."
);
await act(async () => {
fireEvent.change(instructionsTextarea, {
target: { value: "Custom instructions" },
});
});
await waitFor(() => {
const modalModelSelectors = screen
.getAllByRole("combobox")
.filter(el => {
return (
el.textContent?.includes("Select Model") ||
el.closest('[class*="modal"]') ||
el.closest('[class*="card"]')
);
});
expect(modalModelSelectors.length).toBeGreaterThan(0);
});
const modalModelSelectors = screen.getAllByRole("combobox").filter(el => {
return (
el.textContent?.includes("Select Model") ||
el.closest('[class*="modal"]') ||
el.closest('[class*="card"]')
);
});
await act(async () => {
fireEvent.click(modalModelSelectors[0]);
});
await waitFor(() => {
const modelOptions = screen.getAllByText("test-model-1");
expect(modelOptions.length).toBeGreaterThan(0);
});
const modelOptions = screen.getAllByText("test-model-1");
const dropdownOption = modelOptions.find(
option =>
option.closest('[role="option"]') ||
option.id?.includes("radix") ||
option.getAttribute("aria-selected") !== null
);
await act(async () => {
fireEvent.click(
dropdownOption || modelOptions[modelOptions.length - 1]
);
});
await waitFor(() => {
const createButton = screen.getByText("Create Agent");
expect(createButton).not.toBeDisabled();
});
const createButton = screen.getByText("Create Agent");
await act(async () => {
fireEvent.click(createButton);
});
await waitFor(() => {
expect(mockClient.agents.create).toHaveBeenCalledWith({
agent_config: {
model: expect.any(String),
instructions: "Custom instructions",
name: "Test Agent Name",
enable_session_persistence: true,
},
});
});
await waitFor(() => {
expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument();
});
});
});
describe("Agent Selection", () => {
test("creates default session when agent is selected", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(mockClient.agents.session.create).toHaveBeenCalledWith(
"agent_123",
{ session_name: "Default Session" }
);
});
});
test("switches agent when different agent is selected", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
const agentCombobox = screen.getAllByRole("combobox").find(element => {
return (
element.textContent?.includes("Test Agent") ||
element.textContent?.includes("Select Agent")
);
});
expect(agentCombobox).toBeDefined();
fireEvent.click(agentCombobox!);
});
await waitFor(() => {
const anotherAgentOption = screen.getByText("Another Agent");
fireEvent.click(anotherAgentOption);
});
expect(mockClient.agents.session.create).toHaveBeenCalledWith(
"agent_456",
{ session_name: "Default Session" }
);
});
});
describe("Agent Deletion", () => {
test("shows delete button when multiple agents exist", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
});
});
test("shows delete button even when only one agent exists", async () => {
mockClient.agents.list.mockResolvedValue({
data: [mockAgents[0]],
});
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
});
});
test("deletes agent and switches to another when confirmed", async () => {
global.confirm = jest.fn(() => true);
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
});
mockClient.agents.delete.mockResolvedValue(undefined);
mockClient.agents.list.mockResolvedValueOnce({ data: mockAgents });
mockClient.agents.list.mockResolvedValueOnce({
data: [mockAgents[1]],
});
const deleteButton = screen.getByTitle("Delete current agent");
await act(async () => {
deleteButton.click();
});
await waitFor(() => {
expect(mockClient.agents.delete).toHaveBeenCalledWith("agent_123");
expect(global.confirm).toHaveBeenCalledWith(
"Are you sure you want to delete this agent? This action cannot be undone and will delete the agent and all its sessions."
);
});
(global.confirm as jest.Mock).mockRestore();
});
test("does not delete agent when cancelled", async () => {
global.confirm = jest.fn(() => false);
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
});
const deleteButton = screen.getByTitle("Delete current agent");
await act(async () => {
deleteButton.click();
});
await waitFor(() => {
expect(global.confirm).toHaveBeenCalled();
expect(mockClient.agents.delete).not.toHaveBeenCalled();
});
(global.confirm as jest.Mock).mockRestore();
});
});
describe("Error Handling", () => {
test("handles agent loading errors gracefully", async () => {
mockClient.agents.list.mockRejectedValue(
new Error("Failed to load agents")
);
const consoleSpy = jest
.spyOn(console, "error")
.mockImplementation(() => {});
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(consoleSpy).toHaveBeenCalledWith(
"Error fetching agents:",
expect.any(Error)
);
});
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
consoleSpy.mockRestore();
});
test("handles model loading errors gracefully", async () => {
mockClient.models.list.mockRejectedValue(
new Error("Failed to load models")
);
const consoleSpy = jest
.spyOn(console, "error")
.mockImplementation(() => {});
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(consoleSpy).toHaveBeenCalledWith(
"Error fetching models:",
expect.any(Error)
);
});
consoleSpy.mockRestore();
});
});
describe("RAG File Upload", () => {
let mockFileReader: {
readAsDataURL: jest.Mock;
readAsText: jest.Mock;
result: string | null;
onload: (() => void) | null;
onerror: (() => void) | null;
};
let mockRAGTool: {
insert: jest.Mock;
};
beforeEach(() => {
mockFileReader = {
readAsDataURL: jest.fn(),
readAsText: jest.fn(),
result: null,
onload: null,
onerror: null,
};
global.FileReader = jest.fn(() => mockFileReader);
mockRAGTool = {
insert: jest.fn().mockResolvedValue({}),
};
mockClient.toolRuntime = {
ragTool: mockRAGTool,
};
});
afterEach(() => {
jest.clearAllMocks();
});
test("handles text file upload", async () => {
new File(["Hello, world!"], "test.txt", {
type: "text/plain",
});
mockClient.agents.retrieve.mockResolvedValue({
agent_id: "test-agent",
agent_config: {
toolgroups: [
{
name: "builtin::rag/knowledge_search",
args: { vector_db_ids: ["test-vector-db"] },
},
],
},
});
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(screen.getByTestId("chat-component")).toBeInTheDocument();
});
const chatComponent = screen.getByTestId("chat-component");
chatComponent.getAttribute("data-onragfileupload");
// this is a simplified test
expect(mockRAGTool.insert).not.toHaveBeenCalled();
});
test("handles PDF file upload with FileReader", async () => {
new File([new ArrayBuffer(1000)], "test.pdf", {
type: "application/pdf",
});
const mockDataURL = "data:application/pdf;base64,JVBERi0xLjQK";
mockFileReader.result = mockDataURL;
mockClient.agents.retrieve.mockResolvedValue({
agent_id: "test-agent",
agent_config: {
toolgroups: [
{
name: "builtin::rag/knowledge_search",
args: { vector_db_ids: ["test-vector-db"] },
},
],
},
});
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(screen.getByTestId("chat-component")).toBeInTheDocument();
});
expect(global.FileReader).toBeDefined();
});
test("handles different file types correctly", () => {
const getContentType = (filename: string): string => {
const ext = filename.toLowerCase().split(".").pop();
switch (ext) {
case "pdf":
return "application/pdf";
case "txt":
return "text/plain";
case "md":
return "text/markdown";
case "html":
return "text/html";
case "csv":
return "text/csv";
case "json":
return "application/json";
case "docx":
return "application/vnd.openxmlformats-officedocument.wordprocessingml.document";
case "doc":
return "application/msword";
default:
return "application/octet-stream";
}
};
expect(getContentType("test.pdf")).toBe("application/pdf");
expect(getContentType("test.txt")).toBe("text/plain");
expect(getContentType("test.md")).toBe("text/markdown");
expect(getContentType("test.html")).toBe("text/html");
expect(getContentType("test.csv")).toBe("text/csv");
expect(getContentType("test.json")).toBe("application/json");
expect(getContentType("test.docx")).toBe(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
);
expect(getContentType("test.doc")).toBe("application/msword");
expect(getContentType("test.unknown")).toBe("application/octet-stream");
});
test("determines text vs binary file types correctly", () => {
const isTextFile = (mimeType: string): boolean => {
return (
mimeType.startsWith("text/") ||
mimeType === "application/json" ||
mimeType === "text/markdown" ||
mimeType === "text/html" ||
mimeType === "text/csv"
);
};
expect(isTextFile("text/plain")).toBe(true);
expect(isTextFile("text/markdown")).toBe(true);
expect(isTextFile("text/html")).toBe(true);
expect(isTextFile("text/csv")).toBe(true);
expect(isTextFile("application/json")).toBe(true);
expect(isTextFile("application/pdf")).toBe(false);
expect(isTextFile("application/msword")).toBe(false);
expect(
isTextFile(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
)
).toBe(false);
expect(isTextFile("application/octet-stream")).toBe(false);
});
test("handles FileReader error gracefully", async () => {
const pdfFile = new File([new ArrayBuffer(1000)], "test.pdf", {
type: "application/pdf",
});
mockFileReader.onerror = jest.fn();
const mockError = new Error("FileReader failed");
const fileReaderPromise = new Promise<string>((resolve, reject) => {
const reader = new FileReader();
reader.onload = () => resolve(reader.result as string);
reader.onerror = () => reject(reader.error || mockError);
reader.readAsDataURL(pdfFile);
setTimeout(() => {
reader.onerror?.(new ProgressEvent("error"));
}, 0);
});
await expect(fileReaderPromise).rejects.toBeDefined();
});
test("handles large file upload with FileReader approach", () => {
// create a large file
const largeFile = new File(
[new ArrayBuffer(10 * 1024 * 1024)],
"large.pdf",
{
type: "application/pdf",
}
);
expect(largeFile.size).toBe(10 * 1024 * 1024); // 10MB
expect(global.FileReader).toBeDefined();
const reader = new FileReader();
expect(reader.readAsDataURL).toBeDefined();
});
});
});

File diff suppressed because it is too large Load diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 25 KiB

View file

@ -120,3 +120,44 @@
@apply bg-background text-foreground;
}
}
@layer utilities {
.animate-typing-dot-1 {
animation: typing-dot-bounce-1 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
}
.animate-typing-dot-2 {
animation: typing-dot-bounce-2 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
}
.animate-typing-dot-3 {
animation: typing-dot-bounce-3 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
}
@keyframes typing-dot-bounce-1 {
0%, 15%, 85%, 100% {
transform: translateY(0);
}
7.5% {
transform: translateY(-6px);
}
}
@keyframes typing-dot-bounce-2 {
0%, 15%, 35%, 85%, 100% {
transform: translateY(0);
}
25% {
transform: translateY(-6px);
}
}
@keyframes typing-dot-bounce-3 {
0%, 35%, 55%, 85%, 100% {
transform: translateY(0);
}
45% {
transform: translateY(-6px);
}
}
}

View file

@ -18,6 +18,9 @@ const geistMono = Geist_Mono({
export const metadata: Metadata = {
title: "Llama Stack",
description: "Llama Stack UI",
icons: {
icon: "/favicon.ico",
},
};
import { SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";

Some files were not shown because too many files have changed in this diff Show more