Merge remote-tracking branch 'origin/main' into ai

This commit is contained in:
Ashwin Bharambe 2025-08-22 09:29:06 -07:00
commit 0b3218ed81
94 changed files with 4092 additions and 235 deletions

View file

@ -36,20 +36,16 @@ jobs:
**/requirements*.txt
.pre-commit-config.yaml
# npm ci may fail -
# npm error `npm ci` can only install packages when your package.json and package-lock.json or npm-shrinkwrap.json are in sync. Please update your lock file with `npm install` before continuing.
# npm error Invalid: lock file's llama-stack-client@0.2.17 does not satisfy llama-stack-client@0.2.18
- name: Set up Node.js
uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0
with:
node-version: '20'
cache: 'npm'
cache-dependency-path: 'llama_stack/ui/'
# - name: Set up Node.js
# uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0
# with:
# node-version: '20'
# cache: 'npm'
# cache-dependency-path: 'llama_stack/ui/'
# - name: Install npm dependencies
# run: npm ci
# working-directory: llama_stack/ui
- name: Install npm dependencies
run: npm ci
working-directory: llama_stack/ui
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
continue-on-error: true

View file

@ -146,31 +146,13 @@ repos:
pass_filenames: false
require_serial: true
files: ^.github/workflows/.*$
# ui-prettier and ui-eslint are disabled until we can avoid `npm ci`, which is slow and may fail -
# npm error `npm ci` can only install packages when your package.json and package-lock.json or npm-shrinkwrap.json are in sync. Please update your lock file with `npm install` before continuing.
# npm error Invalid: lock file's llama-stack-client@0.2.17 does not satisfy llama-stack-client@0.2.18
# and until we have infra for installing prettier and next via npm -
# Lint UI code with ESLint.....................................................Failed
# - hook id: ui-eslint
# - exit code: 127
# > ui@0.1.0 lint
# > next lint --fix --quiet
# sh: line 1: next: command not found
#
# - id: ui-prettier
# name: Format UI code with Prettier
# entry: bash -c 'cd llama_stack/ui && npm ci && npm run format'
# language: system
# files: ^llama_stack/ui/.*\.(ts|tsx)$
# pass_filenames: false
# require_serial: true
# - id: ui-eslint
# name: Lint UI code with ESLint
# entry: bash -c 'cd llama_stack/ui && npm run lint -- --fix --quiet'
# language: system
# files: ^llama_stack/ui/.*\.(ts|tsx)$
# pass_filenames: false
# require_serial: true
- id: ui-linter
name: Format & Lint UI
entry: bash ./scripts/run-ui-linter.sh
language: system
files: ^llama_stack/ui/.*\.(ts|tsx)$
pass_filenames: false
require_serial: true
- id: check-log-usage
name: Ensure 'llama_stack.log' usage for logging

View file

@ -4605,6 +4605,49 @@
}
}
},
"/v1/inference/rerank": {
"post": {
"responses": {
"200": {
"description": "RerankResponse with indices sorted by relevance score (descending).",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/RerankResponse"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Inference"
],
"description": "Rerank a list of documents based on their relevance to a query.",
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/RerankRequest"
}
}
},
"required": true
}
}
},
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
"post": {
"responses": {
@ -16587,6 +16630,95 @@
],
"title": "RegisterVectorDbRequest"
},
"RerankRequest": {
"type": "object",
"properties": {
"model": {
"type": "string",
"description": "The identifier of the reranking model to use."
},
"query": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
}
],
"description": "The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length."
},
"items": {
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
}
]
},
"description": "List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length."
},
"max_num_results": {
"type": "integer",
"description": "(Optional) Maximum number of results to return. Default: returns all."
}
},
"additionalProperties": false,
"required": [
"model",
"query",
"items"
],
"title": "RerankRequest"
},
"RerankData": {
"type": "object",
"properties": {
"index": {
"type": "integer",
"description": "The original index of the document in the input list"
},
"relevance_score": {
"type": "number",
"description": "The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance."
}
},
"additionalProperties": false,
"required": [
"index",
"relevance_score"
],
"title": "RerankData",
"description": "A single rerank result from a reranking response."
},
"RerankResponse": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/components/schemas/RerankData"
},
"description": "List of rerank result objects, sorted by relevance score (descending)"
}
},
"additionalProperties": false,
"required": [
"data"
],
"title": "RerankResponse",
"description": "Response from a reranking request."
},
"ResumeAgentTurnRequest": {
"type": "object",
"properties": {

View file

@ -3264,6 +3264,37 @@ paths:
schema:
$ref: '#/components/schemas/QueryTracesRequest'
required: true
/v1/inference/rerank:
post:
responses:
'200':
description: >-
RerankResponse with indices sorted by relevance score (descending).
content:
application/json:
schema:
$ref: '#/components/schemas/RerankResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Inference
description: >-
Rerank a list of documents based on their relevance to a query.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RerankRequest'
required: true
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
post:
responses:
@ -12337,6 +12368,76 @@ components:
- vector_db_id
- embedding_model
title: RegisterVectorDbRequest
RerankRequest:
type: object
properties:
model:
type: string
description: >-
The identifier of the reranking model to use.
query:
oneOf:
- type: string
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
description: >-
The search query to rank items against. Can be a string, text content
part, or image content part. The input must not exceed the model's max
input token length.
items:
type: array
items:
oneOf:
- type: string
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
description: >-
List of items to rerank. Each item can be a string, text content part,
or image content part. Each input must not exceed the model's max input
token length.
max_num_results:
type: integer
description: >-
(Optional) Maximum number of results to return. Default: returns all.
additionalProperties: false
required:
- model
- query
- items
title: RerankRequest
RerankData:
type: object
properties:
index:
type: integer
description: >-
The original index of the document in the input list
relevance_score:
type: number
description: >-
The relevance score from the model output. Values are inverted when applicable
so that higher scores indicate greater relevance.
additionalProperties: false
required:
- index
- relevance_score
title: RerankData
description: >-
A single rerank result from a reranking response.
RerankResponse:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/RerankData'
description: >-
List of rerank result objects, sorted by relevance score (descending)
additionalProperties: false
required:
- data
title: RerankResponse
description: Response from a reranking request.
ResumeAgentTurnRequest:
type: object
properties:

View file

@ -10,4 +10,5 @@ This section contains documentation for all available providers for the **files*
:maxdepth: 1
inline_localfs
remote_s3
```

View file

@ -0,0 +1,33 @@
# remote::s3
## Description
AWS S3-based file storage provider for scalable cloud file management with metadata persistence.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `bucket_name` | `<class 'str'>` | No | | S3 bucket name to store files |
| `region` | `<class 'str'>` | No | us-east-1 | AWS region where the bucket is located |
| `aws_access_key_id` | `str \| None` | No | | AWS access key ID (optional if using IAM roles) |
| `aws_secret_access_key` | `str \| None` | No | | AWS secret access key (optional if using IAM roles) |
| `endpoint_url` | `str \| None` | No | | Custom S3 endpoint URL (for MinIO, LocalStack, etc.) |
| `auto_create_bucket` | `<class 'bool'>` | No | False | Automatically create the S3 bucket if it doesn't exist |
| `metadata_store` | `utils.sqlstore.sqlstore.SqliteSqlStoreConfig \| utils.sqlstore.sqlstore.PostgresSqlStoreConfig` | No | sqlite | SQL store configuration for file metadata |
## Sample Configuration
```yaml
bucket_name: ${env.S3_BUCKET_NAME}
region: ${env.AWS_REGION:=us-east-1}
aws_access_key_id: ${env.AWS_ACCESS_KEY_ID:=}
aws_secret_access_key: ${env.AWS_SECRET_ACCESS_KEY:=}
endpoint_url: ${env.S3_ENDPOINT_URL:=}
auto_create_bucket: ${env.S3_AUTO_CREATE_BUCKET:=false}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/s3_files_metadata.db
```

View file

@ -473,6 +473,28 @@ class EmbeddingsResponse(BaseModel):
embeddings: list[list[float]]
@json_schema_type
class RerankData(BaseModel):
"""A single rerank result from a reranking response.
:param index: The original index of the document in the input list
:param relevance_score: The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance.
"""
index: int
relevance_score: float
@json_schema_type
class RerankResponse(BaseModel):
"""Response from a reranking request.
:param data: List of rerank result objects, sorted by relevance score (descending)
"""
data: list[RerankData]
@json_schema_type
class OpenAIChatCompletionContentPartTextParam(BaseModel):
"""Text content part for OpenAI-compatible chat completion messages.
@ -1131,6 +1153,24 @@ class InferenceProvider(Protocol):
"""
...
@webmethod(route="/inference/rerank", method="POST", experimental=True)
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
"""Rerank a list of documents based on their relevance to a query.
:param model: The identifier of the reranking model to use.
:param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length.
:param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length.
:param max_num_results: (Optional) Maximum number of results to return. Default: returns all.
:returns: RerankResponse with indices sorted by relevance score (descending).
"""
raise NotImplementedError("Reranking is not implemented")
@webmethod(route="/openai/v1/completions", method="POST")
async def openai_completion(
self,

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server")
logger = get_logger(name=__name__, category="cli")
class StackRun(Subcommand):

View file

@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class DatasetIORouter(DatasetIO):

View file

@ -16,7 +16,7 @@ from llama_stack.apis.scoring import (
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class ScoringRouter(Scoring):

View file

@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="core::routers")
class InferenceRouter(Inference):

View file

@ -13,7 +13,7 @@ from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class SafetyRouter(Safety):

View file

@ -22,7 +22,7 @@ from llama_stack.log import get_logger
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class ToolRuntimeRouter(ToolRuntime):

View file

@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class VectorIORouter(VectorIO):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):

View file

@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
def get_impl_api(p: Any) -> Api:

View file

@ -26,7 +26,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):

View file

@ -17,7 +17,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class ModelsRoutingTable(CommonRoutingTableImpl, Models):

View file

@ -19,7 +19,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:

View file

@ -30,7 +30,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):

View file

@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
logger = get_logger(name=__name__, category="core::auth")
class AuthenticationMiddleware:

View file

@ -23,7 +23,7 @@ from llama_stack.core.datatypes import (
)
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
logger = get_logger(name=__name__, category="core::auth")
class AuthResponse(BaseModel):

View file

@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
logger = get_logger(name=__name__, category="quota")
logger = get_logger(name=__name__, category="core::server")
class QuotaMiddleware:

View file

@ -84,7 +84,7 @@ from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server")
logger = get_logger(name=__name__, category="core::server")
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -415,7 +415,7 @@ def main(args: argparse.Namespace | None = None):
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="server", config=logger_config)
logger = get_logger(name=__name__, category="core::server", config=logger_config)
if args.env:
for env_pair in args.env:
try:

View file

@ -16,7 +16,7 @@ from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
logger = get_logger(__name__, category="core")
logger = get_logger(__name__, category="core::registry")
class DistributionRegistry(Protocol):

View file

@ -10,7 +10,7 @@ from pathlib import Path
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="config_resolution")
logger = get_logger(name=__name__, category="core")
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"

View file

@ -36,7 +36,7 @@ from .utils import get_negative_inf_value, to_2tuple
MP_SCALE = 8
logger = get_logger(name=__name__, category="models")
logger = get_logger(name=__name__, category="models::llama")
def reduce_from_tensor_model_parallel_region(input_):

View file

@ -11,7 +11,7 @@ from llama_stack.log import get_logger
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="models::llama")
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")

View file

@ -18,7 +18,7 @@ from ...datatypes import QuantizationMode
from ..model import Transformer, TransformerBlock
from ..moe import MoE
log = get_logger(name=__name__, category="models")
log = get_logger(name=__name__, category="models::llama")
def swiglu_wrapper_no_reduce(

View file

@ -9,7 +9,7 @@ import collections
from llama_stack.log import get_logger
log = get_logger(name=__name__, category="llama")
log = get_logger(name=__name__, category="models::llama")
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401

View file

@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag"
logger = get_logger(name=__name__, category="agents")
logger = get_logger(name=__name__, category="agents::meta_reference")
class ChatAgent(ShieldRunnerMixin):

View file

@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
from .persistence import AgentInfo
from .responses.openai_responses import OpenAIResponsesImpl
logger = get_logger(name=__name__, category="agents")
logger = get_logger(name=__name__, category="agents::meta_reference")
class MetaReferenceAgentsImpl(Agents):

View file

@ -17,7 +17,7 @@ from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
log = get_logger(name=__name__, category="agents")
log = get_logger(name=__name__, category="agents::meta_reference")
class AgentSessionInfo(Session):

View file

@ -41,7 +41,7 @@ from .utils import (
convert_response_text_to_chat_response_format,
)
logger = get_logger(name=__name__, category="responses")
logger = get_logger(name=__name__, category="openai::responses")
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):

View file

@ -47,7 +47,7 @@ from llama_stack.log import get_logger
from .types import ChatCompletionContext, ChatCompletionResult
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
logger = get_logger(name=__name__, category="responses")
logger = get_logger(name=__name__, category="agents::meta_reference")
class StreamingResponseOrchestrator:

View file

@ -38,7 +38,7 @@ from llama_stack.log import get_logger
from .types import ChatCompletionContext, ToolExecutionResult
logger = get_logger(name=__name__, category="responses")
logger = get_logger(name=__name__, category="agents::meta_reference")
class ToolExecutor:

View file

@ -11,7 +11,7 @@ from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing
log = get_logger(name=__name__, category="agents")
log = get_logger(name=__name__, category="agents::meta_reference")
class SafetyException(Exception): # noqa: N818

View file

@ -33,6 +33,9 @@ from llama_stack.apis.inference import (
InterleavedContent,
LogProbConfig,
Message,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
RerankResponse,
ResponseFormat,
SamplingParams,
StopReason,
@ -442,6 +445,15 @@ class MetaReferenceInferenceImpl(
results = await self._nonstream_chat_completion(request_batch)
return BatchChatCompletionResponse(batch=results)
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
raise NotImplementedError("Reranking is not supported for Meta Reference")
async def _nonstream_chat_completion(
self, request_batch: list[ChatCompletionRequest]
) -> list[ChatCompletionResponse]:

View file

@ -12,6 +12,9 @@ from llama_stack.apis.inference import (
InterleavedContent,
LogProbConfig,
Message,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
RerankResponse,
ResponseFormat,
SamplingParams,
ToolChoice,
@ -122,3 +125,12 @@ class SentenceTransformersInferenceImpl(
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
raise NotImplementedError("Reranking is not supported for Sentence Transformers")

View file

@ -5,9 +5,11 @@
# the root directory of this source tree.
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
@ -23,4 +25,14 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
description="Local filesystem-based file storage provider for managing files and documents locally.",
),
remote_provider_spec(
api=Api.files,
adapter=AdapterSpec(
adapter_type="s3",
pip_packages=["boto3"] + sql_store_pip_packages,
module="llama_stack.providers.remote.files.s3",
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
),
),
]

View file

@ -0,0 +1,237 @@
# S3 Files Provider
A remote S3-based implementation of the Llama Stack Files API that provides scalable cloud file storage with metadata persistence.
## Features
- **AWS S3 Storage**: Store files in AWS S3 buckets for scalable, durable storage
- **Metadata Management**: Uses SQL database for efficient file metadata queries
- **OpenAI API Compatibility**: Full compatibility with OpenAI Files API endpoints
- **Flexible Authentication**: Support for IAM roles and access keys
- **Custom S3 Endpoints**: Support for MinIO and other S3-compatible services
## Configuration
### Basic Configuration
```yaml
api: files
provider_type: remote::s3
config:
bucket_name: my-llama-stack-files
region: us-east-1
metadata_store:
type: sqlite
db_path: ./s3_files_metadata.db
```
### Advanced Configuration
```yaml
api: files
provider_type: remote::s3
config:
bucket_name: my-llama-stack-files
region: us-east-1
aws_access_key_id: YOUR_ACCESS_KEY
aws_secret_access_key: YOUR_SECRET_KEY
endpoint_url: https://s3.amazonaws.com # Optional for custom endpoints
metadata_store:
type: sqlite
db_path: ./s3_files_metadata.db
```
### Environment Variables
The configuration supports environment variable substitution:
```yaml
config:
bucket_name: "${env.S3_BUCKET_NAME}"
region: "${env.AWS_REGION:=us-east-1}"
aws_access_key_id: "${env.AWS_ACCESS_KEY_ID:=}"
aws_secret_access_key: "${env.AWS_SECRET_ACCESS_KEY:=}"
endpoint_url: "${env.S3_ENDPOINT_URL:=}"
```
Note: `S3_BUCKET_NAME` has no default value since S3 bucket names must be globally unique.
## Authentication
### IAM Roles (Recommended)
For production deployments, use IAM roles:
```yaml
config:
bucket_name: my-bucket
region: us-east-1
# No credentials needed - will use IAM role
```
### Access Keys
For development or specific use cases:
```yaml
config:
bucket_name: my-bucket
region: us-east-1
aws_access_key_id: AKIAIOSFODNN7EXAMPLE
aws_secret_access_key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY
```
## S3 Bucket Setup
### Required Permissions
The S3 provider requires the following permissions:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject",
"s3:ListBucket"
],
"Resource": [
"arn:aws:s3:::your-bucket-name",
"arn:aws:s3:::your-bucket-name/*"
]
}
]
}
```
### Automatic Bucket Creation
By default, the S3 provider expects the bucket to already exist. If you want the provider to automatically create the bucket when it doesn't exist, set `auto_create_bucket: true` in your configuration:
```yaml
config:
bucket_name: my-bucket
auto_create_bucket: true # Will create bucket if it doesn't exist
region: us-east-1
```
**Note**: When `auto_create_bucket` is enabled, the provider will need additional permissions:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject",
"s3:ListBucket",
"s3:CreateBucket"
],
"Resource": [
"arn:aws:s3:::your-bucket-name",
"arn:aws:s3:::your-bucket-name/*"
]
}
]
}
```
### Bucket Policy (Optional)
For additional security, you can add a bucket policy:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Sid": "LlamaStackAccess",
"Effect": "Allow",
"Principal": {
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
},
"Action": [
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject"
],
"Resource": "arn:aws:s3:::your-bucket-name/*"
},
{
"Sid": "LlamaStackBucketAccess",
"Effect": "Allow",
"Principal": {
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
},
"Action": [
"s3:ListBucket"
],
"Resource": "arn:aws:s3:::your-bucket-name"
}
]
}
```
## Features
### Metadata Persistence
File metadata is stored in a SQL database for fast queries and OpenAI API compatibility. The metadata includes:
- File ID
- Original filename
- Purpose (assistants, batch, etc.)
- File size in bytes
- Created and expiration timestamps
### TTL and Cleanup
Files currently have a fixed long expiration time (100 years).
## Development and Testing
### Using MinIO
For self-hosted S3-compatible storage:
```yaml
config:
bucket_name: test-bucket
region: us-east-1
endpoint_url: http://localhost:9000
aws_access_key_id: minioadmin
aws_secret_access_key: minioadmin
```
## Monitoring and Logging
The provider logs important operations and errors. For production deployments, consider:
- CloudWatch monitoring for S3 operations
- Custom metrics for file upload/download rates
- Error rate monitoring
- Performance metrics tracking
## Error Handling
The provider handles various error scenarios:
- S3 connectivity issues
- Bucket access permissions
- File not found errors
- Metadata consistency checks
## Known Limitations
- Fixed long TTL (100 years) instead of configurable expiration
- No server-side encryption enabled by default
- No support for AWS session tokens
- No S3 key prefix organization support
- No multipart upload support (all files uploaded as single objects)

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.core.datatypes import Api
from .config import S3FilesImplConfig
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]):
from .files import S3FilesImpl
# TODO: authorization policies and user separation
impl = S3FilesImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
class S3FilesImplConfig(BaseModel):
"""Configuration for S3-based files provider."""
bucket_name: str = Field(description="S3 bucket name to store files")
region: str = Field(default="us-east-1", description="AWS region where the bucket is located")
aws_access_key_id: str | None = Field(default=None, description="AWS access key ID (optional if using IAM roles)")
aws_secret_access_key: str | None = Field(
default=None, description="AWS secret access key (optional if using IAM roles)"
)
endpoint_url: str | None = Field(default=None, description="Custom S3 endpoint URL (for MinIO, LocalStack, etc.)")
auto_create_bucket: bool = Field(
default=False, description="Automatically create the S3 bucket if it doesn't exist"
)
metadata_store: SqlStoreConfig = Field(description="SQL store configuration for file metadata")
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"bucket_name": "${env.S3_BUCKET_NAME}", # no default, buckets must be globally unique
"region": "${env.AWS_REGION:=us-east-1}",
"aws_access_key_id": "${env.AWS_ACCESS_KEY_ID:=}",
"aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:=}",
"endpoint_url": "${env.S3_ENDPOINT_URL:=}",
"auto_create_bucket": "${env.S3_AUTO_CREATE_BUCKET:=false}",
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="s3_files_metadata.db",
),
}

View file

@ -0,0 +1,272 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
import uuid
from typing import Annotated
import boto3
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
from fastapi import File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import (
Files,
ListOpenAIFileResponse,
OpenAIFileDeleteResponse,
OpenAIFileObject,
OpenAIFilePurpose,
)
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl
from .config import S3FilesImplConfig
# TODO: provider data for S3 credentials
def _create_s3_client(config: S3FilesImplConfig) -> boto3.client:
try:
s3_config = {
"region_name": config.region,
}
# endpoint URL if specified (for MinIO, LocalStack, etc.)
if config.endpoint_url:
s3_config["endpoint_url"] = config.endpoint_url
if config.aws_access_key_id and config.aws_secret_access_key:
s3_config.update(
{
"aws_access_key_id": config.aws_access_key_id,
"aws_secret_access_key": config.aws_secret_access_key,
}
)
return boto3.client("s3", **s3_config)
except (BotoCoreError, NoCredentialsError) as e:
raise RuntimeError(f"Failed to initialize S3 client: {e}") from e
async def _create_bucket_if_not_exists(client: boto3.client, config: S3FilesImplConfig) -> None:
try:
client.head_bucket(Bucket=config.bucket_name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
if error_code == "404":
if not config.auto_create_bucket:
raise RuntimeError(
f"S3 bucket '{config.bucket_name}' does not exist. "
f"Either create the bucket manually or set 'auto_create_bucket: true' in your configuration."
) from e
try:
# For us-east-1, we can't specify LocationConstraint
if config.region == "us-east-1":
client.create_bucket(Bucket=config.bucket_name)
else:
client.create_bucket(
Bucket=config.bucket_name,
CreateBucketConfiguration={"LocationConstraint": config.region},
)
except ClientError as create_error:
raise RuntimeError(
f"Failed to create S3 bucket '{config.bucket_name}': {create_error}"
) from create_error
elif error_code == "403":
raise RuntimeError(f"Access denied to S3 bucket '{config.bucket_name}'") from e
else:
raise RuntimeError(f"Failed to access S3 bucket '{config.bucket_name}': {e}") from e
class S3FilesImpl(Files):
"""S3-based implementation of the Files API."""
# TODO: implement expiration, for now a silly offset
_SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60
def __init__(self, config: S3FilesImplConfig) -> None:
self._config = config
self._client: boto3.client | None = None
self._sql_store: SqlStore | None = None
async def initialize(self) -> None:
self._client = _create_s3_client(self._config)
await _create_bucket_if_not_exists(self._client, self._config)
self._sql_store = sqlstore_impl(self._config.metadata_store)
await self._sql_store.create_table(
"openai_files",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"filename": ColumnType.STRING,
"purpose": ColumnType.STRING,
"bytes": ColumnType.INTEGER,
"created_at": ColumnType.INTEGER,
"expires_at": ColumnType.INTEGER,
# TODO: add s3_etag field for integrity checking
},
)
async def shutdown(self) -> None:
pass
@property
def client(self) -> boto3.client:
assert self._client is not None, "Provider not initialized"
return self._client
@property
def sql_store(self) -> SqlStore:
assert self._sql_store is not None, "Provider not initialized"
return self._sql_store
async def openai_upload_file(
self,
file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()],
) -> OpenAIFileObject:
file_id = f"file-{uuid.uuid4().hex}"
filename = getattr(file, "filename", None) or "uploaded_file"
created_at = int(time.time())
expires_at = created_at + self._SILLY_EXPIRATION_OFFSET
content = await file.read()
file_size = len(content)
await self.sql_store.insert(
"openai_files",
{
"id": file_id,
"filename": filename,
"purpose": purpose.value,
"bytes": file_size,
"created_at": created_at,
"expires_at": expires_at,
},
)
try:
self.client.put_object(
Bucket=self._config.bucket_name,
Key=file_id,
Body=content,
# TODO: enable server-side encryption
)
except ClientError as e:
await self.sql_store.delete("openai_files", where={"id": file_id})
raise RuntimeError(f"Failed to upload file to S3: {e}") from e
return OpenAIFileObject(
id=file_id,
filename=filename,
purpose=purpose,
bytes=file_size,
created_at=created_at,
expires_at=expires_at,
)
async def openai_list_files(
self,
after: str | None = None,
limit: int | None = 10000,
order: Order | None = Order.desc,
purpose: OpenAIFilePurpose | None = None,
) -> ListOpenAIFileResponse:
# this purely defensive. it should not happen because the router also default to Order.desc.
if not order:
order = Order.desc
where_conditions = {}
if purpose:
where_conditions["purpose"] = purpose.value
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,
limit=limit,
)
files = [
OpenAIFileObject(
id=row["id"],
filename=row["filename"],
purpose=OpenAIFilePurpose(row["purpose"]),
bytes=row["bytes"],
created_at=row["created_at"],
expires_at=row["expires_at"],
)
for row in paginated_result.data
]
return ListOpenAIFileResponse(
data=files,
has_more=paginated_result.has_more,
# empty string or None? spec says str, ref impl returns str | None, we go with spec
first_id=files[0].id if files else "",
last_id=files[-1].id if files else "",
)
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "files.list()")
return OpenAIFileObject(
id=row["id"],
filename=row["filename"],
purpose=OpenAIFilePurpose(row["purpose"]),
bytes=row["bytes"],
created_at=row["created_at"],
expires_at=row["expires_at"],
)
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "files.list()")
try:
self.client.delete_object(
Bucket=self._config.bucket_name,
Key=row["id"],
)
except ClientError as e:
if e.response["Error"]["Code"] != "NoSuchKey":
raise RuntimeError(f"Failed to delete file from S3: {e}") from e
await self.sql_store.delete("openai_files", where={"id": file_id})
return OpenAIFileDeleteResponse(id=file_id, deleted=True)
async def openai_retrieve_file_content(self, file_id: str) -> Response:
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "files.list()")
try:
response = self.client.get_object(
Bucket=self._config.bucket_name,
Key=row["id"],
)
# TODO: can we stream this instead of loading it into memory
content = response["Body"].read()
except ClientError as e:
if e.response["Error"]["Code"] == "NoSuchKey":
await self.sql_store.delete("openai_files", where={"id": file_id})
raise ResourceNotFoundError(file_id, "File", "files.list()") from e
raise RuntimeError(f"Failed to download file from S3: {e}") from e
return Response(
content=content,
media_type="application/octet-stream",
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
)

View file

@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::fireworks")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):

View file

@ -3,6 +3,11 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
RerankResponse,
)
from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
@ -10,7 +15,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
@ -54,3 +59,12 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
async def shutdown(self):
await super().shutdown()
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
raise NotImplementedError("Reranking is not supported for Llama OpenAI Compat")

View file

@ -57,7 +57,7 @@ from .openai_utils import (
)
from .utils import _is_nvidia_hosted
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::nvidia")
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):

View file

@ -10,7 +10,7 @@ from llama_stack.log import get_logger
from . import NVIDIAConfig
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::nvidia")
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:

View file

@ -37,11 +37,14 @@ from llama_stack.apis.inference import (
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
RerankResponse,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -85,7 +88,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter(
@ -641,6 +644,15 @@ class OllamaInferenceAdapter(
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
raise NotImplementedError("Reranking is not supported for Ollama")
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
async def _convert_content(content) -> dict:

View file

@ -11,7 +11,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::openai")
#

View file

@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
log = get_logger(name=__name__, category="inference")
log = get_logger(name=__name__, category="inference::tgi")
def build_hf_repo_model_entries():

View file

@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::together")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):

View file

@ -39,12 +39,15 @@ from llama_stack.apis.inference import (
Message,
ModelStore,
OpenAIChatCompletion,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
RerankResponse,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -85,7 +88,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMInferenceAdapterConfig
log = get_logger(name=__name__, category="inference")
log = get_logger(name=__name__, category="inference::vllm")
def build_hf_repo_model_entries():
@ -732,4 +735,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")
raise NotImplementedError("Batch chat completion is not supported for vLLM")
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
raise NotImplementedError("Reranking is not supported for vLLM")

View file

@ -15,7 +15,7 @@ from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefa
from .config import NvidiaPostTrainingConfig
logger = get_logger(name=__name__, category="integration")
logger = get_logger(name=__name__, category="post_training::nvidia")
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:

View file

@ -21,7 +21,7 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from .config import BedrockSafetyConfig
logger = get_logger(name=__name__, category="safety")
logger = get_logger(name=__name__, category="safety::bedrock")
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):

View file

@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
from .config import NVIDIASafetyConfig
logger = get_logger(name=__name__, category="safety")
logger = get_logger(name=__name__, category="safety::nvidia")
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):

View file

@ -25,7 +25,7 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
from .config import SambaNovaSafetyConfig
logger = get_logger(name=__name__, category="safety")
logger = get_logger(name=__name__, category="safety::sambanova")
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"

View file

@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
log = get_logger(name=__name__, category="vector_io")
log = get_logger(name=__name__, category="vector_io::chroma")
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI

View file

@ -36,7 +36,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
logger = get_logger(name=__name__, category="vector_io")
logger = get_logger(name=__name__, category="vector_io::milvus")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"

View file

@ -34,7 +34,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import PGVectorVectorIOConfig
log = get_logger(name=__name__, category="vector_io")
log = get_logger(name=__name__, category="vector_io::pgvector")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"

View file

@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
log = get_logger(name=__name__, category="vector_io")
log = get_logger(name=__name__, category="vector_io::qdrant")
CHUNK_ID_KEY = "_chunk_id"
# KV store prefixes for vector databases

View file

@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
from .config import WeaviateVectorIOConfig
log = get_logger(name=__name__, category="vector_io")
log = get_logger(name=__name__, category="vector_io::weaviate")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"

View file

@ -28,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
EMBEDDING_MODELS = {}
log = get_logger(name=__name__, category="inference")
log = get_logger(name=__name__, category="providers::utils")
class SentenceTransformerEmbeddingMixin:

View file

@ -54,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="providers::utils")
class LiteLLMOpenAIMixin(

View file

@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="providers::utils")
class RemoteInferenceProviderConfig(BaseModel):

View file

@ -134,7 +134,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
decode_assistant_message,
)
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="providers::utils")
class OpenAICompatCompletionChoiceDelta(BaseModel):

View file

@ -25,7 +25,7 @@ from llama_stack.apis.inference import (
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="providers::utils")
class OpenAIMixin(ABC):

View file

@ -58,7 +58,7 @@ from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.utils.inference import supported_inference_models
log = get_logger(name=__name__, category="inference")
log = get_logger(name=__name__, category="providers::utils")
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):

View file

@ -13,7 +13,7 @@ from llama_stack.providers.utils.kvstore import KVStore
from ..config import MongoDBKVStoreConfig
log = get_logger(name=__name__, category="kvstore")
log = get_logger(name=__name__, category="providers::utils")
class MongoDBKVStoreImpl(KVStore):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from ..api import KVStore
from ..config import PostgresKVStoreConfig
log = get_logger(name=__name__, category="kvstore")
log = get_logger(name=__name__, category="providers::utils")
class PostgresKVStoreImpl(KVStore):

View file

@ -44,7 +44,7 @@ from llama_stack.providers.utils.memory.vector_store import (
make_overlapped_chunks,
)
logger = get_logger(name=__name__, category="memory")
logger = get_logger(name=__name__, category="providers::utils")
# Constants for OpenAI vector stores
CHUNK_MULTIPLIER = 5

View file

@ -33,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
log = get_logger(name=__name__, category="memory")
log = get_logger(name=__name__, category="providers::utils")
class ChunkForDeletion(BaseModel):

View file

@ -17,7 +17,7 @@ from pydantic import BaseModel
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="scheduler")
logger = get_logger(name=__name__, category="providers::utils")
# TODO: revisit the list of possible statuses when defining a more coherent

View file

@ -17,7 +17,7 @@ from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
from .sqlstore import SqlStoreType
logger = get_logger(name=__name__, category="authorized_sqlstore")
logger = get_logger(name=__name__, category="providers::utils")
# Hardcoded copy of the default policy that our SQL filtering implements
# WARNING: If default_policy() changes, this constant must be updated accordingly

View file

@ -30,7 +30,7 @@ from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, SqlStore
from .sqlstore import SqlAlchemySqlStoreConfig
logger = get_logger(name=__name__, category="sqlstore")
logger = get_logger(name=__name__, category="providers::utils")
TYPE_MAPPING: dict[ColumnType, Any] = {
ColumnType.INTEGER: Integer,

View file

@ -0,0 +1,587 @@
import React from "react";
import {
render,
screen,
fireEvent,
waitFor,
act,
} from "@testing-library/react";
import "@testing-library/jest-dom";
import ChatPlaygroundPage from "./page";
const mockClient = {
agents: {
list: jest.fn(),
create: jest.fn(),
retrieve: jest.fn(),
delete: jest.fn(),
session: {
list: jest.fn(),
create: jest.fn(),
delete: jest.fn(),
retrieve: jest.fn(),
},
turn: {
create: jest.fn(),
},
},
models: {
list: jest.fn(),
},
toolgroups: {
list: jest.fn(),
},
};
jest.mock("@/hooks/use-auth-client", () => ({
useAuthClient: jest.fn(() => mockClient),
}));
jest.mock("@/components/chat-playground/chat", () => ({
Chat: jest.fn(
({
className,
messages,
handleSubmit,
input,
handleInputChange,
isGenerating,
append,
suggestions,
}) => (
<div data-testid="chat-component" className={className}>
<div data-testid="messages-count">{messages.length}</div>
<input
data-testid="chat-input"
value={input}
onChange={handleInputChange}
disabled={isGenerating}
/>
<button data-testid="submit-button" onClick={handleSubmit}>
Submit
</button>
{suggestions?.map((suggestion: string, index: number) => (
<button
key={index}
data-testid={`suggestion-${index}`}
onClick={() => append({ role: "user", content: suggestion })}
>
{suggestion}
</button>
))}
</div>
)
),
}));
jest.mock("@/components/chat-playground/conversations", () => ({
SessionManager: jest.fn(({ selectedAgentId, onNewSession }) => (
<div data-testid="session-manager">
{selectedAgentId && (
<>
<div data-testid="selected-agent">{selectedAgentId}</div>
<button data-testid="new-session-button" onClick={onNewSession}>
New Session
</button>
</>
)}
</div>
)),
SessionUtils: {
saveCurrentSessionId: jest.fn(),
loadCurrentSessionId: jest.fn(),
loadCurrentAgentId: jest.fn(),
saveCurrentAgentId: jest.fn(),
clearCurrentSession: jest.fn(),
saveSessionData: jest.fn(),
loadSessionData: jest.fn(),
saveAgentConfig: jest.fn(),
loadAgentConfig: jest.fn(),
clearAgentCache: jest.fn(),
createDefaultSession: jest.fn(() => ({
id: "test-session-123",
name: "Default Session",
messages: [],
selectedModel: "",
systemMessage: "You are a helpful assistant.",
agentId: "test-agent-123",
createdAt: Date.now(),
updatedAt: Date.now(),
})),
},
}));
const mockAgents = [
{
agent_id: "agent_123",
agent_config: {
name: "Test Agent",
instructions: "You are a test assistant.",
},
},
{
agent_id: "agent_456",
agent_config: {
agent_name: "Another Agent",
instructions: "You are another assistant.",
},
},
];
const mockModels = [
{
identifier: "test-model-1",
model_type: "llm",
},
{
identifier: "test-model-2",
model_type: "llm",
},
];
const mockToolgroups = [
{
identifier: "builtin::rag",
provider_id: "test-provider",
type: "tool_group",
provider_resource_id: "test-resource",
},
];
describe("ChatPlaygroundPage", () => {
beforeEach(() => {
jest.clearAllMocks();
Element.prototype.scrollIntoView = jest.fn();
mockClient.agents.list.mockResolvedValue({ data: mockAgents });
mockClient.models.list.mockResolvedValue(mockModels);
mockClient.toolgroups.list.mockResolvedValue(mockToolgroups);
mockClient.agents.session.create.mockResolvedValue({
session_id: "new-session-123",
});
mockClient.agents.session.list.mockResolvedValue({ data: [] });
mockClient.agents.session.retrieve.mockResolvedValue({
session_id: "test-session",
session_name: "Test Session",
started_at: new Date().toISOString(),
turns: [],
}); // No turns by default
mockClient.agents.retrieve.mockResolvedValue({
agent_id: "test-agent",
agent_config: {
toolgroups: ["builtin::rag"],
instructions: "Test instructions",
model: "test-model",
},
});
mockClient.agents.delete.mockResolvedValue(undefined);
});
describe("Agent Selector Rendering", () => {
test("shows agent selector when agents are available", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(screen.getByText("Agent Session:")).toBeInTheDocument();
expect(screen.getAllByRole("combobox")).toHaveLength(2);
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
expect(screen.getByText("Clear Chat")).toBeInTheDocument();
});
});
test("does not show agent selector when no agents are available", async () => {
mockClient.agents.list.mockResolvedValue({ data: [] });
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument();
expect(screen.getAllByRole("combobox")).toHaveLength(1);
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument();
});
});
test("does not show agent selector while loading", async () => {
mockClient.agents.list.mockImplementation(() => new Promise(() => {}));
await act(async () => {
render(<ChatPlaygroundPage />);
});
expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument();
expect(screen.getAllByRole("combobox")).toHaveLength(1);
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument();
});
test("shows agent options in selector", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
const agentCombobox = screen.getAllByRole("combobox").find(element => {
return (
element.textContent?.includes("Test Agent") ||
element.textContent?.includes("Select Agent")
);
});
expect(agentCombobox).toBeDefined();
fireEvent.click(agentCombobox!);
});
await waitFor(() => {
expect(screen.getAllByText("Test Agent")).toHaveLength(2);
expect(screen.getByText("Another Agent")).toBeInTheDocument();
});
});
test("displays agent ID when no name is available", async () => {
const agentWithoutName = {
agent_id: "agent_789",
agent_config: {
instructions: "You are an agent without a name.",
},
};
mockClient.agents.list.mockResolvedValue({ data: [agentWithoutName] });
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
const agentCombobox = screen.getAllByRole("combobox").find(element => {
return (
element.textContent?.includes("Agent agent_78") ||
element.textContent?.includes("Select Agent")
);
});
expect(agentCombobox).toBeDefined();
fireEvent.click(agentCombobox!);
});
await waitFor(() => {
expect(screen.getAllByText("Agent agent_78...")).toHaveLength(2);
});
});
});
describe("Agent Creation Modal", () => {
test("opens agent creation modal when + New Agent is clicked", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
const newAgentButton = screen.getByText("+ New Agent");
fireEvent.click(newAgentButton);
expect(screen.getByText("Create New Agent")).toBeInTheDocument();
expect(screen.getByText("Agent Name (optional)")).toBeInTheDocument();
expect(screen.getAllByText("Model")).toHaveLength(2);
expect(screen.getByText("System Instructions")).toBeInTheDocument();
expect(screen.getByText("Tools (optional)")).toBeInTheDocument();
});
test("closes modal when Cancel is clicked", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
const newAgentButton = screen.getByText("+ New Agent");
fireEvent.click(newAgentButton);
const cancelButton = screen.getByText("Cancel");
fireEvent.click(cancelButton);
expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument();
});
test("creates agent when Create Agent is clicked", async () => {
mockClient.agents.create.mockResolvedValue({ agent_id: "new-agent-123" });
mockClient.agents.list
.mockResolvedValueOnce({ data: mockAgents })
.mockResolvedValueOnce({
data: [
...mockAgents,
{ agent_id: "new-agent-123", agent_config: { name: "New Agent" } },
],
});
await act(async () => {
render(<ChatPlaygroundPage />);
});
const newAgentButton = screen.getByText("+ New Agent");
await act(async () => {
fireEvent.click(newAgentButton);
});
await waitFor(() => {
expect(screen.getByText("Create New Agent")).toBeInTheDocument();
});
const nameInput = screen.getByPlaceholderText("My Custom Agent");
await act(async () => {
fireEvent.change(nameInput, { target: { value: "Test Agent Name" } });
});
const instructionsTextarea = screen.getByDisplayValue(
"You are a helpful assistant."
);
await act(async () => {
fireEvent.change(instructionsTextarea, {
target: { value: "Custom instructions" },
});
});
await waitFor(() => {
const modalModelSelectors = screen
.getAllByRole("combobox")
.filter(el => {
return (
el.textContent?.includes("Select Model") ||
el.closest('[class*="modal"]') ||
el.closest('[class*="card"]')
);
});
expect(modalModelSelectors.length).toBeGreaterThan(0);
});
const modalModelSelectors = screen.getAllByRole("combobox").filter(el => {
return (
el.textContent?.includes("Select Model") ||
el.closest('[class*="modal"]') ||
el.closest('[class*="card"]')
);
});
await act(async () => {
fireEvent.click(modalModelSelectors[0]);
});
await waitFor(() => {
const modelOptions = screen.getAllByText("test-model-1");
expect(modelOptions.length).toBeGreaterThan(0);
});
const modelOptions = screen.getAllByText("test-model-1");
const dropdownOption = modelOptions.find(
option =>
option.closest('[role="option"]') ||
option.id?.includes("radix") ||
option.getAttribute("aria-selected") !== null
);
await act(async () => {
fireEvent.click(
dropdownOption || modelOptions[modelOptions.length - 1]
);
});
await waitFor(() => {
const createButton = screen.getByText("Create Agent");
expect(createButton).not.toBeDisabled();
});
const createButton = screen.getByText("Create Agent");
await act(async () => {
fireEvent.click(createButton);
});
await waitFor(() => {
expect(mockClient.agents.create).toHaveBeenCalledWith({
agent_config: {
model: expect.any(String),
instructions: "Custom instructions",
name: "Test Agent Name",
enable_session_persistence: true,
},
});
});
await waitFor(() => {
expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument();
});
});
});
describe("Agent Selection", () => {
test("creates default session when agent is selected", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
// first agent should be auto-selected
expect(mockClient.agents.session.create).toHaveBeenCalledWith(
"agent_123",
{ session_name: "Default Session" }
);
});
});
test("switches agent when different agent is selected", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
const agentCombobox = screen.getAllByRole("combobox").find(element => {
return (
element.textContent?.includes("Test Agent") ||
element.textContent?.includes("Select Agent")
);
});
expect(agentCombobox).toBeDefined();
fireEvent.click(agentCombobox!);
});
await waitFor(() => {
const anotherAgentOption = screen.getByText("Another Agent");
fireEvent.click(anotherAgentOption);
});
expect(mockClient.agents.session.create).toHaveBeenCalledWith(
"agent_456",
{ session_name: "Default Session" }
);
});
});
describe("Agent Deletion", () => {
test("shows delete button when multiple agents exist", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
});
});
test("hides delete button when only one agent exists", async () => {
mockClient.agents.list.mockResolvedValue({
data: [mockAgents[0]],
});
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(
screen.queryByTitle("Delete current agent")
).not.toBeInTheDocument();
});
});
test("deletes agent and switches to another when confirmed", async () => {
global.confirm = jest.fn(() => true);
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
});
mockClient.agents.delete.mockResolvedValue(undefined);
mockClient.agents.list.mockResolvedValueOnce({ data: mockAgents });
mockClient.agents.list.mockResolvedValueOnce({
data: [mockAgents[1]],
});
const deleteButton = screen.getByTitle("Delete current agent");
await act(async () => {
deleteButton.click();
});
await waitFor(() => {
expect(mockClient.agents.delete).toHaveBeenCalledWith("agent_123");
expect(global.confirm).toHaveBeenCalledWith(
"Are you sure you want to delete this agent? This action cannot be undone and will delete all associated sessions."
);
});
(global.confirm as jest.Mock).mockRestore();
});
test("does not delete agent when cancelled", async () => {
global.confirm = jest.fn(() => false);
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
});
const deleteButton = screen.getByTitle("Delete current agent");
await act(async () => {
deleteButton.click();
});
await waitFor(() => {
expect(global.confirm).toHaveBeenCalled();
expect(mockClient.agents.delete).not.toHaveBeenCalled();
});
(global.confirm as jest.Mock).mockRestore();
});
});
describe("Error Handling", () => {
test("handles agent loading errors gracefully", async () => {
mockClient.agents.list.mockRejectedValue(
new Error("Failed to load agents")
);
const consoleSpy = jest
.spyOn(console, "error")
.mockImplementation(() => {});
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(consoleSpy).toHaveBeenCalledWith(
"Error fetching agents:",
expect.any(Error)
);
});
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
consoleSpy.mockRestore();
});
test("handles model loading errors gracefully", async () => {
mockClient.models.list.mockRejectedValue(
new Error("Failed to load models")
);
const consoleSpy = jest
.spyOn(console, "error")
.mockImplementation(() => {});
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(consoleSpy).toHaveBeenCalledWith(
"Error fetching models:",
expect.any(Error)
);
});
consoleSpy.mockRestore();
});
});
});

