Merge branch 'main' into HuggingfacePostTrainingConfig-branch

This commit is contained in:
Sarthak Deshpande 2025-08-25 11:59:15 +05:30 committed by GitHub
commit d0d737680f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
193 changed files with 7108 additions and 881 deletions

View file

@ -9,6 +9,7 @@ updates:
day: "saturday" day: "saturday"
commit-message: commit-message:
prefix: chore(github-deps) prefix: chore(github-deps)
- package-ecosystem: "uv" - package-ecosystem: "uv"
directory: "/" directory: "/"
schedule: schedule:
@ -19,3 +20,14 @@ updates:
- python - python
commit-message: commit-message:
prefix: chore(python-deps) prefix: chore(python-deps)
- package-ecosystem: npm
directory: "/llama_stack/ui"
schedule:
interval: "weekly"
day: "saturday"
labels:
- type/dependencies
- javascript
commit-message:
prefix: chore(ui-deps)

View file

@ -17,7 +17,7 @@ jobs:
pull-requests: write # for peter-evans/create-pull-request to create a PR pull-requests: write # for peter-evans/create-pull-request to create a PR
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with: with:
ref: main ref: main
fetch-depth: 0 fetch-depth: 0

View file

@ -16,14 +16,14 @@ jobs:
lint: lint:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2 - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # 5.0.0
- name: Run ShellCheck on install.sh - name: Run ShellCheck on install.sh
run: shellcheck scripts/install.sh run: shellcheck scripts/install.sh
smoke-test-on-dev: smoke-test-on-dev:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner

View file

@ -18,7 +18,7 @@ on:
- '.github/workflows/integration-auth-tests.yml' # This workflow - '.github/workflows/integration-auth-tests.yml' # This workflow
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
@ -31,7 +31,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner

View file

@ -16,7 +16,7 @@ on:
- '.github/workflows/integration-sql-store-tests.yml' # This workflow - '.github/workflows/integration-sql-store-tests.yml' # This workflow
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
@ -44,7 +44,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner

View file

@ -65,7 +65,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Setup test environment - name: Setup test environment
uses: ./.github/actions/setup-test-environment uses: ./.github/actions/setup-test-environment

View file

@ -33,7 +33,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner

View file

@ -8,7 +8,7 @@ on:
branches: [main] branches: [main]
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
@ -20,7 +20,7 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with: with:
# For dependabot PRs, we need to checkout with a token that can push changes # For dependabot PRs, we need to checkout with a token that can push changes
token: ${{ github.actor == 'dependabot[bot]' && secrets.GITHUB_TOKEN || github.token }} token: ${{ github.actor == 'dependabot[bot]' && secrets.GITHUB_TOKEN || github.token }}
@ -36,6 +36,17 @@ jobs:
**/requirements*.txt **/requirements*.txt
.pre-commit-config.yaml .pre-commit-config.yaml
- name: Set up Node.js
uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0
with:
node-version: '20'
cache: 'npm'
cache-dependency-path: 'llama_stack/ui/'
- name: Install npm dependencies
run: npm ci
working-directory: llama_stack/ui
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
continue-on-error: true continue-on-error: true
env: env:

View file

@ -26,7 +26,7 @@ on:
- 'pyproject.toml' - 'pyproject.toml'
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
@ -36,7 +36,7 @@ jobs:
distros: ${{ steps.set-matrix.outputs.distros }} distros: ${{ steps.set-matrix.outputs.distros }}
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Generate Distribution List - name: Generate Distribution List
id: set-matrix id: set-matrix
@ -55,7 +55,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner
@ -79,7 +79,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner
@ -92,7 +92,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner
@ -106,6 +106,10 @@ jobs:
- name: Inspect the container image entrypoint - name: Inspect the container image entrypoint
run: | run: |
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1) IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
if [ -z "$IMAGE_ID" ]; then
echo "No image found"
exit 1
fi
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID) entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
echo "Entrypoint: $entrypoint" echo "Entrypoint: $entrypoint"
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then
@ -117,7 +121,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner
@ -140,6 +144,10 @@ jobs:
- name: Inspect UBI9 image - name: Inspect UBI9 image
run: | run: |
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1) IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
if [ -z "$IMAGE_ID" ]; then
echo "No image found"
exit 1
fi
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID) entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
echo "Entrypoint: $entrypoint" echo "Entrypoint: $entrypoint"
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then

View file

@ -21,10 +21,10 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@e92bafb6253dcd438e0484186d7669ea7a8ca1cc # v6.4.3 uses: astral-sh/setup-uv@d9e0f98d3fc6adb07d1e3d37f3043649ddad06a1 # v6.5.0
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
activate-environment: true activate-environment: true

View file

@ -46,7 +46,7 @@ jobs:
echo "::endgroup::" echo "::endgroup::"
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with: with:
fetch-depth: 0 fetch-depth: 0

View file

@ -22,6 +22,6 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Check PR Title's semantic conformance - name: Check PR Title's semantic conformance
uses: amannn/action-semantic-pull-request@0723387faaf9b38adef4775cd42cfd5155ed6017 # v5.5.3 uses: amannn/action-semantic-pull-request@7f33ba792281b034f64e96f4c0b5496782dd3b37 # v6.1.0
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View file

@ -27,7 +27,7 @@ jobs:
# container and point 'uv pip install' to the correct path... # container and point 'uv pip install' to the correct path...
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner

View file

@ -27,7 +27,7 @@ jobs:
# container and point 'uv pip install' to the correct path... # container and point 'uv pip install' to the correct path...
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner

View file

@ -13,7 +13,7 @@ on:
workflow_dispatch: workflow_dispatch:
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
@ -26,10 +26,10 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Setup Node.js - name: Setup Node.js
uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0 uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
with: with:
node-version: ${{ matrix.node-version }} node-version: ${{ matrix.node-version }}
cache: 'npm' cache: 'npm'

View file

@ -18,7 +18,7 @@ on:
workflow_dispatch: workflow_dispatch:
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
@ -32,7 +32,7 @@ jobs:
- "3.13" - "3.13"
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner

View file

@ -27,7 +27,7 @@ on:
- '.github/workflows/update-readthedocs.yml' - '.github/workflows/update-readthedocs.yml'
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
@ -37,7 +37,7 @@ jobs:
TOKEN: ${{ secrets.READTHEDOCS_TOKEN }} TOKEN: ${{ secrets.READTHEDOCS_TOKEN }}
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner

View file

@ -146,20 +146,32 @@ repos:
pass_filenames: false pass_filenames: false
require_serial: true require_serial: true
files: ^.github/workflows/.*$ files: ^.github/workflows/.*$
- id: ui-prettier - id: ui-linter
name: Format UI code with Prettier name: Format & Lint UI
entry: bash -c 'cd llama_stack/ui && npm run format' entry: bash ./scripts/run-ui-linter.sh
language: system language: system
files: ^llama_stack/ui/.*\.(ts|tsx)$ files: ^llama_stack/ui/.*\.(ts|tsx)$
pass_filenames: false pass_filenames: false
require_serial: true require_serial: true
- id: ui-eslint
name: Lint UI code with ESLint - id: check-log-usage
entry: bash -c 'cd llama_stack/ui && npm run lint -- --fix --quiet' name: Ensure 'llama_stack.log' usage for logging
entry: bash
language: system language: system
files: ^llama_stack/ui/.*\.(ts|tsx)$ types: [python]
pass_filenames: false pass_filenames: true
require_serial: true args:
- -c
- |
matches=$(grep -EnH '^[^#]*\b(import\s+logging|from\s+logging\b)' "$@" | grep -v -e '#\s*allow-direct-logging' || true)
if [ -n "$matches" ]; then
# GitHub Actions annotation format
while IFS=: read -r file line_num rest; do
echo "::error file=$file,line=$line_num::Do not use 'import logging' or 'from logging import' in $file. Use the custom log instead: from llama_stack.log import get_logger; logger = get_logger(). If direct logging is truly needed, add: # allow-direct-logging"
done <<< "$matches"
exit 1
fi
exit 0
ci: ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

View file

