mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 16:12:46 +00:00
Merge remote-tracking branch 'origin/main' into ai
This commit is contained in:
commit
0b3218ed81
94 changed files with 4092 additions and 235 deletions
22
.github/workflows/pre-commit.yml
vendored
22
.github/workflows/pre-commit.yml
vendored
|
|
@ -36,20 +36,16 @@ jobs:
|
|||
**/requirements*.txt
|
||||
.pre-commit-config.yaml
|
||||
|
||||
# npm ci may fail -
|
||||
# npm error `npm ci` can only install packages when your package.json and package-lock.json or npm-shrinkwrap.json are in sync. Please update your lock file with `npm install` before continuing.
|
||||
# npm error Invalid: lock file's llama-stack-client@0.2.17 does not satisfy llama-stack-client@0.2.18
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0
|
||||
with:
|
||||
node-version: '20'
|
||||
cache: 'npm'
|
||||
cache-dependency-path: 'llama_stack/ui/'
|
||||
|
||||
# - name: Set up Node.js
|
||||
# uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0
|
||||
# with:
|
||||
# node-version: '20'
|
||||
# cache: 'npm'
|
||||
# cache-dependency-path: 'llama_stack/ui/'
|
||||
|
||||
# - name: Install npm dependencies
|
||||
# run: npm ci
|
||||
# working-directory: llama_stack/ui
|
||||
- name: Install npm dependencies
|
||||
run: npm ci
|
||||
working-directory: llama_stack/ui
|
||||
|
||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||
continue-on-error: true
|
||||
|
|
|
|||
|
|
@ -146,31 +146,13 @@ repos:
|
|||
pass_filenames: false
|
||||
require_serial: true
|
||||
files: ^.github/workflows/.*$
|
||||
# ui-prettier and ui-eslint are disabled until we can avoid `npm ci`, which is slow and may fail -
|
||||
# npm error `npm ci` can only install packages when your package.json and package-lock.json or npm-shrinkwrap.json are in sync. Please update your lock file with `npm install` before continuing.
|
||||
# npm error Invalid: lock file's llama-stack-client@0.2.17 does not satisfy llama-stack-client@0.2.18
|
||||
# and until we have infra for installing prettier and next via npm -
|
||||
# Lint UI code with ESLint.....................................................Failed
|
||||
# - hook id: ui-eslint
|
||||
# - exit code: 127
|
||||
# > ui@0.1.0 lint
|
||||
# > next lint --fix --quiet
|
||||
# sh: line 1: next: command not found
|
||||
#
|
||||
# - id: ui-prettier
|
||||
# name: Format UI code with Prettier
|
||||
# entry: bash -c 'cd llama_stack/ui && npm ci && npm run format'
|
||||
# language: system
|
||||
# files: ^llama_stack/ui/.*\.(ts|tsx)$
|
||||
# pass_filenames: false
|
||||
# require_serial: true
|
||||
# - id: ui-eslint
|
||||
# name: Lint UI code with ESLint
|
||||
# entry: bash -c 'cd llama_stack/ui && npm run lint -- --fix --quiet'
|
||||
# language: system
|
||||
# files: ^llama_stack/ui/.*\.(ts|tsx)$
|
||||
# pass_filenames: false
|
||||
# require_serial: true
|
||||
- id: ui-linter
|
||||
name: Format & Lint UI
|
||||
entry: bash ./scripts/run-ui-linter.sh
|
||||
language: system
|
||||
files: ^llama_stack/ui/.*\.(ts|tsx)$
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
|
||||
- id: check-log-usage
|
||||
name: Ensure 'llama_stack.log' usage for logging
|
||||
|
|
|
|||
132
docs/_static/llama-stack-spec.html
vendored
132
docs/_static/llama-stack-spec.html
vendored
|
|
@ -4605,6 +4605,49 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"/v1/inference/rerank": {
|
||||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "RerankResponse with indices sorted by relevance score (descending).",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/RerankResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Inference"
|
||||
],
|
||||
"description": "Rerank a list of documents based on their relevance to a query.",
|
||||
"parameters": [],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/RerankRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
|
||||
"post": {
|
||||
"responses": {
|
||||
|
|
@ -16587,6 +16630,95 @@
|
|||
],
|
||||
"title": "RegisterVectorDbRequest"
|
||||
},
|
||||
"RerankRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the reranking model to use."
|
||||
},
|
||||
"query": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
|
||||
}
|
||||
],
|
||||
"description": "The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length."
|
||||
},
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
|
||||
}
|
||||
]
|
||||
},
|
||||
"description": "List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length."
|
||||
},
|
||||
"max_num_results": {
|
||||
"type": "integer",
|
||||
"description": "(Optional) Maximum number of results to return. Default: returns all."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"model",
|
||||
"query",
|
||||
"items"
|
||||
],
|
||||
"title": "RerankRequest"
|
||||
},
|
||||
"RerankData": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"index": {
|
||||
"type": "integer",
|
||||
"description": "The original index of the document in the input list"
|
||||
},
|
||||
"relevance_score": {
|
||||
"type": "number",
|
||||
"description": "The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"index",
|
||||
"relevance_score"
|
||||
],
|
||||
"title": "RerankData",
|
||||
"description": "A single rerank result from a reranking response."
|
||||
},
|
||||
"RerankResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/RerankData"
|
||||
},
|
||||
"description": "List of rerank result objects, sorted by relevance score (descending)"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"data"
|
||||
],
|
||||
"title": "RerankResponse",
|
||||
"description": "Response from a reranking request."
|
||||
},
|
||||
"ResumeAgentTurnRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
|
|||
101
docs/_static/llama-stack-spec.yaml
vendored
101
docs/_static/llama-stack-spec.yaml
vendored
|
|
@ -3264,6 +3264,37 @@ paths:
|
|||
schema:
|
||||
$ref: '#/components/schemas/QueryTracesRequest'
|
||||
required: true
|
||||
/v1/inference/rerank:
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
RerankResponse with indices sorted by relevance score (descending).
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RerankResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Inference
|
||||
description: >-
|
||||
Rerank a list of documents based on their relevance to a query.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RerankRequest'
|
||||
required: true
|
||||
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -12337,6 +12368,76 @@ components:
|
|||
- vector_db_id
|
||||
- embedding_model
|
||||
title: RegisterVectorDbRequest
|
||||
RerankRequest:
|
||||
type: object
|
||||
properties:
|
||||
model:
|
||||
type: string
|
||||
description: >-
|
||||
The identifier of the reranking model to use.
|
||||
query:
|
||||
oneOf:
|
||||
- type: string
|
||||
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
|
||||
description: >-
|
||||
The search query to rank items against. Can be a string, text content
|
||||
part, or image content part. The input must not exceed the model's max
|
||||
input token length.
|
||||
items:
|
||||
type: array
|
||||
items:
|
||||
oneOf:
|
||||
- type: string
|
||||
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
|
||||
description: >-
|
||||
List of items to rerank. Each item can be a string, text content part,
|
||||
or image content part. Each input must not exceed the model's max input
|
||||
token length.
|
||||
max_num_results:
|
||||
type: integer
|
||||
description: >-
|
||||
(Optional) Maximum number of results to return. Default: returns all.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- model
|
||||
- query
|
||||
- items
|
||||
title: RerankRequest
|
||||
RerankData:
|
||||
type: object
|
||||
properties:
|
||||
index:
|
||||
type: integer
|
||||
description: >-
|
||||
The original index of the document in the input list
|
||||
relevance_score:
|
||||
type: number
|
||||
description: >-
|
||||
The relevance score from the model output. Values are inverted when applicable
|
||||
so that higher scores indicate greater relevance.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- index
|
||||
- relevance_score
|
||||
title: RerankData
|
||||
description: >-
|
||||
A single rerank result from a reranking response.
|
||||
RerankResponse:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/RerankData'
|
||||
description: >-
|
||||
List of rerank result objects, sorted by relevance score (descending)
|
||||
additionalProperties: false
|
||||
required:
|
||||
- data
|
||||
title: RerankResponse
|
||||
description: Response from a reranking request.
|
||||
ResumeAgentTurnRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
|||
|
|
@ -10,4 +10,5 @@ This section contains documentation for all available providers for the **files*
|
|||
:maxdepth: 1
|
||||
|
||||
inline_localfs
|
||||
remote_s3
|
||||
```
|
||||
|
|
|
|||
33
docs/source/providers/files/remote_s3.md
Normal file
33
docs/source/providers/files/remote_s3.md
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
# remote::s3
|
||||
|
||||
## Description
|
||||
|
||||
AWS S3-based file storage provider for scalable cloud file management with metadata persistence.
|
||||
|
||||
## Configuration
|
||||
|
||||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `bucket_name` | `<class 'str'>` | No | | S3 bucket name to store files |
|
||||
| `region` | `<class 'str'>` | No | us-east-1 | AWS region where the bucket is located |
|
||||
| `aws_access_key_id` | `str \| None` | No | | AWS access key ID (optional if using IAM roles) |
|
||||
| `aws_secret_access_key` | `str \| None` | No | | AWS secret access key (optional if using IAM roles) |
|
||||
| `endpoint_url` | `str \| None` | No | | Custom S3 endpoint URL (for MinIO, LocalStack, etc.) |
|
||||
| `auto_create_bucket` | `<class 'bool'>` | No | False | Automatically create the S3 bucket if it doesn't exist |
|
||||
| `metadata_store` | `utils.sqlstore.sqlstore.SqliteSqlStoreConfig \| utils.sqlstore.sqlstore.PostgresSqlStoreConfig` | No | sqlite | SQL store configuration for file metadata |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
bucket_name: ${env.S3_BUCKET_NAME}
|
||||
region: ${env.AWS_REGION:=us-east-1}
|
||||
aws_access_key_id: ${env.AWS_ACCESS_KEY_ID:=}
|
||||
aws_secret_access_key: ${env.AWS_SECRET_ACCESS_KEY:=}
|
||||
endpoint_url: ${env.S3_ENDPOINT_URL:=}
|
||||
auto_create_bucket: ${env.S3_AUTO_CREATE_BUCKET:=false}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/s3_files_metadata.db
|
||||
|
||||
```
|
||||
|
||||
|
|
@ -473,6 +473,28 @@ class EmbeddingsResponse(BaseModel):
|
|||
embeddings: list[list[float]]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RerankData(BaseModel):
|
||||
"""A single rerank result from a reranking response.
|
||||
|
||||
:param index: The original index of the document in the input list
|
||||
:param relevance_score: The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance.
|
||||
"""
|
||||
|
||||
index: int
|
||||
relevance_score: float
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RerankResponse(BaseModel):
|
||||
"""Response from a reranking request.
|
||||
|
||||
:param data: List of rerank result objects, sorted by relevance score (descending)
|
||||
"""
|
||||
|
||||
data: list[RerankData]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
||||
"""Text content part for OpenAI-compatible chat completion messages.
|
||||
|
|
@ -1131,6 +1153,24 @@ class InferenceProvider(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/inference/rerank", method="POST", experimental=True)
|
||||
async def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||
max_num_results: int | None = None,
|
||||
) -> RerankResponse:
|
||||
"""Rerank a list of documents based on their relevance to a query.
|
||||
|
||||
:param model: The identifier of the reranking model to use.
|
||||
:param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length.
|
||||
:param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length.
|
||||
:param max_num_results: (Optional) Maximum number of results to return. Default: returns all.
|
||||
:returns: RerankResponse with indices sorted by relevance score (descending).
|
||||
"""
|
||||
raise NotImplementedError("Reranking is not implemented")
|
||||
|
||||
@webmethod(route="/openai/v1/completions", method="POST")
|
||||
async def openai_completion(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
logger = get_logger(name=__name__, category="server")
|
||||
logger = get_logger(name=__name__, category="cli")
|
||||
|
||||
|
||||
class StackRun(Subcommand):
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
||||
class DatasetIORouter(DatasetIO):
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.scoring import (
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
||||
class ScoringRouter(Scoring):
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin
|
|||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
||||
class InferenceRouter(Inference):
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.shields import Shield
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
||||
class SafetyRouter(Safety):
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
||||
class ToolRuntimeRouter(ToolRuntime):
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
||||
class VectorIORouter(VectorIO):
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
def get_impl_api(p: Any) -> Api:
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider
|
|||
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
logger = get_logger(name=__name__, category="core::auth")
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from llama_stack.core.datatypes import (
|
|||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
logger = get_logger(name=__name__, category="core::auth")
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
|
|||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="quota")
|
||||
logger = get_logger(name=__name__, category="core::server")
|
||||
|
||||
|
||||
class QuotaMiddleware:
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ from .quota import QuotaMiddleware
|
|||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
logger = get_logger(name=__name__, category="server")
|
||||
logger = get_logger(name=__name__, category="core::server")
|
||||
|
||||
|
||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||
|
|
@ -415,7 +415,7 @@ def main(args: argparse.Namespace | None = None):
|
|||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
logger = get_logger(name=__name__, category="server", config=logger_config)
|
||||
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from llama_stack.log import get_logger
|
|||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
logger = get_logger(__name__, category="core::registry")
|
||||
|
||||
|
||||
class DistributionRegistry(Protocol):
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from pathlib import Path
|
|||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="config_resolution")
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ from .utils import get_negative_inf_value, to_2tuple
|
|||
|
||||
MP_SCALE = 8
|
||||
|
||||
logger = get_logger(name=__name__, category="models")
|
||||
logger = get_logger(name=__name__, category="models::llama")
|
||||
|
||||
|
||||
def reduce_from_tensor_model_parallel_region(input_):
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
logger = get_logger(name=__name__, category="models::llama")
|
||||
|
||||
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
||||
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from ...datatypes import QuantizationMode
|
|||
from ..model import Transformer, TransformerBlock
|
||||
from ..moe import MoE
|
||||
|
||||
log = get_logger(name=__name__, category="models")
|
||||
log = get_logger(name=__name__, category="models::llama")
|
||||
|
||||
|
||||
def swiglu_wrapper_no_reduce(
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import collections
|
|||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
log = get_logger(name=__name__, category="llama")
|
||||
log = get_logger(name=__name__, category="models::llama")
|
||||
|
||||
try:
|
||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search"
|
|||
WEB_SEARCH_TOOL = "web_search"
|
||||
RAG_TOOL_GROUP = "builtin::rag"
|
||||
|
||||
logger = get_logger(name=__name__, category="agents")
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
class ChatAgent(ShieldRunnerMixin):
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
|
|||
from .persistence import AgentInfo
|
||||
from .responses.openai_responses import OpenAIResponsesImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="agents")
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
class MetaReferenceAgentsImpl(Agents):
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from llama_stack.core.request_headers import get_authenticated_user
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = get_logger(name=__name__, category="agents")
|
||||
log = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
class AgentSessionInfo(Session):
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ from .utils import (
|
|||
convert_response_text_to_chat_response_format,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="responses")
|
||||
logger = get_logger(name=__name__, category="openai::responses")
|
||||
|
||||
|
||||
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ from llama_stack.log import get_logger
|
|||
from .types import ChatCompletionContext, ChatCompletionResult
|
||||
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
|
||||
|
||||
logger = get_logger(name=__name__, category="responses")
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
class StreamingResponseOrchestrator:
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .types import ChatCompletionContext, ToolExecutionResult
|
||||
|
||||
logger = get_logger(name=__name__, category="responses")
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
log = get_logger(name=__name__, category="agents")
|
||||
log = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
class SafetyException(Exception): # noqa: N818
|
||||
|
|
|
|||
|
|
@ -33,6 +33,9 @@ from llama_stack.apis.inference import (
|
|||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
RerankResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
|
|
@ -442,6 +445,15 @@ class MetaReferenceInferenceImpl(
|
|||
results = await self._nonstream_chat_completion(request_batch)
|
||||
return BatchChatCompletionResponse(batch=results)
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||
max_num_results: int | None = None,
|
||||
) -> RerankResponse:
|
||||
raise NotImplementedError("Reranking is not supported for Meta Reference")
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request_batch: list[ChatCompletionRequest]
|
||||
) -> list[ChatCompletionResponse]:
|
||||
|
|
|
|||
|
|
@ -12,6 +12,9 @@ from llama_stack.apis.inference import (
|
|||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
RerankResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
|
|
@ -122,3 +125,12 @@ class SentenceTransformersInferenceImpl(
|
|||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||
max_num_results: int | None = None,
|
||||
) -> RerankResponse:
|
||||
raise NotImplementedError("Reranking is not supported for Sentence Transformers")
|
||||
|
|
|
|||
|
|
@ -5,9 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
|
||||
|
||||
|
|
@ -23,4 +25,14 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
|
||||
description="Local filesystem-based file storage provider for managing files and documents locally.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.files,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="s3",
|
||||
pip_packages=["boto3"] + sql_store_pip_packages,
|
||||
module="llama_stack.providers.remote.files.s3",
|
||||
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
||||
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
|||
237
llama_stack/providers/remote/files/s3/README.md
Normal file
237
llama_stack/providers/remote/files/s3/README.md
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
# S3 Files Provider
|
||||
|
||||
A remote S3-based implementation of the Llama Stack Files API that provides scalable cloud file storage with metadata persistence.
|
||||
|
||||
## Features
|
||||
|
||||
- **AWS S3 Storage**: Store files in AWS S3 buckets for scalable, durable storage
|
||||
- **Metadata Management**: Uses SQL database for efficient file metadata queries
|
||||
- **OpenAI API Compatibility**: Full compatibility with OpenAI Files API endpoints
|
||||
- **Flexible Authentication**: Support for IAM roles and access keys
|
||||
- **Custom S3 Endpoints**: Support for MinIO and other S3-compatible services
|
||||
|
||||
## Configuration
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
api: files
|
||||
provider_type: remote::s3
|
||||
config:
|
||||
bucket_name: my-llama-stack-files
|
||||
region: us-east-1
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ./s3_files_metadata.db
|
||||
```
|
||||
|
||||
### Advanced Configuration
|
||||
|
||||
```yaml
|
||||
api: files
|
||||
provider_type: remote::s3
|
||||
config:
|
||||
bucket_name: my-llama-stack-files
|
||||
region: us-east-1
|
||||
aws_access_key_id: YOUR_ACCESS_KEY
|
||||
aws_secret_access_key: YOUR_SECRET_KEY
|
||||
endpoint_url: https://s3.amazonaws.com # Optional for custom endpoints
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ./s3_files_metadata.db
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
The configuration supports environment variable substitution:
|
||||
|
||||
```yaml
|
||||
config:
|
||||
bucket_name: "${env.S3_BUCKET_NAME}"
|
||||
region: "${env.AWS_REGION:=us-east-1}"
|
||||
aws_access_key_id: "${env.AWS_ACCESS_KEY_ID:=}"
|
||||
aws_secret_access_key: "${env.AWS_SECRET_ACCESS_KEY:=}"
|
||||
endpoint_url: "${env.S3_ENDPOINT_URL:=}"
|
||||
```
|
||||
|
||||
Note: `S3_BUCKET_NAME` has no default value since S3 bucket names must be globally unique.
|
||||
|
||||
## Authentication
|
||||
|
||||
### IAM Roles (Recommended)
|
||||
|
||||
For production deployments, use IAM roles:
|
||||
|
||||
```yaml
|
||||
config:
|
||||
bucket_name: my-bucket
|
||||
region: us-east-1
|
||||
# No credentials needed - will use IAM role
|
||||
```
|
||||
|
||||
### Access Keys
|
||||
|
||||
For development or specific use cases:
|
||||
|
||||
```yaml
|
||||
config:
|
||||
bucket_name: my-bucket
|
||||
region: us-east-1
|
||||
aws_access_key_id: AKIAIOSFODNN7EXAMPLE
|
||||
aws_secret_access_key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY
|
||||
```
|
||||
|
||||
## S3 Bucket Setup
|
||||
|
||||
### Required Permissions
|
||||
|
||||
The S3 provider requires the following permissions:
|
||||
|
||||
```json
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": [
|
||||
"s3:GetObject",
|
||||
"s3:PutObject",
|
||||
"s3:DeleteObject",
|
||||
"s3:ListBucket"
|
||||
],
|
||||
"Resource": [
|
||||
"arn:aws:s3:::your-bucket-name",
|
||||
"arn:aws:s3:::your-bucket-name/*"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Automatic Bucket Creation
|
||||
|
||||
By default, the S3 provider expects the bucket to already exist. If you want the provider to automatically create the bucket when it doesn't exist, set `auto_create_bucket: true` in your configuration:
|
||||
|
||||
```yaml
|
||||
config:
|
||||
bucket_name: my-bucket
|
||||
auto_create_bucket: true # Will create bucket if it doesn't exist
|
||||
region: us-east-1
|
||||
```
|
||||
|
||||
**Note**: When `auto_create_bucket` is enabled, the provider will need additional permissions:
|
||||
|
||||
```json
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": [
|
||||
"s3:GetObject",
|
||||
"s3:PutObject",
|
||||
"s3:DeleteObject",
|
||||
"s3:ListBucket",
|
||||
"s3:CreateBucket"
|
||||
],
|
||||
"Resource": [
|
||||
"arn:aws:s3:::your-bucket-name",
|
||||
"arn:aws:s3:::your-bucket-name/*"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Bucket Policy (Optional)
|
||||
|
||||
For additional security, you can add a bucket policy:
|
||||
|
||||
```json
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Sid": "LlamaStackAccess",
|
||||
"Effect": "Allow",
|
||||
"Principal": {
|
||||
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
|
||||
},
|
||||
"Action": [
|
||||
"s3:GetObject",
|
||||
"s3:PutObject",
|
||||
"s3:DeleteObject"
|
||||
],
|
||||
"Resource": "arn:aws:s3:::your-bucket-name/*"
|
||||
},
|
||||
{
|
||||
"Sid": "LlamaStackBucketAccess",
|
||||
"Effect": "Allow",
|
||||
"Principal": {
|
||||
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
|
||||
},
|
||||
"Action": [
|
||||
"s3:ListBucket"
|
||||
],
|
||||
"Resource": "arn:aws:s3:::your-bucket-name"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
### Metadata Persistence
|
||||
|
||||
File metadata is stored in a SQL database for fast queries and OpenAI API compatibility. The metadata includes:
|
||||
|
||||
- File ID
|
||||
- Original filename
|
||||
- Purpose (assistants, batch, etc.)
|
||||
- File size in bytes
|
||||
- Created and expiration timestamps
|
||||
|
||||
### TTL and Cleanup
|
||||
|
||||
Files currently have a fixed long expiration time (100 years).
|
||||
|
||||
## Development and Testing
|
||||
|
||||
### Using MinIO
|
||||
|
||||
For self-hosted S3-compatible storage:
|
||||
|
||||
```yaml
|
||||
config:
|
||||
bucket_name: test-bucket
|
||||
region: us-east-1
|
||||
endpoint_url: http://localhost:9000
|
||||
aws_access_key_id: minioadmin
|
||||
aws_secret_access_key: minioadmin
|
||||
```
|
||||
|
||||
## Monitoring and Logging
|
||||
|
||||
The provider logs important operations and errors. For production deployments, consider:
|
||||
|
||||
- CloudWatch monitoring for S3 operations
|
||||
- Custom metrics for file upload/download rates
|
||||
- Error rate monitoring
|
||||
- Performance metrics tracking
|
||||
|
||||
## Error Handling
|
||||
|
||||
The provider handles various error scenarios:
|
||||
|
||||
- S3 connectivity issues
|
||||
- Bucket access permissions
|
||||
- File not found errors
|
||||
- Metadata consistency checks
|
||||
|
||||
## Known Limitations
|
||||
|
||||
- Fixed long TTL (100 years) instead of configurable expiration
|
||||
- No server-side encryption enabled by default
|
||||
- No support for AWS session tokens
|
||||
- No S3 key prefix organization support
|
||||
- No multipart upload support (all files uploaded as single objects)
|
||||
20
llama_stack/providers/remote/files/s3/__init__.py
Normal file
20
llama_stack/providers/remote/files/s3/__init__.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import Api
|
||||
|
||||
from .config import S3FilesImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]):
|
||||
from .files import S3FilesImpl
|
||||
|
||||
# TODO: authorization policies and user separation
|
||||
impl = S3FilesImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
42
llama_stack/providers/remote/files/s3/config.py
Normal file
42
llama_stack/providers/remote/files/s3/config.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||
|
||||
|
||||
class S3FilesImplConfig(BaseModel):
|
||||
"""Configuration for S3-based files provider."""
|
||||
|
||||
bucket_name: str = Field(description="S3 bucket name to store files")
|
||||
region: str = Field(default="us-east-1", description="AWS region where the bucket is located")
|
||||
aws_access_key_id: str | None = Field(default=None, description="AWS access key ID (optional if using IAM roles)")
|
||||
aws_secret_access_key: str | None = Field(
|
||||
default=None, description="AWS secret access key (optional if using IAM roles)"
|
||||
)
|
||||
endpoint_url: str | None = Field(default=None, description="Custom S3 endpoint URL (for MinIO, LocalStack, etc.)")
|
||||
auto_create_bucket: bool = Field(
|
||||
default=False, description="Automatically create the S3 bucket if it doesn't exist"
|
||||
)
|
||||
metadata_store: SqlStoreConfig = Field(description="SQL store configuration for file metadata")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"bucket_name": "${env.S3_BUCKET_NAME}", # no default, buckets must be globally unique
|
||||
"region": "${env.AWS_REGION:=us-east-1}",
|
||||
"aws_access_key_id": "${env.AWS_ACCESS_KEY_ID:=}",
|
||||
"aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:=}",
|
||||
"endpoint_url": "${env.S3_ENDPOINT_URL:=}",
|
||||
"auto_create_bucket": "${env.S3_AUTO_CREATE_BUCKET:=false}",
|
||||
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="s3_files_metadata.db",
|
||||
),
|
||||
}
|
||||
272
llama_stack/providers/remote/files/s3/files.py
Normal file
272
llama_stack/providers/remote/files/s3/files.py
Normal file
|
|
@ -0,0 +1,272 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
|
||||
from fastapi import File, Form, Response, UploadFile
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import (
|
||||
Files,
|
||||
ListOpenAIFileResponse,
|
||||
OpenAIFileDeleteResponse,
|
||||
OpenAIFileObject,
|
||||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl
|
||||
|
||||
from .config import S3FilesImplConfig
|
||||
|
||||
# TODO: provider data for S3 credentials
|
||||
|
||||
|
||||
def _create_s3_client(config: S3FilesImplConfig) -> boto3.client:
|
||||
try:
|
||||
s3_config = {
|
||||
"region_name": config.region,
|
||||
}
|
||||
|
||||
# endpoint URL if specified (for MinIO, LocalStack, etc.)
|
||||
if config.endpoint_url:
|
||||
s3_config["endpoint_url"] = config.endpoint_url
|
||||
|
||||
if config.aws_access_key_id and config.aws_secret_access_key:
|
||||
s3_config.update(
|
||||
{
|
||||
"aws_access_key_id": config.aws_access_key_id,
|
||||
"aws_secret_access_key": config.aws_secret_access_key,
|
||||
}
|
||||
)
|
||||
|
||||
return boto3.client("s3", **s3_config)
|
||||
|
||||
except (BotoCoreError, NoCredentialsError) as e:
|
||||
raise RuntimeError(f"Failed to initialize S3 client: {e}") from e
|
||||
|
||||
|
||||
async def _create_bucket_if_not_exists(client: boto3.client, config: S3FilesImplConfig) -> None:
|
||||
try:
|
||||
client.head_bucket(Bucket=config.bucket_name)
|
||||
except ClientError as e:
|
||||
error_code = e.response["Error"]["Code"]
|
||||
if error_code == "404":
|
||||
if not config.auto_create_bucket:
|
||||
raise RuntimeError(
|
||||
f"S3 bucket '{config.bucket_name}' does not exist. "
|
||||
f"Either create the bucket manually or set 'auto_create_bucket: true' in your configuration."
|
||||
) from e
|
||||
try:
|
||||
# For us-east-1, we can't specify LocationConstraint
|
||||
if config.region == "us-east-1":
|
||||
client.create_bucket(Bucket=config.bucket_name)
|
||||
else:
|
||||
client.create_bucket(
|
||||
Bucket=config.bucket_name,
|
||||
CreateBucketConfiguration={"LocationConstraint": config.region},
|
||||
)
|
||||
except ClientError as create_error:
|
||||
raise RuntimeError(
|
||||
f"Failed to create S3 bucket '{config.bucket_name}': {create_error}"
|
||||
) from create_error
|
||||
elif error_code == "403":
|
||||
raise RuntimeError(f"Access denied to S3 bucket '{config.bucket_name}'") from e
|
||||
else:
|
||||
raise RuntimeError(f"Failed to access S3 bucket '{config.bucket_name}': {e}") from e
|
||||
|
||||
|
||||
class S3FilesImpl(Files):
|
||||
"""S3-based implementation of the Files API."""
|
||||
|
||||
# TODO: implement expiration, for now a silly offset
|
||||
_SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60
|
||||
|
||||
def __init__(self, config: S3FilesImplConfig) -> None:
|
||||
self._config = config
|
||||
self._client: boto3.client | None = None
|
||||
self._sql_store: SqlStore | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self._client = _create_s3_client(self._config)
|
||||
await _create_bucket_if_not_exists(self._client, self._config)
|
||||
|
||||
self._sql_store = sqlstore_impl(self._config.metadata_store)
|
||||
await self._sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"filename": ColumnType.STRING,
|
||||
"purpose": ColumnType.STRING,
|
||||
"bytes": ColumnType.INTEGER,
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"expires_at": ColumnType.INTEGER,
|
||||
# TODO: add s3_etag field for integrity checking
|
||||
},
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def client(self) -> boto3.client:
|
||||
assert self._client is not None, "Provider not initialized"
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def sql_store(self) -> SqlStore:
|
||||
assert self._sql_store is not None, "Provider not initialized"
|
||||
return self._sql_store
|
||||
|
||||
async def openai_upload_file(
|
||||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
) -> OpenAIFileObject:
|
||||
file_id = f"file-{uuid.uuid4().hex}"
|
||||
|
||||
filename = getattr(file, "filename", None) or "uploaded_file"
|
||||
|
||||
created_at = int(time.time())
|
||||
expires_at = created_at + self._SILLY_EXPIRATION_OFFSET
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
await self.sql_store.insert(
|
||||
"openai_files",
|
||||
{
|
||||
"id": file_id,
|
||||
"filename": filename,
|
||||
"purpose": purpose.value,
|
||||
"bytes": file_size,
|
||||
"created_at": created_at,
|
||||
"expires_at": expires_at,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
self.client.put_object(
|
||||
Bucket=self._config.bucket_name,
|
||||
Key=file_id,
|
||||
Body=content,
|
||||
# TODO: enable server-side encryption
|
||||
)
|
||||
except ClientError as e:
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
|
||||
raise RuntimeError(f"Failed to upload file to S3: {e}") from e
|
||||
|
||||
return OpenAIFileObject(
|
||||
id=file_id,
|
||||
filename=filename,
|
||||
purpose=purpose,
|
||||
bytes=file_size,
|
||||
created_at=created_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
async def openai_list_files(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 10000,
|
||||
order: Order | None = Order.desc,
|
||||
purpose: OpenAIFilePurpose | None = None,
|
||||
) -> ListOpenAIFileResponse:
|
||||
# this purely defensive. it should not happen because the router also default to Order.desc.
|
||||
if not order:
|
||||
order = Order.desc
|
||||
|
||||
where_conditions = {}
|
||||
if purpose:
|
||||
where_conditions["purpose"] = purpose.value
|
||||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
where=where_conditions if where_conditions else None,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
files = [
|
||||
OpenAIFileObject(
|
||||
id=row["id"],
|
||||
filename=row["filename"],
|
||||
purpose=OpenAIFilePurpose(row["purpose"]),
|
||||
bytes=row["bytes"],
|
||||
created_at=row["created_at"],
|
||||
expires_at=row["expires_at"],
|
||||
)
|
||||
for row in paginated_result.data
|
||||
]
|
||||
|
||||
return ListOpenAIFileResponse(
|
||||
data=files,
|
||||
has_more=paginated_result.has_more,
|
||||
# empty string or None? spec says str, ref impl returns str | None, we go with spec
|
||||
first_id=files[0].id if files else "",
|
||||
last_id=files[-1].id if files else "",
|
||||
)
|
||||
|
||||
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
|
||||
return OpenAIFileObject(
|
||||
id=row["id"],
|
||||
filename=row["filename"],
|
||||
purpose=OpenAIFilePurpose(row["purpose"]),
|
||||
bytes=row["bytes"],
|
||||
created_at=row["created_at"],
|
||||
expires_at=row["expires_at"],
|
||||
)
|
||||
|
||||
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
|
||||
try:
|
||||
self.client.delete_object(
|
||||
Bucket=self._config.bucket_name,
|
||||
Key=row["id"],
|
||||
)
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] != "NoSuchKey":
|
||||
raise RuntimeError(f"Failed to delete file from S3: {e}") from e
|
||||
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
|
||||
return OpenAIFileDeleteResponse(id=file_id, deleted=True)
|
||||
|
||||
async def openai_retrieve_file_content(self, file_id: str) -> Response:
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
|
||||
try:
|
||||
response = self.client.get_object(
|
||||
Bucket=self._config.bucket_name,
|
||||
Key=row["id"],
|
||||
)
|
||||
# TODO: can we stream this instead of loading it into memory
|
||||
content = response["Body"].read()
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] == "NoSuchKey":
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()") from e
|
||||
raise RuntimeError(f"Failed to download file from S3: {e}") from e
|
||||
|
||||
return Response(
|
||||
content=content,
|
||||
media_type="application/octet-stream",
|
||||
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
|
||||
)
|
||||
|
|
@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import FireworksImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
logger = get_logger(name=__name__, category="inference::fireworks")
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
|
|
|
|||
|
|
@ -3,6 +3,11 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
RerankResponse,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
|
@ -10,7 +15,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
|||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
|
||||
|
||||
|
||||
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
|
|
@ -54,3 +59,12 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||
max_num_results: int | None = None,
|
||||
) -> RerankResponse:
|
||||
raise NotImplementedError("Reranking is not supported for Llama OpenAI Compat")
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ from .openai_utils import (
|
|||
)
|
||||
from .utils import _is_nvidia_hosted
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
logger = get_logger(name=__name__, category="inference::nvidia")
|
||||
|
||||
|
||||
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from . import NVIDIAConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
logger = get_logger(name=__name__, category="inference::nvidia")
|
||||
|
||||
|
||||
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
||||
|
|
|
|||
|
|
@ -37,11 +37,14 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
RerankResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
|
@ -85,7 +88,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
logger = get_logger(name=__name__, category="inference::ollama")
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(
|
||||
|
|
@ -641,6 +644,15 @@ class OllamaInferenceAdapter(
|
|||
):
|
||||
raise NotImplementedError("Batch chat completion is not supported for Ollama")
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||
max_num_results: int | None = None,
|
||||
) -> RerankResponse:
|
||||
raise NotImplementedError("Reranking is not supported for Ollama")
|
||||
|
||||
|
||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
||||
async def _convert_content(content) -> dict:
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
|||
from .config import OpenAIConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
logger = get_logger(name=__name__, category="inference::openai")
|
||||
|
||||
|
||||
#
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
log = get_logger(name=__name__, category="inference::tgi")
|
||||
|
||||
|
||||
def build_hf_repo_model_entries():
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import TogetherImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
logger = get_logger(name=__name__, category="inference::together")
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
|
|
|
|||
|
|
@ -39,12 +39,15 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
ModelStore,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
RerankResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
|
@ -85,7 +88,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import VLLMInferenceAdapterConfig
|
||||
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
log = get_logger(name=__name__, category="inference::vllm")
|
||||
|
||||
|
||||
def build_hf_repo_model_entries():
|
||||
|
|
@ -732,4 +735,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch chat completion is not supported for Ollama")
|
||||
raise NotImplementedError("Batch chat completion is not supported for vLLM")
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||
max_num_results: int | None = None,
|
||||
) -> RerankResponse:
|
||||
raise NotImplementedError("Reranking is not supported for vLLM")
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefa
|
|||
|
||||
from .config import NvidiaPostTrainingConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="integration")
|
||||
logger = get_logger(name=__name__, category="post_training::nvidia")
|
||||
|
||||
|
||||
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
|||
|
||||
from .config import BedrockSafetyConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="safety")
|
||||
logger = get_logger(name=__name__, category="safety::bedrock")
|
||||
|
||||
|
||||
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
|
|||
|
||||
from .config import NVIDIASafetyConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="safety")
|
||||
logger = get_logger(name=__name__, category="safety::nvidia")
|
||||
|
||||
|
||||
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
|
|||
|
||||
from .config import SambaNovaSafetyConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="safety")
|
||||
logger = get_logger(name=__name__, category="safety::sambanova")
|
||||
|
||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||
|
||||
log = get_logger(name=__name__, category="vector_io")
|
||||
log = get_logger(name=__name__, category="vector_io::chroma")
|
||||
|
||||
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
|
||||
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
|||
|
||||
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="vector_io")
|
||||
logger = get_logger(name=__name__, category="vector_io::milvus")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import PGVectorVectorIOConfig
|
||||
|
||||
log = get_logger(name=__name__, category="vector_io")
|
||||
log = get_logger(name=__name__, category="vector_io::pgvector")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
||||
|
||||
log = get_logger(name=__name__, category="vector_io")
|
||||
log = get_logger(name=__name__, category="vector_io::qdrant")
|
||||
CHUNK_ID_KEY = "_chunk_id"
|
||||
|
||||
# KV store prefixes for vector databases
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
|||
|
||||
from .config import WeaviateVectorIOConfig
|
||||
|
||||
log = get_logger(name=__name__, category="vector_io")
|
||||
log = get_logger(name=__name__, category="vector_io::weaviate")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
|
|||
EMBEDDING_MODELS = {}
|
||||
|
||||
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
log = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class SentenceTransformerEmbeddingMixin:
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class LiteLLMOpenAIMixin(
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference import (
|
|||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class RemoteInferenceProviderConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
decode_assistant_message,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class OpenAICompatCompletionChoiceDelta(BaseModel):
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class OpenAIMixin(ABC):
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ from llama_stack.models.llama.sku_list import resolve_model
|
|||
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
log = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from llama_stack.providers.utils.kvstore import KVStore
|
|||
|
||||
from ..config import MongoDBKVStoreConfig
|
||||
|
||||
log = get_logger(name=__name__, category="kvstore")
|
||||
log = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class MongoDBKVStoreImpl(KVStore):
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from llama_stack.log import get_logger
|
|||
from ..api import KVStore
|
||||
from ..config import PostgresKVStoreConfig
|
||||
|
||||
log = get_logger(name=__name__, category="kvstore")
|
||||
log = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class PostgresKVStoreImpl(KVStore):
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
make_overlapped_chunks,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="memory")
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
# Constants for OpenAI vector stores
|
||||
CHUNK_MULTIPLIER = 5
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
|
||||
log = get_logger(name=__name__, category="memory")
|
||||
log = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class ChunkForDeletion(BaseModel):
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from pydantic import BaseModel
|
|||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="scheduler")
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
# TODO: revisit the list of possible statuses when defining a more coherent
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from llama_stack.log import get_logger
|
|||
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
||||
from .sqlstore import SqlStoreType
|
||||
|
||||
logger = get_logger(name=__name__, category="authorized_sqlstore")
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
# Hardcoded copy of the default policy that our SQL filtering implements
|
||||
# WARNING: If default_policy() changes, this constant must be updated accordingly
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from llama_stack.log import get_logger
|
|||
from .api import ColumnDefinition, ColumnType, SqlStore
|
||||
from .sqlstore import SqlAlchemySqlStoreConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="sqlstore")
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
TYPE_MAPPING: dict[ColumnType, Any] = {
|
||||
ColumnType.INTEGER: Integer,
|
||||
|
|
|
|||
587
llama_stack/ui/app/chat-playground/page.test.tsx
Normal file
587
llama_stack/ui/app/chat-playground/page.test.tsx
Normal file
|
|
@ -0,0 +1,587 @@
|
|||
import React from "react";
|
||||
import {
|
||||
render,
|
||||
screen,
|
||||
fireEvent,
|
||||
waitFor,
|
||||
act,
|
||||
} from "@testing-library/react";
|
||||
import "@testing-library/jest-dom";
|
||||
import ChatPlaygroundPage from "./page";
|
||||
|
||||
const mockClient = {
|
||||
agents: {
|
||||
list: jest.fn(),
|
||||
create: jest.fn(),
|
||||
retrieve: jest.fn(),
|
||||
delete: jest.fn(),
|
||||
session: {
|
||||
list: jest.fn(),
|
||||
create: jest.fn(),
|
||||
delete: jest.fn(),
|
||||
retrieve: jest.fn(),
|
||||
},
|
||||
turn: {
|
||||
create: jest.fn(),
|
||||
},
|
||||
},
|
||||
models: {
|
||||
list: jest.fn(),
|
||||
},
|
||||
toolgroups: {
|
||||
list: jest.fn(),
|
||||
},
|
||||
};
|
||||
|
||||
jest.mock("@/hooks/use-auth-client", () => ({
|
||||
useAuthClient: jest.fn(() => mockClient),
|
||||
}));
|
||||
|
||||
jest.mock("@/components/chat-playground/chat", () => ({
|
||||
Chat: jest.fn(
|
||||
({
|
||||
className,
|
||||
messages,
|
||||
handleSubmit,
|
||||
input,
|
||||
handleInputChange,
|
||||
isGenerating,
|
||||
append,
|
||||
suggestions,
|
||||
}) => (
|
||||
<div data-testid="chat-component" className={className}>
|
||||
<div data-testid="messages-count">{messages.length}</div>
|
||||
<input
|
||||
data-testid="chat-input"
|
||||
value={input}
|
||||
onChange={handleInputChange}
|
||||
disabled={isGenerating}
|
||||
/>
|
||||
<button data-testid="submit-button" onClick={handleSubmit}>
|
||||
Submit
|
||||
</button>
|
||||
{suggestions?.map((suggestion: string, index: number) => (
|
||||
<button
|
||||
key={index}
|
||||
data-testid={`suggestion-${index}`}
|
||||
onClick={() => append({ role: "user", content: suggestion })}
|
||||
>
|
||||
{suggestion}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
),
|
||||
}));
|
||||
|
||||
jest.mock("@/components/chat-playground/conversations", () => ({
|
||||
SessionManager: jest.fn(({ selectedAgentId, onNewSession }) => (
|
||||
<div data-testid="session-manager">
|
||||
{selectedAgentId && (
|
||||
<>
|
||||
<div data-testid="selected-agent">{selectedAgentId}</div>
|
||||
<button data-testid="new-session-button" onClick={onNewSession}>
|
||||
New Session
|
||||
</button>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)),
|
||||
SessionUtils: {
|
||||
saveCurrentSessionId: jest.fn(),
|
||||
loadCurrentSessionId: jest.fn(),
|
||||
loadCurrentAgentId: jest.fn(),
|
||||
saveCurrentAgentId: jest.fn(),
|
||||
clearCurrentSession: jest.fn(),
|
||||
saveSessionData: jest.fn(),
|
||||
loadSessionData: jest.fn(),
|
||||
saveAgentConfig: jest.fn(),
|
||||
loadAgentConfig: jest.fn(),
|
||||
clearAgentCache: jest.fn(),
|
||||
createDefaultSession: jest.fn(() => ({
|
||||
id: "test-session-123",
|
||||
name: "Default Session",
|
||||
messages: [],
|
||||
selectedModel: "",
|
||||
systemMessage: "You are a helpful assistant.",
|
||||
agentId: "test-agent-123",
|
||||
createdAt: Date.now(),
|
||||
updatedAt: Date.now(),
|
||||
})),
|
||||
},
|
||||
}));
|
||||
|
||||
const mockAgents = [
|
||||
{
|
||||
agent_id: "agent_123",
|
||||
agent_config: {
|
||||
name: "Test Agent",
|
||||
instructions: "You are a test assistant.",
|
||||
},
|
||||
},
|
||||
{
|
||||
agent_id: "agent_456",
|
||||
agent_config: {
|
||||
agent_name: "Another Agent",
|
||||
instructions: "You are another assistant.",
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const mockModels = [
|
||||
{
|
||||
identifier: "test-model-1",
|
||||
model_type: "llm",
|
||||
},
|
||||
{
|
||||
identifier: "test-model-2",
|
||||
model_type: "llm",
|
||||
},
|
||||
];
|
||||
|
||||
const mockToolgroups = [
|
||||
{
|
||||
identifier: "builtin::rag",
|
||||
provider_id: "test-provider",
|
||||
type: "tool_group",
|
||||
provider_resource_id: "test-resource",
|
||||
},
|
||||
];
|
||||
|
||||
describe("ChatPlaygroundPage", () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
Element.prototype.scrollIntoView = jest.fn();
|
||||
mockClient.agents.list.mockResolvedValue({ data: mockAgents });
|
||||
mockClient.models.list.mockResolvedValue(mockModels);
|
||||
mockClient.toolgroups.list.mockResolvedValue(mockToolgroups);
|
||||
mockClient.agents.session.create.mockResolvedValue({
|
||||
session_id: "new-session-123",
|
||||
});
|
||||
mockClient.agents.session.list.mockResolvedValue({ data: [] });
|
||||
mockClient.agents.session.retrieve.mockResolvedValue({
|
||||
session_id: "test-session",
|
||||
session_name: "Test Session",
|
||||
started_at: new Date().toISOString(),
|
||||
turns: [],
|
||||
}); // No turns by default
|
||||
mockClient.agents.retrieve.mockResolvedValue({
|
||||
agent_id: "test-agent",
|
||||
agent_config: {
|
||||
toolgroups: ["builtin::rag"],
|
||||
instructions: "Test instructions",
|
||||
model: "test-model",
|
||||
},
|
||||
});
|
||||
mockClient.agents.delete.mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
describe("Agent Selector Rendering", () => {
|
||||
test("shows agent selector when agents are available", async () => {
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Agent Session:")).toBeInTheDocument();
|
||||
expect(screen.getAllByRole("combobox")).toHaveLength(2);
|
||||
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
|
||||
expect(screen.getByText("Clear Chat")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("does not show agent selector when no agents are available", async () => {
|
||||
mockClient.agents.list.mockResolvedValue({ data: [] });
|
||||
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument();
|
||||
expect(screen.getAllByRole("combobox")).toHaveLength(1);
|
||||
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
|
||||
expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("does not show agent selector while loading", async () => {
|
||||
mockClient.agents.list.mockImplementation(() => new Promise(() => {}));
|
||||
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument();
|
||||
expect(screen.getAllByRole("combobox")).toHaveLength(1);
|
||||
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
|
||||
expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("shows agent options in selector", async () => {
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const agentCombobox = screen.getAllByRole("combobox").find(element => {
|
||||
return (
|
||||
element.textContent?.includes("Test Agent") ||
|
||||
element.textContent?.includes("Select Agent")
|
||||
);
|
||||
});
|
||||
expect(agentCombobox).toBeDefined();
|
||||
fireEvent.click(agentCombobox!);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getAllByText("Test Agent")).toHaveLength(2);
|
||||
expect(screen.getByText("Another Agent")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("displays agent ID when no name is available", async () => {
|
||||
const agentWithoutName = {
|
||||
agent_id: "agent_789",
|
||||
agent_config: {
|
||||
instructions: "You are an agent without a name.",
|
||||
},
|
||||
};
|
||||
|
||||
mockClient.agents.list.mockResolvedValue({ data: [agentWithoutName] });
|
||||
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const agentCombobox = screen.getAllByRole("combobox").find(element => {
|
||||
return (
|
||||
element.textContent?.includes("Agent agent_78") ||
|
||||
element.textContent?.includes("Select Agent")
|
||||
);
|
||||
});
|
||||
expect(agentCombobox).toBeDefined();
|
||||
fireEvent.click(agentCombobox!);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getAllByText("Agent agent_78...")).toHaveLength(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Agent Creation Modal", () => {
|
||||
test("opens agent creation modal when + New Agent is clicked", async () => {
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
const newAgentButton = screen.getByText("+ New Agent");
|
||||
fireEvent.click(newAgentButton);
|
||||
|
||||
expect(screen.getByText("Create New Agent")).toBeInTheDocument();
|
||||
expect(screen.getByText("Agent Name (optional)")).toBeInTheDocument();
|
||||
expect(screen.getAllByText("Model")).toHaveLength(2);
|
||||
expect(screen.getByText("System Instructions")).toBeInTheDocument();
|
||||
expect(screen.getByText("Tools (optional)")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("closes modal when Cancel is clicked", async () => {
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
const newAgentButton = screen.getByText("+ New Agent");
|
||||
fireEvent.click(newAgentButton);
|
||||
|
||||
const cancelButton = screen.getByText("Cancel");
|
||||
fireEvent.click(cancelButton);
|
||||
|
||||
expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("creates agent when Create Agent is clicked", async () => {
|
||||
mockClient.agents.create.mockResolvedValue({ agent_id: "new-agent-123" });
|
||||
mockClient.agents.list
|
||||
.mockResolvedValueOnce({ data: mockAgents })
|
||||
.mockResolvedValueOnce({
|
||||
data: [
|
||||
...mockAgents,
|
||||
{ agent_id: "new-agent-123", agent_config: { name: "New Agent" } },
|
||||
],
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
const newAgentButton = screen.getByText("+ New Agent");
|
||||
await act(async () => {
|
||||
fireEvent.click(newAgentButton);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Create New Agent")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const nameInput = screen.getByPlaceholderText("My Custom Agent");
|
||||
await act(async () => {
|
||||
fireEvent.change(nameInput, { target: { value: "Test Agent Name" } });
|
||||
});
|
||||
|
||||
const instructionsTextarea = screen.getByDisplayValue(
|
||||
"You are a helpful assistant."
|
||||
);
|
||||
await act(async () => {
|
||||
fireEvent.change(instructionsTextarea, {
|
||||
target: { value: "Custom instructions" },
|
||||
});
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const modalModelSelectors = screen
|
||||
.getAllByRole("combobox")
|
||||
.filter(el => {
|
||||
return (
|
||||
el.textContent?.includes("Select Model") ||
|
||||
el.closest('[class*="modal"]') ||
|
||||
el.closest('[class*="card"]')
|
||||
);
|
||||
});
|
||||
expect(modalModelSelectors.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
const modalModelSelectors = screen.getAllByRole("combobox").filter(el => {
|
||||
return (
|
||||
el.textContent?.includes("Select Model") ||
|
||||
el.closest('[class*="modal"]') ||
|
||||
el.closest('[class*="card"]')
|
||||
);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.click(modalModelSelectors[0]);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const modelOptions = screen.getAllByText("test-model-1");
|
||||
expect(modelOptions.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
const modelOptions = screen.getAllByText("test-model-1");
|
||||
const dropdownOption = modelOptions.find(
|
||||
option =>
|
||||
option.closest('[role="option"]') ||
|
||||
option.id?.includes("radix") ||
|
||||
option.getAttribute("aria-selected") !== null
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.click(
|
||||
dropdownOption || modelOptions[modelOptions.length - 1]
|
||||
);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const createButton = screen.getByText("Create Agent");
|
||||
expect(createButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
const createButton = screen.getByText("Create Agent");
|
||||
await act(async () => {
|
||||
fireEvent.click(createButton);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClient.agents.create).toHaveBeenCalledWith({
|
||||
agent_config: {
|
||||
model: expect.any(String),
|
||||
instructions: "Custom instructions",
|
||||
name: "Test Agent Name",
|
||||
enable_session_persistence: true,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Agent Selection", () => {
|
||||
test("creates default session when agent is selected", async () => {
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
// first agent should be auto-selected
|
||||
expect(mockClient.agents.session.create).toHaveBeenCalledWith(
|
||||
"agent_123",
|
||||
{ session_name: "Default Session" }
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
test("switches agent when different agent is selected", async () => {
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const agentCombobox = screen.getAllByRole("combobox").find(element => {
|
||||
return (
|
||||
element.textContent?.includes("Test Agent") ||
|
||||
element.textContent?.includes("Select Agent")
|
||||
);
|
||||
});
|
||||
expect(agentCombobox).toBeDefined();
|
||||
fireEvent.click(agentCombobox!);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const anotherAgentOption = screen.getByText("Another Agent");
|
||||
fireEvent.click(anotherAgentOption);
|
||||
});
|
||||
|
||||
expect(mockClient.agents.session.create).toHaveBeenCalledWith(
|
||||
"agent_456",
|
||||
{ session_name: "Default Session" }
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Agent Deletion", () => {
|
||||
test("shows delete button when multiple agents exist", async () => {
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("hides delete button when only one agent exists", async () => {
|
||||
mockClient.agents.list.mockResolvedValue({
|
||||
data: [mockAgents[0]],
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByTitle("Delete current agent")
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("deletes agent and switches to another when confirmed", async () => {
|
||||
global.confirm = jest.fn(() => true);
|
||||
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
mockClient.agents.delete.mockResolvedValue(undefined);
|
||||
mockClient.agents.list.mockResolvedValueOnce({ data: mockAgents });
|
||||
mockClient.agents.list.mockResolvedValueOnce({
|
||||
data: [mockAgents[1]],
|
||||
});
|
||||
|
||||
const deleteButton = screen.getByTitle("Delete current agent");
|
||||
await act(async () => {
|
||||
deleteButton.click();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClient.agents.delete).toHaveBeenCalledWith("agent_123");
|
||||
expect(global.confirm).toHaveBeenCalledWith(
|
||||
"Are you sure you want to delete this agent? This action cannot be undone and will delete all associated sessions."
|
||||
);
|
||||
});
|
||||
|
||||
(global.confirm as jest.Mock).mockRestore();
|
||||
});
|
||||
|
||||
test("does not delete agent when cancelled", async () => {
|
||||
global.confirm = jest.fn(() => false);
|
||||
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const deleteButton = screen.getByTitle("Delete current agent");
|
||||
await act(async () => {
|
||||
deleteButton.click();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(global.confirm).toHaveBeenCalled();
|
||||
expect(mockClient.agents.delete).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
(global.confirm as jest.Mock).mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Error Handling", () => {
|
||||
test("handles agent loading errors gracefully", async () => {
|
||||
mockClient.agents.list.mockRejectedValue(
|
||||
new Error("Failed to load agents")
|
||||
);
|
||||
const consoleSpy = jest
|
||||
.spyOn(console, "error")
|
||||
.mockImplementation(() => {});
|
||||
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
"Error fetching agents:",
|
||||
expect.any(Error)
|
||||
);
|
||||
});
|
||||
|
||||
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
|
||||
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
|
||||
test("handles model loading errors gracefully", async () => {
|
||||
mockClient.models.list.mockRejectedValue(
|
||||
new Error("Failed to load models")
|
||||
);
|
||||
const consoleSpy = jest
|
||||
.spyOn(console, "error")
|
||||
.mockImplementation(() => {});
|
||||
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
"Error fetching models:",
|
||||
expect.any(Error)
|
||||
);
|
||||
});
|
||||
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
});
|
||||
File diff suppressed because it is too large
Load diff
Binary file not shown.
|
Before Width: | Height: | Size: 25 KiB |
|
|
@ -120,3 +120,44 @@
|
|||
@apply bg-background text-foreground;
|
||||
}
|
||||
}
|
||||
|
||||
@layer utilities {
|
||||
.animate-typing-dot-1 {
|
||||
animation: typing-dot-bounce-1 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||
}
|
||||
|
||||
.animate-typing-dot-2 {
|
||||
animation: typing-dot-bounce-2 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||
}
|
||||
|
||||
.animate-typing-dot-3 {
|
||||
animation: typing-dot-bounce-3 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||
}
|
||||
|
||||
@keyframes typing-dot-bounce-1 {
|
||||
0%, 15%, 85%, 100% {
|
||||
transform: translateY(0);
|
||||
}
|
||||
7.5% {
|
||||
transform: translateY(-6px);
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes typing-dot-bounce-2 {
|
||||
0%, 15%, 35%, 85%, 100% {
|
||||
transform: translateY(0);
|
||||
}
|
||||
25% {
|
||||
transform: translateY(-6px);
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes typing-dot-bounce-3 {
|
||||
0%, 35%, 55%, 85%, 100% {
|
||||
transform: translateY(0);
|
||||
}
|
||||
45% {
|
||||
transform: translateY(-6px);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@ const geistMono = Geist_Mono({
|
|||
export const metadata: Metadata = {
|
||||
title: "Llama Stack",
|
||||
description: "Llama Stack UI",
|
||||
icons: {
|
||||
icon: "/favicon.ico",
|
||||
},
|
||||
};
|
||||
|
||||
import { SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
|
||||
|
|
|
|||
|
|
@ -161,10 +161,12 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
|||
|
||||
const isUser = role === "user";
|
||||
|
||||
const formattedTime = createdAt?.toLocaleTimeString("en-US", {
|
||||
hour: "2-digit",
|
||||
minute: "2-digit",
|
||||
});
|
||||
const formattedTime = createdAt
|
||||
? new Date(createdAt).toLocaleTimeString("en-US", {
|
||||
hour: "2-digit",
|
||||
minute: "2-digit",
|
||||
})
|
||||
: undefined;
|
||||
|
||||
if (isUser) {
|
||||
return (
|
||||
|
|
@ -185,7 +187,7 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
|||
|
||||
{showTimeStamp && createdAt ? (
|
||||
<time
|
||||
dateTime={createdAt.toISOString()}
|
||||
dateTime={new Date(createdAt).toISOString()}
|
||||
className={cn(
|
||||
"mt-1 block px-1 text-xs opacity-50",
|
||||
animation !== "none" && "duration-500 animate-in fade-in-0"
|
||||
|
|
@ -220,7 +222,7 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
|||
|
||||
{showTimeStamp && createdAt ? (
|
||||
<time
|
||||
dateTime={createdAt.toISOString()}
|
||||
dateTime={new Date(createdAt).toISOString()}
|
||||
className={cn(
|
||||
"mt-1 block px-1 text-xs opacity-50",
|
||||
animation !== "none" && "duration-500 animate-in fade-in-0"
|
||||
|
|
@ -262,7 +264,7 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
|||
|
||||
{showTimeStamp && createdAt ? (
|
||||
<time
|
||||
dateTime={createdAt.toISOString()}
|
||||
dateTime={new Date(createdAt).toISOString()}
|
||||
className={cn(
|
||||
"mt-1 block px-1 text-xs opacity-50",
|
||||
animation !== "none" && "duration-500 animate-in fade-in-0"
|
||||
|
|
|
|||
345
llama_stack/ui/components/chat-playground/conversations.test.tsx
Normal file
345
llama_stack/ui/components/chat-playground/conversations.test.tsx
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
import React from "react";
|
||||
import { render, screen, waitFor, act } from "@testing-library/react";
|
||||
import "@testing-library/jest-dom";
|
||||
import { Conversations, SessionUtils } from "./conversations";
|
||||
import type { Message } from "@/components/chat-playground/chat-message";
|
||||
|
||||
interface ChatSession {
|
||||
id: string;
|
||||
name: string;
|
||||
messages: Message[];
|
||||
selectedModel: string;
|
||||
systemMessage: string;
|
||||
agentId: string;
|
||||
createdAt: number;
|
||||
updatedAt: number;
|
||||
}
|
||||
|
||||
const mockOnSessionChange = jest.fn();
|
||||
const mockOnNewSession = jest.fn();
|
||||
|
||||
// Mock the auth client
|
||||
const mockClient = {
|
||||
agents: {
|
||||
session: {
|
||||
list: jest.fn(),
|
||||
create: jest.fn(),
|
||||
delete: jest.fn(),
|
||||
retrieve: jest.fn(),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Mock the useAuthClient hook
|
||||
jest.mock("@/hooks/use-auth-client", () => ({
|
||||
useAuthClient: jest.fn(() => mockClient),
|
||||
}));
|
||||
|
||||
// Mock additional SessionUtils methods that are now being used
|
||||
jest.mock("./conversations", () => {
|
||||
const actual = jest.requireActual("./conversations");
|
||||
return {
|
||||
...actual,
|
||||
SessionUtils: {
|
||||
...actual.SessionUtils,
|
||||
saveSessionData: jest.fn(),
|
||||
loadSessionData: jest.fn(),
|
||||
saveAgentConfig: jest.fn(),
|
||||
loadAgentConfig: jest.fn(),
|
||||
clearAgentCache: jest.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
const localStorageMock = {
|
||||
getItem: jest.fn(),
|
||||
setItem: jest.fn(),
|
||||
removeItem: jest.fn(),
|
||||
clear: jest.fn(),
|
||||
};
|
||||
|
||||
Object.defineProperty(window, "localStorage", {
|
||||
value: localStorageMock,
|
||||
writable: true,
|
||||
});
|
||||
|
||||
// Mock crypto.randomUUID for test environment
|
||||
let uuidCounter = 0;
|
||||
Object.defineProperty(globalThis, "crypto", {
|
||||
value: {
|
||||
randomUUID: jest.fn(() => `test-uuid-${++uuidCounter}`),
|
||||
},
|
||||
writable: true,
|
||||
});
|
||||
|
||||
describe("SessionManager", () => {
|
||||
const mockSession: ChatSession = {
|
||||
id: "session_123",
|
||||
name: "Test Session",
|
||||
messages: [
|
||||
{
|
||||
id: "msg_1",
|
||||
role: "user",
|
||||
content: "Hello",
|
||||
createdAt: new Date(),
|
||||
},
|
||||
],
|
||||
selectedModel: "test-model",
|
||||
systemMessage: "You are a helpful assistant.",
|
||||
agentId: "agent_123",
|
||||
createdAt: 1710000000,
|
||||
updatedAt: 1710001000,
|
||||
};
|
||||
|
||||
const mockAgentSessions = [
|
||||
{
|
||||
session_id: "session_123",
|
||||
session_name: "Test Session",
|
||||
started_at: "2024-01-01T00:00:00Z",
|
||||
turns: [],
|
||||
},
|
||||
{
|
||||
session_id: "session_456",
|
||||
session_name: "Another Session",
|
||||
started_at: "2024-01-01T01:00:00Z",
|
||||
turns: [],
|
||||
},
|
||||
];
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
localStorageMock.getItem.mockReturnValue(null);
|
||||
localStorageMock.setItem.mockImplementation(() => {});
|
||||
mockClient.agents.session.list.mockResolvedValue({
|
||||
data: mockAgentSessions,
|
||||
});
|
||||
mockClient.agents.session.create.mockResolvedValue({
|
||||
session_id: "new_session_123",
|
||||
});
|
||||
mockClient.agents.session.delete.mockResolvedValue(undefined);
|
||||
mockClient.agents.session.retrieve.mockResolvedValue({
|
||||
session_id: "test-session",
|
||||
session_name: "Test Session",
|
||||
started_at: new Date().toISOString(),
|
||||
turns: [],
|
||||
});
|
||||
uuidCounter = 0; // Reset UUID counter for consistent test behavior
|
||||
});
|
||||
|
||||
describe("Component Rendering", () => {
|
||||
test("does not render when no agent is selected", async () => {
|
||||
const { container } = await act(async () => {
|
||||
return render(
|
||||
<Conversations
|
||||
selectedAgentId=""
|
||||
currentSession={null}
|
||||
onSessionChange={mockOnSessionChange}
|
||||
onNewSession={mockOnNewSession}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
expect(container.firstChild).toBeNull();
|
||||
});
|
||||
|
||||
test("renders loading state initially", async () => {
|
||||
mockClient.agents.session.list.mockImplementation(
|
||||
() => new Promise(() => {}) // Never resolves to simulate loading
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
render(
|
||||
<Conversations
|
||||
selectedAgentId="agent_123"
|
||||
currentSession={null}
|
||||
onSessionChange={mockOnSessionChange}
|
||||
onNewSession={mockOnNewSession}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
expect(screen.getByText("Select Session")).toBeInTheDocument();
|
||||
// When loading, the "+ New" button should be disabled
|
||||
expect(screen.getByText("+ New")).toBeDisabled();
|
||||
});
|
||||
|
||||
test("renders session selector when agent sessions are loaded", async () => {
|
||||
await act(async () => {
|
||||
render(
|
||||
<Conversations
|
||||
selectedAgentId="agent_123"
|
||||
currentSession={null}
|
||||
onSessionChange={mockOnSessionChange}
|
||||
onNewSession={mockOnNewSession}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Select Session")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders current session name when session is selected", async () => {
|
||||
await act(async () => {
|
||||
render(
|
||||
<Conversations
|
||||
selectedAgentId="agent_123"
|
||||
currentSession={mockSession}
|
||||
onSessionChange={mockOnSessionChange}
|
||||
onNewSession={mockOnNewSession}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Test Session")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Agent API Integration", () => {
|
||||
test("loads sessions from agent API on mount", async () => {
|
||||
await act(async () => {
|
||||
render(
|
||||
<Conversations
|
||||
selectedAgentId="agent_123"
|
||||
currentSession={mockSession}
|
||||
onSessionChange={mockOnSessionChange}
|
||||
onNewSession={mockOnNewSession}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClient.agents.session.list).toHaveBeenCalledWith(
|
||||
"agent_123"
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
test("handles API errors gracefully", async () => {
|
||||
mockClient.agents.session.list.mockRejectedValue(new Error("API Error"));
|
||||
const consoleSpy = jest
|
||||
.spyOn(console, "error")
|
||||
.mockImplementation(() => {});
|
||||
|
||||
await act(async () => {
|
||||
render(
|
||||
<Conversations
|
||||
selectedAgentId="agent_123"
|
||||
currentSession={mockSession}
|
||||
onSessionChange={mockOnSessionChange}
|
||||
onNewSession={mockOnNewSession}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
"Error loading agent sessions:",
|
||||
expect.any(Error)
|
||||
);
|
||||
});
|
||||
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Error Handling", () => {
|
||||
test("component renders without crashing when API is unavailable", async () => {
|
||||
mockClient.agents.session.list.mockRejectedValue(
|
||||
new Error("Network Error")
|
||||
);
|
||||
const consoleSpy = jest
|
||||
.spyOn(console, "error")
|
||||
.mockImplementation(() => {});
|
||||
|
||||
await act(async () => {
|
||||
render(
|
||||
<Conversations
|
||||
selectedAgentId="agent_123"
|
||||
currentSession={mockSession}
|
||||
onSessionChange={mockOnSessionChange}
|
||||
onNewSession={mockOnNewSession}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
// Should still render the session manager with the select trigger
|
||||
expect(screen.getByRole("combobox")).toBeInTheDocument();
|
||||
expect(screen.getByText("+ New")).toBeInTheDocument();
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("SessionUtils", () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
localStorageMock.getItem.mockReturnValue(null);
|
||||
localStorageMock.setItem.mockImplementation(() => {});
|
||||
});
|
||||
|
||||
describe("saveCurrentSessionId", () => {
|
||||
test("saves session ID to localStorage", () => {
|
||||
SessionUtils.saveCurrentSessionId("test-session-id");
|
||||
|
||||
expect(localStorageMock.setItem).toHaveBeenCalledWith(
|
||||
"chat-playground-current-session",
|
||||
"test-session-id"
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("createDefaultSession", () => {
|
||||
test("creates default session with agent ID", () => {
|
||||
const result = SessionUtils.createDefaultSession("agent_123");
|
||||
|
||||
expect(result).toEqual(
|
||||
expect.objectContaining({
|
||||
name: "Default Session",
|
||||
messages: [],
|
||||
selectedModel: "",
|
||||
systemMessage: "You are a helpful assistant.",
|
||||
agentId: "agent_123",
|
||||
})
|
||||
);
|
||||
expect(result.id).toBeTruthy();
|
||||
expect(result.createdAt).toBeTruthy();
|
||||
expect(result.updatedAt).toBeTruthy();
|
||||
});
|
||||
|
||||
test("creates default session with inherited model", () => {
|
||||
const result = SessionUtils.createDefaultSession(
|
||||
"agent_123",
|
||||
"inherited-model"
|
||||
);
|
||||
|
||||
expect(result.selectedModel).toBe("inherited-model");
|
||||
expect(result.agentId).toBe("agent_123");
|
||||
});
|
||||
|
||||
test("creates unique session IDs", () => {
|
||||
const originalNow = Date.now;
|
||||
let mockTime = 1710005000;
|
||||
Date.now = jest.fn(() => ++mockTime);
|
||||
|
||||
const session1 = SessionUtils.createDefaultSession("agent_123");
|
||||
const session2 = SessionUtils.createDefaultSession("agent_123");
|
||||
|
||||
expect(session1.id).not.toBe(session2.id);
|
||||
|
||||
Date.now = originalNow;
|
||||
});
|
||||
|
||||
test("sets creation and update timestamps", () => {
|
||||
const result = SessionUtils.createDefaultSession("agent_123");
|
||||
|
||||
expect(result.createdAt).toBeTruthy();
|
||||
expect(result.updatedAt).toBeTruthy();
|
||||
expect(typeof result.createdAt).toBe("number");
|
||||
expect(typeof result.updatedAt).toBe("number");
|
||||
});
|
||||
});
|
||||
});
|
||||
568
llama_stack/ui/components/chat-playground/conversations.tsx
Normal file
568
llama_stack/ui/components/chat-playground/conversations.tsx
Normal file
|
|
@ -0,0 +1,568 @@
|
|||
"use client";
|
||||
|
||||
import { useState, useEffect, useCallback } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Card } from "@/components/ui/card";
|
||||
import { Trash2 } from "lucide-react";
|
||||
import type { Message } from "@/components/chat-playground/chat-message";
|
||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||
import type {
|
||||
Session,
|
||||
SessionCreateParams,
|
||||
} from "llama-stack-client/resources/agents";
|
||||
|
||||
export interface ChatSession {
|
||||
id: string;
|
||||
name: string;
|
||||
messages: Message[];
|
||||
selectedModel: string;
|
||||
systemMessage: string;
|
||||
agentId: string;
|
||||
session?: Session;
|
||||
createdAt: number;
|
||||
updatedAt: number;
|
||||
}
|
||||
|
||||
interface SessionManagerProps {
|
||||
currentSession: ChatSession | null;
|
||||
onSessionChange: (session: ChatSession) => void;
|
||||
onNewSession: () => void;
|
||||
selectedAgentId: string;
|
||||
}
|
||||
|
||||
const CURRENT_SESSION_KEY = "chat-playground-current-session";
|
||||
|
||||
// ensures this only happens client side
|
||||
const safeLocalStorage = {
|
||||
getItem: (key: string): string | null => {
|
||||
if (typeof window === "undefined") return null;
|
||||
try {
|
||||
return localStorage.getItem(key);
|
||||
} catch (err) {
|
||||
console.error("Error accessing localStorage:", err);
|
||||
return null;
|
||||
}
|
||||
},
|
||||
setItem: (key: string, value: string): void => {
|
||||
if (typeof window === "undefined") return;
|
||||
try {
|
||||
localStorage.setItem(key, value);
|
||||
} catch (err) {
|
||||
console.error("Error writing to localStorage:", err);
|
||||
}
|
||||
},
|
||||
removeItem: (key: string): void => {
|
||||
if (typeof window === "undefined") return;
|
||||
try {
|
||||
localStorage.removeItem(key);
|
||||
} catch (err) {
|
||||
console.error("Error removing from localStorage:", err);
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
const generateSessionId = (): string => {
|
||||
return globalThis.crypto.randomUUID();
|
||||
};
|
||||
|
||||
export function Conversations({
|
||||
currentSession,
|
||||
onSessionChange,
|
||||
selectedAgentId,
|
||||
}: SessionManagerProps) {
|
||||
const [sessions, setSessions] = useState<ChatSession[]>([]);
|
||||
const [showCreateForm, setShowCreateForm] = useState(false);
|
||||
const [newSessionName, setNewSessionName] = useState("");
|
||||
const [loading, setLoading] = useState(false);
|
||||
const client = useAuthClient();
|
||||
|
||||
const loadAgentSessions = useCallback(async () => {
|
||||
if (!selectedAgentId) return;
|
||||
|
||||
setLoading(true);
|
||||
try {
|
||||
const response = await client.agents.session.list(selectedAgentId);
|
||||
console.log("Sessions response:", response);
|
||||
|
||||
if (!response.data || !Array.isArray(response.data)) {
|
||||
console.warn("Invalid sessions response, starting fresh");
|
||||
setSessions([]);
|
||||
return;
|
||||
}
|
||||
|
||||
const agentSessions: ChatSession[] = response.data
|
||||
.filter(sessionData => {
|
||||
const isValid =
|
||||
sessionData &&
|
||||
typeof sessionData === "object" &&
|
||||
sessionData.session_id &&
|
||||
sessionData.session_name;
|
||||
if (!isValid) {
|
||||
console.warn("Filtering out invalid session:", sessionData);
|
||||
}
|
||||
return isValid;
|
||||
})
|
||||
.map(sessionData => ({
|
||||
id: sessionData.session_id,
|
||||
name: sessionData.session_name,
|
||||
messages: [],
|
||||
selectedModel: currentSession?.selectedModel || "",
|
||||
systemMessage:
|
||||
currentSession?.systemMessage || "You are a helpful assistant.",
|
||||
agentId: selectedAgentId,
|
||||
session: sessionData,
|
||||
createdAt: sessionData.started_at
|
||||
? new Date(sessionData.started_at).getTime()
|
||||
: Date.now(),
|
||||
updatedAt: sessionData.started_at
|
||||
? new Date(sessionData.started_at).getTime()
|
||||
: Date.now(),
|
||||
}));
|
||||
setSessions(agentSessions);
|
||||
} catch (error) {
|
||||
console.error("Error loading agent sessions:", error);
|
||||
setSessions([]);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, [
|
||||
selectedAgentId,
|
||||
client,
|
||||
currentSession?.selectedModel,
|
||||
currentSession?.systemMessage,
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedAgentId) {
|
||||
loadAgentSessions();
|
||||
}
|
||||
}, [selectedAgentId, loadAgentSessions]);
|
||||
|
||||
const createNewSession = async () => {
|
||||
if (!selectedAgentId) return;
|
||||
|
||||
const sessionName =
|
||||
newSessionName.trim() || `Session ${sessions.length + 1}`;
|
||||
setLoading(true);
|
||||
|
||||
try {
|
||||
const response = await client.agents.session.create(selectedAgentId, {
|
||||
session_name: sessionName,
|
||||
} as SessionCreateParams);
|
||||
|
||||
const newSession: ChatSession = {
|
||||
id: response.session_id,
|
||||
name: sessionName,
|
||||
messages: [],
|
||||
selectedModel: currentSession?.selectedModel || "",
|
||||
systemMessage:
|
||||
currentSession?.systemMessage || "You are a helpful assistant.",
|
||||
agentId: selectedAgentId,
|
||||
createdAt: Date.now(),
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
setSessions(prev => [...prev, newSession]);
|
||||
SessionUtils.saveCurrentSessionId(newSession.id, selectedAgentId);
|
||||
onSessionChange(newSession);
|
||||
|
||||
setNewSessionName("");
|
||||
setShowCreateForm(false);
|
||||
} catch (error) {
|
||||
console.error("Error creating session:", error);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const loadSessionMessages = useCallback(
|
||||
async (agentId: string, sessionId: string): Promise<Message[]> => {
|
||||
try {
|
||||
const session = await client.agents.session.retrieve(
|
||||
agentId,
|
||||
sessionId
|
||||
);
|
||||
|
||||
if (!session || !session.turns || !Array.isArray(session.turns)) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const messages: Message[] = [];
|
||||
for (const turn of session.turns) {
|
||||
// Add user messages from input_messages
|
||||
if (turn.input_messages && Array.isArray(turn.input_messages)) {
|
||||
for (const input of turn.input_messages) {
|
||||
if (input.role === "user" && input.content) {
|
||||
messages.push({
|
||||
id: `${turn.turn_id}-user-${messages.length}`,
|
||||
role: "user",
|
||||
content:
|
||||
typeof input.content === "string"
|
||||
? input.content
|
||||
: JSON.stringify(input.content),
|
||||
createdAt: new Date(turn.started_at || Date.now()),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add assistant message from output_message
|
||||
if (turn.output_message && turn.output_message.content) {
|
||||
messages.push({
|
||||
id: `${turn.turn_id}-assistant-${messages.length}`,
|
||||
role: "assistant",
|
||||
content:
|
||||
typeof turn.output_message.content === "string"
|
||||
? turn.output_message.content
|
||||
: JSON.stringify(turn.output_message.content),
|
||||
createdAt: new Date(
|
||||
turn.completed_at || turn.started_at || Date.now()
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return messages;
|
||||
} catch (error) {
|
||||
console.error("Error loading session messages:", error);
|
||||
return [];
|
||||
}
|
||||
},
|
||||
[client]
|
||||
);
|
||||
|
||||
const switchToSession = useCallback(
|
||||
async (sessionId: string) => {
|
||||
const session = sessions.find(s => s.id === sessionId);
|
||||
if (session) {
|
||||
setLoading(true);
|
||||
try {
|
||||
// Load messages for this session
|
||||
const messages = await loadSessionMessages(
|
||||
selectedAgentId,
|
||||
sessionId
|
||||
);
|
||||
const sessionWithMessages = {
|
||||
...session,
|
||||
messages,
|
||||
};
|
||||
|
||||
SessionUtils.saveCurrentSessionId(sessionId, selectedAgentId);
|
||||
onSessionChange(sessionWithMessages);
|
||||
} catch (error) {
|
||||
console.error("Error switching to session:", error);
|
||||
// Fallback to session without messages
|
||||
SessionUtils.saveCurrentSessionId(sessionId, selectedAgentId);
|
||||
onSessionChange(session);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}
|
||||
},
|
||||
[sessions, selectedAgentId, loadSessionMessages, onSessionChange]
|
||||
);
|
||||
|
||||
const deleteSession = async (sessionId: string) => {
|
||||
if (sessions.length <= 1 || !selectedAgentId) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
confirm(
|
||||
"Are you sure you want to delete this session? This action cannot be undone."
|
||||
)
|
||||
) {
|
||||
setLoading(true);
|
||||
try {
|
||||
await client.agents.session.delete(selectedAgentId, sessionId);
|
||||
|
||||
const updatedSessions = sessions.filter(s => s.id !== sessionId);
|
||||
setSessions(updatedSessions);
|
||||
|
||||
if (currentSession?.id === sessionId) {
|
||||
const newCurrentSession = updatedSessions[0] || null;
|
||||
if (newCurrentSession) {
|
||||
SessionUtils.saveCurrentSessionId(
|
||||
newCurrentSession.id,
|
||||
selectedAgentId
|
||||
);
|
||||
onSessionChange(newCurrentSession);
|
||||
} else {
|
||||
SessionUtils.clearCurrentSession(selectedAgentId);
|
||||
onNewSession();
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error deleting session:", error);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (currentSession) {
|
||||
setSessions(prevSessions => {
|
||||
const updatedSessions = prevSessions.map(session =>
|
||||
session.id === currentSession.id ? currentSession : session
|
||||
);
|
||||
|
||||
if (!prevSessions.find(s => s.id === currentSession.id)) {
|
||||
updatedSessions.push(currentSession);
|
||||
}
|
||||
|
||||
return updatedSessions;
|
||||
});
|
||||
}
|
||||
}, [currentSession]);
|
||||
|
||||
// Don't render if no agent is selected
|
||||
if (!selectedAgentId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="relative">
|
||||
<div className="flex items-center gap-2">
|
||||
<Select
|
||||
value={currentSession?.id || ""}
|
||||
onValueChange={switchToSession}
|
||||
>
|
||||
<SelectTrigger className="w-[200px]">
|
||||
<SelectValue placeholder="Select Session" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{sessions.map(session => (
|
||||
<SelectItem key={session.id} value={session.id}>
|
||||
{session.name}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
|
||||
<Button
|
||||
onClick={() => setShowCreateForm(true)}
|
||||
variant="outline"
|
||||
size="sm"
|
||||
disabled={loading || !selectedAgentId}
|
||||
>
|
||||
+ New
|
||||
</Button>
|
||||
|
||||
{currentSession && sessions.length > 1 && (
|
||||
<Button
|
||||
onClick={() => deleteSession(currentSession.id)}
|
||||
variant="outline"
|
||||
size="sm"
|
||||
className="text-destructive hover:text-destructive hover:bg-destructive/10"
|
||||
title="Delete current session"
|
||||
>
|
||||
<Trash2 className="h-3 w-3" />
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{showCreateForm && (
|
||||
<Card className="absolute top-full left-0 mt-2 p-4 space-y-3 w-80 z-50 bg-background border shadow-lg">
|
||||
<h3 className="text-md font-semibold">Create New Session</h3>
|
||||
|
||||
<Input
|
||||
value={newSessionName}
|
||||
onChange={e => setNewSessionName(e.target.value)}
|
||||
placeholder="Session name (optional)"
|
||||
onKeyDown={e => {
|
||||
if (e.key === "Enter") {
|
||||
createNewSession();
|
||||
} else if (e.key === "Escape") {
|
||||
setShowCreateForm(false);
|
||||
setNewSessionName("");
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
<div className="flex gap-2">
|
||||
<Button
|
||||
onClick={createNewSession}
|
||||
className="flex-1"
|
||||
disabled={loading}
|
||||
>
|
||||
{loading ? "Creating..." : "Create"}
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={() => {
|
||||
setShowCreateForm(false);
|
||||
setNewSessionName("");
|
||||
}}
|
||||
className="flex-1"
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
</div>
|
||||
</Card>
|
||||
)}
|
||||
|
||||
{currentSession && sessions.length > 1 && (
|
||||
<div className="absolute top-full left-0 mt-1 text-xs text-gray-500 whitespace-nowrap">
|
||||
{sessions.length} sessions • Current: {currentSession.name}
|
||||
{currentSession.messages.length > 0 &&
|
||||
` • ${currentSession.messages.length} messages`}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export const SessionUtils = {
|
||||
loadCurrentSessionId: (agentId?: string): string | null => {
|
||||
const key = agentId
|
||||
? `${CURRENT_SESSION_KEY}-${agentId}`
|
||||
: CURRENT_SESSION_KEY;
|
||||
return safeLocalStorage.getItem(key);
|
||||
},
|
||||
|
||||
saveCurrentSessionId: (sessionId: string, agentId?: string) => {
|
||||
const key = agentId
|
||||
? `${CURRENT_SESSION_KEY}-${agentId}`
|
||||
: CURRENT_SESSION_KEY;
|
||||
safeLocalStorage.setItem(key, sessionId);
|
||||
},
|
||||
|
||||
createDefaultSession: (
|
||||
agentId: string,
|
||||
inheritModel?: string
|
||||
): ChatSession => ({
|
||||
id: generateSessionId(),
|
||||
name: "Default Session",
|
||||
messages: [],
|
||||
selectedModel: inheritModel || "",
|
||||
systemMessage: "You are a helpful assistant.",
|
||||
agentId,
|
||||
createdAt: Date.now(),
|
||||
updatedAt: Date.now(),
|
||||
}),
|
||||
|
||||
clearCurrentSession: (agentId?: string) => {
|
||||
const key = agentId
|
||||
? `${CURRENT_SESSION_KEY}-${agentId}`
|
||||
: CURRENT_SESSION_KEY;
|
||||
safeLocalStorage.removeItem(key);
|
||||
},
|
||||
|
||||
loadCurrentAgentId: (): string | null => {
|
||||
return safeLocalStorage.getItem("chat-playground-current-agent");
|
||||
},
|
||||
|
||||
saveCurrentAgentId: (agentId: string) => {
|
||||
safeLocalStorage.setItem("chat-playground-current-agent", agentId);
|
||||
},
|
||||
|
||||
// Comprehensive session caching
|
||||
saveSessionData: (agentId: string, sessionData: ChatSession) => {
|
||||
const key = `chat-playground-session-data-${agentId}-${sessionData.id}`;
|
||||
safeLocalStorage.setItem(
|
||||
key,
|
||||
JSON.stringify({
|
||||
...sessionData,
|
||||
cachedAt: Date.now(),
|
||||
})
|
||||
);
|
||||
},
|
||||
|
||||
loadSessionData: (agentId: string, sessionId: string): ChatSession | null => {
|
||||
const key = `chat-playground-session-data-${agentId}-${sessionId}`;
|
||||
const cached = safeLocalStorage.getItem(key);
|
||||
if (!cached) return null;
|
||||
|
||||
try {
|
||||
const data = JSON.parse(cached);
|
||||
// Check if cache is fresh (less than 1 hour old)
|
||||
const cacheAge = Date.now() - (data.cachedAt || 0);
|
||||
if (cacheAge > 60 * 60 * 1000) {
|
||||
safeLocalStorage.removeItem(key);
|
||||
return null;
|
||||
}
|
||||
|
||||
// Convert date strings back to Date objects
|
||||
return {
|
||||
...data,
|
||||
messages: data.messages.map(
|
||||
(msg: { createdAt: string; [key: string]: unknown }) => ({
|
||||
...msg,
|
||||
createdAt: new Date(msg.createdAt),
|
||||
})
|
||||
),
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Error parsing cached session data:", error);
|
||||
safeLocalStorage.removeItem(key);
|
||||
return null;
|
||||
}
|
||||
},
|
||||
|
||||
// Agent config caching
|
||||
saveAgentConfig: (
|
||||
agentId: string,
|
||||
config: {
|
||||
toolgroups?: Array<
|
||||
string | { name: string; args: Record<string, unknown> }
|
||||
>;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
) => {
|
||||
const key = `chat-playground-agent-config-${agentId}`;
|
||||
safeLocalStorage.setItem(
|
||||
key,
|
||||
JSON.stringify({
|
||||
config,
|
||||
cachedAt: Date.now(),
|
||||
})
|
||||
);
|
||||
},
|
||||
|
||||
loadAgentConfig: (
|
||||
agentId: string
|
||||
): {
|
||||
toolgroups?: Array<
|
||||
string | { name: string; args: Record<string, unknown> }
|
||||
>;
|
||||
[key: string]: unknown;
|
||||
} | null => {
|
||||
const key = `chat-playground-agent-config-${agentId}`;
|
||||
const cached = safeLocalStorage.getItem(key);
|
||||
if (!cached) return null;
|
||||
|
||||
try {
|
||||
const data = JSON.parse(cached);
|
||||
// Check if cache is fresh (less than 30 minutes old)
|
||||
const cacheAge = Date.now() - (data.cachedAt || 0);
|
||||
if (cacheAge > 30 * 60 * 1000) {
|
||||
safeLocalStorage.removeItem(key);
|
||||
return null;
|
||||
}
|
||||
return data.config;
|
||||
} catch (error) {
|
||||
console.error("Error parsing cached agent config:", error);
|
||||
safeLocalStorage.removeItem(key);
|
||||
return null;
|
||||
}
|
||||
},
|
||||
|
||||
// Clear all cached data for an agent
|
||||
clearAgentCache: (agentId: string) => {
|
||||
const keys = Object.keys(localStorage).filter(
|
||||
key =>
|
||||
key.includes(`chat-playground-session-data-${agentId}`) ||
|
||||
key.includes(`chat-playground-agent-config-${agentId}`)
|
||||
);
|
||||
keys.forEach(key => safeLocalStorage.removeItem(key));
|
||||
},
|
||||
};
|
||||
|
|
@ -5,9 +5,9 @@ export function TypingIndicator() {
|
|||
<div className="justify-left flex space-x-1">
|
||||
<div className="rounded-lg bg-muted p-3">
|
||||
<div className="flex -space-x-2.5">
|
||||
<Dot className="h-5 w-5 animate-typing-dot-bounce" />
|
||||
<Dot className="h-5 w-5 animate-typing-dot-bounce [animation-delay:90ms]" />
|
||||
<Dot className="h-5 w-5 animate-typing-dot-bounce [animation-delay:180ms]" />
|
||||
<Dot className="h-5 w-5 animate-typing-dot-1" />
|
||||
<Dot className="h-5 w-5 animate-typing-dot-2" />
|
||||
<Dot className="h-5 w-5 animate-typing-dot-3" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import {
|
|||
} from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { usePathname } from "next/navigation";
|
||||
import Image from "next/image";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
import {
|
||||
|
|
@ -110,7 +111,16 @@ export function AppSidebar() {
|
|||
return (
|
||||
<Sidebar>
|
||||
<SidebarHeader>
|
||||
<Link href="/">Llama Stack</Link>
|
||||
<Link href="/" className="flex items-center gap-2 p-2">
|
||||
<Image
|
||||
src="/logo.webp"
|
||||
alt="Llama Stack"
|
||||
width={32}
|
||||
height={32}
|
||||
className="h-8 w-8"
|
||||
/>
|
||||
<span className="font-semibold text-lg">Llama Stack</span>
|
||||
</Link>
|
||||
</SidebarHeader>
|
||||
<SidebarContent>
|
||||
<SidebarGroup>
|
||||
|
|
|
|||
BIN
llama_stack/ui/public/favicon.ico
Normal file
BIN
llama_stack/ui/public/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.2 KiB |
BIN
llama_stack/ui/public/logo.webp
Normal file
BIN
llama_stack/ui/public/logo.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 19 KiB |
|
|
@ -98,6 +98,7 @@ unit = [
|
|||
"together",
|
||||
"coverage",
|
||||
"chromadb>=1.0.15",
|
||||
"moto[s3]>=5.1.10",
|
||||
]
|
||||
# These are the core dependencies required for running integration tests. They are shared across all
|
||||
# providers. If a provider requires additional dependencies, please add them to your environment
|
||||
|
|
|
|||
|
|
@ -157,12 +157,14 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
def generate_provider_docs(provider_spec: Any, api_name: str) -> str:
|
||||
def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
|
||||
"""Generate markdown documentation for a provider."""
|
||||
provider_type = provider_spec.provider_type
|
||||
config_class = provider_spec.config_class
|
||||
|
||||
config_info = get_config_class_info(config_class)
|
||||
if "error" in config_info:
|
||||
progress.print(config_info["error"])
|
||||
|
||||
md_lines = []
|
||||
md_lines.append(f"# {provider_type}")
|
||||
|
|
@ -295,7 +297,7 @@ def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> N
|
|||
filename = provider_type.replace("::", "_").replace(":", "_")
|
||||
provider_doc_file = doc_output_dir / f"{filename}.md"
|
||||
|
||||
provider_docs = generate_provider_docs(provider, api_name)
|
||||
provider_docs = generate_provider_docs(progress, provider, api_name)
|
||||
|
||||
provider_doc_file.write_text(provider_docs)
|
||||
change_tracker.add_paths(provider_doc_file)
|
||||
|
|
|
|||
17
scripts/run-ui-linter.sh
Executable file
17
scripts/run-ui-linter.sh
Executable file
|
|
@ -0,0 +1,17 @@
|
|||
#!/bin/bash
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
set -e
|
||||
cd llama_stack/ui
|
||||
|
||||
if [ ! -d node_modules ] || [ ! -x node_modules/.bin/prettier ] || [ ! -x node_modules/.bin/eslint ]; then
|
||||
echo "UI dependencies not installed, skipping prettier/linter check"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
npm run format
|
||||
npm run lint
|
||||
251
tests/unit/providers/files/test_s3_files.py
Normal file
251
tests/unit/providers/files/test_s3_files.py
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
from botocore.exceptions import ClientError
|
||||
from moto import mock_aws
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
from llama_stack.providers.remote.files.s3 import (
|
||||
S3FilesImplConfig,
|
||||
get_adapter_impl,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"):
|
||||
self.content = content
|
||||
self.filename = filename
|
||||
self.content_type = content_type
|
||||
|
||||
async def read(self):
|
||||
return self.content
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_config(tmp_path):
|
||||
db_path = tmp_path / "s3_files_metadata.db"
|
||||
|
||||
return S3FilesImplConfig(
|
||||
bucket_name="test-bucket",
|
||||
region="not-a-region",
|
||||
auto_create_bucket=True,
|
||||
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_client():
|
||||
"""Create a mocked S3 client for testing."""
|
||||
# we use `with mock_aws()` because @mock_aws decorator does not support being a generator
|
||||
with mock_aws():
|
||||
# must yield or the mock will be reset before it is used
|
||||
yield boto3.client("s3")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def s3_provider(s3_config, s3_client):
|
||||
"""Create an S3 files provider with mocked S3 for testing."""
|
||||
provider = await get_adapter_impl(s3_config, {})
|
||||
yield provider
|
||||
await provider.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_text_file():
|
||||
content = b"Hello, this is a test file for the S3 Files API!"
|
||||
return MockUploadFile(content, "sample_text_file.txt")
|
||||
|
||||
|
||||
class TestS3FilesImpl:
|
||||
"""Test suite for S3 Files implementation."""
|
||||
|
||||
async def test_upload_file(self, s3_provider, sample_text_file, s3_client, s3_config):
|
||||
"""Test successful file upload."""
|
||||
sample_text_file.filename = "test_upload_file"
|
||||
result = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
assert result.filename == sample_text_file.filename
|
||||
assert result.purpose == OpenAIFilePurpose.ASSISTANTS
|
||||
assert result.bytes == len(sample_text_file.content)
|
||||
assert result.id.startswith("file-")
|
||||
|
||||
# Verify file exists in S3 backend
|
||||
response = s3_client.head_object(Bucket=s3_config.bucket_name, Key=result.id)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
async def test_list_files_empty(self, s3_provider):
|
||||
"""Test listing files when no files exist."""
|
||||
result = await s3_provider.openai_list_files()
|
||||
|
||||
assert len(result.data) == 0
|
||||
assert not result.has_more
|
||||
assert result.first_id == ""
|
||||
assert result.last_id == ""
|
||||
|
||||
async def test_retrieve_file(self, s3_provider, sample_text_file):
|
||||
"""Test retrieving file metadata."""
|
||||
sample_text_file.filename = "test_retrieve_file"
|
||||
uploaded = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
retrieved = await s3_provider.openai_retrieve_file(uploaded.id)
|
||||
|
||||
assert retrieved.id == uploaded.id
|
||||
assert retrieved.filename == uploaded.filename
|
||||
assert retrieved.purpose == uploaded.purpose
|
||||
assert retrieved.bytes == uploaded.bytes
|
||||
|
||||
async def test_retrieve_file_content(self, s3_provider, sample_text_file):
|
||||
"""Test retrieving file content."""
|
||||
sample_text_file.filename = "test_retrieve_file_content"
|
||||
uploaded = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
response = await s3_provider.openai_retrieve_file_content(uploaded.id)
|
||||
|
||||
assert response.body == sample_text_file.content
|
||||
assert response.headers["Content-Disposition"] == f'attachment; filename="{sample_text_file.filename}"'
|
||||
|
||||
async def test_delete_file(self, s3_provider, sample_text_file, s3_config, s3_client):
|
||||
"""Test deleting a file."""
|
||||
sample_text_file.filename = "test_delete_file"
|
||||
uploaded = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
delete_response = await s3_provider.openai_delete_file(uploaded.id)
|
||||
|
||||
assert delete_response.id == uploaded.id
|
||||
assert delete_response.deleted is True
|
||||
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await s3_provider.openai_retrieve_file(uploaded.id)
|
||||
|
||||
# Verify file is gone from S3 backend
|
||||
with pytest.raises(ClientError) as exc_info:
|
||||
s3_client.head_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
|
||||
assert exc_info.value.response["Error"]["Code"] == "404"
|
||||
|
||||
async def test_list_files(self, s3_provider, sample_text_file):
|
||||
"""Test listing files after uploading some."""
|
||||
sample_text_file.filename = "test_list_files_with_content_file1"
|
||||
file1 = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
file2_content = MockUploadFile(b"Second file content", "test_list_files_with_content_file2")
|
||||
file2 = await s3_provider.openai_upload_file(
|
||||
file=file2_content,
|
||||
purpose=OpenAIFilePurpose.BATCH,
|
||||
)
|
||||
|
||||
result = await s3_provider.openai_list_files()
|
||||
|
||||
assert len(result.data) == 2
|
||||
file_ids = {f.id for f in result.data}
|
||||
assert file1.id in file_ids
|
||||
assert file2.id in file_ids
|
||||
|
||||
async def test_list_files_with_purpose_filter(self, s3_provider, sample_text_file):
|
||||
"""Test listing files with purpose filter."""
|
||||
sample_text_file.filename = "test_list_files_with_purpose_filter_file1"
|
||||
file1 = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
file2_content = MockUploadFile(b"Batch file content", "test_list_files_with_purpose_filter_file2")
|
||||
await s3_provider.openai_upload_file(
|
||||
file=file2_content,
|
||||
purpose=OpenAIFilePurpose.BATCH,
|
||||
)
|
||||
|
||||
result = await s3_provider.openai_list_files(purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == file1.id
|
||||
assert result.data[0].purpose == OpenAIFilePurpose.ASSISTANTS
|
||||
|
||||
async def test_nonexistent_file_retrieval(self, s3_provider):
|
||||
"""Test retrieving a non-existent file raises error."""
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await s3_provider.openai_retrieve_file("file-nonexistent")
|
||||
|
||||
async def test_nonexistent_file_content_retrieval(self, s3_provider):
|
||||
"""Test retrieving content of a non-existent file raises error."""
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await s3_provider.openai_retrieve_file_content("file-nonexistent")
|
||||
|
||||
async def test_nonexistent_file_deletion(self, s3_provider):
|
||||
"""Test deleting a non-existent file raises error."""
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await s3_provider.openai_delete_file("file-nonexistent")
|
||||
|
||||
async def test_upload_file_without_filename(self, s3_provider, sample_text_file):
|
||||
"""Test uploading a file without a filename uses the fallback."""
|
||||
del sample_text_file.filename
|
||||
result = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
assert result.purpose == OpenAIFilePurpose.ASSISTANTS
|
||||
assert result.bytes == len(sample_text_file.content)
|
||||
|
||||
retrieved = await s3_provider.openai_retrieve_file(result.id)
|
||||
assert retrieved.filename == result.filename
|
||||
|
||||
async def test_file_operations_when_s3_object_deleted(self, s3_provider, sample_text_file, s3_config, s3_client):
|
||||
"""Test file operations when S3 object is deleted but metadata exists (negative test)."""
|
||||
sample_text_file.filename = "test_orphaned_metadata"
|
||||
uploaded = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
# Directly delete the S3 object from the backend
|
||||
s3_client.delete_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
|
||||
|
||||
with pytest.raises(ResourceNotFoundError, match="not found") as exc_info:
|
||||
await s3_provider.openai_retrieve_file_content(uploaded.id)
|
||||
assert uploaded.id in str(exc_info).lower()
|
||||
|
||||
listed_files = await s3_provider.openai_list_files()
|
||||
assert uploaded.id not in [file.id for file in listed_files.data]
|
||||
|
||||
async def test_upload_file_s3_put_object_failure(self, s3_provider, sample_text_file, s3_config, s3_client):
|
||||
"""Test that put_object failure results in exception and no orphaned metadata."""
|
||||
sample_text_file.filename = "test_s3_put_object_failure"
|
||||
|
||||
def failing_put_object(*args, **kwargs):
|
||||
raise ClientError(
|
||||
error_response={"Error": {"Code": "SolarRadiation", "Message": "Bloop"}}, operation_name="PutObject"
|
||||
)
|
||||
|
||||
with patch.object(s3_provider.client, "put_object", side_effect=failing_put_object):
|
||||
with pytest.raises(RuntimeError, match="Failed to upload file to S3"):
|
||||
await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
files_list = await s3_provider.openai_list_files()
|
||||
assert len(files_list.data) == 0, "No file metadata should remain after failed upload"
|
||||
109
uv.lock
generated
109
uv.lock
generated
|
|
@ -347,6 +347,34 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/ed/4d/1392562369b1139e741b30d624f09fe7091d17dd5579fae5732f044b12bb/blobfile-3.0.0-py3-none-any.whl", hash = "sha256:48ecc3307e622804bd8fe13bf6f40e6463c4439eba7a1f9ad49fd78aa63cc658", size = 75413, upload-time = "2024-08-27T00:02:51.518Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "boto3"
|
||||
version = "1.40.12"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "botocore" },
|
||||
{ name = "jmespath" },
|
||||
{ name = "s3transfer" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/41/19/2c4d140a7f99b5903b21b9ccd7253c71f147c346c3c632b2117444cf2d65/boto3-1.40.12.tar.gz", hash = "sha256:c6b32aee193fbd2eb84696d2b5b2410dcda9fb4a385e1926cff908377d222247", size = 111959, upload-time = "2025-08-18T19:30:23.827Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/16/6e/5a9dcf38ad87838fb99742c4a3ab1b7507ad3a02c8c27a9ccda7a0bb5709/boto3-1.40.12-py3-none-any.whl", hash = "sha256:3c3d6731390b5b11f5e489d5d9daa57f0c3e171efb63ac8f47203df9c71812b3", size = 140075, upload-time = "2025-08-18T19:30:22.494Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "botocore"
|
||||
version = "1.40.12"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jmespath" },
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7d/b2/7933590fc5bca1980801b71e09db1a95581afff177cbf3c8a031d922885c/botocore-1.40.12.tar.gz", hash = "sha256:c6560578e799b47b762b7e555bd9c5dd5c29c5d23bd778a8a72e98c979b3c727", size = 14349930, upload-time = "2025-08-18T19:30:13.794Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/b6/65fd6e718c9538ba1462c9b71e9262bc723202ff203fe64ff66ff676d823/botocore-1.40.12-py3-none-any.whl", hash = "sha256:84e96004a8b426c5508f6b5600312d6271364269466a3a957dc377ad8effc438", size = 14018004, upload-time = "2025-08-18T19:30:09.054Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "braintrust-core"
|
||||
version = "0.0.59"
|
||||
|
|
@ -1580,6 +1608,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/b3/4a/4175a563579e884192ba6e81725fc0448b042024419be8d83aa8a80a3f44/jiter-0.10.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3aa96f2abba33dc77f79b4cf791840230375f9534e5fac927ccceb58c5e604a5", size = 354213, upload-time = "2025-05-18T19:04:41.894Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jmespath"
|
||||
version = "1.0.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/00/2a/e867e8531cf3e36b41201936b7fa7ba7b5702dbef42922193f05c8976cd6/jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe", size = 25843, upload-time = "2022-06-17T18:00:12.224Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/31/b4/b9b800c45527aadd64d5b442f9b932b00648617eb5d63d2c7a6587b7cafc/jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980", size = 20256, upload-time = "2022-06-17T18:00:10.251Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jsonschema"
|
||||
version = "4.25.0"
|
||||
|
|
@ -1820,6 +1857,7 @@ unit = [
|
|||
{ name = "litellm" },
|
||||
{ name = "mcp" },
|
||||
{ name = "milvus-lite" },
|
||||
{ name = "moto", extra = ["s3"] },
|
||||
{ name = "ollama" },
|
||||
{ name = "openai" },
|
||||
{ name = "pymilvus" },
|
||||
|
|
@ -1937,6 +1975,7 @@ unit = [
|
|||
{ name = "litellm" },
|
||||
{ name = "mcp" },
|
||||
{ name = "milvus-lite", specifier = ">=2.5.0" },
|
||||
{ name = "moto", extras = ["s3"], specifier = ">=5.1.10" },
|
||||
{ name = "ollama" },
|
||||
{ name = "openai" },
|
||||
{ name = "pymilvus", specifier = ">=2.5.12" },
|
||||
|
|
@ -2224,6 +2263,32 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/16/71/4ad9a42f2772793a03cb698f0fc42499f04e6e8d2560ba2f7da0fb059a8e/mmh3-5.1.0-cp313-cp313-win_arm64.whl", hash = "sha256:b22fe2e54be81f6c07dcb36b96fa250fb72effe08aa52fbb83eade6e1e2d5fd7", size = 38890, upload-time = "2025-01-25T08:39:25.28Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "moto"
|
||||
version = "5.1.10"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "boto3" },
|
||||
{ name = "botocore" },
|
||||
{ name = "cryptography" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "requests" },
|
||||
{ name = "responses" },
|
||||
{ name = "werkzeug" },
|
||||
{ name = "xmltodict" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c4/72/9bc9b4917b816f5a82fc8f0fbd477c2a669d35a7d7941ae15a5411e266d6/moto-5.1.10.tar.gz", hash = "sha256:d6bdc8f82a1e503502927cc0a3da22014f836094d0bf399bb0f695754ae6c7a6", size = 7087004, upload-time = "2025-08-11T20:59:45.542Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c4/37/9b9cb5597eecc2ebfde2f65a8265f3669f6724ebe82bf9b155a3421039f8/moto-5.1.10-py3-none-any.whl", hash = "sha256:9ec1a21a924f97470af225b2bfa854fe46c1ad30fb44655eba458206dedf28b5", size = 5246859, upload-time = "2025-08-11T20:59:43.22Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
s3 = [
|
||||
{ name = "py-partiql-parser" },
|
||||
{ name = "pyyaml" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mpmath"
|
||||
version = "1.3.0"
|
||||
|
|
@ -3068,6 +3133,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842, upload-time = "2024-07-21T12:58:20.04Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "py-partiql-parser"
|
||||
version = "0.6.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/58/a1/0a2867e48b232b4f82c4929ef7135f2a5d72c3886b957dccf63c70aa2fcb/py_partiql_parser-0.6.1.tar.gz", hash = "sha256:8583ff2a0e15560ef3bc3df109a7714d17f87d81d33e8c38b7fed4e58a63215d", size = 17120, upload-time = "2024-12-25T22:06:41.327Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/97/84/0e410c20bbe9a504fc56e97908f13261c2b313d16cbb3b738556166f044a/py_partiql_parser-0.6.1-py2.py3-none-any.whl", hash = "sha256:ff6a48067bff23c37e9044021bf1d949c83e195490c17e020715e927fe5b2456", size = 23520, upload-time = "2024-12-25T22:06:39.106Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyaml"
|
||||
version = "25.7.0"
|
||||
|
|
@ -3788,6 +3862,20 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/3b/5d/63d4ae3b9daea098d5d6f5da83984853c1bbacd5dc826764b249fe119d24/requests_oauthlib-2.0.0-py2.py3-none-any.whl", hash = "sha256:7dd8a5c40426b779b0868c404bdef9768deccf22749cde15852df527e6269b36", size = 24179, upload-time = "2024-03-22T20:32:28.055Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "responses"
|
||||
version = "0.25.8"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pyyaml" },
|
||||
{ name = "requests" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0e/95/89c054ad70bfef6da605338b009b2e283485835351a9935c7bfbfaca7ffc/responses-0.25.8.tar.gz", hash = "sha256:9374d047a575c8f781b94454db5cab590b6029505f488d12899ddb10a4af1cf4", size = 79320, upload-time = "2025-08-08T19:01:46.709Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1c/4c/cc276ce57e572c102d9542d383b2cfd551276581dc60004cb94fe8774c11/responses-0.25.8-py3-none-any.whl", hash = "sha256:0c710af92def29c8352ceadff0c3fe340ace27cf5af1bbe46fb71275bcd2831c", size = 34769, upload-time = "2025-08-08T19:01:45.018Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rich"
|
||||
version = "14.1.0"
|
||||
|
|
@ -3961,6 +4049,18 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/00/db/c376b0661c24cf770cb8815268190668ec1330eba8374a126ceef8c72d55/ruff-0.12.5-py3-none-win_arm64.whl", hash = "sha256:48cdbfc633de2c5c37d9f090ba3b352d1576b0015bfc3bc98eaf230275b7e805", size = 11951564, upload-time = "2025-07-24T13:26:34.994Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "s3transfer"
|
||||
version = "0.13.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "botocore" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6d/05/d52bf1e65044b4e5e27d4e63e8d1579dbdec54fce685908ae09bc3720030/s3transfer-0.13.1.tar.gz", hash = "sha256:c3fdba22ba1bd367922f27ec8032d6a1cf5f10c934fb5d68cf60fd5a23d936cf", size = 150589, upload-time = "2025-07-18T19:22:42.31Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/6d/4f/d073e09df851cfa251ef7840007d04db3293a0482ce607d2b993926089be/s3transfer-0.13.1-py3-none-any.whl", hash = "sha256:a981aa7429be23fe6dfc13e80e4020057cbab622b08c0315288758d67cabc724", size = 85308, upload-time = "2025-07-18T19:22:40.947Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "safetensors"
|
||||
version = "0.5.3"
|
||||
|
|
@ -5107,6 +5207,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/78/58/e860788190eba3bcce367f74d29c4675466ce8dddfba85f7827588416f01/wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736", size = 24226, upload-time = "2022-08-23T19:58:19.96Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "xmltodict"
|
||||
version = "0.14.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/50/05/51dcca9a9bf5e1bce52582683ce50980bcadbc4fa5143b9f2b19ab99958f/xmltodict-0.14.2.tar.gz", hash = "sha256:201e7c28bb210e374999d1dde6382923ab0ed1a8a5faeece48ab525b7810a553", size = 51942, upload-time = "2024-10-16T06:10:29.683Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d6/45/fc303eb433e8a2a271739c98e953728422fa61a3c1f36077a49e395c972e/xmltodict-0.14.2-py2.py3-none-any.whl", hash = "sha256:20cc7d723ed729276e808f26fb6b3599f786cbc37e06c65e192ba77c40f20aac", size = 9981, upload-time = "2024-10-16T06:10:27.649Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "xxhash"
|
||||
version = "3.5.0"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue