mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 18:02:40 +00:00
Merge remote-tracking branch 'origin/main' into ai
This commit is contained in:
commit
0b3218ed81
94 changed files with 4092 additions and 235 deletions
22
.github/workflows/pre-commit.yml
vendored
22
.github/workflows/pre-commit.yml
vendored
|
|
@ -36,20 +36,16 @@ jobs:
|
||||||
**/requirements*.txt
|
**/requirements*.txt
|
||||||
.pre-commit-config.yaml
|
.pre-commit-config.yaml
|
||||||
|
|
||||||
# npm ci may fail -
|
- name: Set up Node.js
|
||||||
# npm error `npm ci` can only install packages when your package.json and package-lock.json or npm-shrinkwrap.json are in sync. Please update your lock file with `npm install` before continuing.
|
uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0
|
||||||
# npm error Invalid: lock file's llama-stack-client@0.2.17 does not satisfy llama-stack-client@0.2.18
|
with:
|
||||||
|
node-version: '20'
|
||||||
|
cache: 'npm'
|
||||||
|
cache-dependency-path: 'llama_stack/ui/'
|
||||||
|
|
||||||
# - name: Set up Node.js
|
- name: Install npm dependencies
|
||||||
# uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0
|
run: npm ci
|
||||||
# with:
|
working-directory: llama_stack/ui
|
||||||
# node-version: '20'
|
|
||||||
# cache: 'npm'
|
|
||||||
# cache-dependency-path: 'llama_stack/ui/'
|
|
||||||
|
|
||||||
# - name: Install npm dependencies
|
|
||||||
# run: npm ci
|
|
||||||
# working-directory: llama_stack/ui
|
|
||||||
|
|
||||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
|
|
|
||||||
|
|
@ -146,31 +146,13 @@ repos:
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
files: ^.github/workflows/.*$
|
files: ^.github/workflows/.*$
|
||||||
# ui-prettier and ui-eslint are disabled until we can avoid `npm ci`, which is slow and may fail -
|
- id: ui-linter
|
||||||
# npm error `npm ci` can only install packages when your package.json and package-lock.json or npm-shrinkwrap.json are in sync. Please update your lock file with `npm install` before continuing.
|
name: Format & Lint UI
|
||||||
# npm error Invalid: lock file's llama-stack-client@0.2.17 does not satisfy llama-stack-client@0.2.18
|
entry: bash ./scripts/run-ui-linter.sh
|
||||||
# and until we have infra for installing prettier and next via npm -
|
language: system
|
||||||
# Lint UI code with ESLint.....................................................Failed
|
files: ^llama_stack/ui/.*\.(ts|tsx)$
|
||||||
# - hook id: ui-eslint
|
pass_filenames: false
|
||||||
# - exit code: 127
|
require_serial: true
|
||||||
# > ui@0.1.0 lint
|
|
||||||
# > next lint --fix --quiet
|
|
||||||
# sh: line 1: next: command not found
|
|
||||||
#
|
|
||||||
# - id: ui-prettier
|
|
||||||
# name: Format UI code with Prettier
|
|
||||||
# entry: bash -c 'cd llama_stack/ui && npm ci && npm run format'
|
|
||||||
# language: system
|
|
||||||
# files: ^llama_stack/ui/.*\.(ts|tsx)$
|
|
||||||
# pass_filenames: false
|
|
||||||
# require_serial: true
|
|
||||||
# - id: ui-eslint
|
|
||||||
# name: Lint UI code with ESLint
|
|
||||||
# entry: bash -c 'cd llama_stack/ui && npm run lint -- --fix --quiet'
|
|
||||||
# language: system
|
|
||||||
# files: ^llama_stack/ui/.*\.(ts|tsx)$
|
|
||||||
# pass_filenames: false
|
|
||||||
# require_serial: true
|
|
||||||
|
|
||||||
- id: check-log-usage
|
- id: check-log-usage
|
||||||
name: Ensure 'llama_stack.log' usage for logging
|
name: Ensure 'llama_stack.log' usage for logging
|
||||||
|
|
|
||||||
132
docs/_static/llama-stack-spec.html
vendored
132
docs/_static/llama-stack-spec.html
vendored
|
|
@ -4605,6 +4605,49 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/v1/inference/rerank": {
|
||||||
|
"post": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "RerankResponse with indices sorted by relevance score (descending).",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/RerankResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"400": {
|
||||||
|
"$ref": "#/components/responses/BadRequest400"
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"$ref": "#/components/responses/TooManyRequests429"
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"$ref": "#/components/responses/InternalServerError500"
|
||||||
|
},
|
||||||
|
"default": {
|
||||||
|
"$ref": "#/components/responses/DefaultError"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Inference"
|
||||||
|
],
|
||||||
|
"description": "Rerank a list of documents based on their relevance to a query.",
|
||||||
|
"parameters": [],
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/RerankRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
|
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
|
|
@ -16587,6 +16630,95 @@
|
||||||
],
|
],
|
||||||
"title": "RegisterVectorDbRequest"
|
"title": "RegisterVectorDbRequest"
|
||||||
},
|
},
|
||||||
|
"RerankRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"model": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The identifier of the reranking model to use."
|
||||||
|
},
|
||||||
|
"query": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length."
|
||||||
|
},
|
||||||
|
"items": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"description": "List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length."
|
||||||
|
},
|
||||||
|
"max_num_results": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "(Optional) Maximum number of results to return. Default: returns all."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"model",
|
||||||
|
"query",
|
||||||
|
"items"
|
||||||
|
],
|
||||||
|
"title": "RerankRequest"
|
||||||
|
},
|
||||||
|
"RerankData": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"index": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The original index of the document in the input list"
|
||||||
|
},
|
||||||
|
"relevance_score": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"index",
|
||||||
|
"relevance_score"
|
||||||
|
],
|
||||||
|
"title": "RerankData",
|
||||||
|
"description": "A single rerank result from a reranking response."
|
||||||
|
},
|
||||||
|
"RerankResponse": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"data": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/RerankData"
|
||||||
|
},
|
||||||
|
"description": "List of rerank result objects, sorted by relevance score (descending)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"data"
|
||||||
|
],
|
||||||
|
"title": "RerankResponse",
|
||||||
|
"description": "Response from a reranking request."
|
||||||
|
},
|
||||||
"ResumeAgentTurnRequest": {
|
"ResumeAgentTurnRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
|
||||||
101
docs/_static/llama-stack-spec.yaml
vendored
101
docs/_static/llama-stack-spec.yaml
vendored
|
|
@ -3264,6 +3264,37 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/QueryTracesRequest'
|
$ref: '#/components/schemas/QueryTracesRequest'
|
||||||
required: true
|
required: true
|
||||||
|
/v1/inference/rerank:
|
||||||
|
post:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: >-
|
||||||
|
RerankResponse with indices sorted by relevance score (descending).
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/RerankResponse'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Inference
|
||||||
|
description: >-
|
||||||
|
Rerank a list of documents based on their relevance to a query.
|
||||||
|
parameters: []
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/RerankRequest'
|
||||||
|
required: true
|
||||||
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
|
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
|
|
@ -12337,6 +12368,76 @@ components:
|
||||||
- vector_db_id
|
- vector_db_id
|
||||||
- embedding_model
|
- embedding_model
|
||||||
title: RegisterVectorDbRequest
|
title: RegisterVectorDbRequest
|
||||||
|
RerankRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
The identifier of the reranking model to use.
|
||||||
|
query:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||||
|
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
|
||||||
|
description: >-
|
||||||
|
The search query to rank items against. Can be a string, text content
|
||||||
|
part, or image content part. The input must not exceed the model's max
|
||||||
|
input token length.
|
||||||
|
items:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||||
|
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
|
||||||
|
description: >-
|
||||||
|
List of items to rerank. Each item can be a string, text content part,
|
||||||
|
or image content part. Each input must not exceed the model's max input
|
||||||
|
token length.
|
||||||
|
max_num_results:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
(Optional) Maximum number of results to return. Default: returns all.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- model
|
||||||
|
- query
|
||||||
|
- items
|
||||||
|
title: RerankRequest
|
||||||
|
RerankData:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
index:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
The original index of the document in the input list
|
||||||
|
relevance_score:
|
||||||
|
type: number
|
||||||
|
description: >-
|
||||||
|
The relevance score from the model output. Values are inverted when applicable
|
||||||
|
so that higher scores indicate greater relevance.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- index
|
||||||
|
- relevance_score
|
||||||
|
title: RerankData
|
||||||
|
description: >-
|
||||||
|
A single rerank result from a reranking response.
|
||||||
|
RerankResponse:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
data:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/RerankData'
|
||||||
|
description: >-
|
||||||
|
List of rerank result objects, sorted by relevance score (descending)
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- data
|
||||||
|
title: RerankResponse
|
||||||
|
description: Response from a reranking request.
|
||||||
ResumeAgentTurnRequest:
|
ResumeAgentTurnRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
||||||
|
|
@ -10,4 +10,5 @@ This section contains documentation for all available providers for the **files*
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
inline_localfs
|
inline_localfs
|
||||||
|
remote_s3
|
||||||
```
|
```
|
||||||
|
|
|
||||||
33
docs/source/providers/files/remote_s3.md
Normal file
33
docs/source/providers/files/remote_s3.md
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
# remote::s3
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
AWS S3-based file storage provider for scalable cloud file management with metadata persistence.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
| Field | Type | Required | Default | Description |
|
||||||
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `bucket_name` | `<class 'str'>` | No | | S3 bucket name to store files |
|
||||||
|
| `region` | `<class 'str'>` | No | us-east-1 | AWS region where the bucket is located |
|
||||||
|
| `aws_access_key_id` | `str \| None` | No | | AWS access key ID (optional if using IAM roles) |
|
||||||
|
| `aws_secret_access_key` | `str \| None` | No | | AWS secret access key (optional if using IAM roles) |
|
||||||
|
| `endpoint_url` | `str \| None` | No | | Custom S3 endpoint URL (for MinIO, LocalStack, etc.) |
|
||||||
|
| `auto_create_bucket` | `<class 'bool'>` | No | False | Automatically create the S3 bucket if it doesn't exist |
|
||||||
|
| `metadata_store` | `utils.sqlstore.sqlstore.SqliteSqlStoreConfig \| utils.sqlstore.sqlstore.PostgresSqlStoreConfig` | No | sqlite | SQL store configuration for file metadata |
|
||||||
|
|
||||||
|
## Sample Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
bucket_name: ${env.S3_BUCKET_NAME}
|
||||||
|
region: ${env.AWS_REGION:=us-east-1}
|
||||||
|
aws_access_key_id: ${env.AWS_ACCESS_KEY_ID:=}
|
||||||
|
aws_secret_access_key: ${env.AWS_SECRET_ACCESS_KEY:=}
|
||||||
|
endpoint_url: ${env.S3_ENDPOINT_URL:=}
|
||||||
|
auto_create_bucket: ${env.S3_AUTO_CREATE_BUCKET:=false}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/s3_files_metadata.db
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
@ -473,6 +473,28 @@ class EmbeddingsResponse(BaseModel):
|
||||||
embeddings: list[list[float]]
|
embeddings: list[list[float]]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RerankData(BaseModel):
|
||||||
|
"""A single rerank result from a reranking response.
|
||||||
|
|
||||||
|
:param index: The original index of the document in the input list
|
||||||
|
:param relevance_score: The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
index: int
|
||||||
|
relevance_score: float
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RerankResponse(BaseModel):
|
||||||
|
"""Response from a reranking request.
|
||||||
|
|
||||||
|
:param data: List of rerank result objects, sorted by relevance score (descending)
|
||||||
|
"""
|
||||||
|
|
||||||
|
data: list[RerankData]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
||||||
"""Text content part for OpenAI-compatible chat completion messages.
|
"""Text content part for OpenAI-compatible chat completion messages.
|
||||||
|
|
@ -1131,6 +1153,24 @@ class InferenceProvider(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/inference/rerank", method="POST", experimental=True)
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||||
|
max_num_results: int | None = None,
|
||||||
|
) -> RerankResponse:
|
||||||
|
"""Rerank a list of documents based on their relevance to a query.
|
||||||
|
|
||||||
|
:param model: The identifier of the reranking model to use.
|
||||||
|
:param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length.
|
||||||
|
:param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length.
|
||||||
|
:param max_num_results: (Optional) Maximum number of results to return. Default: returns all.
|
||||||
|
:returns: RerankResponse with indices sorted by relevance score (descending).
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Reranking is not implemented")
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/completions", method="POST")
|
@webmethod(route="/openai/v1/completions", method="POST")
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="server")
|
logger = get_logger(name=__name__, category="cli")
|
||||||
|
|
||||||
|
|
||||||
class StackRun(Subcommand):
|
class StackRun(Subcommand):
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class DatasetIORouter(DatasetIO):
|
class DatasetIORouter(DatasetIO):
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.scoring import (
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class ScoringRouter(Scoring):
|
class ScoringRouter(Scoring):
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin
|
||||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class InferenceRouter(Inference):
|
class InferenceRouter(Inference):
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class SafetyRouter(Safety):
|
class SafetyRouter(Safety):
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class ToolRuntimeRouter(ToolRuntime):
|
class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class VectorIORouter(VectorIO):
|
class VectorIORouter(VectorIO):
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
def get_impl_api(p: Any) -> Api:
|
def get_impl_api(p: Any) -> Api:
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl, lookup_model
|
from .common import CommonRoutingTableImpl, lookup_model
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
|
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl, lookup_model
|
from .common import CommonRoutingTableImpl, lookup_model
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider
|
||||||
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="core::auth")
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationMiddleware:
|
class AuthenticationMiddleware:
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ from llama_stack.core.datatypes import (
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="core::auth")
|
||||||
|
|
||||||
|
|
||||||
class AuthResponse(BaseModel):
|
class AuthResponse(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="quota")
|
logger = get_logger(name=__name__, category="core::server")
|
||||||
|
|
||||||
|
|
||||||
class QuotaMiddleware:
|
class QuotaMiddleware:
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ from .quota import QuotaMiddleware
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="server")
|
logger = get_logger(name=__name__, category="core::server")
|
||||||
|
|
||||||
|
|
||||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||||
|
|
@ -415,7 +415,7 @@ def main(args: argparse.Namespace | None = None):
|
||||||
config_contents = yaml.safe_load(fp)
|
config_contents = yaml.safe_load(fp)
|
||||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||||
logger_config = LoggingConfig(**cfg)
|
logger_config = LoggingConfig(**cfg)
|
||||||
logger = get_logger(name=__name__, category="server", config=logger_config)
|
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||||
if args.env:
|
if args.env:
|
||||||
for env_pair in args.env:
|
for env_pair in args.env:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
|
|
||||||
logger = get_logger(__name__, category="core")
|
logger = get_logger(__name__, category="core::registry")
|
||||||
|
|
||||||
|
|
||||||
class DistributionRegistry(Protocol):
|
class DistributionRegistry(Protocol):
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from pathlib import Path
|
||||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="config_resolution")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"
|
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ from .utils import get_negative_inf_value, to_2tuple
|
||||||
|
|
||||||
MP_SCALE = 8
|
MP_SCALE = 8
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="models")
|
logger = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
|
|
||||||
def reduce_from_tensor_model_parallel_region(input_):
|
def reduce_from_tensor_model_parallel_region(input_):
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
||||||
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ from ...datatypes import QuantizationMode
|
||||||
from ..model import Transformer, TransformerBlock
|
from ..model import Transformer, TransformerBlock
|
||||||
from ..moe import MoE
|
from ..moe import MoE
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="models")
|
log = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
|
|
||||||
def swiglu_wrapper_no_reduce(
|
def swiglu_wrapper_no_reduce(
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import collections
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="llama")
|
log = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search"
|
||||||
WEB_SEARCH_TOOL = "web_search"
|
WEB_SEARCH_TOOL = "web_search"
|
||||||
RAG_TOOL_GROUP = "builtin::rag"
|
RAG_TOOL_GROUP = "builtin::rag"
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="agents")
|
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class ChatAgent(ShieldRunnerMixin):
|
class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
|
||||||
from .persistence import AgentInfo
|
from .persistence import AgentInfo
|
||||||
from .responses.openai_responses import OpenAIResponsesImpl
|
from .responses.openai_responses import OpenAIResponsesImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="agents")
|
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceAgentsImpl(Agents):
|
class MetaReferenceAgentsImpl(Agents):
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from llama_stack.core.request_headers import get_authenticated_user
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="agents")
|
log = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class AgentSessionInfo(Session):
|
class AgentSessionInfo(Session):
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ from .utils import (
|
||||||
convert_response_text_to_chat_response_format,
|
convert_response_text_to_chat_response_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="responses")
|
logger = get_logger(name=__name__, category="openai::responses")
|
||||||
|
|
||||||
|
|
||||||
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ from llama_stack.log import get_logger
|
||||||
from .types import ChatCompletionContext, ChatCompletionResult
|
from .types import ChatCompletionContext, ChatCompletionResult
|
||||||
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
|
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="responses")
|
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class StreamingResponseOrchestrator:
|
class StreamingResponseOrchestrator:
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .types import ChatCompletionContext, ToolExecutionResult
|
from .types import ChatCompletionContext, ToolExecutionResult
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="responses")
|
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class ToolExecutor:
|
class ToolExecutor:
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="agents")
|
log = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class SafetyException(Exception): # noqa: N818
|
class SafetyException(Exception): # noqa: N818
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,9 @@ from llama_stack.apis.inference import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
RerankResponse,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
|
@ -442,6 +445,15 @@ class MetaReferenceInferenceImpl(
|
||||||
results = await self._nonstream_chat_completion(request_batch)
|
results = await self._nonstream_chat_completion(request_batch)
|
||||||
return BatchChatCompletionResponse(batch=results)
|
return BatchChatCompletionResponse(batch=results)
|
||||||
|
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||||
|
max_num_results: int | None = None,
|
||||||
|
) -> RerankResponse:
|
||||||
|
raise NotImplementedError("Reranking is not supported for Meta Reference")
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request_batch: list[ChatCompletionRequest]
|
self, request_batch: list[ChatCompletionRequest]
|
||||||
) -> list[ChatCompletionResponse]:
|
) -> list[ChatCompletionResponse]:
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,9 @@ from llama_stack.apis.inference import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
RerankResponse,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
|
@ -122,3 +125,12 @@ class SentenceTransformersInferenceImpl(
|
||||||
logprobs: LogProbConfig | None = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
):
|
):
|
||||||
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")
|
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")
|
||||||
|
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||||
|
max_num_results: int | None = None,
|
||||||
|
) -> RerankResponse:
|
||||||
|
raise NotImplementedError("Reranking is not supported for Sentence Transformers")
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
|
AdapterSpec,
|
||||||
Api,
|
Api,
|
||||||
InlineProviderSpec,
|
InlineProviderSpec,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
|
remote_provider_spec,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
|
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
|
||||||
|
|
||||||
|
|
@ -23,4 +25,14 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
|
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
|
||||||
description="Local filesystem-based file storage provider for managing files and documents locally.",
|
description="Local filesystem-based file storage provider for managing files and documents locally.",
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.files,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="s3",
|
||||||
|
pip_packages=["boto3"] + sql_store_pip_packages,
|
||||||
|
module="llama_stack.providers.remote.files.s3",
|
||||||
|
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
||||||
|
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
237
llama_stack/providers/remote/files/s3/README.md
Normal file
237
llama_stack/providers/remote/files/s3/README.md
Normal file
|
|
@ -0,0 +1,237 @@
|
||||||
|
# S3 Files Provider
|
||||||
|
|
||||||
|
A remote S3-based implementation of the Llama Stack Files API that provides scalable cloud file storage with metadata persistence.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **AWS S3 Storage**: Store files in AWS S3 buckets for scalable, durable storage
|
||||||
|
- **Metadata Management**: Uses SQL database for efficient file metadata queries
|
||||||
|
- **OpenAI API Compatibility**: Full compatibility with OpenAI Files API endpoints
|
||||||
|
- **Flexible Authentication**: Support for IAM roles and access keys
|
||||||
|
- **Custom S3 Endpoints**: Support for MinIO and other S3-compatible services
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Basic Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
api: files
|
||||||
|
provider_type: remote::s3
|
||||||
|
config:
|
||||||
|
bucket_name: my-llama-stack-files
|
||||||
|
region: us-east-1
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ./s3_files_metadata.db
|
||||||
|
```
|
||||||
|
|
||||||
|
### Advanced Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
api: files
|
||||||
|
provider_type: remote::s3
|
||||||
|
config:
|
||||||
|
bucket_name: my-llama-stack-files
|
||||||
|
region: us-east-1
|
||||||
|
aws_access_key_id: YOUR_ACCESS_KEY
|
||||||
|
aws_secret_access_key: YOUR_SECRET_KEY
|
||||||
|
endpoint_url: https://s3.amazonaws.com # Optional for custom endpoints
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ./s3_files_metadata.db
|
||||||
|
```
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
The configuration supports environment variable substitution:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
config:
|
||||||
|
bucket_name: "${env.S3_BUCKET_NAME}"
|
||||||
|
region: "${env.AWS_REGION:=us-east-1}"
|
||||||
|
aws_access_key_id: "${env.AWS_ACCESS_KEY_ID:=}"
|
||||||
|
aws_secret_access_key: "${env.AWS_SECRET_ACCESS_KEY:=}"
|
||||||
|
endpoint_url: "${env.S3_ENDPOINT_URL:=}"
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: `S3_BUCKET_NAME` has no default value since S3 bucket names must be globally unique.
|
||||||
|
|
||||||
|
## Authentication
|
||||||
|
|
||||||
|
### IAM Roles (Recommended)
|
||||||
|
|
||||||
|
For production deployments, use IAM roles:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
config:
|
||||||
|
bucket_name: my-bucket
|
||||||
|
region: us-east-1
|
||||||
|
# No credentials needed - will use IAM role
|
||||||
|
```
|
||||||
|
|
||||||
|
### Access Keys
|
||||||
|
|
||||||
|
For development or specific use cases:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
config:
|
||||||
|
bucket_name: my-bucket
|
||||||
|
region: us-east-1
|
||||||
|
aws_access_key_id: AKIAIOSFODNN7EXAMPLE
|
||||||
|
aws_secret_access_key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY
|
||||||
|
```
|
||||||
|
|
||||||
|
## S3 Bucket Setup
|
||||||
|
|
||||||
|
### Required Permissions
|
||||||
|
|
||||||
|
The S3 provider requires the following permissions:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"Version": "2012-10-17",
|
||||||
|
"Statement": [
|
||||||
|
{
|
||||||
|
"Effect": "Allow",
|
||||||
|
"Action": [
|
||||||
|
"s3:GetObject",
|
||||||
|
"s3:PutObject",
|
||||||
|
"s3:DeleteObject",
|
||||||
|
"s3:ListBucket"
|
||||||
|
],
|
||||||
|
"Resource": [
|
||||||
|
"arn:aws:s3:::your-bucket-name",
|
||||||
|
"arn:aws:s3:::your-bucket-name/*"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Automatic Bucket Creation
|
||||||
|
|
||||||
|
By default, the S3 provider expects the bucket to already exist. If you want the provider to automatically create the bucket when it doesn't exist, set `auto_create_bucket: true` in your configuration:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
config:
|
||||||
|
bucket_name: my-bucket
|
||||||
|
auto_create_bucket: true # Will create bucket if it doesn't exist
|
||||||
|
region: us-east-1
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: When `auto_create_bucket` is enabled, the provider will need additional permissions:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"Version": "2012-10-17",
|
||||||
|
"Statement": [
|
||||||
|
{
|
||||||
|
"Effect": "Allow",
|
||||||
|
"Action": [
|
||||||
|
"s3:GetObject",
|
||||||
|
"s3:PutObject",
|
||||||
|
"s3:DeleteObject",
|
||||||
|
"s3:ListBucket",
|
||||||
|
"s3:CreateBucket"
|
||||||
|
],
|
||||||
|
"Resource": [
|
||||||
|
"arn:aws:s3:::your-bucket-name",
|
||||||
|
"arn:aws:s3:::your-bucket-name/*"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Bucket Policy (Optional)
|
||||||
|
|
||||||
|
For additional security, you can add a bucket policy:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"Version": "2012-10-17",
|
||||||
|
"Statement": [
|
||||||
|
{
|
||||||
|
"Sid": "LlamaStackAccess",
|
||||||
|
"Effect": "Allow",
|
||||||
|
"Principal": {
|
||||||
|
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
|
||||||
|
},
|
||||||
|
"Action": [
|
||||||
|
"s3:GetObject",
|
||||||
|
"s3:PutObject",
|
||||||
|
"s3:DeleteObject"
|
||||||
|
],
|
||||||
|
"Resource": "arn:aws:s3:::your-bucket-name/*"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Sid": "LlamaStackBucketAccess",
|
||||||
|
"Effect": "Allow",
|
||||||
|
"Principal": {
|
||||||
|
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
|
||||||
|
},
|
||||||
|
"Action": [
|
||||||
|
"s3:ListBucket"
|
||||||
|
],
|
||||||
|
"Resource": "arn:aws:s3:::your-bucket-name"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
### Metadata Persistence
|
||||||
|
|
||||||
|
File metadata is stored in a SQL database for fast queries and OpenAI API compatibility. The metadata includes:
|
||||||
|
|
||||||
|
- File ID
|
||||||
|
- Original filename
|
||||||
|
- Purpose (assistants, batch, etc.)
|
||||||
|
- File size in bytes
|
||||||
|
- Created and expiration timestamps
|
||||||
|
|
||||||
|
### TTL and Cleanup
|
||||||
|
|
||||||
|
Files currently have a fixed long expiration time (100 years).
|
||||||
|
|
||||||
|
## Development and Testing
|
||||||
|
|
||||||
|
### Using MinIO
|
||||||
|
|
||||||
|
For self-hosted S3-compatible storage:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
config:
|
||||||
|
bucket_name: test-bucket
|
||||||
|
region: us-east-1
|
||||||
|
endpoint_url: http://localhost:9000
|
||||||
|
aws_access_key_id: minioadmin
|
||||||
|
aws_secret_access_key: minioadmin
|
||||||
|
```
|
||||||
|
|
||||||
|
## Monitoring and Logging
|
||||||
|
|
||||||
|
The provider logs important operations and errors. For production deployments, consider:
|
||||||
|
|
||||||
|
- CloudWatch monitoring for S3 operations
|
||||||
|
- Custom metrics for file upload/download rates
|
||||||
|
- Error rate monitoring
|
||||||
|
- Performance metrics tracking
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
The provider handles various error scenarios:
|
||||||
|
|
||||||
|
- S3 connectivity issues
|
||||||
|
- Bucket access permissions
|
||||||
|
- File not found errors
|
||||||
|
- Metadata consistency checks
|
||||||
|
|
||||||
|
## Known Limitations
|
||||||
|
|
||||||
|
- Fixed long TTL (100 years) instead of configurable expiration
|
||||||
|
- No server-side encryption enabled by default
|
||||||
|
- No support for AWS session tokens
|
||||||
|
- No S3 key prefix organization support
|
||||||
|
- No multipart upload support (all files uploaded as single objects)
|
||||||
20
llama_stack/providers/remote/files/s3/__init__.py
Normal file
20
llama_stack/providers/remote/files/s3/__init__.py
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.core.datatypes import Api
|
||||||
|
|
||||||
|
from .config import S3FilesImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]):
|
||||||
|
from .files import S3FilesImpl
|
||||||
|
|
||||||
|
# TODO: authorization policies and user separation
|
||||||
|
impl = S3FilesImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
42
llama_stack/providers/remote/files/s3/config.py
Normal file
42
llama_stack/providers/remote/files/s3/config.py
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
class S3FilesImplConfig(BaseModel):
|
||||||
|
"""Configuration for S3-based files provider."""
|
||||||
|
|
||||||
|
bucket_name: str = Field(description="S3 bucket name to store files")
|
||||||
|
region: str = Field(default="us-east-1", description="AWS region where the bucket is located")
|
||||||
|
aws_access_key_id: str | None = Field(default=None, description="AWS access key ID (optional if using IAM roles)")
|
||||||
|
aws_secret_access_key: str | None = Field(
|
||||||
|
default=None, description="AWS secret access key (optional if using IAM roles)"
|
||||||
|
)
|
||||||
|
endpoint_url: str | None = Field(default=None, description="Custom S3 endpoint URL (for MinIO, LocalStack, etc.)")
|
||||||
|
auto_create_bucket: bool = Field(
|
||||||
|
default=False, description="Automatically create the S3 bucket if it doesn't exist"
|
||||||
|
)
|
||||||
|
metadata_store: SqlStoreConfig = Field(description="SQL store configuration for file metadata")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"bucket_name": "${env.S3_BUCKET_NAME}", # no default, buckets must be globally unique
|
||||||
|
"region": "${env.AWS_REGION:=us-east-1}",
|
||||||
|
"aws_access_key_id": "${env.AWS_ACCESS_KEY_ID:=}",
|
||||||
|
"aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:=}",
|
||||||
|
"endpoint_url": "${env.S3_ENDPOINT_URL:=}",
|
||||||
|
"auto_create_bucket": "${env.S3_AUTO_CREATE_BUCKET:=false}",
|
||||||
|
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
|
||||||
|
__distro_dir__=__distro_dir__,
|
||||||
|
db_name="s3_files_metadata.db",
|
||||||
|
),
|
||||||
|
}
|
||||||
272
llama_stack/providers/remote/files/s3/files.py
Normal file
272
llama_stack/providers/remote/files/s3/files.py
Normal file
|
|
@ -0,0 +1,272 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
|
||||||
|
from fastapi import File, Form, Response, UploadFile
|
||||||
|
|
||||||
|
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||||
|
from llama_stack.apis.common.responses import Order
|
||||||
|
from llama_stack.apis.files import (
|
||||||
|
Files,
|
||||||
|
ListOpenAIFileResponse,
|
||||||
|
OpenAIFileDeleteResponse,
|
||||||
|
OpenAIFileObject,
|
||||||
|
OpenAIFilePurpose,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl
|
||||||
|
|
||||||
|
from .config import S3FilesImplConfig
|
||||||
|
|
||||||
|
# TODO: provider data for S3 credentials
|
||||||
|
|
||||||
|
|
||||||
|
def _create_s3_client(config: S3FilesImplConfig) -> boto3.client:
|
||||||
|
try:
|
||||||
|
s3_config = {
|
||||||
|
"region_name": config.region,
|
||||||
|
}
|
||||||
|
|
||||||
|
# endpoint URL if specified (for MinIO, LocalStack, etc.)
|
||||||
|
if config.endpoint_url:
|
||||||
|
s3_config["endpoint_url"] = config.endpoint_url
|
||||||
|
|
||||||
|
if config.aws_access_key_id and config.aws_secret_access_key:
|
||||||
|
s3_config.update(
|
||||||
|
{
|
||||||
|
"aws_access_key_id": config.aws_access_key_id,
|
||||||
|
"aws_secret_access_key": config.aws_secret_access_key,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return boto3.client("s3", **s3_config)
|
||||||
|
|
||||||
|
except (BotoCoreError, NoCredentialsError) as e:
|
||||||
|
raise RuntimeError(f"Failed to initialize S3 client: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_bucket_if_not_exists(client: boto3.client, config: S3FilesImplConfig) -> None:
|
||||||
|
try:
|
||||||
|
client.head_bucket(Bucket=config.bucket_name)
|
||||||
|
except ClientError as e:
|
||||||
|
error_code = e.response["Error"]["Code"]
|
||||||
|
if error_code == "404":
|
||||||
|
if not config.auto_create_bucket:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"S3 bucket '{config.bucket_name}' does not exist. "
|
||||||
|
f"Either create the bucket manually or set 'auto_create_bucket: true' in your configuration."
|
||||||
|
) from e
|
||||||
|
try:
|
||||||
|
# For us-east-1, we can't specify LocationConstraint
|
||||||
|
if config.region == "us-east-1":
|
||||||
|
client.create_bucket(Bucket=config.bucket_name)
|
||||||
|
else:
|
||||||
|
client.create_bucket(
|
||||||
|
Bucket=config.bucket_name,
|
||||||
|
CreateBucketConfiguration={"LocationConstraint": config.region},
|
||||||
|
)
|
||||||
|
except ClientError as create_error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to create S3 bucket '{config.bucket_name}': {create_error}"
|
||||||
|
) from create_error
|
||||||
|
elif error_code == "403":
|
||||||
|
raise RuntimeError(f"Access denied to S3 bucket '{config.bucket_name}'") from e
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Failed to access S3 bucket '{config.bucket_name}': {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
class S3FilesImpl(Files):
|
||||||
|
"""S3-based implementation of the Files API."""
|
||||||
|
|
||||||
|
# TODO: implement expiration, for now a silly offset
|
||||||
|
_SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60
|
||||||
|
|
||||||
|
def __init__(self, config: S3FilesImplConfig) -> None:
|
||||||
|
self._config = config
|
||||||
|
self._client: boto3.client | None = None
|
||||||
|
self._sql_store: SqlStore | None = None
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
self._client = _create_s3_client(self._config)
|
||||||
|
await _create_bucket_if_not_exists(self._client, self._config)
|
||||||
|
|
||||||
|
self._sql_store = sqlstore_impl(self._config.metadata_store)
|
||||||
|
await self._sql_store.create_table(
|
||||||
|
"openai_files",
|
||||||
|
{
|
||||||
|
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||||
|
"filename": ColumnType.STRING,
|
||||||
|
"purpose": ColumnType.STRING,
|
||||||
|
"bytes": ColumnType.INTEGER,
|
||||||
|
"created_at": ColumnType.INTEGER,
|
||||||
|
"expires_at": ColumnType.INTEGER,
|
||||||
|
# TODO: add s3_etag field for integrity checking
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> boto3.client:
|
||||||
|
assert self._client is not None, "Provider not initialized"
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sql_store(self) -> SqlStore:
|
||||||
|
assert self._sql_store is not None, "Provider not initialized"
|
||||||
|
return self._sql_store
|
||||||
|
|
||||||
|
async def openai_upload_file(
|
||||||
|
self,
|
||||||
|
file: Annotated[UploadFile, File()],
|
||||||
|
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
file_id = f"file-{uuid.uuid4().hex}"
|
||||||
|
|
||||||
|
filename = getattr(file, "filename", None) or "uploaded_file"
|
||||||
|
|
||||||
|
created_at = int(time.time())
|
||||||
|
expires_at = created_at + self._SILLY_EXPIRATION_OFFSET
|
||||||
|
content = await file.read()
|
||||||
|
file_size = len(content)
|
||||||
|
|
||||||
|
await self.sql_store.insert(
|
||||||
|
"openai_files",
|
||||||
|
{
|
||||||
|
"id": file_id,
|
||||||
|
"filename": filename,
|
||||||
|
"purpose": purpose.value,
|
||||||
|
"bytes": file_size,
|
||||||
|
"created_at": created_at,
|
||||||
|
"expires_at": expires_at,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.put_object(
|
||||||
|
Bucket=self._config.bucket_name,
|
||||||
|
Key=file_id,
|
||||||
|
Body=content,
|
||||||
|
# TODO: enable server-side encryption
|
||||||
|
)
|
||||||
|
except ClientError as e:
|
||||||
|
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||||
|
|
||||||
|
raise RuntimeError(f"Failed to upload file to S3: {e}") from e
|
||||||
|
|
||||||
|
return OpenAIFileObject(
|
||||||
|
id=file_id,
|
||||||
|
filename=filename,
|
||||||
|
purpose=purpose,
|
||||||
|
bytes=file_size,
|
||||||
|
created_at=created_at,
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_list_files(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int | None = 10000,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
purpose: OpenAIFilePurpose | None = None,
|
||||||
|
) -> ListOpenAIFileResponse:
|
||||||
|
# this purely defensive. it should not happen because the router also default to Order.desc.
|
||||||
|
if not order:
|
||||||
|
order = Order.desc
|
||||||
|
|
||||||
|
where_conditions = {}
|
||||||
|
if purpose:
|
||||||
|
where_conditions["purpose"] = purpose.value
|
||||||
|
|
||||||
|
paginated_result = await self.sql_store.fetch_all(
|
||||||
|
table="openai_files",
|
||||||
|
where=where_conditions if where_conditions else None,
|
||||||
|
order_by=[("created_at", order.value)],
|
||||||
|
cursor=("id", after) if after else None,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
files = [
|
||||||
|
OpenAIFileObject(
|
||||||
|
id=row["id"],
|
||||||
|
filename=row["filename"],
|
||||||
|
purpose=OpenAIFilePurpose(row["purpose"]),
|
||||||
|
bytes=row["bytes"],
|
||||||
|
created_at=row["created_at"],
|
||||||
|
expires_at=row["expires_at"],
|
||||||
|
)
|
||||||
|
for row in paginated_result.data
|
||||||
|
]
|
||||||
|
|
||||||
|
return ListOpenAIFileResponse(
|
||||||
|
data=files,
|
||||||
|
has_more=paginated_result.has_more,
|
||||||
|
# empty string or None? spec says str, ref impl returns str | None, we go with spec
|
||||||
|
first_id=files[0].id if files else "",
|
||||||
|
last_id=files[-1].id if files else "",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
|
||||||
|
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||||
|
if not row:
|
||||||
|
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||||
|
|
||||||
|
return OpenAIFileObject(
|
||||||
|
id=row["id"],
|
||||||
|
filename=row["filename"],
|
||||||
|
purpose=OpenAIFilePurpose(row["purpose"]),
|
||||||
|
bytes=row["bytes"],
|
||||||
|
created_at=row["created_at"],
|
||||||
|
expires_at=row["expires_at"],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
|
||||||
|
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||||
|
if not row:
|
||||||
|
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.delete_object(
|
||||||
|
Bucket=self._config.bucket_name,
|
||||||
|
Key=row["id"],
|
||||||
|
)
|
||||||
|
except ClientError as e:
|
||||||
|
if e.response["Error"]["Code"] != "NoSuchKey":
|
||||||
|
raise RuntimeError(f"Failed to delete file from S3: {e}") from e
|
||||||
|
|
||||||
|
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||||
|
|
||||||
|
return OpenAIFileDeleteResponse(id=file_id, deleted=True)
|
||||||
|
|
||||||
|
async def openai_retrieve_file_content(self, file_id: str) -> Response:
|
||||||
|
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||||
|
if not row:
|
||||||
|
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.get_object(
|
||||||
|
Bucket=self._config.bucket_name,
|
||||||
|
Key=row["id"],
|
||||||
|
)
|
||||||
|
# TODO: can we stream this instead of loading it into memory
|
||||||
|
content = response["Body"].read()
|
||||||
|
except ClientError as e:
|
||||||
|
if e.response["Error"]["Code"] == "NoSuchKey":
|
||||||
|
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||||
|
raise ResourceNotFoundError(file_id, "File", "files.list()") from e
|
||||||
|
raise RuntimeError(f"Failed to download file from S3: {e}") from e
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
content=content,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
|
||||||
|
)
|
||||||
|
|
@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
from .config import FireworksImplConfig
|
from .config import FireworksImplConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference::fireworks")
|
||||||
|
|
||||||
|
|
||||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,11 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
RerankResponse,
|
||||||
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
|
|
@ -10,7 +15,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
|
||||||
|
|
||||||
|
|
||||||
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
|
|
@ -54,3 +59,12 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
|
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
await super().shutdown()
|
await super().shutdown()
|
||||||
|
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||||
|
max_num_results: int | None = None,
|
||||||
|
) -> RerankResponse:
|
||||||
|
raise NotImplementedError("Reranking is not supported for Llama OpenAI Compat")
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ from .openai_utils import (
|
||||||
)
|
)
|
||||||
from .utils import _is_nvidia_hosted
|
from .utils import _is_nvidia_hosted
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference::nvidia")
|
||||||
|
|
||||||
|
|
||||||
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from . import NVIDIAConfig
|
from . import NVIDIAConfig
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference::nvidia")
|
||||||
|
|
||||||
|
|
||||||
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -37,11 +37,14 @@ from llama_stack.apis.inference import (
|
||||||
Message,
|
Message,
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIEmbeddingUsage,
|
OpenAIEmbeddingUsage,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
|
RerankResponse,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
|
|
@ -85,7 +88,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference::ollama")
|
||||||
|
|
||||||
|
|
||||||
class OllamaInferenceAdapter(
|
class OllamaInferenceAdapter(
|
||||||
|
|
@ -641,6 +644,15 @@ class OllamaInferenceAdapter(
|
||||||
):
|
):
|
||||||
raise NotImplementedError("Batch chat completion is not supported for Ollama")
|
raise NotImplementedError("Batch chat completion is not supported for Ollama")
|
||||||
|
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||||
|
max_num_results: int | None = None,
|
||||||
|
) -> RerankResponse:
|
||||||
|
raise NotImplementedError("Reranking is not supported for Ollama")
|
||||||
|
|
||||||
|
|
||||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
||||||
async def _convert_content(content) -> dict:
|
async def _convert_content(content) -> dict:
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
from .config import OpenAIConfig
|
from .config import OpenAIConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference::openai")
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="inference")
|
log = get_logger(name=__name__, category="inference::tgi")
|
||||||
|
|
||||||
|
|
||||||
def build_hf_repo_model_entries():
|
def build_hf_repo_model_entries():
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
from .config import TogetherImplConfig
|
from .config import TogetherImplConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference::together")
|
||||||
|
|
||||||
|
|
||||||
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||||
|
|
|
||||||
|
|
@ -39,12 +39,15 @@ from llama_stack.apis.inference import (
|
||||||
Message,
|
Message,
|
||||||
ModelStore,
|
ModelStore,
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
OpenAIEmbeddingData,
|
OpenAIEmbeddingData,
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIEmbeddingUsage,
|
OpenAIEmbeddingUsage,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
|
RerankResponse,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
|
|
@ -85,7 +88,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import VLLMInferenceAdapterConfig
|
from .config import VLLMInferenceAdapterConfig
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="inference")
|
log = get_logger(name=__name__, category="inference::vllm")
|
||||||
|
|
||||||
|
|
||||||
def build_hf_repo_model_entries():
|
def build_hf_repo_model_entries():
|
||||||
|
|
@ -732,4 +735,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
response_format: ResponseFormat | None = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: LogProbConfig | None = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
):
|
):
|
||||||
raise NotImplementedError("Batch chat completion is not supported for Ollama")
|
raise NotImplementedError("Batch chat completion is not supported for vLLM")
|
||||||
|
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||||
|
max_num_results: int | None = None,
|
||||||
|
) -> RerankResponse:
|
||||||
|
raise NotImplementedError("Reranking is not supported for vLLM")
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefa
|
||||||
|
|
||||||
from .config import NvidiaPostTrainingConfig
|
from .config import NvidiaPostTrainingConfig
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="integration")
|
logger = get_logger(name=__name__, category="post_training::nvidia")
|
||||||
|
|
||||||
|
|
||||||
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:
|
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
|
|
||||||
from .config import BedrockSafetyConfig
|
from .config import BedrockSafetyConfig
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="safety")
|
logger = get_logger(name=__name__, category="safety::bedrock")
|
||||||
|
|
||||||
|
|
||||||
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
|
||||||
|
|
||||||
from .config import NVIDIASafetyConfig
|
from .config import NVIDIASafetyConfig
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="safety")
|
logger = get_logger(name=__name__, category="safety::nvidia")
|
||||||
|
|
||||||
|
|
||||||
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
|
||||||
|
|
||||||
from .config import SambaNovaSafetyConfig
|
from .config import SambaNovaSafetyConfig
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="safety")
|
logger = get_logger(name=__name__, category="safety::sambanova")
|
||||||
|
|
||||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
|
||||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="vector_io")
|
log = get_logger(name=__name__, category="vector_io::chroma")
|
||||||
|
|
||||||
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
|
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
||||||
|
|
||||||
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="vector_io")
|
logger = get_logger(name=__name__, category="vector_io::milvus")
|
||||||
|
|
||||||
VERSION = "v3"
|
VERSION = "v3"
|
||||||
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
|
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
|
||||||
from .config import PGVectorVectorIOConfig
|
from .config import PGVectorVectorIOConfig
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="vector_io")
|
log = get_logger(name=__name__, category="vector_io::pgvector")
|
||||||
|
|
||||||
VERSION = "v3"
|
VERSION = "v3"
|
||||||
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
|
||||||
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="vector_io")
|
log = get_logger(name=__name__, category="vector_io::qdrant")
|
||||||
CHUNK_ID_KEY = "_chunk_id"
|
CHUNK_ID_KEY = "_chunk_id"
|
||||||
|
|
||||||
# KV store prefixes for vector databases
|
# KV store prefixes for vector databases
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
||||||
|
|
||||||
from .config import WeaviateVectorIOConfig
|
from .config import WeaviateVectorIOConfig
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="vector_io")
|
log = get_logger(name=__name__, category="vector_io::weaviate")
|
||||||
|
|
||||||
VERSION = "v3"
|
VERSION = "v3"
|
||||||
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
|
||||||
EMBEDDING_MODELS = {}
|
EMBEDDING_MODELS = {}
|
||||||
|
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="inference")
|
log = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class SentenceTransformerEmbeddingMixin:
|
class SentenceTransformerEmbeddingMixin:
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMOpenAIMixin(
|
class LiteLLMOpenAIMixin(
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference import (
|
||||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class RemoteInferenceProviderConfig(BaseModel):
|
class RemoteInferenceProviderConfig(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -134,7 +134,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
decode_assistant_message,
|
decode_assistant_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class OpenAICompatCompletionChoiceDelta(BaseModel):
|
class OpenAICompatCompletionChoiceDelta(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class OpenAIMixin(ABC):
|
class OpenAIMixin(ABC):
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="inference")
|
log = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
|
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
from ..config import MongoDBKVStoreConfig
|
from ..config import MongoDBKVStoreConfig
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="kvstore")
|
log = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class MongoDBKVStoreImpl(KVStore):
|
class MongoDBKVStoreImpl(KVStore):
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from llama_stack.log import get_logger
|
||||||
from ..api import KVStore
|
from ..api import KVStore
|
||||||
from ..config import PostgresKVStoreConfig
|
from ..config import PostgresKVStoreConfig
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="kvstore")
|
log = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class PostgresKVStoreImpl(KVStore):
|
class PostgresKVStoreImpl(KVStore):
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
make_overlapped_chunks,
|
make_overlapped_chunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="memory")
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
# Constants for OpenAI vector stores
|
# Constants for OpenAI vector stores
|
||||||
CHUNK_MULTIPLIER = 5
|
CHUNK_MULTIPLIER = 5
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="memory")
|
log = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class ChunkForDeletion(BaseModel):
|
class ChunkForDeletion(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="scheduler")
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
# TODO: revisit the list of possible statuses when defining a more coherent
|
# TODO: revisit the list of possible statuses when defining a more coherent
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from llama_stack.log import get_logger
|
||||||
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
||||||
from .sqlstore import SqlStoreType
|
from .sqlstore import SqlStoreType
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="authorized_sqlstore")
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
# Hardcoded copy of the default policy that our SQL filtering implements
|
# Hardcoded copy of the default policy that our SQL filtering implements
|
||||||
# WARNING: If default_policy() changes, this constant must be updated accordingly
|
# WARNING: If default_policy() changes, this constant must be updated accordingly
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ from llama_stack.log import get_logger
|
||||||
from .api import ColumnDefinition, ColumnType, SqlStore
|
from .api import ColumnDefinition, ColumnType, SqlStore
|
||||||
from .sqlstore import SqlAlchemySqlStoreConfig
|
from .sqlstore import SqlAlchemySqlStoreConfig
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="sqlstore")
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
TYPE_MAPPING: dict[ColumnType, Any] = {
|
TYPE_MAPPING: dict[ColumnType, Any] = {
|
||||||
ColumnType.INTEGER: Integer,
|
ColumnType.INTEGER: Integer,
|
||||||
|
|
|
||||||
587
llama_stack/ui/app/chat-playground/page.test.tsx
Normal file
587
llama_stack/ui/app/chat-playground/page.test.tsx
Normal file
|
|
@ -0,0 +1,587 @@
|
||||||
|
import React from "react";
|
||||||
|
import {
|
||||||
|
render,
|
||||||
|
screen,
|
||||||
|
fireEvent,
|
||||||
|
waitFor,
|
||||||
|
act,
|
||||||
|
} from "@testing-library/react";
|
||||||
|
import "@testing-library/jest-dom";
|
||||||
|
import ChatPlaygroundPage from "./page";
|
||||||
|
|
||||||
|
const mockClient = {
|
||||||
|
agents: {
|
||||||
|
list: jest.fn(),
|
||||||
|
create: jest.fn(),
|
||||||
|
retrieve: jest.fn(),
|
||||||
|
delete: jest.fn(),
|
||||||
|
session: {
|
||||||
|
list: jest.fn(),
|
||||||
|
create: jest.fn(),
|
||||||
|
delete: jest.fn(),
|
||||||
|
retrieve: jest.fn(),
|
||||||
|
},
|
||||||
|
turn: {
|
||||||
|
create: jest.fn(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
models: {
|
||||||
|
list: jest.fn(),
|
||||||
|
},
|
||||||
|
toolgroups: {
|
||||||
|
list: jest.fn(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
jest.mock("@/hooks/use-auth-client", () => ({
|
||||||
|
useAuthClient: jest.fn(() => mockClient),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock("@/components/chat-playground/chat", () => ({
|
||||||
|
Chat: jest.fn(
|
||||||
|
({
|
||||||
|
className,
|
||||||
|
messages,
|
||||||
|
handleSubmit,
|
||||||
|
input,
|
||||||
|
handleInputChange,
|
||||||
|
isGenerating,
|
||||||
|
append,
|
||||||
|
suggestions,
|
||||||
|
}) => (
|
||||||
|
<div data-testid="chat-component" className={className}>
|
||||||
|
<div data-testid="messages-count">{messages.length}</div>
|
||||||
|
<input
|
||||||
|
data-testid="chat-input"
|
||||||
|
value={input}
|
||||||
|
onChange={handleInputChange}
|
||||||
|
disabled={isGenerating}
|
||||||
|
/>
|
||||||
|
<button data-testid="submit-button" onClick={handleSubmit}>
|
||||||
|
Submit
|
||||||
|
</button>
|
||||||
|
{suggestions?.map((suggestion: string, index: number) => (
|
||||||
|
<button
|
||||||
|
key={index}
|
||||||
|
data-testid={`suggestion-${index}`}
|
||||||
|
onClick={() => append({ role: "user", content: suggestion })}
|
||||||
|
>
|
||||||
|
{suggestion}
|
||||||
|
</button>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock("@/components/chat-playground/conversations", () => ({
|
||||||
|
SessionManager: jest.fn(({ selectedAgentId, onNewSession }) => (
|
||||||
|
<div data-testid="session-manager">
|
||||||
|
{selectedAgentId && (
|
||||||
|
<>
|
||||||
|
<div data-testid="selected-agent">{selectedAgentId}</div>
|
||||||
|
<button data-testid="new-session-button" onClick={onNewSession}>
|
||||||
|
New Session
|
||||||
|
</button>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)),
|
||||||
|
SessionUtils: {
|
||||||
|
saveCurrentSessionId: jest.fn(),
|
||||||
|
loadCurrentSessionId: jest.fn(),
|
||||||
|
loadCurrentAgentId: jest.fn(),
|
||||||
|
saveCurrentAgentId: jest.fn(),
|
||||||
|
clearCurrentSession: jest.fn(),
|
||||||
|
saveSessionData: jest.fn(),
|
||||||
|
loadSessionData: jest.fn(),
|
||||||
|
saveAgentConfig: jest.fn(),
|
||||||
|
loadAgentConfig: jest.fn(),
|
||||||
|
clearAgentCache: jest.fn(),
|
||||||
|
createDefaultSession: jest.fn(() => ({
|
||||||
|
id: "test-session-123",
|
||||||
|
name: "Default Session",
|
||||||
|
messages: [],
|
||||||
|
selectedModel: "",
|
||||||
|
systemMessage: "You are a helpful assistant.",
|
||||||
|
agentId: "test-agent-123",
|
||||||
|
createdAt: Date.now(),
|
||||||
|
updatedAt: Date.now(),
|
||||||
|
})),
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
const mockAgents = [
|
||||||
|
{
|
||||||
|
agent_id: "agent_123",
|
||||||
|
agent_config: {
|
||||||
|
name: "Test Agent",
|
||||||
|
instructions: "You are a test assistant.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
agent_id: "agent_456",
|
||||||
|
agent_config: {
|
||||||
|
agent_name: "Another Agent",
|
||||||
|
instructions: "You are another assistant.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const mockModels = [
|
||||||
|
{
|
||||||
|
identifier: "test-model-1",
|
||||||
|
model_type: "llm",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
identifier: "test-model-2",
|
||||||
|
model_type: "llm",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const mockToolgroups = [
|
||||||
|
{
|
||||||
|
identifier: "builtin::rag",
|
||||||
|
provider_id: "test-provider",
|
||||||
|
type: "tool_group",
|
||||||
|
provider_resource_id: "test-resource",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
describe("ChatPlaygroundPage", () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
Element.prototype.scrollIntoView = jest.fn();
|
||||||
|
mockClient.agents.list.mockResolvedValue({ data: mockAgents });
|
||||||
|
mockClient.models.list.mockResolvedValue(mockModels);
|
||||||
|
mockClient.toolgroups.list.mockResolvedValue(mockToolgroups);
|
||||||
|
mockClient.agents.session.create.mockResolvedValue({
|
||||||
|
session_id: "new-session-123",
|
||||||
|
});
|
||||||
|
mockClient.agents.session.list.mockResolvedValue({ data: [] });
|
||||||
|
mockClient.agents.session.retrieve.mockResolvedValue({
|
||||||
|
session_id: "test-session",
|
||||||
|
session_name: "Test Session",
|
||||||
|
started_at: new Date().toISOString(),
|
||||||
|
turns: [],
|
||||||
|
}); // No turns by default
|
||||||
|
mockClient.agents.retrieve.mockResolvedValue({
|
||||||
|
agent_id: "test-agent",
|
||||||
|
agent_config: {
|
||||||
|
toolgroups: ["builtin::rag"],
|
||||||
|
instructions: "Test instructions",
|
||||||
|
model: "test-model",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
mockClient.agents.delete.mockResolvedValue(undefined);
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Agent Selector Rendering", () => {
|
||||||
|
test("shows agent selector when agents are available", async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("Agent Session:")).toBeInTheDocument();
|
||||||
|
expect(screen.getAllByRole("combobox")).toHaveLength(2);
|
||||||
|
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Clear Chat")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("does not show agent selector when no agents are available", async () => {
|
||||||
|
mockClient.agents.list.mockResolvedValue({ data: [] });
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument();
|
||||||
|
expect(screen.getAllByRole("combobox")).toHaveLength(1);
|
||||||
|
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
|
||||||
|
expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("does not show agent selector while loading", async () => {
|
||||||
|
mockClient.agents.list.mockImplementation(() => new Promise(() => {}));
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument();
|
||||||
|
expect(screen.getAllByRole("combobox")).toHaveLength(1);
|
||||||
|
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
|
||||||
|
expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("shows agent options in selector", async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
const agentCombobox = screen.getAllByRole("combobox").find(element => {
|
||||||
|
return (
|
||||||
|
element.textContent?.includes("Test Agent") ||
|
||||||
|
element.textContent?.includes("Select Agent")
|
||||||
|
);
|
||||||
|
});
|
||||||
|
expect(agentCombobox).toBeDefined();
|
||||||
|
fireEvent.click(agentCombobox!);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getAllByText("Test Agent")).toHaveLength(2);
|
||||||
|
expect(screen.getByText("Another Agent")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("displays agent ID when no name is available", async () => {
|
||||||
|
const agentWithoutName = {
|
||||||
|
agent_id: "agent_789",
|
||||||
|
agent_config: {
|
||||||
|
instructions: "You are an agent without a name.",
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
mockClient.agents.list.mockResolvedValue({ data: [agentWithoutName] });
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
const agentCombobox = screen.getAllByRole("combobox").find(element => {
|
||||||
|
return (
|
||||||
|
element.textContent?.includes("Agent agent_78") ||
|
||||||
|
element.textContent?.includes("Select Agent")
|
||||||
|
);
|
||||||
|
});
|
||||||
|
expect(agentCombobox).toBeDefined();
|
||||||
|
fireEvent.click(agentCombobox!);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getAllByText("Agent agent_78...")).toHaveLength(2);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Agent Creation Modal", () => {
|
||||||
|
test("opens agent creation modal when + New Agent is clicked", async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
const newAgentButton = screen.getByText("+ New Agent");
|
||||||
|
fireEvent.click(newAgentButton);
|
||||||
|
|
||||||
|
expect(screen.getByText("Create New Agent")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Agent Name (optional)")).toBeInTheDocument();
|
||||||
|
expect(screen.getAllByText("Model")).toHaveLength(2);
|
||||||
|
expect(screen.getByText("System Instructions")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Tools (optional)")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("closes modal when Cancel is clicked", async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
const newAgentButton = screen.getByText("+ New Agent");
|
||||||
|
fireEvent.click(newAgentButton);
|
||||||
|
|
||||||
|
const cancelButton = screen.getByText("Cancel");
|
||||||
|
fireEvent.click(cancelButton);
|
||||||
|
|
||||||
|
expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("creates agent when Create Agent is clicked", async () => {
|
||||||
|
mockClient.agents.create.mockResolvedValue({ agent_id: "new-agent-123" });
|
||||||
|
mockClient.agents.list
|
||||||
|
.mockResolvedValueOnce({ data: mockAgents })
|
||||||
|
.mockResolvedValueOnce({
|
||||||
|
data: [
|
||||||
|
...mockAgents,
|
||||||
|
{ agent_id: "new-agent-123", agent_config: { name: "New Agent" } },
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
const newAgentButton = screen.getByText("+ New Agent");
|
||||||
|
await act(async () => {
|
||||||
|
fireEvent.click(newAgentButton);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("Create New Agent")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
const nameInput = screen.getByPlaceholderText("My Custom Agent");
|
||||||
|
await act(async () => {
|
||||||
|
fireEvent.change(nameInput, { target: { value: "Test Agent Name" } });
|
||||||
|
});
|
||||||
|
|
||||||
|
const instructionsTextarea = screen.getByDisplayValue(
|
||||||
|
"You are a helpful assistant."
|
||||||
|
);
|
||||||
|
await act(async () => {
|
||||||
|
fireEvent.change(instructionsTextarea, {
|
||||||
|
target: { value: "Custom instructions" },
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
const modalModelSelectors = screen
|
||||||
|
.getAllByRole("combobox")
|
||||||
|
.filter(el => {
|
||||||
|
return (
|
||||||
|
el.textContent?.includes("Select Model") ||
|
||||||
|
el.closest('[class*="modal"]') ||
|
||||||
|
el.closest('[class*="card"]')
|
||||||
|
);
|
||||||
|
});
|
||||||
|
expect(modalModelSelectors.length).toBeGreaterThan(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
const modalModelSelectors = screen.getAllByRole("combobox").filter(el => {
|
||||||
|
return (
|
||||||
|
el.textContent?.includes("Select Model") ||
|
||||||
|
el.closest('[class*="modal"]') ||
|
||||||
|
el.closest('[class*="card"]')
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
fireEvent.click(modalModelSelectors[0]);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
const modelOptions = screen.getAllByText("test-model-1");
|
||||||
|
expect(modelOptions.length).toBeGreaterThan(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
const modelOptions = screen.getAllByText("test-model-1");
|
||||||
|
const dropdownOption = modelOptions.find(
|
||||||
|
option =>
|
||||||
|
option.closest('[role="option"]') ||
|
||||||
|
option.id?.includes("radix") ||
|
||||||
|
option.getAttribute("aria-selected") !== null
|
||||||
|
);
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
fireEvent.click(
|
||||||
|
dropdownOption || modelOptions[modelOptions.length - 1]
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
const createButton = screen.getByText("Create Agent");
|
||||||
|
expect(createButton).not.toBeDisabled();
|
||||||
|
});
|
||||||
|
|
||||||
|
const createButton = screen.getByText("Create Agent");
|
||||||
|
await act(async () => {
|
||||||
|
fireEvent.click(createButton);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockClient.agents.create).toHaveBeenCalledWith({
|
||||||
|
agent_config: {
|
||||||
|
model: expect.any(String),
|
||||||
|
instructions: "Custom instructions",
|
||||||
|
name: "Test Agent Name",
|
||||||
|
enable_session_persistence: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Agent Selection", () => {
|
||||||
|
test("creates default session when agent is selected", async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
// first agent should be auto-selected
|
||||||
|
expect(mockClient.agents.session.create).toHaveBeenCalledWith(
|
||||||
|
"agent_123",
|
||||||
|
{ session_name: "Default Session" }
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("switches agent when different agent is selected", async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
const agentCombobox = screen.getAllByRole("combobox").find(element => {
|
||||||
|
return (
|
||||||
|
element.textContent?.includes("Test Agent") ||
|
||||||
|
element.textContent?.includes("Select Agent")
|
||||||
|
);
|
||||||
|
});
|
||||||
|
expect(agentCombobox).toBeDefined();
|
||||||
|
fireEvent.click(agentCombobox!);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
const anotherAgentOption = screen.getByText("Another Agent");
|
||||||
|
fireEvent.click(anotherAgentOption);
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(mockClient.agents.session.create).toHaveBeenCalledWith(
|
||||||
|
"agent_456",
|
||||||
|
{ session_name: "Default Session" }
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Agent Deletion", () => {
|
||||||
|
test("shows delete button when multiple agents exist", async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("hides delete button when only one agent exists", async () => {
|
||||||
|
mockClient.agents.list.mockResolvedValue({
|
||||||
|
data: [mockAgents[0]],
|
||||||
|
});
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(
|
||||||
|
screen.queryByTitle("Delete current agent")
|
||||||
|
).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("deletes agent and switches to another when confirmed", async () => {
|
||||||
|
global.confirm = jest.fn(() => true);
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
mockClient.agents.delete.mockResolvedValue(undefined);
|
||||||
|
mockClient.agents.list.mockResolvedValueOnce({ data: mockAgents });
|
||||||
|
mockClient.agents.list.mockResolvedValueOnce({
|
||||||
|
data: [mockAgents[1]],
|
||||||
|
});
|
||||||
|
|
||||||
|
const deleteButton = screen.getByTitle("Delete current agent");
|
||||||
|
await act(async () => {
|
||||||
|
deleteButton.click();
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockClient.agents.delete).toHaveBeenCalledWith("agent_123");
|
||||||
|
expect(global.confirm).toHaveBeenCalledWith(
|
||||||
|
"Are you sure you want to delete this agent? This action cannot be undone and will delete all associated sessions."
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
(global.confirm as jest.Mock).mockRestore();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("does not delete agent when cancelled", async () => {
|
||||||
|
global.confirm = jest.fn(() => false);
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
const deleteButton = screen.getByTitle("Delete current agent");
|
||||||
|
await act(async () => {
|
||||||
|
deleteButton.click();
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(global.confirm).toHaveBeenCalled();
|
||||||
|
expect(mockClient.agents.delete).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
(global.confirm as jest.Mock).mockRestore();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Error Handling", () => {
|
||||||
|
test("handles agent loading errors gracefully", async () => {
|
||||||
|
mockClient.agents.list.mockRejectedValue(
|
||||||
|
new Error("Failed to load agents")
|
||||||
|
);
|
||||||
|
const consoleSpy = jest
|
||||||
|
.spyOn(console, "error")
|
||||||
|
.mockImplementation(() => {});
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(consoleSpy).toHaveBeenCalledWith(
|
||||||
|
"Error fetching agents:",
|
||||||
|
expect.any(Error)
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
|
||||||
|
|
||||||
|
consoleSpy.mockRestore();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("handles model loading errors gracefully", async () => {
|
||||||
|
mockClient.models.list.mockRejectedValue(
|
||||||
|
new Error("Failed to load models")
|
||||||
|
);
|
||||||
|
const consoleSpy = jest
|
||||||
|
.spyOn(console, "error")
|
||||||
|
.mockImplementation(() => {});
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(consoleSpy).toHaveBeenCalledWith(
|
||||||
|
"Error fetching models:",
|
||||||
|
expect.any(Error)
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
consoleSpy.mockRestore();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
File diff suppressed because it is too large
Load diff
Binary file not shown.
|
Before Width: | Height: | Size: 25 KiB |
|
|
@ -120,3 +120,44 @@
|
||||||
@apply bg-background text-foreground;
|
@apply bg-background text-foreground;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@layer utilities {
|
||||||
|
.animate-typing-dot-1 {
|
||||||
|
animation: typing-dot-bounce-1 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
.animate-typing-dot-2 {
|
||||||
|
animation: typing-dot-bounce-2 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
.animate-typing-dot-3 {
|
||||||
|
animation: typing-dot-bounce-3 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes typing-dot-bounce-1 {
|
||||||
|
0%, 15%, 85%, 100% {
|
||||||
|
transform: translateY(0);
|
||||||
|
}
|
||||||
|
7.5% {
|
||||||
|
transform: translateY(-6px);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes typing-dot-bounce-2 {
|
||||||
|
0%, 15%, 35%, 85%, 100% {
|
||||||
|
transform: translateY(0);
|
||||||
|
}
|
||||||
|
25% {
|
||||||
|
transform: translateY(-6px);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes typing-dot-bounce-3 {
|
||||||
|
0%, 35%, 55%, 85%, 100% {
|
||||||
|
transform: translateY(0);
|
||||||
|
}
|
||||||
|
45% {
|
||||||
|
transform: translateY(-6px);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,9 @@ const geistMono = Geist_Mono({
|
||||||
export const metadata: Metadata = {
|
export const metadata: Metadata = {
|
||||||
title: "Llama Stack",
|
title: "Llama Stack",
|
||||||
description: "Llama Stack UI",
|
description: "Llama Stack UI",
|
||||||
|
icons: {
|
||||||
|
icon: "/favicon.ico",
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
import { SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
|
import { SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
|
||||||
|
|
|
||||||
|
|
@ -161,10 +161,12 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
||||||
|
|
||||||
const isUser = role === "user";
|
const isUser = role === "user";
|
||||||
|
|
||||||
const formattedTime = createdAt?.toLocaleTimeString("en-US", {
|
const formattedTime = createdAt
|
||||||
hour: "2-digit",
|
? new Date(createdAt).toLocaleTimeString("en-US", {
|
||||||
minute: "2-digit",
|
hour: "2-digit",
|
||||||
});
|
minute: "2-digit",
|
||||||
|
})
|
||||||
|
: undefined;
|
||||||
|
|
||||||
if (isUser) {
|
if (isUser) {
|
||||||
return (
|
return (
|
||||||
|
|
@ -185,7 +187,7 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
||||||
|
|
||||||
{showTimeStamp && createdAt ? (
|
{showTimeStamp && createdAt ? (
|
||||||
<time
|
<time
|
||||||
dateTime={createdAt.toISOString()}
|
dateTime={new Date(createdAt).toISOString()}
|
||||||
className={cn(
|
className={cn(
|
||||||
"mt-1 block px-1 text-xs opacity-50",
|
"mt-1 block px-1 text-xs opacity-50",
|
||||||
animation !== "none" && "duration-500 animate-in fade-in-0"
|
animation !== "none" && "duration-500 animate-in fade-in-0"
|
||||||
|
|
@ -220,7 +222,7 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
||||||
|
|
||||||
{showTimeStamp && createdAt ? (
|
{showTimeStamp && createdAt ? (
|
||||||
<time
|
<time
|
||||||
dateTime={createdAt.toISOString()}
|
dateTime={new Date(createdAt).toISOString()}
|
||||||
className={cn(
|
className={cn(
|
||||||
"mt-1 block px-1 text-xs opacity-50",
|
"mt-1 block px-1 text-xs opacity-50",
|
||||||
animation !== "none" && "duration-500 animate-in fade-in-0"
|
animation !== "none" && "duration-500 animate-in fade-in-0"
|
||||||
|
|
@ -262,7 +264,7 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
||||||
|
|
||||||
{showTimeStamp && createdAt ? (
|
{showTimeStamp && createdAt ? (
|
||||||
<time
|
<time
|
||||||
dateTime={createdAt.toISOString()}
|
dateTime={new Date(createdAt).toISOString()}
|
||||||
className={cn(
|
className={cn(
|
||||||
"mt-1 block px-1 text-xs opacity-50",
|
"mt-1 block px-1 text-xs opacity-50",
|
||||||
animation !== "none" && "duration-500 animate-in fade-in-0"
|
animation !== "none" && "duration-500 animate-in fade-in-0"
|
||||||
|
|
|
||||||
345
llama_stack/ui/components/chat-playground/conversations.test.tsx
Normal file
345
llama_stack/ui/components/chat-playground/conversations.test.tsx
Normal file
|
|
@ -0,0 +1,345 @@
|
||||||
|
import React from "react";
|
||||||
|
import { render, screen, waitFor, act } from "@testing-library/react";
|
||||||
|
import "@testing-library/jest-dom";
|
||||||
|
import { Conversations, SessionUtils } from "./conversations";
|
||||||
|
import type { Message } from "@/components/chat-playground/chat-message";
|
||||||
|
|
||||||
|
interface ChatSession {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
messages: Message[];
|
||||||
|
selectedModel: string;
|
||||||
|
systemMessage: string;
|
||||||
|
agentId: string;
|
||||||
|
createdAt: number;
|
||||||
|
updatedAt: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
const mockOnSessionChange = jest.fn();
|
||||||
|
const mockOnNewSession = jest.fn();
|
||||||
|
|
||||||
|
// Mock the auth client
|
||||||
|
const mockClient = {
|
||||||
|
agents: {
|
||||||
|
session: {
|
||||||
|
list: jest.fn(),
|
||||||
|
create: jest.fn(),
|
||||||
|
delete: jest.fn(),
|
||||||
|
retrieve: jest.fn(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Mock the useAuthClient hook
|
||||||
|
jest.mock("@/hooks/use-auth-client", () => ({
|
||||||
|
useAuthClient: jest.fn(() => mockClient),
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Mock additional SessionUtils methods that are now being used
|
||||||
|
jest.mock("./conversations", () => {
|
||||||
|
const actual = jest.requireActual("./conversations");
|
||||||
|
return {
|
||||||
|
...actual,
|
||||||
|
SessionUtils: {
|
||||||
|
...actual.SessionUtils,
|
||||||
|
saveSessionData: jest.fn(),
|
||||||
|
loadSessionData: jest.fn(),
|
||||||
|
saveAgentConfig: jest.fn(),
|
||||||
|
loadAgentConfig: jest.fn(),
|
||||||
|
clearAgentCache: jest.fn(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
const localStorageMock = {
|
||||||
|
getItem: jest.fn(),
|
||||||
|
setItem: jest.fn(),
|
||||||
|
removeItem: jest.fn(),
|
||||||
|
clear: jest.fn(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Object.defineProperty(window, "localStorage", {
|
||||||
|
value: localStorageMock,
|
||||||
|
writable: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Mock crypto.randomUUID for test environment
|
||||||
|
let uuidCounter = 0;
|
||||||
|
Object.defineProperty(globalThis, "crypto", {
|
||||||
|
value: {
|
||||||
|
randomUUID: jest.fn(() => `test-uuid-${++uuidCounter}`),
|
||||||
|
},
|
||||||
|
writable: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("SessionManager", () => {
|
||||||
|
const mockSession: ChatSession = {
|
||||||
|
id: "session_123",
|
||||||
|
name: "Test Session",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
id: "msg_1",
|
||||||
|
role: "user",
|
||||||
|
content: "Hello",
|
||||||
|
createdAt: new Date(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
selectedModel: "test-model",
|
||||||
|
systemMessage: "You are a helpful assistant.",
|
||||||
|
agentId: "agent_123",
|
||||||
|
createdAt: 1710000000,
|
||||||
|
updatedAt: 1710001000,
|
||||||
|
};
|
||||||
|
|
||||||
|
const mockAgentSessions = [
|
||||||
|
{
|
||||||
|
session_id: "session_123",
|
||||||
|
session_name: "Test Session",
|
||||||
|
started_at: "2024-01-01T00:00:00Z",
|
||||||
|
turns: [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
session_id: "session_456",
|
||||||
|
session_name: "Another Session",
|
||||||
|
started_at: "2024-01-01T01:00:00Z",
|
||||||
|
turns: [],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
localStorageMock.getItem.mockReturnValue(null);
|
||||||
|
localStorageMock.setItem.mockImplementation(() => {});
|
||||||
|
mockClient.agents.session.list.mockResolvedValue({
|
||||||
|
data: mockAgentSessions,
|
||||||
|
});
|
||||||
|
mockClient.agents.session.create.mockResolvedValue({
|
||||||
|
session_id: "new_session_123",
|
||||||
|
});
|
||||||
|
mockClient.agents.session.delete.mockResolvedValue(undefined);
|
||||||
|
mockClient.agents.session.retrieve.mockResolvedValue({
|
||||||
|
session_id: "test-session",
|
||||||
|
session_name: "Test Session",
|
||||||
|
started_at: new Date().toISOString(),
|
||||||
|
turns: [],
|
||||||
|
});
|
||||||
|
uuidCounter = 0; // Reset UUID counter for consistent test behavior
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Component Rendering", () => {
|
||||||
|
test("does not render when no agent is selected", async () => {
|
||||||
|
const { container } = await act(async () => {
|
||||||
|
return render(
|
||||||
|
<Conversations
|
||||||
|
selectedAgentId=""
|
||||||
|
currentSession={null}
|
||||||
|
onSessionChange={mockOnSessionChange}
|
||||||
|
onNewSession={mockOnNewSession}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(container.firstChild).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders loading state initially", async () => {
|
||||||
|
mockClient.agents.session.list.mockImplementation(
|
||||||
|
() => new Promise(() => {}) // Never resolves to simulate loading
|
||||||
|
);
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(
|
||||||
|
<Conversations
|
||||||
|
selectedAgentId="agent_123"
|
||||||
|
currentSession={null}
|
||||||
|
onSessionChange={mockOnSessionChange}
|
||||||
|
onNewSession={mockOnNewSession}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(screen.getByText("Select Session")).toBeInTheDocument();
|
||||||
|
// When loading, the "+ New" button should be disabled
|
||||||
|
expect(screen.getByText("+ New")).toBeDisabled();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders session selector when agent sessions are loaded", async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(
|
||||||
|
<Conversations
|
||||||
|
selectedAgentId="agent_123"
|
||||||
|
currentSession={null}
|
||||||
|
onSessionChange={mockOnSessionChange}
|
||||||
|
onNewSession={mockOnNewSession}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("Select Session")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders current session name when session is selected", async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(
|
||||||
|
<Conversations
|
||||||
|
selectedAgentId="agent_123"
|
||||||
|
currentSession={mockSession}
|
||||||
|
onSessionChange={mockOnSessionChange}
|
||||||
|
onNewSession={mockOnNewSession}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("Test Session")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Agent API Integration", () => {
|
||||||
|
test("loads sessions from agent API on mount", async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(
|
||||||
|
<Conversations
|
||||||
|
selectedAgentId="agent_123"
|
||||||
|
currentSession={mockSession}
|
||||||
|
onSessionChange={mockOnSessionChange}
|
||||||
|
onNewSession={mockOnNewSession}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockClient.agents.session.list).toHaveBeenCalledWith(
|
||||||
|
"agent_123"
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("handles API errors gracefully", async () => {
|
||||||
|
mockClient.agents.session.list.mockRejectedValue(new Error("API Error"));
|
||||||
|
const consoleSpy = jest
|
||||||
|
.spyOn(console, "error")
|
||||||
|
.mockImplementation(() => {});
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(
|
||||||
|
<Conversations
|
||||||
|
selectedAgentId="agent_123"
|
||||||
|
currentSession={mockSession}
|
||||||
|
onSessionChange={mockOnSessionChange}
|
||||||
|
onNewSession={mockOnNewSession}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(consoleSpy).toHaveBeenCalledWith(
|
||||||
|
"Error loading agent sessions:",
|
||||||
|
expect.any(Error)
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
consoleSpy.mockRestore();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Error Handling", () => {
|
||||||
|
test("component renders without crashing when API is unavailable", async () => {
|
||||||
|
mockClient.agents.session.list.mockRejectedValue(
|
||||||
|
new Error("Network Error")
|
||||||
|
);
|
||||||
|
const consoleSpy = jest
|
||||||
|
.spyOn(console, "error")
|
||||||
|
.mockImplementation(() => {});
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(
|
||||||
|
<Conversations
|
||||||
|
selectedAgentId="agent_123"
|
||||||
|
currentSession={mockSession}
|
||||||
|
onSessionChange={mockOnSessionChange}
|
||||||
|
onNewSession={mockOnNewSession}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should still render the session manager with the select trigger
|
||||||
|
expect(screen.getByRole("combobox")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("+ New")).toBeInTheDocument();
|
||||||
|
consoleSpy.mockRestore();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("SessionUtils", () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
localStorageMock.getItem.mockReturnValue(null);
|
||||||
|
localStorageMock.setItem.mockImplementation(() => {});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("saveCurrentSessionId", () => {
|
||||||
|
test("saves session ID to localStorage", () => {
|
||||||
|
SessionUtils.saveCurrentSessionId("test-session-id");
|
||||||
|
|
||||||
|
expect(localStorageMock.setItem).toHaveBeenCalledWith(
|
||||||
|
"chat-playground-current-session",
|
||||||
|
"test-session-id"
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("createDefaultSession", () => {
|
||||||
|
test("creates default session with agent ID", () => {
|
||||||
|
const result = SessionUtils.createDefaultSession("agent_123");
|
||||||
|
|
||||||
|
expect(result).toEqual(
|
||||||
|
expect.objectContaining({
|
||||||
|
name: "Default Session",
|
||||||
|
messages: [],
|
||||||
|
selectedModel: "",
|
||||||
|
systemMessage: "You are a helpful assistant.",
|
||||||
|
agentId: "agent_123",
|
||||||
|
})
|
||||||
|
);
|
||||||
|
expect(result.id).toBeTruthy();
|
||||||
|
expect(result.createdAt).toBeTruthy();
|
||||||
|
expect(result.updatedAt).toBeTruthy();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("creates default session with inherited model", () => {
|
||||||
|
const result = SessionUtils.createDefaultSession(
|
||||||
|
"agent_123",
|
||||||
|
"inherited-model"
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result.selectedModel).toBe("inherited-model");
|
||||||
|
expect(result.agentId).toBe("agent_123");
|
||||||
|
});
|
||||||
|
|
||||||
|
test("creates unique session IDs", () => {
|
||||||
|
const originalNow = Date.now;
|
||||||
|
let mockTime = 1710005000;
|
||||||
|
Date.now = jest.fn(() => ++mockTime);
|
||||||
|
|
||||||
|
const session1 = SessionUtils.createDefaultSession("agent_123");
|
||||||
|
const session2 = SessionUtils.createDefaultSession("agent_123");
|
||||||
|
|
||||||
|
expect(session1.id).not.toBe(session2.id);
|
||||||
|
|
||||||
|
Date.now = originalNow;
|
||||||
|
});
|
||||||
|
|
||||||
|
test("sets creation and update timestamps", () => {
|
||||||
|
const result = SessionUtils.createDefaultSession("agent_123");
|
||||||
|
|
||||||
|
expect(result.createdAt).toBeTruthy();
|
||||||
|
expect(result.updatedAt).toBeTruthy();
|
||||||
|
expect(typeof result.createdAt).toBe("number");
|
||||||
|
expect(typeof result.updatedAt).toBe("number");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
568
llama_stack/ui/components/chat-playground/conversations.tsx
Normal file
568
llama_stack/ui/components/chat-playground/conversations.tsx
Normal file
|
|
@ -0,0 +1,568 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useState, useEffect, useCallback } from "react";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import {
|
||||||
|
Select,
|
||||||
|
SelectContent,
|
||||||
|
SelectItem,
|
||||||
|
SelectTrigger,
|
||||||
|
SelectValue,
|
||||||
|
} from "@/components/ui/select";
|
||||||
|
import { Input } from "@/components/ui/input";
|
||||||
|
import { Card } from "@/components/ui/card";
|
||||||
|
import { Trash2 } from "lucide-react";
|
||||||
|
import type { Message } from "@/components/chat-playground/chat-message";
|
||||||
|
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||||
|
import type {
|
||||||
|
Session,
|
||||||
|
SessionCreateParams,
|
||||||
|
} from "llama-stack-client/resources/agents";
|
||||||
|
|
||||||
|
export interface ChatSession {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
messages: Message[];
|
||||||
|
selectedModel: string;
|
||||||
|
systemMessage: string;
|
||||||
|
agentId: string;
|
||||||
|
session?: Session;
|
||||||
|
createdAt: number;
|
||||||
|
updatedAt: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface SessionManagerProps {
|
||||||
|
currentSession: ChatSession | null;
|
||||||
|
onSessionChange: (session: ChatSession) => void;
|
||||||
|
onNewSession: () => void;
|
||||||
|
selectedAgentId: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const CURRENT_SESSION_KEY = "chat-playground-current-session";
|
||||||
|
|
||||||
|
// ensures this only happens client side
|
||||||
|
const safeLocalStorage = {
|
||||||
|
getItem: (key: string): string | null => {
|
||||||
|
if (typeof window === "undefined") return null;
|
||||||
|
try {
|
||||||
|
return localStorage.getItem(key);
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Error accessing localStorage:", err);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
setItem: (key: string, value: string): void => {
|
||||||
|
if (typeof window === "undefined") return;
|
||||||
|
try {
|
||||||
|
localStorage.setItem(key, value);
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Error writing to localStorage:", err);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
removeItem: (key: string): void => {
|
||||||
|
if (typeof window === "undefined") return;
|
||||||
|
try {
|
||||||
|
localStorage.removeItem(key);
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Error removing from localStorage:", err);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const generateSessionId = (): string => {
|
||||||
|
return globalThis.crypto.randomUUID();
|
||||||
|
};
|
||||||
|
|
||||||
|
export function Conversations({
|
||||||
|
currentSession,
|
||||||
|
onSessionChange,
|
||||||
|
selectedAgentId,
|
||||||
|
}: SessionManagerProps) {
|
||||||
|
const [sessions, setSessions] = useState<ChatSession[]>([]);
|
||||||
|
const [showCreateForm, setShowCreateForm] = useState(false);
|
||||||
|
const [newSessionName, setNewSessionName] = useState("");
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const client = useAuthClient();
|
||||||
|
|
||||||
|
const loadAgentSessions = useCallback(async () => {
|
||||||
|
if (!selectedAgentId) return;
|
||||||
|
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
const response = await client.agents.session.list(selectedAgentId);
|
||||||
|
console.log("Sessions response:", response);
|
||||||
|
|
||||||
|
if (!response.data || !Array.isArray(response.data)) {
|
||||||
|
console.warn("Invalid sessions response, starting fresh");
|
||||||
|
setSessions([]);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const agentSessions: ChatSession[] = response.data
|
||||||
|
.filter(sessionData => {
|
||||||
|
const isValid =
|
||||||
|
sessionData &&
|
||||||
|
typeof sessionData === "object" &&
|
||||||
|
sessionData.session_id &&
|
||||||
|
sessionData.session_name;
|
||||||
|
if (!isValid) {
|
||||||
|
console.warn("Filtering out invalid session:", sessionData);
|
||||||
|
}
|
||||||
|
return isValid;
|
||||||
|
})
|
||||||
|
.map(sessionData => ({
|
||||||
|
id: sessionData.session_id,
|
||||||
|
name: sessionData.session_name,
|
||||||
|
messages: [],
|
||||||
|
selectedModel: currentSession?.selectedModel || "",
|
||||||
|
systemMessage:
|
||||||
|
currentSession?.systemMessage || "You are a helpful assistant.",
|
||||||
|
agentId: selectedAgentId,
|
||||||
|
session: sessionData,
|
||||||
|
createdAt: sessionData.started_at
|
||||||
|
? new Date(sessionData.started_at).getTime()
|
||||||
|
: Date.now(),
|
||||||
|
updatedAt: sessionData.started_at
|
||||||
|
? new Date(sessionData.started_at).getTime()
|
||||||
|
: Date.now(),
|
||||||
|
}));
|
||||||
|
setSessions(agentSessions);
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error loading agent sessions:", error);
|
||||||
|
setSessions([]);
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
}, [
|
||||||
|
selectedAgentId,
|
||||||
|
client,
|
||||||
|
currentSession?.selectedModel,
|
||||||
|
currentSession?.systemMessage,
|
||||||
|
]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (selectedAgentId) {
|
||||||
|
loadAgentSessions();
|
||||||
|
}
|
||||||
|
}, [selectedAgentId, loadAgentSessions]);
|
||||||
|
|
||||||
|
const createNewSession = async () => {
|
||||||
|
if (!selectedAgentId) return;
|
||||||
|
|
||||||
|
const sessionName =
|
||||||
|
newSessionName.trim() || `Session ${sessions.length + 1}`;
|
||||||
|
setLoading(true);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await client.agents.session.create(selectedAgentId, {
|
||||||
|
session_name: sessionName,
|
||||||
|
} as SessionCreateParams);
|
||||||
|
|
||||||
|
const newSession: ChatSession = {
|
||||||
|
id: response.session_id,
|
||||||
|
name: sessionName,
|
||||||
|
messages: [],
|
||||||
|
selectedModel: currentSession?.selectedModel || "",
|
||||||
|
systemMessage:
|
||||||
|
currentSession?.systemMessage || "You are a helpful assistant.",
|
||||||
|
agentId: selectedAgentId,
|
||||||
|
createdAt: Date.now(),
|
||||||
|
updatedAt: Date.now(),
|
||||||
|
};
|
||||||
|
|
||||||
|
setSessions(prev => [...prev, newSession]);
|
||||||
|
SessionUtils.saveCurrentSessionId(newSession.id, selectedAgentId);
|
||||||
|
onSessionChange(newSession);
|
||||||
|
|
||||||
|
setNewSessionName("");
|
||||||
|
setShowCreateForm(false);
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error creating session:", error);
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const loadSessionMessages = useCallback(
|
||||||
|
async (agentId: string, sessionId: string): Promise<Message[]> => {
|
||||||
|
try {
|
||||||
|
const session = await client.agents.session.retrieve(
|
||||||
|
agentId,
|
||||||
|
sessionId
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!session || !session.turns || !Array.isArray(session.turns)) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const messages: Message[] = [];
|
||||||
|
for (const turn of session.turns) {
|
||||||
|
// Add user messages from input_messages
|
||||||
|
if (turn.input_messages && Array.isArray(turn.input_messages)) {
|
||||||
|
for (const input of turn.input_messages) {
|
||||||
|
if (input.role === "user" && input.content) {
|
||||||
|
messages.push({
|
||||||
|
id: `${turn.turn_id}-user-${messages.length}`,
|
||||||
|
role: "user",
|
||||||
|
content:
|
||||||
|
typeof input.content === "string"
|
||||||
|
? input.content
|
||||||
|
: JSON.stringify(input.content),
|
||||||
|
createdAt: new Date(turn.started_at || Date.now()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add assistant message from output_message
|
||||||
|
if (turn.output_message && turn.output_message.content) {
|
||||||
|
messages.push({
|
||||||
|
id: `${turn.turn_id}-assistant-${messages.length}`,
|
||||||
|
role: "assistant",
|
||||||
|
content:
|
||||||
|
typeof turn.output_message.content === "string"
|
||||||
|
? turn.output_message.content
|
||||||
|
: JSON.stringify(turn.output_message.content),
|
||||||
|
createdAt: new Date(
|
||||||
|
turn.completed_at || turn.started_at || Date.now()
|
||||||
|
),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages;
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error loading session messages:", error);
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[client]
|
||||||
|
);
|
||||||
|
|
||||||
|
const switchToSession = useCallback(
|
||||||
|
async (sessionId: string) => {
|
||||||
|
const session = sessions.find(s => s.id === sessionId);
|
||||||
|
if (session) {
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
// Load messages for this session
|
||||||
|
const messages = await loadSessionMessages(
|
||||||
|
selectedAgentId,
|
||||||
|
sessionId
|
||||||
|
);
|
||||||
|
const sessionWithMessages = {
|
||||||
|
...session,
|
||||||
|
messages,
|
||||||
|
};
|
||||||
|
|
||||||
|
SessionUtils.saveCurrentSessionId(sessionId, selectedAgentId);
|
||||||
|
onSessionChange(sessionWithMessages);
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error switching to session:", error);
|
||||||
|
// Fallback to session without messages
|
||||||
|
SessionUtils.saveCurrentSessionId(sessionId, selectedAgentId);
|
||||||
|
onSessionChange(session);
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[sessions, selectedAgentId, loadSessionMessages, onSessionChange]
|
||||||
|
);
|
||||||
|
|
||||||
|
const deleteSession = async (sessionId: string) => {
|
||||||
|
if (sessions.length <= 1 || !selectedAgentId) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
confirm(
|
||||||
|
"Are you sure you want to delete this session? This action cannot be undone."
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
await client.agents.session.delete(selectedAgentId, sessionId);
|
||||||
|
|
||||||
|
const updatedSessions = sessions.filter(s => s.id !== sessionId);
|
||||||
|
setSessions(updatedSessions);
|
||||||
|
|
||||||
|
if (currentSession?.id === sessionId) {
|
||||||
|
const newCurrentSession = updatedSessions[0] || null;
|
||||||
|
if (newCurrentSession) {
|
||||||
|
SessionUtils.saveCurrentSessionId(
|
||||||
|
newCurrentSession.id,
|
||||||
|
selectedAgentId
|
||||||
|
);
|
||||||
|
onSessionChange(newCurrentSession);
|
||||||
|
} else {
|
||||||
|
SessionUtils.clearCurrentSession(selectedAgentId);
|
||||||
|
onNewSession();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error deleting session:", error);
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (currentSession) {
|
||||||
|
setSessions(prevSessions => {
|
||||||
|
const updatedSessions = prevSessions.map(session =>
|
||||||
|
session.id === currentSession.id ? currentSession : session
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!prevSessions.find(s => s.id === currentSession.id)) {
|
||||||
|
updatedSessions.push(currentSession);
|
||||||
|
}
|
||||||
|
|
||||||
|
return updatedSessions;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [currentSession]);
|
||||||
|
|
||||||
|
// Don't render if no agent is selected
|
||||||
|
if (!selectedAgentId) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="relative">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<Select
|
||||||
|
value={currentSession?.id || ""}
|
||||||
|
onValueChange={switchToSession}
|
||||||
|
>
|
||||||
|
<SelectTrigger className="w-[200px]">
|
||||||
|
<SelectValue placeholder="Select Session" />
|
||||||
|
</SelectTrigger>
|
||||||
|
<SelectContent>
|
||||||
|
{sessions.map(session => (
|
||||||
|
<SelectItem key={session.id} value={session.id}>
|
||||||
|
{session.name}
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
|
||||||
|
<Button
|
||||||
|
onClick={() => setShowCreateForm(true)}
|
||||||
|
variant="outline"
|
||||||
|
size="sm"
|
||||||
|
disabled={loading || !selectedAgentId}
|
||||||
|
>
|
||||||
|
+ New
|
||||||
|
</Button>
|
||||||
|
|
||||||
|
{currentSession && sessions.length > 1 && (
|
||||||
|
<Button
|
||||||
|
onClick={() => deleteSession(currentSession.id)}
|
||||||
|
variant="outline"
|
||||||
|
size="sm"
|
||||||
|
className="text-destructive hover:text-destructive hover:bg-destructive/10"
|
||||||
|
title="Delete current session"
|
||||||
|
>
|
||||||
|
<Trash2 className="h-3 w-3" />
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{showCreateForm && (
|
||||||
|
<Card className="absolute top-full left-0 mt-2 p-4 space-y-3 w-80 z-50 bg-background border shadow-lg">
|
||||||
|
<h3 className="text-md font-semibold">Create New Session</h3>
|
||||||
|
|
||||||
|
<Input
|
||||||
|
value={newSessionName}
|
||||||
|
onChange={e => setNewSessionName(e.target.value)}
|
||||||
|
placeholder="Session name (optional)"
|
||||||
|
onKeyDown={e => {
|
||||||
|
if (e.key === "Enter") {
|
||||||
|
createNewSession();
|
||||||
|
} else if (e.key === "Escape") {
|
||||||
|
setShowCreateForm(false);
|
||||||
|
setNewSessionName("");
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<div className="flex gap-2">
|
||||||
|
<Button
|
||||||
|
onClick={createNewSession}
|
||||||
|
className="flex-1"
|
||||||
|
disabled={loading}
|
||||||
|
>
|
||||||
|
{loading ? "Creating..." : "Create"}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
onClick={() => {
|
||||||
|
setShowCreateForm(false);
|
||||||
|
setNewSessionName("");
|
||||||
|
}}
|
||||||
|
className="flex-1"
|
||||||
|
>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</Card>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{currentSession && sessions.length > 1 && (
|
||||||
|
<div className="absolute top-full left-0 mt-1 text-xs text-gray-500 whitespace-nowrap">
|
||||||
|
{sessions.length} sessions • Current: {currentSession.name}
|
||||||
|
{currentSession.messages.length > 0 &&
|
||||||
|
` • ${currentSession.messages.length} messages`}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export const SessionUtils = {
|
||||||
|
loadCurrentSessionId: (agentId?: string): string | null => {
|
||||||
|
const key = agentId
|
||||||
|
? `${CURRENT_SESSION_KEY}-${agentId}`
|
||||||
|
: CURRENT_SESSION_KEY;
|
||||||
|
return safeLocalStorage.getItem(key);
|
||||||
|
},
|
||||||
|
|
||||||
|
saveCurrentSessionId: (sessionId: string, agentId?: string) => {
|
||||||
|
const key = agentId
|
||||||
|
? `${CURRENT_SESSION_KEY}-${agentId}`
|
||||||
|
: CURRENT_SESSION_KEY;
|
||||||
|
safeLocalStorage.setItem(key, sessionId);
|
||||||
|
},
|
||||||
|
|
||||||
|
createDefaultSession: (
|
||||||
|
agentId: string,
|
||||||
|
inheritModel?: string
|
||||||
|
): ChatSession => ({
|
||||||
|
id: generateSessionId(),
|
||||||
|
name: "Default Session",
|
||||||
|
messages: [],
|
||||||
|
selectedModel: inheritModel || "",
|
||||||
|
systemMessage: "You are a helpful assistant.",
|
||||||
|
agentId,
|
||||||
|
createdAt: Date.now(),
|
||||||
|
updatedAt: Date.now(),
|
||||||
|
}),
|
||||||
|
|
||||||
|
clearCurrentSession: (agentId?: string) => {
|
||||||
|
const key = agentId
|
||||||
|
? `${CURRENT_SESSION_KEY}-${agentId}`
|
||||||
|
: CURRENT_SESSION_KEY;
|
||||||
|
safeLocalStorage.removeItem(key);
|
||||||
|
},
|
||||||
|
|
||||||
|
loadCurrentAgentId: (): string | null => {
|
||||||
|
return safeLocalStorage.getItem("chat-playground-current-agent");
|
||||||
|
},
|
||||||
|
|
||||||
|
saveCurrentAgentId: (agentId: string) => {
|
||||||
|
safeLocalStorage.setItem("chat-playground-current-agent", agentId);
|
||||||
|
},
|
||||||
|
|
||||||
|
// Comprehensive session caching
|
||||||
|
saveSessionData: (agentId: string, sessionData: ChatSession) => {
|
||||||
|
const key = `chat-playground-session-data-${agentId}-${sessionData.id}`;
|
||||||
|
safeLocalStorage.setItem(
|
||||||
|
key,
|
||||||
|
JSON.stringify({
|
||||||
|
...sessionData,
|
||||||
|
cachedAt: Date.now(),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
|
||||||
|
loadSessionData: (agentId: string, sessionId: string): ChatSession | null => {
|
||||||
|
const key = `chat-playground-session-data-${agentId}-${sessionId}`;
|
||||||
|
const cached = safeLocalStorage.getItem(key);
|
||||||
|
if (!cached) return null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const data = JSON.parse(cached);
|
||||||
|
// Check if cache is fresh (less than 1 hour old)
|
||||||
|
const cacheAge = Date.now() - (data.cachedAt || 0);
|
||||||
|
if (cacheAge > 60 * 60 * 1000) {
|
||||||
|
safeLocalStorage.removeItem(key);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert date strings back to Date objects
|
||||||
|
return {
|
||||||
|
...data,
|
||||||
|
messages: data.messages.map(
|
||||||
|
(msg: { createdAt: string; [key: string]: unknown }) => ({
|
||||||
|
...msg,
|
||||||
|
createdAt: new Date(msg.createdAt),
|
||||||
|
})
|
||||||
|
),
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error parsing cached session data:", error);
|
||||||
|
safeLocalStorage.removeItem(key);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// Agent config caching
|
||||||
|
saveAgentConfig: (
|
||||||
|
agentId: string,
|
||||||
|
config: {
|
||||||
|
toolgroups?: Array<
|
||||||
|
string | { name: string; args: Record<string, unknown> }
|
||||||
|
>;
|
||||||
|
[key: string]: unknown;
|
||||||
|
}
|
||||||
|
) => {
|
||||||
|
const key = `chat-playground-agent-config-${agentId}`;
|
||||||
|
safeLocalStorage.setItem(
|
||||||
|
key,
|
||||||
|
JSON.stringify({
|
||||||
|
config,
|
||||||
|
cachedAt: Date.now(),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
|
||||||
|
loadAgentConfig: (
|
||||||
|
agentId: string
|
||||||
|
): {
|
||||||
|
toolgroups?: Array<
|
||||||
|
string | { name: string; args: Record<string, unknown> }
|
||||||
|
>;
|
||||||
|
[key: string]: unknown;
|
||||||
|
} | null => {
|
||||||
|
const key = `chat-playground-agent-config-${agentId}`;
|
||||||
|
const cached = safeLocalStorage.getItem(key);
|
||||||
|
if (!cached) return null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const data = JSON.parse(cached);
|
||||||
|
// Check if cache is fresh (less than 30 minutes old)
|
||||||
|
const cacheAge = Date.now() - (data.cachedAt || 0);
|
||||||
|
if (cacheAge > 30 * 60 * 1000) {
|
||||||
|
safeLocalStorage.removeItem(key);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return data.config;
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error parsing cached agent config:", error);
|
||||||
|
safeLocalStorage.removeItem(key);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// Clear all cached data for an agent
|
||||||
|
clearAgentCache: (agentId: string) => {
|
||||||
|
const keys = Object.keys(localStorage).filter(
|
||||||
|
key =>
|
||||||
|
key.includes(`chat-playground-session-data-${agentId}`) ||
|
||||||
|
key.includes(`chat-playground-agent-config-${agentId}`)
|
||||||
|
);
|
||||||
|
keys.forEach(key => safeLocalStorage.removeItem(key));
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
@ -5,9 +5,9 @@ export function TypingIndicator() {
|
||||||
<div className="justify-left flex space-x-1">
|
<div className="justify-left flex space-x-1">
|
||||||
<div className="rounded-lg bg-muted p-3">
|
<div className="rounded-lg bg-muted p-3">
|
||||||
<div className="flex -space-x-2.5">
|
<div className="flex -space-x-2.5">
|
||||||
<Dot className="h-5 w-5 animate-typing-dot-bounce" />
|
<Dot className="h-5 w-5 animate-typing-dot-1" />
|
||||||
<Dot className="h-5 w-5 animate-typing-dot-bounce [animation-delay:90ms]" />
|
<Dot className="h-5 w-5 animate-typing-dot-2" />
|
||||||
<Dot className="h-5 w-5 animate-typing-dot-bounce [animation-delay:180ms]" />
|
<Dot className="h-5 w-5 animate-typing-dot-3" />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ import {
|
||||||
} from "lucide-react";
|
} from "lucide-react";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { usePathname } from "next/navigation";
|
import { usePathname } from "next/navigation";
|
||||||
|
import Image from "next/image";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
|
|
@ -110,7 +111,16 @@ export function AppSidebar() {
|
||||||
return (
|
return (
|
||||||
<Sidebar>
|
<Sidebar>
|
||||||
<SidebarHeader>
|
<SidebarHeader>
|
||||||
<Link href="/">Llama Stack</Link>
|
<Link href="/" className="flex items-center gap-2 p-2">
|
||||||
|
<Image
|
||||||
|
src="/logo.webp"
|
||||||
|
alt="Llama Stack"
|
||||||
|
width={32}
|
||||||
|
height={32}
|
||||||
|
className="h-8 w-8"
|
||||||
|
/>
|
||||||
|
<span className="font-semibold text-lg">Llama Stack</span>
|
||||||
|
</Link>
|
||||||
</SidebarHeader>
|
</SidebarHeader>
|
||||||
<SidebarContent>
|
<SidebarContent>
|
||||||
<SidebarGroup>
|
<SidebarGroup>
|
||||||
|
|
|
||||||
BIN
llama_stack/ui/public/favicon.ico
Normal file
BIN
llama_stack/ui/public/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.2 KiB |
BIN
llama_stack/ui/public/logo.webp
Normal file
BIN
llama_stack/ui/public/logo.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 19 KiB |
|
|
@ -98,6 +98,7 @@ unit = [
|
||||||
"together",
|
"together",
|
||||||
"coverage",
|
"coverage",
|
||||||
"chromadb>=1.0.15",
|
"chromadb>=1.0.15",
|
||||||
|
"moto[s3]>=5.1.10",
|
||||||
]
|
]
|
||||||
# These are the core dependencies required for running integration tests. They are shared across all
|
# These are the core dependencies required for running integration tests. They are shared across all
|
||||||
# providers. If a provider requires additional dependencies, please add them to your environment
|
# providers. If a provider requires additional dependencies, please add them to your environment
|
||||||
|
|
|
||||||
|
|
@ -157,12 +157,14 @@ def get_config_class_info(config_class_path: str) -> dict[str, Any]:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def generate_provider_docs(provider_spec: Any, api_name: str) -> str:
|
def generate_provider_docs(progress, provider_spec: Any, api_name: str) -> str:
|
||||||
"""Generate markdown documentation for a provider."""
|
"""Generate markdown documentation for a provider."""
|
||||||
provider_type = provider_spec.provider_type
|
provider_type = provider_spec.provider_type
|
||||||
config_class = provider_spec.config_class
|
config_class = provider_spec.config_class
|
||||||
|
|
||||||
config_info = get_config_class_info(config_class)
|
config_info = get_config_class_info(config_class)
|
||||||
|
if "error" in config_info:
|
||||||
|
progress.print(config_info["error"])
|
||||||
|
|
||||||
md_lines = []
|
md_lines = []
|
||||||
md_lines.append(f"# {provider_type}")
|
md_lines.append(f"# {provider_type}")
|
||||||
|
|
@ -295,7 +297,7 @@ def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> N
|
||||||
filename = provider_type.replace("::", "_").replace(":", "_")
|
filename = provider_type.replace("::", "_").replace(":", "_")
|
||||||
provider_doc_file = doc_output_dir / f"{filename}.md"
|
provider_doc_file = doc_output_dir / f"{filename}.md"
|
||||||
|
|
||||||
provider_docs = generate_provider_docs(provider, api_name)
|
provider_docs = generate_provider_docs(progress, provider, api_name)
|
||||||
|
|
||||||
provider_doc_file.write_text(provider_docs)
|
provider_doc_file.write_text(provider_docs)
|
||||||
change_tracker.add_paths(provider_doc_file)
|
change_tracker.add_paths(provider_doc_file)
|
||||||
|
|
|
||||||
17
scripts/run-ui-linter.sh
Executable file
17
scripts/run-ui-linter.sh
Executable file
|
|
@ -0,0 +1,17 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
set -e
|
||||||
|
cd llama_stack/ui
|
||||||
|
|
||||||
|
if [ ! -d node_modules ] || [ ! -x node_modules/.bin/prettier ] || [ ! -x node_modules/.bin/eslint ]; then
|
||||||
|
echo "UI dependencies not installed, skipping prettier/linter check"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
npm run format
|
||||||
|
npm run lint
|
||||||
251
tests/unit/providers/files/test_s3_files.py
Normal file
251
tests/unit/providers/files/test_s3_files.py
Normal file
|
|
@ -0,0 +1,251 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
import pytest
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
from moto import mock_aws
|
||||||
|
|
||||||
|
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||||
|
from llama_stack.apis.files import OpenAIFilePurpose
|
||||||
|
from llama_stack.providers.remote.files.s3 import (
|
||||||
|
S3FilesImplConfig,
|
||||||
|
get_adapter_impl,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MockUploadFile:
|
||||||
|
def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"):
|
||||||
|
self.content = content
|
||||||
|
self.filename = filename
|
||||||
|
self.content_type = content_type
|
||||||
|
|
||||||
|
async def read(self):
|
||||||
|
return self.content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def s3_config(tmp_path):
|
||||||
|
db_path = tmp_path / "s3_files_metadata.db"
|
||||||
|
|
||||||
|
return S3FilesImplConfig(
|
||||||
|
bucket_name="test-bucket",
|
||||||
|
region="not-a-region",
|
||||||
|
auto_create_bucket=True,
|
||||||
|
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def s3_client():
|
||||||
|
"""Create a mocked S3 client for testing."""
|
||||||
|
# we use `with mock_aws()` because @mock_aws decorator does not support being a generator
|
||||||
|
with mock_aws():
|
||||||
|
# must yield or the mock will be reset before it is used
|
||||||
|
yield boto3.client("s3")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def s3_provider(s3_config, s3_client):
|
||||||
|
"""Create an S3 files provider with mocked S3 for testing."""
|
||||||
|
provider = await get_adapter_impl(s3_config, {})
|
||||||
|
yield provider
|
||||||
|
await provider.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_text_file():
|
||||||
|
content = b"Hello, this is a test file for the S3 Files API!"
|
||||||
|
return MockUploadFile(content, "sample_text_file.txt")
|
||||||
|
|
||||||
|
|
||||||
|
class TestS3FilesImpl:
|
||||||
|
"""Test suite for S3 Files implementation."""
|
||||||
|
|
||||||
|
async def test_upload_file(self, s3_provider, sample_text_file, s3_client, s3_config):
|
||||||
|
"""Test successful file upload."""
|
||||||
|
sample_text_file.filename = "test_upload_file"
|
||||||
|
result = await s3_provider.openai_upload_file(
|
||||||
|
file=sample_text_file,
|
||||||
|
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.filename == sample_text_file.filename
|
||||||
|
assert result.purpose == OpenAIFilePurpose.ASSISTANTS
|
||||||
|
assert result.bytes == len(sample_text_file.content)
|
||||||
|
assert result.id.startswith("file-")
|
||||||
|
|
||||||
|
# Verify file exists in S3 backend
|
||||||
|
response = s3_client.head_object(Bucket=s3_config.bucket_name, Key=result.id)
|
||||||
|
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||||
|
|
||||||
|
async def test_list_files_empty(self, s3_provider):
|
||||||
|
"""Test listing files when no files exist."""
|
||||||
|
result = await s3_provider.openai_list_files()
|
||||||
|
|
||||||
|
assert len(result.data) == 0
|
||||||
|
assert not result.has_more
|
||||||
|
assert result.first_id == ""
|
||||||
|
assert result.last_id == ""
|
||||||
|
|
||||||
|
async def test_retrieve_file(self, s3_provider, sample_text_file):
|
||||||
|
"""Test retrieving file metadata."""
|
||||||
|
sample_text_file.filename = "test_retrieve_file"
|
||||||
|
uploaded = await s3_provider.openai_upload_file(
|
||||||
|
file=sample_text_file,
|
||||||
|
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
retrieved = await s3_provider.openai_retrieve_file(uploaded.id)
|
||||||
|
|
||||||
|
assert retrieved.id == uploaded.id
|
||||||
|
assert retrieved.filename == uploaded.filename
|
||||||
|
assert retrieved.purpose == uploaded.purpose
|
||||||
|
assert retrieved.bytes == uploaded.bytes
|
||||||
|
|
||||||
|
async def test_retrieve_file_content(self, s3_provider, sample_text_file):
|
||||||
|
"""Test retrieving file content."""
|
||||||
|
sample_text_file.filename = "test_retrieve_file_content"
|
||||||
|
uploaded = await s3_provider.openai_upload_file(
|
||||||
|
file=sample_text_file,
|
||||||
|
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await s3_provider.openai_retrieve_file_content(uploaded.id)
|
||||||
|
|
||||||
|
assert response.body == sample_text_file.content
|
||||||
|
assert response.headers["Content-Disposition"] == f'attachment; filename="{sample_text_file.filename}"'
|
||||||
|
|
||||||
|
async def test_delete_file(self, s3_provider, sample_text_file, s3_config, s3_client):
|
||||||
|
"""Test deleting a file."""
|
||||||
|
sample_text_file.filename = "test_delete_file"
|
||||||
|
uploaded = await s3_provider.openai_upload_file(
|
||||||
|
file=sample_text_file,
|
||||||
|
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
delete_response = await s3_provider.openai_delete_file(uploaded.id)
|
||||||
|
|
||||||
|
assert delete_response.id == uploaded.id
|
||||||
|
assert delete_response.deleted is True
|
||||||
|
|
||||||
|
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||||
|
await s3_provider.openai_retrieve_file(uploaded.id)
|
||||||
|
|
||||||
|
# Verify file is gone from S3 backend
|
||||||
|
with pytest.raises(ClientError) as exc_info:
|
||||||
|
s3_client.head_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
|
||||||
|
assert exc_info.value.response["Error"]["Code"] == "404"
|
||||||
|
|
||||||
|
async def test_list_files(self, s3_provider, sample_text_file):
|
||||||
|
"""Test listing files after uploading some."""
|
||||||
|
sample_text_file.filename = "test_list_files_with_content_file1"
|
||||||
|
file1 = await s3_provider.openai_upload_file(
|
||||||
|
file=sample_text_file,
|
||||||
|
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
file2_content = MockUploadFile(b"Second file content", "test_list_files_with_content_file2")
|
||||||
|
file2 = await s3_provider.openai_upload_file(
|
||||||
|
file=file2_content,
|
||||||
|
purpose=OpenAIFilePurpose.BATCH,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await s3_provider.openai_list_files()
|
||||||
|
|
||||||
|
assert len(result.data) == 2
|
||||||
|
file_ids = {f.id for f in result.data}
|
||||||
|
assert file1.id in file_ids
|
||||||
|
assert file2.id in file_ids
|
||||||
|
|
||||||
|
async def test_list_files_with_purpose_filter(self, s3_provider, sample_text_file):
|
||||||
|
"""Test listing files with purpose filter."""
|
||||||
|
sample_text_file.filename = "test_list_files_with_purpose_filter_file1"
|
||||||
|
file1 = await s3_provider.openai_upload_file(
|
||||||
|
file=sample_text_file,
|
||||||
|
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
file2_content = MockUploadFile(b"Batch file content", "test_list_files_with_purpose_filter_file2")
|
||||||
|
await s3_provider.openai_upload_file(
|
||||||
|
file=file2_content,
|
||||||
|
purpose=OpenAIFilePurpose.BATCH,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await s3_provider.openai_list_files(purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||||
|
|
||||||
|
assert len(result.data) == 1
|
||||||
|
assert result.data[0].id == file1.id
|
||||||
|
assert result.data[0].purpose == OpenAIFilePurpose.ASSISTANTS
|
||||||
|
|
||||||
|
async def test_nonexistent_file_retrieval(self, s3_provider):
|
||||||
|
"""Test retrieving a non-existent file raises error."""
|
||||||
|
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||||
|
await s3_provider.openai_retrieve_file("file-nonexistent")
|
||||||
|
|
||||||
|
async def test_nonexistent_file_content_retrieval(self, s3_provider):
|
||||||
|
"""Test retrieving content of a non-existent file raises error."""
|
||||||
|
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||||
|
await s3_provider.openai_retrieve_file_content("file-nonexistent")
|
||||||
|
|
||||||
|
async def test_nonexistent_file_deletion(self, s3_provider):
|
||||||
|
"""Test deleting a non-existent file raises error."""
|
||||||
|
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||||
|
await s3_provider.openai_delete_file("file-nonexistent")
|
||||||
|
|
||||||
|
async def test_upload_file_without_filename(self, s3_provider, sample_text_file):
|
||||||
|
"""Test uploading a file without a filename uses the fallback."""
|
||||||
|
del sample_text_file.filename
|
||||||
|
result = await s3_provider.openai_upload_file(
|
||||||
|
file=sample_text_file,
|
||||||
|
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.purpose == OpenAIFilePurpose.ASSISTANTS
|
||||||
|
assert result.bytes == len(sample_text_file.content)
|
||||||
|
|
||||||
|
retrieved = await s3_provider.openai_retrieve_file(result.id)
|
||||||
|
assert retrieved.filename == result.filename
|
||||||
|
|
||||||
|
async def test_file_operations_when_s3_object_deleted(self, s3_provider, sample_text_file, s3_config, s3_client):
|
||||||
|
"""Test file operations when S3 object is deleted but metadata exists (negative test)."""
|
||||||
|
sample_text_file.filename = "test_orphaned_metadata"
|
||||||
|
uploaded = await s3_provider.openai_upload_file(
|
||||||
|
file=sample_text_file,
|
||||||
|
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Directly delete the S3 object from the backend
|
||||||
|
s3_client.delete_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
|
||||||
|
|
||||||
|
with pytest.raises(ResourceNotFoundError, match="not found") as exc_info:
|
||||||
|
await s3_provider.openai_retrieve_file_content(uploaded.id)
|
||||||
|
assert uploaded.id in str(exc_info).lower()
|
||||||
|
|
||||||
|
listed_files = await s3_provider.openai_list_files()
|
||||||
|
assert uploaded.id not in [file.id for file in listed_files.data]
|
||||||
|
|
||||||
|
async def test_upload_file_s3_put_object_failure(self, s3_provider, sample_text_file, s3_config, s3_client):
|
||||||
|
"""Test that put_object failure results in exception and no orphaned metadata."""
|
||||||
|
sample_text_file.filename = "test_s3_put_object_failure"
|
||||||
|
|
||||||
|
def failing_put_object(*args, **kwargs):
|
||||||
|
raise ClientError(
|
||||||
|
error_response={"Error": {"Code": "SolarRadiation", "Message": "Bloop"}}, operation_name="PutObject"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(s3_provider.client, "put_object", side_effect=failing_put_object):
|
||||||
|
with pytest.raises(RuntimeError, match="Failed to upload file to S3"):
|
||||||
|
await s3_provider.openai_upload_file(
|
||||||
|
file=sample_text_file,
|
||||||
|
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
files_list = await s3_provider.openai_list_files()
|
||||||
|
assert len(files_list.data) == 0, "No file metadata should remain after failed upload"
|
||||||
109
uv.lock
generated
109
uv.lock
generated
|
|
@ -347,6 +347,34 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/ed/4d/1392562369b1139e741b30d624f09fe7091d17dd5579fae5732f044b12bb/blobfile-3.0.0-py3-none-any.whl", hash = "sha256:48ecc3307e622804bd8fe13bf6f40e6463c4439eba7a1f9ad49fd78aa63cc658", size = 75413, upload-time = "2024-08-27T00:02:51.518Z" },
|
{ url = "https://files.pythonhosted.org/packages/ed/4d/1392562369b1139e741b30d624f09fe7091d17dd5579fae5732f044b12bb/blobfile-3.0.0-py3-none-any.whl", hash = "sha256:48ecc3307e622804bd8fe13bf6f40e6463c4439eba7a1f9ad49fd78aa63cc658", size = 75413, upload-time = "2024-08-27T00:02:51.518Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "boto3"
|
||||||
|
version = "1.40.12"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "botocore" },
|
||||||
|
{ name = "jmespath" },
|
||||||
|
{ name = "s3transfer" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/41/19/2c4d140a7f99b5903b21b9ccd7253c71f147c346c3c632b2117444cf2d65/boto3-1.40.12.tar.gz", hash = "sha256:c6b32aee193fbd2eb84696d2b5b2410dcda9fb4a385e1926cff908377d222247", size = 111959, upload-time = "2025-08-18T19:30:23.827Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/16/6e/5a9dcf38ad87838fb99742c4a3ab1b7507ad3a02c8c27a9ccda7a0bb5709/boto3-1.40.12-py3-none-any.whl", hash = "sha256:3c3d6731390b5b11f5e489d5d9daa57f0c3e171efb63ac8f47203df9c71812b3", size = 140075, upload-time = "2025-08-18T19:30:22.494Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "botocore"
|
||||||
|
version = "1.40.12"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "jmespath" },
|
||||||
|
{ name = "python-dateutil" },
|
||||||
|
{ name = "urllib3" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/7d/b2/7933590fc5bca1980801b71e09db1a95581afff177cbf3c8a031d922885c/botocore-1.40.12.tar.gz", hash = "sha256:c6560578e799b47b762b7e555bd9c5dd5c29c5d23bd778a8a72e98c979b3c727", size = 14349930, upload-time = "2025-08-18T19:30:13.794Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1e/b6/65fd6e718c9538ba1462c9b71e9262bc723202ff203fe64ff66ff676d823/botocore-1.40.12-py3-none-any.whl", hash = "sha256:84e96004a8b426c5508f6b5600312d6271364269466a3a957dc377ad8effc438", size = 14018004, upload-time = "2025-08-18T19:30:09.054Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "braintrust-core"
|
name = "braintrust-core"
|
||||||
version = "0.0.59"
|
version = "0.0.59"
|
||||||
|
|
@ -1580,6 +1608,15 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/b3/4a/4175a563579e884192ba6e81725fc0448b042024419be8d83aa8a80a3f44/jiter-0.10.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3aa96f2abba33dc77f79b4cf791840230375f9534e5fac927ccceb58c5e604a5", size = 354213, upload-time = "2025-05-18T19:04:41.894Z" },
|
{ url = "https://files.pythonhosted.org/packages/b3/4a/4175a563579e884192ba6e81725fc0448b042024419be8d83aa8a80a3f44/jiter-0.10.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3aa96f2abba33dc77f79b4cf791840230375f9534e5fac927ccceb58c5e604a5", size = 354213, upload-time = "2025-05-18T19:04:41.894Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "jmespath"
|
||||||
|
version = "1.0.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/00/2a/e867e8531cf3e36b41201936b7fa7ba7b5702dbef42922193f05c8976cd6/jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe", size = 25843, upload-time = "2022-06-17T18:00:12.224Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/31/b4/b9b800c45527aadd64d5b442f9b932b00648617eb5d63d2c7a6587b7cafc/jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980", size = 20256, upload-time = "2022-06-17T18:00:10.251Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "jsonschema"
|
name = "jsonschema"
|
||||||
version = "4.25.0"
|
version = "4.25.0"
|
||||||
|
|
@ -1820,6 +1857,7 @@ unit = [
|
||||||
{ name = "litellm" },
|
{ name = "litellm" },
|
||||||
{ name = "mcp" },
|
{ name = "mcp" },
|
||||||
{ name = "milvus-lite" },
|
{ name = "milvus-lite" },
|
||||||
|
{ name = "moto", extra = ["s3"] },
|
||||||
{ name = "ollama" },
|
{ name = "ollama" },
|
||||||
{ name = "openai" },
|
{ name = "openai" },
|
||||||
{ name = "pymilvus" },
|
{ name = "pymilvus" },
|
||||||
|
|
@ -1937,6 +1975,7 @@ unit = [
|
||||||
{ name = "litellm" },
|
{ name = "litellm" },
|
||||||
{ name = "mcp" },
|
{ name = "mcp" },
|
||||||
{ name = "milvus-lite", specifier = ">=2.5.0" },
|
{ name = "milvus-lite", specifier = ">=2.5.0" },
|
||||||
|
{ name = "moto", extras = ["s3"], specifier = ">=5.1.10" },
|
||||||
{ name = "ollama" },
|
{ name = "ollama" },
|
||||||
{ name = "openai" },
|
{ name = "openai" },
|
||||||
{ name = "pymilvus", specifier = ">=2.5.12" },
|
{ name = "pymilvus", specifier = ">=2.5.12" },
|
||||||
|
|
@ -2224,6 +2263,32 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/16/71/4ad9a42f2772793a03cb698f0fc42499f04e6e8d2560ba2f7da0fb059a8e/mmh3-5.1.0-cp313-cp313-win_arm64.whl", hash = "sha256:b22fe2e54be81f6c07dcb36b96fa250fb72effe08aa52fbb83eade6e1e2d5fd7", size = 38890, upload-time = "2025-01-25T08:39:25.28Z" },
|
{ url = "https://files.pythonhosted.org/packages/16/71/4ad9a42f2772793a03cb698f0fc42499f04e6e8d2560ba2f7da0fb059a8e/mmh3-5.1.0-cp313-cp313-win_arm64.whl", hash = "sha256:b22fe2e54be81f6c07dcb36b96fa250fb72effe08aa52fbb83eade6e1e2d5fd7", size = 38890, upload-time = "2025-01-25T08:39:25.28Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "moto"
|
||||||
|
version = "5.1.10"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "boto3" },
|
||||||
|
{ name = "botocore" },
|
||||||
|
{ name = "cryptography" },
|
||||||
|
{ name = "jinja2" },
|
||||||
|
{ name = "python-dateutil" },
|
||||||
|
{ name = "requests" },
|
||||||
|
{ name = "responses" },
|
||||||
|
{ name = "werkzeug" },
|
||||||
|
{ name = "xmltodict" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/c4/72/9bc9b4917b816f5a82fc8f0fbd477c2a669d35a7d7941ae15a5411e266d6/moto-5.1.10.tar.gz", hash = "sha256:d6bdc8f82a1e503502927cc0a3da22014f836094d0bf399bb0f695754ae6c7a6", size = 7087004, upload-time = "2025-08-11T20:59:45.542Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c4/37/9b9cb5597eecc2ebfde2f65a8265f3669f6724ebe82bf9b155a3421039f8/moto-5.1.10-py3-none-any.whl", hash = "sha256:9ec1a21a924f97470af225b2bfa854fe46c1ad30fb44655eba458206dedf28b5", size = 5246859, upload-time = "2025-08-11T20:59:43.22Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.optional-dependencies]
|
||||||
|
s3 = [
|
||||||
|
{ name = "py-partiql-parser" },
|
||||||
|
{ name = "pyyaml" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mpmath"
|
name = "mpmath"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
|
|
@ -3068,6 +3133,15 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842, upload-time = "2024-07-21T12:58:20.04Z" },
|
{ url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842, upload-time = "2024-07-21T12:58:20.04Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "py-partiql-parser"
|
||||||
|
version = "0.6.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/58/a1/0a2867e48b232b4f82c4929ef7135f2a5d72c3886b957dccf63c70aa2fcb/py_partiql_parser-0.6.1.tar.gz", hash = "sha256:8583ff2a0e15560ef3bc3df109a7714d17f87d81d33e8c38b7fed4e58a63215d", size = 17120, upload-time = "2024-12-25T22:06:41.327Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/97/84/0e410c20bbe9a504fc56e97908f13261c2b313d16cbb3b738556166f044a/py_partiql_parser-0.6.1-py2.py3-none-any.whl", hash = "sha256:ff6a48067bff23c37e9044021bf1d949c83e195490c17e020715e927fe5b2456", size = 23520, upload-time = "2024-12-25T22:06:39.106Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyaml"
|
name = "pyaml"
|
||||||
version = "25.7.0"
|
version = "25.7.0"
|
||||||
|
|
@ -3788,6 +3862,20 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/3b/5d/63d4ae3b9daea098d5d6f5da83984853c1bbacd5dc826764b249fe119d24/requests_oauthlib-2.0.0-py2.py3-none-any.whl", hash = "sha256:7dd8a5c40426b779b0868c404bdef9768deccf22749cde15852df527e6269b36", size = 24179, upload-time = "2024-03-22T20:32:28.055Z" },
|
{ url = "https://files.pythonhosted.org/packages/3b/5d/63d4ae3b9daea098d5d6f5da83984853c1bbacd5dc826764b249fe119d24/requests_oauthlib-2.0.0-py2.py3-none-any.whl", hash = "sha256:7dd8a5c40426b779b0868c404bdef9768deccf22749cde15852df527e6269b36", size = 24179, upload-time = "2024-03-22T20:32:28.055Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "responses"
|
||||||
|
version = "0.25.8"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "pyyaml" },
|
||||||
|
{ name = "requests" },
|
||||||
|
{ name = "urllib3" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/0e/95/89c054ad70bfef6da605338b009b2e283485835351a9935c7bfbfaca7ffc/responses-0.25.8.tar.gz", hash = "sha256:9374d047a575c8f781b94454db5cab590b6029505f488d12899ddb10a4af1cf4", size = 79320, upload-time = "2025-08-08T19:01:46.709Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1c/4c/cc276ce57e572c102d9542d383b2cfd551276581dc60004cb94fe8774c11/responses-0.25.8-py3-none-any.whl", hash = "sha256:0c710af92def29c8352ceadff0c3fe340ace27cf5af1bbe46fb71275bcd2831c", size = 34769, upload-time = "2025-08-08T19:01:45.018Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rich"
|
name = "rich"
|
||||||
version = "14.1.0"
|
version = "14.1.0"
|
||||||
|
|
@ -3961,6 +4049,18 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/00/db/c376b0661c24cf770cb8815268190668ec1330eba8374a126ceef8c72d55/ruff-0.12.5-py3-none-win_arm64.whl", hash = "sha256:48cdbfc633de2c5c37d9f090ba3b352d1576b0015bfc3bc98eaf230275b7e805", size = 11951564, upload-time = "2025-07-24T13:26:34.994Z" },
|
{ url = "https://files.pythonhosted.org/packages/00/db/c376b0661c24cf770cb8815268190668ec1330eba8374a126ceef8c72d55/ruff-0.12.5-py3-none-win_arm64.whl", hash = "sha256:48cdbfc633de2c5c37d9f090ba3b352d1576b0015bfc3bc98eaf230275b7e805", size = 11951564, upload-time = "2025-07-24T13:26:34.994Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "s3transfer"
|
||||||
|
version = "0.13.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "botocore" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/6d/05/d52bf1e65044b4e5e27d4e63e8d1579dbdec54fce685908ae09bc3720030/s3transfer-0.13.1.tar.gz", hash = "sha256:c3fdba22ba1bd367922f27ec8032d6a1cf5f10c934fb5d68cf60fd5a23d936cf", size = 150589, upload-time = "2025-07-18T19:22:42.31Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/6d/4f/d073e09df851cfa251ef7840007d04db3293a0482ce607d2b993926089be/s3transfer-0.13.1-py3-none-any.whl", hash = "sha256:a981aa7429be23fe6dfc13e80e4020057cbab622b08c0315288758d67cabc724", size = 85308, upload-time = "2025-07-18T19:22:40.947Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "safetensors"
|
name = "safetensors"
|
||||||
version = "0.5.3"
|
version = "0.5.3"
|
||||||
|
|
@ -5107,6 +5207,15 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/78/58/e860788190eba3bcce367f74d29c4675466ce8dddfba85f7827588416f01/wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736", size = 24226, upload-time = "2022-08-23T19:58:19.96Z" },
|
{ url = "https://files.pythonhosted.org/packages/78/58/e860788190eba3bcce367f74d29c4675466ce8dddfba85f7827588416f01/wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736", size = 24226, upload-time = "2022-08-23T19:58:19.96Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "xmltodict"
|
||||||
|
version = "0.14.2"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/50/05/51dcca9a9bf5e1bce52582683ce50980bcadbc4fa5143b9f2b19ab99958f/xmltodict-0.14.2.tar.gz", hash = "sha256:201e7c28bb210e374999d1dde6382923ab0ed1a8a5faeece48ab525b7810a553", size = 51942, upload-time = "2024-10-16T06:10:29.683Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d6/45/fc303eb433e8a2a271739c98e953728422fa61a3c1f36077a49e395c972e/xmltodict-0.14.2-py2.py3-none-any.whl", hash = "sha256:20cc7d723ed729276e808f26fb6b3599f786cbc37e06c65e192ba77c40f20aac", size = 9981, upload-time = "2024-10-16T06:10:27.649Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "xxhash"
|
name = "xxhash"
|
||||||
version = "3.5.0"
|
version = "3.5.0"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue