Merge branch 'main' into chroma

This commit is contained in:
Bwook (Byoungwook) Kim 2025-09-11 20:46:53 +09:00 committed by GitHub
commit 11c71c958e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
308 changed files with 26415 additions and 11807 deletions

View file

@ -2,26 +2,28 @@ name: 'Run and Record Tests'
description: 'Run integration tests and handle recording/artifact upload' description: 'Run integration tests and handle recording/artifact upload'
inputs: inputs:
test-subdirs:
description: 'Comma-separated list of test subdirectories to run'
required: true
test-pattern:
description: 'Regex pattern to pass to pytest -k'
required: false
default: ''
stack-config: stack-config:
description: 'Stack configuration to use' description: 'Stack configuration to use'
required: true required: true
provider: setup:
description: 'Provider to use for tests' description: 'Setup to use for tests (e.g., ollama, gpt, vllm)'
required: true required: false
default: ''
inference-mode: inference-mode:
description: 'Inference mode (record or replay)' description: 'Inference mode (record or replay)'
required: true required: true
run-vision-tests: suite:
description: 'Whether to run vision tests' description: 'Test suite to use: base, responses, vision, etc.'
required: false required: false
default: 'false' default: ''
subdirs:
description: 'Comma-separated list of test subdirectories to run; overrides suite'
required: false
default: ''
pattern:
description: 'Regex pattern to pass to pytest -k'
required: false
default: ''
runs: runs:
using: 'composite' using: 'composite'
@ -36,14 +38,23 @@ runs:
- name: Run Integration Tests - name: Run Integration Tests
shell: bash shell: bash
run: | run: |
uv run --no-sync ./scripts/integration-tests.sh \ SCRIPT_ARGS="--stack-config ${{ inputs.stack-config }} --inference-mode ${{ inputs.inference-mode }}"
--stack-config '${{ inputs.stack-config }}' \
--provider '${{ inputs.provider }}' \ # Add optional arguments only if they are provided
--test-subdirs '${{ inputs.test-subdirs }}' \ if [ -n '${{ inputs.setup }}' ]; then
--test-pattern '${{ inputs.test-pattern }}' \ SCRIPT_ARGS="$SCRIPT_ARGS --setup ${{ inputs.setup }}"
--inference-mode '${{ inputs.inference-mode }}' \ fi
${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }} \ if [ -n '${{ inputs.suite }}' ]; then
| tee pytest-${{ inputs.inference-mode }}.log SCRIPT_ARGS="$SCRIPT_ARGS --suite ${{ inputs.suite }}"
fi
if [ -n '${{ inputs.subdirs }}' ]; then
SCRIPT_ARGS="$SCRIPT_ARGS --subdirs ${{ inputs.subdirs }}"
fi
if [ -n '${{ inputs.pattern }}' ]; then
SCRIPT_ARGS="$SCRIPT_ARGS --pattern ${{ inputs.pattern }}"
fi
uv run --no-sync ./scripts/integration-tests.sh $SCRIPT_ARGS | tee pytest-${{ inputs.inference-mode }}.log
- name: Commit and push recordings - name: Commit and push recordings
@ -57,12 +68,7 @@ runs:
echo "New recordings detected, committing and pushing" echo "New recordings detected, committing and pushing"
git add tests/integration/recordings/ git add tests/integration/recordings/
if [ "${{ inputs.run-vision-tests }}" == "true" ]; then git commit -m "Recordings update from CI (suite: ${{ inputs.suite }})"
git commit -m "Recordings update from CI (vision)"
else
git commit -m "Recordings update from CI"
fi
git fetch origin ${{ github.ref_name }} git fetch origin ${{ github.ref_name }}
git rebase origin/${{ github.ref_name }} git rebase origin/${{ github.ref_name }}
echo "Rebased successfully" echo "Rebased successfully"

View file

@ -1,17 +1,17 @@
name: Setup Ollama name: Setup Ollama
description: Start Ollama description: Start Ollama
inputs: inputs:
run-vision-tests: suite:
description: 'Run vision tests: "true" or "false"' description: 'Test suite to use: base, responses, vision, etc.'
required: false required: false
default: 'false' default: ''
runs: runs:
using: "composite" using: "composite"
steps: steps:
- name: Start Ollama - name: Start Ollama
shell: bash shell: bash
run: | run: |
if [ "${{ inputs.run-vision-tests }}" == "true" ]; then if [ "${{ inputs.suite }}" == "vision" ]; then
image="ollama-with-vision-model" image="ollama-with-vision-model"
else else
image="ollama-with-models" image="ollama-with-models"

View file

@ -8,14 +8,14 @@ inputs:
client-version: client-version:
description: 'Client version (latest or published)' description: 'Client version (latest or published)'
required: true required: true
provider: setup:
description: 'Provider to setup (ollama or vllm)' description: 'Setup to configure (ollama, vllm, gpt, etc.)'
required: true
default: 'ollama'
run-vision-tests:
description: 'Whether to setup provider for vision tests'
required: false required: false
default: 'false' default: 'ollama'
suite:
description: 'Test suite to use: base, responses, vision, etc.'
required: false
default: ''
inference-mode: inference-mode:
description: 'Inference mode (record or replay)' description: 'Inference mode (record or replay)'
required: true required: true
@ -30,13 +30,13 @@ runs:
client-version: ${{ inputs.client-version }} client-version: ${{ inputs.client-version }}
- name: Setup ollama - name: Setup ollama
if: ${{ inputs.provider == 'ollama' && inputs.inference-mode == 'record' }} if: ${{ (inputs.setup == 'ollama' || inputs.setup == 'ollama-vision') && inputs.inference-mode == 'record' }}
uses: ./.github/actions/setup-ollama uses: ./.github/actions/setup-ollama
with: with:
run-vision-tests: ${{ inputs.run-vision-tests }} suite: ${{ inputs.suite }}
- name: Setup vllm - name: Setup vllm
if: ${{ inputs.provider == 'vllm' && inputs.inference-mode == 'record' }} if: ${{ inputs.setup == 'vllm' && inputs.inference-mode == 'record' }}
uses: ./.github/actions/setup-vllm uses: ./.github/actions/setup-vllm
- name: Build Llama Stack - name: Build Llama Stack

View file

@ -5,10 +5,11 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl
| Name | File | Purpose | | Name | File | Purpose |
| ---- | ---- | ------- | | ---- | ---- | ------- |
| Update Changelog | [changelog.yml](changelog.yml) | Creates PR for updating the CHANGELOG.md | | Update Changelog | [changelog.yml](changelog.yml) | Creates PR for updating the CHANGELOG.md |
| API Conformance Tests | [conformance.yml](conformance.yml) | Run the API Conformance test suite on the changes. |
| Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script | | Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script |
| Integration Auth Tests | [integration-auth-tests.yml](integration-auth-tests.yml) | Run the integration test suite with Kubernetes authentication | | Integration Auth Tests | [integration-auth-tests.yml](integration-auth-tests.yml) | Run the integration test suite with Kubernetes authentication |
| SqlStore Integration Tests | [integration-sql-store-tests.yml](integration-sql-store-tests.yml) | Run the integration test suite with SqlStore | | SqlStore Integration Tests | [integration-sql-store-tests.yml](integration-sql-store-tests.yml) | Run the integration test suite with SqlStore |
| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suite from tests/integration in replay mode | | Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suites from tests/integration in replay mode |
| Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers | | Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers |
| Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks | | Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks |
| Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build | | Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build |

57
.github/workflows/conformance.yml vendored Normal file
View file

@ -0,0 +1,57 @@
# API Conformance Tests
# This workflow ensures that API changes maintain backward compatibility and don't break existing integrations
# It runs schema validation and OpenAPI diff checks to catch breaking changes early
name: API Conformance Tests
run-name: Run the API Conformance test suite on the changes.
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
types: [opened, synchronize, reopened]
paths:
- 'llama_stack/**'
- '!llama_stack/ui/**'
- 'tests/**'
- 'uv.lock'
- 'pyproject.toml'
- '.github/workflows/conformance.yml' # This workflow itself
concurrency:
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
# Cancel in-progress runs when new commits are pushed to avoid wasting CI resources
cancel-in-progress: true
jobs:
# Job to check if API schema changes maintain backward compatibility
check-schema-compatibility:
runs-on: ubuntu-latest
steps:
# Using specific version 4.1.7 because 5.0.0 fails when trying to run this locally using `act`
# This ensures consistent behavior between local testing and CI
- name: Checkout PR Code
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
# Checkout the base branch to compare against (usually main)
# This allows us to diff the current changes against the previous state
- name: Checkout Base Branch
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
ref: ${{ github.event.pull_request.base.ref }}
path: 'base'
# Install oasdiff: https://github.com/oasdiff/oasdiff, a tool for detecting breaking changes in OpenAPI specs.
- name: Install oasdiff
run: |
curl -fsSL https://raw.githubusercontent.com/oasdiff/oasdiff/main/install.sh | sh
# Run oasdiff to detect breaking changes in the API specification
# This step will fail if incompatible changes are detected, preventing breaking changes from being merged
- name: Run OpenAPI Breaking Change Diff
run: |
oasdiff breaking --fail-on ERR base/docs/_static/llama-stack-spec.yaml docs/_static/llama-stack-spec.yaml --match-path '^/v1/openai/v1' \
--match-path '^/v1/vector-io' \
--match-path '^/v1/vector-dbs'

View file

@ -1,6 +1,6 @@
name: Integration Tests (Replay) name: Integration Tests (Replay)
run-name: Run the integration test suite from tests/integration in replay mode run-name: Run the integration test suites from tests/integration in replay mode
on: on:
push: push:
@ -28,18 +28,10 @@ on:
description: 'Test against both the latest and published versions' description: 'Test against both the latest and published versions'
type: boolean type: boolean
default: false default: false
test-provider: test-setup:
description: 'Test against a specific provider' description: 'Test against a specific setup'
type: string type: string
default: 'ollama' default: 'ollama'
test-subdirs:
description: 'Comma-separated list of test subdirectories to run'
type: string
default: ''
test-pattern:
description: 'Regex pattern to pass to pytest -k'
type: string
default: ''
concurrency: concurrency:
# Skip concurrency for pushes to main - each commit should be tested independently # Skip concurrency for pushes to main - each commit should be tested independently
@ -50,18 +42,18 @@ jobs:
run-replay-mode-tests: run-replay-mode-tests:
runs-on: ubuntu-latest runs-on: ubuntu-latest
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }} name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, {4})', matrix.client-type, matrix.setup, matrix.python-version, matrix.client-version, matrix.suite) }}
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
client-type: [library, server] client-type: [library, server]
# Use vllm on weekly schedule, otherwise use test-provider input (defaults to ollama) # Use vllm on weekly schedule, otherwise use test-setup input (defaults to ollama)
provider: ${{ (github.event.schedule == '1 0 * * 0') && fromJSON('["vllm"]') || fromJSON(format('["{0}"]', github.event.inputs.test-provider || 'ollama')) }} setup: ${{ (github.event.schedule == '1 0 * * 0') && fromJSON('["vllm"]') || fromJSON(format('["{0}"]', github.event.inputs.test-setup || 'ollama')) }}
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12 # Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }} python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }} client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
run-vision-tests: [true, false] suite: [base, vision]
steps: steps:
- name: Checkout repository - name: Checkout repository
@ -72,16 +64,14 @@ jobs:
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
client-version: ${{ matrix.client-version }} client-version: ${{ matrix.client-version }}
provider: ${{ matrix.provider }} setup: ${{ matrix.setup }}
run-vision-tests: ${{ matrix.run-vision-tests }} suite: ${{ matrix.suite }}
inference-mode: 'replay' inference-mode: 'replay'
- name: Run tests - name: Run tests
uses: ./.github/actions/run-and-record-tests uses: ./.github/actions/run-and-record-tests
with: with:
test-subdirs: ${{ inputs.test-subdirs }}
test-pattern: ${{ inputs.test-pattern }}
stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }} stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }}
provider: ${{ matrix.provider }} setup: ${{ matrix.setup }}
inference-mode: 'replay' inference-mode: 'replay'
run-vision-tests: ${{ matrix.run-vision-tests }} suite: ${{ matrix.suite }}

View file

@ -28,7 +28,7 @@ jobs:
fetch-depth: ${{ github.actor == 'dependabot[bot]' && 0 || 1 }} fetch-depth: ${{ github.actor == 'dependabot[bot]' && 0 || 1 }}
- name: Set up Python - name: Set up Python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
with: with:
python-version: '3.12' python-version: '3.12'
cache: pip cache: pip
@ -37,7 +37,7 @@ jobs:
.pre-commit-config.yaml .pre-commit-config.yaml
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 uses: actions/setup-node@a0853c24544627f65ddf259abe73b1d18a591444 # v5.0.0
with: with:
node-version: '20' node-version: '20'
cache: 'npm' cache: 'npm'
@ -48,7 +48,6 @@ jobs:
working-directory: llama_stack/ui working-directory: llama_stack/ui
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
continue-on-error: true
env: env:
SKIP: no-commit-to-branch SKIP: no-commit-to-branch
RUFF_OUTPUT_FORMAT: github RUFF_OUTPUT_FORMAT: github

View file

@ -24,7 +24,7 @@ jobs:
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@4959332f0f014c5280e7eac8b70c90cb574c9f9b # v6.6.0 uses: astral-sh/setup-uv@557e51de59eb14aaaba2ed9621916900a91d50c6 # v6.6.1
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
activate-environment: true activate-environment: true

View file

@ -10,19 +10,19 @@ run-name: Run the integration test suite from tests/integration
on: on:
workflow_dispatch: workflow_dispatch:
inputs: inputs:
test-subdirs: test-setup:
description: 'Comma-separated list of test subdirectories to run' description: 'Test against a specific setup'
type: string
default: ''
test-provider:
description: 'Test against a specific provider'
type: string type: string
default: 'ollama' default: 'ollama'
run-vision-tests: suite:
description: 'Whether to run vision tests' description: 'Test suite to use: base, responses, vision, etc.'
type: boolean type: string
default: false default: ''
test-pattern: subdirs:
description: 'Comma-separated list of test subdirectories to run; overrides suite'
type: string
default: ''
pattern:
description: 'Regex pattern to pass to pytest -k' description: 'Regex pattern to pass to pytest -k'
type: string type: string
default: '' default: ''
@ -38,11 +38,11 @@ jobs:
- name: Echo workflow inputs - name: Echo workflow inputs
run: | run: |
echo "::group::Workflow Inputs" echo "::group::Workflow Inputs"
echo "test-subdirs: ${{ inputs.test-subdirs }}"
echo "test-provider: ${{ inputs.test-provider }}"
echo "run-vision-tests: ${{ inputs.run-vision-tests }}"
echo "test-pattern: ${{ inputs.test-pattern }}"
echo "branch: ${{ github.ref_name }}" echo "branch: ${{ github.ref_name }}"
echo "test-setup: ${{ inputs.test-setup }}"
echo "suite: ${{ inputs.suite }}"
echo "subdirs: ${{ inputs.subdirs }}"
echo "pattern: ${{ inputs.pattern }}"
echo "::endgroup::" echo "::endgroup::"
- name: Checkout repository - name: Checkout repository
@ -55,16 +55,16 @@ jobs:
with: with:
python-version: "3.12" # Use single Python version for recording python-version: "3.12" # Use single Python version for recording
client-version: "latest" client-version: "latest"
provider: ${{ inputs.test-provider || 'ollama' }} setup: ${{ inputs.test-setup || 'ollama' }}
run-vision-tests: ${{ inputs.run-vision-tests }} suite: ${{ inputs.suite }}
inference-mode: 'record' inference-mode: 'record'
- name: Run and record tests - name: Run and record tests
uses: ./.github/actions/run-and-record-tests uses: ./.github/actions/run-and-record-tests
with: with:
test-pattern: ${{ inputs.test-pattern }}
test-subdirs: ${{ inputs.test-subdirs }}
stack-config: 'server:ci-tests' # recording must be done with server since more tests are run stack-config: 'server:ci-tests' # recording must be done with server since more tests are run
provider: ${{ inputs.test-provider || 'ollama' }} setup: ${{ inputs.test-setup || 'ollama' }}
inference-mode: 'record' inference-mode: 'record'
run-vision-tests: ${{ inputs.run-vision-tests }} suite: ${{ inputs.suite }}
subdirs: ${{ inputs.subdirs }}
pattern: ${{ inputs.pattern }}

View file

@ -24,7 +24,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Stale Action - name: Stale Action
uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # v9.1.0 uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v10.0.0
with: with:
stale-issue-label: 'stale' stale-issue-label: 'stale'
stale-issue-message: > stale-issue-message: >

View file

@ -29,7 +29,7 @@ jobs:
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Setup Node.js - name: Setup Node.js
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 uses: actions/setup-node@a0853c24544627f65ddf259abe73b1d18a591444 # v5.0.0
with: with:
node-version: ${{ matrix.node-version }} node-version: ${{ matrix.node-version }}
cache: 'npm' cache: 'npm'

2
.gitignore vendored
View file

@ -26,5 +26,7 @@ venv/
pytest-report.xml pytest-report.xml
.coverage .coverage
.python-version .python-version
AGENTS.md
server.log
CLAUDE.md CLAUDE.md
.claude/ .claude/

View file

@ -86,7 +86,7 @@ repos:
language: python language: python
pass_filenames: false pass_filenames: false
require_serial: true require_serial: true
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$ files: ^llama_stack/distributions/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
- id: provider-codegen - id: provider-codegen
name: Provider Codegen name: Provider Codegen
additional_dependencies: additional_dependencies:

View file

@ -1,5 +1,103 @@
# Changelog # Changelog
# v0.2.20
Published on: 2025-08-29T22:25:32Z
Here are some key changes that are coming as part of this release.
### Build and Environment
- Environment improvements: fixed env var replacement to preserve types.
- Docker stability: fixed container startup failures for Fireworks AI provider.
- Removed absolute paths in build for better portability.
### Features
- UI Enhancements: Implemented file upload and VectorDB creation/configuration directly in UI.
- Vector Store Improvements: Added keyword, vector, and hybrid search inside vector store.
- Added S3 authorization support for file providers.
- SQL Store: Added inequality support to where clause.
### Documentation
- Fixed post-training docs.
- Added Contributor Guidelines for creating Internal vs. External providers.
### Fixes
- Removed unsupported bfcl scoring function.
- Multiple reliability and configuration fixes for providers and environment handling.
### Engineering / Chores
- Cleaner internal development setup with consistent paths.
- Incremental improvements to provider integration and vector store behavior.
### New Contributors
- @omertuc made their first contribution in #3270
- @r3v5 made their first contribution in vector store hybrid search
---
# v0.2.19
Published on: 2025-08-26T22:06:55Z
## Highlights
* feat: Add CORS configuration support for server by @skamenan7 in https://github.com/llamastack/llama-stack/pull/3201
* feat(api): introduce /rerank by @ehhuang in https://github.com/llamastack/llama-stack/pull/2940
* feat: Add S3 Files Provider by @mattf in https://github.com/llamastack/llama-stack/pull/3202
---
# v0.2.18
Published on: 2025-08-20T01:09:27Z
## Highlights
* Add moderations create API
* Hybrid search in Milvus
* Numerous Responses API improvements
* Documentation updates
---
# v0.2.17
Published on: 2025-08-05T01:51:14Z
## Highlights
* feat(tests): introduce inference record/replay to increase test reliability by @ashwinb in https://github.com/meta-llama/llama-stack/pull/2941
* fix(library_client): improve initialization error handling and prevent AttributeError by @mattf in https://github.com/meta-llama/llama-stack/pull/2944
* fix: use OLLAMA_URL to activate Ollama provider in starter by @ashwinb in https://github.com/meta-llama/llama-stack/pull/2963
* feat(UI): adding MVP playground UI by @franciscojavierarceo in https://github.com/meta-llama/llama-stack/pull/2828
* Standardization of errors (@nathan-weinberg)
* feat: Enable DPO training with HuggingFace inline provider by @Nehanth in https://github.com/meta-llama/llama-stack/pull/2825
* chore: rename templates to distributions by @ashwinb in https://github.com/meta-llama/llama-stack/pull/3035
---
# v0.2.16
Published on: 2025-07-28T23:35:23Z
## Highlights
* Automatic model registration for self-hosted providers (ollama and vllm currently). No need for `INFERENCE_MODEL` environment variables which need to be updated, etc.
* Much simplified starter distribution. Most `ENABLE_` env variables are now gone. When you set `VLLM_URL`, the `vllm` provider is auto-enabled. Similar for `MILVUS_URL`, `PGVECTOR_DB`, etc. Check the [run.yaml](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/starter/run.yaml) for more details.
* All tests migrated to pytest now (thanks @Elbehery)
* DPO implementation in the post-training provider (thanks @Nehanth)
* (Huge!) Support for external APIs and providers thereof (thanks @leseb, @cdoern and others). This is a really big deal -- you can now add more APIs completely out of tree and experiment with them before (optionally) wanting to contribute back.
* `inline::vllm` provider is gone thank you very much
* several improvements to OpenAI inference implementations and LiteLLM backend (thanks @mattf)
* Chroma now supports Vector Store API (thanks @franciscojavierarceo).
* Authorization improvements: Vector Store/File APIs now supports access control (thanks @franciscojavierarceo); Telemetry read APIs are gated according to logged-in user's roles.
---
# v0.2.15 # v0.2.15
Published on: 2025-07-16T03:30:01Z Published on: 2025-07-16T03:30:01Z

View file

@ -34,13 +34,12 @@ This data enables data-driven architectural decisions and performance optimizati
**1. Deploy base k8s infrastructure:** **1. Deploy base k8s infrastructure:**
```bash ```bash
cd ../k8s cd ../../docs/source/distributions/k8s
./apply.sh ./apply.sh
``` ```
**2. Deploy benchmark components:** **2. Deploy benchmark components:**
```bash ```bash
cd ../k8s-benchmark
./apply.sh ./apply.sh
``` ```
@ -56,7 +55,6 @@ kubectl get pods
**Benchmark Llama Stack (default):** **Benchmark Llama Stack (default):**
```bash ```bash
cd docs/source/distributions/k8s-benchmark/
./run-benchmark.sh ./run-benchmark.sh
``` ```

View file

@ -14,7 +14,7 @@ import os
import random import random
import statistics import statistics
import time import time
from typing import Tuple
import aiohttp import aiohttp
@ -55,10 +55,50 @@ class BenchmarkStats:
total_time = self.end_time - self.start_time total_time = self.end_time - self.start_time
success_rate = (self.success_count / self.total_requests) * 100 success_rate = (self.success_count / self.total_requests) * 100
print(f"\n{'='*60}") print(f"\n{'=' * 60}")
print(f"BENCHMARK RESULTS") print("BENCHMARK RESULTS")
print(f"{'='*60}")
print("\nResponse Time Statistics:")
print(f" Mean: {statistics.mean(self.response_times):.3f}s")
print(f" Median: {statistics.median(self.response_times):.3f}s")
print(f" Min: {min(self.response_times):.3f}s")
print(f" Max: {max(self.response_times):.3f}s")
if len(self.response_times) > 1:
print(f" Std Dev: {statistics.stdev(self.response_times):.3f}s")
percentiles = [50, 90, 95, 99]
sorted_times = sorted(self.response_times)
print("\nPercentiles:")
for p in percentiles:
idx = int(len(sorted_times) * p / 100) - 1
idx = max(0, min(idx, len(sorted_times) - 1))
print(f" P{p}: {sorted_times[idx]:.3f}s")
if self.ttft_times:
print("\nTime to First Token (TTFT) Statistics:")
print(f" Mean: {statistics.mean(self.ttft_times):.3f}s")
print(f" Median: {statistics.median(self.ttft_times):.3f}s")
print(f" Min: {min(self.ttft_times):.3f}s")
print(f" Max: {max(self.ttft_times):.3f}s")
if len(self.ttft_times) > 1:
print(f" Std Dev: {statistics.stdev(self.ttft_times):.3f}s")
sorted_ttft = sorted(self.ttft_times)
print("\nTTFT Percentiles:")
for p in percentiles:
idx = int(len(sorted_ttft) * p / 100) - 1
idx = max(0, min(idx, len(sorted_ttft) - 1))
print(f" P{p}: {sorted_ttft[idx]:.3f}s")
if self.chunks_received:
print("\nStreaming Statistics:")
print(f" Mean chunks per response: {statistics.mean(self.chunks_received):.1f}")
print(f" Total chunks received: {sum(self.chunks_received)}")
print(f"{'=' * 60}")
print(f"Total time: {total_time:.2f}s") print(f"Total time: {total_time:.2f}s")
print(f"Concurrent users: {self.concurrent_users}") print(f"Concurrent users: {self.concurrent_users}")
print(f"Total requests: {self.total_requests}") print(f"Total requests: {self.total_requests}")
@ -66,55 +106,16 @@ class BenchmarkStats:
print(f"Failed requests: {len(self.errors)}") print(f"Failed requests: {len(self.errors)}")
print(f"Success rate: {success_rate:.1f}%") print(f"Success rate: {success_rate:.1f}%")
print(f"Requests per second: {self.success_count / total_time:.2f}") print(f"Requests per second: {self.success_count / total_time:.2f}")
print(f"\nResponse Time Statistics:")
print(f" Mean: {statistics.mean(self.response_times):.3f}s")
print(f" Median: {statistics.median(self.response_times):.3f}s")
print(f" Min: {min(self.response_times):.3f}s")
print(f" Max: {max(self.response_times):.3f}s")
if len(self.response_times) > 1:
print(f" Std Dev: {statistics.stdev(self.response_times):.3f}s")
percentiles = [50, 90, 95, 99]
sorted_times = sorted(self.response_times)
print(f"\nPercentiles:")
for p in percentiles:
idx = int(len(sorted_times) * p / 100) - 1
idx = max(0, min(idx, len(sorted_times) - 1))
print(f" P{p}: {sorted_times[idx]:.3f}s")
if self.ttft_times:
print(f"\nTime to First Token (TTFT) Statistics:")
print(f" Mean: {statistics.mean(self.ttft_times):.3f}s")
print(f" Median: {statistics.median(self.ttft_times):.3f}s")
print(f" Min: {min(self.ttft_times):.3f}s")
print(f" Max: {max(self.ttft_times):.3f}s")
if len(self.ttft_times) > 1:
print(f" Std Dev: {statistics.stdev(self.ttft_times):.3f}s")
sorted_ttft = sorted(self.ttft_times)
print(f"\nTTFT Percentiles:")
for p in percentiles:
idx = int(len(sorted_ttft) * p / 100) - 1
idx = max(0, min(idx, len(sorted_ttft) - 1))
print(f" P{p}: {sorted_ttft[idx]:.3f}s")
if self.chunks_received:
print(f"\nStreaming Statistics:")
print(f" Mean chunks per response: {statistics.mean(self.chunks_received):.1f}")
print(f" Total chunks received: {sum(self.chunks_received)}")
if self.errors: if self.errors:
print(f"\nErrors (showing first 5):") print("\nErrors (showing first 5):")
for error in self.errors[:5]: for error in self.errors[:5]:
print(f" {error}") print(f" {error}")
class LlamaStackBenchmark: class LlamaStackBenchmark:
def __init__(self, base_url: str, model_id: str): def __init__(self, base_url: str, model_id: str):
self.base_url = base_url.rstrip('/') self.base_url = base_url.rstrip("/")
self.model_id = model_id self.model_id = model_id
self.headers = {"Content-Type": "application/json"} self.headers = {"Content-Type": "application/json"}
self.test_messages = [ self.test_messages = [
@ -125,74 +126,67 @@ class LlamaStackBenchmark:
[ [
{"role": "user", "content": "What is machine learning?"}, {"role": "user", "content": "What is machine learning?"},
{"role": "assistant", "content": "Machine learning is a subset of AI..."}, {"role": "assistant", "content": "Machine learning is a subset of AI..."},
{"role": "user", "content": "Can you give me a practical example?"} {"role": "user", "content": "Can you give me a practical example?"},
] ],
] ]
async def make_async_streaming_request(self) -> tuple[float, int, float | None, str | None]:
async def make_async_streaming_request(self) -> Tuple[float, int, float | None, str | None]:
"""Make a single async streaming chat completion request.""" """Make a single async streaming chat completion request."""
messages = random.choice(self.test_messages) messages = random.choice(self.test_messages)
payload = { payload = {"model": self.model_id, "messages": messages, "stream": True, "max_tokens": 100}
"model": self.model_id,
"messages": messages,
"stream": True,
"max_tokens": 100
}
start_time = time.time() start_time = time.time()
chunks_received = 0 chunks_received = 0
ttft = None ttft = None
error = None error = None
session = aiohttp.ClientSession() session = aiohttp.ClientSession()
try: try:
async with session.post( async with session.post(
f"{self.base_url}/chat/completions", f"{self.base_url}/chat/completions",
headers=self.headers, headers=self.headers,
json=payload, json=payload,
timeout=aiohttp.ClientTimeout(total=30) timeout=aiohttp.ClientTimeout(total=30),
) as response: ) as response:
if response.status == 200: if response.status == 200:
async for line in response.content: async for line in response.content:
if line: if line:
line_str = line.decode('utf-8').strip() line_str = line.decode("utf-8").strip()
if line_str.startswith('data: '): if line_str.startswith("data: "):
chunks_received += 1 chunks_received += 1
if ttft is None: if ttft is None:
ttft = time.time() - start_time ttft = time.time() - start_time
if line_str == 'data: [DONE]': if line_str == "data: [DONE]":
break break
if chunks_received == 0: if chunks_received == 0:
error = "No streaming chunks received" error = "No streaming chunks received"
else: else:
text = await response.text() text = await response.text()
error = f"HTTP {response.status}: {text[:100]}" error = f"HTTP {response.status}: {text[:100]}"
except Exception as e: except Exception as e:
error = f"Request error: {str(e)}" error = f"Request error: {str(e)}"
finally: finally:
await session.close() await session.close()
response_time = time.time() - start_time response_time = time.time() - start_time
return response_time, chunks_received, ttft, error return response_time, chunks_received, ttft, error
async def run_benchmark(self, duration: int, concurrent_users: int) -> BenchmarkStats: async def run_benchmark(self, duration: int, concurrent_users: int) -> BenchmarkStats:
"""Run benchmark using async requests for specified duration.""" """Run benchmark using async requests for specified duration."""
stats = BenchmarkStats() stats = BenchmarkStats()
stats.concurrent_users = concurrent_users stats.concurrent_users = concurrent_users
stats.start_time = time.time() stats.start_time = time.time()
print(f"Starting benchmark: {duration}s duration, {concurrent_users} concurrent users") print(f"Starting benchmark: {duration}s duration, {concurrent_users} concurrent users")
print(f"Target URL: {self.base_url}/chat/completions") print(f"Target URL: {self.base_url}/chat/completions")
print(f"Model: {self.model_id}") print(f"Model: {self.model_id}")
connector = aiohttp.TCPConnector(limit=concurrent_users) connector = aiohttp.TCPConnector(limit=concurrent_users)
async with aiohttp.ClientSession(connector=connector) as session: async with aiohttp.ClientSession(connector=connector):
async def worker(worker_id: int): async def worker(worker_id: int):
"""Worker that sends requests sequentially until canceled.""" """Worker that sends requests sequentially until canceled."""
request_count = 0 request_count = 0
@ -201,12 +195,12 @@ class LlamaStackBenchmark:
response_time, chunks, ttft, error = await self.make_async_streaming_request() response_time, chunks, ttft, error = await self.make_async_streaming_request()
await stats.add_result(response_time, chunks, ttft, error) await stats.add_result(response_time, chunks, ttft, error)
request_count += 1 request_count += 1
except asyncio.CancelledError: except asyncio.CancelledError:
break break
except Exception as e: except Exception as e:
await stats.add_result(0, 0, None, f"Worker {worker_id} error: {str(e)}") await stats.add_result(0, 0, None, f"Worker {worker_id} error: {str(e)}")
# Progress reporting task # Progress reporting task
async def progress_reporter(): async def progress_reporter():
last_report_time = time.time() last_report_time = time.time()
@ -215,48 +209,52 @@ class LlamaStackBenchmark:
await asyncio.sleep(1) # Report every second await asyncio.sleep(1) # Report every second
if time.time() >= last_report_time + 10: # Report every 10 seconds if time.time() >= last_report_time + 10: # Report every 10 seconds
elapsed = time.time() - stats.start_time elapsed = time.time() - stats.start_time
print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s") print(
f"Completed: {stats.total_requests} requests in {elapsed:.1f}s, RPS: {stats.total_requests / elapsed:.1f}"
)
last_report_time = time.time() last_report_time = time.time()
except asyncio.CancelledError: except asyncio.CancelledError:
break break
# Spawn concurrent workers # Spawn concurrent workers
tasks = [asyncio.create_task(worker(i)) for i in range(concurrent_users)] tasks = [asyncio.create_task(worker(i)) for i in range(concurrent_users)]
progress_task = asyncio.create_task(progress_reporter()) progress_task = asyncio.create_task(progress_reporter())
tasks.append(progress_task) tasks.append(progress_task)
# Wait for duration then cancel all tasks # Wait for duration then cancel all tasks
await asyncio.sleep(duration) await asyncio.sleep(duration)
for task in tasks: for task in tasks:
task.cancel() task.cancel()
# Wait for all tasks to complete # Wait for all tasks to complete
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
stats.end_time = time.time() stats.end_time = time.time()
return stats return stats
def main(): def main():
parser = argparse.ArgumentParser(description="Llama Stack Benchmark Tool") parser = argparse.ArgumentParser(description="Llama Stack Benchmark Tool")
parser.add_argument("--base-url", default=os.getenv("BENCHMARK_BASE_URL", "http://localhost:8000/v1/openai/v1"), parser.add_argument(
help="Base URL for the API (default: http://localhost:8000/v1/openai/v1)") "--base-url",
parser.add_argument("--model", default=os.getenv("INFERENCE_MODEL", "test-model"), default=os.getenv("BENCHMARK_BASE_URL", "http://localhost:8000/v1/openai/v1"),
help="Model ID to use for requests") help="Base URL for the API (default: http://localhost:8000/v1/openai/v1)",
parser.add_argument("--duration", type=int, default=60, )
help="Duration in seconds to run benchmark (default: 60)") parser.add_argument(
parser.add_argument("--concurrent", type=int, default=10, "--model", default=os.getenv("INFERENCE_MODEL", "test-model"), help="Model ID to use for requests"
help="Number of concurrent users (default: 10)") )
parser.add_argument("--duration", type=int, default=60, help="Duration in seconds to run benchmark (default: 60)")
parser.add_argument("--concurrent", type=int, default=10, help="Number of concurrent users (default: 10)")
args = parser.parse_args() args = parser.parse_args()
benchmark = LlamaStackBenchmark(args.base_url, args.model) benchmark = LlamaStackBenchmark(args.base_url, args.model)
try: try:
stats = asyncio.run(benchmark.run_benchmark(args.duration, args.concurrent)) stats = asyncio.run(benchmark.run_benchmark(args.duration, args.concurrent))
stats.print_summary() stats.print_summary()
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nBenchmark interrupted by user") print("\nBenchmark interrupted by user")
except Exception as e: except Exception as e:

View file

@ -11,180 +11,192 @@ OpenAI-compatible mock server that returns:
- Valid OpenAI-formatted chat completion responses with dynamic content - Valid OpenAI-formatted chat completion responses with dynamic content
""" """
from flask import Flask, request, jsonify, Response
import time
import random
import uuid
import json
import argparse import argparse
import json
import os import os
import random
import time
import uuid
from flask import Flask, Response, jsonify, request
app = Flask(__name__) app = Flask(__name__)
# Models from environment variables # Models from environment variables
def get_models(): def get_models():
models_str = os.getenv("MOCK_MODELS", "meta-llama/Llama-3.2-3B-Instruct") models_str = os.getenv("MOCK_MODELS", "meta-llama/Llama-3.2-3B-Instruct")
model_ids = [m.strip() for m in models_str.split(",") if m.strip()] model_ids = [m.strip() for m in models_str.split(",") if m.strip()]
return { return {
"object": "list", "object": "list",
"data": [ "data": [
{ {"id": model_id, "object": "model", "created": 1234567890, "owned_by": "vllm"} for model_id in model_ids
"id": model_id, ],
"object": "model",
"created": 1234567890,
"owned_by": "vllm"
}
for model_id in model_ids
]
} }
def generate_random_text(length=50): def generate_random_text(length=50):
"""Generate random but coherent text for responses.""" """Generate random but coherent text for responses."""
words = [ words = [
"Hello", "there", "I'm", "an", "AI", "assistant", "ready", "to", "help", "you", "Hello",
"with", "your", "questions", "and", "tasks", "today", "Let", "me","know", "what", "there",
"you'd", "like", "to", "discuss", "or", "explore", "together", "I", "can", "assist", "I'm",
"with", "various", "topics", "including", "coding", "writing", "analysis", "and", "more" "an",
"AI",
"assistant",
"ready",
"to",
"help",
"you",
"with",
"your",
"questions",
"and",
"tasks",
"today",
"Let",
"me",
"know",
"what",
"you'd",
"like",
"to",
"discuss",
"or",
"explore",
"together",
"I",
"can",
"assist",
"with",
"various",
"topics",
"including",
"coding",
"writing",
"analysis",
"and",
"more",
] ]
return " ".join(random.choices(words, k=length)) return " ".join(random.choices(words, k=length))
@app.route('/v1/models', methods=['GET'])
@app.route("/v1/models", methods=["GET"])
def list_models(): def list_models():
models = get_models() models = get_models()
print(f"[MOCK] Returning models: {[m['id'] for m in models['data']]}") print(f"[MOCK] Returning models: {[m['id'] for m in models['data']]}")
return jsonify(models) return jsonify(models)
@app.route('/v1/chat/completions', methods=['POST'])
@app.route("/v1/chat/completions", methods=["POST"])
def chat_completions(): def chat_completions():
"""Return OpenAI-formatted chat completion responses.""" """Return OpenAI-formatted chat completion responses."""
data = request.get_json() data = request.get_json()
default_model = get_models()['data'][0]['id'] default_model = get_models()["data"][0]["id"]
model = data.get('model', default_model) model = data.get("model", default_model)
messages = data.get('messages', []) messages = data.get("messages", [])
stream = data.get('stream', False) stream = data.get("stream", False)
print(f"[MOCK] Chat completion request - model: {model}, stream: {stream}") print(f"[MOCK] Chat completion request - model: {model}, stream: {stream}")
if stream: if stream:
return handle_streaming_completion(model, messages) return handle_streaming_completion(model, messages)
else: else:
return handle_non_streaming_completion(model, messages) return handle_non_streaming_completion(model, messages)
def handle_non_streaming_completion(model, messages): def handle_non_streaming_completion(model, messages):
response_text = generate_random_text(random.randint(20, 80)) response_text = generate_random_text(random.randint(20, 80))
# Calculate realistic token counts # Calculate realistic token counts
prompt_tokens = sum(len(str(msg.get('content', '')).split()) for msg in messages) prompt_tokens = sum(len(str(msg.get("content", "")).split()) for msg in messages)
completion_tokens = len(response_text.split()) completion_tokens = len(response_text.split())
response = { response = {
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}", "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"object": "chat.completion", "object": "chat.completion",
"created": int(time.time()), "created": int(time.time()),
"model": model, "model": model,
"choices": [ "choices": [{"index": 0, "message": {"role": "assistant", "content": response_text}, "finish_reason": "stop"}],
{
"index": 0,
"message": {
"role": "assistant",
"content": response_text
},
"finish_reason": "stop"
}
],
"usage": { "usage": {
"prompt_tokens": prompt_tokens, "prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens, "completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens "total_tokens": prompt_tokens + completion_tokens,
} },
} }
return jsonify(response) return jsonify(response)
def handle_streaming_completion(model, messages): def handle_streaming_completion(model, messages):
def generate_stream(): def generate_stream():
# Generate response text # Generate response text
full_response = generate_random_text(random.randint(30, 100)) full_response = generate_random_text(random.randint(30, 100))
words = full_response.split() words = full_response.split()
# Send initial chunk # Send initial chunk
initial_chunk = { initial_chunk = {
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}", "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"created": int(time.time()), "created": int(time.time()),
"model": model, "model": model,
"choices": [ "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}}],
{
"index": 0,
"delta": {"role": "assistant", "content": ""}
}
]
} }
yield f"data: {json.dumps(initial_chunk)}\n\n" yield f"data: {json.dumps(initial_chunk)}\n\n"
# Send word by word # Send word by word
for i, word in enumerate(words): for i, word in enumerate(words):
chunk = { chunk = {
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}", "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"created": int(time.time()), "created": int(time.time()),
"model": model, "model": model,
"choices": [ "choices": [{"index": 0, "delta": {"content": f"{word} " if i < len(words) - 1 else word}}],
{
"index": 0,
"delta": {"content": f"{word} " if i < len(words) - 1 else word}
}
]
} }
yield f"data: {json.dumps(chunk)}\n\n" yield f"data: {json.dumps(chunk)}\n\n"
# Configurable delay to simulate realistic streaming # Configurable delay to simulate realistic streaming
stream_delay = float(os.getenv("STREAM_DELAY_SECONDS", "0.005")) stream_delay = float(os.getenv("STREAM_DELAY_SECONDS", "0.005"))
time.sleep(stream_delay) time.sleep(stream_delay)
# Send final chunk # Send final chunk
final_chunk = { final_chunk = {
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}", "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"created": int(time.time()), "created": int(time.time()),
"model": model, "model": model,
"choices": [ "choices": [{"index": 0, "delta": {"content": ""}, "finish_reason": "stop"}],
{
"index": 0,
"delta": {"content": ""},
"finish_reason": "stop"
}
]
} }
yield f"data: {json.dumps(final_chunk)}\n\n" yield f"data: {json.dumps(final_chunk)}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return Response( return Response(
generate_stream(), generate_stream(),
mimetype='text/event-stream', mimetype="text/event-stream",
headers={ headers={
'Cache-Control': 'no-cache', "Cache-Control": "no-cache",
'Connection': 'keep-alive', "Connection": "keep-alive",
'Access-Control-Allow-Origin': '*', "Access-Control-Allow-Origin": "*",
} },
) )
@app.route('/health', methods=['GET'])
@app.route("/health", methods=["GET"])
def health(): def health():
return jsonify({"status": "healthy", "type": "openai-mock"}) return jsonify({"status": "healthy", "type": "openai-mock"})
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='OpenAI-compatible mock server') if __name__ == "__main__":
parser.add_argument('--port', type=int, default=8081, parser = argparse.ArgumentParser(description="OpenAI-compatible mock server")
help='Port to run the server on (default: 8081)') parser.add_argument("--port", type=int, default=8081, help="Port to run the server on (default: 8081)")
args = parser.parse_args() args = parser.parse_args()
port = args.port port = args.port
models = get_models() models = get_models()
print("Starting OpenAI-compatible mock server...") print("Starting OpenAI-compatible mock server...")
print(f"- /models endpoint with: {[m['id'] for m in models['data']]}") print(f"- /models endpoint with: {[m['id'] for m in models['data']]}")
print("- OpenAI-formatted chat/completion responses with dynamic content") print("- OpenAI-formatted chat/completion responses with dynamic content")
print("- Streaming support with valid SSE format") print("- Streaming support with valid SSE format")
print(f"- Listening on: http://0.0.0.0:{port}") print(f"- Listening on: http://0.0.0.0:{port}")
app.run(host='0.0.0.0', port=port, debug=False) app.run(host="0.0.0.0", port=port, debug=False)

View file

@ -6,6 +6,7 @@ data:
apis: apis:
- agents - agents
- inference - inference
- files
- safety - safety
- telemetry - telemetry
- tool_runtime - tool_runtime
@ -19,13 +20,6 @@ data:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake} api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true} tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: vllm-safety
provider_type: remote::vllm
config:
url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}
@ -41,6 +35,14 @@ data:
db: ${env.POSTGRES_DB:=llamastack} db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack} user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack} password: ${env.POSTGRES_PASSWORD:=llamastack}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard
@ -111,9 +113,6 @@ data:
- model_id: ${env.INFERENCE_MODEL} - model_id: ${env.INFERENCE_MODEL}
provider_id: vllm-inference provider_id: vllm-inference
model_type: llm model_type: llm
- model_id: ${env.SAFETY_MODEL}
provider_id: vllm-safety
model_type: llm
shields: shields:
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B} - shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
vector_dbs: [] vector_dbs: []

View file

@ -2,7 +2,10 @@ version: '2'
image_name: kubernetes-benchmark-demo image_name: kubernetes-benchmark-demo
apis: apis:
- agents - agents
- files
- inference - inference
- files
- safety
- telemetry - telemetry
- tool_runtime - tool_runtime
- vector_io - vector_io
@ -18,6 +21,14 @@ providers:
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
vector_io: vector_io:
- provider_id: ${env.ENABLE_CHROMADB:+chromadb} - provider_id: ${env.ENABLE_CHROMADB:+chromadb}
provider_type: remote::chromadb provider_type: remote::chromadb
@ -30,6 +41,19 @@ providers:
db: ${env.POSTGRES_DB:=llamastack} db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack} user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack} password: ${env.POSTGRES_PASSWORD:=llamastack}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents: agents:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
@ -95,6 +119,8 @@ models:
- model_id: ${env.INFERENCE_MODEL} - model_id: ${env.INFERENCE_MODEL}
provider_id: vllm-inference provider_id: vllm-inference
model_type: llm model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []

View file

@ -1,5 +1,106 @@
@import url("theme.css"); @import url("theme.css");
/* Horizontal Navigation Bar */
.horizontal-nav {
background-color: #ffffff;
border-bottom: 1px solid #e5e5e5;
padding: 0;
position: fixed;
top: 0;
left: 0;
right: 0;
z-index: 1050;
height: 50px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
[data-theme="dark"] .horizontal-nav {
background-color: #1a1a1a;
border-bottom: 1px solid #333;
}
.horizontal-nav .nav-container {
max-width: 1200px;
margin: 0 auto;
display: flex;
align-items: center;
justify-content: space-between;
padding: 0 20px;
height: 100%;
}
.horizontal-nav .nav-brand {
font-size: 18px;
font-weight: 600;
color: #333;
text-decoration: none;
}
[data-theme="dark"] .horizontal-nav .nav-brand {
color: #fff;
}
.horizontal-nav .nav-links {
display: flex;
align-items: center;
gap: 30px;
list-style: none;
margin: 0;
padding: 0;
}
.horizontal-nav .nav-links a {
color: #666;
text-decoration: none;
font-size: 14px;
font-weight: 500;
padding: 8px 12px;
border-radius: 6px;
transition: all 0.2s ease;
}
.horizontal-nav .nav-links a:hover,
.horizontal-nav .nav-links a.active {
color: #333;
background-color: #f5f5f5;
}
.horizontal-nav .nav-links a.active {
font-weight: 600;
}
[data-theme="dark"] .horizontal-nav .nav-links a {
color: #ccc;
}
[data-theme="dark"] .horizontal-nav .nav-links a:hover,
[data-theme="dark"] .horizontal-nav .nav-links a.active {
color: #fff;
background-color: #333;
}
.horizontal-nav .nav-links .github-link {
display: flex;
align-items: center;
gap: 6px;
}
.horizontal-nav .nav-links .github-icon {
width: 16px;
height: 16px;
fill: currentColor;
}
/* Adjust main content to account for fixed nav */
.wy-nav-side {
top: 50px;
height: calc(100vh - 50px);
}
.wy-nav-content-wrap {
margin-top: 50px;
}
.wy-nav-content { .wy-nav-content {
max-width: 90%; max-width: 90%;
} }

44
docs/_static/js/horizontal_nav.js vendored Normal file
View file

@ -0,0 +1,44 @@
// Horizontal Navigation Bar for Llama Stack Documentation
document.addEventListener('DOMContentLoaded', function() {
// Create the horizontal navigation HTML
const navHTML = `
<nav class="horizontal-nav">
<div class="nav-container">
<a href="/" class="nav-brand">Llama Stack</a>
<ul class="nav-links">
<li><a href="/">Docs</a></li>
<li><a href="/references/api_reference/">API Reference</a></li>
<li><a href="https://github.com/meta-llama/llama-stack" target="_blank" class="github-link">
<svg class="github-icon" viewBox="0 0 16 16" aria-hidden="true">
<path d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0016 8c0-4.42-3.58-8-8-8z"/>
</svg>
GitHub
</a></li>
</ul>
</div>
</nav>
`;
// Insert the navigation at the beginning of the body
document.body.insertAdjacentHTML('afterbegin', navHTML);
// Update navigation links based on current page
updateActiveNav();
});
function updateActiveNav() {
const currentPath = window.location.pathname;
const navLinks = document.querySelectorAll('.horizontal-nav .nav-links a');
navLinks.forEach(link => {
// Remove any existing active classes
link.classList.remove('active');
// Add active class based on current path
if (currentPath === '/' && link.getAttribute('href') === '/') {
link.classList.add('active');
} else if (currentPath.includes('/references/api_reference/') && link.getAttribute('href').includes('api_reference')) {
link.classList.add('active');
}
});
}

View file

@ -633,6 +633,80 @@
} }
} }
}, },
"/v1/prompts": {
"get": {
"responses": {
"200": {
"description": "A ListPromptsResponse containing all prompts.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ListPromptsResponse"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Prompts"
],
"description": "List all prompts.",
"parameters": []
},
"post": {
"responses": {
"200": {
"description": "The created Prompt resource.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Prompt"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Prompts"
],
"description": "Create a new prompt.",
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/CreatePromptRequest"
}
}
},
"required": true
}
}
},
"/v1/agents/{agent_id}": { "/v1/agents/{agent_id}": {
"get": { "get": {
"responses": { "responses": {
@ -901,6 +975,143 @@
] ]
} }
}, },
"/v1/prompts/{prompt_id}": {
"get": {
"responses": {
"200": {
"description": "A Prompt resource.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Prompt"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Prompts"
],
"description": "Get a prompt by its identifier and optional version.",
"parameters": [
{
"name": "prompt_id",
"in": "path",
"description": "The identifier of the prompt to get.",
"required": true,
"schema": {
"type": "string"
}
},
{
"name": "version",
"in": "query",
"description": "The version of the prompt to get (defaults to latest).",
"required": false,
"schema": {
"type": "integer"
}
}
]
},
"post": {
"responses": {
"200": {
"description": "The updated Prompt resource with incremented version.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Prompt"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Prompts"
],
"description": "Update an existing prompt (increments version).",
"parameters": [
{
"name": "prompt_id",
"in": "path",
"description": "The identifier of the prompt to update.",
"required": true,
"schema": {
"type": "string"
}
}
],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/UpdatePromptRequest"
}
}
},
"required": true
}
},
"delete": {
"responses": {
"200": {
"description": "OK"
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Prompts"
],
"description": "Delete a prompt.",
"parameters": [
{
"name": "prompt_id",
"in": "path",
"description": "The identifier of the prompt to delete.",
"required": true,
"schema": {
"type": "string"
}
}
]
}
},
"/v1/inference/embeddings": { "/v1/inference/embeddings": {
"post": { "post": {
"responses": { "responses": {
@ -2836,6 +3047,49 @@
] ]
} }
}, },
"/v1/prompts/{prompt_id}/versions": {
"get": {
"responses": {
"200": {
"description": "A ListPromptsResponse containing all versions of the prompt.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ListPromptsResponse"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Prompts"
],
"description": "List all versions of a specific prompt.",
"parameters": [
{
"name": "prompt_id",
"in": "path",
"description": "The identifier of the prompt to list versions for.",
"required": true,
"schema": {
"type": "string"
}
}
]
}
},
"/v1/providers": { "/v1/providers": {
"get": { "get": {
"responses": { "responses": {
@ -4129,7 +4383,7 @@
"tags": [ "tags": [
"Files" "Files"
], ],
"description": "Upload a file that can be used across various endpoints.\nThe file upload should be a multipart form request with:\n- file: The File object (not file name) to be uploaded.\n- purpose: The intended purpose of the uploaded file.", "description": "Upload a file that can be used across various endpoints.\nThe file upload should be a multipart form request with:\n- file: The File object (not file name) to be uploaded.\n- purpose: The intended purpose of the uploaded file.\n- expires_after: Optional form values describing expiration for the file. Expected expires_after[anchor] = \"created_at\", expires_after[seconds] = <int>. Seconds must be between 3600 and 2592000 (1 hour to 30 days).",
"parameters": [], "parameters": [],
"requestBody": { "requestBody": {
"content": { "content": {
@ -4143,11 +4397,33 @@
}, },
"purpose": { "purpose": {
"$ref": "#/components/schemas/OpenAIFilePurpose" "$ref": "#/components/schemas/OpenAIFilePurpose"
},
"expires_after_anchor": {
"oneOf": [
{
"type": "string"
},
{
"type": "null"
}
]
},
"expires_after_seconds": {
"oneOf": [
{
"type": "integer"
},
{
"type": "null"
}
]
} }
}, },
"required": [ "required": [
"file", "file",
"purpose" "purpose",
"expires_after_anchor",
"expires_after_seconds"
] ]
} }
} }
@ -4985,6 +5261,59 @@
} }
} }
}, },
"/v1/prompts/{prompt_id}/set-default-version": {
"post": {
"responses": {
"200": {
"description": "The prompt with the specified version now set as default.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Prompt"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Prompts"
],
"description": "Set which version of a prompt should be the default in get_prompt (latest).",
"parameters": [
{
"name": "prompt_id",
"in": "path",
"description": "The identifier of the prompt.",
"required": true,
"schema": {
"type": "string"
}
}
],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SetDefaultVersionRequest"
}
}
},
"required": true
}
}
},
"/v1/post-training/supervised-fine-tune": { "/v1/post-training/supervised-fine-tune": {
"post": { "post": {
"responses": { "responses": {
@ -9648,6 +9977,65 @@
], ],
"title": "OpenAIResponseObjectStreamResponseWebSearchCallSearching" "title": "OpenAIResponseObjectStreamResponseWebSearchCallSearching"
}, },
"CreatePromptRequest": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "The prompt text content with variable placeholders."
},
"variables": {
"type": "array",
"items": {
"type": "string"
},
"description": "List of variable names that can be used in the prompt template."
}
},
"additionalProperties": false,
"required": [
"prompt"
],
"title": "CreatePromptRequest"
},
"Prompt": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "The system prompt text with variable placeholders. Variables are only supported when using the Responses API."
},
"version": {
"type": "integer",
"description": "Version (integer starting at 1, incremented on save)"
},
"prompt_id": {
"type": "string",
"description": "Unique identifier formatted as 'pmpt_<48-digit-hash>'"
},
"variables": {
"type": "array",
"items": {
"type": "string"
},
"description": "List of prompt variable names that can be used in the prompt template"
},
"is_default": {
"type": "boolean",
"default": false,
"description": "Boolean indicating whether this version is the default version for this prompt"
}
},
"additionalProperties": false,
"required": [
"version",
"prompt_id",
"variables",
"is_default"
],
"title": "Prompt",
"description": "A prompt resource representing a stored OpenAI Compatible prompt template in Llama Stack."
},
"OpenAIDeleteResponseObject": { "OpenAIDeleteResponseObject": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -10274,7 +10662,8 @@
"scoring_function", "scoring_function",
"benchmark", "benchmark",
"tool", "tool",
"tool_group" "tool_group",
"prompt"
], ],
"const": "benchmark", "const": "benchmark",
"default": "benchmark", "default": "benchmark",
@ -10901,7 +11290,8 @@
"scoring_function", "scoring_function",
"benchmark", "benchmark",
"tool", "tool",
"tool_group" "tool_group",
"prompt"
], ],
"const": "dataset", "const": "dataset",
"default": "dataset", "default": "dataset",
@ -11051,7 +11441,8 @@
"scoring_function", "scoring_function",
"benchmark", "benchmark",
"tool", "tool",
"tool_group" "tool_group",
"prompt"
], ],
"const": "model", "const": "model",
"default": "model", "default": "model",
@ -11316,7 +11707,8 @@
"scoring_function", "scoring_function",
"benchmark", "benchmark",
"tool", "tool",
"tool_group" "tool_group",
"prompt"
], ],
"const": "scoring_function", "const": "scoring_function",
"default": "scoring_function", "default": "scoring_function",
@ -11424,7 +11816,8 @@
"scoring_function", "scoring_function",
"benchmark", "benchmark",
"tool", "tool",
"tool_group" "tool_group",
"prompt"
], ],
"const": "shield", "const": "shield",
"default": "shield", "default": "shield",
@ -11669,7 +12062,8 @@
"scoring_function", "scoring_function",
"benchmark", "benchmark",
"tool", "tool",
"tool_group" "tool_group",
"prompt"
], ],
"const": "tool", "const": "tool",
"default": "tool", "default": "tool",
@ -11751,7 +12145,8 @@
"scoring_function", "scoring_function",
"benchmark", "benchmark",
"tool", "tool",
"tool_group" "tool_group",
"prompt"
], ],
"const": "tool_group", "const": "tool_group",
"default": "tool_group", "default": "tool_group",
@ -12045,7 +12440,8 @@
"scoring_function", "scoring_function",
"benchmark", "benchmark",
"tool", "tool",
"tool_group" "tool_group",
"prompt"
], ],
"const": "vector_db", "const": "vector_db",
"default": "vector_db", "default": "vector_db",
@ -12860,6 +13256,23 @@
"title": "OpenAIResponseObjectWithInput", "title": "OpenAIResponseObjectWithInput",
"description": "OpenAI response object extended with input context information." "description": "OpenAI response object extended with input context information."
}, },
"ListPromptsResponse": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Prompt"
}
}
},
"additionalProperties": false,
"required": [
"data"
],
"title": "ListPromptsResponse",
"description": "Response model to list prompts."
},
"ListProvidersResponse": { "ListProvidersResponse": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -17106,6 +17519,20 @@
"title": "ScoreBatchResponse", "title": "ScoreBatchResponse",
"description": "Response from batch scoring operations on datasets." "description": "Response from batch scoring operations on datasets."
}, },
"SetDefaultVersionRequest": {
"type": "object",
"properties": {
"version": {
"type": "integer",
"description": "The version to set as default."
}
},
"additionalProperties": false,
"required": [
"version"
],
"title": "SetDefaultVersionRequest"
},
"AlgorithmConfig": { "AlgorithmConfig": {
"oneOf": [ "oneOf": [
{ {
@ -17390,6 +17817,37 @@
"title": "SyntheticDataGenerationResponse", "title": "SyntheticDataGenerationResponse",
"description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold." "description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."
}, },
"UpdatePromptRequest": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "The updated prompt text content."
},
"version": {
"type": "integer",
"description": "The current version of the prompt being updated."
},
"variables": {
"type": "array",
"items": {
"type": "string"
},
"description": "Updated list of variable names that can be used in the prompt template."
},
"set_as_default": {
"type": "boolean",
"description": "Set the new version as the default (default=True)."
}
},
"additionalProperties": false,
"required": [
"prompt",
"version",
"set_as_default"
],
"title": "UpdatePromptRequest"
},
"VersionInfo": { "VersionInfo": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -17515,6 +17973,10 @@
{ {
"name": "PostTraining (Coming Soon)" "name": "PostTraining (Coming Soon)"
}, },
{
"name": "Prompts",
"x-displayName": "Protocol for prompt management operations."
},
{ {
"name": "Providers", "name": "Providers",
"x-displayName": "Providers API for inspecting, listing, and modifying providers and their configurations." "x-displayName": "Providers API for inspecting, listing, and modifying providers and their configurations."
@ -17565,6 +18027,7 @@
"Inspect", "Inspect",
"Models", "Models",
"PostTraining (Coming Soon)", "PostTraining (Coming Soon)",
"Prompts",
"Providers", "Providers",
"Safety", "Safety",
"Scoring", "Scoring",

View file

@ -427,6 +427,58 @@ paths:
schema: schema:
$ref: '#/components/schemas/CreateOpenaiResponseRequest' $ref: '#/components/schemas/CreateOpenaiResponseRequest'
required: true required: true
/v1/prompts:
get:
responses:
'200':
description: >-
A ListPromptsResponse containing all prompts.
content:
application/json:
schema:
$ref: '#/components/schemas/ListPromptsResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Prompts
description: List all prompts.
parameters: []
post:
responses:
'200':
description: The created Prompt resource.
content:
application/json:
schema:
$ref: '#/components/schemas/Prompt'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Prompts
description: Create a new prompt.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/CreatePromptRequest'
required: true
/v1/agents/{agent_id}: /v1/agents/{agent_id}:
get: get:
responses: responses:
@ -616,6 +668,103 @@ paths:
required: true required: true
schema: schema:
type: string type: string
/v1/prompts/{prompt_id}:
get:
responses:
'200':
description: A Prompt resource.
content:
application/json:
schema:
$ref: '#/components/schemas/Prompt'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Prompts
description: >-
Get a prompt by its identifier and optional version.
parameters:
- name: prompt_id
in: path
description: The identifier of the prompt to get.
required: true
schema:
type: string
- name: version
in: query
description: >-
The version of the prompt to get (defaults to latest).
required: false
schema:
type: integer
post:
responses:
'200':
description: >-
The updated Prompt resource with incremented version.
content:
application/json:
schema:
$ref: '#/components/schemas/Prompt'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Prompts
description: >-
Update an existing prompt (increments version).
parameters:
- name: prompt_id
in: path
description: The identifier of the prompt to update.
required: true
schema:
type: string
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/UpdatePromptRequest'
required: true
delete:
responses:
'200':
description: OK
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Prompts
description: Delete a prompt.
parameters:
- name: prompt_id
in: path
description: The identifier of the prompt to delete.
required: true
schema:
type: string
/v1/inference/embeddings: /v1/inference/embeddings:
post: post:
responses: responses:
@ -1983,6 +2132,37 @@ paths:
required: false required: false
schema: schema:
$ref: '#/components/schemas/Order' $ref: '#/components/schemas/Order'
/v1/prompts/{prompt_id}/versions:
get:
responses:
'200':
description: >-
A ListPromptsResponse containing all versions of the prompt.
content:
application/json:
schema:
$ref: '#/components/schemas/ListPromptsResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Prompts
description: List all versions of a specific prompt.
parameters:
- name: prompt_id
in: path
description: >-
The identifier of the prompt to list versions for.
required: true
schema:
type: string
/v1/providers: /v1/providers:
get: get:
responses: responses:
@ -2933,6 +3113,10 @@ paths:
- file: The File object (not file name) to be uploaded. - file: The File object (not file name) to be uploaded.
- purpose: The intended purpose of the uploaded file. - purpose: The intended purpose of the uploaded file.
- expires_after: Optional form values describing expiration for the file.
Expected expires_after[anchor] = "created_at", expires_after[seconds] = <int>.
Seconds must be between 3600 and 2592000 (1 hour to 30 days).
parameters: [] parameters: []
requestBody: requestBody:
content: content:
@ -2945,9 +3129,19 @@ paths:
format: binary format: binary
purpose: purpose:
$ref: '#/components/schemas/OpenAIFilePurpose' $ref: '#/components/schemas/OpenAIFilePurpose'
expires_after_anchor:
oneOf:
- type: string
- type: 'null'
expires_after_seconds:
oneOf:
- type: integer
- type: 'null'
required: required:
- file - file
- purpose - purpose
- expires_after_anchor
- expires_after_seconds
required: true required: true
/v1/openai/v1/models: /v1/openai/v1/models:
get: get:
@ -3532,6 +3726,43 @@ paths:
schema: schema:
$ref: '#/components/schemas/ScoreBatchRequest' $ref: '#/components/schemas/ScoreBatchRequest'
required: true required: true
/v1/prompts/{prompt_id}/set-default-version:
post:
responses:
'200':
description: >-
The prompt with the specified version now set as default.
content:
application/json:
schema:
$ref: '#/components/schemas/Prompt'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Prompts
description: >-
Set which version of a prompt should be the default in get_prompt (latest).
parameters:
- name: prompt_id
in: path
description: The identifier of the prompt.
required: true
schema:
type: string
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/SetDefaultVersionRequest'
required: true
/v1/post-training/supervised-fine-tune: /v1/post-training/supervised-fine-tune:
post: post:
responses: responses:
@ -7134,6 +7365,61 @@ components:
- type - type
title: >- title: >-
OpenAIResponseObjectStreamResponseWebSearchCallSearching OpenAIResponseObjectStreamResponseWebSearchCallSearching
CreatePromptRequest:
type: object
properties:
prompt:
type: string
description: >-
The prompt text content with variable placeholders.
variables:
type: array
items:
type: string
description: >-
List of variable names that can be used in the prompt template.
additionalProperties: false
required:
- prompt
title: CreatePromptRequest
Prompt:
type: object
properties:
prompt:
type: string
description: >-
The system prompt text with variable placeholders. Variables are only
supported when using the Responses API.
version:
type: integer
description: >-
Version (integer starting at 1, incremented on save)
prompt_id:
type: string
description: >-
Unique identifier formatted as 'pmpt_<48-digit-hash>'
variables:
type: array
items:
type: string
description: >-
List of prompt variable names that can be used in the prompt template
is_default:
type: boolean
default: false
description: >-
Boolean indicating whether this version is the default version for this
prompt
additionalProperties: false
required:
- version
- prompt_id
- variables
- is_default
title: Prompt
description: >-
A prompt resource representing a stored OpenAI Compatible prompt template
in Llama Stack.
OpenAIDeleteResponseObject: OpenAIDeleteResponseObject:
type: object type: object
properties: properties:
@ -7607,6 +7893,7 @@ components:
- benchmark - benchmark
- tool - tool
- tool_group - tool_group
- prompt
const: benchmark const: benchmark
default: benchmark default: benchmark
description: The resource type, always benchmark description: The resource type, always benchmark
@ -8093,6 +8380,7 @@ components:
- benchmark - benchmark
- tool - tool
- tool_group - tool_group
- prompt
const: dataset const: dataset
default: dataset default: dataset
description: >- description: >-
@ -8205,6 +8493,7 @@ components:
- benchmark - benchmark
- tool - tool
- tool_group - tool_group
- prompt
const: model const: model
default: model default: model
description: >- description: >-
@ -8396,6 +8685,7 @@ components:
- benchmark - benchmark
- tool - tool
- tool_group - tool_group
- prompt
const: scoring_function const: scoring_function
default: scoring_function default: scoring_function
description: >- description: >-
@ -8472,6 +8762,7 @@ components:
- benchmark - benchmark
- tool - tool
- tool_group - tool_group
- prompt
const: shield const: shield
default: shield default: shield
description: The resource type, always shield description: The resource type, always shield
@ -8651,6 +8942,7 @@ components:
- benchmark - benchmark
- tool - tool
- tool_group - tool_group
- prompt
const: tool const: tool
default: tool default: tool
description: Type of resource, always 'tool' description: Type of resource, always 'tool'
@ -8709,6 +9001,7 @@ components:
- benchmark - benchmark
- tool - tool
- tool_group - tool_group
- prompt
const: tool_group const: tool_group
default: tool_group default: tool_group
description: Type of resource, always 'tool_group' description: Type of resource, always 'tool_group'
@ -8937,6 +9230,7 @@ components:
- benchmark - benchmark
- tool - tool
- tool_group - tool_group
- prompt
const: vector_db const: vector_db
default: vector_db default: vector_db
description: >- description: >-
@ -9563,6 +9857,18 @@ components:
title: OpenAIResponseObjectWithInput title: OpenAIResponseObjectWithInput
description: >- description: >-
OpenAI response object extended with input context information. OpenAI response object extended with input context information.
ListPromptsResponse:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/Prompt'
additionalProperties: false
required:
- data
title: ListPromptsResponse
description: Response model to list prompts.
ListProvidersResponse: ListProvidersResponse:
type: object type: object
properties: properties:
@ -12708,6 +13014,16 @@ components:
title: ScoreBatchResponse title: ScoreBatchResponse
description: >- description: >-
Response from batch scoring operations on datasets. Response from batch scoring operations on datasets.
SetDefaultVersionRequest:
type: object
properties:
version:
type: integer
description: The version to set as default.
additionalProperties: false
required:
- version
title: SetDefaultVersionRequest
AlgorithmConfig: AlgorithmConfig:
oneOf: oneOf:
- $ref: '#/components/schemas/LoraFinetuningConfig' - $ref: '#/components/schemas/LoraFinetuningConfig'
@ -12904,6 +13220,32 @@ components:
description: >- description: >-
Response from the synthetic data generation. Batch of (prompt, response, score) Response from the synthetic data generation. Batch of (prompt, response, score)
tuples that pass the threshold. tuples that pass the threshold.
UpdatePromptRequest:
type: object
properties:
prompt:
type: string
description: The updated prompt text content.
version:
type: integer
description: >-
The current version of the prompt being updated.
variables:
type: array
items:
type: string
description: >-
Updated list of variable names that can be used in the prompt template.
set_as_default:
type: boolean
description: >-
Set the new version as the default (default=True).
additionalProperties: false
required:
- prompt
- version
- set_as_default
title: UpdatePromptRequest
VersionInfo: VersionInfo:
type: object type: object
properties: properties:
@ -13015,6 +13357,9 @@ tags:
- name: Inspect - name: Inspect
- name: Models - name: Models
- name: PostTraining (Coming Soon) - name: PostTraining (Coming Soon)
- name: Prompts
x-displayName: >-
Protocol for prompt management operations.
- name: Providers - name: Providers
x-displayName: >- x-displayName: >-
Providers API for inspecting, listing, and modifying providers and their configurations. Providers API for inspecting, listing, and modifying providers and their configurations.
@ -13042,6 +13387,7 @@ x-tagGroups:
- Inspect - Inspect
- Models - Models
- PostTraining (Coming Soon) - PostTraining (Coming Soon)
- Prompts
- Providers - Providers
- Safety - Safety
- Scoring - Scoring

View file

@ -33,7 +33,7 @@ The list of open-benchmarks we currently support:
- [MMMU](https://arxiv.org/abs/2311.16502) (A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI)]: Benchmark designed to evaluate multimodal models. - [MMMU](https://arxiv.org/abs/2311.16502) (A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI)]: Benchmark designed to evaluate multimodal models.
You can follow this [contributing guide](https://llama-stack.readthedocs.io/en/latest/references/evals_reference/index.html#open-benchmark-contributing-guide) to add more open-benchmarks to Llama Stack You can follow this [contributing guide](../references/evals_reference/index.md#open-benchmark-contributing-guide) to add more open-benchmarks to Llama Stack
#### Run evaluation on open-benchmarks via CLI #### Run evaluation on open-benchmarks via CLI

View file

@ -35,3 +35,6 @@ device: cpu
``` ```
[Find more detailed information here!](huggingface.md)

View file

@ -22,3 +22,4 @@ checkpoint_format: meta
``` ```
[Find more detailed information here!](torchtune.md)

View file

@ -88,7 +88,7 @@ Interactive pages for users to play with and explore Llama Stack API capabilitie
- **API Resources**: Inspect Llama Stack API resources - **API Resources**: Inspect Llama Stack API resources
- This page allows you to inspect Llama Stack API resources (`models`, `datasets`, `memory_banks`, `benchmarks`, `shields`). - This page allows you to inspect Llama Stack API resources (`models`, `datasets`, `memory_banks`, `benchmarks`, `shields`).
- Under the hood, it uses Llama Stack's `/<resources>/list` API to get information about each resources. - Under the hood, it uses Llama Stack's `/<resources>/list` API to get information about each resources.
- Please visit [Core Concepts](https://llama-stack.readthedocs.io/en/latest/concepts/index.html) for more details about the resources. - Please visit [Core Concepts](../../concepts/index.md) for more details about the resources.
### Starting the Llama Stack Playground ### Starting the Llama Stack Playground

View file

@ -3,7 +3,7 @@
Llama Stack (LLS) provides two different APIs for building AI applications with tool calling capabilities: the **Agents API** and the **OpenAI Responses API**. While both enable AI systems to use tools, and maintain full conversation history, they serve different use cases and have distinct characteristics. Llama Stack (LLS) provides two different APIs for building AI applications with tool calling capabilities: the **Agents API** and the **OpenAI Responses API**. While both enable AI systems to use tools, and maintain full conversation history, they serve different use cases and have distinct characteristics.
```{note} ```{note}
For simple and basic inferencing, you may want to use the [Chat Completions API](https://llama-stack.readthedocs.io/en/latest/providers/index.html#chat-completions) directly, before progressing to Agents or Responses API. **Note:** For simple and basic inferencing, you may want to use the [Chat Completions API](../providers/openai.md#chat-completions) directly, before progressing to Agents or Responses API.
``` ```
## Overview ## Overview
@ -173,7 +173,7 @@ Both APIs demonstrate distinct strengths that make them valuable on their own fo
## For More Information ## For More Information
- **LLS Agents API**: For detailed information on creating and managing agents, see the [Agents documentation](https://llama-stack.readthedocs.io/en/latest/building_applications/agent.html) - **LLS Agents API**: For detailed information on creating and managing agents, see the [Agents documentation](agent.md)
- **OpenAI Responses API**: For information on using the OpenAI-compatible responses API, see the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/responses) - **OpenAI Responses API**: For information on using the OpenAI-compatible responses API, see the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/responses)
- **Chat Completions API**: For the default backend API used by Agents, see the [Chat Completions providers documentation](https://llama-stack.readthedocs.io/en/latest/providers/index.html#chat-completions) - **Chat Completions API**: For the default backend API used by Agents, see the [Chat Completions providers documentation](../providers/openai.md#chat-completions)
- **Agent Execution Loop**: For understanding how agents process turns and steps in their execution, see the [Agent Execution Loop documentation](https://llama-stack.readthedocs.io/en/latest/building_applications/agent_execution_loop.html) - **Agent Execution Loop**: For understanding how agents process turns and steps in their execution, see the [Agent Execution Loop documentation](agent_execution_loop.md)

View file

@ -6,4 +6,4 @@ While there is a lot of flexibility to mix-and-match providers, often users will
**Locally Hosted Distro**: You may want to run Llama Stack on your own hardware. Typically though, you still need to use Inference via an external service. You can use providers like HuggingFace TGI, Fireworks, Together, etc. for this purpose. Or you may have access to GPUs and can run a [vLLM](https://github.com/vllm-project/vllm) or [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) instance. If you "just" have a regular desktop machine, you can use [Ollama](https://ollama.com/) for inference. To provide convenient quick access to these options, we provide a number of such pre-configured locally-hosted Distros. **Locally Hosted Distro**: You may want to run Llama Stack on your own hardware. Typically though, you still need to use Inference via an external service. You can use providers like HuggingFace TGI, Fireworks, Together, etc. for this purpose. Or you may have access to GPUs and can run a [vLLM](https://github.com/vllm-project/vllm) or [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) instance. If you "just" have a regular desktop machine, you can use [Ollama](https://ollama.com/) for inference. To provide convenient quick access to these options, we provide a number of such pre-configured locally-hosted Distros.
**On-device Distro**: To run Llama Stack directly on an edge device (mobile phone or a tablet), we provide Distros for [iOS](https://llama-stack.readthedocs.io/en/latest/distributions/ondevice_distro/ios_sdk.html) and [Android](https://llama-stack.readthedocs.io/en/latest/distributions/ondevice_distro/android_sdk.html) **On-device Distro**: To run Llama Stack directly on an edge device (mobile phone or a tablet), we provide Distros for [iOS](../distributions/ondevice_distro/ios_sdk.md) and [Android](../distributions/ondevice_distro/android_sdk.md)

View file

@ -131,6 +131,7 @@ html_static_path = ["../_static"]
def setup(app): def setup(app):
app.add_css_file("css/my_theme.css") app.add_css_file("css/my_theme.css")
app.add_js_file("js/detect_theme.js") app.add_js_file("js/detect_theme.js")
app.add_js_file("js/horizontal_nav.js")
app.add_js_file("js/keyboard_shortcuts.js") app.add_js_file("js/keyboard_shortcuts.js")
def dockerhub_role(name, rawtext, text, lineno, inliner, options={}, content=[]): def dockerhub_role(name, rawtext, text, lineno, inliner, options={}, content=[]):

View file

@ -35,5 +35,5 @@ testing/record-replay
### Benchmarking ### Benchmarking
```{include} ../../../docs/source/distributions/k8s-benchmark/README.md ```{include} ../../../benchmarking/k8s-benchmark/README.md
``` ```

View file

@ -14,6 +14,13 @@ Here are some example PRs to help you get started:
- [Nvidia Inference Implementation](https://github.com/meta-llama/llama-stack/pull/355) - [Nvidia Inference Implementation](https://github.com/meta-llama/llama-stack/pull/355)
- [Model context protocol Tool Runtime](https://github.com/meta-llama/llama-stack/pull/665) - [Model context protocol Tool Runtime](https://github.com/meta-llama/llama-stack/pull/665)
## Guidelines for creating Internal or External Providers
|**Type** |Internal (In-tree) |External (out-of-tree)
|---------|-------------------|---------------------|
|**Description** |A provider that is directly in the Llama Stack code|A provider that is outside of the Llama stack core codebase but is still accessible and usable by Llama Stack.
|**Benefits** |Ability to interact with the provider with minimal additional configurations or installations| Contributors do not have to add directly to the code to create providers accessible on Llama Stack. Keep provider-specific code separate from the core Llama Stack code.
## Inference Provider Patterns ## Inference Provider Patterns
When implementing Inference providers for OpenAI-compatible APIs, Llama Stack provides several mixin classes to simplify development and ensure consistent behavior across providers. When implementing Inference providers for OpenAI-compatible APIs, Llama Stack provides several mixin classes to simplify development and ensure consistent behavior across providers.

View file

@ -40,18 +40,15 @@ The system patches OpenAI and Ollama client methods to intercept calls before th
### Storage Architecture ### Storage Architecture
Recordings use a two-tier storage system optimized for both speed and debuggability: Recordings are stored as JSON files in the recording directory. They are looked up by their request hash.
``` ```
recordings/ recordings/
├── index.sqlite # Fast lookup by request hash
└── responses/ └── responses/
├── abc123def456.json # Individual response files ├── abc123def456.json # Individual response files
└── def789ghi012.json └── def789ghi012.json
``` ```
**SQLite index** enables O(log n) hash lookups and metadata queries without loading response bodies.
**JSON files** store complete request/response pairs in human-readable format for debugging. **JSON files** store complete request/response pairs in human-readable format for debugging.
## Recording Modes ## Recording Modes
@ -166,8 +163,8 @@ This preserves type safety - when replayed, you get the same Pydantic objects wi
Control recording behavior globally: Control recording behavior globally:
```bash ```bash
export LLAMA_STACK_TEST_INFERENCE_MODE=replay export LLAMA_STACK_TEST_INFERENCE_MODE=replay # this is the default
export LLAMA_STACK_TEST_RECORDING_DIR=/path/to/recordings export LLAMA_STACK_TEST_RECORDING_DIR=/path/to/recordings # default is tests/integration/recordings
pytest tests/integration/ pytest tests/integration/
``` ```

View file

@ -354,6 +354,47 @@ You can easily validate a request by running:
curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers
``` ```
#### Kubernetes Authentication Provider
The server can be configured to use Kubernetes SelfSubjectReview API to validate tokens directly against the Kubernetes API server:
```yaml
server:
auth:
provider_config:
type: "kubernetes"
api_server_url: "https://kubernetes.default.svc"
claims_mapping:
username: "roles"
groups: "roles"
uid: "uid_attr"
verify_tls: true
tls_cafile: "/path/to/ca.crt"
```
Configuration options:
- `api_server_url`: The Kubernetes API server URL (e.g., https://kubernetes.default.svc:6443)
- `verify_tls`: Whether to verify TLS certificates (default: true)
- `tls_cafile`: Path to CA certificate file for TLS verification
- `claims_mapping`: Mapping of Kubernetes user claims to access attributes
The provider validates tokens by sending a SelfSubjectReview request to the Kubernetes API server at `/apis/authentication.k8s.io/v1/selfsubjectreviews`. The provider extracts user information from the response:
- Username from the `userInfo.username` field
- Groups from the `userInfo.groups` field
- UID from the `userInfo.uid` field
To obtain a token for testing:
```bash
kubectl create namespace llama-stack
kubectl create serviceaccount llama-stack-auth -n llama-stack
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
```
You can validate a request by running:
```bash
curl -s -L -H "Authorization: Bearer $(cat llama-stack-auth-token)" http://127.0.0.1:8321/v1/providers
```
#### GitHub Token Provider #### GitHub Token Provider
Validates GitHub personal access tokens or OAuth tokens directly: Validates GitHub personal access tokens or OAuth tokens directly:
```yaml ```yaml

View file

@ -27,7 +27,7 @@ Then, you can access the APIs like `models` and `inference` on the client and ca
response = client.models.list() response = client.models.list()
``` ```
If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/latest/distributions/building_distro.html), you can also use the run.yaml configuration file directly: If you've created a [custom distribution](building_distro.md), you can also use the run.yaml configuration file directly:
```python ```python
client = LlamaStackAsLibraryClient(config_path) client = LlamaStackAsLibraryClient(config_path)

View file

@ -22,17 +22,17 @@ else
fi fi
if [ -z "${GITHUB_CLIENT_ID:-}" ]; then if [ -z "${GITHUB_CLIENT_ID:-}" ]; then
echo "ERROR: GITHUB_CLIENT_ID not set. You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide" echo "ERROR: GITHUB_CLIENT_ID not set. You need it for Github login to work. See the Kubernetes Deployment Guide in the Llama Stack documentation."
exit 1 exit 1
fi fi
if [ -z "${GITHUB_CLIENT_SECRET:-}" ]; then if [ -z "${GITHUB_CLIENT_SECRET:-}" ]; then
echo "ERROR: GITHUB_CLIENT_SECRET not set. You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide" echo "ERROR: GITHUB_CLIENT_SECRET not set. You need it for Github login to work. See the Kubernetes Deployment Guide in the Llama Stack documentation."
exit 1 exit 1
fi fi
if [ -z "${LLAMA_STACK_UI_URL:-}" ]; then if [ -z "${LLAMA_STACK_UI_URL:-}" ]; then
echo "ERROR: LLAMA_STACK_UI_URL not set. Should be set to the external URL of the UI (excluding port). You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide" echo "ERROR: LLAMA_STACK_UI_URL not set. Should be set to the external URL of the UI (excluding port). You need it for Github login to work. See the Kubernetes Deployment Guide in the Llama Stack documentation."
exit 1 exit 1
fi fi

View file

@ -1,137 +1,55 @@
apiVersion: v1 apiVersion: v1
data: data:
stack_run_config.yaml: | stack_run_config.yaml: "version: '2'\nimage_name: kubernetes-demo\napis:\n- agents\n-
version: '2' inference\n- files\n- safety\n- telemetry\n- tool_runtime\n- vector_io\nproviders:\n
image_name: kubernetes-demo \ inference:\n - provider_id: vllm-inference\n provider_type: remote::vllm\n
apis: \ config:\n url: ${env.VLLM_URL:=http://localhost:8000/v1}\n max_tokens:
- agents ${env.VLLM_MAX_TOKENS:=4096}\n api_token: ${env.VLLM_API_TOKEN:=fake}\n tls_verify:
- inference ${env.VLLM_TLS_VERIFY:=true}\n - provider_id: vllm-safety\n provider_type:
- safety remote::vllm\n config:\n url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1}\n
- telemetry \ max_tokens: ${env.VLLM_MAX_TOKENS:=4096}\n api_token: ${env.VLLM_API_TOKEN:=fake}\n
- tool_runtime \ tls_verify: ${env.VLLM_TLS_VERIFY:=true}\n - provider_id: sentence-transformers\n
- vector_io \ provider_type: inline::sentence-transformers\n config: {}\n vector_io:\n
providers: \ - provider_id: ${env.ENABLE_CHROMADB:+chromadb}\n provider_type: remote::chromadb\n
inference: \ config:\n url: ${env.CHROMADB_URL:=}\n kvstore:\n type: postgres\n
- provider_id: vllm-inference \ host: ${env.POSTGRES_HOST:=localhost}\n port: ${env.POSTGRES_PORT:=5432}\n
provider_type: remote::vllm \ db: ${env.POSTGRES_DB:=llamastack}\n user: ${env.POSTGRES_USER:=llamastack}\n
config: \ password: ${env.POSTGRES_PASSWORD:=llamastack}\n files:\n - provider_id:
url: ${env.VLLM_URL:=http://localhost:8000/v1} meta-reference-files\n provider_type: inline::localfs\n config:\n storage_dir:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}\n metadata_store:\n
api_token: ${env.VLLM_API_TOKEN:=fake} \ type: sqlite\n db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
tls_verify: ${env.VLLM_TLS_VERIFY:=true} \ \n safety:\n - provider_id: llama-guard\n provider_type: inline::llama-guard\n
- provider_id: vllm-safety \ config:\n excluded_categories: []\n agents:\n - provider_id: meta-reference\n
provider_type: remote::vllm \ provider_type: inline::meta-reference\n config:\n persistence_store:\n
config: \ type: postgres\n host: ${env.POSTGRES_HOST:=localhost}\n port:
url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1} ${env.POSTGRES_PORT:=5432}\n db: ${env.POSTGRES_DB:=llamastack}\n user:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} ${env.POSTGRES_USER:=llamastack}\n password: ${env.POSTGRES_PASSWORD:=llamastack}\n
api_token: ${env.VLLM_API_TOKEN:=fake} \ responses_store:\n type: postgres\n host: ${env.POSTGRES_HOST:=localhost}\n
tls_verify: ${env.VLLM_TLS_VERIFY:=true} \ port: ${env.POSTGRES_PORT:=5432}\n db: ${env.POSTGRES_DB:=llamastack}\n
- provider_id: sentence-transformers \ user: ${env.POSTGRES_USER:=llamastack}\n password: ${env.POSTGRES_PASSWORD:=llamastack}\n
provider_type: inline::sentence-transformers \ telemetry:\n - provider_id: meta-reference\n provider_type: inline::meta-reference\n
config: {} \ config:\n service_name: \"${env.OTEL_SERVICE_NAME:=\\u200B}\"\n sinks:
vector_io: ${env.TELEMETRY_SINKS:=console}\n tool_runtime:\n - provider_id: brave-search\n
- provider_id: ${env.ENABLE_CHROMADB:+chromadb} \ provider_type: remote::brave-search\n config:\n api_key: ${env.BRAVE_SEARCH_API_KEY:+}\n
provider_type: remote::chromadb \ max_results: 3\n - provider_id: tavily-search\n provider_type: remote::tavily-search\n
config: \ config:\n api_key: ${env.TAVILY_SEARCH_API_KEY:+}\n max_results:
url: ${env.CHROMADB_URL:=} 3\n - provider_id: rag-runtime\n provider_type: inline::rag-runtime\n config:
kvstore: {}\n - provider_id: model-context-protocol\n provider_type: remote::model-context-protocol\n
type: postgres \ config: {}\nmetadata_store:\n type: postgres\n host: ${env.POSTGRES_HOST:=localhost}\n
host: ${env.POSTGRES_HOST:=localhost} \ port: ${env.POSTGRES_PORT:=5432}\n db: ${env.POSTGRES_DB:=llamastack}\n user:
port: ${env.POSTGRES_PORT:=5432} ${env.POSTGRES_USER:=llamastack}\n password: ${env.POSTGRES_PASSWORD:=llamastack}\n
db: ${env.POSTGRES_DB:=llamastack} \ table_name: llamastack_kvstore\ninference_store:\n type: postgres\n host:
user: ${env.POSTGRES_USER:=llamastack} ${env.POSTGRES_HOST:=localhost}\n port: ${env.POSTGRES_PORT:=5432}\n db: ${env.POSTGRES_DB:=llamastack}\n
password: ${env.POSTGRES_PASSWORD:=llamastack} \ user: ${env.POSTGRES_USER:=llamastack}\n password: ${env.POSTGRES_PASSWORD:=llamastack}\nmodels:\n-
safety: metadata:\n embedding_dimension: 384\n model_id: all-MiniLM-L6-v2\n provider_id:
- provider_id: llama-guard sentence-transformers\n model_type: embedding\n- metadata: {}\n model_id: ${env.INFERENCE_MODEL}\n
provider_type: inline::llama-guard \ provider_id: vllm-inference\n model_type: llm\n- metadata: {}\n model_id:
config: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}\n provider_id: vllm-safety\n
excluded_categories: [] \ model_type: llm\nshields:\n- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}\nvector_dbs:
agents: []\ndatasets: []\nscoring_fns: []\nbenchmarks: []\ntool_groups:\n- toolgroup_id:
- provider_id: meta-reference builtin::websearch\n provider_id: tavily-search\n- toolgroup_id: builtin::rag\n
provider_type: inline::meta-reference \ provider_id: rag-runtime\nserver:\n port: 8321\n auth:\n provider_config:\n
config: \ type: github_token\n"
persistence_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
responses_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:+}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:+}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
table_name: llamastack_kvstore
inference_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
models:
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: vllm-inference
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
provider_id: vllm-safety
model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
auth:
provider_config:
type: github_token
kind: ConfigMap kind: ConfigMap
metadata: metadata:
creationTimestamp: null creationTimestamp: null

View file

@ -3,6 +3,7 @@ image_name: kubernetes-demo
apis: apis:
- agents - agents
- inference - inference
- files
- safety - safety
- telemetry - telemetry
- tool_runtime - tool_runtime
@ -38,6 +39,14 @@ providers:
db: ${env.POSTGRES_DB:=llamastack} db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack} user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack} password: ${env.POSTGRES_PASSWORD:=llamastack}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard

View file

@ -66,7 +66,7 @@ llama stack run starter --port 5050
Ensure the Llama Stack server version is the same as the Kotlin SDK Library for maximum compatibility. Ensure the Llama Stack server version is the same as the Kotlin SDK Library for maximum compatibility.
Other inference providers: [Table](https://llama-stack.readthedocs.io/en/latest/index.html#supported-llama-stack-implementations) Other inference providers: [Table](../../index.md#supported-llama-stack-implementations)
How to set remote localhost in Demo App: [Settings](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/examples/android_app#settings) How to set remote localhost in Demo App: [Settings](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/examples/android_app#settings)

View file

@ -2,7 +2,7 @@
orphan: true orphan: true
--- ---
<!-- This file was auto-generated by distro_codegen.py, please edit source --> <!-- This file was auto-generated by distro_codegen.py, please edit source -->
# Meta Reference Distribution # Meta Reference GPU Distribution
```{toctree} ```{toctree}
:maxdepth: 2 :maxdepth: 2
@ -41,7 +41,7 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models ## Prerequisite: Downloading Models
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](../../references/llama_cli_reference/download_models.md) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
``` ```
$ llama model list --downloaded $ llama model list --downloaded

View file

@ -50,6 +50,7 @@ The following models are available by default:
- `meta/llama-3.2-11b-vision-instruct ` - `meta/llama-3.2-11b-vision-instruct `
- `meta/llama-3.2-90b-vision-instruct ` - `meta/llama-3.2-90b-vision-instruct `
- `meta/llama-3.3-70b-instruct ` - `meta/llama-3.3-70b-instruct `
- `nvidia/vila `
- `nvidia/llama-3.2-nv-embedqa-1b-v2 ` - `nvidia/llama-3.2-nv-embedqa-1b-v2 `
- `nvidia/nv-embedqa-e5-v5 ` - `nvidia/nv-embedqa-e5-v5 `
- `nvidia/nv-embedqa-mistral-7b-v2 ` - `nvidia/nv-embedqa-mistral-7b-v2 `

View file

@ -18,12 +18,13 @@ embedding_model_id = (
).identifier ).identifier
embedding_dimension = em.metadata["embedding_dimension"] embedding_dimension = em.metadata["embedding_dimension"]
_ = client.vector_dbs.register( vector_db = client.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
embedding_model=embedding_model_id, embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,
provider_id="faiss", provider_id="faiss",
) )
vector_db_id = vector_db.identifier
source = "https://www.paulgraham.com/greatwork.html" source = "https://www.paulgraham.com/greatwork.html"
print("rag_tool> Ingesting document:", source) print("rag_tool> Ingesting document:", source)
document = RAGDocument( document = RAGDocument(
@ -35,7 +36,7 @@ document = RAGDocument(
client.tool_runtime.rag_tool.insert( client.tool_runtime.rag_tool.insert(
documents=[document], documents=[document],
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
chunk_size_in_tokens=50, chunk_size_in_tokens=100,
) )
agent = Agent( agent = Agent(
client, client,

View file

@ -7,4 +7,5 @@ Here's a list of known external providers that you can use with Llama Stack:
| KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) | | KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
| KubeFlow Pipelines | Train models with KubeFlow Pipelines | Post Training | Inline **and** Remote | [llama-stack-provider-kfp-trainer](https://github.com/opendatahub-io/llama-stack-provider-kfp-trainer) | | KubeFlow Pipelines | Train models with KubeFlow Pipelines | Post Training | Inline **and** Remote | [llama-stack-provider-kfp-trainer](https://github.com/opendatahub-io/llama-stack-provider-kfp-trainer) |
| RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) | | RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) |
| TrustyAI LM-Eval | Evaluate models with TrustyAI LM-Eval | Eval | Remote | [llama-stack-provider-lmeval](https://github.com/trustyai-explainability/llama-stack-provider-lmeval) | | TrustyAI LM-Eval | Evaluate models with TrustyAI LM-Eval | Eval | Remote | [llama-stack-provider-lmeval](https://github.com/trustyai-explainability/llama-stack-provider-lmeval) |
| MongoDB | VectorIO with MongoDB | Vector_IO | Remote | [mongodb-llama-stack](https://github.com/mongodb-partners/mongodb-llama-stack) |

View file

@ -15,8 +15,8 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE | | `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS | | `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE | | `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
| `connect_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. | | `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
| `read_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. | | `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). | | `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |
## Sample Configuration ## Sample Configuration

View file

@ -9,7 +9,6 @@ This section contains documentation for all available providers for the **post_t
```{toctree} ```{toctree}
:maxdepth: 1 :maxdepth: 1
inline_huggingface-cpu
inline_huggingface-gpu inline_huggingface-gpu
inline_torchtune-cpu inline_torchtune-cpu
inline_torchtune-gpu inline_torchtune-gpu

View file

@ -15,8 +15,8 @@ AWS Bedrock safety provider for content moderation using AWS's safety services.
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE | | `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS | | `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE | | `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
| `connect_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. | | `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
| `read_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. | | `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). | | `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |
## Sample Configuration ## Sample Configuration

View file

@ -12,6 +12,60 @@ That means you'll get fast and efficient vector retrieval.
- Easy to use - Easy to use
- Fully integrated with Llama Stack - Fully integrated with Llama Stack
There are three implementations of search for PGVectoIndex available:
1. Vector Search:
- How it works:
- Uses PostgreSQL's vector extension (pgvector) to perform similarity search
- Compares query embeddings against stored embeddings using Cosine distance or other distance metrics
- Eg. SQL query: SELECT document, embedding <=> %s::vector AS distance FROM table ORDER BY distance
-Characteristics:
- Semantic understanding - finds documents similar in meaning even if they don't share keywords
- Works with high-dimensional vector embeddings (typically 768, 1024, or higher dimensions)
- Best for: Finding conceptually related content, handling synonyms, cross-language search
2. Keyword Search
- How it works:
- Uses PostgreSQL's full-text search capabilities with tsvector and ts_rank
- Converts text to searchable tokens using to_tsvector('english', text). Default language is English.
- Eg. SQL query: SELECT document, ts_rank(tokenized_content, plainto_tsquery('english', %s)) AS score
- Characteristics:
- Lexical matching - finds exact keyword matches and variations
- Uses GIN (Generalized Inverted Index) for fast text search performance
- Scoring: Uses PostgreSQL's ts_rank function for relevance scoring
- Best for: Exact term matching, proper names, technical terms, Boolean-style queries
3. Hybrid Search
- How it works:
- Combines both vector and keyword search results
- Runs both searches independently, then merges results using configurable reranking
- Two reranking strategies available:
- Reciprocal Rank Fusion (RRF) - (default: 60.0)
- Weighted Average - (default: 0.5)
- Characteristics:
- Best of both worlds: semantic understanding + exact matching
- Documents appearing in both searches get boosted scores
- Configurable balance between semantic and lexical matching
- Best for: General-purpose search where you want both precision and recall
4. Database Schema
The PGVector implementation stores data optimized for all three search types:
CREATE TABLE vector_store_xxx (
id TEXT PRIMARY KEY,
document JSONB, -- Original document
embedding vector(dimension), -- For vector search
content_text TEXT, -- Raw text content
tokenized_content TSVECTOR -- For keyword search
);
-- Indexes for performance
CREATE INDEX content_gin_idx ON table USING GIN(tokenized_content); -- Keyword search
-- Vector index created automatically by pgvector
## Usage ## Usage
To use PGVector in your Llama Stack project, follow these steps: To use PGVector in your Llama Stack project, follow these steps:
@ -20,6 +74,25 @@ To use PGVector in your Llama Stack project, follow these steps:
2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector). 2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector).
3. Start storing and querying vectors. 3. Start storing and querying vectors.
## This is an example how you can set up your environment for using PGVector
1. Export env vars:
```bash
export ENABLE_PGVECTOR=true
export PGVECTOR_HOST=localhost
export PGVECTOR_PORT=5432
export PGVECTOR_DB=llamastack
export PGVECTOR_USER=llamastack
export PGVECTOR_PASSWORD=llamastack
```
2. Create DB:
```bash
psql -h localhost -U postgres -c "CREATE ROLE llamastack LOGIN PASSWORD 'llamastack';"
psql -h localhost -U postgres -c "CREATE DATABASE llamastack OWNER llamastack;"
psql -h localhost -U llamastack -d llamastack -c "CREATE EXTENSION IF NOT EXISTS vector;"
```
## Installation ## Installation
You can install PGVector using docker: You can install PGVector using docker:

View file

@ -17,6 +17,7 @@ Weaviate supports:
- Metadata filtering - Metadata filtering
- Multi-modal retrieval - Multi-modal retrieval
## Usage ## Usage
To use Weaviate in your Llama Stack project, follow these steps: To use Weaviate in your Llama Stack project, follow these steps:

View file

@ -202,7 +202,7 @@ pprint(response)
Llama Stack offers a library of scoring functions and the `/scoring` API, allowing you to run evaluations on your pre-annotated AI application datasets. Llama Stack offers a library of scoring functions and the `/scoring` API, allowing you to run evaluations on your pre-annotated AI application datasets.
In this example, we will work with an example RAG dataset you have built previously, label with an annotation, and use LLM-As-Judge with custom judge prompt for scoring. Please checkout our [Llama Stack Playground](https://llama-stack.readthedocs.io/en/latest/playground/index.html) for an interactive interface to upload datasets and run scorings. In this example, we will work with an example RAG dataset you have built previously, label with an annotation, and use LLM-As-Judge with custom judge prompt for scoring. Please checkout our [Llama Stack Playground](../../building_applications/playground/index.md) for an interactive interface to upload datasets and run scorings.
```python ```python
judge_model_id = "meta-llama/Llama-3.1-405B-Instruct-FP8" judge_model_id = "meta-llama/Llama-3.1-405B-Instruct-FP8"

View file

@ -478,7 +478,6 @@ llama-stack-client scoring_functions list
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓ ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓
┃ identifier ┃ provider_id ┃ description ┃ type ┃ ┃ identifier ┃ provider_id ┃ description ┃ type ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩
│ basic::bfcl │ basic │ BFCL complex scoring │ scoring_function │
│ basic::docvqa │ basic │ DocVQA Visual Question & Answer scoring function │ scoring_function │ │ basic::docvqa │ basic │ DocVQA Visual Question & Answer scoring function │ scoring_function │
│ basic::equality │ basic │ Returns 1.0 if the input is equal to the target, 0.0 │ scoring_function │ │ basic::equality │ basic │ Returns 1.0 if the input is equal to the target, 0.0 │ scoring_function │
│ │ │ otherwise. │ │ │ │ │ otherwise. │ │

View file

@ -79,3 +79,10 @@ class ConflictError(ValueError):
def __init__(self, message: str) -> None: def __init__(self, message: str) -> None:
super().__init__(message) super().__init__(message)
class TokenValidationError(ValueError):
"""raised when token validation fails during authentication"""
def __init__(self, message: str) -> None:
super().__init__(message)

View file

@ -102,6 +102,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
:cvar benchmarks: Benchmark suite management :cvar benchmarks: Benchmark suite management
:cvar tool_groups: Tool group organization :cvar tool_groups: Tool group organization
:cvar files: File storage and management :cvar files: File storage and management
:cvar prompts: Prompt versions and management
:cvar inspect: Built-in system inspection and introspection :cvar inspect: Built-in system inspection and introspection
""" """
@ -127,6 +128,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
benchmarks = "benchmarks" benchmarks = "benchmarks"
tool_groups = "tool_groups" tool_groups = "tool_groups"
files = "files" files = "files"
prompts = "prompts"
# built-in API # built-in API
inspect = "inspect" inspect = "inspect"

View file

@ -5,10 +5,10 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import StrEnum from enum import StrEnum
from typing import Annotated, Literal, Protocol, runtime_checkable from typing import Annotated, ClassVar, Literal, Protocol, runtime_checkable
from fastapi import File, Form, Response, UploadFile from fastapi import File, Form, Response, UploadFile
from pydantic import BaseModel from pydantic import BaseModel, Field
from llama_stack.apis.common.responses import Order from llama_stack.apis.common.responses import Order
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -49,6 +49,23 @@ class OpenAIFileObject(BaseModel):
purpose: OpenAIFilePurpose purpose: OpenAIFilePurpose
@json_schema_type
class ExpiresAfter(BaseModel):
"""
Control expiration of uploaded files.
Params:
- anchor, must be "created_at"
- seconds, must be int between 3600 and 2592000 (1 hour to 30 days)
"""
MIN: ClassVar[int] = 3600 # 1 hour
MAX: ClassVar[int] = 2592000 # 30 days
anchor: Literal["created_at"]
seconds: int = Field(..., ge=3600, le=2592000)
@json_schema_type @json_schema_type
class ListOpenAIFileResponse(BaseModel): class ListOpenAIFileResponse(BaseModel):
""" """
@ -92,6 +109,9 @@ class Files(Protocol):
self, self,
file: Annotated[UploadFile, File()], file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()], purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
# TODO: expires_after is producing strange openapi spec, params are showing up as a required w/ oneOf being null
) -> OpenAIFileObject: ) -> OpenAIFileObject:
""" """
Upload a file that can be used across various endpoints. Upload a file that can be used across various endpoints.
@ -99,6 +119,7 @@ class Files(Protocol):
The file upload should be a multipart form request with: The file upload should be a multipart form request with:
- file: The File object (not file name) to be uploaded. - file: The File object (not file name) to be uploaded.
- purpose: The intended purpose of the uploaded file. - purpose: The intended purpose of the uploaded file.
- expires_after: Optional form values describing expiration for the file. Expected expires_after[anchor] = "created_at", expires_after[seconds] = <int>. Seconds must be between 3600 and 2592000 (1 hour to 30 days).
:param file: The uploaded file object containing content and metadata (filename, content_type, etc.). :param file: The uploaded file object containing content and metadata (filename, content_type, etc.).
:param purpose: The intended purpose of the uploaded file (e.g., "assistants", "fine-tune"). :param purpose: The intended purpose of the uploaded file (e.g., "assistants", "fine-tune").

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .prompts import ListPromptsResponse, Prompt, Prompts
__all__ = ["Prompt", "Prompts", "ListPromptsResponse"]

View file

@ -0,0 +1,189 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
import secrets
from typing import Protocol, runtime_checkable
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class Prompt(BaseModel):
"""A prompt resource representing a stored OpenAI Compatible prompt template in Llama Stack.
:param prompt: The system prompt text with variable placeholders. Variables are only supported when using the Responses API.
:param version: Version (integer starting at 1, incremented on save)
:param prompt_id: Unique identifier formatted as 'pmpt_<48-digit-hash>'
:param variables: List of prompt variable names that can be used in the prompt template
:param is_default: Boolean indicating whether this version is the default version for this prompt
"""
prompt: str | None = Field(default=None, description="The system prompt with variable placeholders")
version: int = Field(description="Version (integer starting at 1, incremented on save)", ge=1)
prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'")
variables: list[str] = Field(
default_factory=list, description="List of variable names that can be used in the prompt template"
)
is_default: bool = Field(
default=False, description="Boolean indicating whether this version is the default version"
)
@field_validator("prompt_id")
@classmethod
def validate_prompt_id(cls, prompt_id: str) -> str:
if not isinstance(prompt_id, str):
raise TypeError("prompt_id must be a string in format 'pmpt_<48-digit-hash>'")
if not prompt_id.startswith("pmpt_"):
raise ValueError("prompt_id must start with 'pmpt_' prefix")
hex_part = prompt_id[5:]
if len(hex_part) != 48:
raise ValueError("prompt_id must be in format 'pmpt_<48-digit-hash>' (48 lowercase hex chars)")
for char in hex_part:
if char not in "0123456789abcdef":
raise ValueError("prompt_id hex part must contain only lowercase hex characters [0-9a-f]")
return prompt_id
@field_validator("version")
@classmethod
def validate_version(cls, prompt_version: int) -> int:
if prompt_version < 1:
raise ValueError("version must be >= 1")
return prompt_version
@model_validator(mode="after")
def validate_prompt_variables(self):
"""Validate that all variables used in the prompt are declared in the variables list."""
if not self.prompt:
return self
prompt_variables = set(re.findall(r"{{\s*(\w+)\s*}}", self.prompt))
declared_variables = set(self.variables)
undeclared = prompt_variables - declared_variables
if undeclared:
raise ValueError(f"Prompt contains undeclared variables: {sorted(undeclared)}")
return self
@classmethod
def generate_prompt_id(cls) -> str:
# Generate 48 hex characters (24 bytes)
random_bytes = secrets.token_bytes(24)
hex_string = random_bytes.hex()
return f"pmpt_{hex_string}"
class ListPromptsResponse(BaseModel):
"""Response model to list prompts."""
data: list[Prompt]
@runtime_checkable
@trace_protocol
class Prompts(Protocol):
"""Protocol for prompt management operations."""
@webmethod(route="/prompts", method="GET")
async def list_prompts(self) -> ListPromptsResponse:
"""List all prompts.
:returns: A ListPromptsResponse containing all prompts.
"""
...
@webmethod(route="/prompts/{prompt_id}/versions", method="GET")
async def list_prompt_versions(
self,
prompt_id: str,
) -> ListPromptsResponse:
"""List all versions of a specific prompt.
:param prompt_id: The identifier of the prompt to list versions for.
:returns: A ListPromptsResponse containing all versions of the prompt.
"""
...
@webmethod(route="/prompts/{prompt_id}", method="GET")
async def get_prompt(
self,
prompt_id: str,
version: int | None = None,
) -> Prompt:
"""Get a prompt by its identifier and optional version.
:param prompt_id: The identifier of the prompt to get.
:param version: The version of the prompt to get (defaults to latest).
:returns: A Prompt resource.
"""
...
@webmethod(route="/prompts", method="POST")
async def create_prompt(
self,
prompt: str,
variables: list[str] | None = None,
) -> Prompt:
"""Create a new prompt.
:param prompt: The prompt text content with variable placeholders.
:param variables: List of variable names that can be used in the prompt template.
:returns: The created Prompt resource.
"""
...
@webmethod(route="/prompts/{prompt_id}", method="PUT")
async def update_prompt(
self,
prompt_id: str,
prompt: str,
version: int,
variables: list[str] | None = None,
set_as_default: bool = True,
) -> Prompt:
"""Update an existing prompt (increments version).
:param prompt_id: The identifier of the prompt to update.
:param prompt: The updated prompt text content.
:param version: The current version of the prompt being updated.
:param variables: Updated list of variable names that can be used in the prompt template.
:param set_as_default: Set the new version as the default (default=True).
:returns: The updated Prompt resource with incremented version.
"""
...
@webmethod(route="/prompts/{prompt_id}", method="DELETE")
async def delete_prompt(
self,
prompt_id: str,
) -> None:
"""Delete a prompt.
:param prompt_id: The identifier of the prompt to delete.
"""
...
@webmethod(route="/prompts/{prompt_id}/set-default-version", method="PUT")
async def set_default_version(
self,
prompt_id: str,
version: int,
) -> Prompt:
"""Set which version of a prompt should be the default in get_prompt (latest).
:param prompt_id: The identifier of the prompt.
:param version: The version to set as default.
:returns: The prompt with the specified version now set as default.
"""
...

View file

@ -19,6 +19,7 @@ class ResourceType(StrEnum):
benchmark = "benchmark" benchmark = "benchmark"
tool = "tool" tool = "tool"
tool_group = "tool_group" tool_group = "tool_group"
prompt = "prompt"
class Resource(BaseModel): class Resource(BaseModel):

View file

@ -45,6 +45,7 @@ from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.exec import formulate_run_args, run_command from llama_stack.core.utils.exec import formulate_run_args, run_command
from llama_stack.core.utils.image_types import LlamaStackImageType from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions" DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"
@ -294,6 +295,12 @@ def _generate_run_config(
if build_config.external_providers_dir if build_config.external_providers_dir
else EXTERNAL_PROVIDERS_DIR, else EXTERNAL_PROVIDERS_DIR,
) )
if not run_config.inference_store:
run_config.inference_store = SqliteSqlStoreConfig(
**SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=(DISTRIBS_BASE_DIR / image_name).as_posix(), db_name="inference_store.db"
)
)
# build providers dict # build providers dict
provider_registry = get_provider_registry(build_config) provider_registry = get_provider_registry(build_config)
for api in apis: for api in apis:

View file

@ -80,7 +80,7 @@ def get_provider_dependencies(
normal_deps = [] normal_deps = []
special_deps = [] special_deps = []
for package in deps: for package in deps:
if "--no-deps" in package or "--index-url" in package: if any(f in package for f in ["--no-deps", "--index-url", "--extra-index-url"]):
special_deps.append(package) special_deps.append(package)
else: else:
normal_deps.append(package) normal_deps.append(package)

View file

@ -7,6 +7,7 @@
from enum import StrEnum from enum import StrEnum
from pathlib import Path from pathlib import Path
from typing import Annotated, Any, Literal, Self from typing import Annotated, Any, Literal, Self
from urllib.parse import urlparse
from pydantic import BaseModel, Field, field_validator, model_validator from pydantic import BaseModel, Field, field_validator, model_validator
@ -212,6 +213,7 @@ class AuthProviderType(StrEnum):
OAUTH2_TOKEN = "oauth2_token" OAUTH2_TOKEN = "oauth2_token"
GITHUB_TOKEN = "github_token" GITHUB_TOKEN = "github_token"
CUSTOM = "custom" CUSTOM = "custom"
KUBERNETES = "kubernetes"
class OAuth2TokenAuthConfig(BaseModel): class OAuth2TokenAuthConfig(BaseModel):
@ -282,8 +284,45 @@ class GitHubTokenAuthConfig(BaseModel):
) )
class KubernetesAuthProviderConfig(BaseModel):
"""Configuration for Kubernetes authentication provider."""
type: Literal[AuthProviderType.KUBERNETES] = AuthProviderType.KUBERNETES
api_server_url: str = Field(
default="https://kubernetes.default.svc",
description="Kubernetes API server URL (e.g., https://api.cluster.domain:6443)",
)
verify_tls: bool = Field(default=True, description="Whether to verify TLS certificates")
tls_cafile: Path | None = Field(default=None, description="Path to CA certificate file for TLS verification")
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"username": "roles",
"groups": "roles",
},
description="Mapping of Kubernetes user claims to access attributes",
)
@field_validator("api_server_url")
@classmethod
def validate_api_server_url(cls, v):
parsed = urlparse(v)
if not parsed.scheme or not parsed.netloc:
raise ValueError(f"api_server_url must be a valid URL with scheme and host: {v}")
if parsed.scheme not in ["http", "https"]:
raise ValueError(f"api_server_url scheme must be http or https: {v}")
return v
@field_validator("claims_mapping")
@classmethod
def validate_claims_mapping(cls, v):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
return v
AuthProviderConfig = Annotated[ AuthProviderConfig = Annotated[
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig, OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig | KubernetesAuthProviderConfig,
Field(discriminator="type"), Field(discriminator="type"),
] ]
@ -392,6 +431,12 @@ class ServerConfig(BaseModel):
) )
class InferenceStoreConfig(BaseModel):
sql_store_config: SqlStoreConfig
max_write_queue_size: int = Field(default=10000, description="Max queued writes for inference store")
num_writers: int = Field(default=4, description="Number of concurrent background writers")
class StackRunConfig(BaseModel): class StackRunConfig(BaseModel):
version: int = LLAMA_STACK_RUN_CONFIG_VERSION version: int = LLAMA_STACK_RUN_CONFIG_VERSION
@ -425,11 +470,12 @@ Configuration for the persistence store used by the distribution registry. If no
a default SQLite store will be used.""", a default SQLite store will be used.""",
) )
inference_store: SqlStoreConfig | None = Field( inference_store: InferenceStoreConfig | SqlStoreConfig | None = Field(
default=None, default=None,
description=""" description="""
Configuration for the persistence store used by the inference API. If not specified, Configuration for the persistence store used by the inference API. Can be either a
a default SQLite store will be used.""", InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (deprecated).
If not specified, a default SQLite store will be used.""",
) )
# registry of "resources" in the distribution # registry of "resources" in the distribution

