Merge branch 'main' into responses_object

This commit is contained in:
Emilio Garcia 2025-08-25 16:01:59 -04:00 committed by GitHub
commit 708b2c1b05
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
166 changed files with 6944 additions and 809 deletions

View file

@ -9,6 +9,7 @@ updates:
day: "saturday" day: "saturday"
commit-message: commit-message:
prefix: chore(github-deps) prefix: chore(github-deps)
- package-ecosystem: "uv" - package-ecosystem: "uv"
directory: "/" directory: "/"
schedule: schedule:
@ -19,3 +20,14 @@ updates:
- python - python
commit-message: commit-message:
prefix: chore(python-deps) prefix: chore(python-deps)
- package-ecosystem: npm
directory: "/llama_stack/ui"
schedule:
interval: "weekly"
day: "saturday"
labels:
- type/dependencies
- javascript
commit-message:
prefix: chore(ui-deps)

View file

@ -17,7 +17,7 @@ jobs:
pull-requests: write # for peter-evans/create-pull-request to create a PR pull-requests: write # for peter-evans/create-pull-request to create a PR
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with: with:
ref: main ref: main
fetch-depth: 0 fetch-depth: 0

View file

@ -16,14 +16,14 @@ jobs:
lint: lint:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2 - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # 5.0.0
- name: Run ShellCheck on install.sh - name: Run ShellCheck on install.sh
run: shellcheck scripts/install.sh run: shellcheck scripts/install.sh
smoke-test-on-dev: smoke-test-on-dev:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner

View file

@ -18,7 +18,7 @@ on:
- '.github/workflows/integration-auth-tests.yml' # This workflow - '.github/workflows/integration-auth-tests.yml' # This workflow
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
@ -31,7 +31,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner

View file

@ -16,7 +16,7 @@ on:
- '.github/workflows/integration-sql-store-tests.yml' # This workflow - '.github/workflows/integration-sql-store-tests.yml' # This workflow
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
@ -44,7 +44,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner

View file

@ -65,7 +65,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Setup test environment - name: Setup test environment
uses: ./.github/actions/setup-test-environment uses: ./.github/actions/setup-test-environment

View file

@ -33,7 +33,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner

View file

@ -8,7 +8,7 @@ on:
branches: [main] branches: [main]
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
@ -20,7 +20,7 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with: with:
# For dependabot PRs, we need to checkout with a token that can push changes # For dependabot PRs, we need to checkout with a token that can push changes
token: ${{ github.actor == 'dependabot[bot]' && secrets.GITHUB_TOKEN || github.token }} token: ${{ github.actor == 'dependabot[bot]' && secrets.GITHUB_TOKEN || github.token }}
@ -36,20 +36,16 @@ jobs:
**/requirements*.txt **/requirements*.txt
.pre-commit-config.yaml .pre-commit-config.yaml
# npm ci may fail - - name: Set up Node.js
# npm error `npm ci` can only install packages when your package.json and package-lock.json or npm-shrinkwrap.json are in sync. Please update your lock file with `npm install` before continuing. uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
# npm error Invalid: lock file's llama-stack-client@0.2.17 does not satisfy llama-stack-client@0.2.18 with:
node-version: '20'
cache: 'npm'
cache-dependency-path: 'llama_stack/ui/'
# - name: Set up Node.js - name: Install npm dependencies
# uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0 run: npm ci
# with: working-directory: llama_stack/ui
# node-version: '20'
# cache: 'npm'
# cache-dependency-path: 'llama_stack/ui/'
# - name: Install npm dependencies
# run: npm ci
# working-directory: llama_stack/ui
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
continue-on-error: true continue-on-error: true

View file

@ -26,7 +26,7 @@ on:
- 'pyproject.toml' - 'pyproject.toml'
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
@ -36,7 +36,7 @@ jobs:
distros: ${{ steps.set-matrix.outputs.distros }} distros: ${{ steps.set-matrix.outputs.distros }}
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Generate Distribution List - name: Generate Distribution List
id: set-matrix id: set-matrix
@ -55,7 +55,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner
@ -79,7 +79,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner
@ -92,7 +92,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner
@ -106,6 +106,10 @@ jobs:
- name: Inspect the container image entrypoint - name: Inspect the container image entrypoint
run: | run: |
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1) IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
if [ -z "$IMAGE_ID" ]; then
echo "No image found"
exit 1
fi
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID) entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
echo "Entrypoint: $entrypoint" echo "Entrypoint: $entrypoint"
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then
@ -117,7 +121,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner
@ -140,6 +144,10 @@ jobs:
- name: Inspect UBI9 image - name: Inspect UBI9 image
run: | run: |
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1) IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
if [ -z "$IMAGE_ID" ]; then
echo "No image found"
exit 1
fi
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID) entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
echo "Entrypoint: $entrypoint" echo "Entrypoint: $entrypoint"
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then

View file

@ -21,10 +21,10 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@e92bafb6253dcd438e0484186d7669ea7a8ca1cc # v6.4.3 uses: astral-sh/setup-uv@4959332f0f014c5280e7eac8b70c90cb574c9f9b # v6.6.0
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
activate-environment: true activate-environment: true

View file

@ -46,7 +46,7 @@ jobs:
echo "::endgroup::" echo "::endgroup::"
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with: with:
fetch-depth: 0 fetch-depth: 0

View file

@ -22,6 +22,6 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Check PR Title's semantic conformance - name: Check PR Title's semantic conformance
uses: amannn/action-semantic-pull-request@0723387faaf9b38adef4775cd42cfd5155ed6017 # v5.5.3 uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View file

@ -27,7 +27,7 @@ jobs:
# container and point 'uv pip install' to the correct path... # container and point 'uv pip install' to the correct path...
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner

View file

@ -27,7 +27,7 @@ jobs:
# container and point 'uv pip install' to the correct path... # container and point 'uv pip install' to the correct path...
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner

View file

@ -13,7 +13,7 @@ on:
workflow_dispatch: workflow_dispatch:
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
@ -26,10 +26,10 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Setup Node.js - name: Setup Node.js
uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0 uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
with: with:
node-version: ${{ matrix.node-version }} node-version: ${{ matrix.node-version }}
cache: 'npm' cache: 'npm'

View file

@ -18,7 +18,7 @@ on:
workflow_dispatch: workflow_dispatch:
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
@ -32,7 +32,7 @@ jobs:
- "3.13" - "3.13"
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner

View file

@ -27,7 +27,7 @@ on:
- '.github/workflows/update-readthedocs.yml' - '.github/workflows/update-readthedocs.yml'
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
@ -37,7 +37,7 @@ jobs:
TOKEN: ${{ secrets.READTHEDOCS_TOKEN }} TOKEN: ${{ secrets.READTHEDOCS_TOKEN }}
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner

View file

@ -146,31 +146,13 @@ repos:
pass_filenames: false pass_filenames: false
require_serial: true require_serial: true
files: ^.github/workflows/.*$ files: ^.github/workflows/.*$
# ui-prettier and ui-eslint are disabled until we can avoid `npm ci`, which is slow and may fail - - id: ui-linter
# npm error `npm ci` can only install packages when your package.json and package-lock.json or npm-shrinkwrap.json are in sync. Please update your lock file with `npm install` before continuing. name: Format & Lint UI
# npm error Invalid: lock file's llama-stack-client@0.2.17 does not satisfy llama-stack-client@0.2.18 entry: bash ./scripts/run-ui-linter.sh
# and until we have infra for installing prettier and next via npm - language: system
# Lint UI code with ESLint.....................................................Failed files: ^llama_stack/ui/.*\.(ts|tsx)$
# - hook id: ui-eslint pass_filenames: false
# - exit code: 127 require_serial: true
# > ui@0.1.0 lint
# > next lint --fix --quiet
# sh: line 1: next: command not found
#
# - id: ui-prettier
# name: Format UI code with Prettier
# entry: bash -c 'cd llama_stack/ui && npm ci && npm run format'
# language: system
# files: ^llama_stack/ui/.*\.(ts|tsx)$
# pass_filenames: false
# require_serial: true
# - id: ui-eslint
# name: Lint UI code with ESLint
# entry: bash -c 'cd llama_stack/ui && npm run lint -- --fix --quiet'
# language: system
# files: ^llama_stack/ui/.*\.(ts|tsx)$
# pass_filenames: false
# require_serial: true
- id: check-log-usage - id: check-log-usage
name: Ensure 'llama_stack.log' usage for logging name: Ensure 'llama_stack.log' usage for logging

View file

