Merge branch 'main' into responses_object

This commit is contained in:
Emilio Garcia 2025-08-25 16:01:59 -04:00 committed by GitHub
commit 708b2c1b05
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
166 changed files with 6944 additions and 809 deletions

View file

@ -9,6 +9,7 @@ updates:
day: "saturday"
commit-message:
prefix: chore(github-deps)
- package-ecosystem: "uv"
directory: "/"
schedule:
@ -19,3 +20,14 @@ updates:
- python
commit-message:
prefix: chore(python-deps)
- package-ecosystem: npm
directory: "/llama_stack/ui"
schedule:
interval: "weekly"
day: "saturday"
labels:
- type/dependencies
- javascript
commit-message:
prefix: chore(ui-deps)

View file

@ -17,7 +17,7 @@ jobs:
pull-requests: write # for peter-evans/create-pull-request to create a PR
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
ref: main
fetch-depth: 0

View file

@ -16,14 +16,14 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # 5.0.0
- name: Run ShellCheck on install.sh
run: shellcheck scripts/install.sh
smoke-test-on-dev:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies
uses: ./.github/actions/setup-runner

View file

@ -18,7 +18,7 @@ on:
- '.github/workflows/integration-auth-tests.yml' # This workflow
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true
jobs:
@ -31,7 +31,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies
uses: ./.github/actions/setup-runner

View file

@ -16,7 +16,7 @@ on:
- '.github/workflows/integration-sql-store-tests.yml' # This workflow
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true
jobs:
@ -44,7 +44,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies
uses: ./.github/actions/setup-runner

View file

@ -65,7 +65,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Setup test environment
uses: ./.github/actions/setup-test-environment

View file

@ -33,7 +33,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies
uses: ./.github/actions/setup-runner

View file

@ -8,7 +8,7 @@ on:
branches: [main]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true
jobs:
@ -20,7 +20,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
# For dependabot PRs, we need to checkout with a token that can push changes
token: ${{ github.actor == 'dependabot[bot]' && secrets.GITHUB_TOKEN || github.token }}
@ -36,20 +36,16 @@ jobs:
**/requirements*.txt
.pre-commit-config.yaml
# npm ci may fail -
# npm error `npm ci` can only install packages when your package.json and package-lock.json or npm-shrinkwrap.json are in sync. Please update your lock file with `npm install` before continuing.
# npm error Invalid: lock file's llama-stack-client@0.2.17 does not satisfy llama-stack-client@0.2.18
- name: Set up Node.js
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
with:
node-version: '20'
cache: 'npm'
cache-dependency-path: 'llama_stack/ui/'
# - name: Set up Node.js
# uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0
# with:
# node-version: '20'
# cache: 'npm'
# cache-dependency-path: 'llama_stack/ui/'
# - name: Install npm dependencies
# run: npm ci
# working-directory: llama_stack/ui
- name: Install npm dependencies
run: npm ci
working-directory: llama_stack/ui
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
continue-on-error: true

View file

@ -26,7 +26,7 @@ on:
- 'pyproject.toml'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true
jobs:
@ -36,7 +36,7 @@ jobs:
distros: ${{ steps.set-matrix.outputs.distros }}
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Generate Distribution List
id: set-matrix
@ -55,7 +55,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies
uses: ./.github/actions/setup-runner
@ -79,7 +79,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies
uses: ./.github/actions/setup-runner
@ -92,7 +92,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies
uses: ./.github/actions/setup-runner
@ -106,6 +106,10 @@ jobs:
- name: Inspect the container image entrypoint
run: |
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
if [ -z "$IMAGE_ID" ]; then
echo "No image found"
exit 1
fi
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
echo "Entrypoint: $entrypoint"
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then
@ -117,7 +121,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies
uses: ./.github/actions/setup-runner
@ -140,6 +144,10 @@ jobs:
- name: Inspect UBI9 image
run: |
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
if [ -z "$IMAGE_ID" ]; then
echo "No image found"
exit 1
fi
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
echo "Entrypoint: $entrypoint"
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then

View file

@ -21,10 +21,10 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install uv
uses: astral-sh/setup-uv@e92bafb6253dcd438e0484186d7669ea7a8ca1cc # v6.4.3
uses: astral-sh/setup-uv@4959332f0f014c5280e7eac8b70c90cb574c9f9b # v6.6.0
with:
python-version: ${{ matrix.python-version }}
activate-environment: true

View file

@ -46,7 +46,7 @@ jobs:
echo "::endgroup::"
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
fetch-depth: 0

View file

@ -22,6 +22,6 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check PR Title's semantic conformance
uses: amannn/action-semantic-pull-request@0723387faaf9b38adef4775cd42cfd5155ed6017 # v5.5.3
uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View file

@ -27,7 +27,7 @@ jobs:
# container and point 'uv pip install' to the correct path...
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies
uses: ./.github/actions/setup-runner

View file

@ -27,7 +27,7 @@ jobs:
# container and point 'uv pip install' to the correct path...
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies
uses: ./.github/actions/setup-runner

View file

@ -13,7 +13,7 @@ on:
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true
jobs:
@ -26,10 +26,10 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Setup Node.js
uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
with:
node-version: ${{ matrix.node-version }}
cache: 'npm'

View file

@ -18,7 +18,7 @@ on:
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true
jobs:
@ -32,7 +32,7 @@ jobs:
- "3.13"
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies
uses: ./.github/actions/setup-runner

View file

@ -27,7 +27,7 @@ on:
- '.github/workflows/update-readthedocs.yml'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true
jobs:
@ -37,7 +37,7 @@ jobs:
TOKEN: ${{ secrets.READTHEDOCS_TOKEN }}
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies
uses: ./.github/actions/setup-runner

View file

@ -146,31 +146,13 @@ repos:
pass_filenames: false
require_serial: true
files: ^.github/workflows/.*$
# ui-prettier and ui-eslint are disabled until we can avoid `npm ci`, which is slow and may fail -
# npm error `npm ci` can only install packages when your package.json and package-lock.json or npm-shrinkwrap.json are in sync. Please update your lock file with `npm install` before continuing.
# npm error Invalid: lock file's llama-stack-client@0.2.17 does not satisfy llama-stack-client@0.2.18
# and until we have infra for installing prettier and next via npm -
# Lint UI code with ESLint.....................................................Failed
# - hook id: ui-eslint
# - exit code: 127
# > ui@0.1.0 lint
# > next lint --fix --quiet
# sh: line 1: next: command not found
#
# - id: ui-prettier
# name: Format UI code with Prettier
# entry: bash -c 'cd llama_stack/ui && npm ci && npm run format'
# language: system
# files: ^llama_stack/ui/.*\.(ts|tsx)$
# pass_filenames: false
# require_serial: true
# - id: ui-eslint
# name: Lint UI code with ESLint
# entry: bash -c 'cd llama_stack/ui && npm run lint -- --fix --quiet'
# language: system
# files: ^llama_stack/ui/.*\.(ts|tsx)$
# pass_filenames: false
# require_serial: true
- id: ui-linter
name: Format & Lint UI
entry: bash ./scripts/run-ui-linter.sh
language: system
files: ^llama_stack/ui/.*\.(ts|tsx)$
pass_filenames: false
require_serial: true
- id: check-log-usage
name: Ensure 'llama_stack.log' usage for logging

View file