View file

@ -10,7 +10,6 @@ import json
import logging # allow-direct-logging import logging # allow-direct-logging
import os import os
import sys import sys
from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
@ -148,7 +147,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
self.async_client = AsyncLlamaStackAsLibraryClient( self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
) )
self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.provider_data = provider_data self.provider_data = provider_data
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()

View file

@ -0,0 +1,233 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any
from pydantic import BaseModel
from llama_stack.apis.prompts import ListPromptsResponse, Prompt, Prompts
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class PromptServiceConfig(BaseModel):
"""Configuration for the built-in prompt service.
:param run_config: Stack run configuration containing distribution info
"""
run_config: StackRunConfig
async def get_provider_impl(config: PromptServiceConfig, deps: dict[Any, Any]):
"""Get the prompt service implementation."""
impl = PromptServiceImpl(config, deps)
await impl.initialize()
return impl
class PromptServiceImpl(Prompts):
"""Built-in prompt service implementation using KVStore."""
def __init__(self, config: PromptServiceConfig, deps: dict[Any, Any]):
self.config = config
self.deps = deps
self.kvstore: KVStore
async def initialize(self) -> None:
kvstore_config = SqliteKVStoreConfig(
db_path=(DISTRIBS_BASE_DIR / self.config.run_config.image_name / "prompts.db").as_posix()
)
self.kvstore = await kvstore_impl(kvstore_config)
def _get_default_key(self, prompt_id: str) -> str:
"""Get the KVStore key that stores the default version number."""
return f"prompts:v1:{prompt_id}:default"
async def _get_prompt_key(self, prompt_id: str, version: int | None = None) -> str:
"""Get the KVStore key for prompt data, returning default version if applicable."""
if version:
return self._get_version_key(prompt_id, str(version))
default_key = self._get_default_key(prompt_id)
resolved_version = await self.kvstore.get(default_key)
if resolved_version is None:
raise ValueError(f"Prompt {prompt_id}:default not found")
return self._get_version_key(prompt_id, resolved_version)
def _get_version_key(self, prompt_id: str, version: str) -> str:
"""Get the KVStore key for a specific prompt version."""
return f"prompts:v1:{prompt_id}:{version}"
def _get_list_key_prefix(self) -> str:
"""Get the key prefix for listing prompts."""
return "prompts:v1:"
def _serialize_prompt(self, prompt: Prompt) -> str:
"""Serialize a prompt to JSON string for storage."""
return json.dumps(
{
"prompt_id": prompt.prompt_id,
"prompt": prompt.prompt,
"version": prompt.version,
"variables": prompt.variables or [],
"is_default": prompt.is_default,
}
)
def _deserialize_prompt(self, data: str) -> Prompt:
"""Deserialize a prompt from JSON string."""
obj = json.loads(data)
return Prompt(
prompt_id=obj["prompt_id"],
prompt=obj["prompt"],
version=obj["version"],
variables=obj.get("variables", []),
is_default=obj.get("is_default", False),
)
async def list_prompts(self) -> ListPromptsResponse:
"""List all prompts (default versions only)."""
prefix = self._get_list_key_prefix()
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
prompts = []
for key in keys:
if key.endswith(":default"):
try:
default_version = await self.kvstore.get(key)
if default_version:
prompt_id = key.replace(prefix, "").replace(":default", "")
version_key = self._get_version_key(prompt_id, default_version)
data = await self.kvstore.get(version_key)
if data:
prompt = self._deserialize_prompt(data)
prompts.append(prompt)
except (json.JSONDecodeError, KeyError):
continue
prompts.sort(key=lambda p: p.prompt_id or "", reverse=True)
return ListPromptsResponse(data=prompts)
async def get_prompt(self, prompt_id: str, version: int | None = None) -> Prompt:
"""Get a prompt by its identifier and optional version."""
key = await self._get_prompt_key(prompt_id, version)
data = await self.kvstore.get(key)
if data is None:
raise ValueError(f"Prompt {prompt_id}:{version if version else 'default'} not found")
return self._deserialize_prompt(data)
async def create_prompt(
self,
prompt: str,
variables: list[str] | None = None,
) -> Prompt:
"""Create a new prompt."""
if variables is None:
variables = []
prompt_obj = Prompt(
prompt_id=Prompt.generate_prompt_id(),
prompt=prompt,
version=1,
variables=variables,
)
version_key = self._get_version_key(prompt_obj.prompt_id, str(prompt_obj.version))
data = self._serialize_prompt(prompt_obj)
await self.kvstore.set(version_key, data)
default_key = self._get_default_key(prompt_obj.prompt_id)
await self.kvstore.set(default_key, str(prompt_obj.version))
return prompt_obj
async def update_prompt(
self,
prompt_id: str,
prompt: str,
version: int,
variables: list[str] | None = None,
set_as_default: bool = True,
) -> Prompt:
"""Update an existing prompt (increments version)."""
if version < 1:
raise ValueError("Version must be >= 1")
if variables is None:
variables = []
prompt_versions = await self.list_prompt_versions(prompt_id)
latest_prompt = max(prompt_versions.data, key=lambda x: int(x.version))
if version and latest_prompt.version != version:
raise ValueError(
f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{latest_prompt.version}' in request."
)
current_version = latest_prompt.version if version is None else version
new_version = current_version + 1
updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables)
version_key = self._get_version_key(prompt_id, str(new_version))
data = self._serialize_prompt(updated_prompt)
await self.kvstore.set(version_key, data)
if set_as_default:
await self.set_default_version(prompt_id, new_version)
return updated_prompt
async def delete_prompt(self, prompt_id: str) -> None:
"""Delete a prompt and all its versions."""
await self.get_prompt(prompt_id)
prefix = f"prompts:v1:{prompt_id}:"
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
for key in keys:
await self.kvstore.delete(key)
async def list_prompt_versions(self, prompt_id: str) -> ListPromptsResponse:
"""List all versions of a specific prompt."""
prefix = f"prompts:v1:{prompt_id}:"
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
default_version = None
prompts = []
for key in keys:
data = await self.kvstore.get(key)
if key.endswith(":default"):
default_version = data
else:
if data:
prompt_obj = self._deserialize_prompt(data)
prompts.append(prompt_obj)
if not prompts:
raise ValueError(f"Prompt {prompt_id} not found")
for prompt in prompts:
prompt.is_default = str(prompt.version) == default_version
prompts.sort(key=lambda x: x.version)
return ListPromptsResponse(data=prompts)
async def set_default_version(self, prompt_id: str, version: int) -> Prompt:
"""Set which version of a prompt should be the default, If not set. the default is the latest."""
version_key = self._get_version_key(prompt_id, str(version))
data = await self.kvstore.get(version_key)
if data is None:
raise ValueError(f"Prompt {prompt_id} version {version} not found")
default_key = self._get_default_key(prompt_id)
await self.kvstore.set(default_key, str(version))
return self._deserialize_prompt(data)