@ -4605,6 +4605,49 @@
} }
} }
}, },
"/v1/inference/rerank": {
"post": {
"responses": {
"200": {
"description": "RerankResponse with indices sorted by relevance score (descending).",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/RerankResponse"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Inference"
],
"description": "Rerank a list of documents based on their relevance to a query.",
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/RerankRequest"
}
}
},
"required": true
}
}
},
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": { "/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
"post": { "post": {
"responses": { "responses": {
@ -16418,12 +16461,16 @@
"value": { "value": {
"type": "number", "type": "number",
"description": "The numeric value of the metric at this timestamp" "description": "The numeric value of the metric at this timestamp"
},
"unit": {
"type": "string"
} }
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"timestamp", "timestamp",
"value" "value",
"unit"
], ],
"title": "MetricDataPoint", "title": "MetricDataPoint",
"description": "A single data point in a metric time series." "description": "A single data point in a metric time series."
@ -16981,6 +17028,95 @@
], ],
"title": "RegisterVectorDbRequest" "title": "RegisterVectorDbRequest"
}, },
"RerankRequest": {
"type": "object",
"properties": {
"model": {
"type": "string",
"description": "The identifier of the reranking model to use."
},
"query": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
}
],
"description": "The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length."
},
"items": {
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
}
]
},
"description": "List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length."
},
"max_num_results": {
"type": "integer",
"description": "(Optional) Maximum number of results to return. Default: returns all."
}
},
"additionalProperties": false,
"required": [
"model",
"query",
"items"
],
"title": "RerankRequest"
},
"RerankData": {
"type": "object",
"properties": {
"index": {
"type": "integer",
"description": "The original index of the document in the input list"
},
"relevance_score": {
"type": "number",
"description": "The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance."
}
},
"additionalProperties": false,
"required": [
"index",
"relevance_score"
],
"title": "RerankData",
"description": "A single rerank result from a reranking response."
},
"RerankResponse": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/components/schemas/RerankData"
},
"description": "List of rerank result objects, sorted by relevance score (descending)"
}
},
"additionalProperties": false,
"required": [
"data"
],
"title": "RerankResponse",
"description": "Response from a reranking request."
},
"ResumeAgentTurnRequest": { "ResumeAgentTurnRequest": {
"type": "object", "type": "object",
"properties": { "properties": {

View file

@ -3264,6 +3264,37 @@ paths:
schema: schema:
$ref: '#/components/schemas/QueryTracesRequest' $ref: '#/components/schemas/QueryTracesRequest'
required: true required: true
/v1/inference/rerank:
post:
responses:
'200':
description: >-
RerankResponse with indices sorted by relevance score (descending).
content:
application/json:
schema:
$ref: '#/components/schemas/RerankResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Inference
description: >-
Rerank a list of documents based on their relevance to a query.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RerankRequest'
required: true
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume: /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
post: post:
responses: responses:
@ -12242,10 +12273,13 @@ components:
type: number type: number
description: >- description: >-
The numeric value of the metric at this timestamp The numeric value of the metric at this timestamp
unit:
type: string
additionalProperties: false additionalProperties: false
required: required:
- timestamp - timestamp
- value - value
- unit
title: MetricDataPoint title: MetricDataPoint
description: >- description: >-
A single data point in a metric time series. A single data point in a metric time series.
@ -12656,6 +12690,76 @@ components:
- vector_db_id - vector_db_id
- embedding_model - embedding_model
title: RegisterVectorDbRequest title: RegisterVectorDbRequest
RerankRequest:
type: object
properties:
model:
type: string
description: >-
The identifier of the reranking model to use.
query:
oneOf:
- type: string
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
description: >-
The search query to rank items against. Can be a string, text content
part, or image content part. The input must not exceed the model's max
input token length.
items:
type: array
items:
oneOf:
- type: string
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
description: >-
List of items to rerank. Each item can be a string, text content part,
or image content part. Each input must not exceed the model's max input
token length.
max_num_results:
type: integer
description: >-
(Optional) Maximum number of results to return. Default: returns all.
additionalProperties: false
required:
- model
- query
- items
title: RerankRequest
RerankData:
type: object
properties:
index:
type: integer
description: >-
The original index of the document in the input list
relevance_score:
type: number
description: >-
The relevance score from the model output. Values are inverted when applicable
so that higher scores indicate greater relevance.
additionalProperties: false
required:
- index
- relevance_score
title: RerankData
description: >-
A single rerank result from a reranking response.
RerankResponse:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/RerankData'
description: >-
List of rerank result objects, sorted by relevance score (descending)
additionalProperties: false
required:
- data
title: RerankResponse
description: Response from a reranking request.
ResumeAgentTurnRequest: ResumeAgentTurnRequest:
type: object type: object
properties: properties:

View file

@ -225,8 +225,32 @@ server:
port: 8321 # Port to listen on (default: 8321) port: 8321 # Port to listen on (default: 8321)
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
cors: true # Optional: Enable CORS (dev mode) or full config object
``` ```
### CORS Configuration
CORS (Cross-Origin Resource Sharing) can be configured in two ways:
**Local development** (allows localhost origins only):
```yaml
server:
cors: true
```
**Explicit configuration** (custom origins and settings):
```yaml
server:
cors:
allow_origins: ["https://myapp.com", "https://app.example.com"]
allow_methods: ["GET", "POST", "PUT", "DELETE"]
allow_headers: ["Content-Type", "Authorization"]
allow_credentials: true
max_age: 3600
```
When `cors: true`, the server enables secure localhost-only access for local development. For production, specify exact origins to maintain security.
### Authentication Configuration ### Authentication Configuration
> **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly. > **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly.
@ -618,6 +642,54 @@ Content-Type: application/json
} }
``` ```
### CORS Configuration
Configure CORS to allow web browsers to make requests from different domains. Disabled by default.
#### Quick Setup
For development, use the simple boolean flag:
```yaml
server:
cors: true # Auto-enables localhost with any port
```
This automatically allows `http://localhost:*` and `https://localhost:*` with secure defaults.
#### Custom Configuration
For specific origins and full control:
```yaml
server:
cors:
allow_origins: ["https://myapp.com", "https://staging.myapp.com"]
allow_credentials: true
allow_methods: ["GET", "POST", "PUT", "DELETE"]
allow_headers: ["Content-Type", "Authorization"]
allow_origin_regex: "https://.*\\.example\\.com" # Optional regex pattern
expose_headers: ["X-Total-Count"]
max_age: 86400
```
#### Configuration Options
| Field | Description | Default |
| -------------------- | ---------------------------------------------- | ------- |
| `allow_origins` | List of allowed origins. Use `["*"]` for any. | `["*"]` |
| `allow_origin_regex` | Regex pattern for allowed origins (optional). | `None` |
| `allow_methods` | Allowed HTTP methods. | `["*"]` |
| `allow_headers` | Allowed headers. | `["*"]` |
| `allow_credentials` | Allow credentials (cookies, auth headers). | `false` |
| `expose_headers` | Headers exposed to browser. | `[]` |
| `max_age` | Preflight cache time (seconds). | `600` |
**Security Notes**:
- `allow_credentials: true` requires explicit origins (no wildcards)
- `cors: true` enables localhost access only (secure for development)
- For public APIs, always specify exact allowed origins
## Extending to handle Safety ## Extending to handle Safety
Configuring Safety can be a little involved so it is instructive to go through an example. Configuring Safety can be a little involved so it is instructive to go through an example.

View file

@ -17,7 +17,6 @@ client = LlamaStackAsLibraryClient(
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here. # provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]}, provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
) )
client.initialize()
``` ```
This will parse your config and set up any inline implementations and remote clients needed for your implementation. This will parse your config and set up any inline implementations and remote clients needed for your implementation.
@ -32,5 +31,4 @@ If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/
```python ```python
client = LlamaStackAsLibraryClient(config_path) client = LlamaStackAsLibraryClient(config_path)
client.initialize()
``` ```

View file

@ -2,12 +2,15 @@
## Overview ## Overview
Protocol for batch processing API operations. The Batches API enables efficient processing of multiple requests in a single operation,
The Batches API enables efficient processing of multiple requests in a single operation,
particularly useful for processing large datasets, batch evaluation workflows, and particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale. cost-effective inference at scale.
The API is designed to allow use of openai client libraries for seamless integration.
This API provides the following extensions:
- idempotent batch creation
Note: This API is currently under active development and may undergo changes. Note: This API is currently under active development and may undergo changes.
This section contains documentation for all available providers for the **batches** API. This section contains documentation for all available providers for the **batches** API.

View file

@ -10,4 +10,5 @@ This section contains documentation for all available providers for the **files*
:maxdepth: 1 :maxdepth: 1
inline_localfs inline_localfs
remote_s3
``` ```

View file

@ -0,0 +1,33 @@
# remote::s3
## Description
AWS S3-based file storage provider for scalable cloud file management with metadata persistence.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `bucket_name` | `<class 'str'>` | No | | S3 bucket name to store files |
| `region` | `<class 'str'>` | No | us-east-1 | AWS region where the bucket is located |
| `aws_access_key_id` | `str \| None` | No | | AWS access key ID (optional if using IAM roles) |
| `aws_secret_access_key` | `str \| None` | No | | AWS secret access key (optional if using IAM roles) |
| `endpoint_url` | `str \| None` | No | | Custom S3 endpoint URL (for MinIO, LocalStack, etc.) |
| `auto_create_bucket` | `<class 'bool'>` | No | False | Automatically create the S3 bucket if it doesn't exist |
| `metadata_store` | `utils.sqlstore.sqlstore.SqliteSqlStoreConfig \| utils.sqlstore.sqlstore.PostgresSqlStoreConfig` | No | sqlite | SQL store configuration for file metadata |
## Sample Configuration
```yaml
bucket_name: ${env.S3_BUCKET_NAME}
region: ${env.AWS_REGION:=us-east-1}
aws_access_key_id: ${env.AWS_ACCESS_KEY_ID:=}
aws_secret_access_key: ${env.AWS_SECRET_ACCESS_KEY:=}
endpoint_url: ${env.S3_ENDPOINT_URL:=}
auto_create_bucket: ${env.S3_AUTO_CREATE_BUCKET:=false}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/s3_files_metadata.db
```

View file

@ -9,7 +9,9 @@ This section contains documentation for all available providers for the **post_t
```{toctree} ```{toctree}
:maxdepth: 1 :maxdepth: 1
inline_huggingface inline_huggingface-cpu
inline_torchtune inline_huggingface-gpu
inline_torchtune-cpu
inline_torchtune-gpu
remote_nvidia remote_nvidia
``` ```

View file

@ -0,0 +1,41 @@
# inline::huggingface-cpu
## Description
HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `device` | `<class 'str'>` | No | cuda | |
| `distributed_backend` | `Literal['fsdp', 'deepspeed'` | No | | |
| `checkpoint_format` | `Literal['full_state', 'huggingface'` | No | huggingface | |
| `chat_template` | `<class 'str'>` | No | <|user|>
{input}
<|assistant|>
{output} | |
| `model_specific_config` | `<class 'dict'>` | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | |
| `max_seq_length` | `<class 'int'>` | No | 2048 | |
| `gradient_checkpointing` | `<class 'bool'>` | No | False | |
| `save_total_limit` | `<class 'int'>` | No | 3 | |
| `logging_steps` | `<class 'int'>` | No | 10 | |
| `warmup_ratio` | `<class 'float'>` | No | 0.1 | |
| `weight_decay` | `<class 'float'>` | No | 0.01 | |
| `dataloader_num_workers` | `<class 'int'>` | No | 4 | |
| `dataloader_pin_memory` | `<class 'bool'>` | No | True | |
| `dpo_beta` | `<class 'float'>` | No | 0.1 | |
| `use_reference_model` | `<class 'bool'>` | No | True | |
| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | |
| `dpo_output_dir` | `<class 'str'>` | No | | |
## Sample Configuration
```yaml
checkpoint_format: huggingface
distributed_backend: null
device: cpu
dpo_output_dir: ~/.llama/dummy/dpo_output
```

View file

@ -0,0 +1,41 @@
# inline::huggingface-gpu
## Description
HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `device` | `<class 'str'>` | No | cuda | |
| `distributed_backend` | `Literal['fsdp', 'deepspeed'` | No | | |
| `checkpoint_format` | `Literal['full_state', 'huggingface'` | No | huggingface | |
| `chat_template` | `<class 'str'>` | No | <|user|>
{input}
<|assistant|>
{output} | |
| `model_specific_config` | `<class 'dict'>` | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | |
| `max_seq_length` | `<class 'int'>` | No | 2048 | |
| `gradient_checkpointing` | `<class 'bool'>` | No | False | |
| `save_total_limit` | `<class 'int'>` | No | 3 | |
| `logging_steps` | `<class 'int'>` | No | 10 | |
| `warmup_ratio` | `<class 'float'>` | No | 0.1 | |
| `weight_decay` | `<class 'float'>` | No | 0.01 | |
| `dataloader_num_workers` | `<class 'int'>` | No | 4 | |
| `dataloader_pin_memory` | `<class 'bool'>` | No | True | |
| `dpo_beta` | `<class 'float'>` | No | 0.1 | |
| `use_reference_model` | `<class 'bool'>` | No | True | |
| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | |
| `dpo_output_dir` | `<class 'str'>` | No | | |
## Sample Configuration
```yaml
checkpoint_format: huggingface
distributed_backend: null
device: cpu
dpo_output_dir: ~/.llama/dummy/dpo_output
```

View file

@ -0,0 +1,20 @@
# inline::torchtune-cpu
## Description
TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `torch_seed` | `int \| None` | No | | |
| `checkpoint_format` | `Literal['meta', 'huggingface'` | No | meta | |
## Sample Configuration
```yaml
checkpoint_format: meta
```

View file

@ -0,0 +1,20 @@
# inline::torchtune-gpu
## Description
TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `torch_seed` | `int \| None` | No | | |
| `checkpoint_format` | `Literal['meta', 'huggingface'` | No | meta | |
## Sample Configuration
```yaml
checkpoint_format: meta
```

View file

@ -29,12 +29,16 @@ class ListBatchesResponse(BaseModel):
@runtime_checkable @runtime_checkable
class Batches(Protocol): class Batches(Protocol):
"""Protocol for batch processing API operations. """
The Batches API enables efficient processing of multiple requests in a single operation, The Batches API enables efficient processing of multiple requests in a single operation,
particularly useful for processing large datasets, batch evaluation workflows, and particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale. cost-effective inference at scale.
The API is designed to allow use of openai client libraries for seamless integration.
This API provides the following extensions:
- idempotent batch creation
Note: This API is currently under active development and may undergo changes. Note: This API is currently under active development and may undergo changes.
""" """
@ -45,6 +49,7 @@ class Batches(Protocol):
endpoint: str, endpoint: str,
completion_window: Literal["24h"], completion_window: Literal["24h"],
metadata: dict[str, str] | None = None, metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
) -> BatchObject: ) -> BatchObject:
"""Create a new batch for processing multiple API requests. """Create a new batch for processing multiple API requests.
@ -52,6 +57,7 @@ class Batches(Protocol):
:param endpoint: The endpoint to be used for all requests in the batch. :param endpoint: The endpoint to be used for all requests in the batch.
:param completion_window: The time window within which the batch should be processed. :param completion_window: The time window within which the batch should be processed.
:param metadata: Optional metadata for the batch. :param metadata: Optional metadata for the batch.
:param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior.
:returns: The created batch object. :returns: The created batch object.
""" """
... ...

View file

@ -473,6 +473,28 @@ class EmbeddingsResponse(BaseModel):
embeddings: list[list[float]] embeddings: list[list[float]]
@json_schema_type
class RerankData(BaseModel):
"""A single rerank result from a reranking response.
:param index: The original index of the document in the input list
:param relevance_score: The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance.
"""
index: int
relevance_score: float
@json_schema_type
class RerankResponse(BaseModel):
"""Response from a reranking request.
:param data: List of rerank result objects, sorted by relevance score (descending)
"""
data: list[RerankData]
@json_schema_type @json_schema_type
class OpenAIChatCompletionContentPartTextParam(BaseModel): class OpenAIChatCompletionContentPartTextParam(BaseModel):
"""Text content part for OpenAI-compatible chat completion messages. """Text content part for OpenAI-compatible chat completion messages.
@ -1046,6 +1068,7 @@ class InferenceProvider(Protocol):
:returns: A BatchCompletionResponse with the full completions. :returns: A BatchCompletionResponse with the full completions.
""" """
raise NotImplementedError("Batch completion is not implemented") raise NotImplementedError("Batch completion is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/inference/chat-completion", method="POST") @webmethod(route="/inference/chat-completion", method="POST")
async def chat_completion( async def chat_completion(
@ -1110,6 +1133,7 @@ class InferenceProvider(Protocol):
:returns: A BatchChatCompletionResponse with the full completions. :returns: A BatchChatCompletionResponse with the full completions.
""" """
raise NotImplementedError("Batch chat completion is not implemented") raise NotImplementedError("Batch chat completion is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/inference/embeddings", method="POST") @webmethod(route="/inference/embeddings", method="POST")
async def embeddings( async def embeddings(
@ -1131,6 +1155,25 @@ class InferenceProvider(Protocol):
""" """
... ...
@webmethod(route="/inference/rerank", method="POST", experimental=True)
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
"""Rerank a list of documents based on their relevance to a query.
:param model: The identifier of the reranking model to use.
:param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length.
:param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length.
:param max_num_results: (Optional) Maximum number of results to return. Default: returns all.
:returns: RerankResponse with indices sorted by relevance score (descending).
"""
raise NotImplementedError("Reranking is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/openai/v1/completions", method="POST") @webmethod(route="/openai/v1/completions", method="POST")
async def openai_completion( async def openai_completion(
self, self,

View file

@ -386,6 +386,7 @@ class MetricDataPoint(BaseModel):
timestamp: int timestamp: int
value: float value: float
unit: str
@json_schema_type @json_schema_type
@ -518,7 +519,7 @@ class Telemetry(Protocol):
metric_name: str, metric_name: str,
start_time: int, start_time: int,
end_time: int | None = None, end_time: int | None = None,
granularity: str | None = "1d", granularity: str | None = None,
query_type: MetricQueryType = MetricQueryType.RANGE, query_type: MetricQueryType = MetricQueryType.RANGE,
label_matchers: list[MetricLabelMatcher] | None = None, label_matchers: list[MetricLabelMatcher] | None = None,
) -> QueryMetricsResponse: ) -> QueryMetricsResponse:

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server") logger = get_logger(name=__name__, category="cli")
class StackRun(Subcommand): class StackRun(Subcommand):

View file

@ -318,6 +318,41 @@ class QuotaConfig(BaseModel):
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set") period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
class CORSConfig(BaseModel):
allow_origins: list[str] = Field(default_factory=list)
allow_origin_regex: str | None = Field(default=None)
allow_methods: list[str] = Field(default=["OPTIONS"])
allow_headers: list[str] = Field(default_factory=list)
allow_credentials: bool = Field(default=False)
expose_headers: list[str] = Field(default_factory=list)
max_age: int = Field(default=600, ge=0)
@model_validator(mode="after")
def validate_credentials_config(self) -> Self:
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
raise ValueError("Cannot use wildcard origins with credentials enabled")
return self
def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None:
if cors_config is False or cors_config is None:
return None
if cors_config is True:
# dev mode: allow localhost on any port
return CORSConfig(
allow_origins=[],
allow_origin_regex=r"https?://localhost:\d+",
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
)
if isinstance(cors_config, CORSConfig):
return cors_config
raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
class ServerConfig(BaseModel): class ServerConfig(BaseModel):
port: int = Field( port: int = Field(
default=8321, default=8321,
@ -349,6 +384,12 @@ class ServerConfig(BaseModel):
default=None, default=None,
description="Per client quota request configuration", description="Per client quota request configuration",
) )
cors: bool | CORSConfig | None = Field(
default=None,
description="CORS configuration for cross-origin requests. Can be:\n"
"- true: Enable localhost CORS for development\n"
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
)
class StackRunConfig(BaseModel): class StackRunConfig(BaseModel):

View file

@ -146,39 +146,26 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
): ):
super().__init__() super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient( self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_distro_name, custom_provider_registry, provider_data config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
) )
self.pool_executor = ThreadPoolExecutor(max_workers=4) self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal
self.provider_data = provider_data self.provider_data = provider_data
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
def initialize(self):
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
if not self.skip_logger_removal:
self._remove_root_logger_handlers()
# use a new event loop to avoid interfering with the main event loop # use a new event loop to avoid interfering with the main event loop
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
try: try:
return loop.run_until_complete(self.async_client.initialize()) loop.run_until_complete(self.async_client.initialize())
finally: finally:
asyncio.set_event_loop(None) asyncio.set_event_loop(None)
def _remove_root_logger_handlers(self): def initialize(self):
""" """
Remove all handlers from the root logger. Needed to avoid polluting the console with logs. Deprecated method for backward compatibility.
""" """
root_logger = logging.getLogger() pass
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
def request(self, *args, **kwargs): def request(self, *args, **kwargs):
loop = self.loop loop = self.loop
@ -216,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
config_path_or_distro_name: str, config_path_or_distro_name: str,
custom_provider_registry: ProviderRegistry | None = None, custom_provider_registry: ProviderRegistry | None = None,
provider_data: dict[str, Any] | None = None, provider_data: dict[str, Any] | None = None,
skip_logger_removal: bool = False,
): ):
super().__init__() super().__init__()
# when using the library client, we should not log to console since many # when using the library client, we should not log to console since many
@ -223,6 +211,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",") current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console") os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
if not skip_logger_removal:
self._remove_root_logger_handlers()
if config_path_or_distro_name.endswith(".yaml"): if config_path_or_distro_name.endswith(".yaml"):
config_path = Path(config_path_or_distro_name) config_path = Path(config_path_or_distro_name)
if not config_path.exists(): if not config_path.exists():
@ -239,7 +234,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self.provider_data = provider_data self.provider_data = provider_data
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
def _remove_root_logger_handlers(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
"""
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
async def initialize(self) -> bool: async def initialize(self) -> bool:
"""
Initialize the async client.
Returns:
bool: True if initialization was successful
"""
try: try:
self.route_impls = None self.route_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry) self.impls = await construct_stack(self.config, self.custom_provider_registry)

View file

@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routers")
class DatasetIORouter(DatasetIO): class DatasetIORouter(DatasetIO):

View file

@ -16,7 +16,7 @@ from llama_stack.apis.scoring import (
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routers")
class ScoringRouter(Scoring): class ScoringRouter(Scoring):

View file

@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin
from llama_stack.providers.utils.inference.inference_store import InferenceStore from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.telemetry.tracing import get_current_span from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="core::routers")
class InferenceRouter(Inference): class InferenceRouter(Inference):

View file

@ -13,7 +13,7 @@ from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routers")
class SafetyRouter(Safety): class SafetyRouter(Safety):

View file

@ -22,7 +22,7 @@ from llama_stack.log import get_logger
from ..routing_tables.toolgroups import ToolGroupsRoutingTable from ..routing_tables.toolgroups import ToolGroupsRoutingTable
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routers")
class ToolRuntimeRouter(ToolRuntime): class ToolRuntimeRouter(ToolRuntime):

View file

@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routers")
class VectorIORouter(VectorIO): class VectorIORouter(VectorIO):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):

View file

@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable from llama_stack.providers.datatypes import Api, RoutingTable
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
def get_impl_api(p: Any) -> Api: def get_impl_api(p: Any) -> Api:

View file

@ -26,7 +26,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):

View file

@ -17,7 +17,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ModelsRoutingTable(CommonRoutingTableImpl, Models):

View file

@ -19,7 +19,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None: def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:

View file

@ -30,7 +30,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):

View file

@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth") logger = get_logger(name=__name__, category="core::auth")
class AuthenticationMiddleware: class AuthenticationMiddleware:

View file

@ -23,7 +23,7 @@ from llama_stack.core.datatypes import (
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth") logger = get_logger(name=__name__, category="core::auth")
class AuthResponse(BaseModel): class AuthResponse(BaseModel):

View file

@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
logger = get_logger(name=__name__, category="quota") logger = get_logger(name=__name__, category="core::server")
class QuotaMiddleware: class QuotaMiddleware:

View file

@ -28,6 +28,7 @@ from aiohttp import hdrs
from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi import Path as FastapiPath from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError from openai import BadRequestError
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
@ -40,6 +41,7 @@ from llama_stack.core.datatypes import (
AuthenticationRequiredError, AuthenticationRequiredError,
LoggingConfig, LoggingConfig,
StackRunConfig, StackRunConfig,
process_cors_config,
) )
from llama_stack.core.distribution import builtin_automatically_routed_apis from llama_stack.core.distribution import builtin_automatically_routed_apis
from llama_stack.core.external import ExternalApiSpec, load_external_apis from llama_stack.core.external import ExternalApiSpec, load_external_apis
@ -82,7 +84,7 @@ from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server") logger = get_logger(name=__name__, category="core::server")
def warn_with_traceback(message, category, filename, lineno, file=None, line=None): def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -413,7 +415,7 @@ def main(args: argparse.Namespace | None = None):
config_contents = yaml.safe_load(fp) config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg) logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="server", config=logger_config) logger = get_logger(name=__name__, category="core::server", config=logger_config)
if args.env: if args.env:
for env_pair in args.env: for env_pair in args.env:
try: try:
@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None):
window_seconds=window_seconds, window_seconds=window_seconds,
) )
if config.server.cors:
logger.info("Enabling CORS")
cors_config = process_cors_config(config.server.cors)
if cors_config:
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
if Api.telemetry in impls: if Api.telemetry in impls:
setup_logger(impls[Api.telemetry]) setup_logger(impls[Api.telemetry])
else: else:

View file

@ -16,7 +16,7 @@ from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
logger = get_logger(__name__, category="core") logger = get_logger(__name__, category="core::registry")
class DistributionRegistry(Protocol): class DistributionRegistry(Protocol):

View file

@ -10,7 +10,7 @@ from pathlib import Path
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="config_resolution") logger = get_logger(name=__name__, category="core")
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions" DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"

View file

@ -34,7 +34,7 @@ distribution_spec:
telemetry: telemetry:
- provider_type: inline::meta-reference - provider_type: inline::meta-reference
post_training: post_training:
- provider_type: inline::huggingface - provider_type: inline::huggingface-cpu
eval: eval:
- provider_type: inline::meta-reference - provider_type: inline::meta-reference
datasetio: datasetio:

View file

@ -156,8 +156,8 @@ providers:
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training: post_training:
- provider_id: huggingface - provider_id: huggingface-cpu
provider_type: inline::huggingface provider_type: inline::huggingface-cpu
config: config:
checkpoint_format: huggingface checkpoint_format: huggingface
distributed_backend: null distributed_backend: null

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .starter_gpu import get_distribution_template # noqa: F401

View file

@ -0,0 +1,59 @@
version: 2
distribution_spec:
description: Quick start template for running Llama Stack with several popular providers.
This distribution is intended for GPU-enabled environments.
providers:
inference:
- provider_type: remote::cerebras
- provider_type: remote::ollama
- provider_type: remote::vllm
- provider_type: remote::tgi
- provider_type: remote::fireworks
- provider_type: remote::together
- provider_type: remote::bedrock
- provider_type: remote::nvidia
- provider_type: remote::openai
- provider_type: remote::anthropic
- provider_type: remote::gemini
- provider_type: remote::vertexai
- provider_type: remote::groq
- provider_type: remote::sambanova
- provider_type: inline::sentence-transformers
vector_io:
- provider_type: inline::faiss
- provider_type: inline::sqlite-vec
- provider_type: inline::milvus
- provider_type: remote::chromadb
- provider_type: remote::pgvector
files:
- provider_type: inline::localfs
safety:
- provider_type: inline::llama-guard
- provider_type: inline::code-scanner
agents:
- provider_type: inline::meta-reference
telemetry:
- provider_type: inline::meta-reference
post_training:
- provider_type: inline::torchtune-gpu
eval:
- provider_type: inline::meta-reference
datasetio:
- provider_type: remote::huggingface
- provider_type: inline::localfs
scoring:
- provider_type: inline::basic
- provider_type: inline::llm-as-judge
- provider_type: inline::braintrust
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
batches:
- provider_type: inline::reference
image_type: venv
additional_pip_packages:
- aiosqlite
- asyncpg
- sqlalchemy[asyncio]

View file

@ -0,0 +1,238 @@
version: 2
image_name: starter-gpu
apis:
- agents
- batches
- datasetio
- eval
- files
- inference
- post_training
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
api_key: ${env.OPENAI_API_KEY:=}
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
- provider_id: anthropic
provider_type: remote::anthropic
config:
api_key: ${env.ANTHROPIC_API_KEY:=}
- provider_id: gemini
provider_type: remote::gemini
config:
api_key: ${env.GEMINI_API_KEY:=}
- provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
provider_type: remote::vertexai
config:
project: ${env.VERTEX_AI_PROJECT:=}
location: ${env.VERTEX_AI_LOCATION:=us-central1}
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
- provider_id: ${env.MILVUS_URL:+milvus}
provider_type: inline::milvus
config:
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db
- provider_id: ${env.CHROMADB_URL:+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db
- provider_id: ${env.PGVECTOR_DB:+pgvector}
provider_type: remote::pgvector
config:
host: ${env.PGVECTOR_HOST:=localhost}
port: ${env.PGVECTOR_PORT:=5432}
db: ${env.PGVECTOR_DB:=}
user: ${env.PGVECTOR_USER:=}
password: ${env.PGVECTOR_PASSWORD:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
- provider_id: code-scanner
provider_type: inline::code-scanner
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training:
- provider_id: torchtune-gpu
provider_type: inline::torchtune-gpu
config:
checkpoint_format: meta
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/huggingface_datasetio.db
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:=}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
batches:
- provider_id: reference
provider_type: inline::reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/batches.db
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/inference_store.db
models: []
shields:
- shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=}
- shield_id: code-scanner
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.distributions.template import BuildProvider, DistributionTemplate
from ..starter.starter import get_distribution_template as get_starter_distribution_template
def get_distribution_template() -> DistributionTemplate:
template = get_starter_distribution_template()
name = "starter-gpu"
template.name = name
template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments."
template.providers["post_training"] = [
BuildProvider(provider_type="inline::torchtune-gpu"),
]
return template

View file

@ -1,6 +1,7 @@
version: 2 version: 2
distribution_spec: distribution_spec:
description: Quick start template for running Llama Stack with several popular providers description: Quick start template for running Llama Stack with several popular providers.
This distribution is intended for CPU-only environments.
providers: providers:
inference: inference:
- provider_type: remote::cerebras - provider_type: remote::cerebras
@ -34,7 +35,7 @@ distribution_spec:
telemetry: telemetry:
- provider_type: inline::meta-reference - provider_type: inline::meta-reference
post_training: post_training:
- provider_type: inline::huggingface - provider_type: inline::huggingface-cpu
eval: eval:
- provider_type: inline::meta-reference - provider_type: inline::meta-reference
datasetio: datasetio:

View file

@ -156,8 +156,8 @@ providers:
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training: post_training:
- provider_id: huggingface - provider_id: huggingface-cpu
provider_type: inline::huggingface provider_type: inline::huggingface-cpu
config: config:
checkpoint_format: huggingface checkpoint_format: huggingface
distributed_backend: null distributed_backend: null

View file

@ -120,7 +120,7 @@ def get_distribution_template() -> DistributionTemplate:
], ],
"agents": [BuildProvider(provider_type="inline::meta-reference")], "agents": [BuildProvider(provider_type="inline::meta-reference")],
"telemetry": [BuildProvider(provider_type="inline::meta-reference")], "telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"post_training": [BuildProvider(provider_type="inline::huggingface")], "post_training": [BuildProvider(provider_type="inline::huggingface-cpu")],
"eval": [BuildProvider(provider_type="inline::meta-reference")], "eval": [BuildProvider(provider_type="inline::meta-reference")],
"datasetio": [ "datasetio": [
BuildProvider(provider_type="remote::huggingface"), BuildProvider(provider_type="remote::huggingface"),
@ -178,7 +178,7 @@ def get_distribution_template() -> DistributionTemplate:
return DistributionTemplate( return DistributionTemplate(
name=name, name=name,
distro_type="self_hosted", distro_type="self_hosted",
description="Quick start template for running Llama Stack with several popular providers", description="Quick start template for running Llama Stack with several popular providers. This distribution is intended for CPU-only environments.",
container_image=None, container_image=None,
template_path=None, template_path=None,
providers=providers, providers=providers,

View file

@ -36,7 +36,7 @@ from .utils import get_negative_inf_value, to_2tuple
MP_SCALE = 8 MP_SCALE = 8
logger = get_logger(name=__name__, category="models") logger = get_logger(name=__name__, category="models::llama")
def reduce_from_tensor_model_parallel_region(input_): def reduce_from_tensor_model_parallel_region(input_):

View file

@ -11,7 +11,7 @@ from llama_stack.log import get_logger
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="models::llama")
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)' BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})") CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")

View file

@ -18,7 +18,7 @@ from ...datatypes import QuantizationMode
from ..model import Transformer, TransformerBlock from ..model import Transformer, TransformerBlock
from ..moe import MoE from ..moe import MoE
log = get_logger(name=__name__, category="models") log = get_logger(name=__name__, category="models::llama")
def swiglu_wrapper_no_reduce( def swiglu_wrapper_no_reduce(

View file

@ -9,7 +9,7 @@ import collections
from llama_stack.log import get_logger from llama_stack.log import get_logger
log = get_logger(name=__name__, category="llama") log = get_logger(name=__name__, category="models::llama")
try: try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401 import fbgemm_gpu.experimental.gen_ai # noqa: F401

View file

@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search" WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag" RAG_TOOL_GROUP = "builtin::rag"
logger = get_logger(name=__name__, category="agents") logger = get_logger(name=__name__, category="agents::meta_reference")
class ChatAgent(ShieldRunnerMixin): class ChatAgent(ShieldRunnerMixin):

View file

@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
from .persistence import AgentInfo from .persistence import AgentInfo
from .responses.openai_responses import OpenAIResponsesImpl from .responses.openai_responses import OpenAIResponsesImpl
logger = get_logger(name=__name__, category="agents") logger = get_logger(name=__name__, category="agents::meta_reference")
class MetaReferenceAgentsImpl(Agents): class MetaReferenceAgentsImpl(Agents):

View file

@ -17,7 +17,7 @@ from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
log = get_logger(name=__name__, category="agents") log = get_logger(name=__name__, category="agents::meta_reference")
class AgentSessionInfo(Session): class AgentSessionInfo(Session):

View file

@ -41,7 +41,7 @@ from .utils import (
convert_response_text_to_chat_response_format, convert_response_text_to_chat_response_format,
) )
logger = get_logger(name=__name__, category="responses") logger = get_logger(name=__name__, category="openai::responses")
class OpenAIResponsePreviousResponseWithInputItems(BaseModel): class OpenAIResponsePreviousResponseWithInputItems(BaseModel):

View file

@ -47,7 +47,7 @@ from llama_stack.log import get_logger
from .types import ChatCompletionContext, ChatCompletionResult from .types import ChatCompletionContext, ChatCompletionResult
from .utils import convert_chat_choice_to_response_message, is_function_tool_call from .utils import convert_chat_choice_to_response_message, is_function_tool_call
logger = get_logger(name=__name__, category="responses") logger = get_logger(name=__name__, category="agents::meta_reference")
class StreamingResponseOrchestrator: class StreamingResponseOrchestrator:

View file

@ -38,7 +38,7 @@ from llama_stack.log import get_logger
from .types import ChatCompletionContext, ToolExecutionResult from .types import ChatCompletionContext, ToolExecutionResult
logger = get_logger(name=__name__, category="responses") logger = get_logger(name=__name__, category="agents::meta_reference")
class ToolExecutor: class ToolExecutor:

View file

@ -17,6 +17,8 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageContent, OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText, OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall, OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseText, OpenAIResponseText,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -99,14 +101,22 @@ async def convert_response_input_to_chat_messages(
""" """
messages: list[OpenAIMessageParam] = [] messages: list[OpenAIMessageParam] = []
if isinstance(input, list): if isinstance(input, list):
# extract all OpenAIResponseInputFunctionToolCallOutput items
# so their corresponding OpenAIToolMessageParam instances can
# be added immediately following the corresponding
# OpenAIAssistantMessageParam
tool_call_results = {}
for input_item in input: for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput): if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
messages.append( tool_call_results[input_item.call_id] = OpenAIToolMessageParam(
OpenAIToolMessageParam( content=input_item.output,
content=input_item.output, tool_call_id=input_item.call_id,
tool_call_id=input_item.call_id,
)
) )
for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
# skip as these have been extracted and inserted in order
pass
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall): elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
tool_call = OpenAIChatCompletionToolCall( tool_call = OpenAIChatCompletionToolCall(
index=0, index=0,
@ -117,6 +127,28 @@ async def convert_response_input_to_chat_messages(
), ),
) )
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
if input_item.call_id in tool_call_results:
messages.append(tool_call_results[input_item.call_id])
del tool_call_results[input_item.call_id]
elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id=input_item.id,
function=OpenAIChatCompletionToolCallFunction(
name=input_item.name,
arguments=input_item.arguments,
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
messages.append(
OpenAIToolMessageParam(
content=input_item.output,
tool_call_id=input_item.id,
)
)
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
# the tool list will be handled separately
pass
else: else:
content = await convert_response_content_to_chat_content(input_item.content) content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role) message_type = await get_message_type_by_role(input_item.role)
@ -125,6 +157,10 @@ async def convert_response_input_to_chat_messages(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context" f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
) )
messages.append(message_type(content=content)) messages.append(message_type(content=content))
if len(tool_call_results):
raise ValueError(
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
)
else: else:
messages.append(OpenAIUserMessageParam(content=input)) messages.append(OpenAIUserMessageParam(content=input))
return messages return messages

View file

@ -11,7 +11,7 @@ from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing from llama_stack.providers.utils.telemetry import tracing
log = get_logger(name=__name__, category="agents") log = get_logger(name=__name__, category="agents::meta_reference")
class SafetyException(Exception): # noqa: N818 class SafetyException(Exception): # noqa: N818

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import hashlib
import itertools import itertools
import json import json
import time import time
@ -136,28 +137,45 @@ class ReferenceBatchesImpl(Batches):
endpoint: str, endpoint: str,
completion_window: Literal["24h"], completion_window: Literal["24h"],
metadata: dict[str, str] | None = None, metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
) -> BatchObject: ) -> BatchObject:
""" """
Create a new batch for processing multiple API requests. Create a new batch for processing multiple API requests.
Error handling by levels - This implementation provides optional idempotency: when an idempotency key
0. Input param handling, results in 40x errors before processing, e.g. (idempotency_key) is provided, a deterministic ID is generated based on the input
- Wrong completion_window parameters. If a batch with the same parameters already exists, it will be
- Invalid metadata types returned instead of creating a duplicate. Without an idempotency key,
- Unknown endpoint each request creates a new batch with a unique ID.
-> no batch created
1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g. Args:
- input_file_id missing input_file_id: The ID of an uploaded file containing requests for the batch.
- invalid json in file endpoint: The endpoint to be used for all requests in the batch.
- missing custom_id, method, url, body completion_window: The time window within which the batch should be processed.
- invalid model metadata: Optional metadata for the batch.
- streaming idempotency_key: Optional idempotency key for enabling idempotent behavior.
-> batch created, validation sends to failed status
2. Processing errors, result in error_file_id entries, e.g. Returns:
- Any error returned from inference endpoint The created or existing batch object.
-> batch created, goes to completed status
""" """
# Error handling by levels -
# 0. Input param handling, results in 40x errors before processing, e.g.
# - Wrong completion_window
# - Invalid metadata types
# - Unknown endpoint
# -> no batch created
# 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
# - input_file_id missing
# - invalid json in file
# - missing custom_id, method, url, body
# - invalid model
# - streaming
# -> batch created, validation sends to failed status
# 2. Processing errors, result in error_file_id entries, e.g.
# - Any error returned from inference endpoint
# -> batch created, goes to completed status
# TODO: set expiration time for garbage collection # TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions"]: if endpoint not in ["/v1/chat/completions"]:
@ -171,6 +189,35 @@ class ReferenceBatchesImpl(Batches):
) )
batch_id = f"batch_{uuid.uuid4().hex[:16]}" batch_id = f"batch_{uuid.uuid4().hex[:16]}"
# For idempotent requests, use the idempotency key for the batch ID
# This ensures the same key always maps to the same batch ID,
# allowing us to detect parameter conflicts
if idempotency_key is not None:
hash_input = idempotency_key.encode("utf-8")
hash_digest = hashlib.sha256(hash_input).hexdigest()[:24]
batch_id = f"batch_{hash_digest}"
try:
existing_batch = await self.retrieve_batch(batch_id)
if (
existing_batch.input_file_id != input_file_id
or existing_batch.endpoint != endpoint
or existing_batch.completion_window != completion_window
or existing_batch.metadata != metadata
):
raise ConflictError(
f"Idempotency key '{idempotency_key}' was previously used with different parameters. "
"Either use a new idempotency key or ensure all parameters match the original request."
)
logger.info(f"Returning existing batch with ID: {batch_id}")
return existing_batch
except ResourceNotFoundError:
# Batch doesn't exist, continue with creation
pass
current_time = int(time.time()) current_time = int(time.time())
batch = BatchObject( batch = BatchObject(
@ -185,6 +232,7 @@ class ReferenceBatchesImpl(Batches):
) )
await self.kvstore.set(f"batch:{batch_id}", batch.to_json()) await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
logger.info(f"Created new batch with ID: {batch_id}")
if self.process_batches: if self.process_batches:
task = asyncio.create_task(self._process_batch(batch_id)) task = asyncio.create_task(self._process_batch(batch_id))

View file

@ -11,6 +11,7 @@ from typing import Annotated
from fastapi import File, Form, Response, UploadFile from fastapi import File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import ( from llama_stack.apis.files import (
Files, Files,
@ -20,12 +21,15 @@ from llama_stack.apis.files import (
OpenAIFilePurpose, OpenAIFilePurpose,
) )
from llama_stack.core.datatypes import AccessRule from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
from .config import LocalfsFilesImplConfig from .config import LocalfsFilesImplConfig
logger = get_logger(name=__name__, category="files")
class LocalfsFilesImpl(Files): class LocalfsFilesImpl(Files):
def __init__(self, config: LocalfsFilesImplConfig, policy: list[AccessRule]) -> None: def __init__(self, config: LocalfsFilesImplConfig, policy: list[AccessRule]) -> None:
@ -65,6 +69,18 @@ class LocalfsFilesImpl(Files):
"""Get the filesystem path for a file ID.""" """Get the filesystem path for a file ID."""
return Path(self.config.storage_dir) / file_id return Path(self.config.storage_dir) / file_id
async def _lookup_file_id(self, file_id: str) -> tuple[OpenAIFileObject, Path]:
"""Look up a OpenAIFileObject and filesystem path from its ID."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
file_path = Path(row.pop("file_path"))
return OpenAIFileObject(**row), file_path
# OpenAI Files API Implementation # OpenAI Files API Implementation
async def openai_upload_file( async def openai_upload_file(
self, self,
@ -157,37 +173,19 @@ class LocalfsFilesImpl(Files):
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject: async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
"""Returns information about a specific file.""" """Returns information about a specific file."""
if not self.sql_store: file_obj, _ = await self._lookup_file_id(file_id)
raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) return file_obj
if not row:
raise ValueError(f"File with id {file_id} not found")
return OpenAIFileObject(
id=row["id"],
filename=row["filename"],
purpose=OpenAIFilePurpose(row["purpose"]),
bytes=row["bytes"],
created_at=row["created_at"],
expires_at=row["expires_at"],
)
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse: async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
"""Delete a file.""" """Delete a file."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row:
raise ValueError(f"File with id {file_id} not found")
# Delete physical file # Delete physical file
file_path = Path(row["file_path"]) _, file_path = await self._lookup_file_id(file_id)
if file_path.exists(): if file_path.exists():
file_path.unlink() file_path.unlink()
# Delete metadata from database # Delete metadata from database
assert self.sql_store is not None, "Files provider not initialized"
await self.sql_store.delete("openai_files", where={"id": file_id}) await self.sql_store.delete("openai_files", where={"id": file_id})
return OpenAIFileDeleteResponse( return OpenAIFileDeleteResponse(
@ -197,25 +195,17 @@ class LocalfsFilesImpl(Files):
async def openai_retrieve_file_content(self, file_id: str) -> Response: async def openai_retrieve_file_content(self, file_id: str) -> Response:
"""Returns the contents of the specified file.""" """Returns the contents of the specified file."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
# Get file metadata
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row:
raise ValueError(f"File with id {file_id} not found")
# Read file content # Read file content
file_path = Path(row["file_path"]) file_obj, file_path = await self._lookup_file_id(file_id)
if not file_path.exists():
raise ValueError(f"File content not found on disk: {file_path}")
with open(file_path, "rb") as f: if not file_path.exists():
content = f.read() logger.warning(f"File '{file_id}'s underlying '{file_path}' is missing, deleting metadata.")
await self.openai_delete_file(file_id)
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
# Return as binary response with appropriate content type # Return as binary response with appropriate content type
return Response( return Response(
content=content, content=file_path.read_bytes(),
media_type="application/octet-stream", media_type="application/octet-stream",
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'}, headers={"Content-Disposition": f'attachment; filename="{file_obj.filename}"'},
) )

View file

@ -9,7 +9,6 @@ from collections.abc import AsyncGenerator
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
CompletionResponse, CompletionResponse,
InferenceProvider, InferenceProvider,
InterleavedContent,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
@ -100,25 +99,3 @@ class SentenceTransformersInferenceImpl(
tool_config: ToolConfig | None = None, tool_config: ToolConfig | None = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion") raise ValueError("Sentence transformers don't support chat completion")
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import datetime
import threading import threading
from typing import Any from typing import Any
@ -145,11 +146,41 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
metric_name: str, metric_name: str,
start_time: int, start_time: int,
end_time: int | None = None, end_time: int | None = None,
granularity: str | None = "1d", granularity: str | None = None,
query_type: MetricQueryType = MetricQueryType.RANGE, query_type: MetricQueryType = MetricQueryType.RANGE,
label_matchers: list[MetricLabelMatcher] | None = None, label_matchers: list[MetricLabelMatcher] | None = None,
) -> QueryMetricsResponse: ) -> QueryMetricsResponse:
raise NotImplementedError("Querying metrics is not implemented") """Query metrics from the telemetry store.
Args:
metric_name: The name of the metric to query (e.g., "prompt_tokens")
start_time: Start time as Unix timestamp
end_time: End time as Unix timestamp (defaults to now if None)
granularity: Time granularity for aggregation
query_type: Type of query (RANGE or INSTANT)
label_matchers: Label filters to apply
Returns:
QueryMetricsResponse with metric time series data
"""
# Convert timestamps to datetime objects
start_dt = datetime.datetime.fromtimestamp(start_time, datetime.UTC)
end_dt = datetime.datetime.fromtimestamp(end_time, datetime.UTC) if end_time else None
# Use SQLite trace store if available
if hasattr(self, "trace_store") and self.trace_store:
return await self.trace_store.query_metrics(
metric_name=metric_name,
start_time=start_dt,
end_time=end_dt,
granularity=granularity,
query_type=query_type,
label_matchers=label_matchers,
)
else:
raise ValueError(
f"In order to query_metrics, you must have {TelemetrySink.SQLITE} set in your telemetry sinks"
)
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None: def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
with self._lock: with self._lock:

View file

@ -5,9 +5,11 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec,
) )
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
@ -23,4 +25,14 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig", config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
description="Local filesystem-based file storage provider for managing files and documents locally.", description="Local filesystem-based file storage provider for managing files and documents locally.",
), ),
remote_provider_spec(
api=Api.files,
adapter=AdapterSpec(
adapter_type="s3",
pip_packages=["boto3"] + sql_store_pip_packages,
module="llama_stack.providers.remote.files.s3",
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
),
),
] ]

View file

@ -5,34 +5,74 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import cast
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
# We provide two versions of these providers so that distributions can package the appropriate version of torch.
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
torchtune_def = dict(
api=Api.post_training,
pip_packages=["torchtune==0.5.0", "torchao==0.8.0", "numpy"],
module="llama_stack.providers.inline.post_training.torchtune",
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
description="TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.",
)
huggingface_def = dict(
api=Api.post_training,
pip_packages=["trl", "transformers", "peft", "datasets"],
module="llama_stack.providers.inline.post_training.huggingface",
config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
)
def available_providers() -> list[ProviderSpec]: def available_providers() -> list[ProviderSpec]:
return [ return [
InlineProviderSpec( InlineProviderSpec(
api=Api.post_training, **{
provider_type="inline::torchtune", **torchtune_def,
pip_packages=["torch", "torchtune==0.5.0", "torchao==0.8.0", "numpy"], "provider_type": "inline::torchtune-cpu",
module="llama_stack.providers.inline.post_training.torchtune", "pip_packages": (
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig", cast(list[str], torchtune_def["pip_packages"])
api_dependencies=[ + ["torch torchtune==0.5.0 torchao==0.8.0 --index-url https://download.pytorch.org/whl/cpu"]
Api.datasetio, ),
Api.datasets, },
],
description="TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.",
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.post_training, **{
provider_type="inline::huggingface", **huggingface_def,
pip_packages=["torch", "trl", "transformers", "peft", "datasets"], "provider_type": "inline::huggingface-cpu",
module="llama_stack.providers.inline.post_training.huggingface", "pip_packages": (
config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig", cast(list[str], huggingface_def["pip_packages"])
api_dependencies=[ + ["torch --index-url https://download.pytorch.org/whl/cpu"]
Api.datasetio, ),
Api.datasets, },
], ),
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.", InlineProviderSpec(
**{
**torchtune_def,
"provider_type": "inline::torchtune-gpu",
"pip_packages": (
cast(list[str], torchtune_def["pip_packages"]) + ["torch torchtune==0.5.0 torchao==0.8.0"]
),
},
),
InlineProviderSpec(
**{
**huggingface_def,
"provider_type": "inline::huggingface-gpu",
"pip_packages": (cast(list[str], huggingface_def["pip_packages"]) + ["torch"]),
},
), ),
remote_provider_spec( remote_provider_spec(
api=Api.post_training, api=Api.post_training,

View file

@ -0,0 +1,237 @@
# S3 Files Provider
A remote S3-based implementation of the Llama Stack Files API that provides scalable cloud file storage with metadata persistence.
## Features
- **AWS S3 Storage**: Store files in AWS S3 buckets for scalable, durable storage
- **Metadata Management**: Uses SQL database for efficient file metadata queries
- **OpenAI API Compatibility**: Full compatibility with OpenAI Files API endpoints
- **Flexible Authentication**: Support for IAM roles and access keys
- **Custom S3 Endpoints**: Support for MinIO and other S3-compatible services
## Configuration
### Basic Configuration
```yaml
api: files
provider_type: remote::s3
config:
bucket_name: my-llama-stack-files
region: us-east-1
metadata_store:
type: sqlite
db_path: ./s3_files_metadata.db
```
### Advanced Configuration
```yaml
api: files
provider_type: remote::s3
config:
bucket_name: my-llama-stack-files
region: us-east-1
aws_access_key_id: YOUR_ACCESS_KEY
aws_secret_access_key: YOUR_SECRET_KEY
endpoint_url: https://s3.amazonaws.com # Optional for custom endpoints
metadata_store:
type: sqlite
db_path: ./s3_files_metadata.db
```
### Environment Variables
The configuration supports environment variable substitution:
```yaml
config:
bucket_name: "${env.S3_BUCKET_NAME}"
region: "${env.AWS_REGION:=us-east-1}"
aws_access_key_id: "${env.AWS_ACCESS_KEY_ID:=}"
aws_secret_access_key: "${env.AWS_SECRET_ACCESS_KEY:=}"
endpoint_url: "${env.S3_ENDPOINT_URL:=}"
```
Note: `S3_BUCKET_NAME` has no default value since S3 bucket names must be globally unique.
## Authentication
### IAM Roles (Recommended)
For production deployments, use IAM roles:
```yaml
config:
bucket_name: my-bucket
region: us-east-1
# No credentials needed - will use IAM role
```
### Access Keys
For development or specific use cases:
```yaml
config:
bucket_name: my-bucket
region: us-east-1
aws_access_key_id: AKIAIOSFODNN7EXAMPLE
aws_secret_access_key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY
```
## S3 Bucket Setup
### Required Permissions
The S3 provider requires the following permissions:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject",
"s3:ListBucket"
],
"Resource": [
"arn:aws:s3:::your-bucket-name",
"arn:aws:s3:::your-bucket-name/*"
]
}
]
}
```
### Automatic Bucket Creation
By default, the S3 provider expects the bucket to already exist. If you want the provider to automatically create the bucket when it doesn't exist, set `auto_create_bucket: true` in your configuration:
```yaml
config:
bucket_name: my-bucket
auto_create_bucket: true # Will create bucket if it doesn't exist
region: us-east-1
```
**Note**: When `auto_create_bucket` is enabled, the provider will need additional permissions:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject",
"s3:ListBucket",
"s3:CreateBucket"
],
"Resource": [
"arn:aws:s3:::your-bucket-name",
"arn:aws:s3:::your-bucket-name/*"
]
}
]
}
```
### Bucket Policy (Optional)
For additional security, you can add a bucket policy:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Sid": "LlamaStackAccess",
"Effect": "Allow",
"Principal": {
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
},
"Action": [
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject"
],
"Resource": "arn:aws:s3:::your-bucket-name/*"
},
{
"Sid": "LlamaStackBucketAccess",
"Effect": "Allow",
"Principal": {
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
},
"Action": [
"s3:ListBucket"
],
"Resource": "arn:aws:s3:::your-bucket-name"
}
]
}
```
## Features
### Metadata Persistence
File metadata is stored in a SQL database for fast queries and OpenAI API compatibility. The metadata includes:
- File ID
- Original filename
- Purpose (assistants, batch, etc.)
- File size in bytes
- Created and expiration timestamps
### TTL and Cleanup
Files currently have a fixed long expiration time (100 years).
## Development and Testing
### Using MinIO
For self-hosted S3-compatible storage:
```yaml
config:
bucket_name: test-bucket
region: us-east-1
endpoint_url: http://localhost:9000
aws_access_key_id: minioadmin
aws_secret_access_key: minioadmin
```
## Monitoring and Logging
The provider logs important operations and errors. For production deployments, consider:
- CloudWatch monitoring for S3 operations
- Custom metrics for file upload/download rates
- Error rate monitoring
- Performance metrics tracking
## Error Handling
The provider handles various error scenarios:
- S3 connectivity issues
- Bucket access permissions
- File not found errors
- Metadata consistency checks
## Known Limitations
- Fixed long TTL (100 years) instead of configurable expiration
- No server-side encryption enabled by default
- No support for AWS session tokens
- No S3 key prefix organization support
- No multipart upload support (all files uploaded as single objects)

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.core.datatypes import Api
from .config import S3FilesImplConfig
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]):
from .files import S3FilesImpl
# TODO: authorization policies and user separation
impl = S3FilesImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
class S3FilesImplConfig(BaseModel):
"""Configuration for S3-based files provider."""
bucket_name: str = Field(description="S3 bucket name to store files")
region: str = Field(default="us-east-1", description="AWS region where the bucket is located")
aws_access_key_id: str | None = Field(default=None, description="AWS access key ID (optional if using IAM roles)")
aws_secret_access_key: str | None = Field(
default=None, description="AWS secret access key (optional if using IAM roles)"
)
endpoint_url: str | None = Field(default=None, description="Custom S3 endpoint URL (for MinIO, LocalStack, etc.)")
auto_create_bucket: bool = Field(
default=False, description="Automatically create the S3 bucket if it doesn't exist"
)
metadata_store: SqlStoreConfig = Field(description="SQL store configuration for file metadata")
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"bucket_name": "${env.S3_BUCKET_NAME}", # no default, buckets must be globally unique
"region": "${env.AWS_REGION:=us-east-1}",
"aws_access_key_id": "${env.AWS_ACCESS_KEY_ID:=}",
"aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:=}",
"endpoint_url": "${env.S3_ENDPOINT_URL:=}",
"auto_create_bucket": "${env.S3_AUTO_CREATE_BUCKET:=false}",
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="s3_files_metadata.db",
),
}

View file

@ -0,0 +1,272 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
import uuid
from typing import Annotated
import boto3
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
from fastapi import File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import (
Files,
ListOpenAIFileResponse,
OpenAIFileDeleteResponse,
OpenAIFileObject,
OpenAIFilePurpose,
)
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl
from .config import S3FilesImplConfig
# TODO: provider data for S3 credentials
def _create_s3_client(config: S3FilesImplConfig) -> boto3.client:
try:
s3_config = {
"region_name": config.region,
}
# endpoint URL if specified (for MinIO, LocalStack, etc.)
if config.endpoint_url:
s3_config["endpoint_url"] = config.endpoint_url
if config.aws_access_key_id and config.aws_secret_access_key:
s3_config.update(
{
"aws_access_key_id": config.aws_access_key_id,
"aws_secret_access_key": config.aws_secret_access_key,
}
)
return boto3.client("s3", **s3_config)
except (BotoCoreError, NoCredentialsError) as e:
raise RuntimeError(f"Failed to initialize S3 client: {e}") from e
async def _create_bucket_if_not_exists(client: boto3.client, config: S3FilesImplConfig) -> None:
try:
client.head_bucket(Bucket=config.bucket_name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
if error_code == "404":
if not config.auto_create_bucket:
raise RuntimeError(
f"S3 bucket '{config.bucket_name}' does not exist. "
f"Either create the bucket manually or set 'auto_create_bucket: true' in your configuration."
) from e
try:
# For us-east-1, we can't specify LocationConstraint
if config.region == "us-east-1":
client.create_bucket(Bucket=config.bucket_name)
else:
client.create_bucket(
Bucket=config.bucket_name,
CreateBucketConfiguration={"LocationConstraint": config.region},
)
except ClientError as create_error:
raise RuntimeError(
f"Failed to create S3 bucket '{config.bucket_name}': {create_error}"
) from create_error
elif error_code == "403":
raise RuntimeError(f"Access denied to S3 bucket '{config.bucket_name}'") from e
else:
raise RuntimeError(f"Failed to access S3 bucket '{config.bucket_name}': {e}") from e
class S3FilesImpl(Files):
"""S3-based implementation of the Files API."""
# TODO: implement expiration, for now a silly offset
_SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60
def __init__(self, config: S3FilesImplConfig) -> None:
self._config = config
self._client: boto3.client | None = None
self._sql_store: SqlStore | None = None
async def initialize(self) -> None:
self._client = _create_s3_client(self._config)
await _create_bucket_if_not_exists(self._client, self._config)
self._sql_store = sqlstore_impl(self._config.metadata_store)
await self._sql_store.create_table(
"openai_files",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"filename": ColumnType.STRING,
"purpose": ColumnType.STRING,
"bytes": ColumnType.INTEGER,
"created_at": ColumnType.INTEGER,
"expires_at": ColumnType.INTEGER,
# TODO: add s3_etag field for integrity checking
},
)
async def shutdown(self) -> None:
pass
@property
def client(self) -> boto3.client:
assert self._client is not None, "Provider not initialized"
return self._client
@property
def sql_store(self) -> SqlStore:
assert self._sql_store is not None, "Provider not initialized"
return self._sql_store
async def openai_upload_file(
self,
file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()],
) -> OpenAIFileObject:
file_id = f"file-{uuid.uuid4().hex}"
filename = getattr(file, "filename", None) or "uploaded_file"
created_at = int(time.time())
expires_at = created_at + self._SILLY_EXPIRATION_OFFSET
content = await file.read()
file_size = len(content)
await self.sql_store.insert(
"openai_files",
{
"id": file_id,
"filename": filename,
"purpose": purpose.value,
"bytes": file_size,
"created_at": created_at,
"expires_at": expires_at,
},
)
try:
self.client.put_object(
Bucket=self._config.bucket_name,
Key=file_id,
Body=content,
# TODO: enable server-side encryption
)
except ClientError as e:
await self.sql_store.delete("openai_files", where={"id": file_id})
raise RuntimeError(f"Failed to upload file to S3: {e}") from e
return OpenAIFileObject(
id=file_id,
filename=filename,
purpose=purpose,
bytes=file_size,
created_at=created_at,
expires_at=expires_at,
)
async def openai_list_files(
self,
after: str | None = None,
limit: int | None = 10000,
order: Order | None = Order.desc,
purpose: OpenAIFilePurpose | None = None,
) -> ListOpenAIFileResponse:
# this purely defensive. it should not happen because the router also default to Order.desc.
if not order:
order = Order.desc
where_conditions = {}
if purpose:
where_conditions["purpose"] = purpose.value
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,
limit=limit,
)
files = [
OpenAIFileObject(
id=row["id"],
filename=row["filename"],
purpose=OpenAIFilePurpose(row["purpose"]),
bytes=row["bytes"],
created_at=row["created_at"],
expires_at=row["expires_at"],
)
for row in paginated_result.data
]
return ListOpenAIFileResponse(
data=files,
has_more=paginated_result.has_more,
# empty string or None? spec says str, ref impl returns str | None, we go with spec
first_id=files[0].id if files else "",
last_id=files[-1].id if files else "",
)
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "files.list()")
return OpenAIFileObject(
id=row["id"],
filename=row["filename"],
purpose=OpenAIFilePurpose(row["purpose"]),
bytes=row["bytes"],
created_at=row["created_at"],
expires_at=row["expires_at"],
)
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "files.list()")
try:
self.client.delete_object(
Bucket=self._config.bucket_name,
Key=row["id"],
)
except ClientError as e:
if e.response["Error"]["Code"] != "NoSuchKey":
raise RuntimeError(f"Failed to delete file from S3: {e}") from e
await self.sql_store.delete("openai_files", where={"id": file_id})
return OpenAIFileDeleteResponse(id=file_id, deleted=True)
async def openai_retrieve_file_content(self, file_id: str) -> Response:
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "files.list()")
try:
response = self.client.get_object(
Bucket=self._config.bucket_name,
Key=row["id"],
)
# TODO: can we stream this instead of loading it into memory
content = response["Body"].read()
except ClientError as e:
if e.response["Error"]["Code"] == "NoSuchKey":
await self.sql_store.delete("openai_files", where={"id": file_id})
raise ResourceNotFoundError(file_id, "File", "files.list()") from e
raise RuntimeError(f"Failed to download file from S3: {e}") from e
return Response(
content=content,
media_type="application/octet-stream",
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
)

View file

@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig from .config import FireworksImplConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="inference::fireworks")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):

View file

@ -10,7 +10,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="inference::llama_openai_compat")
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):

View file

@ -41,6 +41,11 @@ client.initialize()
### Create Completion ### Create Completion
> Note on Completion API
>
> The hosted NVIDIA Llama NIMs (e.g., `meta-llama/Llama-3.1-8B-Instruct`) with ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` does not support the ```completion``` method, while the locally deployed NIM does.
```python ```python
response = client.inference.completion( response = client.inference.completion(
model_id="meta-llama/Llama-3.1-8B-Instruct", model_id="meta-llama/Llama-3.1-8B-Instruct",
@ -76,6 +81,73 @@ response = client.inference.chat_completion(
print(f"Response: {response.completion_message.content}") print(f"Response: {response.completion_message.content}")
``` ```
### Tool Calling Example ###
```python
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
tool_definition = ToolDefinition(
tool_name="get_weather",
description="Get current weather information for a location",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
required=True,
),
"unit": ToolParamDefinition(
param_type="string",
description="Temperature unit (celsius or fahrenheit)",
required=False,
default="celsius",
),
},
)
tool_response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
tools=[tool_definition],
)
print(f"Tool Response: {tool_response.completion_message.content}")
if tool_response.completion_message.tool_calls:
for tool_call in tool_response.completion_message.tool_calls:
print(f"Tool Called: {tool_call.tool_name}")
print(f"Arguments: {tool_call.arguments}")
```
### Structured Output Example
```python
from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
person_schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"occupation": {"type": "string"},
},
"required": ["name", "age", "occupation"],
}
response_format = JsonSchemaResponseFormat(
type=ResponseFormatType.json_schema, json_schema=person_schema
)
structured_response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
messages=[
{
"role": "user",
"content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ",
}
],
response_format=response_format,
)
print(f"Structured Response: {structured_response.completion_message.content}")
```
### Create Embeddings ### Create Embeddings
> Note on OpenAI embeddings compatibility > Note on OpenAI embeddings compatibility
> >

View file

@ -7,7 +7,7 @@
import warnings import warnings
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from openai import NOT_GIVEN, APIConnectionError, BadRequestError from openai import NOT_GIVEN, APIConnectionError
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
@ -57,7 +57,7 @@ from .openai_utils import (
) )
from .utils import _is_nvidia_hosted from .utils import _is_nvidia_hosted
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="inference::nvidia")
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
@ -197,15 +197,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
} }
extra_body["input_type"] = task_type_options[task_type] extra_body["input_type"] = task_type_options[task_type]
try: response = await self.client.embeddings.create(
response = await self.client.embeddings.create( model=provider_model_id,
model=provider_model_id, input=input,
input=input, extra_body=extra_body,
extra_body=extra_body, )
)
except BadRequestError as e:
raise ValueError(f"Failed to get embeddings: {e}") from e
# #
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...) # OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
# -> # ->