@ -4605,6 +4605,49 @@
}
}
},
"/v1/inference/rerank": {
"post": {
"responses": {
"200": {
"description": "RerankResponse with indices sorted by relevance score (descending).",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/RerankResponse"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Inference"
],
"description": "Rerank a list of documents based on their relevance to a query.",
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/RerankRequest"
}
}
},
"required": true
}
}
},
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
"post": {
"responses": {
@ -16418,12 +16461,16 @@
"value": {
"type": "number",
"description": "The numeric value of the metric at this timestamp"
},
"unit": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"timestamp",
"value"
"value",
"unit"
],
"title": "MetricDataPoint",
"description": "A single data point in a metric time series."
@ -16981,6 +17028,95 @@
],
"title": "RegisterVectorDbRequest"
},
"RerankRequest": {
"type": "object",
"properties": {
"model": {
"type": "string",
"description": "The identifier of the reranking model to use."
},
"query": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
}
],
"description": "The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length."
},
"items": {
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
}
]
},
"description": "List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length."
},
"max_num_results": {
"type": "integer",
"description": "(Optional) Maximum number of results to return. Default: returns all."
}
},
"additionalProperties": false,
"required": [
"model",
"query",
"items"
],
"title": "RerankRequest"
},
"RerankData": {
"type": "object",
"properties": {
"index": {
"type": "integer",
"description": "The original index of the document in the input list"
},
"relevance_score": {
"type": "number",
"description": "The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance."
}
},
"additionalProperties": false,
"required": [
"index",
"relevance_score"
],
"title": "RerankData",
"description": "A single rerank result from a reranking response."
},
"RerankResponse": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/components/schemas/RerankData"
},
"description": "List of rerank result objects, sorted by relevance score (descending)"
}
},
"additionalProperties": false,
"required": [
"data"
],
"title": "RerankResponse",
"description": "Response from a reranking request."
},
"ResumeAgentTurnRequest": {
"type": "object",
"properties": {

View file

@ -3264,6 +3264,37 @@ paths:
schema:
$ref: '#/components/schemas/QueryTracesRequest'
required: true
/v1/inference/rerank:
post:
responses:
'200':
description: >-
RerankResponse with indices sorted by relevance score (descending).
content:
application/json:
schema:
$ref: '#/components/schemas/RerankResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Inference
description: >-
Rerank a list of documents based on their relevance to a query.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RerankRequest'
required: true
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
post:
responses:
@ -12242,10 +12273,13 @@ components:
type: number
description: >-
The numeric value of the metric at this timestamp
unit:
type: string
additionalProperties: false
required:
- timestamp
- value
- unit
title: MetricDataPoint
description: >-
A single data point in a metric time series.
@ -12656,6 +12690,76 @@ components:
- vector_db_id
- embedding_model
title: RegisterVectorDbRequest
RerankRequest:
type: object
properties:
model:
type: string
description: >-
The identifier of the reranking model to use.
query:
oneOf:
- type: string
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
description: >-
The search query to rank items against. Can be a string, text content
part, or image content part. The input must not exceed the model's max
input token length.
items:
type: array
items:
oneOf:
- type: string
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
description: >-
List of items to rerank. Each item can be a string, text content part,
or image content part. Each input must not exceed the model's max input
token length.
max_num_results:
type: integer
description: >-
(Optional) Maximum number of results to return. Default: returns all.
additionalProperties: false
required:
- model
- query
- items
title: RerankRequest
RerankData:
type: object
properties:
index:
type: integer
description: >-
The original index of the document in the input list
relevance_score:
type: number
description: >-
The relevance score from the model output. Values are inverted when applicable
so that higher scores indicate greater relevance.
additionalProperties: false
required:
- index
- relevance_score
title: RerankData
description: >-
A single rerank result from a reranking response.
RerankResponse:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/RerankData'
description: >-
List of rerank result objects, sorted by relevance score (descending)
additionalProperties: false
required:
- data
title: RerankResponse
description: Response from a reranking request.
ResumeAgentTurnRequest:
type: object
properties:

View file

@ -225,8 +225,32 @@ server:
port: 8321 # Port to listen on (default: 8321)
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
cors: true # Optional: Enable CORS (dev mode) or full config object
```
### CORS Configuration
CORS (Cross-Origin Resource Sharing) can be configured in two ways:
**Local development** (allows localhost origins only):
```yaml
server:
cors: true
```
**Explicit configuration** (custom origins and settings):
```yaml
server:
cors:
allow_origins: ["https://myapp.com", "https://app.example.com"]
allow_methods: ["GET", "POST", "PUT", "DELETE"]
allow_headers: ["Content-Type", "Authorization"]
allow_credentials: true
max_age: 3600
```
When `cors: true`, the server enables secure localhost-only access for local development. For production, specify exact origins to maintain security.
### Authentication Configuration
> **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly.
@ -618,6 +642,54 @@ Content-Type: application/json
}
```
### CORS Configuration
Configure CORS to allow web browsers to make requests from different domains. Disabled by default.
#### Quick Setup
For development, use the simple boolean flag:
```yaml
server:
cors: true # Auto-enables localhost with any port
```
This automatically allows `http://localhost:*` and `https://localhost:*` with secure defaults.
#### Custom Configuration
For specific origins and full control:
```yaml
server:
cors:
allow_origins: ["https://myapp.com", "https://staging.myapp.com"]
allow_credentials: true
allow_methods: ["GET", "POST", "PUT", "DELETE"]
allow_headers: ["Content-Type", "Authorization"]
allow_origin_regex: "https://.*\\.example\\.com" # Optional regex pattern
expose_headers: ["X-Total-Count"]
max_age: 86400
```
#### Configuration Options
| Field | Description | Default |
| -------------------- | ---------------------------------------------- | ------- |
| `allow_origins` | List of allowed origins. Use `["*"]` for any. | `["*"]` |
| `allow_origin_regex` | Regex pattern for allowed origins (optional). | `None` |
| `allow_methods` | Allowed HTTP methods. | `["*"]` |
| `allow_headers` | Allowed headers. | `["*"]` |
| `allow_credentials` | Allow credentials (cookies, auth headers). | `false` |
| `expose_headers` | Headers exposed to browser. | `[]` |
| `max_age` | Preflight cache time (seconds). | `600` |
**Security Notes**:
- `allow_credentials: true` requires explicit origins (no wildcards)
- `cors: true` enables localhost access only (secure for development)
- For public APIs, always specify exact allowed origins
## Extending to handle Safety
Configuring Safety can be a little involved so it is instructive to go through an example.

View file

@ -17,7 +17,6 @@ client = LlamaStackAsLibraryClient(
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
)
client.initialize()
```
This will parse your config and set up any inline implementations and remote clients needed for your implementation.
@ -32,5 +31,4 @@ If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/
```python
client = LlamaStackAsLibraryClient(config_path)
client.initialize()
```

View file

@ -2,12 +2,15 @@
## Overview
Protocol for batch processing API operations.
The Batches API enables efficient processing of multiple requests in a single operation,
particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale.
The API is designed to allow use of openai client libraries for seamless integration.
This API provides the following extensions:
- idempotent batch creation
Note: This API is currently under active development and may undergo changes.
This section contains documentation for all available providers for the **batches** API.

View file

@ -10,4 +10,5 @@ This section contains documentation for all available providers for the **files*
:maxdepth: 1
inline_localfs
remote_s3
```

View file

@ -0,0 +1,33 @@
# remote::s3
## Description
AWS S3-based file storage provider for scalable cloud file management with metadata persistence.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `bucket_name` | `<class 'str'>` | No | | S3 bucket name to store files |
| `region` | `<class 'str'>` | No | us-east-1 | AWS region where the bucket is located |
| `aws_access_key_id` | `str \| None` | No | | AWS access key ID (optional if using IAM roles) |
| `aws_secret_access_key` | `str \| None` | No | | AWS secret access key (optional if using IAM roles) |
| `endpoint_url` | `str \| None` | No | | Custom S3 endpoint URL (for MinIO, LocalStack, etc.) |
| `auto_create_bucket` | `<class 'bool'>` | No | False | Automatically create the S3 bucket if it doesn't exist |
| `metadata_store` | `utils.sqlstore.sqlstore.SqliteSqlStoreConfig \| utils.sqlstore.sqlstore.PostgresSqlStoreConfig` | No | sqlite | SQL store configuration for file metadata |
## Sample Configuration
```yaml
bucket_name: ${env.S3_BUCKET_NAME}
region: ${env.AWS_REGION:=us-east-1}
aws_access_key_id: ${env.AWS_ACCESS_KEY_ID:=}
aws_secret_access_key: ${env.AWS_SECRET_ACCESS_KEY:=}
endpoint_url: ${env.S3_ENDPOINT_URL:=}
auto_create_bucket: ${env.S3_AUTO_CREATE_BUCKET:=false}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/s3_files_metadata.db
```

View file

@ -9,7 +9,9 @@ This section contains documentation for all available providers for the **post_t
```{toctree}
:maxdepth: 1
inline_huggingface
inline_torchtune
inline_huggingface-cpu
inline_huggingface-gpu
inline_torchtune-cpu
inline_torchtune-gpu
remote_nvidia
```

View file

@ -0,0 +1,41 @@
# inline::huggingface-cpu
## Description
HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `device` | `<class 'str'>` | No | cuda | |
| `distributed_backend` | `Literal['fsdp', 'deepspeed'` | No | | |
| `checkpoint_format` | `Literal['full_state', 'huggingface'` | No | huggingface | |
| `chat_template` | `<class 'str'>` | No | <|user|>
{input}
<|assistant|>
{output} | |
| `model_specific_config` | `<class 'dict'>` | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | |
| `max_seq_length` | `<class 'int'>` | No | 2048 | |
| `gradient_checkpointing` | `<class 'bool'>` | No | False | |
| `save_total_limit` | `<class 'int'>` | No | 3 | |
| `logging_steps` | `<class 'int'>` | No | 10 | |
| `warmup_ratio` | `<class 'float'>` | No | 0.1 | |
| `weight_decay` | `<class 'float'>` | No | 0.01 | |
| `dataloader_num_workers` | `<class 'int'>` | No | 4 | |
| `dataloader_pin_memory` | `<class 'bool'>` | No | True | |
| `dpo_beta` | `<class 'float'>` | No | 0.1 | |
| `use_reference_model` | `<class 'bool'>` | No | True | |
| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | |
| `dpo_output_dir` | `<class 'str'>` | No | | |
## Sample Configuration
```yaml
checkpoint_format: huggingface
distributed_backend: null
device: cpu
dpo_output_dir: ~/.llama/dummy/dpo_output
```

View file

@ -0,0 +1,41 @@
# inline::huggingface-gpu
## Description
HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `device` | `<class 'str'>` | No | cuda | |
| `distributed_backend` | `Literal['fsdp', 'deepspeed'` | No | | |
| `checkpoint_format` | `Literal['full_state', 'huggingface'` | No | huggingface | |
| `chat_template` | `<class 'str'>` | No | <|user|>
{input}
<|assistant|>
{output} | |
| `model_specific_config` | `<class 'dict'>` | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | |
| `max_seq_length` | `<class 'int'>` | No | 2048 | |
| `gradient_checkpointing` | `<class 'bool'>` | No | False | |
| `save_total_limit` | `<class 'int'>` | No | 3 | |
| `logging_steps` | `<class 'int'>` | No | 10 | |
| `warmup_ratio` | `<class 'float'>` | No | 0.1 | |
| `weight_decay` | `<class 'float'>` | No | 0.01 | |
| `dataloader_num_workers` | `<class 'int'>` | No | 4 | |
| `dataloader_pin_memory` | `<class 'bool'>` | No | True | |
| `dpo_beta` | `<class 'float'>` | No | 0.1 | |
| `use_reference_model` | `<class 'bool'>` | No | True | |
| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | |
| `dpo_output_dir` | `<class 'str'>` | No | | |
## Sample Configuration
```yaml
checkpoint_format: huggingface
distributed_backend: null
device: cpu
dpo_output_dir: ~/.llama/dummy/dpo_output
```

View file

@ -0,0 +1,20 @@
# inline::torchtune-cpu
## Description
TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `torch_seed` | `int \| None` | No | | |
| `checkpoint_format` | `Literal['meta', 'huggingface'` | No | meta | |
## Sample Configuration
```yaml
checkpoint_format: meta
```

View file

@ -0,0 +1,20 @@
# inline::torchtune-gpu
## Description
TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `torch_seed` | `int \| None` | No | | |
| `checkpoint_format` | `Literal['meta', 'huggingface'` | No | meta | |
## Sample Configuration
```yaml
checkpoint_format: meta
```

View file

@ -29,12 +29,16 @@ class ListBatchesResponse(BaseModel):
@runtime_checkable
class Batches(Protocol):
"""Protocol for batch processing API operations.
"""
The Batches API enables efficient processing of multiple requests in a single operation,
particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale.
The API is designed to allow use of openai client libraries for seamless integration.
This API provides the following extensions:
- idempotent batch creation
Note: This API is currently under active development and may undergo changes.
"""
@ -45,6 +49,7 @@ class Batches(Protocol):
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
) -> BatchObject:
"""Create a new batch for processing multiple API requests.
@ -52,6 +57,7 @@ class Batches(Protocol):
:param endpoint: The endpoint to be used for all requests in the batch.
:param completion_window: The time window within which the batch should be processed.
:param metadata: Optional metadata for the batch.
:param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior.
:returns: The created batch object.
"""
...

View file

@ -473,6 +473,28 @@ class EmbeddingsResponse(BaseModel):
embeddings: list[list[float]]
@json_schema_type
class RerankData(BaseModel):
"""A single rerank result from a reranking response.
:param index: The original index of the document in the input list
:param relevance_score: The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance.
"""
index: int
relevance_score: float
@json_schema_type
class RerankResponse(BaseModel):
"""Response from a reranking request.
:param data: List of rerank result objects, sorted by relevance score (descending)
"""
data: list[RerankData]
@json_schema_type
class OpenAIChatCompletionContentPartTextParam(BaseModel):
"""Text content part for OpenAI-compatible chat completion messages.
@ -1046,6 +1068,7 @@ class InferenceProvider(Protocol):
:returns: A BatchCompletionResponse with the full completions.
"""
raise NotImplementedError("Batch completion is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/inference/chat-completion", method="POST")
async def chat_completion(
@ -1110,6 +1133,7 @@ class InferenceProvider(Protocol):
:returns: A BatchChatCompletionResponse with the full completions.
"""
raise NotImplementedError("Batch chat completion is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/inference/embeddings", method="POST")
async def embeddings(
@ -1131,6 +1155,25 @@ class InferenceProvider(Protocol):
"""
...
@webmethod(route="/inference/rerank", method="POST", experimental=True)
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
"""Rerank a list of documents based on their relevance to a query.
:param model: The identifier of the reranking model to use.
:param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length.
:param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length.
:param max_num_results: (Optional) Maximum number of results to return. Default: returns all.
:returns: RerankResponse with indices sorted by relevance score (descending).
"""
raise NotImplementedError("Reranking is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/openai/v1/completions", method="POST")
async def openai_completion(
self,

View file

@ -386,6 +386,7 @@ class MetricDataPoint(BaseModel):
timestamp: int
value: float
unit: str
@json_schema_type
@ -518,7 +519,7 @@ class Telemetry(Protocol):
metric_name: str,
start_time: int,
end_time: int | None = None,
granularity: str | None = "1d",
granularity: str | None = None,
query_type: MetricQueryType = MetricQueryType.RANGE,
label_matchers: list[MetricLabelMatcher] | None = None,
) -> QueryMetricsResponse:

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server")
logger = get_logger(name=__name__, category="cli")
class StackRun(Subcommand):

View file

@ -318,6 +318,41 @@ class QuotaConfig(BaseModel):
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
class CORSConfig(BaseModel):
allow_origins: list[str] = Field(default_factory=list)
allow_origin_regex: str | None = Field(default=None)
allow_methods: list[str] = Field(default=["OPTIONS"])
allow_headers: list[str] = Field(default_factory=list)
allow_credentials: bool = Field(default=False)
expose_headers: list[str] = Field(default_factory=list)
max_age: int = Field(default=600, ge=0)
@model_validator(mode="after")
def validate_credentials_config(self) -> Self:
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
raise ValueError("Cannot use wildcard origins with credentials enabled")
return self
def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None:
if cors_config is False or cors_config is None:
return None
if cors_config is True:
# dev mode: allow localhost on any port
return CORSConfig(
allow_origins=[],
allow_origin_regex=r"https?://localhost:\d+",
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
)
if isinstance(cors_config, CORSConfig):
return cors_config
raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
class ServerConfig(BaseModel):
port: int = Field(
default=8321,
@ -349,6 +384,12 @@ class ServerConfig(BaseModel):
default=None,
description="Per client quota request configuration",
)
cors: bool | CORSConfig | None = Field(
default=None,
description="CORS configuration for cross-origin requests. Can be:\n"
"- true: Enable localhost CORS for development\n"
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
)
class StackRunConfig(BaseModel):

View file

@ -146,39 +146,26 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
):
super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_distro_name, custom_provider_registry, provider_data
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
)
self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal
self.provider_data = provider_data
self.loop = asyncio.new_event_loop()
def initialize(self):
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
if not self.skip_logger_removal:
self._remove_root_logger_handlers()
# use a new event loop to avoid interfering with the main event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(self.async_client.initialize())
loop.run_until_complete(self.async_client.initialize())
finally:
asyncio.set_event_loop(None)
def _remove_root_logger_handlers(self):
def initialize(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
Deprecated method for backward compatibility.
"""
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
pass
def request(self, *args, **kwargs):
loop = self.loop
@ -216,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
config_path_or_distro_name: str,
custom_provider_registry: ProviderRegistry | None = None,
provider_data: dict[str, Any] | None = None,
skip_logger_removal: bool = False,
):
super().__init__()
# when using the library client, we should not log to console since many
@ -223,6 +211,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
if not skip_logger_removal:
self._remove_root_logger_handlers()
if config_path_or_distro_name.endswith(".yaml"):
config_path = Path(config_path_or_distro_name)
if not config_path.exists():
@ -239,7 +234,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self.provider_data = provider_data
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
def _remove_root_logger_handlers(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
"""
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
async def initialize(self) -> bool:
"""
Initialize the async client.
Returns:
bool: True if initialization was successful
"""
try:
self.route_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry)

View file

@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class DatasetIORouter(DatasetIO):

View file

@ -16,7 +16,7 @@ from llama_stack.apis.scoring import (
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class ScoringRouter(Scoring):

View file

@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="core::routers")
class InferenceRouter(Inference):

View file

@ -13,7 +13,7 @@ from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class SafetyRouter(Safety):

View file

@ -22,7 +22,7 @@ from llama_stack.log import get_logger
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class ToolRuntimeRouter(ToolRuntime):

View file

@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class VectorIORouter(VectorIO):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):

View file

@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
def get_impl_api(p: Any) -> Api:

View file

@ -26,7 +26,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):

View file

@ -17,7 +17,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class ModelsRoutingTable(CommonRoutingTableImpl, Models):

View file

@ -19,7 +19,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:

View file

@ -30,7 +30,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):

View file

@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
logger = get_logger(name=__name__, category="core::auth")
class AuthenticationMiddleware:

View file

@ -23,7 +23,7 @@ from llama_stack.core.datatypes import (
)
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
logger = get_logger(name=__name__, category="core::auth")
class AuthResponse(BaseModel):

View file

@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
logger = get_logger(name=__name__, category="quota")
logger = get_logger(name=__name__, category="core::server")
class QuotaMiddleware:

View file

@ -28,6 +28,7 @@ from aiohttp import hdrs
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
@ -40,6 +41,7 @@ from llama_stack.core.datatypes import (
AuthenticationRequiredError,
LoggingConfig,
StackRunConfig,
process_cors_config,
)
from llama_stack.core.distribution import builtin_automatically_routed_apis
from llama_stack.core.external import ExternalApiSpec, load_external_apis
@ -82,7 +84,7 @@ from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server")
logger = get_logger(name=__name__, category="core::server")
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -413,7 +415,7 @@ def main(args: argparse.Namespace | None = None):
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="server", config=logger_config)
logger = get_logger(name=__name__, category="core::server", config=logger_config)
if args.env:
for env_pair in args.env:
try:
@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None):
window_seconds=window_seconds,
)
if config.server.cors:
logger.info("Enabling CORS")
cors_config = process_cors_config(config.server.cors)
if cors_config:
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:

View file

@ -16,7 +16,7 @@ from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
logger = get_logger(__name__, category="core")
logger = get_logger(__name__, category="core::registry")
class DistributionRegistry(Protocol):

View file

@ -10,7 +10,7 @@ from pathlib import Path
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="config_resolution")
logger = get_logger(name=__name__, category="core")
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"

View file

@ -34,7 +34,7 @@ distribution_spec:
telemetry:
- provider_type: inline::meta-reference
post_training:
- provider_type: inline::huggingface
- provider_type: inline::huggingface-cpu
eval:
- provider_type: inline::meta-reference
datasetio:

View file

@ -156,8 +156,8 @@ providers:
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training:
- provider_id: huggingface
provider_type: inline::huggingface
- provider_id: huggingface-cpu
provider_type: inline::huggingface-cpu
config:
checkpoint_format: huggingface
distributed_backend: null

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .starter_gpu import get_distribution_template # noqa: F401

View file

@ -0,0 +1,59 @@
version: 2
distribution_spec:
description: Quick start template for running Llama Stack with several popular providers.
This distribution is intended for GPU-enabled environments.
providers:
inference:
- provider_type: remote::cerebras
- provider_type: remote::ollama
- provider_type: remote::vllm
- provider_type: remote::tgi
- provider_type: remote::fireworks
- provider_type: remote::together
- provider_type: remote::bedrock
- provider_type: remote::nvidia
- provider_type: remote::openai
- provider_type: remote::anthropic
- provider_type: remote::gemini
- provider_type: remote::vertexai
- provider_type: remote::groq
- provider_type: remote::sambanova
- provider_type: inline::sentence-transformers
vector_io:
- provider_type: inline::faiss
- provider_type: inline::sqlite-vec
- provider_type: inline::milvus
- provider_type: remote::chromadb
- provider_type: remote::pgvector
files:
- provider_type: inline::localfs
safety:
- provider_type: inline::llama-guard
- provider_type: inline::code-scanner
agents:
- provider_type: inline::meta-reference
telemetry:
- provider_type: inline::meta-reference
post_training:
- provider_type: inline::torchtune-gpu
eval:
- provider_type: inline::meta-reference
datasetio:
- provider_type: remote::huggingface
- provider_type: inline::localfs
scoring:
- provider_type: inline::basic
- provider_type: inline::llm-as-judge
- provider_type: inline::braintrust
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
batches:
- provider_type: inline::reference
image_type: venv
additional_pip_packages:
- aiosqlite
- asyncpg
- sqlalchemy[asyncio]

View file

@ -0,0 +1,238 @@
version: 2
image_name: starter-gpu
apis:
- agents
- batches
- datasetio
- eval
- files
- inference
- post_training
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
api_key: ${env.OPENAI_API_KEY:=}
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
- provider_id: anthropic
provider_type: remote::anthropic
config:
api_key: ${env.ANTHROPIC_API_KEY:=}
- provider_id: gemini
provider_type: remote::gemini
config:
api_key: ${env.GEMINI_API_KEY:=}
- provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
provider_type: remote::vertexai
config:
project: ${env.VERTEX_AI_PROJECT:=}
location: ${env.VERTEX_AI_LOCATION:=us-central1}
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
- provider_id: ${env.MILVUS_URL:+milvus}
provider_type: inline::milvus
config:
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db
- provider_id: ${env.CHROMADB_URL:+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db
- provider_id: ${env.PGVECTOR_DB:+pgvector}
provider_type: remote::pgvector
config:
host: ${env.PGVECTOR_HOST:=localhost}
port: ${env.PGVECTOR_PORT:=5432}
db: ${env.PGVECTOR_DB:=}
user: ${env.PGVECTOR_USER:=}
password: ${env.PGVECTOR_PASSWORD:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
- provider_id: code-scanner
provider_type: inline::code-scanner
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training:
- provider_id: torchtune-gpu
provider_type: inline::torchtune-gpu
config:
checkpoint_format: meta
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/huggingface_datasetio.db
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:=}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
batches:
- provider_id: reference
provider_type: inline::reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/batches.db
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/inference_store.db
models: []
shields:
- shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=}
- shield_id: code-scanner
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.distributions.template import BuildProvider, DistributionTemplate
from ..starter.starter import get_distribution_template as get_starter_distribution_template
def get_distribution_template() -> DistributionTemplate:
template = get_starter_distribution_template()
name = "starter-gpu"
template.name = name
template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments."
template.providers["post_training"] = [
BuildProvider(provider_type="inline::torchtune-gpu"),
]
return template

View file

@ -1,6 +1,7 @@
version: 2
distribution_spec:
description: Quick start template for running Llama Stack with several popular providers
description: Quick start template for running Llama Stack with several popular providers.
This distribution is intended for CPU-only environments.
providers:
inference:
- provider_type: remote::cerebras
@ -34,7 +35,7 @@ distribution_spec:
telemetry:
- provider_type: inline::meta-reference
post_training:
- provider_type: inline::huggingface
- provider_type: inline::huggingface-cpu
eval:
- provider_type: inline::meta-reference
datasetio:

View file

@ -156,8 +156,8 @@ providers:
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training:
- provider_id: huggingface
provider_type: inline::huggingface
- provider_id: huggingface-cpu
provider_type: inline::huggingface-cpu
config:
checkpoint_format: huggingface
distributed_backend: null

View file

@ -120,7 +120,7 @@ def get_distribution_template() -> DistributionTemplate:
],
"agents": [BuildProvider(provider_type="inline::meta-reference")],
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"post_training": [BuildProvider(provider_type="inline::huggingface")],
"post_training": [BuildProvider(provider_type="inline::huggingface-cpu")],
"eval": [BuildProvider(provider_type="inline::meta-reference")],
"datasetio": [
BuildProvider(provider_type="remote::huggingface"),
@ -178,7 +178,7 @@ def get_distribution_template() -> DistributionTemplate:
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Quick start template for running Llama Stack with several popular providers",
description="Quick start template for running Llama Stack with several popular providers. This distribution is intended for CPU-only environments.",
container_image=None,
template_path=None,
providers=providers,

View file

@ -36,7 +36,7 @@ from .utils import get_negative_inf_value, to_2tuple
MP_SCALE = 8
logger = get_logger(name=__name__, category="models")
logger = get_logger(name=__name__, category="models::llama")
def reduce_from_tensor_model_parallel_region(input_):

View file

@ -11,7 +11,7 @@ from llama_stack.log import get_logger
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="models::llama")
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")

View file

@ -18,7 +18,7 @@ from ...datatypes import QuantizationMode
from ..model import Transformer, TransformerBlock
from ..moe import MoE
log = get_logger(name=__name__, category="models")
log = get_logger(name=__name__, category="models::llama")
def swiglu_wrapper_no_reduce(

View file

@ -9,7 +9,7 @@ import collections
from llama_stack.log import get_logger
log = get_logger(name=__name__, category="llama")
log = get_logger(name=__name__, category="models::llama")
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401

View file

@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag"
logger = get_logger(name=__name__, category="agents")
logger = get_logger(name=__name__, category="agents::meta_reference")
class ChatAgent(ShieldRunnerMixin):

View file

@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
from .persistence import AgentInfo
from .responses.openai_responses import OpenAIResponsesImpl
logger = get_logger(name=__name__, category="agents")
logger = get_logger(name=__name__, category="agents::meta_reference")
class MetaReferenceAgentsImpl(Agents):

View file

@ -17,7 +17,7 @@ from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
log = get_logger(name=__name__, category="agents")
log = get_logger(name=__name__, category="agents::meta_reference")
class AgentSessionInfo(Session):

View file

@ -41,7 +41,7 @@ from .utils import (
convert_response_text_to_chat_response_format,
)
logger = get_logger(name=__name__, category="responses")
logger = get_logger(name=__name__, category="openai::responses")
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):

View file

@ -47,7 +47,7 @@ from llama_stack.log import get_logger
from .types import ChatCompletionContext, ChatCompletionResult
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
logger = get_logger(name=__name__, category="responses")
logger = get_logger(name=__name__, category="agents::meta_reference")
class StreamingResponseOrchestrator:

View file

@ -38,7 +38,7 @@ from llama_stack.log import get_logger
from .types import ChatCompletionContext, ToolExecutionResult
logger = get_logger(name=__name__, category="responses")
logger = get_logger(name=__name__, category="agents::meta_reference")
class ToolExecutor:

View file

@ -17,6 +17,8 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseText,
)
from llama_stack.apis.inference import (
@ -99,14 +101,22 @@ async def convert_response_input_to_chat_messages(
"""
messages: list[OpenAIMessageParam] = []
if isinstance(input, list):
# extract all OpenAIResponseInputFunctionToolCallOutput items
# so their corresponding OpenAIToolMessageParam instances can
# be added immediately following the corresponding
# OpenAIAssistantMessageParam
tool_call_results = {}
for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
messages.append(
OpenAIToolMessageParam(
tool_call_results[input_item.call_id] = OpenAIToolMessageParam(
content=input_item.output,
tool_call_id=input_item.call_id,
)
)
for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
# skip as these have been extracted and inserted in order
pass
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
tool_call = OpenAIChatCompletionToolCall(
index=0,
@ -117,6 +127,28 @@ async def convert_response_input_to_chat_messages(
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
if input_item.call_id in tool_call_results:
messages.append(tool_call_results[input_item.call_id])
del tool_call_results[input_item.call_id]
elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id=input_item.id,
function=OpenAIChatCompletionToolCallFunction(
name=input_item.name,
arguments=input_item.arguments,
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
messages.append(
OpenAIToolMessageParam(
content=input_item.output,
tool_call_id=input_item.id,
)
)
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
# the tool list will be handled separately
pass
else:
content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role)
@ -125,6 +157,10 @@ async def convert_response_input_to_chat_messages(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
)
messages.append(message_type(content=content))
if len(tool_call_results):
raise ValueError(
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
)
else:
messages.append(OpenAIUserMessageParam(content=input))
return messages

View file

@ -11,7 +11,7 @@ from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing
log = get_logger(name=__name__, category="agents")
log = get_logger(name=__name__, category="agents::meta_reference")
class SafetyException(Exception): # noqa: N818

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import asyncio
import hashlib
import itertools
import json
import time
@ -136,28 +137,45 @@ class ReferenceBatchesImpl(Batches):
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
) -> BatchObject:
"""
Create a new batch for processing multiple API requests.
Error handling by levels -
0. Input param handling, results in 40x errors before processing, e.g.
- Wrong completion_window
- Invalid metadata types
- Unknown endpoint
-> no batch created
1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
- input_file_id missing
- invalid json in file
- missing custom_id, method, url, body
- invalid model
- streaming
-> batch created, validation sends to failed status
2. Processing errors, result in error_file_id entries, e.g.
- Any error returned from inference endpoint
-> batch created, goes to completed status
This implementation provides optional idempotency: when an idempotency key
(idempotency_key) is provided, a deterministic ID is generated based on the input
parameters. If a batch with the same parameters already exists, it will be
returned instead of creating a duplicate. Without an idempotency key,
each request creates a new batch with a unique ID.
Args:
input_file_id: The ID of an uploaded file containing requests for the batch.
endpoint: The endpoint to be used for all requests in the batch.
completion_window: The time window within which the batch should be processed.
metadata: Optional metadata for the batch.
idempotency_key: Optional idempotency key for enabling idempotent behavior.
Returns:
The created or existing batch object.
"""
# Error handling by levels -
# 0. Input param handling, results in 40x errors before processing, e.g.
# - Wrong completion_window
# - Invalid metadata types
# - Unknown endpoint
# -> no batch created
# 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
# - input_file_id missing
# - invalid json in file
# - missing custom_id, method, url, body
# - invalid model
# - streaming
# -> batch created, validation sends to failed status
# 2. Processing errors, result in error_file_id entries, e.g.
# - Any error returned from inference endpoint
# -> batch created, goes to completed status
# TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions"]:
@ -171,6 +189,35 @@ class ReferenceBatchesImpl(Batches):
)
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
# For idempotent requests, use the idempotency key for the batch ID
# This ensures the same key always maps to the same batch ID,
# allowing us to detect parameter conflicts
if idempotency_key is not None:
hash_input = idempotency_key.encode("utf-8")
hash_digest = hashlib.sha256(hash_input).hexdigest()[:24]
batch_id = f"batch_{hash_digest}"
try:
existing_batch = await self.retrieve_batch(batch_id)
if (
existing_batch.input_file_id != input_file_id
or existing_batch.endpoint != endpoint
or existing_batch.completion_window != completion_window
or existing_batch.metadata != metadata
):
raise ConflictError(
f"Idempotency key '{idempotency_key}' was previously used with different parameters. "
"Either use a new idempotency key or ensure all parameters match the original request."
)
logger.info(f"Returning existing batch with ID: {batch_id}")
return existing_batch
except ResourceNotFoundError:
# Batch doesn't exist, continue with creation
pass
current_time = int(time.time())
batch = BatchObject(
@ -185,6 +232,7 @@ class ReferenceBatchesImpl(Batches):
)
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
logger.info(f"Created new batch with ID: {batch_id}")
if self.process_batches:
task = asyncio.create_task(self._process_batch(batch_id))

View file

@ -11,6 +11,7 @@ from typing import Annotated
from fastapi import File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import (
Files,
@ -20,12 +21,15 @@ from llama_stack.apis.files import (
OpenAIFilePurpose,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
from .config import LocalfsFilesImplConfig
logger = get_logger(name=__name__, category="files")
class LocalfsFilesImpl(Files):
def __init__(self, config: LocalfsFilesImplConfig, policy: list[AccessRule]) -> None:
@ -65,6 +69,18 @@ class LocalfsFilesImpl(Files):
"""Get the filesystem path for a file ID."""
return Path(self.config.storage_dir) / file_id
async def _lookup_file_id(self, file_id: str) -> tuple[OpenAIFileObject, Path]:
"""Look up a OpenAIFileObject and filesystem path from its ID."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
file_path = Path(row.pop("file_path"))
return OpenAIFileObject(**row), file_path
# OpenAI Files API Implementation
async def openai_upload_file(
self,
@ -157,37 +173,19 @@ class LocalfsFilesImpl(Files):
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
"""Returns information about a specific file."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
file_obj, _ = await self._lookup_file_id(file_id)
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row:
raise ValueError(f"File with id {file_id} not found")
return OpenAIFileObject(
id=row["id"],
filename=row["filename"],
purpose=OpenAIFilePurpose(row["purpose"]),
bytes=row["bytes"],
created_at=row["created_at"],
expires_at=row["expires_at"],
)
return file_obj
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
"""Delete a file."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row:
raise ValueError(f"File with id {file_id} not found")
# Delete physical file
file_path = Path(row["file_path"])
_, file_path = await self._lookup_file_id(file_id)
if file_path.exists():
file_path.unlink()
# Delete metadata from database
assert self.sql_store is not None, "Files provider not initialized"
await self.sql_store.delete("openai_files", where={"id": file_id})
return OpenAIFileDeleteResponse(
@ -197,25 +195,17 @@ class LocalfsFilesImpl(Files):
async def openai_retrieve_file_content(self, file_id: str) -> Response:
"""Returns the contents of the specified file."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
# Get file metadata
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row:
raise ValueError(f"File with id {file_id} not found")
# Read file content
file_path = Path(row["file_path"])
if not file_path.exists():
raise ValueError(f"File content not found on disk: {file_path}")
file_obj, file_path = await self._lookup_file_id(file_id)
with open(file_path, "rb") as f:
content = f.read()
if not file_path.exists():
logger.warning(f"File '{file_id}'s underlying '{file_path}' is missing, deleting metadata.")
await self.openai_delete_file(file_id)
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
# Return as binary response with appropriate content type
return Response(
content=content,
content=file_path.read_bytes(),
media_type="application/octet-stream",
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
headers={"Content-Disposition": f'attachment; filename="{file_obj.filename}"'},
)

View file

@ -9,7 +9,6 @@ from collections.abc import AsyncGenerator
from llama_stack.apis.inference import (
CompletionResponse,
InferenceProvider,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
@ -100,25 +99,3 @@ class SentenceTransformersInferenceImpl(
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import datetime
import threading
from typing import Any
@ -145,11 +146,41 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
metric_name: str,
start_time: int,
end_time: int | None = None,
granularity: str | None = "1d",
granularity: str | None = None,
query_type: MetricQueryType = MetricQueryType.RANGE,
label_matchers: list[MetricLabelMatcher] | None = None,
) -> QueryMetricsResponse:
raise NotImplementedError("Querying metrics is not implemented")
"""Query metrics from the telemetry store.
Args:
metric_name: The name of the metric to query (e.g., "prompt_tokens")
start_time: Start time as Unix timestamp
end_time: End time as Unix timestamp (defaults to now if None)
granularity: Time granularity for aggregation
query_type: Type of query (RANGE or INSTANT)
label_matchers: Label filters to apply
Returns:
QueryMetricsResponse with metric time series data
"""
# Convert timestamps to datetime objects
start_dt = datetime.datetime.fromtimestamp(start_time, datetime.UTC)
end_dt = datetime.datetime.fromtimestamp(end_time, datetime.UTC) if end_time else None
# Use SQLite trace store if available
if hasattr(self, "trace_store") and self.trace_store:
return await self.trace_store.query_metrics(
metric_name=metric_name,
start_time=start_dt,
end_time=end_dt,
granularity=granularity,
query_type=query_type,
label_matchers=label_matchers,
)
else:
raise ValueError(
f"In order to query_metrics, you must have {TelemetrySink.SQLITE} set in your telemetry sinks"
)
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:

View file

@ -5,9 +5,11 @@
# the root directory of this source tree.
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
@ -23,4 +25,14 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
description="Local filesystem-based file storage provider for managing files and documents locally.",
),
remote_provider_spec(
api=Api.files,
adapter=AdapterSpec(
adapter_type="s3",
pip_packages=["boto3"] + sql_store_pip_packages,
module="llama_stack.providers.remote.files.s3",
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
),
),
]

View file

@ -5,15 +5,15 @@
# the root directory of this source tree.
from typing import cast
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
def available_providers() -> list[ProviderSpec]:
return [
InlineProviderSpec(
# We provide two versions of these providers so that distributions can package the appropriate version of torch.
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
torchtune_def = dict(
api=Api.post_training,
provider_type="inline::torchtune",
pip_packages=["torch", "torchtune==0.5.0", "torchao==0.8.0", "numpy"],
pip_packages=["torchtune==0.5.0", "torchao==0.8.0", "numpy"],
module="llama_stack.providers.inline.post_training.torchtune",
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
api_dependencies=[
@ -21,11 +21,11 @@ def available_providers() -> list[ProviderSpec]:
Api.datasets,
],
description="TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.",
),
InlineProviderSpec(
)
huggingface_def = dict(
api=Api.post_training,
provider_type="inline::huggingface",
pip_packages=["torch", "trl", "transformers", "peft", "datasets"],
pip_packages=["trl", "transformers", "peft", "datasets"],
module="llama_stack.providers.inline.post_training.huggingface",
config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig",
api_dependencies=[
@ -33,6 +33,46 @@ def available_providers() -> list[ProviderSpec]:
Api.datasets,
],
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
)
def available_providers() -> list[ProviderSpec]:
return [
InlineProviderSpec(
**{
**torchtune_def,
"provider_type": "inline::torchtune-cpu",
"pip_packages": (
cast(list[str], torchtune_def["pip_packages"])
+ ["torch torchtune==0.5.0 torchao==0.8.0 --index-url https://download.pytorch.org/whl/cpu"]
),
},
),
InlineProviderSpec(
**{
**huggingface_def,
"provider_type": "inline::huggingface-cpu",
"pip_packages": (
cast(list[str], huggingface_def["pip_packages"])
+ ["torch --index-url https://download.pytorch.org/whl/cpu"]
),
},
),
InlineProviderSpec(
**{
**torchtune_def,
"provider_type": "inline::torchtune-gpu",
"pip_packages": (
cast(list[str], torchtune_def["pip_packages"]) + ["torch torchtune==0.5.0 torchao==0.8.0"]
),
},
),
InlineProviderSpec(
**{
**huggingface_def,
"provider_type": "inline::huggingface-gpu",
"pip_packages": (cast(list[str], huggingface_def["pip_packages"]) + ["torch"]),
},
),
remote_provider_spec(
api=Api.post_training,

View file

@ -0,0 +1,237 @@
# S3 Files Provider
A remote S3-based implementation of the Llama Stack Files API that provides scalable cloud file storage with metadata persistence.
## Features
- **AWS S3 Storage**: Store files in AWS S3 buckets for scalable, durable storage
- **Metadata Management**: Uses SQL database for efficient file metadata queries
- **OpenAI API Compatibility**: Full compatibility with OpenAI Files API endpoints
- **Flexible Authentication**: Support for IAM roles and access keys
- **Custom S3 Endpoints**: Support for MinIO and other S3-compatible services
## Configuration
### Basic Configuration
```yaml
api: files
provider_type: remote::s3
config:
bucket_name: my-llama-stack-files
region: us-east-1
metadata_store:
type: sqlite
db_path: ./s3_files_metadata.db
```
### Advanced Configuration
```yaml
api: files
provider_type: remote::s3
config:
bucket_name: my-llama-stack-files
region: us-east-1
aws_access_key_id: YOUR_ACCESS_KEY
aws_secret_access_key: YOUR_SECRET_KEY
endpoint_url: https://s3.amazonaws.com # Optional for custom endpoints
metadata_store:
type: sqlite
db_path: ./s3_files_metadata.db
```
### Environment Variables
The configuration supports environment variable substitution:
```yaml
config:
bucket_name: "${env.S3_BUCKET_NAME}"
region: "${env.AWS_REGION:=us-east-1}"
aws_access_key_id: "${env.AWS_ACCESS_KEY_ID:=}"
aws_secret_access_key: "${env.AWS_SECRET_ACCESS_KEY:=}"
endpoint_url: "${env.S3_ENDPOINT_URL:=}"
```
Note: `S3_BUCKET_NAME` has no default value since S3 bucket names must be globally unique.
## Authentication
### IAM Roles (Recommended)
For production deployments, use IAM roles:
```yaml
config:
bucket_name: my-bucket
region: us-east-1
# No credentials needed - will use IAM role
```
### Access Keys
For development or specific use cases:
```yaml
config:
bucket_name: my-bucket
region: us-east-1
aws_access_key_id: AKIAIOSFODNN7EXAMPLE
aws_secret_access_key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY
```
## S3 Bucket Setup
### Required Permissions
The S3 provider requires the following permissions:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject",
"s3:ListBucket"
],
"Resource": [
"arn:aws:s3:::your-bucket-name",
"arn:aws:s3:::your-bucket-name/*"
]
}
]
}
```
### Automatic Bucket Creation
By default, the S3 provider expects the bucket to already exist. If you want the provider to automatically create the bucket when it doesn't exist, set `auto_create_bucket: true` in your configuration:
```yaml
config:
bucket_name: my-bucket
auto_create_bucket: true # Will create bucket if it doesn't exist
region: us-east-1
```
**Note**: When `auto_create_bucket` is enabled, the provider will need additional permissions:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject",
"s3:ListBucket",
"s3:CreateBucket"
],
"Resource": [
"arn:aws:s3:::your-bucket-name",
"arn:aws:s3:::your-bucket-name/*"
]
}
]
}
```
### Bucket Policy (Optional)
For additional security, you can add a bucket policy:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Sid": "LlamaStackAccess",
"Effect": "Allow",
"Principal": {
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
},
"Action": [
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject"
],
"Resource": "arn:aws:s3:::your-bucket-name/*"
},
{
"Sid": "LlamaStackBucketAccess",
"Effect": "Allow",
"Principal": {
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
},
"Action": [
"s3:ListBucket"
],
"Resource": "arn:aws:s3:::your-bucket-name"
}
]
}
```
## Features
### Metadata Persistence
File metadata is stored in a SQL database for fast queries and OpenAI API compatibility. The metadata includes:
- File ID
- Original filename
- Purpose (assistants, batch, etc.)
- File size in bytes
- Created and expiration timestamps
### TTL and Cleanup
Files currently have a fixed long expiration time (100 years).
## Development and Testing
### Using MinIO
For self-hosted S3-compatible storage:
```yaml
config:
bucket_name: test-bucket
region: us-east-1
endpoint_url: http://localhost:9000
aws_access_key_id: minioadmin
aws_secret_access_key: minioadmin
```
## Monitoring and Logging
The provider logs important operations and errors. For production deployments, consider:
- CloudWatch monitoring for S3 operations
- Custom metrics for file upload/download rates
- Error rate monitoring
- Performance metrics tracking
## Error Handling
The provider handles various error scenarios:
- S3 connectivity issues
- Bucket access permissions
- File not found errors
- Metadata consistency checks
## Known Limitations
- Fixed long TTL (100 years) instead of configurable expiration
- No server-side encryption enabled by default
- No support for AWS session tokens
- No S3 key prefix organization support
- No multipart upload support (all files uploaded as single objects)

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.core.datatypes import Api
from .config import S3FilesImplConfig
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]):
from .files import S3FilesImpl
# TODO: authorization policies and user separation
impl = S3FilesImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
class S3FilesImplConfig(BaseModel):
"""Configuration for S3-based files provider."""
bucket_name: str = Field(description="S3 bucket name to store files")
region: str = Field(default="us-east-1", description="AWS region where the bucket is located")
aws_access_key_id: str | None = Field(default=None, description="AWS access key ID (optional if using IAM roles)")
aws_secret_access_key: str | None = Field(
default=None, description="AWS secret access key (optional if using IAM roles)"
)
endpoint_url: str | None = Field(default=None, description="Custom S3 endpoint URL (for MinIO, LocalStack, etc.)")
auto_create_bucket: bool = Field(
default=False, description="Automatically create the S3 bucket if it doesn't exist"
)
metadata_store: SqlStoreConfig = Field(description="SQL store configuration for file metadata")
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"bucket_name": "${env.S3_BUCKET_NAME}", # no default, buckets must be globally unique
"region": "${env.AWS_REGION:=us-east-1}",
"aws_access_key_id": "${env.AWS_ACCESS_KEY_ID:=}",
"aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:=}",
"endpoint_url": "${env.S3_ENDPOINT_URL:=}",
"auto_create_bucket": "${env.S3_AUTO_CREATE_BUCKET:=false}",
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="s3_files_metadata.db",
),
}

View file

@ -0,0 +1,272 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
import uuid
from typing import Annotated
import boto3
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
from fastapi import File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import (
Files,
ListOpenAIFileResponse,
OpenAIFileDeleteResponse,
OpenAIFileObject,
OpenAIFilePurpose,
)
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl
from .config import S3FilesImplConfig
# TODO: provider data for S3 credentials
def _create_s3_client(config: S3FilesImplConfig) -> boto3.client:
try:
s3_config = {
"region_name": config.region,
}
# endpoint URL if specified (for MinIO, LocalStack, etc.)
if config.endpoint_url:
s3_config["endpoint_url"] = config.endpoint_url
if config.aws_access_key_id and config.aws_secret_access_key:
s3_config.update(
{
"aws_access_key_id": config.aws_access_key_id,
"aws_secret_access_key": config.aws_secret_access_key,
}
)
return boto3.client("s3", **s3_config)
except (BotoCoreError, NoCredentialsError) as e:
raise RuntimeError(f"Failed to initialize S3 client: {e}") from e
async def _create_bucket_if_not_exists(client: boto3.client, config: S3FilesImplConfig) -> None:
try:
client.head_bucket(Bucket=config.bucket_name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
if error_code == "404":
if not config.auto_create_bucket:
raise RuntimeError(
f"S3 bucket '{config.bucket_name}' does not exist. "
f"Either create the bucket manually or set 'auto_create_bucket: true' in your configuration."
) from e
try:
# For us-east-1, we can't specify LocationConstraint
if config.region == "us-east-1":
client.create_bucket(Bucket=config.bucket_name)
else:
client.create_bucket(
Bucket=config.bucket_name,
CreateBucketConfiguration={"LocationConstraint": config.region},
)
except ClientError as create_error:
raise RuntimeError(
f"Failed to create S3 bucket '{config.bucket_name}': {create_error}"
) from create_error
elif error_code == "403":
raise RuntimeError(f"Access denied to S3 bucket '{config.bucket_name}'") from e
else:
raise RuntimeError(f"Failed to access S3 bucket '{config.bucket_name}': {e}") from e
class S3FilesImpl(Files):
"""S3-based implementation of the Files API."""
# TODO: implement expiration, for now a silly offset
_SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60
def __init__(self, config: S3FilesImplConfig) -> None:
self._config = config
self._client: boto3.client | None = None
self._sql_store: SqlStore | None = None
async def initialize(self) -> None:
self._client = _create_s3_client(self._config)
await _create_bucket_if_not_exists(self._client, self._config)
self._sql_store = sqlstore_impl(self._config.metadata_store)
await self._sql_store.create_table(
"openai_files",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"filename": ColumnType.STRING,
"purpose": ColumnType.STRING,
"bytes": ColumnType.INTEGER,
"created_at": ColumnType.INTEGER,
"expires_at": ColumnType.INTEGER,
# TODO: add s3_etag field for integrity checking
},
)
async def shutdown(self) -> None:
pass
@property
def client(self) -> boto3.client:
assert self._client is not None, "Provider not initialized"
return self._client
@property
def sql_store(self) -> SqlStore:
assert self._sql_store is not None, "Provider not initialized"
return self._sql_store
async def openai_upload_file(
self,
file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()],
) -> OpenAIFileObject:
file_id = f"file-{uuid.uuid4().hex}"
filename = getattr(file, "filename", None) or "uploaded_file"
created_at = int(time.time())
expires_at = created_at + self._SILLY_EXPIRATION_OFFSET
content = await file.read()
file_size = len(content)
await self.sql_store.insert(
"openai_files",
{
"id": file_id,
"filename": filename,
"purpose": purpose.value,
"bytes": file_size,
"created_at": created_at,
"expires_at": expires_at,
},
)
try:
self.client.put_object(
Bucket=self._config.bucket_name,
Key=file_id,
Body=content,
# TODO: enable server-side encryption
)
except ClientError as e:
await self.sql_store.delete("openai_files", where={"id": file_id})
raise RuntimeError(f"Failed to upload file to S3: {e}") from e
return OpenAIFileObject(
id=file_id,
filename=filename,
purpose=purpose,
bytes=file_size,
created_at=created_at,
expires_at=expires_at,
)
async def openai_list_files(
self,
after: str | None = None,
limit: int | None = 10000,
order: Order | None = Order.desc,
purpose: OpenAIFilePurpose | None = None,
) -> ListOpenAIFileResponse:
# this purely defensive. it should not happen because the router also default to Order.desc.
if not order:
order = Order.desc
where_conditions = {}
if purpose:
where_conditions["purpose"] = purpose.value
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,
limit=limit,
)
files = [
OpenAIFileObject(
id=row["id"],
filename=row["filename"],
purpose=OpenAIFilePurpose(row["purpose"]),
bytes=row["bytes"],
created_at=row["created_at"],
expires_at=row["expires_at"],
)
for row in paginated_result.data
]
return ListOpenAIFileResponse(
data=files,
has_more=paginated_result.has_more,
# empty string or None? spec says str, ref impl returns str | None, we go with spec
first_id=files[0].id if files else "",
last_id=files[-1].id if files else "",
)
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "files.list()")
return OpenAIFileObject(
id=row["id"],
filename=row["filename"],
purpose=OpenAIFilePurpose(row["purpose"]),
bytes=row["bytes"],
created_at=row["created_at"],
expires_at=row["expires_at"],
)
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "files.list()")
try:
self.client.delete_object(
Bucket=self._config.bucket_name,
Key=row["id"],
)
except ClientError as e:
if e.response["Error"]["Code"] != "NoSuchKey":
raise RuntimeError(f"Failed to delete file from S3: {e}") from e
await self.sql_store.delete("openai_files", where={"id": file_id})
return OpenAIFileDeleteResponse(id=file_id, deleted=True)
async def openai_retrieve_file_content(self, file_id: str) -> Response:
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "files.list()")
try:
response = self.client.get_object(
Bucket=self._config.bucket_name,
Key=row["id"],
)
# TODO: can we stream this instead of loading it into memory
content = response["Body"].read()
except ClientError as e:
if e.response["Error"]["Code"] == "NoSuchKey":
await self.sql_store.delete("openai_files", where={"id": file_id})
raise ResourceNotFoundError(file_id, "File", "files.list()") from e
raise RuntimeError(f"Failed to download file from S3: {e}") from e
return Response(
content=content,
media_type="application/octet-stream",
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
)

View file

@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::fireworks")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):

View file

@ -10,7 +10,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):

View file

@ -41,6 +41,11 @@ client.initialize()
### Create Completion
> Note on Completion API
>
> The hosted NVIDIA Llama NIMs (e.g., `meta-llama/Llama-3.1-8B-Instruct`) with ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` does not support the ```completion``` method, while the locally deployed NIM does.
```python
response = client.inference.completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
@ -76,6 +81,73 @@ response = client.inference.chat_completion(
print(f"Response: {response.completion_message.content}")
```
### Tool Calling Example ###
```python
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
tool_definition = ToolDefinition(
tool_name="get_weather",
description="Get current weather information for a location",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
required=True,
),
"unit": ToolParamDefinition(
param_type="string",
description="Temperature unit (celsius or fahrenheit)",
required=False,
default="celsius",
),
},
)
tool_response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
tools=[tool_definition],
)
print(f"Tool Response: {tool_response.completion_message.content}")
if tool_response.completion_message.tool_calls:
for tool_call in tool_response.completion_message.tool_calls:
print(f"Tool Called: {tool_call.tool_name}")
print(f"Arguments: {tool_call.arguments}")
```
### Structured Output Example
```python
from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
person_schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"occupation": {"type": "string"},
},
"required": ["name", "age", "occupation"],
}
response_format = JsonSchemaResponseFormat(
type=ResponseFormatType.json_schema, json_schema=person_schema
)
structured_response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
messages=[
{
"role": "user",
"content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ",
}
],
response_format=response_format,
)
print(f"Structured Response: {structured_response.completion_message.content}")
```
### Create Embeddings
> Note on OpenAI embeddings compatibility
>

View file

@ -7,7 +7,7 @@
import warnings
from collections.abc import AsyncIterator
from openai import NOT_GIVEN, APIConnectionError, BadRequestError
from openai import NOT_GIVEN, APIConnectionError
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -57,7 +57,7 @@ from .openai_utils import (
)
from .utils import _is_nvidia_hosted
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::nvidia")
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
@ -197,15 +197,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
}
extra_body["input_type"] = task_type_options[task_type]
try:
response = await self.client.embeddings.create(
model=provider_model_id,
input=input,
extra_body=extra_body,
)
except BadRequestError as e:
raise ValueError(f"Failed to get embeddings: {e}") from e
#
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
# ->

View file

@ -10,7 +10,7 @@ from llama_stack.log import get_logger
from . import NVIDIAConfig
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::nvidia")
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:

View file

@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter(
@ -619,28 +619,6 @@ class OllamaInferenceAdapter(
response.id = id
return response
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for Ollama")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
async def _convert_content(content) -> dict:

View file

@ -11,7 +11,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::openai")
#

View file

@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
log = get_logger(name=__name__, category="inference")
log = get_logger(name=__name__, category="inference::tgi")
def build_hf_repo_model_entries():

View file

@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::together")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):

View file

@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMInferenceAdapterConfig
log = get_logger(name=__name__, category="inference")
log = get_logger(name=__name__, category="inference::vllm")
def build_hf_repo_model_entries():
@ -711,25 +711,3 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
user=user,
)
return await self.client.chat.completions.create(**params) # type: ignore
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for Ollama")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")

View file

@ -15,7 +15,7 @@ from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefa
from .config import NvidiaPostTrainingConfig
logger = get_logger(name=__name__, category="integration")
logger = get_logger(name=__name__, category="post_training::nvidia")
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:

View file

@ -21,7 +21,7 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from .config import BedrockSafetyConfig
logger = get_logger(name=__name__, category="safety")
logger = get_logger(name=__name__, category="safety::bedrock")
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):

View file

@ -9,7 +9,7 @@ from typing import Any
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
from .config import NVIDIASafetyConfig
logger = get_logger(name=__name__, category="safety")
logger = get_logger(name=__name__, category="safety::nvidia")
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
@ -67,6 +67,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
self.shield = NeMoGuardrails(self.config, shield.shield_id)
return await self.shield.run(messages)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")
class NeMoGuardrails:
"""

Some files were not shown because too many files have changed in this diff Show more