View file

@ -19,6 +19,7 @@ from llama_stack.apis.inference import Inference, InferenceProvider
from llama_stack.apis.inspect import Inspect from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.prompts import Prompts
from llama_stack.apis.providers import Providers as ProvidersAPI from llama_stack.apis.providers import Providers as ProvidersAPI
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring import Scoring
@ -93,6 +94,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
Api.tool_groups: ToolGroups, Api.tool_groups: ToolGroups,
Api.tool_runtime: ToolRuntime, Api.tool_runtime: ToolRuntime,
Api.files: Files, Api.files: Files,
Api.prompts: Prompts,
} }
if external_apis: if external_apis:
@ -284,7 +286,15 @@ async def instantiate_providers(
if provider.provider_id is None: if provider.provider_id is None:
continue continue
deps = {a: impls[a] for a in provider.spec.api_dependencies} try:
deps = {a: impls[a] for a in provider.spec.api_dependencies}
except KeyError as e:
missing_api = e.args[0]
raise RuntimeError(
f"Failed to resolve '{provider.spec.api.value}' provider '{provider.provider_id}' of type '{provider.spec.provider_type}': "
f"required dependency '{missing_api.value}' is not available. "
f"Please add a '{missing_api.value}' provider to your configuration or check if the provider is properly configured."
) from e
for a in provider.spec.optional_api_dependencies: for a in provider.spec.optional_api_dependencies:
if a in impls: if a in impls:
deps[a] = impls[a] deps[a] = impls[a]

View file

@ -78,7 +78,10 @@ async def get_auto_router_impl(
# TODO: move pass configs to routers instead # TODO: move pass configs to routers instead
if api == Api.inference and run_config.inference_store: if api == Api.inference and run_config.inference_store:
inference_store = InferenceStore(run_config.inference_store, policy) inference_store = InferenceStore(
config=run_config.inference_store,
policy=policy,
)
await inference_store.initialize() await inference_store.initialize()
api_to_dep_impl["store"] = inference_store api_to_dep_impl["store"] = inference_store

View file

@ -63,7 +63,7 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.telemetry.tracing import get_current_span from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span
logger = get_logger(name=__name__, category="core::routers") logger = get_logger(name=__name__, category="core::routers")
@ -90,6 +90,11 @@ class InferenceRouter(Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
logger.debug("InferenceRouter.shutdown") logger.debug("InferenceRouter.shutdown")
if self.store:
try:
await self.store.shutdown()
except Exception as e:
logger.warning(f"Error during InferenceStore shutdown: {e}")
async def register_model( async def register_model(
self, self,
@ -160,7 +165,7 @@ class InferenceRouter(Inference):
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
if self.telemetry: if self.telemetry:
for metric in metrics: for metric in metrics:
await self.telemetry.log_event(metric) enqueue_event(metric)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
async def _count_tokens( async def _count_tokens(
@ -431,7 +436,7 @@ class InferenceRouter(Inference):
model=model_obj, model=model_obj,
) )
for metric in metrics: for metric in metrics:
await self.telemetry.log_event(metric) enqueue_event(metric)
# these metrics will show up in the client response. # these metrics will show up in the client response.
response.metrics = ( response.metrics = (
@ -527,7 +532,7 @@ class InferenceRouter(Inference):
# Store the response with the ID that will be returned to the client # Store the response with the ID that will be returned to the client
if self.store: if self.store:
await self.store.store_chat_completion(response, messages) asyncio.create_task(self.store.store_chat_completion(response, messages))
if self.telemetry: if self.telemetry:
metrics = self._construct_metrics( metrics = self._construct_metrics(
@ -537,7 +542,7 @@ class InferenceRouter(Inference):
model=model_obj, model=model_obj,
) )
for metric in metrics: for metric in metrics:
await self.telemetry.log_event(metric) enqueue_event(metric)
# these metrics will show up in the client response. # these metrics will show up in the client response.
response.metrics = ( response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
@ -664,7 +669,7 @@ class InferenceRouter(Inference):
"completion_tokens", "completion_tokens",
"total_tokens", "total_tokens",
]: # Only log completion and total tokens ]: # Only log completion and total tokens
await self.telemetry.log_event(metric) enqueue_event(metric)
# Return metrics in response # Return metrics in response
async_metrics = [ async_metrics = [
@ -710,7 +715,7 @@ class InferenceRouter(Inference):
) )
for metric in completion_metrics: for metric in completion_metrics:
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
await self.telemetry.log_event(metric) enqueue_event(metric)
# Return metrics in response # Return metrics in response
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics] return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
@ -755,7 +760,7 @@ class InferenceRouter(Inference):
choices_data[idx] = { choices_data[idx] = {
"content_parts": [], "content_parts": [],
"tool_calls_builder": {}, "tool_calls_builder": {},
"finish_reason": None, "finish_reason": "stop",
"logprobs_content_parts": [], "logprobs_content_parts": [],
} }
current_choice_data = choices_data[idx] current_choice_data = choices_data[idx]
@ -806,7 +811,7 @@ class InferenceRouter(Inference):
model=model, model=model,
) )
for metric in metrics: for metric in metrics:
await self.telemetry.log_event(metric) enqueue_event(metric)
yield chunk yield chunk
finally: finally:
@ -855,4 +860,4 @@ class InferenceRouter(Inference):
object="chat.completion", object="chat.completion",
) )
logger.debug(f"InferenceRouter.completion_response: {final_response}") logger.debug(f"InferenceRouter.completion_response: {final_response}")
await self.store.store_chat_completion(final_response, messages) asyncio.create_task(self.store.store_chat_completion(final_response, messages))

View file

@ -52,7 +52,6 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
provider_vector_db_id: str | None = None, provider_vector_db_id: str | None = None,
vector_db_name: str | None = None, vector_db_name: str | None = None,
) -> VectorDB: ) -> VectorDB:
provider_vector_db_id = provider_vector_db_id or vector_db_id
if provider_id is None: if provider_id is None:
if len(self.impls_by_provider_id) > 0: if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0] provider_id = list(self.impls_by_provider_id.keys())[0]
@ -69,14 +68,33 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding) raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
if "embedding_dimension" not in model.metadata: if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension") raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
provider = self.impls_by_provider_id[provider_id]
logger.warning(
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
)
vector_store = await provider.openai_create_vector_store(
name=vector_db_name or vector_db_id,
embedding_model=embedding_model,
embedding_dimension=model.metadata["embedding_dimension"],
provider_id=provider_id,
provider_vector_db_id=provider_vector_db_id,
)
vector_store_id = vector_store.id
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
logger.warning(
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
)
vector_db_data = { vector_db_data = {
"identifier": vector_db_id, "identifier": vector_store_id,
"type": ResourceType.vector_db.value, "type": ResourceType.vector_db.value,
"provider_id": provider_id, "provider_id": provider_id,
"provider_resource_id": provider_vector_db_id, "provider_resource_id": actual_provider_vector_db_id,
"embedding_model": embedding_model, "embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"], "embedding_dimension": model.metadata["embedding_dimension"],
"vector_db_name": vector_db_name, "vector_db_name": vector_store.name,
} }
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data) vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
await self.register_object(vector_db) await self.register_object(vector_db)