View file

@ -10,7 +10,7 @@ from llama_stack.log import get_logger
from . import NVIDIAConfig from . import NVIDIAConfig
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="inference::nvidia")
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool: def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:

View file

@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter( class OllamaInferenceAdapter(
@ -619,28 +619,6 @@ class OllamaInferenceAdapter(
response.id = id response.id = id
return response return response
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for Ollama")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]: async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
async def _convert_content(content) -> dict: async def _convert_content(content) -> dict:

View file

@ -11,7 +11,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig from .config import OpenAIConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="inference::openai")
# #

View file

@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
log = get_logger(name=__name__, category="inference") log = get_logger(name=__name__, category="inference::tgi")
def build_hf_repo_model_entries(): def build_hf_repo_model_entries():

View file

@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig from .config import TogetherImplConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="inference::together")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):

View file

@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMInferenceAdapterConfig from .config import VLLMInferenceAdapterConfig
log = get_logger(name=__name__, category="inference") log = get_logger(name=__name__, category="inference::vllm")
def build_hf_repo_model_entries(): def build_hf_repo_model_entries():
@ -711,25 +711,3 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
user=user, user=user,
) )
return await self.client.chat.completions.create(**params) # type: ignore return await self.client.chat.completions.create(**params) # type: ignore
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for Ollama")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")

View file

@ -15,7 +15,7 @@ from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefa
from .config import NvidiaPostTrainingConfig from .config import NvidiaPostTrainingConfig
logger = get_logger(name=__name__, category="integration") logger = get_logger(name=__name__, category="post_training::nvidia")
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None: def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:

View file

@ -21,7 +21,7 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from .config import BedrockSafetyConfig from .config import BedrockSafetyConfig
logger = get_logger(name=__name__, category="safety") logger = get_logger(name=__name__, category="safety::bedrock")
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):

View file

@ -9,7 +9,7 @@ from typing import Any
import requests import requests
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
from .config import NVIDIASafetyConfig from .config import NVIDIASafetyConfig
logger = get_logger(name=__name__, category="safety") logger = get_logger(name=__name__, category="safety::nvidia")
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
@ -67,6 +67,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
self.shield = NeMoGuardrails(self.config, shield.shield_id) self.shield = NeMoGuardrails(self.config, shield.shield_id)
return await self.shield.run(messages) return await self.shield.run(messages)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")
class NeMoGuardrails: class NeMoGuardrails:
""" """

Some files were not shown because too many files have changed in this diff Show more