File diff suppressed because it is too large Load diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 25 KiB

View file

@ -120,3 +120,44 @@
@apply bg-background text-foreground;
}
}
@layer utilities {
.animate-typing-dot-1 {
animation: typing-dot-bounce-1 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
}
.animate-typing-dot-2 {
animation: typing-dot-bounce-2 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
}
.animate-typing-dot-3 {
animation: typing-dot-bounce-3 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
}
@keyframes typing-dot-bounce-1 {
0%, 15%, 85%, 100% {
transform: translateY(0);
}
7.5% {
transform: translateY(-6px);
}
}
@keyframes typing-dot-bounce-2 {
0%, 15%, 35%, 85%, 100% {
transform: translateY(0);
}
25% {
transform: translateY(-6px);
}
}
@keyframes typing-dot-bounce-3 {
0%, 35%, 55%, 85%, 100% {
transform: translateY(0);
}
45% {
transform: translateY(-6px);
}
}
}

View file

@ -18,6 +18,9 @@ const geistMono = Geist_Mono({
export const metadata: Metadata = {
title: "Llama Stack",
description: "Llama Stack UI",
icons: {
icon: "/favicon.ico",
},
};
import { SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";

View file

@ -161,10 +161,12 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
const isUser = role === "user";
const formattedTime = createdAt?.toLocaleTimeString("en-US", {
hour: "2-digit",
minute: "2-digit",
});
const formattedTime = createdAt
? new Date(createdAt).toLocaleTimeString("en-US", {
hour: "2-digit",
minute: "2-digit",
})
: undefined;
if (isUser) {
return (
@ -185,7 +187,7 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
{showTimeStamp && createdAt ? (
<time
dateTime={createdAt.toISOString()}
dateTime={new Date(createdAt).toISOString()}
className={cn(
"mt-1 block px-1 text-xs opacity-50",
animation !== "none" && "duration-500 animate-in fade-in-0"
@ -220,7 +222,7 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
{showTimeStamp && createdAt ? (
<time
dateTime={createdAt.toISOString()}
dateTime={new Date(createdAt).toISOString()}
className={cn(
"mt-1 block px-1 text-xs opacity-50",
animation !== "none" && "duration-500 animate-in fade-in-0"
@ -262,7 +264,7 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
{showTimeStamp && createdAt ? (
<time
dateTime={createdAt.toISOString()}
dateTime={new Date(createdAt).toISOString()}
className={cn(
"mt-1 block px-1 text-xs opacity-50",
animation !== "none" && "duration-500 animate-in fade-in-0"

View file

@ -0,0 +1,345 @@
import React from "react";
import { render, screen, waitFor, act } from "@testing-library/react";
import "@testing-library/jest-dom";
import { Conversations, SessionUtils } from "./conversations";
import type { Message } from "@/components/chat-playground/chat-message";
interface ChatSession {
id: string;
name: string;
messages: Message[];
selectedModel: string;
systemMessage: string;
agentId: string;
createdAt: number;
updatedAt: number;
}
const mockOnSessionChange = jest.fn();
const mockOnNewSession = jest.fn();
// Mock the auth client
const mockClient = {
agents: {
session: {
list: jest.fn(),
create: jest.fn(),
delete: jest.fn(),
retrieve: jest.fn(),
},
},
};
// Mock the useAuthClient hook
jest.mock("@/hooks/use-auth-client", () => ({
useAuthClient: jest.fn(() => mockClient),
}));
// Mock additional SessionUtils methods that are now being used
jest.mock("./conversations", () => {
const actual = jest.requireActual("./conversations");
return {
...actual,
SessionUtils: {
...actual.SessionUtils,
saveSessionData: jest.fn(),
loadSessionData: jest.fn(),
saveAgentConfig: jest.fn(),
loadAgentConfig: jest.fn(),
clearAgentCache: jest.fn(),
},
};
});
const localStorageMock = {
getItem: jest.fn(),
setItem: jest.fn(),
removeItem: jest.fn(),
clear: jest.fn(),
};
Object.defineProperty(window, "localStorage", {
value: localStorageMock,
writable: true,
});
// Mock crypto.randomUUID for test environment
let uuidCounter = 0;
Object.defineProperty(globalThis, "crypto", {
value: {
randomUUID: jest.fn(() => `test-uuid-${++uuidCounter}`),
},
writable: true,
});
describe("SessionManager", () => {
const mockSession: ChatSession = {
id: "session_123",
name: "Test Session",
messages: [
{
id: "msg_1",
role: "user",
content: "Hello",
createdAt: new Date(),
},
],
selectedModel: "test-model",
systemMessage: "You are a helpful assistant.",
agentId: "agent_123",
createdAt: 1710000000,
updatedAt: 1710001000,
};
const mockAgentSessions = [
{
session_id: "session_123",
session_name: "Test Session",
started_at: "2024-01-01T00:00:00Z",
turns: [],
},
{
session_id: "session_456",
session_name: "Another Session",
started_at: "2024-01-01T01:00:00Z",
turns: [],
},
];
beforeEach(() => {
jest.clearAllMocks();
localStorageMock.getItem.mockReturnValue(null);
localStorageMock.setItem.mockImplementation(() => {});
mockClient.agents.session.list.mockResolvedValue({
data: mockAgentSessions,
});
mockClient.agents.session.create.mockResolvedValue({
session_id: "new_session_123",
});
mockClient.agents.session.delete.mockResolvedValue(undefined);
mockClient.agents.session.retrieve.mockResolvedValue({
session_id: "test-session",
session_name: "Test Session",
started_at: new Date().toISOString(),
turns: [],
});
uuidCounter = 0; // Reset UUID counter for consistent test behavior
});
describe("Component Rendering", () => {
test("does not render when no agent is selected", async () => {
const { container } = await act(async () => {
return render(
<Conversations
selectedAgentId=""
currentSession={null}
onSessionChange={mockOnSessionChange}
onNewSession={mockOnNewSession}
/>
);
});
expect(container.firstChild).toBeNull();
});
test("renders loading state initially", async () => {
mockClient.agents.session.list.mockImplementation(
() => new Promise(() => {}) // Never resolves to simulate loading
);
await act(async () => {
render(
<Conversations
selectedAgentId="agent_123"
currentSession={null}
onSessionChange={mockOnSessionChange}
onNewSession={mockOnNewSession}
/>
);
});
expect(screen.getByText("Select Session")).toBeInTheDocument();
// When loading, the "+ New" button should be disabled
expect(screen.getByText("+ New")).toBeDisabled();
});
test("renders session selector when agent sessions are loaded", async () => {
await act(async () => {
render(
<Conversations
selectedAgentId="agent_123"
currentSession={null}
onSessionChange={mockOnSessionChange}
onNewSession={mockOnNewSession}
/>
);
});
await waitFor(() => {
expect(screen.getByText("Select Session")).toBeInTheDocument();
});
});
test("renders current session name when session is selected", async () => {
await act(async () => {
render(
<Conversations
selectedAgentId="agent_123"
currentSession={mockSession}
onSessionChange={mockOnSessionChange}
onNewSession={mockOnNewSession}
/>
);
});
await waitFor(() => {
expect(screen.getByText("Test Session")).toBeInTheDocument();
});
});
});
describe("Agent API Integration", () => {
test("loads sessions from agent API on mount", async () => {
await act(async () => {
render(
<Conversations
selectedAgentId="agent_123"
currentSession={mockSession}
onSessionChange={mockOnSessionChange}
onNewSession={mockOnNewSession}
/>
);
});
await waitFor(() => {
expect(mockClient.agents.session.list).toHaveBeenCalledWith(
"agent_123"
);
});
});
test("handles API errors gracefully", async () => {
mockClient.agents.session.list.mockRejectedValue(new Error("API Error"));
const consoleSpy = jest
.spyOn(console, "error")
.mockImplementation(() => {});
await act(async () => {
render(
<Conversations
selectedAgentId="agent_123"
currentSession={mockSession}
onSessionChange={mockOnSessionChange}
onNewSession={mockOnNewSession}
/>
);
});
await waitFor(() => {
expect(consoleSpy).toHaveBeenCalledWith(
"Error loading agent sessions:",
expect.any(Error)
);
});
consoleSpy.mockRestore();
});
});
describe("Error Handling", () => {
test("component renders without crashing when API is unavailable", async () => {
mockClient.agents.session.list.mockRejectedValue(
new Error("Network Error")
);
const consoleSpy = jest
.spyOn(console, "error")
.mockImplementation(() => {});
await act(async () => {
render(
<Conversations
selectedAgentId="agent_123"
currentSession={mockSession}
onSessionChange={mockOnSessionChange}
onNewSession={mockOnNewSession}
/>
);
});
// Should still render the session manager with the select trigger
expect(screen.getByRole("combobox")).toBeInTheDocument();
expect(screen.getByText("+ New")).toBeInTheDocument();
consoleSpy.mockRestore();
});
});
});
describe("SessionUtils", () => {
beforeEach(() => {
jest.clearAllMocks();
localStorageMock.getItem.mockReturnValue(null);
localStorageMock.setItem.mockImplementation(() => {});
});
describe("saveCurrentSessionId", () => {
test("saves session ID to localStorage", () => {
SessionUtils.saveCurrentSessionId("test-session-id");
expect(localStorageMock.setItem).toHaveBeenCalledWith(
"chat-playground-current-session",
"test-session-id"
);
});
});
describe("createDefaultSession", () => {
test("creates default session with agent ID", () => {
const result = SessionUtils.createDefaultSession("agent_123");
expect(result).toEqual(
expect.objectContaining({
name: "Default Session",
messages: [],
selectedModel: "",
systemMessage: "You are a helpful assistant.",
agentId: "agent_123",
})
);
expect(result.id).toBeTruthy();
expect(result.createdAt).toBeTruthy();
expect(result.updatedAt).toBeTruthy();
});
test("creates default session with inherited model", () => {
const result = SessionUtils.createDefaultSession(
"agent_123",
"inherited-model"
);
expect(result.selectedModel).toBe("inherited-model");
expect(result.agentId).toBe("agent_123");
});
test("creates unique session IDs", () => {
const originalNow = Date.now;
let mockTime = 1710005000;
Date.now = jest.fn(() => ++mockTime);
const session1 = SessionUtils.createDefaultSession("agent_123");
const session2 = SessionUtils.createDefaultSession("agent_123");
expect(session1.id).not.toBe(session2.id);
Date.now = originalNow;
});
test("sets creation and update timestamps", () => {
const result = SessionUtils.createDefaultSession("agent_123");
expect(result.createdAt).toBeTruthy();
expect(result.updatedAt).toBeTruthy();
expect(typeof result.createdAt).toBe("number");
expect(typeof result.updatedAt).toBe("number");
});
});
});

View file

@ -0,0 +1,568 @@
"use client";
import { useState, useEffect, useCallback } from "react";
import { Button } from "@/components/ui/button";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Input } from "@/components/ui/input";
import { Card } from "@/components/ui/card";
import { Trash2 } from "lucide-react";
import type { Message } from "@/components/chat-playground/chat-message";
import { useAuthClient } from "@/hooks/use-auth-client";
import type {
Session,
SessionCreateParams,
} from "llama-stack-client/resources/agents";
export interface ChatSession {
id: string;
name: string;
messages: Message[];
selectedModel: string;
systemMessage: string;
agentId: string;
session?: Session;
createdAt: number;
updatedAt: number;
}
interface SessionManagerProps {
currentSession: ChatSession | null;
onSessionChange: (session: ChatSession) => void;
onNewSession: () => void;
selectedAgentId: string;
}
const CURRENT_SESSION_KEY = "chat-playground-current-session";
// ensures this only happens client side
const safeLocalStorage = {
getItem: (key: string): string | null => {
if (typeof window === "undefined") return null;
try {
return localStorage.getItem(key);
} catch (err) {
console.error("Error accessing localStorage:", err);
return null;
}
},
setItem: (key: string, value: string): void => {
if (typeof window === "undefined") return;
try {
localStorage.setItem(key, value);
} catch (err) {
console.error("Error writing to localStorage:", err);
}
},
removeItem: (key: string): void => {
if (typeof window === "undefined") return;
try {
localStorage.removeItem(key);
} catch (err) {
console.error("Error removing from localStorage:", err);
}
},
};
const generateSessionId = (): string => {
return globalThis.crypto.randomUUID();
};
export function Conversations({
currentSession,
onSessionChange,
selectedAgentId,
}: SessionManagerProps) {
const [sessions, setSessions] = useState<ChatSession[]>([]);
const [showCreateForm, setShowCreateForm] = useState(false);
const [newSessionName, setNewSessionName] = useState("");
const [loading, setLoading] = useState(false);
const client = useAuthClient();
const loadAgentSessions = useCallback(async () => {
if (!selectedAgentId) return;
setLoading(true);
try {
const response = await client.agents.session.list(selectedAgentId);
console.log("Sessions response:", response);
if (!response.data || !Array.isArray(response.data)) {
console.warn("Invalid sessions response, starting fresh");
setSessions([]);
return;
}
const agentSessions: ChatSession[] = response.data
.filter(sessionData => {
const isValid =
sessionData &&
typeof sessionData === "object" &&
sessionData.session_id &&
sessionData.session_name;
if (!isValid) {
console.warn("Filtering out invalid session:", sessionData);
}
return isValid;
})
.map(sessionData => ({
id: sessionData.session_id,
name: sessionData.session_name,
messages: [],
selectedModel: currentSession?.selectedModel || "",
systemMessage:
currentSession?.systemMessage || "You are a helpful assistant.",
agentId: selectedAgentId,
session: sessionData,
createdAt: sessionData.started_at
? new Date(sessionData.started_at).getTime()
: Date.now(),
updatedAt: sessionData.started_at
? new Date(sessionData.started_at).getTime()
: Date.now(),
}));
setSessions(agentSessions);
} catch (error) {
console.error("Error loading agent sessions:", error);
setSessions([]);
} finally {
setLoading(false);
}
}, [
selectedAgentId,
client,
currentSession?.selectedModel,
currentSession?.systemMessage,
]);
useEffect(() => {
if (selectedAgentId) {
loadAgentSessions();
}
}, [selectedAgentId, loadAgentSessions]);
const createNewSession = async () => {
if (!selectedAgentId) return;
const sessionName =
newSessionName.trim() || `Session ${sessions.length + 1}`;
setLoading(true);
try {
const response = await client.agents.session.create(selectedAgentId, {
session_name: sessionName,
} as SessionCreateParams);
const newSession: ChatSession = {
id: response.session_id,
name: sessionName,
messages: [],
selectedModel: currentSession?.selectedModel || "",
systemMessage:
currentSession?.systemMessage || "You are a helpful assistant.",
agentId: selectedAgentId,
createdAt: Date.now(),
updatedAt: Date.now(),
};
setSessions(prev => [...prev, newSession]);
SessionUtils.saveCurrentSessionId(newSession.id, selectedAgentId);
onSessionChange(newSession);
setNewSessionName("");
setShowCreateForm(false);
} catch (error) {
console.error("Error creating session:", error);
} finally {
setLoading(false);
}
};
const loadSessionMessages = useCallback(
async (agentId: string, sessionId: string): Promise<Message[]> => {
try {
const session = await client.agents.session.retrieve(
agentId,
sessionId
);
if (!session || !session.turns || !Array.isArray(session.turns)) {
return [];
}
const messages: Message[] = [];
for (const turn of session.turns) {
// Add user messages from input_messages
if (turn.input_messages && Array.isArray(turn.input_messages)) {
for (const input of turn.input_messages) {
if (input.role === "user" && input.content) {
messages.push({
id: `${turn.turn_id}-user-${messages.length}`,
role: "user",
content:
typeof input.content === "string"
? input.content
: JSON.stringify(input.content),
createdAt: new Date(turn.started_at || Date.now()),
});
}
}
}
// Add assistant message from output_message
if (turn.output_message && turn.output_message.content) {
messages.push({
id: `${turn.turn_id}-assistant-${messages.length}`,
role: "assistant",
content:
typeof turn.output_message.content === "string"
? turn.output_message.content
: JSON.stringify(turn.output_message.content),
createdAt: new Date(
turn.completed_at || turn.started_at || Date.now()
),
});
}
}
return messages;
} catch (error) {
console.error("Error loading session messages:", error);
return [];
}
},
[client]
);
const switchToSession = useCallback(
async (sessionId: string) => {
const session = sessions.find(s => s.id === sessionId);
if (session) {
setLoading(true);
try {
// Load messages for this session
const messages = await loadSessionMessages(
selectedAgentId,
sessionId
);
const sessionWithMessages = {
...session,
messages,
};
SessionUtils.saveCurrentSessionId(sessionId, selectedAgentId);
onSessionChange(sessionWithMessages);
} catch (error) {
console.error("Error switching to session:", error);
// Fallback to session without messages
SessionUtils.saveCurrentSessionId(sessionId, selectedAgentId);
onSessionChange(session);
} finally {
setLoading(false);
}
}
},
[sessions, selectedAgentId, loadSessionMessages, onSessionChange]
);
const deleteSession = async (sessionId: string) => {
if (sessions.length <= 1 || !selectedAgentId) {
return;
}
if (
confirm(
"Are you sure you want to delete this session? This action cannot be undone."
)
) {
setLoading(true);
try {
await client.agents.session.delete(selectedAgentId, sessionId);
const updatedSessions = sessions.filter(s => s.id !== sessionId);
setSessions(updatedSessions);
if (currentSession?.id === sessionId) {
const newCurrentSession = updatedSessions[0] || null;
if (newCurrentSession) {
SessionUtils.saveCurrentSessionId(
newCurrentSession.id,
selectedAgentId
);
onSessionChange(newCurrentSession);
} else {
SessionUtils.clearCurrentSession(selectedAgentId);
onNewSession();
}
}
} catch (error) {
console.error("Error deleting session:", error);
} finally {
setLoading(false);
}
}
};
useEffect(() => {
if (currentSession) {
setSessions(prevSessions => {
const updatedSessions = prevSessions.map(session =>
session.id === currentSession.id ? currentSession : session
);
if (!prevSessions.find(s => s.id === currentSession.id)) {
updatedSessions.push(currentSession);
}
return updatedSessions;
});
}
}, [currentSession]);
// Don't render if no agent is selected
if (!selectedAgentId) {
return null;
}
return (
<div className="relative">
<div className="flex items-center gap-2">
<Select
value={currentSession?.id || ""}
onValueChange={switchToSession}
>
<SelectTrigger className="w-[200px]">
<SelectValue placeholder="Select Session" />
</SelectTrigger>
<SelectContent>
{sessions.map(session => (
<SelectItem key={session.id} value={session.id}>
{session.name}
</SelectItem>
))}
</SelectContent>
</Select>
<Button
onClick={() => setShowCreateForm(true)}
variant="outline"
size="sm"
disabled={loading || !selectedAgentId}
>
+ New
</Button>
{currentSession && sessions.length > 1 && (
<Button
onClick={() => deleteSession(currentSession.id)}
variant="outline"
size="sm"
className="text-destructive hover:text-destructive hover:bg-destructive/10"
title="Delete current session"
>
<Trash2 className="h-3 w-3" />
</Button>
)}
</div>
{showCreateForm && (
<Card className="absolute top-full left-0 mt-2 p-4 space-y-3 w-80 z-50 bg-background border shadow-lg">
<h3 className="text-md font-semibold">Create New Session</h3>
<Input
value={newSessionName}
onChange={e => setNewSessionName(e.target.value)}
placeholder="Session name (optional)"
onKeyDown={e => {
if (e.key === "Enter") {
createNewSession();
} else if (e.key === "Escape") {
setShowCreateForm(false);
setNewSessionName("");
}
}}
/>
<div className="flex gap-2">
<Button
onClick={createNewSession}
className="flex-1"
disabled={loading}
>
{loading ? "Creating..." : "Create"}
</Button>
<Button
variant="outline"
onClick={() => {
setShowCreateForm(false);
setNewSessionName("");
}}
className="flex-1"
>
Cancel
</Button>
</div>
</Card>
)}
{currentSession && sessions.length > 1 && (
<div className="absolute top-full left-0 mt-1 text-xs text-gray-500 whitespace-nowrap">
{sessions.length} sessions Current: {currentSession.name}
{currentSession.messages.length > 0 &&
`${currentSession.messages.length} messages`}
</div>
)}
</div>
);
}
export const SessionUtils = {
loadCurrentSessionId: (agentId?: string): string | null => {
const key = agentId
? `${CURRENT_SESSION_KEY}-${agentId}`
: CURRENT_SESSION_KEY;
return safeLocalStorage.getItem(key);
},
saveCurrentSessionId: (sessionId: string, agentId?: string) => {
const key = agentId
? `${CURRENT_SESSION_KEY}-${agentId}`
: CURRENT_SESSION_KEY;
safeLocalStorage.setItem(key, sessionId);
},
createDefaultSession: (
agentId: string,
inheritModel?: string
): ChatSession => ({
id: generateSessionId(),
name: "Default Session",
messages: [],
selectedModel: inheritModel || "",
systemMessage: "You are a helpful assistant.",
agentId,
createdAt: Date.now(),
updatedAt: Date.now(),
}),
clearCurrentSession: (agentId?: string) => {
const key = agentId
? `${CURRENT_SESSION_KEY}-${agentId}`
: CURRENT_SESSION_KEY;
safeLocalStorage.removeItem(key);
},
loadCurrentAgentId: (): string | null => {
return safeLocalStorage.getItem("chat-playground-current-agent");
},
saveCurrentAgentId: (agentId: string) => {
safeLocalStorage.setItem("chat-playground-current-agent", agentId);
},
// Comprehensive session caching
saveSessionData: (agentId: string, sessionData: ChatSession) => {
const key = `chat-playground-session-data-${agentId}-${sessionData.id}`;
safeLocalStorage.setItem(
key,
JSON.stringify({
...sessionData,
cachedAt: Date.now(),
})
);
},
loadSessionData: (agentId: string, sessionId: string): ChatSession | null => {
const key = `chat-playground-session-data-${agentId}-${sessionId}`;
const cached = safeLocalStorage.getItem(key);
if (!cached) return null;
try {
const data = JSON.parse(cached);
// Check if cache is fresh (less than 1 hour old)
const cacheAge = Date.now() - (data.cachedAt || 0);
if (cacheAge > 60 * 60 * 1000) {
safeLocalStorage.removeItem(key);
return null;
}
// Convert date strings back to Date objects
return {
...data,
messages: data.messages.map(
(msg: { createdAt: string; [key: string]: unknown }) => ({
...msg,
createdAt: new Date(msg.createdAt),
})
),
};
} catch (error) {
console.error("Error parsing cached session data:", error);
safeLocalStorage.removeItem(key);
return null;
}
},
// Agent config caching
saveAgentConfig: (
agentId: string,
config: {
toolgroups?: Array<
string | { name: string; args: Record<string, unknown> }
>;
[key: string]: unknown;
}
) => {
const key = `chat-playground-agent-config-${agentId}`;
safeLocalStorage.setItem(
key,
JSON.stringify({
config,
cachedAt: Date.now(),
})
);
},
loadAgentConfig: (
agentId: string
): {
toolgroups?: Array<
string | { name: string; args: Record<string, unknown> }
>;
[key: string]: unknown;
} | null => {
const key = `chat-playground-agent-config-${agentId}`;
const cached = safeLocalStorage.getItem(key);
if (!cached) return null;
try {
const data = JSON.parse(cached);
// Check if cache is fresh (less than 30 minutes old)
const cacheAge = Date.now() - (data.cachedAt || 0);
if (cacheAge > 30 * 60 * 1000) {
safeLocalStorage.removeItem(key);
return null;
}
return data.config;
} catch (error) {
console.error("Error parsing cached agent config:", error);
safeLocalStorage.removeItem(key);
return null;
}
},
// Clear all cached data for an agent
clearAgentCache: (agentId: string) => {
const keys = Object.keys(localStorage).filter(
key =>
key.includes(`chat-playground-session-data-${agentId}`) ||
key.includes(`chat-playground-agent-config-${agentId}`)
);
keys.forEach(key => safeLocalStorage.removeItem(key));
},
};

View file

@ -5,9 +5,9 @@ export function TypingIndicator() {
<div className="justify-left flex space-x-1">
<div className="rounded-lg bg-muted p-3">
<div className="flex -space-x-2.5">
<Dot className="h-5 w-5 animate-typing-dot-bounce" />
<Dot className="h-5 w-5 animate-typing-dot-bounce [animation-delay:90ms]" />
<Dot className="h-5 w-5 animate-typing-dot-bounce [animation-delay:180ms]" />
<Dot className="h-5 w-5 animate-typing-dot-1" />
<Dot className="h-5 w-5 animate-typing-dot-2" />
<Dot className="h-5 w-5 animate-typing-dot-3" />
</div>
</div>
</div>

View file

@ -11,6 +11,7 @@ import {
} from "lucide-react";
import Link from "next/link";
import { usePathname } from "next/navigation";
import Image from "next/image";
import { cn } from "@/lib/utils";
import {
@ -110,7 +111,16 @@ export function AppSidebar() {
return (
<Sidebar>
<SidebarHeader>
<Link href="/">Llama Stack</Link>
<Link href="/" className="flex items-center gap-2 p-2">
<Image
src="/logo.webp"
alt="Llama Stack"
width={32}
height={32}
className="h-8 w-8"
/>
<span className="font-semibold text-lg">Llama Stack</span>
</Link>
</SidebarHeader>
<SidebarContent>
<SidebarGroup>

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

View file

@ -98,6 +98,7 @@ unit = [
"together",
"coverage",
"chromadb>=1.0.15",
"moto[s3]>=5.1.10",
]
# These are the core dependencies required for running integration tests. They are shared across all
# providers. If a provider requires additional dependencies, please add them to your environment

View file

@ -157,12 +157,14 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
}
def generate_provider_docs(provider_spec: Any, api_name: str) -> str:
def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
"""Generate markdown documentation for a provider."""
provider_type = provider_spec.provider_type
config_class = provider_spec.config_class
config_info = get_config_class_info(config_class)
if "error" in config_info:
progress.print(config_info["error"])
md_lines = []
md_lines.append(f"# {provider_type}")
@ -295,7 +297,7 @@ def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> N
filename = provider_type.replace("::", "_").replace(":", "_")
provider_doc_file = doc_output_dir / f"{filename}.md"
provider_docs = generate_provider_docs(provider, api_name)
provider_docs = generate_provider_docs(progress, provider, api_name)
provider_doc_file.write_text(provider_docs)
change_tracker.add_paths(provider_doc_file)

17
scripts/run-ui-linter.sh Executable file
View file

@ -0,0 +1,17 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -e
cd llama_stack/ui
if [ ! -d node_modules ] || [ ! -x node_modules/.bin/prettier ] || [ ! -x node_modules/.bin/eslint ]; then
echo "UI dependencies not installed, skipping prettier/linter check"
exit 0
fi
npm run format
npm run lint

View file

@ -0,0 +1,251 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import patch
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_aws
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.files import OpenAIFilePurpose
from llama_stack.providers.remote.files.s3 import (
S3FilesImplConfig,
get_adapter_impl,
)
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
class MockUploadFile:
def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"):
self.content = content
self.filename = filename
self.content_type = content_type
async def read(self):
return self.content
@pytest.fixture
def s3_config(tmp_path):
db_path = tmp_path / "s3_files_metadata.db"
return S3FilesImplConfig(
bucket_name="test-bucket",
region="not-a-region",
auto_create_bucket=True,
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
)
@pytest.fixture
def s3_client():
"""Create a mocked S3 client for testing."""
# we use `with mock_aws()` because @mock_aws decorator does not support being a generator
with mock_aws():
# must yield or the mock will be reset before it is used
yield boto3.client("s3")
@pytest.fixture
async def s3_provider(s3_config, s3_client):
"""Create an S3 files provider with mocked S3 for testing."""
provider = await get_adapter_impl(s3_config, {})
yield provider
await provider.shutdown()
@pytest.fixture
def sample_text_file():
content = b"Hello, this is a test file for the S3 Files API!"
return MockUploadFile(content, "sample_text_file.txt")
class TestS3FilesImpl:
"""Test suite for S3 Files implementation."""
async def test_upload_file(self, s3_provider, sample_text_file, s3_client, s3_config):
"""Test successful file upload."""
sample_text_file.filename = "test_upload_file"
result = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
assert result.filename == sample_text_file.filename
assert result.purpose == OpenAIFilePurpose.ASSISTANTS
assert result.bytes == len(sample_text_file.content)
assert result.id.startswith("file-")
# Verify file exists in S3 backend
response = s3_client.head_object(Bucket=s3_config.bucket_name, Key=result.id)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
async def test_list_files_empty(self, s3_provider):
"""Test listing files when no files exist."""
result = await s3_provider.openai_list_files()
assert len(result.data) == 0
assert not result.has_more
assert result.first_id == ""
assert result.last_id == ""
async def test_retrieve_file(self, s3_provider, sample_text_file):
"""Test retrieving file metadata."""
sample_text_file.filename = "test_retrieve_file"
uploaded = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
retrieved = await s3_provider.openai_retrieve_file(uploaded.id)
assert retrieved.id == uploaded.id
assert retrieved.filename == uploaded.filename
assert retrieved.purpose == uploaded.purpose
assert retrieved.bytes == uploaded.bytes
async def test_retrieve_file_content(self, s3_provider, sample_text_file):
"""Test retrieving file content."""
sample_text_file.filename = "test_retrieve_file_content"
uploaded = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
response = await s3_provider.openai_retrieve_file_content(uploaded.id)
assert response.body == sample_text_file.content
assert response.headers["Content-Disposition"] == f'attachment; filename="{sample_text_file.filename}"'
async def test_delete_file(self, s3_provider, sample_text_file, s3_config, s3_client):
"""Test deleting a file."""
sample_text_file.filename = "test_delete_file"
uploaded = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
delete_response = await s3_provider.openai_delete_file(uploaded.id)
assert delete_response.id == uploaded.id
assert delete_response.deleted is True
with pytest.raises(ResourceNotFoundError, match="not found"):
await s3_provider.openai_retrieve_file(uploaded.id)
# Verify file is gone from S3 backend
with pytest.raises(ClientError) as exc_info:
s3_client.head_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
assert exc_info.value.response["Error"]["Code"] == "404"
async def test_list_files(self, s3_provider, sample_text_file):
"""Test listing files after uploading some."""
sample_text_file.filename = "test_list_files_with_content_file1"
file1 = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
file2_content = MockUploadFile(b"Second file content", "test_list_files_with_content_file2")
file2 = await s3_provider.openai_upload_file(
file=file2_content,
purpose=OpenAIFilePurpose.BATCH,
)
result = await s3_provider.openai_list_files()
assert len(result.data) == 2
file_ids = {f.id for f in result.data}
assert file1.id in file_ids
assert file2.id in file_ids
async def test_list_files_with_purpose_filter(self, s3_provider, sample_text_file):
"""Test listing files with purpose filter."""
sample_text_file.filename = "test_list_files_with_purpose_filter_file1"
file1 = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
file2_content = MockUploadFile(b"Batch file content", "test_list_files_with_purpose_filter_file2")
await s3_provider.openai_upload_file(
file=file2_content,
purpose=OpenAIFilePurpose.BATCH,
)
result = await s3_provider.openai_list_files(purpose=OpenAIFilePurpose.ASSISTANTS)
assert len(result.data) == 1
assert result.data[0].id == file1.id
assert result.data[0].purpose == OpenAIFilePurpose.ASSISTANTS
async def test_nonexistent_file_retrieval(self, s3_provider):
"""Test retrieving a non-existent file raises error."""
with pytest.raises(ResourceNotFoundError, match="not found"):
await s3_provider.openai_retrieve_file("file-nonexistent")
async def test_nonexistent_file_content_retrieval(self, s3_provider):
"""Test retrieving content of a non-existent file raises error."""
with pytest.raises(ResourceNotFoundError, match="not found"):
await s3_provider.openai_retrieve_file_content("file-nonexistent")
async def test_nonexistent_file_deletion(self, s3_provider):
"""Test deleting a non-existent file raises error."""
with pytest.raises(ResourceNotFoundError, match="not found"):
await s3_provider.openai_delete_file("file-nonexistent")
async def test_upload_file_without_filename(self, s3_provider, sample_text_file):
"""Test uploading a file without a filename uses the fallback."""
del sample_text_file.filename
result = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
assert result.purpose == OpenAIFilePurpose.ASSISTANTS
assert result.bytes == len(sample_text_file.content)
retrieved = await s3_provider.openai_retrieve_file(result.id)
assert retrieved.filename == result.filename
async def test_file_operations_when_s3_object_deleted(self, s3_provider, sample_text_file, s3_config, s3_client):
"""Test file operations when S3 object is deleted but metadata exists (negative test)."""
sample_text_file.filename = "test_orphaned_metadata"
uploaded = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
# Directly delete the S3 object from the backend
s3_client.delete_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
with pytest.raises(ResourceNotFoundError, match="not found") as exc_info:
await s3_provider.openai_retrieve_file_content(uploaded.id)
assert uploaded.id in str(exc_info).lower()
listed_files = await s3_provider.openai_list_files()
assert uploaded.id not in [file.id for file in listed_files.data]
async def test_upload_file_s3_put_object_failure(self, s3_provider, sample_text_file, s3_config, s3_client):
"""Test that put_object failure results in exception and no orphaned metadata."""
sample_text_file.filename = "test_s3_put_object_failure"
def failing_put_object(*args, **kwargs):
raise ClientError(
error_response={"Error": {"Code": "SolarRadiation", "Message": "Bloop"}}, operation_name="PutObject"
)
with patch.object(s3_provider.client, "put_object", side_effect=failing_put_object):
with pytest.raises(RuntimeError, match="Failed to upload file to S3"):
await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
files_list = await s3_provider.openai_list_files()
assert len(files_list.data) == 0, "No file metadata should remain after failed upload"

109
uv.lock generated
View file

@ -347,6 +347,34 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ed/4d/1392562369b1139e741b30d624f09fe7091d17dd5579fae5732f044b12bb/blobfile-3.0.0-py3-none-any.whl", hash = "sha256:48ecc3307e622804bd8fe13bf6f40e6463c4439eba7a1f9ad49fd78aa63cc658", size = 75413, upload-time = "2024-08-27T00:02:51.518Z" },
]
[[package]]
name = "boto3"
version = "1.40.12"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "botocore" },
{ name = "jmespath" },
{ name = "s3transfer" },
]
sdist = { url = "https://files.pythonhosted.org/packages/41/19/2c4d140a7f99b5903b21b9ccd7253c71f147c346c3c632b2117444cf2d65/boto3-1.40.12.tar.gz", hash = "sha256:c6b32aee193fbd2eb84696d2b5b2410dcda9fb4a385e1926cff908377d222247", size = 111959, upload-time = "2025-08-18T19:30:23.827Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/16/6e/5a9dcf38ad87838fb99742c4a3ab1b7507ad3a02c8c27a9ccda7a0bb5709/boto3-1.40.12-py3-none-any.whl", hash = "sha256:3c3d6731390b5b11f5e489d5d9daa57f0c3e171efb63ac8f47203df9c71812b3", size = 140075, upload-time = "2025-08-18T19:30:22.494Z" },
]
[[package]]
name = "botocore"
version = "1.40.12"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jmespath" },
{ name = "python-dateutil" },
{ name = "urllib3" },
]
sdist = { url = "https://files.pythonhosted.org/packages/7d/b2/7933590fc5bca1980801b71e09db1a95581afff177cbf3c8a031d922885c/botocore-1.40.12.tar.gz", hash = "sha256:c6560578e799b47b762b7e555bd9c5dd5c29c5d23bd778a8a72e98c979b3c727", size = 14349930, upload-time = "2025-08-18T19:30:13.794Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/b6/65fd6e718c9538ba1462c9b71e9262bc723202ff203fe64ff66ff676d823/botocore-1.40.12-py3-none-any.whl", hash = "sha256:84e96004a8b426c5508f6b5600312d6271364269466a3a957dc377ad8effc438", size = 14018004, upload-time = "2025-08-18T19:30:09.054Z" },
]
[[package]]
name = "braintrust-core"
version = "0.0.59"
@ -1580,6 +1608,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b3/4a/4175a563579e884192ba6e81725fc0448b042024419be8d83aa8a80a3f44/jiter-0.10.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3aa96f2abba33dc77f79b4cf791840230375f9534e5fac927ccceb58c5e604a5", size = 354213, upload-time = "2025-05-18T19:04:41.894Z" },
]
[[package]]
name = "jmespath"
version = "1.0.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/00/2a/e867e8531cf3e36b41201936b7fa7ba7b5702dbef42922193f05c8976cd6/jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe", size = 25843, upload-time = "2022-06-17T18:00:12.224Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/31/b4/b9b800c45527aadd64d5b442f9b932b00648617eb5d63d2c7a6587b7cafc/jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980", size = 20256, upload-time = "2022-06-17T18:00:10.251Z" },
]
[[package]]
name = "jsonschema"
version = "4.25.0"
@ -1820,6 +1857,7 @@ unit = [
{ name = "litellm" },
{ name = "mcp" },
{ name = "milvus-lite" },
{ name = "moto", extra = ["s3"] },
{ name = "ollama" },
{ name = "openai" },
{ name = "pymilvus" },
@ -1937,6 +1975,7 @@ unit = [
{ name = "litellm" },
{ name = "mcp" },
{ name = "milvus-lite", specifier = ">=2.5.0" },
{ name = "moto", extras = ["s3"], specifier = ">=5.1.10" },
{ name = "ollama" },
{ name = "openai" },
{ name = "pymilvus", specifier = ">=2.5.12" },
@ -2224,6 +2263,32 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/16/71/4ad9a42f2772793a03cb698f0fc42499f04e6e8d2560ba2f7da0fb059a8e/mmh3-5.1.0-cp313-cp313-win_arm64.whl", hash = "sha256:b22fe2e54be81f6c07dcb36b96fa250fb72effe08aa52fbb83eade6e1e2d5fd7", size = 38890, upload-time = "2025-01-25T08:39:25.28Z" },
]
[[package]]
name = "moto"
version = "5.1.10"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "boto3" },
{ name = "botocore" },
{ name = "cryptography" },
{ name = "jinja2" },
{ name = "python-dateutil" },
{ name = "requests" },
{ name = "responses" },
{ name = "werkzeug" },
{ name = "xmltodict" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c4/72/9bc9b4917b816f5a82fc8f0fbd477c2a669d35a7d7941ae15a5411e266d6/moto-5.1.10.tar.gz", hash = "sha256:d6bdc8f82a1e503502927cc0a3da22014f836094d0bf399bb0f695754ae6c7a6", size = 7087004, upload-time = "2025-08-11T20:59:45.542Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c4/37/9b9cb5597eecc2ebfde2f65a8265f3669f6724ebe82bf9b155a3421039f8/moto-5.1.10-py3-none-any.whl", hash = "sha256:9ec1a21a924f97470af225b2bfa854fe46c1ad30fb44655eba458206dedf28b5", size = 5246859, upload-time = "2025-08-11T20:59:43.22Z" },
]
[package.optional-dependencies]
s3 = [
{ name = "py-partiql-parser" },
{ name = "pyyaml" },
]
[[package]]
name = "mpmath"
version = "1.3.0"
@ -3068,6 +3133,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842, upload-time = "2024-07-21T12:58:20.04Z" },
]
[[package]]
name = "py-partiql-parser"
version = "0.6.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/58/a1/0a2867e48b232b4f82c4929ef7135f2a5d72c3886b957dccf63c70aa2fcb/py_partiql_parser-0.6.1.tar.gz", hash = "sha256:8583ff2a0e15560ef3bc3df109a7714d17f87d81d33e8c38b7fed4e58a63215d", size = 17120, upload-time = "2024-12-25T22:06:41.327Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/97/84/0e410c20bbe9a504fc56e97908f13261c2b313d16cbb3b738556166f044a/py_partiql_parser-0.6.1-py2.py3-none-any.whl", hash = "sha256:ff6a48067bff23c37e9044021bf1d949c83e195490c17e020715e927fe5b2456", size = 23520, upload-time = "2024-12-25T22:06:39.106Z" },
]
[[package]]
name = "pyaml"
version = "25.7.0"
@ -3788,6 +3862,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/3b/5d/63d4ae3b9daea098d5d6f5da83984853c1bbacd5dc826764b249fe119d24/requests_oauthlib-2.0.0-py2.py3-none-any.whl", hash = "sha256:7dd8a5c40426b779b0868c404bdef9768deccf22749cde15852df527e6269b36", size = 24179, upload-time = "2024-03-22T20:32:28.055Z" },
]
[[package]]
name = "responses"
version = "0.25.8"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pyyaml" },
{ name = "requests" },
{ name = "urllib3" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0e/95/89c054ad70bfef6da605338b009b2e283485835351a9935c7bfbfaca7ffc/responses-0.25.8.tar.gz", hash = "sha256:9374d047a575c8f781b94454db5cab590b6029505f488d12899ddb10a4af1cf4", size = 79320, upload-time = "2025-08-08T19:01:46.709Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1c/4c/cc276ce57e572c102d9542d383b2cfd551276581dc60004cb94fe8774c11/responses-0.25.8-py3-none-any.whl", hash = "sha256:0c710af92def29c8352ceadff0c3fe340ace27cf5af1bbe46fb71275bcd2831c", size = 34769, upload-time = "2025-08-08T19:01:45.018Z" },
]
[[package]]
name = "rich"
version = "14.1.0"
@ -3961,6 +4049,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/00/db/c376b0661c24cf770cb8815268190668ec1330eba8374a126ceef8c72d55/ruff-0.12.5-py3-none-win_arm64.whl", hash = "sha256:48cdbfc633de2c5c37d9f090ba3b352d1576b0015bfc3bc98eaf230275b7e805", size = 11951564, upload-time = "2025-07-24T13:26:34.994Z" },
]
[[package]]
name = "s3transfer"
version = "0.13.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "botocore" },
]
sdist = { url = "https://files.pythonhosted.org/packages/6d/05/d52bf1e65044b4e5e27d4e63e8d1579dbdec54fce685908ae09bc3720030/s3transfer-0.13.1.tar.gz", hash = "sha256:c3fdba22ba1bd367922f27ec8032d6a1cf5f10c934fb5d68cf60fd5a23d936cf", size = 150589, upload-time = "2025-07-18T19:22:42.31Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/6d/4f/d073e09df851cfa251ef7840007d04db3293a0482ce607d2b993926089be/s3transfer-0.13.1-py3-none-any.whl", hash = "sha256:a981aa7429be23fe6dfc13e80e4020057cbab622b08c0315288758d67cabc724", size = 85308, upload-time = "2025-07-18T19:22:40.947Z" },
]
[[package]]
name = "safetensors"
version = "0.5.3"
@ -5107,6 +5207,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/78/58/e860788190eba3bcce367f74d29c4675466ce8dddfba85f7827588416f01/wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736", size = 24226, upload-time = "2022-08-23T19:58:19.96Z" },
]
[[package]]
name = "xmltodict"
version = "0.14.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/50/05/51dcca9a9bf5e1bce52582683ce50980bcadbc4fa5143b9f2b19ab99958f/xmltodict-0.14.2.tar.gz", hash = "sha256:201e7c28bb210e374999d1dde6382923ab0ed1a8a5faeece48ab525b7810a553", size = 51942, upload-time = "2024-10-16T06:10:29.683Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d6/45/fc303eb433e8a2a271739c98e953728422fa61a3c1f36077a49e395c972e/xmltodict-0.14.2-py2.py3-none-any.whl", hash = "sha256:20cc7d723ed729276e808f26fb6b3599f786cbc37e06c65e192ba77c40f20aac", size = 9981, upload-time = "2024-10-16T06:10:27.649Z" },
]
[[package]]
name = "xxhash"
version = "3.5.0"