@ -4605,6 +4605,49 @@
} }
} }
}, },
"/v1/inference/rerank": {
"post": {
"responses": {
"200": {
"description": "RerankResponse with indices sorted by relevance score (descending).",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/RerankResponse"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Inference"
],
"description": "Rerank a list of documents based on their relevance to a query.",
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/RerankRequest"
}
}
},
"required": true
}
}
},
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": { "/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
"post": { "post": {
"responses": { "responses": {
@ -16024,12 +16067,16 @@
"value": { "value": {
"type": "number", "type": "number",
"description": "The numeric value of the metric at this timestamp" "description": "The numeric value of the metric at this timestamp"
},
"unit": {
"type": "string"
} }
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"timestamp", "timestamp",
"value" "value",
"unit"
], ],
"title": "MetricDataPoint", "title": "MetricDataPoint",
"description": "A single data point in a metric time series." "description": "A single data point in a metric time series."
@ -16587,6 +16634,95 @@
], ],
"title": "RegisterVectorDbRequest" "title": "RegisterVectorDbRequest"
}, },
"RerankRequest": {
"type": "object",
"properties": {
"model": {
"type": "string",
"description": "The identifier of the reranking model to use."
},
"query": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
}
],
"description": "The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length."
},
"items": {
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
},
{
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
}
]
},
"description": "List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length."
},
"max_num_results": {
"type": "integer",
"description": "(Optional) Maximum number of results to return. Default: returns all."
}
},
"additionalProperties": false,
"required": [
"model",
"query",
"items"
],
"title": "RerankRequest"
},
"RerankData": {
"type": "object",
"properties": {
"index": {
"type": "integer",
"description": "The original index of the document in the input list"
},
"relevance_score": {
"type": "number",
"description": "The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance."
}
},
"additionalProperties": false,
"required": [
"index",
"relevance_score"
],
"title": "RerankData",
"description": "A single rerank result from a reranking response."
},
"RerankResponse": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/components/schemas/RerankData"
},
"description": "List of rerank result objects, sorted by relevance score (descending)"
}
},
"additionalProperties": false,
"required": [
"data"
],
"title": "RerankResponse",
"description": "Response from a reranking request."
},
"ResumeAgentTurnRequest": { "ResumeAgentTurnRequest": {
"type": "object", "type": "object",
"properties": { "properties": {

View file

@ -3264,6 +3264,37 @@ paths:
schema: schema:
$ref: '#/components/schemas/QueryTracesRequest' $ref: '#/components/schemas/QueryTracesRequest'
required: true required: true
/v1/inference/rerank:
post:
responses:
'200':
description: >-
RerankResponse with indices sorted by relevance score (descending).
content:
application/json:
schema:
$ref: '#/components/schemas/RerankResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Inference
description: >-
Rerank a list of documents based on their relevance to a query.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RerankRequest'
required: true
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume: /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
post: post:
responses: responses:
@ -11923,10 +11954,13 @@ components:
type: number type: number
description: >- description: >-
The numeric value of the metric at this timestamp The numeric value of the metric at this timestamp
unit:
type: string
additionalProperties: false additionalProperties: false
required: required:
- timestamp - timestamp
- value - value
- unit
title: MetricDataPoint title: MetricDataPoint
description: >- description: >-
A single data point in a metric time series. A single data point in a metric time series.
@ -12337,6 +12371,76 @@ components:
- vector_db_id - vector_db_id
- embedding_model - embedding_model
title: RegisterVectorDbRequest title: RegisterVectorDbRequest
RerankRequest:
type: object
properties:
model:
type: string
description: >-
The identifier of the reranking model to use.
query:
oneOf:
- type: string
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
description: >-
The search query to rank items against. Can be a string, text content
part, or image content part. The input must not exceed the model's max
input token length.
items:
type: array
items:
oneOf:
- type: string
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
description: >-
List of items to rerank. Each item can be a string, text content part,
or image content part. Each input must not exceed the model's max input
token length.
max_num_results:
type: integer
description: >-
(Optional) Maximum number of results to return. Default: returns all.
additionalProperties: false
required:
- model
- query
- items
title: RerankRequest
RerankData:
type: object
properties:
index:
type: integer
description: >-
The original index of the document in the input list
relevance_score:
type: number
description: >-
The relevance score from the model output. Values are inverted when applicable
so that higher scores indicate greater relevance.
additionalProperties: false
required:
- index
- relevance_score
title: RerankData
description: >-
A single rerank result from a reranking response.
RerankResponse:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/RerankData'
description: >-
List of rerank result objects, sorted by relevance score (descending)
additionalProperties: false
required:
- data
title: RerankResponse
description: Response from a reranking request.
ResumeAgentTurnRequest: ResumeAgentTurnRequest:
type: object type: object
properties: properties:

View file

@ -225,8 +225,32 @@ server:
port: 8321 # Port to listen on (default: 8321) port: 8321 # Port to listen on (default: 8321)
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
cors: true # Optional: Enable CORS (dev mode) or full config object
``` ```
### CORS Configuration
CORS (Cross-Origin Resource Sharing) can be configured in two ways:
**Local development** (allows localhost origins only):
```yaml
server:
cors: true
```
**Explicit configuration** (custom origins and settings):
```yaml
server:
cors:
allow_origins: ["https://myapp.com", "https://app.example.com"]
allow_methods: ["GET", "POST", "PUT", "DELETE"]
allow_headers: ["Content-Type", "Authorization"]
allow_credentials: true
max_age: 3600
```
When `cors: true`, the server enables secure localhost-only access for local development. For production, specify exact origins to maintain security.
### Authentication Configuration ### Authentication Configuration
> **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly. > **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly.
@ -618,6 +642,54 @@ Content-Type: application/json
} }
``` ```
### CORS Configuration
Configure CORS to allow web browsers to make requests from different domains. Disabled by default.
#### Quick Setup
For development, use the simple boolean flag:
```yaml
server:
cors: true # Auto-enables localhost with any port
```
This automatically allows `http://localhost:*` and `https://localhost:*` with secure defaults.
#### Custom Configuration
For specific origins and full control:
```yaml
server:
cors:
allow_origins: ["https://myapp.com", "https://staging.myapp.com"]
allow_credentials: true
allow_methods: ["GET", "POST", "PUT", "DELETE"]
allow_headers: ["Content-Type", "Authorization"]
allow_origin_regex: "https://.*\\.example\\.com" # Optional regex pattern
expose_headers: ["X-Total-Count"]
max_age: 86400
```
#### Configuration Options
| Field | Description | Default |
| -------------------- | ---------------------------------------------- | ------- |
| `allow_origins` | List of allowed origins. Use `["*"]` for any. | `["*"]` |
| `allow_origin_regex` | Regex pattern for allowed origins (optional). | `None` |
| `allow_methods` | Allowed HTTP methods. | `["*"]` |
| `allow_headers` | Allowed headers. | `["*"]` |
| `allow_credentials` | Allow credentials (cookies, auth headers). | `false` |
| `expose_headers` | Headers exposed to browser. | `[]` |
| `max_age` | Preflight cache time (seconds). | `600` |
**Security Notes**:
- `allow_credentials: true` requires explicit origins (no wildcards)
- `cors: true` enables localhost access only (secure for development)
- For public APIs, always specify exact allowed origins
## Extending to handle Safety ## Extending to handle Safety
Configuring Safety can be a little involved so it is instructive to go through an example. Configuring Safety can be a little involved so it is instructive to go through an example.

View file

@ -17,7 +17,6 @@ client = LlamaStackAsLibraryClient(
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here. # provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]}, provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
) )
client.initialize()
``` ```
This will parse your config and set up any inline implementations and remote clients needed for your implementation. This will parse your config and set up any inline implementations and remote clients needed for your implementation.
@ -32,5 +31,4 @@ If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/
```python ```python
client = LlamaStackAsLibraryClient(config_path) client = LlamaStackAsLibraryClient(config_path)
client.initialize()
``` ```

View file

@ -1,4 +1,3 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# #

View file

@ -2,12 +2,15 @@
## Overview ## Overview
Protocol for batch processing API operations. The Batches API enables efficient processing of multiple requests in a single operation,
The Batches API enables efficient processing of multiple requests in a single operation,
particularly useful for processing large datasets, batch evaluation workflows, and particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale. cost-effective inference at scale.
The API is designed to allow use of openai client libraries for seamless integration.
This API provides the following extensions:
- idempotent batch creation
Note: This API is currently under active development and may undergo changes. Note: This API is currently under active development and may undergo changes.
This section contains documentation for all available providers for the **batches** API. This section contains documentation for all available providers for the **batches** API.

View file

@ -10,4 +10,5 @@ This section contains documentation for all available providers for the **files*
:maxdepth: 1 :maxdepth: 1
inline_localfs inline_localfs
remote_s3
``` ```

View file

@ -0,0 +1,33 @@
# remote::s3
## Description
AWS S3-based file storage provider for scalable cloud file management with metadata persistence.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `bucket_name` | `<class 'str'>` | No | | S3 bucket name to store files |
| `region` | `<class 'str'>` | No | us-east-1 | AWS region where the bucket is located |
| `aws_access_key_id` | `str \| None` | No | | AWS access key ID (optional if using IAM roles) |
| `aws_secret_access_key` | `str \| None` | No | | AWS secret access key (optional if using IAM roles) |
| `endpoint_url` | `str \| None` | No | | Custom S3 endpoint URL (for MinIO, LocalStack, etc.) |
| `auto_create_bucket` | `<class 'bool'>` | No | False | Automatically create the S3 bucket if it doesn't exist |
| `metadata_store` | `utils.sqlstore.sqlstore.SqliteSqlStoreConfig \| utils.sqlstore.sqlstore.PostgresSqlStoreConfig` | No | sqlite | SQL store configuration for file metadata |
## Sample Configuration
```yaml
bucket_name: ${env.S3_BUCKET_NAME}
region: ${env.AWS_REGION:=us-east-1}
aws_access_key_id: ${env.AWS_ACCESS_KEY_ID:=}
aws_secret_access_key: ${env.AWS_SECRET_ACCESS_KEY:=}
endpoint_url: ${env.S3_ENDPOINT_URL:=}
auto_create_bucket: ${env.S3_AUTO_CREATE_BUCKET:=false}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/s3_files_metadata.db
```

View file

@ -9,7 +9,9 @@ This section contains documentation for all available providers for the **post_t
```{toctree} ```{toctree}
:maxdepth: 1 :maxdepth: 1
inline_huggingface inline_huggingface-cpu
inline_torchtune inline_huggingface-gpu
inline_torchtune-cpu
inline_torchtune-gpu
remote_nvidia remote_nvidia
``` ```

View file

@ -0,0 +1,41 @@
# inline::huggingface-cpu
## Description
HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `device` | `<class 'str'>` | No | cuda | |
| `distributed_backend` | `Literal['fsdp', 'deepspeed'` | No | | |
| `checkpoint_format` | `Literal['full_state', 'huggingface'` | No | huggingface | |
| `chat_template` | `<class 'str'>` | No | <|user|>
{input}
<|assistant|>
{output} | |
| `model_specific_config` | `<class 'dict'>` | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | |
| `max_seq_length` | `<class 'int'>` | No | 2048 | |
| `gradient_checkpointing` | `<class 'bool'>` | No | False | |
| `save_total_limit` | `<class 'int'>` | No | 3 | |
| `logging_steps` | `<class 'int'>` | No | 10 | |
| `warmup_ratio` | `<class 'float'>` | No | 0.1 | |
| `weight_decay` | `<class 'float'>` | No | 0.01 | |
| `dataloader_num_workers` | `<class 'int'>` | No | 4 | |
| `dataloader_pin_memory` | `<class 'bool'>` | No | True | |
| `dpo_beta` | `<class 'float'>` | No | 0.1 | |
| `use_reference_model` | `<class 'bool'>` | No | True | |
| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | |
| `dpo_output_dir` | `<class 'str'>` | No | | |
## Sample Configuration
```yaml
checkpoint_format: huggingface
distributed_backend: null
device: cpu
dpo_output_dir: ~/.llama/dummy/dpo_output
```

View file

@ -0,0 +1,41 @@
# inline::huggingface-gpu
## Description
HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `device` | `<class 'str'>` | No | cuda | |
| `distributed_backend` | `Literal['fsdp', 'deepspeed'` | No | | |
| `checkpoint_format` | `Literal['full_state', 'huggingface'` | No | huggingface | |
| `chat_template` | `<class 'str'>` | No | <|user|>
{input}
<|assistant|>
{output} | |
| `model_specific_config` | `<class 'dict'>` | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | |
| `max_seq_length` | `<class 'int'>` | No | 2048 | |
| `gradient_checkpointing` | `<class 'bool'>` | No | False | |
| `save_total_limit` | `<class 'int'>` | No | 3 | |
| `logging_steps` | `<class 'int'>` | No | 10 | |
| `warmup_ratio` | `<class 'float'>` | No | 0.1 | |
| `weight_decay` | `<class 'float'>` | No | 0.01 | |
| `dataloader_num_workers` | `<class 'int'>` | No | 4 | |
| `dataloader_pin_memory` | `<class 'bool'>` | No | True | |
| `dpo_beta` | `<class 'float'>` | No | 0.1 | |
| `use_reference_model` | `<class 'bool'>` | No | True | |
| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | |
| `dpo_output_dir` | `<class 'str'>` | No | | |
## Sample Configuration
```yaml
checkpoint_format: huggingface
distributed_backend: null
device: cpu
dpo_output_dir: ~/.llama/dummy/dpo_output
```

View file

@ -0,0 +1,20 @@
# inline::torchtune-cpu
## Description
TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `torch_seed` | `int \| None` | No | | |
| `checkpoint_format` | `Literal['meta', 'huggingface'` | No | meta | |
## Sample Configuration
```yaml
checkpoint_format: meta
```

View file

@ -0,0 +1,20 @@
# inline::torchtune-gpu
## Description
TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `torch_seed` | `int \| None` | No | | |
| `checkpoint_format` | `Literal['meta', 'huggingface'` | No | meta | |
## Sample Configuration
```yaml
checkpoint_format: meta
```

View file

@ -29,12 +29,16 @@ class ListBatchesResponse(BaseModel):
@runtime_checkable @runtime_checkable
class Batches(Protocol): class Batches(Protocol):
"""Protocol for batch processing API operations. """
The Batches API enables efficient processing of multiple requests in a single operation, The Batches API enables efficient processing of multiple requests in a single operation,
particularly useful for processing large datasets, batch evaluation workflows, and particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale. cost-effective inference at scale.
The API is designed to allow use of openai client libraries for seamless integration.
This API provides the following extensions:
- idempotent batch creation
Note: This API is currently under active development and may undergo changes. Note: This API is currently under active development and may undergo changes.
""" """
@ -45,6 +49,7 @@ class Batches(Protocol):
endpoint: str, endpoint: str,
completion_window: Literal["24h"], completion_window: Literal["24h"],
metadata: dict[str, str] | None = None, metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
) -> BatchObject: ) -> BatchObject:
"""Create a new batch for processing multiple API requests. """Create a new batch for processing multiple API requests.
@ -52,6 +57,7 @@ class Batches(Protocol):
:param endpoint: The endpoint to be used for all requests in the batch. :param endpoint: The endpoint to be used for all requests in the batch.
:param completion_window: The time window within which the batch should be processed. :param completion_window: The time window within which the batch should be processed.
:param metadata: Optional metadata for the batch. :param metadata: Optional metadata for the batch.
:param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior.
:returns: The created batch object. :returns: The created batch object.
""" """
... ...