View file

@ -8,16 +8,18 @@ import ssl
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from asyncio import Lock from asyncio import Lock
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urljoin, urlparse
import httpx import httpx
from jose import jwt from jose import jwt
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.common.errors import TokenValidationError
from llama_stack.core.datatypes import ( from llama_stack.core.datatypes import (
AuthenticationConfig, AuthenticationConfig,
CustomAuthConfig, CustomAuthConfig,
GitHubTokenAuthConfig, GitHubTokenAuthConfig,
KubernetesAuthProviderConfig,
OAuth2TokenAuthConfig, OAuth2TokenAuthConfig,
User, User,
) )
@ -162,7 +164,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
auth=auth, auth=auth,
timeout=10.0, # Add a reasonable timeout timeout=10.0, # Add a reasonable timeout
) )
if response.status_code != 200: if response.status_code != httpx.codes.OK:
logger.warning(f"Token introspection failed with status code: {response.status_code}") logger.warning(f"Token introspection failed with status code: {response.status_code}")
raise ValueError(f"Token introspection failed: {response.status_code}") raise ValueError(f"Token introspection failed: {response.status_code}")
@ -272,7 +274,7 @@ class CustomAuthProvider(AuthProvider):
json=auth_request.model_dump(), json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout timeout=10.0, # Add a reasonable timeout
) )
if response.status_code != 200: if response.status_code != httpx.codes.OK:
logger.warning(f"Authentication failed with status code: {response.status_code}") logger.warning(f"Authentication failed with status code: {response.status_code}")
raise ValueError(f"Authentication failed: {response.status_code}") raise ValueError(f"Authentication failed: {response.status_code}")
@ -374,6 +376,89 @@ async def _get_github_user_info(access_token: str, github_api_base_url: str) ->
} }
class KubernetesAuthProvider(AuthProvider):
"""
Kubernetes authentication provider that validates tokens using the Kubernetes SelfSubjectReview API.
This provider integrates with Kubernetes API server by using the
/apis/authentication.k8s.io/v1/selfsubjectreviews endpoint to validate tokens and extract user information.
"""
def __init__(self, config: KubernetesAuthProviderConfig):
self.config = config
def _httpx_verify_value(self) -> bool | str:
"""
Build the value for httpx's `verify` parameter.
- False disables verification.
- Path string points to a CA bundle.
- True uses system defaults.
"""
if not self.config.verify_tls:
return False
if self.config.tls_cafile:
return self.config.tls_cafile.as_posix()
return True
async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a token using Kubernetes SelfSubjectReview API endpoint."""
# Build the Kubernetes SelfSubjectReview API endpoint URL
review_api_url = urljoin(self.config.api_server_url, "/apis/authentication.k8s.io/v1/selfsubjectreviews")
# Create SelfSubjectReview request body
review_request = {"apiVersion": "authentication.k8s.io/v1", "kind": "SelfSubjectReview"}
verify = self._httpx_verify_value()
try:
async with httpx.AsyncClient(verify=verify, timeout=10.0) as client:
response = await client.post(
review_api_url,
json=review_request,
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
)
if response.status_code == httpx.codes.UNAUTHORIZED:
raise TokenValidationError("Invalid token")
if response.status_code != httpx.codes.CREATED:
logger.warning(f"Kubernetes SelfSubjectReview API failed with status code: {response.status_code}")
raise TokenValidationError(f"Token validation failed: {response.status_code}")
review_response = response.json()
# Extract user information from SelfSubjectReview response
status = review_response.get("status", {})
if not status:
raise ValueError("No status found in SelfSubjectReview response")
user_info = status.get("userInfo", {})
if not user_info:
raise ValueError("No userInfo found in SelfSubjectReview response")
username = user_info.get("username")
if not username:
raise ValueError("No username found in SelfSubjectReview response")
# Build user attributes from Kubernetes user info
user_attributes = get_attributes_from_claims(user_info, self.config.claims_mapping)
return User(
principal=username,
attributes=user_attributes,
)
except httpx.TimeoutException:
logger.warning("Kubernetes SelfSubjectReview API request timed out")
raise ValueError("Token validation timeout") from None
except Exception as e:
logger.warning(f"Error during token validation: {str(e)}")
raise ValueError(f"Token validation error: {str(e)}") from e
async def close(self):
"""Close any resources."""
pass
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider: def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
"""Factory function to create the appropriate auth provider.""" """Factory function to create the appropriate auth provider."""
provider_config = config.provider_config provider_config = config.provider_config
@ -384,5 +469,7 @@ def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
return OAuth2TokenAuthProvider(provider_config) return OAuth2TokenAuthProvider(provider_config)
elif isinstance(provider_config, GitHubTokenAuthConfig): elif isinstance(provider_config, GitHubTokenAuthConfig):
return GitHubTokenAuthProvider(provider_config) return GitHubTokenAuthProvider(provider_config)
elif isinstance(provider_config, KubernetesAuthProviderConfig):
return KubernetesAuthProvider(provider_config)
else: else:
raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}") raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")

View file

@ -132,15 +132,17 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
}, },
) )
elif isinstance(exc, ConflictError): elif isinstance(exc, ConflictError):
return HTTPException(status_code=409, detail=str(exc)) return HTTPException(status_code=httpx.codes.CONFLICT, detail=str(exc))
elif isinstance(exc, ResourceNotFoundError): elif isinstance(exc, ResourceNotFoundError):
return HTTPException(status_code=404, detail=str(exc)) return HTTPException(status_code=httpx.codes.NOT_FOUND, detail=str(exc))
elif isinstance(exc, ValueError): elif isinstance(exc, ValueError):
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}") return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
elif isinstance(exc, BadRequestError): elif isinstance(exc, BadRequestError):
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc)) return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc))
elif isinstance(exc, PermissionError | AccessDeniedError): elif isinstance(exc, PermissionError | AccessDeniedError):
return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}") return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}")
elif isinstance(exc, ConnectionError | httpx.ConnectError):
return HTTPException(status_code=httpx.codes.BAD_GATEWAY, detail=str(exc))
elif isinstance(exc, asyncio.TimeoutError | TimeoutError): elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}") return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}")
elif isinstance(exc, NotImplementedError): elif isinstance(exc, NotImplementedError):
@ -513,6 +515,7 @@ def main(args: argparse.Namespace | None = None):
apis_to_serve.add("inspect") apis_to_serve.add("inspect")
apis_to_serve.add("providers") apis_to_serve.add("providers")
apis_to_serve.add("prompts")
for api_str in apis_to_serve: for api_str in apis_to_serve:
api = Api(api_str) api = Api(api_str)

View file

@ -24,6 +24,7 @@ from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.prompts import Prompts
from llama_stack.apis.providers import Providers from llama_stack.apis.providers import Providers
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring import Scoring
@ -37,6 +38,7 @@ from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import Provider, StackRunConfig from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_provider_registry from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
from llama_stack.core.providers import ProviderImpl, ProviderImplConfig from llama_stack.core.providers import ProviderImpl, ProviderImplConfig
from llama_stack.core.resolver import ProviderRegistry, resolve_impls from llama_stack.core.resolver import ProviderRegistry, resolve_impls
from llama_stack.core.routing_tables.common import CommonRoutingTableImpl from llama_stack.core.routing_tables.common import CommonRoutingTableImpl
@ -72,6 +74,7 @@ class LlamaStack(
ToolRuntime, ToolRuntime,
RAGToolRuntime, RAGToolRuntime,
Files, Files,
Prompts,
): ):
pass pass
@ -105,12 +108,12 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
method = getattr(impls[api], register_method) method = getattr(impls[api], register_method)
for obj in objects: for obj in objects:
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}") if hasattr(obj, "provider_id"):
# Do not register models on disabled providers
# Do not register models on disabled providers if not obj.provider_id or obj.provider_id == "__disabled__":
if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"): logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.") continue
continue logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
# we want to maintain the type information in arguments to method. # we want to maintain the type information in arguments to method.
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict, # instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
@ -225,7 +228,10 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
try: try:
result = re.sub(pattern, get_env_var, config) result = re.sub(pattern, get_env_var, config)
return _convert_string_to_proper_type(result) # Only apply type conversion if substitution actually happened
if result != config:
return _convert_string_to_proper_type(result)
return result
except EnvVarError as e: except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None raise EnvVarError(e.var_name, e.path) from None
@ -302,6 +308,12 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
) )
impls[Api.providers] = providers_impl impls[Api.providers] = providers_impl
prompts_impl = PromptServiceImpl(
PromptServiceConfig(run_config=run_config),
deps=impls,
)
impls[Api.prompts] = prompts_impl
# Produces a stack of providers for the given run config. Not all APIs may be # Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config. # asked for in the run config.
@ -326,6 +338,9 @@ async def construct_stack(
# Add internal implementations after all other providers are resolved # Add internal implementations after all other providers are resolved
add_internal_implementations(impls, run_config) add_internal_implementations(impls, run_config)
if Api.prompts in impls:
await impls[Api.prompts].initialize()
await register_resources(run_config, impls) await register_resources(run_config, impls)
await refresh_registry_once(impls) await refresh_registry_once(impls)

View file

@ -34,7 +34,7 @@ distribution_spec:
telemetry: telemetry:
- provider_type: inline::meta-reference - provider_type: inline::meta-reference
post_training: post_training:
- provider_type: inline::huggingface-cpu - provider_type: inline::torchtune-cpu
eval: eval:
- provider_type: inline::meta-reference - provider_type: inline::meta-reference
datasetio: datasetio:

View file

@ -11,9 +11,7 @@ from ..starter.starter import get_distribution_template as get_starter_distribut
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
template = get_starter_distribution_template() template = get_starter_distribution_template(name="ci-tests")
name = "ci-tests"
template.name = name
template.description = "CI tests for Llama Stack" template.description = "CI tests for Llama Stack"
return template return template

View file

@ -89,28 +89,28 @@ providers:
config: config:
kvstore: kvstore:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/faiss_store.db
- provider_id: sqlite-vec - provider_id: sqlite-vec
provider_type: inline::sqlite-vec provider_type: inline::sqlite-vec
config: config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec.db
kvstore: kvstore:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec_registry.db
- provider_id: ${env.MILVUS_URL:+milvus} - provider_id: ${env.MILVUS_URL:+milvus}
provider_type: inline::milvus provider_type: inline::milvus
config: config:
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/ci-tests}/milvus.db
kvstore: kvstore:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/milvus_registry.db
- provider_id: ${env.CHROMADB_URL:+chromadb} - provider_id: ${env.CHROMADB_URL:+chromadb}
provider_type: remote::chromadb provider_type: remote::chromadb
config: config:
url: ${env.CHROMADB_URL:=} url: ${env.CHROMADB_URL:=}
kvstore: kvstore:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests/}/chroma_remote_registry.db
- provider_id: ${env.PGVECTOR_DB:+pgvector} - provider_id: ${env.PGVECTOR_DB:+pgvector}
provider_type: remote::pgvector provider_type: remote::pgvector
config: config:
@ -121,15 +121,15 @@ providers:
password: ${env.PGVECTOR_PASSWORD:=} password: ${env.PGVECTOR_PASSWORD:=}
kvstore: kvstore:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/pgvector_registry.db
files: files:
- provider_id: meta-reference-files - provider_id: meta-reference-files
provider_type: inline::localfs provider_type: inline::localfs
config: config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files} storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/ci-tests/files}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/files_metadata.db
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard
@ -156,13 +156,10 @@ providers:
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training: post_training:
- provider_id: huggingface-cpu - provider_id: torchtune-cpu
provider_type: inline::huggingface-cpu provider_type: inline::torchtune-cpu
config: config:
checkpoint_format: huggingface checkpoint_format: meta
distributed_backend: null
device: cpu
dpo_output_dir: ~/.llama/distributions/ci-tests/dpo_output
eval: eval:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference

View file

@ -1,7 +1,7 @@
--- ---
orphan: true orphan: true
--- ---
# Meta Reference Distribution # Meta Reference GPU Distribution
```{toctree} ```{toctree}
:maxdepth: 2 :maxdepth: 2
@ -29,7 +29,7 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models ## Prerequisite: Downloading Models
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](../../references/llama_cli_reference/download_models.md) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
``` ```
$ llama model list --downloaded $ llama model list --downloaded

View file

@ -134,6 +134,11 @@ models:
provider_id: nvidia provider_id: nvidia
provider_model_id: meta/llama-3.3-70b-instruct provider_model_id: meta/llama-3.3-70b-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: nvidia/vila
provider_id: nvidia
provider_model_id: nvidia/vila
model_type: llm
- metadata: - metadata:
embedding_dimension: 2048 embedding_dimension: 2048
context_length: 8192 context_length: 8192

View file

@ -43,7 +43,7 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo
"openai", "openai",
[ [
ProviderModelEntry( ProviderModelEntry(
provider_model_id="openai/gpt-4o", provider_model_id="gpt-4o",
model_type=ModelType.llm, model_type=ModelType.llm,
) )
], ],
@ -53,7 +53,7 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo
"anthropic", "anthropic",
[ [
ProviderModelEntry( ProviderModelEntry(
provider_model_id="anthropic/claude-3-5-sonnet-latest", provider_model_id="claude-3-5-sonnet-latest",
model_type=ModelType.llm, model_type=ModelType.llm,
) )
], ],
@ -206,13 +206,6 @@ def get_distribution_template() -> DistributionTemplate:
uri="huggingface://datasets/llamastack/math_500?split=test", uri="huggingface://datasets/llamastack/math_500?split=test",
), ),
), ),
DatasetInput(
dataset_id="bfcl",
purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource(
uri="huggingface://datasets/llamastack/bfcl_v3?split=train",
),
),
DatasetInput( DatasetInput(
dataset_id="ifeval", dataset_id="ifeval",
purpose=DatasetPurpose.eval_messages_answer, purpose=DatasetPurpose.eval_messages_answer,
@ -250,11 +243,6 @@ def get_distribution_template() -> DistributionTemplate:
dataset_id="math_500", dataset_id="math_500",
scoring_functions=["basic::regex_parser_math_response"], scoring_functions=["basic::regex_parser_math_response"],
), ),
BenchmarkInput(
benchmark_id="meta-reference-bfcl",
dataset_id="bfcl",
scoring_functions=["basic::bfcl"],
),
BenchmarkInput( BenchmarkInput(
benchmark_id="meta-reference-ifeval", benchmark_id="meta-reference-ifeval",
dataset_id="ifeval", dataset_id="ifeval",

View file

@ -136,14 +136,14 @@ inference_store:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/inference_store.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/inference_store.db
models: models:
- metadata: {} - metadata: {}
model_id: openai/gpt-4o model_id: gpt-4o
provider_id: openai provider_id: openai
provider_model_id: openai/gpt-4o provider_model_id: gpt-4o
model_type: llm model_type: llm
- metadata: {} - metadata: {}
model_id: anthropic/claude-3-5-sonnet-latest model_id: claude-3-5-sonnet-latest
provider_id: anthropic provider_id: anthropic
provider_model_id: anthropic/claude-3-5-sonnet-latest provider_model_id: claude-3-5-sonnet-latest
model_type: llm model_type: llm
- metadata: {} - metadata: {}
model_id: gemini/gemini-1.5-flash model_id: gemini/gemini-1.5-flash
@ -188,12 +188,6 @@ datasets:
uri: huggingface://datasets/llamastack/math_500?split=test uri: huggingface://datasets/llamastack/math_500?split=test
metadata: {} metadata: {}
dataset_id: math_500 dataset_id: math_500
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/bfcl_v3?split=train
metadata: {}
dataset_id: bfcl
- purpose: eval/messages-answer - purpose: eval/messages-answer
source: source:
type: uri type: uri
@ -228,11 +222,6 @@ benchmarks:
- basic::regex_parser_math_response - basic::regex_parser_math_response
metadata: {} metadata: {}
benchmark_id: meta-reference-math-500 benchmark_id: meta-reference-math-500
- dataset_id: bfcl
scoring_functions:
- basic::bfcl
metadata: {}
benchmark_id: meta-reference-bfcl
- dataset_id: ifeval - dataset_id: ifeval
scoring_functions: scoring_functions:
- basic::ifeval - basic::ifeval

View file

@ -35,7 +35,7 @@ distribution_spec:
telemetry: telemetry:
- provider_type: inline::meta-reference - provider_type: inline::meta-reference
post_training: post_training:
- provider_type: inline::torchtune-gpu - provider_type: inline::huggingface-gpu
eval: eval:
- provider_type: inline::meta-reference - provider_type: inline::meta-reference
datasetio: datasetio:

View file

@ -89,28 +89,28 @@ providers:
config: config:
kvstore: kvstore:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/faiss_store.db
- provider_id: sqlite-vec - provider_id: sqlite-vec
provider_type: inline::sqlite-vec provider_type: inline::sqlite-vec
config: config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/sqlite_vec.db
kvstore: kvstore:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/sqlite_vec_registry.db
- provider_id: ${env.MILVUS_URL:+milvus} - provider_id: ${env.MILVUS_URL:+milvus}
provider_type: inline::milvus provider_type: inline::milvus
config: config:
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter-gpu}/milvus.db
kvstore: kvstore:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/milvus_registry.db
- provider_id: ${env.CHROMADB_URL:+chromadb} - provider_id: ${env.CHROMADB_URL:+chromadb}
provider_type: remote::chromadb provider_type: remote::chromadb
config: config:
url: ${env.CHROMADB_URL:=} url: ${env.CHROMADB_URL:=}
kvstore: kvstore:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu/}/chroma_remote_registry.db
- provider_id: ${env.PGVECTOR_DB:+pgvector} - provider_id: ${env.PGVECTOR_DB:+pgvector}
provider_type: remote::pgvector provider_type: remote::pgvector
config: config:
@ -121,15 +121,15 @@ providers:
password: ${env.PGVECTOR_PASSWORD:=} password: ${env.PGVECTOR_PASSWORD:=}
kvstore: kvstore:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/pgvector_registry.db
files: files:
- provider_id: meta-reference-files - provider_id: meta-reference-files
provider_type: inline::localfs provider_type: inline::localfs
config: config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files} storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter-gpu/files}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/files_metadata.db
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard
@ -156,10 +156,13 @@ providers:
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/trace_store.db sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training: post_training:
- provider_id: torchtune-gpu - provider_id: huggingface-gpu
provider_type: inline::torchtune-gpu provider_type: inline::huggingface-gpu
config: config:
checkpoint_format: meta checkpoint_format: huggingface
distributed_backend: null
device: cpu
dpo_output_dir: ~/.llama/distributions/starter-gpu/dpo_output
eval: eval:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference

View file

@ -11,12 +11,10 @@ from ..starter.starter import get_distribution_template as get_starter_distribut
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
template = get_starter_distribution_template() template = get_starter_distribution_template(name="starter-gpu")
name = "starter-gpu"
template.name = name
template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments." template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments."
template.providers["post_training"] = [ template.providers["post_training"] = [
BuildProvider(provider_type="inline::torchtune-gpu"), BuildProvider(provider_type="inline::huggingface-gpu"),
] ]
return template return template

View file

@ -35,7 +35,7 @@ distribution_spec:
telemetry: telemetry:
- provider_type: inline::meta-reference - provider_type: inline::meta-reference
post_training: post_training:
- provider_type: inline::huggingface-cpu - provider_type: inline::torchtune-cpu
eval: eval:
- provider_type: inline::meta-reference - provider_type: inline::meta-reference
datasetio: datasetio:

View file

@ -156,13 +156,10 @@ providers:
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training: post_training:
- provider_id: huggingface-cpu - provider_id: torchtune-cpu
provider_type: inline::huggingface-cpu provider_type: inline::torchtune-cpu
config: config:
checkpoint_format: huggingface checkpoint_format: meta
distributed_backend: null
device: cpu
dpo_output_dir: ~/.llama/distributions/starter/dpo_output
eval: eval:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference

View file

@ -99,9 +99,8 @@ def get_remote_inference_providers() -> list[Provider]:
return inference_providers return inference_providers
def get_distribution_template() -> DistributionTemplate: def get_distribution_template(name: str = "starter") -> DistributionTemplate:
remote_inference_providers = get_remote_inference_providers() remote_inference_providers = get_remote_inference_providers()
name = "starter"
providers = { providers = {
"inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in remote_inference_providers] "inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in remote_inference_providers]
@ -120,7 +119,7 @@ def get_distribution_template() -> DistributionTemplate:
], ],
"agents": [BuildProvider(provider_type="inline::meta-reference")], "agents": [BuildProvider(provider_type="inline::meta-reference")],
"telemetry": [BuildProvider(provider_type="inline::meta-reference")], "telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"post_training": [BuildProvider(provider_type="inline::huggingface-cpu")], "post_training": [BuildProvider(provider_type="inline::torchtune-cpu")],
"eval": [BuildProvider(provider_type="inline::meta-reference")], "eval": [BuildProvider(provider_type="inline::meta-reference")],
"datasetio": [ "datasetio": [
BuildProvider(provider_type="remote::huggingface"), BuildProvider(provider_type="remote::huggingface"),

View file

@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches):
# TODO: set expiration time for garbage collection # TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions"]: if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
raise ValueError( raise ValueError(
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint", f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
) )
if completion_window != "24h": if completion_window != "24h":
@ -424,13 +424,21 @@ class ReferenceBatchesImpl(Batches):
) )
valid = False valid = False
for param, expected_type, type_string in [ if batch.endpoint == "/v1/chat/completions":
("model", str, "a string"), required_params = [
# messages is specific to /v1/chat/completions ("model", str, "a string"),
# we could skip validating messages here and let inference fail. however, # messages is specific to /v1/chat/completions
# that would be a very expensive way to find out messages is wrong. # we could skip validating messages here and let inference fail. however,
("messages", list, "an array"), # TODO: allow messages to be a string? # that would be a very expensive way to find out messages is wrong.
]: ("messages", list, "an array"), # TODO: allow messages to be a string?
]
else: # /v1/completions
required_params = [
("model", str, "a string"),
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
]
for param, expected_type, type_string in required_params:
if param not in body: if param not in body:
errors.append( errors.append(
BatchError( BatchError(
@ -591,20 +599,37 @@ class ReferenceBatchesImpl(Batches):
try: try:
# TODO(SECURITY): review body for security issues # TODO(SECURITY): review body for security issues
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]] if request.url == "/v1/chat/completions":
chat_response = await self.inference_api.openai_chat_completion(**request.body) request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
chat_response = await self.inference_api.openai_chat_completion(**request.body)
# this is for mypy, we don't allow streaming so we'll get the right type # this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method" assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
return { return {
"id": request_id, "id": request_id,
"custom_id": request.custom_id, "custom_id": request.custom_id,
"response": { "response": {
"status_code": 200, "status_code": 200,
"request_id": request_id, # TODO: should this be different? "request_id": request_id, # TODO: should this be different?
"body": chat_response.model_dump_json(), "body": chat_response.model_dump_json(),
}, },
} }
else: # /v1/completions
completion_response = await self.inference_api.openai_completion(**request.body)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(completion_response, "model_dump_json"), (
"Completion response must have model_dump_json method"
)
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id,
"body": completion_response.model_dump_json(),
},
}
except Exception as e: except Exception as e:
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}") logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
return { return {

View file

@ -86,11 +86,16 @@ class LocalfsFilesImpl(Files):
self, self,
file: Annotated[UploadFile, File()], file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()], purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
) -> OpenAIFileObject: ) -> OpenAIFileObject:
"""Upload a file that can be used across various endpoints.""" """Upload a file that can be used across various endpoints."""
if not self.sql_store: if not self.sql_store:
raise RuntimeError("Files provider not initialized") raise RuntimeError("Files provider not initialized")
if expires_after_anchor is not None or expires_after_seconds is not None:
raise NotImplementedError("File expiration is not supported by this provider")
file_id = self._generate_file_id() file_id = self._generate_file_id()
file_path = self._get_file_path(file_id) file_path = self._get_file_path(file_id)

View file

@ -22,7 +22,6 @@ from llama_stack.providers.utils.common.data_schema_validator import (
) )
from .config import BasicScoringConfig from .config import BasicScoringConfig
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.ifeval_scoring_fn import IfEvalScoringFn from .scoring_fn.ifeval_scoring_fn import IfEvalScoringFn
@ -37,7 +36,6 @@ FIXED_FNS = [
SubsetOfScoringFn, SubsetOfScoringFn,
RegexParserScoringFn, RegexParserScoringFn,
RegexParserMathResponseScoringFn, RegexParserMathResponseScoringFn,
BFCLScoringFn,
IfEvalScoringFn, IfEvalScoringFn,
DocVQAScoringFn, DocVQAScoringFn,
] ]

View file

@ -1,93 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import re
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from ..utils.bfcl.ast_parser import decode_ast
from ..utils.bfcl.checker import ast_checker, is_empty_output
from .fn_defs.bfcl import bfcl
def postprocess(x: dict[str, Any], test_category: str) -> dict[str, Any]:
contain_func_call = False
error = None
error_type = None
checker_result = {}
try:
prediction = decode_ast(x["generated_answer"], x["language"]) or ""
contain_func_call = True
# if not is_function_calling_format_output(prediction):
if is_empty_output(prediction):
contain_func_call = False
error = "Did not output in the specified format. Note: the model_result is wrapped in a string to ensure json serializability."
error_type = "ast_decoder:decoder_wrong_output_format"
else:
checker_result = ast_checker(
json.loads(x["function"]),
prediction,
json.loads(x["ground_truth"]),
x["language"],
test_category=test_category,
model_name="",
)
except Exception as e:
prediction = ""
error = f"Invalid syntax. Failed to decode AST. {str(e)}"
error_type = "ast_decoder:decoder_failed"
return {
"prediction": prediction,
"contain_func_call": contain_func_call,
"valid": checker_result.get("valid", False),
"error": error or checker_result.get("error", ""),
"error_type": error_type or checker_result.get("error_type", ""),
}
def gen_valid(x: dict[str, Any]) -> dict[str, float]:
return {"valid": x["valid"]}
def gen_relevance_acc(x: dict[str, Any]) -> dict[str, float]:
# This function serves for both relevance and irrelevance tests, which share the exact opposite logic.
# If `test_category` is "irrelevance", the model is expected to output no function call.
# No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`).
# If `test_category` is "relevance", the model is expected to output to a function call, and empty list doesn't count as a function call.
acc = not x["contain_func_call"] if "irrelevance" in x["id"] else x["contain_func_call"]
return {"valid": float(acc)}
class BFCLScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn for BFCL
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
bfcl.identifier: bfcl,
}
async def score_row(
self,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = "bfcl",
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"])
score_result = postprocess(input_row, test_category)
if test_category in {"irrelevance", "live_relevance", "live_irrelevance"}:
score = gen_relevance_acc(score_result)["valid"]
else:
score = gen_valid(score_result)["valid"]
return {
"score": float(score),
}

View file

@ -1,21 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
bfcl = ScoringFn(
identifier="basic::bfcl",
description="BFCL complex scoring",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="bfcl",
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
)

View file

@ -1,296 +0,0 @@
# ruff: noqa
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import ast
from .tree_sitter import get_parser
def parse_java_function_call(source_code):
if not source_code.endswith(";"):
source_code += ";" # Necessary for the parser not to register an error
parser = get_parser("java")
tree = parser.parse(bytes(source_code, "utf8"))
root_node = tree.root_node
if root_node.has_error:
raise Exception("Error parsing java the source code.")
def get_text(node):
"""Returns the text represented by the node."""
return source_code[node.start_byte : node.end_byte]
def traverse_node(node, nested=False):
if node.type == "string_literal":
if nested:
return get_text(node)
# Strip surrounding quotes from string literals
return get_text(node)[1:-1]
elif node.type == "character_literal":
if nested:
return get_text(node)
# Strip surrounding single quotes from character literals
return get_text(node)[1:-1]
"""Traverse the node to collect texts for complex structures."""
if node.type in [
"identifier",
"class_literal",
"type_identifier",
"method_invocation",
]:
return get_text(node)
elif node.type == "array_creation_expression":
# Handle array creation expression specifically
type_node = node.child_by_field_name("type")
value_node = node.child_by_field_name("value")
type_text = traverse_node(type_node, True)
value_text = traverse_node(value_node, True)
return f"new {type_text}[]{value_text}"
elif node.type == "object_creation_expression":
# Handle object creation expression specifically
type_node = node.child_by_field_name("type")
arguments_node = node.child_by_field_name("arguments")
type_text = traverse_node(type_node, True)
if arguments_node:
# Process each argument carefully, avoiding unnecessary punctuation
argument_texts = []
for child in arguments_node.children:
if child.type not in [
",",
"(",
")",
]: # Exclude commas and parentheses
argument_text = traverse_node(child, True)
argument_texts.append(argument_text)
arguments_text = ", ".join(argument_texts)
return f"new {type_text}({arguments_text})"
else:
return f"new {type_text}()"
elif node.type == "set":
# Handling sets specifically
items = [traverse_node(n, True) for n in node.children if n.type not in [",", "set"]]
return "{" + ", ".join(items) + "}"
elif node.child_count > 0:
return "".join(traverse_node(child, True) for child in node.children)
else:
return get_text(node)
def extract_arguments(args_node):
arguments = {}
for child in args_node.children:
if child.type == "assignment_expression":
# For named parameters
name_node, value_node = child.children[0], child.children[2]
name = get_text(name_node)
value = traverse_node(value_node)
if name in arguments:
if not isinstance(arguments[name], list):
arguments[name] = [arguments[name]]
arguments[name].append(value)
else:
arguments[name] = value
# arguments.append({'name': name, 'value': value})
elif child.type in ["identifier", "class_literal", "set"]:
# For unnamed parameters and handling sets
value = traverse_node(child)
if None in arguments:
if not isinstance(arguments[None], list):
arguments[None] = [arguments[None]]
arguments[None].append(value)
else:
arguments[None] = value
return arguments
def traverse(node):
if node.type == "method_invocation":
# Extract the function name and its arguments
method_name = get_text(node.child_by_field_name("name"))
class_name_node = node.child_by_field_name("object")
if class_name_node:
class_name = get_text(class_name_node)
function_name = f"{class_name}.{method_name}"
else:
function_name = method_name
arguments_node = node.child_by_field_name("arguments")
if arguments_node:
arguments = extract_arguments(arguments_node)
for key, value in arguments.items():
if isinstance(value, list):
raise Exception("Error: Multiple arguments with the same name are not supported.")
return [{function_name: arguments}]
else:
for child in node.children:
result = traverse(child)
if result:
return result
result = traverse(root_node)
return result if result else {}
def parse_javascript_function_call(source_code):
if not source_code.endswith(";"):
source_code += ";" # Necessary for the parser not to register an error
parser = get_parser("javascript")
# Parse the source code
tree = parser.parse(bytes(source_code, "utf8"))
root_node = tree.root_node
if root_node.has_error:
raise Exception("Error js parsing the source code.")
# Function to recursively extract argument details
def extract_arguments(node):
args = {}
for child in node.children:
if child.type == "assignment_expression":
# Extract left (name) and right (value) parts of the assignment
name = child.children[0].text.decode("utf-8")
value = child.children[2].text.decode("utf-8")
if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")):
value = value[1:-1] # Trim the quotation marks
if name in args:
if not isinstance(args[name], list):
args[name] = [args[name]]
args[name].append(value)
else:
args[name] = value
elif child.type == "identifier" or child.type == "true":
# Handle non-named arguments and boolean values
value = child.text.decode("utf-8")
if None in args:
if not isinstance(args[None], list):
args[None] = [args[None]]
args[None].append(value)
else:
args[None] = value
return args
# Find the function call and extract its name and arguments
if root_node.type == "program":
for child in root_node.children:
if child.type == "expression_statement":
for sub_child in child.children:
if sub_child.type == "call_expression":
function_name = sub_child.children[0].text.decode("utf8")
arguments_node = sub_child.children[1]
parameters = extract_arguments(arguments_node)
for key, value in parameters.items():
if isinstance(value, list):
raise Exception("Error: Multiple arguments with the same name are not supported.")
result = [{function_name: parameters}]
return result
def ast_parse(input_str, language="Python"):
if language == "Python":
cleaned_input = input_str.strip("[]'")
parsed = ast.parse(cleaned_input, mode="eval")
extracted = []
if isinstance(parsed.body, ast.Call):
extracted.append(resolve_ast_call(parsed.body))
else:
for elem in parsed.body.elts:
extracted.append(resolve_ast_call(elem))
return extracted
elif language == "Java":
return parse_java_function_call(input_str[1:-1]) # Remove the [ and ] from the string
elif language == "JavaScript":
return parse_javascript_function_call(input_str[1:-1])
else:
raise NotImplementedError(f"Unsupported language: {language}")
def resolve_ast_call(elem):
# Handle nested attributes for deeply nested module paths
func_parts = []
func_part = elem.func
while isinstance(func_part, ast.Attribute):
func_parts.append(func_part.attr)
func_part = func_part.value
if isinstance(func_part, ast.Name):
func_parts.append(func_part.id)
func_name = ".".join(reversed(func_parts))
args_dict = {}
# Parse when args are simply passed as an unnamed dictionary arg
for arg in elem.args:
if isinstance(arg, ast.Dict):
for key, value in zip(arg.keys, arg.values):
if isinstance(key, ast.Constant):
arg_name = key.value
output = resolve_ast_by_type(value)
args_dict[arg_name] = output
for arg in elem.keywords:
output = resolve_ast_by_type(arg.value)
args_dict[arg.arg] = output
return {func_name: args_dict}
def resolve_ast_by_type(value):
if isinstance(value, ast.Constant):
if value.value is Ellipsis:
output = "..."
else:
output = value.value
elif isinstance(value, ast.UnaryOp):
output = -value.operand.value
elif isinstance(value, ast.List):
output = [resolve_ast_by_type(v) for v in value.elts]
elif isinstance(value, ast.Dict):
output = {resolve_ast_by_type(k): resolve_ast_by_type(v) for k, v in zip(value.keys, value.values)}
elif isinstance(value, ast.NameConstant): # Added this condition to handle boolean values
output = value.value
elif isinstance(value, ast.BinOp): # Added this condition to handle function calls as arguments
output = eval(ast.unparse(value))
elif isinstance(value, ast.Name):
output = value.id
elif isinstance(value, ast.Call):
if len(value.keywords) == 0:
output = ast.unparse(value)
else:
output = resolve_ast_call(value)
elif isinstance(value, ast.Tuple):
output = tuple(resolve_ast_by_type(v) for v in value.elts)
elif isinstance(value, ast.Lambda):
output = eval(ast.unparse(value.body[0].value))
elif isinstance(value, ast.Ellipsis):
output = "..."
elif isinstance(value, ast.Subscript):
try:
output = ast.unparse(value.body[0].value)
except:
output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]"
else:
raise Exception(f"Unsupported AST type: {type(value)}")
return output
def decode_ast(result, language="Python"):
func = result
func = func.replace("\n", "") # remove new line characters
if not func.startswith("["):
func = "[" + func
if not func.endswith("]"):
func = func + "]"
decoded_output = ast_parse(func, language)
return decoded_output
def decode_execute(result):
func = result
func = func.replace("\n", "") # remove new line characters
if not func.startswith("["):
func = "[" + func
if not func.endswith("]"):
func = func + "]"
decode_output = ast_parse(func)
execution_list = []
for function_call in decode_output:
for key, value in function_call.items():
execution_list.append(f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})")
return execution_list

View file

@ -1,989 +0,0 @@
# ruff: noqa
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import re
import time
from typing import Any
# Comment out for now until we actually use the rest checker in evals
# import requests # Do not remove this import even though it seems to be unused. It's used in the executable_checker_rest function.
class NoAPIKeyError(Exception):
def __init__(self):
self.message = "Please fill in the API keys in the function_credential_config.json file. If you do not provide the API keys, the executable test category results will be inaccurate."
super().__init__(self.message)
REAL_TIME_MATCH_ALLOWED_DIFFERENCE = 0.2
JAVA_TYPE_CONVERSION = {
"byte": int,
"short": int,
"integer": int,
"float": float,
"double": float,
"long": int,
"boolean": bool,
"char": str,
"Array": list,
"ArrayList": list,
"Set": set,
"HashMap": dict,
"Hashtable": dict,
"Queue": list, # this can be `queue.Queue` as well, for simplicity we check with list
"Stack": list,
"String": str,
"any": str,
}
JS_TYPE_CONVERSION = {
"String": str,
"integer": int,
"float": float,
"Bigint": int,
"Boolean": bool,
"dict": dict,
"array": list,
"any": str,
}
# We switch to conditional import for the following two imports to avoid unnecessary installations.
# User doesn't need to setup the tree-sitter packages if they are not running the test for that language.
# from js_type_converter import js_type_converter
# from java_type_converter import java_type_converter
PYTHON_TYPE_MAPPING = {
"string": str,
"integer": int,
"float": float,
"boolean": bool,
"array": list,
"tuple": list,
"dict": dict,
"any": str,
}
# This is the list of types that we need to recursively check its values
PYTHON_NESTED_TYPE_CHECK_LIST = ["array", "tuple"]
NESTED_CONVERSION_TYPE_LIST = ["Array", "ArrayList", "array"]
#### Helper functions for AST ####
def find_description(func_descriptions, name):
if type(func_descriptions) == list:
for func_description in func_descriptions:
if func_description["name"] == name:
return func_description
return None
else:
# it is a dict, there is only one function
return func_descriptions
def get_possible_answer_type(possible_answer: list):
for answer in possible_answer:
if answer != "": # Optional parameter
return type(answer)
return None
def type_checker(
param: str,
value,
possible_answer: list,
expected_type_description: str,
expected_type_converted,
nested_type_converted,
):
# NOTE: This type checker only supports nested type checking for one level deep.
# We didn't implement recursive type checking for nested types, as it's not needed for the current use case and it's very complex.
result: Any = {
"valid": True,
"error": [],
"is_variable": False,
"error_type": "type_error:simple",
}
is_variable = False
# check for the case where a variable is used instead of a actual value.
# use the type in possible_answer as the expected type
possible_answer_type = get_possible_answer_type(possible_answer)
# if possible_answer only contains optional parameters, we can't determine the type
if possible_answer_type != None:
# we are being precise here.
# in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer
if possible_answer_type != expected_type_converted:
is_variable = True
# value is the same type as in function description
if type(value) == expected_type_converted:
# We don't need to do recursive check for simple types
if nested_type_converted == None:
result["is_variable"] = is_variable
return result
else:
for possible_answer_item in possible_answer:
flag = True # Each parameter should match to at least one possible answer type.
# Here, we assume that each item should be the same type. We could also relax it.
if type(possible_answer_item) == list:
for value_item in value:
checker_result = type_checker(
param,
value_item,
possible_answer_item,
str(nested_type_converted),
nested_type_converted,
None,
)
if not checker_result["valid"]:
flag = False
break
if flag:
return {"valid": True, "error": [], "is_variable": is_variable}
result["valid"] = False
result["error"] = [
f"Nested type checking failed for parameter {repr(param)}. Expected outer type {expected_type_description} with inner type {str(nested_type_converted)}. Parameter value: {repr(value)}."
]
result["error_type"] = "type_error:nested"
# value is not as expected, check for the case where a variable is used instead of a actual value
# use the type in possible_answer as the expected type
possible_answer_type = get_possible_answer_type(possible_answer)
# if possible_answer only contains optional parameters, we can't determine the type
if possible_answer_type != None:
# we are being precise here.
# in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer
if type(value) == possible_answer_type:
result["is_variable"] = True
return result
result["valid"] = False
result["error"].append(
f"Incorrect type for parameter {repr(param)}. Expected type {expected_type_description}, got {type(value).__name__}. Parameter value: {repr(value)}."
)
result["error_type"] = "type_error:simple"
return result
def standardize_string(input_string: str):
# This function standardizes the string by removing all the spaces, ",./-_*^" punctuation, and converting it to lowercase
# It will also convert all the single quotes to double quotes
# This is used to compare the model output with the possible answers
# We don't want to punish model for answer like April 1, 2024 vs April 1,2024, vs April 1 2024
regex_string = r"[ \,\.\/\-\_\*\^]"
return re.sub(regex_string, "", input_string).lower().replace("'", '"')
def string_checker(param: str, model_output: str, possible_answer: list):
standardize_possible_answer = []
standardize_model_output = standardize_string(model_output)
for i in range(len(possible_answer)):
if type(possible_answer[i]) == str:
standardize_possible_answer.append(standardize_string(possible_answer[i]))
if standardize_model_output not in standardize_possible_answer:
return {
"valid": False,
"error": [
f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}. Case insensitive."
],
"error_type": "value_error:string",
}
return {"valid": True, "error": []}
def list_checker(param: str, model_output: list, possible_answer: list):
# Convert the tuple to a list
standardize_model_output = list(model_output)
# If the element in the list is a string, we need to standardize it
for i in range(len(standardize_model_output)):
if type(standardize_model_output[i]) == str:
standardize_model_output[i] = standardize_string(model_output[i])
standardize_possible_answer: Any = []
# We also need to standardize the possible answers
for i in range(len(possible_answer)):
standardize_possible_answer.append([])
for j in range(len(possible_answer[i])):
if type(possible_answer[i][j]) == str:
standardize_possible_answer[i].append(standardize_string(possible_answer[i][j]))
else:
standardize_possible_answer[i].append(possible_answer[i][j])
if standardize_model_output not in standardize_possible_answer:
return {
"valid": False,
"error": [
f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}."
],
"error_type": "value_error:list/tuple",
}
return {"valid": True, "error": []}
def dict_checker(param: str, model_output: dict, possible_answers: list):
# This function works for simple dictionaries, but not dictionaries with nested dictionaries.
# The current dataset only contains simple dictionaries, so this is sufficient.
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
for i in range(len(possible_answers)):
if possible_answers[i] == "":
continue
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
flag = True
possible_answer = possible_answers[i]
# possible_anwer is a single dictionary
for key, value in model_output.items():
if key not in possible_answer:
result["valid"] = False
result["error"].append(f"Unexpected dict key parameter: '{key}'.") # type: ignore[attr-defined]
result["error_type"] = "value_error:dict_key"
flag = False
break
standardize_value = value
# If the value is a string, we need to standardize it
if type(value) == str:
standardize_value = standardize_string(value)
# We also need to standardize the possible answers if they are string
standardize_possible_answer = []
for i in range(len(possible_answer[key])):
if type(possible_answer[key][i]) == str:
standardize_possible_answer.append(standardize_string(possible_answer[key][i]))
else:
standardize_possible_answer.append(possible_answer[key][i])
if standardize_value not in standardize_possible_answer:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Invalid value for parameter {repr(key)}: {repr(value)}. Expected one of {standardize_possible_answer}."
)
result["error_type"] = "value_error:dict_value"
flag = False
break
for key, value in possible_answer.items():
if key not in model_output and "" not in value:
result["valid"] = False
result["error"].append(f"Missing dict key parameter: '{key}'.") # type: ignore[attr-defined]
result["error_type"] = "value_error:dict_key"
flag = False
break
if flag:
return {"valid": True, "error": []}
return result
def list_dict_checker(param: str, model_output: list, possible_answers: list):
# This function takes in a list of dictionaries and checks if each dictionary is valid
# The order of the dictionaries in the list must match the order of the possible answers
result = {"valid": False, "error": [], "error_type": "list_dict_checker:unclear"}
for answer_index in range(len(possible_answers)):
flag = True # True means so far, all dictionaries are valid
# Only proceed if the number of dictionaries in the list matches the number of dictionaries in the possible answers
if len(model_output) != len(possible_answers[answer_index]):
result["valid"] = False
result["error"] = ["Wrong number of dictionaries in the list."]
result["error_type"] = "value_error:list_dict_count"
flag = False
continue
for dict_index in range(len(model_output)):
result = dict_checker(
param,
model_output[dict_index],
[possible_answers[answer_index][dict_index]],
)
if not result["valid"]:
flag = False
break
if flag:
return {"valid": True, "error": []}
return result
def simple_function_checker(
func_description: dict,
model_output: dict,
possible_answer: dict,
language: str,
model_name: str,
):
possible_answer = list(possible_answer.values())[0]
# Extract function name and parameters details
func_name = func_description["name"]
param_details = func_description["parameters"]["properties"]
required_params = func_description["parameters"]["required"]
# Initialize a result dictionary
result = {
"valid": True,
"error": [],
"error_type": "simple_function_checker:unclear",
}
# Check if function name matches
if func_name not in model_output:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Function name {repr(func_name)} not found in model output."
)
result["error_type"] = "simple_function_checker:wrong_func_name"
return result
model_params = model_output[func_name]
# Check for required parameters in model output
for param in required_params:
if param not in model_params:
result["valid"] = False
result["error"].append(f"Missing required parameter: {repr(param)}.") # type: ignore[attr-defined]
result["error_type"] = "simple_function_checker:missing_required"
return result
# Validate types and values for each parameter in model output
for param, value in model_params.items():
if param not in param_details or param not in possible_answer:
result["valid"] = False
result["error"].append(f"Unexpected parameter: {repr(param)}.") # type: ignore[attr-defined]
result["error_type"] = "simple_function_checker:unexpected_param"
return result
full_param_details = param_details[param]
expected_type_description = full_param_details["type"] # This is a string
is_variable = False
nested_type_converted = None
if language == "Java":
from evals.utils.bfcl.java_type_converter import java_type_converter
expected_type_converted = JAVA_TYPE_CONVERSION[expected_type_description]
if expected_type_description in JAVA_TYPE_CONVERSION:
if type(value) != str:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}."
)
result["error_type"] = "type_error:java"
return result
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
nested_type = param_details[param]["items"]["type"]
nested_type_converted = JAVA_TYPE_CONVERSION[nested_type]
value = java_type_converter(value, expected_type_description, nested_type)
else:
value = java_type_converter(value, expected_type_description)
elif language == "JavaScript":
from evals.utils.bfcl.js_type_converter import js_type_converter
expected_type_converted = JS_TYPE_CONVERSION[expected_type_description]
if expected_type_description in JS_TYPE_CONVERSION:
if type(value) != str:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}."
)
result["error_type"] = "type_error:js"
return result
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
nested_type = param_details[param]["items"]["type"]
nested_type_converted = JS_TYPE_CONVERSION[nested_type]
value = js_type_converter(value, expected_type_description, nested_type)
else:
value = js_type_converter(value, expected_type_description)
elif language == "Python":
expected_type_converted = PYTHON_TYPE_MAPPING[expected_type_description]
if expected_type_description in PYTHON_NESTED_TYPE_CHECK_LIST:
nested_type = param_details[param]["items"]["type"]
nested_type_converted = PYTHON_TYPE_MAPPING[nested_type]
# We convert all tuple value to list when the expected type is tuple.
# The conversion is necessary because any tuple in the possible answer would become a list after being processed through json.dump() and json.load().
# This does introduce some false positive (eg, when the model provides a list value instead of tuple). We hope to find a better solution in the future.
if expected_type_description == "tuple" and type(value) == tuple:
value = list(value)
# Allow python auto conversion from int to float
if language == "Python" and expected_type_description == "float" and type(value) == int:
value = float(value)
# Type checking
# In fact, we only check for Python here.
# Type check for other languages are handled by the type converter, and so their value (after conversion) is always correct.
type_check_result = type_checker(
param,
value,
possible_answer[param],
expected_type_description,
expected_type_converted,
nested_type_converted,
)
is_variable = type_check_result["is_variable"]
if not type_check_result["valid"]:
return type_check_result
# It doesn't make sense to special handle dictionaries and list of dictionaries if the value is a variable.
# We can just treat the variable as a string and use the normal flow.
if not is_variable:
# Special handle for dictionaries
if expected_type_converted == dict:
result = dict_checker(param, value, possible_answer[param])
if not result["valid"]:
return result
continue
# Special handle for list of dictionaries
elif expected_type_converted == list and nested_type_converted == dict:
result = list_dict_checker(param, value, possible_answer[param])
if not result["valid"]:
return result
continue
# Special handle for strings
elif expected_type_converted == str:
# We don't check for case sensitivity for string, as long as it's not a variable
result = string_checker(param, value, possible_answer[param])
if not result["valid"]:
return result
continue
elif expected_type_converted == list:
result = list_checker(param, value, possible_answer[param])
if not result["valid"]:
return result
continue
# Check if the value is within the possible answers
if value not in possible_answer[param]:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Invalid value for parameter {repr(param)}: {repr(value)}. Expected one of {possible_answer[param]}."
)
result["error_type"] = "value_error:others"
return result
# Check for optional parameters not provided but allowed
for param in possible_answer:
if param not in model_params and "" not in possible_answer[param]:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Optional parameter {repr(param)} not provided and not marked as optional."
)
result["error_type"] = "simple_function_checker:missing_optional"
return result
return result
def parallel_function_checker_enforce_order(
func_descriptions: list,
model_output: list,
possible_answers: dict,
language: str,
model_name: str,
):
if len(model_output) != len(possible_answers):
return {
"valid": False,
"error": ["Wrong number of functions."],
"error_type": "parallel_function_checker_enforce_order:wrong_count",
}
func_name_list = list(possible_answers.keys())
possible_answers_list = []
for key, value in possible_answers.items():
possible_answers_list.append({key: value})
for i in range(len(possible_answers_list)):
func_description = find_description(func_descriptions, func_name_list[i])
result = simple_function_checker(
func_description,
model_output[i],
possible_answers_list[i],
language,
model_name,
)
if not result["valid"]:
return result
return {"valid": True, "error": []}
def parallel_function_checker_no_order(
func_descriptions: list,
model_output: list,
possible_answers: list,
language: str,
model_name: str,
):
if len(model_output) != len(possible_answers):
return {
"valid": False,
"error": ["Wrong number of functions."],
"error_type": "parallel_function_checker_no_order:wrong_count",
}
matched_indices = []
# We go throught the possible answers one by one, and eliminate the model output that matches the possible answer
# It must be this way because we need ground truth to fetch the correct function description
for i in range(len(possible_answers)):
# possible_answers[i] is a dictionary with only one key
func_name_expected = list(possible_answers[i].keys())[0]
func_description = find_description(func_descriptions, func_name_expected)
all_errors = []
for index in range(len(model_output)):
if index in matched_indices:
continue
result = simple_function_checker(
func_description,
model_output[index],
possible_answers[i],
language,
model_name,
)
if result["valid"]:
matched_indices.append(index)
break
else:
all_errors.append(
{
f"Model Result Index {index}": {
"sub_error": result["error"],
"sub_error_type": result["error_type"],
"model_output_item": model_output[index],
"possible_answer_item": possible_answers[i],
}
}
)
if not result["valid"]:
considered_indices = [i for i in range(len(model_output)) if i not in matched_indices]
all_errors.insert(
0,
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
)
return {
"valid": False,
"error": all_errors,
"error_type": "parallel_function_checker_no_order:cannot_find_match",
}
return {"valid": True, "error": []}
def multiple_function_checker(
func_descriptions: list,
model_output: list,
possible_answers: list,
language: str,
model_name: str,
):
if len(model_output) != len(possible_answers):
return {
"valid": False,
"error": ["Wrong number of functions."],
"error_type": "multiple_function_checker:wrong_count",
}
# possible_answers is a list of only one dictionary with only one key
func_name_expected = list(possible_answers[0].keys())[0]
func_description = find_description(func_descriptions, func_name_expected)
return simple_function_checker(
func_description,
model_output[0],
possible_answers[0],
language,
model_name,
)
def patten_matcher(exec_output, expected_result, function_call, is_sanity_check):
result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
if type(exec_output) != type(expected_result):
return {
"valid": False,
"error": [
f"Wrong execution result type for {repr(function_call)}. Expected type: {type(expected_result)}, but got: {type(exec_output)}."
],
"error_type": "executable_checker:wrong_result_type",
"model_executed_output": exec_output,
}
if type(exec_output) == dict:
# We loose the requirement for the sanity check as the expected result used in the sanity check might not be the most up-to-date one.
# This happens when the key is a timestamp or a random number.
if is_sanity_check:
if len(exec_output) != len(expected_result):
return {
"valid": False,
"error": [
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}."
],
"error_type": "executable_checker:wrong_result_type:dict_length",
"model_executed_output": exec_output,
}
else:
return result
for key, value in expected_result.items():
if key not in exec_output:
return {
"valid": False,
"error": [
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not found in the model output."
],
"error_type": "executable_checker:wrong_result_type:dict_key_not_found",
"model_executed_output": exec_output,
}
for key, value in exec_output.items():
if key not in expected_result:
return {
"valid": False,
"error": [
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not expected in the model output."
],
"error_type": "executable_checker:wrong_result_type:dict_extra_key",
"model_executed_output": exec_output,
}
if type(exec_output) == list:
if len(exec_output) != len(expected_result):
return {
"valid": False,
"error": [
f"Wrong execution result pattern for {repr(function_call)}. Expect type list, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}."
],
"error_type": "executable_checker:wrong_result_type:list_length",
"model_executed_output": exec_output,
}
return result
#### Helper functions for Exec ####
def executable_checker_simple(
function_call: str,
expected_result,
expected_result_type: str,
is_sanity_check=False,
):
result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
exec_dict: Any = {}
try:
exec(
"from executable_python_function import *" + "\nresult=" + function_call,
exec_dict,
)
exec_output = exec_dict["result"]
except NoAPIKeyError as e:
raise e
except Exception as e:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Error in execution: {repr(function_call)}. Error: {str(e)}"
)
result["error_type"] = "executable_checker:execution_error"
return result
# We need to special handle the case where the execution result is a tuple and convert it to a list
# Because when json is stored, the tuple is converted to a list, and so the expected result is a list when loaded from json
if isinstance(exec_output, tuple):
exec_output = list(exec_output)
if expected_result_type == "exact_match":
if exec_output != expected_result:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}."
)
result["error_type"] = "executable_checker:wrong_result"
result["model_executed_output"] = exec_output
return result
elif expected_result_type == "real_time_match":
# Allow for 5% difference
if (type(expected_result) == float or type(expected_result) == int) and (
type(exec_output) == float or type(exec_output) == int
):
if not (
expected_result * (1 - REAL_TIME_MATCH_ALLOWED_DIFFERENCE)
<= exec_output
<= expected_result * (1 + REAL_TIME_MATCH_ALLOWED_DIFFERENCE)
):
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. {REAL_TIME_MATCH_ALLOWED_DIFFERENCE * 100}% difference allowed."
)
result["error_type"] = "executable_checker:wrong_result_real_time"
result["model_executed_output"] = exec_output
return result
else:
result["valid"] = False
result["error"].append( # type: ignore[attr-defined]
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. Type needs to be float or int for real time match criteria."
)
result["error_type"] = "executable_checker:wrong_result_real_time"
result["model_executed_output"] = exec_output
return result
else:
# structural match
pattern_match_result = patten_matcher(exec_output, expected_result, function_call, is_sanity_check)
if not pattern_match_result["valid"]:
return pattern_match_result
return result
def executable_checker_parallel_no_order(
decoded_result: list, expected_exec_result: list, expected_exec_result_type: list
):
if len(decoded_result) != len(expected_exec_result):
return {
"valid": False,
"error": [
f"Wrong number of functions provided. Expected {len(expected_exec_result)}, but got {len(decoded_result)}."
],
"error_type": "value_error:exec_result_count",
}
matched_indices = []
for i in range(len(expected_exec_result)):
all_errors = []
for index in range(len(decoded_result)):
if index in matched_indices:
continue
result = executable_checker_simple(
decoded_result[index],
expected_exec_result[i],
expected_exec_result_type[i],
False,
)
if result["valid"]:
matched_indices.append(index)
break
else:
all_errors.append(
{
f"Model Result Index {index}": {
"sub_error": result["error"],
"sub_error_type": result["error_type"],
"model_executed_output": (
result["model_executed_output"] if "model_executed_output" in result else None
),
}
}
)
if not result["valid"]:
considered_indices = [i for i in range(len(decoded_result)) if i not in matched_indices]
all_errors.insert(
0,
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
)
return {
"valid": False,
"error": all_errors,
"error_type": "executable_checker:cannot_find_match",
}
return {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
#### Main function ####
def executable_checker_rest(func_call, idx):
# Move this here for now to avoid needing to read this file / fix paths to be relative to dataset_dir. Fix when it's actually needed / used.
EVAL_GROUND_TRUTH_PATH = "/mnt/wsfuse/fair_llm_v2/datasets/eval/bfcl/rest-eval-response_v5.jsonl" # Ground truth file for v5 for rest execution
with open(EVAL_GROUND_TRUTH_PATH, "r") as f:
EVAL_GROUND_TRUTH = f.readlines()
if "https://geocode.maps.co" in func_call:
time.sleep(2)
if "requests_get" in func_call:
func_call = func_call.replace("requests_get", "requests.get")
try:
response = eval(func_call)
except Exception as e:
return {
"valid": False,
"error": [f"Execution failed. {str(e)}"],
"error_type": "executable_checker_rest:execution_error",
}
try:
if response.status_code == 200:
eval_GT_json = json.loads(EVAL_GROUND_TRUTH[idx])
try:
if isinstance(eval_GT_json, dict):
if isinstance(response.json(), dict):
if set(eval_GT_json.keys()) == set(response.json().keys()):
return {"valid": True, "error": [], "error_type": ""}
return {
"valid": False,
"error": ["Key inconsistency"],
"error_type": "executable_checker_rest:wrong_key",
}
return {
"valid": False,
"error": [f"Expected dictionary, but got {type(response.json())}"],
"error_type": "executable_checker_rest:wrong_type",
}
elif isinstance(eval_GT_json, list):
if isinstance(response.json(), list):
if len(eval_GT_json) != len(response.json()):
return {
"valid": False,
"error": [f"Response list length inconsistency."],
"error_type": "value_error:exec_result_rest_count",
}
else:
for i in range(len(eval_GT_json)):
if set(eval_GT_json[i].keys()) != set(response.json()[i].keys()):
return {
"valid": False,
"error": [f"Key inconsistency"],
"error_type": "executable_checker_rest:wrong_key",
}
return {"valid": True, "error": []}
else:
return {
"valid": False,
"error": [f"Expected list, but got {type(response.json())}"],
"error_type": "executable_checker_rest:wrong_type",
}
return {
"valid": False,
"error": [f"Expected dict or list, but got {type(response.json())}"],
"error_type": "executable_checker_rest:wrong_type",
}
except Exception as e:
return {
"valid": False,
"error": [
f"Error in execution and type checking. Status code: {response.status_code}. Error: {str(e)}"
],
"error_type": "executable_checker_rest:response_format_error",
}
else:
return {
"valid": False,
"error": [f"Execution result status code is not 200, got {response.status_code}"],
"error_type": "executable_checker_rest:wrong_status_code",
}
except Exception as e:
return {
"valid": False,
"error": [f"Cannot get status code of the response. Error: {str(e)}"],
"error_type": "executable_checker_rest:cannot_get_status_code",
}
def ast_checker(func_description, model_output, possible_answer, language, test_category, model_name):
if "parallel" in test_category:
return parallel_function_checker_no_order(func_description, model_output, possible_answer, language, model_name)
elif "multiple" in test_category:
return multiple_function_checker(func_description, model_output, possible_answer, language, model_name)
else:
if len(model_output) != 1:
return {
"valid": False,
"error": ["Wrong number of functions."],
"error_type": "simple_function_checker:wrong_count",
}
return simple_function_checker(
func_description[0],
model_output[0],
possible_answer[0],
language,
model_name,
)
def exec_checker(decoded_result: list, func_description: dict, test_category: str):
if "multiple" in test_category or "parallel" in test_category:
return executable_checker_parallel_no_order(
decoded_result,
func_description["execution_result"],
func_description["execution_result_type"],
)
else:
if len(decoded_result) != 1:
return {
"valid": False,
"error": ["Wrong number of functions."],
"error_type": "simple_exec_checker:wrong_count",
}
return executable_checker_simple(
decoded_result[0],
func_description["execution_result"][0],
func_description["execution_result_type"][0],
False,
)
def is_empty_output(decoded_output):
# This function is a patch to the ast decoder for relevance detection
# Sometimes the ast decoder will parse successfully, but the input doens't really have a function call
# [], [{}], and anything that is not in function calling format is considered empty (and thus should be marked as correct)
if not is_function_calling_format_output(decoded_output):
return True
if len(decoded_output) == 0:
return True
if len(decoded_output) == 1 and len(decoded_output[0]) == 0:
return True
def is_function_calling_format_output(decoded_output):
# Ensure the output is a list of dictionaries
if type(decoded_output) == list:
for item in decoded_output:
if type(item) != dict:
return False
return True
return False

View file

@ -1,40 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Tree-sitter changes its API with unfortunate frequency. Modules that need it should
import it from here so that we can centrally manage things as necessary.
"""
# These currently work with tree-sitter 0.23.0
# NOTE: Don't import tree-sitter or any of the language modules in the main module
# because not all environments have them. Import lazily inside functions where needed.
import importlib
import typing
if typing.TYPE_CHECKING:
import tree_sitter
def get_language(language: str) -> "tree_sitter.Language":
import tree_sitter
language_module_name = f"tree_sitter_{language}"
try:
language_module = importlib.import_module(language_module_name)
except ModuleNotFoundError as exc:
raise ValueError(
f"Language {language} is not found. Please install the tree-sitter-{language} package."
) from exc
return tree_sitter.Language(language_module.language())
def get_parser(language: str, **kwargs) -> "tree_sitter.Parser":
import tree_sitter
lang = get_language(language)
return tree_sitter.Parser(lang, **kwargs)

View file

@ -14,6 +14,6 @@ from .config import RagToolRuntimeConfig
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]): async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
from .memory import MemoryToolRuntimeImpl from .memory import MemoryToolRuntimeImpl
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference]) impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -5,10 +5,15 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import base64
import io
import mimetypes
import secrets import secrets
import string import string
from typing import Any from typing import Any
import httpx
from fastapi import UploadFile
from pydantic import TypeAdapter from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
@ -17,6 +22,7 @@ from llama_stack.apis.common.content_types import (
InterleavedContentItem, InterleavedContentItem,
TextContentItem, TextContentItem,
) )
from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolDefsResponse,
@ -30,13 +36,18 @@ from llama_stack.apis.tools import (
ToolParameter, ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
VectorStoreChunkingStrategyStatic,
VectorStoreChunkingStrategyStaticConfig,
)
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
content_from_doc, content_from_doc,
make_overlapped_chunks, parse_data_url,
) )
from .config import RagToolRuntimeConfig from .config import RagToolRuntimeConfig
@ -55,10 +66,12 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
config: RagToolRuntimeConfig, config: RagToolRuntimeConfig,
vector_io_api: VectorIO, vector_io_api: VectorIO,
inference_api: Inference, inference_api: Inference,
files_api: Files,
): ):
self.config = config self.config = config
self.vector_io_api = vector_io_api self.vector_io_api = vector_io_api
self.inference_api = inference_api self.inference_api = inference_api
self.files_api = files_api
async def initialize(self): async def initialize(self):
pass pass
@ -78,27 +91,50 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
vector_db_id: str, vector_db_id: str,
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
) -> None: ) -> None:
chunks = [] if not documents:
return
for doc in documents: for doc in documents:
content = await content_from_doc(doc) if isinstance(doc.content, URL):
# TODO: we should add enrichment here as URLs won't be added to the metadata by default if doc.content.uri.startswith("data:"):
chunks.extend( parts = parse_data_url(doc.content.uri)
make_overlapped_chunks( file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode()
doc.document_id, mime_type = parts["mimetype"]
content, else:
chunk_size_in_tokens, async with httpx.AsyncClient() as client:
chunk_size_in_tokens // 4, response = await client.get(doc.content.uri)
doc.metadata, file_data = response.content
mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream")
else:
content_str = await content_from_doc(doc)
file_data = content_str.encode("utf-8")
mime_type = doc.mime_type or "text/plain"
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
file_obj = io.BytesIO(file_data)
file_obj.name = filename
upload_file = UploadFile(file=file_obj, filename=filename)
created_file = await self.files_api.openai_upload_file(
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
chunking_strategy = VectorStoreChunkingStrategyStatic(
static=VectorStoreChunkingStrategyStaticConfig(
max_chunk_size_tokens=chunk_size_in_tokens,
chunk_overlap_tokens=chunk_size_in_tokens // 4,
) )
) )
if not chunks: await self.vector_io_api.openai_attach_file_to_vector_store(
return vector_store_id=vector_db_id,
file_id=created_file.id,
await self.vector_io_api.insert_chunks( attributes=doc.metadata,
chunks=chunks, chunking_strategy=chunking_strategy,
vector_db_id=vector_db_id, )
)
async def query( async def query(
self, self,
@ -131,8 +167,18 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
for vector_db_id in vector_db_ids for vector_db_id in vector_db_ids
] ]
results: list[QueryChunksResponse] = await asyncio.gather(*tasks) results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores] chunks = []
scores = []
for vector_db_id, result in zip(vector_db_ids, results, strict=False):
for chunk, score in zip(result.chunks, result.scores, strict=False):
if not hasattr(chunk, "metadata") or chunk.metadata is None:
chunk.metadata = {}
chunk.metadata["vector_db_id"] = vector_db_id
chunks.append(chunk)
scores.append(score)
if not chunks: if not chunks:
return RAGQueryResult(content=None) return RAGQueryResult(content=None)
@ -167,6 +213,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
metadata_keys_to_exclude_from_context = [ metadata_keys_to_exclude_from_context = [
"token_count", "token_count",
"metadata_token_count", "metadata_token_count",
"vector_db_id",
] ]
metadata_for_context = {} metadata_for_context = {}
for k in chunk_metadata_keys_to_include_from_context: for k in chunk_metadata_keys_to_include_from_context:
@ -191,6 +238,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]], "document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
"chunks": [c.content for c in chunks[: len(picked)]], "chunks": [c.content for c in chunks[: len(picked)]],
"scores": scores[: len(picked)], "scores": scores[: len(picked)],
"vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]],
}, },
) )

View file

@ -30,11 +30,11 @@ from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF, RERANKER_TYPE_RRF,
RERANKER_TYPE_WEIGHTED,
ChunkForDeletion, ChunkForDeletion,
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorDBWithIndex,
) )
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator
logger = get_logger(name=__name__, category="vector_io") logger = get_logger(name=__name__, category="vector_io")
@ -66,59 +66,6 @@ def _create_sqlite_connection(db_path):
return connection return connection
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
"""Normalize scores to [0,1] range using min-max normalization."""
if not scores:
return {}
min_score = min(scores.values())
max_score = max(scores.values())
score_range = max_score - min_score
if score_range > 0:
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
return dict.fromkeys(scores, 1.0)
def _weighted_rerank(
vector_scores: dict[str, float],
keyword_scores: dict[str, float],
alpha: float = 0.5,
) -> dict[str, float]:
"""ReRanker that uses weighted average of scores."""
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
normalized_vector_scores = _normalize_scores(vector_scores)
normalized_keyword_scores = _normalize_scores(keyword_scores)
return {
doc_id: (alpha * normalized_keyword_scores.get(doc_id, 0.0))
+ ((1 - alpha) * normalized_vector_scores.get(doc_id, 0.0))
for doc_id in all_ids
}
def _rrf_rerank(
vector_scores: dict[str, float],
keyword_scores: dict[str, float],
impact_factor: float = 60.0,
) -> dict[str, float]:
"""ReRanker that uses Reciprocal Rank Fusion."""
# Convert scores to ranks
vector_ranks = {
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
}
keyword_ranks = {
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
}
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
rrf_scores = {}
for doc_id in all_ids:
vector_rank = vector_ranks.get(doc_id, float("inf"))
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
# RRF formula: score = 1/(k + r) where k is impact_factor and r is the rank
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
return rrf_scores
def _make_sql_identifier(name: str) -> str: def _make_sql_identifier(name: str) -> str:
return re.sub(r"[^a-zA-Z0-9_]", "_", name) return re.sub(r"[^a-zA-Z0-9_]", "_", name)
@ -398,14 +345,10 @@ class SQLiteVecIndex(EmbeddingIndex):
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False) for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
} }
# Combine scores using the specified reranker # Combine scores using the reranking utility
if reranker_type == RERANKER_TYPE_WEIGHTED: combined_scores = WeightedInMemoryAggregator.combine_search_results(
alpha = reranker_params.get("alpha", 0.5) vector_scores, keyword_scores, reranker_type, reranker_params
combined_scores = _weighted_rerank(vector_scores, keyword_scores, alpha) )
else:
# Default to RRF for None, RRF, or any unknown types
impact_factor = reranker_params.get("impact_factor", 60.0)
combined_scores = _rrf_rerank(vector_scores, keyword_scores, impact_factor)
# Sort by combined score and get top k results # Sort by combined score and get top k results
sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True) sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)

View file

@ -13,7 +13,7 @@ def available_providers() -> list[ProviderSpec]:
InlineProviderSpec( InlineProviderSpec(
api=Api.batches, api=Api.batches,
provider_type="inline::reference", provider_type="inline::reference",
pip_packages=["openai"], pip_packages=[],
module="llama_stack.providers.inline.batches.reference", module="llama_stack.providers.inline.batches.reference",
config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig", config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig",
api_dependencies=[ api_dependencies=[

View file

@ -30,7 +30,7 @@ def available_providers() -> list[ProviderSpec]:
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_type="huggingface", adapter_type="huggingface",
pip_packages=[ pip_packages=[
"datasets", "datasets>=4.0.0",
], ],
module="llama_stack.providers.remote.datasetio.huggingface", module="llama_stack.providers.remote.datasetio.huggingface",
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig", config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
@ -42,7 +42,7 @@ def available_providers() -> list[ProviderSpec]:
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_type="nvidia", adapter_type="nvidia",
pip_packages=[ pip_packages=[
"datasets", "datasets>=4.0.0",
], ],
module="llama_stack.providers.remote.datasetio.nvidia", module="llama_stack.providers.remote.datasetio.nvidia",
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig", config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",

View file

@ -40,8 +40,9 @@ def available_providers() -> list[ProviderSpec]:
InlineProviderSpec( InlineProviderSpec(
api=Api.inference, api=Api.inference,
provider_type="inline::sentence-transformers", provider_type="inline::sentence-transformers",
# CrossEncoder depends on torchao.quantization
pip_packages=[ pip_packages=[
"torch torchvision --index-url https://download.pytorch.org/whl/cpu", "torch torchvision torchao>=0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu",
"sentence-transformers --no-deps", "sentence-transformers --no-deps",
], ],
module="llama_stack.providers.inline.inference.sentence_transformers", module="llama_stack.providers.inline.inference.sentence_transformers",
@ -74,7 +75,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_type="vllm", adapter_type="vllm",
pip_packages=["openai"], pip_packages=[],
module="llama_stack.providers.remote.inference.vllm", module="llama_stack.providers.remote.inference.vllm",
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig", config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
description="Remote vLLM inference provider for connecting to vLLM servers.", description="Remote vLLM inference provider for connecting to vLLM servers.",
@ -115,7 +116,7 @@ def available_providers() -> list[ProviderSpec]:
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_type="fireworks", adapter_type="fireworks",
pip_packages=[ pip_packages=[
"fireworks-ai", "fireworks-ai<=0.17.16",
], ],
module="llama_stack.providers.remote.inference.fireworks", module="llama_stack.providers.remote.inference.fireworks",
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig", config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
@ -150,9 +151,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_type="databricks", adapter_type="databricks",
pip_packages=[ pip_packages=[],
"openai",
],
module="llama_stack.providers.remote.inference.databricks", module="llama_stack.providers.remote.inference.databricks",
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig", config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
description="Databricks inference provider for running models on Databricks' unified analytics platform.", description="Databricks inference provider for running models on Databricks' unified analytics platform.",
@ -162,9 +161,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_type="nvidia", adapter_type="nvidia",
pip_packages=[ pip_packages=[],
"openai",
],
module="llama_stack.providers.remote.inference.nvidia", module="llama_stack.providers.remote.inference.nvidia",
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig", config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.", description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
@ -174,7 +171,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_type="runpod", adapter_type="runpod",
pip_packages=["openai"], pip_packages=[],
module="llama_stack.providers.remote.inference.runpod", module="llama_stack.providers.remote.inference.runpod",
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig", config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
description="RunPod inference provider for running models on RunPod's cloud GPU platform.", description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
@ -291,7 +288,7 @@ Available Models:
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_type="watsonx", adapter_type="watsonx",
pip_packages=["ibm_watson_machine_learning"], pip_packages=["ibm_watsonx_ai"],
module="llama_stack.providers.remote.inference.watsonx", module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",

Some files were not shown because too many files have changed in this diff Show more