View file

@ -473,6 +473,28 @@ class EmbeddingsResponse(BaseModel):
embeddings: list[list[float]] embeddings: list[list[float]]
@json_schema_type
class RerankData(BaseModel):
"""A single rerank result from a reranking response.
:param index: The original index of the document in the input list
:param relevance_score: The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance.
"""
index: int
relevance_score: float
@json_schema_type
class RerankResponse(BaseModel):
"""Response from a reranking request.
:param data: List of rerank result objects, sorted by relevance score (descending)
"""
data: list[RerankData]
@json_schema_type @json_schema_type
class OpenAIChatCompletionContentPartTextParam(BaseModel): class OpenAIChatCompletionContentPartTextParam(BaseModel):
"""Text content part for OpenAI-compatible chat completion messages. """Text content part for OpenAI-compatible chat completion messages.
@ -1046,6 +1068,7 @@ class InferenceProvider(Protocol):
:returns: A BatchCompletionResponse with the full completions. :returns: A BatchCompletionResponse with the full completions.
""" """
raise NotImplementedError("Batch completion is not implemented") raise NotImplementedError("Batch completion is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/inference/chat-completion", method="POST") @webmethod(route="/inference/chat-completion", method="POST")
async def chat_completion( async def chat_completion(
@ -1110,6 +1133,7 @@ class InferenceProvider(Protocol):
:returns: A BatchChatCompletionResponse with the full completions. :returns: A BatchChatCompletionResponse with the full completions.
""" """
raise NotImplementedError("Batch chat completion is not implemented") raise NotImplementedError("Batch chat completion is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/inference/embeddings", method="POST") @webmethod(route="/inference/embeddings", method="POST")
async def embeddings( async def embeddings(
@ -1131,6 +1155,25 @@ class InferenceProvider(Protocol):
""" """
... ...
@webmethod(route="/inference/rerank", method="POST", experimental=True)
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
"""Rerank a list of documents based on their relevance to a query.
:param model: The identifier of the reranking model to use.
:param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length.
:param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length.
:param max_num_results: (Optional) Maximum number of results to return. Default: returns all.
:returns: RerankResponse with indices sorted by relevance score (descending).
"""
raise NotImplementedError("Reranking is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/openai/v1/completions", method="POST") @webmethod(route="/openai/v1/completions", method="POST")
async def openai_completion( async def openai_completion(
self, self,

View file

@ -386,6 +386,7 @@ class MetricDataPoint(BaseModel):
timestamp: int timestamp: int
value: float value: float
unit: str
@json_schema_type @json_schema_type
@ -518,7 +519,7 @@ class Telemetry(Protocol):
metric_name: str, metric_name: str,
start_time: int, start_time: int,
end_time: int | None = None, end_time: int | None = None,
granularity: str | None = "1d", granularity: str | None = None,
query_type: MetricQueryType = MetricQueryType.RANGE, query_type: MetricQueryType = MetricQueryType.RANGE,
label_matchers: list[MetricLabelMatcher] | None = None, label_matchers: list[MetricLabelMatcher] | None = None,
) -> QueryMetricsResponse: ) -> QueryMetricsResponse:

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server") logger = get_logger(name=__name__, category="cli")
class StackRun(Subcommand): class StackRun(Subcommand):

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import importlib.resources import importlib.resources
import logging
import sys import sys
from pydantic import BaseModel from pydantic import BaseModel
@ -17,9 +16,10 @@ from llama_stack.core.external import load_external_apis
from llama_stack.core.utils.exec import run_command from llama_stack.core.utils.exec import run_command
from llama_stack.core.utils.image_types import LlamaStackImageType from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.distributions.template import DistributionTemplate from llama_stack.distributions.template import DistributionTemplate
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="core")
# These are the dependencies needed by the distribution server. # These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script. # `llama-stack` is automatically installed by the installation script.

View file

@ -3,7 +3,6 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import textwrap import textwrap
from typing import Any from typing import Any
@ -21,9 +20,10 @@ from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.prompt_for_config import prompt_for_config from llama_stack.core.utils.prompt_for_config import prompt_for_config
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.datatypes import Api, ProviderSpec
logger = logging.getLogger(__name__) logger = get_logger(name=__name__, category="core")
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider: def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:

View file

@ -318,6 +318,41 @@ class QuotaConfig(BaseModel):
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set") period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
class CORSConfig(BaseModel):
allow_origins: list[str] = Field(default_factory=list)
allow_origin_regex: str | None = Field(default=None)
allow_methods: list[str] = Field(default=["OPTIONS"])
allow_headers: list[str] = Field(default_factory=list)
allow_credentials: bool = Field(default=False)
expose_headers: list[str] = Field(default_factory=list)
max_age: int = Field(default=600, ge=0)
@model_validator(mode="after")
def validate_credentials_config(self) -> Self:
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
raise ValueError("Cannot use wildcard origins with credentials enabled")
return self
def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None:
if cors_config is False or cors_config is None:
return None
if cors_config is True:
# dev mode: allow localhost on any port
return CORSConfig(
allow_origins=[],
allow_origin_regex=r"https?://localhost:\d+",
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
)
if isinstance(cors_config, CORSConfig):
return cors_config
raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
class ServerConfig(BaseModel): class ServerConfig(BaseModel):
port: int = Field( port: int = Field(
default=8321, default=8321,
@ -349,6 +384,12 @@ class ServerConfig(BaseModel):
default=None, default=None,
description="Per client quota request configuration", description="Per client quota request configuration",
) )
cors: bool | CORSConfig | None = Field(
default=None,
description="CORS configuration for cross-origin requests. Can be:\n"
"- true: Enable localhost CORS for development\n"
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
)
class StackRunConfig(BaseModel): class StackRunConfig(BaseModel):

View file

@ -7,7 +7,7 @@
import asyncio import asyncio
import inspect import inspect
import json import json
import logging import logging # allow-direct-logging
import os import os
import sys import sys
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -48,6 +48,7 @@ from llama_stack.core.stack import (
from llama_stack.core.utils.config import redact_sensitive_fields from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.context import preserve_contexts_async_generator from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.core.utils.exec import in_notebook from llama_stack.core.utils.exec import in_notebook
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT, CURRENT_TRACE_CONTEXT,
end_trace, end_trace,
@ -55,7 +56,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
start_trace, start_trace,
) )
logger = logging.getLogger(__name__) logger = get_logger(name=__name__, category="core")
T = TypeVar("T") T = TypeVar("T")
@ -145,39 +146,26 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
): ):
super().__init__() super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient( self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_distro_name, custom_provider_registry, provider_data config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
) )
self.pool_executor = ThreadPoolExecutor(max_workers=4) self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal
self.provider_data = provider_data self.provider_data = provider_data
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
def initialize(self):
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
if not self.skip_logger_removal:
self._remove_root_logger_handlers()
# use a new event loop to avoid interfering with the main event loop # use a new event loop to avoid interfering with the main event loop
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
try: try:
return loop.run_until_complete(self.async_client.initialize()) loop.run_until_complete(self.async_client.initialize())
finally: finally:
asyncio.set_event_loop(None) asyncio.set_event_loop(None)
def _remove_root_logger_handlers(self): def initialize(self):
""" """
Remove all handlers from the root logger. Needed to avoid polluting the console with logs. Deprecated method for backward compatibility.
""" """
root_logger = logging.getLogger() pass
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
def request(self, *args, **kwargs): def request(self, *args, **kwargs):
loop = self.loop loop = self.loop
@ -215,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
config_path_or_distro_name: str, config_path_or_distro_name: str,
custom_provider_registry: ProviderRegistry | None = None, custom_provider_registry: ProviderRegistry | None = None,
provider_data: dict[str, Any] | None = None, provider_data: dict[str, Any] | None = None,
skip_logger_removal: bool = False,
): ):
super().__init__() super().__init__()
# when using the library client, we should not log to console since many # when using the library client, we should not log to console since many
@ -222,6 +211,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",") current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console") os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
if not skip_logger_removal:
self._remove_root_logger_handlers()
if config_path_or_distro_name.endswith(".yaml"): if config_path_or_distro_name.endswith(".yaml"):
config_path = Path(config_path_or_distro_name) config_path = Path(config_path_or_distro_name)
if not config_path.exists(): if not config_path.exists():
@ -238,7 +234,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self.provider_data = provider_data self.provider_data = provider_data
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
def _remove_root_logger_handlers(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
"""
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
async def initialize(self) -> bool: async def initialize(self) -> bool:
"""
Initialize the async client.
Returns:
bool: True if initialization was successful
"""
try: try:
self.route_impls = None self.route_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry) self.impls = await construct_stack(self.config, self.custom_provider_registry)

View file

@ -6,15 +6,15 @@
import contextvars import contextvars
import json import json
import logging
from contextlib import AbstractContextManager from contextlib import AbstractContextManager
from typing import Any from typing import Any
from llama_stack.core.datatypes import User from llama_stack.core.datatypes import User
from llama_stack.log import get_logger
from .utils.dynamic import instantiate_class_type from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="core")
# Context variable for request provider data and auth attributes # Context variable for request provider data and auth attributes
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None) PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)

View file

@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routers")
class DatasetIORouter(DatasetIO): class DatasetIORouter(DatasetIO):

View file

@ -16,7 +16,7 @@ from llama_stack.apis.scoring import (
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routers")
class ScoringRouter(Scoring): class ScoringRouter(Scoring):

View file

@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin
from llama_stack.providers.utils.inference.inference_store import InferenceStore from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.telemetry.tracing import get_current_span from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="core::routers")
class InferenceRouter(Inference): class InferenceRouter(Inference):

View file

@ -13,7 +13,7 @@ from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routers")
class SafetyRouter(Safety): class SafetyRouter(Safety):

View file

@ -22,7 +22,7 @@ from llama_stack.log import get_logger
from ..routing_tables.toolgroups import ToolGroupsRoutingTable from ..routing_tables.toolgroups import ToolGroupsRoutingTable
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routers")
class ToolRuntimeRouter(ToolRuntime): class ToolRuntimeRouter(ToolRuntime):

View file

@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routers")
class VectorIORouter(VectorIO): class VectorIORouter(VectorIO):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):

View file

@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable from llama_stack.providers.datatypes import Api, RoutingTable
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
def get_impl_api(p: Any) -> Api: def get_impl_api(p: Any) -> Api:

View file

@ -26,7 +26,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):

View file

@ -17,7 +17,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ModelsRoutingTable(CommonRoutingTableImpl, Models):

View file

@ -19,7 +19,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None: def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:

View file

@ -30,7 +30,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):

View file

@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth") logger = get_logger(name=__name__, category="core::auth")
class AuthenticationMiddleware: class AuthenticationMiddleware:

View file

@ -23,7 +23,7 @@ from llama_stack.core.datatypes import (
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth") logger = get_logger(name=__name__, category="core::auth")
class AuthResponse(BaseModel): class AuthResponse(BaseModel):

View file

@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
logger = get_logger(name=__name__, category="quota") logger = get_logger(name=__name__, category="core::server")
class QuotaMiddleware: class QuotaMiddleware:

View file

@ -9,7 +9,7 @@ import asyncio
import functools import functools
import inspect import inspect
import json import json
import logging import logging # allow-direct-logging
import os import os
import ssl import ssl
import sys import sys
@ -28,6 +28,7 @@ from aiohttp import hdrs
from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi import Path as FastapiPath from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError from openai import BadRequestError
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
@ -40,6 +41,7 @@ from llama_stack.core.datatypes import (
AuthenticationRequiredError, AuthenticationRequiredError,
LoggingConfig, LoggingConfig,
StackRunConfig, StackRunConfig,
process_cors_config,
) )
from llama_stack.core.distribution import builtin_automatically_routed_apis from llama_stack.core.distribution import builtin_automatically_routed_apis
from llama_stack.core.external import ExternalApiSpec, load_external_apis from llama_stack.core.external import ExternalApiSpec, load_external_apis
@ -82,7 +84,7 @@ from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server") logger = get_logger(name=__name__, category="core::server")
def warn_with_traceback(message, category, filename, lineno, file=None, line=None): def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -413,7 +415,7 @@ def main(args: argparse.Namespace | None = None):
config_contents = yaml.safe_load(fp) config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg) logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="server", config=logger_config) logger = get_logger(name=__name__, category="core::server", config=logger_config)
if args.env: if args.env:
for env_pair in args.env: for env_pair in args.env:
try: try:
@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None):
window_seconds=window_seconds, window_seconds=window_seconds,
) )
if config.server.cors:
logger.info("Enabling CORS")
cors_config = process_cors_config(config.server.cors)
if cors_config:
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
if Api.telemetry in impls: if Api.telemetry in impls:
setup_logger(impls[Api.telemetry]) setup_logger(impls[Api.telemetry])
else: else:

View file

@ -16,7 +16,7 @@ from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
logger = get_logger(__name__, category="core") logger = get_logger(__name__, category="core::registry")
class DistributionRegistry(Protocol): class DistributionRegistry(Protocol):

View file

@ -10,7 +10,7 @@ from pathlib import Path
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="config_resolution") logger = get_logger(name=__name__, category="core")
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions" DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging import importlib
import os import os
import signal import signal
import subprocess import subprocess
@ -12,9 +12,9 @@ import sys
from termcolor import cprint from termcolor import cprint
log = logging.getLogger(__name__) from llama_stack.log import get_logger
import importlib log = get_logger(name=__name__, category="core")
def formulate_run_args(image_type: str, image_name: str) -> list: def formulate_run_args(image_type: str, image_name: str) -> list:

View file

@ -6,7 +6,6 @@
import inspect import inspect
import json import json
import logging
from enum import Enum from enum import Enum
from typing import Annotated, Any, Literal, Union, get_args, get_origin from typing import Annotated, Any, Literal, Union, get_args, get_origin
@ -14,7 +13,9 @@ from pydantic import BaseModel
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefinedType from pydantic_core import PydanticUndefinedType
log = logging.getLogger(__name__) from llama_stack.log import get_logger
log = get_logger(name=__name__, category="core")
def is_list_of_primitives(field_type): def is_list_of_primitives(field_type):

View file

@ -34,7 +34,7 @@ distribution_spec:
telemetry: telemetry:
- provider_type: inline::meta-reference - provider_type: inline::meta-reference
post_training: post_training:
- provider_type: inline::huggingface - provider_type: inline::huggingface-cpu
eval: eval:
- provider_type: inline::meta-reference - provider_type: inline::meta-reference
datasetio: datasetio:

View file

@ -156,8 +156,8 @@ providers:
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training: post_training:
- provider_id: huggingface - provider_id: huggingface-cpu
provider_type: inline::huggingface provider_type: inline::huggingface-cpu
config: config:
checkpoint_format: huggingface checkpoint_format: huggingface
distributed_backend: null distributed_backend: null

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .starter_gpu import get_distribution_template # noqa: F401

View file

@ -0,0 +1,59 @@
version: 2
distribution_spec:
description: Quick start template for running Llama Stack with several popular providers.
This distribution is intended for GPU-enabled environments.
providers:
inference:
- provider_type: remote::cerebras
- provider_type: remote::ollama
- provider_type: remote::vllm
- provider_type: remote::tgi
- provider_type: remote::fireworks
- provider_type: remote::together
- provider_type: remote::bedrock
- provider_type: remote::nvidia
- provider_type: remote::openai
- provider_type: remote::anthropic
- provider_type: remote::gemini
- provider_type: remote::vertexai
- provider_type: remote::groq
- provider_type: remote::sambanova
- provider_type: inline::sentence-transformers
vector_io:
- provider_type: inline::faiss
- provider_type: inline::sqlite-vec
- provider_type: inline::milvus
- provider_type: remote::chromadb
- provider_type: remote::pgvector
files:
- provider_type: inline::localfs
safety:
- provider_type: inline::llama-guard
- provider_type: inline::code-scanner
agents:
- provider_type: inline::meta-reference
telemetry:
- provider_type: inline::meta-reference
post_training:
- provider_type: inline::torchtune-gpu
eval:
- provider_type: inline::meta-reference
datasetio:
- provider_type: remote::huggingface
- provider_type: inline::localfs
scoring:
- provider_type: inline::basic
- provider_type: inline::llm-as-judge
- provider_type: inline::braintrust
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
batches:
- provider_type: inline::reference
image_type: venv
additional_pip_packages:
- aiosqlite
- asyncpg
- sqlalchemy[asyncio]

View file

@ -0,0 +1,238 @@
version: 2
image_name: starter-gpu
apis:
- agents
- batches
- datasetio
- eval
- files
- inference
- post_training
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
api_key: ${env.OPENAI_API_KEY:=}
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
- provider_id: anthropic
provider_type: remote::anthropic
config:
api_key: ${env.ANTHROPIC_API_KEY:=}
- provider_id: gemini
provider_type: remote::gemini
config:
api_key: ${env.GEMINI_API_KEY:=}
- provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
provider_type: remote::vertexai
config:
project: ${env.VERTEX_AI_PROJECT:=}
location: ${env.VERTEX_AI_LOCATION:=us-central1}
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
- provider_id: ${env.MILVUS_URL:+milvus}
provider_type: inline::milvus
config:
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db
- provider_id: ${env.CHROMADB_URL:+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db
- provider_id: ${env.PGVECTOR_DB:+pgvector}
provider_type: remote::pgvector
config:
host: ${env.PGVECTOR_HOST:=localhost}
port: ${env.PGVECTOR_PORT:=5432}
db: ${env.PGVECTOR_DB:=}
user: ${env.PGVECTOR_USER:=}
password: ${env.PGVECTOR_PASSWORD:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
- provider_id: code-scanner
provider_type: inline::code-scanner
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training:
- provider_id: torchtune-gpu
provider_type: inline::torchtune-gpu
config:
checkpoint_format: meta
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/huggingface_datasetio.db
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:=}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
batches:
- provider_id: reference
provider_type: inline::reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/batches.db
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/inference_store.db
models: []
shields:
- shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=}
- shield_id: code-scanner
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.distributions.template import BuildProvider, DistributionTemplate
from ..starter.starter import get_distribution_template as get_starter_distribution_template
def get_distribution_template() -> DistributionTemplate:
template = get_starter_distribution_template()
name = "starter-gpu"
template.name = name
template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments."
template.providers["post_training"] = [
BuildProvider(provider_type="inline::torchtune-gpu"),
]
return template

View file

@ -1,6 +1,7 @@
version: 2 version: 2
distribution_spec: distribution_spec:
description: Quick start template for running Llama Stack with several popular providers description: Quick start template for running Llama Stack with several popular providers.
This distribution is intended for CPU-only environments.
providers: providers:
inference: inference:
- provider_type: remote::cerebras - provider_type: remote::cerebras
@ -34,7 +35,7 @@ distribution_spec:
telemetry: telemetry:
- provider_type: inline::meta-reference - provider_type: inline::meta-reference
post_training: post_training:
- provider_type: inline::huggingface - provider_type: inline::huggingface-cpu
eval: eval:
- provider_type: inline::meta-reference - provider_type: inline::meta-reference
datasetio: datasetio:

View file

@ -156,8 +156,8 @@ providers:
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training: post_training:
- provider_id: huggingface - provider_id: huggingface-cpu
provider_type: inline::huggingface provider_type: inline::huggingface-cpu
config: config:
checkpoint_format: huggingface checkpoint_format: huggingface
distributed_backend: null distributed_backend: null

View file

@ -120,7 +120,7 @@ def get_distribution_template() -> DistributionTemplate:
], ],
"agents": [BuildProvider(provider_type="inline::meta-reference")], "agents": [BuildProvider(provider_type="inline::meta-reference")],
"telemetry": [BuildProvider(provider_type="inline::meta-reference")], "telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"post_training": [BuildProvider(provider_type="inline::huggingface")], "post_training": [BuildProvider(provider_type="inline::huggingface-cpu")],
"eval": [BuildProvider(provider_type="inline::meta-reference")], "eval": [BuildProvider(provider_type="inline::meta-reference")],
"datasetio": [ "datasetio": [
BuildProvider(provider_type="remote::huggingface"), BuildProvider(provider_type="remote::huggingface"),
@ -178,7 +178,7 @@ def get_distribution_template() -> DistributionTemplate:
return DistributionTemplate( return DistributionTemplate(
name=name, name=name,
distro_type="self_hosted", distro_type="self_hosted",
description="Quick start template for running Llama Stack with several popular providers", description="Quick start template for running Llama Stack with several popular providers. This distribution is intended for CPU-only environments.",
container_image=None, container_image=None,
template_path=None, template_path=None,
providers=providers, providers=providers,

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging # allow-direct-logging
import os import os
import re import re
from logging.config import dictConfig from logging.config import dictConfig # allow-direct-logging
from rich.console import Console from rich.console import Console
from rich.errors import MarkupError from rich.errors import MarkupError

View file

@ -13,14 +13,15 @@
# Copyright (c) Meta Platforms, Inc. and its affiliates. # Copyright (c) Meta Platforms, Inc. and its affiliates.
import math import math
from logging import getLogger
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from llama_stack.log import get_logger
from .utils import get_negative_inf_value, to_2tuple from .utils import get_negative_inf_value, to_2tuple
logger = getLogger() logger = get_logger(name=__name__, category="models::llama")
def resize_local_position_embedding(orig_pos_embed, grid_size): def resize_local_position_embedding(orig_pos_embed, grid_size):

View file

@ -13,7 +13,6 @@
import math import math
from collections import defaultdict from collections import defaultdict
from logging import getLogger
from typing import Any from typing import Any
import torch import torch
@ -21,9 +20,11 @@ import torchvision.transforms as tv
from PIL import Image from PIL import Image
from torchvision.transforms import functional as F from torchvision.transforms import functional as F
from llama_stack.log import get_logger
IMAGE_RES = 224 IMAGE_RES = 224
logger = getLogger() logger = get_logger(name=__name__, category="models::llama")
class VariableSizeImageTransform: class VariableSizeImageTransform:

View file

@ -3,8 +3,6 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import math import math
from collections.abc import Callable from collections.abc import Callable
from functools import partial from functools import partial
@ -22,6 +20,8 @@ from PIL import Image as PIL_Image
from torch import Tensor, nn from torch import Tensor, nn
from torch.distributed import _functional_collectives as funcol from torch.distributed import _functional_collectives as funcol
from llama_stack.log import get_logger
from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis
from .encoder_utils import ( from .encoder_utils import (
build_encoder_attention_mask, build_encoder_attention_mask,
@ -34,9 +34,10 @@ from .encoder_utils import (
from .image_transform import VariableSizeImageTransform from .image_transform import VariableSizeImageTransform
from .utils import get_negative_inf_value, to_2tuple from .utils import get_negative_inf_value, to_2tuple
logger = logging.getLogger(__name__)
MP_SCALE = 8 MP_SCALE = 8
logger = get_logger(name=__name__, category="models::llama")
def reduce_from_tensor_model_parallel_region(input_): def reduce_from_tensor_model_parallel_region(input_):
"""All-reduce the input tensor across model parallel group.""" """All-reduce the input tensor across model parallel group."""
@ -771,7 +772,7 @@ class TilePositionEmbedding(nn.Module):
if embed is not None: if embed is not None:
# reshape the weights to the correct shape # reshape the weights to the correct shape
nt_old, nt_old, _, w = embed.shape nt_old, nt_old, _, w = embed.shape
logging.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}") logger.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles) embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles)
# assign the weights to the module # assign the weights to the module
state_dict[prefix + "embedding"] = embed_new state_dict[prefix + "embedding"] = embed_new

View file

@ -4,8 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import Collection, Iterator, Sequence, Set from collections.abc import Collection, Iterator, Sequence, Set
from logging import getLogger
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Literal, Literal,
@ -14,11 +14,9 @@ from typing import (
import tiktoken import tiktoken
from llama_stack.log import get_logger
from llama_stack.models.llama.tokenizer_utils import load_bpe_file from llama_stack.models.llama.tokenizer_utils import load_bpe_file
logger = getLogger(__name__)
# The tiktoken tokenizer can handle <=400k chars without # The tiktoken tokenizer can handle <=400k chars without
# pyo3_runtime.PanicException. # pyo3_runtime.PanicException.
TIKTOKEN_MAX_ENCODE_CHARS = 400_000 TIKTOKEN_MAX_ENCODE_CHARS = 400_000
@ -31,6 +29,8 @@ MAX_NO_WHITESPACES_CHARS = 25_000
_INSTANCE = None _INSTANCE = None
logger = get_logger(name=__name__, category="models::llama")
class Tokenizer: class Tokenizer:
""" """

View file

@ -11,7 +11,7 @@ from llama_stack.log import get_logger
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="models::llama")
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)' BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})") CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import os import os
from collections.abc import Callable from collections.abc import Callable
@ -13,11 +12,13 @@ from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn import functional as F from torch.nn import functional as F
from llama_stack.log import get_logger
from ...datatypes import QuantizationMode from ...datatypes import QuantizationMode
from ..model import Transformer, TransformerBlock from ..model import Transformer, TransformerBlock
from ..moe import MoE from ..moe import MoE
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="models::llama")
def swiglu_wrapper_no_reduce( def swiglu_wrapper_no_reduce(

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import Collection, Iterator, Sequence, Set from collections.abc import Collection, Iterator, Sequence, Set
from logging import getLogger
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Literal, Literal,
@ -14,11 +13,9 @@ from typing import (
import tiktoken import tiktoken
from llama_stack.log import get_logger
from llama_stack.models.llama.tokenizer_utils import load_bpe_file from llama_stack.models.llama.tokenizer_utils import load_bpe_file
logger = getLogger(__name__)
# The tiktoken tokenizer can handle <=400k chars without # The tiktoken tokenizer can handle <=400k chars without
# pyo3_runtime.PanicException. # pyo3_runtime.PanicException.
TIKTOKEN_MAX_ENCODE_CHARS = 400_000 TIKTOKEN_MAX_ENCODE_CHARS = 400_000
@ -101,6 +98,8 @@ BASIC_SPECIAL_TOKENS = [
"<|fim_suffix|>", "<|fim_suffix|>",
] ]
logger = get_logger(name=__name__, category="models::llama")
class Tokenizer: class Tokenizer:
""" """

View file

@ -6,9 +6,10 @@
# type: ignore # type: ignore
import collections import collections
import logging
log = logging.getLogger(__name__) from llama_stack.log import get_logger
log = get_logger(name=__name__, category="models::llama")
try: try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401 import fbgemm_gpu.experimental.gen_ai # noqa: F401

View file

@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search" WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag" RAG_TOOL_GROUP = "builtin::rag"
logger = get_logger(name=__name__, category="agents") logger = get_logger(name=__name__, category="agents::meta_reference")
class ChatAgent(ShieldRunnerMixin): class ChatAgent(ShieldRunnerMixin):

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import uuid import uuid
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from datetime import UTC, datetime from datetime import UTC, datetime
@ -42,6 +41,7 @@ from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import AccessRule from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records from llama_stack.providers.utils.pagination import paginate_records
from llama_stack.providers.utils.responses.responses_store import ResponsesStore from llama_stack.providers.utils.responses.responses_store import ResponsesStore
@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
from .persistence import AgentInfo from .persistence import AgentInfo
from .responses.openai_responses import OpenAIResponsesImpl from .responses.openai_responses import OpenAIResponsesImpl
logger = logging.getLogger() logger = get_logger(name=__name__, category="agents::meta_reference")
class MetaReferenceAgentsImpl(Agents): class MetaReferenceAgentsImpl(Agents):

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import logging
import uuid import uuid
from datetime import UTC, datetime from datetime import UTC, datetime
@ -15,9 +14,10 @@ from llama_stack.core.access_control.access_control import AccessDeniedError, is
from llama_stack.core.access_control.datatypes import AccessRule from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.datatypes import User from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="agents::meta_reference")
class AgentSessionInfo(Session): class AgentSessionInfo(Session):

View file

@ -41,7 +41,7 @@ from .utils import (
convert_response_text_to_chat_response_format, convert_response_text_to_chat_response_format,
) )
logger = get_logger(name=__name__, category="responses") logger = get_logger(name=__name__, category="openai::responses")
class OpenAIResponsePreviousResponseWithInputItems(BaseModel): class OpenAIResponsePreviousResponseWithInputItems(BaseModel):

View file

@ -47,7 +47,7 @@ from llama_stack.log import get_logger
from .types import ChatCompletionContext, ChatCompletionResult from .types import ChatCompletionContext, ChatCompletionResult
from .utils import convert_chat_choice_to_response_message, is_function_tool_call from .utils import convert_chat_choice_to_response_message, is_function_tool_call
logger = get_logger(name=__name__, category="responses") logger = get_logger(name=__name__, category="agents::meta_reference")
class StreamingResponseOrchestrator: class StreamingResponseOrchestrator:

View file

@ -38,7 +38,7 @@ from llama_stack.log import get_logger
from .types import ChatCompletionContext, ToolExecutionResult from .types import ChatCompletionContext, ToolExecutionResult
logger = get_logger(name=__name__, category="responses") logger = get_logger(name=__name__, category="agents::meta_reference")
class ToolExecutor: class ToolExecutor:

View file

@ -17,6 +17,8 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageContent, OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText, OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall, OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseText, OpenAIResponseText,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -99,14 +101,22 @@ async def convert_response_input_to_chat_messages(
""" """
messages: list[OpenAIMessageParam] = [] messages: list[OpenAIMessageParam] = []
if isinstance(input, list): if isinstance(input, list):
# extract all OpenAIResponseInputFunctionToolCallOutput items
# so their corresponding OpenAIToolMessageParam instances can
# be added immediately following the corresponding
# OpenAIAssistantMessageParam
tool_call_results = {}
for input_item in input: for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput): if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
messages.append( tool_call_results[input_item.call_id] = OpenAIToolMessageParam(
OpenAIToolMessageParam( content=input_item.output,
content=input_item.output, tool_call_id=input_item.call_id,
tool_call_id=input_item.call_id,
)
) )
for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
# skip as these have been extracted and inserted in order
pass
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall): elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
tool_call = OpenAIChatCompletionToolCall( tool_call = OpenAIChatCompletionToolCall(
index=0, index=0,
@ -117,6 +127,28 @@ async def convert_response_input_to_chat_messages(
), ),
) )
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
if input_item.call_id in tool_call_results:
messages.append(tool_call_results[input_item.call_id])
del tool_call_results[input_item.call_id]
elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id=input_item.id,
function=OpenAIChatCompletionToolCallFunction(
name=input_item.name,
arguments=input_item.arguments,
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
messages.append(
OpenAIToolMessageParam(
content=input_item.output,
tool_call_id=input_item.id,
)
)
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
# the tool list will be handled separately
pass
else: else:
content = await convert_response_content_to_chat_content(input_item.content) content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role) message_type = await get_message_type_by_role(input_item.role)
@ -125,6 +157,10 @@ async def convert_response_input_to_chat_messages(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context" f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
) )
messages.append(message_type(content=content)) messages.append(message_type(content=content))
if len(tool_call_results):
raise ValueError(
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
)
else: else:
messages.append(OpenAIUserMessageParam(content=input)) messages.append(OpenAIUserMessageParam(content=input))
return messages return messages

View file

@ -5,13 +5,13 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import logging
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing from llama_stack.providers.utils.telemetry import tracing
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="agents::meta_reference")
class SafetyException(Exception): # noqa: N818 class SafetyException(Exception): # noqa: N818

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import hashlib
import itertools import itertools
import json import json
import time import time
@ -136,28 +137,45 @@ class ReferenceBatchesImpl(Batches):
endpoint: str, endpoint: str,
completion_window: Literal["24h"], completion_window: Literal["24h"],
metadata: dict[str, str] | None = None, metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
) -> BatchObject: ) -> BatchObject:
""" """
Create a new batch for processing multiple API requests. Create a new batch for processing multiple API requests.
Error handling by levels - This implementation provides optional idempotency: when an idempotency key
0. Input param handling, results in 40x errors before processing, e.g. (idempotency_key) is provided, a deterministic ID is generated based on the input
- Wrong completion_window parameters. If a batch with the same parameters already exists, it will be
- Invalid metadata types returned instead of creating a duplicate. Without an idempotency key,
- Unknown endpoint each request creates a new batch with a unique ID.
-> no batch created
1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g. Args:
- input_file_id missing input_file_id: The ID of an uploaded file containing requests for the batch.
- invalid json in file endpoint: The endpoint to be used for all requests in the batch.
- missing custom_id, method, url, body completion_window: The time window within which the batch should be processed.
- invalid model metadata: Optional metadata for the batch.
- streaming idempotency_key: Optional idempotency key for enabling idempotent behavior.
-> batch created, validation sends to failed status
2. Processing errors, result in error_file_id entries, e.g. Returns:
- Any error returned from inference endpoint The created or existing batch object.
-> batch created, goes to completed status
""" """
# Error handling by levels -
# 0. Input param handling, results in 40x errors before processing, e.g.
# - Wrong completion_window
# - Invalid metadata types
# - Unknown endpoint
# -> no batch created
# 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
# - input_file_id missing
# - invalid json in file
# - missing custom_id, method, url, body
# - invalid model
# - streaming
# -> batch created, validation sends to failed status
# 2. Processing errors, result in error_file_id entries, e.g.
# - Any error returned from inference endpoint
# -> batch created, goes to completed status
# TODO: set expiration time for garbage collection # TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions"]: if endpoint not in ["/v1/chat/completions"]:
@ -171,6 +189,35 @@ class ReferenceBatchesImpl(Batches):
) )
batch_id = f"batch_{uuid.uuid4().hex[:16]}" batch_id = f"batch_{uuid.uuid4().hex[:16]}"
# For idempotent requests, use the idempotency key for the batch ID
# This ensures the same key always maps to the same batch ID,
# allowing us to detect parameter conflicts
if idempotency_key is not None:
hash_input = idempotency_key.encode("utf-8")
hash_digest = hashlib.sha256(hash_input).hexdigest()[:24]
batch_id = f"batch_{hash_digest}"
try:
existing_batch = await self.retrieve_batch(batch_id)
if (
existing_batch.input_file_id != input_file_id
or existing_batch.endpoint != endpoint
or existing_batch.completion_window != completion_window
or existing_batch.metadata != metadata
):
raise ConflictError(
f"Idempotency key '{idempotency_key}' was previously used with different parameters. "
"Either use a new idempotency key or ensure all parameters match the original request."
)
logger.info(f"Returning existing batch with ID: {batch_id}")
return existing_batch
except ResourceNotFoundError:
# Batch doesn't exist, continue with creation
pass
current_time = int(time.time()) current_time = int(time.time())
batch = BatchObject( batch = BatchObject(
@ -185,6 +232,7 @@ class ReferenceBatchesImpl(Batches):
) )
await self.kvstore.set(f"batch:{batch_id}", batch.to_json()) await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
logger.info(f"Created new batch with ID: {batch_id}")
if self.process_batches: if self.process_batches:
task = asyncio.create_task(self._process_batch(batch_id)) task = asyncio.create_task(self._process_batch(batch_id))

View file

@ -11,6 +11,7 @@ from typing import Annotated
from fastapi import File, Form, Response, UploadFile from fastapi import File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import ( from llama_stack.apis.files import (
Files, Files,
@ -20,12 +21,15 @@ from llama_stack.apis.files import (
OpenAIFilePurpose, OpenAIFilePurpose,
) )
from llama_stack.core.datatypes import AccessRule from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
from .config import LocalfsFilesImplConfig from .config import LocalfsFilesImplConfig
logger = get_logger(name=__name__, category="files")
class LocalfsFilesImpl(Files): class LocalfsFilesImpl(Files):
def __init__(self, config: LocalfsFilesImplConfig, policy: list[AccessRule]) -> None: def __init__(self, config: LocalfsFilesImplConfig, policy: list[AccessRule]) -> None:
@ -65,6 +69,18 @@ class LocalfsFilesImpl(Files):
"""Get the filesystem path for a file ID.""" """Get the filesystem path for a file ID."""
return Path(self.config.storage_dir) / file_id return Path(self.config.storage_dir) / file_id
async def _lookup_file_id(self, file_id: str) -> tuple[OpenAIFileObject, Path]:
"""Look up a OpenAIFileObject and filesystem path from its ID."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
file_path = Path(row.pop("file_path"))
return OpenAIFileObject(**row), file_path
# OpenAI Files API Implementation # OpenAI Files API Implementation
async def openai_upload_file( async def openai_upload_file(
self, self,
@ -157,37 +173,19 @@ class LocalfsFilesImpl(Files):
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject: async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
"""Returns information about a specific file.""" """Returns information about a specific file."""
if not self.sql_store: file_obj, _ = await self._lookup_file_id(file_id)
raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) return file_obj
if not row:
raise ValueError(f"File with id {file_id} not found")
return OpenAIFileObject(
id=row["id"],
filename=row["filename"],
purpose=OpenAIFilePurpose(row["purpose"]),
bytes=row["bytes"],
created_at=row["created_at"],
expires_at=row["expires_at"],
)
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse: async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
"""Delete a file.""" """Delete a file."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row:
raise ValueError(f"File with id {file_id} not found")
# Delete physical file # Delete physical file
file_path = Path(row["file_path"]) _, file_path = await self._lookup_file_id(file_id)
if file_path.exists(): if file_path.exists():
file_path.unlink() file_path.unlink()
# Delete metadata from database # Delete metadata from database
assert self.sql_store is not None, "Files provider not initialized"
await self.sql_store.delete("openai_files", where={"id": file_id}) await self.sql_store.delete("openai_files", where={"id": file_id})
return OpenAIFileDeleteResponse( return OpenAIFileDeleteResponse(
@ -197,25 +195,17 @@ class LocalfsFilesImpl(Files):
async def openai_retrieve_file_content(self, file_id: str) -> Response: async def openai_retrieve_file_content(self, file_id: str) -> Response:
"""Returns the contents of the specified file.""" """Returns the contents of the specified file."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
# Get file metadata
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row:
raise ValueError(f"File with id {file_id} not found")
# Read file content # Read file content
file_path = Path(row["file_path"]) file_obj, file_path = await self._lookup_file_id(file_id)
if not file_path.exists():
raise ValueError(f"File content not found on disk: {file_path}")
with open(file_path, "rb") as f: if not file_path.exists():
content = f.read() logger.warning(f"File '{file_id}'s underlying '{file_path}' is missing, deleting metadata.")
await self.openai_delete_file(file_id)
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
# Return as binary response with appropriate content type # Return as binary response with appropriate content type
return Response( return Response(
content=content, content=file_path.read_bytes(),
media_type="application/octet-stream", media_type="application/octet-stream",
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'}, headers={"Content-Disposition": f'attachment; filename="{file_obj.filename}"'},
) )

View file

@ -12,7 +12,6 @@
import copy import copy
import json import json
import logging
import multiprocessing import multiprocessing
import os import os
import tempfile import tempfile
@ -32,13 +31,14 @@ from fairscale.nn.model_parallel.initialize import (
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from torch.distributed.launcher.api import LaunchConfig, elastic_launch from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import GenerationResult from llama_stack.models.llama.datatypes import GenerationResult
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent, ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent, CompletionRequestWithRawContent,
) )
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="inference")
class ProcessingMessageName(str, Enum): class ProcessingMessageName(str, Enum):

View file

@ -4,13 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
CompletionResponse, CompletionResponse,
InferenceProvider, InferenceProvider,
InterleavedContent,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
@ -21,6 +19,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import ( from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
@ -32,7 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from .config import SentenceTransformersInferenceConfig from .config import SentenceTransformersInferenceConfig
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="inference")
class SentenceTransformersInferenceImpl( class SentenceTransformersInferenceImpl(
@ -100,25 +99,3 @@ class SentenceTransformersInferenceImpl(
tool_config: ToolConfig | None = None, tool_config: ToolConfig | None = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion") raise ValueError("Sentence transformers don't support chat completion")
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")

View file

@ -6,7 +6,6 @@
import gc import gc
import json import json
import logging
import multiprocessing import multiprocessing
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -28,6 +27,7 @@ from llama_stack.apis.post_training import (
LoraFinetuningConfig, LoraFinetuningConfig,
TrainingConfig, TrainingConfig,
) )
from llama_stack.log import get_logger
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig from ..config import HuggingFacePostTrainingConfig
@ -44,7 +44,7 @@ from ..utils import (
split_dataset, split_dataset,
) )
logger = logging.getLogger(__name__) logger = get_logger(name=__name__, category="post_training")
class HFFinetuningSingleDevice: class HFFinetuningSingleDevice:

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import gc import gc
import logging
import multiprocessing import multiprocessing
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -24,6 +23,7 @@ from llama_stack.apis.post_training import (
DPOAlignmentConfig, DPOAlignmentConfig,
TrainingConfig, TrainingConfig,
) )
from llama_stack.log import get_logger
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig from ..config import HuggingFacePostTrainingConfig
@ -40,7 +40,7 @@ from ..utils import (
split_dataset, split_dataset,
) )
logger = logging.getLogger(__name__) logger = get_logger(name=__name__, category="post_training")
class HFDPOAlignmentSingleDevice: class HFDPOAlignmentSingleDevice:

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import os import os
import signal import signal
import sys import sys
@ -19,10 +18,11 @@ from transformers import AutoConfig, AutoModelForCausalLM
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.post_training import Checkpoint, TrainingConfig from llama_stack.apis.post_training import Checkpoint, TrainingConfig
from llama_stack.log import get_logger
from .config import HuggingFacePostTrainingConfig from .config import HuggingFacePostTrainingConfig
logger = logging.getLogger(__name__) logger = get_logger(name=__name__, category="post_training")
def setup_environment(): def setup_environment():

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import os import os
import time import time
from datetime import UTC, datetime from datetime import UTC, datetime
@ -19,6 +18,7 @@ from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training from torchtune import modules, training
from torchtune import utils as torchtune_utils from torchtune import utils as torchtune_utils
from torchtune.data import padded_collate_sft from torchtune.data import padded_collate_sft
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import ( from torchtune.modules.peft import (
get_adapter_params, get_adapter_params,
@ -45,6 +45,7 @@ from llama_stack.apis.post_training import (
) )
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.core.utils.model_utils import model_local_dir from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from llama_stack.providers.inline.post_training.torchtune.common import utils from llama_stack.providers.inline.post_training.torchtune.common import utils
@ -56,9 +57,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
) )
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="post_training")
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
class LoraFinetuningSingleDevice: class LoraFinetuningSingleDevice:

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import uuid import uuid
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@ -20,13 +19,14 @@ from llama_stack.apis.safety import (
) )
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )
from .config import CodeScannerConfig from .config import CodeScannerConfig
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="safety")
ALLOWED_CODE_SCANNER_MODEL_IDS = [ ALLOWED_CODE_SCANNER_MODEL_IDS = [
"code-scanner", "code-scanner",

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import re import re
import uuid import uuid
from string import Template from string import Template
@ -21,6 +20,7 @@ from llama_stack.apis.safety import (
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.core.datatypes import Api from llama_stack.core.datatypes import Api
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import Role from llama_stack.models.llama.datatypes import Role
from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
@ -132,6 +132,8 @@ Provide your safety assessment for ONLY THE LAST $agent_type message in the abov
PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}") PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}")
logger = get_logger(name=__name__, category="safety")
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: LlamaGuardConfig, deps) -> None: def __init__(self, config: LlamaGuardConfig, deps) -> None:
@ -407,7 +409,7 @@ class LlamaGuardShield:
unsafe_code_list = [code.strip() for code in unsafe_code.split(",")] unsafe_code_list = [code.strip() for code in unsafe_code.split(",")]
invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP] invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP]
if invalid_codes: if invalid_codes:
logging.warning(f"Invalid safety codes returned: {invalid_codes}") logger.warning(f"Invalid safety codes returned: {invalid_codes}")
# just returning safe object, as we don't know what the invalid codes can map to # just returning safe object, as we don't know what the invalid codes can map to
return ModerationObject( return ModerationObject(
id=f"modr-{uuid.uuid4()}", id=f"modr-{uuid.uuid4()}",

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from typing import Any from typing import Any
import torch import torch
@ -21,6 +20,7 @@ from llama_stack.apis.safety import (
from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.core.utils.model_utils import model_local_dir from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
@ -28,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import PromptGuardConfig, PromptGuardType from .config import PromptGuardConfig, PromptGuardType
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="safety")
PROMPT_GUARD_MODEL = "Prompt-Guard-86M" PROMPT_GUARD_MODEL = "Prompt-Guard-86M"

View file

@ -7,7 +7,6 @@
import collections import collections
import functools import functools
import json import json
import logging
import random import random
import re import re
import string import string
@ -20,7 +19,9 @@ import nltk
from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai
from pythainlp.tokenize import word_tokenize as word_tokenize_thai from pythainlp.tokenize import word_tokenize as word_tokenize_thai
logger = logging.getLogger() from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="scoring")
WORD_LIST = [ WORD_LIST = [
"western", "western",

Some files were not shown because too many files have changed in this diff Show more